Fully native DTensor.__new__ (#162508)

Move the entirety of `__new__` into C++, saving a layer of disable_dynamo and making progress toward all-C++.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162508
Approved by: https://github.com/ezyang
ghstack dependencies: #161695
This commit is contained in:
Scott Wolchok 2025-09-18 13:19:20 -07:00 committed by PyTorch MergeBot
parent 51152efa67
commit 5599f487ef
4 changed files with 129 additions and 103 deletions

View File

@ -1539,15 +1539,14 @@ def gen_pyi(
"S",
)
],
"_make_dtensor": [
"_dtensor__new__": [
"@staticmethod\n"
+ defs(
"_make_dtensor",
"_dtensor__new__",
[
"cls: type[S]",
"size: Sequence[_int | SymInt]",
"strides: Sequence[_int | SymInt]",
"local_tensor: Tensor",
"spec: torch.distributed.tensor._dtensor_spec.DTensorSpec",
"requires_grad: _bool",
],
"S",

View File

@ -626,7 +626,7 @@ static PyObject* THPVariable_make_subclass(
}
// Shared code factored out of THPVariable_make_wrapper_subclass and
// THPVariable_make_dtensor.
// THPVariable_dtensor__new__.
static Tensor make_tensor_for_subclass_helper(
SymIntArrayRef sym_sizes,
OptionalSymIntArrayRef sym_strides,
@ -785,73 +785,6 @@ static PyObject* THPVariable_make_wrapper_subclass(
END_HANDLE_TH_ERRORS
}
// DTensor-specific variant of make_wrapper_subclass to minimize DTensor
// overhead.
static PyObject* THPVariable_make_dtensor(
PyObject*,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser({
"_make_dtensor(PyObject* cls, SymIntArrayRef size, SymIntArrayRef strides, "
"Tensor local_tensor, bool requires_grad)",
});
ParsedArgs<5> parsed_args{};
auto r = parser.parse(args, kwargs, parsed_args);
PyObject* cls = r.pyobject(0);
TORCH_CHECK_TYPE(
PyType_Check(cls),
"cls must be a type (got ",
Py_TYPE(cls)->tp_name,
")");
#ifndef NDEBUG
// This is specifically for making a DTensor, which we know defines
// __torch_dispatch__. Check anyway in debug builds in case somebody
// removes it.
py::object attr = PyObject_FastGetAttrString(cls, "__torch_dispatch__");
TORCH_CHECK_TYPE(
attr.ptr() != nullptr &&
attr.ptr() != torch::disabled_torch_dispatch_impl(),
((PyTypeObject*)cls)->tp_name,
" must define __torch_dispatch__");
#endif
const auto& local_tensor = r.tensor(3);
const auto options = TensorOptions()
.dtype(local_tensor.dtype())
.device(local_tensor.device())
.layout(local_tensor.layout());
DispatchKeySet extra_dispatch_keys;
const auto tensor_keys = local_tensor.key_set();
if (tensor_keys.has(c10::DispatchKey::Conjugate)) {
extra_dispatch_keys = extra_dispatch_keys.add(c10::DispatchKey::Conjugate);
}
if (tensor_keys.has(c10::DispatchKey::Negative)) {
extra_dispatch_keys = extra_dispatch_keys.add(c10::DispatchKey::Negative);
}
Tensor tensor = make_tensor_for_subclass_helper(
/*sym_sizes=*/r.symintlist(1),
/*sym_strides=*/r.symintlist(2),
/*sym_storage_offset=*/std::nullopt,
options,
/*storage_size=*/std::nullopt,
extra_dispatch_keys);
tensor.set_requires_grad(r.toBool(4));
return THPVariable_NewWithVar(
(PyTypeObject*)cls,
tensor,
// false is the default
/*allow_preexisting_pyobj=*/false,
// we know DTensor has __torch_dispatch__ and we double-checked
// above; avoid checking again.
/*has_torch_dispatch_if_known=*/true);
END_HANDLE_TH_ERRORS
}
static py::handle get_dtensor_spec_class() {
#if IS_PYBIND_2_13_PLUS
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
@ -895,6 +828,8 @@ static bool arg_type_tensor_or_tensor_list_like(py::handle arg) {
#define FOR_EACH_DTENSOR_INTERNED_STRING(_) \
_(_comparison_key) \
_(_local_tensor) \
_(_spec) \
_(args_schema) \
_(has_symints) \
_(kwargs_schema) \
@ -903,6 +838,7 @@ static bool arg_type_tensor_or_tensor_list_like(py::handle arg) {
_(shape) \
_(static_argnum) \
_(static_kwargkey) \
_(stride) \
_(tensor_meta)
struct DTensorInternedStrings {
@ -934,6 +870,111 @@ static bool checked_not(PyObject* obj) {
return result;
}
static c10::SymDimVector tuple_to_symintlist(PyObject* obj) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(PyTuple_Check(obj));
c10::SymDimVector res;
const auto size = PyTuple_GET_SIZE(obj);
res.reserve(size);
for (const auto idx : c10::irange(size)) {
PyObject* item = PyTuple_GET_ITEM(obj, idx);
if (THPUtils_checkLongExact(item)) {
res.emplace_back(THPUtils_unpackLong(item));
} else if (torch::is_symint(py::handle(item))) {
res.push_back(py::handle(item).cast<c10::SymInt>());
} else {
// N.B. torch.Tensor.__index__ exists, so this should handle
// scalar Tensors fine.
res.emplace_back(THPUtils_unpackIndex(item));
}
}
return res;
}
// DTensor-specific variant of make_wrapper_subclass to minimize DTensor
// overhead.
static PyObject* THPVariable_dtensor_new(
PyObject*,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser({
"_dtensor__new__(PyObject* cls, Tensor local_tensor, PyObject* spec, bool requires_grad)",
});
ParsedArgs<4> parsed_args{};
auto r = parser.parse(args, kwargs, parsed_args);
PyObject* cls = r.pyobject(0);
TORCH_CHECK_TYPE(
PyType_Check(cls),
"cls must be a type (got ",
Py_TYPE(cls)->tp_name,
")");
#ifndef NDEBUG
// This is specifically for making a DTensor, which we know defines
// __torch_dispatch__. Check anyway in debug builds in case somebody
// removes it.
py::object attr = PyObject_FastGetAttrString(cls, "__torch_dispatch__");
TORCH_CHECK_TYPE(
attr.ptr() != nullptr &&
attr.ptr() != torch::disabled_torch_dispatch_impl(),
((PyTypeObject*)cls)->tp_name,
" must define __torch_dispatch__");
#endif
const auto& local_tensor = r.tensor(1);
const bool requires_grad = r.toBool(3);
if (local_tensor.requires_grad() && !requires_grad) {
TORCH_WARN(
"To construct DTensor from torch.Tensor, it's recommended to use "
"local_tensor.detach() and make requires_grad consistent.");
}
const auto options = TensorOptions()
.dtype(local_tensor.dtype())
.device(local_tensor.device())
.layout(local_tensor.layout());
DispatchKeySet extra_dispatch_keys;
const auto tensor_keys = local_tensor.key_set();
if (tensor_keys.has(c10::DispatchKey::Conjugate)) {
extra_dispatch_keys = extra_dispatch_keys.add(c10::DispatchKey::Conjugate);
}
if (tensor_keys.has(c10::DispatchKey::Negative)) {
extra_dispatch_keys = extra_dispatch_keys.add(c10::DispatchKey::Negative);
}
py::handle spec = py::handle(r.pyobject(2));
const auto tensor_meta = spec.attr(dtensor_interned_strings.tensor_meta);
TORCH_CHECK(!tensor_meta.is_none());
const auto sizes = tensor_meta.attr(dtensor_interned_strings.shape);
TORCH_CHECK(
PyTuple_Check(sizes.ptr()), "spec.tensor_meta.shape must be a tuple");
const auto stride = tensor_meta.attr(dtensor_interned_strings.stride);
TORCH_CHECK(
PyTuple_Check(stride.ptr()), "spec.tensor_meta.stride must be a tuple");
Tensor tensor = make_tensor_for_subclass_helper(
/*sym_sizes=*/tuple_to_symintlist(sizes.ptr()),
/*sym_strides=*/tuple_to_symintlist(stride.ptr()),
/*sym_storage_offset=*/std::nullopt,
options,
/*storage_size=*/std::nullopt,
extra_dispatch_keys);
tensor.set_requires_grad(requires_grad);
py::object py_tensor =
py::reinterpret_steal<py::object>(THPVariable_NewWithVar(
(PyTypeObject*)cls,
tensor,
// false is the default
/*allow_preexisting_pyobj=*/false,
// we know DTensor has __torch_dispatch__; avoid checking again.
/*has_torch_dispatch_if_known=*/true));
py_tensor.attr(dtensor_interned_strings._spec) = spec;
py_tensor.attr(dtensor_interned_strings._local_tensor) = local_tensor;
return py_tensor.release().ptr();
END_HANDLE_TH_ERRORS
}
static bool DTensor_OpSchema_recompute_comparison_key_impl(
PyObject* self,
const py::tuple& args_schema) {
@ -1979,8 +2020,8 @@ static PyMethodDef extra_methods[] = {
castPyCFunctionWithKeywords(THPVariable_make_wrapper_subclass),
METH_STATIC | METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_make_dtensor",
castPyCFunctionWithKeywords(THPVariable_make_dtensor),
{"_dtensor__new__",
castPyCFunctionWithKeywords(THPVariable_dtensor_new),
METH_STATIC | METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_fix_weakref", THPVariable_fix_weakref, METH_NOARGS, nullptr},

View File

@ -240,8 +240,8 @@ class DTensor(torch.Tensor):
# _op_dispatcher instance as a class attribute to handle runtime dispatching logic
_op_dispatcher: op_dispatch.OpDispatcher = op_dispatch.OpDispatcher()
@staticmethod
@torch._disable_dynamo
# This implementation is just to convince mypy _spec and _local_tensor are
# initialized; it is immediately overridden below.
def __new__(
cls,
local_tensor: torch.Tensor,
@ -249,10 +249,21 @@ class DTensor(torch.Tensor):
*,
requires_grad: bool,
) -> "DTensor":
r = torch.Tensor._dtensor__new__(
cls, local_tensor, spec, requires_grad=requires_grad
)
r._spec = spec
r._local_tensor = local_tensor
return r
__new__ = torch.Tensor._dtensor__new__ # type: ignore[assignment] # noqa: F811
@torch._disable_dynamo
@mark_subclass_constructor_exportable_experimental
def __init__(self, *args, **kwargs):
"""
Construct a DTensor from a local tensor, device mesh, and placement and
other tensor properties (i.e. shape, requires_grad, strides, etc).
.. note:: This is not a public API and it's only supposed to be used by the
operator implementations and internals. If you want to construct a
DTensor from a local tensor, consider using ``DTensor.from_local``, if
@ -260,31 +271,6 @@ class DTensor(torch.Tensor):
already have tensor initialized and want to shard this tensor),
consider using ``distribute_tensor``.
"""
if local_tensor.requires_grad and not requires_grad:
warnings.warn(
"To construct DTensor from torch.Tensor, it's recommended to "
"use local_tensor.detach() and make requires_grad consistent."
)
# new method instruct wrapper tensor from local_tensor and add
# placement spec, it does not do actual distribution
assert spec.tensor_meta is not None, "TensorMeta should not be None!"
r = torch.Tensor._make_dtensor(
cls,
spec.tensor_meta.shape,
spec.tensor_meta.stride,
local_tensor,
requires_grad,
)
r._spec = spec
r._local_tensor = local_tensor
return r
@torch._disable_dynamo
@mark_subclass_constructor_exportable_experimental
def __init__(self, *args, **kwargs):
super().__init__()
# pyre-fixme[14]: `__repr__` overrides method defined in `DTensor` inconsistently.

View File

@ -362,7 +362,7 @@ def get_ignored_functions() -> set[Callable]:
Tensor._view_func,
Tensor._view_func_unsafe,
Tensor._rev_view_func_unsafe,
Tensor._make_dtensor,
Tensor._dtensor__new__,
Tensor._make_wrapper_subclass,
Tensor._python_dispatch.__get__,
Tensor._has_symbolic_sizes_strides.__get__,