mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
**Summary:** This commit simplifies the existing decomposition hierarchy of batch norm ops by adding a single, backend agnostic op: `batch_norm_with_update`. The existing hierarchy looks like: ``` aten.batch_norm -> aten._batch_norm_impl_index -> [ aten.native_batch_norm -> aten._native_batch_norm_legit (export only) -> _batch_norm_legit_cpu/cuda (kernels, export only) -> _batch_norm_cpu/cuda (kernels) ] OR [ aten.cudnn_batch_norm ] OR [ aten.miopen_batch_norm ] ``` Aside from complexity, an important problem with the above decomposition hierarchy is cuda numerics in export flows. We observed significantly worse convergence when training a mobilenetv2-like model when using the `_batch_norm_cuda` kernel instead of the `cudnn_batch_norm` kernel. This means users who export their models on CPU first then move the models to cuda later may silently see worse accuracies even when cudnn is installed, because they are using the worse kernel. This issue is summarized in https://github.com/pytorch/pytorch/issues/111384. Instead, the new hierarchy proposed by consolidating existing batch norm ops will look like: ``` aten.batch_norm -> aten.batch_norm_with_update -> [ _batch_norm_cpu (kernel) ] OR [ _batch_norm_cuda (kernel) ] OR [ cudnn_batch_norm (kernel) ] OR [ miopen_batch_norm (kernel) ] ``` The new op `batch_norm_with_update` hides backend implementation details and automatically picks the right kernel based on what is installed. This commit also adds the following variants to this op: ``` batch_norm_with_update_functional batch_norm_with_update.out batch_norm_no_update batch_norm_no_update.out batch_norm_backward ``` Note that this commit only adds this op and its variants, but does not actually change the decomps to produce these ops in the graph. This will be done after the 2 week FC window, and the ops used in the old stack is planned to be removed after the 6 month BC window. Test Plan: `OpInfo` tests for `batch_norm_with_update`. Reviewers: albanD, bdhirsh Subscribers: albanD, bdhirsh, supriyar Tasks: https://github.com/pytorch/pytorch/issues/111384 Co-authored-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/116092 Approved by: https://github.com/bdhirsh, https://github.com/albanD
1465 lines
44 KiB
Python
1465 lines
44 KiB
Python
import math
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
number = Union[int, float]
|
|
# flake8: noqa
|
|
|
|
###
|
|
# There are generated files that depend on this file
|
|
# To re-generate, please run from the root of the repo:
|
|
# python torchgen/shape_functions/gen_jit_shape_functions.py
|
|
|
|
# How to test:
|
|
# After regenerating files, compile PyTorch.
|
|
# Then run: ./build/bin/test_jit --gtest_filter=TestShapeGraphLinting.Basic
|
|
# If you have enabled opinfo testing for the op, also run:
|
|
# python test/test_ops_jit.py TestJitCPU.test_variant_consistency_jit_[FAILING_OP]_cpu_float32
|
|
# to reproduce errors from opinfo tests.
|
|
|
|
# Example PR: https://github.com/pytorch/pytorch/pull/80860/files
|
|
####
|
|
|
|
import torch
|
|
|
|
|
|
def broadcast(a: List[int], b: List[int]):
|
|
dimsA = len(a)
|
|
dimsB = len(b)
|
|
ndim = max(dimsA, dimsB)
|
|
expandedSizes: List[int] = []
|
|
|
|
for i in range(ndim):
|
|
offset = ndim - 1 - i
|
|
dimA = dimsA - 1 - offset
|
|
dimB = dimsB - 1 - offset
|
|
sizeA = a[dimA] if (dimA >= 0) else 1
|
|
sizeB = b[dimB] if (dimB >= 0) else 1
|
|
|
|
if sizeA != sizeB and sizeA != 1 and sizeB != 1:
|
|
# TODO: only assertion error is bound in C++ compilation right now
|
|
raise AssertionError(
|
|
f"The size of tensor a {sizeA} must match the size of tensor b ({sizeB}) at non-singleton dimension {i}"
|
|
)
|
|
|
|
expandedSizes.append(sizeB if sizeA == 1 else sizeA)
|
|
|
|
return expandedSizes
|
|
|
|
|
|
def broadcast_three(a: List[int], b: List[int], c: List[int]):
|
|
return broadcast(broadcast(a, b), c)
|
|
|
|
|
|
def broadcast_one_three(a: List[int], b: Any, c: List[int]):
|
|
return broadcast(a, c)
|
|
|
|
|
|
def adaptive_avg_pool2d(self: List[int], out: List[int]):
|
|
assert len(out) == 2
|
|
assert len(self) == 3 or len(self) == 4
|
|
for i in range(1, len(self)):
|
|
assert self[i] != 0
|
|
|
|
shape: List[int] = []
|
|
for i in range(0, len(self) - 2):
|
|
shape.append(self[i])
|
|
for elem in out:
|
|
shape.append(elem)
|
|
return shape
|
|
|
|
|
|
def _copy(self: List[int]):
|
|
out: List[int] = []
|
|
for elem in self:
|
|
out.append(elem)
|
|
return out
|
|
|
|
|
|
def unary(self: List[int]):
|
|
return _copy(self)
|
|
|
|
|
|
def broadcast_inplace(a: List[int], b: List[int]):
|
|
dimsA = len(a)
|
|
dimsB = len(b)
|
|
if dimsB > dimsA:
|
|
raise AssertionError(
|
|
f"The dims of tensor b ({dimsB}) must be less than or equal tothe dims of tensor a ({dimsA}) "
|
|
)
|
|
for dimA in range(dimsA):
|
|
dimB = dimsB - dimsA + dimA
|
|
sizeA = a[dimA]
|
|
sizeB = b[dimB] if (dimB >= 0) else 1
|
|
if sizeA != sizeB and sizeB != 1:
|
|
# TODO: only assertion error is bound in C++ compilation right now
|
|
raise AssertionError(
|
|
"The size of tensor a {} must match the size of tensor b ("
|
|
"{}) at non-singleton dimension {}".format(sizeA, sizeB, dimA)
|
|
)
|
|
return _copy(a)
|
|
|
|
|
|
def expand(self: List[int], sizes: List[int]):
|
|
assert len(sizes) >= len(self)
|
|
ndim = len(sizes)
|
|
tensor_dim = len(self)
|
|
if ndim == 0:
|
|
return _copy(sizes)
|
|
out: List[int] = []
|
|
for i in range(ndim):
|
|
offset = ndim - 1 - i
|
|
dim = tensor_dim - 1 - offset
|
|
size = self[dim] if dim >= 0 else 1
|
|
targetSize = sizes[i]
|
|
if targetSize == -1:
|
|
assert dim >= 0
|
|
targetSize = size
|
|
if size != targetSize:
|
|
assert size == 1
|
|
size = targetSize
|
|
out.append(size)
|
|
return out
|
|
|
|
|
|
def expand_one_unused(self: List[int], sizes: List[int], inp0: Any):
|
|
return expand(self, sizes)
|
|
|
|
|
|
def infer_size_impl(shape: List[int], numel: int) -> List[int]:
|
|
newsize = 1
|
|
infer_dim: Optional[int] = None
|
|
for dim in range(len(shape)):
|
|
if shape[dim] == -1:
|
|
if infer_dim is not None:
|
|
raise AssertionError("only one dimension can be inferred")
|
|
infer_dim = dim
|
|
elif shape[dim] >= 0:
|
|
newsize *= shape[dim]
|
|
else:
|
|
raise AssertionError("invalid shape dimensions")
|
|
if not (
|
|
numel == newsize
|
|
or (infer_dim is not None and newsize > 0 and numel % newsize == 0)
|
|
):
|
|
raise AssertionError("invalid shape")
|
|
out = _copy(shape)
|
|
if infer_dim is not None:
|
|
out[infer_dim] = numel // newsize
|
|
return out
|
|
|
|
|
|
def numel(sizes: List[int]):
|
|
numel = 1
|
|
for elem in sizes:
|
|
numel *= elem
|
|
return numel
|
|
|
|
|
|
def view(self: List[int], sizes: List[int]):
|
|
return infer_size_impl(sizes, numel(self))
|
|
|
|
|
|
def view_one_unused(self: List[int], sizes: List[int], *, implicit: bool = False):
|
|
return view(self, sizes)
|
|
|
|
|
|
def sum_mean_dim(
|
|
self: List[int], opt_dims: Optional[List[int]], keep_dim: bool, dt: Any
|
|
):
|
|
out: List[int] = []
|
|
if opt_dims is None or len(opt_dims) == 0:
|
|
dims: List[int] = list(range(len(self)))
|
|
else:
|
|
dims = opt_dims
|
|
|
|
for idx in range(len(self)):
|
|
is_mean_dim: bool = False
|
|
for reduce_dim in dims:
|
|
if idx == maybe_wrap_dim(reduce_dim, len(self)):
|
|
is_mean_dim = True
|
|
if is_mean_dim:
|
|
if keep_dim:
|
|
out.append(1)
|
|
else:
|
|
out.append(self[idx])
|
|
return out
|
|
|
|
|
|
def max_dim(self: List[int], dim: int, keep_dim: bool):
|
|
out = sum_mean_dim(self, [dim], keep_dim, None)
|
|
return out, out
|
|
|
|
|
|
# note: python already rounds down towards negative infinity on integer division, special arithmetic not needed
|
|
def div_rtn(x: int, y: int):
|
|
return x // y
|
|
|
|
|
|
def pooling_output_shape_pad_lr(
|
|
inputSize: int,
|
|
kernelSize: int,
|
|
pad_l: int,
|
|
pad_r: int,
|
|
stride: int,
|
|
dilation: int,
|
|
ceil_mode: bool,
|
|
):
|
|
outputSize = (
|
|
div_rtn(
|
|
inputSize
|
|
+ pad_l
|
|
+ pad_r
|
|
- dilation * (kernelSize - 1)
|
|
- 1
|
|
+ (stride - 1 if ceil_mode else 0),
|
|
stride,
|
|
)
|
|
+ 1
|
|
)
|
|
if ceil_mode:
|
|
if (outputSize - 1) * stride >= inputSize + pad_l:
|
|
outputSize = outputSize - 1
|
|
return outputSize
|
|
|
|
|
|
def pooling_output_shape(
|
|
inputSize: int,
|
|
kernelSize: int,
|
|
pad_l: int,
|
|
stride: int,
|
|
dilation: int,
|
|
ceil_mode: bool,
|
|
):
|
|
assert stride != 0, "stride should not be zeero"
|
|
return pooling_output_shape_pad_lr(
|
|
inputSize, kernelSize, pad_l, pad_l, stride, dilation, ceil_mode
|
|
)
|
|
|
|
|
|
def pool2d_shape_check(
|
|
input: List[int],
|
|
kH: int,
|
|
kW: int,
|
|
dH: int,
|
|
dW: int,
|
|
padH: int,
|
|
padW: int,
|
|
dilationH: int,
|
|
dilationW: int,
|
|
nInputPlane: int,
|
|
inputHeight: int,
|
|
inputWidth: int,
|
|
outputHeight: int,
|
|
outputWidth: int,
|
|
):
|
|
ndim = len(input)
|
|
nOutputPlane = nInputPlane
|
|
|
|
assert kW > 0 and kH > 0
|
|
assert dW > 0 and dH > 0
|
|
assert dilationH > 0 and dilationW > 0
|
|
|
|
valid_dims = input[1] != 0 and input[2] != 0
|
|
assert (
|
|
ndim == 3
|
|
and input[0] != 0
|
|
and valid_dims
|
|
or (ndim == 4 and valid_dims and input[3] != 0)
|
|
)
|
|
|
|
assert kW // 2 >= padW and kH // 2 >= padH
|
|
assert outputWidth >= 1 and outputHeight >= 1
|
|
|
|
|
|
def max_pool2d(
|
|
input: List[int],
|
|
kernel_size: List[int],
|
|
stride: List[int],
|
|
padding: List[int],
|
|
dilation: List[int],
|
|
ceil_mode: bool,
|
|
):
|
|
assert (
|
|
len(kernel_size) == 1 or len(kernel_size) == 2
|
|
), "max_pool2d: kernel_size must either be a single int, or a tuple of two ints"
|
|
kH = kernel_size[0]
|
|
kW = kH if len(kernel_size) == 1 else kernel_size[1]
|
|
|
|
assert (
|
|
len(stride) == 0 or len(stride) == 1 or len(stride) == 2
|
|
), "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints"
|
|
dH = kH if len(stride) == 0 else stride[0]
|
|
if len(stride) == 0:
|
|
dW = kW
|
|
elif len(stride) == 1:
|
|
dW = dH
|
|
else:
|
|
dW = stride[1]
|
|
|
|
assert (
|
|
len(padding) == 1 or len(padding) == 2
|
|
), "max_pool2d: padding must either be a single int, or a tuple of two ints"
|
|
padH = padding[0]
|
|
padW = padH if len(padding) == 1 else padding[1]
|
|
|
|
assert (
|
|
len(dilation) == 1 or len(dilation) == 2
|
|
), "max_pool2d: dilation must be either a single int, or a tuple of two ints"
|
|
dilationH = dilation[0]
|
|
dilationW = dilationH if len(dilation) == 1 else dilation[1]
|
|
|
|
assert len(input) == 3 or len(input) == 4
|
|
|
|
nbatch = input[-4] if len(input) == 4 else 1
|
|
nInputPlane = input[-3]
|
|
inputHeight = input[-2]
|
|
inputWidth = input[-1]
|
|
|
|
outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode)
|
|
outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode)
|
|
|
|
pool2d_shape_check(
|
|
input,
|
|
kH,
|
|
kW,
|
|
dH,
|
|
dW,
|
|
padH,
|
|
padW,
|
|
dilationH,
|
|
dilationW,
|
|
nInputPlane,
|
|
inputHeight,
|
|
inputWidth,
|
|
outputHeight,
|
|
outputWidth,
|
|
)
|
|
|
|
if len(input) == 3:
|
|
return [nInputPlane, outputHeight, outputWidth]
|
|
else:
|
|
return [nbatch, nInputPlane, outputHeight, outputWidth]
|
|
|
|
|
|
def max_pool2d_with_indices(
|
|
input: List[int],
|
|
kernel_size: List[int],
|
|
stride: List[int],
|
|
padding: List[int],
|
|
dilation: List[int],
|
|
ceil_mode: bool,
|
|
):
|
|
out = max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
|
|
return (out, out)
|
|
|
|
|
|
def upsample_nearest2d(
|
|
input: List[int],
|
|
output_size: Optional[List[int]],
|
|
scale_factors: Optional[List[float]],
|
|
):
|
|
out: List[int] = []
|
|
out.append(input[0])
|
|
out.append(input[1])
|
|
|
|
if scale_factors is None and output_size is None:
|
|
assert 0, "Either output_size or scale_factors must be presented"
|
|
|
|
if output_size is not None:
|
|
assert (
|
|
scale_factors is None
|
|
), "Must specify exactly one of output_size and scale_factors"
|
|
assert len(output_size) == 2
|
|
out.append(output_size[0])
|
|
out.append(output_size[1])
|
|
|
|
if scale_factors is not None:
|
|
assert (
|
|
output_size is None
|
|
), "Must specify exactly one of output_size and scale_factors"
|
|
assert len(scale_factors) == 2
|
|
out.append(int(input[2] * scale_factors[0]))
|
|
out.append(int(input[3] * scale_factors[1]))
|
|
|
|
return out
|
|
|
|
|
|
def mm(self: List[int], mat2: List[int]):
|
|
assert len(self) == 2, "self must be a matrix"
|
|
assert len(mat2) == 2, "mat2 must be a matrix"
|
|
|
|
assert self[1] == mat2[0]
|
|
return [self[0], mat2[1]]
|
|
|
|
|
|
def dot(self: List[int], tensor: List[int]):
|
|
assert len(self) == 1 and len(tensor) == 1
|
|
assert self[0] == tensor[0]
|
|
out: List[int] = []
|
|
return out
|
|
|
|
|
|
def mv(self: List[int], vec: List[int]):
|
|
assert len(self) == 2 and len(vec) == 1
|
|
assert self[1] == vec[0]
|
|
# TODO: return self
|
|
return [self[0]]
|
|
|
|
|
|
def unsqueeze(li: List[int], dim: int):
|
|
dim = maybe_wrap_dim(dim, len(li) + 1)
|
|
out = _copy(li)
|
|
out.insert(dim, 1)
|
|
return out
|
|
|
|
|
|
def squeeze_nodim(li: List[int]):
|
|
out: List[int] = []
|
|
for i in range(len(li)):
|
|
if li[i] != 1:
|
|
out.append(li[i])
|
|
return out
|
|
|
|
|
|
def squeeze(li: List[int], dim: int):
|
|
out: List[int] = []
|
|
wrapped_dim = maybe_wrap_dim(dim, len(li))
|
|
for i in range(len(li)):
|
|
if i == wrapped_dim:
|
|
if li[i] != 1:
|
|
out.append(li[i])
|
|
else:
|
|
out.append(li[i])
|
|
return out
|
|
|
|
|
|
def squeeze_dims(li: List[int], dims: List[int]):
|
|
if len(dims) == 0:
|
|
return li
|
|
wrapped_dims = _copy(dims)
|
|
for i in range(len(dims)):
|
|
wrapped_dims[i] = maybe_wrap_dim(wrapped_dims[i], len(li))
|
|
result: List[int] = []
|
|
for i in range(len(li)):
|
|
if li[i] == 1:
|
|
if i not in wrapped_dims:
|
|
result.append(li[i])
|
|
else:
|
|
result.append(li[i])
|
|
return result
|
|
|
|
|
|
def index_select(self: List[int], dim: int, index: List[int]):
|
|
dim = maybe_wrap_dim(dim, len(self))
|
|
numel = multiply_integers(index)
|
|
assert len(index) <= 1
|
|
assert dim == 0 or dim < len(self)
|
|
result_size: List[int] = []
|
|
for i in range(len(self)):
|
|
if dim == i:
|
|
result_size.append(numel)
|
|
else:
|
|
result_size.append(self[i])
|
|
return result_size
|
|
|
|
|
|
def embedding(
|
|
weight: List[int],
|
|
indices: List[int],
|
|
padding_idx: int = -1,
|
|
scale_grad_by_freq: bool = False,
|
|
sparse: bool = False,
|
|
):
|
|
assert len(weight) == 2
|
|
if len(indices) == 1:
|
|
return index_select(weight, 0, indices)
|
|
size = _copy(indices)
|
|
size.append(weight[1])
|
|
return size
|
|
|
|
|
|
def max_int():
|
|
return 9223372036854775807
|
|
|
|
|
|
def slice(
|
|
self: List[int], dim: int, start: Optional[int], end: Optional[int], step: int
|
|
):
|
|
ndim = len(self)
|
|
assert ndim != 0
|
|
dim = maybe_wrap_dim(dim, ndim)
|
|
start_val = start if start is not None else 0
|
|
end_val = end if end is not None else max_int()
|
|
assert step > 0
|
|
if start_val == max_int():
|
|
start_val = 0
|
|
if start_val < 0:
|
|
start_val += self[dim]
|
|
if end_val < 0:
|
|
end_val += self[dim]
|
|
if start_val < 0:
|
|
start_val = 0
|
|
elif start_val > self[dim]:
|
|
start_val = self[dim]
|
|
if end_val < start_val:
|
|
end_val = start_val
|
|
elif end_val >= self[dim]:
|
|
end_val = self[dim]
|
|
slice_len = end_val - start_val
|
|
out = _copy(self)
|
|
out[dim] = (slice_len + step - 1) // step
|
|
return out
|
|
|
|
|
|
def check_cat_no_zero_dim(tensors: List[List[int]]):
|
|
for tensor in tensors:
|
|
assert len(tensor) > 0
|
|
|
|
|
|
def legacy_cat_wrap_dim(dim: int, tensor_sizes: List[List[int]]):
|
|
out_dim: Optional[int] = None
|
|
for size in tensor_sizes:
|
|
if not (len(size) == 1 and size[0] == 0):
|
|
if out_dim is None:
|
|
out_dim = maybe_wrap_dim(dim, len(size))
|
|
if out_dim is None:
|
|
out_dim = dim
|
|
return out_dim
|
|
|
|
|
|
def should_skip(tensor: List[int]):
|
|
return numel(tensor) == 0 and len(tensor) == 1
|
|
|
|
|
|
def check_cat_shape_except_dim(
|
|
first: List[int], second: List[int], dimension: int, index: int
|
|
):
|
|
first_dims = len(first)
|
|
second_dims = len(second)
|
|
assert first_dims == second_dims, "Tensors must have same number of dimensions"
|
|
for dim in range(0, first_dims):
|
|
if dim != dimension:
|
|
assert (
|
|
first[dim] == second[dim]
|
|
), "Sizes of tensors must match except in dimension"
|
|
|
|
|
|
def cat(tensors: List[List[int]], dim: int):
|
|
check_cat_no_zero_dim(tensors)
|
|
dim = legacy_cat_wrap_dim(dim, tensors)
|
|
assert len(tensors) > 0
|
|
not_skipped_tensor: Optional[List[int]] = None
|
|
for tensor in tensors:
|
|
if not should_skip(tensor):
|
|
not_skipped_tensor = tensor
|
|
if not_skipped_tensor is None:
|
|
return [0]
|
|
|
|
cat_dim_size = 0
|
|
|
|
for i in range(len(tensors)):
|
|
tensor = tensors[i]
|
|
if not should_skip(tensor):
|
|
check_cat_shape_except_dim(not_skipped_tensor, tensor, dim, i)
|
|
cat_dim_size = cat_dim_size + tensor[dim]
|
|
|
|
result_size = _copy(not_skipped_tensor)
|
|
result_size[dim] = cat_dim_size
|
|
return result_size
|
|
|
|
|
|
def stack(tensors: List[List[int]], dim: int):
|
|
unsqueezed_tensors: List[List[int]] = []
|
|
for tensor in tensors:
|
|
unsqueezed = unsqueeze(tensor, dim)
|
|
unsqueezed_tensors.append(unsqueezed)
|
|
return cat(unsqueezed_tensors, dim)
|
|
|
|
|
|
def select(self: List[int], dim: int, index: int):
|
|
ndim = len(self)
|
|
assert ndim != 0
|
|
dim = maybe_wrap_dim(dim, ndim)
|
|
size = self[dim]
|
|
assert not (index < -size or index >= size)
|
|
if index < 0:
|
|
index += size
|
|
out: List[int] = []
|
|
for i in range(ndim):
|
|
if i != dim:
|
|
out.append(self[i])
|
|
return out
|
|
|
|
|
|
def matmul(tensor1: List[int], tensor2: List[int]):
|
|
dim_tensor1 = len(tensor1)
|
|
dim_tensor2 = len(tensor2)
|
|
if dim_tensor1 == 1 and dim_tensor2 == 1:
|
|
return dot(tensor1, tensor2)
|
|
elif dim_tensor1 == 2 and dim_tensor2 == 1:
|
|
return mv(tensor1, tensor2)
|
|
elif dim_tensor1 == 1 and dim_tensor2 == 2:
|
|
return squeeze(mm(unsqueeze(tensor1, 0), tensor2), 0)
|
|
elif dim_tensor1 == 2 and dim_tensor2 == 2:
|
|
return mm(tensor1, tensor2)
|
|
elif dim_tensor1 >= 1 and dim_tensor2 >= 1:
|
|
# We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
|
|
# we track m1 vs m2 separately even though they must match for nicer error messages
|
|
n = tensor1[-2] if dim_tensor1 > 1 else 1
|
|
m1 = tensor1[-1]
|
|
batch_tensor1: List[int] = []
|
|
# TODO: handling of slice
|
|
for i in range(dim_tensor1 - 2):
|
|
batch_tensor1.append(tensor1[i])
|
|
m2 = tensor2[-1] if dim_tensor2 > 1 else 1
|
|
p = tensor2[-1]
|
|
batch_tensor2: List[int] = []
|
|
# TODO: handling of slice
|
|
for i in range(dim_tensor2 - 2):
|
|
batch_tensor2.append(tensor2[i])
|
|
|
|
# expand the batch portion (i.e. cut off matrix dimensions and expand rest)
|
|
expand_batch_portion = broadcast(batch_tensor1, batch_tensor2)
|
|
|
|
# todo: copy ?
|
|
output_shape = expand_batch_portion
|
|
if dim_tensor1 > 1:
|
|
output_shape.append(n)
|
|
|
|
if dim_tensor2 > 1:
|
|
output_shape.append(p)
|
|
|
|
return output_shape
|
|
else:
|
|
assert False, "both arguments to matmul need to be at least 1D"
|
|
|
|
|
|
def t(self: List[int]):
|
|
assert len(self) <= 2
|
|
self_len = len(self)
|
|
if self_len == 0:
|
|
out: List[int] = []
|
|
return out
|
|
elif self_len == 1:
|
|
return [self[0]]
|
|
else:
|
|
return [self[1], self[0]]
|
|
|
|
|
|
def transpose(self: List[int], dim0: int, dim1: int):
|
|
ndims = len(self)
|
|
dim0 = maybe_wrap_dim(dim0, ndims)
|
|
dim1 = maybe_wrap_dim(dim1, ndims)
|
|
if dim0 == dim1:
|
|
return _copy(self)
|
|
out: List[int] = []
|
|
for i in range(ndims):
|
|
if i == dim0:
|
|
out.append(self[dim1])
|
|
elif i == dim1:
|
|
out.append(self[dim0])
|
|
else:
|
|
out.append(self[i])
|
|
return out
|
|
|
|
|
|
def linear(input: List[int], weight: List[int], bias: Optional[List[int]]):
|
|
out = matmul(input, t(weight))
|
|
if bias is not None:
|
|
assert broadcast(bias, out) == out
|
|
return out
|
|
|
|
|
|
def addmm(self: List[int], mat1: List[int], mat2: List[int], beta: Any, alpha: Any):
|
|
return broadcast(self, mm(mat1, mat2))
|
|
|
|
|
|
def check_non_negative(array: List[int]) -> bool:
|
|
# TODO: look into rewriting with early return and getting loop unrolling to fire
|
|
non_negative = False
|
|
for val in array:
|
|
if val < 0:
|
|
non_negative = True
|
|
return non_negative
|
|
|
|
|
|
def check_shape_forward(
|
|
input: List[int],
|
|
weight_sizes: List[int],
|
|
bias: Optional[List[int]],
|
|
stride: List[int],
|
|
padding: List[int],
|
|
dilation: List[int],
|
|
groups: int,
|
|
):
|
|
k = len(input)
|
|
weight_dim = len(weight_sizes)
|
|
|
|
# TODO: assertions could be expanded with the error messages
|
|
assert not check_non_negative(padding)
|
|
assert not check_non_negative(stride)
|
|
|
|
assert weight_dim == k
|
|
assert weight_sizes[0] >= groups
|
|
assert (weight_sizes[0] % groups) == 0
|
|
# only handling not transposed
|
|
assert input[1] == weight_sizes[1] * groups
|
|
assert bias is None or (len(bias) == 1 and bias[0] == weight_sizes[0])
|
|
|
|
for i in range(2, k):
|
|
assert (input[i] + 2 * padding[i - 2]) >= (
|
|
dilation[i - 2] * (weight_sizes[i] - 1) + 1
|
|
)
|
|
|
|
# this is not handling transposed convolution yet
|
|
|
|
|
|
def conv_output_size(
|
|
input_size: List[int],
|
|
weight_size: List[int],
|
|
bias: Optional[List[int]],
|
|
stride: List[int],
|
|
padding: List[int],
|
|
dilation: List[int],
|
|
groups: int,
|
|
):
|
|
check_shape_forward(
|
|
input_size, weight_size, bias, stride, padding, dilation, groups
|
|
)
|
|
|
|
has_dilation = len(dilation) > 0
|
|
dim = len(input_size)
|
|
output_size: List[int] = []
|
|
input_batch_size_dim = 0
|
|
weight_output_channels_dim = 0
|
|
output_size.append(input_size[input_batch_size_dim])
|
|
output_size.append(weight_size[weight_output_channels_dim])
|
|
|
|
for d in range(2, dim):
|
|
dilation_ = dilation[d - 2] if has_dilation else 1
|
|
kernel = dilation_ * (weight_size[d] - 1) + 1
|
|
output_size.append(
|
|
(input_size[d] + (2 * padding[d - 2]) - kernel) // stride[d - 2] + 1
|
|
)
|
|
return output_size
|
|
|
|
|
|
def conv1d(
|
|
input: List[int],
|
|
weight: List[int],
|
|
bias: Optional[List[int]],
|
|
stride: List[int],
|
|
padding: List[int],
|
|
dilation: List[int],
|
|
groups: int,
|
|
):
|
|
assert len(weight) == 3
|
|
assert len(input) == 3
|
|
return conv_output_size(input, weight, bias, stride, padding, dilation, groups)
|
|
|
|
|
|
def conv2d(
|
|
input: List[int],
|
|
weight: List[int],
|
|
bias: Optional[List[int]],
|
|
stride: List[int],
|
|
padding: List[int],
|
|
dilation: List[int],
|
|
groups: int,
|
|
):
|
|
assert len(weight) == 4
|
|
assert len(input) == 4
|
|
return conv_output_size(input, weight, bias, stride, padding, dilation, groups)
|
|
|
|
|
|
def conv_backwards(
|
|
grad_output: List[int],
|
|
input: List[int],
|
|
weight: List[int],
|
|
biases: Optional[List[int]],
|
|
):
|
|
# Bias gradient is always generated regardess of if biases is supplied
|
|
return _copy(input), _copy(weight), [grad_output[1]]
|
|
|
|
|
|
def conv_transpose2d_input(
|
|
input: List[int],
|
|
weight: List[int],
|
|
bias: Optional[List[int]] = None,
|
|
stride: Optional[List[int]] = None,
|
|
padding: Optional[List[int]] = None,
|
|
output_padding: Optional[List[int]] = None,
|
|
groups: int = 1,
|
|
dilation: Optional[List[int]] = None,
|
|
) -> List[int]:
|
|
if stride is None:
|
|
stride = [1, 1]
|
|
if padding is None:
|
|
padding = [0, 0]
|
|
if output_padding is None:
|
|
output_padding = [0, 0]
|
|
if dilation is None:
|
|
dilation = [1, 1]
|
|
has_dilation = len(dilation) > 0
|
|
dim = len(input)
|
|
output_size: List[int] = []
|
|
input_batch_size_dim = 0
|
|
weight_output_channels_dim = 1
|
|
output_size.append(input[input_batch_size_dim])
|
|
output_size.append(weight[weight_output_channels_dim] * groups)
|
|
|
|
for d in range(2, dim):
|
|
dilation_ = dilation[d - 2] if has_dilation else 1
|
|
kernel = dilation_ * (weight[d] - 1)
|
|
output_size.append(
|
|
(input[d] - 1) * stride[d - 2]
|
|
- 2 * padding[d - 2]
|
|
+ kernel
|
|
+ output_padding[d - 2]
|
|
+ 1
|
|
)
|
|
return output_size
|
|
|
|
|
|
def conv_forwards(
|
|
input: List[int],
|
|
weight: List[int],
|
|
bias: Optional[List[int]],
|
|
stride: List[int],
|
|
padding: List[int],
|
|
dilation: List[int],
|
|
transposed: bool,
|
|
output_padding: List[int],
|
|
groups: int,
|
|
) -> List[int]:
|
|
has_dilation = len(dilation) > 0
|
|
has_output_padding = len(output_padding) > 0
|
|
dim = len(input)
|
|
output_size: List[int] = []
|
|
input_batch_size_dim = 0
|
|
weight_output_channels_dim = 1 if transposed else 0
|
|
output_size.append(input[input_batch_size_dim])
|
|
if transposed:
|
|
output_size.append(weight[weight_output_channels_dim] * groups)
|
|
else:
|
|
output_size.append(weight[weight_output_channels_dim])
|
|
|
|
for d in range(2, dim):
|
|
dilation_ = dilation[d - 2] if has_dilation else 1
|
|
output_padding_ = output_padding[d - 2] if has_output_padding else 0
|
|
if transposed:
|
|
kernel = dilation_ * (weight[d] - 1)
|
|
output_size.append(
|
|
(input[d] - 1) * stride[d - 2]
|
|
- 2 * padding[d - 2]
|
|
+ kernel
|
|
+ output_padding_
|
|
+ 1
|
|
)
|
|
else:
|
|
kernel = dilation_ * (weight[d] - 1) + 1
|
|
output_size.append(
|
|
(input[d] + (2 * padding[d - 2]) - kernel) // stride[d - 2] + 1
|
|
)
|
|
return output_size
|
|
|
|
|
|
def _conv_forwards(
|
|
input: List[int],
|
|
weight: List[int],
|
|
bias: Optional[List[int]],
|
|
stride: List[int],
|
|
padding: List[int],
|
|
dilation: List[int],
|
|
transposed: bool,
|
|
output_padding: List[int],
|
|
groups: int,
|
|
benchmark: bool,
|
|
deterministic: bool,
|
|
cudnn_enabled: bool,
|
|
allow_tf32: bool,
|
|
) -> List[int]:
|
|
return conv_forwards(
|
|
input,
|
|
weight,
|
|
bias,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
transposed,
|
|
output_padding,
|
|
groups,
|
|
)
|
|
|
|
|
|
def batch_norm(
|
|
input: List[int],
|
|
weight: Optional[List[int]],
|
|
bias: Optional[List[int]],
|
|
running_mean: Optional[List[int]],
|
|
running_var: Optional[List[int]],
|
|
training: bool,
|
|
momentum: float,
|
|
eps: float,
|
|
cudnn_enabled: bool,
|
|
):
|
|
out: List[int] = []
|
|
for elem in input:
|
|
out.append(elem)
|
|
return out
|
|
|
|
|
|
def conv3d(
|
|
input: List[int],
|
|
weight: List[int],
|
|
bias: Optional[List[int]],
|
|
stride: List[int],
|
|
padding: List[int],
|
|
dilation: List[int],
|
|
groups: int,
|
|
):
|
|
assert len(weight) == 5
|
|
assert len(input) == 5
|
|
return conv_output_size(input, weight, bias, stride, padding, dilation, groups)
|
|
|
|
|
|
def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True):
|
|
if dim_post_expr <= 0:
|
|
assert wrap_scalar
|
|
dim_post_expr = 1
|
|
min = -dim_post_expr
|
|
max = dim_post_expr - 1
|
|
assert not (dim < min or dim > max)
|
|
if dim < 0:
|
|
dim += dim_post_expr
|
|
return dim
|
|
|
|
|
|
def zero_dim_tensor(input: Any):
|
|
out: List[int] = []
|
|
return out
|
|
|
|
|
|
def multiply_integers(li: List[int]):
|
|
out = 1
|
|
for elem in li:
|
|
out = out * elem
|
|
return out
|
|
|
|
|
|
def arange_end(end: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any):
|
|
assert end >= 0
|
|
return [int(math.ceil(end))]
|
|
|
|
|
|
def arange_start(
|
|
start: number, end: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any
|
|
):
|
|
assert end >= 0
|
|
assert end >= start
|
|
return [int(math.ceil(end - start))]
|
|
|
|
|
|
def arange_start_step(
|
|
start: number, end: number, step: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any
|
|
):
|
|
assert step != 0
|
|
if step < 0:
|
|
assert start >= end
|
|
else:
|
|
assert end >= start
|
|
return [int(math.ceil((end - start) / step))]
|
|
|
|
|
|
def permute(input: List[int], dims: List[int]):
|
|
assert len(input) == len(dims)
|
|
ndim = len(dims)
|
|
seen_dims: List[int] = []
|
|
newSizes: List[int] = []
|
|
for i in range(ndim):
|
|
dim = maybe_wrap_dim(dims[i], ndim)
|
|
seen_dims.append(dim)
|
|
newSizes.append(input[dim])
|
|
for i in range(1, ndim):
|
|
for j in range(i):
|
|
assert seen_dims[i] != seen_dims[j]
|
|
return newSizes
|
|
|
|
|
|
def movedim(self: List[int], source: List[int], destination: List[int]) -> List[int]:
|
|
self_dim = len(self)
|
|
if self_dim <= 1:
|
|
return self
|
|
normalized_src: List[int] = []
|
|
normalized_dst: List[int] = []
|
|
for i in range(len(source)):
|
|
normalized_src.append(maybe_wrap_dim(source[i], self_dim))
|
|
normalized_dst.append(maybe_wrap_dim(destination[i], self_dim))
|
|
order = [-1 for i in range(self_dim)]
|
|
src_dims = [i for i in range(self_dim)]
|
|
dst_dims = [i for i in range(self_dim)]
|
|
|
|
for i in range(len(source)):
|
|
order[normalized_dst[i]] = normalized_src[i]
|
|
src_dims[normalized_src[i]] = -1
|
|
dst_dims[normalized_dst[i]] = -1
|
|
|
|
source_dims: List[int] = []
|
|
destination_dims: List[int] = []
|
|
for ele in src_dims:
|
|
if ele != -1:
|
|
source_dims.append(ele)
|
|
for ele in dst_dims:
|
|
if ele != -1:
|
|
destination_dims.append(ele)
|
|
|
|
rest_dim = self_dim - len(source)
|
|
for i in range(rest_dim):
|
|
order[destination_dims[i]] = source_dims[i]
|
|
return permute(self, order)
|
|
|
|
|
|
def flatten(input: List[int], start_dim: int, end_dim: int):
|
|
start_dim = maybe_wrap_dim(start_dim, len(input))
|
|
end_dim = maybe_wrap_dim(end_dim, len(input))
|
|
assert start_dim <= end_dim
|
|
if len(input) == 0:
|
|
return [1]
|
|
if start_dim == end_dim:
|
|
# TODO: return self
|
|
out: List[int] = []
|
|
for elem in input:
|
|
out.append(elem)
|
|
return out
|
|
slice_numel = 1
|
|
for i in range(start_dim, end_dim + 1):
|
|
slice_numel *= input[i]
|
|
# TODO: use slicing when slice optimization has landed
|
|
# slice_numel = multiply_integers(input[start_dim:end_dim - start_dim + 1])
|
|
shape: List[int] = []
|
|
for i in range(start_dim):
|
|
shape.append(input[i])
|
|
shape.append(slice_numel)
|
|
for i in range(end_dim + 1, len(input)):
|
|
shape.append(input[i])
|
|
return shape
|
|
|
|
|
|
def nonzero_lower_bound(input: List[int]):
|
|
return [0, len(input)]
|
|
|
|
|
|
def nonzero_upper_bound(input: List[int]):
|
|
return [numel(input), len(input)]
|
|
|
|
|
|
def _reduce_along_dim(self: List[int], dim: int, keepdim: bool):
|
|
dim = maybe_wrap_dim(dim, len(self))
|
|
out: List[int] = []
|
|
for i, self_dim in enumerate(self):
|
|
if i == dim:
|
|
if keepdim:
|
|
out.append(1)
|
|
else:
|
|
out.append(self_dim)
|
|
return out
|
|
|
|
|
|
def argmax(
|
|
self: List[int], dim: Optional[int] = None, keepdim: bool = False
|
|
) -> List[int]:
|
|
if dim is None:
|
|
return []
|
|
return _reduce_along_dim(self, dim, keepdim)
|
|
|
|
|
|
def bmm(self: List[int], mat2: List[int]) -> List[int]:
|
|
assert len(self) == 3, "bmm only supports 3D tensors"
|
|
assert len(mat2) == 3, "bmm only supports 3D tensors"
|
|
assert self[0] == mat2[0], "mismatching batch dimension"
|
|
assert self[2] == mat2[1], "mismatching contracting dimension"
|
|
return [self[0], self[1], mat2[2]]
|
|
|
|
|
|
def _shape_as_tensor(self: List[int]) -> List[int]:
|
|
return [len(self)]
|
|
|
|
|
|
def topk(self: List[int], k: int, dim: int = -1) -> Tuple[List[int], List[int]]:
|
|
if len(self) == 0:
|
|
result: List[int] = []
|
|
else:
|
|
assert (
|
|
k <= self[dim]
|
|
), f"k ({k}) is too big for dimension {dim} of size {self[dim]}"
|
|
result = _copy(self)
|
|
result[dim] = k
|
|
return result, result
|
|
|
|
|
|
def nll_loss_forward(
|
|
self: List[int], target: List[int], weight: Optional[List[int]], reduction: int
|
|
) -> Tuple[List[int], List[int]]:
|
|
# This is taken shamelessly from the meta function in LossNLL.cpp
|
|
self_dim = len(self)
|
|
target_dim = len(target)
|
|
assert 0 < self_dim <= 2
|
|
assert target_dim <= 1
|
|
no_batch_dim = self_dim == 1 and target_dim == 0
|
|
assert no_batch_dim or (self[0] == target[0])
|
|
n_classes = self[-1]
|
|
scalar_shape: List[int] = []
|
|
assert weight is None or (len(weight) == 1 and weight[0] == n_classes)
|
|
if reduction == 0 and self_dim == 2:
|
|
reduction_shape = [self[0]]
|
|
else:
|
|
reduction_shape = scalar_shape
|
|
return reduction_shape, scalar_shape
|
|
|
|
|
|
def native_layer_norm(
|
|
input: List[int], normalized_shape: List[int]
|
|
) -> Tuple[List[int], List[int], List[int]]:
|
|
reduction_shape: List[int] = []
|
|
num_unreduced_dimensions = len(input) - len(normalized_shape)
|
|
assert num_unreduced_dimensions >= 0
|
|
for i in range(num_unreduced_dimensions):
|
|
reduction_shape.append(input[i])
|
|
for i in range(num_unreduced_dimensions, len(input)):
|
|
reduction_shape.append(1)
|
|
return _copy(input), reduction_shape, reduction_shape
|
|
|
|
|
|
def native_batch_norm(
|
|
input: List[int],
|
|
weight: Optional[List[int]],
|
|
bias: Optional[List[int]],
|
|
running_mean: Optional[List[int]],
|
|
running_var: Optional[List[int]],
|
|
training: bool,
|
|
) -> Tuple[List[int], List[int], List[int]]:
|
|
if training:
|
|
_size = [input[1]]
|
|
else:
|
|
_size = [0]
|
|
return _copy(input), _size, _size
|
|
|
|
|
|
def cross_entropy_loss(
|
|
self: List[int],
|
|
target: List[int],
|
|
weight: Optional[List[int]] = None,
|
|
reduction: int = 1,
|
|
ignore_index: int = -100,
|
|
label_smoothing: float = 0.0,
|
|
) -> List[int]:
|
|
result_shape = nll_loss_forward(self, target, weight, reduction)[0]
|
|
return result_shape
|
|
|
|
|
|
"""
|
|
Currently deferring the enabling of this, as part of the propoasal to suspend
|
|
adding ops.
|
|
There are currently cases in the test case where this is being called
|
|
in the SSA opinfo tests with with unexpected values (eg list of two ints, see the first
|
|
opinfo test). The behavoir of index is significantly dependent on the inputs.
|
|
|
|
This could be an error with how we are matching up shape functions, or that this
|
|
function needs to just implement everything.
|
|
|
|
def index_Tensor(self: List[int], indices: List[Optional[List[int]]]) -> List[int]:
|
|
assert len(indices) <= len(self), "More indices than dimensions to index"
|
|
broadcasted_shape: List[int] = []
|
|
for index_tensor_shape in indices:
|
|
if index_tensor_shape is not None:
|
|
broadcasted_shape = broadcast(broadcasted_shape, index_tensor_shape)
|
|
return broadcasted_shape
|
|
"""
|
|
|
|
ScriptFn = torch._C.ScriptFunction
|
|
shape_compute_graph_mapping: Dict[str, ScriptFn] = {}
|
|
bounded_compute_graph_mapping: Dict[str, Tuple[ScriptFn, ScriptFn]] = {}
|
|
script_func_map: Dict[Callable, ScriptFn] = {}
|
|
|
|
|
|
def process_func(func: Callable):
|
|
if func not in script_func_map:
|
|
scripted_func = torch.jit.script(func)
|
|
|
|
torch._C._jit_pass_inline(scripted_func.graph)
|
|
|
|
for _ in range(2):
|
|
torch._C._jit_pass_peephole(scripted_func.graph)
|
|
torch._C._jit_pass_constant_propagation(scripted_func.graph)
|
|
|
|
script_func_map[func] = scripted_func
|
|
return script_func_map[func]
|
|
|
|
|
|
def add_shape_compute_mapping(operator_schema: str, func: Callable):
|
|
global shape_compute_graph_mapping
|
|
|
|
shape_compute_graph_mapping[operator_schema] = process_func(func)
|
|
|
|
|
|
def add_bounded_compute_mapping(
|
|
operator_schema: str, lower_bound_func: Callable, upper_bound_func: Callable
|
|
):
|
|
# Adds a shape compute function for both upper and lower bounds
|
|
fns = (process_func(lower_bound_func), process_func(upper_bound_func))
|
|
bounded_compute_graph_mapping[operator_schema] = fns
|
|
|
|
|
|
add_shape_compute_mapping(
|
|
"aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)",
|
|
unary,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::rsub.Tensor(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", unary
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::dropout(Tensor input, float p, bool train) -> Tensor", unary
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor",
|
|
adaptive_avg_pool2d,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"prim::NumToTensor.Scalar(Scalar a) -> Tensor", zero_dim_tensor
|
|
)
|
|
add_shape_compute_mapping("prim::NumToTensor.bool(bool a) -> Tensor", zero_dim_tensor)
|
|
add_shape_compute_mapping(
|
|
"aten::zeros(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)",
|
|
unary,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))",
|
|
unary,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::arange(Scalar end, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)",
|
|
arange_end,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor",
|
|
arange_start,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor",
|
|
arange_start_step,
|
|
)
|
|
add_shape_compute_mapping("aten::squeeze(Tensor(a) self) -> Tensor(a)", squeeze_nodim)
|
|
add_shape_compute_mapping(
|
|
"aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)", squeeze
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a)", squeeze_dims
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)", unsqueeze
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)",
|
|
slice,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::select.int(Tensor(a) self, int dim, int index) -> Tensor(a)", select
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::index_select(Tensor self, int dim, Tensor index) -> Tensor", index_select
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, "
|
|
"float eps=1e-05, bool cudnn_enable=True) -> Tensor",
|
|
unary,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", unary
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor",
|
|
unary,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!)",
|
|
unary,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor",
|
|
embedding,
|
|
)
|
|
add_shape_compute_mapping("aten::mm(Tensor self, Tensor mat2) -> Tensor", mm)
|
|
add_shape_compute_mapping("aten::dot(Tensor self, Tensor tensor) -> Tensor", dot)
|
|
add_shape_compute_mapping("aten::mv(Tensor self, Tensor vec) -> Tensor", mv)
|
|
add_shape_compute_mapping("aten::matmul(Tensor self, Tensor other) -> Tensor", matmul)
|
|
add_shape_compute_mapping(
|
|
"aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor", linear
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
|
|
max_pool2d,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)",
|
|
max_pool2d_with_indices,
|
|
)
|
|
add_shape_compute_mapping("aten::t(Tensor(a) self) -> Tensor(a)", t)
|
|
add_shape_compute_mapping(
|
|
"aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)", transpose
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::conv1d(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, int[1] padding=0, int[1] dilation=1, int groups=1) -> Tensor",
|
|
conv1d,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor",
|
|
conv2d,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
|
|
batch_norm,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor",
|
|
conv3d,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, int[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)",
|
|
conv_backwards,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor",
|
|
conv_forwards,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor",
|
|
_conv_forwards,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int groups=1, int[2] dilation=1) -> Tensor",
|
|
conv_transpose2d_input,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)",
|
|
flatten,
|
|
)
|
|
add_shape_compute_mapping("aten::cat(Tensor[] tensors, int dim=0) -> Tensor", cat)
|
|
add_shape_compute_mapping("aten::stack(Tensor[] tensors, int dim=0) -> Tensor", stack)
|
|
add_shape_compute_mapping(
|
|
"aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", permute
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)",
|
|
movedim,
|
|
)
|
|
add_shape_compute_mapping("aten::view(Tensor(a) self, int[] size) -> Tensor(a)", view)
|
|
add_shape_compute_mapping(
|
|
"aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", expand
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)",
|
|
expand_one_unused,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor",
|
|
sum_mean_dim,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor",
|
|
sum_mean_dim,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)",
|
|
max_dim,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor",
|
|
addmm,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::upsample_nearest2d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)",
|
|
upsample_nearest2d,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::quantize_per_tensor(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor",
|
|
unary,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::quantize_per_tensor.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype) -> Tensor",
|
|
unary,
|
|
)
|
|
add_shape_compute_mapping("aten::dequantize(Tensor self) -> Tensor", unary)
|
|
add_shape_compute_mapping(
|
|
"quantized::add(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc",
|
|
broadcast,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor", argmax
|
|
)
|
|
add_shape_compute_mapping("aten::bmm(Tensor self, Tensor mat2) -> Tensor", bmm)
|
|
add_shape_compute_mapping(
|
|
"aten::_shape_as_tensor(Tensor self) -> Tensor", _shape_as_tensor
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)",
|
|
topk,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> (Tensor output, Tensor total_weight)",
|
|
nll_loss_forward,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)",
|
|
native_layer_norm,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
|
|
native_batch_norm,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
|
|
native_batch_norm,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
|
|
native_batch_norm,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"_batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)",
|
|
native_batch_norm,
|
|
)
|
|
|
|
add_shape_compute_mapping(
|
|
"aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor",
|
|
cross_entropy_loss,
|
|
)
|
|
# add_shape_compute_mapping("aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor", index_Tensor)
|
|
|
|
# TODO: migrate over all of symbolic_shape_registry_util.cpp
|
|
# These are duplicated here so that the functions will be serialiazed
|
|
add_shape_compute_mapping(
|
|
"aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor",
|
|
broadcast_three,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor",
|
|
broadcast_one_three,
|
|
)
|
|
add_shape_compute_mapping(
|
|
"aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)",
|
|
broadcast_inplace,
|
|
)
|
|
|
|
# quantized_conv_prepack TODO
|
|
|
|
# Shape Compute Fn with upper and lower bounds
|
|
add_bounded_compute_mapping(
|
|
"aten::nonzero(Tensor self) -> (Tensor)", nonzero_lower_bound, nonzero_upper_bound
|
|
)
|