pytorch/torch/_meta_registrations.py
Peter Bell 5817695bfa [pt2] Fix arange to match ATen behavior (#93353)
Fixes #92676

`arange` infers the output dtype from the argument types, but in order to reduce
falling back to ATen, inductor preferred to cast whole number float arguments to
int which gave the wrong output dtype. Instead, this decomposes floating point
arange into the prim equivalent for integers.

This also changes the signature of `prims.arange` to

```python
prims.iota(length, *, start, step, **factory_kwargs)
```

which only supports integers arguments. This is done because calculating the
output size from `start, end, step` is surprisingly complex and liable to off by
one errors so should not be duplicated in each backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93353
Approved by: https://github.com/ngimel, https://github.com/lezcano
2023-02-03 00:44:32 +00:00

2706 lines
83 KiB
Python

import math
from typing import List, Optional, Union
import torch
import torch._prims_common as utils
from torch import Tensor
from torch._decomp import _add_op_to_registry, global_decomposition_table, meta_table
from torch._ops import OpOverload
from torch._prims import _elementwise_meta, ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND
from torch._prims_common import (
check,
corresponding_complex_dtype,
corresponding_real_dtype,
elementwise_dtypes,
ELEMENTWISE_TYPE_PROMOTION_KIND,
IntLike,
make_contiguous_strides_for,
)
from torch._prims_common.wrappers import out_wrapper
from torch._refs import _broadcast_shapes
from torch._subclasses.fake_tensor import check_no_bool_index_tensors
from torch.utils._pytree import tree_map
aten = torch.ops.aten
_meta_lib_dont_use_me_use_register_meta = torch.library.Library("aten", "IMPL", "Meta")
def register_meta(op):
def wrapper(fn):
def register(op):
_add_op_to_registry(meta_table, op, fn)
tree_map(register, op)
return fn
return wrapper
def toRealValueType(dtype):
from_complex = {
torch.complex32: torch.half,
torch.cfloat: torch.float,
torch.cdouble: torch.double,
}
return from_complex.get(dtype, dtype)
@register_meta([aten._fft_c2c.default, aten._fft_c2c.out])
@out_wrapper()
def meta_fft_c2c(self, dim, normalization, forward):
assert self.dtype.is_complex
return self.new_empty(self.size())
@register_meta([aten._fft_r2c.default, aten._fft_r2c.out])
@out_wrapper()
def meta_fft_r2c(self, dim, normalization, onesided):
assert self.dtype.is_floating_point
output_sizes = list(self.size())
if onesided:
last_dim = dim[-1]
last_dim_halfsize = (output_sizes[last_dim] // 2) + 1
output_sizes[last_dim] = last_dim_halfsize
return self.new_empty(
output_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
)
@register_meta(aten.randperm.generator_out)
def meta_randperm(n, *, generator=None, out):
assert out.ndim == 1 and out.size(0) == n
return out
@register_meta(aten.randint.default)
def meta_randint(
high, size, *, dtype=torch.long, layout=None, device=None, pin_memory=None
):
return torch.empty(
size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
)
@register_meta(aten.randint.low)
def meta_randint_low(
low, high, size, *, dtype=torch.long, layout=None, device=None, pin_memory=None
):
return torch.empty(
size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
)
@register_meta(aten.rand.default)
def meta_rand_default(size, *, dtype=None, layout=None, device=None, pin_memory=None):
return torch.empty(
size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
)
@register_meta([aten._fft_c2r.default, aten._fft_c2r.out])
@out_wrapper()
def meta_fft_c2r(self, dim, normalization, lastdim):
assert self.dtype.is_complex
output_sizes = list(self.size())
output_sizes[dim[-1]] = lastdim
return self.new_empty(output_sizes, dtype=toRealValueType(self.dtype))
@register_meta(aten.copy_.default)
def meta_copy_(self, src, non_blocking=False):
return self
def inferUnsqueezeGeometry(tensor, dim):
result_sizes = list(tensor.size())
result_strides = list(tensor.stride())
new_stride = 1 if dim >= tensor.dim() else result_sizes[dim] * result_strides[dim]
result_sizes.insert(dim, 1)
result_strides.insert(dim, new_stride)
return result_sizes, result_strides
@register_meta(aten.unsqueeze_.default)
def meta_unsqueeze_(self, dim):
dim = maybe_wrap_dim(dim, self.dim() + 1)
g_sizes, g_strides = inferUnsqueezeGeometry(self, dim)
self.as_strided_(g_sizes, g_strides)
return self
# Implementations below are taken from https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py
@register_meta(aten.index_select.default)
def meta_index_select(self, dim, index):
result_size = list(self.size())
if self.dim() > 0:
result_size[dim] = index.numel()
return self.new_empty(result_size)
@register_meta(aten.index_select.out)
def meta_index_select_out(self, dim, index, out):
torch._resize_output_(out, self.size(), self.device)
return out.copy_(torch.index_select(self, dim, index))
@register_meta([aten.max.default, aten.max.unary_out])
@out_wrapper()
def meta_max(self):
return self.new_empty(())
@register_meta(aten.max.dim)
def meta_max_dim(self, dim, keepdim=False):
dim = utils.reduction_dims(self.shape, (dim,))
output_shape = _compute_reduction_shape(self, dim, keepdim)
return (
self.new_empty(output_shape),
self.new_empty(output_shape, dtype=torch.long),
)
@register_meta([aten.min.default])
def meta_min(self):
return self.new_empty(())
@register_meta(aten.angle.default)
def meta_angle(self):
if self.is_complex():
result_dtype = corresponding_real_dtype(self.dtype)
else:
_, result_dtype = elementwise_dtypes(
self, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
)
return torch.empty_like(self, dtype=result_dtype)
@register_meta(aten.angle.out)
def meta_angle_out(self, out):
torch._resize_output_(out, self.size(), self.device)
return out.copy_(torch.angle(self))
# From aten/src/ATen/native/LinearAlgebraUtils.h
def squareCheckInputs(self: Tensor, f_name: str):
assert (
self.dim() >= 2
), f"{f_name}: The input tensor must have at least 2 dimensions."
assert self.size(-1) == self.size(
-2
), f"{f_name}: A must be batches of square matrices, but they are {self.size(-2)} by {self.size(-1)} matrices"
# From aten/src/ATen/native/LinearAlgebraUtils.h
def checkFloatingOrComplex(
t: Tensor, f_name: str, allow_low_precision_dtypes: bool = True
):
dtype = t.dtype
check(
t.is_floating_point() or t.is_complex(),
lambda: f"{f_name}, : Expected a floating point or complex tensor as input. Got , {dtype}",
)
if allow_low_precision_dtypes:
check(
dtype in (torch.float, torch.double, torch.cfloat, torch.cdouble),
lambda: f"{f_name} : Low precision dtypes not supported. Got {dtype}",
)
# From aten/src/ATen/native/LinearAlgebraUtils.h
def checkIsMatrix(A: Tensor, f_name: str, arg_name: str = "A"):
check(
A.dim() >= 2,
lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.",
)
def checkUplo(uplo: str):
uplo_uppercase = uplo.upper()
assert (
len(uplo) == 1 and uplo_uppercase == "U" or uplo_uppercase == "L"
), f"Expected UPLO argument to be 'L' or 'U', but got {uplo}"
# @register_meta(aten.linalg_eigh.default)
def meta_linalg_eigh(self, uplo="L"):
squareCheckInputs(self, "linalg_eigh")
checkUplo(uplo)
real_dtype = toRealValueType(self.dtype)
assert self.dim() >= 2
values = self.new_empty(self.shape, dtype=real_dtype)
values.transpose_(-2, -1)
vectors = self.new_empty(self.shape[:-1])
return (values, vectors)
# From aten/src/ATen/native/BatchLinearAlgebra.cpp
@register_meta(aten.linalg_cholesky_ex.default)
def linalg_cholesky_ex(A: Tensor, upper: bool = False, check_errors: bool = False):
squareCheckInputs(A, "linalg.cholesky")
checkFloatingOrComplex(A, "linalg.cholesky")
A_shape = A.shape
ndim = len(A_shape)
# L
L_strides = make_contiguous_strides_for(A_shape, False)
L = A.new_empty(A_shape)
L.as_strided_(A_shape, L_strides)
# infos
infos = A.new_empty(A_shape[0 : ndim - 2], dtype=torch.int32)
return L, infos
# From aten/src/ATen/native/BatchLinearAlgebra.cpp
@register_meta(aten.linalg_inv_ex.default)
def linalg_inv_ex_meta(A: Tensor, check_errors: bool = False):
squareCheckInputs(A, "linalg.inv_ex")
checkFloatingOrComplex(A, "linalg.inv_ex", allow_low_precision_dtypes=False)
L = A.new_empty(A.shape)
L.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False))
infos = A.new_empty(A.shape[:-2], dtype=torch.int32)
return L, infos
# From aten/src/ATen/native/BatchLinearAlgebra.cpp
# NOTE: matching defaults in aten/src/ATen/native/native_functions.yaml
@register_meta(aten._linalg_svd.default)
def _linalg_svd_meta(
A: Tensor, full_matrices: bool = False, compute_uv: bool = True, driver: str = None
):
checkIsMatrix(A, "linalg.svd")
checkFloatingOrComplex(A, "linalg.svd")
batch_dims = list(A.shape[:-2])
m = A.shape[-2]
n = A.shape[-1]
k = min(m, n)
if compute_uv:
U_shape = batch_dims + [m, m if full_matrices else k]
U = A.new_empty(U_shape)
U.as_strided_(U_shape, make_contiguous_strides_for(U_shape, row_major=False))
V_shape = batch_dims + [n if full_matrices else k, n]
V = A.new_empty(V_shape)
# TODO: need to distinguish cuSOLVER case? (see original code)
V.as_strided_(V_shape, make_contiguous_strides_for(V_shape, row_major=False))
else:
# doesn't matter
U = A.new_empty([0])
V = A.new_empty([0])
# S is always real, even when A is complex.
S = A.new_empty(batch_dims + [k], dtype=toRealValueType(A.dtype))
return U, S, V
# From aten/src/ATen/native/LinearAlgebra.cpp
@register_meta(aten._linalg_det.default)
def _linalg_det_meta(A):
squareCheckInputs(A, "linalg.det")
checkFloatingOrComplex(A, "linalg.det")
det = A.new_empty(A.shape[:-2])
LU = A.new_empty(A.shape)
LU.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False))
pivots = A.new_empty(A.shape[:-1], dtype=torch.int32)
return det, LU, pivots
# From aten/src/ATen/native/ReflectionPad.cpp
@register_meta(
[aten.reflection_pad2d_backward.default, aten.replication_pad2d_backward.default]
)
def meta_pad2d_backward(grad_output, self, padding):
dim_w = 2
dim_h = 1
dim_plane = 0
nbatch = 1
self_shape = self.shape
if self.dim() == 4:
nbatch = self_shape[0]
dim_w += 1
dim_h += 1
dim_plane += 1
pad_l = padding[0]
pad_r = padding[1]
pad_t = padding[2]
pad_b = padding[3]
nplane = self_shape[dim_plane]
input_h = self_shape[dim_h]
input_w = self_shape[dim_w]
output_h = input_h + pad_t + pad_b
output_w = input_w + pad_l + pad_r
check(
output_w == grad_output.shape[dim_w],
lambda: f"gradOutput width unexpected. Expected: {output_w}, Got: {grad_output.shape[dim_w]}",
)
check(
output_h == grad_output.shape[dim_h],
lambda: f"gradOutput height unexpected. Expected: {output_h}, Got: {grad_output.shape[dim_h]}",
)
return self.new_empty(self.shape)
@register_meta(aten.reflection_pad2d.default)
def meta_pad2d(self, padding):
valid_dims = self.size(1) != 0 and self.size(2) != 0
check(
(self.ndim == 3 and valid_dims)
or (self.ndim == 4 and valid_dims and self.size(3) != 0),
lambda: f"3D or 4D (batch mode) tensor expected for input, but got: {self}",
)
if self.ndim == 4:
nbatch, nplane, input_h, input_w = self.shape
else:
nbatch = 1
nplane, input_h, input_w = self.shape
pad_l, pad_r, pad_t, pad_b = padding
output_h = input_h + pad_t + pad_b
output_w = input_w + pad_l + pad_r
if self.ndim == 3:
return self.new_empty((nplane, output_h, output_w))
else:
return self.new_empty((nbatch, nplane, output_h, output_w))
@register_meta([aten.bernoulli.default, aten.bernoulli.out])
@out_wrapper()
def meta_bernoulli(self, *, generator=None):
# https://github.com/pytorch/pytorch/issues/88612
return torch.empty_like(self).contiguous()
@register_meta(aten.bernoulli_.float)
def meta_bernoulli_(self, p=0.5, generator=None):
return self
@register_meta(aten.bernoulli.p)
def meta_bernoulli_p(self, p=0.5, generator=None):
# https://github.com/pytorch/pytorch/issues/88612
return torch.empty_like(self).contiguous()
@register_meta(aten._fused_moving_avg_obs_fq_helper.default)
def meta__fused_moving_avg_obs_fq_helper(
self,
observer_on,
fake_quant_on,
running_min,
running_max,
scale,
zero_point,
averaging_const,
quant_min,
quant_max,
ch_axis,
per_row_fake_quant=False,
symmetric_quant=False,
):
check(
ch_axis < self.dim(),
lambda: "Error in fused_moving_avg_obs_fake_quant_cpu: ch_axis must be < self.dim()",
)
mask = torch.empty_like(self, dtype=torch.bool)
return (torch.empty_like(self), mask)
def dot_check(self, other):
check(
self.dim() == 1 and other.dim() == 1,
lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors",
)
@register_meta(aten.dot.default)
def meta_dot(self, tensor):
dot_check(self, tensor)
return self.new_empty(())
@register_meta([aten.mm.default])
def meta_mm(a, b):
check(a.dim() == 2, lambda: "a must be 2D")
check(b.dim() == 2, lambda: "b must be 2D")
N, M1 = a.shape
M2, P = b.shape
check(M1 == M2, lambda: "a and b must have same reduction dim")
return a.new_empty(N, P)
def _compute_reduction_shape(self, dims, keepdim):
if keepdim:
return tuple(self.shape[i] if i not in dims else 1 for i in range(self.ndim))
return utils.compute_reduction_output_shape(self.shape, dims)
# FakeTensors (meta tensors with a device) will report device as meta
# when running meta kernels. Here, access the "fake device" of FakeTensor if it
# exists so meta kernels which have diverge per device will be more
# accurate when run with FakeTensors
def device_hint(tensor) -> "str":
if isinstance(tensor, torch._subclasses.FakeTensor):
return tensor.fake_device.type
else:
return "cuda" # default to cuda
def calc_conv_nd_return_shape(
input_tensor: torch.Tensor,
weight: torch.Tensor,
stride: Union[List[int], int],
padding: Union[List[int], int],
dilation: Union[List[int], int],
is_transposed: bool,
groups: int,
output_padding: Optional[Union[List[int], int]] = None,
):
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
"""
Formula to apply to calculate the length of some dimension of the output
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
Args:
ln: length of the dimension
p: padding in that dim
d: dilation in that dim
k: kernel size in that dim
s: stride in that dim
Returns:
The output length
"""
return (ln + 2 * p - d * (k - 1) - 1) // s + 1
def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int:
"""
Formula to apply to calculate the length of some dimension of the output
if transposed convolution is used.
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
Args:
ln: length of the dimension
p: padding in that dim
d: dilation in that dim
k: kernel size in that dim
s: stride in that dim
op: output padding in that dim
Returns:
The output length
"""
return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1
kernel_size = weight.shape[2:]
dims = input_tensor.shape[2:]
if is_transposed:
out_channels = groups * weight.shape[1]
else:
out_channels = weight.shape[0]
if weight.shape[1] * groups != input_tensor.shape[1]:
raise RuntimeError("Invalid channel dimensions")
ret_shape = [input_tensor.shape[0], out_channels]
if isinstance(stride, IntLike):
stride = [stride] * len(dims)
elif len(stride) == 1:
stride = [stride[0]] * len(dims)
if isinstance(padding, IntLike):
padding = [padding] * len(dims)
elif len(padding) == 1:
padding = [padding[0]] * len(dims)
if isinstance(dilation, IntLike):
dilation = [dilation] * len(dims)
elif len(dilation) == 1:
dilation = [dilation[0]] * len(dims)
output_padding_list: Optional[List[int]] = None
if output_padding:
if isinstance(output_padding, IntLike):
output_padding_list = [output_padding] * len(dims)
elif len(output_padding) == 1:
output_padding_list = [output_padding[0]] * len(dims)
else:
output_padding_list = output_padding
for i in range(len(dims)):
# If output_padding is present, we are dealing with a transposed convolution
if output_padding_list:
ret_shape.append(
_formula_transposed(
dims[i],
padding[i],
dilation[i],
kernel_size[i],
stride[i],
output_padding_list[i],
)
)
else:
ret_shape.append(
_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])
)
return ret_shape
def is_channels_last(ten):
return torch._prims_common.suggest_memory_format(ten) == torch.channels_last
@register_meta(aten.convolution.default)
def meta_conv(
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: List[int],
padding: List[int],
dilation: List[int],
is_transposed: bool,
output_padding: List[int],
groups: int,
):
def pick_memory_format():
if device_hint(input_tensor) == "cuda":
if is_channels_last(input_tensor) or is_channels_last(weight):
return torch.channels_last
else:
if is_channels_last(input_tensor):
return torch.channels_last
if input_tensor.is_contiguous(memory_format=torch.contiguous_format):
return torch.contiguous_format
elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
return torch.preserve_format
shape_out = calc_conv_nd_return_shape(
input_tensor,
weight,
stride,
padding,
dilation,
is_transposed,
groups,
output_padding if is_transposed else None,
)
out = input_tensor.new_empty(shape_out)
out = out.to(memory_format=pick_memory_format()) # type: ignore[call-overload]
return out
if torch._C.has_mkldnn:
_meta_lib_dont_use_me_use_register_meta_for_mkldnn = torch.library.Library(
"mkldnn", "IMPL", "Meta"
)
def pick_mkldnn_conv_memory_format(input_tensor, weight):
if weight.is_mkldnn:
return torch.channels_last
if is_channels_last(input_tensor) or is_channels_last(weight):
return torch.channels_last
if input_tensor.is_contiguous(memory_format=torch.contiguous_format):
return torch.contiguous_format
elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
return torch.preserve_format
@register_meta(torch.ops.mkldnn._convolution_pointwise.default)
def meta_mkldnn_convolution_default(
input_tensor,
weight,
bias,
padding,
stride,
dilation,
groups,
attr,
scalars,
algorithm,
):
shape_out = calc_conv_nd_return_shape(
input_tensor, weight, stride, padding, dilation, False, groups, []
)
out = input_tensor.new_empty(shape_out)
out_memory_format = torch.channels_last
out = out.to(memory_format=out_memory_format) # type: ignore[call-overload]
return out
@register_meta(torch.ops.mkldnn._convolution_pointwise.binary)
def meta_mkldnn_convolution_binary(
input_tensor,
other,
weight,
bias,
padding,
stride,
dilation,
groups,
binary_attr,
alpha,
unary_attr,
unary_scalars,
unary_algorithm,
):
out = input_tensor.new_empty(other.size())
out = out.to(memory_format=torch.channels_last) # type: ignore[call-overload]
return out
@register_meta(torch.ops.mkldnn._convolution_pointwise_.binary)
def meta_mkldnn_convolution_binary_inplace(
input_tensor,
other,
weight,
bias,
padding,
stride,
dilation,
groups,
binary_attr,
alpha,
unary_attr,
unary_scalars,
unary_algorithm,
):
return other
@register_meta(torch.ops.mkldnn._linear_pointwise.default)
def meta_linear_pointwise_default(
input_tensor, weight, bias, attr, scalars, algorithm
):
return input_tensor.new_empty((*input_tensor.shape[:-1], weight.shape[0]))
@register_meta(torch.ops.mkldnn._linear_pointwise.binary)
def meta_linear_pointwise_binary(input_tensor, other, weight, bias, attr):
out = input_tensor.new_empty(other.size())
return out
if torch._C.has_mkl:
_meta_lib_dont_use_me_use_register_meta_for_mkl = torch.library.Library(
"mkl", "IMPL", "Meta"
)
@register_meta(torch.ops.mkl._mkl_linear)
def meta_mkl_linear(
input_tensor,
packed_weight,
orig_weight,
bias,
batch_size,
):
return input_tensor.new_empty(
(*input_tensor.shape[:-1], orig_weight.shape[0])
)
# from check_dim_size() in aten/src/ATen/TensorUtils.cpp.
def check_dim_size(tensor, dim, dim_size, size):
check(
tensor.dim() == dim and tensor.shape[dim_size] == size,
lambda: f"Expected a tensor of dimension {dim} and tensor.size[{dim_size}] == {size}, "
+ f"but got : dimension {tensor.dim()} and tensor.size[{dim_size}] = {tensor.shape[dim_size]}",
)
@register_meta(aten.avg_pool2d.default)
def meta_avg_pool2d(
input,
kernel_size,
stride=(),
padding=(0,),
ceil_mode=False,
count_include_pad=True,
divisor_override=None,
):
def unpack(name, val):
check(
len(val) in [1, 2],
lambda: f"avg_pool2d: {name} must either be a single int, or a tuple of two ints",
)
H = val[0]
W = H if len(val) == 1 else val[1]
return H, W
kH, kW = unpack("kernel_size", kernel_size)
check(
len(stride) in [0, 1, 2],
lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
)
if len(stride) == 0:
dH, dW = kH, kW
elif len(stride) == 1:
dH, dW = stride[0], stride[0]
else:
dH, dW = unpack("stride", stride)
padH, padW = unpack("padding", padding)
check(
divisor_override is None or divisor_override != 0,
lambda: "divisor must be not zero",
)
nbatch = input.size(-4) if input.dim() == 4 else 1
nInputPlane = input.size(-3)
inputHeight = input.size(-2)
inputWidth = input.size(-1)
outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
memory_format = utils.suggest_memory_format(input)
pool2d_shape_check(
input,
kH,
kW,
dH,
dW,
padH,
padW,
1,
1,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
memory_format,
)
if input.dim() == 3:
size = [nInputPlane, outputHeight, outputWidth]
else:
size = [nbatch, nInputPlane, outputHeight, outputWidth]
return torch.empty(
size, dtype=input.dtype, device=input.device, memory_format=memory_format
)
# from avg_pool2d_backward_shape_check() in aten/src/ATen/native/Pool.h.
def avg_pool2d_backward_shape_check(
input,
gradOutput,
nbatch,
kH,
kW,
dH,
dW,
padH,
padW,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
mem_format,
):
pool2d_shape_check(
input,
kH,
kW,
dH,
dW,
padH,
padW,
1,
1,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
mem_format,
)
ndim = input.dim()
nOutputPlane = nInputPlane
check_dim_size(gradOutput, ndim, ndim - 3, nOutputPlane)
check_dim_size(gradOutput, ndim, ndim - 2, outputHeight)
check_dim_size(gradOutput, ndim, ndim - 1, outputWidth)
# Don't override the C++ registration.
@register_meta(aten.avg_pool2d_backward.default)
def meta_avg_pool2d_backward(
gradOutput_,
input,
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
divisor_override,
):
# From aten/src/ATen/native/AveragePool2d.cpp structured kernel meta func.
check(
len(kernel_size) == 1 or len(kernel_size) == 2,
lambda: "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints",
)
kH = kernel_size[0]
kW = kH if len(kernel_size) == 1 else kernel_size[1]
check(
len(stride) == 0 or len(stride) == 1 or len(stride) == 2,
lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
)
dH = kH if len(stride) == 0 else stride[0]
dW = kW if len(stride) == 0 else dH if len(stride) == 1 else stride[1]
check(
len(padding) == 1 or len(padding) == 2,
lambda: "avg_pool2d: padding must either be a single int, or a tuple of two ints",
)
padH = padding[0]
padW = padH if len(padding) == 1 else padding[1]
check(
divisor_override is None or divisor_override != 0,
lambda: "divisor must be not zero",
)
input_size = input.shape
nbatch = input_size[-4] if input.dim() == 4 else 1
nInputPlane = input_size[-3]
inputHeight = input_size[-2]
inputWidth = input_size[-1]
outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
mem_format = utils.suggest_memory_format(input)
avg_pool2d_backward_shape_check(
input,
gradOutput_,
nbatch,
kH,
kW,
dH,
dW,
padH,
padW,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
mem_format,
)
return torch.empty(
input_size, dtype=input.dtype, device=input.device, memory_format=mem_format
)
@register_meta(aten._adaptive_avg_pool2d.default)
def meta_adaptive_avg_pool2d(self, output_size):
check(
self.ndim == 3 or self.ndim == 4,
lambda: f"Expected 3D or 4D tensor, but got {self.shape}",
)
output_shape = self.shape[:-2] + tuple(output_size)
memory_format = utils.suggest_memory_format(self)
# need to set memory_format to preserve the memory format of the input
# channel last input should have channel last output
return torch.empty(
output_shape, dtype=self.dtype, device=self.device, memory_format=memory_format
)
@register_meta(aten._adaptive_avg_pool3d.default)
def meta_adaptive_avg_pool3d(self, output_size):
check(
self.ndim == 4 or self.ndim == 5,
lambda: f"Expected 4D or 5D tensor, but got {self.shape}",
)
return self.new_empty(self.shape[:-3] + tuple(output_size))
@register_meta(aten._adaptive_avg_pool2d_backward.default)
def meta__adaptive_avg_pool2d_backward(grad_out, self):
ndim = grad_out.ndim
for i in range(1, ndim):
check(
grad_out.size(i) > 0,
lambda: f"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero \
size for non-batch dimensions, {grad_out.shape} with dimension {i} being empty",
)
check(
ndim == 3 or ndim == 4,
lambda: f"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got {self.shape}",
)
check(
self.dtype == grad_out.dtype,
lambda: f"expected dtype {self.dtype} for `grad_output` but got dtype {grad_out.dtype}",
)
return self.new_empty(self.shape)
@register_meta(aten.repeat_interleave.Tensor)
def meta_repeat_interleave_Tensor(repeats, output_size=None):
if output_size is None:
raise RuntimeError("cannot repeat_interleave a meta tensor without output_size")
return repeats.new_empty(output_size)
@register_meta([aten.complex.default, aten.complex.out])
@out_wrapper()
def meta_complex(real, imag):
assert real.dtype.is_floating_point
assert imag.dtype.is_floating_point
out_shape = _broadcast_shapes(real.shape, imag.shape)
return real.new_empty(out_shape, dtype=corresponding_complex_dtype(real.dtype))
@register_meta(aten.vdot.default)
def vdot(self, other):
if not self.is_complex:
return torch.dot(self, other)
if self.is_conj():
if other.is_conj():
return torch.vdot(other.conj(), self.conj())
else:
return torch.dot(self.conj(), other)
elif other.is_conj():
return torch.dot(self, other.conj()).conj()
dot_check(self, other)
return self.new_empty(())
# Leaving this function around because a python implementation
# of indexing shape inference is useful,
# but not registering it to the dispatcher because we already
# get shape inference through structured kernels
@register_meta(aten.index.Tensor)
def meta_index_Tensor(self, indices):
check_no_bool_index_tensors(aten.index.Tensor, self, indices)
check(indices, lambda: "at least one index must be provided")
# aten::index is the internal advanced indexing implementation
# checkIndexTensorTypes and expandTensors
result: List[Optional[Tensor]] = []
for i, index in enumerate(indices):
if index is not None:
check(
index.dtype in [torch.long, torch.int, torch.int8, torch.bool],
lambda: "tensors used as indices must be long, int, byte or bool tensors",
)
if index.dtype in [torch.int8, torch.bool]:
nonzero = index.nonzero()
k = len(result)
check(
k + index.ndim <= self.ndim,
lambda: f"too many indices for tensor of dimension {self.ndim}",
IndexError,
)
for j in range(index.ndim):
check(
index.shape[j] == self.shape[k + j],
lambda: f"The shape of the mask {index.shape} at index {i} "
f"does not match the shape of the indexed tensor {self.shape} at index {k + j}",
IndexError,
)
result.append(nonzero.select(1, j))
else:
result.append(index)
else:
result.append(index)
indices = result
check(
len(indices) <= self.ndim,
lambda: f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})",
)
# expand_outplace
import torch._refs as refs # avoid import cycle in mypy
indices = list(refs._maybe_broadcast(*indices))
# add missing null tensors
while len(indices) < self.ndim:
indices.append(None)
# hasContiguousSubspace
# true if all non-null tensors are adjacent
# See:
# https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
# https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
state = 0
has_contiguous_subspace = False
for index in indices:
if state == 0:
if index is not None:
state = 1
elif state == 1:
if index is None:
state = 2
else:
if index is not None:
break
else:
has_contiguous_subspace = True
# transposeToFront
# This is the logic that causes the newly inserted dimensions to show up
# at the beginning of the tensor, if they're not contiguous
if not has_contiguous_subspace:
dims = []
transposed_indices = []
for i, index in enumerate(indices):
if index is not None:
dims.append(i)
transposed_indices.append(index)
for i, index in enumerate(indices):
if index is None:
dims.append(i)
transposed_indices.append(index)
self = self.permute(dims)
indices = transposed_indices
# AdvancedIndex::AdvancedIndex
# Now we can assume the indices have contiguous subspace
# This is simplified from AdvancedIndex which goes to more effort
# to put the input and indices in a form so that TensorIterator can
# take them. If we write a ref for this, probably that logic should
# get implemented
before_shape: List[int] = []
after_shape: List[int] = []
replacement_shape: List[int] = []
for dim, index in enumerate(indices):
if index is None:
if replacement_shape:
after_shape.append(self.shape[dim])
else:
before_shape.append(self.shape[dim])
else:
replacement_shape = list(index.shape)
return self.new_empty(before_shape + replacement_shape + after_shape)
@register_meta([aten.convolution_backward.default])
def meta_convolution_backward(
grad_output_,
input_,
weight_,
bias_sizes_opt,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
output_mask,
):
# High level logic taken from slow_conv3d_backward_cpu which should
# be representative of all convolution_backward impls
backend_grad_input = None
backend_grad_weight = None
backend_grad_bias = None
if output_mask[0]:
backend_grad_input = grad_output_.new_empty(input_.size())
if output_mask[1]:
backend_grad_weight = grad_output_.new_empty(weight_.size())
if output_mask[2]:
backend_grad_bias = grad_output_.new_empty(bias_sizes_opt)
return (backend_grad_input, backend_grad_weight, backend_grad_bias)
@register_meta([aten.addbmm.default, aten.addbmm.out])
@out_wrapper()
def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1):
dim1 = batch1.size(1)
dim2 = batch2.size(2)
self = self.expand((dim1, dim2))
check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
check(
batch1.size(0) == batch2.size(0),
lambda: f"batch1 and batch2 must have same number of batches, got {batch1.size(0)} and {batch2.size(0)}",
)
check(
batch1.size(2) == batch2.size(1),
lambda: (
f"Incompatible matrix sizes for bmm ({batch1.size(1)}x{batch1.size(2)} "
f"and {batch2.size(1)}x{batch2.size(2)})"
),
)
check(
self.size(0) == dim1 and self.size(1) == dim2,
lambda: "self tensor does not match matmul output shape",
)
return self.new_empty(self.size())
@register_meta(aten._cdist_forward.default)
def meta_cdist_forward(x1, x2, p, compute_mode):
check(
x1.dim() >= 2,
lambda: f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D",
)
check(
x2.dim() >= 2,
lambda: f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D",
)
check(
x1.size(-1) == x2.size(-1),
lambda: f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}",
)
check(
utils.is_float_dtype(x1.dtype),
lambda: "cdist only supports floating-point dtypes, X1 got: {x1.dtype}",
)
check(
utils.is_float_dtype(x2.dtype),
lambda: "cdist only supports floating-point dtypes, X2 got: {x2.dtype}",
)
check(p >= 0, lambda: "cdist only supports non-negative p values")
check(
compute_mode in (None, 1, 2),
lambda: f"possible modes: None, 1, 2, but was: {compute_mode}",
)
r1 = x1.size(-2)
r2 = x2.size(-2)
batch_tensor1 = x1.shape[:-2]
batch_tensor2 = x2.shape[:-2]
output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2))
output_shape.extend([r1, r2])
return x1.new_empty(output_shape)
@register_meta(aten._embedding_bag.default)
def meta_embedding_bag(
weight,
indices,
offsets,
scale_grad_by_freq=False,
mode=0,
sparse=False,
per_sample_weights=None,
include_last_offset=False,
padding_idx=-1,
):
check(
indices.dtype in (torch.long, torch.int),
lambda: f"expected indices to be long or int, got {indices.dtype}",
)
check(
offsets.dtype in (torch.long, torch.int),
lambda: f"expected offsets to be long or int, got {offsets.dtype}",
)
check(
utils.is_float_dtype(weight.dtype),
lambda: f"expected weight to be floating point type, got {weight.dtype}",
)
num_bags = offsets.size(0)
if include_last_offset:
check(
num_bags >= 1, lambda: "include_last_offset: numBags should be at least 1"
)
num_bags -= 1
output = weight.new_empty(num_bags, weight.size(1))
MODE_SUM, MODE_MEAN, MODE_MAX = range(3)
if per_sample_weights is not None:
check(
mode == MODE_SUM,
lambda: "embedding_bag: per_sample_weights only supported with mode='sum'",
)
check(
per_sample_weights.dtype == weight.dtype,
lambda: f"expected weight ({weight.dtype}) and per_sample_weights ({per_sample_weights.dtype}) to have same dtype",
)
check(
per_sample_weights.ndim == 1,
lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D",
)
check(
per_sample_weights.numel() == indices.numel(),
lambda: (
f"expected per_sample_weights.numel() ({per_sample_weights.numel()} "
f"to be the same as indices.numel() ({indices.numel()})"
),
)
def is_fast_path_index_select_scale(src, scale, output, padding_idx):
return (
is_fast_path_index_select(src, output, padding_idx) and scale.stride(0) == 1
)
def is_fast_path_index_select(src, output, padding_idx):
return (
(src.dtype == torch.float or src.dtype == torch.half)
and src.stride(1) == 1
and output.stride(1) == 1
and padding_idx < 0
)
def is_fast_path(src, scale, output, padding_idx):
if scale is not None:
return is_fast_path_index_select_scale(src, scale, output, padding_idx)
else:
return is_fast_path_index_select(src, output, padding_idx)
if device_hint(offsets) != "cpu":
offset2bag = indices.new_empty(indices.size(0))
bag_size = indices.new_empty(offsets.size())
if mode == MODE_MAX:
max_indices = indices.new_empty(num_bags, weight.size(1))
else:
max_indices = indices.new_empty(0)
else:
fast_path_sum = is_fast_path(weight, per_sample_weights, output, padding_idx)
if mode == MODE_MEAN or mode == MODE_MAX or not fast_path_sum:
offset2bag = offsets.new_empty(indices.size(0))
else:
offset2bag = offsets.new_empty(0)
bag_size = offsets.new_empty(num_bags)
# This part of the logic comes from make_max_indices_out in EmbeddingBag.cpp
numBags = offsets.shape[0]
if mode == MODE_MAX:
if include_last_offset:
check(
numBags >= 1,
lambda: "include_last_offset: numBags should be at least 1",
)
numBags -= 1
max_indices = offsets.new_empty(numBags, weight.shape[1])
else:
max_indices = offsets.new_empty(bag_size.size())
return output, offset2bag, bag_size, max_indices
@register_meta(aten._embedding_bag_forward_only.default)
def meta_embedding_bag_forward_only(weight, indices, offsets, *args):
output, offset2bag, bag_size, max_indices = meta_embedding_bag(
weight, indices, offsets, *args
)
if device_hint(offsets) == "cpu":
bag_size = offsets.new_empty(offsets.size())
return output, offset2bag, bag_size, max_indices
def _get_reduction_dtype(input, dtype, promote_int_to_long=True):
# if specified, dtype takes precedence
if dtype:
return dtype
if input.dtype.is_floating_point or input.dtype.is_complex:
return input.dtype
elif promote_int_to_long:
return torch.long
return input.dtype
@register_meta([aten.nansum.default, aten.nansum.out])
@out_wrapper()
def meta_nansum(input, dims=None, keepdim=False, *, dtype=None):
output_dtype = _get_reduction_dtype(input, dtype, promote_int_to_long=True)
dims = utils.reduction_dims(input.shape, dims)
output_shape = _compute_reduction_shape(input, dims, keepdim)
return input.new_empty(output_shape, dtype=output_dtype)
@register_meta(aten.nanmedian.default)
def meta_nanmedian(input):
output_shape = utils.compute_reduction_output_shape(
input.shape, tuple(range(input.dim()))
)
return input.new_empty(output_shape)
@register_meta([aten.nanmedian.dim, aten.nanmedian.dim_values])
@out_wrapper("values", "indices")
def meta_nanmedian_dim(input, dim=-1, keepdim=False):
dim = utils.reduction_dims(input.shape, (dim,))
output_shape = _compute_reduction_shape(input, dim, keepdim)
return (
input.new_empty(output_shape),
input.new_empty(output_shape, dtype=torch.long),
)
@register_meta(aten.logical_not_.default)
def meta_logical_not_(self):
return self
@register_meta(aten.repeat.default)
def meta_repeat(self, repeats):
check(
len(repeats) >= self.dim(),
lambda: "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
)
# Add new leading dimensions to the tensor if the
# number of target dimensions is larger than the
# number of source dimensions.
num_new_dimensions = len(repeats) - self.dim()
padded_size = (1,) * num_new_dimensions + tuple(self.shape)
target_size = [padded_size[i] * repeats[i] for i in range(len(repeats))]
return self.new_empty(target_size)
@register_meta(aten.zero_.default)
def meta_zero_(self):
return self
@register_meta(
[
aten.mul_.Scalar,
aten.div_.Scalar,
aten.mul_.Tensor,
aten.div_.Tensor,
aten.logical_and_.default,
aten.logical_or_.default,
aten.logical_xor_.default,
],
)
def meta_binop_inplace(self, other):
return self
@register_meta(
[
aten.add_.Scalar,
aten.sub_.Scalar,
aten.add_.Tensor,
aten.sub_.Tensor,
],
)
def meta_binop_inplace_alpha(self, other, alpha=1):
return self
@register_meta([aten.round.default, aten.round.decimals])
def meta_round(self, **kwargs):
return _elementwise_meta(
self, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT
)
@register_meta(aten.zero.default)
def meta_zero(self):
return self.new_empty(self.shape)
@register_meta([aten.fill_.Tensor, aten.fill_.Scalar])
def meta_fill_(self, val):
return self
@register_meta([aten.fill.Tensor, aten.fill.Scalar])
def meta_fill(self, val):
return torch.empty_like(self)
@register_meta(aten.relu_.default)
def meta_relu_(self):
return self
@register_meta(aten.index_put.default)
def meta_index_put(self, indices, values, accumulate=False):
return torch.empty_like(self)
@register_meta(aten.masked_fill_.Scalar)
def meta_masked_fill_(self, mask, value):
return self
@register_meta(aten.index_put_.default)
def meta_index_put_(self, indices, values, accumulate=False):
return self
@register_meta(aten.alias.default)
def meta_alias(self):
return self.view(self.shape)
def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None):
check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
batch1_sizes = batch1.size()
batch2_sizes = batch2.size()
bs = batch1_sizes[0]
contraction_size = batch1_sizes[2]
res_rows = batch1_sizes[1]
res_cols = batch2_sizes[2]
output_size = (bs, res_rows, res_cols)
check(
batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size,
lambda: f"Expected size for first two dimensions of batch2 tensor to be: [{bs}"
f", {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}].",
)
# TODO: handle out
output = batch2.new_empty(output_size)
if not is_bmm and self_baddbmm is not None:
check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor")
check(
self_baddbmm.size() == output_size,
lambda: "Expected an input tensor shape with shape {output_size} but got shape: {self.size()}",
)
return output
@register_meta(aten.bmm.default)
def meta_bmm(self, mat2):
return common_meta_baddbmm_bmm(self, mat2, True)
def div_rtn(x, y):
q = x // y
r = x % y
# WARNING: explicit bool conversion here is necessary;
# would be fixed by SymBool
if r != 0 and (bool(r < 0) != bool(y < 0)):
q -= 1
return q
def pooling_output_shape_pad_lr(
inputSize, kernelSize, pad_l, pad_r, stride, dilation, ceil_mode
):
outputSize = (
div_rtn(
inputSize
+ pad_l
+ pad_r
- dilation * (kernelSize - 1)
- 1
+ (stride - 1 if ceil_mode else 0),
stride,
)
+ 1
)
if ceil_mode:
if (outputSize - 1) * stride >= inputSize + pad_l:
outputSize -= 1
return outputSize
def pooling_output_shape(inputSize, kernelSize, pad, stride, dilation, ceil_mode):
check(stride != 0, lambda: "stride should not be zero")
check(pad >= 0, lambda: f"pad must be non-negative, but got pad: {pad}")
check(
pad <= kernelSize // 2,
lambda: f"pad should be at most half of kernel size, but got pad={pad} and kernel_size={kernelSize}",
)
return pooling_output_shape_pad_lr(
inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode
)
def pool2d_shape_check(
input,
kH,
kW,
dH,
dW,
padH,
padW,
dilationH,
dilationW,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
memory_format,
):
ndim = input.dim()
nOutputPlane = nInputPlane
check(
kW > 0 and kH > 0,
lambda: "kernel size should be greater than zero, but got kH: {kH}, kW: {kW}",
)
check(
dW > 0 and dH > 0,
lambda: "stride should be greater than zero, but got dH: {dH}, dW: {dW}",
)
check(
dilationH > 0 and dilationW > 0,
lambda: "dilation should be greater than zero, but got dilationH: {dilationH}, dilationW: {dilationW}",
)
valid_dims = input.size(1) != 0 and input.size(2) != 0
if memory_format == torch.channels_last:
check(
ndim == 4 and valid_dims and input.size(3) != 0,
lambda: "Expected 4D (batch mode) tensor expected for input with channels_last layout"
" with optional 0 dim batch size for input, but got: {input.size()}",
)
else:
check(
(ndim == 3 and input.size(0) != 0 and valid_dims)
or (ndim == 4 and valid_dims and input.size(3) != 0),
lambda: f"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got: {input.size()}",
)
check(
kW // 2 >= padW and kH // 2 >= padH,
lambda: "pad should be smaller than or equal to half of kernel size, but got "
f"padW = {padW}, padH = {padH}, kW = {kW}, kH = {kH}",
)
check(
outputWidth >= 1 and outputHeight >= 1,
lambda: f"Given input size: ({nInputPlane}x{inputHeight}x{inputWidth}). "
f"Calculated output size: ({nOutputPlane}x{outputHeight}x{outputWidth}). "
"Output size is too small",
)
def max_pool2d_checks_and_compute_shape(
input, kernel_size, stride, padding, dilation, ceil_mode
):
# Reference: aten/src/ATen/native/DilatedMaxPool2d.cpp
def unpack(name, val):
check(
len(val) in [1, 2],
lambda: f"max_pool2d: {name} must either be a single int, or a tuple of two ints",
)
H = val[0]
W = H if len(val) == 1 else val[1]
return H, W
kH, kW = unpack("kernel_size", kernel_size)
check(
len(stride) in [0, 1, 2],
lambda: "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
)
if len(stride) == 0:
dH, dW = kH, kW
else:
dH, dW = unpack("stride", stride)
padH, padW = unpack("padding", padding)
dilationH, dilationW = unpack("dilation", dilation)
nInputPlane = input.size(-3)
inputHeight = input.size(-2)
inputWidth = input.size(-1)
memory_format = utils.suggest_memory_format(input)
if memory_format == torch.channels_last:
check(
input.dim() == 4,
lambda: "non-empty 4D (batch mode) tensor expected for input with channels_last layout",
)
elif memory_format == torch.contiguous_format:
check(
input.dim() in [3, 4],
lambda: "non-empty 3D or 4D (batch mode) tensor expected for input",
)
else:
check(
False,
lambda: "Unsupport memory format. Supports only ChannelsLast, Contiguous",
)
outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode)
outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode)
pool2d_shape_check(
input,
kH,
kW,
dH,
dW,
padH,
padW,
dilationH,
dilationW,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
memory_format,
)
return nInputPlane, outputHeight, outputWidth
@register_meta(aten.max_pool2d_with_indices_backward.default)
def meta_max_pool2d_with_indices_backward(
grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices
):
nInputPlane, outputHeight, outputWidth = max_pool2d_checks_and_compute_shape(
self, kernel_size, stride, padding, dilation, ceil_mode
)
check(
self.dtype == grad_output.dtype,
lambda: "expected dtype {self.dtype} for `gradOutput` but got dtype {grad_output.dtype}",
)
nOutputPlane = nInputPlane
ndim = self.ndim
def _check_dim_size(t):
check_dim_size(t, ndim, ndim - 3, nOutputPlane)
check_dim_size(t, ndim, ndim - 2, outputHeight)
check_dim_size(t, ndim, ndim - 1, outputWidth)
_check_dim_size(grad_output)
_check_dim_size(indices)
memory_format = utils.suggest_memory_format(self)
return torch.empty(
self.shape, dtype=self.dtype, device=self.device, memory_format=memory_format
)
@register_meta(aten.max_pool2d_with_indices.default)
def meta_max_pool2d_with_indices(
input, kernel_size, stride=(), padding=(0,), dilation=(1,), ceil_mode=False
):
nInputPlane, outputHeight, outputWidth = max_pool2d_checks_and_compute_shape(
input, kernel_size, stride, padding, dilation, ceil_mode
)
nbatch = input.size(-4) if input.dim() == 4 else 1
memory_format = utils.suggest_memory_format(input)
if input.dim() == 3:
size = [nInputPlane, outputHeight, outputWidth]
else:
size = [nbatch, nInputPlane, outputHeight, outputWidth]
return (
torch.empty(
size, dtype=input.dtype, device=input.device, memory_format=memory_format
),
torch.empty(
size, dtype=torch.int64, device=input.device, memory_format=memory_format
),
)
@register_meta(aten.grid_sampler_2d_backward.default)
def grid_sampler_2d_backward_meta(
grad_output,
input,
grid,
interpolation_mode,
padding_mode,
align_corners,
output_mask,
):
input_requires_grad = output_mask[0]
if input_requires_grad:
grad_input = torch.zeros_like(input, memory_format=torch.contiguous_format)
else:
grad_input = None
grad_grid = torch.empty_like(grid, memory_format=torch.contiguous_format)
return (grad_input, grad_grid)
@register_meta([aten.full.default])
def full(size, fill_value, *args, **kwargs):
return torch.empty(size, *args, **kwargs)
@register_meta(
[
aten.randint_like.default,
aten.randint_like.low_dtype,
aten.randn_like.default,
aten.rand_like.default,
aten.full_like.default,
aten.ones_like.default,
]
)
def meta_like(self, *args, **kwargs):
return aten.empty_like.default(self, **kwargs)
# zeros_like is special cased to work for sparse
@register_meta(aten.zeros_like.default)
def zeros_like(
self, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None
):
if layout == torch.sparse_coo:
check(
memory_format is None,
lambda: "memory format option is only supported by strided tensors",
)
res = torch.empty(
0,
dtype=self.dtype if dtype is None else dtype,
layout=layout,
device=self.device if device is None else device,
pin_memory=pin_memory,
)
if self.is_sparse:
res.sparse_resize_and_clear_(
self.size(), self.sparse_dim(), self.dense_dim()
)
else:
res.sparse_resize_and_clear_(self.size(), self.dim(), 0)
res._coalesced_(True)
return res
return aten.empty_like.default(
self,
dtype=dtype,
layout=layout,
device=device,
pin_memory=pin_memory,
memory_format=memory_format,
)
@register_meta(aten.select.int)
def meta_select(self, dim, index):
ndim = self.dim()
check(
ndim != 0, lambda: "select() cannot be applied to a 0-dim tensor.", IndexError
)
dim = dim if dim >= 0 else dim + ndim
size = self.size(dim)
check(
not (-index > size or index >= size),
lambda: f"select(): index {index} out of range for tensor of size "
f"{self.size()} at dimension {dim}",
IndexError,
)
index = index if index >= 0 else index + size
new_size = list(self.size())
new_stride = list(self.stride())
new_storage_offset = self.storage_offset() + index * new_stride[dim]
del new_size[dim]
del new_stride[dim]
return self.as_strided(new_size, new_stride, new_storage_offset)
@register_meta(aten.select_scatter.default)
def meta_select_scatter(self, src, dim, index):
return utils.clone_preserve_strides(self)
@register_meta(aten.slice_scatter.default)
def meta_slice_scatter(self, src, dim=0, start=None, end=None, step=1):
return utils.clone_preserve_strides(self)
# TODO: Deduplicate this with canonicalize_dim
def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True):
if dim_post_expr <= 0:
assert wrap_scalar
dim_post_expr = 1
min = -dim_post_expr
max = dim_post_expr - 1
assert not (dim < min or dim > max), f"dim {dim} out of bounds ({min}, {max})"
if dim < 0:
dim += dim_post_expr
return dim
def ensure_nonempty_size(t, dim):
return 1 if t.dim() == 0 else t.shape[dim]
# From aten/src/ATen/native/ScatterGatherChecks.h
def gather_shape_check(self, dim, index):
self_dims = max(self.dim(), 1)
index_dims = max(index.dim(), 1)
check(
self_dims == index_dims,
lambda: "Index tensor must have the same number of dimensions as input tensor",
)
for i in range(self_dims):
if i != dim:
check(
ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i),
lambda: f"Size does not match at dimension {i} expected index {index.shape}"
+ f" to be smaller than self {self.shape} apart from dimension {dim}",
)
@register_meta(aten.gather.default)
def meta_gather(self, dim, index, sparse_grad=False):
wrapped_dim = maybe_wrap_dim(dim, self.dim())
is_index_empty = index.numel() == 0
if not is_index_empty:
check(
index.dtype == torch.long,
lambda: f"gather(): Expected dtype int64 for index, but got {index.dtype}",
)
gather_shape_check(self, wrapped_dim, index)
return self.new_empty(index.shape)
# From aten/src/ATen/native/TensorAdvancedIndexing.cpp
def get_operator_enum(reduce_, use_new_options=False):
if use_new_options:
if reduce_ == "sum":
return "REDUCE_ADD"
elif reduce_ == "prod":
return "REDUCE_MULTIPLY"
elif reduce_ == "mean":
return "REDUCE_MEAN"
elif reduce_ == "amax":
return "REDUCE_MAXIMUM"
elif reduce_ == "amin":
return "REDUCE_MINIMUM"
check(
False,
lambda: "reduce argument must be either sum, prod, mean, amax or amin.",
)
return
else:
if reduce_ == "add":
return "REDUCE_ADD"
elif reduce_ == "multiply":
return "REDUCE_MULTIPLY"
check(False, lambda: "reduce argument must be either add or multiply.")
return
# From aten/src/ATen/native/ScatterGatherChecks.h
def scatter_gather_dtype_check(method_name, self, index, src_opt=None):
if index.numel() != 0:
check(
index.dtype == torch.long,
lambda: f"{method_name}(): Expected dtype int64 for index",
)
if src_opt is not None:
check(
self.dtype == src_opt.dtype,
lambda: f"{method_name}(): Expected self.dtype to be equal to src.dtype",
)
def ensure_nonempty_dim(dim):
return max(dim, 1)
# From aten/src/ATen/native/ScatterGatherChecks.h
def scatter_shape_check(self, dim, index, src_opt=None):
if index.numel() == 0:
return
check(
ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
lambda: "Index tensor must have the same number of dimensions as self tensor",
)
is_wrong_shape = False
self_dims = ensure_nonempty_dim(self.dim())
# Check: index.size(d) <= self.size(d) for all d != dim
for d in range(self_dims):
index_d_size = ensure_nonempty_size(index, d)
if d == dim:
continue
if index_d_size > ensure_nonempty_size(self, d):
is_wrong_shape = True
break
# Check: index.size(d) <= src.size(d) for all d if src is Tensor
if not is_wrong_shape and src_opt is not None:
for d in range(self_dims):
index_d_size = ensure_nonempty_size(index, d)
if index_d_size > ensure_nonempty_size(src_opt, d):
is_wrong_shape = True
break
if src_opt is not None:
check(
ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
lambda: "Index tensor must have the same number of dimensions as self tensor",
)
check(
not is_wrong_shape,
lambda: f"Expected index {index.shape} to be smaller than self {self.shape}"
+ f" apart from dimension {dim} and to be smaller than src {src_opt.shape}",
)
else:
check(
not is_wrong_shape,
lambda: f"Expected index {index.shape} to be smaller than self {self.shape}"
+ f" apart from dimension {dim}",
)
# From aten/src/ATen/native/TensorAdvancedIndexing.cpp
def scatter_meta_impl(self, dim, index, src=None, reduce_=None, use_new_options=False):
wrapped_dim = maybe_wrap_dim(dim, self.dim())
scatter_gather_dtype_check("scatter", self, index, src)
scatter_shape_check(self, wrapped_dim, index, src)
if reduce_ is not None:
# Check if we have a valid reduce operator.
get_operator_enum(reduce_, use_new_options)
@register_meta(aten.scatter_add.default)
def meta_scatter_add(self, dim, index, src):
scatter_meta_impl(self, dim, index, src, "add")
return self.new_empty(self.shape)
@register_meta(aten.scatter_add_)
def meta_scatter_add_(self, dim, index, src):
scatter_meta_impl(self, dim, index, src, "add")
return self
@register_meta(
[
aten.scatter.src,
aten.scatter.value,
aten.scatter.reduce,
aten.scatter.value_reduce,
]
)
@out_wrapper()
def meta_scatter(self, dim, index, src_or_value, reduce=None):
src = src_or_value if isinstance(src_or_value, torch.Tensor) else None
scatter_meta_impl(self, dim, index, src, reduce)
return self.new_empty(self.shape)
@register_meta(
[
aten.scatter_.src,
aten.scatter_.value,
aten.scatter_.reduce,
aten.scatter_.value_reduce,
]
)
def meta_scatter_(self, dim, index, src_or_value, reduce=None):
src = src_or_value if isinstance(src_or_value, torch.Tensor) else None
scatter_meta_impl(self, dim, index, src, reduce)
return self
@register_meta(
[
aten._scaled_dot_product_flash_attention,
]
)
def meta__scaled_dot_product_flash(
query: Tensor,
key: Tensor,
value: Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
return_debug_mask: bool = False,
):
# [Note] SDPA_flash's meta function returns incorrect Philox seed and offset:
# We have added logic to torch/_dynamo/variables/torch.py
# We need to check if scaled_dot_product_attention will run the flash attention
# kernel and if dropout is != 0.0. If that is the case then we want dynamo
# to graph break. The derivative calculation for _scaled_dot_product_flash_attention
# does not function correctly with cuda graphs because the full philox state is not captured
# the forward's return values. Another reason to graph break is that the the meta function
# returns the wrong outputs for philox seed and offset and these values get baked into the
# inductor fallback calls to the eager kernels.
check(
dropout_p == 0.0,
lambda: f"Can only trace _scaled_dot_product_flash_attention when dropout is set to 0 but got a dropout_p of {dropout_p}.",
)
batch_size = query.size(0)
num_heads = query.size(1)
max_seqlen_batch_q = query.size(2)
head_dim = query.size(3)
max_seqlen_batch_k = key.size(2)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
Nnz_q = batch_size * max_seqlen_batch_q
output = torch.empty(
(Nnz_q, num_heads, head_dim), dtype=query.dtype, device=query.device
)
output = output.view(batch_size, max_seqlen_batch_q, num_heads, head_dim).transpose(
1, 2
)
max_seqlen_q = math.ceil(max_seqlen_batch_q / 16) * 16
logsumexp = torch.empty(
(batch_size, num_heads, max_seqlen_q),
dtype=torch.float,
device=query.device,
)
cumulative_sequence_length_q = torch.empty(
batch_size + 1, dtype=torch.int32, device="meta"
)
cumulative_sequence_length_k = torch.empty(
batch_size + 1, dtype=torch.int32, device="meta"
)
if return_debug_mask:
blocksize_c = 128 if head_dim > 64 else 256
max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
if max_seqlen_batch_k <= 128:
max_seqlen_k = 128
elif max_seqlen_batch_k <= 256:
max_seqlen_k = 256
debug_mask = torch.empty(
(batch_size, num_heads, max_seqlen_q, max_seqlen_k),
dtype=query.dtype,
device=query.device,
)
else:
debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
return (
output,
logsumexp,
cumulative_sequence_length_q,
cumulative_sequence_length_k,
max_seqlen_batch_q,
max_seqlen_batch_k,
1, # Philox Seed will not be used, see note at top.
1, # Philox Offset will not be used, see note at top.
debug_mask,
)
@register_meta(
[
aten._scaled_dot_product_flash_attention_backward,
]
)
def meta__scaled_dot_product_flash_backward(
grad_out: Tensor,
query: Tensor,
key: Tensor,
value: Tensor,
out: Tensor,
logsumexp: Tensor,
cum_seq_q: Tensor,
cum_seq_k: Tensor,
max_q: int,
max_k: int,
dropout_p: float,
is_causal: bool,
philox_seed: int,
philox_offset: int,
):
batch_size = query.size(0)
num_heads = query.size(1)
head_dim = query.size(3)
Nnz_q = batch_size * max_q
Nnz_kv = batch_size * max_k
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
query_reshaped = query.reshape(Nnz_q, num_heads, head_dim)
key_reshaped = key.reshape(Nnz_kv, num_heads, head_dim)
value_reshaped = value.reshape(Nnz_kv, num_heads, head_dim)
grad_q = torch.empty_like(query_reshaped)
grad_k = torch.empty_like(key_reshaped)
grad_v = torch.empty_like(value_reshaped)
grad_q = grad_q.view(batch_size, max_q, num_heads, head_dim).transpose(1, 2)
grad_k = grad_k.view(batch_size, max_k, num_heads, head_dim).transpose(1, 2)
grad_v = grad_v.view(batch_size, max_k, num_heads, head_dim).transpose(1, 2)
return grad_q, grad_k, grad_v
@register_meta(
[
aten._scaled_dot_product_efficient_attention,
]
)
def meta__scaled_dot_product_efficient(
query: Tensor,
key: Tensor,
value: Tensor,
compute_log_sumexp: bool,
is_causal: bool = False,
):
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
B = query.size(0)
M = query.size(1)
N = key.size(1)
num_heads = query.size(-2)
K = query.size(-1)
Kv = value.size(-1)
res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device)
logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0
logsum_exp = torch.empty(
(B, num_heads, logsumexp_dim),
dtype=torch.float,
device=query.device,
)
res = res.transpose(1, 2)
return res, logsum_exp
@register_meta(
[
aten._scaled_dot_product_efficient_attention_backward,
]
)
def meta__scaled_dot_product_efficient_backward(
grad_out: Tensor,
query: Tensor,
key: Tensor,
value: Tensor,
out: Tensor,
logsumexp: Tensor,
is_causal: bool = False,
chunk_grad_outputs=False,
):
grad_out = grad_out.transpose(1, 2)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
B = query.size(0)
M = query.size(1)
N = key.size(1)
nH = query.size(2)
K = query.size(3)
grad_kv_needs_init = is_causal and N > M
if chunk_grad_outputs:
chunk = torch.empty((B, M, 3, nH, K), dtype=query.dtype, device=query.device)
grad_q = chunk.select(2, 0)
grad_k = chunk.select(2, 1)
grad_v = chunk.select(2, 2)
else:
grad_q = torch.empty(query.shape, dtype=query.dtype, device=query.device)
grad_k = (
torch.zeros(key.shape, dtype=key.dtype, device=key.device)
if grad_kv_needs_init
else torch.empty(key.shape, dtype=key.dtype, device=key.device)
)
grad_v = (
torch.zeros(value.shape, dtype=value.dtype, device=value.device)
if grad_kv_needs_init
else torch.empty(value.shape, dtype=value.dtype, device=value.device)
)
return grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2)
@register_meta([aten.scatter_reduce.two, aten.scatter_reduce.two_out])
@out_wrapper()
def meta_scatter_reduce_two(self, dim, index, src, reduce, include_self=True):
scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True)
return self.new_empty(self.shape)
@register_meta(aten.scatter_reduce_.two)
def meta_scatter_reduce__two(self, dim, index, src, reduce, include_self=True):
scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True)
return self
def multiply_integers(vs):
r = 1
for v in vs:
r *= v
return r
def upsample_common_check(input_size, output_size, num_spatial_dims):
check(
len(output_size) == num_spatial_dims,
lambda: f"It is expected output_size equals to {num_spatial_dims}, but got size {len(output_size)}",
)
expected_input_dims = num_spatial_dims + 2 # N, C, ...
check(
len(input_size) == expected_input_dims,
lambda: f"It is expected input_size equals to {expected_input_dims}, but got size {len(input_size)}",
)
check(
all([s > 0 for s in input_size[2:]]) and all([s > 0 for s in output_size]),
lambda: f"Input and output sizes should be greater than 0, but got "
f"input size {input_size} and output size {output_size}",
)
nbatch, channels = input_size[:2]
return (nbatch, channels, *output_size)
@register_meta(aten.upsample_nearest1d.default)
def upsample_nearest1d(input, output_size, scales=None):
check(
input.numel() != 0 or multiply_integers(input.size()[1:]),
lambda: "Non-empty 3D data tensor expected but got a tensor with sizes {input.size()}",
)
full_output_size = upsample_common_check(
input.size(), output_size, num_spatial_dims=1
)
return input.new_empty(full_output_size).to(
memory_format=utils.suggest_memory_format(input)
)
@register_meta(aten.upsample_nearest2d.default)
def upsample_nearest2d(input, output_size, scales_h=None, scales_w=None):
check(
input.numel() != 0 or multiply_integers(input.size()[1:]),
lambda: "Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}",
)
full_output_size = upsample_common_check(
input.size(), output_size, num_spatial_dims=2
)
output = input.new_empty(full_output_size)
# convert output to correct memory format, if necessary
memory_format = utils.suggest_memory_format(input)
# following "heuristic: only use channels_last path when it's faster than the contiguous path"
_, n_channels, _, _ = input.shape
if input.device.type == "cuda" and n_channels < 4:
memory_format = torch.contiguous_format
output = output.contiguous(memory_format=memory_format)
return output
@register_meta(aten.upsample_nearest3d.default)
def upsample_nearest3d(input, output_size, scales_d=None, scales_h=None, scales_w=None):
check(
input.numel() != 0 or multiply_integers(input.size()[1:]),
lambda: "Non-empty 5D data tensor expected but got a tensor with sizes {input.size()}",
)
full_output_size = upsample_common_check(
input.size(), output_size, num_spatial_dims=3
)
return input.new_empty(full_output_size).to(
memory_format=utils.suggest_memory_format(input)
)
@register_meta([aten.sort.default, aten.sort.stable])
def meta_sort(self, stable=None, dim=-1, descending=False):
return torch.empty_like(self), torch.empty_like(self, dtype=torch.int64)
def rnn_cell_checkSizes(
input_gates, hidden_gates, input_bias, hidden_bias, factor, prev_hidden
):
check(input_gates.ndim == 2, lambda: f"{input_gates.ndim} != 2")
check(
input_gates.shape == hidden_gates.shape,
lambda: f"{input_gates.shape} != {hidden_gates.shape}",
)
gates_size = input_gates.size(1)
if input_bias is not None:
check(input_bias.ndim == 1, lambda: f"{input_bias.ndim} != 1")
check(
input_bias.numel() == gates_size,
lambda: f"{input_bias.numel()} != {gates_size}",
)
check(
input_bias.shape == hidden_bias.shape,
lambda: f"{input_bias.shape} != {hidden_bias.shape}",
)
check(prev_hidden.ndim == 2, lambda: f"{prev_hidden.ndim} != 2")
expected_prev_hidden_numel = input_gates.size(0) * gates_size // factor
check(
prev_hidden.numel() == expected_prev_hidden_numel,
lambda: f"{prev_hidden.numel()} != {input_gates.size(0)} * {gates_size} // {factor} (aka {expected_prev_hidden_numel})",
)
check(
all(
x.device == input_gates.device
for x in [hidden_gates, input_bias, hidden_bias, prev_hidden]
),
lambda: "expected all inputs to be same device",
)
@register_meta(aten._thnn_fused_lstm_cell.default)
def _thnn_fused_lstm_cell_meta(
input_gates, hidden_gates, cx, input_bias=None, hidden_bias=None
):
rnn_cell_checkSizes(input_gates, hidden_gates, input_bias, hidden_bias, 4, cx)
workspace = torch.empty_like(input_gates, memory_format=torch.contiguous_format)
hy = torch.empty_like(cx, memory_format=torch.contiguous_format)
cy = torch.empty_like(cx, memory_format=torch.contiguous_format)
return (hy, cy, workspace)
@register_meta(aten._cudnn_rnn.default)
def _cudnn_rnn(
input,
weight,
weight_stride0,
weight_buf,
hx,
cx,
mode,
hidden_size,
proj_size,
num_layers,
batch_first,
dropout,
train,
bidirectional,
batch_sizes,
dropout_state,
):
is_input_packed = len(batch_sizes) != 0
if is_input_packed:
seq_length = len(batch_sizes)
mini_batch = batch_sizes[0]
batch_sizes_sum = input.shape[0]
else:
seq_length = input.shape[1] if batch_first else input.shape[0]
mini_batch = input.shape[0] if batch_first else input.shape[1]
batch_sizes_sum = -1
num_directions = 2 if bidirectional else 1
out_size = proj_size if proj_size != 0 else hidden_size
if is_input_packed:
out_shape = [batch_sizes_sum, out_size * num_directions]
else:
out_shape = (
[mini_batch, seq_length, out_size * num_directions]
if batch_first
else [seq_length, mini_batch, out_size * num_directions]
)
output = input.new_empty(out_shape)
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
if cx is None:
cy = torch.empty(0, device=input.device)
else:
cy = cx.new_empty(cell_shape)
hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size])
# TODO: Query cudnnGetRNNTrainingReserveSize (expose to python)
reserve_shape = 0 if train else 0
reserve = input.new_empty(reserve_shape, dtype=torch.uint8)
return output, hy, cy, reserve, weight_buf
@register_meta(aten.mkldnn_rnn_layer.default)
def mkldnn_rnn_layer(
input,
w0,
w1,
w2,
w3,
hx_,
cx_,
reverse,
batch_sizes,
mode,
hidden_size,
num_layers,
has_biases,
bidirectional,
batch_first,
train,
):
seq_length = input.shape[1] if batch_first else input.shape[0]
mini_batch = input.shape[0] if batch_first else input.shape[1]
output_chanels = hidden_size
out_shape = (
[mini_batch, seq_length, output_chanels]
if batch_first
else [seq_length, mini_batch, output_chanels]
)
output = input.new_empty(out_shape)
if hx_ is None:
hy = torch.empty(0, device=input.device)
else:
hy = hx_.new_empty(hx_.shape)
if cx_ is None:
cy = torch.empty(0, device=input.device)
else:
cy = cx_.new_empty(cx_.shape)
workspace = torch.empty(0, device=input.device, dtype=torch.uint8)
return output, hy, cy, workspace
def zero_numel_check_dims(self, dim, fn_name):
if self.ndim == 0:
check(
dim == 0 or dim == -1,
lambda: f"{fn_name}: Expected reduction dim -1 or 0 for scalar but got {dim}",
IndexError,
)
else:
check(
self.size(dim) != 0,
lambda: f"{fn_name}: Expected reduction dim {dim} to have non-zero size.",
IndexError,
)
# From aten/src/ATen/native/ReduceOps.cpp
def check_argmax_argmin(name, self, dim):
if dim is not None:
dim = maybe_wrap_dim(dim, self.dim())
zero_numel_check_dims(self, dim, name)
else:
check(
self.numel() != 0,
lambda: f"{name}: Expected reduction dim to be specified for input.numel() == 0.",
)
@register_meta([aten.argmax.default, aten.argmin.default])
def argmax_argmin_meta(self, dim=None, keepdim=False):
check_argmax_argmin("argmax", self, dim)
dims = utils.reduction_dims(self.shape, (dim,) if dim is not None else None)
shape = _compute_reduction_shape(self, dims, keepdim)
return self.new_empty(shape, dtype=torch.int64)
@register_meta(aten.scalar_tensor.default)
def scalar_tensor(s, dtype=None, layout=None, device=None, pin_memory=None):
return torch.empty(
(), dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
)
@register_meta(aten.topk.default)
def topk_meta(self, k, dim=-1, largest=True, sorted=True):
# From aten/src/ATen/native/Sorting.cpp
dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True)
check(
k >= 0 and k <= (self.size(dim) if self.dim() > 0 else 1),
lambda: "selected index k out of range",
)
sliceSize = 1 if self.dim() == 0 else self.size(dim)
check(k >= 0 and k <= sliceSize, lambda: "k not in range for dimension")
topKSize = list(self.shape)
if len(topKSize) > 0:
topKSize[dim] = k
return self.new_empty(topKSize), self.new_empty(topKSize, dtype=torch.int64)
legacy_contiguous_memory_format = torch.contiguous_format
# From aten/src/ATen/native/cuda/RNN.cu
def checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace):
defined_grad = grad_hy if grad_hy is not None else grad_cy
check(defined_grad.dim() == 2, lambda: "")
exp_size = defined_grad.size()
if grad_hy is not None:
check(grad_hy.size() == exp_size, lambda: "")
if grad_cy is not None:
check(grad_cy.size() == exp_size, lambda: "")
check(cx.size() == exp_size, lambda: "")
check(cy.size() == exp_size, lambda: "")
check(workspace.dim() == 2, lambda: "")
check(workspace.numel() == exp_size[0] * exp_size[1] * 4, lambda: "")
# From aten/src/ATen/native/cuda/RNN.cu
@register_meta(aten._thnn_fused_lstm_cell_backward_impl.default)
def _thnn_fused_lstm_cell_backward_impl(grad_hy, grad_cy, cx, cy, workspace, has_bias):
if grad_hy is None and grad_cy is None:
return None, None, None
checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace)
grad_gates = torch.empty_like(
workspace, memory_format=legacy_contiguous_memory_format
)
grad_cx = torch.empty_like(cx, memory_format=legacy_contiguous_memory_format)
grad_bias = grad_gates.sum(0, keepdim=False) if has_bias else None
return grad_gates, grad_cx, grad_bias
@register_meta(aten.pixel_shuffle.default)
def meta_pixel_shuffle(self, upscale_factor):
assert (
len(self.shape) > 2 and self.shape[-3] % (upscale_factor * upscale_factor) == 0
), f"Invalid input shape for pixel_shuffle: {self.shape} with upscale_factor = {upscale_factor}"
def is_channels_last(ten):
return torch._prims_common.suggest_memory_format(ten) == torch.channels_last
def pick_memory_format():
if is_channels_last(self):
if device_hint(self) == "cuda":
return torch.contiguous_format
else:
return torch.channels_last
elif self.is_contiguous(memory_format=torch.contiguous_format):
return torch.contiguous_format
elif self.is_contiguous(memory_format=torch.preserve_format):
return torch.preserve_format
C = self.shape[-3] // (upscale_factor * upscale_factor)
Hr = self.shape[-2] * upscale_factor
Wr = self.shape[-1] * upscale_factor
out_shape = (*self.shape[:-3], C, Hr, Wr)
out = self.new_empty(out_shape)
out = out.to(memory_format=pick_memory_format()) # type: ignore[call-overload]
return out
@register_meta(aten.mkldnn_rnn_layer_backward.default)
def mkldnn_rnn_layer_backward(
input,
weight0,
weight1,
weight2,
weight3,
hx_,
cx_tmp,
output,
hy_,
cy_,
grad_output_r_opt,
grad_hy_r_opt,
grad_cy_r_opt,
reverse,
mode,
hidden_size,
num_layers,
has_biases,
train,
bidirectional,
batch_sizes,
batch_first,
workspace,
):
diff_x = input.new_empty(input.shape)
diff_hx = hx_.new_empty(hx_.shape)
diff_cx = cx_tmp.new_empty(cx_tmp.shape)
diff_w1 = weight0.new_empty(weight0.shape)
diff_w2 = weight1.new_empty(weight1.shape)
diff_b = weight2.new_empty(weight2.shape)
return diff_x, diff_w1, diff_w2, diff_b, diff_b, diff_hx, diff_cx
@register_meta([aten.bucketize.Tensor, aten.bucketize.Tensor_out])
@out_wrapper()
def meta_bucketize(self, boundaries, *, out_int32=False, right=False):
return torch.empty_like(
self, dtype=torch.int32 if out_int32 else torch.int64
).contiguous()
# We must also trigger meta registrations from PrimTorch ref
# decompositions
import torch._refs
import torch._refs.nn.functional
import torch._refs.special
def activate_meta():
activate_meta_table = {}
# For a given op, we pick the most specific decomp function from
# global_decomp_table in the precedence order of meta > post_autograd > pre_autograd
for type in ["meta", "post_autograd", "pre_autograd"]:
registry = global_decomposition_table[type]
for opo in registry:
if opo not in activate_meta_table:
activate_meta_table[opo] = registry[opo]
for op_overload, fn in activate_meta_table.items():
assert isinstance(op_overload, OpOverload)
op_overload.py_impl(torch._C.DispatchKey.Meta)(fn)
if torch._C._dispatch_has_kernel_for_dispatch_key(
op_overload.name(), "CompositeImplicitAutograd"
):
# Internally, we shouldn't be registering meta kernels for any operators that
# have CompositeImplicitAutograd kernels.
# Instead, we should be letting those decompositions run, and writing meta kernels
# only for the base operators.
if op_overload in global_decomposition_table["meta"]:
raise RuntimeError(
f"{op_overload} is a CompositeImplicitAutograd op, we shouldn't "
"register meta function for it. Instead, we should let the decomposition run and write "
"meta kernels for the base operators."
)
pass
elif op_overload.is_view:
# Attempting to register a python meta kernel for a view operator.
# We shouldn't do this, because the output will report as not having aliased storages.
# All view ops have meta kernels in C++ today, so we should use those instead.
pass
elif op_overload.name() in {
"aten::empty_strided", # causing infinite recursion, test_meta.py
"aten::clone", # causing infinite recursion
"aten::_to_copy", # causing infinite recursion, test_serialization.py -k test_tensor_subclass_getstate_overwrite # noqa: B950
"aten::copy_", # Exception not raised, test_torch.py -k test_storage_meta_errors_cpu_int64 # noqa: B950
"aten::constant_pad_nd", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_amp_istft_cuda_float32 # noqa: B950
"aten::rot90", # requires_grad mismatch! test_ops.py -k test_fake_crossref_backward_amp_rot90_cuda_float32 # noqa: B950
"aten::as_strided_scatter", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_no_amp_as_strided_scatter_cuda_float32 # noqa: B950
}:
pass
else:
if "mkldnn::" in op_overload.name():
_meta_lib_dont_use_me_use_register_meta_for_mkldnn.impl(op_overload, fn)
elif "mkl::" in op_overload.name():
_meta_lib_dont_use_me_use_register_meta_for_mkl.impl(op_overload, fn)
else:
_meta_lib_dont_use_me_use_register_meta.impl(op_overload, fn)
activate_meta()