mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
Fixes #92676 `arange` infers the output dtype from the argument types, but in order to reduce falling back to ATen, inductor preferred to cast whole number float arguments to int which gave the wrong output dtype. Instead, this decomposes floating point arange into the prim equivalent for integers. This also changes the signature of `prims.arange` to ```python prims.iota(length, *, start, step, **factory_kwargs) ``` which only supports integers arguments. This is done because calculating the output size from `start, end, step` is surprisingly complex and liable to off by one errors so should not be duplicated in each backend. Pull Request resolved: https://github.com/pytorch/pytorch/pull/93353 Approved by: https://github.com/ngimel, https://github.com/lezcano
2706 lines
83 KiB
Python
2706 lines
83 KiB
Python
import math
|
|
from typing import List, Optional, Union
|
|
|
|
import torch
|
|
import torch._prims_common as utils
|
|
from torch import Tensor
|
|
from torch._decomp import _add_op_to_registry, global_decomposition_table, meta_table
|
|
from torch._ops import OpOverload
|
|
from torch._prims import _elementwise_meta, ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND
|
|
from torch._prims_common import (
|
|
check,
|
|
corresponding_complex_dtype,
|
|
corresponding_real_dtype,
|
|
elementwise_dtypes,
|
|
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
|
IntLike,
|
|
make_contiguous_strides_for,
|
|
)
|
|
|
|
from torch._prims_common.wrappers import out_wrapper
|
|
from torch._refs import _broadcast_shapes
|
|
|
|
from torch._subclasses.fake_tensor import check_no_bool_index_tensors
|
|
from torch.utils._pytree import tree_map
|
|
|
|
|
|
aten = torch.ops.aten
|
|
|
|
_meta_lib_dont_use_me_use_register_meta = torch.library.Library("aten", "IMPL", "Meta")
|
|
|
|
|
|
def register_meta(op):
|
|
def wrapper(fn):
|
|
def register(op):
|
|
_add_op_to_registry(meta_table, op, fn)
|
|
|
|
tree_map(register, op)
|
|
return fn
|
|
|
|
return wrapper
|
|
|
|
|
|
def toRealValueType(dtype):
|
|
from_complex = {
|
|
torch.complex32: torch.half,
|
|
torch.cfloat: torch.float,
|
|
torch.cdouble: torch.double,
|
|
}
|
|
return from_complex.get(dtype, dtype)
|
|
|
|
|
|
@register_meta([aten._fft_c2c.default, aten._fft_c2c.out])
|
|
@out_wrapper()
|
|
def meta_fft_c2c(self, dim, normalization, forward):
|
|
assert self.dtype.is_complex
|
|
return self.new_empty(self.size())
|
|
|
|
|
|
@register_meta([aten._fft_r2c.default, aten._fft_r2c.out])
|
|
@out_wrapper()
|
|
def meta_fft_r2c(self, dim, normalization, onesided):
|
|
assert self.dtype.is_floating_point
|
|
output_sizes = list(self.size())
|
|
|
|
if onesided:
|
|
last_dim = dim[-1]
|
|
last_dim_halfsize = (output_sizes[last_dim] // 2) + 1
|
|
output_sizes[last_dim] = last_dim_halfsize
|
|
|
|
return self.new_empty(
|
|
output_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
|
|
)
|
|
|
|
|
|
@register_meta(aten.randperm.generator_out)
|
|
def meta_randperm(n, *, generator=None, out):
|
|
assert out.ndim == 1 and out.size(0) == n
|
|
return out
|
|
|
|
|
|
@register_meta(aten.randint.default)
|
|
def meta_randint(
|
|
high, size, *, dtype=torch.long, layout=None, device=None, pin_memory=None
|
|
):
|
|
return torch.empty(
|
|
size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
|
|
)
|
|
|
|
|
|
@register_meta(aten.randint.low)
|
|
def meta_randint_low(
|
|
low, high, size, *, dtype=torch.long, layout=None, device=None, pin_memory=None
|
|
):
|
|
return torch.empty(
|
|
size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
|
|
)
|
|
|
|
|
|
@register_meta(aten.rand.default)
|
|
def meta_rand_default(size, *, dtype=None, layout=None, device=None, pin_memory=None):
|
|
return torch.empty(
|
|
size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
|
|
)
|
|
|
|
|
|
@register_meta([aten._fft_c2r.default, aten._fft_c2r.out])
|
|
@out_wrapper()
|
|
def meta_fft_c2r(self, dim, normalization, lastdim):
|
|
assert self.dtype.is_complex
|
|
output_sizes = list(self.size())
|
|
output_sizes[dim[-1]] = lastdim
|
|
return self.new_empty(output_sizes, dtype=toRealValueType(self.dtype))
|
|
|
|
|
|
@register_meta(aten.copy_.default)
|
|
def meta_copy_(self, src, non_blocking=False):
|
|
return self
|
|
|
|
|
|
def inferUnsqueezeGeometry(tensor, dim):
|
|
result_sizes = list(tensor.size())
|
|
result_strides = list(tensor.stride())
|
|
new_stride = 1 if dim >= tensor.dim() else result_sizes[dim] * result_strides[dim]
|
|
result_sizes.insert(dim, 1)
|
|
result_strides.insert(dim, new_stride)
|
|
return result_sizes, result_strides
|
|
|
|
|
|
@register_meta(aten.unsqueeze_.default)
|
|
def meta_unsqueeze_(self, dim):
|
|
dim = maybe_wrap_dim(dim, self.dim() + 1)
|
|
g_sizes, g_strides = inferUnsqueezeGeometry(self, dim)
|
|
self.as_strided_(g_sizes, g_strides)
|
|
return self
|
|
|
|
|
|
# Implementations below are taken from https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py
|
|
@register_meta(aten.index_select.default)
|
|
def meta_index_select(self, dim, index):
|
|
result_size = list(self.size())
|
|
if self.dim() > 0:
|
|
result_size[dim] = index.numel()
|
|
return self.new_empty(result_size)
|
|
|
|
|
|
@register_meta(aten.index_select.out)
|
|
def meta_index_select_out(self, dim, index, out):
|
|
torch._resize_output_(out, self.size(), self.device)
|
|
return out.copy_(torch.index_select(self, dim, index))
|
|
|
|
|
|
@register_meta([aten.max.default, aten.max.unary_out])
|
|
@out_wrapper()
|
|
def meta_max(self):
|
|
return self.new_empty(())
|
|
|
|
|
|
@register_meta(aten.max.dim)
|
|
def meta_max_dim(self, dim, keepdim=False):
|
|
dim = utils.reduction_dims(self.shape, (dim,))
|
|
output_shape = _compute_reduction_shape(self, dim, keepdim)
|
|
return (
|
|
self.new_empty(output_shape),
|
|
self.new_empty(output_shape, dtype=torch.long),
|
|
)
|
|
|
|
|
|
@register_meta([aten.min.default])
|
|
def meta_min(self):
|
|
return self.new_empty(())
|
|
|
|
|
|
@register_meta(aten.angle.default)
|
|
def meta_angle(self):
|
|
if self.is_complex():
|
|
result_dtype = corresponding_real_dtype(self.dtype)
|
|
else:
|
|
_, result_dtype = elementwise_dtypes(
|
|
self, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
|
|
)
|
|
return torch.empty_like(self, dtype=result_dtype)
|
|
|
|
|
|
@register_meta(aten.angle.out)
|
|
def meta_angle_out(self, out):
|
|
torch._resize_output_(out, self.size(), self.device)
|
|
return out.copy_(torch.angle(self))
|
|
|
|
|
|
# From aten/src/ATen/native/LinearAlgebraUtils.h
|
|
def squareCheckInputs(self: Tensor, f_name: str):
|
|
assert (
|
|
self.dim() >= 2
|
|
), f"{f_name}: The input tensor must have at least 2 dimensions."
|
|
assert self.size(-1) == self.size(
|
|
-2
|
|
), f"{f_name}: A must be batches of square matrices, but they are {self.size(-2)} by {self.size(-1)} matrices"
|
|
|
|
|
|
# From aten/src/ATen/native/LinearAlgebraUtils.h
|
|
def checkFloatingOrComplex(
|
|
t: Tensor, f_name: str, allow_low_precision_dtypes: bool = True
|
|
):
|
|
dtype = t.dtype
|
|
check(
|
|
t.is_floating_point() or t.is_complex(),
|
|
lambda: f"{f_name}, : Expected a floating point or complex tensor as input. Got , {dtype}",
|
|
)
|
|
if allow_low_precision_dtypes:
|
|
check(
|
|
dtype in (torch.float, torch.double, torch.cfloat, torch.cdouble),
|
|
lambda: f"{f_name} : Low precision dtypes not supported. Got {dtype}",
|
|
)
|
|
|
|
|
|
# From aten/src/ATen/native/LinearAlgebraUtils.h
|
|
def checkIsMatrix(A: Tensor, f_name: str, arg_name: str = "A"):
|
|
check(
|
|
A.dim() >= 2,
|
|
lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.",
|
|
)
|
|
|
|
|
|
def checkUplo(uplo: str):
|
|
uplo_uppercase = uplo.upper()
|
|
assert (
|
|
len(uplo) == 1 and uplo_uppercase == "U" or uplo_uppercase == "L"
|
|
), f"Expected UPLO argument to be 'L' or 'U', but got {uplo}"
|
|
|
|
|
|
# @register_meta(aten.linalg_eigh.default)
|
|
def meta_linalg_eigh(self, uplo="L"):
|
|
squareCheckInputs(self, "linalg_eigh")
|
|
checkUplo(uplo)
|
|
real_dtype = toRealValueType(self.dtype)
|
|
assert self.dim() >= 2
|
|
values = self.new_empty(self.shape, dtype=real_dtype)
|
|
values.transpose_(-2, -1)
|
|
vectors = self.new_empty(self.shape[:-1])
|
|
return (values, vectors)
|
|
|
|
|
|
# From aten/src/ATen/native/BatchLinearAlgebra.cpp
|
|
@register_meta(aten.linalg_cholesky_ex.default)
|
|
def linalg_cholesky_ex(A: Tensor, upper: bool = False, check_errors: bool = False):
|
|
squareCheckInputs(A, "linalg.cholesky")
|
|
checkFloatingOrComplex(A, "linalg.cholesky")
|
|
|
|
A_shape = A.shape
|
|
ndim = len(A_shape)
|
|
|
|
# L
|
|
L_strides = make_contiguous_strides_for(A_shape, False)
|
|
L = A.new_empty(A_shape)
|
|
L.as_strided_(A_shape, L_strides)
|
|
|
|
# infos
|
|
infos = A.new_empty(A_shape[0 : ndim - 2], dtype=torch.int32)
|
|
return L, infos
|
|
|
|
|
|
# From aten/src/ATen/native/BatchLinearAlgebra.cpp
|
|
@register_meta(aten.linalg_inv_ex.default)
|
|
def linalg_inv_ex_meta(A: Tensor, check_errors: bool = False):
|
|
squareCheckInputs(A, "linalg.inv_ex")
|
|
checkFloatingOrComplex(A, "linalg.inv_ex", allow_low_precision_dtypes=False)
|
|
|
|
L = A.new_empty(A.shape)
|
|
L.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False))
|
|
|
|
infos = A.new_empty(A.shape[:-2], dtype=torch.int32)
|
|
return L, infos
|
|
|
|
|
|
# From aten/src/ATen/native/BatchLinearAlgebra.cpp
|
|
# NOTE: matching defaults in aten/src/ATen/native/native_functions.yaml
|
|
@register_meta(aten._linalg_svd.default)
|
|
def _linalg_svd_meta(
|
|
A: Tensor, full_matrices: bool = False, compute_uv: bool = True, driver: str = None
|
|
):
|
|
checkIsMatrix(A, "linalg.svd")
|
|
checkFloatingOrComplex(A, "linalg.svd")
|
|
|
|
batch_dims = list(A.shape[:-2])
|
|
m = A.shape[-2]
|
|
n = A.shape[-1]
|
|
k = min(m, n)
|
|
|
|
if compute_uv:
|
|
U_shape = batch_dims + [m, m if full_matrices else k]
|
|
U = A.new_empty(U_shape)
|
|
U.as_strided_(U_shape, make_contiguous_strides_for(U_shape, row_major=False))
|
|
|
|
V_shape = batch_dims + [n if full_matrices else k, n]
|
|
V = A.new_empty(V_shape)
|
|
# TODO: need to distinguish cuSOLVER case? (see original code)
|
|
V.as_strided_(V_shape, make_contiguous_strides_for(V_shape, row_major=False))
|
|
else:
|
|
# doesn't matter
|
|
U = A.new_empty([0])
|
|
V = A.new_empty([0])
|
|
|
|
# S is always real, even when A is complex.
|
|
S = A.new_empty(batch_dims + [k], dtype=toRealValueType(A.dtype))
|
|
return U, S, V
|
|
|
|
|
|
# From aten/src/ATen/native/LinearAlgebra.cpp
|
|
@register_meta(aten._linalg_det.default)
|
|
def _linalg_det_meta(A):
|
|
squareCheckInputs(A, "linalg.det")
|
|
checkFloatingOrComplex(A, "linalg.det")
|
|
|
|
det = A.new_empty(A.shape[:-2])
|
|
|
|
LU = A.new_empty(A.shape)
|
|
LU.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False))
|
|
|
|
pivots = A.new_empty(A.shape[:-1], dtype=torch.int32)
|
|
return det, LU, pivots
|
|
|
|
|
|
# From aten/src/ATen/native/ReflectionPad.cpp
|
|
@register_meta(
|
|
[aten.reflection_pad2d_backward.default, aten.replication_pad2d_backward.default]
|
|
)
|
|
def meta_pad2d_backward(grad_output, self, padding):
|
|
dim_w = 2
|
|
dim_h = 1
|
|
dim_plane = 0
|
|
nbatch = 1
|
|
|
|
self_shape = self.shape
|
|
if self.dim() == 4:
|
|
nbatch = self_shape[0]
|
|
dim_w += 1
|
|
dim_h += 1
|
|
dim_plane += 1
|
|
|
|
pad_l = padding[0]
|
|
pad_r = padding[1]
|
|
pad_t = padding[2]
|
|
pad_b = padding[3]
|
|
|
|
nplane = self_shape[dim_plane]
|
|
input_h = self_shape[dim_h]
|
|
input_w = self_shape[dim_w]
|
|
output_h = input_h + pad_t + pad_b
|
|
output_w = input_w + pad_l + pad_r
|
|
|
|
check(
|
|
output_w == grad_output.shape[dim_w],
|
|
lambda: f"gradOutput width unexpected. Expected: {output_w}, Got: {grad_output.shape[dim_w]}",
|
|
)
|
|
check(
|
|
output_h == grad_output.shape[dim_h],
|
|
lambda: f"gradOutput height unexpected. Expected: {output_h}, Got: {grad_output.shape[dim_h]}",
|
|
)
|
|
return self.new_empty(self.shape)
|
|
|
|
|
|
@register_meta(aten.reflection_pad2d.default)
|
|
def meta_pad2d(self, padding):
|
|
valid_dims = self.size(1) != 0 and self.size(2) != 0
|
|
check(
|
|
(self.ndim == 3 and valid_dims)
|
|
or (self.ndim == 4 and valid_dims and self.size(3) != 0),
|
|
lambda: f"3D or 4D (batch mode) tensor expected for input, but got: {self}",
|
|
)
|
|
if self.ndim == 4:
|
|
nbatch, nplane, input_h, input_w = self.shape
|
|
else:
|
|
nbatch = 1
|
|
nplane, input_h, input_w = self.shape
|
|
|
|
pad_l, pad_r, pad_t, pad_b = padding
|
|
|
|
output_h = input_h + pad_t + pad_b
|
|
output_w = input_w + pad_l + pad_r
|
|
|
|
if self.ndim == 3:
|
|
return self.new_empty((nplane, output_h, output_w))
|
|
else:
|
|
return self.new_empty((nbatch, nplane, output_h, output_w))
|
|
|
|
|
|
@register_meta([aten.bernoulli.default, aten.bernoulli.out])
|
|
@out_wrapper()
|
|
def meta_bernoulli(self, *, generator=None):
|
|
# https://github.com/pytorch/pytorch/issues/88612
|
|
return torch.empty_like(self).contiguous()
|
|
|
|
|
|
@register_meta(aten.bernoulli_.float)
|
|
def meta_bernoulli_(self, p=0.5, generator=None):
|
|
return self
|
|
|
|
|
|
@register_meta(aten.bernoulli.p)
|
|
def meta_bernoulli_p(self, p=0.5, generator=None):
|
|
# https://github.com/pytorch/pytorch/issues/88612
|
|
return torch.empty_like(self).contiguous()
|
|
|
|
|
|
@register_meta(aten._fused_moving_avg_obs_fq_helper.default)
|
|
def meta__fused_moving_avg_obs_fq_helper(
|
|
self,
|
|
observer_on,
|
|
fake_quant_on,
|
|
running_min,
|
|
running_max,
|
|
scale,
|
|
zero_point,
|
|
averaging_const,
|
|
quant_min,
|
|
quant_max,
|
|
ch_axis,
|
|
per_row_fake_quant=False,
|
|
symmetric_quant=False,
|
|
):
|
|
check(
|
|
ch_axis < self.dim(),
|
|
lambda: "Error in fused_moving_avg_obs_fake_quant_cpu: ch_axis must be < self.dim()",
|
|
)
|
|
mask = torch.empty_like(self, dtype=torch.bool)
|
|
return (torch.empty_like(self), mask)
|
|
|
|
|
|
def dot_check(self, other):
|
|
check(
|
|
self.dim() == 1 and other.dim() == 1,
|
|
lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors",
|
|
)
|
|
|
|
|
|
@register_meta(aten.dot.default)
|
|
def meta_dot(self, tensor):
|
|
dot_check(self, tensor)
|
|
return self.new_empty(())
|
|
|
|
|
|
@register_meta([aten.mm.default])
|
|
def meta_mm(a, b):
|
|
check(a.dim() == 2, lambda: "a must be 2D")
|
|
check(b.dim() == 2, lambda: "b must be 2D")
|
|
N, M1 = a.shape
|
|
M2, P = b.shape
|
|
check(M1 == M2, lambda: "a and b must have same reduction dim")
|
|
return a.new_empty(N, P)
|
|
|
|
|
|
def _compute_reduction_shape(self, dims, keepdim):
|
|
if keepdim:
|
|
return tuple(self.shape[i] if i not in dims else 1 for i in range(self.ndim))
|
|
|
|
return utils.compute_reduction_output_shape(self.shape, dims)
|
|
|
|
|
|
# FakeTensors (meta tensors with a device) will report device as meta
|
|
# when running meta kernels. Here, access the "fake device" of FakeTensor if it
|
|
# exists so meta kernels which have diverge per device will be more
|
|
# accurate when run with FakeTensors
|
|
def device_hint(tensor) -> "str":
|
|
if isinstance(tensor, torch._subclasses.FakeTensor):
|
|
return tensor.fake_device.type
|
|
else:
|
|
return "cuda" # default to cuda
|
|
|
|
|
|
def calc_conv_nd_return_shape(
|
|
input_tensor: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
stride: Union[List[int], int],
|
|
padding: Union[List[int], int],
|
|
dilation: Union[List[int], int],
|
|
is_transposed: bool,
|
|
groups: int,
|
|
output_padding: Optional[Union[List[int], int]] = None,
|
|
):
|
|
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
|
|
"""
|
|
Formula to apply to calculate the length of some dimension of the output
|
|
|
|
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
|
|
|
Args:
|
|
ln: length of the dimension
|
|
p: padding in that dim
|
|
d: dilation in that dim
|
|
k: kernel size in that dim
|
|
s: stride in that dim
|
|
Returns:
|
|
The output length
|
|
"""
|
|
return (ln + 2 * p - d * (k - 1) - 1) // s + 1
|
|
|
|
def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int:
|
|
"""
|
|
Formula to apply to calculate the length of some dimension of the output
|
|
if transposed convolution is used.
|
|
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
|
|
|
|
Args:
|
|
ln: length of the dimension
|
|
p: padding in that dim
|
|
d: dilation in that dim
|
|
k: kernel size in that dim
|
|
s: stride in that dim
|
|
op: output padding in that dim
|
|
|
|
Returns:
|
|
The output length
|
|
"""
|
|
return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1
|
|
|
|
kernel_size = weight.shape[2:]
|
|
dims = input_tensor.shape[2:]
|
|
if is_transposed:
|
|
out_channels = groups * weight.shape[1]
|
|
else:
|
|
out_channels = weight.shape[0]
|
|
if weight.shape[1] * groups != input_tensor.shape[1]:
|
|
raise RuntimeError("Invalid channel dimensions")
|
|
|
|
ret_shape = [input_tensor.shape[0], out_channels]
|
|
if isinstance(stride, IntLike):
|
|
stride = [stride] * len(dims)
|
|
elif len(stride) == 1:
|
|
stride = [stride[0]] * len(dims)
|
|
|
|
if isinstance(padding, IntLike):
|
|
padding = [padding] * len(dims)
|
|
elif len(padding) == 1:
|
|
padding = [padding[0]] * len(dims)
|
|
|
|
if isinstance(dilation, IntLike):
|
|
dilation = [dilation] * len(dims)
|
|
elif len(dilation) == 1:
|
|
dilation = [dilation[0]] * len(dims)
|
|
|
|
output_padding_list: Optional[List[int]] = None
|
|
if output_padding:
|
|
if isinstance(output_padding, IntLike):
|
|
output_padding_list = [output_padding] * len(dims)
|
|
elif len(output_padding) == 1:
|
|
output_padding_list = [output_padding[0]] * len(dims)
|
|
else:
|
|
output_padding_list = output_padding
|
|
|
|
for i in range(len(dims)):
|
|
# If output_padding is present, we are dealing with a transposed convolution
|
|
if output_padding_list:
|
|
ret_shape.append(
|
|
_formula_transposed(
|
|
dims[i],
|
|
padding[i],
|
|
dilation[i],
|
|
kernel_size[i],
|
|
stride[i],
|
|
output_padding_list[i],
|
|
)
|
|
)
|
|
else:
|
|
ret_shape.append(
|
|
_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])
|
|
)
|
|
|
|
return ret_shape
|
|
|
|
|
|
def is_channels_last(ten):
|
|
return torch._prims_common.suggest_memory_format(ten) == torch.channels_last
|
|
|
|
|
|
@register_meta(aten.convolution.default)
|
|
def meta_conv(
|
|
input_tensor: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
bias: torch.Tensor,
|
|
stride: List[int],
|
|
padding: List[int],
|
|
dilation: List[int],
|
|
is_transposed: bool,
|
|
output_padding: List[int],
|
|
groups: int,
|
|
):
|
|
def pick_memory_format():
|
|
if device_hint(input_tensor) == "cuda":
|
|
if is_channels_last(input_tensor) or is_channels_last(weight):
|
|
return torch.channels_last
|
|
else:
|
|
if is_channels_last(input_tensor):
|
|
return torch.channels_last
|
|
if input_tensor.is_contiguous(memory_format=torch.contiguous_format):
|
|
return torch.contiguous_format
|
|
elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
|
|
return torch.preserve_format
|
|
|
|
shape_out = calc_conv_nd_return_shape(
|
|
input_tensor,
|
|
weight,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
is_transposed,
|
|
groups,
|
|
output_padding if is_transposed else None,
|
|
)
|
|
|
|
out = input_tensor.new_empty(shape_out)
|
|
out = out.to(memory_format=pick_memory_format()) # type: ignore[call-overload]
|
|
return out
|
|
|
|
|
|
if torch._C.has_mkldnn:
|
|
_meta_lib_dont_use_me_use_register_meta_for_mkldnn = torch.library.Library(
|
|
"mkldnn", "IMPL", "Meta"
|
|
)
|
|
|
|
def pick_mkldnn_conv_memory_format(input_tensor, weight):
|
|
if weight.is_mkldnn:
|
|
return torch.channels_last
|
|
if is_channels_last(input_tensor) or is_channels_last(weight):
|
|
return torch.channels_last
|
|
if input_tensor.is_contiguous(memory_format=torch.contiguous_format):
|
|
return torch.contiguous_format
|
|
elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
|
|
return torch.preserve_format
|
|
|
|
@register_meta(torch.ops.mkldnn._convolution_pointwise.default)
|
|
def meta_mkldnn_convolution_default(
|
|
input_tensor,
|
|
weight,
|
|
bias,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
attr,
|
|
scalars,
|
|
algorithm,
|
|
):
|
|
shape_out = calc_conv_nd_return_shape(
|
|
input_tensor, weight, stride, padding, dilation, False, groups, []
|
|
)
|
|
out = input_tensor.new_empty(shape_out)
|
|
out_memory_format = torch.channels_last
|
|
out = out.to(memory_format=out_memory_format) # type: ignore[call-overload]
|
|
return out
|
|
|
|
@register_meta(torch.ops.mkldnn._convolution_pointwise.binary)
|
|
def meta_mkldnn_convolution_binary(
|
|
input_tensor,
|
|
other,
|
|
weight,
|
|
bias,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
binary_attr,
|
|
alpha,
|
|
unary_attr,
|
|
unary_scalars,
|
|
unary_algorithm,
|
|
):
|
|
out = input_tensor.new_empty(other.size())
|
|
out = out.to(memory_format=torch.channels_last) # type: ignore[call-overload]
|
|
return out
|
|
|
|
@register_meta(torch.ops.mkldnn._convolution_pointwise_.binary)
|
|
def meta_mkldnn_convolution_binary_inplace(
|
|
input_tensor,
|
|
other,
|
|
weight,
|
|
bias,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
binary_attr,
|
|
alpha,
|
|
unary_attr,
|
|
unary_scalars,
|
|
unary_algorithm,
|
|
):
|
|
return other
|
|
|
|
@register_meta(torch.ops.mkldnn._linear_pointwise.default)
|
|
def meta_linear_pointwise_default(
|
|
input_tensor, weight, bias, attr, scalars, algorithm
|
|
):
|
|
return input_tensor.new_empty((*input_tensor.shape[:-1], weight.shape[0]))
|
|
|
|
@register_meta(torch.ops.mkldnn._linear_pointwise.binary)
|
|
def meta_linear_pointwise_binary(input_tensor, other, weight, bias, attr):
|
|
out = input_tensor.new_empty(other.size())
|
|
return out
|
|
|
|
if torch._C.has_mkl:
|
|
_meta_lib_dont_use_me_use_register_meta_for_mkl = torch.library.Library(
|
|
"mkl", "IMPL", "Meta"
|
|
)
|
|
|
|
@register_meta(torch.ops.mkl._mkl_linear)
|
|
def meta_mkl_linear(
|
|
input_tensor,
|
|
packed_weight,
|
|
orig_weight,
|
|
bias,
|
|
batch_size,
|
|
):
|
|
return input_tensor.new_empty(
|
|
(*input_tensor.shape[:-1], orig_weight.shape[0])
|
|
)
|
|
|
|
|
|
# from check_dim_size() in aten/src/ATen/TensorUtils.cpp.
|
|
def check_dim_size(tensor, dim, dim_size, size):
|
|
check(
|
|
tensor.dim() == dim and tensor.shape[dim_size] == size,
|
|
lambda: f"Expected a tensor of dimension {dim} and tensor.size[{dim_size}] == {size}, "
|
|
+ f"but got : dimension {tensor.dim()} and tensor.size[{dim_size}] = {tensor.shape[dim_size]}",
|
|
)
|
|
|
|
|
|
@register_meta(aten.avg_pool2d.default)
|
|
def meta_avg_pool2d(
|
|
input,
|
|
kernel_size,
|
|
stride=(),
|
|
padding=(0,),
|
|
ceil_mode=False,
|
|
count_include_pad=True,
|
|
divisor_override=None,
|
|
):
|
|
def unpack(name, val):
|
|
check(
|
|
len(val) in [1, 2],
|
|
lambda: f"avg_pool2d: {name} must either be a single int, or a tuple of two ints",
|
|
)
|
|
H = val[0]
|
|
W = H if len(val) == 1 else val[1]
|
|
return H, W
|
|
|
|
kH, kW = unpack("kernel_size", kernel_size)
|
|
check(
|
|
len(stride) in [0, 1, 2],
|
|
lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
|
|
)
|
|
if len(stride) == 0:
|
|
dH, dW = kH, kW
|
|
elif len(stride) == 1:
|
|
dH, dW = stride[0], stride[0]
|
|
else:
|
|
dH, dW = unpack("stride", stride)
|
|
|
|
padH, padW = unpack("padding", padding)
|
|
|
|
check(
|
|
divisor_override is None or divisor_override != 0,
|
|
lambda: "divisor must be not zero",
|
|
)
|
|
|
|
nbatch = input.size(-4) if input.dim() == 4 else 1
|
|
nInputPlane = input.size(-3)
|
|
inputHeight = input.size(-2)
|
|
inputWidth = input.size(-1)
|
|
|
|
outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
|
|
outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
|
|
|
|
memory_format = utils.suggest_memory_format(input)
|
|
pool2d_shape_check(
|
|
input,
|
|
kH,
|
|
kW,
|
|
dH,
|
|
dW,
|
|
padH,
|
|
padW,
|
|
1,
|
|
1,
|
|
nInputPlane,
|
|
inputHeight,
|
|
inputWidth,
|
|
outputHeight,
|
|
outputWidth,
|
|
memory_format,
|
|
)
|
|
|
|
if input.dim() == 3:
|
|
size = [nInputPlane, outputHeight, outputWidth]
|
|
else:
|
|
size = [nbatch, nInputPlane, outputHeight, outputWidth]
|
|
return torch.empty(
|
|
size, dtype=input.dtype, device=input.device, memory_format=memory_format
|
|
)
|
|
|
|
|
|
# from avg_pool2d_backward_shape_check() in aten/src/ATen/native/Pool.h.
|
|
def avg_pool2d_backward_shape_check(
|
|
input,
|
|
gradOutput,
|
|
nbatch,
|
|
kH,
|
|
kW,
|
|
dH,
|
|
dW,
|
|
padH,
|
|
padW,
|
|
nInputPlane,
|
|
inputHeight,
|
|
inputWidth,
|
|
outputHeight,
|
|
outputWidth,
|
|
mem_format,
|
|
):
|
|
pool2d_shape_check(
|
|
input,
|
|
kH,
|
|
kW,
|
|
dH,
|
|
dW,
|
|
padH,
|
|
padW,
|
|
1,
|
|
1,
|
|
nInputPlane,
|
|
inputHeight,
|
|
inputWidth,
|
|
outputHeight,
|
|
outputWidth,
|
|
mem_format,
|
|
)
|
|
|
|
ndim = input.dim()
|
|
nOutputPlane = nInputPlane
|
|
|
|
check_dim_size(gradOutput, ndim, ndim - 3, nOutputPlane)
|
|
check_dim_size(gradOutput, ndim, ndim - 2, outputHeight)
|
|
check_dim_size(gradOutput, ndim, ndim - 1, outputWidth)
|
|
|
|
|
|
# Don't override the C++ registration.
|
|
@register_meta(aten.avg_pool2d_backward.default)
|
|
def meta_avg_pool2d_backward(
|
|
gradOutput_,
|
|
input,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
ceil_mode,
|
|
count_include_pad,
|
|
divisor_override,
|
|
):
|
|
# From aten/src/ATen/native/AveragePool2d.cpp structured kernel meta func.
|
|
check(
|
|
len(kernel_size) == 1 or len(kernel_size) == 2,
|
|
lambda: "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints",
|
|
)
|
|
kH = kernel_size[0]
|
|
kW = kH if len(kernel_size) == 1 else kernel_size[1]
|
|
check(
|
|
len(stride) == 0 or len(stride) == 1 or len(stride) == 2,
|
|
lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
|
|
)
|
|
dH = kH if len(stride) == 0 else stride[0]
|
|
dW = kW if len(stride) == 0 else dH if len(stride) == 1 else stride[1]
|
|
check(
|
|
len(padding) == 1 or len(padding) == 2,
|
|
lambda: "avg_pool2d: padding must either be a single int, or a tuple of two ints",
|
|
)
|
|
padH = padding[0]
|
|
padW = padH if len(padding) == 1 else padding[1]
|
|
|
|
check(
|
|
divisor_override is None or divisor_override != 0,
|
|
lambda: "divisor must be not zero",
|
|
)
|
|
|
|
input_size = input.shape
|
|
nbatch = input_size[-4] if input.dim() == 4 else 1
|
|
nInputPlane = input_size[-3]
|
|
inputHeight = input_size[-2]
|
|
inputWidth = input_size[-1]
|
|
|
|
outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
|
|
outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
|
|
|
|
mem_format = utils.suggest_memory_format(input)
|
|
|
|
avg_pool2d_backward_shape_check(
|
|
input,
|
|
gradOutput_,
|
|
nbatch,
|
|
kH,
|
|
kW,
|
|
dH,
|
|
dW,
|
|
padH,
|
|
padW,
|
|
nInputPlane,
|
|
inputHeight,
|
|
inputWidth,
|
|
outputHeight,
|
|
outputWidth,
|
|
mem_format,
|
|
)
|
|
|
|
return torch.empty(
|
|
input_size, dtype=input.dtype, device=input.device, memory_format=mem_format
|
|
)
|
|
|
|
|
|
@register_meta(aten._adaptive_avg_pool2d.default)
|
|
def meta_adaptive_avg_pool2d(self, output_size):
|
|
check(
|
|
self.ndim == 3 or self.ndim == 4,
|
|
lambda: f"Expected 3D or 4D tensor, but got {self.shape}",
|
|
)
|
|
output_shape = self.shape[:-2] + tuple(output_size)
|
|
memory_format = utils.suggest_memory_format(self)
|
|
# need to set memory_format to preserve the memory format of the input
|
|
# channel last input should have channel last output
|
|
return torch.empty(
|
|
output_shape, dtype=self.dtype, device=self.device, memory_format=memory_format
|
|
)
|
|
|
|
|
|
@register_meta(aten._adaptive_avg_pool3d.default)
|
|
def meta_adaptive_avg_pool3d(self, output_size):
|
|
check(
|
|
self.ndim == 4 or self.ndim == 5,
|
|
lambda: f"Expected 4D or 5D tensor, but got {self.shape}",
|
|
)
|
|
return self.new_empty(self.shape[:-3] + tuple(output_size))
|
|
|
|
|
|
@register_meta(aten._adaptive_avg_pool2d_backward.default)
|
|
def meta__adaptive_avg_pool2d_backward(grad_out, self):
|
|
ndim = grad_out.ndim
|
|
for i in range(1, ndim):
|
|
check(
|
|
grad_out.size(i) > 0,
|
|
lambda: f"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero \
|
|
size for non-batch dimensions, {grad_out.shape} with dimension {i} being empty",
|
|
)
|
|
check(
|
|
ndim == 3 or ndim == 4,
|
|
lambda: f"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got {self.shape}",
|
|
)
|
|
check(
|
|
self.dtype == grad_out.dtype,
|
|
lambda: f"expected dtype {self.dtype} for `grad_output` but got dtype {grad_out.dtype}",
|
|
)
|
|
return self.new_empty(self.shape)
|
|
|
|
|
|
@register_meta(aten.repeat_interleave.Tensor)
|
|
def meta_repeat_interleave_Tensor(repeats, output_size=None):
|
|
if output_size is None:
|
|
raise RuntimeError("cannot repeat_interleave a meta tensor without output_size")
|
|
return repeats.new_empty(output_size)
|
|
|
|
|
|
@register_meta([aten.complex.default, aten.complex.out])
|
|
@out_wrapper()
|
|
def meta_complex(real, imag):
|
|
assert real.dtype.is_floating_point
|
|
assert imag.dtype.is_floating_point
|
|
out_shape = _broadcast_shapes(real.shape, imag.shape)
|
|
return real.new_empty(out_shape, dtype=corresponding_complex_dtype(real.dtype))
|
|
|
|
|
|
@register_meta(aten.vdot.default)
|
|
def vdot(self, other):
|
|
if not self.is_complex:
|
|
return torch.dot(self, other)
|
|
|
|
if self.is_conj():
|
|
if other.is_conj():
|
|
return torch.vdot(other.conj(), self.conj())
|
|
else:
|
|
return torch.dot(self.conj(), other)
|
|
elif other.is_conj():
|
|
return torch.dot(self, other.conj()).conj()
|
|
|
|
dot_check(self, other)
|
|
return self.new_empty(())
|
|
|
|
|
|
# Leaving this function around because a python implementation
|
|
# of indexing shape inference is useful,
|
|
# but not registering it to the dispatcher because we already
|
|
# get shape inference through structured kernels
|
|
@register_meta(aten.index.Tensor)
|
|
def meta_index_Tensor(self, indices):
|
|
check_no_bool_index_tensors(aten.index.Tensor, self, indices)
|
|
check(indices, lambda: "at least one index must be provided")
|
|
# aten::index is the internal advanced indexing implementation
|
|
# checkIndexTensorTypes and expandTensors
|
|
result: List[Optional[Tensor]] = []
|
|
for i, index in enumerate(indices):
|
|
if index is not None:
|
|
check(
|
|
index.dtype in [torch.long, torch.int, torch.int8, torch.bool],
|
|
lambda: "tensors used as indices must be long, int, byte or bool tensors",
|
|
)
|
|
if index.dtype in [torch.int8, torch.bool]:
|
|
nonzero = index.nonzero()
|
|
k = len(result)
|
|
check(
|
|
k + index.ndim <= self.ndim,
|
|
lambda: f"too many indices for tensor of dimension {self.ndim}",
|
|
IndexError,
|
|
)
|
|
for j in range(index.ndim):
|
|
check(
|
|
index.shape[j] == self.shape[k + j],
|
|
lambda: f"The shape of the mask {index.shape} at index {i} "
|
|
f"does not match the shape of the indexed tensor {self.shape} at index {k + j}",
|
|
IndexError,
|
|
)
|
|
result.append(nonzero.select(1, j))
|
|
else:
|
|
result.append(index)
|
|
else:
|
|
result.append(index)
|
|
indices = result
|
|
check(
|
|
len(indices) <= self.ndim,
|
|
lambda: f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})",
|
|
)
|
|
# expand_outplace
|
|
import torch._refs as refs # avoid import cycle in mypy
|
|
|
|
indices = list(refs._maybe_broadcast(*indices))
|
|
# add missing null tensors
|
|
while len(indices) < self.ndim:
|
|
indices.append(None)
|
|
|
|
# hasContiguousSubspace
|
|
# true if all non-null tensors are adjacent
|
|
# See:
|
|
# https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
|
|
# https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
|
|
state = 0
|
|
has_contiguous_subspace = False
|
|
for index in indices:
|
|
if state == 0:
|
|
if index is not None:
|
|
state = 1
|
|
elif state == 1:
|
|
if index is None:
|
|
state = 2
|
|
else:
|
|
if index is not None:
|
|
break
|
|
else:
|
|
has_contiguous_subspace = True
|
|
|
|
# transposeToFront
|
|
# This is the logic that causes the newly inserted dimensions to show up
|
|
# at the beginning of the tensor, if they're not contiguous
|
|
if not has_contiguous_subspace:
|
|
dims = []
|
|
transposed_indices = []
|
|
for i, index in enumerate(indices):
|
|
if index is not None:
|
|
dims.append(i)
|
|
transposed_indices.append(index)
|
|
for i, index in enumerate(indices):
|
|
if index is None:
|
|
dims.append(i)
|
|
transposed_indices.append(index)
|
|
self = self.permute(dims)
|
|
indices = transposed_indices
|
|
|
|
# AdvancedIndex::AdvancedIndex
|
|
# Now we can assume the indices have contiguous subspace
|
|
# This is simplified from AdvancedIndex which goes to more effort
|
|
# to put the input and indices in a form so that TensorIterator can
|
|
# take them. If we write a ref for this, probably that logic should
|
|
# get implemented
|
|
before_shape: List[int] = []
|
|
after_shape: List[int] = []
|
|
replacement_shape: List[int] = []
|
|
for dim, index in enumerate(indices):
|
|
if index is None:
|
|
if replacement_shape:
|
|
after_shape.append(self.shape[dim])
|
|
else:
|
|
before_shape.append(self.shape[dim])
|
|
else:
|
|
replacement_shape = list(index.shape)
|
|
return self.new_empty(before_shape + replacement_shape + after_shape)
|
|
|
|
|
|
@register_meta([aten.convolution_backward.default])
|
|
def meta_convolution_backward(
|
|
grad_output_,
|
|
input_,
|
|
weight_,
|
|
bias_sizes_opt,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
transposed,
|
|
output_padding,
|
|
groups,
|
|
output_mask,
|
|
):
|
|
# High level logic taken from slow_conv3d_backward_cpu which should
|
|
# be representative of all convolution_backward impls
|
|
backend_grad_input = None
|
|
backend_grad_weight = None
|
|
backend_grad_bias = None
|
|
|
|
if output_mask[0]:
|
|
backend_grad_input = grad_output_.new_empty(input_.size())
|
|
if output_mask[1]:
|
|
backend_grad_weight = grad_output_.new_empty(weight_.size())
|
|
if output_mask[2]:
|
|
backend_grad_bias = grad_output_.new_empty(bias_sizes_opt)
|
|
|
|
return (backend_grad_input, backend_grad_weight, backend_grad_bias)
|
|
|
|
|
|
@register_meta([aten.addbmm.default, aten.addbmm.out])
|
|
@out_wrapper()
|
|
def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1):
|
|
dim1 = batch1.size(1)
|
|
dim2 = batch2.size(2)
|
|
self = self.expand((dim1, dim2))
|
|
check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
|
|
check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
|
|
check(
|
|
batch1.size(0) == batch2.size(0),
|
|
lambda: f"batch1 and batch2 must have same number of batches, got {batch1.size(0)} and {batch2.size(0)}",
|
|
)
|
|
check(
|
|
batch1.size(2) == batch2.size(1),
|
|
lambda: (
|
|
f"Incompatible matrix sizes for bmm ({batch1.size(1)}x{batch1.size(2)} "
|
|
f"and {batch2.size(1)}x{batch2.size(2)})"
|
|
),
|
|
)
|
|
check(
|
|
self.size(0) == dim1 and self.size(1) == dim2,
|
|
lambda: "self tensor does not match matmul output shape",
|
|
)
|
|
return self.new_empty(self.size())
|
|
|
|
|
|
@register_meta(aten._cdist_forward.default)
|
|
def meta_cdist_forward(x1, x2, p, compute_mode):
|
|
check(
|
|
x1.dim() >= 2,
|
|
lambda: f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D",
|
|
)
|
|
check(
|
|
x2.dim() >= 2,
|
|
lambda: f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D",
|
|
)
|
|
check(
|
|
x1.size(-1) == x2.size(-1),
|
|
lambda: f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}",
|
|
)
|
|
check(
|
|
utils.is_float_dtype(x1.dtype),
|
|
lambda: "cdist only supports floating-point dtypes, X1 got: {x1.dtype}",
|
|
)
|
|
check(
|
|
utils.is_float_dtype(x2.dtype),
|
|
lambda: "cdist only supports floating-point dtypes, X2 got: {x2.dtype}",
|
|
)
|
|
check(p >= 0, lambda: "cdist only supports non-negative p values")
|
|
check(
|
|
compute_mode in (None, 1, 2),
|
|
lambda: f"possible modes: None, 1, 2, but was: {compute_mode}",
|
|
)
|
|
r1 = x1.size(-2)
|
|
r2 = x2.size(-2)
|
|
batch_tensor1 = x1.shape[:-2]
|
|
batch_tensor2 = x2.shape[:-2]
|
|
output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2))
|
|
output_shape.extend([r1, r2])
|
|
return x1.new_empty(output_shape)
|
|
|
|
|
|
@register_meta(aten._embedding_bag.default)
|
|
def meta_embedding_bag(
|
|
weight,
|
|
indices,
|
|
offsets,
|
|
scale_grad_by_freq=False,
|
|
mode=0,
|
|
sparse=False,
|
|
per_sample_weights=None,
|
|
include_last_offset=False,
|
|
padding_idx=-1,
|
|
):
|
|
check(
|
|
indices.dtype in (torch.long, torch.int),
|
|
lambda: f"expected indices to be long or int, got {indices.dtype}",
|
|
)
|
|
check(
|
|
offsets.dtype in (torch.long, torch.int),
|
|
lambda: f"expected offsets to be long or int, got {offsets.dtype}",
|
|
)
|
|
check(
|
|
utils.is_float_dtype(weight.dtype),
|
|
lambda: f"expected weight to be floating point type, got {weight.dtype}",
|
|
)
|
|
|
|
num_bags = offsets.size(0)
|
|
if include_last_offset:
|
|
check(
|
|
num_bags >= 1, lambda: "include_last_offset: numBags should be at least 1"
|
|
)
|
|
num_bags -= 1
|
|
|
|
output = weight.new_empty(num_bags, weight.size(1))
|
|
MODE_SUM, MODE_MEAN, MODE_MAX = range(3)
|
|
|
|
if per_sample_weights is not None:
|
|
check(
|
|
mode == MODE_SUM,
|
|
lambda: "embedding_bag: per_sample_weights only supported with mode='sum'",
|
|
)
|
|
check(
|
|
per_sample_weights.dtype == weight.dtype,
|
|
lambda: f"expected weight ({weight.dtype}) and per_sample_weights ({per_sample_weights.dtype}) to have same dtype",
|
|
)
|
|
check(
|
|
per_sample_weights.ndim == 1,
|
|
lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D",
|
|
)
|
|
check(
|
|
per_sample_weights.numel() == indices.numel(),
|
|
lambda: (
|
|
f"expected per_sample_weights.numel() ({per_sample_weights.numel()} "
|
|
f"to be the same as indices.numel() ({indices.numel()})"
|
|
),
|
|
)
|
|
|
|
def is_fast_path_index_select_scale(src, scale, output, padding_idx):
|
|
return (
|
|
is_fast_path_index_select(src, output, padding_idx) and scale.stride(0) == 1
|
|
)
|
|
|
|
def is_fast_path_index_select(src, output, padding_idx):
|
|
return (
|
|
(src.dtype == torch.float or src.dtype == torch.half)
|
|
and src.stride(1) == 1
|
|
and output.stride(1) == 1
|
|
and padding_idx < 0
|
|
)
|
|
|
|
def is_fast_path(src, scale, output, padding_idx):
|
|
if scale is not None:
|
|
return is_fast_path_index_select_scale(src, scale, output, padding_idx)
|
|
else:
|
|
return is_fast_path_index_select(src, output, padding_idx)
|
|
|
|
if device_hint(offsets) != "cpu":
|
|
offset2bag = indices.new_empty(indices.size(0))
|
|
bag_size = indices.new_empty(offsets.size())
|
|
if mode == MODE_MAX:
|
|
max_indices = indices.new_empty(num_bags, weight.size(1))
|
|
else:
|
|
max_indices = indices.new_empty(0)
|
|
else:
|
|
fast_path_sum = is_fast_path(weight, per_sample_weights, output, padding_idx)
|
|
if mode == MODE_MEAN or mode == MODE_MAX or not fast_path_sum:
|
|
offset2bag = offsets.new_empty(indices.size(0))
|
|
else:
|
|
offset2bag = offsets.new_empty(0)
|
|
bag_size = offsets.new_empty(num_bags)
|
|
# This part of the logic comes from make_max_indices_out in EmbeddingBag.cpp
|
|
numBags = offsets.shape[0]
|
|
if mode == MODE_MAX:
|
|
if include_last_offset:
|
|
check(
|
|
numBags >= 1,
|
|
lambda: "include_last_offset: numBags should be at least 1",
|
|
)
|
|
numBags -= 1
|
|
max_indices = offsets.new_empty(numBags, weight.shape[1])
|
|
else:
|
|
max_indices = offsets.new_empty(bag_size.size())
|
|
return output, offset2bag, bag_size, max_indices
|
|
|
|
|
|
@register_meta(aten._embedding_bag_forward_only.default)
|
|
def meta_embedding_bag_forward_only(weight, indices, offsets, *args):
|
|
output, offset2bag, bag_size, max_indices = meta_embedding_bag(
|
|
weight, indices, offsets, *args
|
|
)
|
|
if device_hint(offsets) == "cpu":
|
|
bag_size = offsets.new_empty(offsets.size())
|
|
return output, offset2bag, bag_size, max_indices
|
|
|
|
|
|
def _get_reduction_dtype(input, dtype, promote_int_to_long=True):
|
|
# if specified, dtype takes precedence
|
|
if dtype:
|
|
return dtype
|
|
|
|
if input.dtype.is_floating_point or input.dtype.is_complex:
|
|
return input.dtype
|
|
elif promote_int_to_long:
|
|
return torch.long
|
|
|
|
return input.dtype
|
|
|
|
|
|
@register_meta([aten.nansum.default, aten.nansum.out])
|
|
@out_wrapper()
|
|
def meta_nansum(input, dims=None, keepdim=False, *, dtype=None):
|
|
output_dtype = _get_reduction_dtype(input, dtype, promote_int_to_long=True)
|
|
dims = utils.reduction_dims(input.shape, dims)
|
|
output_shape = _compute_reduction_shape(input, dims, keepdim)
|
|
return input.new_empty(output_shape, dtype=output_dtype)
|
|
|
|
|
|
@register_meta(aten.nanmedian.default)
|
|
def meta_nanmedian(input):
|
|
output_shape = utils.compute_reduction_output_shape(
|
|
input.shape, tuple(range(input.dim()))
|
|
)
|
|
return input.new_empty(output_shape)
|
|
|
|
|
|
@register_meta([aten.nanmedian.dim, aten.nanmedian.dim_values])
|
|
@out_wrapper("values", "indices")
|
|
def meta_nanmedian_dim(input, dim=-1, keepdim=False):
|
|
dim = utils.reduction_dims(input.shape, (dim,))
|
|
output_shape = _compute_reduction_shape(input, dim, keepdim)
|
|
return (
|
|
input.new_empty(output_shape),
|
|
input.new_empty(output_shape, dtype=torch.long),
|
|
)
|
|
|
|
|
|
@register_meta(aten.logical_not_.default)
|
|
def meta_logical_not_(self):
|
|
return self
|
|
|
|
|
|
@register_meta(aten.repeat.default)
|
|
def meta_repeat(self, repeats):
|
|
check(
|
|
len(repeats) >= self.dim(),
|
|
lambda: "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
|
|
)
|
|
# Add new leading dimensions to the tensor if the
|
|
# number of target dimensions is larger than the
|
|
# number of source dimensions.
|
|
num_new_dimensions = len(repeats) - self.dim()
|
|
padded_size = (1,) * num_new_dimensions + tuple(self.shape)
|
|
target_size = [padded_size[i] * repeats[i] for i in range(len(repeats))]
|
|
return self.new_empty(target_size)
|
|
|
|
|
|
@register_meta(aten.zero_.default)
|
|
def meta_zero_(self):
|
|
return self
|
|
|
|
|
|
@register_meta(
|
|
[
|
|
aten.mul_.Scalar,
|
|
aten.div_.Scalar,
|
|
aten.mul_.Tensor,
|
|
aten.div_.Tensor,
|
|
aten.logical_and_.default,
|
|
aten.logical_or_.default,
|
|
aten.logical_xor_.default,
|
|
],
|
|
)
|
|
def meta_binop_inplace(self, other):
|
|
return self
|
|
|
|
|
|
@register_meta(
|
|
[
|
|
aten.add_.Scalar,
|
|
aten.sub_.Scalar,
|
|
aten.add_.Tensor,
|
|
aten.sub_.Tensor,
|
|
],
|
|
)
|
|
def meta_binop_inplace_alpha(self, other, alpha=1):
|
|
return self
|
|
|
|
|
|
@register_meta([aten.round.default, aten.round.decimals])
|
|
def meta_round(self, **kwargs):
|
|
return _elementwise_meta(
|
|
self, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT
|
|
)
|
|
|
|
|
|
@register_meta(aten.zero.default)
|
|
def meta_zero(self):
|
|
return self.new_empty(self.shape)
|
|
|
|
|
|
@register_meta([aten.fill_.Tensor, aten.fill_.Scalar])
|
|
def meta_fill_(self, val):
|
|
return self
|
|
|
|
|
|
@register_meta([aten.fill.Tensor, aten.fill.Scalar])
|
|
def meta_fill(self, val):
|
|
return torch.empty_like(self)
|
|
|
|
|
|
@register_meta(aten.relu_.default)
|
|
def meta_relu_(self):
|
|
return self
|
|
|
|
|
|
@register_meta(aten.index_put.default)
|
|
def meta_index_put(self, indices, values, accumulate=False):
|
|
return torch.empty_like(self)
|
|
|
|
|
|
@register_meta(aten.masked_fill_.Scalar)
|
|
def meta_masked_fill_(self, mask, value):
|
|
return self
|
|
|
|
|
|
@register_meta(aten.index_put_.default)
|
|
def meta_index_put_(self, indices, values, accumulate=False):
|
|
return self
|
|
|
|
|
|
@register_meta(aten.alias.default)
|
|
def meta_alias(self):
|
|
return self.view(self.shape)
|
|
|
|
|
|
def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None):
|
|
check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
|
|
check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
|
|
|
|
batch1_sizes = batch1.size()
|
|
batch2_sizes = batch2.size()
|
|
|
|
bs = batch1_sizes[0]
|
|
contraction_size = batch1_sizes[2]
|
|
res_rows = batch1_sizes[1]
|
|
res_cols = batch2_sizes[2]
|
|
output_size = (bs, res_rows, res_cols)
|
|
|
|
check(
|
|
batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size,
|
|
lambda: f"Expected size for first two dimensions of batch2 tensor to be: [{bs}"
|
|
f", {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}].",
|
|
)
|
|
|
|
# TODO: handle out
|
|
|
|
output = batch2.new_empty(output_size)
|
|
|
|
if not is_bmm and self_baddbmm is not None:
|
|
check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor")
|
|
check(
|
|
self_baddbmm.size() == output_size,
|
|
lambda: "Expected an input tensor shape with shape {output_size} but got shape: {self.size()}",
|
|
)
|
|
|
|
return output
|
|
|
|
|
|
@register_meta(aten.bmm.default)
|
|
def meta_bmm(self, mat2):
|
|
return common_meta_baddbmm_bmm(self, mat2, True)
|
|
|
|
|
|
def div_rtn(x, y):
|
|
q = x // y
|
|
r = x % y
|
|
# WARNING: explicit bool conversion here is necessary;
|
|
# would be fixed by SymBool
|
|
if r != 0 and (bool(r < 0) != bool(y < 0)):
|
|
q -= 1
|
|
return q
|
|
|
|
|
|
def pooling_output_shape_pad_lr(
|
|
inputSize, kernelSize, pad_l, pad_r, stride, dilation, ceil_mode
|
|
):
|
|
outputSize = (
|
|
div_rtn(
|
|
inputSize
|
|
+ pad_l
|
|
+ pad_r
|
|
- dilation * (kernelSize - 1)
|
|
- 1
|
|
+ (stride - 1 if ceil_mode else 0),
|
|
stride,
|
|
)
|
|
+ 1
|
|
)
|
|
if ceil_mode:
|
|
if (outputSize - 1) * stride >= inputSize + pad_l:
|
|
outputSize -= 1
|
|
return outputSize
|
|
|
|
|
|
def pooling_output_shape(inputSize, kernelSize, pad, stride, dilation, ceil_mode):
|
|
check(stride != 0, lambda: "stride should not be zero")
|
|
check(pad >= 0, lambda: f"pad must be non-negative, but got pad: {pad}")
|
|
check(
|
|
pad <= kernelSize // 2,
|
|
lambda: f"pad should be at most half of kernel size, but got pad={pad} and kernel_size={kernelSize}",
|
|
)
|
|
return pooling_output_shape_pad_lr(
|
|
inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode
|
|
)
|
|
|
|
|
|
def pool2d_shape_check(
|
|
input,
|
|
kH,
|
|
kW,
|
|
dH,
|
|
dW,
|
|
padH,
|
|
padW,
|
|
dilationH,
|
|
dilationW,
|
|
nInputPlane,
|
|
inputHeight,
|
|
inputWidth,
|
|
outputHeight,
|
|
outputWidth,
|
|
memory_format,
|
|
):
|
|
ndim = input.dim()
|
|
nOutputPlane = nInputPlane
|
|
|
|
check(
|
|
kW > 0 and kH > 0,
|
|
lambda: "kernel size should be greater than zero, but got kH: {kH}, kW: {kW}",
|
|
)
|
|
check(
|
|
dW > 0 and dH > 0,
|
|
lambda: "stride should be greater than zero, but got dH: {dH}, dW: {dW}",
|
|
)
|
|
check(
|
|
dilationH > 0 and dilationW > 0,
|
|
lambda: "dilation should be greater than zero, but got dilationH: {dilationH}, dilationW: {dilationW}",
|
|
)
|
|
|
|
valid_dims = input.size(1) != 0 and input.size(2) != 0
|
|
|
|
if memory_format == torch.channels_last:
|
|
check(
|
|
ndim == 4 and valid_dims and input.size(3) != 0,
|
|
lambda: "Expected 4D (batch mode) tensor expected for input with channels_last layout"
|
|
" with optional 0 dim batch size for input, but got: {input.size()}",
|
|
)
|
|
else:
|
|
check(
|
|
(ndim == 3 and input.size(0) != 0 and valid_dims)
|
|
or (ndim == 4 and valid_dims and input.size(3) != 0),
|
|
lambda: f"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got: {input.size()}",
|
|
)
|
|
|
|
check(
|
|
kW // 2 >= padW and kH // 2 >= padH,
|
|
lambda: "pad should be smaller than or equal to half of kernel size, but got "
|
|
f"padW = {padW}, padH = {padH}, kW = {kW}, kH = {kH}",
|
|
)
|
|
|
|
check(
|
|
outputWidth >= 1 and outputHeight >= 1,
|
|
lambda: f"Given input size: ({nInputPlane}x{inputHeight}x{inputWidth}). "
|
|
f"Calculated output size: ({nOutputPlane}x{outputHeight}x{outputWidth}). "
|
|
"Output size is too small",
|
|
)
|
|
|
|
|
|
def max_pool2d_checks_and_compute_shape(
|
|
input, kernel_size, stride, padding, dilation, ceil_mode
|
|
):
|
|
# Reference: aten/src/ATen/native/DilatedMaxPool2d.cpp
|
|
def unpack(name, val):
|
|
check(
|
|
len(val) in [1, 2],
|
|
lambda: f"max_pool2d: {name} must either be a single int, or a tuple of two ints",
|
|
)
|
|
H = val[0]
|
|
W = H if len(val) == 1 else val[1]
|
|
return H, W
|
|
|
|
kH, kW = unpack("kernel_size", kernel_size)
|
|
|
|
check(
|
|
len(stride) in [0, 1, 2],
|
|
lambda: "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
|
|
)
|
|
if len(stride) == 0:
|
|
dH, dW = kH, kW
|
|
else:
|
|
dH, dW = unpack("stride", stride)
|
|
|
|
padH, padW = unpack("padding", padding)
|
|
dilationH, dilationW = unpack("dilation", dilation)
|
|
nInputPlane = input.size(-3)
|
|
inputHeight = input.size(-2)
|
|
inputWidth = input.size(-1)
|
|
|
|
memory_format = utils.suggest_memory_format(input)
|
|
if memory_format == torch.channels_last:
|
|
check(
|
|
input.dim() == 4,
|
|
lambda: "non-empty 4D (batch mode) tensor expected for input with channels_last layout",
|
|
)
|
|
elif memory_format == torch.contiguous_format:
|
|
check(
|
|
input.dim() in [3, 4],
|
|
lambda: "non-empty 3D or 4D (batch mode) tensor expected for input",
|
|
)
|
|
else:
|
|
check(
|
|
False,
|
|
lambda: "Unsupport memory format. Supports only ChannelsLast, Contiguous",
|
|
)
|
|
|
|
outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode)
|
|
outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode)
|
|
|
|
pool2d_shape_check(
|
|
input,
|
|
kH,
|
|
kW,
|
|
dH,
|
|
dW,
|
|
padH,
|
|
padW,
|
|
dilationH,
|
|
dilationW,
|
|
nInputPlane,
|
|
inputHeight,
|
|
inputWidth,
|
|
outputHeight,
|
|
outputWidth,
|
|
memory_format,
|
|
)
|
|
|
|
return nInputPlane, outputHeight, outputWidth
|
|
|
|
|
|
@register_meta(aten.max_pool2d_with_indices_backward.default)
|
|
def meta_max_pool2d_with_indices_backward(
|
|
grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices
|
|
):
|
|
nInputPlane, outputHeight, outputWidth = max_pool2d_checks_and_compute_shape(
|
|
self, kernel_size, stride, padding, dilation, ceil_mode
|
|
)
|
|
|
|
check(
|
|
self.dtype == grad_output.dtype,
|
|
lambda: "expected dtype {self.dtype} for `gradOutput` but got dtype {grad_output.dtype}",
|
|
)
|
|
|
|
nOutputPlane = nInputPlane
|
|
ndim = self.ndim
|
|
|
|
def _check_dim_size(t):
|
|
check_dim_size(t, ndim, ndim - 3, nOutputPlane)
|
|
check_dim_size(t, ndim, ndim - 2, outputHeight)
|
|
check_dim_size(t, ndim, ndim - 1, outputWidth)
|
|
|
|
_check_dim_size(grad_output)
|
|
_check_dim_size(indices)
|
|
|
|
memory_format = utils.suggest_memory_format(self)
|
|
return torch.empty(
|
|
self.shape, dtype=self.dtype, device=self.device, memory_format=memory_format
|
|
)
|
|
|
|
|
|
@register_meta(aten.max_pool2d_with_indices.default)
|
|
def meta_max_pool2d_with_indices(
|
|
input, kernel_size, stride=(), padding=(0,), dilation=(1,), ceil_mode=False
|
|
):
|
|
nInputPlane, outputHeight, outputWidth = max_pool2d_checks_and_compute_shape(
|
|
input, kernel_size, stride, padding, dilation, ceil_mode
|
|
)
|
|
|
|
nbatch = input.size(-4) if input.dim() == 4 else 1
|
|
memory_format = utils.suggest_memory_format(input)
|
|
if input.dim() == 3:
|
|
size = [nInputPlane, outputHeight, outputWidth]
|
|
else:
|
|
size = [nbatch, nInputPlane, outputHeight, outputWidth]
|
|
return (
|
|
torch.empty(
|
|
size, dtype=input.dtype, device=input.device, memory_format=memory_format
|
|
),
|
|
torch.empty(
|
|
size, dtype=torch.int64, device=input.device, memory_format=memory_format
|
|
),
|
|
)
|
|
|
|
|
|
@register_meta(aten.grid_sampler_2d_backward.default)
|
|
def grid_sampler_2d_backward_meta(
|
|
grad_output,
|
|
input,
|
|
grid,
|
|
interpolation_mode,
|
|
padding_mode,
|
|
align_corners,
|
|
output_mask,
|
|
):
|
|
input_requires_grad = output_mask[0]
|
|
if input_requires_grad:
|
|
grad_input = torch.zeros_like(input, memory_format=torch.contiguous_format)
|
|
else:
|
|
grad_input = None
|
|
grad_grid = torch.empty_like(grid, memory_format=torch.contiguous_format)
|
|
return (grad_input, grad_grid)
|
|
|
|
|
|
@register_meta([aten.full.default])
|
|
def full(size, fill_value, *args, **kwargs):
|
|
return torch.empty(size, *args, **kwargs)
|
|
|
|
|
|
@register_meta(
|
|
[
|
|
aten.randint_like.default,
|
|
aten.randint_like.low_dtype,
|
|
aten.randn_like.default,
|
|
aten.rand_like.default,
|
|
aten.full_like.default,
|
|
aten.ones_like.default,
|
|
]
|
|
)
|
|
def meta_like(self, *args, **kwargs):
|
|
return aten.empty_like.default(self, **kwargs)
|
|
|
|
|
|
# zeros_like is special cased to work for sparse
|
|
@register_meta(aten.zeros_like.default)
|
|
def zeros_like(
|
|
self, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None
|
|
):
|
|
if layout == torch.sparse_coo:
|
|
check(
|
|
memory_format is None,
|
|
lambda: "memory format option is only supported by strided tensors",
|
|
)
|
|
|
|
res = torch.empty(
|
|
0,
|
|
dtype=self.dtype if dtype is None else dtype,
|
|
layout=layout,
|
|
device=self.device if device is None else device,
|
|
pin_memory=pin_memory,
|
|
)
|
|
|
|
if self.is_sparse:
|
|
res.sparse_resize_and_clear_(
|
|
self.size(), self.sparse_dim(), self.dense_dim()
|
|
)
|
|
else:
|
|
res.sparse_resize_and_clear_(self.size(), self.dim(), 0)
|
|
|
|
res._coalesced_(True)
|
|
return res
|
|
return aten.empty_like.default(
|
|
self,
|
|
dtype=dtype,
|
|
layout=layout,
|
|
device=device,
|
|
pin_memory=pin_memory,
|
|
memory_format=memory_format,
|
|
)
|
|
|
|
|
|
@register_meta(aten.select.int)
|
|
def meta_select(self, dim, index):
|
|
ndim = self.dim()
|
|
check(
|
|
ndim != 0, lambda: "select() cannot be applied to a 0-dim tensor.", IndexError
|
|
)
|
|
|
|
dim = dim if dim >= 0 else dim + ndim
|
|
size = self.size(dim)
|
|
|
|
check(
|
|
not (-index > size or index >= size),
|
|
lambda: f"select(): index {index} out of range for tensor of size "
|
|
f"{self.size()} at dimension {dim}",
|
|
IndexError,
|
|
)
|
|
|
|
index = index if index >= 0 else index + size
|
|
|
|
new_size = list(self.size())
|
|
new_stride = list(self.stride())
|
|
|
|
new_storage_offset = self.storage_offset() + index * new_stride[dim]
|
|
del new_size[dim]
|
|
del new_stride[dim]
|
|
|
|
return self.as_strided(new_size, new_stride, new_storage_offset)
|
|
|
|
|
|
@register_meta(aten.select_scatter.default)
|
|
def meta_select_scatter(self, src, dim, index):
|
|
return utils.clone_preserve_strides(self)
|
|
|
|
|
|
@register_meta(aten.slice_scatter.default)
|
|
def meta_slice_scatter(self, src, dim=0, start=None, end=None, step=1):
|
|
return utils.clone_preserve_strides(self)
|
|
|
|
|
|
# TODO: Deduplicate this with canonicalize_dim
|
|
def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True):
|
|
if dim_post_expr <= 0:
|
|
assert wrap_scalar
|
|
dim_post_expr = 1
|
|
min = -dim_post_expr
|
|
max = dim_post_expr - 1
|
|
assert not (dim < min or dim > max), f"dim {dim} out of bounds ({min}, {max})"
|
|
if dim < 0:
|
|
dim += dim_post_expr
|
|
return dim
|
|
|
|
|
|
def ensure_nonempty_size(t, dim):
|
|
return 1 if t.dim() == 0 else t.shape[dim]
|
|
|
|
|
|
# From aten/src/ATen/native/ScatterGatherChecks.h
|
|
def gather_shape_check(self, dim, index):
|
|
self_dims = max(self.dim(), 1)
|
|
index_dims = max(index.dim(), 1)
|
|
check(
|
|
self_dims == index_dims,
|
|
lambda: "Index tensor must have the same number of dimensions as input tensor",
|
|
)
|
|
for i in range(self_dims):
|
|
if i != dim:
|
|
check(
|
|
ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i),
|
|
lambda: f"Size does not match at dimension {i} expected index {index.shape}"
|
|
+ f" to be smaller than self {self.shape} apart from dimension {dim}",
|
|
)
|
|
|
|
|
|
@register_meta(aten.gather.default)
|
|
def meta_gather(self, dim, index, sparse_grad=False):
|
|
wrapped_dim = maybe_wrap_dim(dim, self.dim())
|
|
is_index_empty = index.numel() == 0
|
|
if not is_index_empty:
|
|
check(
|
|
index.dtype == torch.long,
|
|
lambda: f"gather(): Expected dtype int64 for index, but got {index.dtype}",
|
|
)
|
|
gather_shape_check(self, wrapped_dim, index)
|
|
return self.new_empty(index.shape)
|
|
|
|
|
|
# From aten/src/ATen/native/TensorAdvancedIndexing.cpp
|
|
def get_operator_enum(reduce_, use_new_options=False):
|
|
if use_new_options:
|
|
if reduce_ == "sum":
|
|
return "REDUCE_ADD"
|
|
elif reduce_ == "prod":
|
|
return "REDUCE_MULTIPLY"
|
|
elif reduce_ == "mean":
|
|
return "REDUCE_MEAN"
|
|
elif reduce_ == "amax":
|
|
return "REDUCE_MAXIMUM"
|
|
elif reduce_ == "amin":
|
|
return "REDUCE_MINIMUM"
|
|
check(
|
|
False,
|
|
lambda: "reduce argument must be either sum, prod, mean, amax or amin.",
|
|
)
|
|
return
|
|
else:
|
|
if reduce_ == "add":
|
|
return "REDUCE_ADD"
|
|
elif reduce_ == "multiply":
|
|
return "REDUCE_MULTIPLY"
|
|
check(False, lambda: "reduce argument must be either add or multiply.")
|
|
return
|
|
|
|
|
|
# From aten/src/ATen/native/ScatterGatherChecks.h
|
|
def scatter_gather_dtype_check(method_name, self, index, src_opt=None):
|
|
if index.numel() != 0:
|
|
check(
|
|
index.dtype == torch.long,
|
|
lambda: f"{method_name}(): Expected dtype int64 for index",
|
|
)
|
|
|
|
if src_opt is not None:
|
|
check(
|
|
self.dtype == src_opt.dtype,
|
|
lambda: f"{method_name}(): Expected self.dtype to be equal to src.dtype",
|
|
)
|
|
|
|
|
|
def ensure_nonempty_dim(dim):
|
|
return max(dim, 1)
|
|
|
|
|
|
# From aten/src/ATen/native/ScatterGatherChecks.h
|
|
def scatter_shape_check(self, dim, index, src_opt=None):
|
|
if index.numel() == 0:
|
|
return
|
|
check(
|
|
ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
|
|
lambda: "Index tensor must have the same number of dimensions as self tensor",
|
|
)
|
|
|
|
is_wrong_shape = False
|
|
self_dims = ensure_nonempty_dim(self.dim())
|
|
|
|
# Check: index.size(d) <= self.size(d) for all d != dim
|
|
for d in range(self_dims):
|
|
index_d_size = ensure_nonempty_size(index, d)
|
|
if d == dim:
|
|
continue
|
|
if index_d_size > ensure_nonempty_size(self, d):
|
|
is_wrong_shape = True
|
|
break
|
|
|
|
# Check: index.size(d) <= src.size(d) for all d if src is Tensor
|
|
if not is_wrong_shape and src_opt is not None:
|
|
for d in range(self_dims):
|
|
index_d_size = ensure_nonempty_size(index, d)
|
|
if index_d_size > ensure_nonempty_size(src_opt, d):
|
|
is_wrong_shape = True
|
|
break
|
|
|
|
if src_opt is not None:
|
|
check(
|
|
ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
|
|
lambda: "Index tensor must have the same number of dimensions as self tensor",
|
|
)
|
|
check(
|
|
not is_wrong_shape,
|
|
lambda: f"Expected index {index.shape} to be smaller than self {self.shape}"
|
|
+ f" apart from dimension {dim} and to be smaller than src {src_opt.shape}",
|
|
)
|
|
else:
|
|
check(
|
|
not is_wrong_shape,
|
|
lambda: f"Expected index {index.shape} to be smaller than self {self.shape}"
|
|
+ f" apart from dimension {dim}",
|
|
)
|
|
|
|
|
|
# From aten/src/ATen/native/TensorAdvancedIndexing.cpp
|
|
def scatter_meta_impl(self, dim, index, src=None, reduce_=None, use_new_options=False):
|
|
wrapped_dim = maybe_wrap_dim(dim, self.dim())
|
|
scatter_gather_dtype_check("scatter", self, index, src)
|
|
scatter_shape_check(self, wrapped_dim, index, src)
|
|
if reduce_ is not None:
|
|
# Check if we have a valid reduce operator.
|
|
get_operator_enum(reduce_, use_new_options)
|
|
|
|
|
|
@register_meta(aten.scatter_add.default)
|
|
def meta_scatter_add(self, dim, index, src):
|
|
scatter_meta_impl(self, dim, index, src, "add")
|
|
return self.new_empty(self.shape)
|
|
|
|
|
|
@register_meta(aten.scatter_add_)
|
|
def meta_scatter_add_(self, dim, index, src):
|
|
scatter_meta_impl(self, dim, index, src, "add")
|
|
return self
|
|
|
|
|
|
@register_meta(
|
|
[
|
|
aten.scatter.src,
|
|
aten.scatter.value,
|
|
aten.scatter.reduce,
|
|
aten.scatter.value_reduce,
|
|
]
|
|
)
|
|
@out_wrapper()
|
|
def meta_scatter(self, dim, index, src_or_value, reduce=None):
|
|
src = src_or_value if isinstance(src_or_value, torch.Tensor) else None
|
|
scatter_meta_impl(self, dim, index, src, reduce)
|
|
return self.new_empty(self.shape)
|
|
|
|
|
|
@register_meta(
|
|
[
|
|
aten.scatter_.src,
|
|
aten.scatter_.value,
|
|
aten.scatter_.reduce,
|
|
aten.scatter_.value_reduce,
|
|
]
|
|
)
|
|
def meta_scatter_(self, dim, index, src_or_value, reduce=None):
|
|
src = src_or_value if isinstance(src_or_value, torch.Tensor) else None
|
|
scatter_meta_impl(self, dim, index, src, reduce)
|
|
return self
|
|
|
|
|
|
@register_meta(
|
|
[
|
|
aten._scaled_dot_product_flash_attention,
|
|
]
|
|
)
|
|
def meta__scaled_dot_product_flash(
|
|
query: Tensor,
|
|
key: Tensor,
|
|
value: Tensor,
|
|
dropout_p: float = 0.0,
|
|
is_causal: bool = False,
|
|
return_debug_mask: bool = False,
|
|
):
|
|
# [Note] SDPA_flash's meta function returns incorrect Philox seed and offset:
|
|
# We have added logic to torch/_dynamo/variables/torch.py
|
|
# We need to check if scaled_dot_product_attention will run the flash attention
|
|
# kernel and if dropout is != 0.0. If that is the case then we want dynamo
|
|
# to graph break. The derivative calculation for _scaled_dot_product_flash_attention
|
|
# does not function correctly with cuda graphs because the full philox state is not captured
|
|
# the forward's return values. Another reason to graph break is that the the meta function
|
|
# returns the wrong outputs for philox seed and offset and these values get baked into the
|
|
# inductor fallback calls to the eager kernels.
|
|
check(
|
|
dropout_p == 0.0,
|
|
lambda: f"Can only trace _scaled_dot_product_flash_attention when dropout is set to 0 but got a dropout_p of {dropout_p}.",
|
|
)
|
|
batch_size = query.size(0)
|
|
num_heads = query.size(1)
|
|
max_seqlen_batch_q = query.size(2)
|
|
head_dim = query.size(3)
|
|
|
|
max_seqlen_batch_k = key.size(2)
|
|
|
|
query = query.transpose(1, 2)
|
|
key = key.transpose(1, 2)
|
|
value = value.transpose(1, 2)
|
|
|
|
Nnz_q = batch_size * max_seqlen_batch_q
|
|
|
|
output = torch.empty(
|
|
(Nnz_q, num_heads, head_dim), dtype=query.dtype, device=query.device
|
|
)
|
|
output = output.view(batch_size, max_seqlen_batch_q, num_heads, head_dim).transpose(
|
|
1, 2
|
|
)
|
|
max_seqlen_q = math.ceil(max_seqlen_batch_q / 16) * 16
|
|
logsumexp = torch.empty(
|
|
(batch_size, num_heads, max_seqlen_q),
|
|
dtype=torch.float,
|
|
device=query.device,
|
|
)
|
|
cumulative_sequence_length_q = torch.empty(
|
|
batch_size + 1, dtype=torch.int32, device="meta"
|
|
)
|
|
cumulative_sequence_length_k = torch.empty(
|
|
batch_size + 1, dtype=torch.int32, device="meta"
|
|
)
|
|
|
|
if return_debug_mask:
|
|
blocksize_c = 128 if head_dim > 64 else 256
|
|
max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
|
|
if max_seqlen_batch_k <= 128:
|
|
max_seqlen_k = 128
|
|
elif max_seqlen_batch_k <= 256:
|
|
max_seqlen_k = 256
|
|
debug_mask = torch.empty(
|
|
(batch_size, num_heads, max_seqlen_q, max_seqlen_k),
|
|
dtype=query.dtype,
|
|
device=query.device,
|
|
)
|
|
else:
|
|
debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
|
|
|
|
return (
|
|
output,
|
|
logsumexp,
|
|
cumulative_sequence_length_q,
|
|
cumulative_sequence_length_k,
|
|
max_seqlen_batch_q,
|
|
max_seqlen_batch_k,
|
|
1, # Philox Seed will not be used, see note at top.
|
|
1, # Philox Offset will not be used, see note at top.
|
|
debug_mask,
|
|
)
|
|
|
|
|
|
@register_meta(
|
|
[
|
|
aten._scaled_dot_product_flash_attention_backward,
|
|
]
|
|
)
|
|
def meta__scaled_dot_product_flash_backward(
|
|
grad_out: Tensor,
|
|
query: Tensor,
|
|
key: Tensor,
|
|
value: Tensor,
|
|
out: Tensor,
|
|
logsumexp: Tensor,
|
|
cum_seq_q: Tensor,
|
|
cum_seq_k: Tensor,
|
|
max_q: int,
|
|
max_k: int,
|
|
dropout_p: float,
|
|
is_causal: bool,
|
|
philox_seed: int,
|
|
philox_offset: int,
|
|
):
|
|
batch_size = query.size(0)
|
|
num_heads = query.size(1)
|
|
head_dim = query.size(3)
|
|
|
|
Nnz_q = batch_size * max_q
|
|
Nnz_kv = batch_size * max_k
|
|
|
|
query = query.transpose(1, 2)
|
|
key = key.transpose(1, 2)
|
|
value = value.transpose(1, 2)
|
|
|
|
query_reshaped = query.reshape(Nnz_q, num_heads, head_dim)
|
|
key_reshaped = key.reshape(Nnz_kv, num_heads, head_dim)
|
|
value_reshaped = value.reshape(Nnz_kv, num_heads, head_dim)
|
|
|
|
grad_q = torch.empty_like(query_reshaped)
|
|
grad_k = torch.empty_like(key_reshaped)
|
|
grad_v = torch.empty_like(value_reshaped)
|
|
|
|
grad_q = grad_q.view(batch_size, max_q, num_heads, head_dim).transpose(1, 2)
|
|
grad_k = grad_k.view(batch_size, max_k, num_heads, head_dim).transpose(1, 2)
|
|
grad_v = grad_v.view(batch_size, max_k, num_heads, head_dim).transpose(1, 2)
|
|
|
|
return grad_q, grad_k, grad_v
|
|
|
|
|
|
@register_meta(
|
|
[
|
|
aten._scaled_dot_product_efficient_attention,
|
|
]
|
|
)
|
|
def meta__scaled_dot_product_efficient(
|
|
query: Tensor,
|
|
key: Tensor,
|
|
value: Tensor,
|
|
compute_log_sumexp: bool,
|
|
is_causal: bool = False,
|
|
):
|
|
query = query.transpose(1, 2)
|
|
key = key.transpose(1, 2)
|
|
value = value.transpose(1, 2)
|
|
|
|
B = query.size(0)
|
|
M = query.size(1)
|
|
N = key.size(1)
|
|
num_heads = query.size(-2)
|
|
K = query.size(-1)
|
|
Kv = value.size(-1)
|
|
|
|
res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device)
|
|
|
|
logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0
|
|
logsum_exp = torch.empty(
|
|
(B, num_heads, logsumexp_dim),
|
|
dtype=torch.float,
|
|
device=query.device,
|
|
)
|
|
|
|
res = res.transpose(1, 2)
|
|
|
|
return res, logsum_exp
|
|
|
|
|
|
@register_meta(
|
|
[
|
|
aten._scaled_dot_product_efficient_attention_backward,
|
|
]
|
|
)
|
|
def meta__scaled_dot_product_efficient_backward(
|
|
grad_out: Tensor,
|
|
query: Tensor,
|
|
key: Tensor,
|
|
value: Tensor,
|
|
out: Tensor,
|
|
logsumexp: Tensor,
|
|
is_causal: bool = False,
|
|
chunk_grad_outputs=False,
|
|
):
|
|
grad_out = grad_out.transpose(1, 2)
|
|
query = query.transpose(1, 2)
|
|
key = key.transpose(1, 2)
|
|
value = value.transpose(1, 2)
|
|
|
|
B = query.size(0)
|
|
M = query.size(1)
|
|
N = key.size(1)
|
|
nH = query.size(2)
|
|
K = query.size(3)
|
|
|
|
grad_kv_needs_init = is_causal and N > M
|
|
|
|
if chunk_grad_outputs:
|
|
chunk = torch.empty((B, M, 3, nH, K), dtype=query.dtype, device=query.device)
|
|
grad_q = chunk.select(2, 0)
|
|
grad_k = chunk.select(2, 1)
|
|
grad_v = chunk.select(2, 2)
|
|
else:
|
|
grad_q = torch.empty(query.shape, dtype=query.dtype, device=query.device)
|
|
grad_k = (
|
|
torch.zeros(key.shape, dtype=key.dtype, device=key.device)
|
|
if grad_kv_needs_init
|
|
else torch.empty(key.shape, dtype=key.dtype, device=key.device)
|
|
)
|
|
grad_v = (
|
|
torch.zeros(value.shape, dtype=value.dtype, device=value.device)
|
|
if grad_kv_needs_init
|
|
else torch.empty(value.shape, dtype=value.dtype, device=value.device)
|
|
)
|
|
return grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2)
|
|
|
|
|
|
@register_meta([aten.scatter_reduce.two, aten.scatter_reduce.two_out])
|
|
@out_wrapper()
|
|
def meta_scatter_reduce_two(self, dim, index, src, reduce, include_self=True):
|
|
scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True)
|
|
return self.new_empty(self.shape)
|
|
|
|
|
|
@register_meta(aten.scatter_reduce_.two)
|
|
def meta_scatter_reduce__two(self, dim, index, src, reduce, include_self=True):
|
|
scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True)
|
|
return self
|
|
|
|
|
|
def multiply_integers(vs):
|
|
r = 1
|
|
for v in vs:
|
|
r *= v
|
|
return r
|
|
|
|
|
|
def upsample_common_check(input_size, output_size, num_spatial_dims):
|
|
check(
|
|
len(output_size) == num_spatial_dims,
|
|
lambda: f"It is expected output_size equals to {num_spatial_dims}, but got size {len(output_size)}",
|
|
)
|
|
expected_input_dims = num_spatial_dims + 2 # N, C, ...
|
|
check(
|
|
len(input_size) == expected_input_dims,
|
|
lambda: f"It is expected input_size equals to {expected_input_dims}, but got size {len(input_size)}",
|
|
)
|
|
|
|
check(
|
|
all([s > 0 for s in input_size[2:]]) and all([s > 0 for s in output_size]),
|
|
lambda: f"Input and output sizes should be greater than 0, but got "
|
|
f"input size {input_size} and output size {output_size}",
|
|
)
|
|
|
|
nbatch, channels = input_size[:2]
|
|
return (nbatch, channels, *output_size)
|
|
|
|
|
|
@register_meta(aten.upsample_nearest1d.default)
|
|
def upsample_nearest1d(input, output_size, scales=None):
|
|
check(
|
|
input.numel() != 0 or multiply_integers(input.size()[1:]),
|
|
lambda: "Non-empty 3D data tensor expected but got a tensor with sizes {input.size()}",
|
|
)
|
|
full_output_size = upsample_common_check(
|
|
input.size(), output_size, num_spatial_dims=1
|
|
)
|
|
return input.new_empty(full_output_size).to(
|
|
memory_format=utils.suggest_memory_format(input)
|
|
)
|
|
|
|
|
|
@register_meta(aten.upsample_nearest2d.default)
|
|
def upsample_nearest2d(input, output_size, scales_h=None, scales_w=None):
|
|
check(
|
|
input.numel() != 0 or multiply_integers(input.size()[1:]),
|
|
lambda: "Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}",
|
|
)
|
|
full_output_size = upsample_common_check(
|
|
input.size(), output_size, num_spatial_dims=2
|
|
)
|
|
output = input.new_empty(full_output_size)
|
|
|
|
# convert output to correct memory format, if necessary
|
|
memory_format = utils.suggest_memory_format(input)
|
|
|
|
# following "heuristic: only use channels_last path when it's faster than the contiguous path"
|
|
_, n_channels, _, _ = input.shape
|
|
if input.device.type == "cuda" and n_channels < 4:
|
|
memory_format = torch.contiguous_format
|
|
|
|
output = output.contiguous(memory_format=memory_format)
|
|
|
|
return output
|
|
|
|
|
|
@register_meta(aten.upsample_nearest3d.default)
|
|
def upsample_nearest3d(input, output_size, scales_d=None, scales_h=None, scales_w=None):
|
|
check(
|
|
input.numel() != 0 or multiply_integers(input.size()[1:]),
|
|
lambda: "Non-empty 5D data tensor expected but got a tensor with sizes {input.size()}",
|
|
)
|
|
full_output_size = upsample_common_check(
|
|
input.size(), output_size, num_spatial_dims=3
|
|
)
|
|
return input.new_empty(full_output_size).to(
|
|
memory_format=utils.suggest_memory_format(input)
|
|
)
|
|
|
|
|
|
@register_meta([aten.sort.default, aten.sort.stable])
|
|
def meta_sort(self, stable=None, dim=-1, descending=False):
|
|
return torch.empty_like(self), torch.empty_like(self, dtype=torch.int64)
|
|
|
|
|
|
def rnn_cell_checkSizes(
|
|
input_gates, hidden_gates, input_bias, hidden_bias, factor, prev_hidden
|
|
):
|
|
check(input_gates.ndim == 2, lambda: f"{input_gates.ndim} != 2")
|
|
check(
|
|
input_gates.shape == hidden_gates.shape,
|
|
lambda: f"{input_gates.shape} != {hidden_gates.shape}",
|
|
)
|
|
gates_size = input_gates.size(1)
|
|
if input_bias is not None:
|
|
check(input_bias.ndim == 1, lambda: f"{input_bias.ndim} != 1")
|
|
check(
|
|
input_bias.numel() == gates_size,
|
|
lambda: f"{input_bias.numel()} != {gates_size}",
|
|
)
|
|
check(
|
|
input_bias.shape == hidden_bias.shape,
|
|
lambda: f"{input_bias.shape} != {hidden_bias.shape}",
|
|
)
|
|
check(prev_hidden.ndim == 2, lambda: f"{prev_hidden.ndim} != 2")
|
|
expected_prev_hidden_numel = input_gates.size(0) * gates_size // factor
|
|
check(
|
|
prev_hidden.numel() == expected_prev_hidden_numel,
|
|
lambda: f"{prev_hidden.numel()} != {input_gates.size(0)} * {gates_size} // {factor} (aka {expected_prev_hidden_numel})",
|
|
)
|
|
check(
|
|
all(
|
|
x.device == input_gates.device
|
|
for x in [hidden_gates, input_bias, hidden_bias, prev_hidden]
|
|
),
|
|
lambda: "expected all inputs to be same device",
|
|
)
|
|
|
|
|
|
@register_meta(aten._thnn_fused_lstm_cell.default)
|
|
def _thnn_fused_lstm_cell_meta(
|
|
input_gates, hidden_gates, cx, input_bias=None, hidden_bias=None
|
|
):
|
|
rnn_cell_checkSizes(input_gates, hidden_gates, input_bias, hidden_bias, 4, cx)
|
|
workspace = torch.empty_like(input_gates, memory_format=torch.contiguous_format)
|
|
hy = torch.empty_like(cx, memory_format=torch.contiguous_format)
|
|
cy = torch.empty_like(cx, memory_format=torch.contiguous_format)
|
|
return (hy, cy, workspace)
|
|
|
|
|
|
@register_meta(aten._cudnn_rnn.default)
|
|
def _cudnn_rnn(
|
|
input,
|
|
weight,
|
|
weight_stride0,
|
|
weight_buf,
|
|
hx,
|
|
cx,
|
|
mode,
|
|
hidden_size,
|
|
proj_size,
|
|
num_layers,
|
|
batch_first,
|
|
dropout,
|
|
train,
|
|
bidirectional,
|
|
batch_sizes,
|
|
dropout_state,
|
|
):
|
|
|
|
is_input_packed = len(batch_sizes) != 0
|
|
if is_input_packed:
|
|
seq_length = len(batch_sizes)
|
|
mini_batch = batch_sizes[0]
|
|
batch_sizes_sum = input.shape[0]
|
|
else:
|
|
seq_length = input.shape[1] if batch_first else input.shape[0]
|
|
mini_batch = input.shape[0] if batch_first else input.shape[1]
|
|
batch_sizes_sum = -1
|
|
|
|
num_directions = 2 if bidirectional else 1
|
|
out_size = proj_size if proj_size != 0 else hidden_size
|
|
if is_input_packed:
|
|
out_shape = [batch_sizes_sum, out_size * num_directions]
|
|
else:
|
|
out_shape = (
|
|
[mini_batch, seq_length, out_size * num_directions]
|
|
if batch_first
|
|
else [seq_length, mini_batch, out_size * num_directions]
|
|
)
|
|
output = input.new_empty(out_shape)
|
|
|
|
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
|
|
if cx is None:
|
|
cy = torch.empty(0, device=input.device)
|
|
else:
|
|
cy = cx.new_empty(cell_shape)
|
|
|
|
hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size])
|
|
|
|
# TODO: Query cudnnGetRNNTrainingReserveSize (expose to python)
|
|
reserve_shape = 0 if train else 0
|
|
reserve = input.new_empty(reserve_shape, dtype=torch.uint8)
|
|
|
|
return output, hy, cy, reserve, weight_buf
|
|
|
|
|
|
@register_meta(aten.mkldnn_rnn_layer.default)
|
|
def mkldnn_rnn_layer(
|
|
input,
|
|
w0,
|
|
w1,
|
|
w2,
|
|
w3,
|
|
hx_,
|
|
cx_,
|
|
reverse,
|
|
batch_sizes,
|
|
mode,
|
|
hidden_size,
|
|
num_layers,
|
|
has_biases,
|
|
bidirectional,
|
|
batch_first,
|
|
train,
|
|
):
|
|
seq_length = input.shape[1] if batch_first else input.shape[0]
|
|
mini_batch = input.shape[0] if batch_first else input.shape[1]
|
|
output_chanels = hidden_size
|
|
out_shape = (
|
|
[mini_batch, seq_length, output_chanels]
|
|
if batch_first
|
|
else [seq_length, mini_batch, output_chanels]
|
|
)
|
|
output = input.new_empty(out_shape)
|
|
if hx_ is None:
|
|
hy = torch.empty(0, device=input.device)
|
|
else:
|
|
hy = hx_.new_empty(hx_.shape)
|
|
if cx_ is None:
|
|
cy = torch.empty(0, device=input.device)
|
|
else:
|
|
cy = cx_.new_empty(cx_.shape)
|
|
workspace = torch.empty(0, device=input.device, dtype=torch.uint8)
|
|
return output, hy, cy, workspace
|
|
|
|
|
|
def zero_numel_check_dims(self, dim, fn_name):
|
|
if self.ndim == 0:
|
|
check(
|
|
dim == 0 or dim == -1,
|
|
lambda: f"{fn_name}: Expected reduction dim -1 or 0 for scalar but got {dim}",
|
|
IndexError,
|
|
)
|
|
else:
|
|
check(
|
|
self.size(dim) != 0,
|
|
lambda: f"{fn_name}: Expected reduction dim {dim} to have non-zero size.",
|
|
IndexError,
|
|
)
|
|
|
|
|
|
# From aten/src/ATen/native/ReduceOps.cpp
|
|
def check_argmax_argmin(name, self, dim):
|
|
if dim is not None:
|
|
dim = maybe_wrap_dim(dim, self.dim())
|
|
zero_numel_check_dims(self, dim, name)
|
|
else:
|
|
check(
|
|
self.numel() != 0,
|
|
lambda: f"{name}: Expected reduction dim to be specified for input.numel() == 0.",
|
|
)
|
|
|
|
|
|
@register_meta([aten.argmax.default, aten.argmin.default])
|
|
def argmax_argmin_meta(self, dim=None, keepdim=False):
|
|
check_argmax_argmin("argmax", self, dim)
|
|
dims = utils.reduction_dims(self.shape, (dim,) if dim is not None else None)
|
|
shape = _compute_reduction_shape(self, dims, keepdim)
|
|
return self.new_empty(shape, dtype=torch.int64)
|
|
|
|
|
|
@register_meta(aten.scalar_tensor.default)
|
|
def scalar_tensor(s, dtype=None, layout=None, device=None, pin_memory=None):
|
|
return torch.empty(
|
|
(), dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
|
|
)
|
|
|
|
|
|
@register_meta(aten.topk.default)
|
|
def topk_meta(self, k, dim=-1, largest=True, sorted=True):
|
|
# From aten/src/ATen/native/Sorting.cpp
|
|
dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True)
|
|
check(
|
|
k >= 0 and k <= (self.size(dim) if self.dim() > 0 else 1),
|
|
lambda: "selected index k out of range",
|
|
)
|
|
sliceSize = 1 if self.dim() == 0 else self.size(dim)
|
|
check(k >= 0 and k <= sliceSize, lambda: "k not in range for dimension")
|
|
|
|
topKSize = list(self.shape)
|
|
if len(topKSize) > 0:
|
|
topKSize[dim] = k
|
|
return self.new_empty(topKSize), self.new_empty(topKSize, dtype=torch.int64)
|
|
|
|
|
|
legacy_contiguous_memory_format = torch.contiguous_format
|
|
|
|
|
|
# From aten/src/ATen/native/cuda/RNN.cu
|
|
def checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace):
|
|
defined_grad = grad_hy if grad_hy is not None else grad_cy
|
|
check(defined_grad.dim() == 2, lambda: "")
|
|
exp_size = defined_grad.size()
|
|
if grad_hy is not None:
|
|
check(grad_hy.size() == exp_size, lambda: "")
|
|
if grad_cy is not None:
|
|
check(grad_cy.size() == exp_size, lambda: "")
|
|
check(cx.size() == exp_size, lambda: "")
|
|
check(cy.size() == exp_size, lambda: "")
|
|
check(workspace.dim() == 2, lambda: "")
|
|
check(workspace.numel() == exp_size[0] * exp_size[1] * 4, lambda: "")
|
|
|
|
|
|
# From aten/src/ATen/native/cuda/RNN.cu
|
|
@register_meta(aten._thnn_fused_lstm_cell_backward_impl.default)
|
|
def _thnn_fused_lstm_cell_backward_impl(grad_hy, grad_cy, cx, cy, workspace, has_bias):
|
|
if grad_hy is None and grad_cy is None:
|
|
return None, None, None
|
|
checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace)
|
|
grad_gates = torch.empty_like(
|
|
workspace, memory_format=legacy_contiguous_memory_format
|
|
)
|
|
grad_cx = torch.empty_like(cx, memory_format=legacy_contiguous_memory_format)
|
|
grad_bias = grad_gates.sum(0, keepdim=False) if has_bias else None
|
|
return grad_gates, grad_cx, grad_bias
|
|
|
|
|
|
@register_meta(aten.pixel_shuffle.default)
|
|
def meta_pixel_shuffle(self, upscale_factor):
|
|
assert (
|
|
len(self.shape) > 2 and self.shape[-3] % (upscale_factor * upscale_factor) == 0
|
|
), f"Invalid input shape for pixel_shuffle: {self.shape} with upscale_factor = {upscale_factor}"
|
|
|
|
def is_channels_last(ten):
|
|
return torch._prims_common.suggest_memory_format(ten) == torch.channels_last
|
|
|
|
def pick_memory_format():
|
|
if is_channels_last(self):
|
|
if device_hint(self) == "cuda":
|
|
return torch.contiguous_format
|
|
else:
|
|
return torch.channels_last
|
|
elif self.is_contiguous(memory_format=torch.contiguous_format):
|
|
return torch.contiguous_format
|
|
elif self.is_contiguous(memory_format=torch.preserve_format):
|
|
return torch.preserve_format
|
|
|
|
C = self.shape[-3] // (upscale_factor * upscale_factor)
|
|
Hr = self.shape[-2] * upscale_factor
|
|
Wr = self.shape[-1] * upscale_factor
|
|
out_shape = (*self.shape[:-3], C, Hr, Wr)
|
|
|
|
out = self.new_empty(out_shape)
|
|
out = out.to(memory_format=pick_memory_format()) # type: ignore[call-overload]
|
|
return out
|
|
|
|
|
|
@register_meta(aten.mkldnn_rnn_layer_backward.default)
|
|
def mkldnn_rnn_layer_backward(
|
|
input,
|
|
weight0,
|
|
weight1,
|
|
weight2,
|
|
weight3,
|
|
hx_,
|
|
cx_tmp,
|
|
output,
|
|
hy_,
|
|
cy_,
|
|
grad_output_r_opt,
|
|
grad_hy_r_opt,
|
|
grad_cy_r_opt,
|
|
reverse,
|
|
mode,
|
|
hidden_size,
|
|
num_layers,
|
|
has_biases,
|
|
train,
|
|
bidirectional,
|
|
batch_sizes,
|
|
batch_first,
|
|
workspace,
|
|
):
|
|
diff_x = input.new_empty(input.shape)
|
|
diff_hx = hx_.new_empty(hx_.shape)
|
|
diff_cx = cx_tmp.new_empty(cx_tmp.shape)
|
|
diff_w1 = weight0.new_empty(weight0.shape)
|
|
diff_w2 = weight1.new_empty(weight1.shape)
|
|
diff_b = weight2.new_empty(weight2.shape)
|
|
return diff_x, diff_w1, diff_w2, diff_b, diff_b, diff_hx, diff_cx
|
|
|
|
|
|
@register_meta([aten.bucketize.Tensor, aten.bucketize.Tensor_out])
|
|
@out_wrapper()
|
|
def meta_bucketize(self, boundaries, *, out_int32=False, right=False):
|
|
return torch.empty_like(
|
|
self, dtype=torch.int32 if out_int32 else torch.int64
|
|
).contiguous()
|
|
|
|
|
|
# We must also trigger meta registrations from PrimTorch ref
|
|
# decompositions
|
|
import torch._refs
|
|
import torch._refs.nn.functional
|
|
import torch._refs.special
|
|
|
|
|
|
def activate_meta():
|
|
|
|
activate_meta_table = {}
|
|
|
|
# For a given op, we pick the most specific decomp function from
|
|
# global_decomp_table in the precedence order of meta > post_autograd > pre_autograd
|
|
for type in ["meta", "post_autograd", "pre_autograd"]:
|
|
registry = global_decomposition_table[type]
|
|
|
|
for opo in registry:
|
|
if opo not in activate_meta_table:
|
|
activate_meta_table[opo] = registry[opo]
|
|
|
|
for op_overload, fn in activate_meta_table.items():
|
|
assert isinstance(op_overload, OpOverload)
|
|
|
|
op_overload.py_impl(torch._C.DispatchKey.Meta)(fn)
|
|
|
|
if torch._C._dispatch_has_kernel_for_dispatch_key(
|
|
op_overload.name(), "CompositeImplicitAutograd"
|
|
):
|
|
# Internally, we shouldn't be registering meta kernels for any operators that
|
|
# have CompositeImplicitAutograd kernels.
|
|
# Instead, we should be letting those decompositions run, and writing meta kernels
|
|
# only for the base operators.
|
|
if op_overload in global_decomposition_table["meta"]:
|
|
raise RuntimeError(
|
|
f"{op_overload} is a CompositeImplicitAutograd op, we shouldn't "
|
|
"register meta function for it. Instead, we should let the decomposition run and write "
|
|
"meta kernels for the base operators."
|
|
)
|
|
pass
|
|
elif op_overload.is_view:
|
|
# Attempting to register a python meta kernel for a view operator.
|
|
# We shouldn't do this, because the output will report as not having aliased storages.
|
|
# All view ops have meta kernels in C++ today, so we should use those instead.
|
|
pass
|
|
elif op_overload.name() in {
|
|
"aten::empty_strided", # causing infinite recursion, test_meta.py
|
|
"aten::clone", # causing infinite recursion
|
|
"aten::_to_copy", # causing infinite recursion, test_serialization.py -k test_tensor_subclass_getstate_overwrite # noqa: B950
|
|
"aten::copy_", # Exception not raised, test_torch.py -k test_storage_meta_errors_cpu_int64 # noqa: B950
|
|
"aten::constant_pad_nd", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_amp_istft_cuda_float32 # noqa: B950
|
|
"aten::rot90", # requires_grad mismatch! test_ops.py -k test_fake_crossref_backward_amp_rot90_cuda_float32 # noqa: B950
|
|
"aten::as_strided_scatter", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_no_amp_as_strided_scatter_cuda_float32 # noqa: B950
|
|
}:
|
|
pass
|
|
else:
|
|
if "mkldnn::" in op_overload.name():
|
|
_meta_lib_dont_use_me_use_register_meta_for_mkldnn.impl(op_overload, fn)
|
|
elif "mkl::" in op_overload.name():
|
|
_meta_lib_dont_use_me_use_register_meta_for_mkl.impl(op_overload, fn)
|
|
else:
|
|
_meta_lib_dont_use_me_use_register_meta.impl(op_overload, fn)
|
|
|
|
|
|
activate_meta()
|