Optimize mutable torch.library.custom_op overhead (#139513)

We don't need to do a loop over all the args, kwargs in the
AdInplaceOrView key; we just need to bump the version on the args,
kwargs that are mutable.

On the benchmark mentioned in
https://github.com/pytorch/pytorch/issues/139494
this made the time go from
```
mutate2 = 61.72943878173828
no_mutate2 = 36.89440155029297
mutate = 236.3092498779297
no_mutate = 59.31964874267578

```
to
```
mutate2 = 47.976478576660156
no_mutate2 = 38.37468719482422
mutate = 71.21315002441406
no_mutate = 59.7432975769043
```

Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139513
Approved by: https://github.com/bdhirsh
ghstack dependencies: #139509
This commit is contained in:
rzou 2024-11-01 13:21:12 -07:00 committed by PyTorch MergeBot
parent 9dc5851f5d
commit 27ec3921bc
2 changed files with 27 additions and 12 deletions

View File

@ -594,19 +594,13 @@ class CustomOpDef:
schema = self._opoverload._schema schema = self._opoverload._schema
if schema.is_mutable: if schema.is_mutable:
mutated_idxs, mutated_keys = utils.mutated_args_kwargs(schema)
def adinplaceorview_impl(keyset, *args, **kwargs): def adinplaceorview_impl(keyset, *args, **kwargs):
for arg, val in utils.zip_schema(schema, args, kwargs): for idx in mutated_idxs:
if not arg.alias_info: increment_version(args[idx])
continue for key in mutated_keys:
if not arg.alias_info.is_write: increment_version(kwargs[key])
continue
if isinstance(val, Tensor):
torch.autograd.graph.increment_version(val)
elif isinstance(val, (tuple, list)):
for v in val:
if isinstance(v, Tensor):
torch.autograd.graph.increment_version(v)
with _C._AutoDispatchBelowADInplaceOrView(): with _C._AutoDispatchBelowADInplaceOrView():
return self._opoverload.redispatch( return self._opoverload.redispatch(
keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs
@ -740,6 +734,15 @@ class CustomOpDef:
return register(func) return register(func)
def increment_version(val: Any) -> None:
if isinstance(val, Tensor):
torch.autograd.graph.increment_version(val)
elif isinstance(val, (tuple, list)):
for v in val:
if isinstance(v, Tensor):
torch.autograd.graph.increment_version(v)
# NOTE: [Supporting decorator and non-decorator usage] # NOTE: [Supporting decorator and non-decorator usage]
# #
# Some APIs may be both used as a decorator and not as a decorator. # Some APIs may be both used as a decorator and not as a decorator.

View File

@ -3,7 +3,7 @@ import dataclasses
import inspect import inspect
import sys import sys
import warnings import warnings
from typing import Any, Callable, Dict, Iterable, Iterator, Tuple, Union from typing import Any, Callable, Dict, Iterable, Iterator, List, Tuple, Union
import torch import torch
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
@ -463,3 +463,15 @@ def has_fake_kernel(op: torch._ops.OpOverload) -> bool:
if opdef._abstract_fn is not None: if opdef._abstract_fn is not None:
return True return True
return False return False
def mutated_args_kwargs(schema: _C.FunctionSchema) -> Tuple[List[int], List[str]]:
idxs = []
keys = []
for i, info in enumerate(schema.arguments):
if info.alias_info is not None and info.alias_info.is_write:
if info.kwarg_only:
keys.append(info.name)
else:
idxs.append(i)
return idxs, keys