pytorch/tools/codegen/dest/native_functions.py
Edward Yang bf2ca35f35 Rejigger to use NativeFunctionsGroup even without structured: True (#54426)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54426

Previously, we only put NativeFunctions in StructuredNativeFunctions
if the out variant advertised that the kernel was structured.  However,
there are a few code generation things that can take advantage of
this trio structure, even if the kernel itself hasn't been ported
to be structured.  So better to always group things when they are
related, and then let clients decide whether or not to use the
structure or throw it away.

While doing this, I had hoped that there weren't any functional/inplace
pairs that didn't also have an out variant.  This turned out to not
be true.  These are probably all oversights and should get fixed at
some point.

Bill of changes:

- The actual operational change happens in
  StructuredNativeFunctions.from_dict; then I need to relax some
  __post_init__ invariants.  To tell if a StructuredNativeFunctions
  is actually structured, there is a new structured property, which
  is queried from a few new locations in code
- Refactor native_functions.py into gen_structured/gen_unstructured
  functions so I can easily call gen_unstructured from two contexts

I intend to s/StructuredNativeFunctions/NativeFunctionsGroup/ but
for ease of review this rename hasn't been done in this PR.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Reviewed By: ailzhang

Differential Revision: D27235379

Pulled By: ezyang

fbshipit-source-id: d8a15de9abb75b365348ab94e67b830704e30cf0
2021-03-23 00:43:54 -07:00

76 lines
2.3 KiB
Python

from typing import List, Union, Set, Any
from tools.codegen.context import *
from tools.codegen.utils import *
from tools.codegen.model import *
from tools.codegen.api.types import *
import tools.codegen.api.meta as meta
import tools.codegen.api.native as native
import tools.codegen.api.structured as structured
@with_native_function
def gen_unstructured(f: NativeFunction) -> List[str]:
ns = list(f.dispatch.values())
rs = []
# Sometimes a function name shows up multiple times; only generate
# it once!
seen = set()
for n in ns:
if n in seen:
continue
if "legacy::" in n:
continue
seen.add(n)
returns_type = native.returns_type(f.func.returns)
args = native.arguments(f.func)
rs.append(f"TORCH_API {returns_type} {n}({', '.join(a.decl() for a in args)});")
return rs
@with_native_function
def gen_structured(g: StructuredNativeFunctions) -> List[str]:
# only out has dispatch
meta_name = meta.name(g)
rs = []
seen: Set[Any] = set()
out_args = structured.impl_arguments(g)
for k, n in g.out.dispatch.items():
if n in seen:
continue
if not is_structured_dispatch_key(k):
continue
seen.add(n)
rs.append(f"""\
struct TORCH_API structured_{n} : public at::meta::{meta_name} {{
void impl({', '.join(a.decl() for a in out_args)});
}};
""")
seen = set()
for f in g.functions():
returns_type = native.returns_type(f.func.returns)
args = native.arguments(f.func)
for k, n in f.dispatch.items():
if n in seen:
continue
if is_structured_dispatch_key(k):
continue
seen.add(n)
args_str = ', '.join(a.decl() for a in args)
rs.append(f"TORCH_API {returns_type} {n}({args_str});")
return rs
# Generates NativeFunctions.h, a list of forward declarations of all
# actual kernel definitions we keep in aten/src/ATen/native/
@with_native_function
def compute_native_function_declaration(g: Union[StructuredNativeFunctions, NativeFunction]) -> List[str]:
if isinstance(g, StructuredNativeFunctions):
if g.structured:
return gen_structured(g)
else:
return list(concatMap(gen_unstructured, g.functions()))
else:
return gen_unstructured(g)