mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This PR adds a experimental implicit replication support for DTensor to inter-op with torch.Tensor, basically under this context manager DTensor could work together with torch.Tensor by assuming the torch.Tensor sharding layout is replicated. Note that this is risky for DTensor so we don't turn it on by default, but for certain cases where it is for sure replicated, user can use this to allow DTensor and Tensor computation work together Pull Request resolved: https://github.com/pytorch/pytorch/pull/115297 Approved by: https://github.com/awgu
402 lines
16 KiB
Python
402 lines
16 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
import functools
|
|
import operator
|
|
from typing import cast, Dict, List, Optional, Sequence, Tuple
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
import torch.distributed._tensor.api as dtensor
|
|
import torch.distributed._tensor.random as random
|
|
from torch.distributed._tensor._utils import try_find_mesh_from_args
|
|
from torch.distributed._tensor.op_schema import (
|
|
_is_inplace_op,
|
|
_is_out_variant_op,
|
|
OpInfo,
|
|
OpSchema,
|
|
OutputSpecType,
|
|
)
|
|
from torch.distributed._tensor.placement_types import DTensorSpec, Replicate, TensorMeta
|
|
from torch.distributed._tensor.random import is_rng_supported_mesh
|
|
from torch.distributed._tensor.redistribute import redistribute_local_tensor
|
|
from torch.distributed._tensor.sharding_prop import ShardingPropagator
|
|
from torch.distributed._tensor.tp_conv import (
|
|
convolution_backward_handler,
|
|
convolution_handler,
|
|
)
|
|
from torch.distributed.device_mesh import DeviceMesh
|
|
|
|
try:
|
|
from torch.utils import _cxx_pytree as pytree
|
|
except ImportError:
|
|
from torch.utils import _pytree as pytree # type: ignore[no-redef]
|
|
|
|
aten = torch.ops.aten
|
|
|
|
|
|
def decompose_handler(
|
|
op_call: torch._ops.OpOverload,
|
|
args: Tuple[object, ...],
|
|
kwargs: Dict[str, object],
|
|
) -> object:
|
|
"""
|
|
Decomposes a op to core ATen op, this handler is mostly here
|
|
for inference mode usage where the ops are not core aten ops.
|
|
"""
|
|
r = op_call.decompose(*args, **kwargs)
|
|
if r is not NotImplemented:
|
|
return r
|
|
else:
|
|
raise RuntimeError("Decomposition failed")
|
|
|
|
|
|
def is_same_size_handler(
|
|
op_call: torch._ops.OpOverload,
|
|
args: Tuple[object, ...],
|
|
kwargs: Dict[str, object],
|
|
) -> bool:
|
|
lhs = cast(torch.Tensor, args[0])
|
|
rhs = cast(torch.Tensor, args[1])
|
|
return lhs.shape == rhs.shape
|
|
|
|
|
|
class OpDispatcher:
|
|
"""
|
|
Op dispatching class instance to handle args/kwargs pre-processing (un-wrapping), sharding
|
|
propagation, redistribute local args, local compute, and post-processing (re-wrapping). It
|
|
also handles any op specific logic if necessary.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self.sharding_propagator = ShardingPropagator()
|
|
self._random_ops = {
|
|
aten.native_dropout.default,
|
|
aten.normal_.default,
|
|
aten.rand_like.default,
|
|
aten.randn_like.default,
|
|
aten.randint_like.default,
|
|
aten.randint_like.low_dtype,
|
|
aten.randint_like.low_dtype_out,
|
|
aten.uniform_.default,
|
|
aten.bernoulli.default,
|
|
aten.bernoulli_.float,
|
|
}
|
|
self._custom_op_handlers = {
|
|
aten.linear.default: decompose_handler,
|
|
aten.is_same_size.default: is_same_size_handler,
|
|
aten.convolution.default: convolution_handler,
|
|
aten.convolution_backward.default: convolution_backward_handler,
|
|
}
|
|
|
|
# This flag is used internally to control whether we treat the torch.Tensor(non-DTensor)
|
|
# as implicitly replicated or we throw error to user.
|
|
# NOTE: It is EXTREMELY UNSAFE to turn this flag on by default so we intentionally leave
|
|
# it as False by default.
|
|
self._allow_implicit_replication = False
|
|
|
|
def dispatch(
|
|
self,
|
|
op_call: torch._ops.OpOverload,
|
|
args: Tuple[object, ...],
|
|
kwargs: Dict[str, object],
|
|
) -> object:
|
|
"""
|
|
Main dispatching logic
|
|
"""
|
|
# operators that does not need to go through sharding propagation
|
|
if op_call in self._custom_op_handlers:
|
|
return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator]
|
|
|
|
# extract local tensor and sharding infos to a OpInfo
|
|
op_info = self.unwrap_to_op_info(op_call, args, kwargs)
|
|
|
|
self.sharding_propagator.propagate(op_info)
|
|
output_sharding = op_info.output_sharding
|
|
assert output_sharding is not None, "output sharding should not be None"
|
|
|
|
mesh = op_info.mesh
|
|
if mesh.get_coordinate() is None:
|
|
# For a non-participating device, we do:
|
|
# 1. if the return type is scalar, set the local result to None.
|
|
# The local results from all devices will then be all-gathered
|
|
# and a reduce op will be performed on the list of results
|
|
# with appropriate operators:
|
|
# for bool type, we by default use AND to reduce;
|
|
# we can extend for more ops if necessary.
|
|
# 2. if the return type is Tensor or List[Tensor], return empty
|
|
# tensor(s) with correct dtype.
|
|
spec = output_sharding.output_spec
|
|
ret_list = op_info.schema.op._schema.returns
|
|
|
|
if spec is None:
|
|
# For a scalar return type, the non-participating device has None
|
|
# as its local result
|
|
local_results: object = None
|
|
else:
|
|
|
|
def default_tensor(spec: DTensorSpec) -> torch.Tensor:
|
|
if spec.tensor_meta is not None:
|
|
shape = spec.tensor_meta.shape
|
|
dtype = spec.tensor_meta.dtype
|
|
if len(shape) == 0:
|
|
# scalar tensor
|
|
return torch.zeros((), dtype=dtype)
|
|
else:
|
|
# non-scalar tensor
|
|
return torch.tensor([], dtype=dtype)
|
|
else:
|
|
raise RuntimeError(f"{spec} has no tensor metadata.")
|
|
|
|
if isinstance(spec, DTensorSpec):
|
|
# return a Tensor value
|
|
local_results = default_tensor(spec)
|
|
elif isinstance(spec, Sequence):
|
|
# return a List[Tensor] value
|
|
local_results = [
|
|
default_tensor(s) if s is not None else None for s in spec
|
|
]
|
|
assert isinstance(local_results, List)
|
|
if None in local_results:
|
|
ret_type = str(ret_list[0].type)
|
|
raise NotImplementedError(
|
|
f"return type {ret_type} in DTensor op is not supported"
|
|
)
|
|
else:
|
|
if output_sharding.needs_redistribute:
|
|
# compute locally with redistribute first if needed
|
|
assert output_sharding.schema_suggestions is not None
|
|
self.redistribute_local_args(
|
|
op_info, output_sharding.schema_suggestions[0]
|
|
)
|
|
|
|
local_tensor_args = (
|
|
pytree.tree_unflatten(
|
|
cast(List[object], op_info.local_args), op_info.args_tree_spec
|
|
)
|
|
if op_info.args_tree_spec
|
|
else op_info.local_args
|
|
)
|
|
|
|
# run local op computation with potentially modified args/kwargs
|
|
local_tensor_args = cast(Tuple[object, ...], local_tensor_args)
|
|
if op_call in self._random_ops and is_rng_supported_mesh(mesh):
|
|
if not random._rng_tracker:
|
|
raise RuntimeError(
|
|
"A CudaRNGStateTracker instance must be instantiated "
|
|
"before executing a random op over a DTensor. "
|
|
"Try calling random.manual_seed() or distribute_tensor() "
|
|
"before executing a DTensor random op."
|
|
)
|
|
# For DTensor random operator, run it within a distribute region
|
|
with random._rng_tracker._distribute_region(
|
|
cast(dtensor.DTensor, args[0])._spec
|
|
):
|
|
local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
|
|
else:
|
|
local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
|
|
|
|
# communicate the result to all ranks for some operators that return scalar value
|
|
if output_sharding.output_spec is None:
|
|
if op_call == aten.equal.default:
|
|
obj_list = [None for _ in range(dist.get_world_size())]
|
|
dist.all_gather_object(obj_list, local_results)
|
|
obj_list = list(filter(lambda x: x is not None, obj_list))
|
|
# perform reduce on the collection with AND op
|
|
local_results = functools.reduce(operator.and_, obj_list, True)
|
|
|
|
if _is_inplace_op(op_call):
|
|
# inplace op should return self instead of re-wrapping
|
|
if output_sharding.output_spec is not None:
|
|
return args[0]
|
|
else:
|
|
return None
|
|
elif _is_out_variant_op(op_call):
|
|
# out variant could possibly have multiple out args (i.e. lu_unpack.out)
|
|
output_specs = (
|
|
(output_sharding.output_spec,)
|
|
if not isinstance(output_sharding.output_spec, tuple)
|
|
else output_sharding.output_spec
|
|
)
|
|
out_dts = []
|
|
spec_idx = 0
|
|
for argument in op_call._schema.arguments:
|
|
if argument.is_out:
|
|
out_dt = cast(dtensor.DTensor, kwargs[argument.name])
|
|
out_dt._spec = cast(DTensorSpec, output_specs[spec_idx])
|
|
out_dts.append(out_dt)
|
|
spec_idx += 1
|
|
|
|
assert len(out_dts) >= 1, "out variant should have at least one out arg"
|
|
return tuple(out_dts) if len(out_dts) > 1 else out_dts[0]
|
|
else:
|
|
return self.wrap(local_results, output_sharding.output_spec)
|
|
|
|
@staticmethod
|
|
def redistribute_local_args(
|
|
op_info: OpInfo,
|
|
suggested_input_schema: OpSchema,
|
|
) -> None:
|
|
# NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it
|
|
|
|
# TODO: the op schema should probably just remain flattened so that we can avoid this tree flatten
|
|
# Need to fix all the ops before doing this.
|
|
if op_info.args_tree_spec is not None:
|
|
flatten_args_schema_to_reshard = tuple(
|
|
pytree.tree_leaves(suggested_input_schema.args_schema)
|
|
)
|
|
else:
|
|
flatten_args_schema_to_reshard = suggested_input_schema.args_schema
|
|
|
|
new_local_args: List[object] = []
|
|
for i, arg_spec in enumerate(op_info.flat_args_schema):
|
|
reshard_arg_spec = flatten_args_schema_to_reshard[i]
|
|
if isinstance(arg_spec, DTensorSpec):
|
|
local_tensor = cast(torch.Tensor, op_info.local_args[i])
|
|
if arg_spec != reshard_arg_spec:
|
|
resharded_local_tensor = redistribute_local_tensor(
|
|
local_tensor, arg_spec, reshard_arg_spec
|
|
)
|
|
new_local_args.append(resharded_local_tensor)
|
|
else:
|
|
new_local_args.append(local_tensor)
|
|
else:
|
|
new_local_args.append(reshard_arg_spec)
|
|
|
|
op_info.local_args = tuple(new_local_args)
|
|
|
|
def unwrap_to_op_info(
|
|
self,
|
|
op_call: torch._ops.OpOverload,
|
|
args: Tuple[object, ...],
|
|
kwargs: Dict[str, object],
|
|
) -> OpInfo:
|
|
# get runtime schema to determine whether to use pytree to flatten inputs
|
|
runtime_schema_info = self.sharding_propagator.op_to_schema_info.get(
|
|
op_call, None
|
|
)
|
|
|
|
if runtime_schema_info is not None and runtime_schema_info.needs_pytree:
|
|
# flatten args/kwargs when necessary
|
|
tree_args, args_spec = pytree.tree_flatten(args)
|
|
args_list: Sequence[object] = tree_args
|
|
else:
|
|
args_list, args_spec = args, None
|
|
|
|
args_schema: List[object] = []
|
|
kwargs_schema: Dict[str, object] = {}
|
|
local_args: List[object] = []
|
|
local_kwargs: Dict[str, object] = {}
|
|
mesh: Optional[DeviceMesh] = None
|
|
|
|
for arg in args_list:
|
|
if isinstance(arg, dtensor.DTensor):
|
|
args_schema.append(arg._spec)
|
|
local_args.append(arg._local_tensor)
|
|
if mesh is not None:
|
|
if mesh != arg.device_mesh:
|
|
raise NotImplementedError(
|
|
f"{op_call}: DTensor does not support cross-mesh operation yet!"
|
|
)
|
|
else:
|
|
mesh = arg.device_mesh
|
|
elif isinstance(arg, torch.Tensor):
|
|
if arg.ndim == 0 or self._allow_implicit_replication:
|
|
mesh = mesh or try_find_mesh_from_args(op_call, args_list)
|
|
# scalar tensor can be safely treated as replicated
|
|
args_schema.append(
|
|
DTensorSpec(
|
|
mesh,
|
|
(Replicate(),) * mesh.ndim,
|
|
tensor_meta=TensorMeta(
|
|
shape=arg.shape, stride=arg.stride(), dtype=arg.dtype
|
|
),
|
|
)
|
|
)
|
|
local_args.append(arg)
|
|
else:
|
|
raise RuntimeError(
|
|
f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
|
|
" torch.Tensor to DTensor before calling distributed operators!"
|
|
)
|
|
else:
|
|
args_schema.append(arg)
|
|
local_args.append(arg)
|
|
|
|
for k, v in kwargs.items():
|
|
if isinstance(v, dtensor.DTensor):
|
|
kwargs_schema[k] = v._spec
|
|
local_kwargs[k] = v._local_tensor
|
|
if mesh is not None:
|
|
if mesh != v.device_mesh:
|
|
raise NotImplementedError(
|
|
f"{op_call}: DTensor does not support cross-mesh operation yet!"
|
|
)
|
|
else:
|
|
mesh = v.device_mesh
|
|
elif isinstance(v, torch.Tensor):
|
|
raise RuntimeError(
|
|
f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
|
|
" torch.Tensor to DTensor before calling distributed operators!"
|
|
)
|
|
else:
|
|
kwargs_schema[k] = v
|
|
local_kwargs[k] = v
|
|
|
|
assert mesh is not None, f"found no DeviceMesh from dtensor args for {op_call}!"
|
|
op_info = OpInfo(
|
|
mesh,
|
|
OpSchema(
|
|
op_call,
|
|
pytree.tree_unflatten(args_schema, args_spec)
|
|
if args_spec
|
|
else tuple(args_schema),
|
|
kwargs_schema,
|
|
schema_info=runtime_schema_info,
|
|
),
|
|
args_schema,
|
|
tuple(local_args),
|
|
local_kwargs,
|
|
args_spec,
|
|
)
|
|
return op_info
|
|
|
|
@staticmethod
|
|
def wrap(res: object, spec: OutputSpecType) -> object:
|
|
def to_dt(res, spec):
|
|
assert spec is not None and isinstance(
|
|
spec, DTensorSpec
|
|
), f"output spec does not match with output! Expected DTensorSpec, got {spec}."
|
|
assert spec.tensor_meta is not None
|
|
return dtensor.DTensor(
|
|
res,
|
|
spec.mesh,
|
|
spec.placements,
|
|
shape=spec.tensor_meta.shape,
|
|
dtype=spec.tensor_meta.dtype,
|
|
requires_grad=res.requires_grad,
|
|
stride=spec.tensor_meta.stride,
|
|
)
|
|
|
|
if isinstance(res, torch.Tensor):
|
|
return to_dt(res, spec)
|
|
elif isinstance(res, (list, tuple)):
|
|
assert spec is not None and isinstance(
|
|
spec, (list, tuple)
|
|
), f"output spec does not match with output! Expected list/tuple, got {spec}."
|
|
res_list = []
|
|
for e, s in zip(res, spec):
|
|
# NOTE: local results might return Optional Tensor from ATen op, so we need
|
|
# to handle that case and make sure we don't wrap None with DTensor.
|
|
# (i.e. native_layer_norm.backward)
|
|
if isinstance(e, (list, tuple)) and isinstance(s, (list, tuple)):
|
|
res_list.append(type(e)([to_dt(ee, ss) for ee, ss in zip(e, s)]))
|
|
elif e is not None and s is not None:
|
|
res_list.append(to_dt(e, s))
|
|
else:
|
|
res_list.append(None) # type: ignore[arg-type]
|
|
|
|
return tuple(res_list) if isinstance(res, tuple) else res_list
|
|
else:
|
|
# if the res contains only non tensor values, we simply return it without rewrapping
|
|
return res
|