pytorch/tools/codegen/api/structured.py
francescocastelli 5e6f296612 Structured Kernel Precompute codegen handle fields without replacement (#71368)
Summary:
I've added the parsing of an optional first line in native_functions.yaml after the precomputed keyword for arguments that will be precomputed without replacement. This line is optional, must be the first and does not contain any arrow.

These new fields are precomputed as before in the meta function and added to the precompute struct returned by the meta function. For now I've put them as last args of the impl function where they can be reused.

example:

native_function.yaml:
```
  ...
  precomputed:
  - int numBatch, int numPlanes, int inputT, int inputH, int inputW   <- new
  - kernel_size -> int poolSizeT, int poolSizeH, int poolSizeW
  - output_size -> int outputT, int outputH, int outputW
```

meta:
```
TORCH_PRECOMPUTE_META_FUNC(fractional_max_pool3d)(
  const at::Tensor& input_,
  IntArrayRef pool_size,
  IntArrayRef output_size,
  const at::Tensor& randomSamples
) {
    ...

return TORCH_PRECOMPUTE_STRUCT(fractional_max_pool3d)().set_numBatch(numBatch).set_numPlanes(numPlanes).set_inputT(inputT).set_inputH(inputH).set_inputW(inputW)
  .set_poolSizeT(poolSizeT) ...
}
```

impl:
```
TORCH_IMPL_FUNC(fractional_max_pool3d_out_cpu)(
  const at::Tensor& input_,
  int64_t poolSizeT,
  int64_t poolSizeH,
  int64_t poolSizeW,
  int64_t outputT,
  int64_t outputH,
  int64_t outputW,
  const at::Tensor& randomSamples,
  const at::Tensor& output,
  const at::Tensor& indices,
  int64_t numBatch,    <- for now I've put them here
  int64_t numPlanes,
  int64_t inputT,
  int64_t inputH,
  int64_t inputW) {
```

Fixes https://github.com/pytorch/pytorch/issues/71314

Pull Request resolved: https://github.com/pytorch/pytorch/pull/71368

Reviewed By: zou3519

Differential Revision: D33683984

Pulled By: bdhirsh

fbshipit-source-id: 33066dd92b8743aadf0dc8102f6bf0689f843242
(cherry picked from commit 64e46af6a4)
2022-02-08 03:56:56 +00:00

122 lines
5.6 KiB
Python

from tools.codegen.model import (Argument, BaseTy, BaseType, ListType,
NativeFunctionsGroup, OptionalType,
SelfArgument, TensorOptionsArguments, Type)
from tools.codegen.api.types import (ArgName, BaseCType, Binding, ArrayRefCType,
ConstRefCType, OptionalCType, NamedCType,
tensorT, scalarT, intArrayRefT, dimnameListT,
optionalTensorRefT, optionalScalarRefT)
from tools.codegen.api import cpp
from tools.codegen.utils import assert_never
from typing import Union, List
# This file describes the translation of JIT schema to the structured functions API.
# This is similar to native API, but a number of historical problems with native
# API have been fixed.
# Translation of types occuring in JIT arguments to a C++ argument type.
# NB: For now, mutable doesn't do anything; but it could if we make
# some more nominal types
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
# If it's a value type, do the value type translation
r = cpp.valuetype_type(t, binds=binds)
if r is not None:
return r
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor:
return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
elif t.name == BaseTy.Scalar:
return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
else:
raise AssertionError(f"base type should have been value type {t}")
elif isinstance(t, OptionalType):
if t.elem == BaseType(BaseTy.Tensor):
return NamedCType(binds, BaseCType(optionalTensorRefT))
elif t.elem == BaseType(BaseTy.Scalar):
return NamedCType(binds, BaseCType(optionalScalarRefT))
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
return NamedCType(binds, OptionalCType(elem.type))
elif isinstance(t, ListType):
if t.elem == BaseType(BaseTy.Tensor):
raise AssertionError(
"list of tensor not supported by structured yet; to implement this "
"resolve torch::List issue, see "
"https://fb.workplace.com/groups/894363187646754/permalink/1149276442155426"
)
# TODO: delete these special cases; see tools.codegen.api.cpp--these
# must be changed in tandem, but there are problems; see
# https://github.com/pytorch/pytorch/pull/51485
elif str(t.elem) == 'int':
return NamedCType(binds, BaseCType(intArrayRefT))
elif str(t.elem) == 'Dimname':
return NamedCType(binds, BaseCType(dimnameListT))
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
return NamedCType(binds, ArrayRefCType(elem.type))
else:
raise AssertionError(f"unrecognized type {repr(t)}")
def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
# returns_type intentionally omitted, because structured kernels never "return";
# instead, they always indirectly report their outputs (in the case of a meta
# function, by calling set_output; in the case of an impl function, by writing
# directly into the provided out argument).
# Structured kernels are never defaulted
def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments]) -> List[Binding]:
if isinstance(a, Argument):
return [Binding(
nctype=argument_type(a, binds=a.name),
name=a.name,
default=None,
argument=a,
)]
elif isinstance(a, SelfArgument):
return argument(a.argument)
elif isinstance(a, TensorOptionsArguments):
raise AssertionError("structured kernels don't support TensorOptions yet")
else:
assert_never(a)
def impl_arguments(g: NativeFunctionsGroup) -> List[Binding]:
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
if g.out.precomputed:
# A list of parameters for the impl function with
# certain parameters replaced with precomputed counterparts
# as specified in native_functions.yaml.
non_out_args_replaced: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
for a in g.out.func.arguments.non_out:
if isinstance(a, Argument) and a.name in g.out.precomputed.replace:
# If a is in precompute.replace, append the parameters
# that should replace it onto non_out_args_replaced.
for replacement in g.out.precomputed.replace[a.name]:
non_out_args_replaced.append(replacement)
else:
# If not, push a as it is.
non_out_args_replaced.append(a)
args.extend(non_out_args_replaced)
# g.out.precomputed.add is the list of parameters that are added
# without replacement after the non out args and just before the out args
args.extend(g.out.precomputed.add)
else:
args.extend(g.out.func.arguments.non_out)
args.extend(g.out.func.arguments.out)
return [r for arg in args for r in argument(arg)]
def meta_arguments(g: NativeFunctionsGroup) -> List[Binding]:
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
args.extend(g.functional.func.arguments.non_out)
return [r for arg in args for r in argument(arg)]
def out_arguments(g: NativeFunctionsGroup) -> List[Binding]:
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
args.extend(g.out.func.arguments.out)
return [r for arg in args for r in argument(arg)]