mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Optimize mutable torch.library.custom_op overhead (#139513)
We don't need to do a loop over all the args, kwargs in the AdInplaceOrView key; we just need to bump the version on the args, kwargs that are mutable. On the benchmark mentioned in https://github.com/pytorch/pytorch/issues/139494 this made the time go from ``` mutate2 = 61.72943878173828 no_mutate2 = 36.89440155029297 mutate = 236.3092498779297 no_mutate = 59.31964874267578 ``` to ``` mutate2 = 47.976478576660156 no_mutate2 = 38.37468719482422 mutate = 71.21315002441406 no_mutate = 59.7432975769043 ``` Test Plan: - existing tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/139513 Approved by: https://github.com/bdhirsh ghstack dependencies: #139509
This commit is contained in:
parent
9dc5851f5d
commit
27ec3921bc
|
|
@ -594,19 +594,13 @@ class CustomOpDef:
|
|||
|
||||
schema = self._opoverload._schema
|
||||
if schema.is_mutable:
|
||||
mutated_idxs, mutated_keys = utils.mutated_args_kwargs(schema)
|
||||
|
||||
def adinplaceorview_impl(keyset, *args, **kwargs):
|
||||
for arg, val in utils.zip_schema(schema, args, kwargs):
|
||||
if not arg.alias_info:
|
||||
continue
|
||||
if not arg.alias_info.is_write:
|
||||
continue
|
||||
if isinstance(val, Tensor):
|
||||
torch.autograd.graph.increment_version(val)
|
||||
elif isinstance(val, (tuple, list)):
|
||||
for v in val:
|
||||
if isinstance(v, Tensor):
|
||||
torch.autograd.graph.increment_version(v)
|
||||
for idx in mutated_idxs:
|
||||
increment_version(args[idx])
|
||||
for key in mutated_keys:
|
||||
increment_version(kwargs[key])
|
||||
with _C._AutoDispatchBelowADInplaceOrView():
|
||||
return self._opoverload.redispatch(
|
||||
keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs
|
||||
|
|
@ -740,6 +734,15 @@ class CustomOpDef:
|
|||
return register(func)
|
||||
|
||||
|
||||
def increment_version(val: Any) -> None:
|
||||
if isinstance(val, Tensor):
|
||||
torch.autograd.graph.increment_version(val)
|
||||
elif isinstance(val, (tuple, list)):
|
||||
for v in val:
|
||||
if isinstance(v, Tensor):
|
||||
torch.autograd.graph.increment_version(v)
|
||||
|
||||
|
||||
# NOTE: [Supporting decorator and non-decorator usage]
|
||||
#
|
||||
# Some APIs may be both used as a decorator and not as a decorator.
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import dataclasses
|
|||
import inspect
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, Iterable, Iterator, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Iterable, Iterator, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
|
|
@ -463,3 +463,15 @@ def has_fake_kernel(op: torch._ops.OpOverload) -> bool:
|
|||
if opdef._abstract_fn is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def mutated_args_kwargs(schema: _C.FunctionSchema) -> Tuple[List[int], List[str]]:
|
||||
idxs = []
|
||||
keys = []
|
||||
for i, info in enumerate(schema.arguments):
|
||||
if info.alias_info is not None and info.alias_info.is_write:
|
||||
if info.kwarg_only:
|
||||
keys.append(info.name)
|
||||
else:
|
||||
idxs.append(i)
|
||||
return idxs, keys
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user