mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[executorch] Always generate CustomOpsNativeFunctions.h if custom_ops.yaml is present (#95084)
To match the build system logic, enforce CustomOpsNativeFunctions.h to be generated if we have custom_ops.yaml, even if we don't select any custom ops. Added unit test. Differential Revision: [D43402718](https://our.internmc.facebook.com/intern/diff/D43402718) Pull Request resolved: https://github.com/pytorch/pytorch/pull/95084 Approved by: https://github.com/iseeyuan
This commit is contained in:
parent
da41003b5f
commit
679e5dbfa1
|
|
@ -1,9 +1,16 @@
|
|||
import tempfile
|
||||
import unittest
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
|
||||
import expecttest
|
||||
|
||||
import torchgen
|
||||
from torchgen.executorch.api.custom_ops import ComputeNativeFunctionStub
|
||||
from torchgen.gen_executorch import gen_headers
|
||||
from torchgen.model import Location, NativeFunction
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
from torchgen.utils import FileManager
|
||||
|
||||
SPACES = " "
|
||||
|
||||
|
|
@ -72,3 +79,45 @@ void wrapper_CPU__foo_out(const at::Tensor & self, at::TensorList out) {{
|
|||
gen = ComputeNativeFunctionStub()
|
||||
with self.assertRaisesRegex(Exception, "Can't handle this return type"):
|
||||
gen(func)
|
||||
|
||||
|
||||
class TestGenCustomOpsHeader(unittest.TestCase):
|
||||
@patch.object(torchgen.utils.FileManager, "write_with_template")
|
||||
@patch.object(torchgen.utils.FileManager, "write")
|
||||
def test_fm_writes_custom_ops_header_when_boolean_is_true(
|
||||
self, unused: Mock, mock_method: Mock
|
||||
) -> None:
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
fm = FileManager(tempdir, tempdir, False)
|
||||
gen_headers(
|
||||
native_functions=[],
|
||||
gen_custom_ops_header=True,
|
||||
custom_ops_native_functions=[],
|
||||
static_dispatch_idx=[],
|
||||
selector=SelectiveBuilder.get_nop_selector(),
|
||||
backend_indices={},
|
||||
cpu_fm=fm,
|
||||
use_aten_lib=False,
|
||||
)
|
||||
mock_method.assert_called_once_with(
|
||||
"CustomOpsNativeFunctions.h", "NativeFunctions.h", ANY
|
||||
)
|
||||
|
||||
@patch.object(torchgen.utils.FileManager, "write_with_template")
|
||||
@patch.object(torchgen.utils.FileManager, "write")
|
||||
def test_fm_doesnot_writes_custom_ops_header_when_boolean_is_false(
|
||||
self, unused: Mock, mock_method: Mock
|
||||
) -> None:
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
fm = FileManager(tempdir, tempdir, False)
|
||||
gen_headers(
|
||||
native_functions=[],
|
||||
gen_custom_ops_header=False,
|
||||
custom_ops_native_functions=[],
|
||||
static_dispatch_idx=[],
|
||||
selector=SelectiveBuilder.get_nop_selector(),
|
||||
backend_indices={},
|
||||
cpu_fm=fm,
|
||||
use_aten_lib=False,
|
||||
)
|
||||
mock_method.assert_not_called()
|
||||
|
|
|
|||
|
|
@ -291,6 +291,7 @@ def gen_functions_declarations(
|
|||
def gen_headers(
|
||||
*,
|
||||
native_functions: Sequence[NativeFunction],
|
||||
gen_custom_ops_header: bool,
|
||||
custom_ops_native_functions: Sequence[NativeFunction],
|
||||
static_dispatch_idx: List[BackendIndex],
|
||||
selector: SelectiveBuilder,
|
||||
|
|
@ -298,8 +299,20 @@ def gen_headers(
|
|||
cpu_fm: FileManager,
|
||||
use_aten_lib: bool,
|
||||
) -> None:
|
||||
"""Generate headers.
|
||||
|
||||
Args:
|
||||
native_functions (Sequence[NativeFunction]): a collection of NativeFunction for ATen ops.
|
||||
gen_custom_ops_header (bool): whether we should generate CustomOpsNativeFunctions.h
|
||||
custom_ops_native_functions (Sequence[NativeFunction]): a collection of NativeFunction for custom ops.
|
||||
static_dispatch_idx (List[BackendIndex]): kernel collection
|
||||
selector (SelectiveBuilder): for selective build
|
||||
backend_indices (Dict[DispatchKey, BackendIndex]): kernel collection TODO (larryliu): merge with static_dispatch_idx
|
||||
cpu_fm (FileManager): file manager manages output stream
|
||||
use_aten_lib (bool): whether we are generating for PyTorch types or Executorch types.
|
||||
"""
|
||||
aten_headers = ["#include <ATen/Functions.h>"]
|
||||
if custom_ops_native_functions:
|
||||
if gen_custom_ops_header:
|
||||
cpu_fm.write_with_template(
|
||||
"CustomOpsNativeFunctions.h",
|
||||
"NativeFunctions.h",
|
||||
|
|
@ -744,8 +757,10 @@ def main() -> None:
|
|||
static_dispatch_idx: List[BackendIndex] = [backend_indices[DispatchKey.CPU]]
|
||||
|
||||
if "headers" in options.generate:
|
||||
# generate CustomOpsNativeFunctions.h when custom_ops.yaml is present, to match the build system.
|
||||
gen_headers(
|
||||
native_functions=native_functions,
|
||||
gen_custom_ops_header=options.custom_ops_yaml_path,
|
||||
custom_ops_native_functions=custom_ops_native_functions,
|
||||
static_dispatch_idx=static_dispatch_idx,
|
||||
selector=selector,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user