pytorch/torch/_functorch/utils.py
Xuehai Pan e7eeee473c [BE][Easy][14/19] enforce style for empty lines in import segments in torch/_[a-c]*/ and torch/_[e-h]*/ and torch/_[j-z]*/ (#129765)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129765
Approved by: https://github.com/ezyang
2024-07-31 10:42:50 +00:00

41 lines
976 B
Python

# mypy: allow-untyped-defs
import contextlib
from typing import Tuple, Union
import torch
from torch._C._functorch import (
get_single_level_autograd_function_allowed,
set_single_level_autograd_function_allowed,
unwrap_if_dead,
)
from torch.utils._exposed_in import exposed_in
__all__ = [
"exposed_in",
"argnums_t",
"enable_single_level_autograd_function",
"unwrap_dead_wrappers",
]
@contextlib.contextmanager
def enable_single_level_autograd_function():
try:
prev_state = get_single_level_autograd_function_allowed()
set_single_level_autograd_function_allowed(True)
yield
finally:
set_single_level_autograd_function_allowed(prev_state)
def unwrap_dead_wrappers(args):
# NB: doesn't use tree_map_only for performance reasons
result = tuple(
unwrap_if_dead(arg) if isinstance(arg, torch.Tensor) else arg for arg in args
)
return result
argnums_t = Union[int, Tuple[int, ...]]