From d7040e6d7515cea485824d2b810bea94e5958dea Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 29 Oct 2025 16:52:31 +0000 Subject: [PATCH] Revert "[dynamo][guards] 1/N Guard selectively for DTensor (#165824)" This reverts commit ee7434be822cf6e75b4566d8159f550ee233d8ae. Reverted https://github.com/pytorch/pytorch/pull/165824 on behalf of https://github.com/anijain2305 due to internal job failed ([comment](https://github.com/pytorch/pytorch/pull/165824#issuecomment-3462667536)) --- .../tensor/test_dtensor_compile.py | 19 ----- torch/_dynamo/guards.py | 13 ---- torch/_dynamo/variables/builder.py | 73 ++++--------------- torch/distributed/tensor/_api.py | 2 - 4 files changed, 14 insertions(+), 93 deletions(-) diff --git a/test/distributed/tensor/test_dtensor_compile.py b/test/distributed/tensor/test_dtensor_compile.py index ddba3150b05..b82e9c97b57 100644 --- a/test/distributed/tensor/test_dtensor_compile.py +++ b/test/distributed/tensor/test_dtensor_compile.py @@ -464,25 +464,6 @@ def forward(self, b_parametrizations_buffer_original0, x): run(g, 64, 8) self.assertEqual(cnt.frame_count, 2) - def test_dtensor_requires_grad_recompile(self): - cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) - - @torch.compile(backend=cnt, fullgraph=True) - def f(x): - y = x * x - return y.to_local() - - full_x = torch.randn(8, 8, requires_grad=False) - x = distribute_tensor(full_x, mesh, [Shard(0)]) - f(x) - - full_x = torch.randn(8, 8, requires_grad=True) - x = distribute_tensor(full_x, mesh, [Shard(0)]) - f(x) - - self.assertEqual(cnt.frame_count, 2) - def test_dtensor_attribute_access_on_intermediate(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 2e15d8b7530..d5869b9b29f 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -2150,19 +2150,6 @@ class GuardBuilder(GuardBuilderBase): metadata_checker, get_verbose_code_parts(global_name, guard) ) - def DTENSOR_SPEC_MATCH(self, guard: Guard) -> None: - # Copied from DTensor __metadata_guard__ - # TODO - Consider moving this to C++ if stable - value = deepcopy(self.get(guard.name)) - - def guard_fn(x: Any) -> bool: - return x._check_equals(value, skip_shapes=True) - - code = f"__dtensor_spec_{id(guard_fn)}" - self.get_guard_manager(guard).add_lambda_guard( - guard_fn, get_verbose_code_parts(code, guard) - ) - def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None) -> None: ref = self.arg_ref(guard) val = self.get(guard.name) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index f5d851fafac..2a1cff0211f 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -2229,70 +2229,25 @@ class VariableBuilder: if isinstance(source, GradSource) and is_from_optimizer_source(source): guard_type = GuardBuilder.NOT_NONE_MATCH - is_dtensor = torch.distributed.is_available() and isinstance( - value, torch.distributed.tensor.DTensor - ) - if not is_dtensor: - # We guard on the _local_tensor and the _spec, and therefore we dont - # have to guard on the outer DTensor. - self.install_guards( - functools.partial( - guard_type, - value=( - value - if isinstance(source, NumpyTensorSource) - else TensorWeakRef(value) - ), - ) + self.install_guards( + functools.partial( + guard_type, + value=( + value + if isinstance(source, NumpyTensorSource) + else TensorWeakRef(value) + ), ) + ) # We install TYPE_MATCH guards for traceable wrapper subclass object, # and recursively install corresponding guard for each inner attribute. if is_traceable_wrapper_subclass(value): - # Tensor subclass guards are very expensive because they are - # implemented in Python. Since DTensor is PyTorch-maintained class, - # we can skip a lot of these guards. - if is_dtensor: - self.install_guards(GuardBuilder.TYPE_MATCH) - - # The inner tensor name is always _local_tensor. If its not, we - # raise assertion to update the check accordingly. - inner_tensor_name = value.__tensor_flatten__()[0][0] - if inner_tensor_name != "_local_tensor": - raise RuntimeError( - "Expecting Dtensor inner tensor name to be _local_tensor" - ) - - # Now selectively guard on the flattening context - flattening_ctx = value.__tensor_flatten__()[1] - # This is supposed to be (self._spec, self.requires_grad) - if not ( - len(flattening_ctx) == 2 - and flattening_ctx[0] == value._spec - and flattening_ctx[1] == value.requires_grad - ): - # If not, raise an assertion to update to the new guards - raise RuntimeError( - "Expecting Dtensor flattening ctx to be _spec, requires_grad" - ) - # Guard on the dtensor spec - install_guard( - AttrSource(self.source, "_spec").make_guard( - GuardBuilder.DTENSOR_SPEC_MATCH - ) - ) - # Move this to C++ - install_guard( - AttrSource(self.source, "requires_grad").make_guard( - GuardBuilder.EQUALS_MATCH - ) - ) - else: - self.install_guards(GuardBuilder.TENSOR_SUBCLASS_METADATA_MATCH) - self.install_guards(GuardBuilder.TYPE_MATCH) - install_guard( - SubclassAttrListSource(source).make_guard(GuardBuilder.EQUALS_MATCH) - ) + self.install_guards(GuardBuilder.TENSOR_SUBCLASS_METADATA_MATCH) + self.install_guards(GuardBuilder.TYPE_MATCH) + install_guard( + SubclassAttrListSource(source).make_guard(GuardBuilder.EQUALS_MATCH) + ) attrs, _ = value.__tensor_flatten__() for attr in attrs: diff --git a/torch/distributed/tensor/_api.py b/torch/distributed/tensor/_api.py index 7d730feb3e0..865de11dacc 100644 --- a/torch/distributed/tensor/_api.py +++ b/torch/distributed/tensor/_api.py @@ -671,8 +671,6 @@ class DTensor(torch.Tensor): def __metadata_guard__( cls, orig: tuple[DTensorSpec, bool], other: tuple[DTensorSpec, bool] ) -> bool: - # TODO - delete this - This is now unused after the PR - - # https://github.com/pytorch/pytorch/pull/165824 orig_spec, orig_requires_grad = orig other_spec, other_requires_grad = other return (