pytorch/tools/codegen/api/structured.py
Edward Yang 81c7c3bae5 Add api.structured; switch structured kernels to use const Tensor& everywhere (#51490)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51490

Mutable Tensor ref is a source of endless confusion for kernel writers;
if we're going to make everyone rewrite their kernels, might as well
also get rid of mutable Tensor& while we're at it.

This is a refactor-then-small-update double whammy.  The refactor
is to separate tools.codegen.api.structured from api.native for
describing the type signatures of structured kernels (previously,
I was naughtily reusing native for this purpose--now I need it to
behave differently as Tensor).  This started off as a copy paste, but
since there are not that many structured kernels so far I could delete
all of the legacy logic from native that didn't make sense (without
having to go out and fix all the use sites all at once).

One more small addition was teaching translate to convert Tensor& to const Tensor&.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Reviewed By: bhosmer

Differential Revision: D26182413

Pulled By: ezyang

fbshipit-source-id: ed636866add3581179669cf9283f9835fcaddc06
2021-02-03 14:03:46 -08:00

92 lines
3.9 KiB
Python

from tools.codegen.model import *
from tools.codegen.api.types import *
import tools.codegen.api.cpp as cpp
from typing import Union, List
# This file describes the translation of JIT schema to the structured functions API.
# This is similar to native API, but a number of historical problems with native
# API have been fixed.
# Translation of types occuring in JIT arguments to a C++ argument type.
# NB: For now, mutable doesn't do anything; but it could if we make
# some more nominal types
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType:
# If it's a value type, do the value type translation
r = cpp.valuetype_type(t, binds=binds)
if r is not None:
return r
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor:
return ConstRefCType(BaseCType('Tensor', binds))
else:
raise AssertionError(f"base type should have been value type {t}")
elif isinstance(t, OptionalType):
if t.elem == BaseType(BaseTy.Tensor):
raise AssertionError(
"optional tensor not supported by structured yet; to implement this "
"add OptionalTensor c.f. https://github.com/pytorch/pytorch/issues/51456"
)
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
return OptionalCType(elem)
elif isinstance(t, ListType):
if t.elem == BaseType(BaseTy.Tensor):
raise AssertionError(
"list of tensor not supported by structured yet; to implement this "
"resolve torch::List issue, see "
"https://fb.workplace.com/groups/894363187646754/permalink/1149276442155426"
)
# TODO: delete these special cases; see tools.codegen.api.cpp--these
# must be changed in tandem, but there are problems; see
# https://github.com/pytorch/pytorch/pull/51485
elif str(t.elem) == 'int':
return BaseCType("IntArrayRef", binds)
elif str(t.elem) == 'Dimname':
return BaseCType("DimnameList", binds)
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
return BaseCType(f"ArrayRef<{elem.cpp_type()}>", binds)
else:
raise AssertionError(f"unrecognized type {repr(t)}")
def argument_type(a: Argument, *, binds: ArgName) -> CType:
return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
# returns_type intentionally omitted, because structured kernels never "return";
# instead, they always indirectly report their outputs (in the case of a meta
# function, by calling set_output; in the case of an impl function, by writing
# directly into the provided out argument).
# Structured kernels are never defaulted
def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments]) -> List[Binding]:
if isinstance(a, Argument):
return [Binding(
ctype=argument_type(a, binds=a.name),
name=a.name,
default=None,
argument=a,
)]
elif isinstance(a, SelfArgument):
return argument(a.argument)
elif isinstance(a, TensorOptionsArguments):
raise AssertionError("structured kernels don't support TensorOptions yet")
else:
assert_never(a)
def impl_arguments(g: StructuredNativeFunctions) -> List[Binding]:
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
args.extend(g.out.func.arguments.non_out)
args.extend(g.out.func.arguments.out)
return [r for arg in args for r in argument(arg)]
def meta_arguments(g: StructuredNativeFunctions) -> List[Binding]:
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
args.extend(g.functional.func.arguments.non_out)
return [r for arg in args for r in argument(arg)]
def out_arguments(g: StructuredNativeFunctions) -> List[Binding]:
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
args.extend(g.out.func.arguments.out)
return [r for arg in args for r in argument(arg)]