mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[custom_ops][perf] Move expensive pytree traversals of tensors to C++ (#148555)
(benchmark for 1 call) Before: ``` └─ $ python ~/task_custom_ops_perf/test_custom_ops_perf_repro.py DO_BENCH mutate: 77.72445678710938 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/mutate.json DO_BENCH no_mutate: 64.61143493652344 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/no_mutate.json DO_BENCH direct_mutate: 11.682510375976562 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/direct_mutate.json DO_BENCH direct_no_mutate: 18.596649169921875 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/direct_no_mutate.json ``` After: ``` └─ $ python ~/task_custom_ops_perf/test_custom_ops_perf_repro.py DO_BENCH mutate: 47.6837158203125 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/mutate.json DO_BENCH no_mutate: 31.709671020507812 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/no_mutate.json DO_BENCH direct_mutate: 10.967254638671875 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/direct_mutate.json DO_BENCH direct_no_mutate: 10.728836059570312 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/direct_no_mutate.json ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/148555 Approved by: https://github.com/zou3519
This commit is contained in:
parent
518563d6ef
commit
d686d04c2f
|
|
@ -3910,6 +3910,115 @@ Please use `add.register_fake` to add an fake impl.""",
|
|||
self.assertTrue(called)
|
||||
self.assertEqual(result, x + y)
|
||||
|
||||
def test_any_requires_grad(self):
|
||||
test_fn = torch._C._any_requires_grad
|
||||
|
||||
self.assertTrue(
|
||||
test_fn(
|
||||
torch.zeros(1, requires_grad=True), torch.ones(1, requires_grad=True)
|
||||
)
|
||||
)
|
||||
self.assertFalse(test_fn(torch.ones(1), torch.zeros(1)))
|
||||
self.assertTrue(
|
||||
test_fn(
|
||||
[torch.zeros(1, requires_grad=True), torch.ones(1, requires_grad=True)]
|
||||
)
|
||||
)
|
||||
# _C_any_requires_grad supports only List[Tensor] in args, not List[List[Tensor]]
|
||||
self.assertFalse(test_fn([[torch.zeros(1, requires_grad=True)]], torch.ones(1)))
|
||||
self.assertFalse(test_fn([torch.zeros(1), torch.ones(1)]))
|
||||
self.assertTrue(test_fn(torch.zeros(1), a=torch.ones(1, requires_grad=True)))
|
||||
self.assertFalse(test_fn(torch.zeros(1), a=torch.ones(1)))
|
||||
self.assertTrue(
|
||||
test_fn([torch.zeros(1, requires_grad=True), torch.ones(1)], torch.zeros(1))
|
||||
)
|
||||
self.assertFalse(test_fn([torch.zeros(1), torch.ones(1)], torch.zeros(1)))
|
||||
|
||||
def test_any_output_is_alias_to_input_or_output(self):
|
||||
test_fn = torch._C._any_output_is_alias_to_input_or_output
|
||||
|
||||
x = torch.randn(2, 2)
|
||||
y = torch.randn(2, 2)
|
||||
self.assertTrue(
|
||||
test_fn(
|
||||
(x,),
|
||||
{},
|
||||
(x.t(),),
|
||||
)
|
||||
)
|
||||
self.assertFalse(test_fn((x,), None, (2 * x,)))
|
||||
self.assertTrue(
|
||||
test_fn(
|
||||
(),
|
||||
{"a": x.view(-1)},
|
||||
(x,),
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
test_fn(
|
||||
(),
|
||||
{"a": x.view(-1)},
|
||||
(x.t(),),
|
||||
)
|
||||
)
|
||||
self.assertTrue(test_fn((y,), {}, (y[1:],)))
|
||||
self.assertFalse(
|
||||
test_fn(
|
||||
(x,),
|
||||
{"a": x},
|
||||
(),
|
||||
)
|
||||
)
|
||||
self.assertFalse(
|
||||
test_fn(
|
||||
(torch.tensor([]),),
|
||||
{},
|
||||
(torch.tensor([]),),
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
test_fn(
|
||||
([x], x + 1),
|
||||
{},
|
||||
(x.t(),),
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
test_fn(
|
||||
([x], x + 1),
|
||||
{},
|
||||
([x.t()], x + 1),
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
test_fn(
|
||||
([x], x),
|
||||
{},
|
||||
([x.t()], x + 1),
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
test_fn(
|
||||
([x, 1], x),
|
||||
{},
|
||||
([x.t()], x + 1),
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
test_fn(
|
||||
([[x]], x),
|
||||
{},
|
||||
([x.t()], x + 1),
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
test_fn(
|
||||
([[1, x], 2], 3),
|
||||
{},
|
||||
([x.t()], x + 1),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class MiniOpTestOther(CustomOpTestCaseBase):
|
||||
test_ns = "mini_op_test"
|
||||
|
|
|
|||
|
|
@ -1357,6 +1357,8 @@ def _set_grad_enabled(enabled: _bool) -> None: ...
|
|||
def is_grad_enabled() -> _bool: ...
|
||||
def _set_fwd_grad_enabled(enabled: _bool) -> None: ...
|
||||
def _is_fwd_grad_enabled() -> _bool: ...
|
||||
def _any_requires_grad(*args, **kwargs) -> _bool: ...
|
||||
def _any_output_is_alias_to_input_or_output(*args, **kwargs) -> _bool: ...
|
||||
def is_inference_mode_enabled() -> _bool: ...
|
||||
@overload
|
||||
def set_autocast_enabled(device_type: str, enabled: _bool) -> None: ...
|
||||
|
|
|
|||
|
|
@ -105,9 +105,7 @@ def make_autograd_impl(op: _ops.OpOverload, info: InfoProtocol) -> Callable:
|
|||
# The dispatcher passes any keyword-only-args as kwargs and the
|
||||
# rest of the args (even if specified as kwargs) as args.
|
||||
def autograd_impl(keyset, *args, **keyword_only_args):
|
||||
if _C.is_grad_enabled() and _pytree.tree_any_only(
|
||||
Tensor, lambda x: x.requires_grad, args, not_list_of_tensor
|
||||
):
|
||||
if _C.is_grad_enabled() and _C._any_requires_grad(*args):
|
||||
result = Generated.apply(*args, Metadata(keyset, keyword_only_args)) # type: ignore[attr-defined]
|
||||
else:
|
||||
result = forward_no_grad(*args, Metadata(keyset, keyword_only_args))
|
||||
|
|
|
|||
|
|
@ -338,9 +338,10 @@ class CustomOpDef:
|
|||
fn = self._backend_fns[device_type]
|
||||
return inspect.getmodule(fn)
|
||||
|
||||
utils.check_aliasing_constraint(
|
||||
utils._c_check_aliasing_constraint(
|
||||
self._name,
|
||||
utils.iter_tensors(args, kwargs),
|
||||
args,
|
||||
kwargs,
|
||||
result,
|
||||
get_module,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -373,6 +373,31 @@ def check_aliasing_constraint(name, prev, result, get_module=lambda: "???"):
|
|||
storages.add(key)
|
||||
|
||||
|
||||
def _c_check_aliasing_constraint(name, args, kwargs, result, get_module=lambda: "???"):
|
||||
"""
|
||||
custom operators' outputs must not have any aliases
|
||||
This version uses C++ implementation for perf.
|
||||
Only List container is supported.
|
||||
Tensors in Lists with not only Tensors are checked.
|
||||
"""
|
||||
tuple_result = result
|
||||
if not isinstance(result, tuple):
|
||||
tuple_result = (result,)
|
||||
if _C._any_output_is_alias_to_input_or_output(args, kwargs, tuple_result):
|
||||
raise RuntimeError(
|
||||
f"{name} (with implementation in {get_module()}): "
|
||||
f"The output of this custom operator (1) must not "
|
||||
f"also be an input to this custom operator and "
|
||||
f"(2) may not alias any inputs to this custom operator "
|
||||
f"or other returns. "
|
||||
f"The most common way to trigger this error is if "
|
||||
f"we have y = custom_op(x) and y and x are the same Tensor. "
|
||||
f"Please instead return a clone of the offending output "
|
||||
f"tensor(s) (e.g. return x.clone()) or refactor the custom "
|
||||
f"operator to not return y."
|
||||
)
|
||||
|
||||
|
||||
class MutationChecker:
|
||||
"""
|
||||
Check if an operator mutated its arguments.
|
||||
|
|
|
|||
|
|
@ -955,6 +955,133 @@ static PyObject* is_fwd_grad_enabled(PyObject* _unused, PyObject* arg) {
|
|||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
template <bool skip_tensors_in_non_tensorlist>
|
||||
static bool visit(
|
||||
PyObject* o,
|
||||
const std::function<bool(at::Tensor&)>& visit_tensor) {
|
||||
if (THPVariable_Check(o)) {
|
||||
auto t = THPVariable_Unpack(o);
|
||||
if (visit_tensor(t)) {
|
||||
return true;
|
||||
}
|
||||
} else if (PyList_Check(o)) {
|
||||
// Check that this List is TensorList
|
||||
if constexpr (skip_tensors_in_non_tensorlist) {
|
||||
for (const auto i : c10::irange(PyList_GET_SIZE(o))) {
|
||||
if (!THPVariable_Check(PyList_GET_ITEM(o, i))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (const auto i : c10::irange(PyList_GET_SIZE(o))) {
|
||||
if (visit<skip_tensors_in_non_tensorlist>(
|
||||
PyList_GET_ITEM(o, i), visit_tensor)) {
|
||||
return true;
|
||||
};
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Visiting of tensors in args and kwargs,
|
||||
// only List container is visited.
|
||||
// skip_tensors_in_non_tensorlist will skip any List with non-Tensor.
|
||||
// Lambda returning true means short circuit, traversal stops after that.
|
||||
template <bool skip_tensors_in_non_tensorlist>
|
||||
static void visit_tensors(
|
||||
PyObject* args,
|
||||
PyObject* kwargs,
|
||||
const std::function<bool(at::Tensor&)>& visit_tensor) {
|
||||
if (args && PyTuple_Check(args)) {
|
||||
for (const auto i : c10::irange(PyTuple_GET_SIZE(args))) {
|
||||
if (visit<skip_tensors_in_non_tensorlist>(
|
||||
PyTuple_GET_ITEM(args, i), visit_tensor)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (kwargs && PyDict_Check(kwargs)) {
|
||||
auto vals = PyDict_Values(kwargs);
|
||||
for (const auto i : c10::irange(PyList_GET_SIZE(vals))) {
|
||||
if (visit<skip_tensors_in_non_tensorlist>(
|
||||
PyList_GET_ITEM(vals, i), visit_tensor)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Returns true if any of the args, kwargs tensor leaves have requires_grad.
|
||||
// Only List[Tensor] container in args is supported.
|
||||
static PyObject* any_requires_grad(
|
||||
PyObject* _unused,
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
bool has_requires_grad = false;
|
||||
visit_tensors<true>(args, kwargs, [&has_requires_grad](at::Tensor& t) {
|
||||
if (t.requires_grad()) {
|
||||
has_requires_grad = true;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
if (has_requires_grad) {
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
Py_RETURN_FALSE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// Checks aliasing constraint for custom ops:
|
||||
// Returns true if any of outputs is alias to any of inputs or another output
|
||||
// Args:
|
||||
// args[0] - inputs args
|
||||
// args[1] - inputs kwargs
|
||||
// args[2] - outputs
|
||||
// Only List container is supported.
|
||||
// Tensors in Lists that has not only Tensor are checked.
|
||||
static PyObject* any_output_is_alias_to_input_or_output(
|
||||
PyObject* _unused,
|
||||
PyObject* args) {
|
||||
HANDLE_TH_ERRORS
|
||||
PyObject* inps = PyTuple_GET_ITEM(args, 0);
|
||||
PyObject* inps_kwargs = PyTuple_GET_ITEM(args, 1);
|
||||
PyObject* outs = PyTuple_GET_ITEM(args, 2);
|
||||
std::unordered_set<void*> s;
|
||||
visit_tensors<false>(inps, inps_kwargs, [&s](at::Tensor& t) {
|
||||
if (!t.storage()) {
|
||||
return false;
|
||||
}
|
||||
auto* cp = t.storage().data_ptr().get_context();
|
||||
if (cp) {
|
||||
s.insert(cp);
|
||||
}
|
||||
return false;
|
||||
});
|
||||
bool ret = false;
|
||||
visit_tensors<false>(outs, nullptr, [&s, &ret](at::Tensor& t) {
|
||||
if (!t.storage()) {
|
||||
return false;
|
||||
}
|
||||
auto* cp = t.storage().data_ptr().get_context();
|
||||
if (!cp) {
|
||||
return false;
|
||||
}
|
||||
if (s.find(cp) != s.end()) {
|
||||
ret = true;
|
||||
return true;
|
||||
}
|
||||
s.insert(cp);
|
||||
return false;
|
||||
});
|
||||
if (ret) {
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
Py_RETURN_FALSE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* set_multithreading_enabled(
|
||||
PyObject* self,
|
||||
PyObject* args,
|
||||
|
|
@ -1326,6 +1453,14 @@ static PyMethodDef methods[] = {
|
|||
nullptr},
|
||||
{"is_grad_enabled", is_grad_enabled, METH_NOARGS, nullptr},
|
||||
{"_set_fwd_grad_enabled", set_fwd_grad_enabled, METH_O, nullptr},
|
||||
{"_any_requires_grad",
|
||||
castPyCFunctionWithKeywords(any_requires_grad),
|
||||
METH_VARARGS | METH_KEYWORDS,
|
||||
nullptr},
|
||||
{"_any_output_is_alias_to_input_or_output",
|
||||
any_output_is_alias_to_input_or_output,
|
||||
METH_VARARGS,
|
||||
nullptr},
|
||||
{"_is_fwd_grad_enabled", is_fwd_grad_enabled, METH_NOARGS, nullptr},
|
||||
{"is_inference_mode_enabled",
|
||||
is_inference_mode_enabled,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user