mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Fully native DTensor.__new__ (#162508)
Move the entirety of `__new__` into C++, saving a layer of disable_dynamo and making progress toward all-C++. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162508 Approved by: https://github.com/ezyang ghstack dependencies: #161695
This commit is contained in:
parent
51152efa67
commit
5599f487ef
|
|
@ -1539,15 +1539,14 @@ def gen_pyi(
|
|||
"S",
|
||||
)
|
||||
],
|
||||
"_make_dtensor": [
|
||||
"_dtensor__new__": [
|
||||
"@staticmethod\n"
|
||||
+ defs(
|
||||
"_make_dtensor",
|
||||
"_dtensor__new__",
|
||||
[
|
||||
"cls: type[S]",
|
||||
"size: Sequence[_int | SymInt]",
|
||||
"strides: Sequence[_int | SymInt]",
|
||||
"local_tensor: Tensor",
|
||||
"spec: torch.distributed.tensor._dtensor_spec.DTensorSpec",
|
||||
"requires_grad: _bool",
|
||||
],
|
||||
"S",
|
||||
|
|
|
|||
|
|
@ -626,7 +626,7 @@ static PyObject* THPVariable_make_subclass(
|
|||
}
|
||||
|
||||
// Shared code factored out of THPVariable_make_wrapper_subclass and
|
||||
// THPVariable_make_dtensor.
|
||||
// THPVariable_dtensor__new__.
|
||||
static Tensor make_tensor_for_subclass_helper(
|
||||
SymIntArrayRef sym_sizes,
|
||||
OptionalSymIntArrayRef sym_strides,
|
||||
|
|
@ -785,73 +785,6 @@ static PyObject* THPVariable_make_wrapper_subclass(
|
|||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// DTensor-specific variant of make_wrapper_subclass to minimize DTensor
|
||||
// overhead.
|
||||
static PyObject* THPVariable_make_dtensor(
|
||||
PyObject*,
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
static PythonArgParser parser({
|
||||
"_make_dtensor(PyObject* cls, SymIntArrayRef size, SymIntArrayRef strides, "
|
||||
"Tensor local_tensor, bool requires_grad)",
|
||||
});
|
||||
ParsedArgs<5> parsed_args{};
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
PyObject* cls = r.pyobject(0);
|
||||
|
||||
TORCH_CHECK_TYPE(
|
||||
PyType_Check(cls),
|
||||
"cls must be a type (got ",
|
||||
Py_TYPE(cls)->tp_name,
|
||||
")");
|
||||
|
||||
#ifndef NDEBUG
|
||||
// This is specifically for making a DTensor, which we know defines
|
||||
// __torch_dispatch__. Check anyway in debug builds in case somebody
|
||||
// removes it.
|
||||
py::object attr = PyObject_FastGetAttrString(cls, "__torch_dispatch__");
|
||||
TORCH_CHECK_TYPE(
|
||||
attr.ptr() != nullptr &&
|
||||
attr.ptr() != torch::disabled_torch_dispatch_impl(),
|
||||
((PyTypeObject*)cls)->tp_name,
|
||||
" must define __torch_dispatch__");
|
||||
#endif
|
||||
|
||||
const auto& local_tensor = r.tensor(3);
|
||||
const auto options = TensorOptions()
|
||||
.dtype(local_tensor.dtype())
|
||||
.device(local_tensor.device())
|
||||
.layout(local_tensor.layout());
|
||||
|
||||
DispatchKeySet extra_dispatch_keys;
|
||||
const auto tensor_keys = local_tensor.key_set();
|
||||
if (tensor_keys.has(c10::DispatchKey::Conjugate)) {
|
||||
extra_dispatch_keys = extra_dispatch_keys.add(c10::DispatchKey::Conjugate);
|
||||
}
|
||||
if (tensor_keys.has(c10::DispatchKey::Negative)) {
|
||||
extra_dispatch_keys = extra_dispatch_keys.add(c10::DispatchKey::Negative);
|
||||
}
|
||||
|
||||
Tensor tensor = make_tensor_for_subclass_helper(
|
||||
/*sym_sizes=*/r.symintlist(1),
|
||||
/*sym_strides=*/r.symintlist(2),
|
||||
/*sym_storage_offset=*/std::nullopt,
|
||||
options,
|
||||
/*storage_size=*/std::nullopt,
|
||||
extra_dispatch_keys);
|
||||
tensor.set_requires_grad(r.toBool(4));
|
||||
return THPVariable_NewWithVar(
|
||||
(PyTypeObject*)cls,
|
||||
tensor,
|
||||
// false is the default
|
||||
/*allow_preexisting_pyobj=*/false,
|
||||
// we know DTensor has __torch_dispatch__ and we double-checked
|
||||
// above; avoid checking again.
|
||||
/*has_torch_dispatch_if_known=*/true);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static py::handle get_dtensor_spec_class() {
|
||||
#if IS_PYBIND_2_13_PLUS
|
||||
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
|
||||
|
|
@ -895,6 +828,8 @@ static bool arg_type_tensor_or_tensor_list_like(py::handle arg) {
|
|||
|
||||
#define FOR_EACH_DTENSOR_INTERNED_STRING(_) \
|
||||
_(_comparison_key) \
|
||||
_(_local_tensor) \
|
||||
_(_spec) \
|
||||
_(args_schema) \
|
||||
_(has_symints) \
|
||||
_(kwargs_schema) \
|
||||
|
|
@ -903,6 +838,7 @@ static bool arg_type_tensor_or_tensor_list_like(py::handle arg) {
|
|||
_(shape) \
|
||||
_(static_argnum) \
|
||||
_(static_kwargkey) \
|
||||
_(stride) \
|
||||
_(tensor_meta)
|
||||
|
||||
struct DTensorInternedStrings {
|
||||
|
|
@ -934,6 +870,111 @@ static bool checked_not(PyObject* obj) {
|
|||
return result;
|
||||
}
|
||||
|
||||
static c10::SymDimVector tuple_to_symintlist(PyObject* obj) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(PyTuple_Check(obj));
|
||||
c10::SymDimVector res;
|
||||
const auto size = PyTuple_GET_SIZE(obj);
|
||||
res.reserve(size);
|
||||
for (const auto idx : c10::irange(size)) {
|
||||
PyObject* item = PyTuple_GET_ITEM(obj, idx);
|
||||
if (THPUtils_checkLongExact(item)) {
|
||||
res.emplace_back(THPUtils_unpackLong(item));
|
||||
} else if (torch::is_symint(py::handle(item))) {
|
||||
res.push_back(py::handle(item).cast<c10::SymInt>());
|
||||
} else {
|
||||
// N.B. torch.Tensor.__index__ exists, so this should handle
|
||||
// scalar Tensors fine.
|
||||
res.emplace_back(THPUtils_unpackIndex(item));
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
// DTensor-specific variant of make_wrapper_subclass to minimize DTensor
|
||||
// overhead.
|
||||
static PyObject* THPVariable_dtensor_new(
|
||||
PyObject*,
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
static PythonArgParser parser({
|
||||
"_dtensor__new__(PyObject* cls, Tensor local_tensor, PyObject* spec, bool requires_grad)",
|
||||
});
|
||||
ParsedArgs<4> parsed_args{};
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
PyObject* cls = r.pyobject(0);
|
||||
|
||||
TORCH_CHECK_TYPE(
|
||||
PyType_Check(cls),
|
||||
"cls must be a type (got ",
|
||||
Py_TYPE(cls)->tp_name,
|
||||
")");
|
||||
|
||||
#ifndef NDEBUG
|
||||
// This is specifically for making a DTensor, which we know defines
|
||||
// __torch_dispatch__. Check anyway in debug builds in case somebody
|
||||
// removes it.
|
||||
py::object attr = PyObject_FastGetAttrString(cls, "__torch_dispatch__");
|
||||
TORCH_CHECK_TYPE(
|
||||
attr.ptr() != nullptr &&
|
||||
attr.ptr() != torch::disabled_torch_dispatch_impl(),
|
||||
((PyTypeObject*)cls)->tp_name,
|
||||
" must define __torch_dispatch__");
|
||||
#endif
|
||||
|
||||
const auto& local_tensor = r.tensor(1);
|
||||
const bool requires_grad = r.toBool(3);
|
||||
if (local_tensor.requires_grad() && !requires_grad) {
|
||||
TORCH_WARN(
|
||||
"To construct DTensor from torch.Tensor, it's recommended to use "
|
||||
"local_tensor.detach() and make requires_grad consistent.");
|
||||
}
|
||||
const auto options = TensorOptions()
|
||||
.dtype(local_tensor.dtype())
|
||||
.device(local_tensor.device())
|
||||
.layout(local_tensor.layout());
|
||||
|
||||
DispatchKeySet extra_dispatch_keys;
|
||||
const auto tensor_keys = local_tensor.key_set();
|
||||
if (tensor_keys.has(c10::DispatchKey::Conjugate)) {
|
||||
extra_dispatch_keys = extra_dispatch_keys.add(c10::DispatchKey::Conjugate);
|
||||
}
|
||||
if (tensor_keys.has(c10::DispatchKey::Negative)) {
|
||||
extra_dispatch_keys = extra_dispatch_keys.add(c10::DispatchKey::Negative);
|
||||
}
|
||||
|
||||
py::handle spec = py::handle(r.pyobject(2));
|
||||
const auto tensor_meta = spec.attr(dtensor_interned_strings.tensor_meta);
|
||||
TORCH_CHECK(!tensor_meta.is_none());
|
||||
const auto sizes = tensor_meta.attr(dtensor_interned_strings.shape);
|
||||
TORCH_CHECK(
|
||||
PyTuple_Check(sizes.ptr()), "spec.tensor_meta.shape must be a tuple");
|
||||
const auto stride = tensor_meta.attr(dtensor_interned_strings.stride);
|
||||
TORCH_CHECK(
|
||||
PyTuple_Check(stride.ptr()), "spec.tensor_meta.stride must be a tuple");
|
||||
|
||||
Tensor tensor = make_tensor_for_subclass_helper(
|
||||
/*sym_sizes=*/tuple_to_symintlist(sizes.ptr()),
|
||||
/*sym_strides=*/tuple_to_symintlist(stride.ptr()),
|
||||
/*sym_storage_offset=*/std::nullopt,
|
||||
options,
|
||||
/*storage_size=*/std::nullopt,
|
||||
extra_dispatch_keys);
|
||||
tensor.set_requires_grad(requires_grad);
|
||||
py::object py_tensor =
|
||||
py::reinterpret_steal<py::object>(THPVariable_NewWithVar(
|
||||
(PyTypeObject*)cls,
|
||||
tensor,
|
||||
// false is the default
|
||||
/*allow_preexisting_pyobj=*/false,
|
||||
// we know DTensor has __torch_dispatch__; avoid checking again.
|
||||
/*has_torch_dispatch_if_known=*/true));
|
||||
py_tensor.attr(dtensor_interned_strings._spec) = spec;
|
||||
py_tensor.attr(dtensor_interned_strings._local_tensor) = local_tensor;
|
||||
return py_tensor.release().ptr();
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static bool DTensor_OpSchema_recompute_comparison_key_impl(
|
||||
PyObject* self,
|
||||
const py::tuple& args_schema) {
|
||||
|
|
@ -1979,8 +2020,8 @@ static PyMethodDef extra_methods[] = {
|
|||
castPyCFunctionWithKeywords(THPVariable_make_wrapper_subclass),
|
||||
METH_STATIC | METH_VARARGS | METH_KEYWORDS,
|
||||
nullptr},
|
||||
{"_make_dtensor",
|
||||
castPyCFunctionWithKeywords(THPVariable_make_dtensor),
|
||||
{"_dtensor__new__",
|
||||
castPyCFunctionWithKeywords(THPVariable_dtensor_new),
|
||||
METH_STATIC | METH_VARARGS | METH_KEYWORDS,
|
||||
nullptr},
|
||||
{"_fix_weakref", THPVariable_fix_weakref, METH_NOARGS, nullptr},
|
||||
|
|
|
|||
|
|
@ -240,8 +240,8 @@ class DTensor(torch.Tensor):
|
|||
# _op_dispatcher instance as a class attribute to handle runtime dispatching logic
|
||||
_op_dispatcher: op_dispatch.OpDispatcher = op_dispatch.OpDispatcher()
|
||||
|
||||
@staticmethod
|
||||
@torch._disable_dynamo
|
||||
# This implementation is just to convince mypy _spec and _local_tensor are
|
||||
# initialized; it is immediately overridden below.
|
||||
def __new__(
|
||||
cls,
|
||||
local_tensor: torch.Tensor,
|
||||
|
|
@ -249,10 +249,21 @@ class DTensor(torch.Tensor):
|
|||
*,
|
||||
requires_grad: bool,
|
||||
) -> "DTensor":
|
||||
r = torch.Tensor._dtensor__new__(
|
||||
cls, local_tensor, spec, requires_grad=requires_grad
|
||||
)
|
||||
r._spec = spec
|
||||
r._local_tensor = local_tensor
|
||||
return r
|
||||
|
||||
__new__ = torch.Tensor._dtensor__new__ # type: ignore[assignment] # noqa: F811
|
||||
|
||||
@torch._disable_dynamo
|
||||
@mark_subclass_constructor_exportable_experimental
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
Construct a DTensor from a local tensor, device mesh, and placement and
|
||||
other tensor properties (i.e. shape, requires_grad, strides, etc).
|
||||
|
||||
.. note:: This is not a public API and it's only supposed to be used by the
|
||||
operator implementations and internals. If you want to construct a
|
||||
DTensor from a local tensor, consider using ``DTensor.from_local``, if
|
||||
|
|
@ -260,31 +271,6 @@ class DTensor(torch.Tensor):
|
|||
already have tensor initialized and want to shard this tensor),
|
||||
consider using ``distribute_tensor``.
|
||||
"""
|
||||
if local_tensor.requires_grad and not requires_grad:
|
||||
warnings.warn(
|
||||
"To construct DTensor from torch.Tensor, it's recommended to "
|
||||
"use local_tensor.detach() and make requires_grad consistent."
|
||||
)
|
||||
|
||||
# new method instruct wrapper tensor from local_tensor and add
|
||||
# placement spec, it does not do actual distribution
|
||||
assert spec.tensor_meta is not None, "TensorMeta should not be None!"
|
||||
|
||||
r = torch.Tensor._make_dtensor(
|
||||
cls,
|
||||
spec.tensor_meta.shape,
|
||||
spec.tensor_meta.stride,
|
||||
local_tensor,
|
||||
requires_grad,
|
||||
)
|
||||
|
||||
r._spec = spec
|
||||
r._local_tensor = local_tensor
|
||||
return r
|
||||
|
||||
@torch._disable_dynamo
|
||||
@mark_subclass_constructor_exportable_experimental
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
# pyre-fixme[14]: `__repr__` overrides method defined in `DTensor` inconsistently.
|
||||
|
|
|
|||
|
|
@ -362,7 +362,7 @@ def get_ignored_functions() -> set[Callable]:
|
|||
Tensor._view_func,
|
||||
Tensor._view_func_unsafe,
|
||||
Tensor._rev_view_func_unsafe,
|
||||
Tensor._make_dtensor,
|
||||
Tensor._dtensor__new__,
|
||||
Tensor._make_wrapper_subclass,
|
||||
Tensor._python_dispatch.__get__,
|
||||
Tensor._has_symbolic_sizes_strides.__get__,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user