mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45277 Implements structured kernels as per https://github.com/pytorch/rfcs/pull/9 and ports upsample_nearest1d to use the framework. The general structure of this diff: - Define a new syntax for specifying structured kernels in `native_functions.yaml`. You put `structured: True` on the `out` function (that's what you implement) and `structured_delegate: foo.out` on the functional/inplace variants to define them in terms of the `out` function. There's a bunch of new consistency checking to see if you've done this right, though the error messages are of varying quality. This is most of what's going on in tools.codegen.model - NativeFunctionGroup turns into StructuredNativeFunctions. Previously I thought that maybe we would use this grouping mechanism for both structured and unstructured kernels, but it turned out that Jiakai needed to make his own grouping structure. So now I've specialized it for structured kernels, which also means I get to add a bunch of invariants, like requiring structured kernels to have both a functional and an out variant. This is the lower bundle of changes in tools.codegen.model - When you make an out kernel structured, this induces us to generate a new meta function signature for you to write shape checking and output allocation code. The signatures of these is defined by `tools.codegen.api.meta` and generated into `MetaFunctions.h`. Coverage here is very bare bones and will be driven by actual operators we port as we go. - The meaty part of code generation is what we do when we have some grouped StructuredNativeFunctions. We continue to generate a wrapper per function type, but they're are a bit different as the call your meta functions, and make reference to the actual implementations in out. - Then there's a port of `upsample_nearest1d`; easiest to review by just looking at what the final code looks like. Missing pieces: - Stride calculation in TensorMeta - Sufficient sanity checking for inplace/out variants - Enough rope to make TensorIterator work This PR improves instruction counts on `upsample_nearest1d` because it eliminates an extra redispatch. Testing `at::upsample_nearest1d(x, {10});` * Functional: before 1314105, after 1150705 * Out: before 915705, after 838405 These numbers may be jittered up to +-16400 (which is the difference when I tested against an unaffected operator `at::upsample_linear1d`), though that may also because unrelated changes affected all operators globally. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Differential Revision: D24253555 Test Plan: Imported from OSS Reviewed By: smessmer Pulled By: ezyang fbshipit-source-id: 4ef58dd911991060f13576864c8171f9cc614456
60 lines
1.7 KiB
Python
60 lines
1.7 KiB
Python
from tools.codegen.model import *
|
|
from tools.codegen.api.types import MetaArgument
|
|
|
|
import tools.codegen.api.cpp as cpp
|
|
import tools.codegen.api.dispatcher as dispatcher
|
|
|
|
from typing import Sequence
|
|
import itertools
|
|
|
|
# Follows dispatcher calling convention, but:
|
|
# - Mutable arguments not allowed. Meta functions are always
|
|
# written in functional form. Look at FunctionSchema.signature()
|
|
# - No tensor returns; instead we return a TensorMeta describing
|
|
# the tensor in question
|
|
|
|
def name(f: FunctionSchema) -> str:
|
|
assert f.name.overload_name == ""
|
|
return str(f.name.name)
|
|
|
|
def argument_type(a: Argument) -> str:
|
|
assert not a.is_write
|
|
return dispatcher.argumenttype_type(a.type, mutable=False)
|
|
|
|
def returntype_type(t: Type) -> str:
|
|
r = cpp.valuetype_type(t)
|
|
if r is not None:
|
|
return r
|
|
|
|
if isinstance(t, BaseType):
|
|
if t.name == BaseTy.Tensor:
|
|
return 'TensorMeta'
|
|
elif isinstance(t, ListType):
|
|
raise NotImplementedError("list returns not supported yet")
|
|
|
|
raise AssertionError(f"unrecognized return type {t}")
|
|
|
|
def return_type(r: Return) -> str:
|
|
assert not r.is_write
|
|
return returntype_type(r.type)
|
|
|
|
def returns_type(rs: Sequence[Return]) -> str:
|
|
if len(rs) == 0:
|
|
return 'void'
|
|
elif len(rs) == 1:
|
|
return return_type(rs[0])
|
|
else:
|
|
args = ','.join(map(return_type, rs))
|
|
return f'std::tuple<{args}>'
|
|
|
|
def argument(a: Argument) -> MetaArgument:
|
|
return MetaArgument(
|
|
type=argument_type(a),
|
|
name=a.name,
|
|
argument=a,
|
|
)
|
|
|
|
def arguments(func: FunctionSchema) -> Sequence[MetaArgument]:
|
|
assert not func.out_arguments
|
|
return list(map(argument, itertools.chain(func.arguments, func.kwarg_only_arguments)))
|