Structured Kernel Precompute codegen handle fields without replacement (#71368)

Summary:
I've added the parsing of an optional first line in native_functions.yaml after the precomputed keyword for arguments that will be precomputed without replacement. This line is optional, must be the first and does not contain any arrow.

These new fields are precomputed as before in the meta function and added to the precompute struct returned by the meta function. For now I've put them as last args of the impl function where they can be reused.

example:

native_function.yaml:
```
  ...
  precomputed:
  - int numBatch, int numPlanes, int inputT, int inputH, int inputW   <- new
  - kernel_size -> int poolSizeT, int poolSizeH, int poolSizeW
  - output_size -> int outputT, int outputH, int outputW
```

meta:
```
TORCH_PRECOMPUTE_META_FUNC(fractional_max_pool3d)(
  const at::Tensor& input_,
  IntArrayRef pool_size,
  IntArrayRef output_size,
  const at::Tensor& randomSamples
) {
    ...

return TORCH_PRECOMPUTE_STRUCT(fractional_max_pool3d)().set_numBatch(numBatch).set_numPlanes(numPlanes).set_inputT(inputT).set_inputH(inputH).set_inputW(inputW)
  .set_poolSizeT(poolSizeT) ...
}
```

impl:
```
TORCH_IMPL_FUNC(fractional_max_pool3d_out_cpu)(
  const at::Tensor& input_,
  int64_t poolSizeT,
  int64_t poolSizeH,
  int64_t poolSizeW,
  int64_t outputT,
  int64_t outputH,
  int64_t outputW,
  const at::Tensor& randomSamples,
  const at::Tensor& output,
  const at::Tensor& indices,
  int64_t numBatch,    <- for now I've put them here
  int64_t numPlanes,
  int64_t inputT,
  int64_t inputH,
  int64_t inputW) {
```

Fixes https://github.com/pytorch/pytorch/issues/71314

Pull Request resolved: https://github.com/pytorch/pytorch/pull/71368

Reviewed By: zou3519

Differential Revision: D33683984

Pulled By: bdhirsh

fbshipit-source-id: 33066dd92b8743aadf0dc8102f6bf0689f843242
(cherry picked from commit 64e46af6a4)
This commit is contained in:
francescocastelli 2022-02-07 19:51:51 -08:00 committed by PyTorch MergeBot
parent 8bf3179f6e
commit 5e6f296612
7 changed files with 38 additions and 48 deletions

View File

@ -80,7 +80,8 @@ TORCH_PRECOMPUTE_META_FUNC(fractional_max_pool3d)(
set_output(1, {numBatch, numPlanes, outputT, outputH, outputW}, input_.options().dtype(kLong)); set_output(1, {numBatch, numPlanes, outputT, outputH, outputW}, input_.options().dtype(kLong));
} }
return TORCH_PRECOMPUTE_STRUCT(fractional_max_pool3d)().set_poolSizeT(poolSizeT).set_poolSizeH(poolSizeH).set_poolSizeW(poolSizeW) return TORCH_PRECOMPUTE_STRUCT(fractional_max_pool3d)().set_numBatch(numBatch).set_numPlanes(numPlanes).set_inputT(inputT).set_inputH(inputH).set_inputW(inputW)
.set_poolSizeT(poolSizeT).set_poolSizeH(poolSizeH).set_poolSizeW(poolSizeW)
.set_outputT(outputT).set_outputH(outputH).set_outputW(outputW); .set_outputT(outputT).set_outputH(outputH).set_outputW(outputW);
} }
@ -230,30 +231,14 @@ TORCH_IMPL_FUNC(fractional_max_pool3d_out_cpu)(
int64_t outputH, int64_t outputH,
int64_t outputW, int64_t outputW,
const at::Tensor& randomSamples, const at::Tensor& randomSamples,
int64_t numBatch,
int64_t numPlanes,
int64_t inputT,
int64_t inputH,
int64_t inputW,
const at::Tensor& output, const at::Tensor& output,
const at::Tensor& indices) { const at::Tensor& indices) {
int64_t numBatch = 1;
int64_t planeDim = 0;
int64_t timeDim = 1;
int64_t heightDim = 2;
int64_t widthDim = 3;
int64_t ndims = input_.ndimension();
if (ndims == 5) {
numBatch = input_.size(0);
planeDim++;
timeDim++;
heightDim++;
widthDim++;
}
/* sizes */
int64_t numPlanes = input_.size(planeDim);
int64_t inputT = input_.size(timeDim);
int64_t inputH = input_.size(heightDim);
int64_t inputW = input_.size(widthDim);
/* get contiguous input */ /* get contiguous input */
auto input = input_.contiguous(); auto input = input_.contiguous();

View File

@ -241,32 +241,19 @@ TORCH_IMPL_FUNC(fractional_max_pool3d_out_cuda) (
int64_t outputH, int64_t outputH,
int64_t outputW, int64_t outputW,
const Tensor& randomSamples, const Tensor& randomSamples,
int64_t numBatch,
int64_t numPlanes,
int64_t inputT,
int64_t inputH,
int64_t inputW,
const Tensor& output, const Tensor& output,
const Tensor& indices const Tensor& indices) {
) {
int64_t planeDim = 0;
int64_t dimt = 1;
int64_t dimh = 2;
int64_t dimw = 3;
int64_t ndims = input.ndimension();
if (ndims == 5) {
planeDim++;
dimt++;
dimh++;
dimw++;
}
/* sizes */
int64_t numPlanes = input.size(planeDim);
int64_t inputT = input.size(dimt);
int64_t inputH = input.size(dimh);
int64_t inputW = input.size(dimw);
auto output_ = output; auto output_ = output;
auto indices_ = indices; auto indices_ = indices;
auto input_ = input; auto input_ = input;
int64_t ndims = input_.ndimension();
if(ndims == 4) { if(ndims == 4) {
output_ = output_.reshape({1, numPlanes, outputT, outputH, outputW}); output_ = output_.reshape({1, numPlanes, outputT, outputH, outputW});
indices_ = indices_.reshape({1, numPlanes, outputT, outputH, outputW}); indices_ = indices_.reshape({1, numPlanes, outputT, outputH, outputW});

View File

@ -9283,6 +9283,7 @@
precomputed: precomputed:
- kernel_size -> int poolSizeT, int poolSizeH, int poolSizeW - kernel_size -> int poolSizeT, int poolSizeH, int poolSizeW
- output_size -> int outputT, int outputH, int outputW - output_size -> int outputT, int outputH, int outputW
- int numBatch, int numPlanes, int inputT, int inputH, int inputW
dispatch: dispatch:
CPU: fractional_max_pool3d_out_cpu CPU: fractional_max_pool3d_out_cpu
CUDA: fractional_max_pool3d_out_cuda CUDA: fractional_max_pool3d_out_cuda

View File

@ -90,7 +90,6 @@ def impl_arguments(g: NativeFunctionsGroup) -> List[Binding]:
# certain parameters replaced with precomputed counterparts # certain parameters replaced with precomputed counterparts
# as specified in native_functions.yaml. # as specified in native_functions.yaml.
non_out_args_replaced: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = [] non_out_args_replaced: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
for a in g.out.func.arguments.non_out: for a in g.out.func.arguments.non_out:
if isinstance(a, Argument) and a.name in g.out.precomputed.replace: if isinstance(a, Argument) and a.name in g.out.precomputed.replace:
# If a is in precompute.replace, append the parameters # If a is in precompute.replace, append the parameters
@ -102,6 +101,9 @@ def impl_arguments(g: NativeFunctionsGroup) -> List[Binding]:
non_out_args_replaced.append(a) non_out_args_replaced.append(a)
args.extend(non_out_args_replaced) args.extend(non_out_args_replaced)
# g.out.precomputed.add is the list of parameters that are added
# without replacement after the non out args and just before the out args
args.extend(g.out.precomputed.add)
else: else:
args.extend(g.out.func.arguments.non_out) args.extend(g.out.func.arguments.non_out)

View File

@ -645,7 +645,8 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
# Put all of the contents of the precompute struct into the context # Put all of the contents of the precompute struct into the context
# so that translate will be able to return the correct args for the # so that translate will be able to return the correct args for the
# call to the impl. # call to the impl.
for precomputed_elems in self.g.out.precomputed.replace.values(): precomputed_values = [*self.g.out.precomputed.replace.values(), self.g.out.precomputed.add]
for precomputed_elems in precomputed_values:
for arg in precomputed_elems: for arg in precomputed_elems:
context.append(Expr( context.append(Expr(
expr=f"precompute.{arg.name}", expr=f"precompute.{arg.name}",

View File

@ -508,7 +508,8 @@ def compute_meta_function_declaration(g: NativeFunctionsGroup) -> Optional[str]:
# Generate the template declaration with one bool parameter for each # Generate the template declaration with one bool parameter for each
# precomputed element. Each parameter is true if the corresponding (in # precomputed element. Each parameter is true if the corresponding (in
# terms of position) precomputed element has been set. # terms of position) precomputed element has been set.
precomputed_elements = [elem for replace_list in precomputed.replace.values() for elem in replace_list] precomputed_values = [*precomputed.replace.values(), precomputed.add]
precomputed_elements = [elem for replace_list in precomputed_values for elem in replace_list]
precomputed_template_parameters = [elem.name.upper() for elem in precomputed_elements] precomputed_template_parameters = [elem.name.upper() for elem in precomputed_elements]
precomputed_template_params_str = ", ".join(f"bool {param} = false" for param in precomputed_template_parameters) precomputed_template_params_str = ", ".join(f"bool {param} = false" for param in precomputed_template_parameters)
precompute_template_decl = f"template <{precomputed_template_params_str}>" precompute_template_decl = f"template <{precomputed_template_params_str}>"

View File

@ -1595,6 +1595,8 @@ class Precompute:
# A map from kernel argument name -> a list of precomputed # A map from kernel argument name -> a list of precomputed
# elements that replaces/supersedes it. # elements that replaces/supersedes it.
replace: Dict[str, List[Argument]] replace: Dict[str, List[Argument]]
# List of precomputed args added without replacement
add: List[Argument]
@staticmethod @staticmethod
def parse(src: object) -> 'Precompute': def parse(src: object) -> 'Precompute':
@ -1602,18 +1604,29 @@ class Precompute:
# src is a list of strings of the format: # src is a list of strings of the format:
# {kernel param name} -> {replacement decl}[, {replacement decl}, ...] # {kernel param name} -> {replacement decl}[, {replacement decl}, ...]
# Parse this list to get the names of which precomputed elements # [{add decl}[, {add decl}, ...]]
# The last line is optional and contains the precomputed parameters that are
# added without replacement.
# The other lines are parsed to get the names of which precomputed elements
# should replace which kernel arguments. # should replace which kernel arguments.
add_args = []
if ' -> ' not in src[-1]:
add_list = src[-1].split(',')
add_args = [Argument.parse(name.strip()) for name in add_list]
src = src[:-1]
replace = {} replace = {}
for raw_replace_item in src: for raw_replace_item in src:
assert isinstance(raw_replace_item, str) assert isinstance(raw_replace_item, str)
assert ' -> ' in raw_replace_item, 'precomputed parameters without replacement' \
' are allowed only in the last line'
arg, with_list_raw = raw_replace_item.split(' -> ') arg, with_list_raw = raw_replace_item.split(' -> ')
with_list = with_list_raw.split(',') with_list = with_list_raw.split(',')
with_list_args = [Argument.parse(name.strip()) for name in with_list] with_list_args = [Argument.parse(name.strip()) for name in with_list]
replace[arg] = with_list_args replace[arg] = with_list_args
r = Precompute(replace=replace) r = Precompute(replace=replace, add=add_args)
assert r.to_list() == src, 'r.to_list() != src' assert r.to_list() == src, 'r.to_list() != src'
return r return r