mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
With ufmt in place https://github.com/pytorch/pytorch/pull/81157, we can now use it to gradually format all files. I'm breaking this down into multiple smaller batches to avoid too many merge conflicts later on. This batch (as copied from the current BLACK linter config): * `tools/**/*.py` Upcoming batchs: * `torchgen/**/*.py` * `torch/package/**/*.py` * `torch/onnx/**/*.py` * `torch/_refs/**/*.py` * `torch/_prims/**/*.py` * `torch/_meta_registrations.py` * `torch/_decomp/**/*.py` * `test/onnx/**/*.py` Once they are all formatted, BLACK linter will be removed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/81285 Approved by: https://github.com/suo
19 lines
574 B
Python
19 lines
574 B
Python
import functools
|
|
from typing import Callable
|
|
|
|
from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo as NFWDI
|
|
from torchgen.context import native_function_manager
|
|
from torchgen.utils import T
|
|
|
|
# Like tools.api.context.with_native_function, but for
|
|
# NativeFunctionWithDifferentiabilityInfo.
|
|
def with_native_function_with_differentiability_info(
|
|
func: Callable[[NFWDI], T]
|
|
) -> Callable[[NFWDI], T]:
|
|
@functools.wraps(func)
|
|
def wrapper(f: NFWDI) -> T:
|
|
with native_function_manager(f.func):
|
|
return func(f)
|
|
|
|
return wrapper
|