[custom_ops][perf] Move expensive pytree traversals of tensors to C++ (#148555)

(benchmark for 1 call)

Before:
```
└─ $ python ~/task_custom_ops_perf/test_custom_ops_perf_repro.py
DO_BENCH mutate: 77.72445678710938 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/mutate.json
DO_BENCH no_mutate: 64.61143493652344 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/no_mutate.json
DO_BENCH direct_mutate: 11.682510375976562 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/direct_mutate.json
DO_BENCH direct_no_mutate: 18.596649169921875 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/direct_no_mutate.json
```

After:
```
└─ $ python ~/task_custom_ops_perf/test_custom_ops_perf_repro.py
DO_BENCH mutate: 47.6837158203125 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/mutate.json
DO_BENCH no_mutate: 31.709671020507812 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/no_mutate.json
DO_BENCH direct_mutate: 10.967254638671875 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/direct_mutate.json
DO_BENCH direct_no_mutate: 10.728836059570312 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/direct_no_mutate.json
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148555
Approved by: https://github.com/zou3519
This commit is contained in:
IvanKobzarev 2025-03-18 14:03:16 -07:00 committed by PyTorch MergeBot
parent 518563d6ef
commit d686d04c2f
6 changed files with 275 additions and 5 deletions

View File

@ -3910,6 +3910,115 @@ Please use `add.register_fake` to add an fake impl.""",
self.assertTrue(called)
self.assertEqual(result, x + y)
def test_any_requires_grad(self):
test_fn = torch._C._any_requires_grad
self.assertTrue(
test_fn(
torch.zeros(1, requires_grad=True), torch.ones(1, requires_grad=True)
)
)
self.assertFalse(test_fn(torch.ones(1), torch.zeros(1)))
self.assertTrue(
test_fn(
[torch.zeros(1, requires_grad=True), torch.ones(1, requires_grad=True)]
)
)
# _C_any_requires_grad supports only List[Tensor] in args, not List[List[Tensor]]
self.assertFalse(test_fn([[torch.zeros(1, requires_grad=True)]], torch.ones(1)))
self.assertFalse(test_fn([torch.zeros(1), torch.ones(1)]))
self.assertTrue(test_fn(torch.zeros(1), a=torch.ones(1, requires_grad=True)))
self.assertFalse(test_fn(torch.zeros(1), a=torch.ones(1)))
self.assertTrue(
test_fn([torch.zeros(1, requires_grad=True), torch.ones(1)], torch.zeros(1))
)
self.assertFalse(test_fn([torch.zeros(1), torch.ones(1)], torch.zeros(1)))
def test_any_output_is_alias_to_input_or_output(self):
test_fn = torch._C._any_output_is_alias_to_input_or_output
x = torch.randn(2, 2)
y = torch.randn(2, 2)
self.assertTrue(
test_fn(
(x,),
{},
(x.t(),),
)
)
self.assertFalse(test_fn((x,), None, (2 * x,)))
self.assertTrue(
test_fn(
(),
{"a": x.view(-1)},
(x,),
)
)
self.assertTrue(
test_fn(
(),
{"a": x.view(-1)},
(x.t(),),
)
)
self.assertTrue(test_fn((y,), {}, (y[1:],)))
self.assertFalse(
test_fn(
(x,),
{"a": x},
(),
)
)
self.assertFalse(
test_fn(
(torch.tensor([]),),
{},
(torch.tensor([]),),
)
)
self.assertTrue(
test_fn(
([x], x + 1),
{},
(x.t(),),
)
)
self.assertTrue(
test_fn(
([x], x + 1),
{},
([x.t()], x + 1),
)
)
self.assertTrue(
test_fn(
([x], x),
{},
([x.t()], x + 1),
)
)
self.assertTrue(
test_fn(
([x, 1], x),
{},
([x.t()], x + 1),
)
)
self.assertTrue(
test_fn(
([[x]], x),
{},
([x.t()], x + 1),
)
)
self.assertTrue(
test_fn(
([[1, x], 2], 3),
{},
([x.t()], x + 1),
)
)
class MiniOpTestOther(CustomOpTestCaseBase):
test_ns = "mini_op_test"

View File

@ -1357,6 +1357,8 @@ def _set_grad_enabled(enabled: _bool) -> None: ...
def is_grad_enabled() -> _bool: ...
def _set_fwd_grad_enabled(enabled: _bool) -> None: ...
def _is_fwd_grad_enabled() -> _bool: ...
def _any_requires_grad(*args, **kwargs) -> _bool: ...
def _any_output_is_alias_to_input_or_output(*args, **kwargs) -> _bool: ...
def is_inference_mode_enabled() -> _bool: ...
@overload
def set_autocast_enabled(device_type: str, enabled: _bool) -> None: ...

View File

@ -105,9 +105,7 @@ def make_autograd_impl(op: _ops.OpOverload, info: InfoProtocol) -> Callable:
# The dispatcher passes any keyword-only-args as kwargs and the
# rest of the args (even if specified as kwargs) as args.
def autograd_impl(keyset, *args, **keyword_only_args):
if _C.is_grad_enabled() and _pytree.tree_any_only(
Tensor, lambda x: x.requires_grad, args, not_list_of_tensor
):
if _C.is_grad_enabled() and _C._any_requires_grad(*args):
result = Generated.apply(*args, Metadata(keyset, keyword_only_args)) # type: ignore[attr-defined]
else:
result = forward_no_grad(*args, Metadata(keyset, keyword_only_args))

View File

@ -338,9 +338,10 @@ class CustomOpDef:
fn = self._backend_fns[device_type]
return inspect.getmodule(fn)
utils.check_aliasing_constraint(
utils._c_check_aliasing_constraint(
self._name,
utils.iter_tensors(args, kwargs),
args,
kwargs,
result,
get_module,
)

View File

@ -373,6 +373,31 @@ def check_aliasing_constraint(name, prev, result, get_module=lambda: "???"):
storages.add(key)
def _c_check_aliasing_constraint(name, args, kwargs, result, get_module=lambda: "???"):
"""
custom operators' outputs must not have any aliases
This version uses C++ implementation for perf.
Only List container is supported.
Tensors in Lists with not only Tensors are checked.
"""
tuple_result = result
if not isinstance(result, tuple):
tuple_result = (result,)
if _C._any_output_is_alias_to_input_or_output(args, kwargs, tuple_result):
raise RuntimeError(
f"{name} (with implementation in {get_module()}): "
f"The output of this custom operator (1) must not "
f"also be an input to this custom operator and "
f"(2) may not alias any inputs to this custom operator "
f"or other returns. "
f"The most common way to trigger this error is if "
f"we have y = custom_op(x) and y and x are the same Tensor. "
f"Please instead return a clone of the offending output "
f"tensor(s) (e.g. return x.clone()) or refactor the custom "
f"operator to not return y."
)
class MutationChecker:
"""
Check if an operator mutated its arguments.

View File

@ -955,6 +955,133 @@ static PyObject* is_fwd_grad_enabled(PyObject* _unused, PyObject* arg) {
END_HANDLE_TH_ERRORS
}
template <bool skip_tensors_in_non_tensorlist>
static bool visit(
PyObject* o,
const std::function<bool(at::Tensor&)>& visit_tensor) {
if (THPVariable_Check(o)) {
auto t = THPVariable_Unpack(o);
if (visit_tensor(t)) {
return true;
}
} else if (PyList_Check(o)) {
// Check that this List is TensorList
if constexpr (skip_tensors_in_non_tensorlist) {
for (const auto i : c10::irange(PyList_GET_SIZE(o))) {
if (!THPVariable_Check(PyList_GET_ITEM(o, i))) {
return false;
}
}
}
for (const auto i : c10::irange(PyList_GET_SIZE(o))) {
if (visit<skip_tensors_in_non_tensorlist>(
PyList_GET_ITEM(o, i), visit_tensor)) {
return true;
};
}
}
return false;
}
// Visiting of tensors in args and kwargs,
// only List container is visited.
// skip_tensors_in_non_tensorlist will skip any List with non-Tensor.
// Lambda returning true means short circuit, traversal stops after that.
template <bool skip_tensors_in_non_tensorlist>
static void visit_tensors(
PyObject* args,
PyObject* kwargs,
const std::function<bool(at::Tensor&)>& visit_tensor) {
if (args && PyTuple_Check(args)) {
for (const auto i : c10::irange(PyTuple_GET_SIZE(args))) {
if (visit<skip_tensors_in_non_tensorlist>(
PyTuple_GET_ITEM(args, i), visit_tensor)) {
return;
}
}
}
if (kwargs && PyDict_Check(kwargs)) {
auto vals = PyDict_Values(kwargs);
for (const auto i : c10::irange(PyList_GET_SIZE(vals))) {
if (visit<skip_tensors_in_non_tensorlist>(
PyList_GET_ITEM(vals, i), visit_tensor)) {
return;
}
}
}
}
// Returns true if any of the args, kwargs tensor leaves have requires_grad.
// Only List[Tensor] container in args is supported.
static PyObject* any_requires_grad(
PyObject* _unused,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
bool has_requires_grad = false;
visit_tensors<true>(args, kwargs, [&has_requires_grad](at::Tensor& t) {
if (t.requires_grad()) {
has_requires_grad = true;
return true;
}
return false;
});
if (has_requires_grad) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
END_HANDLE_TH_ERRORS
}
// Checks aliasing constraint for custom ops:
// Returns true if any of outputs is alias to any of inputs or another output
// Args:
// args[0] - inputs args
// args[1] - inputs kwargs
// args[2] - outputs
// Only List container is supported.
// Tensors in Lists that has not only Tensor are checked.
static PyObject* any_output_is_alias_to_input_or_output(
PyObject* _unused,
PyObject* args) {
HANDLE_TH_ERRORS
PyObject* inps = PyTuple_GET_ITEM(args, 0);
PyObject* inps_kwargs = PyTuple_GET_ITEM(args, 1);
PyObject* outs = PyTuple_GET_ITEM(args, 2);
std::unordered_set<void*> s;
visit_tensors<false>(inps, inps_kwargs, [&s](at::Tensor& t) {
if (!t.storage()) {
return false;
}
auto* cp = t.storage().data_ptr().get_context();
if (cp) {
s.insert(cp);
}
return false;
});
bool ret = false;
visit_tensors<false>(outs, nullptr, [&s, &ret](at::Tensor& t) {
if (!t.storage()) {
return false;
}
auto* cp = t.storage().data_ptr().get_context();
if (!cp) {
return false;
}
if (s.find(cp) != s.end()) {
ret = true;
return true;
}
s.insert(cp);
return false;
});
if (ret) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
END_HANDLE_TH_ERRORS
}
static PyObject* set_multithreading_enabled(
PyObject* self,
PyObject* args,
@ -1326,6 +1453,14 @@ static PyMethodDef methods[] = {
nullptr},
{"is_grad_enabled", is_grad_enabled, METH_NOARGS, nullptr},
{"_set_fwd_grad_enabled", set_fwd_grad_enabled, METH_O, nullptr},
{"_any_requires_grad",
castPyCFunctionWithKeywords(any_requires_grad),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_any_output_is_alias_to_input_or_output",
any_output_is_alias_to_input_or_output,
METH_VARARGS,
nullptr},
{"_is_fwd_grad_enabled", is_fwd_grad_enabled, METH_NOARGS, nullptr},
{"is_inference_mode_enabled",
is_inference_mode_enabled,