diff --git a/torch/_library/custom_ops.py b/torch/_library/custom_ops.py index 2c57ceb9baf..b0bc2ea9458 100644 --- a/torch/_library/custom_ops.py +++ b/torch/_library/custom_ops.py @@ -594,19 +594,13 @@ class CustomOpDef: schema = self._opoverload._schema if schema.is_mutable: + mutated_idxs, mutated_keys = utils.mutated_args_kwargs(schema) def adinplaceorview_impl(keyset, *args, **kwargs): - for arg, val in utils.zip_schema(schema, args, kwargs): - if not arg.alias_info: - continue - if not arg.alias_info.is_write: - continue - if isinstance(val, Tensor): - torch.autograd.graph.increment_version(val) - elif isinstance(val, (tuple, list)): - for v in val: - if isinstance(v, Tensor): - torch.autograd.graph.increment_version(v) + for idx in mutated_idxs: + increment_version(args[idx]) + for key in mutated_keys: + increment_version(kwargs[key]) with _C._AutoDispatchBelowADInplaceOrView(): return self._opoverload.redispatch( keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs @@ -740,6 +734,15 @@ class CustomOpDef: return register(func) +def increment_version(val: Any) -> None: + if isinstance(val, Tensor): + torch.autograd.graph.increment_version(val) + elif isinstance(val, (tuple, list)): + for v in val: + if isinstance(v, Tensor): + torch.autograd.graph.increment_version(v) + + # NOTE: [Supporting decorator and non-decorator usage] # # Some APIs may be both used as a decorator and not as a decorator. diff --git a/torch/_library/utils.py b/torch/_library/utils.py index 82ebdad018d..9c12ec9ebb0 100644 --- a/torch/_library/utils.py +++ b/torch/_library/utils.py @@ -3,7 +3,7 @@ import dataclasses import inspect import sys import warnings -from typing import Any, Callable, Dict, Iterable, Iterator, Tuple, Union +from typing import Any, Callable, Dict, Iterable, Iterator, List, Tuple, Union import torch import torch.utils._pytree as pytree @@ -463,3 +463,15 @@ def has_fake_kernel(op: torch._ops.OpOverload) -> bool: if opdef._abstract_fn is not None: return True return False + + +def mutated_args_kwargs(schema: _C.FunctionSchema) -> Tuple[List[int], List[str]]: + idxs = [] + keys = [] + for i, info in enumerate(schema.arguments): + if info.alias_info is not None and info.alias_info.is_write: + if info.kwarg_only: + keys.append(info.name) + else: + idxs.append(i) + return idxs, keys