[torchgen] Let native function declaration generation logic take a callable (#90780)

Retry of #90590, which is a retry of #89594. Original PR reverted due to internal breakage.
This PR fixes the breakage by adding a default value to the new argument.

This PR allows `get_native_function_declarations` API to take a function as argument. This function should take `NativeFunction` as input and emit code for native function declaration. By default it is `dest.compute_native_function_declaration`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90780
Approved by: https://github.com/ezyang
This commit is contained in:
Larry Liu 2022-12-13 16:55:38 -08:00 committed by PyTorch MergeBot
parent df58020bb6
commit 4adffe6d51
2 changed files with 29 additions and 3 deletions

View File

@ -9,6 +9,7 @@ import torchgen.model
import yaml
from tools.autograd import gen_autograd_functions, load_derivatives
from torchgen import dest
from torchgen.api.types import CppSignatureGroup, DispatcherSignature
from torchgen.context import native_function_manager
from torchgen.gen import (
@ -356,6 +357,7 @@ class TestGenNativeFunctionDeclaration(unittest.TestCase):
self.op_2_native_function,
],
backend_indices=self.backend_indices,
native_function_decl_gen=dest.compute_native_function_declaration,
)
def test_native_function_declaration_1_op_1_ns_valid(self) -> None:
@ -365,6 +367,7 @@ class TestGenNativeFunctionDeclaration(unittest.TestCase):
self.op_1_native_function,
],
backend_indices=self.backend_indices,
native_function_decl_gen=dest.compute_native_function_declaration,
)
target = """
namespace at {

View File

@ -5,7 +5,18 @@ import os
import pathlib
from collections import defaultdict, namedtuple, OrderedDict
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Set,
Tuple,
TypeVar,
Union,
)
import yaml
from typing_extensions import Literal
@ -1406,7 +1417,17 @@ def get_native_function_declarations(
*,
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
backend_indices: Dict[DispatchKey, BackendIndex],
native_function_decl_gen: Callable[
[Union[NativeFunctionsGroup, NativeFunction], BackendIndex], List[str]
] = dest.compute_native_function_declaration,
) -> List[str]:
"""
Generate kernel declarations, in `NativeFunction(s).h`.
:param grouped_native_functions: a sequence of `NativeFunction` or `NativeFunctionGroup`.
:param backend_indices: kernel collections grouped by dispatch key.
:param native_function_decl_gen: callable to generate kernel declaration for each `NativeFunction`.
:return: a list of string, from the string with all declarations, grouped by namespaces, split by newline.
"""
declarations: List[str] = []
ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list)
newline = "\n"
@ -1425,7 +1446,7 @@ def get_native_function_declarations(
len(native_function_namespaces) <= 1
), f"Codegen only supports one namespace per operator, got {native_function_namespaces} from {dispatch_keys}"
ns_grouped_kernels[namespace].extend(
dest.compute_native_function_declaration(f, backend_idx)
native_function_decl_gen(f, backend_idx)
)
for namespace, kernels in ns_grouped_kernels.items():
@ -1863,7 +1884,9 @@ def gen_per_operator_headers(
},
)
declarations = get_native_function_declarations(
grouped_native_functions=grouped_functions, backend_indices=backend_indices
grouped_native_functions=grouped_functions,
backend_indices=backend_indices,
native_function_decl_gen=dest.compute_native_function_declaration,
)
ops_fm.write_with_template(
f"{name}_native.h",