mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Optimize mutable torch.library.custom_op overhead (#139513)
We don't need to do a loop over all the args, kwargs in the AdInplaceOrView key; we just need to bump the version on the args, kwargs that are mutable. On the benchmark mentioned in https://github.com/pytorch/pytorch/issues/139494 this made the time go from ``` mutate2 = 61.72943878173828 no_mutate2 = 36.89440155029297 mutate = 236.3092498779297 no_mutate = 59.31964874267578 ``` to ``` mutate2 = 47.976478576660156 no_mutate2 = 38.37468719482422 mutate = 71.21315002441406 no_mutate = 59.7432975769043 ``` Test Plan: - existing tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/139513 Approved by: https://github.com/bdhirsh ghstack dependencies: #139509
This commit is contained in:
parent
9dc5851f5d
commit
27ec3921bc
|
|
@ -594,19 +594,13 @@ class CustomOpDef:
|
||||||
|
|
||||||
schema = self._opoverload._schema
|
schema = self._opoverload._schema
|
||||||
if schema.is_mutable:
|
if schema.is_mutable:
|
||||||
|
mutated_idxs, mutated_keys = utils.mutated_args_kwargs(schema)
|
||||||
|
|
||||||
def adinplaceorview_impl(keyset, *args, **kwargs):
|
def adinplaceorview_impl(keyset, *args, **kwargs):
|
||||||
for arg, val in utils.zip_schema(schema, args, kwargs):
|
for idx in mutated_idxs:
|
||||||
if not arg.alias_info:
|
increment_version(args[idx])
|
||||||
continue
|
for key in mutated_keys:
|
||||||
if not arg.alias_info.is_write:
|
increment_version(kwargs[key])
|
||||||
continue
|
|
||||||
if isinstance(val, Tensor):
|
|
||||||
torch.autograd.graph.increment_version(val)
|
|
||||||
elif isinstance(val, (tuple, list)):
|
|
||||||
for v in val:
|
|
||||||
if isinstance(v, Tensor):
|
|
||||||
torch.autograd.graph.increment_version(v)
|
|
||||||
with _C._AutoDispatchBelowADInplaceOrView():
|
with _C._AutoDispatchBelowADInplaceOrView():
|
||||||
return self._opoverload.redispatch(
|
return self._opoverload.redispatch(
|
||||||
keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs
|
keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs
|
||||||
|
|
@ -740,6 +734,15 @@ class CustomOpDef:
|
||||||
return register(func)
|
return register(func)
|
||||||
|
|
||||||
|
|
||||||
|
def increment_version(val: Any) -> None:
|
||||||
|
if isinstance(val, Tensor):
|
||||||
|
torch.autograd.graph.increment_version(val)
|
||||||
|
elif isinstance(val, (tuple, list)):
|
||||||
|
for v in val:
|
||||||
|
if isinstance(v, Tensor):
|
||||||
|
torch.autograd.graph.increment_version(v)
|
||||||
|
|
||||||
|
|
||||||
# NOTE: [Supporting decorator and non-decorator usage]
|
# NOTE: [Supporting decorator and non-decorator usage]
|
||||||
#
|
#
|
||||||
# Some APIs may be both used as a decorator and not as a decorator.
|
# Some APIs may be both used as a decorator and not as a decorator.
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import dataclasses
|
||||||
import inspect
|
import inspect
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Callable, Dict, Iterable, Iterator, Tuple, Union
|
from typing import Any, Callable, Dict, Iterable, Iterator, List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
|
|
@ -463,3 +463,15 @@ def has_fake_kernel(op: torch._ops.OpOverload) -> bool:
|
||||||
if opdef._abstract_fn is not None:
|
if opdef._abstract_fn is not None:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def mutated_args_kwargs(schema: _C.FunctionSchema) -> Tuple[List[int], List[str]]:
|
||||||
|
idxs = []
|
||||||
|
keys = []
|
||||||
|
for i, info in enumerate(schema.arguments):
|
||||||
|
if info.alias_info is not None and info.alias_info.is_write:
|
||||||
|
if info.kwarg_only:
|
||||||
|
keys.append(info.name)
|
||||||
|
else:
|
||||||
|
idxs.append(i)
|
||||||
|
return idxs, keys
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user