mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/76275 In preparation for addressing https://github.com/pytorch/pytorch/issues/73212 Diff was generated with: ``` git mv tools/codegen torchgen git grep -l 'tools.codegen' | xargs sed -i 's/tools.codegen/torchgen/g' sed -i "s/\${TOOLS_PATH}\/codegen/\${TORCH_ROOT}\/torchgen/g" caffe2/CMakeLists.txt ``` and a manual edits to: * tools/test/test_gen_backend_stubs.py * torchgen/build.bzl * torchgen/gen_backend_stubs.py aka this diff: ``` diff --git a/tools/test/test_gen_backend_stubs.py b/tools/test/test_gen_backend_stubs.py index 3dc26c6d2d..104054575e 100644 --- a/tools/test/test_gen_backend_stubs.py +++ b/tools/test/test_gen_backend_stubs.py @@ -9,7 +9,7 @@ from torchgen.gen_backend_stubs import run from torchgen.gen import _GLOBAL_PARSE_NATIVE_YAML_CACHE # noqa: F401 path = os.path.dirname(os.path.realpath(__file__)) -gen_backend_stubs_path = os.path.join(path, '../torchgen/gen_backend_stubs.py') +gen_backend_stubs_path = os.path.join(path, '../../torchgen/gen_backend_stubs.py') # gen_backend_stubs.py is an integration point that is called directly by external backends. # The tests here are to confirm that badly formed inputs result in reasonable error messages. diff --git a/torchgen/build.bzl b/torchgen/build.bzl index ed04e35a43..d00078a3cf 100644 --- a/torchgen/build.bzl +++ b/torchgen/build.bzl @@ -1,6 +1,6 @@ def define_targets(rules): rules.py_library( - name = "codegen", + name = "torchgen", srcs = rules.glob(["**/*.py"]), deps = [ rules.requirement("PyYAML"), @@ -11,6 +11,6 @@ def define_targets(rules): rules.py_binary( name = "gen", - srcs = [":codegen"], + srcs = [":torchgen"], visibility = ["//visibility:public"], ) diff --git a/torchgen/gen_backend_stubs.py b/torchgen/gen_backend_stubs.py index c1a672a655..beee7a15e0 100644 --- a/torchgen/gen_backend_stubs.py +++ b/torchgen/gen_backend_stubs.py @@ -474,7 +474,7 @@ def run( ) -> None: # Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py - pytorch_root = pathlib.Path(__file__).parent.parent.parent.absolute() + pytorch_root = pathlib.Path(__file__).parent.parent.absolute() template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates") def make_file_manager(install_dir: str) -> FileManager: ``` run_all_fbandroid_tests Test Plan: sandcastle Reviewed By: albanD, ngimel Differential Revision: D35770317 fbshipit-source-id: 153ac4a7fef15b1e750812a90bfafdbc8f1ebcdf (cherry picked from commit c6d485d1d4648fa1c8a4c14c5bf3d8e899b9b4dd)
64 lines
2.3 KiB
Python
64 lines
2.3 KiB
Python
from typing import List, Union, Optional
|
|
|
|
from torchgen.context import with_native_function_and_index
|
|
from torchgen.utils import mapMaybe
|
|
from torchgen.model import NativeFunction, NativeFunctionsGroup, BackendIndex
|
|
from torchgen.api.types import kernel_signature
|
|
import torchgen.api.meta as meta
|
|
import torchgen.api.structured as structured
|
|
|
|
|
|
@with_native_function_and_index
|
|
def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> Optional[str]:
|
|
sig = kernel_signature(f, backend_index)
|
|
metadata = backend_index.get_kernel(f)
|
|
if metadata is None:
|
|
return None
|
|
if "legacy::" in metadata.kernel:
|
|
return None
|
|
else:
|
|
prefix = "static" if backend_index.external else "TORCH_API"
|
|
return f"{prefix} {sig.decl(name=metadata.kernel)};"
|
|
|
|
|
|
@with_native_function_and_index
|
|
def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> List[str]:
|
|
meta_name = meta.name(g)
|
|
out_args = structured.impl_arguments(g)
|
|
metadata = backend_index.get_kernel(g)
|
|
if metadata is None:
|
|
return []
|
|
prefix = "" if backend_index.external else "TORCH_API "
|
|
return [
|
|
f"""\
|
|
struct {prefix}structured_{metadata.kernel} : public at::meta::structured_{meta_name} {{
|
|
void impl({', '.join(a.decl() for a in out_args)});
|
|
}};
|
|
"""
|
|
]
|
|
|
|
|
|
# Generates NativeFunctions.h, a list of forward declarations of all
|
|
# actual kernel definitions we keep in aten/src/ATen/native/
|
|
@with_native_function_and_index
|
|
def compute_native_function_declaration(
|
|
g: Union[NativeFunctionsGroup, NativeFunction], backend_index: BackendIndex
|
|
) -> List[str]:
|
|
metadata = backend_index.get_kernel(g)
|
|
if isinstance(g, NativeFunctionsGroup):
|
|
if metadata is not None and metadata.structured:
|
|
if backend_index.external:
|
|
# Structured hasn't been tested with external backends yet.
|
|
raise AssertionError(
|
|
"Structured external backend functions are not implemented yet."
|
|
)
|
|
else:
|
|
return gen_structured(g, backend_index)
|
|
else:
|
|
return list(
|
|
mapMaybe(lambda f: gen_unstructured(f, backend_index), g.functions())
|
|
)
|
|
else:
|
|
x = gen_unstructured(g, backend_index)
|
|
return [] if x is None else [x]
|