Ensure outer aliasing on DTensor matches inner aliasing (#158954)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158954
Approved by: https://github.com/albanD, https://github.com/wconstab
This commit is contained in:
Edward Z. Yang 2025-08-12 06:23:03 -07:00 committed by PyTorch MergeBot
parent ee9f8ba11d
commit 3cec82a7e9
2 changed files with 14 additions and 2 deletions

View File

@ -23,6 +23,7 @@ from torch.distributed.tensor._tp_conv import (
) )
from torch.distributed.tensor._utils import try_find_mesh_from_args from torch.distributed.tensor._utils import try_find_mesh_from_args
from torch.distributed.tensor.placement_types import Partial, Placement, Replicate from torch.distributed.tensor.placement_types import Partial, Placement, Replicate
from torch.utils._python_dispatch import return_and_correct_aliasing
try: try:
@ -164,7 +165,8 @@ class OpDispatcher:
assert output_sharding is not None, "output sharding should not be None" assert output_sharding is not None, "output sharding should not be None"
mesh = op_info.compute_mesh mesh = op_info.compute_mesh
if mesh.get_coordinate() is not None: participating = mesh.get_coordinate() is not None
if participating:
# computation that happens in the current rank of the mesh, normal case # computation that happens in the current rank of the mesh, normal case
if output_sharding.needs_redistribute: if output_sharding.needs_redistribute:
# If sharding propagation decision needs redistribute, perform redistribute # If sharding propagation decision needs redistribute, perform redistribute
@ -299,7 +301,11 @@ class OpDispatcher:
assert len(out_dts) >= 1, "out variant should have at least one out arg" assert len(out_dts) >= 1, "out variant should have at least one out arg"
return tuple(out_dts) if len(out_dts) > 1 else out_dts[0] return tuple(out_dts) if len(out_dts) > 1 else out_dts[0]
else: else:
return self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined] ret = self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined]
if participating and op_info.schema.is_view_op():
return return_and_correct_aliasing(op_call, args, kwargs, ret)
else:
return ret
@staticmethod @staticmethod
def redistribute_local_args( def redistribute_local_args(

View File

@ -450,6 +450,12 @@ class OpSchema:
# be entirely correct, but it's good enough for now. # be entirely correct, but it's good enough for now.
return "out" in self.op._schema.overload_name return "out" in self.op._schema.overload_name
def is_view_op(self) -> bool:
return any(
a.alias_info is not None and not a.alias_info.is_write
for a in self.op._schema.arguments
)
def __hash__(self) -> int: def __hash__(self) -> int:
# Only hash args and kwargs that op indicates to hash # Only hash args and kwargs that op indicates to hash
if not self.schema_info: if not self.schema_info: