mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Replace _prims_common.check with torch._check* (#103240)
This relands most of the changes from #102219 which were backed out by #103128. However, instead of removing `_prims_common.check`, it adds a warning and a comment mentioning that it will be removed in the future and `torch._check*` should be used instead. As mentioned in https://github.com/pytorch/pytorch/pull/103128#pullrequestreview-1466414415, `_prims_common.check` cannot yet be removed because of some internal usage Part of #72948 Pull Request resolved: https://github.com/pytorch/pytorch/pull/103240 Approved by: https://github.com/albanD
This commit is contained in:
parent
f3c3d12efb
commit
ee83c646bb
|
|
@ -1161,6 +1161,10 @@ $1 = torch._ops.prims.sin.default($0)""")
|
||||||
def test_mul_complex(self):
|
def test_mul_complex(self):
|
||||||
prims.mul(torch.randn(2), 1 + 1j)
|
prims.mul(torch.randn(2), 1 + 1j)
|
||||||
|
|
||||||
|
def test_check_deprecation_warning(self):
|
||||||
|
with self.assertWarnsRegex(DeprecationWarning, 'will be removed in the future'):
|
||||||
|
torch._prims_common.check(True, lambda: 'message')
|
||||||
|
|
||||||
|
|
||||||
instantiate_device_type_tests(TestPrims, globals())
|
instantiate_device_type_tests(TestPrims, globals())
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -936,7 +936,7 @@ def is_warn_always_enabled():
|
||||||
# These error checking functions must be kept consistent with their C++
|
# These error checking functions must be kept consistent with their C++
|
||||||
# equivalents. Their C++ equivalents are mentioned where applicable.
|
# equivalents. Their C++ equivalents are mentioned where applicable.
|
||||||
|
|
||||||
def _check_with(error_type, cond, message):
|
def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callable[[], str]):
|
||||||
if not isinstance(cond, (builtins.bool, torch.SymBool)):
|
if not isinstance(cond, (builtins.bool, torch.SymBool)):
|
||||||
raise TypeError(f'cond must be a bool, but got {type(cond)}')
|
raise TypeError(f'cond must be a bool, but got {type(cond)}')
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -149,7 +149,7 @@ def fill_scalar(self, value):
|
||||||
|
|
||||||
@register_decomposition([aten.fill.Tensor])
|
@register_decomposition([aten.fill.Tensor])
|
||||||
def fill_tensor(self, value: Tensor):
|
def fill_tensor(self, value: Tensor):
|
||||||
utils.check(
|
torch._check(
|
||||||
value.dim() == 0,
|
value.dim() == 0,
|
||||||
lambda: f"fill only supports 0-dimension value tensor but got tensor with {value.dim()} dimensions",
|
lambda: f"fill only supports 0-dimension value tensor but got tensor with {value.dim()} dimensions",
|
||||||
)
|
)
|
||||||
|
|
@ -785,14 +785,14 @@ def im2col(
|
||||||
padding: List[int],
|
padding: List[int],
|
||||||
stride: List[int],
|
stride: List[int],
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
utils.check(len(kernel_size) == 2, lambda: "im2col(): only 2D kernel supported")
|
torch._check(len(kernel_size) == 2, lambda: "im2col(): only 2D kernel supported")
|
||||||
utils.check(len(dilation) == 2, lambda: "im2col(): only 2D dilation supported")
|
torch._check(len(dilation) == 2, lambda: "im2col(): only 2D dilation supported")
|
||||||
utils.check(len(padding) == 2, lambda: "im2col(): only 2D padding supported")
|
torch._check(len(padding) == 2, lambda: "im2col(): only 2D padding supported")
|
||||||
utils.check(len(stride) == 2, lambda: "im2col(): only 2D stride supported")
|
torch._check(len(stride) == 2, lambda: "im2col(): only 2D stride supported")
|
||||||
|
|
||||||
def check_positive(param, param_name, strict=True):
|
def check_positive(param, param_name, strict=True):
|
||||||
cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
|
cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
|
||||||
utils.check(
|
torch._check(
|
||||||
cond, lambda: "{param_name} should be greater {'than' zero, but got {param}"
|
cond, lambda: "{param_name} should be greater {'than' zero, but got {param}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -803,7 +803,7 @@ def im2col(
|
||||||
|
|
||||||
shape = input.shape
|
shape = input.shape
|
||||||
ndim = len(shape)
|
ndim = len(shape)
|
||||||
utils.check(
|
torch._check(
|
||||||
ndim in (3, 4) and all(d != 0 for d in shape[-3:]),
|
ndim in (3, 4) and all(d != 0 for d in shape[-3:]),
|
||||||
lambda: "Expected 3D or 4D (batch mode) tensor for input with possible 0 batch size "
|
lambda: "Expected 3D or 4D (batch mode) tensor for input with possible 0 batch size "
|
||||||
f"and non-zero dimensions, but got: {tuple(shape)}",
|
f"and non-zero dimensions, but got: {tuple(shape)}",
|
||||||
|
|
@ -814,7 +814,7 @@ def im2col(
|
||||||
shape[-2:], padding, dilation, kernel_size, stride
|
shape[-2:], padding, dilation, kernel_size, stride
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
utils.check(
|
torch._check(
|
||||||
all(c > 0 for c in output_size),
|
all(c > 0 for c in output_size),
|
||||||
lambda: f"Given an input with spacial size {tuple(shape[-2:])}, "
|
lambda: f"Given an input with spacial size {tuple(shape[-2:])}, "
|
||||||
f"kernel_size={kernel_size}, dilation={dilation}, "
|
f"kernel_size={kernel_size}, dilation={dilation}, "
|
||||||
|
|
@ -869,15 +869,15 @@ def col2im(
|
||||||
padding: List[int],
|
padding: List[int],
|
||||||
stride: List[int],
|
stride: List[int],
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
utils.check(len(output_size) == 2, lambda: "only 2D output_size supported")
|
torch._check(len(output_size) == 2, lambda: "only 2D output_size supported")
|
||||||
utils.check(len(kernel_size) == 2, lambda: "only 2D kernel supported")
|
torch._check(len(kernel_size) == 2, lambda: "only 2D kernel supported")
|
||||||
utils.check(len(dilation) == 2, lambda: "only 2D dilation supported")
|
torch._check(len(dilation) == 2, lambda: "only 2D dilation supported")
|
||||||
utils.check(len(padding) == 2, lambda: "only 2D padding supported")
|
torch._check(len(padding) == 2, lambda: "only 2D padding supported")
|
||||||
utils.check(len(stride) == 2, lambda: "only 2D stride supported")
|
torch._check(len(stride) == 2, lambda: "only 2D stride supported")
|
||||||
|
|
||||||
def check_positive(param, param_name, strict=True):
|
def check_positive(param, param_name, strict=True):
|
||||||
cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
|
cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
|
||||||
utils.check(
|
torch._check(
|
||||||
cond, lambda: "{param_name} should be greater than zero, but got {param}"
|
cond, lambda: "{param_name} should be greater than zero, but got {param}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -889,13 +889,13 @@ def col2im(
|
||||||
|
|
||||||
shape = input.shape
|
shape = input.shape
|
||||||
ndim = len(shape)
|
ndim = len(shape)
|
||||||
utils.check(
|
torch._check(
|
||||||
ndim in (2, 3) and all(d != 0 for d in shape[-2:]),
|
ndim in (2, 3) and all(d != 0 for d in shape[-2:]),
|
||||||
lambda: "Expected 2D or 3D (batch mode) tensor for input with possible 0 batch size "
|
lambda: "Expected 2D or 3D (batch mode) tensor for input with possible 0 batch size "
|
||||||
f"and non-zero dimensions, but got: {tuple(shape)}",
|
f"and non-zero dimensions, but got: {tuple(shape)}",
|
||||||
)
|
)
|
||||||
prod_kernel_size = kernel_size[0] * kernel_size[1]
|
prod_kernel_size = kernel_size[0] * kernel_size[1]
|
||||||
utils.check(
|
torch._check(
|
||||||
shape[-2] % prod_kernel_size == 0,
|
shape[-2] % prod_kernel_size == 0,
|
||||||
lambda: "Expected size of input's first non-batch dimension to be divisible by the "
|
lambda: "Expected size of input's first non-batch dimension to be divisible by the "
|
||||||
f"product of kernel_size, but got input.shape[-2] = {shape[-2]} and "
|
f"product of kernel_size, but got input.shape[-2] = {shape[-2]} and "
|
||||||
|
|
@ -908,13 +908,13 @@ def col2im(
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
L = col[0] * col[1]
|
L = col[0] * col[1]
|
||||||
utils.check(
|
torch._check(
|
||||||
shape[-1] == L,
|
shape[-1] == L,
|
||||||
lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, "
|
lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, "
|
||||||
f"dilation={dilation}, padding={padding}, stride={stride}, "
|
f"dilation={dilation}, padding={padding}, stride={stride}, "
|
||||||
f"expected input.size(-1) to be {L} but got {shape[-1]}.",
|
f"expected input.size(-1) to be {L} but got {shape[-1]}.",
|
||||||
)
|
)
|
||||||
utils.check(
|
torch._check(
|
||||||
L > 0,
|
L > 0,
|
||||||
lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, "
|
lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, "
|
||||||
f"dilation={dilation}, padding={padding}, stride={stride}, "
|
f"dilation={dilation}, padding={padding}, stride={stride}, "
|
||||||
|
|
@ -961,7 +961,7 @@ def col2im(
|
||||||
def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float):
|
def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float):
|
||||||
# According to the CUDA kernel implementation we should have this test;
|
# According to the CUDA kernel implementation we should have this test;
|
||||||
# but it seems to fail tests!
|
# but it seems to fail tests!
|
||||||
# utils.check(mask.dtype == torch.bool, lambda: f"Mask should be Bool Scalar Type {mask.dtype}")
|
# torch._check(mask.dtype == torch.bool, lambda: f"Mask should be Bool Scalar Type {mask.dtype}")
|
||||||
|
|
||||||
# Mimicking CUDA kernel's behavior for output stride: output follow input's memory format
|
# Mimicking CUDA kernel's behavior for output stride: output follow input's memory format
|
||||||
# This different from TensorIterator's behavior
|
# This different from TensorIterator's behavior
|
||||||
|
|
@ -1221,21 +1221,21 @@ def native_group_norm_backward(
|
||||||
)
|
)
|
||||||
utils.check_same_shape(input, grad_output, allow_cpu_scalar_tensors=False)
|
utils.check_same_shape(input, grad_output, allow_cpu_scalar_tensors=False)
|
||||||
utils.check_same_shape(mean, rstd, allow_cpu_scalar_tensors=False)
|
utils.check_same_shape(mean, rstd, allow_cpu_scalar_tensors=False)
|
||||||
utils.check(
|
torch._check(
|
||||||
input.numel() == N * C * HxW,
|
input.numel() == N * C * HxW,
|
||||||
lambda: f"Expect input to have { N * C * HxW} elements",
|
lambda: f"Expect input to have { N * C * HxW} elements",
|
||||||
)
|
)
|
||||||
utils.check(
|
torch._check(
|
||||||
mean.shape == (N, group),
|
mean.shape == (N, group),
|
||||||
lambda: f"Expect mean to have shape ({N}, {group}, but got {mean.shape}",
|
lambda: f"Expect mean to have shape ({N}, {group}, but got {mean.shape}",
|
||||||
)
|
)
|
||||||
utils.check(
|
torch._check(
|
||||||
gamma is None or gamma.numel() == C,
|
gamma is None or gamma.numel() == C,
|
||||||
lambda: f"Expect gamma to have {C} elements but got {gamma.numel() if gamma is not None else -1}",
|
lambda: f"Expect gamma to have {C} elements but got {gamma.numel() if gamma is not None else -1}",
|
||||||
)
|
)
|
||||||
|
|
||||||
cpg, _rem = divmod(C, group)
|
cpg, _rem = divmod(C, group)
|
||||||
utils.check(
|
torch._check(
|
||||||
_rem == 0,
|
_rem == 0,
|
||||||
lambda: f"Expect number of channels {C} to be evenly-divisible by number of groups {group}",
|
lambda: f"Expect number of channels {C} to be evenly-divisible by number of groups {group}",
|
||||||
)
|
)
|
||||||
|
|
@ -1834,12 +1834,12 @@ def adaptive_avg_pool2d(input: Tensor, output_size: Tuple[int, int]):
|
||||||
device = input.device
|
device = input.device
|
||||||
shape = input.shape
|
shape = input.shape
|
||||||
ndim = len(shape)
|
ndim = len(shape)
|
||||||
utils.check(
|
torch._check(
|
||||||
ndim in (3, 4),
|
ndim in (3, 4),
|
||||||
lambda: f"adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got {ndim}",
|
lambda: f"adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got {ndim}",
|
||||||
)
|
)
|
||||||
for d in input.shape[-2:]:
|
for d in input.shape[-2:]:
|
||||||
utils.check(
|
torch._check(
|
||||||
d != 0,
|
d != 0,
|
||||||
lambda: "adaptive_avg_pool2d(): Expected input to have non-zero size for "
|
lambda: "adaptive_avg_pool2d(): Expected input to have non-zero size for "
|
||||||
f"non-batch dimensions, but input has shape {tuple(shape)}.",
|
f"non-batch dimensions, but input has shape {tuple(shape)}.",
|
||||||
|
|
@ -1966,13 +1966,13 @@ def _index_add(
|
||||||
alpha: NumberType = 1,
|
alpha: NumberType = 1,
|
||||||
):
|
):
|
||||||
dim = utils.canonicalize_dims(x.ndim, dim)
|
dim = utils.canonicalize_dims(x.ndim, dim)
|
||||||
utils.check(
|
torch._check(
|
||||||
index.ndim <= 1,
|
index.ndim <= 1,
|
||||||
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
|
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
|
||||||
)
|
)
|
||||||
if alpha != 1:
|
if alpha != 1:
|
||||||
python_type = utils.dtype_to_type(x.dtype)
|
python_type = utils.dtype_to_type(x.dtype)
|
||||||
utils.check(
|
torch._check(
|
||||||
python_type == bool
|
python_type == bool
|
||||||
or utils.is_weakly_lesser_type(type(alpha), python_type),
|
or utils.is_weakly_lesser_type(type(alpha), python_type),
|
||||||
lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!",
|
lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!",
|
||||||
|
|
@ -2005,7 +2005,7 @@ def _index_copy(
|
||||||
x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike, *, inplace: bool
|
x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike, *, inplace: bool
|
||||||
):
|
):
|
||||||
dim = utils.canonicalize_dims(x.ndim, dim)
|
dim = utils.canonicalize_dims(x.ndim, dim)
|
||||||
utils.check(
|
torch._check(
|
||||||
index.ndim <= 1,
|
index.ndim <= 1,
|
||||||
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
|
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
|
||||||
)
|
)
|
||||||
|
|
@ -2060,19 +2060,19 @@ def uniform_(self, low=0, high=1, generator=None):
|
||||||
def upsample_compute_output_size(input_size, output_size, scale_factors):
|
def upsample_compute_output_size(input_size, output_size, scale_factors):
|
||||||
spatial_dimensions = len(input_size) - 2
|
spatial_dimensions = len(input_size) - 2
|
||||||
if output_size is not None:
|
if output_size is not None:
|
||||||
utils.check(
|
torch._check(
|
||||||
scale_factors is None,
|
scale_factors is None,
|
||||||
lambda: "Must specify exactly one of output_size and scale_factors",
|
lambda: "Must specify exactly one of output_size and scale_factors",
|
||||||
)
|
)
|
||||||
utils.check(len(output_size) == spatial_dimensions, lambda: "")
|
torch._check(len(output_size) == spatial_dimensions, lambda: "")
|
||||||
return output_size
|
return output_size
|
||||||
if scale_factors is not None:
|
if scale_factors is not None:
|
||||||
# NB: this isn't necessary lol
|
# NB: this isn't necessary lol
|
||||||
utils.check(
|
torch._check(
|
||||||
output_size is None,
|
output_size is None,
|
||||||
lambda: "Must specify exactly one of output_size and scale_factors",
|
lambda: "Must specify exactly one of output_size and scale_factors",
|
||||||
)
|
)
|
||||||
utils.check(len(scale_factors) == spatial_dimensions, lambda: "")
|
torch._check(len(scale_factors) == spatial_dimensions, lambda: "")
|
||||||
output_size = []
|
output_size = []
|
||||||
for i, s in enumerate(scale_factors):
|
for i, s in enumerate(scale_factors):
|
||||||
if int(s) == s:
|
if int(s) == s:
|
||||||
|
|
@ -2080,7 +2080,7 @@ def upsample_compute_output_size(input_size, output_size, scale_factors):
|
||||||
else:
|
else:
|
||||||
output_size.append(sym_int(input_size[i + 2] * s))
|
output_size.append(sym_int(input_size[i + 2] * s))
|
||||||
return output_size
|
return output_size
|
||||||
utils.check(
|
torch._check(
|
||||||
False, lambda: "Must specify exactly one of output_size and scale_factors"
|
False, lambda: "Must specify exactly one of output_size and scale_factors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -2969,11 +2969,11 @@ def grid_sampler_2d(
|
||||||
padding_mode: int = 0,
|
padding_mode: int = 0,
|
||||||
align_corners: bool = False,
|
align_corners: bool = False,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
utils.check(
|
torch._check(
|
||||||
interpolation_mode in (0, 1, 2),
|
interpolation_mode in (0, 1, 2),
|
||||||
lambda: f"Invalid interpolation mode {interpolation_mode}",
|
lambda: f"Invalid interpolation mode {interpolation_mode}",
|
||||||
)
|
)
|
||||||
utils.check(
|
torch._check(
|
||||||
padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}"
|
padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -3110,11 +3110,11 @@ def grid_sampler_2d(
|
||||||
@out_wrapper()
|
@out_wrapper()
|
||||||
@pw_cast_for_opmath
|
@pw_cast_for_opmath
|
||||||
def mv(self, vec):
|
def mv(self, vec):
|
||||||
utils.check(
|
torch._check(
|
||||||
self.dim() == 2 and vec.dim() == 1,
|
self.dim() == 2 and vec.dim() == 1,
|
||||||
lambda: f"matrix @ vector expected, got {self.dim()}, {vec.dim()}",
|
lambda: f"matrix @ vector expected, got {self.dim()}, {vec.dim()}",
|
||||||
)
|
)
|
||||||
utils.check(
|
torch._check(
|
||||||
self.size(1) == vec.size(0),
|
self.size(1) == vec.size(0),
|
||||||
lambda: f"size mismatch, got input ({self.size(0)}x{self.size(1)}), vec ({vec.size(0)})",
|
lambda: f"size mismatch, got input ({self.size(0)}x{self.size(1)}), vec ({vec.size(0)})",
|
||||||
)
|
)
|
||||||
|
|
@ -3134,11 +3134,11 @@ def dot(self, other):
|
||||||
elif other.is_conj():
|
elif other.is_conj():
|
||||||
return torch.vdot(other.conj(), self)
|
return torch.vdot(other.conj(), self)
|
||||||
|
|
||||||
utils.check(
|
torch._check(
|
||||||
self.dim() == 1 and other.dim() == 1,
|
self.dim() == 1 and other.dim() == 1,
|
||||||
lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors",
|
lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors",
|
||||||
)
|
)
|
||||||
utils.check(
|
torch._check(
|
||||||
self.dtype == other.dtype,
|
self.dtype == other.dtype,
|
||||||
lambda: f"dot : expected both vectors to have same dtype, but found {self.dtype} and {other.dtype}",
|
lambda: f"dot : expected both vectors to have same dtype, but found {self.dtype} and {other.dtype}",
|
||||||
)
|
)
|
||||||
|
|
@ -3149,7 +3149,7 @@ def dot(self, other):
|
||||||
f"same number of elements, but got {self.numel()} and {other.numel()} elements respectively"
|
f"same number of elements, but got {self.numel()} and {other.numel()} elements respectively"
|
||||||
)
|
)
|
||||||
|
|
||||||
utils.check(self.numel() == other.numel(), numel_error)
|
torch._check(self.numel() == other.numel(), numel_error)
|
||||||
|
|
||||||
return (self * other).sum()
|
return (self * other).sum()
|
||||||
|
|
||||||
|
|
@ -3296,7 +3296,7 @@ def matmul(tensor1, tensor2):
|
||||||
|
|
||||||
return tensor1_expanded.bmm(tensor2_expanded).view(output_shape)
|
return tensor1_expanded.bmm(tensor2_expanded).view(output_shape)
|
||||||
else:
|
else:
|
||||||
utils.check(False, lambda: "both arguments to matmul need to be at least 1D")
|
torch._check(False, lambda: "both arguments to matmul need to be at least 1D")
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition(aten.upsample_bicubic2d.default)
|
@register_decomposition(aten.upsample_bicubic2d.default)
|
||||||
|
|
@ -3373,7 +3373,7 @@ def upsample_bicubic2d_vec(
|
||||||
align_corners: bool,
|
align_corners: bool,
|
||||||
scale_factors: Optional[Tuple[float, float]] = None,
|
scale_factors: Optional[Tuple[float, float]] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
utils.check(
|
torch._check(
|
||||||
bool(output_size) + bool(scale_factors) == 1,
|
bool(output_size) + bool(scale_factors) == 1,
|
||||||
lambda: "Must specify exactly one of output_size and scale_factors.",
|
lambda: "Must specify exactly one of output_size and scale_factors.",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -72,7 +72,6 @@ from torch._inductor.compile_fx import (
|
||||||
remove_unaligned_input_idxs,
|
remove_unaligned_input_idxs,
|
||||||
static_input,
|
static_input,
|
||||||
)
|
)
|
||||||
from torch._prims_common import check
|
|
||||||
from torch.multiprocessing.reductions import StorageWeakRef
|
from torch.multiprocessing.reductions import StorageWeakRef
|
||||||
from torch.storage import UntypedStorage
|
from torch.storage import UntypedStorage
|
||||||
from torch.utils import _pytree as pytree
|
from torch.utils import _pytree as pytree
|
||||||
|
|
@ -1071,7 +1070,7 @@ class CUDAGraphNode:
|
||||||
self.output_storage_alias.append(UnaliasedStorage)
|
self.output_storage_alias.append(UnaliasedStorage)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
check(
|
torch._check(
|
||||||
o.is_cuda,
|
o.is_cuda,
|
||||||
lambda: f"Expected all cuda outputs in cuda graph recording. Non cuda output from {self.stack_traces[i]}",
|
lambda: f"Expected all cuda outputs in cuda graph recording. Non cuda output from {self.stack_traces[i]}",
|
||||||
),
|
),
|
||||||
|
|
@ -1447,7 +1446,7 @@ class CUDAGraphNode:
|
||||||
for idx in self.cudagraph_managed_idxs:
|
for idx in self.cudagraph_managed_idxs:
|
||||||
inputs[idx] = None
|
inputs[idx] = None
|
||||||
|
|
||||||
check(
|
torch._check(
|
||||||
self._check_liveness(
|
self._check_liveness(
|
||||||
self.expected_dead_indices_after_graph, self.path_weakrefs
|
self.expected_dead_indices_after_graph, self.path_weakrefs
|
||||||
),
|
),
|
||||||
|
|
@ -1522,7 +1521,7 @@ def check_memory_pool(device, pool_id, live_storages_ptrs: List[StorageWeakRefWr
|
||||||
|
|
||||||
addr += block["size"]
|
addr += block["size"]
|
||||||
|
|
||||||
check(
|
torch._check(
|
||||||
len(unique_storages) == 0,
|
len(unique_storages) == 0,
|
||||||
lambda: f"These storage data ptrs are not allocated in pool {pool_id} but should be {unique_storages}",
|
lambda: f"These storage data ptrs are not allocated in pool {pool_id} but should be {unique_storages}",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -16,7 +16,6 @@ from torch._prims.debug_prims import register_debug_prims
|
||||||
from torch._prims.nvfuser_prims import register_nvprims
|
from torch._prims.nvfuser_prims import register_nvprims
|
||||||
from torch._prims.rng_prims import register_rng_prims
|
from torch._prims.rng_prims import register_rng_prims
|
||||||
from torch._prims_common import (
|
from torch._prims_common import (
|
||||||
check,
|
|
||||||
Dim,
|
Dim,
|
||||||
DimsSequenceType,
|
DimsSequenceType,
|
||||||
DimsType,
|
DimsType,
|
||||||
|
|
@ -422,7 +421,7 @@ def _elementwise_meta(
|
||||||
|
|
||||||
|
|
||||||
def _complex_only_elementwise_meta(*args, **kwargs):
|
def _complex_only_elementwise_meta(*args, **kwargs):
|
||||||
utils.check(
|
torch._check(
|
||||||
utils.is_complex_dtype(args[0].dtype), lambda: "Only complex dtype is supported"
|
utils.is_complex_dtype(args[0].dtype), lambda: "Only complex dtype is supported"
|
||||||
)
|
)
|
||||||
return _elementwise_meta(*args, **kwargs)
|
return _elementwise_meta(*args, **kwargs)
|
||||||
|
|
@ -581,7 +580,7 @@ bitwise_not = _make_elementwise_unary_prim(
|
||||||
|
|
||||||
|
|
||||||
def _cbrt_aten(a: torch.Tensor) -> Tensor:
|
def _cbrt_aten(a: torch.Tensor) -> Tensor:
|
||||||
utils.check(
|
torch._check(
|
||||||
not a.is_complex(),
|
not a.is_complex(),
|
||||||
lambda: "cbrt: Complex inputs not supported. Consider calling torch.pow(a, 1.0/3.0)",
|
lambda: "cbrt: Complex inputs not supported. Consider calling torch.pow(a, 1.0/3.0)",
|
||||||
)
|
)
|
||||||
|
|
@ -1293,10 +1292,9 @@ def _validate_collapse_args(a: Tensor, start: int, end: int) -> None:
|
||||||
|
|
||||||
# Verifies end is strictly greater than start
|
# Verifies end is strictly greater than start
|
||||||
# (Collapse requires a non-empty interval)
|
# (Collapse requires a non-empty interval)
|
||||||
utils.check(
|
torch._check_value(
|
||||||
end >= start,
|
end >= start,
|
||||||
lambda: f"Attempting to collapse but end, {end}, is less than start, {start}!",
|
lambda: f"Attempting to collapse but end, {end}, is less than start, {start}!",
|
||||||
ValueError,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1823,7 +1821,7 @@ def _as_strided_scatter_meta(
|
||||||
utils.validate_strides(stride)
|
utils.validate_strides(stride)
|
||||||
|
|
||||||
required_size = utils.compute_required_storage_length(size, stride, storage_offset)
|
required_size = utils.compute_required_storage_length(size, stride, storage_offset)
|
||||||
utils.check(
|
torch._check(
|
||||||
input.numel() >= required_size,
|
input.numel() >= required_size,
|
||||||
lambda: (
|
lambda: (
|
||||||
f"as_strided_scatter: sizes {size}, strides {stride}, storage offset {storage_offset} "
|
f"as_strided_scatter: sizes {size}, strides {stride}, storage offset {storage_offset} "
|
||||||
|
|
@ -1832,7 +1830,7 @@ def _as_strided_scatter_meta(
|
||||||
f"for storage of size {input.numel() * input.element_size()}"
|
f"for storage of size {input.numel() * input.element_size()}"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
utils.check(
|
torch._check(
|
||||||
utils.is_same_shape(src.shape, size),
|
utils.is_same_shape(src.shape, size),
|
||||||
lambda: f"expected src to have a size equal to the slice of self. src size = {src.shape}, slice size = {size}",
|
lambda: f"expected src to have a size equal to the slice of self. src size = {src.shape}, slice size = {size}",
|
||||||
)
|
)
|
||||||
|
|
@ -2432,11 +2430,11 @@ def _iota_meta(
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
requires_grad: bool,
|
requires_grad: bool,
|
||||||
) -> TensorLikeType:
|
) -> TensorLikeType:
|
||||||
utils.check(
|
torch._check(
|
||||||
utils.is_integer_dtype(dtype),
|
utils.is_integer_dtype(dtype),
|
||||||
lambda: "prims.iota only supports integer dtypes",
|
lambda: "prims.iota only supports integer dtypes",
|
||||||
)
|
)
|
||||||
utils.check(step != 0, lambda: "step must be nonzero")
|
torch._check(step != 0, lambda: "step must be nonzero")
|
||||||
return torch.empty(
|
return torch.empty(
|
||||||
length,
|
length,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
|
@ -2532,7 +2530,7 @@ def _empty_permuted_meta(
|
||||||
) -> TensorLikeType:
|
) -> TensorLikeType:
|
||||||
p_strides = utils.make_contiguous_strides_for([shape[l] for l in physical_layout])
|
p_strides = utils.make_contiguous_strides_for([shape[l] for l in physical_layout])
|
||||||
dim = len(shape)
|
dim = len(shape)
|
||||||
utils.check(
|
torch._check(
|
||||||
len(physical_layout) == dim,
|
len(physical_layout) == dim,
|
||||||
lambda: (
|
lambda: (
|
||||||
"Number of dimensions in the tensor input does not match the "
|
"Number of dimensions in the tensor input does not match the "
|
||||||
|
|
@ -2543,7 +2541,7 @@ def _empty_permuted_meta(
|
||||||
strides = [0] * len(shape)
|
strides = [0] * len(shape)
|
||||||
seen_dims = set()
|
seen_dims = set()
|
||||||
for p, l in enumerate(physical_layout):
|
for p, l in enumerate(physical_layout):
|
||||||
utils.check(
|
torch._check(
|
||||||
0 <= l < dim,
|
0 <= l < dim,
|
||||||
lambda: (
|
lambda: (
|
||||||
f"Dimension out of range (expected to be between 0 and {dim - 1}, but got "
|
f"Dimension out of range (expected to be between 0 and {dim - 1}, but got "
|
||||||
|
|
@ -2551,7 +2549,7 @@ def _empty_permuted_meta(
|
||||||
"not currently supported; file an issue if you want it."
|
"not currently supported; file an issue if you want it."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
utils.check(l not in seen_dims, lambda: "Duplicate dim not allowed")
|
torch._check(l not in seen_dims, lambda: "Duplicate dim not allowed")
|
||||||
strides[l] = p_strides[p]
|
strides[l] = p_strides[p]
|
||||||
seen_dims.add(l)
|
seen_dims.add(l)
|
||||||
return TensorMeta(
|
return TensorMeta(
|
||||||
|
|
@ -2779,12 +2777,12 @@ def _normal_meta(
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
requires_grad: bool,
|
requires_grad: bool,
|
||||||
) -> TensorLikeType:
|
) -> TensorLikeType:
|
||||||
utils.check(
|
torch._check(
|
||||||
std >= 0.0,
|
std >= 0.0,
|
||||||
lambda: f"expected non-negative standard deviation, but got std={std}",
|
lambda: f"expected non-negative standard deviation, but got std={std}",
|
||||||
)
|
)
|
||||||
|
|
||||||
utils.check(
|
torch._check(
|
||||||
utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
|
utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
|
||||||
lambda: f"expected a floating-point or complex dtype, but got dtype={dtype}",
|
lambda: f"expected a floating-point or complex dtype, but got dtype={dtype}",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from functools import reduce, cmp_to_key
|
||||||
import operator
|
import operator
|
||||||
import sympy
|
import sympy
|
||||||
import weakref
|
import weakref
|
||||||
|
import warnings
|
||||||
import torch
|
import torch
|
||||||
from torch import sym_float, sym_int, sym_max
|
from torch import sym_float, sym_int, sym_max
|
||||||
|
|
||||||
|
|
@ -268,7 +269,7 @@ _memory_formats = {
|
||||||
|
|
||||||
|
|
||||||
def validate_memory_format(memory_format: torch.memory_format):
|
def validate_memory_format(memory_format: torch.memory_format):
|
||||||
check(
|
torch._check(
|
||||||
memory_format in _memory_formats,
|
memory_format in _memory_formats,
|
||||||
lambda: f"Received unknown memory format {memory_format}!",
|
lambda: f"Received unknown memory format {memory_format}!",
|
||||||
)
|
)
|
||||||
|
|
@ -286,7 +287,7 @@ def is_contiguous_for_memory_format( # type: ignore[return]
|
||||||
if memory_format == torch.channels_last_3d:
|
if memory_format == torch.channels_last_3d:
|
||||||
return is_channels_last_contiguous_3d(a)
|
return is_channels_last_contiguous_3d(a)
|
||||||
|
|
||||||
check(
|
torch._check(
|
||||||
False,
|
False,
|
||||||
lambda: f"is_contiguous received unsupported memory format {memory_format}",
|
lambda: f"is_contiguous received unsupported memory format {memory_format}",
|
||||||
)
|
)
|
||||||
|
|
@ -795,13 +796,13 @@ def infer_size(shape: ShapeType, numel: int) -> Tuple[int, ...]:
|
||||||
newsize = 1
|
newsize = 1
|
||||||
for i, d in enumerate(shape):
|
for i, d in enumerate(shape):
|
||||||
if d == -1:
|
if d == -1:
|
||||||
check(dim is None, lambda: "only one dimension can be inferred")
|
torch._check(dim is None, lambda: "only one dimension can be inferred")
|
||||||
dim = i
|
dim = i
|
||||||
elif d >= 0:
|
elif d >= 0:
|
||||||
newsize *= d
|
newsize *= d
|
||||||
else:
|
else:
|
||||||
check(False, lambda: f"invalid shape dimension {d}")
|
torch._check(False, lambda: f"invalid shape dimension {d}")
|
||||||
check(
|
torch._check(
|
||||||
numel == newsize or (dim is not None and newsize > 0 and numel % newsize == 0),
|
numel == newsize or (dim is not None and newsize > 0 and numel % newsize == 0),
|
||||||
lambda: f"shape '{list(shape)}' is invalid for input of size {numel}",
|
lambda: f"shape '{list(shape)}' is invalid for input of size {numel}",
|
||||||
)
|
)
|
||||||
|
|
@ -809,7 +810,7 @@ def infer_size(shape: ShapeType, numel: int) -> Tuple[int, ...]:
|
||||||
# Convert to list to produce a compatible error message with core
|
# Convert to list to produce a compatible error message with core
|
||||||
# PyTorch, which prints sequences in square brackets.
|
# PyTorch, which prints sequences in square brackets.
|
||||||
shape = list(shape)
|
shape = list(shape)
|
||||||
check(
|
torch._check(
|
||||||
newsize != 0,
|
newsize != 0,
|
||||||
lambda: (f"cannot reshape tensor of 0 elements into shape {shape} because the "
|
lambda: (f"cannot reshape tensor of 0 elements into shape {shape} because the "
|
||||||
f"unspecified dimension size -1 can be any value and is ambiguous"),
|
f"unspecified dimension size -1 can be any value and is ambiguous"),
|
||||||
|
|
@ -954,18 +955,18 @@ def check_fp_or_complex(
|
||||||
Checks whether the input is floating point or complex.
|
Checks whether the input is floating point or complex.
|
||||||
If allow_low_precision_dtypes is True, it allows having float16, bfloat16, and complex32
|
If allow_low_precision_dtypes is True, it allows having float16, bfloat16, and complex32
|
||||||
"""
|
"""
|
||||||
check(
|
torch._check(
|
||||||
is_float_dtype(dtype) or is_complex_dtype(dtype),
|
is_float_dtype(dtype) or is_complex_dtype(dtype),
|
||||||
lambda: f"{fn_name}: Expected a floating point or complex tensor as input. Got {dtype}",
|
lambda: f"{fn_name}: Expected a floating point or complex tensor as input. Got {dtype}",
|
||||||
)
|
)
|
||||||
check(
|
torch._check(
|
||||||
allow_low_precision_dtypes or not is_low_precision_dtype(dtype),
|
allow_low_precision_dtypes or not is_low_precision_dtype(dtype),
|
||||||
lambda: f"{fn_name}: Half precision dtypes not supported. Got {dtype}",
|
lambda: f"{fn_name}: Half precision dtypes not supported. Got {dtype}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_is_matrix(A: TensorLikeType, f_name: str, arg_name: str = "A"):
|
def check_is_matrix(A: TensorLikeType, f_name: str, arg_name: str = "A"):
|
||||||
check(
|
torch._check(
|
||||||
len(A.shape) >= 2,
|
len(A.shape) >= 2,
|
||||||
lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.",
|
lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.",
|
||||||
)
|
)
|
||||||
|
|
@ -1060,11 +1061,11 @@ def get_higher_dtype(
|
||||||
|
|
||||||
|
|
||||||
def check_pin_memory(pin_memory: bool):
|
def check_pin_memory(pin_memory: bool):
|
||||||
check(not pin_memory, lambda: "PrimTorch does not support pinned memory", NotImplementedError)
|
torch._check_not_implemented(not pin_memory, lambda: "PrimTorch does not support pinned memory")
|
||||||
|
|
||||||
|
|
||||||
def check_layout(layout: torch.layout):
|
def check_layout(layout: torch.layout):
|
||||||
check(layout == torch.strided, lambda: f"PrimTorch doesn't support layout={layout}", NotImplementedError)
|
torch._check_not_implemented(layout == torch.strided, lambda: f"PrimTorch doesn't support layout={layout}")
|
||||||
|
|
||||||
|
|
||||||
# TODO: maybe unify with can_cast_to?
|
# TODO: maybe unify with can_cast_to?
|
||||||
|
|
@ -1485,7 +1486,7 @@ def make_contiguous_strides_for(
|
||||||
|
|
||||||
|
|
||||||
def make_channels_last_1d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
|
def make_channels_last_1d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
|
||||||
check(
|
torch._check(
|
||||||
len(shape) == 3,
|
len(shape) == 3,
|
||||||
lambda: "Only tensors of rank 3 can use the channels_last_1d memory format",
|
lambda: "Only tensors of rank 3 can use the channels_last_1d memory format",
|
||||||
)
|
)
|
||||||
|
|
@ -1503,7 +1504,7 @@ def make_channels_last_1d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
|
||||||
|
|
||||||
def make_channels_last_2d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
|
def make_channels_last_2d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
|
||||||
# TODO: maybe inform the user of channels_last_3d if rank of the tensor is 5?
|
# TODO: maybe inform the user of channels_last_3d if rank of the tensor is 5?
|
||||||
check(
|
torch._check(
|
||||||
len(shape) == 4,
|
len(shape) == 4,
|
||||||
lambda: "Only tensors of rank 4 can use the channels_last memory format",
|
lambda: "Only tensors of rank 4 can use the channels_last memory format",
|
||||||
)
|
)
|
||||||
|
|
@ -1520,7 +1521,7 @@ def make_channels_last_2d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
|
||||||
|
|
||||||
|
|
||||||
def make_channels_last_3d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
|
def make_channels_last_3d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
|
||||||
check(
|
torch._check(
|
||||||
len(shape) == 5,
|
len(shape) == 5,
|
||||||
lambda: "Only tensors of rank 5 can use the channels_last_3d memory format",
|
lambda: "Only tensors of rank 5 can use the channels_last_3d memory format",
|
||||||
)
|
)
|
||||||
|
|
@ -1654,6 +1655,9 @@ def check_in_bounds_for_storage(
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: This function should ideally be removed, but some Meta internal models
|
||||||
|
# packaged with `torch.package` are using it, so it will have to be removed
|
||||||
|
# at some point in the future when those models no longer use this function.
|
||||||
def check(
|
def check(
|
||||||
b: bool, s: Callable[[], str], exc_type: Type[Exception] = RuntimeError
|
b: bool, s: Callable[[], str], exc_type: Type[Exception] = RuntimeError
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -1662,9 +1666,14 @@ def check(
|
||||||
Error message is a callable producing a string (to avoid wasting time
|
Error message is a callable producing a string (to avoid wasting time
|
||||||
string formatting in non-error case, and also to make it easier for torchdynamo
|
string formatting in non-error case, and also to make it easier for torchdynamo
|
||||||
to trace.)
|
to trace.)
|
||||||
|
|
||||||
|
.. note:: This function is planned for removal in the future. Please use
|
||||||
|
`torch._check*` functions instead.
|
||||||
"""
|
"""
|
||||||
if not b:
|
warnings.warn(DeprecationWarning((
|
||||||
raise exc_type(s())
|
"'torch._prims_common.check' will be removed in the future. Please use "
|
||||||
|
"'torch._check*' functions instead")))
|
||||||
|
torch._check_with(exc_type, b, s)
|
||||||
|
|
||||||
|
|
||||||
# This combines is_channels_last_strides_2d and is_channels_last_strides_3d in
|
# This combines is_channels_last_strides_2d and is_channels_last_strides_3d in
|
||||||
|
|
|
||||||
|
|
@ -176,13 +176,13 @@ def _safe_copy_out(
|
||||||
|
|
||||||
# Checks safe cast
|
# Checks safe cast
|
||||||
if exact_dtype:
|
if exact_dtype:
|
||||||
utils.check(
|
torch._check(
|
||||||
copy_from.dtype == copy_to.dtype,
|
copy_from.dtype == copy_to.dtype,
|
||||||
lambda: f"Expected out tensor to have dtype {copy_from.dtype} "
|
lambda: f"Expected out tensor to have dtype {copy_from.dtype} "
|
||||||
f"but got {copy_to.dtype} instead",
|
f"but got {copy_to.dtype} instead",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
utils.check(
|
torch._check(
|
||||||
utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype),
|
utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype),
|
||||||
lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, "
|
lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, "
|
||||||
"but this can't be cast because it is not safe!",
|
"but this can't be cast because it is not safe!",
|
||||||
|
|
@ -255,10 +255,9 @@ def out_wrapper(*out_names: str, exact_dtype: bool = False):
|
||||||
_safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type]
|
_safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type]
|
||||||
else:
|
else:
|
||||||
assert isinstance(out, Tuple) # type: ignore[arg-type]
|
assert isinstance(out, Tuple) # type: ignore[arg-type]
|
||||||
utils.check(
|
torch._check_type(
|
||||||
len(out) == len(result),
|
len(out) == len(result),
|
||||||
lambda: f"expected tuple of {len(result)} elements but got {len(out)}",
|
lambda: f"expected tuple of {len(result)} elements but got {len(out)}",
|
||||||
TypeError,
|
|
||||||
)
|
)
|
||||||
for r, o in zip(result, out):
|
for r, o in zip(result, out):
|
||||||
# These two operations are done in-place
|
# These two operations are done in-place
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@ import torch._prims as prims
|
||||||
import torch._prims_common as utils
|
import torch._prims_common as utils
|
||||||
from torch import sym_float, sym_int
|
from torch import sym_float, sym_int
|
||||||
from torch._prims_common import (
|
from torch._prims_common import (
|
||||||
check,
|
|
||||||
DeviceLikeType,
|
DeviceLikeType,
|
||||||
Dim,
|
Dim,
|
||||||
DimsSequenceType,
|
DimsSequenceType,
|
||||||
|
|
@ -626,7 +625,7 @@ def frac(x: TensorLikeType) -> TensorLikeType:
|
||||||
# imag does not use _make_elementwise_unary_reference because it does not support out
|
# imag does not use _make_elementwise_unary_reference because it does not support out
|
||||||
def imag(a: TensorLikeType) -> TensorLikeType:
|
def imag(a: TensorLikeType) -> TensorLikeType:
|
||||||
assert isinstance(a, TensorLike)
|
assert isinstance(a, TensorLike)
|
||||||
utils.check(
|
torch._check(
|
||||||
utils.is_complex_dtype(a.dtype), lambda: "imag only supports complex tensors."
|
utils.is_complex_dtype(a.dtype), lambda: "imag only supports complex tensors."
|
||||||
)
|
)
|
||||||
return prims.imag(a)
|
return prims.imag(a)
|
||||||
|
|
@ -654,7 +653,7 @@ def isinf(a: TensorLikeType) -> TensorLikeType:
|
||||||
|
|
||||||
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
|
||||||
def isposinf(a: TensorLikeType) -> TensorLikeType:
|
def isposinf(a: TensorLikeType) -> TensorLikeType:
|
||||||
utils.check(
|
torch._check(
|
||||||
not utils.is_complex_dtype(a.dtype),
|
not utils.is_complex_dtype(a.dtype),
|
||||||
lambda: f"Complex dtype is not supported for isposinf, got dtype {a.dtype}",
|
lambda: f"Complex dtype is not supported for isposinf, got dtype {a.dtype}",
|
||||||
)
|
)
|
||||||
|
|
@ -665,7 +664,7 @@ def isposinf(a: TensorLikeType) -> TensorLikeType:
|
||||||
|
|
||||||
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
|
||||||
def isneginf(a: TensorLikeType) -> TensorLikeType:
|
def isneginf(a: TensorLikeType) -> TensorLikeType:
|
||||||
utils.check(
|
torch._check(
|
||||||
not utils.is_complex_dtype(a.dtype),
|
not utils.is_complex_dtype(a.dtype),
|
||||||
lambda: f"Complex dtype is not supported for isneginf, got dtype {a.dtype}",
|
lambda: f"Complex dtype is not supported for isneginf, got dtype {a.dtype}",
|
||||||
)
|
)
|
||||||
|
|
@ -788,7 +787,7 @@ def nan_to_num(
|
||||||
|
|
||||||
|
|
||||||
def _neg_meta(a: TensorLikeType):
|
def _neg_meta(a: TensorLikeType):
|
||||||
check(
|
torch._check(
|
||||||
a.dtype is not torch.bool,
|
a.dtype is not torch.bool,
|
||||||
lambda: (
|
lambda: (
|
||||||
"Negation, the `-` operator, on a bool tensor is not supported. "
|
"Negation, the `-` operator, on a bool tensor is not supported. "
|
||||||
|
|
@ -935,23 +934,20 @@ def _make_elementwise_binary_reference(
|
||||||
a: Union[Tensor, NumberType],
|
a: Union[Tensor, NumberType],
|
||||||
b: Union[Tensor, NumberType],
|
b: Union[Tensor, NumberType],
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
check(
|
torch._check_value(
|
||||||
supports_lhs_python_scalar or not isinstance(a, Number),
|
supports_lhs_python_scalar or not isinstance(a, Number),
|
||||||
lambda: f"{name}: Received a lhs Python scalar to an elementwise binary "
|
lambda: f"{name}: Received a lhs Python scalar to an elementwise binary "
|
||||||
"operation that does not accept lhs scalars!",
|
"operation that does not accept lhs scalars!",
|
||||||
ValueError,
|
|
||||||
)
|
)
|
||||||
check(
|
torch._check_value(
|
||||||
supports_rhs_python_scalar or not isinstance(b, Number),
|
supports_rhs_python_scalar or not isinstance(b, Number),
|
||||||
lambda: f"{name}: Received a rhs Python scalar to an elementwise binary "
|
lambda: f"{name}: Received a rhs Python scalar to an elementwise binary "
|
||||||
"operation that does not accept rhs scalars!",
|
"operation that does not accept rhs scalars!",
|
||||||
ValueError,
|
|
||||||
)
|
)
|
||||||
check(
|
torch._check_value(
|
||||||
supports_two_python_scalars
|
supports_two_python_scalars
|
||||||
or not (isinstance(a, Number) and isinstance(b, Number)),
|
or not (isinstance(a, Number) and isinstance(b, Number)),
|
||||||
lambda: f"{name}: Receive two Number inputs to an elementwise binary operation!",
|
lambda: f"{name}: Receive two Number inputs to an elementwise binary operation!",
|
||||||
ValueError,
|
|
||||||
)
|
)
|
||||||
a, b = _maybe_broadcast(a, b)
|
a, b = _maybe_broadcast(a, b)
|
||||||
return prim(a, b)
|
return prim(a, b)
|
||||||
|
|
@ -1230,7 +1226,7 @@ def floor_divide(
|
||||||
elif utils.is_integer_dtype(dtype):
|
elif utils.is_integer_dtype(dtype):
|
||||||
return _floor_divide_integer(a, b)
|
return _floor_divide_integer(a, b)
|
||||||
else:
|
else:
|
||||||
check(False, lambda: f"{dtype} not supported for floor_divide")
|
torch._check(False, lambda: f"{dtype} not supported for floor_divide")
|
||||||
|
|
||||||
|
|
||||||
def _floor_divide_integer(a: Tensor, b: Tensor) -> Tensor:
|
def _floor_divide_integer(a: Tensor, b: Tensor) -> Tensor:
|
||||||
|
|
@ -1374,20 +1370,19 @@ def _check_close_args(
|
||||||
rtol: float,
|
rtol: float,
|
||||||
atol: float,
|
atol: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
check(
|
torch._check_value(
|
||||||
a.dtype == b.dtype,
|
a.dtype == b.dtype,
|
||||||
lambda: "{0}: Attempting to compare tensors of different dtypes {1} and {2}!".format(
|
lambda: "{0}: Attempting to compare tensors of different dtypes {1} and {2}!".format(
|
||||||
name, a.dtype, b.dtype
|
name, a.dtype, b.dtype
|
||||||
),
|
),
|
||||||
ValueError,
|
|
||||||
)
|
)
|
||||||
check(
|
torch._check(
|
||||||
rtol >= 0,
|
rtol >= 0,
|
||||||
lambda: "{0}: rtol must be greater than or equal to zero, but got {1}!".format(
|
lambda: "{0}: rtol must be greater than or equal to zero, but got {1}!".format(
|
||||||
name, rtol
|
name, rtol
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
check(
|
torch._check(
|
||||||
atol >= 0,
|
atol >= 0,
|
||||||
lambda: "{0}: atol must be greater than or equal to zero, but got {1}!".format(
|
lambda: "{0}: atol must be greater than or equal to zero, but got {1}!".format(
|
||||||
name, atol
|
name, atol
|
||||||
|
|
@ -1678,7 +1673,7 @@ def true_divide(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
|
||||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
||||||
)
|
)
|
||||||
def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]):
|
def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]):
|
||||||
utils.check(
|
torch._check(
|
||||||
isinstance(a, TensorLike) or isinstance(b, TensorLike),
|
isinstance(a, TensorLike) or isinstance(b, TensorLike),
|
||||||
lambda: 'Expected either argument a or b to be a Tensor"',
|
lambda: 'Expected either argument a or b to be a Tensor"',
|
||||||
)
|
)
|
||||||
|
|
@ -1736,12 +1731,11 @@ def addcdiv(
|
||||||
if value is not None:
|
if value is not None:
|
||||||
dtype = self.dtype # no scalars allowed, see add
|
dtype = self.dtype # no scalars allowed, see add
|
||||||
python_type = utils.dtype_to_type(dtype)
|
python_type = utils.dtype_to_type(dtype)
|
||||||
check(
|
torch._check_value(
|
||||||
utils.is_weakly_lesser_type(type(value), python_type),
|
utils.is_weakly_lesser_type(type(value), python_type),
|
||||||
lambda: "value argument of type {0} cannot be safely cast to type {1}!".format(
|
lambda: "value argument of type {0} cannot be safely cast to type {1}!".format(
|
||||||
type(value), python_type
|
type(value), python_type
|
||||||
),
|
),
|
||||||
exc_type=ValueError,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self + value * tensor1 / tensor2
|
return self + value * tensor1 / tensor2
|
||||||
|
|
@ -1766,12 +1760,11 @@ def addcmul(
|
||||||
if value is not None:
|
if value is not None:
|
||||||
dtype = self.dtype # no scalars allowed, see add
|
dtype = self.dtype # no scalars allowed, see add
|
||||||
python_type = utils.dtype_to_type(dtype)
|
python_type = utils.dtype_to_type(dtype)
|
||||||
check(
|
torch._check_value(
|
||||||
utils.is_weakly_lesser_type(type(value), python_type),
|
utils.is_weakly_lesser_type(type(value), python_type),
|
||||||
lambda: "value argument of type {0} cannot be safely cast to type {1}!".format(
|
lambda: "value argument of type {0} cannot be safely cast to type {1}!".format(
|
||||||
type(value), python_type
|
type(value), python_type
|
||||||
),
|
),
|
||||||
exc_type=ValueError,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self + value * tensor1 * tensor2
|
return self + value * tensor1 * tensor2
|
||||||
|
|
@ -1851,7 +1844,7 @@ def where(
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
utils.check_same_device(pred, a, b, allow_cpu_scalar_tensors=True)
|
utils.check_same_device(pred, a, b, allow_cpu_scalar_tensors=True)
|
||||||
check(
|
torch._check(
|
||||||
pred.dtype is torch.bool,
|
pred.dtype is torch.bool,
|
||||||
lambda: f"expected predicate to be bool, got {pred.dtype}",
|
lambda: f"expected predicate to be bool, got {pred.dtype}",
|
||||||
)
|
)
|
||||||
|
|
@ -2229,7 +2222,7 @@ def sum_to_size(
|
||||||
*shape,
|
*shape,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
shape = utils.extract_shape_from_varargs(shape, validate=False)
|
shape = utils.extract_shape_from_varargs(shape, validate=False)
|
||||||
utils.check(
|
torch._check(
|
||||||
utils.is_expandable_to(shape, a.shape),
|
utils.is_expandable_to(shape, a.shape),
|
||||||
lambda: f'sum_to_size: size "{shape}" is not expandable to size "{a.shape}"',
|
lambda: f'sum_to_size: size "{shape}" is not expandable to size "{a.shape}"',
|
||||||
)
|
)
|
||||||
|
|
@ -2402,7 +2395,7 @@ def mean(
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = a.dtype
|
dtype = a.dtype
|
||||||
# can't use out wrapper because of this argument
|
# can't use out wrapper because of this argument
|
||||||
check(
|
torch._check(
|
||||||
out is None or out.dtype == dtype,
|
out is None or out.dtype == dtype,
|
||||||
lambda: f"Expected out tensor to have dtype {dtype}, but got {out.dtype} instead",
|
lambda: f"Expected out tensor to have dtype {dtype}, but got {out.dtype} instead",
|
||||||
)
|
)
|
||||||
|
|
@ -2415,7 +2408,7 @@ def mean(
|
||||||
out=None,
|
out=None,
|
||||||
output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE,
|
output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE,
|
||||||
)
|
)
|
||||||
check(
|
torch._check(
|
||||||
utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
|
utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
|
||||||
lambda: (
|
lambda: (
|
||||||
f"mean(): could not infer output dtype. "
|
f"mean(): could not infer output dtype. "
|
||||||
|
|
@ -2491,22 +2484,22 @@ def addr(
|
||||||
beta: NumberType = 1,
|
beta: NumberType = 1,
|
||||||
alpha: NumberType = 1,
|
alpha: NumberType = 1,
|
||||||
) -> TensorLikeType:
|
) -> TensorLikeType:
|
||||||
check(
|
torch._check(
|
||||||
vec1.ndim == 1,
|
vec1.ndim == 1,
|
||||||
lambda: f"addr: Expected 1-D argument vec1, but got {vec1.ndim}-D",
|
lambda: f"addr: Expected 1-D argument vec1, but got {vec1.ndim}-D",
|
||||||
)
|
)
|
||||||
check(
|
torch._check(
|
||||||
vec2.ndim == 1,
|
vec2.ndim == 1,
|
||||||
lambda: f"addr: Expected 1-D argument vec2, but got {vec2.ndim}-D",
|
lambda: f"addr: Expected 1-D argument vec2, but got {vec2.ndim}-D",
|
||||||
)
|
)
|
||||||
self = self.expand(vec1.shape[0], vec2.shape[0])
|
self = self.expand(vec1.shape[0], vec2.shape[0])
|
||||||
if utils.is_boolean_dtype(self.dtype):
|
if utils.is_boolean_dtype(self.dtype):
|
||||||
# Integers are accepted for booleans
|
# Integers are accepted for booleans
|
||||||
check(
|
torch._check(
|
||||||
is_weakly_lesser_type(type(beta), int),
|
is_weakly_lesser_type(type(beta), int),
|
||||||
lambda: f"expected bool/int beta but got {type(beta)}",
|
lambda: f"expected bool/int beta but got {type(beta)}",
|
||||||
)
|
)
|
||||||
check(
|
torch._check(
|
||||||
is_weakly_lesser_type(type(alpha), int),
|
is_weakly_lesser_type(type(alpha), int),
|
||||||
lambda: f"expected bool/int alpha but got {type(beta)}",
|
lambda: f"expected bool/int alpha but got {type(beta)}",
|
||||||
)
|
)
|
||||||
|
|
@ -2518,11 +2511,11 @@ def addr(
|
||||||
torch.outer(vec1, vec2) if alpha else torch.full_like(self, False),
|
torch.outer(vec1, vec2) if alpha else torch.full_like(self, False),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
check(
|
torch._check(
|
||||||
is_weakly_lesser_type(type(beta), dtype_to_type(self.dtype)),
|
is_weakly_lesser_type(type(beta), dtype_to_type(self.dtype)),
|
||||||
lambda: f"cannot safely convert {type(beta)} to {self.dtype}",
|
lambda: f"cannot safely convert {type(beta)} to {self.dtype}",
|
||||||
)
|
)
|
||||||
check(
|
torch._check(
|
||||||
is_weakly_lesser_type(type(alpha), dtype_to_type(self.dtype)),
|
is_weakly_lesser_type(type(alpha), dtype_to_type(self.dtype)),
|
||||||
lambda: f"cannot safely convert {type(alpha)} to {self.dtype}",
|
lambda: f"cannot safely convert {type(alpha)} to {self.dtype}",
|
||||||
)
|
)
|
||||||
|
|
@ -2712,7 +2705,7 @@ def conj(input: TensorLikeType) -> TensorLikeType:
|
||||||
def constant_pad_nd(
|
def constant_pad_nd(
|
||||||
input: TensorLikeType, pad: List[int], value: NumberType = 0
|
input: TensorLikeType, pad: List[int], value: NumberType = 0
|
||||||
) -> TensorLikeType:
|
) -> TensorLikeType:
|
||||||
check(
|
torch._check(
|
||||||
len(pad) % 2 == 0,
|
len(pad) % 2 == 0,
|
||||||
lambda: f"Length of pad must be even but instead it equals {len(pad)}",
|
lambda: f"Length of pad must be even but instead it equals {len(pad)}",
|
||||||
)
|
)
|
||||||
|
|
@ -2723,7 +2716,7 @@ def constant_pad_nd(
|
||||||
l_pad = len(pad) // 2
|
l_pad = len(pad) // 2
|
||||||
l_diff = l_inp - l_pad
|
l_diff = l_inp - l_pad
|
||||||
|
|
||||||
check(
|
torch._check(
|
||||||
l_inp >= l_pad,
|
l_inp >= l_pad,
|
||||||
lambda: "Length of pad should be no more than twice the number of "
|
lambda: "Length of pad should be no more than twice the number of "
|
||||||
f"dimensions of the input. Pad length is {len(pad)} while the input has "
|
f"dimensions of the input. Pad length is {len(pad)} while the input has "
|
||||||
|
|
@ -2748,7 +2741,7 @@ def constant_pad_nd(
|
||||||
for i in range(l_pad):
|
for i in range(l_pad):
|
||||||
pad_idx = len(pad) - ((i + 1) * 2)
|
pad_idx = len(pad) - ((i + 1) * 2)
|
||||||
new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1]
|
new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1]
|
||||||
check(
|
torch._check(
|
||||||
new_dim > 0,
|
new_dim > 0,
|
||||||
lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding "
|
lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding "
|
||||||
f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, "
|
f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, "
|
||||||
|
|
@ -2787,7 +2780,7 @@ def constant_pad_nd(
|
||||||
def contiguous(
|
def contiguous(
|
||||||
a: Tensor, *, memory_format: torch.memory_format = torch.contiguous_format
|
a: Tensor, *, memory_format: torch.memory_format = torch.contiguous_format
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
check(
|
torch._check(
|
||||||
memory_format != torch.preserve_format,
|
memory_format != torch.preserve_format,
|
||||||
lambda: "preserve memory format is unsupported by the contiguous operator",
|
lambda: "preserve memory format is unsupported by the contiguous operator",
|
||||||
)
|
)
|
||||||
|
|
@ -2800,7 +2793,7 @@ def contiguous(
|
||||||
|
|
||||||
@out_wrapper()
|
@out_wrapper()
|
||||||
def dstack(tensors: TensorSequenceType) -> TensorLikeType:
|
def dstack(tensors: TensorSequenceType) -> TensorLikeType:
|
||||||
check(len(tensors) > 0, lambda: "dstack expects a non-empty TensorList")
|
torch._check(len(tensors) > 0, lambda: "dstack expects a non-empty TensorList")
|
||||||
aligned_tensors = atleast_3d(*tensors)
|
aligned_tensors = atleast_3d(*tensors)
|
||||||
return cat(aligned_tensors, 2)
|
return cat(aligned_tensors, 2)
|
||||||
|
|
||||||
|
|
@ -2813,7 +2806,7 @@ def expand(a: Tensor, *shape) -> Tensor:
|
||||||
if len(shape) == 1 and isinstance(shape[0], Sequence):
|
if len(shape) == 1 and isinstance(shape[0], Sequence):
|
||||||
shape = tuple(shape[0])
|
shape = tuple(shape[0])
|
||||||
|
|
||||||
check(
|
torch._check(
|
||||||
len(shape) >= len(a.shape),
|
len(shape) >= len(a.shape),
|
||||||
lambda: "expand: the requested shape has too few dimensions!",
|
lambda: "expand: the requested shape has too few dimensions!",
|
||||||
)
|
)
|
||||||
|
|
@ -2823,7 +2816,7 @@ def expand(a: Tensor, *shape) -> Tensor:
|
||||||
for idx, x in enumerate(a.shape):
|
for idx, x in enumerate(a.shape):
|
||||||
offset_idx = idx + offset
|
offset_idx = idx + offset
|
||||||
requested_length = shape[offset_idx]
|
requested_length = shape[offset_idx]
|
||||||
check(
|
torch._check(
|
||||||
requested_length == x or x == 1 or requested_length == -1,
|
requested_length == x or x == 1 or requested_length == -1,
|
||||||
lambda: f"expand: attempting to expand a dimension of length {x}!",
|
lambda: f"expand: attempting to expand a dimension of length {x}!",
|
||||||
)
|
)
|
||||||
|
|
@ -2917,13 +2910,13 @@ def narrow(
|
||||||
# Supports Tensor overload that was added for XLA:
|
# Supports Tensor overload that was added for XLA:
|
||||||
# https://github.com/pytorch/pytorch/issues/31558
|
# https://github.com/pytorch/pytorch/issues/31558
|
||||||
if isinstance(start, TensorLike):
|
if isinstance(start, TensorLike):
|
||||||
check(
|
torch._check(
|
||||||
start.dim() == 0 and utils.is_integer_dtype(start.dtype),
|
start.dim() == 0 and utils.is_integer_dtype(start.dtype),
|
||||||
lambda: "start must be an 0-dim integral Tensor.",
|
lambda: "start must be an 0-dim integral Tensor.",
|
||||||
)
|
)
|
||||||
start = start.item() # type: ignore[assignment]
|
start = start.item() # type: ignore[assignment]
|
||||||
check(a.dim() > 0, lambda: "narrow() cannot be applied to a 0-dim tensor.")
|
torch._check(a.dim() > 0, lambda: "narrow() cannot be applied to a 0-dim tensor.")
|
||||||
check(length >= 0, lambda: "narrow(): length must be non-negative.")
|
torch._check(length >= 0, lambda: "narrow(): length must be non-negative.")
|
||||||
dim = utils.canonicalize_dim(a.ndim, dim)
|
dim = utils.canonicalize_dim(a.ndim, dim)
|
||||||
dim_length = a.size(dim)
|
dim_length = a.size(dim)
|
||||||
# Start being the end is usually invalid since it's out of bounds. So it's
|
# Start being the end is usually invalid since it's out of bounds. So it's
|
||||||
|
|
@ -2934,7 +2927,7 @@ def narrow(
|
||||||
# Note: a dimension isn't being canonicalized here, this reuses
|
# Note: a dimension isn't being canonicalized here, this reuses
|
||||||
# canonicalize_dim because the semantics are similar.
|
# canonicalize_dim because the semantics are similar.
|
||||||
start = utils.canonicalize_dim(dim_length, start) # type: ignore[arg-type]
|
start = utils.canonicalize_dim(dim_length, start) # type: ignore[arg-type]
|
||||||
check(
|
torch._check(
|
||||||
start <= dim_length - length, # type: ignore[arg-type]
|
start <= dim_length - length, # type: ignore[arg-type]
|
||||||
lambda: f"start ({start}) + length ({length}) exceeds dimension size ({dim_length}).",
|
lambda: f"start ({start}) + length ({length}) exceeds dimension size ({dim_length}).",
|
||||||
)
|
)
|
||||||
|
|
@ -2993,11 +2986,11 @@ def native_group_norm(
|
||||||
num_groups: int,
|
num_groups: int,
|
||||||
eps: float,
|
eps: float,
|
||||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||||
utils.check(
|
torch._check(
|
||||||
input.ndim >= 2,
|
input.ndim >= 2,
|
||||||
lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}",
|
lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}",
|
||||||
)
|
)
|
||||||
utils.check(
|
torch._check(
|
||||||
num_channels % num_groups == 0,
|
num_channels % num_groups == 0,
|
||||||
lambda: "Expected number of channels in input to be divisible by num_groups, "
|
lambda: "Expected number of channels in input to be divisible by num_groups, "
|
||||||
+ f"but got input of shape {input.shape} and num_groups = {num_groups}",
|
+ f"but got input of shape {input.shape} and num_groups = {num_groups}",
|
||||||
|
|
@ -3044,7 +3037,7 @@ def native_layer_norm(
|
||||||
eps: float,
|
eps: float,
|
||||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||||
normalized_ndim = len(normalized_shape)
|
normalized_ndim = len(normalized_shape)
|
||||||
utils.check(
|
torch._check(
|
||||||
normalized_ndim >= 1,
|
normalized_ndim >= 1,
|
||||||
lambda: "Expected normalized_shape to be at least 1-dimensional, i.e., "
|
lambda: "Expected normalized_shape to be at least 1-dimensional, i.e., "
|
||||||
+ "containing at least one element, but got normalized_shape = "
|
+ "containing at least one element, but got normalized_shape = "
|
||||||
|
|
@ -3053,7 +3046,7 @@ def native_layer_norm(
|
||||||
# torch.Size([1, 2, 3]) == [1, 2, 3] evaluates to False
|
# torch.Size([1, 2, 3]) == [1, 2, 3] evaluates to False
|
||||||
# while torch.Size([1, 2, 3]) == (1, 2, 3) is True
|
# while torch.Size([1, 2, 3]) == (1, 2, 3) is True
|
||||||
# therefore we use tuple(normalized_shape)
|
# therefore we use tuple(normalized_shape)
|
||||||
utils.check(
|
torch._check(
|
||||||
weight is None or weight.shape == tuple(normalized_shape),
|
weight is None or weight.shape == tuple(normalized_shape),
|
||||||
lambda: "Expected weight to be of same shape as normalized_shape, but got "
|
lambda: "Expected weight to be of same shape as normalized_shape, but got "
|
||||||
+ "weight of shape "
|
+ "weight of shape "
|
||||||
|
|
@ -3061,7 +3054,7 @@ def native_layer_norm(
|
||||||
+ " and normalized_shape = "
|
+ " and normalized_shape = "
|
||||||
+ str(normalized_shape),
|
+ str(normalized_shape),
|
||||||
)
|
)
|
||||||
utils.check(
|
torch._check(
|
||||||
bias is None or bias.shape == tuple(normalized_shape),
|
bias is None or bias.shape == tuple(normalized_shape),
|
||||||
lambda: "Expected bias to be of same shape as normalized_shape, but got "
|
lambda: "Expected bias to be of same shape as normalized_shape, but got "
|
||||||
+ "bias of shape "
|
+ "bias of shape "
|
||||||
|
|
@ -3069,7 +3062,7 @@ def native_layer_norm(
|
||||||
+ " and normalized_shape = "
|
+ " and normalized_shape = "
|
||||||
+ str(normalized_shape),
|
+ str(normalized_shape),
|
||||||
)
|
)
|
||||||
utils.check(
|
torch._check(
|
||||||
input.ndim >= normalized_ndim
|
input.ndim >= normalized_ndim
|
||||||
and input.shape[(input.ndim - normalized_ndim) :] == tuple(normalized_shape),
|
and input.shape[(input.ndim - normalized_ndim) :] == tuple(normalized_shape),
|
||||||
lambda: "Given normalized_shape="
|
lambda: "Given normalized_shape="
|
||||||
|
|
@ -3123,12 +3116,12 @@ def _get_unfold_shape_stride(
|
||||||
max_size = 1 if a_ndim == 0 else a_shape[dim]
|
max_size = 1 if a_ndim == 0 else a_shape[dim]
|
||||||
last_stride = 1 if a_ndim == 0 else a_stride[dim]
|
last_stride = 1 if a_ndim == 0 else a_stride[dim]
|
||||||
|
|
||||||
utils.check(
|
torch._check(
|
||||||
size <= max_size,
|
size <= max_size,
|
||||||
lambda: f"Maximum size for tensor at dimension {dim} is {max_size} but size is {size}",
|
lambda: f"Maximum size for tensor at dimension {dim} is {max_size} but size is {size}",
|
||||||
)
|
)
|
||||||
|
|
||||||
utils.check(
|
torch._check(
|
||||||
step > 0,
|
step > 0,
|
||||||
lambda: f"Step is {step} but must be > 0",
|
lambda: f"Step is {step} but must be > 0",
|
||||||
)
|
)
|
||||||
|
|
@ -3146,7 +3139,7 @@ def _get_unfold_shape_stride(
|
||||||
@register_decomposition(aten.repeat)
|
@register_decomposition(aten.repeat)
|
||||||
def repeat(a: Tensor, *repeat_shape) -> Tensor:
|
def repeat(a: Tensor, *repeat_shape) -> Tensor:
|
||||||
repeat_shape = utils.extract_shape_from_varargs(repeat_shape, validate=False)
|
repeat_shape = utils.extract_shape_from_varargs(repeat_shape, validate=False)
|
||||||
utils.check(
|
torch._check(
|
||||||
len(repeat_shape) >= len(a.shape),
|
len(repeat_shape) >= len(a.shape),
|
||||||
lambda: "repeat: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
|
lambda: "repeat: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
|
||||||
)
|
)
|
||||||
|
|
@ -3452,7 +3445,7 @@ def softmax(
|
||||||
# CompositeImplicitAutograd - don't register decomp
|
# CompositeImplicitAutograd - don't register decomp
|
||||||
@out_wrapper()
|
@out_wrapper()
|
||||||
def hstack(tensors: TensorSequenceType) -> TensorLikeType:
|
def hstack(tensors: TensorSequenceType) -> TensorLikeType:
|
||||||
check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList")
|
torch._check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList")
|
||||||
aligned_tensors = atleast_1d(*tensors)
|
aligned_tensors = atleast_1d(*tensors)
|
||||||
if aligned_tensors[0].ndim == 1:
|
if aligned_tensors[0].ndim == 1:
|
||||||
return cat(aligned_tensors, 0)
|
return cat(aligned_tensors, 0)
|
||||||
|
|
@ -3462,7 +3455,7 @@ def hstack(tensors: TensorSequenceType) -> TensorLikeType:
|
||||||
# CompositeImplicitAutograd - don't register decomp
|
# CompositeImplicitAutograd - don't register decomp
|
||||||
@out_wrapper()
|
@out_wrapper()
|
||||||
def vstack(tensors: TensorSequenceType) -> TensorLikeType:
|
def vstack(tensors: TensorSequenceType) -> TensorLikeType:
|
||||||
check(len(tensors) > 0, lambda: "vstack expects a non-empty TensorList")
|
torch._check(len(tensors) > 0, lambda: "vstack expects a non-empty TensorList")
|
||||||
aligned_tensors = atleast_2d(*tensors)
|
aligned_tensors = atleast_2d(*tensors)
|
||||||
return cat(aligned_tensors, 0)
|
return cat(aligned_tensors, 0)
|
||||||
|
|
||||||
|
|
@ -3470,17 +3463,16 @@ def vstack(tensors: TensorSequenceType) -> TensorLikeType:
|
||||||
# CompositeImplicitAutograd - don't register decomp
|
# CompositeImplicitAutograd - don't register decomp
|
||||||
def unflatten(a: TensorLikeType, dim: int, sizes: ShapeType) -> TensorLikeType:
|
def unflatten(a: TensorLikeType, dim: int, sizes: ShapeType) -> TensorLikeType:
|
||||||
dim = utils.canonicalize_dim(a.ndim, dim)
|
dim = utils.canonicalize_dim(a.ndim, dim)
|
||||||
utils.check(len(sizes) != 0, lambda: "unflatten: sizes must be non-empty")
|
torch._check(len(sizes) != 0, lambda: "unflatten: sizes must be non-empty")
|
||||||
return a.view(tuple(a.shape[:dim]) + tuple(sizes) + tuple(a.shape[dim + 1 :]))
|
return a.view(tuple(a.shape[:dim]) + tuple(sizes) + tuple(a.shape[dim + 1 :]))
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition(aten.unbind)
|
@register_decomposition(aten.unbind)
|
||||||
def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType:
|
def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType:
|
||||||
dim = utils.canonicalize_dim(t.ndim, dim)
|
dim = utils.canonicalize_dim(t.ndim, dim)
|
||||||
check(
|
torch._check_index(
|
||||||
len(t.shape) > 0,
|
len(t.shape) > 0,
|
||||||
lambda: "Dimension specified as 0 but tensor has no dimensions",
|
lambda: "Dimension specified as 0 but tensor has no dimensions",
|
||||||
IndexError,
|
|
||||||
)
|
)
|
||||||
if t.shape[dim] == 0:
|
if t.shape[dim] == 0:
|
||||||
return tuple()
|
return tuple()
|
||||||
|
|
@ -3499,7 +3491,7 @@ def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
|
||||||
|
|
||||||
def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
|
def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
|
||||||
dim = utils.canonicalize_dims(x.ndim, dim)
|
dim = utils.canonicalize_dims(x.ndim, dim)
|
||||||
utils.check(
|
torch._check(
|
||||||
index.ndim <= 1,
|
index.ndim <= 1,
|
||||||
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
|
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
|
||||||
)
|
)
|
||||||
|
|
@ -3532,12 +3524,12 @@ def _index_fill(
|
||||||
*,
|
*,
|
||||||
inplace: bool,
|
inplace: bool,
|
||||||
):
|
):
|
||||||
utils.check(
|
torch._check(
|
||||||
index.ndim <= 1,
|
index.ndim <= 1,
|
||||||
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
|
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
|
||||||
)
|
)
|
||||||
if isinstance(value, TensorLike):
|
if isinstance(value, TensorLike):
|
||||||
utils.check(
|
torch._check(
|
||||||
value.ndim == 0,
|
value.ndim == 0,
|
||||||
lambda: "Only supports 0-dimensional value tensor. " # type: ignore[union-attr]
|
lambda: "Only supports 0-dimensional value tensor. " # type: ignore[union-attr]
|
||||||
f"Got a tensor with {value.ndim} dimensions.",
|
f"Got a tensor with {value.ndim} dimensions.",
|
||||||
|
|
@ -3589,7 +3581,7 @@ def index_add(
|
||||||
@out_wrapper()
|
@out_wrapper()
|
||||||
def index_select(x: TensorLike, dim: int, index: TensorLike):
|
def index_select(x: TensorLike, dim: int, index: TensorLike):
|
||||||
dim = utils.canonicalize_dims(x.ndim, dim)
|
dim = utils.canonicalize_dims(x.ndim, dim)
|
||||||
utils.check(
|
torch._check(
|
||||||
index.ndim <= 1,
|
index.ndim <= 1,
|
||||||
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
|
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
|
||||||
)
|
)
|
||||||
|
|
@ -3713,7 +3705,7 @@ def tensor_split(
|
||||||
def hsplit(
|
def hsplit(
|
||||||
a: TensorLikeType, indices_or_sections: DimsType
|
a: TensorLikeType, indices_or_sections: DimsType
|
||||||
) -> Tuple[TensorLikeType, ...]:
|
) -> Tuple[TensorLikeType, ...]:
|
||||||
check(
|
torch._check(
|
||||||
a.ndim >= 1,
|
a.ndim >= 1,
|
||||||
lambda: (
|
lambda: (
|
||||||
"torch.hsplit requires a tensor with at least 1 dimension, but got a tensor with "
|
"torch.hsplit requires a tensor with at least 1 dimension, but got a tensor with "
|
||||||
|
|
@ -3724,7 +3716,7 @@ def hsplit(
|
||||||
dim = 0 if a.ndim == 1 else 1
|
dim = 0 if a.ndim == 1 else 1
|
||||||
if isinstance(indices_or_sections, IntLike):
|
if isinstance(indices_or_sections, IntLike):
|
||||||
split_size = indices_or_sections
|
split_size = indices_or_sections
|
||||||
check(
|
torch._check(
|
||||||
(split_size != 0 and a.shape[dim] % split_size == 0),
|
(split_size != 0 and a.shape[dim] % split_size == 0),
|
||||||
lambda: (
|
lambda: (
|
||||||
"torch.hsplit attempted to split along dimension "
|
"torch.hsplit attempted to split along dimension "
|
||||||
|
|
@ -3738,14 +3730,13 @@ def hsplit(
|
||||||
)
|
)
|
||||||
return tensor_split(a, split_size, dim)
|
return tensor_split(a, split_size, dim)
|
||||||
|
|
||||||
check(
|
torch._check_type(
|
||||||
isinstance(indices_or_sections, (list, tuple)),
|
isinstance(indices_or_sections, (list, tuple)),
|
||||||
lambda: (
|
lambda: (
|
||||||
"hsplit(): received an invalid combination of arguments. "
|
"hsplit(): received an invalid combination of arguments. "
|
||||||
"Expected indices_or_sections to be of type int, list of ints or tuple of ints "
|
"Expected indices_or_sections to be of type int, list of ints or tuple of ints "
|
||||||
f"but got type {type(indices_or_sections)}"
|
f"but got type {type(indices_or_sections)}"
|
||||||
),
|
),
|
||||||
exc_type=TypeError,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
split_sizes = indices_or_sections
|
split_sizes = indices_or_sections
|
||||||
|
|
@ -3756,7 +3747,7 @@ def hsplit(
|
||||||
def vsplit(
|
def vsplit(
|
||||||
a: TensorLikeType, indices_or_sections: DimsType
|
a: TensorLikeType, indices_or_sections: DimsType
|
||||||
) -> Tuple[TensorLikeType, ...]:
|
) -> Tuple[TensorLikeType, ...]:
|
||||||
check(
|
torch._check(
|
||||||
a.ndim >= 2,
|
a.ndim >= 2,
|
||||||
lambda: (
|
lambda: (
|
||||||
"torch.vsplit requires a tensor with at least 2 dimension, but got a tensor with "
|
"torch.vsplit requires a tensor with at least 2 dimension, but got a tensor with "
|
||||||
|
|
@ -3766,7 +3757,7 @@ def vsplit(
|
||||||
)
|
)
|
||||||
if isinstance(indices_or_sections, IntLike):
|
if isinstance(indices_or_sections, IntLike):
|
||||||
split_size = indices_or_sections
|
split_size = indices_or_sections
|
||||||
check(
|
torch._check(
|
||||||
(split_size != 0 and a.shape[0] % split_size == 0),
|
(split_size != 0 and a.shape[0] % split_size == 0),
|
||||||
lambda: (
|
lambda: (
|
||||||
f"torch.vsplit attempted to split along dimension 0"
|
f"torch.vsplit attempted to split along dimension 0"
|
||||||
|
|
@ -3779,14 +3770,13 @@ def vsplit(
|
||||||
)
|
)
|
||||||
return tensor_split(a, split_size, 0)
|
return tensor_split(a, split_size, 0)
|
||||||
|
|
||||||
check(
|
torch._check_type(
|
||||||
isinstance(indices_or_sections, (list, tuple)),
|
isinstance(indices_or_sections, (list, tuple)),
|
||||||
lambda: (
|
lambda: (
|
||||||
"vsplit(): received an invalid combination of arguments. "
|
"vsplit(): received an invalid combination of arguments. "
|
||||||
"Expected indices_or_sections to be of type int, list of ints or tuple of ints "
|
"Expected indices_or_sections to be of type int, list of ints or tuple of ints "
|
||||||
f"but got type {type(indices_or_sections)}"
|
f"but got type {type(indices_or_sections)}"
|
||||||
),
|
),
|
||||||
exc_type=TypeError,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
split_sizes = indices_or_sections
|
split_sizes = indices_or_sections
|
||||||
|
|
@ -3800,7 +3790,7 @@ def diag(
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
) -> TensorLikeType:
|
) -> TensorLikeType:
|
||||||
ndim = self.dim()
|
ndim = self.dim()
|
||||||
utils.check(
|
torch._check(
|
||||||
ndim in (1, 2), lambda: f"diag(): Supports 1D or 2D tensors. Got {ndim}D"
|
ndim in (1, 2), lambda: f"diag(): Supports 1D or 2D tensors. Got {ndim}D"
|
||||||
)
|
)
|
||||||
if ndim == 1:
|
if ndim == 1:
|
||||||
|
|
@ -3820,7 +3810,7 @@ def diagonal_scatter(
|
||||||
) -> TensorLikeType:
|
) -> TensorLikeType:
|
||||||
out = utils.clone_preserve_strides(input)
|
out = utils.clone_preserve_strides(input)
|
||||||
diag = out.diagonal(offset, dim1, dim2)
|
diag = out.diagonal(offset, dim1, dim2)
|
||||||
check(
|
torch._check(
|
||||||
diag.shape == src.shape,
|
diag.shape == src.shape,
|
||||||
lambda: "expected src to have a size equal to the diagonal of the input."
|
lambda: "expected src to have a size equal to the diagonal of the input."
|
||||||
f"Got {src.shape} for a diagonal of shape {diag.shape}",
|
f"Got {src.shape} for a diagonal of shape {diag.shape}",
|
||||||
|
|
@ -3843,7 +3833,7 @@ def diagonal(
|
||||||
dim1 = utils.canonicalize_dim(idx=dim1, rank=num_dims)
|
dim1 = utils.canonicalize_dim(idx=dim1, rank=num_dims)
|
||||||
dim2 = utils.canonicalize_dim(idx=dim2, rank=num_dims)
|
dim2 = utils.canonicalize_dim(idx=dim2, rank=num_dims)
|
||||||
|
|
||||||
check(
|
torch._check(
|
||||||
dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
|
dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -3896,7 +3886,7 @@ def diag_embed(
|
||||||
dim1 = utils.canonicalize_dim(rank=rank, idx=dim1)
|
dim1 = utils.canonicalize_dim(rank=rank, idx=dim1)
|
||||||
dim2 = utils.canonicalize_dim(rank=rank, idx=dim2)
|
dim2 = utils.canonicalize_dim(rank=rank, idx=dim2)
|
||||||
|
|
||||||
check(
|
torch._check(
|
||||||
dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
|
dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -3967,7 +3957,7 @@ def t(a: TensorLikeType):
|
||||||
# CompositeImplicitAutograd - don't register decomp
|
# CompositeImplicitAutograd - don't register decomp
|
||||||
def T(a: TensorLikeType) -> TensorLikeType:
|
def T(a: TensorLikeType) -> TensorLikeType:
|
||||||
# n != 2 && n != 0 is deprecated in regular PyTorch.
|
# n != 2 && n != 0 is deprecated in regular PyTorch.
|
||||||
check(
|
torch._check(
|
||||||
a.ndim in (0, 2),
|
a.ndim in (0, 2),
|
||||||
lambda: (
|
lambda: (
|
||||||
"The use of `x.T` on tensors of dimension other than 0 or 2 "
|
"The use of `x.T` on tensors of dimension other than 0 or 2 "
|
||||||
|
|
@ -4102,7 +4092,7 @@ def empty(
|
||||||
pin_memory: bool = False,
|
pin_memory: bool = False,
|
||||||
memory_format: torch.memory_format = torch.contiguous_format,
|
memory_format: torch.memory_format = torch.contiguous_format,
|
||||||
) -> TensorLikeType:
|
) -> TensorLikeType:
|
||||||
check(
|
torch._check(
|
||||||
memory_format != torch.preserve_format,
|
memory_format != torch.preserve_format,
|
||||||
lambda: "torch.empty: the Preserve memory format is not supported",
|
lambda: "torch.empty: the Preserve memory format is not supported",
|
||||||
)
|
)
|
||||||
|
|
@ -4114,7 +4104,7 @@ def empty(
|
||||||
elif memory_format == torch.channels_last_3d:
|
elif memory_format == torch.channels_last_3d:
|
||||||
strides = utils.make_channels_last_3d_strides_for(shape)
|
strides = utils.make_channels_last_3d_strides_for(shape)
|
||||||
else: # memory_format == torch.channels_last
|
else: # memory_format == torch.channels_last
|
||||||
check(
|
torch._check(
|
||||||
memory_format == torch.channels_last,
|
memory_format == torch.channels_last,
|
||||||
lambda: f"torch.empty: received an unknown memory format {memory_format}!",
|
lambda: f"torch.empty: received an unknown memory format {memory_format}!",
|
||||||
)
|
)
|
||||||
|
|
@ -4398,8 +4388,8 @@ def arange(
|
||||||
if end is None:
|
if end is None:
|
||||||
end = start
|
end = start
|
||||||
start = 0
|
start = 0
|
||||||
utils.check(step != 0, lambda: "step must be nonzero")
|
torch._check(step != 0, lambda: "step must be nonzero")
|
||||||
utils.check(
|
torch._check(
|
||||||
(step > 0 and end >= start) or (step < 0 and end <= start),
|
(step > 0 and end >= start) or (step < 0 and end <= start),
|
||||||
lambda: "upper bound and lower bound inconsistent with step sign",
|
lambda: "upper bound and lower bound inconsistent with step sign",
|
||||||
)
|
)
|
||||||
|
|
@ -4407,11 +4397,11 @@ def arange(
|
||||||
def is_finite(x):
|
def is_finite(x):
|
||||||
return not isinstance(x, FloatWithoutSymFloat) or math.isfinite(x)
|
return not isinstance(x, FloatWithoutSymFloat) or math.isfinite(x)
|
||||||
|
|
||||||
utils.check(
|
torch._check(
|
||||||
is_finite(start) and is_finite(end),
|
is_finite(start) and is_finite(end),
|
||||||
lambda: f"unsupported range: {start} -> {end}",
|
lambda: f"unsupported range: {start} -> {end}",
|
||||||
)
|
)
|
||||||
utils.check(
|
torch._check(
|
||||||
is_finite(step),
|
is_finite(step),
|
||||||
lambda: f"step must be finite but got {step}",
|
lambda: f"step must be finite but got {step}",
|
||||||
)
|
)
|
||||||
|
|
@ -4514,7 +4504,7 @@ def linspace(
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = default_complex_dtype
|
dtype = default_complex_dtype
|
||||||
else:
|
else:
|
||||||
check(
|
torch._check(
|
||||||
utils.is_complex_dtype(dtype),
|
utils.is_complex_dtype(dtype),
|
||||||
lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}",
|
lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}",
|
||||||
)
|
)
|
||||||
|
|
@ -4523,13 +4513,12 @@ def linspace(
|
||||||
assert isinstance(dtype, torch.dtype)
|
assert isinstance(dtype, torch.dtype)
|
||||||
|
|
||||||
# steps does not participate in the computation of the dtype
|
# steps does not participate in the computation of the dtype
|
||||||
check(
|
torch._check_type(
|
||||||
isinstance(steps, IntLike),
|
isinstance(steps, IntLike),
|
||||||
lambda: "steps must be int, not float",
|
lambda: "steps must be int, not float",
|
||||||
exc_type=TypeError,
|
|
||||||
)
|
)
|
||||||
assert isinstance(steps, IntLike) # for mypy
|
assert isinstance(steps, IntLike) # for mypy
|
||||||
check(steps >= 0, lambda: "number of steps must be non-negative")
|
torch._check(steps >= 0, lambda: "number of steps must be non-negative")
|
||||||
|
|
||||||
factory_kwargs = {
|
factory_kwargs = {
|
||||||
"layout": layout,
|
"layout": layout,
|
||||||
|
|
@ -4631,19 +4620,19 @@ def meshgrid(
|
||||||
assert len(tensors) == 1
|
assert len(tensors) == 1
|
||||||
tensors = tuple(tensors[0])
|
tensors = tuple(tensors[0])
|
||||||
|
|
||||||
check(
|
torch._check(
|
||||||
py_all(isinstance(a, TensorLike) for a in tensors),
|
py_all(isinstance(a, TensorLike) for a in tensors),
|
||||||
lambda: "meshgrid expects its inputs to be tensors",
|
lambda: "meshgrid expects its inputs to be tensors",
|
||||||
)
|
)
|
||||||
|
|
||||||
check(len(tensors) > 0, lambda: "meshgrid expects a non-empty TensorList")
|
torch._check(len(tensors) > 0, lambda: "meshgrid expects a non-empty TensorList")
|
||||||
|
|
||||||
for i in range(len(tensors) - 1):
|
for i in range(len(tensors) - 1):
|
||||||
check(
|
torch._check(
|
||||||
tensors[i].dtype == tensors[i + 1].dtype, # type: ignore[union-attr]
|
tensors[i].dtype == tensors[i + 1].dtype, # type: ignore[union-attr]
|
||||||
lambda: "meshgrid expects all tensors to have the same dtype",
|
lambda: "meshgrid expects all tensors to have the same dtype",
|
||||||
)
|
)
|
||||||
check(
|
torch._check(
|
||||||
tensors[i].device == tensors[i + 1].device, # type: ignore[union-attr]
|
tensors[i].device == tensors[i + 1].device, # type: ignore[union-attr]
|
||||||
lambda: "meshgrid expects all tensors to have the same device",
|
lambda: "meshgrid expects all tensors to have the same device",
|
||||||
)
|
)
|
||||||
|
|
@ -4654,7 +4643,7 @@ def meshgrid(
|
||||||
if swap_first_and_second_tensors:
|
if swap_first_and_second_tensors:
|
||||||
tensors = (tensors[1], tensors[0], *tensors[2:])
|
tensors = (tensors[1], tensors[0], *tensors[2:])
|
||||||
else:
|
else:
|
||||||
check(
|
torch._check(
|
||||||
indexing == "ij",
|
indexing == "ij",
|
||||||
lambda: (
|
lambda: (
|
||||||
'torch.meshgrid: indexing must be one of "xy" or "ij", '
|
'torch.meshgrid: indexing must be one of "xy" or "ij", '
|
||||||
|
|
@ -4665,7 +4654,7 @@ def meshgrid(
|
||||||
result_shape: List[int] = []
|
result_shape: List[int] = []
|
||||||
for t in tensors:
|
for t in tensors:
|
||||||
assert isinstance(t, TensorLike) # mypy
|
assert isinstance(t, TensorLike) # mypy
|
||||||
check(
|
torch._check(
|
||||||
t.ndim == 0 or t.ndim == 1,
|
t.ndim == 0 or t.ndim == 1,
|
||||||
lambda: f"torch.meshgrid: Expected 0D or 1D tensor in the tensor list but got: {t}",
|
lambda: f"torch.meshgrid: Expected 0D or 1D tensor in the tensor list but got: {t}",
|
||||||
)
|
)
|
||||||
|
|
@ -4701,7 +4690,7 @@ def movedim(
|
||||||
|
|
||||||
# Converts to list to produce a compatible error message with core PyTorch,
|
# Converts to list to produce a compatible error message with core PyTorch,
|
||||||
# which prints sequences in square brackets.
|
# which prints sequences in square brackets.
|
||||||
utils.check(
|
torch._check(
|
||||||
len(source) == len(destination), # type: ignore[arg-type]
|
len(source) == len(destination), # type: ignore[arg-type]
|
||||||
lambda: (
|
lambda: (
|
||||||
"movedim: Invalid source or destination dims: source " # type: ignore[arg-type]
|
"movedim: Invalid source or destination dims: source " # type: ignore[arg-type]
|
||||||
|
|
@ -4718,11 +4707,11 @@ def movedim(
|
||||||
dss = set(ds)
|
dss = set(ds)
|
||||||
|
|
||||||
# See above on why this converts to list in error messages.
|
# See above on why this converts to list in error messages.
|
||||||
utils.check(
|
torch._check(
|
||||||
len(ss) == len(sss),
|
len(ss) == len(sss),
|
||||||
lambda: f"movedim: repeated dim in `source` ({list(source)})", # type: ignore[arg-type]
|
lambda: f"movedim: repeated dim in `source` ({list(source)})", # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
utils.check(
|
torch._check(
|
||||||
len(ds) == len(dss),
|
len(ds) == len(dss),
|
||||||
lambda: f"movedim: repeated dim in `destination` ({list(destination)})", # type: ignore[arg-type]
|
lambda: f"movedim: repeated dim in `destination` ({list(destination)})", # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
|
|
@ -4795,8 +4784,8 @@ def eye(
|
||||||
if m is None:
|
if m is None:
|
||||||
m = n
|
m = n
|
||||||
|
|
||||||
check(n >= 0, lambda: f"n must be greater or equal to 0, got {n}")
|
torch._check(n >= 0, lambda: f"n must be greater or equal to 0, got {n}")
|
||||||
check(m >= 0, lambda: f"m must be greater or equal to 0, got {m}")
|
torch._check(m >= 0, lambda: f"m must be greater or equal to 0, got {m}")
|
||||||
|
|
||||||
range_n = torch.arange(n, dtype=torch.int64, device=device, requires_grad=False)
|
range_n = torch.arange(n, dtype=torch.int64, device=device, requires_grad=False)
|
||||||
range_m = torch.arange(m, dtype=torch.int64, device=device, requires_grad=False)
|
range_m = torch.arange(m, dtype=torch.int64, device=device, requires_grad=False)
|
||||||
|
|
@ -4994,13 +4983,13 @@ def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLi
|
||||||
# NOTE: Could not use value = item(value) as it resulted in
|
# NOTE: Could not use value = item(value) as it resulted in
|
||||||
# RuntimeError: Cannot cast FakeTensor(cpu) to number
|
# RuntimeError: Cannot cast FakeTensor(cpu) to number
|
||||||
value_ndim = value.ndim
|
value_ndim = value.ndim
|
||||||
check(
|
torch._check(
|
||||||
value_ndim == 0,
|
value_ndim == 0,
|
||||||
lambda: f"only supports a 0-dimensional value tensor, but got tensor with {value_ndim} dimension",
|
lambda: f"only supports a 0-dimensional value tensor, but got tensor with {value_ndim} dimension",
|
||||||
)
|
)
|
||||||
# `masked_fill` allows cpu scalar to be moved to cuda and xpu but not otherwise.
|
# `masked_fill` allows cpu scalar to be moved to cuda and xpu but not otherwise.
|
||||||
is_cpu_scalar = a.device.type in ["cuda", "xpu"] and value.device.type == "cpu"
|
is_cpu_scalar = a.device.type in ["cuda", "xpu"] and value.device.type == "cpu"
|
||||||
check(
|
torch._check(
|
||||||
is_cpu_scalar or value.device == a.device,
|
is_cpu_scalar or value.device == a.device,
|
||||||
lambda: "Expected `value` to be on same device as `a`",
|
lambda: "Expected `value` to be on same device as `a`",
|
||||||
)
|
)
|
||||||
|
|
@ -5011,7 +5000,7 @@ def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLi
|
||||||
# We allow casting `value` to lower type for other case
|
# We allow casting `value` to lower type for other case
|
||||||
# Eg. float -> int.
|
# Eg. float -> int.
|
||||||
# Ref: https://github.com/pytorch/pytorch/issues/79195
|
# Ref: https://github.com/pytorch/pytorch/issues/79195
|
||||||
check(
|
torch._check(
|
||||||
utils.is_weakly_lesser_type(value_type, python_type),
|
utils.is_weakly_lesser_type(value_type, python_type),
|
||||||
lambda: f"could not convert to type {python_type} without overflow",
|
lambda: f"could not convert to type {python_type} without overflow",
|
||||||
)
|
)
|
||||||
|
|
@ -5101,7 +5090,7 @@ def norm(
|
||||||
|
|
||||||
@register_decomposition(aten.trace)
|
@register_decomposition(aten.trace)
|
||||||
def trace(self: TensorLikeType) -> TensorLikeType:
|
def trace(self: TensorLikeType) -> TensorLikeType:
|
||||||
utils.check(
|
torch._check(
|
||||||
self.ndim == 2, lambda: "expected a matrix, but got tensor with dim {self.ndim}"
|
self.ndim == 2, lambda: "expected a matrix, but got tensor with dim {self.ndim}"
|
||||||
)
|
)
|
||||||
return torch.sum(torch.diag(self, 0))
|
return torch.sum(torch.diag(self, 0))
|
||||||
|
|
@ -5125,7 +5114,7 @@ rpow = _make_r_binary_op(pow)
|
||||||
@register_decomposition(aten.triu)
|
@register_decomposition(aten.triu)
|
||||||
@out_wrapper()
|
@out_wrapper()
|
||||||
def triu(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
|
def triu(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
|
||||||
utils.check(
|
torch._check(
|
||||||
a.ndim >= 2, lambda: "triu: input tensor must have at least 2 dimensions"
|
a.ndim >= 2, lambda: "triu: input tensor must have at least 2 dimensions"
|
||||||
)
|
)
|
||||||
h, w = a.shape[-2:]
|
h, w = a.shape[-2:]
|
||||||
|
|
@ -5142,7 +5131,7 @@ def triu(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
|
||||||
@register_decomposition(aten.tril)
|
@register_decomposition(aten.tril)
|
||||||
@out_wrapper()
|
@out_wrapper()
|
||||||
def tril(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
|
def tril(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
|
||||||
utils.check(
|
torch._check(
|
||||||
a.ndim >= 2, lambda: "tril: input tensor must have at least 2 dimensions"
|
a.ndim >= 2, lambda: "tril: input tensor must have at least 2 dimensions"
|
||||||
)
|
)
|
||||||
h, w = a.shape[-2:]
|
h, w = a.shape[-2:]
|
||||||
|
|
@ -5187,9 +5176,9 @@ def _trilu_checks(
|
||||||
layout: torch.layout,
|
layout: torch.layout,
|
||||||
pin_memory: bool,
|
pin_memory: bool,
|
||||||
):
|
):
|
||||||
check(row >= 0, lambda: f"row must be non-negative, got {row}")
|
torch._check(row >= 0, lambda: f"row must be non-negative, got {row}")
|
||||||
check(col >= 0, lambda: f"col must be non-negative, got {col}")
|
torch._check(col >= 0, lambda: f"col must be non-negative, got {col}")
|
||||||
check(
|
torch._check(
|
||||||
dtype in (torch.int32, torch.int64),
|
dtype in (torch.int32, torch.int64),
|
||||||
lambda: f"\"{name}\" not implemented for '{dtype}'",
|
lambda: f"\"{name}\" not implemented for '{dtype}'",
|
||||||
)
|
)
|
||||||
|
|
@ -5306,7 +5295,7 @@ def bucketize(
|
||||||
out_int32: bool = False,
|
out_int32: bool = False,
|
||||||
right: bool = False,
|
right: bool = False,
|
||||||
):
|
):
|
||||||
utils.check(
|
torch._check(
|
||||||
boundaries.dim() == 1,
|
boundaries.dim() == 1,
|
||||||
lambda: f"boundaries tensor must be 1 dimension but got dim({boundaries.dim()})",
|
lambda: f"boundaries tensor must be 1 dimension but got dim({boundaries.dim()})",
|
||||||
)
|
)
|
||||||
|
|
@ -5364,14 +5353,14 @@ def bucketize(
|
||||||
)
|
)
|
||||||
def cauchy(self, median=0, sigma=1, generator=None):
|
def cauchy(self, median=0, sigma=1, generator=None):
|
||||||
assert generator is None
|
assert generator is None
|
||||||
utils.check(
|
torch._check(
|
||||||
not utils.is_complex_dtype(self.dtype)
|
not utils.is_complex_dtype(self.dtype)
|
||||||
and not utils.is_integer_dtype(self.dtype)
|
and not utils.is_integer_dtype(self.dtype)
|
||||||
and not utils.is_boolean_dtype(self.dtype),
|
and not utils.is_boolean_dtype(self.dtype),
|
||||||
lambda: f"Cauchy distribution is a continuous probability distribution. \
|
lambda: f"Cauchy distribution is a continuous probability distribution. \
|
||||||
dtype must be a floating point but you specified {self.dtype}",
|
dtype must be a floating point but you specified {self.dtype}",
|
||||||
)
|
)
|
||||||
utils.check(
|
torch._check(
|
||||||
sigma > 0.0,
|
sigma > 0.0,
|
||||||
lambda: f"cauchy_ expects sigma > 0.0, but found sigma={sigma}",
|
lambda: f"cauchy_ expects sigma > 0.0, but found sigma={sigma}",
|
||||||
)
|
)
|
||||||
|
|
@ -5386,14 +5375,14 @@ def cauchy(self, median=0, sigma=1, generator=None):
|
||||||
)
|
)
|
||||||
def exponential(self, rate=1, generator=None):
|
def exponential(self, rate=1, generator=None):
|
||||||
assert generator is None
|
assert generator is None
|
||||||
utils.check(
|
torch._check(
|
||||||
not utils.is_complex_dtype(self.dtype)
|
not utils.is_complex_dtype(self.dtype)
|
||||||
and not utils.is_integer_dtype(self.dtype)
|
and not utils.is_integer_dtype(self.dtype)
|
||||||
and not utils.is_boolean_dtype(self.dtype),
|
and not utils.is_boolean_dtype(self.dtype),
|
||||||
lambda: f"Exponential distribution is a continuous probability distribution. \
|
lambda: f"Exponential distribution is a continuous probability distribution. \
|
||||||
dtype must be a floating point but you specified {self.dtype}",
|
dtype must be a floating point but you specified {self.dtype}",
|
||||||
)
|
)
|
||||||
utils.check(
|
torch._check(
|
||||||
rate > 0.0,
|
rate > 0.0,
|
||||||
lambda: f"exponential_ expects lambda > 0.0, but found lambda={rate}",
|
lambda: f"exponential_ expects lambda > 0.0, but found lambda={rate}",
|
||||||
)
|
)
|
||||||
|
|
@ -5409,12 +5398,12 @@ def exponential(self, rate=1, generator=None):
|
||||||
def geometric(self, p, generator=None):
|
def geometric(self, p, generator=None):
|
||||||
assert generator is None
|
assert generator is None
|
||||||
# TODO: fix inductor rand_like for integer, bool dtypes
|
# TODO: fix inductor rand_like for integer, bool dtypes
|
||||||
utils.check(
|
torch._check(
|
||||||
not utils.is_complex_dtype(self.dtype)
|
not utils.is_complex_dtype(self.dtype)
|
||||||
and not utils.is_boolean_dtype(self.dtype),
|
and not utils.is_boolean_dtype(self.dtype),
|
||||||
lambda: f"geometric not implemented for {self.dtype}",
|
lambda: f"geometric not implemented for {self.dtype}",
|
||||||
)
|
)
|
||||||
utils.check(
|
torch._check(
|
||||||
0 < p and p < 1,
|
0 < p and p < 1,
|
||||||
lambda: f"geometric_ expects p to be in (0, 1), but got p={p}",
|
lambda: f"geometric_ expects p to be in (0, 1), but got p={p}",
|
||||||
)
|
)
|
||||||
|
|
@ -5429,13 +5418,13 @@ def geometric(self, p, generator=None):
|
||||||
)
|
)
|
||||||
def log_normal(self, mean=1, std=2, generator=None):
|
def log_normal(self, mean=1, std=2, generator=None):
|
||||||
assert generator is None
|
assert generator is None
|
||||||
utils.check(
|
torch._check(
|
||||||
not utils.is_complex_dtype(self.dtype)
|
not utils.is_complex_dtype(self.dtype)
|
||||||
and not utils.is_integer_dtype(self.dtype)
|
and not utils.is_integer_dtype(self.dtype)
|
||||||
and not utils.is_boolean_dtype(self.dtype),
|
and not utils.is_boolean_dtype(self.dtype),
|
||||||
lambda: f"log_normal not implemented for {self.dtype}",
|
lambda: f"log_normal not implemented for {self.dtype}",
|
||||||
)
|
)
|
||||||
utils.check(
|
torch._check(
|
||||||
0 < std,
|
0 < std,
|
||||||
lambda: f"log_normal_ expects std > 0.0, but found std={std}",
|
lambda: f"log_normal_ expects std > 0.0, but found std={std}",
|
||||||
)
|
)
|
||||||
|
|
@ -5451,7 +5440,7 @@ def log_normal(self, mean=1, std=2, generator=None):
|
||||||
)
|
)
|
||||||
def normal(self, mean=0, std=1, generator=None):
|
def normal(self, mean=0, std=1, generator=None):
|
||||||
assert generator is None
|
assert generator is None
|
||||||
utils.check(std >= 0, lambda: f"normal expects std >= 0.0, but found std {std}")
|
torch._check(std >= 0, lambda: f"normal expects std >= 0.0, but found std {std}")
|
||||||
normal_samples = prims.normal(
|
normal_samples = prims.normal(
|
||||||
self.shape,
|
self.shape,
|
||||||
mean=0.0,
|
mean=0.0,
|
||||||
|
|
@ -5465,7 +5454,7 @@ def normal(self, mean=0, std=1, generator=None):
|
||||||
|
|
||||||
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
|
||||||
def rad2deg(self: TensorLikeType):
|
def rad2deg(self: TensorLikeType):
|
||||||
utils.check(
|
torch._check(
|
||||||
not utils.is_complex_dtype(self.dtype),
|
not utils.is_complex_dtype(self.dtype),
|
||||||
lambda: "rad2deg is not supported for complex tensors.",
|
lambda: "rad2deg is not supported for complex tensors.",
|
||||||
)
|
)
|
||||||
|
|
@ -5475,7 +5464,7 @@ def rad2deg(self: TensorLikeType):
|
||||||
|
|
||||||
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
|
||||||
def deg2rad(self: TensorLikeType):
|
def deg2rad(self: TensorLikeType):
|
||||||
utils.check(
|
torch._check(
|
||||||
not utils.is_complex_dtype(self.dtype),
|
not utils.is_complex_dtype(self.dtype),
|
||||||
lambda: "deg2rad is not supported for complex tensors.",
|
lambda: "deg2rad is not supported for complex tensors.",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import torch._prims_common as utils
|
||||||
# Utilities should come BEFORE this import
|
# Utilities should come BEFORE this import
|
||||||
from torch._decomp import register_decomposition
|
from torch._decomp import register_decomposition
|
||||||
|
|
||||||
from torch._prims_common import check, TensorLikeType
|
from torch._prims_common import TensorLikeType
|
||||||
from torch._prims_common.wrappers import out_wrapper
|
from torch._prims_common.wrappers import out_wrapper
|
||||||
from torch._refs import _broadcast_shapes
|
from torch._refs import _broadcast_shapes
|
||||||
|
|
||||||
|
|
@ -79,14 +79,14 @@ short = _make_conversion_method("short", torch.short)
|
||||||
@out_wrapper(exact_dtype=True)
|
@out_wrapper(exact_dtype=True)
|
||||||
def complex(real: TensorLikeType, imag: TensorLikeType) -> TensorLikeType:
|
def complex(real: TensorLikeType, imag: TensorLikeType) -> TensorLikeType:
|
||||||
allowed_dtypes = (torch.float32, torch.float64, torch.float16)
|
allowed_dtypes = (torch.float32, torch.float64, torch.float16)
|
||||||
check(
|
torch._check(
|
||||||
real.dtype in allowed_dtypes and imag.dtype in allowed_dtypes,
|
real.dtype in allowed_dtypes and imag.dtype in allowed_dtypes,
|
||||||
lambda: (
|
lambda: (
|
||||||
f"Expected both inputs to be Half, Float or Double tensors but got "
|
f"Expected both inputs to be Half, Float or Double tensors but got "
|
||||||
f"{real.dtype} and {imag.dtype}"
|
f"{real.dtype} and {imag.dtype}"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
check(
|
torch._check(
|
||||||
real.dtype == imag.dtype,
|
real.dtype == imag.dtype,
|
||||||
lambda: (
|
lambda: (
|
||||||
f"Expected object of scalar type {real.dtype} but got "
|
f"Expected object of scalar type {real.dtype} but got "
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import torch
|
||||||
import torch._prims as prims
|
import torch._prims as prims
|
||||||
import torch._prims_common as utils
|
import torch._prims_common as utils
|
||||||
from torch._decomp import register_decomposition
|
from torch._decomp import register_decomposition
|
||||||
from torch._prims_common import check, DimsType, ShapeType, TensorLikeType
|
from torch._prims_common import DimsType, ShapeType, TensorLikeType
|
||||||
from torch._prims_common.wrappers import _maybe_convert_to_dtype, out_wrapper
|
from torch._prims_common.wrappers import _maybe_convert_to_dtype, out_wrapper
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
@ -43,7 +43,7 @@ def _apply_norm(
|
||||||
x: TensorLikeType, norm: NormType, signal_numel: int, forward: bool
|
x: TensorLikeType, norm: NormType, signal_numel: int, forward: bool
|
||||||
) -> TensorLikeType:
|
) -> TensorLikeType:
|
||||||
"""Apply normalization to the un-normalized FFT result"""
|
"""Apply normalization to the un-normalized FFT result"""
|
||||||
check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}")
|
torch._check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}")
|
||||||
|
|
||||||
if norm == "ortho":
|
if norm == "ortho":
|
||||||
return x * (1 / math.sqrt(signal_numel))
|
return x * (1 / math.sqrt(signal_numel))
|
||||||
|
|
@ -116,7 +116,9 @@ def _fft_c2r(
|
||||||
input = _maybe_promote_tensor_fft(input, require_complex=True)
|
input = _maybe_promote_tensor_fft(input, require_complex=True)
|
||||||
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
|
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
|
||||||
last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1)
|
last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1)
|
||||||
check(last_dim_size >= 1, lambda: f"Invalid number of data points ({n}) specified")
|
torch._check(
|
||||||
|
last_dim_size >= 1, lambda: f"Invalid number of data points ({n}) specified"
|
||||||
|
)
|
||||||
|
|
||||||
if n is not None:
|
if n is not None:
|
||||||
input = _resize_fft_input(input, dims=dims, sizes=(last_dim_size // 2 + 1,))
|
input = _resize_fft_input(input, dims=dims, sizes=(last_dim_size // 2 + 1,))
|
||||||
|
|
@ -138,7 +140,7 @@ def _fft_r2c(
|
||||||
onesided: bool,
|
onesided: bool,
|
||||||
) -> TensorLikeType:
|
) -> TensorLikeType:
|
||||||
"""Common code for performing any real to complex FFT (rfft or ihfft)"""
|
"""Common code for performing any real to complex FFT (rfft or ihfft)"""
|
||||||
check(
|
torch._check(
|
||||||
not input.dtype.is_complex,
|
not input.dtype.is_complex,
|
||||||
lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}",
|
lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}",
|
||||||
)
|
)
|
||||||
|
|
@ -162,7 +164,7 @@ def _fft_c2c(
|
||||||
forward: bool,
|
forward: bool,
|
||||||
) -> TensorLikeType:
|
) -> TensorLikeType:
|
||||||
"""Common code for performing any complex to complex FFT (fft or ifft)"""
|
"""Common code for performing any complex to complex FFT (fft or ifft)"""
|
||||||
check(
|
torch._check(
|
||||||
input.dtype.is_complex,
|
input.dtype.is_complex,
|
||||||
lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}",
|
lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}",
|
||||||
)
|
)
|
||||||
|
|
@ -265,20 +267,20 @@ def _canonicalize_fft_shape_and_dim_args(
|
||||||
ret_dims = utils.canonicalize_dims(input_dim, dim, wrap_scalar=False)
|
ret_dims = utils.canonicalize_dims(input_dim, dim, wrap_scalar=False)
|
||||||
|
|
||||||
# Check dims are unique
|
# Check dims are unique
|
||||||
check(len(set(dim)) == len(dim), lambda: "FFT dims must be unique")
|
torch._check(len(set(dim)) == len(dim), lambda: "FFT dims must be unique")
|
||||||
|
|
||||||
if shape is not None:
|
if shape is not None:
|
||||||
if not isinstance(shape, Sequence):
|
if not isinstance(shape, Sequence):
|
||||||
shape = (shape,)
|
shape = (shape,)
|
||||||
|
|
||||||
# Has shape, might have dim
|
# Has shape, might have dim
|
||||||
check(
|
torch._check(
|
||||||
dim is None or len(dim) == len(shape),
|
dim is None or len(dim) == len(shape),
|
||||||
lambda: "When given, dim and shape arguments must have the same length",
|
lambda: "When given, dim and shape arguments must have the same length",
|
||||||
)
|
)
|
||||||
transform_ndim = len(shape)
|
transform_ndim = len(shape)
|
||||||
|
|
||||||
check(
|
torch._check(
|
||||||
transform_ndim <= input_dim,
|
transform_ndim <= input_dim,
|
||||||
lambda: f"Got shape with {transform_ndim} values but input tensor "
|
lambda: f"Got shape with {transform_ndim} values but input tensor "
|
||||||
f"only has {input_dim} dimensions.",
|
f"only has {input_dim} dimensions.",
|
||||||
|
|
@ -301,7 +303,7 @@ def _canonicalize_fft_shape_and_dim_args(
|
||||||
ret_shape = tuple(input_sizes[d] for d in ret_dims)
|
ret_shape = tuple(input_sizes[d] for d in ret_dims)
|
||||||
|
|
||||||
for n in ret_shape:
|
for n in ret_shape:
|
||||||
check(n > 0, lambda: f"Invalid number of data points ({n}) specified")
|
torch._check(n > 0, lambda: f"Invalid number of data points ({n}) specified")
|
||||||
|
|
||||||
return _ShapeAndDims(shape=ret_shape, dims=ret_dims)
|
return _ShapeAndDims(shape=ret_shape, dims=ret_dims)
|
||||||
|
|
||||||
|
|
@ -323,7 +325,7 @@ def _fftn_c2c(
|
||||||
forward: bool,
|
forward: bool,
|
||||||
) -> TensorLikeType:
|
) -> TensorLikeType:
|
||||||
"""Common code for n-dimensional complex to complex FFTs (fftn or ifftn)"""
|
"""Common code for n-dimensional complex to complex FFTs (fftn or ifftn)"""
|
||||||
check(
|
torch._check(
|
||||||
input.dtype.is_complex,
|
input.dtype.is_complex,
|
||||||
lambda: f"{function_name} expects a complex input tensor, "
|
lambda: f"{function_name} expects a complex input tensor, "
|
||||||
f"but got {input.dtype}",
|
f"but got {input.dtype}",
|
||||||
|
|
@ -367,7 +369,7 @@ def rfftn(
|
||||||
dim: Optional[DimsType] = None,
|
dim: Optional[DimsType] = None,
|
||||||
norm: NormType = None,
|
norm: NormType = None,
|
||||||
) -> TensorLikeType:
|
) -> TensorLikeType:
|
||||||
check(
|
torch._check(
|
||||||
not input.dtype.is_complex,
|
not input.dtype.is_complex,
|
||||||
lambda: f"rfftn expects a real-valued input tensor, but got {input.dtype}",
|
lambda: f"rfftn expects a real-valued input tensor, but got {input.dtype}",
|
||||||
)
|
)
|
||||||
|
|
@ -386,12 +388,12 @@ def ihfftn(
|
||||||
dim: Optional[DimsType] = None,
|
dim: Optional[DimsType] = None,
|
||||||
norm: NormType = None,
|
norm: NormType = None,
|
||||||
) -> TensorLikeType:
|
) -> TensorLikeType:
|
||||||
check(
|
torch._check(
|
||||||
not input.dtype.is_complex,
|
not input.dtype.is_complex,
|
||||||
lambda: f"ihfftn expects a real-valued input tensor, but got {input.dtype}",
|
lambda: f"ihfftn expects a real-valued input tensor, but got {input.dtype}",
|
||||||
)
|
)
|
||||||
shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim)
|
shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim)
|
||||||
check(len(shape) > 0, lambda: "ihfftn must transform at least one axis")
|
torch._check(len(shape) > 0, lambda: "ihfftn must transform at least one axis")
|
||||||
input = _maybe_promote_tensor_fft(input, require_complex=False)
|
input = _maybe_promote_tensor_fft(input, require_complex=False)
|
||||||
input = _resize_fft_input(input, dim, shape)
|
input = _resize_fft_input(input, dim, shape)
|
||||||
|
|
||||||
|
|
@ -421,14 +423,14 @@ def _canonicalize_fft_c2r_shape_and_dim_args(
|
||||||
"""Canonicalize shape and dim arguments for n-dimensional c2r transforms,
|
"""Canonicalize shape and dim arguments for n-dimensional c2r transforms,
|
||||||
as well as calculating the last_dim_size which is shape[dim[-1]] for the output"""
|
as well as calculating the last_dim_size which is shape[dim[-1]] for the output"""
|
||||||
(shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
|
(shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
|
||||||
check(len(shape) > 0, lambda: f"{fname} must transform at least one axis")
|
torch._check(len(shape) > 0, lambda: f"{fname} must transform at least one axis")
|
||||||
|
|
||||||
if s is None or s[-1] == -1:
|
if s is None or s[-1] == -1:
|
||||||
last_dim_size = 2 * (input.shape[dim[-1]] - 1)
|
last_dim_size = 2 * (input.shape[dim[-1]] - 1)
|
||||||
else:
|
else:
|
||||||
last_dim_size = shape[-1]
|
last_dim_size = shape[-1]
|
||||||
|
|
||||||
check(
|
torch._check(
|
||||||
last_dim_size >= 1,
|
last_dim_size >= 1,
|
||||||
lambda: f"Invalid number of data points ({last_dim_size}) specified",
|
lambda: f"Invalid number of data points ({last_dim_size}) specified",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@ import torch._refs as refs
|
||||||
import torch._refs.linalg as linalg
|
import torch._refs.linalg as linalg
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch._prims_common import (
|
from torch._prims_common import (
|
||||||
check,
|
|
||||||
check_fp_or_complex,
|
check_fp_or_complex,
|
||||||
check_is_matrix,
|
check_is_matrix,
|
||||||
Dim,
|
Dim,
|
||||||
|
|
@ -29,11 +28,11 @@ def _check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_nam
|
||||||
Checks related to the dtype kwarg in `linalg.*norm` functions
|
Checks related to the dtype kwarg in `linalg.*norm` functions
|
||||||
"""
|
"""
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
check(
|
torch._check(
|
||||||
utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
|
utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
|
||||||
lambda: f"{fn_name}: dtype should be floating point or complex. Got {dtype}",
|
lambda: f"{fn_name}: dtype should be floating point or complex. Got {dtype}",
|
||||||
)
|
)
|
||||||
check(
|
torch._check(
|
||||||
utils.is_complex_dtype(dtype) == utils.is_complex_dtype(x_dtype),
|
utils.is_complex_dtype(dtype) == utils.is_complex_dtype(x_dtype),
|
||||||
lambda: "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}".format(
|
lambda: "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}".format(
|
||||||
fn_name=fn_name,
|
fn_name=fn_name,
|
||||||
|
|
@ -41,7 +40,7 @@ def _check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_nam
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
check(
|
torch._check(
|
||||||
utils.get_higher_dtype(dtype, x_dtype) == dtype,
|
utils.get_higher_dtype(dtype, x_dtype) == dtype,
|
||||||
lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible "
|
lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible "
|
||||||
"without narrowing to the specified dtype ({dtype})",
|
"without narrowing to the specified dtype ({dtype})",
|
||||||
|
|
@ -79,7 +78,7 @@ def vector_norm(
|
||||||
dim = [dim] # type: ignore[assignment]
|
dim = [dim] # type: ignore[assignment]
|
||||||
|
|
||||||
if x.numel() == 0 and (ord < 0.0 or ord == float("inf")):
|
if x.numel() == 0 and (ord < 0.0 or ord == float("inf")):
|
||||||
check(
|
torch._check(
|
||||||
dim is not None and len(dim) != 0,
|
dim is not None and len(dim) != 0,
|
||||||
lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor "
|
lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor "
|
||||||
"because the operation does not have an identity",
|
"because the operation does not have an identity",
|
||||||
|
|
@ -87,7 +86,7 @@ def vector_norm(
|
||||||
shape = x.shape
|
shape = x.shape
|
||||||
assert dim is not None # mypy does not seem to be able to see through check?
|
assert dim is not None # mypy does not seem to be able to see through check?
|
||||||
for d in dim:
|
for d in dim:
|
||||||
check(
|
torch._check(
|
||||||
shape[d] != 0,
|
shape[d] != 0,
|
||||||
lambda: f"linalg.vector_norm cannot compute the {ord} norm on the "
|
lambda: f"linalg.vector_norm cannot compute the {ord} norm on the "
|
||||||
f"dimension {d} because this dimension is empty and the "
|
f"dimension {d} because this dimension is empty and the "
|
||||||
|
|
@ -147,8 +146,10 @@ def matrix_norm(
|
||||||
dim = utils.canonicalize_dims(A.ndim, dim)
|
dim = utils.canonicalize_dims(A.ndim, dim)
|
||||||
if isinstance(dim, Dim):
|
if isinstance(dim, Dim):
|
||||||
dim = (dim,) # type: ignore[assignment]
|
dim = (dim,) # type: ignore[assignment]
|
||||||
check(len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}")
|
torch._check(
|
||||||
check(
|
len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}"
|
||||||
|
)
|
||||||
|
torch._check(
|
||||||
dim[0] != dim[1],
|
dim[0] != dim[1],
|
||||||
lambda: "linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})",
|
lambda: "linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})",
|
||||||
)
|
)
|
||||||
|
|
@ -157,7 +158,7 @@ def matrix_norm(
|
||||||
|
|
||||||
if isinstance(ord, str):
|
if isinstance(ord, str):
|
||||||
# ord
|
# ord
|
||||||
check(
|
torch._check(
|
||||||
ord in ("fro", "nuc"),
|
ord in ("fro", "nuc"),
|
||||||
lambda: "linalg.matrix_norm: Order {ord} not supported.",
|
lambda: "linalg.matrix_norm: Order {ord} not supported.",
|
||||||
)
|
)
|
||||||
|
|
@ -180,7 +181,7 @@ def matrix_norm(
|
||||||
else:
|
else:
|
||||||
# ord
|
# ord
|
||||||
abs_ord = abs(ord)
|
abs_ord = abs(ord)
|
||||||
check(
|
torch._check(
|
||||||
abs_ord in (2, 1, float("inf")),
|
abs_ord in (2, 1, float("inf")),
|
||||||
lambda: "linalg.matrix_norm: Order {ord} not supported.",
|
lambda: "linalg.matrix_norm: Order {ord} not supported.",
|
||||||
)
|
)
|
||||||
|
|
@ -224,12 +225,12 @@ def norm(
|
||||||
if dim is not None:
|
if dim is not None:
|
||||||
if isinstance(dim, Dim):
|
if isinstance(dim, Dim):
|
||||||
dim = (dim,) # type: ignore[assignment]
|
dim = (dim,) # type: ignore[assignment]
|
||||||
check(
|
torch._check(
|
||||||
len(dim) in (1, 2),
|
len(dim) in (1, 2),
|
||||||
lambda: "linalg.norm: If dim is specified, it must be of length 1 or 2. Got {dim}",
|
lambda: "linalg.norm: If dim is specified, it must be of length 1 or 2. Got {dim}",
|
||||||
)
|
)
|
||||||
elif ord is not None:
|
elif ord is not None:
|
||||||
check(
|
torch._check(
|
||||||
A.ndim in (1, 2),
|
A.ndim in (1, 2),
|
||||||
lambda: "linalg.norm: If dim is not specified but ord is, the input must be 1D or 2D. Got {A.ndim}D",
|
lambda: "linalg.norm: If dim is not specified but ord is, the input must be 1D or 2D. Got {A.ndim}D",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ import torch._prims_common as utils
|
||||||
import torch._refs as refs
|
import torch._refs as refs
|
||||||
from torch._decomp import register_decomposition
|
from torch._decomp import register_decomposition
|
||||||
from torch._prims_common import (
|
from torch._prims_common import (
|
||||||
check,
|
|
||||||
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
||||||
NumberType,
|
NumberType,
|
||||||
ShapeType,
|
ShapeType,
|
||||||
|
|
@ -98,7 +97,7 @@ def alpha_dropout(
|
||||||
if not training:
|
if not training:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
utils.check(
|
torch._check(
|
||||||
p <= 1 and p >= 0,
|
p <= 1 and p >= 0,
|
||||||
lambda: f"dropout probability has to be between 0 and 1, but got, {p}",
|
lambda: f"dropout probability has to be between 0 and 1, but got, {p}",
|
||||||
)
|
)
|
||||||
|
|
@ -134,7 +133,7 @@ def _inplace_wrapper(fn):
|
||||||
@wraps(fn)
|
@wraps(fn)
|
||||||
def _fn(a, *args, inplace=False, **kwargs):
|
def _fn(a, *args, inplace=False, **kwargs):
|
||||||
if inplace:
|
if inplace:
|
||||||
check(
|
torch._check(
|
||||||
"out" not in kwargs,
|
"out" not in kwargs,
|
||||||
lambda: "Cannot set inplace=True and pass out= at the same time",
|
lambda: "Cannot set inplace=True and pass out= at the same time",
|
||||||
)
|
)
|
||||||
|
|
@ -193,7 +192,7 @@ def dropout(
|
||||||
if not training:
|
if not training:
|
||||||
return a
|
return a
|
||||||
|
|
||||||
utils.check(
|
torch._check(
|
||||||
p <= 1 and p >= 0,
|
p <= 1 and p >= 0,
|
||||||
lambda: f"dropout probability has to be between 0 and 1, but got, {p}",
|
lambda: f"dropout probability has to be between 0 and 1, but got, {p}",
|
||||||
)
|
)
|
||||||
|
|
@ -232,15 +231,15 @@ def elu(
|
||||||
|
|
||||||
# nb. This should be factored out into a can_cast aux function
|
# nb. This should be factored out into a can_cast aux function
|
||||||
python_type = utils.dtype_to_type(a.dtype)
|
python_type = utils.dtype_to_type(a.dtype)
|
||||||
check(
|
torch._check(
|
||||||
utils.is_weakly_lesser_type(type(input_scale), python_type),
|
utils.is_weakly_lesser_type(type(input_scale), python_type),
|
||||||
lambda: f"input_scale argument of type {type(input_scale)} cannot be safely cast to type {python_type}!",
|
lambda: f"input_scale argument of type {type(input_scale)} cannot be safely cast to type {python_type}!",
|
||||||
)
|
)
|
||||||
check(
|
torch._check(
|
||||||
utils.is_weakly_lesser_type(type(scale), python_type),
|
utils.is_weakly_lesser_type(type(scale), python_type),
|
||||||
lambda: f"scale argument of type {type(scale)} cannot be safely cast to type {python_type}!",
|
lambda: f"scale argument of type {type(scale)} cannot be safely cast to type {python_type}!",
|
||||||
)
|
)
|
||||||
check(
|
torch._check(
|
||||||
utils.is_weakly_lesser_type(type(alpha), python_type),
|
utils.is_weakly_lesser_type(type(alpha), python_type),
|
||||||
lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!",
|
lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!",
|
||||||
)
|
)
|
||||||
|
|
@ -276,14 +275,14 @@ def group_norm(
|
||||||
"""
|
"""
|
||||||
Reference implementation of :func:`torch.nn.functional.group_norm`.
|
Reference implementation of :func:`torch.nn.functional.group_norm`.
|
||||||
"""
|
"""
|
||||||
utils.check(
|
torch._check(
|
||||||
input.ndim >= 2,
|
input.ndim >= 2,
|
||||||
lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}",
|
lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}",
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_size = input.shape[0]
|
batch_size = input.shape[0]
|
||||||
num_channels = input.shape[1]
|
num_channels = input.shape[1]
|
||||||
utils.check(
|
torch._check(
|
||||||
num_channels % num_groups == 0,
|
num_channels % num_groups == 0,
|
||||||
lambda: "Expected number of channels in input to be divisible by num_groups, "
|
lambda: "Expected number of channels in input to be divisible by num_groups, "
|
||||||
+ f"but got input of shape {input.shape} and num_groups = {num_groups}",
|
+ f"but got input of shape {input.shape} and num_groups = {num_groups}",
|
||||||
|
|
@ -394,7 +393,7 @@ def softmax(
|
||||||
# deprecated. For PrimTorch, it's fine to drop support for deprecated
|
# deprecated. For PrimTorch, it's fine to drop support for deprecated
|
||||||
# behavior because it requires explicit opt in. This error is to inform
|
# behavior because it requires explicit opt in. This error is to inform
|
||||||
# users how to update their calls.
|
# users how to update their calls.
|
||||||
check(dim is not None, lambda: "implicit dim not supported, use dim=X")
|
torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X")
|
||||||
return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload]
|
return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -409,7 +408,7 @@ def softmin(
|
||||||
# deprecated. For PrimTorch, it's fine to drop support for deprecated
|
# deprecated. For PrimTorch, it's fine to drop support for deprecated
|
||||||
# behavior because it requires explicit opt in. This error is to inform
|
# behavior because it requires explicit opt in. This error is to inform
|
||||||
# users how to update their calls.
|
# users how to update their calls.
|
||||||
check(dim is not None, lambda: "implicit dim not supported, use dim=X")
|
torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X")
|
||||||
return torch.softmax(a=-a, dim=dim, dtype=dtype) # type: ignore[call-overload]
|
return torch.softmax(a=-a, dim=dim, dtype=dtype) # type: ignore[call-overload]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -469,7 +468,7 @@ def softshrink(a: TensorLikeType, lambd: float = 0.5):
|
||||||
# softshrink(x) = x - lambd if x > lambd
|
# softshrink(x) = x - lambd if x > lambd
|
||||||
# = x + lambd if x < -lambd
|
# = x + lambd if x < -lambd
|
||||||
# = 0 otherwise
|
# = 0 otherwise
|
||||||
check(
|
torch._check(
|
||||||
lambd >= 0,
|
lambd >= 0,
|
||||||
lambda: f"lambda must be greater or equal to 0, but found to be {lambd}",
|
lambda: f"lambda must be greater or equal to 0, but found to be {lambd}",
|
||||||
)
|
)
|
||||||
|
|
@ -596,7 +595,7 @@ def log_softmax(
|
||||||
# deprecated. For PrimTorch, it's fine to drop support for deprecated
|
# deprecated. For PrimTorch, it's fine to drop support for deprecated
|
||||||
# behavior because it requires explicit opt in. This error is to inform
|
# behavior because it requires explicit opt in. This error is to inform
|
||||||
# users how to update their calls.
|
# users how to update their calls.
|
||||||
check(dim is not None, lambda: "implicit dim not supported, use dim=X")
|
torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X")
|
||||||
return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload]
|
return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -668,12 +667,12 @@ def _nll_loss_nd(
|
||||||
reduction: str,
|
reduction: str,
|
||||||
ignore_index: int,
|
ignore_index: int,
|
||||||
) -> TensorLikeType:
|
) -> TensorLikeType:
|
||||||
utils.check(
|
torch._check(
|
||||||
input.ndim > 0 and input.ndim <= 3,
|
input.ndim > 0 and input.ndim <= 3,
|
||||||
lambda: f"Expected input dimension to be either [1, 2, 3] but received {input.ndim}.",
|
lambda: f"Expected input dimension to be either [1, 2, 3] but received {input.ndim}.",
|
||||||
)
|
)
|
||||||
|
|
||||||
utils.check(
|
torch._check(
|
||||||
(input.ndim == 1) or (input.shape[0] == target.shape[0]),
|
(input.ndim == 1) or (input.shape[0] == target.shape[0]),
|
||||||
lambda: f"Expected input batch size {input.shape[0]} to match target batch size {target.shape[0]}.",
|
lambda: f"Expected input batch size {input.shape[0]} to match target batch size {target.shape[0]}.",
|
||||||
)
|
)
|
||||||
|
|
@ -693,7 +692,7 @@ def _nll_loss_nd(
|
||||||
(flat_target >= 0), (flat_target < num_classes)
|
(flat_target >= 0), (flat_target < num_classes)
|
||||||
)
|
)
|
||||||
class_check = torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask))
|
class_check = torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask))
|
||||||
utils.check(
|
torch._check(
|
||||||
isinstance(target, FakeTensor) or bool(class_check.item()),
|
isinstance(target, FakeTensor) or bool(class_check.item()),
|
||||||
lambda: "A target class is out-of-bounds and not the ignore index.",
|
lambda: "A target class is out-of-bounds and not the ignore index.",
|
||||||
)
|
)
|
||||||
|
|
@ -758,7 +757,7 @@ def nll_loss(
|
||||||
"""
|
"""
|
||||||
Reference implementation of torch.nn.functional.nll_loss
|
Reference implementation of torch.nn.functional.nll_loss
|
||||||
"""
|
"""
|
||||||
utils.check(
|
torch._check(
|
||||||
input.ndim > 0,
|
input.ndim > 0,
|
||||||
lambda: f"Expected input tensor to have 1 or more dimensions (got {input.ndim})",
|
lambda: f"Expected input tensor to have 1 or more dimensions (got {input.ndim})",
|
||||||
)
|
)
|
||||||
|
|
@ -796,9 +795,13 @@ def nll_loss(
|
||||||
# For ndim > 3, we reshape the input and target to 3-D case.
|
# For ndim > 3, we reshape the input and target to 3-D case.
|
||||||
# Input (N batch-size, C classes, k-dimensions)
|
# Input (N batch-size, C classes, k-dimensions)
|
||||||
# Target (N batch-size, k-dimensions)
|
# Target (N batch-size, k-dimensions)
|
||||||
utils.check(
|
torch._check(
|
||||||
input.ndim > 0 and target.ndim > 0 and target.shape[1:] == input.shape[2:],
|
input.ndim > 0 and target.ndim > 0 and target.shape[1:] == input.shape[2:],
|
||||||
lambda: f"Expected target shape {out_size} but got {target.shape}",
|
lambda: (
|
||||||
|
"Expected input and target to both have ndim > 0 and "
|
||||||
|
"target.shape[1:] == input.shape[2:], but got "
|
||||||
|
f"target.shape {target.shape} and input.shape {input.shape}"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_size = input.shape[0]
|
batch_size = input.shape[0]
|
||||||
|
|
@ -837,7 +840,7 @@ def huber_loss(
|
||||||
if type(reduction) is int:
|
if type(reduction) is int:
|
||||||
reduction = _reduction_int_to_str(reduction)
|
reduction = _reduction_int_to_str(reduction)
|
||||||
_check_reduction_value(reduction) # type: ignore[arg-type]
|
_check_reduction_value(reduction) # type: ignore[arg-type]
|
||||||
check(
|
torch._check(
|
||||||
delta > 0,
|
delta > 0,
|
||||||
lambda: "huber_loss does not support non-positive values for delta.",
|
lambda: "huber_loss does not support non-positive values for delta.",
|
||||||
)
|
)
|
||||||
|
|
@ -938,7 +941,7 @@ def _triplet_margin_with_distance_loss(
|
||||||
a_dim = anchor.ndim
|
a_dim = anchor.ndim
|
||||||
p_dim = positive.ndim
|
p_dim = positive.ndim
|
||||||
n_dim = negative.ndim
|
n_dim = negative.ndim
|
||||||
check(
|
torch._check(
|
||||||
a_dim == p_dim and p_dim == n_dim,
|
a_dim == p_dim and p_dim == n_dim,
|
||||||
lambda: (
|
lambda: (
|
||||||
f"The anchor, positive, and negative tensors are expected to have "
|
f"The anchor, positive, and negative tensors are expected to have "
|
||||||
|
|
@ -1075,25 +1078,25 @@ def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType:
|
||||||
"""
|
"""
|
||||||
Reference implementation of torch.nn.functional.prelu
|
Reference implementation of torch.nn.functional.prelu
|
||||||
"""
|
"""
|
||||||
check(
|
torch._check(
|
||||||
isinstance(a, TensorLike),
|
isinstance(a, TensorLike),
|
||||||
lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}",
|
lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}",
|
||||||
)
|
)
|
||||||
check(
|
torch._check(
|
||||||
isinstance(weight, TensorLike),
|
isinstance(weight, TensorLike),
|
||||||
lambda: f"prelu: Expected `weight` to be tensor, but got: {type(weight)}",
|
lambda: f"prelu: Expected `weight` to be tensor, but got: {type(weight)}",
|
||||||
)
|
)
|
||||||
|
|
||||||
if weight.numel() != 1:
|
if weight.numel() != 1:
|
||||||
check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.")
|
torch._check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.")
|
||||||
channel_size = a.shape[1] if a.ndim >= 2 else 1
|
channel_size = a.shape[1] if a.ndim >= 2 else 1
|
||||||
check(
|
torch._check(
|
||||||
weight.numel() == channel_size,
|
weight.numel() == channel_size,
|
||||||
lambda: f"Mismatch of parameter numbers and input channel size. Found parameter numbers ="
|
lambda: f"Mismatch of parameter numbers and input channel size. Found parameter numbers ="
|
||||||
f" {weight.numel()} and channel size = {channel_size}.",
|
f" {weight.numel()} and channel size = {channel_size}.",
|
||||||
)
|
)
|
||||||
|
|
||||||
check(
|
torch._check(
|
||||||
weight.ndim == 0 or weight.ndim == 1,
|
weight.ndim == 0 or weight.ndim == 1,
|
||||||
lambda: f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: "
|
lambda: f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: "
|
||||||
f"ndim = {weight.ndim}",
|
f"ndim = {weight.ndim}",
|
||||||
|
|
@ -1132,7 +1135,7 @@ def relu6(a: TensorLikeType, inplace: bool = False) -> TensorLikeType:
|
||||||
)
|
)
|
||||||
def glu(a: TensorLikeType, dim: int = -1) -> TensorLikeType:
|
def glu(a: TensorLikeType, dim: int = -1) -> TensorLikeType:
|
||||||
dim = utils.canonicalize_dims(a.ndim, dim)
|
dim = utils.canonicalize_dims(a.ndim, dim)
|
||||||
check(
|
torch._check(
|
||||||
a.shape[dim] % 2 == 0,
|
a.shape[dim] % 2 == 0,
|
||||||
lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}",
|
lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}",
|
||||||
)
|
)
|
||||||
|
|
@ -1160,8 +1163,8 @@ def pairwise_distance(
|
||||||
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||||
)
|
)
|
||||||
def pdist(a: TensorLikeType, p: float = 2) -> TensorLikeType:
|
def pdist(a: TensorLikeType, p: float = 2) -> TensorLikeType:
|
||||||
check(a.ndim == 2, lambda: f"pdist only supports 2D tensors, got: {a.ndim}D")
|
torch._check(a.ndim == 2, lambda: f"pdist only supports 2D tensors, got: {a.ndim}D")
|
||||||
check(p >= 0, lambda: "pdist only supports non-negative p values")
|
torch._check(p >= 0, lambda: "pdist only supports non-negative p values")
|
||||||
# For p == 2 we can use an efficient implementation, but other values of p
|
# For p == 2 we can use an efficient implementation, but other values of p
|
||||||
# require creating a much bigger tensor for an intermediate step
|
# require creating a much bigger tensor for an intermediate step
|
||||||
if p == 2:
|
if p == 2:
|
||||||
|
|
|
||||||
|
|
@ -148,7 +148,7 @@ def logit(self: TensorLikeType, eps: Optional[float] = None) -> TensorLikeType:
|
||||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
||||||
)
|
)
|
||||||
def xlog1py(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]):
|
def xlog1py(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]):
|
||||||
utils.check(
|
torch._check(
|
||||||
isinstance(a, TensorLike) or isinstance(b, TensorLike),
|
isinstance(a, TensorLike) or isinstance(b, TensorLike),
|
||||||
lambda: 'Expected either argument a or b to be a Tensor"',
|
lambda: 'Expected either argument a or b to be a Tensor"',
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ import torch._logging
|
||||||
from torch._guards import Source
|
from torch._guards import Source
|
||||||
from torch._ops import OpOverload
|
from torch._ops import OpOverload
|
||||||
from torch._prims_common import (
|
from torch._prims_common import (
|
||||||
check,
|
|
||||||
elementwise_dtypes,
|
elementwise_dtypes,
|
||||||
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
||||||
is_boolean_dtype,
|
is_boolean_dtype,
|
||||||
|
|
@ -1495,7 +1494,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||||
) = FakeTensor._find_common_device(func, args, kwargs)
|
) = FakeTensor._find_common_device(func, args, kwargs)
|
||||||
|
|
||||||
if isinstance(e, FakeTensor):
|
if isinstance(e, FakeTensor):
|
||||||
check(
|
torch._check(
|
||||||
e.device == common_device,
|
e.device == common_device,
|
||||||
lambda: f"FakeTensor is wrapped to wrong device, found {e.device}, expected {common_device}",
|
lambda: f"FakeTensor is wrapped to wrong device, found {e.device}, expected {common_device}",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user