mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Ensure outer aliasing on DTensor matches inner aliasing (#158954)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/158954 Approved by: https://github.com/albanD, https://github.com/wconstab
This commit is contained in:
parent
ee9f8ba11d
commit
3cec82a7e9
|
|
@ -23,6 +23,7 @@ from torch.distributed.tensor._tp_conv import (
|
|||
)
|
||||
from torch.distributed.tensor._utils import try_find_mesh_from_args
|
||||
from torch.distributed.tensor.placement_types import Partial, Placement, Replicate
|
||||
from torch.utils._python_dispatch import return_and_correct_aliasing
|
||||
|
||||
|
||||
try:
|
||||
|
|
@ -164,7 +165,8 @@ class OpDispatcher:
|
|||
assert output_sharding is not None, "output sharding should not be None"
|
||||
|
||||
mesh = op_info.compute_mesh
|
||||
if mesh.get_coordinate() is not None:
|
||||
participating = mesh.get_coordinate() is not None
|
||||
if participating:
|
||||
# computation that happens in the current rank of the mesh, normal case
|
||||
if output_sharding.needs_redistribute:
|
||||
# If sharding propagation decision needs redistribute, perform redistribute
|
||||
|
|
@ -299,7 +301,11 @@ class OpDispatcher:
|
|||
assert len(out_dts) >= 1, "out variant should have at least one out arg"
|
||||
return tuple(out_dts) if len(out_dts) > 1 else out_dts[0]
|
||||
else:
|
||||
return self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined]
|
||||
ret = self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined]
|
||||
if participating and op_info.schema.is_view_op():
|
||||
return return_and_correct_aliasing(op_call, args, kwargs, ret)
|
||||
else:
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def redistribute_local_args(
|
||||
|
|
|
|||
|
|
@ -450,6 +450,12 @@ class OpSchema:
|
|||
# be entirely correct, but it's good enough for now.
|
||||
return "out" in self.op._schema.overload_name
|
||||
|
||||
def is_view_op(self) -> bool:
|
||||
return any(
|
||||
a.alias_info is not None and not a.alias_info.is_write
|
||||
for a in self.op._schema.arguments
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
# Only hash args and kwargs that op indicates to hash
|
||||
if not self.schema_info:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user