mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[DTensor] Check if tracing for sharding propagation to handle unhashable keys (#160798)
Fixes #159590 This is similar to the reverted commit #156868, except it resolves an issue with two caches becoming misaligned, leading to incorrect objects for stateful placements (i.e. `_MaskPartial`) as in issue #159601. This adds little to no overhead in eager ([see past benchmarks](https://github.com/pytorch/pytorch/pull/156868#issuecomment-3047831149)). This also handles cases such as #159590 where dynamo is disabled during tracing by entering the Python Dispatcher ahead of the sharding propogation during compile. Tests are added/modified to handle these, and the list/tuple inputs with the cat op. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160798 Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
1641606aa4
commit
4c45090cf7
|
|
@ -20,7 +20,15 @@ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
|||
)
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
distribute_module,
|
||||
distribute_tensor,
|
||||
DTensor,
|
||||
Partial,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
|
|
@ -88,6 +96,33 @@ aot_eager_graph = aot_autograd(
|
|||
)
|
||||
|
||||
|
||||
def _apply_sharding(mod: nn.Module, shard_dim: int, device_mesh: DeviceMesh):
|
||||
"""
|
||||
Shards on the given dimension if possible, else replicate
|
||||
Args:
|
||||
mod: (nn.Module) Module to shard or replicate
|
||||
shard_dim: (int) Dimension to shard on if possible
|
||||
device_mesh: (DeviceMesh) 1D Device Mesh
|
||||
|
||||
Returns:
|
||||
Sharded DTensor
|
||||
"""
|
||||
|
||||
def shard_module_params(name, module, device_mesh):
|
||||
for name, param in module.named_parameters():
|
||||
placement = Replicate()
|
||||
if shard_dim < len(param.size()):
|
||||
placement = Shard(shard_dim)
|
||||
dist_param = torch.nn.Parameter(
|
||||
distribute_tensor(param, device_mesh, [placement])
|
||||
)
|
||||
name = name.split(".")[-1]
|
||||
module.register_parameter(name, dist_param)
|
||||
|
||||
sharded_mod = distribute_module(mod, device_mesh, shard_module_params)
|
||||
return sharded_mod
|
||||
|
||||
|
||||
class TestDTensorCompile(torch._dynamo.test_case.TestCase):
|
||||
def setUp(self):
|
||||
super(
|
||||
|
|
@ -167,6 +202,8 @@ def forward(self, b_buffer, x):
|
|||
return (view_as_1,)""", # noqa: B950
|
||||
)
|
||||
|
||||
# During tracing, sharding propagation cache is skipped, so an extra dry run for
|
||||
# add is performed in _propagate_tensor_meta_non_cached, hence add_1 instead of add
|
||||
self.assertExpectedInline(
|
||||
str(ep.run_decompositions({}).graph_module.code).strip(),
|
||||
"""\
|
||||
|
|
@ -174,8 +211,8 @@ def forward(self, b_parametrizations_buffer_original0, x):
|
|||
_assert_tensor_metadata = torch.ops.aten._assert_tensor_metadata.default(x, None, None, torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata = None
|
||||
_to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda', index=0)); x = None
|
||||
view = torch.ops.aten.view.default(_to_copy, [4, 4]); _to_copy = None
|
||||
add = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None
|
||||
view_1 = torch.ops.aten.view.default(add, [4, 4]); add = None
|
||||
add_1 = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None
|
||||
view_1 = torch.ops.aten.view.default(add_1, [4, 4]); add_1 = None
|
||||
return (view_1,)""", # noqa: B950
|
||||
)
|
||||
|
||||
|
|
@ -269,7 +306,9 @@ def forward(self, b_parametrizations_buffer_original0, x):
|
|||
.to_local()[0]
|
||||
)
|
||||
|
||||
x = DTensor.from_local(torch.rand(4, 4), mesh, [Shard(0)], run_check=False)
|
||||
x = DTensor.from_local(
|
||||
torch.rand(4, 4, requires_grad=True), mesh, [Shard(0)], run_check=False
|
||||
)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
ref = fn(x)
|
||||
|
||||
|
|
@ -290,7 +329,9 @@ def forward(self, b_parametrizations_buffer_original0, x):
|
|||
for t in torch.tensor_split(x, 2)
|
||||
]
|
||||
|
||||
x = DTensor.from_local(torch.rand(4, 4), mesh, [Shard(0)], run_check=False)
|
||||
x = DTensor.from_local(
|
||||
torch.rand(4, 4, requires_grad=True), mesh, [Shard(0)], run_check=False
|
||||
)
|
||||
ref = fn(x)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True, dynamic=True)
|
||||
|
|
@ -317,6 +358,30 @@ def forward(self, b_parametrizations_buffer_original0, x):
|
|||
res = opt_fn(x)
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
def test_dtensor_dynamic_cat(self):
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
|
||||
# test passing in tuple of DTensors as
|
||||
def fn(x, y):
|
||||
return (
|
||||
torch.cat((x, y), dim=0)
|
||||
.redistribute(device_mesh=x.device_mesh, placements=[Replicate()])
|
||||
.to_local()[0]
|
||||
)
|
||||
|
||||
x = DTensor.from_local(
|
||||
torch.rand(4, 4, requires_grad=True), mesh, [Shard(0)], run_check=False
|
||||
)
|
||||
y = DTensor.from_local(
|
||||
torch.rand(4, 4, requires_grad=True), mesh, [Shard(0)], run_check=False
|
||||
)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
ref = fn(x, y)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True)
|
||||
res = opt_fn(x, y)
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
def test_dtensor_attribute_access_on_intermediate(self):
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
|
||||
|
|
@ -1150,6 +1215,29 @@ class TestDTensorCompileE2E(DTensorTestBase):
|
|||
self.assertEqual(x_ref.grad, x.grad)
|
||||
self.assertEqual(y_ref.grad, y.grad)
|
||||
|
||||
@with_comms
|
||||
def test_compile_embedding_redistribute(self):
|
||||
mesh = self.build_device_mesh()
|
||||
|
||||
class Network(nn.Module):
|
||||
def __init__(self, embedding, mesh):
|
||||
super().__init__()
|
||||
self.mesh = mesh
|
||||
self.embedding = _apply_sharding(embedding, 0, self.mesh)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.embedding(x)
|
||||
x = x.redistribute(self.mesh, [Shard(1)])
|
||||
return x
|
||||
|
||||
embedding = torch.nn.Embedding(10, 20, device=self.device_type)
|
||||
inp = torch.randint(0, 10, (8,), device=self.device_type)
|
||||
ref_out = embedding(inp)
|
||||
sharded_net = torch.compile(Network(embedding, mesh))
|
||||
replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False)
|
||||
output = sharded_net(replicated_inp)
|
||||
self.assertEqual(output.full_tensor(), ref_out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
|||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._dynamo.external_utils import (
|
||||
call_accumulate_grad,
|
||||
call_backward,
|
||||
|
|
@ -344,6 +345,10 @@ class AutogradCompilerInstance:
|
|||
self.stack.enter_context(preserve_node_meta())
|
||||
inputs_origins, sizes_origins, scalars_origins = origins
|
||||
|
||||
# Turn on PythonDispatcher during initial trace to make it identifiable
|
||||
# that tracing is happening, which is needed to prevent hashing symints
|
||||
self.stack.enter_context(enable_python_dispatcher())
|
||||
|
||||
# tensor inputs to fake tensors
|
||||
x = inputs[0] # mypy will complain about unbound x
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ import sympy
|
|||
|
||||
import torch
|
||||
from torch import SymInt
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._dynamo.utils import (
|
||||
get_metrics_context,
|
||||
is_int_specialization_case,
|
||||
|
|
@ -3497,6 +3498,12 @@ def wrap_to_fake_tensor_and_record(
|
|||
type(e),
|
||||
)
|
||||
|
||||
# Note [enable_python_dispatcher in dynamo]
|
||||
# Dynamo disables itself when it runs fake tensor prop, which means that tensor subclasses
|
||||
# have no way to know (purely based off of global state) if they are currently being run under compile or not.
|
||||
# we use enable_python_dispatcher mainly to tweak the DispatchKeyState so that subclass authors
|
||||
# can check it to know if they are running in an eager context or not
|
||||
with enable_python_dispatcher():
|
||||
fake_e = wrap_fake_exception(
|
||||
lambda: tx.fake_mode.from_tensor(
|
||||
e,
|
||||
|
|
|
|||
|
|
@ -819,6 +819,11 @@ def _are_we_tracing() -> bool:
|
|||
# If fake mode is turned on, we are almost definitely compiling/tracing.
|
||||
if torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) is not None:
|
||||
return True
|
||||
# See Note [enable_python_dispatcher in dynamo]
|
||||
if torch._C._dispatch_tls_is_dispatch_key_included(
|
||||
torch._C.DispatchKey.PythonDispatcher
|
||||
):
|
||||
return True
|
||||
return get_proxy_mode() is not None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from typing import cast, NamedTuple, Optional
|
|||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
import torch.distributed.tensor._api as dtensor
|
||||
from torch.distributed._functional_collectives import _are_we_tracing
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
|
||||
from torch.distributed.tensor.device_mesh import DeviceMesh
|
||||
from torch.distributed.tensor.placement_types import (
|
||||
|
|
@ -181,10 +182,7 @@ def redistribute_local_tensor(
|
|||
# which should be an empty tensor
|
||||
return local_tensor
|
||||
|
||||
has_symints = any(isinstance(s, torch.SymInt) for s in current_spec.shape) or any(
|
||||
isinstance(s, torch.SymInt) for s in target_spec.shape
|
||||
)
|
||||
if has_symints:
|
||||
if _are_we_tracing():
|
||||
transform_infos = _gen_transform_infos_non_cached(current_spec, target_spec)
|
||||
else:
|
||||
transform_infos = _gen_transform_infos(current_spec, target_spec)
|
||||
|
|
|
|||
|
|
@ -320,7 +320,7 @@ class ShardingPropagator:
|
|||
# because SymInts are not hashable.
|
||||
# This is generally ok because this only happens during tracing in torch.compile,
|
||||
# and tracing does not need to be as fast as eagermode DTensor usages.
|
||||
if op_info.schema.has_symints:
|
||||
if _are_we_tracing():
|
||||
output_sharding = self.propagate_op_sharding_non_cached(op_info.schema)
|
||||
else:
|
||||
output_sharding = cast(
|
||||
|
|
@ -338,7 +338,6 @@ class ShardingPropagator:
|
|||
return OutputSharding(None, op_schema)
|
||||
|
||||
out_tensor_meta = self._propagate_tensor_meta_non_cached(op_schema)
|
||||
|
||||
if op_schema.op in self.op_strategy_funcs:
|
||||
# wrap the op_schema with op strategy for sharding strategy propagation
|
||||
strategy_schema = self._wrap_with_op_strategy(op_schema)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user