DTensor: dont hash symint tensor input in propagate_tensor_meta (#136266)

This fixes a subset of issues for dynamic shapes + DTensor.

It's pretty easy to run into other issues - it's likely that we need https://github.com/pytorch/pytorch/pull/125941 to land for DTensor + dynamic shapes to work more generally. I ended up writing a test that had dynamic shape inputs but not dynamic shape outputs in order to properly test this fix

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136266
Approved by: https://github.com/ezyang, https://github.com/yf225
This commit is contained in:
Brian Hirsh 2024-09-18 11:57:01 -07:00 committed by PyTorch MergeBot
parent 7bbdf87517
commit 172ecf78b7
3 changed files with 43 additions and 6 deletions

View File

@ -176,6 +176,25 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
res = opt_fn(x)
self.assertEqual(res, ref)
def test_dtensor_dynamic(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
# test passing in DTensor as inputs/outputs and run some tensor computation
def fn(x):
return (
torch.mul(x, x)
.redistribute(device_mesh=x.device_mesh, placements=[Replicate()])
.to_local()[0]
)
x = DTensor.from_local(torch.rand(4, 4), mesh, [Shard(0)], run_check=False)
torch._dynamo.mark_dynamic(x, 0)
ref = fn(x)
opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(res, ref)
def test_dtensor_attribute_access_on_intermediate(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))

View File

@ -27,8 +27,7 @@ class _TransformInfo(NamedTuple):
logical_shape: List[int]
@lru_cache(maxsize=None)
def _gen_transform_infos(
def _gen_transform_infos_non_cached(
src_spec: DTensorSpec,
dst_spec: DTensorSpec,
) -> List[_TransformInfo]:
@ -146,6 +145,14 @@ def _gen_transform_infos(
return transform_infos
@lru_cache(maxsize=None)
def _gen_transform_infos(
src_spec: DTensorSpec,
dst_spec: DTensorSpec,
) -> List[_TransformInfo]:
return _gen_transform_infos_non_cached(src_spec, dst_spec)
def redistribute_local_tensor(
local_tensor: torch.Tensor,
current_spec: DTensorSpec,
@ -174,7 +181,13 @@ def redistribute_local_tensor(
# which should be an empty tensor
return local_tensor
transform_infos = _gen_transform_infos(current_spec, target_spec)
has_symints = any(isinstance(s, torch.SymInt) for s in current_spec.shape) or any(
isinstance(s, torch.SymInt) for s in target_spec.shape
)
if has_symints:
transform_infos = _gen_transform_infos_non_cached(current_spec, target_spec)
else:
transform_infos = _gen_transform_infos(current_spec, target_spec)
for transform_info in transform_infos:
i = transform_info.mesh_dim

View File

@ -104,8 +104,7 @@ class ShardingPropagator:
if schema_info is not None:
self.op_to_schema_info[op_overload] = schema_info
@lru_cache # noqa: B019
def _propagate_tensor_meta(
def _propagate_tensor_meta_non_cached(
self, op_schema: OpSchema
) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]:
"""
@ -150,6 +149,12 @@ class ShardingPropagator:
# if fake is not a tensor or tuple of tensor, return as none
return None
@lru_cache # noqa: B019
def _propagate_tensor_meta(
self, op_schema: OpSchema
) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]:
return self._propagate_tensor_meta_non_cached(op_schema)
def _wrap_output_spec_tensor_meta(
self,
op: OpOverload,
@ -211,7 +216,7 @@ class ShardingPropagator:
if op_schema.op is aten._local_scalar_dense.default:
return OutputSharding(None, op_schema)
out_tensor_meta = self._propagate_tensor_meta(op_schema)
out_tensor_meta = self._propagate_tensor_meta_non_cached(op_schema)
def spec_to_strategy(spec: object) -> object:
if isinstance(spec, DTensorSpec):