mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51490 Mutable Tensor ref is a source of endless confusion for kernel writers; if we're going to make everyone rewrite their kernels, might as well also get rid of mutable Tensor& while we're at it. This is a refactor-then-small-update double whammy. The refactor is to separate tools.codegen.api.structured from api.native for describing the type signatures of structured kernels (previously, I was naughtily reusing native for this purpose--now I need it to behave differently as Tensor). This started off as a copy paste, but since there are not that many structured kernels so far I could delete all of the legacy logic from native that didn't make sense (without having to go out and fix all the use sites all at once). One more small addition was teaching translate to convert Tensor& to const Tensor&. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: bhosmer Differential Revision: D26182413 Pulled By: ezyang fbshipit-source-id: ed636866add3581179669cf9283f9835fcaddc06
92 lines
3.9 KiB
Python
92 lines
3.9 KiB
Python
from tools.codegen.model import *
|
|
|
|
from tools.codegen.api.types import *
|
|
import tools.codegen.api.cpp as cpp
|
|
|
|
from typing import Union, List
|
|
|
|
# This file describes the translation of JIT schema to the structured functions API.
|
|
# This is similar to native API, but a number of historical problems with native
|
|
# API have been fixed.
|
|
|
|
# Translation of types occuring in JIT arguments to a C++ argument type.
|
|
# NB: For now, mutable doesn't do anything; but it could if we make
|
|
# some more nominal types
|
|
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType:
|
|
# If it's a value type, do the value type translation
|
|
r = cpp.valuetype_type(t, binds=binds)
|
|
if r is not None:
|
|
return r
|
|
|
|
if isinstance(t, BaseType):
|
|
if t.name == BaseTy.Tensor:
|
|
return ConstRefCType(BaseCType('Tensor', binds))
|
|
else:
|
|
raise AssertionError(f"base type should have been value type {t}")
|
|
elif isinstance(t, OptionalType):
|
|
if t.elem == BaseType(BaseTy.Tensor):
|
|
raise AssertionError(
|
|
"optional tensor not supported by structured yet; to implement this "
|
|
"add OptionalTensor c.f. https://github.com/pytorch/pytorch/issues/51456"
|
|
)
|
|
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
|
|
return OptionalCType(elem)
|
|
elif isinstance(t, ListType):
|
|
if t.elem == BaseType(BaseTy.Tensor):
|
|
raise AssertionError(
|
|
"list of tensor not supported by structured yet; to implement this "
|
|
"resolve torch::List issue, see "
|
|
"https://fb.workplace.com/groups/894363187646754/permalink/1149276442155426"
|
|
)
|
|
# TODO: delete these special cases; see tools.codegen.api.cpp--these
|
|
# must be changed in tandem, but there are problems; see
|
|
# https://github.com/pytorch/pytorch/pull/51485
|
|
elif str(t.elem) == 'int':
|
|
return BaseCType("IntArrayRef", binds)
|
|
elif str(t.elem) == 'Dimname':
|
|
return BaseCType("DimnameList", binds)
|
|
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
|
|
return BaseCType(f"ArrayRef<{elem.cpp_type()}>", binds)
|
|
else:
|
|
raise AssertionError(f"unrecognized type {repr(t)}")
|
|
|
|
def argument_type(a: Argument, *, binds: ArgName) -> CType:
|
|
return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
|
|
|
|
# returns_type intentionally omitted, because structured kernels never "return";
|
|
# instead, they always indirectly report their outputs (in the case of a meta
|
|
# function, by calling set_output; in the case of an impl function, by writing
|
|
# directly into the provided out argument).
|
|
|
|
# Structured kernels are never defaulted
|
|
def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments]) -> List[Binding]:
|
|
if isinstance(a, Argument):
|
|
return [Binding(
|
|
ctype=argument_type(a, binds=a.name),
|
|
name=a.name,
|
|
default=None,
|
|
argument=a,
|
|
)]
|
|
elif isinstance(a, SelfArgument):
|
|
return argument(a.argument)
|
|
elif isinstance(a, TensorOptionsArguments):
|
|
raise AssertionError("structured kernels don't support TensorOptions yet")
|
|
else:
|
|
assert_never(a)
|
|
|
|
def impl_arguments(g: StructuredNativeFunctions) -> List[Binding]:
|
|
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
|
|
args.extend(g.out.func.arguments.non_out)
|
|
args.extend(g.out.func.arguments.out)
|
|
return [r for arg in args for r in argument(arg)]
|
|
|
|
def meta_arguments(g: StructuredNativeFunctions) -> List[Binding]:
|
|
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
|
|
args.extend(g.functional.func.arguments.non_out)
|
|
return [r for arg in args for r in argument(arg)]
|
|
|
|
def out_arguments(g: StructuredNativeFunctions) -> List[Binding]:
|
|
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
|
|
args.extend(g.out.func.arguments.out)
|
|
return [r for arg in args for r in argument(arg)]
|