Optimize mutable torch.library.custom_op overhead (#139513)

We don't need to do a loop over all the args, kwargs in the
AdInplaceOrView key; we just need to bump the version on the args,
kwargs that are mutable.

On the benchmark mentioned in
https://github.com/pytorch/pytorch/issues/139494
this made the time go from
```
mutate2 = 61.72943878173828
no_mutate2 = 36.89440155029297
mutate = 236.3092498779297
no_mutate = 59.31964874267578

```
to
```
mutate2 = 47.976478576660156
no_mutate2 = 38.37468719482422
mutate = 71.21315002441406
no_mutate = 59.7432975769043
```

Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139513
Approved by: https://github.com/bdhirsh
ghstack dependencies: #139509
This commit is contained in:
rzou 2024-11-01 13:21:12 -07:00 committed by PyTorch MergeBot
parent 9dc5851f5d
commit 27ec3921bc
2 changed files with 27 additions and 12 deletions

View File

@ -594,19 +594,13 @@ class CustomOpDef:
schema = self._opoverload._schema
if schema.is_mutable:
mutated_idxs, mutated_keys = utils.mutated_args_kwargs(schema)
def adinplaceorview_impl(keyset, *args, **kwargs):
for arg, val in utils.zip_schema(schema, args, kwargs):
if not arg.alias_info:
continue
if not arg.alias_info.is_write:
continue
if isinstance(val, Tensor):
torch.autograd.graph.increment_version(val)
elif isinstance(val, (tuple, list)):
for v in val:
if isinstance(v, Tensor):
torch.autograd.graph.increment_version(v)
for idx in mutated_idxs:
increment_version(args[idx])
for key in mutated_keys:
increment_version(kwargs[key])
with _C._AutoDispatchBelowADInplaceOrView():
return self._opoverload.redispatch(
keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs
@ -740,6 +734,15 @@ class CustomOpDef:
return register(func)
def increment_version(val: Any) -> None:
if isinstance(val, Tensor):
torch.autograd.graph.increment_version(val)
elif isinstance(val, (tuple, list)):
for v in val:
if isinstance(v, Tensor):
torch.autograd.graph.increment_version(v)
# NOTE: [Supporting decorator and non-decorator usage]
#
# Some APIs may be both used as a decorator and not as a decorator.

View File

@ -3,7 +3,7 @@ import dataclasses
import inspect
import sys
import warnings
from typing import Any, Callable, Dict, Iterable, Iterator, Tuple, Union
from typing import Any, Callable, Dict, Iterable, Iterator, List, Tuple, Union
import torch
import torch.utils._pytree as pytree
@ -463,3 +463,15 @@ def has_fake_kernel(op: torch._ops.OpOverload) -> bool:
if opdef._abstract_fn is not None:
return True
return False
def mutated_args_kwargs(schema: _C.FunctionSchema) -> Tuple[List[int], List[str]]:
idxs = []
keys = []
for i, info in enumerate(schema.arguments):
if info.alias_info is not None and info.alias_info.is_write:
if info.kwarg_only:
keys.append(info.name)
else:
idxs.append(i)
return idxs, keys