mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
DTensor: C++ compute_global_tensor_info (#162990)
compute_global_tensor_info is on the hot path for DTensor.{from,to}_local. More incremental progress toward C++.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162990
Approved by: https://github.com/ezyang
This commit is contained in:
parent
ad559072db
commit
6a5a436624
|
|
@ -43,7 +43,9 @@ from torch._C import (
|
|||
from torch._prims_common import DeviceLikeType
|
||||
from torch.autograd.graph import Node as _Node
|
||||
from torch.cuda import _POOL_HANDLE
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.tensor._op_schema import OpSchema
|
||||
from torch.distributed.tensor.placement_types import Placement
|
||||
from torch.fx.node import Node as FxNode
|
||||
from torch.package import PackageExporter
|
||||
from torch.storage import TypedStorage, UntypedStorage
|
||||
|
|
@ -1956,6 +1958,9 @@ _TensorBase = TensorBase
|
|||
|
||||
def _DTensor_OpSchema_post_init(self: OpSchema) -> None: ...
|
||||
def _DTensor_OpSchema_recompute_comparison_key(self: OpSchema) -> None: ...
|
||||
def _DTensor_compute_global_tensor_info(
|
||||
tensor: Tensor, mesh: DeviceMesh, placements: Sequence[Placement]
|
||||
) -> tuple[list[_int], list[_int]]: ...
|
||||
|
||||
# Defined in torch/csrc/multiprocessing/init.cpp
|
||||
def _multiprocessing_init() -> None: ...
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@
|
|||
#include <torch/csrc/autograd/utils/error_messages.h>
|
||||
#include <torch/csrc/autograd/utils/wrap_outputs.h>
|
||||
#include <torch/csrc/autograd/variable.h>
|
||||
#include <torch/csrc/distributed/Placement.h>
|
||||
#include <torch/csrc/jit/frontend/tracer.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
#include <torch/csrc/tensor/python_tensor.h>
|
||||
|
|
@ -29,6 +30,7 @@
|
|||
#include <torch/csrc/utils/pycfunction_helpers.h>
|
||||
#include <torch/csrc/utils/pyobject_preservation.h>
|
||||
#include <torch/csrc/utils/python_arg_parser.h>
|
||||
#include <torch/csrc/utils/python_compat.h>
|
||||
#include <torch/csrc/utils/python_dispatch.h>
|
||||
#include <torch/csrc/utils/python_strings.h>
|
||||
#include <torch/csrc/utils/tensor_new.h>
|
||||
|
|
@ -828,19 +830,27 @@ static bool arg_type_tensor_or_tensor_list_like(py::handle arg) {
|
|||
return true;
|
||||
}
|
||||
|
||||
#define FOR_EACH_DTENSOR_INTERNED_STRING(_) \
|
||||
_(_comparison_key) \
|
||||
_(_local_tensor) \
|
||||
_(_spec) \
|
||||
_(args_schema) \
|
||||
_(has_symints) \
|
||||
_(kwargs_schema) \
|
||||
_(op) \
|
||||
_(schema_info) \
|
||||
_(shape) \
|
||||
_(static_argnum) \
|
||||
_(static_kwargkey) \
|
||||
_(stride) \
|
||||
#if IS_PYTHON_3_11_PLUS
|
||||
#define MAYBE_FOR_EACH_PYTHON_3_10_MINUS_DTENSOR_INTERNED_STRING(_)
|
||||
#else
|
||||
#define MAYBE_FOR_EACH_PYTHON_3_10_MINUS_DTENSOR_INTERNED_STRING(_) _(__name__)
|
||||
#endif
|
||||
|
||||
#define FOR_EACH_DTENSOR_INTERNED_STRING(_) \
|
||||
MAYBE_FOR_EACH_PYTHON_3_10_MINUS_DTENSOR_INTERNED_STRING(_) \
|
||||
_(_comparison_key) \
|
||||
_(_local_tensor) \
|
||||
_(_spec) \
|
||||
_(args_schema) \
|
||||
_(has_symints) \
|
||||
_(kwargs_schema) \
|
||||
_(op) \
|
||||
_(schema_info) \
|
||||
_(shape) \
|
||||
_(size) \
|
||||
_(static_argnum) \
|
||||
_(static_kwargkey) \
|
||||
_(stride) \
|
||||
_(tensor_meta)
|
||||
|
||||
struct DTensorInternedStrings {
|
||||
|
|
@ -1132,6 +1142,132 @@ static PyObject* DTensor_OpSchema_post_init(PyObject* mod, PyObject* self) {
|
|||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static py::list symint_array_to_list(SymIntArrayRef arr) {
|
||||
py::list result(arr.size());
|
||||
for (const auto idx : c10::irange(arr.size())) {
|
||||
result[idx] = py::cast(arr[idx]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static PyObject* DTensor_compute_global_tensor_info_impl(
|
||||
const Tensor& tensor,
|
||||
py::handle mesh,
|
||||
const py::sequence& placements) {
|
||||
Py_ssize_t idx = 0;
|
||||
c10::SymDimVector tensor_shape(
|
||||
tensor.sym_sizes().begin(), tensor.sym_sizes().end());
|
||||
c10::SymDimVector tensor_strides(
|
||||
tensor.sym_strides().begin(), tensor.sym_strides().end());
|
||||
// NOTE: if this is a py::handle then this code stops working;
|
||||
// apparently we can't rely on the bound method to stick around.
|
||||
py::object mesh_size;
|
||||
for (const auto& placement : placements) {
|
||||
// TODO: C++ify DeviceMesh somehow; profiling seems
|
||||
// to say that nearly all our remaining time spent is spent
|
||||
// calling back into Python.
|
||||
const auto& cpp_placement = placement.cast<const distributed::Placement&>();
|
||||
if (const auto* cpp_shard =
|
||||
dynamic_cast<const distributed::Shard*>(&cpp_placement)) {
|
||||
const auto shard_dim = cpp_shard->dim;
|
||||
TORCH_CHECK(
|
||||
shard_dim >= 0,
|
||||
"Shard placements should have negative dims normalized in the user-facing APIs: ",
|
||||
py::cast<std::string>(py::str(placement)));
|
||||
const auto tensor_ndim = tensor.dim();
|
||||
TORCH_CHECK(
|
||||
shard_dim < tensor_ndim,
|
||||
"Sharding dim ",
|
||||
shard_dim,
|
||||
" greater than tensor ndim ",
|
||||
tensor_ndim,
|
||||
" for placement number ",
|
||||
idx);
|
||||
|
||||
if (!mesh_size) {
|
||||
mesh_size = mesh.attr(dtensor_interned_strings.size);
|
||||
}
|
||||
const auto mesh_dim_size = py::cast<int64_t>(mesh_size(idx));
|
||||
tensor_shape[shard_dim] *= mesh_dim_size;
|
||||
// recover tensor stride by modifying the strides that are
|
||||
// larger than the current stride on the shard_dim.
|
||||
for (const auto i : c10::irange(tensor_strides.size())) {
|
||||
if (static_cast<int64_t>(i) != shard_dim &&
|
||||
tensor_strides[i] >= tensor_strides[shard_dim]) {
|
||||
tensor_strides[i] *= mesh_dim_size;
|
||||
}
|
||||
}
|
||||
} else if (!cpp_placement.is_replicate() && !cpp_placement.is_partial()) {
|
||||
#if IS_PYTHON_3_11_PLUS
|
||||
const auto placement_type_name =
|
||||
py::str(py::handle(PyType_GetName(Py_TYPE(placement.ptr()))));
|
||||
#else
|
||||
const auto placement_type_name =
|
||||
py::str(py::handle((PyObject*)Py_TYPE(placement.ptr()))
|
||||
.attr(dtensor_interned_strings.__name__));
|
||||
#endif
|
||||
return PyErr_Format(
|
||||
PyExc_RuntimeError,
|
||||
"placement type %s not supported!",
|
||||
py::cast<std::string>(placement_type_name).c_str());
|
||||
}
|
||||
idx++;
|
||||
}
|
||||
return py::make_tuple(
|
||||
symint_array_to_list(tensor_shape),
|
||||
symint_array_to_list(tensor_strides))
|
||||
.release()
|
||||
.ptr();
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||
static constexpr const char compute_global_tensor_info_doc[] =
|
||||
"Compute the global size and stride of a DTensor from the given local tensor.\n"
|
||||
"The local size is multiplied by `world_size` per Sharding dim.\n"
|
||||
"The local stride is multiplied by `world_size` per Sharding dim, as long as the\n"
|
||||
"dimension is outside sharding dim.\n"
|
||||
"\n"
|
||||
"For example, if we have a local tensor with size (4, 8, 2) and stride (16, 1, 8).\n"
|
||||
"If the DTensor placements are [Shard(2)] and world_size is 2;\n"
|
||||
"then the global size is (4, 8, 4) and stride is (16 * 2, 1, 8).\n"
|
||||
"\n"
|
||||
"Args:\n"
|
||||
" tensor (:class:`torch.Tensor`):\n"
|
||||
" Local tensor which DTensor will be constructed from.\n"
|
||||
" mesh (:class:`DeviceMesh`):\n"
|
||||
" Object which describes the mesh topology\n"
|
||||
" of devices for the DTensor.\n"
|
||||
" placements (Sequence[:class:`Placement`]]):\n"
|
||||
" The attribute of the DTensor that describes its layout\n"
|
||||
" on the mesh topology.\n"
|
||||
"\n"
|
||||
"Return:\n"
|
||||
" tensor_shape: A List of int which specifies the size of DTensor which build\n"
|
||||
" on top of the local tensor.\n"
|
||||
" tensor_stride: A List of int which specifies the stride of DTensor.\n";
|
||||
|
||||
static PyObject* DTensor_compute_global_tensor_info(
|
||||
PyObject* self,
|
||||
PyObject* const* args,
|
||||
Py_ssize_t nargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK_VALUE(
|
||||
nargs == 3,
|
||||
"compute_global_tensor_info expects 3 arguments, got ",
|
||||
nargs);
|
||||
TORCH_CHECK_TYPE(
|
||||
THPVariable_Check(args[0]),
|
||||
"compute_global_tensor_info 1st argument must be Tensor!");
|
||||
const auto& tensor = THPVariable_Unpack(args[0]);
|
||||
const py::handle mesh = args[1];
|
||||
TORCH_CHECK_TYPE(
|
||||
PySequence_Check(args[2]),
|
||||
"compute_global_tensor_info 3rd argument must be sequence!");
|
||||
const py::sequence placements = py::reinterpret_borrow<py::sequence>(args[2]);
|
||||
return DTensor_compute_global_tensor_info_impl(tensor, mesh, placements);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
using getter = PyObject* (*)(PyObject*, void*);
|
||||
using setter = int (*)(PyObject*, PyObject*, void*);
|
||||
|
||||
|
|
@ -2114,6 +2250,10 @@ static PyMethodDef extra_functions[] = {
|
|||
DTensor_OpSchema_recompute_comparison_key,
|
||||
METH_O,
|
||||
nullptr},
|
||||
{"_DTensor_compute_global_tensor_info",
|
||||
castPyCFunctionFast(DTensor_compute_global_tensor_info),
|
||||
METH_FASTCALL,
|
||||
compute_global_tensor_info_doc},
|
||||
{nullptr}};
|
||||
|
||||
struct THPVariableMeta {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <torch/csrc/utils/python_compat.h>
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
|
|
@ -11,3 +12,15 @@ inline PyCFunction castPyCFunctionWithKeywords(PyCFunctionWithKeywords func) {
|
|||
C10_DIAGNOSTIC_POP()
|
||||
C10_DIAGNOSTIC_POP()
|
||||
}
|
||||
|
||||
#if !IS_PYTHON_3_13_PLUS
|
||||
using PyCFunctionFast = _PyCFunctionFast;
|
||||
#endif
|
||||
|
||||
inline PyCFunction castPyCFunctionFast(PyCFunctionFast func) {
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wcast-function-type")
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wcast-function-type-strict")
|
||||
return reinterpret_cast<PyCFunction>(func);
|
||||
C10_DIAGNOSTIC_POP()
|
||||
C10_DIAGNOSTIC_POP()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,11 +9,11 @@ extern "C" {
|
|||
|
||||
// PyTorch-only compat functions
|
||||
|
||||
#define IS_PYTHON_3_11_PLUS PY_VERSION_HEX >= 0x030B00C1
|
||||
#define IS_PYTHON_3_12_PLUS PY_VERSION_HEX >= 0x030C0000
|
||||
#define IS_PYTHON_3_13_PLUS PY_VERSION_HEX >= 0x030D0000
|
||||
#define IS_PYTHON_3_14_PLUS PY_VERSION_HEX >= 0x030E0000
|
||||
#define IS_PYTHON_3_15_PLUS PY_VERSION_HEX >= 0x030F0000
|
||||
#define IS_PYTHON_3_11_PLUS (PY_VERSION_HEX >= 0x030B00C1)
|
||||
#define IS_PYTHON_3_12_PLUS (PY_VERSION_HEX >= 0x030C0000)
|
||||
#define IS_PYTHON_3_13_PLUS (PY_VERSION_HEX >= 0x030D0000)
|
||||
#define IS_PYTHON_3_14_PLUS (PY_VERSION_HEX >= 0x030E0000)
|
||||
#define IS_PYTHON_3_15_PLUS (PY_VERSION_HEX >= 0x030F0000)
|
||||
|
||||
static inline int PyCode_GetNCellvars(PyCodeObject* code) {
|
||||
// gh-26364 added co_ncellvars to Python 3.11.0rc1
|
||||
|
|
|
|||
|
|
@ -220,63 +220,7 @@ def _compute_local_shape_and_global_offset(
|
|||
return tuple(local_shape), tuple(global_offset)
|
||||
|
||||
|
||||
def compute_global_tensor_info(
|
||||
tensor: torch.Tensor, mesh: DeviceMesh, placements: Sequence[Placement]
|
||||
) -> tuple[list[int], list[int]]:
|
||||
"""
|
||||
Compute the global size and stride of a DTensor from the given local tensor.
|
||||
The local size is multiplited by `world_size` per Sharding dim.
|
||||
The local stride is multiplited by `world_size` per Sharding dim, as long as the
|
||||
dimension is outside sharding dim.
|
||||
|
||||
For example, if we have a local tensor with size (4, 8, 2) and stride (16, 1, 8).
|
||||
If the DTensor placements are [Shard(2)] and world_size is 2;
|
||||
then the global size is (4, 8, 4) and stride is (16 * 2, 1, 8).
|
||||
|
||||
Args:
|
||||
tensor (:class:`torch.Tensor`):
|
||||
Local tensor which DTensor will be constructed from.
|
||||
mesh (:class:`DeviceMesh`):
|
||||
Object which describes the mesh topology
|
||||
of devices for the DTensor.
|
||||
placements (Sequence[:class:`Placement`]]):
|
||||
The attribute of the DTensor that describes its layout
|
||||
on the mesh topology.
|
||||
|
||||
Return:
|
||||
tensor_shape: A List of int which specifies the size of DTensor which build
|
||||
on top of the local tensor.
|
||||
tensor_stride: A List of int which specifies the stride of DTensor.
|
||||
"""
|
||||
tensor_shape = list(tensor.size())
|
||||
tensor_stride = list(tensor.stride())
|
||||
for idx, placement in enumerate(placements):
|
||||
mesh_dim_size = mesh.size(idx)
|
||||
if placement.is_shard():
|
||||
shard_placement = cast(Shard, placement)
|
||||
if shard_placement.dim < 0:
|
||||
raise AssertionError(
|
||||
"Shard placements should have negative dims normalized in "
|
||||
f"the user-facing APIs: {shard_placement}"
|
||||
)
|
||||
shard_dim = shard_placement.dim
|
||||
|
||||
assert shard_dim < tensor.ndim, (
|
||||
f"Sharding dim {shard_dim} greater than tensor ndim {tensor.ndim} for placement number {idx}."
|
||||
)
|
||||
|
||||
local_dim_size = tensor_shape[shard_dim]
|
||||
tensor_shape[shard_dim] = local_dim_size * mesh_dim_size
|
||||
|
||||
# recover tensor stride by modifying the stride that larger than
|
||||
# the current stride on the shard_dim
|
||||
for i in range(len(tensor_stride)):
|
||||
if i != shard_dim and tensor_stride[i] >= tensor_stride[shard_dim]:
|
||||
# rescale the stride by the shard size
|
||||
tensor_stride[i] = tensor_stride[i] * mesh_dim_size
|
||||
elif not isinstance(placement, (Replicate, Partial)):
|
||||
raise RuntimeError(f"placement type {type(placement)} not supported!")
|
||||
return tensor_shape, tensor_stride
|
||||
compute_global_tensor_info = torch._C._DTensor_compute_global_tensor_info
|
||||
|
||||
|
||||
def compute_local_tensor_info(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user