mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Adds a ruff lint rule to ban raising raw exceptions. Most of these should at the very least be runtime exception, value errors, type errors or some other errors. There are hundreds of instance of these bad exception types already in the codebase, so I have noqa'd most of them. Hopefully this error code will get commiters to rethink what exception type they should raise when they submit a PR. I also encourage people to gradually go and fix all the existing noqas that have been added so they can be removed overtime and our exception typing can be improved. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124570 Approved by: https://github.com/ezyang
145 lines
5.3 KiB
Python
145 lines
5.3 KiB
Python
from collections import defaultdict
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Dict, List, Optional, Sequence, Tuple
|
|
|
|
from torchgen import dest
|
|
|
|
# disable import sorting to avoid circular dependency.
|
|
from torchgen.api.types import DispatcherSignature # isort:skip
|
|
from torchgen.context import method_with_native_function
|
|
from torchgen.executorch.model import ETKernelIndex
|
|
from torchgen.model import BaseTy, BaseType, DispatchKey, NativeFunction, Variant
|
|
from torchgen.selective_build.selector import SelectiveBuilder
|
|
from torchgen.utils import concatMap, Target
|
|
|
|
|
|
# Generates RegisterKernelStub.cpp, which provides placeholder kernels for custom operators. This will be used at
|
|
# model authoring side.
|
|
@dataclass(frozen=True)
|
|
class ComputeNativeFunctionStub:
|
|
@method_with_native_function
|
|
def __call__(self, f: NativeFunction) -> Optional[str]:
|
|
if Variant.function not in f.variants:
|
|
return None
|
|
|
|
sig = DispatcherSignature.from_schema(
|
|
f.func, prefix=f"wrapper_CPU_{f.func.name.overload_name}_", symint=False
|
|
)
|
|
assert sig is not None
|
|
if len(f.func.returns) == 0:
|
|
ret_name = ""
|
|
elif len(f.func.returns) == 1:
|
|
if f.func.arguments.out:
|
|
ret_name = f.func.arguments.out[0].name
|
|
else:
|
|
ret_name = next(
|
|
(
|
|
a.name
|
|
for a in f.func.arguments.flat_non_out
|
|
if a.type == f.func.returns[0].type
|
|
),
|
|
"",
|
|
)
|
|
if not ret_name:
|
|
# if return type is tensor
|
|
if f.func.returns[0].type == BaseType(BaseTy.Tensor):
|
|
# Returns an empty tensor
|
|
ret_name = "at::Tensor()"
|
|
else:
|
|
raise Exception( # noqa: TRY002
|
|
f"Can't handle this return type {f.func}"
|
|
) # noqa: TRY002
|
|
elif len(f.func.arguments.out) == len(f.func.returns):
|
|
# Returns a tuple of out arguments
|
|
tensor_type = "at::Tensor &"
|
|
comma = ", "
|
|
ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>(
|
|
{comma.join([r.name for r in f.func.arguments.out])}
|
|
)"""
|
|
else:
|
|
assert all(
|
|
a.type == BaseType(BaseTy.Tensor) for a in f.func.returns
|
|
), f"Only support tensor returns but got {f.func.returns}"
|
|
# Returns a tuple of empty tensors
|
|
tensor_type = "at::Tensor"
|
|
comma = ", "
|
|
ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>(
|
|
{comma.join(["at::Tensor()" for _ in f.func.returns])}
|
|
)"""
|
|
ret_str = f"return {ret_name};" if len(f.func.returns) > 0 else ""
|
|
return f"""
|
|
{sig.defn()} {{
|
|
{ret_str}
|
|
}}
|
|
"""
|
|
|
|
|
|
def gen_custom_ops_registration(
|
|
*,
|
|
native_functions: Sequence[NativeFunction],
|
|
selector: SelectiveBuilder,
|
|
kernel_index: ETKernelIndex,
|
|
rocm: bool,
|
|
) -> Tuple[str, str]:
|
|
"""
|
|
Generate custom ops registration code for dest.RegisterDispatchKey.
|
|
|
|
:param native_functions: a sequence of `NativeFunction`
|
|
:param selector: for selective build.
|
|
:param kernel_index: kernels for all the ops.
|
|
:param rocm: bool for dest.RegisterDispatchKey.
|
|
:return: generated C++ code to register custom operators into PyTorch
|
|
"""
|
|
|
|
# convert kernel index to BackendIndex. This is because we can't handle ETKernelIndex yet.
|
|
# TODO larryliu: evaluate if this code is still needed. If yes let it handle ETKernelIndex.
|
|
|
|
dispatch_key = DispatchKey.CPU
|
|
backend_index = kernel_index._to_backend_index()
|
|
static_init_dispatch_registrations = ""
|
|
ns_grouped_native_functions: Dict[str, List[NativeFunction]] = defaultdict(list)
|
|
for native_function in native_functions:
|
|
ns_grouped_native_functions[native_function.namespace].append(native_function)
|
|
|
|
for namespace, functions in ns_grouped_native_functions.items():
|
|
if len(functions) == 0:
|
|
continue
|
|
dispatch_registrations_body = "\n".join(
|
|
list(
|
|
concatMap(
|
|
dest.RegisterDispatchKey(
|
|
backend_index,
|
|
Target.REGISTRATION,
|
|
selector,
|
|
rocm=rocm,
|
|
symint=False,
|
|
class_method_name=None,
|
|
skip_dispatcher_op_registration=False,
|
|
),
|
|
functions,
|
|
)
|
|
)
|
|
)
|
|
static_init_dispatch_registrations += f"""
|
|
TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
|
|
{dispatch_registrations_body}
|
|
}};"""
|
|
anonymous_definition = "\n".join(
|
|
list(
|
|
concatMap(
|
|
dest.RegisterDispatchKey(
|
|
backend_index,
|
|
Target.ANONYMOUS_DEFINITION,
|
|
selector,
|
|
rocm=rocm,
|
|
symint=False,
|
|
class_method_name=None,
|
|
skip_dispatcher_op_registration=False,
|
|
),
|
|
native_functions,
|
|
)
|
|
)
|
|
)
|
|
return anonymous_definition, static_init_dispatch_registrations
|