mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Structured Kernel Precompute codegen handle fields without replacement (#71368)
Summary:
I've added the parsing of an optional first line in native_functions.yaml after the precomputed keyword for arguments that will be precomputed without replacement. This line is optional, must be the first and does not contain any arrow.
These new fields are precomputed as before in the meta function and added to the precompute struct returned by the meta function. For now I've put them as last args of the impl function where they can be reused.
example:
native_function.yaml:
```
...
precomputed:
- int numBatch, int numPlanes, int inputT, int inputH, int inputW <- new
- kernel_size -> int poolSizeT, int poolSizeH, int poolSizeW
- output_size -> int outputT, int outputH, int outputW
```
meta:
```
TORCH_PRECOMPUTE_META_FUNC(fractional_max_pool3d)(
const at::Tensor& input_,
IntArrayRef pool_size,
IntArrayRef output_size,
const at::Tensor& randomSamples
) {
...
return TORCH_PRECOMPUTE_STRUCT(fractional_max_pool3d)().set_numBatch(numBatch).set_numPlanes(numPlanes).set_inputT(inputT).set_inputH(inputH).set_inputW(inputW)
.set_poolSizeT(poolSizeT) ...
}
```
impl:
```
TORCH_IMPL_FUNC(fractional_max_pool3d_out_cpu)(
const at::Tensor& input_,
int64_t poolSizeT,
int64_t poolSizeH,
int64_t poolSizeW,
int64_t outputT,
int64_t outputH,
int64_t outputW,
const at::Tensor& randomSamples,
const at::Tensor& output,
const at::Tensor& indices,
int64_t numBatch, <- for now I've put them here
int64_t numPlanes,
int64_t inputT,
int64_t inputH,
int64_t inputW) {
```
Fixes https://github.com/pytorch/pytorch/issues/71314
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71368
Reviewed By: zou3519
Differential Revision: D33683984
Pulled By: bdhirsh
fbshipit-source-id: 33066dd92b8743aadf0dc8102f6bf0689f843242
(cherry picked from commit 64e46af6a4)
This commit is contained in:
parent
8bf3179f6e
commit
5e6f296612
|
|
@ -80,7 +80,8 @@ TORCH_PRECOMPUTE_META_FUNC(fractional_max_pool3d)(
|
||||||
set_output(1, {numBatch, numPlanes, outputT, outputH, outputW}, input_.options().dtype(kLong));
|
set_output(1, {numBatch, numPlanes, outputT, outputH, outputW}, input_.options().dtype(kLong));
|
||||||
}
|
}
|
||||||
|
|
||||||
return TORCH_PRECOMPUTE_STRUCT(fractional_max_pool3d)().set_poolSizeT(poolSizeT).set_poolSizeH(poolSizeH).set_poolSizeW(poolSizeW)
|
return TORCH_PRECOMPUTE_STRUCT(fractional_max_pool3d)().set_numBatch(numBatch).set_numPlanes(numPlanes).set_inputT(inputT).set_inputH(inputH).set_inputW(inputW)
|
||||||
|
.set_poolSizeT(poolSizeT).set_poolSizeH(poolSizeH).set_poolSizeW(poolSizeW)
|
||||||
.set_outputT(outputT).set_outputH(outputH).set_outputW(outputW);
|
.set_outputT(outputT).set_outputH(outputH).set_outputW(outputW);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -230,30 +231,14 @@ TORCH_IMPL_FUNC(fractional_max_pool3d_out_cpu)(
|
||||||
int64_t outputH,
|
int64_t outputH,
|
||||||
int64_t outputW,
|
int64_t outputW,
|
||||||
const at::Tensor& randomSamples,
|
const at::Tensor& randomSamples,
|
||||||
|
int64_t numBatch,
|
||||||
|
int64_t numPlanes,
|
||||||
|
int64_t inputT,
|
||||||
|
int64_t inputH,
|
||||||
|
int64_t inputW,
|
||||||
const at::Tensor& output,
|
const at::Tensor& output,
|
||||||
const at::Tensor& indices) {
|
const at::Tensor& indices) {
|
||||||
|
|
||||||
int64_t numBatch = 1;
|
|
||||||
int64_t planeDim = 0;
|
|
||||||
int64_t timeDim = 1;
|
|
||||||
int64_t heightDim = 2;
|
|
||||||
int64_t widthDim = 3;
|
|
||||||
|
|
||||||
int64_t ndims = input_.ndimension();
|
|
||||||
if (ndims == 5) {
|
|
||||||
numBatch = input_.size(0);
|
|
||||||
planeDim++;
|
|
||||||
timeDim++;
|
|
||||||
heightDim++;
|
|
||||||
widthDim++;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* sizes */
|
|
||||||
int64_t numPlanes = input_.size(planeDim);
|
|
||||||
int64_t inputT = input_.size(timeDim);
|
|
||||||
int64_t inputH = input_.size(heightDim);
|
|
||||||
int64_t inputW = input_.size(widthDim);
|
|
||||||
|
|
||||||
/* get contiguous input */
|
/* get contiguous input */
|
||||||
auto input = input_.contiguous();
|
auto input = input_.contiguous();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -241,32 +241,19 @@ TORCH_IMPL_FUNC(fractional_max_pool3d_out_cuda) (
|
||||||
int64_t outputH,
|
int64_t outputH,
|
||||||
int64_t outputW,
|
int64_t outputW,
|
||||||
const Tensor& randomSamples,
|
const Tensor& randomSamples,
|
||||||
|
int64_t numBatch,
|
||||||
|
int64_t numPlanes,
|
||||||
|
int64_t inputT,
|
||||||
|
int64_t inputH,
|
||||||
|
int64_t inputW,
|
||||||
const Tensor& output,
|
const Tensor& output,
|
||||||
const Tensor& indices
|
const Tensor& indices) {
|
||||||
) {
|
|
||||||
|
|
||||||
int64_t planeDim = 0;
|
|
||||||
int64_t dimt = 1;
|
|
||||||
int64_t dimh = 2;
|
|
||||||
int64_t dimw = 3;
|
|
||||||
|
|
||||||
int64_t ndims = input.ndimension();
|
|
||||||
if (ndims == 5) {
|
|
||||||
planeDim++;
|
|
||||||
dimt++;
|
|
||||||
dimh++;
|
|
||||||
dimw++;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* sizes */
|
|
||||||
int64_t numPlanes = input.size(planeDim);
|
|
||||||
int64_t inputT = input.size(dimt);
|
|
||||||
int64_t inputH = input.size(dimh);
|
|
||||||
int64_t inputW = input.size(dimw);
|
|
||||||
|
|
||||||
auto output_ = output;
|
auto output_ = output;
|
||||||
auto indices_ = indices;
|
auto indices_ = indices;
|
||||||
auto input_ = input;
|
auto input_ = input;
|
||||||
|
|
||||||
|
int64_t ndims = input_.ndimension();
|
||||||
if(ndims == 4) {
|
if(ndims == 4) {
|
||||||
output_ = output_.reshape({1, numPlanes, outputT, outputH, outputW});
|
output_ = output_.reshape({1, numPlanes, outputT, outputH, outputW});
|
||||||
indices_ = indices_.reshape({1, numPlanes, outputT, outputH, outputW});
|
indices_ = indices_.reshape({1, numPlanes, outputT, outputH, outputW});
|
||||||
|
|
|
||||||
|
|
@ -9283,6 +9283,7 @@
|
||||||
precomputed:
|
precomputed:
|
||||||
- kernel_size -> int poolSizeT, int poolSizeH, int poolSizeW
|
- kernel_size -> int poolSizeT, int poolSizeH, int poolSizeW
|
||||||
- output_size -> int outputT, int outputH, int outputW
|
- output_size -> int outputT, int outputH, int outputW
|
||||||
|
- int numBatch, int numPlanes, int inputT, int inputH, int inputW
|
||||||
dispatch:
|
dispatch:
|
||||||
CPU: fractional_max_pool3d_out_cpu
|
CPU: fractional_max_pool3d_out_cpu
|
||||||
CUDA: fractional_max_pool3d_out_cuda
|
CUDA: fractional_max_pool3d_out_cuda
|
||||||
|
|
|
||||||
|
|
@ -90,7 +90,6 @@ def impl_arguments(g: NativeFunctionsGroup) -> List[Binding]:
|
||||||
# certain parameters replaced with precomputed counterparts
|
# certain parameters replaced with precomputed counterparts
|
||||||
# as specified in native_functions.yaml.
|
# as specified in native_functions.yaml.
|
||||||
non_out_args_replaced: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
|
non_out_args_replaced: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
|
||||||
|
|
||||||
for a in g.out.func.arguments.non_out:
|
for a in g.out.func.arguments.non_out:
|
||||||
if isinstance(a, Argument) and a.name in g.out.precomputed.replace:
|
if isinstance(a, Argument) and a.name in g.out.precomputed.replace:
|
||||||
# If a is in precompute.replace, append the parameters
|
# If a is in precompute.replace, append the parameters
|
||||||
|
|
@ -102,6 +101,9 @@ def impl_arguments(g: NativeFunctionsGroup) -> List[Binding]:
|
||||||
non_out_args_replaced.append(a)
|
non_out_args_replaced.append(a)
|
||||||
|
|
||||||
args.extend(non_out_args_replaced)
|
args.extend(non_out_args_replaced)
|
||||||
|
# g.out.precomputed.add is the list of parameters that are added
|
||||||
|
# without replacement after the non out args and just before the out args
|
||||||
|
args.extend(g.out.precomputed.add)
|
||||||
else:
|
else:
|
||||||
args.extend(g.out.func.arguments.non_out)
|
args.extend(g.out.func.arguments.non_out)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -645,7 +645,8 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
|
||||||
# Put all of the contents of the precompute struct into the context
|
# Put all of the contents of the precompute struct into the context
|
||||||
# so that translate will be able to return the correct args for the
|
# so that translate will be able to return the correct args for the
|
||||||
# call to the impl.
|
# call to the impl.
|
||||||
for precomputed_elems in self.g.out.precomputed.replace.values():
|
precomputed_values = [*self.g.out.precomputed.replace.values(), self.g.out.precomputed.add]
|
||||||
|
for precomputed_elems in precomputed_values:
|
||||||
for arg in precomputed_elems:
|
for arg in precomputed_elems:
|
||||||
context.append(Expr(
|
context.append(Expr(
|
||||||
expr=f"precompute.{arg.name}",
|
expr=f"precompute.{arg.name}",
|
||||||
|
|
|
||||||
|
|
@ -508,7 +508,8 @@ def compute_meta_function_declaration(g: NativeFunctionsGroup) -> Optional[str]:
|
||||||
# Generate the template declaration with one bool parameter for each
|
# Generate the template declaration with one bool parameter for each
|
||||||
# precomputed element. Each parameter is true if the corresponding (in
|
# precomputed element. Each parameter is true if the corresponding (in
|
||||||
# terms of position) precomputed element has been set.
|
# terms of position) precomputed element has been set.
|
||||||
precomputed_elements = [elem for replace_list in precomputed.replace.values() for elem in replace_list]
|
precomputed_values = [*precomputed.replace.values(), precomputed.add]
|
||||||
|
precomputed_elements = [elem for replace_list in precomputed_values for elem in replace_list]
|
||||||
precomputed_template_parameters = [elem.name.upper() for elem in precomputed_elements]
|
precomputed_template_parameters = [elem.name.upper() for elem in precomputed_elements]
|
||||||
precomputed_template_params_str = ", ".join(f"bool {param} = false" for param in precomputed_template_parameters)
|
precomputed_template_params_str = ", ".join(f"bool {param} = false" for param in precomputed_template_parameters)
|
||||||
precompute_template_decl = f"template <{precomputed_template_params_str}>"
|
precompute_template_decl = f"template <{precomputed_template_params_str}>"
|
||||||
|
|
|
||||||
|
|
@ -1595,6 +1595,8 @@ class Precompute:
|
||||||
# A map from kernel argument name -> a list of precomputed
|
# A map from kernel argument name -> a list of precomputed
|
||||||
# elements that replaces/supersedes it.
|
# elements that replaces/supersedes it.
|
||||||
replace: Dict[str, List[Argument]]
|
replace: Dict[str, List[Argument]]
|
||||||
|
# List of precomputed args added without replacement
|
||||||
|
add: List[Argument]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse(src: object) -> 'Precompute':
|
def parse(src: object) -> 'Precompute':
|
||||||
|
|
@ -1602,18 +1604,29 @@ class Precompute:
|
||||||
|
|
||||||
# src is a list of strings of the format:
|
# src is a list of strings of the format:
|
||||||
# {kernel param name} -> {replacement decl}[, {replacement decl}, ...]
|
# {kernel param name} -> {replacement decl}[, {replacement decl}, ...]
|
||||||
# Parse this list to get the names of which precomputed elements
|
# [{add decl}[, {add decl}, ...]]
|
||||||
|
# The last line is optional and contains the precomputed parameters that are
|
||||||
|
# added without replacement.
|
||||||
|
# The other lines are parsed to get the names of which precomputed elements
|
||||||
# should replace which kernel arguments.
|
# should replace which kernel arguments.
|
||||||
|
add_args = []
|
||||||
|
if ' -> ' not in src[-1]:
|
||||||
|
add_list = src[-1].split(',')
|
||||||
|
add_args = [Argument.parse(name.strip()) for name in add_list]
|
||||||
|
src = src[:-1]
|
||||||
|
|
||||||
replace = {}
|
replace = {}
|
||||||
for raw_replace_item in src:
|
for raw_replace_item in src:
|
||||||
assert isinstance(raw_replace_item, str)
|
assert isinstance(raw_replace_item, str)
|
||||||
|
assert ' -> ' in raw_replace_item, 'precomputed parameters without replacement' \
|
||||||
|
' are allowed only in the last line'
|
||||||
|
|
||||||
arg, with_list_raw = raw_replace_item.split(' -> ')
|
arg, with_list_raw = raw_replace_item.split(' -> ')
|
||||||
with_list = with_list_raw.split(',')
|
with_list = with_list_raw.split(',')
|
||||||
with_list_args = [Argument.parse(name.strip()) for name in with_list]
|
with_list_args = [Argument.parse(name.strip()) for name in with_list]
|
||||||
replace[arg] = with_list_args
|
replace[arg] = with_list_args
|
||||||
|
|
||||||
r = Precompute(replace=replace)
|
r = Precompute(replace=replace, add=add_args)
|
||||||
assert r.to_list() == src, 'r.to_list() != src'
|
assert r.to_list() == src, 'r.to_list() != src'
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user