[executorch] Always generate CustomOpsNativeFunctions.h if custom_ops.yaml is present (#95084)

To match the build system logic, enforce CustomOpsNativeFunctions.h to be generated if we have custom_ops.yaml, even if we don't select any custom ops.

Added unit test.

Differential Revision: [D43402718](https://our.internmc.facebook.com/intern/diff/D43402718)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95084
Approved by: https://github.com/iseeyuan
This commit is contained in:
Mengwei Liu 2023-02-19 08:32:47 +00:00 committed by PyTorch MergeBot
parent da41003b5f
commit 679e5dbfa1
2 changed files with 65 additions and 1 deletions

View File

@ -1,9 +1,16 @@
import tempfile
import unittest
from typing import Any, Dict
from unittest.mock import ANY, Mock, patch
import expecttest
import torchgen
from torchgen.executorch.api.custom_ops import ComputeNativeFunctionStub
from torchgen.gen_executorch import gen_headers
from torchgen.model import Location, NativeFunction
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import FileManager
SPACES = " "
@ -72,3 +79,45 @@ void wrapper_CPU__foo_out(const at::Tensor & self, at::TensorList out) {{
gen = ComputeNativeFunctionStub()
with self.assertRaisesRegex(Exception, "Can't handle this return type"):
gen(func)
class TestGenCustomOpsHeader(unittest.TestCase):
@patch.object(torchgen.utils.FileManager, "write_with_template")
@patch.object(torchgen.utils.FileManager, "write")
def test_fm_writes_custom_ops_header_when_boolean_is_true(
self, unused: Mock, mock_method: Mock
) -> None:
with tempfile.TemporaryDirectory() as tempdir:
fm = FileManager(tempdir, tempdir, False)
gen_headers(
native_functions=[],
gen_custom_ops_header=True,
custom_ops_native_functions=[],
static_dispatch_idx=[],
selector=SelectiveBuilder.get_nop_selector(),
backend_indices={},
cpu_fm=fm,
use_aten_lib=False,
)
mock_method.assert_called_once_with(
"CustomOpsNativeFunctions.h", "NativeFunctions.h", ANY
)
@patch.object(torchgen.utils.FileManager, "write_with_template")
@patch.object(torchgen.utils.FileManager, "write")
def test_fm_doesnot_writes_custom_ops_header_when_boolean_is_false(
self, unused: Mock, mock_method: Mock
) -> None:
with tempfile.TemporaryDirectory() as tempdir:
fm = FileManager(tempdir, tempdir, False)
gen_headers(
native_functions=[],
gen_custom_ops_header=False,
custom_ops_native_functions=[],
static_dispatch_idx=[],
selector=SelectiveBuilder.get_nop_selector(),
backend_indices={},
cpu_fm=fm,
use_aten_lib=False,
)
mock_method.assert_not_called()

View File

@ -291,6 +291,7 @@ def gen_functions_declarations(
def gen_headers(
*,
native_functions: Sequence[NativeFunction],
gen_custom_ops_header: bool,
custom_ops_native_functions: Sequence[NativeFunction],
static_dispatch_idx: List[BackendIndex],
selector: SelectiveBuilder,
@ -298,8 +299,20 @@ def gen_headers(
cpu_fm: FileManager,
use_aten_lib: bool,
) -> None:
"""Generate headers.
Args:
native_functions (Sequence[NativeFunction]): a collection of NativeFunction for ATen ops.
gen_custom_ops_header (bool): whether we should generate CustomOpsNativeFunctions.h
custom_ops_native_functions (Sequence[NativeFunction]): a collection of NativeFunction for custom ops.
static_dispatch_idx (List[BackendIndex]): kernel collection
selector (SelectiveBuilder): for selective build
backend_indices (Dict[DispatchKey, BackendIndex]): kernel collection TODO (larryliu): merge with static_dispatch_idx
cpu_fm (FileManager): file manager manages output stream
use_aten_lib (bool): whether we are generating for PyTorch types or Executorch types.
"""
aten_headers = ["#include <ATen/Functions.h>"]
if custom_ops_native_functions:
if gen_custom_ops_header:
cpu_fm.write_with_template(
"CustomOpsNativeFunctions.h",
"NativeFunctions.h",
@ -744,8 +757,10 @@ def main() -> None:
static_dispatch_idx: List[BackendIndex] = [backend_indices[DispatchKey.CPU]]
if "headers" in options.generate:
# generate CustomOpsNativeFunctions.h when custom_ops.yaml is present, to match the build system.
gen_headers(
native_functions=native_functions,
gen_custom_ops_header=options.custom_ops_yaml_path,
custom_ops_native_functions=custom_ops_native_functions,
static_dispatch_idx=static_dispatch_idx,
selector=selector,