mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
DTensor: dont hash symint tensor input in propagate_tensor_meta (#136266)
This fixes a subset of issues for dynamic shapes + DTensor. It's pretty easy to run into other issues - it's likely that we need https://github.com/pytorch/pytorch/pull/125941 to land for DTensor + dynamic shapes to work more generally. I ended up writing a test that had dynamic shape inputs but not dynamic shape outputs in order to properly test this fix Pull Request resolved: https://github.com/pytorch/pytorch/pull/136266 Approved by: https://github.com/ezyang, https://github.com/yf225
This commit is contained in:
parent
7bbdf87517
commit
172ecf78b7
|
|
@ -176,6 +176,25 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
|
|||
res = opt_fn(x)
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
def test_dtensor_dynamic(self):
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
|
||||
# test passing in DTensor as inputs/outputs and run some tensor computation
|
||||
def fn(x):
|
||||
return (
|
||||
torch.mul(x, x)
|
||||
.redistribute(device_mesh=x.device_mesh, placements=[Replicate()])
|
||||
.to_local()[0]
|
||||
)
|
||||
|
||||
x = DTensor.from_local(torch.rand(4, 4), mesh, [Shard(0)], run_check=False)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
ref = fn(x)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
def test_dtensor_attribute_access_on_intermediate(self):
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
|
||||
|
|
|
|||
|
|
@ -27,8 +27,7 @@ class _TransformInfo(NamedTuple):
|
|||
logical_shape: List[int]
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _gen_transform_infos(
|
||||
def _gen_transform_infos_non_cached(
|
||||
src_spec: DTensorSpec,
|
||||
dst_spec: DTensorSpec,
|
||||
) -> List[_TransformInfo]:
|
||||
|
|
@ -146,6 +145,14 @@ def _gen_transform_infos(
|
|||
return transform_infos
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _gen_transform_infos(
|
||||
src_spec: DTensorSpec,
|
||||
dst_spec: DTensorSpec,
|
||||
) -> List[_TransformInfo]:
|
||||
return _gen_transform_infos_non_cached(src_spec, dst_spec)
|
||||
|
||||
|
||||
def redistribute_local_tensor(
|
||||
local_tensor: torch.Tensor,
|
||||
current_spec: DTensorSpec,
|
||||
|
|
@ -174,7 +181,13 @@ def redistribute_local_tensor(
|
|||
# which should be an empty tensor
|
||||
return local_tensor
|
||||
|
||||
transform_infos = _gen_transform_infos(current_spec, target_spec)
|
||||
has_symints = any(isinstance(s, torch.SymInt) for s in current_spec.shape) or any(
|
||||
isinstance(s, torch.SymInt) for s in target_spec.shape
|
||||
)
|
||||
if has_symints:
|
||||
transform_infos = _gen_transform_infos_non_cached(current_spec, target_spec)
|
||||
else:
|
||||
transform_infos = _gen_transform_infos(current_spec, target_spec)
|
||||
|
||||
for transform_info in transform_infos:
|
||||
i = transform_info.mesh_dim
|
||||
|
|
|
|||
|
|
@ -104,8 +104,7 @@ class ShardingPropagator:
|
|||
if schema_info is not None:
|
||||
self.op_to_schema_info[op_overload] = schema_info
|
||||
|
||||
@lru_cache # noqa: B019
|
||||
def _propagate_tensor_meta(
|
||||
def _propagate_tensor_meta_non_cached(
|
||||
self, op_schema: OpSchema
|
||||
) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]:
|
||||
"""
|
||||
|
|
@ -150,6 +149,12 @@ class ShardingPropagator:
|
|||
# if fake is not a tensor or tuple of tensor, return as none
|
||||
return None
|
||||
|
||||
@lru_cache # noqa: B019
|
||||
def _propagate_tensor_meta(
|
||||
self, op_schema: OpSchema
|
||||
) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]:
|
||||
return self._propagate_tensor_meta_non_cached(op_schema)
|
||||
|
||||
def _wrap_output_spec_tensor_meta(
|
||||
self,
|
||||
op: OpOverload,
|
||||
|
|
@ -211,7 +216,7 @@ class ShardingPropagator:
|
|||
if op_schema.op is aten._local_scalar_dense.default:
|
||||
return OutputSharding(None, op_schema)
|
||||
|
||||
out_tensor_meta = self._propagate_tensor_meta(op_schema)
|
||||
out_tensor_meta = self._propagate_tensor_meta_non_cached(op_schema)
|
||||
|
||||
def spec_to_strategy(spec: object) -> object:
|
||||
if isinstance(spec, DTensorSpec):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user