mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Ensure outer aliasing on DTensor matches inner aliasing (#158954)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/158954 Approved by: https://github.com/albanD, https://github.com/wconstab
This commit is contained in:
parent
ee9f8ba11d
commit
3cec82a7e9
|
|
@ -23,6 +23,7 @@ from torch.distributed.tensor._tp_conv import (
|
||||||
)
|
)
|
||||||
from torch.distributed.tensor._utils import try_find_mesh_from_args
|
from torch.distributed.tensor._utils import try_find_mesh_from_args
|
||||||
from torch.distributed.tensor.placement_types import Partial, Placement, Replicate
|
from torch.distributed.tensor.placement_types import Partial, Placement, Replicate
|
||||||
|
from torch.utils._python_dispatch import return_and_correct_aliasing
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -164,7 +165,8 @@ class OpDispatcher:
|
||||||
assert output_sharding is not None, "output sharding should not be None"
|
assert output_sharding is not None, "output sharding should not be None"
|
||||||
|
|
||||||
mesh = op_info.compute_mesh
|
mesh = op_info.compute_mesh
|
||||||
if mesh.get_coordinate() is not None:
|
participating = mesh.get_coordinate() is not None
|
||||||
|
if participating:
|
||||||
# computation that happens in the current rank of the mesh, normal case
|
# computation that happens in the current rank of the mesh, normal case
|
||||||
if output_sharding.needs_redistribute:
|
if output_sharding.needs_redistribute:
|
||||||
# If sharding propagation decision needs redistribute, perform redistribute
|
# If sharding propagation decision needs redistribute, perform redistribute
|
||||||
|
|
@ -299,7 +301,11 @@ class OpDispatcher:
|
||||||
assert len(out_dts) >= 1, "out variant should have at least one out arg"
|
assert len(out_dts) >= 1, "out variant should have at least one out arg"
|
||||||
return tuple(out_dts) if len(out_dts) > 1 else out_dts[0]
|
return tuple(out_dts) if len(out_dts) > 1 else out_dts[0]
|
||||||
else:
|
else:
|
||||||
return self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined]
|
ret = self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined]
|
||||||
|
if participating and op_info.schema.is_view_op():
|
||||||
|
return return_and_correct_aliasing(op_call, args, kwargs, ret)
|
||||||
|
else:
|
||||||
|
return ret
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def redistribute_local_args(
|
def redistribute_local_args(
|
||||||
|
|
|
||||||
|
|
@ -450,6 +450,12 @@ class OpSchema:
|
||||||
# be entirely correct, but it's good enough for now.
|
# be entirely correct, but it's good enough for now.
|
||||||
return "out" in self.op._schema.overload_name
|
return "out" in self.op._schema.overload_name
|
||||||
|
|
||||||
|
def is_view_op(self) -> bool:
|
||||||
|
return any(
|
||||||
|
a.alias_info is not None and not a.alias_info.is_write
|
||||||
|
for a in self.op._schema.arguments
|
||||||
|
)
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
# Only hash args and kwargs that op indicates to hash
|
# Only hash args and kwargs that op indicates to hash
|
||||||
if not self.schema_info:
|
if not self.schema_info:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user