pytorch/tools/codegen/api/meta.py
Edward Yang cdc2d2843b Structured kernel definitions (#45277)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45277

Implements structured kernels as per https://github.com/pytorch/rfcs/pull/9 and ports upsample_nearest1d to use the framework.

The general structure of this diff:

- Define a new syntax for specifying structured kernels in `native_functions.yaml`. You put `structured: True` on the `out` function (that's what you implement) and `structured_delegate: foo.out` on the functional/inplace variants to define them in terms of the `out` function. There's a bunch of new consistency checking to see if you've done this right, though the error messages are of varying quality. This is most of what's going on in tools.codegen.model
- NativeFunctionGroup turns into StructuredNativeFunctions. Previously I thought that maybe we would use this grouping mechanism for both structured and unstructured kernels, but it turned out that Jiakai needed to make his own grouping structure. So now I've specialized it for structured kernels, which also means I get to add a bunch of invariants, like requiring structured kernels to have both a functional and an out variant. This is the lower bundle of changes in tools.codegen.model
- When you make an out kernel structured, this induces us to generate a new meta function signature for you to write shape checking and output allocation code. The signatures of these is defined by `tools.codegen.api.meta` and generated into `MetaFunctions.h`. Coverage here is very bare bones and will be driven by actual operators we port as we go.
- The meaty part of code generation is what we do when we have some grouped StructuredNativeFunctions. We continue to generate a wrapper per function type, but they're are a bit different as the call your meta functions, and make reference to the actual implementations in out.
- Then there's a port of `upsample_nearest1d`; easiest to review by just looking at what the final code looks like.

Missing pieces:

- Stride calculation in TensorMeta
- Sufficient sanity checking for inplace/out variants
- Enough rope to make TensorIterator work

This PR improves instruction counts on `upsample_nearest1d` because it eliminates an extra redispatch. Testing `at::upsample_nearest1d(x, {10});`

* Functional: before 1314105, after 1150705
* Out: before 915705, after 838405

These numbers may be jittered up to +-16400 (which is the difference when I tested against an unaffected operator `at::upsample_linear1d`), though that may also because unrelated changes affected all operators globally.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision: D24253555

Test Plan: Imported from OSS

Reviewed By: smessmer

Pulled By: ezyang

fbshipit-source-id: 4ef58dd911991060f13576864c8171f9cc614456
2020-11-17 15:24:43 -08:00

60 lines
1.7 KiB
Python

from tools.codegen.model import *
from tools.codegen.api.types import MetaArgument
import tools.codegen.api.cpp as cpp
import tools.codegen.api.dispatcher as dispatcher
from typing import Sequence
import itertools
# Follows dispatcher calling convention, but:
# - Mutable arguments not allowed. Meta functions are always
# written in functional form. Look at FunctionSchema.signature()
# - No tensor returns; instead we return a TensorMeta describing
# the tensor in question
def name(f: FunctionSchema) -> str:
assert f.name.overload_name == ""
return str(f.name.name)
def argument_type(a: Argument) -> str:
assert not a.is_write
return dispatcher.argumenttype_type(a.type, mutable=False)
def returntype_type(t: Type) -> str:
r = cpp.valuetype_type(t)
if r is not None:
return r
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor:
return 'TensorMeta'
elif isinstance(t, ListType):
raise NotImplementedError("list returns not supported yet")
raise AssertionError(f"unrecognized return type {t}")
def return_type(r: Return) -> str:
assert not r.is_write
return returntype_type(r.type)
def returns_type(rs: Sequence[Return]) -> str:
if len(rs) == 0:
return 'void'
elif len(rs) == 1:
return return_type(rs[0])
else:
args = ','.join(map(return_type, rs))
return f'std::tuple<{args}>'
def argument(a: Argument) -> MetaArgument:
return MetaArgument(
type=argument_type(a),
name=a.name,
argument=a,
)
def arguments(func: FunctionSchema) -> Sequence[MetaArgument]:
assert not func.out_arguments
return list(map(argument, itertools.chain(func.arguments, func.kwarg_only_arguments)))