mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary:
I've added the parsing of an optional first line in native_functions.yaml after the precomputed keyword for arguments that will be precomputed without replacement. This line is optional, must be the first and does not contain any arrow.
These new fields are precomputed as before in the meta function and added to the precompute struct returned by the meta function. For now I've put them as last args of the impl function where they can be reused.
example:
native_function.yaml:
```
...
precomputed:
- int numBatch, int numPlanes, int inputT, int inputH, int inputW <- new
- kernel_size -> int poolSizeT, int poolSizeH, int poolSizeW
- output_size -> int outputT, int outputH, int outputW
```
meta:
```
TORCH_PRECOMPUTE_META_FUNC(fractional_max_pool3d)(
const at::Tensor& input_,
IntArrayRef pool_size,
IntArrayRef output_size,
const at::Tensor& randomSamples
) {
...
return TORCH_PRECOMPUTE_STRUCT(fractional_max_pool3d)().set_numBatch(numBatch).set_numPlanes(numPlanes).set_inputT(inputT).set_inputH(inputH).set_inputW(inputW)
.set_poolSizeT(poolSizeT) ...
}
```
impl:
```
TORCH_IMPL_FUNC(fractional_max_pool3d_out_cpu)(
const at::Tensor& input_,
int64_t poolSizeT,
int64_t poolSizeH,
int64_t poolSizeW,
int64_t outputT,
int64_t outputH,
int64_t outputW,
const at::Tensor& randomSamples,
const at::Tensor& output,
const at::Tensor& indices,
int64_t numBatch, <- for now I've put them here
int64_t numPlanes,
int64_t inputT,
int64_t inputH,
int64_t inputW) {
```
Fixes https://github.com/pytorch/pytorch/issues/71314
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71368
Reviewed By: zou3519
Differential Revision: D33683984
Pulled By: bdhirsh
fbshipit-source-id: 33066dd92b8743aadf0dc8102f6bf0689f843242
(cherry picked from commit 64e46af6a4)
122 lines
5.6 KiB
Python
122 lines
5.6 KiB
Python
from tools.codegen.model import (Argument, BaseTy, BaseType, ListType,
|
|
NativeFunctionsGroup, OptionalType,
|
|
SelfArgument, TensorOptionsArguments, Type)
|
|
|
|
from tools.codegen.api.types import (ArgName, BaseCType, Binding, ArrayRefCType,
|
|
ConstRefCType, OptionalCType, NamedCType,
|
|
tensorT, scalarT, intArrayRefT, dimnameListT,
|
|
optionalTensorRefT, optionalScalarRefT)
|
|
|
|
from tools.codegen.api import cpp
|
|
from tools.codegen.utils import assert_never
|
|
|
|
from typing import Union, List
|
|
|
|
# This file describes the translation of JIT schema to the structured functions API.
|
|
# This is similar to native API, but a number of historical problems with native
|
|
# API have been fixed.
|
|
|
|
# Translation of types occuring in JIT arguments to a C++ argument type.
|
|
# NB: For now, mutable doesn't do anything; but it could if we make
|
|
# some more nominal types
|
|
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
|
|
# If it's a value type, do the value type translation
|
|
r = cpp.valuetype_type(t, binds=binds)
|
|
if r is not None:
|
|
return r
|
|
|
|
if isinstance(t, BaseType):
|
|
if t.name == BaseTy.Tensor:
|
|
return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
|
|
elif t.name == BaseTy.Scalar:
|
|
return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
|
|
else:
|
|
raise AssertionError(f"base type should have been value type {t}")
|
|
elif isinstance(t, OptionalType):
|
|
if t.elem == BaseType(BaseTy.Tensor):
|
|
return NamedCType(binds, BaseCType(optionalTensorRefT))
|
|
elif t.elem == BaseType(BaseTy.Scalar):
|
|
return NamedCType(binds, BaseCType(optionalScalarRefT))
|
|
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
|
|
return NamedCType(binds, OptionalCType(elem.type))
|
|
elif isinstance(t, ListType):
|
|
if t.elem == BaseType(BaseTy.Tensor):
|
|
raise AssertionError(
|
|
"list of tensor not supported by structured yet; to implement this "
|
|
"resolve torch::List issue, see "
|
|
"https://fb.workplace.com/groups/894363187646754/permalink/1149276442155426"
|
|
)
|
|
# TODO: delete these special cases; see tools.codegen.api.cpp--these
|
|
# must be changed in tandem, but there are problems; see
|
|
# https://github.com/pytorch/pytorch/pull/51485
|
|
elif str(t.elem) == 'int':
|
|
return NamedCType(binds, BaseCType(intArrayRefT))
|
|
elif str(t.elem) == 'Dimname':
|
|
return NamedCType(binds, BaseCType(dimnameListT))
|
|
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
|
|
return NamedCType(binds, ArrayRefCType(elem.type))
|
|
else:
|
|
raise AssertionError(f"unrecognized type {repr(t)}")
|
|
|
|
def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
|
|
return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
|
|
|
|
# returns_type intentionally omitted, because structured kernels never "return";
|
|
# instead, they always indirectly report their outputs (in the case of a meta
|
|
# function, by calling set_output; in the case of an impl function, by writing
|
|
# directly into the provided out argument).
|
|
|
|
# Structured kernels are never defaulted
|
|
def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments]) -> List[Binding]:
|
|
if isinstance(a, Argument):
|
|
return [Binding(
|
|
nctype=argument_type(a, binds=a.name),
|
|
name=a.name,
|
|
default=None,
|
|
argument=a,
|
|
)]
|
|
elif isinstance(a, SelfArgument):
|
|
return argument(a.argument)
|
|
elif isinstance(a, TensorOptionsArguments):
|
|
raise AssertionError("structured kernels don't support TensorOptions yet")
|
|
else:
|
|
assert_never(a)
|
|
|
|
def impl_arguments(g: NativeFunctionsGroup) -> List[Binding]:
|
|
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
|
|
|
|
if g.out.precomputed:
|
|
# A list of parameters for the impl function with
|
|
# certain parameters replaced with precomputed counterparts
|
|
# as specified in native_functions.yaml.
|
|
non_out_args_replaced: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
|
|
for a in g.out.func.arguments.non_out:
|
|
if isinstance(a, Argument) and a.name in g.out.precomputed.replace:
|
|
# If a is in precompute.replace, append the parameters
|
|
# that should replace it onto non_out_args_replaced.
|
|
for replacement in g.out.precomputed.replace[a.name]:
|
|
non_out_args_replaced.append(replacement)
|
|
else:
|
|
# If not, push a as it is.
|
|
non_out_args_replaced.append(a)
|
|
|
|
args.extend(non_out_args_replaced)
|
|
# g.out.precomputed.add is the list of parameters that are added
|
|
# without replacement after the non out args and just before the out args
|
|
args.extend(g.out.precomputed.add)
|
|
else:
|
|
args.extend(g.out.func.arguments.non_out)
|
|
|
|
args.extend(g.out.func.arguments.out)
|
|
return [r for arg in args for r in argument(arg)]
|
|
|
|
def meta_arguments(g: NativeFunctionsGroup) -> List[Binding]:
|
|
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
|
|
args.extend(g.functional.func.arguments.non_out)
|
|
return [r for arg in args for r in argument(arg)]
|
|
|
|
def out_arguments(g: NativeFunctionsGroup) -> List[Binding]:
|
|
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
|
|
args.extend(g.out.func.arguments.out)
|
|
return [r for arg in args for r in argument(arg)]
|