Revert "[dynamo][guards] 1/N Guard selectively for DTensor (#165824)"

This reverts commit ee7434be82.

Reverted https://github.com/pytorch/pytorch/pull/165824 on behalf of https://github.com/anijain2305 due to internal job failed ([comment](https://github.com/pytorch/pytorch/pull/165824#issuecomment-3462667536))
This commit is contained in:
PyTorch MergeBot 2025-10-29 16:52:31 +00:00
parent 35f3572fa4
commit d7040e6d75
4 changed files with 14 additions and 93 deletions

View File

@ -464,25 +464,6 @@ def forward(self, b_parametrizations_buffer_original0, x):
run(g, 64, 8)
self.assertEqual(cnt.frame_count, 2)
def test_dtensor_requires_grad_recompile(self):
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
@torch.compile(backend=cnt, fullgraph=True)
def f(x):
y = x * x
return y.to_local()
full_x = torch.randn(8, 8, requires_grad=False)
x = distribute_tensor(full_x, mesh, [Shard(0)])
f(x)
full_x = torch.randn(8, 8, requires_grad=True)
x = distribute_tensor(full_x, mesh, [Shard(0)])
f(x)
self.assertEqual(cnt.frame_count, 2)
def test_dtensor_attribute_access_on_intermediate(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))

View File

@ -2150,19 +2150,6 @@ class GuardBuilder(GuardBuilderBase):
metadata_checker, get_verbose_code_parts(global_name, guard)
)
def DTENSOR_SPEC_MATCH(self, guard: Guard) -> None:
# Copied from DTensor __metadata_guard__
# TODO - Consider moving this to C++ if stable
value = deepcopy(self.get(guard.name))
def guard_fn(x: Any) -> bool:
return x._check_equals(value, skip_shapes=True)
code = f"__dtensor_spec_{id(guard_fn)}"
self.get_guard_manager(guard).add_lambda_guard(
guard_fn, get_verbose_code_parts(code, guard)
)
def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None) -> None:
ref = self.arg_ref(guard)
val = self.get(guard.name)

View File

@ -2229,70 +2229,25 @@ class VariableBuilder:
if isinstance(source, GradSource) and is_from_optimizer_source(source):
guard_type = GuardBuilder.NOT_NONE_MATCH
is_dtensor = torch.distributed.is_available() and isinstance(
value, torch.distributed.tensor.DTensor
)
if not is_dtensor:
# We guard on the _local_tensor and the _spec, and therefore we dont
# have to guard on the outer DTensor.
self.install_guards(
functools.partial(
guard_type,
value=(
value
if isinstance(source, NumpyTensorSource)
else TensorWeakRef(value)
),
)
self.install_guards(
functools.partial(
guard_type,
value=(
value
if isinstance(source, NumpyTensorSource)
else TensorWeakRef(value)
),
)
)
# We install TYPE_MATCH guards for traceable wrapper subclass object,
# and recursively install corresponding guard for each inner attribute.
if is_traceable_wrapper_subclass(value):
# Tensor subclass guards are very expensive because they are
# implemented in Python. Since DTensor is PyTorch-maintained class,
# we can skip a lot of these guards.
if is_dtensor:
self.install_guards(GuardBuilder.TYPE_MATCH)
# The inner tensor name is always _local_tensor. If its not, we
# raise assertion to update the check accordingly.
inner_tensor_name = value.__tensor_flatten__()[0][0]
if inner_tensor_name != "_local_tensor":
raise RuntimeError(
"Expecting Dtensor inner tensor name to be _local_tensor"
)
# Now selectively guard on the flattening context
flattening_ctx = value.__tensor_flatten__()[1]
# This is supposed to be (self._spec, self.requires_grad)
if not (
len(flattening_ctx) == 2
and flattening_ctx[0] == value._spec
and flattening_ctx[1] == value.requires_grad
):
# If not, raise an assertion to update to the new guards
raise RuntimeError(
"Expecting Dtensor flattening ctx to be _spec, requires_grad"
)
# Guard on the dtensor spec
install_guard(
AttrSource(self.source, "_spec").make_guard(
GuardBuilder.DTENSOR_SPEC_MATCH
)
)
# Move this to C++
install_guard(
AttrSource(self.source, "requires_grad").make_guard(
GuardBuilder.EQUALS_MATCH
)
)
else:
self.install_guards(GuardBuilder.TENSOR_SUBCLASS_METADATA_MATCH)
self.install_guards(GuardBuilder.TYPE_MATCH)
install_guard(
SubclassAttrListSource(source).make_guard(GuardBuilder.EQUALS_MATCH)
)
self.install_guards(GuardBuilder.TENSOR_SUBCLASS_METADATA_MATCH)
self.install_guards(GuardBuilder.TYPE_MATCH)
install_guard(
SubclassAttrListSource(source).make_guard(GuardBuilder.EQUALS_MATCH)
)
attrs, _ = value.__tensor_flatten__()
for attr in attrs:

View File

@ -671,8 +671,6 @@ class DTensor(torch.Tensor):
def __metadata_guard__(
cls, orig: tuple[DTensorSpec, bool], other: tuple[DTensorSpec, bool]
) -> bool:
# TODO - delete this - This is now unused after the PR -
# https://github.com/pytorch/pytorch/pull/165824
orig_spec, orig_requires_grad = orig
other_spec, other_requires_grad = other
return (