DTensor: C++ compute_global_tensor_info (#162990)

compute_global_tensor_info is on the hot path for DTensor.{from,to}_local. More incremental progress toward C++.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162990
Approved by: https://github.com/ezyang
This commit is contained in:
Scott Wolchok 2025-10-30 02:56:54 +00:00 committed by PyTorch MergeBot
parent ad559072db
commit 6a5a436624
5 changed files with 177 additions and 75 deletions

View File

@ -43,7 +43,9 @@ from torch._C import (
from torch._prims_common import DeviceLikeType
from torch.autograd.graph import Node as _Node
from torch.cuda import _POOL_HANDLE
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._op_schema import OpSchema
from torch.distributed.tensor.placement_types import Placement
from torch.fx.node import Node as FxNode
from torch.package import PackageExporter
from torch.storage import TypedStorage, UntypedStorage
@ -1956,6 +1958,9 @@ _TensorBase = TensorBase
def _DTensor_OpSchema_post_init(self: OpSchema) -> None: ...
def _DTensor_OpSchema_recompute_comparison_key(self: OpSchema) -> None: ...
def _DTensor_compute_global_tensor_info(
tensor: Tensor, mesh: DeviceMesh, placements: Sequence[Placement]
) -> tuple[list[_int], list[_int]]: ...
# Defined in torch/csrc/multiprocessing/init.cpp
def _multiprocessing_init() -> None: ...

View File

@ -22,6 +22,7 @@
#include <torch/csrc/autograd/utils/error_messages.h>
#include <torch/csrc/autograd/utils/wrap_outputs.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/distributed/Placement.h>
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/tensor/python_tensor.h>
@ -29,6 +30,7 @@
#include <torch/csrc/utils/pycfunction_helpers.h>
#include <torch/csrc/utils/pyobject_preservation.h>
#include <torch/csrc/utils/python_arg_parser.h>
#include <torch/csrc/utils/python_compat.h>
#include <torch/csrc/utils/python_dispatch.h>
#include <torch/csrc/utils/python_strings.h>
#include <torch/csrc/utils/tensor_new.h>
@ -828,19 +830,27 @@ static bool arg_type_tensor_or_tensor_list_like(py::handle arg) {
return true;
}
#define FOR_EACH_DTENSOR_INTERNED_STRING(_) \
_(_comparison_key) \
_(_local_tensor) \
_(_spec) \
_(args_schema) \
_(has_symints) \
_(kwargs_schema) \
_(op) \
_(schema_info) \
_(shape) \
_(static_argnum) \
_(static_kwargkey) \
_(stride) \
#if IS_PYTHON_3_11_PLUS
#define MAYBE_FOR_EACH_PYTHON_3_10_MINUS_DTENSOR_INTERNED_STRING(_)
#else
#define MAYBE_FOR_EACH_PYTHON_3_10_MINUS_DTENSOR_INTERNED_STRING(_) _(__name__)
#endif
#define FOR_EACH_DTENSOR_INTERNED_STRING(_) \
MAYBE_FOR_EACH_PYTHON_3_10_MINUS_DTENSOR_INTERNED_STRING(_) \
_(_comparison_key) \
_(_local_tensor) \
_(_spec) \
_(args_schema) \
_(has_symints) \
_(kwargs_schema) \
_(op) \
_(schema_info) \
_(shape) \
_(size) \
_(static_argnum) \
_(static_kwargkey) \
_(stride) \
_(tensor_meta)
struct DTensorInternedStrings {
@ -1132,6 +1142,132 @@ static PyObject* DTensor_OpSchema_post_init(PyObject* mod, PyObject* self) {
END_HANDLE_TH_ERRORS
}
static py::list symint_array_to_list(SymIntArrayRef arr) {
py::list result(arr.size());
for (const auto idx : c10::irange(arr.size())) {
result[idx] = py::cast(arr[idx]);
}
return result;
}
static PyObject* DTensor_compute_global_tensor_info_impl(
const Tensor& tensor,
py::handle mesh,
const py::sequence& placements) {
Py_ssize_t idx = 0;
c10::SymDimVector tensor_shape(
tensor.sym_sizes().begin(), tensor.sym_sizes().end());
c10::SymDimVector tensor_strides(
tensor.sym_strides().begin(), tensor.sym_strides().end());
// NOTE: if this is a py::handle then this code stops working;
// apparently we can't rely on the bound method to stick around.
py::object mesh_size;
for (const auto& placement : placements) {
// TODO: C++ify DeviceMesh somehow; profiling seems
// to say that nearly all our remaining time spent is spent
// calling back into Python.
const auto& cpp_placement = placement.cast<const distributed::Placement&>();
if (const auto* cpp_shard =
dynamic_cast<const distributed::Shard*>(&cpp_placement)) {
const auto shard_dim = cpp_shard->dim;
TORCH_CHECK(
shard_dim >= 0,
"Shard placements should have negative dims normalized in the user-facing APIs: ",
py::cast<std::string>(py::str(placement)));
const auto tensor_ndim = tensor.dim();
TORCH_CHECK(
shard_dim < tensor_ndim,
"Sharding dim ",
shard_dim,
" greater than tensor ndim ",
tensor_ndim,
" for placement number ",
idx);
if (!mesh_size) {
mesh_size = mesh.attr(dtensor_interned_strings.size);
}
const auto mesh_dim_size = py::cast<int64_t>(mesh_size(idx));
tensor_shape[shard_dim] *= mesh_dim_size;
// recover tensor stride by modifying the strides that are
// larger than the current stride on the shard_dim.
for (const auto i : c10::irange(tensor_strides.size())) {
if (static_cast<int64_t>(i) != shard_dim &&
tensor_strides[i] >= tensor_strides[shard_dim]) {
tensor_strides[i] *= mesh_dim_size;
}
}
} else if (!cpp_placement.is_replicate() && !cpp_placement.is_partial()) {
#if IS_PYTHON_3_11_PLUS
const auto placement_type_name =
py::str(py::handle(PyType_GetName(Py_TYPE(placement.ptr()))));
#else
const auto placement_type_name =
py::str(py::handle((PyObject*)Py_TYPE(placement.ptr()))
.attr(dtensor_interned_strings.__name__));
#endif
return PyErr_Format(
PyExc_RuntimeError,
"placement type %s not supported!",
py::cast<std::string>(placement_type_name).c_str());
}
idx++;
}
return py::make_tuple(
symint_array_to_list(tensor_shape),
symint_array_to_list(tensor_strides))
.release()
.ptr();
}
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
static constexpr const char compute_global_tensor_info_doc[] =
"Compute the global size and stride of a DTensor from the given local tensor.\n"
"The local size is multiplied by `world_size` per Sharding dim.\n"
"The local stride is multiplied by `world_size` per Sharding dim, as long as the\n"
"dimension is outside sharding dim.\n"
"\n"
"For example, if we have a local tensor with size (4, 8, 2) and stride (16, 1, 8).\n"
"If the DTensor placements are [Shard(2)] and world_size is 2;\n"
"then the global size is (4, 8, 4) and stride is (16 * 2, 1, 8).\n"
"\n"
"Args:\n"
" tensor (:class:`torch.Tensor`):\n"
" Local tensor which DTensor will be constructed from.\n"
" mesh (:class:`DeviceMesh`):\n"
" Object which describes the mesh topology\n"
" of devices for the DTensor.\n"
" placements (Sequence[:class:`Placement`]]):\n"
" The attribute of the DTensor that describes its layout\n"
" on the mesh topology.\n"
"\n"
"Return:\n"
" tensor_shape: A List of int which specifies the size of DTensor which build\n"
" on top of the local tensor.\n"
" tensor_stride: A List of int which specifies the stride of DTensor.\n";
static PyObject* DTensor_compute_global_tensor_info(
PyObject* self,
PyObject* const* args,
Py_ssize_t nargs) {
HANDLE_TH_ERRORS
TORCH_CHECK_VALUE(
nargs == 3,
"compute_global_tensor_info expects 3 arguments, got ",
nargs);
TORCH_CHECK_TYPE(
THPVariable_Check(args[0]),
"compute_global_tensor_info 1st argument must be Tensor!");
const auto& tensor = THPVariable_Unpack(args[0]);
const py::handle mesh = args[1];
TORCH_CHECK_TYPE(
PySequence_Check(args[2]),
"compute_global_tensor_info 3rd argument must be sequence!");
const py::sequence placements = py::reinterpret_borrow<py::sequence>(args[2]);
return DTensor_compute_global_tensor_info_impl(tensor, mesh, placements);
END_HANDLE_TH_ERRORS
}
using getter = PyObject* (*)(PyObject*, void*);
using setter = int (*)(PyObject*, PyObject*, void*);
@ -2114,6 +2250,10 @@ static PyMethodDef extra_functions[] = {
DTensor_OpSchema_recompute_comparison_key,
METH_O,
nullptr},
{"_DTensor_compute_global_tensor_info",
castPyCFunctionFast(DTensor_compute_global_tensor_info),
METH_FASTCALL,
compute_global_tensor_info_doc},
{nullptr}};
struct THPVariableMeta {

View File

@ -1,6 +1,7 @@
#pragma once
#include <c10/macros/Macros.h>
#include <torch/csrc/utils/python_compat.h>
#include <Python.h>
@ -11,3 +12,15 @@ inline PyCFunction castPyCFunctionWithKeywords(PyCFunctionWithKeywords func) {
C10_DIAGNOSTIC_POP()
C10_DIAGNOSTIC_POP()
}
#if !IS_PYTHON_3_13_PLUS
using PyCFunctionFast = _PyCFunctionFast;
#endif
inline PyCFunction castPyCFunctionFast(PyCFunctionFast func) {
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wcast-function-type")
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wcast-function-type-strict")
return reinterpret_cast<PyCFunction>(func);
C10_DIAGNOSTIC_POP()
C10_DIAGNOSTIC_POP()
}

View File

@ -9,11 +9,11 @@ extern "C" {
// PyTorch-only compat functions
#define IS_PYTHON_3_11_PLUS PY_VERSION_HEX >= 0x030B00C1
#define IS_PYTHON_3_12_PLUS PY_VERSION_HEX >= 0x030C0000
#define IS_PYTHON_3_13_PLUS PY_VERSION_HEX >= 0x030D0000
#define IS_PYTHON_3_14_PLUS PY_VERSION_HEX >= 0x030E0000
#define IS_PYTHON_3_15_PLUS PY_VERSION_HEX >= 0x030F0000
#define IS_PYTHON_3_11_PLUS (PY_VERSION_HEX >= 0x030B00C1)
#define IS_PYTHON_3_12_PLUS (PY_VERSION_HEX >= 0x030C0000)
#define IS_PYTHON_3_13_PLUS (PY_VERSION_HEX >= 0x030D0000)
#define IS_PYTHON_3_14_PLUS (PY_VERSION_HEX >= 0x030E0000)
#define IS_PYTHON_3_15_PLUS (PY_VERSION_HEX >= 0x030F0000)
static inline int PyCode_GetNCellvars(PyCodeObject* code) {
// gh-26364 added co_ncellvars to Python 3.11.0rc1

View File

@ -220,63 +220,7 @@ def _compute_local_shape_and_global_offset(
return tuple(local_shape), tuple(global_offset)
def compute_global_tensor_info(
tensor: torch.Tensor, mesh: DeviceMesh, placements: Sequence[Placement]
) -> tuple[list[int], list[int]]:
"""
Compute the global size and stride of a DTensor from the given local tensor.
The local size is multiplited by `world_size` per Sharding dim.
The local stride is multiplited by `world_size` per Sharding dim, as long as the
dimension is outside sharding dim.
For example, if we have a local tensor with size (4, 8, 2) and stride (16, 1, 8).
If the DTensor placements are [Shard(2)] and world_size is 2;
then the global size is (4, 8, 4) and stride is (16 * 2, 1, 8).
Args:
tensor (:class:`torch.Tensor`):
Local tensor which DTensor will be constructed from.
mesh (:class:`DeviceMesh`):
Object which describes the mesh topology
of devices for the DTensor.
placements (Sequence[:class:`Placement`]]):
The attribute of the DTensor that describes its layout
on the mesh topology.
Return:
tensor_shape: A List of int which specifies the size of DTensor which build
on top of the local tensor.
tensor_stride: A List of int which specifies the stride of DTensor.
"""
tensor_shape = list(tensor.size())
tensor_stride = list(tensor.stride())
for idx, placement in enumerate(placements):
mesh_dim_size = mesh.size(idx)
if placement.is_shard():
shard_placement = cast(Shard, placement)
if shard_placement.dim < 0:
raise AssertionError(
"Shard placements should have negative dims normalized in "
f"the user-facing APIs: {shard_placement}"
)
shard_dim = shard_placement.dim
assert shard_dim < tensor.ndim, (
f"Sharding dim {shard_dim} greater than tensor ndim {tensor.ndim} for placement number {idx}."
)
local_dim_size = tensor_shape[shard_dim]
tensor_shape[shard_dim] = local_dim_size * mesh_dim_size
# recover tensor stride by modifying the stride that larger than
# the current stride on the shard_dim
for i in range(len(tensor_stride)):
if i != shard_dim and tensor_stride[i] >= tensor_stride[shard_dim]:
# rescale the stride by the shard size
tensor_stride[i] = tensor_stride[i] * mesh_dim_size
elif not isinstance(placement, (Replicate, Partial)):
raise RuntimeError(f"placement type {type(placement)} not supported!")
return tensor_shape, tensor_stride
compute_global_tensor_info = torch._C._DTensor_compute_global_tensor_info
def compute_local_tensor_info(