diff --git a/test/distributed/_tensor/test_dtensor_compile.py b/test/distributed/_tensor/test_dtensor_compile.py index eaef0716ba3..3f4ddfce781 100644 --- a/test/distributed/_tensor/test_dtensor_compile.py +++ b/test/distributed/_tensor/test_dtensor_compile.py @@ -176,6 +176,25 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase): res = opt_fn(x) self.assertEqual(res, ref) + def test_dtensor_dynamic(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + # test passing in DTensor as inputs/outputs and run some tensor computation + def fn(x): + return ( + torch.mul(x, x) + .redistribute(device_mesh=x.device_mesh, placements=[Replicate()]) + .to_local()[0] + ) + + x = DTensor.from_local(torch.rand(4, 4), mesh, [Shard(0)], run_check=False) + torch._dynamo.mark_dynamic(x, 0) + ref = fn(x) + + opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(res, ref) + def test_dtensor_attribute_access_on_intermediate(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) diff --git a/torch/distributed/tensor/_redistribute.py b/torch/distributed/tensor/_redistribute.py index 6b4e37e5474..88414081a17 100644 --- a/torch/distributed/tensor/_redistribute.py +++ b/torch/distributed/tensor/_redistribute.py @@ -27,8 +27,7 @@ class _TransformInfo(NamedTuple): logical_shape: List[int] -@lru_cache(maxsize=None) -def _gen_transform_infos( +def _gen_transform_infos_non_cached( src_spec: DTensorSpec, dst_spec: DTensorSpec, ) -> List[_TransformInfo]: @@ -146,6 +145,14 @@ def _gen_transform_infos( return transform_infos +@lru_cache(maxsize=None) +def _gen_transform_infos( + src_spec: DTensorSpec, + dst_spec: DTensorSpec, +) -> List[_TransformInfo]: + return _gen_transform_infos_non_cached(src_spec, dst_spec) + + def redistribute_local_tensor( local_tensor: torch.Tensor, current_spec: DTensorSpec, @@ -174,7 +181,13 @@ def redistribute_local_tensor( # which should be an empty tensor return local_tensor - transform_infos = _gen_transform_infos(current_spec, target_spec) + has_symints = any(isinstance(s, torch.SymInt) for s in current_spec.shape) or any( + isinstance(s, torch.SymInt) for s in target_spec.shape + ) + if has_symints: + transform_infos = _gen_transform_infos_non_cached(current_spec, target_spec) + else: + transform_infos = _gen_transform_infos(current_spec, target_spec) for transform_info in transform_infos: i = transform_info.mesh_dim diff --git a/torch/distributed/tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py index ec538d35ed6..2b87d79a342 100644 --- a/torch/distributed/tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -104,8 +104,7 @@ class ShardingPropagator: if schema_info is not None: self.op_to_schema_info[op_overload] = schema_info - @lru_cache # noqa: B019 - def _propagate_tensor_meta( + def _propagate_tensor_meta_non_cached( self, op_schema: OpSchema ) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]: """ @@ -150,6 +149,12 @@ class ShardingPropagator: # if fake is not a tensor or tuple of tensor, return as none return None + @lru_cache # noqa: B019 + def _propagate_tensor_meta( + self, op_schema: OpSchema + ) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]: + return self._propagate_tensor_meta_non_cached(op_schema) + def _wrap_output_spec_tensor_meta( self, op: OpOverload, @@ -211,7 +216,7 @@ class ShardingPropagator: if op_schema.op is aten._local_scalar_dense.default: return OutputSharding(None, op_schema) - out_tensor_meta = self._propagate_tensor_meta(op_schema) + out_tensor_meta = self._propagate_tensor_meta_non_cached(op_schema) def spec_to_strategy(spec: object) -> object: if isinstance(spec, DTensorSpec):