pytorch/tools/codegen/dest/native_functions.py
Edward Yang c00d66f73c Move compute_native_function_declaration to its own dest module (#54419)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54419

I'm planning to break it into some helper functions, so let's put it in its own module first.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Reviewed By: ailzhang

Differential Revision: D27235378

Pulled By: ezyang

fbshipit-source-id: c03c5440d2d753859e2c5ec2b2c8b1b82870f03a
2021-03-23 00:43:50 -07:00

67 lines
2.2 KiB
Python

from typing import List, Union, Set, Any
from tools.codegen.context import *
from tools.codegen.utils import *
from tools.codegen.model import *
from tools.codegen.api.types import *
import tools.codegen.api.meta as meta
import tools.codegen.api.native as native
import tools.codegen.api.structured as structured
# Generates NativeFunctions.h, a list of forward declarations of all
# actual kernel definitions we keep in aten/src/ATen/native/
@with_native_function
def compute_native_function_declaration(g: Union[StructuredNativeFunctions, NativeFunction]) -> List[str]:
if isinstance(g, StructuredNativeFunctions):
# only out has dispatch
meta_name = meta.name(g)
rs = []
seen: Set[Any] = set()
out_args = structured.impl_arguments(g)
for k, n in g.out.dispatch.items():
if n in seen:
continue
if not is_structured_dispatch_key(k):
continue
seen.add(n)
rs.append(f"""\
struct TORCH_API structured_{n} : public at::meta::{meta_name} {{
void impl({', '.join(a.decl() for a in out_args)});
}};
""")
seen = set()
for f in g.functions():
returns_type = native.returns_type(f.func.returns)
args = native.arguments(f.func)
for k, n in f.dispatch.items():
if n in seen:
continue
if is_structured_dispatch_key(k):
continue
seen.add(n)
args_str = ', '.join(a.decl() for a in args)
rs.append(f"TORCH_API {returns_type} {n}({args_str});")
return rs
else:
f = g
ns = list(f.dispatch.values())
rs = []
# Sometimes a function name shows up multiple times; only generate
# it once!
seen = set()
for n in ns:
if n in seen:
continue
if "legacy::" in n:
continue
seen.add(n)
returns_type = native.returns_type(f.func.returns)
args = native.arguments(f.func)
rs.append(f"TORCH_API {returns_type} {n}({', '.join(a.decl() for a in args)});")
return rs