[DTensor] Check if tracing for sharding propagation to handle unhashable keys (#160798)

Fixes #159590

This is similar to the reverted commit #156868, except it resolves an issue with two caches becoming misaligned, leading to incorrect objects for stateful placements (i.e. `_MaskPartial`) as in issue #159601. This adds little to no overhead in eager ([see past benchmarks](https://github.com/pytorch/pytorch/pull/156868#issuecomment-3047831149)).

This also handles cases such as #159590  where dynamo is disabled during tracing by entering the Python Dispatcher ahead of the sharding propogation during compile. Tests are added/modified to handle these, and the list/tuple inputs with the cat op.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160798
Approved by: https://github.com/bdhirsh
This commit is contained in:
Arsh Zahed 2025-09-09 03:52:05 +00:00 committed by PyTorch MergeBot
parent 1641606aa4
commit 4c45090cf7
6 changed files with 119 additions and 17 deletions

View File

@ -20,7 +20,15 @@ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
from torch.distributed.tensor import (
DeviceMesh,
distribute_module,
distribute_tensor,
DTensor,
Partial,
Replicate,
Shard,
)
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor.parallel import (
ColwiseParallel,
@ -88,6 +96,33 @@ aot_eager_graph = aot_autograd(
)
def _apply_sharding(mod: nn.Module, shard_dim: int, device_mesh: DeviceMesh):
"""
Shards on the given dimension if possible, else replicate
Args:
mod: (nn.Module) Module to shard or replicate
shard_dim: (int) Dimension to shard on if possible
device_mesh: (DeviceMesh) 1D Device Mesh
Returns:
Sharded DTensor
"""
def shard_module_params(name, module, device_mesh):
for name, param in module.named_parameters():
placement = Replicate()
if shard_dim < len(param.size()):
placement = Shard(shard_dim)
dist_param = torch.nn.Parameter(
distribute_tensor(param, device_mesh, [placement])
)
name = name.split(".")[-1]
module.register_parameter(name, dist_param)
sharded_mod = distribute_module(mod, device_mesh, shard_module_params)
return sharded_mod
class TestDTensorCompile(torch._dynamo.test_case.TestCase):
def setUp(self):
super(
@ -167,6 +202,8 @@ def forward(self, b_buffer, x):
return (view_as_1,)""", # noqa: B950
)
# During tracing, sharding propagation cache is skipped, so an extra dry run for
# add is performed in _propagate_tensor_meta_non_cached, hence add_1 instead of add
self.assertExpectedInline(
str(ep.run_decompositions({}).graph_module.code).strip(),
"""\
@ -174,8 +211,8 @@ def forward(self, b_parametrizations_buffer_original0, x):
_assert_tensor_metadata = torch.ops.aten._assert_tensor_metadata.default(x, None, None, torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata = None
_to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda', index=0)); x = None
view = torch.ops.aten.view.default(_to_copy, [4, 4]); _to_copy = None
add = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None
view_1 = torch.ops.aten.view.default(add, [4, 4]); add = None
add_1 = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None
view_1 = torch.ops.aten.view.default(add_1, [4, 4]); add_1 = None
return (view_1,)""", # noqa: B950
)
@ -269,7 +306,9 @@ def forward(self, b_parametrizations_buffer_original0, x):
.to_local()[0]
)
x = DTensor.from_local(torch.rand(4, 4), mesh, [Shard(0)], run_check=False)
x = DTensor.from_local(
torch.rand(4, 4, requires_grad=True), mesh, [Shard(0)], run_check=False
)
torch._dynamo.mark_dynamic(x, 0)
ref = fn(x)
@ -290,7 +329,9 @@ def forward(self, b_parametrizations_buffer_original0, x):
for t in torch.tensor_split(x, 2)
]
x = DTensor.from_local(torch.rand(4, 4), mesh, [Shard(0)], run_check=False)
x = DTensor.from_local(
torch.rand(4, 4, requires_grad=True), mesh, [Shard(0)], run_check=False
)
ref = fn(x)
opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True, dynamic=True)
@ -317,6 +358,30 @@ def forward(self, b_parametrizations_buffer_original0, x):
res = opt_fn(x)
self.assertEqual(res, ref)
def test_dtensor_dynamic_cat(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
# test passing in tuple of DTensors as
def fn(x, y):
return (
torch.cat((x, y), dim=0)
.redistribute(device_mesh=x.device_mesh, placements=[Replicate()])
.to_local()[0]
)
x = DTensor.from_local(
torch.rand(4, 4, requires_grad=True), mesh, [Shard(0)], run_check=False
)
y = DTensor.from_local(
torch.rand(4, 4, requires_grad=True), mesh, [Shard(0)], run_check=False
)
torch._dynamo.mark_dynamic(x, 0)
ref = fn(x, y)
opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True)
res = opt_fn(x, y)
self.assertEqual(res, ref)
def test_dtensor_attribute_access_on_intermediate(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
@ -1150,6 +1215,29 @@ class TestDTensorCompileE2E(DTensorTestBase):
self.assertEqual(x_ref.grad, x.grad)
self.assertEqual(y_ref.grad, y.grad)
@with_comms
def test_compile_embedding_redistribute(self):
mesh = self.build_device_mesh()
class Network(nn.Module):
def __init__(self, embedding, mesh):
super().__init__()
self.mesh = mesh
self.embedding = _apply_sharding(embedding, 0, self.mesh)
def forward(self, x):
x = self.embedding(x)
x = x.redistribute(self.mesh, [Shard(1)])
return x
embedding = torch.nn.Embedding(10, 20, device=self.device_type)
inp = torch.randint(0, 10, (8,), device=self.device_type)
ref_out = embedding(inp)
sharded_net = torch.compile(Network(embedding, mesh))
replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False)
output = sharded_net(replicated_inp)
self.assertEqual(output.full_tensor(), ref_out)
if __name__ == "__main__":
run_tests()

View File

@ -25,6 +25,7 @@ from typing import Any, Callable, Optional, TYPE_CHECKING, Union
import torch
import torch.utils._pytree as pytree
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.external_utils import (
call_accumulate_grad,
call_backward,
@ -344,6 +345,10 @@ class AutogradCompilerInstance:
self.stack.enter_context(preserve_node_meta())
inputs_origins, sizes_origins, scalars_origins = origins
# Turn on PythonDispatcher during initial trace to make it identifiable
# that tracing is happening, which is needed to prevent hashing symints
self.stack.enter_context(enable_python_dispatcher())
# tensor inputs to fake tensors
x = inputs[0] # mypy will complain about unbound x
try:

View File

@ -44,6 +44,7 @@ import sympy
import torch
from torch import SymInt
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.utils import (
get_metrics_context,
is_int_specialization_case,
@ -3497,6 +3498,12 @@ def wrap_to_fake_tensor_and_record(
type(e),
)
# Note [enable_python_dispatcher in dynamo]
# Dynamo disables itself when it runs fake tensor prop, which means that tensor subclasses
# have no way to know (purely based off of global state) if they are currently being run under compile or not.
# we use enable_python_dispatcher mainly to tweak the DispatchKeyState so that subclass authors
# can check it to know if they are running in an eager context or not
with enable_python_dispatcher():
fake_e = wrap_fake_exception(
lambda: tx.fake_mode.from_tensor(
e,

View File

@ -819,6 +819,11 @@ def _are_we_tracing() -> bool:
# If fake mode is turned on, we are almost definitely compiling/tracing.
if torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) is not None:
return True
# See Note [enable_python_dispatcher in dynamo]
if torch._C._dispatch_tls_is_dispatch_key_included(
torch._C.DispatchKey.PythonDispatcher
):
return True
return get_proxy_mode() is not None

View File

@ -7,6 +7,7 @@ from typing import cast, NamedTuple, Optional
import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.tensor._api as dtensor
from torch.distributed._functional_collectives import _are_we_tracing
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor.device_mesh import DeviceMesh
from torch.distributed.tensor.placement_types import (
@ -181,10 +182,7 @@ def redistribute_local_tensor(
# which should be an empty tensor
return local_tensor
has_symints = any(isinstance(s, torch.SymInt) for s in current_spec.shape) or any(
isinstance(s, torch.SymInt) for s in target_spec.shape
)
if has_symints:
if _are_we_tracing():
transform_infos = _gen_transform_infos_non_cached(current_spec, target_spec)
else:
transform_infos = _gen_transform_infos(current_spec, target_spec)

View File

@ -320,7 +320,7 @@ class ShardingPropagator:
# because SymInts are not hashable.
# This is generally ok because this only happens during tracing in torch.compile,
# and tracing does not need to be as fast as eagermode DTensor usages.
if op_info.schema.has_symints:
if _are_we_tracing():
output_sharding = self.propagate_op_sharding_non_cached(op_info.schema)
else:
output_sharding = cast(
@ -338,7 +338,6 @@ class ShardingPropagator:
return OutputSharding(None, op_schema)
out_tensor_meta = self._propagate_tensor_meta_non_cached(op_schema)
if op_schema.op in self.op_strategy_funcs:
# wrap the op_schema with op strategy for sharding strategy propagation
strategy_schema = self._wrap_with_op_strategy(op_schema)