mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[torchgen] Let native function declaration generation logic take a callable (#90780)
Retry of #90590, which is a retry of #89594. Original PR reverted due to internal breakage. This PR fixes the breakage by adding a default value to the new argument. This PR allows `get_native_function_declarations` API to take a function as argument. This function should take `NativeFunction` as input and emit code for native function declaration. By default it is `dest.compute_native_function_declaration`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/90780 Approved by: https://github.com/ezyang
This commit is contained in:
parent
df58020bb6
commit
4adffe6d51
|
|
@ -9,6 +9,7 @@ import torchgen.model
|
|||
import yaml
|
||||
|
||||
from tools.autograd import gen_autograd_functions, load_derivatives
|
||||
from torchgen import dest
|
||||
from torchgen.api.types import CppSignatureGroup, DispatcherSignature
|
||||
from torchgen.context import native_function_manager
|
||||
from torchgen.gen import (
|
||||
|
|
@ -356,6 +357,7 @@ class TestGenNativeFunctionDeclaration(unittest.TestCase):
|
|||
self.op_2_native_function,
|
||||
],
|
||||
backend_indices=self.backend_indices,
|
||||
native_function_decl_gen=dest.compute_native_function_declaration,
|
||||
)
|
||||
|
||||
def test_native_function_declaration_1_op_1_ns_valid(self) -> None:
|
||||
|
|
@ -365,6 +367,7 @@ class TestGenNativeFunctionDeclaration(unittest.TestCase):
|
|||
self.op_1_native_function,
|
||||
],
|
||||
backend_indices=self.backend_indices,
|
||||
native_function_decl_gen=dest.compute_native_function_declaration,
|
||||
)
|
||||
target = """
|
||||
namespace at {
|
||||
|
|
|
|||
|
|
@ -5,7 +5,18 @@ import os
|
|||
import pathlib
|
||||
from collections import defaultdict, namedtuple, OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import yaml
|
||||
from typing_extensions import Literal
|
||||
|
|
@ -1406,7 +1417,17 @@ def get_native_function_declarations(
|
|||
*,
|
||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
native_function_decl_gen: Callable[
|
||||
[Union[NativeFunctionsGroup, NativeFunction], BackendIndex], List[str]
|
||||
] = dest.compute_native_function_declaration,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate kernel declarations, in `NativeFunction(s).h`.
|
||||
:param grouped_native_functions: a sequence of `NativeFunction` or `NativeFunctionGroup`.
|
||||
:param backend_indices: kernel collections grouped by dispatch key.
|
||||
:param native_function_decl_gen: callable to generate kernel declaration for each `NativeFunction`.
|
||||
:return: a list of string, from the string with all declarations, grouped by namespaces, split by newline.
|
||||
"""
|
||||
declarations: List[str] = []
|
||||
ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list)
|
||||
newline = "\n"
|
||||
|
|
@ -1425,7 +1446,7 @@ def get_native_function_declarations(
|
|||
len(native_function_namespaces) <= 1
|
||||
), f"Codegen only supports one namespace per operator, got {native_function_namespaces} from {dispatch_keys}"
|
||||
ns_grouped_kernels[namespace].extend(
|
||||
dest.compute_native_function_declaration(f, backend_idx)
|
||||
native_function_decl_gen(f, backend_idx)
|
||||
)
|
||||
|
||||
for namespace, kernels in ns_grouped_kernels.items():
|
||||
|
|
@ -1863,7 +1884,9 @@ def gen_per_operator_headers(
|
|||
},
|
||||
)
|
||||
declarations = get_native_function_declarations(
|
||||
grouped_native_functions=grouped_functions, backend_indices=backend_indices
|
||||
grouped_native_functions=grouped_functions,
|
||||
backend_indices=backend_indices,
|
||||
native_function_decl_gen=dest.compute_native_function_declaration,
|
||||
)
|
||||
ops_fm.write_with_template(
|
||||
f"{name}_native.h",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user