Replace _prims_common.check with torch._check* (#103240)

This relands most of the changes from #102219 which were backed out by #103128. However, instead of removing `_prims_common.check`, it adds a warning and a comment mentioning that it will be removed in the future and `torch._check*` should be used instead. As mentioned in https://github.com/pytorch/pytorch/pull/103128#pullrequestreview-1466414415, `_prims_common.check` cannot yet be removed because of some internal usage

Part of #72948

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103240
Approved by: https://github.com/albanD
This commit is contained in:
Kurt Mohler 2023-06-21 00:46:14 +00:00 committed by PyTorch MergeBot
parent f3c3d12efb
commit ee83c646bb
15 changed files with 462 additions and 463 deletions

View File

@ -1161,6 +1161,10 @@ $1 = torch._ops.prims.sin.default($0)""")
def test_mul_complex(self): def test_mul_complex(self):
prims.mul(torch.randn(2), 1 + 1j) prims.mul(torch.randn(2), 1 + 1j)
def test_check_deprecation_warning(self):
with self.assertWarnsRegex(DeprecationWarning, 'will be removed in the future'):
torch._prims_common.check(True, lambda: 'message')
instantiate_device_type_tests(TestPrims, globals()) instantiate_device_type_tests(TestPrims, globals())

View File

@ -936,7 +936,7 @@ def is_warn_always_enabled():
# These error checking functions must be kept consistent with their C++ # These error checking functions must be kept consistent with their C++
# equivalents. Their C++ equivalents are mentioned where applicable. # equivalents. Their C++ equivalents are mentioned where applicable.
def _check_with(error_type, cond, message): def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callable[[], str]):
if not isinstance(cond, (builtins.bool, torch.SymBool)): if not isinstance(cond, (builtins.bool, torch.SymBool)):
raise TypeError(f'cond must be a bool, but got {type(cond)}') raise TypeError(f'cond must be a bool, but got {type(cond)}')

View File

@ -149,7 +149,7 @@ def fill_scalar(self, value):
@register_decomposition([aten.fill.Tensor]) @register_decomposition([aten.fill.Tensor])
def fill_tensor(self, value: Tensor): def fill_tensor(self, value: Tensor):
utils.check( torch._check(
value.dim() == 0, value.dim() == 0,
lambda: f"fill only supports 0-dimension value tensor but got tensor with {value.dim()} dimensions", lambda: f"fill only supports 0-dimension value tensor but got tensor with {value.dim()} dimensions",
) )
@ -785,14 +785,14 @@ def im2col(
padding: List[int], padding: List[int],
stride: List[int], stride: List[int],
) -> Tensor: ) -> Tensor:
utils.check(len(kernel_size) == 2, lambda: "im2col(): only 2D kernel supported") torch._check(len(kernel_size) == 2, lambda: "im2col(): only 2D kernel supported")
utils.check(len(dilation) == 2, lambda: "im2col(): only 2D dilation supported") torch._check(len(dilation) == 2, lambda: "im2col(): only 2D dilation supported")
utils.check(len(padding) == 2, lambda: "im2col(): only 2D padding supported") torch._check(len(padding) == 2, lambda: "im2col(): only 2D padding supported")
utils.check(len(stride) == 2, lambda: "im2col(): only 2D stride supported") torch._check(len(stride) == 2, lambda: "im2col(): only 2D stride supported")
def check_positive(param, param_name, strict=True): def check_positive(param, param_name, strict=True):
cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param) cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
utils.check( torch._check(
cond, lambda: "{param_name} should be greater {'than' zero, but got {param}" cond, lambda: "{param_name} should be greater {'than' zero, but got {param}"
) )
@ -803,7 +803,7 @@ def im2col(
shape = input.shape shape = input.shape
ndim = len(shape) ndim = len(shape)
utils.check( torch._check(
ndim in (3, 4) and all(d != 0 for d in shape[-3:]), ndim in (3, 4) and all(d != 0 for d in shape[-3:]),
lambda: "Expected 3D or 4D (batch mode) tensor for input with possible 0 batch size " lambda: "Expected 3D or 4D (batch mode) tensor for input with possible 0 batch size "
f"and non-zero dimensions, but got: {tuple(shape)}", f"and non-zero dimensions, but got: {tuple(shape)}",
@ -814,7 +814,7 @@ def im2col(
shape[-2:], padding, dilation, kernel_size, stride shape[-2:], padding, dilation, kernel_size, stride
) )
) )
utils.check( torch._check(
all(c > 0 for c in output_size), all(c > 0 for c in output_size),
lambda: f"Given an input with spacial size {tuple(shape[-2:])}, " lambda: f"Given an input with spacial size {tuple(shape[-2:])}, "
f"kernel_size={kernel_size}, dilation={dilation}, " f"kernel_size={kernel_size}, dilation={dilation}, "
@ -869,15 +869,15 @@ def col2im(
padding: List[int], padding: List[int],
stride: List[int], stride: List[int],
) -> Tensor: ) -> Tensor:
utils.check(len(output_size) == 2, lambda: "only 2D output_size supported") torch._check(len(output_size) == 2, lambda: "only 2D output_size supported")
utils.check(len(kernel_size) == 2, lambda: "only 2D kernel supported") torch._check(len(kernel_size) == 2, lambda: "only 2D kernel supported")
utils.check(len(dilation) == 2, lambda: "only 2D dilation supported") torch._check(len(dilation) == 2, lambda: "only 2D dilation supported")
utils.check(len(padding) == 2, lambda: "only 2D padding supported") torch._check(len(padding) == 2, lambda: "only 2D padding supported")
utils.check(len(stride) == 2, lambda: "only 2D stride supported") torch._check(len(stride) == 2, lambda: "only 2D stride supported")
def check_positive(param, param_name, strict=True): def check_positive(param, param_name, strict=True):
cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param) cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
utils.check( torch._check(
cond, lambda: "{param_name} should be greater than zero, but got {param}" cond, lambda: "{param_name} should be greater than zero, but got {param}"
) )
@ -889,13 +889,13 @@ def col2im(
shape = input.shape shape = input.shape
ndim = len(shape) ndim = len(shape)
utils.check( torch._check(
ndim in (2, 3) and all(d != 0 for d in shape[-2:]), ndim in (2, 3) and all(d != 0 for d in shape[-2:]),
lambda: "Expected 2D or 3D (batch mode) tensor for input with possible 0 batch size " lambda: "Expected 2D or 3D (batch mode) tensor for input with possible 0 batch size "
f"and non-zero dimensions, but got: {tuple(shape)}", f"and non-zero dimensions, but got: {tuple(shape)}",
) )
prod_kernel_size = kernel_size[0] * kernel_size[1] prod_kernel_size = kernel_size[0] * kernel_size[1]
utils.check( torch._check(
shape[-2] % prod_kernel_size == 0, shape[-2] % prod_kernel_size == 0,
lambda: "Expected size of input's first non-batch dimension to be divisible by the " lambda: "Expected size of input's first non-batch dimension to be divisible by the "
f"product of kernel_size, but got input.shape[-2] = {shape[-2]} and " f"product of kernel_size, but got input.shape[-2] = {shape[-2]} and "
@ -908,13 +908,13 @@ def col2im(
) )
] ]
L = col[0] * col[1] L = col[0] * col[1]
utils.check( torch._check(
shape[-1] == L, shape[-1] == L,
lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, " lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, "
f"dilation={dilation}, padding={padding}, stride={stride}, " f"dilation={dilation}, padding={padding}, stride={stride}, "
f"expected input.size(-1) to be {L} but got {shape[-1]}.", f"expected input.size(-1) to be {L} but got {shape[-1]}.",
) )
utils.check( torch._check(
L > 0, L > 0,
lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, " lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, "
f"dilation={dilation}, padding={padding}, stride={stride}, " f"dilation={dilation}, padding={padding}, stride={stride}, "
@ -961,7 +961,7 @@ def col2im(
def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float): def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float):
# According to the CUDA kernel implementation we should have this test; # According to the CUDA kernel implementation we should have this test;
# but it seems to fail tests! # but it seems to fail tests!
# utils.check(mask.dtype == torch.bool, lambda: f"Mask should be Bool Scalar Type {mask.dtype}") # torch._check(mask.dtype == torch.bool, lambda: f"Mask should be Bool Scalar Type {mask.dtype}")
# Mimicking CUDA kernel's behavior for output stride: output follow input's memory format # Mimicking CUDA kernel's behavior for output stride: output follow input's memory format
# This different from TensorIterator's behavior # This different from TensorIterator's behavior
@ -1221,21 +1221,21 @@ def native_group_norm_backward(
) )
utils.check_same_shape(input, grad_output, allow_cpu_scalar_tensors=False) utils.check_same_shape(input, grad_output, allow_cpu_scalar_tensors=False)
utils.check_same_shape(mean, rstd, allow_cpu_scalar_tensors=False) utils.check_same_shape(mean, rstd, allow_cpu_scalar_tensors=False)
utils.check( torch._check(
input.numel() == N * C * HxW, input.numel() == N * C * HxW,
lambda: f"Expect input to have { N * C * HxW} elements", lambda: f"Expect input to have { N * C * HxW} elements",
) )
utils.check( torch._check(
mean.shape == (N, group), mean.shape == (N, group),
lambda: f"Expect mean to have shape ({N}, {group}, but got {mean.shape}", lambda: f"Expect mean to have shape ({N}, {group}, but got {mean.shape}",
) )
utils.check( torch._check(
gamma is None or gamma.numel() == C, gamma is None or gamma.numel() == C,
lambda: f"Expect gamma to have {C} elements but got {gamma.numel() if gamma is not None else -1}", lambda: f"Expect gamma to have {C} elements but got {gamma.numel() if gamma is not None else -1}",
) )
cpg, _rem = divmod(C, group) cpg, _rem = divmod(C, group)
utils.check( torch._check(
_rem == 0, _rem == 0,
lambda: f"Expect number of channels {C} to be evenly-divisible by number of groups {group}", lambda: f"Expect number of channels {C} to be evenly-divisible by number of groups {group}",
) )
@ -1834,12 +1834,12 @@ def adaptive_avg_pool2d(input: Tensor, output_size: Tuple[int, int]):
device = input.device device = input.device
shape = input.shape shape = input.shape
ndim = len(shape) ndim = len(shape)
utils.check( torch._check(
ndim in (3, 4), ndim in (3, 4),
lambda: f"adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got {ndim}", lambda: f"adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got {ndim}",
) )
for d in input.shape[-2:]: for d in input.shape[-2:]:
utils.check( torch._check(
d != 0, d != 0,
lambda: "adaptive_avg_pool2d(): Expected input to have non-zero size for " lambda: "adaptive_avg_pool2d(): Expected input to have non-zero size for "
f"non-batch dimensions, but input has shape {tuple(shape)}.", f"non-batch dimensions, but input has shape {tuple(shape)}.",
@ -1966,13 +1966,13 @@ def _index_add(
alpha: NumberType = 1, alpha: NumberType = 1,
): ):
dim = utils.canonicalize_dims(x.ndim, dim) dim = utils.canonicalize_dims(x.ndim, dim)
utils.check( torch._check(
index.ndim <= 1, index.ndim <= 1,
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
) )
if alpha != 1: if alpha != 1:
python_type = utils.dtype_to_type(x.dtype) python_type = utils.dtype_to_type(x.dtype)
utils.check( torch._check(
python_type == bool python_type == bool
or utils.is_weakly_lesser_type(type(alpha), python_type), or utils.is_weakly_lesser_type(type(alpha), python_type),
lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!", lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!",
@ -2005,7 +2005,7 @@ def _index_copy(
x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike, *, inplace: bool x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike, *, inplace: bool
): ):
dim = utils.canonicalize_dims(x.ndim, dim) dim = utils.canonicalize_dims(x.ndim, dim)
utils.check( torch._check(
index.ndim <= 1, index.ndim <= 1,
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
) )
@ -2060,19 +2060,19 @@ def uniform_(self, low=0, high=1, generator=None):
def upsample_compute_output_size(input_size, output_size, scale_factors): def upsample_compute_output_size(input_size, output_size, scale_factors):
spatial_dimensions = len(input_size) - 2 spatial_dimensions = len(input_size) - 2
if output_size is not None: if output_size is not None:
utils.check( torch._check(
scale_factors is None, scale_factors is None,
lambda: "Must specify exactly one of output_size and scale_factors", lambda: "Must specify exactly one of output_size and scale_factors",
) )
utils.check(len(output_size) == spatial_dimensions, lambda: "") torch._check(len(output_size) == spatial_dimensions, lambda: "")
return output_size return output_size
if scale_factors is not None: if scale_factors is not None:
# NB: this isn't necessary lol # NB: this isn't necessary lol
utils.check( torch._check(
output_size is None, output_size is None,
lambda: "Must specify exactly one of output_size and scale_factors", lambda: "Must specify exactly one of output_size and scale_factors",
) )
utils.check(len(scale_factors) == spatial_dimensions, lambda: "") torch._check(len(scale_factors) == spatial_dimensions, lambda: "")
output_size = [] output_size = []
for i, s in enumerate(scale_factors): for i, s in enumerate(scale_factors):
if int(s) == s: if int(s) == s:
@ -2080,7 +2080,7 @@ def upsample_compute_output_size(input_size, output_size, scale_factors):
else: else:
output_size.append(sym_int(input_size[i + 2] * s)) output_size.append(sym_int(input_size[i + 2] * s))
return output_size return output_size
utils.check( torch._check(
False, lambda: "Must specify exactly one of output_size and scale_factors" False, lambda: "Must specify exactly one of output_size and scale_factors"
) )
@ -2969,11 +2969,11 @@ def grid_sampler_2d(
padding_mode: int = 0, padding_mode: int = 0,
align_corners: bool = False, align_corners: bool = False,
) -> Tensor: ) -> Tensor:
utils.check( torch._check(
interpolation_mode in (0, 1, 2), interpolation_mode in (0, 1, 2),
lambda: f"Invalid interpolation mode {interpolation_mode}", lambda: f"Invalid interpolation mode {interpolation_mode}",
) )
utils.check( torch._check(
padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}" padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}"
) )
@ -3110,11 +3110,11 @@ def grid_sampler_2d(
@out_wrapper() @out_wrapper()
@pw_cast_for_opmath @pw_cast_for_opmath
def mv(self, vec): def mv(self, vec):
utils.check( torch._check(
self.dim() == 2 and vec.dim() == 1, self.dim() == 2 and vec.dim() == 1,
lambda: f"matrix @ vector expected, got {self.dim()}, {vec.dim()}", lambda: f"matrix @ vector expected, got {self.dim()}, {vec.dim()}",
) )
utils.check( torch._check(
self.size(1) == vec.size(0), self.size(1) == vec.size(0),
lambda: f"size mismatch, got input ({self.size(0)}x{self.size(1)}), vec ({vec.size(0)})", lambda: f"size mismatch, got input ({self.size(0)}x{self.size(1)}), vec ({vec.size(0)})",
) )
@ -3134,11 +3134,11 @@ def dot(self, other):
elif other.is_conj(): elif other.is_conj():
return torch.vdot(other.conj(), self) return torch.vdot(other.conj(), self)
utils.check( torch._check(
self.dim() == 1 and other.dim() == 1, self.dim() == 1 and other.dim() == 1,
lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors", lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors",
) )
utils.check( torch._check(
self.dtype == other.dtype, self.dtype == other.dtype,
lambda: f"dot : expected both vectors to have same dtype, but found {self.dtype} and {other.dtype}", lambda: f"dot : expected both vectors to have same dtype, but found {self.dtype} and {other.dtype}",
) )
@ -3149,7 +3149,7 @@ def dot(self, other):
f"same number of elements, but got {self.numel()} and {other.numel()} elements respectively" f"same number of elements, but got {self.numel()} and {other.numel()} elements respectively"
) )
utils.check(self.numel() == other.numel(), numel_error) torch._check(self.numel() == other.numel(), numel_error)
return (self * other).sum() return (self * other).sum()
@ -3296,7 +3296,7 @@ def matmul(tensor1, tensor2):
return tensor1_expanded.bmm(tensor2_expanded).view(output_shape) return tensor1_expanded.bmm(tensor2_expanded).view(output_shape)
else: else:
utils.check(False, lambda: "both arguments to matmul need to be at least 1D") torch._check(False, lambda: "both arguments to matmul need to be at least 1D")
@register_decomposition(aten.upsample_bicubic2d.default) @register_decomposition(aten.upsample_bicubic2d.default)
@ -3373,7 +3373,7 @@ def upsample_bicubic2d_vec(
align_corners: bool, align_corners: bool,
scale_factors: Optional[Tuple[float, float]] = None, scale_factors: Optional[Tuple[float, float]] = None,
) -> Tensor: ) -> Tensor:
utils.check( torch._check(
bool(output_size) + bool(scale_factors) == 1, bool(output_size) + bool(scale_factors) == 1,
lambda: "Must specify exactly one of output_size and scale_factors.", lambda: "Must specify exactly one of output_size and scale_factors.",
) )

View File

@ -72,7 +72,6 @@ from torch._inductor.compile_fx import (
remove_unaligned_input_idxs, remove_unaligned_input_idxs,
static_input, static_input,
) )
from torch._prims_common import check
from torch.multiprocessing.reductions import StorageWeakRef from torch.multiprocessing.reductions import StorageWeakRef
from torch.storage import UntypedStorage from torch.storage import UntypedStorage
from torch.utils import _pytree as pytree from torch.utils import _pytree as pytree
@ -1071,7 +1070,7 @@ class CUDAGraphNode:
self.output_storage_alias.append(UnaliasedStorage) self.output_storage_alias.append(UnaliasedStorage)
continue continue
check( torch._check(
o.is_cuda, o.is_cuda,
lambda: f"Expected all cuda outputs in cuda graph recording. Non cuda output from {self.stack_traces[i]}", lambda: f"Expected all cuda outputs in cuda graph recording. Non cuda output from {self.stack_traces[i]}",
), ),
@ -1447,7 +1446,7 @@ class CUDAGraphNode:
for idx in self.cudagraph_managed_idxs: for idx in self.cudagraph_managed_idxs:
inputs[idx] = None inputs[idx] = None
check( torch._check(
self._check_liveness( self._check_liveness(
self.expected_dead_indices_after_graph, self.path_weakrefs self.expected_dead_indices_after_graph, self.path_weakrefs
), ),
@ -1522,7 +1521,7 @@ def check_memory_pool(device, pool_id, live_storages_ptrs: List[StorageWeakRefWr
addr += block["size"] addr += block["size"]
check( torch._check(
len(unique_storages) == 0, len(unique_storages) == 0,
lambda: f"These storage data ptrs are not allocated in pool {pool_id} but should be {unique_storages}", lambda: f"These storage data ptrs are not allocated in pool {pool_id} but should be {unique_storages}",
) )

File diff suppressed because it is too large Load Diff

View File

@ -16,7 +16,6 @@ from torch._prims.debug_prims import register_debug_prims
from torch._prims.nvfuser_prims import register_nvprims from torch._prims.nvfuser_prims import register_nvprims
from torch._prims.rng_prims import register_rng_prims from torch._prims.rng_prims import register_rng_prims
from torch._prims_common import ( from torch._prims_common import (
check,
Dim, Dim,
DimsSequenceType, DimsSequenceType,
DimsType, DimsType,
@ -422,7 +421,7 @@ def _elementwise_meta(
def _complex_only_elementwise_meta(*args, **kwargs): def _complex_only_elementwise_meta(*args, **kwargs):
utils.check( torch._check(
utils.is_complex_dtype(args[0].dtype), lambda: "Only complex dtype is supported" utils.is_complex_dtype(args[0].dtype), lambda: "Only complex dtype is supported"
) )
return _elementwise_meta(*args, **kwargs) return _elementwise_meta(*args, **kwargs)
@ -581,7 +580,7 @@ bitwise_not = _make_elementwise_unary_prim(
def _cbrt_aten(a: torch.Tensor) -> Tensor: def _cbrt_aten(a: torch.Tensor) -> Tensor:
utils.check( torch._check(
not a.is_complex(), not a.is_complex(),
lambda: "cbrt: Complex inputs not supported. Consider calling torch.pow(a, 1.0/3.0)", lambda: "cbrt: Complex inputs not supported. Consider calling torch.pow(a, 1.0/3.0)",
) )
@ -1293,10 +1292,9 @@ def _validate_collapse_args(a: Tensor, start: int, end: int) -> None:
# Verifies end is strictly greater than start # Verifies end is strictly greater than start
# (Collapse requires a non-empty interval) # (Collapse requires a non-empty interval)
utils.check( torch._check_value(
end >= start, end >= start,
lambda: f"Attempting to collapse but end, {end}, is less than start, {start}!", lambda: f"Attempting to collapse but end, {end}, is less than start, {start}!",
ValueError,
) )
@ -1823,7 +1821,7 @@ def _as_strided_scatter_meta(
utils.validate_strides(stride) utils.validate_strides(stride)
required_size = utils.compute_required_storage_length(size, stride, storage_offset) required_size = utils.compute_required_storage_length(size, stride, storage_offset)
utils.check( torch._check(
input.numel() >= required_size, input.numel() >= required_size,
lambda: ( lambda: (
f"as_strided_scatter: sizes {size}, strides {stride}, storage offset {storage_offset} " f"as_strided_scatter: sizes {size}, strides {stride}, storage offset {storage_offset} "
@ -1832,7 +1830,7 @@ def _as_strided_scatter_meta(
f"for storage of size {input.numel() * input.element_size()}" f"for storage of size {input.numel() * input.element_size()}"
), ),
) )
utils.check( torch._check(
utils.is_same_shape(src.shape, size), utils.is_same_shape(src.shape, size),
lambda: f"expected src to have a size equal to the slice of self. src size = {src.shape}, slice size = {size}", lambda: f"expected src to have a size equal to the slice of self. src size = {src.shape}, slice size = {size}",
) )
@ -2432,11 +2430,11 @@ def _iota_meta(
device: torch.device, device: torch.device,
requires_grad: bool, requires_grad: bool,
) -> TensorLikeType: ) -> TensorLikeType:
utils.check( torch._check(
utils.is_integer_dtype(dtype), utils.is_integer_dtype(dtype),
lambda: "prims.iota only supports integer dtypes", lambda: "prims.iota only supports integer dtypes",
) )
utils.check(step != 0, lambda: "step must be nonzero") torch._check(step != 0, lambda: "step must be nonzero")
return torch.empty( return torch.empty(
length, length,
dtype=dtype, dtype=dtype,
@ -2532,7 +2530,7 @@ def _empty_permuted_meta(
) -> TensorLikeType: ) -> TensorLikeType:
p_strides = utils.make_contiguous_strides_for([shape[l] for l in physical_layout]) p_strides = utils.make_contiguous_strides_for([shape[l] for l in physical_layout])
dim = len(shape) dim = len(shape)
utils.check( torch._check(
len(physical_layout) == dim, len(physical_layout) == dim,
lambda: ( lambda: (
"Number of dimensions in the tensor input does not match the " "Number of dimensions in the tensor input does not match the "
@ -2543,7 +2541,7 @@ def _empty_permuted_meta(
strides = [0] * len(shape) strides = [0] * len(shape)
seen_dims = set() seen_dims = set()
for p, l in enumerate(physical_layout): for p, l in enumerate(physical_layout):
utils.check( torch._check(
0 <= l < dim, 0 <= l < dim,
lambda: ( lambda: (
f"Dimension out of range (expected to be between 0 and {dim - 1}, but got " f"Dimension out of range (expected to be between 0 and {dim - 1}, but got "
@ -2551,7 +2549,7 @@ def _empty_permuted_meta(
"not currently supported; file an issue if you want it." "not currently supported; file an issue if you want it."
), ),
) )
utils.check(l not in seen_dims, lambda: "Duplicate dim not allowed") torch._check(l not in seen_dims, lambda: "Duplicate dim not allowed")
strides[l] = p_strides[p] strides[l] = p_strides[p]
seen_dims.add(l) seen_dims.add(l)
return TensorMeta( return TensorMeta(
@ -2779,12 +2777,12 @@ def _normal_meta(
device: torch.device, device: torch.device,
requires_grad: bool, requires_grad: bool,
) -> TensorLikeType: ) -> TensorLikeType:
utils.check( torch._check(
std >= 0.0, std >= 0.0,
lambda: f"expected non-negative standard deviation, but got std={std}", lambda: f"expected non-negative standard deviation, but got std={std}",
) )
utils.check( torch._check(
utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype), utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
lambda: f"expected a floating-point or complex dtype, but got dtype={dtype}", lambda: f"expected a floating-point or complex dtype, but got dtype={dtype}",
) )

View File

@ -7,6 +7,7 @@ from functools import reduce, cmp_to_key
import operator import operator
import sympy import sympy
import weakref import weakref
import warnings
import torch import torch
from torch import sym_float, sym_int, sym_max from torch import sym_float, sym_int, sym_max
@ -268,7 +269,7 @@ _memory_formats = {
def validate_memory_format(memory_format: torch.memory_format): def validate_memory_format(memory_format: torch.memory_format):
check( torch._check(
memory_format in _memory_formats, memory_format in _memory_formats,
lambda: f"Received unknown memory format {memory_format}!", lambda: f"Received unknown memory format {memory_format}!",
) )
@ -286,7 +287,7 @@ def is_contiguous_for_memory_format( # type: ignore[return]
if memory_format == torch.channels_last_3d: if memory_format == torch.channels_last_3d:
return is_channels_last_contiguous_3d(a) return is_channels_last_contiguous_3d(a)
check( torch._check(
False, False,
lambda: f"is_contiguous received unsupported memory format {memory_format}", lambda: f"is_contiguous received unsupported memory format {memory_format}",
) )
@ -795,13 +796,13 @@ def infer_size(shape: ShapeType, numel: int) -> Tuple[int, ...]:
newsize = 1 newsize = 1
for i, d in enumerate(shape): for i, d in enumerate(shape):
if d == -1: if d == -1:
check(dim is None, lambda: "only one dimension can be inferred") torch._check(dim is None, lambda: "only one dimension can be inferred")
dim = i dim = i
elif d >= 0: elif d >= 0:
newsize *= d newsize *= d
else: else:
check(False, lambda: f"invalid shape dimension {d}") torch._check(False, lambda: f"invalid shape dimension {d}")
check( torch._check(
numel == newsize or (dim is not None and newsize > 0 and numel % newsize == 0), numel == newsize or (dim is not None and newsize > 0 and numel % newsize == 0),
lambda: f"shape '{list(shape)}' is invalid for input of size {numel}", lambda: f"shape '{list(shape)}' is invalid for input of size {numel}",
) )
@ -809,7 +810,7 @@ def infer_size(shape: ShapeType, numel: int) -> Tuple[int, ...]:
# Convert to list to produce a compatible error message with core # Convert to list to produce a compatible error message with core
# PyTorch, which prints sequences in square brackets. # PyTorch, which prints sequences in square brackets.
shape = list(shape) shape = list(shape)
check( torch._check(
newsize != 0, newsize != 0,
lambda: (f"cannot reshape tensor of 0 elements into shape {shape} because the " lambda: (f"cannot reshape tensor of 0 elements into shape {shape} because the "
f"unspecified dimension size -1 can be any value and is ambiguous"), f"unspecified dimension size -1 can be any value and is ambiguous"),
@ -954,18 +955,18 @@ def check_fp_or_complex(
Checks whether the input is floating point or complex. Checks whether the input is floating point or complex.
If allow_low_precision_dtypes is True, it allows having float16, bfloat16, and complex32 If allow_low_precision_dtypes is True, it allows having float16, bfloat16, and complex32
""" """
check( torch._check(
is_float_dtype(dtype) or is_complex_dtype(dtype), is_float_dtype(dtype) or is_complex_dtype(dtype),
lambda: f"{fn_name}: Expected a floating point or complex tensor as input. Got {dtype}", lambda: f"{fn_name}: Expected a floating point or complex tensor as input. Got {dtype}",
) )
check( torch._check(
allow_low_precision_dtypes or not is_low_precision_dtype(dtype), allow_low_precision_dtypes or not is_low_precision_dtype(dtype),
lambda: f"{fn_name}: Half precision dtypes not supported. Got {dtype}", lambda: f"{fn_name}: Half precision dtypes not supported. Got {dtype}",
) )
def check_is_matrix(A: TensorLikeType, f_name: str, arg_name: str = "A"): def check_is_matrix(A: TensorLikeType, f_name: str, arg_name: str = "A"):
check( torch._check(
len(A.shape) >= 2, len(A.shape) >= 2,
lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.", lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.",
) )
@ -1060,11 +1061,11 @@ def get_higher_dtype(
def check_pin_memory(pin_memory: bool): def check_pin_memory(pin_memory: bool):
check(not pin_memory, lambda: "PrimTorch does not support pinned memory", NotImplementedError) torch._check_not_implemented(not pin_memory, lambda: "PrimTorch does not support pinned memory")
def check_layout(layout: torch.layout): def check_layout(layout: torch.layout):
check(layout == torch.strided, lambda: f"PrimTorch doesn't support layout={layout}", NotImplementedError) torch._check_not_implemented(layout == torch.strided, lambda: f"PrimTorch doesn't support layout={layout}")
# TODO: maybe unify with can_cast_to? # TODO: maybe unify with can_cast_to?
@ -1485,7 +1486,7 @@ def make_contiguous_strides_for(
def make_channels_last_1d_strides_for(shape: ShapeType) -> Tuple[int, ...]: def make_channels_last_1d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
check( torch._check(
len(shape) == 3, len(shape) == 3,
lambda: "Only tensors of rank 3 can use the channels_last_1d memory format", lambda: "Only tensors of rank 3 can use the channels_last_1d memory format",
) )
@ -1503,7 +1504,7 @@ def make_channels_last_1d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
def make_channels_last_2d_strides_for(shape: ShapeType) -> Tuple[int, ...]: def make_channels_last_2d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
# TODO: maybe inform the user of channels_last_3d if rank of the tensor is 5? # TODO: maybe inform the user of channels_last_3d if rank of the tensor is 5?
check( torch._check(
len(shape) == 4, len(shape) == 4,
lambda: "Only tensors of rank 4 can use the channels_last memory format", lambda: "Only tensors of rank 4 can use the channels_last memory format",
) )
@ -1520,7 +1521,7 @@ def make_channels_last_2d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
def make_channels_last_3d_strides_for(shape: ShapeType) -> Tuple[int, ...]: def make_channels_last_3d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
check( torch._check(
len(shape) == 5, len(shape) == 5,
lambda: "Only tensors of rank 5 can use the channels_last_3d memory format", lambda: "Only tensors of rank 5 can use the channels_last_3d memory format",
) )
@ -1654,6 +1655,9 @@ def check_in_bounds_for_storage(
raise ValueError(msg) raise ValueError(msg)
# NOTE: This function should ideally be removed, but some Meta internal models
# packaged with `torch.package` are using it, so it will have to be removed
# at some point in the future when those models no longer use this function.
def check( def check(
b: bool, s: Callable[[], str], exc_type: Type[Exception] = RuntimeError b: bool, s: Callable[[], str], exc_type: Type[Exception] = RuntimeError
) -> None: ) -> None:
@ -1662,9 +1666,14 @@ def check(
Error message is a callable producing a string (to avoid wasting time Error message is a callable producing a string (to avoid wasting time
string formatting in non-error case, and also to make it easier for torchdynamo string formatting in non-error case, and also to make it easier for torchdynamo
to trace.) to trace.)
.. note:: This function is planned for removal in the future. Please use
`torch._check*` functions instead.
""" """
if not b: warnings.warn(DeprecationWarning((
raise exc_type(s()) "'torch._prims_common.check' will be removed in the future. Please use "
"'torch._check*' functions instead")))
torch._check_with(exc_type, b, s)
# This combines is_channels_last_strides_2d and is_channels_last_strides_3d in # This combines is_channels_last_strides_2d and is_channels_last_strides_3d in

View File

@ -176,13 +176,13 @@ def _safe_copy_out(
# Checks safe cast # Checks safe cast
if exact_dtype: if exact_dtype:
utils.check( torch._check(
copy_from.dtype == copy_to.dtype, copy_from.dtype == copy_to.dtype,
lambda: f"Expected out tensor to have dtype {copy_from.dtype} " lambda: f"Expected out tensor to have dtype {copy_from.dtype} "
f"but got {copy_to.dtype} instead", f"but got {copy_to.dtype} instead",
) )
else: else:
utils.check( torch._check(
utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype), utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype),
lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, " lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, "
"but this can't be cast because it is not safe!", "but this can't be cast because it is not safe!",
@ -255,10 +255,9 @@ def out_wrapper(*out_names: str, exact_dtype: bool = False):
_safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type] _safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type]
else: else:
assert isinstance(out, Tuple) # type: ignore[arg-type] assert isinstance(out, Tuple) # type: ignore[arg-type]
utils.check( torch._check_type(
len(out) == len(result), len(out) == len(result),
lambda: f"expected tuple of {len(result)} elements but got {len(out)}", lambda: f"expected tuple of {len(result)} elements but got {len(out)}",
TypeError,
) )
for r, o in zip(result, out): for r, o in zip(result, out):
# These two operations are done in-place # These two operations are done in-place

View File

@ -16,7 +16,6 @@ import torch._prims as prims
import torch._prims_common as utils import torch._prims_common as utils
from torch import sym_float, sym_int from torch import sym_float, sym_int
from torch._prims_common import ( from torch._prims_common import (
check,
DeviceLikeType, DeviceLikeType,
Dim, Dim,
DimsSequenceType, DimsSequenceType,
@ -626,7 +625,7 @@ def frac(x: TensorLikeType) -> TensorLikeType:
# imag does not use _make_elementwise_unary_reference because it does not support out # imag does not use _make_elementwise_unary_reference because it does not support out
def imag(a: TensorLikeType) -> TensorLikeType: def imag(a: TensorLikeType) -> TensorLikeType:
assert isinstance(a, TensorLike) assert isinstance(a, TensorLike)
utils.check( torch._check(
utils.is_complex_dtype(a.dtype), lambda: "imag only supports complex tensors." utils.is_complex_dtype(a.dtype), lambda: "imag only supports complex tensors."
) )
return prims.imag(a) return prims.imag(a)
@ -654,7 +653,7 @@ def isinf(a: TensorLikeType) -> TensorLikeType:
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
def isposinf(a: TensorLikeType) -> TensorLikeType: def isposinf(a: TensorLikeType) -> TensorLikeType:
utils.check( torch._check(
not utils.is_complex_dtype(a.dtype), not utils.is_complex_dtype(a.dtype),
lambda: f"Complex dtype is not supported for isposinf, got dtype {a.dtype}", lambda: f"Complex dtype is not supported for isposinf, got dtype {a.dtype}",
) )
@ -665,7 +664,7 @@ def isposinf(a: TensorLikeType) -> TensorLikeType:
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
def isneginf(a: TensorLikeType) -> TensorLikeType: def isneginf(a: TensorLikeType) -> TensorLikeType:
utils.check( torch._check(
not utils.is_complex_dtype(a.dtype), not utils.is_complex_dtype(a.dtype),
lambda: f"Complex dtype is not supported for isneginf, got dtype {a.dtype}", lambda: f"Complex dtype is not supported for isneginf, got dtype {a.dtype}",
) )
@ -788,7 +787,7 @@ def nan_to_num(
def _neg_meta(a: TensorLikeType): def _neg_meta(a: TensorLikeType):
check( torch._check(
a.dtype is not torch.bool, a.dtype is not torch.bool,
lambda: ( lambda: (
"Negation, the `-` operator, on a bool tensor is not supported. " "Negation, the `-` operator, on a bool tensor is not supported. "
@ -935,23 +934,20 @@ def _make_elementwise_binary_reference(
a: Union[Tensor, NumberType], a: Union[Tensor, NumberType],
b: Union[Tensor, NumberType], b: Union[Tensor, NumberType],
) -> Tensor: ) -> Tensor:
check( torch._check_value(
supports_lhs_python_scalar or not isinstance(a, Number), supports_lhs_python_scalar or not isinstance(a, Number),
lambda: f"{name}: Received a lhs Python scalar to an elementwise binary " lambda: f"{name}: Received a lhs Python scalar to an elementwise binary "
"operation that does not accept lhs scalars!", "operation that does not accept lhs scalars!",
ValueError,
) )
check( torch._check_value(
supports_rhs_python_scalar or not isinstance(b, Number), supports_rhs_python_scalar or not isinstance(b, Number),
lambda: f"{name}: Received a rhs Python scalar to an elementwise binary " lambda: f"{name}: Received a rhs Python scalar to an elementwise binary "
"operation that does not accept rhs scalars!", "operation that does not accept rhs scalars!",
ValueError,
) )
check( torch._check_value(
supports_two_python_scalars supports_two_python_scalars
or not (isinstance(a, Number) and isinstance(b, Number)), or not (isinstance(a, Number) and isinstance(b, Number)),
lambda: f"{name}: Receive two Number inputs to an elementwise binary operation!", lambda: f"{name}: Receive two Number inputs to an elementwise binary operation!",
ValueError,
) )
a, b = _maybe_broadcast(a, b) a, b = _maybe_broadcast(a, b)
return prim(a, b) return prim(a, b)
@ -1230,7 +1226,7 @@ def floor_divide(
elif utils.is_integer_dtype(dtype): elif utils.is_integer_dtype(dtype):
return _floor_divide_integer(a, b) return _floor_divide_integer(a, b)
else: else:
check(False, lambda: f"{dtype} not supported for floor_divide") torch._check(False, lambda: f"{dtype} not supported for floor_divide")
def _floor_divide_integer(a: Tensor, b: Tensor) -> Tensor: def _floor_divide_integer(a: Tensor, b: Tensor) -> Tensor:
@ -1374,20 +1370,19 @@ def _check_close_args(
rtol: float, rtol: float,
atol: float, atol: float,
) -> None: ) -> None:
check( torch._check_value(
a.dtype == b.dtype, a.dtype == b.dtype,
lambda: "{0}: Attempting to compare tensors of different dtypes {1} and {2}!".format( lambda: "{0}: Attempting to compare tensors of different dtypes {1} and {2}!".format(
name, a.dtype, b.dtype name, a.dtype, b.dtype
), ),
ValueError,
) )
check( torch._check(
rtol >= 0, rtol >= 0,
lambda: "{0}: rtol must be greater than or equal to zero, but got {1}!".format( lambda: "{0}: rtol must be greater than or equal to zero, but got {1}!".format(
name, rtol name, rtol
), ),
) )
check( torch._check(
atol >= 0, atol >= 0,
lambda: "{0}: atol must be greater than or equal to zero, but got {1}!".format( lambda: "{0}: atol must be greater than or equal to zero, but got {1}!".format(
name, atol name, atol
@ -1678,7 +1673,7 @@ def true_divide(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
) )
def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]): def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]):
utils.check( torch._check(
isinstance(a, TensorLike) or isinstance(b, TensorLike), isinstance(a, TensorLike) or isinstance(b, TensorLike),
lambda: 'Expected either argument a or b to be a Tensor"', lambda: 'Expected either argument a or b to be a Tensor"',
) )
@ -1736,12 +1731,11 @@ def addcdiv(
if value is not None: if value is not None:
dtype = self.dtype # no scalars allowed, see add dtype = self.dtype # no scalars allowed, see add
python_type = utils.dtype_to_type(dtype) python_type = utils.dtype_to_type(dtype)
check( torch._check_value(
utils.is_weakly_lesser_type(type(value), python_type), utils.is_weakly_lesser_type(type(value), python_type),
lambda: "value argument of type {0} cannot be safely cast to type {1}!".format( lambda: "value argument of type {0} cannot be safely cast to type {1}!".format(
type(value), python_type type(value), python_type
), ),
exc_type=ValueError,
) )
return self + value * tensor1 / tensor2 return self + value * tensor1 / tensor2
@ -1766,12 +1760,11 @@ def addcmul(
if value is not None: if value is not None:
dtype = self.dtype # no scalars allowed, see add dtype = self.dtype # no scalars allowed, see add
python_type = utils.dtype_to_type(dtype) python_type = utils.dtype_to_type(dtype)
check( torch._check_value(
utils.is_weakly_lesser_type(type(value), python_type), utils.is_weakly_lesser_type(type(value), python_type),
lambda: "value argument of type {0} cannot be safely cast to type {1}!".format( lambda: "value argument of type {0} cannot be safely cast to type {1}!".format(
type(value), python_type type(value), python_type
), ),
exc_type=ValueError,
) )
return self + value * tensor1 * tensor2 return self + value * tensor1 * tensor2
@ -1851,7 +1844,7 @@ def where(
raise NotImplementedError raise NotImplementedError
utils.check_same_device(pred, a, b, allow_cpu_scalar_tensors=True) utils.check_same_device(pred, a, b, allow_cpu_scalar_tensors=True)
check( torch._check(
pred.dtype is torch.bool, pred.dtype is torch.bool,
lambda: f"expected predicate to be bool, got {pred.dtype}", lambda: f"expected predicate to be bool, got {pred.dtype}",
) )
@ -2229,7 +2222,7 @@ def sum_to_size(
*shape, *shape,
) -> Tensor: ) -> Tensor:
shape = utils.extract_shape_from_varargs(shape, validate=False) shape = utils.extract_shape_from_varargs(shape, validate=False)
utils.check( torch._check(
utils.is_expandable_to(shape, a.shape), utils.is_expandable_to(shape, a.shape),
lambda: f'sum_to_size: size "{shape}" is not expandable to size "{a.shape}"', lambda: f'sum_to_size: size "{shape}" is not expandable to size "{a.shape}"',
) )
@ -2402,7 +2395,7 @@ def mean(
if dtype is None: if dtype is None:
dtype = a.dtype dtype = a.dtype
# can't use out wrapper because of this argument # can't use out wrapper because of this argument
check( torch._check(
out is None or out.dtype == dtype, out is None or out.dtype == dtype,
lambda: f"Expected out tensor to have dtype {dtype}, but got {out.dtype} instead", lambda: f"Expected out tensor to have dtype {dtype}, but got {out.dtype} instead",
) )
@ -2415,7 +2408,7 @@ def mean(
out=None, out=None,
output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE, output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE,
) )
check( torch._check(
utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype), utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
lambda: ( lambda: (
f"mean(): could not infer output dtype. " f"mean(): could not infer output dtype. "
@ -2491,22 +2484,22 @@ def addr(
beta: NumberType = 1, beta: NumberType = 1,
alpha: NumberType = 1, alpha: NumberType = 1,
) -> TensorLikeType: ) -> TensorLikeType:
check( torch._check(
vec1.ndim == 1, vec1.ndim == 1,
lambda: f"addr: Expected 1-D argument vec1, but got {vec1.ndim}-D", lambda: f"addr: Expected 1-D argument vec1, but got {vec1.ndim}-D",
) )
check( torch._check(
vec2.ndim == 1, vec2.ndim == 1,
lambda: f"addr: Expected 1-D argument vec2, but got {vec2.ndim}-D", lambda: f"addr: Expected 1-D argument vec2, but got {vec2.ndim}-D",
) )
self = self.expand(vec1.shape[0], vec2.shape[0]) self = self.expand(vec1.shape[0], vec2.shape[0])
if utils.is_boolean_dtype(self.dtype): if utils.is_boolean_dtype(self.dtype):
# Integers are accepted for booleans # Integers are accepted for booleans
check( torch._check(
is_weakly_lesser_type(type(beta), int), is_weakly_lesser_type(type(beta), int),
lambda: f"expected bool/int beta but got {type(beta)}", lambda: f"expected bool/int beta but got {type(beta)}",
) )
check( torch._check(
is_weakly_lesser_type(type(alpha), int), is_weakly_lesser_type(type(alpha), int),
lambda: f"expected bool/int alpha but got {type(beta)}", lambda: f"expected bool/int alpha but got {type(beta)}",
) )
@ -2518,11 +2511,11 @@ def addr(
torch.outer(vec1, vec2) if alpha else torch.full_like(self, False), torch.outer(vec1, vec2) if alpha else torch.full_like(self, False),
) )
else: else:
check( torch._check(
is_weakly_lesser_type(type(beta), dtype_to_type(self.dtype)), is_weakly_lesser_type(type(beta), dtype_to_type(self.dtype)),
lambda: f"cannot safely convert {type(beta)} to {self.dtype}", lambda: f"cannot safely convert {type(beta)} to {self.dtype}",
) )
check( torch._check(
is_weakly_lesser_type(type(alpha), dtype_to_type(self.dtype)), is_weakly_lesser_type(type(alpha), dtype_to_type(self.dtype)),
lambda: f"cannot safely convert {type(alpha)} to {self.dtype}", lambda: f"cannot safely convert {type(alpha)} to {self.dtype}",
) )
@ -2712,7 +2705,7 @@ def conj(input: TensorLikeType) -> TensorLikeType:
def constant_pad_nd( def constant_pad_nd(
input: TensorLikeType, pad: List[int], value: NumberType = 0 input: TensorLikeType, pad: List[int], value: NumberType = 0
) -> TensorLikeType: ) -> TensorLikeType:
check( torch._check(
len(pad) % 2 == 0, len(pad) % 2 == 0,
lambda: f"Length of pad must be even but instead it equals {len(pad)}", lambda: f"Length of pad must be even but instead it equals {len(pad)}",
) )
@ -2723,7 +2716,7 @@ def constant_pad_nd(
l_pad = len(pad) // 2 l_pad = len(pad) // 2
l_diff = l_inp - l_pad l_diff = l_inp - l_pad
check( torch._check(
l_inp >= l_pad, l_inp >= l_pad,
lambda: "Length of pad should be no more than twice the number of " lambda: "Length of pad should be no more than twice the number of "
f"dimensions of the input. Pad length is {len(pad)} while the input has " f"dimensions of the input. Pad length is {len(pad)} while the input has "
@ -2748,7 +2741,7 @@ def constant_pad_nd(
for i in range(l_pad): for i in range(l_pad):
pad_idx = len(pad) - ((i + 1) * 2) pad_idx = len(pad) - ((i + 1) * 2)
new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1] new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1]
check( torch._check(
new_dim > 0, new_dim > 0,
lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding " lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding "
f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, " f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, "
@ -2787,7 +2780,7 @@ def constant_pad_nd(
def contiguous( def contiguous(
a: Tensor, *, memory_format: torch.memory_format = torch.contiguous_format a: Tensor, *, memory_format: torch.memory_format = torch.contiguous_format
) -> Tensor: ) -> Tensor:
check( torch._check(
memory_format != torch.preserve_format, memory_format != torch.preserve_format,
lambda: "preserve memory format is unsupported by the contiguous operator", lambda: "preserve memory format is unsupported by the contiguous operator",
) )
@ -2800,7 +2793,7 @@ def contiguous(
@out_wrapper() @out_wrapper()
def dstack(tensors: TensorSequenceType) -> TensorLikeType: def dstack(tensors: TensorSequenceType) -> TensorLikeType:
check(len(tensors) > 0, lambda: "dstack expects a non-empty TensorList") torch._check(len(tensors) > 0, lambda: "dstack expects a non-empty TensorList")
aligned_tensors = atleast_3d(*tensors) aligned_tensors = atleast_3d(*tensors)
return cat(aligned_tensors, 2) return cat(aligned_tensors, 2)
@ -2813,7 +2806,7 @@ def expand(a: Tensor, *shape) -> Tensor:
if len(shape) == 1 and isinstance(shape[0], Sequence): if len(shape) == 1 and isinstance(shape[0], Sequence):
shape = tuple(shape[0]) shape = tuple(shape[0])
check( torch._check(
len(shape) >= len(a.shape), len(shape) >= len(a.shape),
lambda: "expand: the requested shape has too few dimensions!", lambda: "expand: the requested shape has too few dimensions!",
) )
@ -2823,7 +2816,7 @@ def expand(a: Tensor, *shape) -> Tensor:
for idx, x in enumerate(a.shape): for idx, x in enumerate(a.shape):
offset_idx = idx + offset offset_idx = idx + offset
requested_length = shape[offset_idx] requested_length = shape[offset_idx]
check( torch._check(
requested_length == x or x == 1 or requested_length == -1, requested_length == x or x == 1 or requested_length == -1,
lambda: f"expand: attempting to expand a dimension of length {x}!", lambda: f"expand: attempting to expand a dimension of length {x}!",
) )
@ -2917,13 +2910,13 @@ def narrow(
# Supports Tensor overload that was added for XLA: # Supports Tensor overload that was added for XLA:
# https://github.com/pytorch/pytorch/issues/31558 # https://github.com/pytorch/pytorch/issues/31558
if isinstance(start, TensorLike): if isinstance(start, TensorLike):
check( torch._check(
start.dim() == 0 and utils.is_integer_dtype(start.dtype), start.dim() == 0 and utils.is_integer_dtype(start.dtype),
lambda: "start must be an 0-dim integral Tensor.", lambda: "start must be an 0-dim integral Tensor.",
) )
start = start.item() # type: ignore[assignment] start = start.item() # type: ignore[assignment]
check(a.dim() > 0, lambda: "narrow() cannot be applied to a 0-dim tensor.") torch._check(a.dim() > 0, lambda: "narrow() cannot be applied to a 0-dim tensor.")
check(length >= 0, lambda: "narrow(): length must be non-negative.") torch._check(length >= 0, lambda: "narrow(): length must be non-negative.")
dim = utils.canonicalize_dim(a.ndim, dim) dim = utils.canonicalize_dim(a.ndim, dim)
dim_length = a.size(dim) dim_length = a.size(dim)
# Start being the end is usually invalid since it's out of bounds. So it's # Start being the end is usually invalid since it's out of bounds. So it's
@ -2934,7 +2927,7 @@ def narrow(
# Note: a dimension isn't being canonicalized here, this reuses # Note: a dimension isn't being canonicalized here, this reuses
# canonicalize_dim because the semantics are similar. # canonicalize_dim because the semantics are similar.
start = utils.canonicalize_dim(dim_length, start) # type: ignore[arg-type] start = utils.canonicalize_dim(dim_length, start) # type: ignore[arg-type]
check( torch._check(
start <= dim_length - length, # type: ignore[arg-type] start <= dim_length - length, # type: ignore[arg-type]
lambda: f"start ({start}) + length ({length}) exceeds dimension size ({dim_length}).", lambda: f"start ({start}) + length ({length}) exceeds dimension size ({dim_length}).",
) )
@ -2993,11 +2986,11 @@ def native_group_norm(
num_groups: int, num_groups: int,
eps: float, eps: float,
) -> Tuple[Tensor, Tensor, Tensor]: ) -> Tuple[Tensor, Tensor, Tensor]:
utils.check( torch._check(
input.ndim >= 2, input.ndim >= 2,
lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}", lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}",
) )
utils.check( torch._check(
num_channels % num_groups == 0, num_channels % num_groups == 0,
lambda: "Expected number of channels in input to be divisible by num_groups, " lambda: "Expected number of channels in input to be divisible by num_groups, "
+ f"but got input of shape {input.shape} and num_groups = {num_groups}", + f"but got input of shape {input.shape} and num_groups = {num_groups}",
@ -3044,7 +3037,7 @@ def native_layer_norm(
eps: float, eps: float,
) -> Tuple[Tensor, Tensor, Tensor]: ) -> Tuple[Tensor, Tensor, Tensor]:
normalized_ndim = len(normalized_shape) normalized_ndim = len(normalized_shape)
utils.check( torch._check(
normalized_ndim >= 1, normalized_ndim >= 1,
lambda: "Expected normalized_shape to be at least 1-dimensional, i.e., " lambda: "Expected normalized_shape to be at least 1-dimensional, i.e., "
+ "containing at least one element, but got normalized_shape = " + "containing at least one element, but got normalized_shape = "
@ -3053,7 +3046,7 @@ def native_layer_norm(
# torch.Size([1, 2, 3]) == [1, 2, 3] evaluates to False # torch.Size([1, 2, 3]) == [1, 2, 3] evaluates to False
# while torch.Size([1, 2, 3]) == (1, 2, 3) is True # while torch.Size([1, 2, 3]) == (1, 2, 3) is True
# therefore we use tuple(normalized_shape) # therefore we use tuple(normalized_shape)
utils.check( torch._check(
weight is None or weight.shape == tuple(normalized_shape), weight is None or weight.shape == tuple(normalized_shape),
lambda: "Expected weight to be of same shape as normalized_shape, but got " lambda: "Expected weight to be of same shape as normalized_shape, but got "
+ "weight of shape " + "weight of shape "
@ -3061,7 +3054,7 @@ def native_layer_norm(
+ " and normalized_shape = " + " and normalized_shape = "
+ str(normalized_shape), + str(normalized_shape),
) )
utils.check( torch._check(
bias is None or bias.shape == tuple(normalized_shape), bias is None or bias.shape == tuple(normalized_shape),
lambda: "Expected bias to be of same shape as normalized_shape, but got " lambda: "Expected bias to be of same shape as normalized_shape, but got "
+ "bias of shape " + "bias of shape "
@ -3069,7 +3062,7 @@ def native_layer_norm(
+ " and normalized_shape = " + " and normalized_shape = "
+ str(normalized_shape), + str(normalized_shape),
) )
utils.check( torch._check(
input.ndim >= normalized_ndim input.ndim >= normalized_ndim
and input.shape[(input.ndim - normalized_ndim) :] == tuple(normalized_shape), and input.shape[(input.ndim - normalized_ndim) :] == tuple(normalized_shape),
lambda: "Given normalized_shape=" lambda: "Given normalized_shape="
@ -3123,12 +3116,12 @@ def _get_unfold_shape_stride(
max_size = 1 if a_ndim == 0 else a_shape[dim] max_size = 1 if a_ndim == 0 else a_shape[dim]
last_stride = 1 if a_ndim == 0 else a_stride[dim] last_stride = 1 if a_ndim == 0 else a_stride[dim]
utils.check( torch._check(
size <= max_size, size <= max_size,
lambda: f"Maximum size for tensor at dimension {dim} is {max_size} but size is {size}", lambda: f"Maximum size for tensor at dimension {dim} is {max_size} but size is {size}",
) )
utils.check( torch._check(
step > 0, step > 0,
lambda: f"Step is {step} but must be > 0", lambda: f"Step is {step} but must be > 0",
) )
@ -3146,7 +3139,7 @@ def _get_unfold_shape_stride(
@register_decomposition(aten.repeat) @register_decomposition(aten.repeat)
def repeat(a: Tensor, *repeat_shape) -> Tensor: def repeat(a: Tensor, *repeat_shape) -> Tensor:
repeat_shape = utils.extract_shape_from_varargs(repeat_shape, validate=False) repeat_shape = utils.extract_shape_from_varargs(repeat_shape, validate=False)
utils.check( torch._check(
len(repeat_shape) >= len(a.shape), len(repeat_shape) >= len(a.shape),
lambda: "repeat: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor", lambda: "repeat: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
) )
@ -3452,7 +3445,7 @@ def softmax(
# CompositeImplicitAutograd - don't register decomp # CompositeImplicitAutograd - don't register decomp
@out_wrapper() @out_wrapper()
def hstack(tensors: TensorSequenceType) -> TensorLikeType: def hstack(tensors: TensorSequenceType) -> TensorLikeType:
check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList") torch._check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList")
aligned_tensors = atleast_1d(*tensors) aligned_tensors = atleast_1d(*tensors)
if aligned_tensors[0].ndim == 1: if aligned_tensors[0].ndim == 1:
return cat(aligned_tensors, 0) return cat(aligned_tensors, 0)
@ -3462,7 +3455,7 @@ def hstack(tensors: TensorSequenceType) -> TensorLikeType:
# CompositeImplicitAutograd - don't register decomp # CompositeImplicitAutograd - don't register decomp
@out_wrapper() @out_wrapper()
def vstack(tensors: TensorSequenceType) -> TensorLikeType: def vstack(tensors: TensorSequenceType) -> TensorLikeType:
check(len(tensors) > 0, lambda: "vstack expects a non-empty TensorList") torch._check(len(tensors) > 0, lambda: "vstack expects a non-empty TensorList")
aligned_tensors = atleast_2d(*tensors) aligned_tensors = atleast_2d(*tensors)
return cat(aligned_tensors, 0) return cat(aligned_tensors, 0)
@ -3470,17 +3463,16 @@ def vstack(tensors: TensorSequenceType) -> TensorLikeType:
# CompositeImplicitAutograd - don't register decomp # CompositeImplicitAutograd - don't register decomp
def unflatten(a: TensorLikeType, dim: int, sizes: ShapeType) -> TensorLikeType: def unflatten(a: TensorLikeType, dim: int, sizes: ShapeType) -> TensorLikeType:
dim = utils.canonicalize_dim(a.ndim, dim) dim = utils.canonicalize_dim(a.ndim, dim)
utils.check(len(sizes) != 0, lambda: "unflatten: sizes must be non-empty") torch._check(len(sizes) != 0, lambda: "unflatten: sizes must be non-empty")
return a.view(tuple(a.shape[:dim]) + tuple(sizes) + tuple(a.shape[dim + 1 :])) return a.view(tuple(a.shape[:dim]) + tuple(sizes) + tuple(a.shape[dim + 1 :]))
@register_decomposition(aten.unbind) @register_decomposition(aten.unbind)
def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType: def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType:
dim = utils.canonicalize_dim(t.ndim, dim) dim = utils.canonicalize_dim(t.ndim, dim)
check( torch._check_index(
len(t.shape) > 0, len(t.shape) > 0,
lambda: "Dimension specified as 0 but tensor has no dimensions", lambda: "Dimension specified as 0 but tensor has no dimensions",
IndexError,
) )
if t.shape[dim] == 0: if t.shape[dim] == 0:
return tuple() return tuple()
@ -3499,7 +3491,7 @@ def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike): def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
dim = utils.canonicalize_dims(x.ndim, dim) dim = utils.canonicalize_dims(x.ndim, dim)
utils.check( torch._check(
index.ndim <= 1, index.ndim <= 1,
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
) )
@ -3532,12 +3524,12 @@ def _index_fill(
*, *,
inplace: bool, inplace: bool,
): ):
utils.check( torch._check(
index.ndim <= 1, index.ndim <= 1,
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
) )
if isinstance(value, TensorLike): if isinstance(value, TensorLike):
utils.check( torch._check(
value.ndim == 0, value.ndim == 0,
lambda: "Only supports 0-dimensional value tensor. " # type: ignore[union-attr] lambda: "Only supports 0-dimensional value tensor. " # type: ignore[union-attr]
f"Got a tensor with {value.ndim} dimensions.", f"Got a tensor with {value.ndim} dimensions.",
@ -3589,7 +3581,7 @@ def index_add(
@out_wrapper() @out_wrapper()
def index_select(x: TensorLike, dim: int, index: TensorLike): def index_select(x: TensorLike, dim: int, index: TensorLike):
dim = utils.canonicalize_dims(x.ndim, dim) dim = utils.canonicalize_dims(x.ndim, dim)
utils.check( torch._check(
index.ndim <= 1, index.ndim <= 1,
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
) )
@ -3713,7 +3705,7 @@ def tensor_split(
def hsplit( def hsplit(
a: TensorLikeType, indices_or_sections: DimsType a: TensorLikeType, indices_or_sections: DimsType
) -> Tuple[TensorLikeType, ...]: ) -> Tuple[TensorLikeType, ...]:
check( torch._check(
a.ndim >= 1, a.ndim >= 1,
lambda: ( lambda: (
"torch.hsplit requires a tensor with at least 1 dimension, but got a tensor with " "torch.hsplit requires a tensor with at least 1 dimension, but got a tensor with "
@ -3724,7 +3716,7 @@ def hsplit(
dim = 0 if a.ndim == 1 else 1 dim = 0 if a.ndim == 1 else 1
if isinstance(indices_or_sections, IntLike): if isinstance(indices_or_sections, IntLike):
split_size = indices_or_sections split_size = indices_or_sections
check( torch._check(
(split_size != 0 and a.shape[dim] % split_size == 0), (split_size != 0 and a.shape[dim] % split_size == 0),
lambda: ( lambda: (
"torch.hsplit attempted to split along dimension " "torch.hsplit attempted to split along dimension "
@ -3738,14 +3730,13 @@ def hsplit(
) )
return tensor_split(a, split_size, dim) return tensor_split(a, split_size, dim)
check( torch._check_type(
isinstance(indices_or_sections, (list, tuple)), isinstance(indices_or_sections, (list, tuple)),
lambda: ( lambda: (
"hsplit(): received an invalid combination of arguments. " "hsplit(): received an invalid combination of arguments. "
"Expected indices_or_sections to be of type int, list of ints or tuple of ints " "Expected indices_or_sections to be of type int, list of ints or tuple of ints "
f"but got type {type(indices_or_sections)}" f"but got type {type(indices_or_sections)}"
), ),
exc_type=TypeError,
) )
split_sizes = indices_or_sections split_sizes = indices_or_sections
@ -3756,7 +3747,7 @@ def hsplit(
def vsplit( def vsplit(
a: TensorLikeType, indices_or_sections: DimsType a: TensorLikeType, indices_or_sections: DimsType
) -> Tuple[TensorLikeType, ...]: ) -> Tuple[TensorLikeType, ...]:
check( torch._check(
a.ndim >= 2, a.ndim >= 2,
lambda: ( lambda: (
"torch.vsplit requires a tensor with at least 2 dimension, but got a tensor with " "torch.vsplit requires a tensor with at least 2 dimension, but got a tensor with "
@ -3766,7 +3757,7 @@ def vsplit(
) )
if isinstance(indices_or_sections, IntLike): if isinstance(indices_or_sections, IntLike):
split_size = indices_or_sections split_size = indices_or_sections
check( torch._check(
(split_size != 0 and a.shape[0] % split_size == 0), (split_size != 0 and a.shape[0] % split_size == 0),
lambda: ( lambda: (
f"torch.vsplit attempted to split along dimension 0" f"torch.vsplit attempted to split along dimension 0"
@ -3779,14 +3770,13 @@ def vsplit(
) )
return tensor_split(a, split_size, 0) return tensor_split(a, split_size, 0)
check( torch._check_type(
isinstance(indices_or_sections, (list, tuple)), isinstance(indices_or_sections, (list, tuple)),
lambda: ( lambda: (
"vsplit(): received an invalid combination of arguments. " "vsplit(): received an invalid combination of arguments. "
"Expected indices_or_sections to be of type int, list of ints or tuple of ints " "Expected indices_or_sections to be of type int, list of ints or tuple of ints "
f"but got type {type(indices_or_sections)}" f"but got type {type(indices_or_sections)}"
), ),
exc_type=TypeError,
) )
split_sizes = indices_or_sections split_sizes = indices_or_sections
@ -3800,7 +3790,7 @@ def diag(
offset: int = 0, offset: int = 0,
) -> TensorLikeType: ) -> TensorLikeType:
ndim = self.dim() ndim = self.dim()
utils.check( torch._check(
ndim in (1, 2), lambda: f"diag(): Supports 1D or 2D tensors. Got {ndim}D" ndim in (1, 2), lambda: f"diag(): Supports 1D or 2D tensors. Got {ndim}D"
) )
if ndim == 1: if ndim == 1:
@ -3820,7 +3810,7 @@ def diagonal_scatter(
) -> TensorLikeType: ) -> TensorLikeType:
out = utils.clone_preserve_strides(input) out = utils.clone_preserve_strides(input)
diag = out.diagonal(offset, dim1, dim2) diag = out.diagonal(offset, dim1, dim2)
check( torch._check(
diag.shape == src.shape, diag.shape == src.shape,
lambda: "expected src to have a size equal to the diagonal of the input." lambda: "expected src to have a size equal to the diagonal of the input."
f"Got {src.shape} for a diagonal of shape {diag.shape}", f"Got {src.shape} for a diagonal of shape {diag.shape}",
@ -3843,7 +3833,7 @@ def diagonal(
dim1 = utils.canonicalize_dim(idx=dim1, rank=num_dims) dim1 = utils.canonicalize_dim(idx=dim1, rank=num_dims)
dim2 = utils.canonicalize_dim(idx=dim2, rank=num_dims) dim2 = utils.canonicalize_dim(idx=dim2, rank=num_dims)
check( torch._check(
dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}" dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
) )
@ -3896,7 +3886,7 @@ def diag_embed(
dim1 = utils.canonicalize_dim(rank=rank, idx=dim1) dim1 = utils.canonicalize_dim(rank=rank, idx=dim1)
dim2 = utils.canonicalize_dim(rank=rank, idx=dim2) dim2 = utils.canonicalize_dim(rank=rank, idx=dim2)
check( torch._check(
dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}" dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
) )
@ -3967,7 +3957,7 @@ def t(a: TensorLikeType):
# CompositeImplicitAutograd - don't register decomp # CompositeImplicitAutograd - don't register decomp
def T(a: TensorLikeType) -> TensorLikeType: def T(a: TensorLikeType) -> TensorLikeType:
# n != 2 && n != 0 is deprecated in regular PyTorch. # n != 2 && n != 0 is deprecated in regular PyTorch.
check( torch._check(
a.ndim in (0, 2), a.ndim in (0, 2),
lambda: ( lambda: (
"The use of `x.T` on tensors of dimension other than 0 or 2 " "The use of `x.T` on tensors of dimension other than 0 or 2 "
@ -4102,7 +4092,7 @@ def empty(
pin_memory: bool = False, pin_memory: bool = False,
memory_format: torch.memory_format = torch.contiguous_format, memory_format: torch.memory_format = torch.contiguous_format,
) -> TensorLikeType: ) -> TensorLikeType:
check( torch._check(
memory_format != torch.preserve_format, memory_format != torch.preserve_format,
lambda: "torch.empty: the Preserve memory format is not supported", lambda: "torch.empty: the Preserve memory format is not supported",
) )
@ -4114,7 +4104,7 @@ def empty(
elif memory_format == torch.channels_last_3d: elif memory_format == torch.channels_last_3d:
strides = utils.make_channels_last_3d_strides_for(shape) strides = utils.make_channels_last_3d_strides_for(shape)
else: # memory_format == torch.channels_last else: # memory_format == torch.channels_last
check( torch._check(
memory_format == torch.channels_last, memory_format == torch.channels_last,
lambda: f"torch.empty: received an unknown memory format {memory_format}!", lambda: f"torch.empty: received an unknown memory format {memory_format}!",
) )
@ -4398,8 +4388,8 @@ def arange(
if end is None: if end is None:
end = start end = start
start = 0 start = 0
utils.check(step != 0, lambda: "step must be nonzero") torch._check(step != 0, lambda: "step must be nonzero")
utils.check( torch._check(
(step > 0 and end >= start) or (step < 0 and end <= start), (step > 0 and end >= start) or (step < 0 and end <= start),
lambda: "upper bound and lower bound inconsistent with step sign", lambda: "upper bound and lower bound inconsistent with step sign",
) )
@ -4407,11 +4397,11 @@ def arange(
def is_finite(x): def is_finite(x):
return not isinstance(x, FloatWithoutSymFloat) or math.isfinite(x) return not isinstance(x, FloatWithoutSymFloat) or math.isfinite(x)
utils.check( torch._check(
is_finite(start) and is_finite(end), is_finite(start) and is_finite(end),
lambda: f"unsupported range: {start} -> {end}", lambda: f"unsupported range: {start} -> {end}",
) )
utils.check( torch._check(
is_finite(step), is_finite(step),
lambda: f"step must be finite but got {step}", lambda: f"step must be finite but got {step}",
) )
@ -4514,7 +4504,7 @@ def linspace(
if dtype is None: if dtype is None:
dtype = default_complex_dtype dtype = default_complex_dtype
else: else:
check( torch._check(
utils.is_complex_dtype(dtype), utils.is_complex_dtype(dtype),
lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}", lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}",
) )
@ -4523,13 +4513,12 @@ def linspace(
assert isinstance(dtype, torch.dtype) assert isinstance(dtype, torch.dtype)
# steps does not participate in the computation of the dtype # steps does not participate in the computation of the dtype
check( torch._check_type(
isinstance(steps, IntLike), isinstance(steps, IntLike),
lambda: "steps must be int, not float", lambda: "steps must be int, not float",
exc_type=TypeError,
) )
assert isinstance(steps, IntLike) # for mypy assert isinstance(steps, IntLike) # for mypy
check(steps >= 0, lambda: "number of steps must be non-negative") torch._check(steps >= 0, lambda: "number of steps must be non-negative")
factory_kwargs = { factory_kwargs = {
"layout": layout, "layout": layout,
@ -4631,19 +4620,19 @@ def meshgrid(
assert len(tensors) == 1 assert len(tensors) == 1
tensors = tuple(tensors[0]) tensors = tuple(tensors[0])
check( torch._check(
py_all(isinstance(a, TensorLike) for a in tensors), py_all(isinstance(a, TensorLike) for a in tensors),
lambda: "meshgrid expects its inputs to be tensors", lambda: "meshgrid expects its inputs to be tensors",
) )
check(len(tensors) > 0, lambda: "meshgrid expects a non-empty TensorList") torch._check(len(tensors) > 0, lambda: "meshgrid expects a non-empty TensorList")
for i in range(len(tensors) - 1): for i in range(len(tensors) - 1):
check( torch._check(
tensors[i].dtype == tensors[i + 1].dtype, # type: ignore[union-attr] tensors[i].dtype == tensors[i + 1].dtype, # type: ignore[union-attr]
lambda: "meshgrid expects all tensors to have the same dtype", lambda: "meshgrid expects all tensors to have the same dtype",
) )
check( torch._check(
tensors[i].device == tensors[i + 1].device, # type: ignore[union-attr] tensors[i].device == tensors[i + 1].device, # type: ignore[union-attr]
lambda: "meshgrid expects all tensors to have the same device", lambda: "meshgrid expects all tensors to have the same device",
) )
@ -4654,7 +4643,7 @@ def meshgrid(
if swap_first_and_second_tensors: if swap_first_and_second_tensors:
tensors = (tensors[1], tensors[0], *tensors[2:]) tensors = (tensors[1], tensors[0], *tensors[2:])
else: else:
check( torch._check(
indexing == "ij", indexing == "ij",
lambda: ( lambda: (
'torch.meshgrid: indexing must be one of "xy" or "ij", ' 'torch.meshgrid: indexing must be one of "xy" or "ij", '
@ -4665,7 +4654,7 @@ def meshgrid(
result_shape: List[int] = [] result_shape: List[int] = []
for t in tensors: for t in tensors:
assert isinstance(t, TensorLike) # mypy assert isinstance(t, TensorLike) # mypy
check( torch._check(
t.ndim == 0 or t.ndim == 1, t.ndim == 0 or t.ndim == 1,
lambda: f"torch.meshgrid: Expected 0D or 1D tensor in the tensor list but got: {t}", lambda: f"torch.meshgrid: Expected 0D or 1D tensor in the tensor list but got: {t}",
) )
@ -4701,7 +4690,7 @@ def movedim(
# Converts to list to produce a compatible error message with core PyTorch, # Converts to list to produce a compatible error message with core PyTorch,
# which prints sequences in square brackets. # which prints sequences in square brackets.
utils.check( torch._check(
len(source) == len(destination), # type: ignore[arg-type] len(source) == len(destination), # type: ignore[arg-type]
lambda: ( lambda: (
"movedim: Invalid source or destination dims: source " # type: ignore[arg-type] "movedim: Invalid source or destination dims: source " # type: ignore[arg-type]
@ -4718,11 +4707,11 @@ def movedim(
dss = set(ds) dss = set(ds)
# See above on why this converts to list in error messages. # See above on why this converts to list in error messages.
utils.check( torch._check(
len(ss) == len(sss), len(ss) == len(sss),
lambda: f"movedim: repeated dim in `source` ({list(source)})", # type: ignore[arg-type] lambda: f"movedim: repeated dim in `source` ({list(source)})", # type: ignore[arg-type]
) )
utils.check( torch._check(
len(ds) == len(dss), len(ds) == len(dss),
lambda: f"movedim: repeated dim in `destination` ({list(destination)})", # type: ignore[arg-type] lambda: f"movedim: repeated dim in `destination` ({list(destination)})", # type: ignore[arg-type]
) )
@ -4795,8 +4784,8 @@ def eye(
if m is None: if m is None:
m = n m = n
check(n >= 0, lambda: f"n must be greater or equal to 0, got {n}") torch._check(n >= 0, lambda: f"n must be greater or equal to 0, got {n}")
check(m >= 0, lambda: f"m must be greater or equal to 0, got {m}") torch._check(m >= 0, lambda: f"m must be greater or equal to 0, got {m}")
range_n = torch.arange(n, dtype=torch.int64, device=device, requires_grad=False) range_n = torch.arange(n, dtype=torch.int64, device=device, requires_grad=False)
range_m = torch.arange(m, dtype=torch.int64, device=device, requires_grad=False) range_m = torch.arange(m, dtype=torch.int64, device=device, requires_grad=False)
@ -4994,13 +4983,13 @@ def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLi
# NOTE: Could not use value = item(value) as it resulted in # NOTE: Could not use value = item(value) as it resulted in
# RuntimeError: Cannot cast FakeTensor(cpu) to number # RuntimeError: Cannot cast FakeTensor(cpu) to number
value_ndim = value.ndim value_ndim = value.ndim
check( torch._check(
value_ndim == 0, value_ndim == 0,
lambda: f"only supports a 0-dimensional value tensor, but got tensor with {value_ndim} dimension", lambda: f"only supports a 0-dimensional value tensor, but got tensor with {value_ndim} dimension",
) )
# `masked_fill` allows cpu scalar to be moved to cuda and xpu but not otherwise. # `masked_fill` allows cpu scalar to be moved to cuda and xpu but not otherwise.
is_cpu_scalar = a.device.type in ["cuda", "xpu"] and value.device.type == "cpu" is_cpu_scalar = a.device.type in ["cuda", "xpu"] and value.device.type == "cpu"
check( torch._check(
is_cpu_scalar or value.device == a.device, is_cpu_scalar or value.device == a.device,
lambda: "Expected `value` to be on same device as `a`", lambda: "Expected `value` to be on same device as `a`",
) )
@ -5011,7 +5000,7 @@ def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLi
# We allow casting `value` to lower type for other case # We allow casting `value` to lower type for other case
# Eg. float -> int. # Eg. float -> int.
# Ref: https://github.com/pytorch/pytorch/issues/79195 # Ref: https://github.com/pytorch/pytorch/issues/79195
check( torch._check(
utils.is_weakly_lesser_type(value_type, python_type), utils.is_weakly_lesser_type(value_type, python_type),
lambda: f"could not convert to type {python_type} without overflow", lambda: f"could not convert to type {python_type} without overflow",
) )
@ -5101,7 +5090,7 @@ def norm(
@register_decomposition(aten.trace) @register_decomposition(aten.trace)
def trace(self: TensorLikeType) -> TensorLikeType: def trace(self: TensorLikeType) -> TensorLikeType:
utils.check( torch._check(
self.ndim == 2, lambda: "expected a matrix, but got tensor with dim {self.ndim}" self.ndim == 2, lambda: "expected a matrix, but got tensor with dim {self.ndim}"
) )
return torch.sum(torch.diag(self, 0)) return torch.sum(torch.diag(self, 0))
@ -5125,7 +5114,7 @@ rpow = _make_r_binary_op(pow)
@register_decomposition(aten.triu) @register_decomposition(aten.triu)
@out_wrapper() @out_wrapper()
def triu(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType: def triu(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
utils.check( torch._check(
a.ndim >= 2, lambda: "triu: input tensor must have at least 2 dimensions" a.ndim >= 2, lambda: "triu: input tensor must have at least 2 dimensions"
) )
h, w = a.shape[-2:] h, w = a.shape[-2:]
@ -5142,7 +5131,7 @@ def triu(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
@register_decomposition(aten.tril) @register_decomposition(aten.tril)
@out_wrapper() @out_wrapper()
def tril(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType: def tril(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
utils.check( torch._check(
a.ndim >= 2, lambda: "tril: input tensor must have at least 2 dimensions" a.ndim >= 2, lambda: "tril: input tensor must have at least 2 dimensions"
) )
h, w = a.shape[-2:] h, w = a.shape[-2:]
@ -5187,9 +5176,9 @@ def _trilu_checks(
layout: torch.layout, layout: torch.layout,
pin_memory: bool, pin_memory: bool,
): ):
check(row >= 0, lambda: f"row must be non-negative, got {row}") torch._check(row >= 0, lambda: f"row must be non-negative, got {row}")
check(col >= 0, lambda: f"col must be non-negative, got {col}") torch._check(col >= 0, lambda: f"col must be non-negative, got {col}")
check( torch._check(
dtype in (torch.int32, torch.int64), dtype in (torch.int32, torch.int64),
lambda: f"\"{name}\" not implemented for '{dtype}'", lambda: f"\"{name}\" not implemented for '{dtype}'",
) )
@ -5306,7 +5295,7 @@ def bucketize(
out_int32: bool = False, out_int32: bool = False,
right: bool = False, right: bool = False,
): ):
utils.check( torch._check(
boundaries.dim() == 1, boundaries.dim() == 1,
lambda: f"boundaries tensor must be 1 dimension but got dim({boundaries.dim()})", lambda: f"boundaries tensor must be 1 dimension but got dim({boundaries.dim()})",
) )
@ -5364,14 +5353,14 @@ def bucketize(
) )
def cauchy(self, median=0, sigma=1, generator=None): def cauchy(self, median=0, sigma=1, generator=None):
assert generator is None assert generator is None
utils.check( torch._check(
not utils.is_complex_dtype(self.dtype) not utils.is_complex_dtype(self.dtype)
and not utils.is_integer_dtype(self.dtype) and not utils.is_integer_dtype(self.dtype)
and not utils.is_boolean_dtype(self.dtype), and not utils.is_boolean_dtype(self.dtype),
lambda: f"Cauchy distribution is a continuous probability distribution. \ lambda: f"Cauchy distribution is a continuous probability distribution. \
dtype must be a floating point but you specified {self.dtype}", dtype must be a floating point but you specified {self.dtype}",
) )
utils.check( torch._check(
sigma > 0.0, sigma > 0.0,
lambda: f"cauchy_ expects sigma > 0.0, but found sigma={sigma}", lambda: f"cauchy_ expects sigma > 0.0, but found sigma={sigma}",
) )
@ -5386,14 +5375,14 @@ def cauchy(self, median=0, sigma=1, generator=None):
) )
def exponential(self, rate=1, generator=None): def exponential(self, rate=1, generator=None):
assert generator is None assert generator is None
utils.check( torch._check(
not utils.is_complex_dtype(self.dtype) not utils.is_complex_dtype(self.dtype)
and not utils.is_integer_dtype(self.dtype) and not utils.is_integer_dtype(self.dtype)
and not utils.is_boolean_dtype(self.dtype), and not utils.is_boolean_dtype(self.dtype),
lambda: f"Exponential distribution is a continuous probability distribution. \ lambda: f"Exponential distribution is a continuous probability distribution. \
dtype must be a floating point but you specified {self.dtype}", dtype must be a floating point but you specified {self.dtype}",
) )
utils.check( torch._check(
rate > 0.0, rate > 0.0,
lambda: f"exponential_ expects lambda > 0.0, but found lambda={rate}", lambda: f"exponential_ expects lambda > 0.0, but found lambda={rate}",
) )
@ -5409,12 +5398,12 @@ def exponential(self, rate=1, generator=None):
def geometric(self, p, generator=None): def geometric(self, p, generator=None):
assert generator is None assert generator is None
# TODO: fix inductor rand_like for integer, bool dtypes # TODO: fix inductor rand_like for integer, bool dtypes
utils.check( torch._check(
not utils.is_complex_dtype(self.dtype) not utils.is_complex_dtype(self.dtype)
and not utils.is_boolean_dtype(self.dtype), and not utils.is_boolean_dtype(self.dtype),
lambda: f"geometric not implemented for {self.dtype}", lambda: f"geometric not implemented for {self.dtype}",
) )
utils.check( torch._check(
0 < p and p < 1, 0 < p and p < 1,
lambda: f"geometric_ expects p to be in (0, 1), but got p={p}", lambda: f"geometric_ expects p to be in (0, 1), but got p={p}",
) )
@ -5429,13 +5418,13 @@ def geometric(self, p, generator=None):
) )
def log_normal(self, mean=1, std=2, generator=None): def log_normal(self, mean=1, std=2, generator=None):
assert generator is None assert generator is None
utils.check( torch._check(
not utils.is_complex_dtype(self.dtype) not utils.is_complex_dtype(self.dtype)
and not utils.is_integer_dtype(self.dtype) and not utils.is_integer_dtype(self.dtype)
and not utils.is_boolean_dtype(self.dtype), and not utils.is_boolean_dtype(self.dtype),
lambda: f"log_normal not implemented for {self.dtype}", lambda: f"log_normal not implemented for {self.dtype}",
) )
utils.check( torch._check(
0 < std, 0 < std,
lambda: f"log_normal_ expects std > 0.0, but found std={std}", lambda: f"log_normal_ expects std > 0.0, but found std={std}",
) )
@ -5451,7 +5440,7 @@ def log_normal(self, mean=1, std=2, generator=None):
) )
def normal(self, mean=0, std=1, generator=None): def normal(self, mean=0, std=1, generator=None):
assert generator is None assert generator is None
utils.check(std >= 0, lambda: f"normal expects std >= 0.0, but found std {std}") torch._check(std >= 0, lambda: f"normal expects std >= 0.0, but found std {std}")
normal_samples = prims.normal( normal_samples = prims.normal(
self.shape, self.shape,
mean=0.0, mean=0.0,
@ -5465,7 +5454,7 @@ def normal(self, mean=0, std=1, generator=None):
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def rad2deg(self: TensorLikeType): def rad2deg(self: TensorLikeType):
utils.check( torch._check(
not utils.is_complex_dtype(self.dtype), not utils.is_complex_dtype(self.dtype),
lambda: "rad2deg is not supported for complex tensors.", lambda: "rad2deg is not supported for complex tensors.",
) )
@ -5475,7 +5464,7 @@ def rad2deg(self: TensorLikeType):
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def deg2rad(self: TensorLikeType): def deg2rad(self: TensorLikeType):
utils.check( torch._check(
not utils.is_complex_dtype(self.dtype), not utils.is_complex_dtype(self.dtype),
lambda: "deg2rad is not supported for complex tensors.", lambda: "deg2rad is not supported for complex tensors.",
) )

View File

@ -4,7 +4,7 @@ import torch._prims_common as utils
# Utilities should come BEFORE this import # Utilities should come BEFORE this import
from torch._decomp import register_decomposition from torch._decomp import register_decomposition
from torch._prims_common import check, TensorLikeType from torch._prims_common import TensorLikeType
from torch._prims_common.wrappers import out_wrapper from torch._prims_common.wrappers import out_wrapper
from torch._refs import _broadcast_shapes from torch._refs import _broadcast_shapes
@ -79,14 +79,14 @@ short = _make_conversion_method("short", torch.short)
@out_wrapper(exact_dtype=True) @out_wrapper(exact_dtype=True)
def complex(real: TensorLikeType, imag: TensorLikeType) -> TensorLikeType: def complex(real: TensorLikeType, imag: TensorLikeType) -> TensorLikeType:
allowed_dtypes = (torch.float32, torch.float64, torch.float16) allowed_dtypes = (torch.float32, torch.float64, torch.float16)
check( torch._check(
real.dtype in allowed_dtypes and imag.dtype in allowed_dtypes, real.dtype in allowed_dtypes and imag.dtype in allowed_dtypes,
lambda: ( lambda: (
f"Expected both inputs to be Half, Float or Double tensors but got " f"Expected both inputs to be Half, Float or Double tensors but got "
f"{real.dtype} and {imag.dtype}" f"{real.dtype} and {imag.dtype}"
), ),
) )
check( torch._check(
real.dtype == imag.dtype, real.dtype == imag.dtype,
lambda: ( lambda: (
f"Expected object of scalar type {real.dtype} but got " f"Expected object of scalar type {real.dtype} but got "

View File

@ -6,7 +6,7 @@ import torch
import torch._prims as prims import torch._prims as prims
import torch._prims_common as utils import torch._prims_common as utils
from torch._decomp import register_decomposition from torch._decomp import register_decomposition
from torch._prims_common import check, DimsType, ShapeType, TensorLikeType from torch._prims_common import DimsType, ShapeType, TensorLikeType
from torch._prims_common.wrappers import _maybe_convert_to_dtype, out_wrapper from torch._prims_common.wrappers import _maybe_convert_to_dtype, out_wrapper
__all__ = [ __all__ = [
@ -43,7 +43,7 @@ def _apply_norm(
x: TensorLikeType, norm: NormType, signal_numel: int, forward: bool x: TensorLikeType, norm: NormType, signal_numel: int, forward: bool
) -> TensorLikeType: ) -> TensorLikeType:
"""Apply normalization to the un-normalized FFT result""" """Apply normalization to the un-normalized FFT result"""
check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}") torch._check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}")
if norm == "ortho": if norm == "ortho":
return x * (1 / math.sqrt(signal_numel)) return x * (1 / math.sqrt(signal_numel))
@ -116,7 +116,9 @@ def _fft_c2r(
input = _maybe_promote_tensor_fft(input, require_complex=True) input = _maybe_promote_tensor_fft(input, require_complex=True)
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),) dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1) last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1)
check(last_dim_size >= 1, lambda: f"Invalid number of data points ({n}) specified") torch._check(
last_dim_size >= 1, lambda: f"Invalid number of data points ({n}) specified"
)
if n is not None: if n is not None:
input = _resize_fft_input(input, dims=dims, sizes=(last_dim_size // 2 + 1,)) input = _resize_fft_input(input, dims=dims, sizes=(last_dim_size // 2 + 1,))
@ -138,7 +140,7 @@ def _fft_r2c(
onesided: bool, onesided: bool,
) -> TensorLikeType: ) -> TensorLikeType:
"""Common code for performing any real to complex FFT (rfft or ihfft)""" """Common code for performing any real to complex FFT (rfft or ihfft)"""
check( torch._check(
not input.dtype.is_complex, not input.dtype.is_complex,
lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}", lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}",
) )
@ -162,7 +164,7 @@ def _fft_c2c(
forward: bool, forward: bool,
) -> TensorLikeType: ) -> TensorLikeType:
"""Common code for performing any complex to complex FFT (fft or ifft)""" """Common code for performing any complex to complex FFT (fft or ifft)"""
check( torch._check(
input.dtype.is_complex, input.dtype.is_complex,
lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}", lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}",
) )
@ -265,20 +267,20 @@ def _canonicalize_fft_shape_and_dim_args(
ret_dims = utils.canonicalize_dims(input_dim, dim, wrap_scalar=False) ret_dims = utils.canonicalize_dims(input_dim, dim, wrap_scalar=False)
# Check dims are unique # Check dims are unique
check(len(set(dim)) == len(dim), lambda: "FFT dims must be unique") torch._check(len(set(dim)) == len(dim), lambda: "FFT dims must be unique")
if shape is not None: if shape is not None:
if not isinstance(shape, Sequence): if not isinstance(shape, Sequence):
shape = (shape,) shape = (shape,)
# Has shape, might have dim # Has shape, might have dim
check( torch._check(
dim is None or len(dim) == len(shape), dim is None or len(dim) == len(shape),
lambda: "When given, dim and shape arguments must have the same length", lambda: "When given, dim and shape arguments must have the same length",
) )
transform_ndim = len(shape) transform_ndim = len(shape)
check( torch._check(
transform_ndim <= input_dim, transform_ndim <= input_dim,
lambda: f"Got shape with {transform_ndim} values but input tensor " lambda: f"Got shape with {transform_ndim} values but input tensor "
f"only has {input_dim} dimensions.", f"only has {input_dim} dimensions.",
@ -301,7 +303,7 @@ def _canonicalize_fft_shape_and_dim_args(
ret_shape = tuple(input_sizes[d] for d in ret_dims) ret_shape = tuple(input_sizes[d] for d in ret_dims)
for n in ret_shape: for n in ret_shape:
check(n > 0, lambda: f"Invalid number of data points ({n}) specified") torch._check(n > 0, lambda: f"Invalid number of data points ({n}) specified")
return _ShapeAndDims(shape=ret_shape, dims=ret_dims) return _ShapeAndDims(shape=ret_shape, dims=ret_dims)
@ -323,7 +325,7 @@ def _fftn_c2c(
forward: bool, forward: bool,
) -> TensorLikeType: ) -> TensorLikeType:
"""Common code for n-dimensional complex to complex FFTs (fftn or ifftn)""" """Common code for n-dimensional complex to complex FFTs (fftn or ifftn)"""
check( torch._check(
input.dtype.is_complex, input.dtype.is_complex,
lambda: f"{function_name} expects a complex input tensor, " lambda: f"{function_name} expects a complex input tensor, "
f"but got {input.dtype}", f"but got {input.dtype}",
@ -367,7 +369,7 @@ def rfftn(
dim: Optional[DimsType] = None, dim: Optional[DimsType] = None,
norm: NormType = None, norm: NormType = None,
) -> TensorLikeType: ) -> TensorLikeType:
check( torch._check(
not input.dtype.is_complex, not input.dtype.is_complex,
lambda: f"rfftn expects a real-valued input tensor, but got {input.dtype}", lambda: f"rfftn expects a real-valued input tensor, but got {input.dtype}",
) )
@ -386,12 +388,12 @@ def ihfftn(
dim: Optional[DimsType] = None, dim: Optional[DimsType] = None,
norm: NormType = None, norm: NormType = None,
) -> TensorLikeType: ) -> TensorLikeType:
check( torch._check(
not input.dtype.is_complex, not input.dtype.is_complex,
lambda: f"ihfftn expects a real-valued input tensor, but got {input.dtype}", lambda: f"ihfftn expects a real-valued input tensor, but got {input.dtype}",
) )
shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim) shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim)
check(len(shape) > 0, lambda: "ihfftn must transform at least one axis") torch._check(len(shape) > 0, lambda: "ihfftn must transform at least one axis")
input = _maybe_promote_tensor_fft(input, require_complex=False) input = _maybe_promote_tensor_fft(input, require_complex=False)
input = _resize_fft_input(input, dim, shape) input = _resize_fft_input(input, dim, shape)
@ -421,14 +423,14 @@ def _canonicalize_fft_c2r_shape_and_dim_args(
"""Canonicalize shape and dim arguments for n-dimensional c2r transforms, """Canonicalize shape and dim arguments for n-dimensional c2r transforms,
as well as calculating the last_dim_size which is shape[dim[-1]] for the output""" as well as calculating the last_dim_size which is shape[dim[-1]] for the output"""
(shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim) (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
check(len(shape) > 0, lambda: f"{fname} must transform at least one axis") torch._check(len(shape) > 0, lambda: f"{fname} must transform at least one axis")
if s is None or s[-1] == -1: if s is None or s[-1] == -1:
last_dim_size = 2 * (input.shape[dim[-1]] - 1) last_dim_size = 2 * (input.shape[dim[-1]] - 1)
else: else:
last_dim_size = shape[-1] last_dim_size = shape[-1]
check( torch._check(
last_dim_size >= 1, last_dim_size >= 1,
lambda: f"Invalid number of data points ({last_dim_size}) specified", lambda: f"Invalid number of data points ({last_dim_size}) specified",
) )

View File

@ -11,7 +11,6 @@ import torch._refs as refs
import torch._refs.linalg as linalg import torch._refs.linalg as linalg
from torch import Tensor from torch import Tensor
from torch._prims_common import ( from torch._prims_common import (
check,
check_fp_or_complex, check_fp_or_complex,
check_is_matrix, check_is_matrix,
Dim, Dim,
@ -29,11 +28,11 @@ def _check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_nam
Checks related to the dtype kwarg in `linalg.*norm` functions Checks related to the dtype kwarg in `linalg.*norm` functions
""" """
if dtype is not None: if dtype is not None:
check( torch._check(
utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype), utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
lambda: f"{fn_name}: dtype should be floating point or complex. Got {dtype}", lambda: f"{fn_name}: dtype should be floating point or complex. Got {dtype}",
) )
check( torch._check(
utils.is_complex_dtype(dtype) == utils.is_complex_dtype(x_dtype), utils.is_complex_dtype(dtype) == utils.is_complex_dtype(x_dtype),
lambda: "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}".format( lambda: "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}".format(
fn_name=fn_name, fn_name=fn_name,
@ -41,7 +40,7 @@ def _check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_nam
dtype=dtype, dtype=dtype,
), ),
) )
check( torch._check(
utils.get_higher_dtype(dtype, x_dtype) == dtype, utils.get_higher_dtype(dtype, x_dtype) == dtype,
lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible " lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible "
"without narrowing to the specified dtype ({dtype})", "without narrowing to the specified dtype ({dtype})",
@ -79,7 +78,7 @@ def vector_norm(
dim = [dim] # type: ignore[assignment] dim = [dim] # type: ignore[assignment]
if x.numel() == 0 and (ord < 0.0 or ord == float("inf")): if x.numel() == 0 and (ord < 0.0 or ord == float("inf")):
check( torch._check(
dim is not None and len(dim) != 0, dim is not None and len(dim) != 0,
lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor " lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor "
"because the operation does not have an identity", "because the operation does not have an identity",
@ -87,7 +86,7 @@ def vector_norm(
shape = x.shape shape = x.shape
assert dim is not None # mypy does not seem to be able to see through check? assert dim is not None # mypy does not seem to be able to see through check?
for d in dim: for d in dim:
check( torch._check(
shape[d] != 0, shape[d] != 0,
lambda: f"linalg.vector_norm cannot compute the {ord} norm on the " lambda: f"linalg.vector_norm cannot compute the {ord} norm on the "
f"dimension {d} because this dimension is empty and the " f"dimension {d} because this dimension is empty and the "
@ -147,8 +146,10 @@ def matrix_norm(
dim = utils.canonicalize_dims(A.ndim, dim) dim = utils.canonicalize_dims(A.ndim, dim)
if isinstance(dim, Dim): if isinstance(dim, Dim):
dim = (dim,) # type: ignore[assignment] dim = (dim,) # type: ignore[assignment]
check(len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}") torch._check(
check( len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}"
)
torch._check(
dim[0] != dim[1], dim[0] != dim[1],
lambda: "linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})", lambda: "linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})",
) )
@ -157,7 +158,7 @@ def matrix_norm(
if isinstance(ord, str): if isinstance(ord, str):
# ord # ord
check( torch._check(
ord in ("fro", "nuc"), ord in ("fro", "nuc"),
lambda: "linalg.matrix_norm: Order {ord} not supported.", lambda: "linalg.matrix_norm: Order {ord} not supported.",
) )
@ -180,7 +181,7 @@ def matrix_norm(
else: else:
# ord # ord
abs_ord = abs(ord) abs_ord = abs(ord)
check( torch._check(
abs_ord in (2, 1, float("inf")), abs_ord in (2, 1, float("inf")),
lambda: "linalg.matrix_norm: Order {ord} not supported.", lambda: "linalg.matrix_norm: Order {ord} not supported.",
) )
@ -224,12 +225,12 @@ def norm(
if dim is not None: if dim is not None:
if isinstance(dim, Dim): if isinstance(dim, Dim):
dim = (dim,) # type: ignore[assignment] dim = (dim,) # type: ignore[assignment]
check( torch._check(
len(dim) in (1, 2), len(dim) in (1, 2),
lambda: "linalg.norm: If dim is specified, it must be of length 1 or 2. Got {dim}", lambda: "linalg.norm: If dim is specified, it must be of length 1 or 2. Got {dim}",
) )
elif ord is not None: elif ord is not None:
check( torch._check(
A.ndim in (1, 2), A.ndim in (1, 2),
lambda: "linalg.norm: If dim is not specified but ord is, the input must be 1D or 2D. Got {A.ndim}D", lambda: "linalg.norm: If dim is not specified but ord is, the input must be 1D or 2D. Got {A.ndim}D",
) )

View File

@ -8,7 +8,6 @@ import torch._prims_common as utils
import torch._refs as refs import torch._refs as refs
from torch._decomp import register_decomposition from torch._decomp import register_decomposition
from torch._prims_common import ( from torch._prims_common import (
check,
ELEMENTWISE_TYPE_PROMOTION_KIND, ELEMENTWISE_TYPE_PROMOTION_KIND,
NumberType, NumberType,
ShapeType, ShapeType,
@ -98,7 +97,7 @@ def alpha_dropout(
if not training: if not training:
return self return self
utils.check( torch._check(
p <= 1 and p >= 0, p <= 1 and p >= 0,
lambda: f"dropout probability has to be between 0 and 1, but got, {p}", lambda: f"dropout probability has to be between 0 and 1, but got, {p}",
) )
@ -134,7 +133,7 @@ def _inplace_wrapper(fn):
@wraps(fn) @wraps(fn)
def _fn(a, *args, inplace=False, **kwargs): def _fn(a, *args, inplace=False, **kwargs):
if inplace: if inplace:
check( torch._check(
"out" not in kwargs, "out" not in kwargs,
lambda: "Cannot set inplace=True and pass out= at the same time", lambda: "Cannot set inplace=True and pass out= at the same time",
) )
@ -193,7 +192,7 @@ def dropout(
if not training: if not training:
return a return a
utils.check( torch._check(
p <= 1 and p >= 0, p <= 1 and p >= 0,
lambda: f"dropout probability has to be between 0 and 1, but got, {p}", lambda: f"dropout probability has to be between 0 and 1, but got, {p}",
) )
@ -232,15 +231,15 @@ def elu(
# nb. This should be factored out into a can_cast aux function # nb. This should be factored out into a can_cast aux function
python_type = utils.dtype_to_type(a.dtype) python_type = utils.dtype_to_type(a.dtype)
check( torch._check(
utils.is_weakly_lesser_type(type(input_scale), python_type), utils.is_weakly_lesser_type(type(input_scale), python_type),
lambda: f"input_scale argument of type {type(input_scale)} cannot be safely cast to type {python_type}!", lambda: f"input_scale argument of type {type(input_scale)} cannot be safely cast to type {python_type}!",
) )
check( torch._check(
utils.is_weakly_lesser_type(type(scale), python_type), utils.is_weakly_lesser_type(type(scale), python_type),
lambda: f"scale argument of type {type(scale)} cannot be safely cast to type {python_type}!", lambda: f"scale argument of type {type(scale)} cannot be safely cast to type {python_type}!",
) )
check( torch._check(
utils.is_weakly_lesser_type(type(alpha), python_type), utils.is_weakly_lesser_type(type(alpha), python_type),
lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!", lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!",
) )
@ -276,14 +275,14 @@ def group_norm(
""" """
Reference implementation of :func:`torch.nn.functional.group_norm`. Reference implementation of :func:`torch.nn.functional.group_norm`.
""" """
utils.check( torch._check(
input.ndim >= 2, input.ndim >= 2,
lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}", lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}",
) )
batch_size = input.shape[0] batch_size = input.shape[0]
num_channels = input.shape[1] num_channels = input.shape[1]
utils.check( torch._check(
num_channels % num_groups == 0, num_channels % num_groups == 0,
lambda: "Expected number of channels in input to be divisible by num_groups, " lambda: "Expected number of channels in input to be divisible by num_groups, "
+ f"but got input of shape {input.shape} and num_groups = {num_groups}", + f"but got input of shape {input.shape} and num_groups = {num_groups}",
@ -394,7 +393,7 @@ def softmax(
# deprecated. For PrimTorch, it's fine to drop support for deprecated # deprecated. For PrimTorch, it's fine to drop support for deprecated
# behavior because it requires explicit opt in. This error is to inform # behavior because it requires explicit opt in. This error is to inform
# users how to update their calls. # users how to update their calls.
check(dim is not None, lambda: "implicit dim not supported, use dim=X") torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X")
return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload]
@ -409,7 +408,7 @@ def softmin(
# deprecated. For PrimTorch, it's fine to drop support for deprecated # deprecated. For PrimTorch, it's fine to drop support for deprecated
# behavior because it requires explicit opt in. This error is to inform # behavior because it requires explicit opt in. This error is to inform
# users how to update their calls. # users how to update their calls.
check(dim is not None, lambda: "implicit dim not supported, use dim=X") torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X")
return torch.softmax(a=-a, dim=dim, dtype=dtype) # type: ignore[call-overload] return torch.softmax(a=-a, dim=dim, dtype=dtype) # type: ignore[call-overload]
@ -469,7 +468,7 @@ def softshrink(a: TensorLikeType, lambd: float = 0.5):
# softshrink(x) = x - lambd if x > lambd # softshrink(x) = x - lambd if x > lambd
# = x + lambd if x < -lambd # = x + lambd if x < -lambd
# = 0 otherwise # = 0 otherwise
check( torch._check(
lambd >= 0, lambd >= 0,
lambda: f"lambda must be greater or equal to 0, but found to be {lambd}", lambda: f"lambda must be greater or equal to 0, but found to be {lambd}",
) )
@ -596,7 +595,7 @@ def log_softmax(
# deprecated. For PrimTorch, it's fine to drop support for deprecated # deprecated. For PrimTorch, it's fine to drop support for deprecated
# behavior because it requires explicit opt in. This error is to inform # behavior because it requires explicit opt in. This error is to inform
# users how to update their calls. # users how to update their calls.
check(dim is not None, lambda: "implicit dim not supported, use dim=X") torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X")
return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload]
@ -668,12 +667,12 @@ def _nll_loss_nd(
reduction: str, reduction: str,
ignore_index: int, ignore_index: int,
) -> TensorLikeType: ) -> TensorLikeType:
utils.check( torch._check(
input.ndim > 0 and input.ndim <= 3, input.ndim > 0 and input.ndim <= 3,
lambda: f"Expected input dimension to be either [1, 2, 3] but received {input.ndim}.", lambda: f"Expected input dimension to be either [1, 2, 3] but received {input.ndim}.",
) )
utils.check( torch._check(
(input.ndim == 1) or (input.shape[0] == target.shape[0]), (input.ndim == 1) or (input.shape[0] == target.shape[0]),
lambda: f"Expected input batch size {input.shape[0]} to match target batch size {target.shape[0]}.", lambda: f"Expected input batch size {input.shape[0]} to match target batch size {target.shape[0]}.",
) )
@ -693,7 +692,7 @@ def _nll_loss_nd(
(flat_target >= 0), (flat_target < num_classes) (flat_target >= 0), (flat_target < num_classes)
) )
class_check = torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask)) class_check = torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask))
utils.check( torch._check(
isinstance(target, FakeTensor) or bool(class_check.item()), isinstance(target, FakeTensor) or bool(class_check.item()),
lambda: "A target class is out-of-bounds and not the ignore index.", lambda: "A target class is out-of-bounds and not the ignore index.",
) )
@ -758,7 +757,7 @@ def nll_loss(
""" """
Reference implementation of torch.nn.functional.nll_loss Reference implementation of torch.nn.functional.nll_loss
""" """
utils.check( torch._check(
input.ndim > 0, input.ndim > 0,
lambda: f"Expected input tensor to have 1 or more dimensions (got {input.ndim})", lambda: f"Expected input tensor to have 1 or more dimensions (got {input.ndim})",
) )
@ -796,9 +795,13 @@ def nll_loss(
# For ndim > 3, we reshape the input and target to 3-D case. # For ndim > 3, we reshape the input and target to 3-D case.
# Input (N batch-size, C classes, k-dimensions) # Input (N batch-size, C classes, k-dimensions)
# Target (N batch-size, k-dimensions) # Target (N batch-size, k-dimensions)
utils.check( torch._check(
input.ndim > 0 and target.ndim > 0 and target.shape[1:] == input.shape[2:], input.ndim > 0 and target.ndim > 0 and target.shape[1:] == input.shape[2:],
lambda: f"Expected target shape {out_size} but got {target.shape}", lambda: (
"Expected input and target to both have ndim > 0 and "
"target.shape[1:] == input.shape[2:], but got "
f"target.shape {target.shape} and input.shape {input.shape}"
),
) )
batch_size = input.shape[0] batch_size = input.shape[0]
@ -837,7 +840,7 @@ def huber_loss(
if type(reduction) is int: if type(reduction) is int:
reduction = _reduction_int_to_str(reduction) reduction = _reduction_int_to_str(reduction)
_check_reduction_value(reduction) # type: ignore[arg-type] _check_reduction_value(reduction) # type: ignore[arg-type]
check( torch._check(
delta > 0, delta > 0,
lambda: "huber_loss does not support non-positive values for delta.", lambda: "huber_loss does not support non-positive values for delta.",
) )
@ -938,7 +941,7 @@ def _triplet_margin_with_distance_loss(
a_dim = anchor.ndim a_dim = anchor.ndim
p_dim = positive.ndim p_dim = positive.ndim
n_dim = negative.ndim n_dim = negative.ndim
check( torch._check(
a_dim == p_dim and p_dim == n_dim, a_dim == p_dim and p_dim == n_dim,
lambda: ( lambda: (
f"The anchor, positive, and negative tensors are expected to have " f"The anchor, positive, and negative tensors are expected to have "
@ -1075,25 +1078,25 @@ def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType:
""" """
Reference implementation of torch.nn.functional.prelu Reference implementation of torch.nn.functional.prelu
""" """
check( torch._check(
isinstance(a, TensorLike), isinstance(a, TensorLike),
lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}", lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}",
) )
check( torch._check(
isinstance(weight, TensorLike), isinstance(weight, TensorLike),
lambda: f"prelu: Expected `weight` to be tensor, but got: {type(weight)}", lambda: f"prelu: Expected `weight` to be tensor, but got: {type(weight)}",
) )
if weight.numel() != 1: if weight.numel() != 1:
check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.") torch._check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.")
channel_size = a.shape[1] if a.ndim >= 2 else 1 channel_size = a.shape[1] if a.ndim >= 2 else 1
check( torch._check(
weight.numel() == channel_size, weight.numel() == channel_size,
lambda: f"Mismatch of parameter numbers and input channel size. Found parameter numbers =" lambda: f"Mismatch of parameter numbers and input channel size. Found parameter numbers ="
f" {weight.numel()} and channel size = {channel_size}.", f" {weight.numel()} and channel size = {channel_size}.",
) )
check( torch._check(
weight.ndim == 0 or weight.ndim == 1, weight.ndim == 0 or weight.ndim == 1,
lambda: f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: " lambda: f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: "
f"ndim = {weight.ndim}", f"ndim = {weight.ndim}",
@ -1132,7 +1135,7 @@ def relu6(a: TensorLikeType, inplace: bool = False) -> TensorLikeType:
) )
def glu(a: TensorLikeType, dim: int = -1) -> TensorLikeType: def glu(a: TensorLikeType, dim: int = -1) -> TensorLikeType:
dim = utils.canonicalize_dims(a.ndim, dim) dim = utils.canonicalize_dims(a.ndim, dim)
check( torch._check(
a.shape[dim] % 2 == 0, a.shape[dim] % 2 == 0,
lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}", lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}",
) )
@ -1160,8 +1163,8 @@ def pairwise_distance(
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
) )
def pdist(a: TensorLikeType, p: float = 2) -> TensorLikeType: def pdist(a: TensorLikeType, p: float = 2) -> TensorLikeType:
check(a.ndim == 2, lambda: f"pdist only supports 2D tensors, got: {a.ndim}D") torch._check(a.ndim == 2, lambda: f"pdist only supports 2D tensors, got: {a.ndim}D")
check(p >= 0, lambda: "pdist only supports non-negative p values") torch._check(p >= 0, lambda: "pdist only supports non-negative p values")
# For p == 2 we can use an efficient implementation, but other values of p # For p == 2 we can use an efficient implementation, but other values of p
# require creating a much bigger tensor for an intermediate step # require creating a much bigger tensor for an intermediate step
if p == 2: if p == 2:

View File

@ -148,7 +148,7 @@ def logit(self: TensorLikeType, eps: Optional[float] = None) -> TensorLikeType:
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
) )
def xlog1py(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]): def xlog1py(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]):
utils.check( torch._check(
isinstance(a, TensorLike) or isinstance(b, TensorLike), isinstance(a, TensorLike) or isinstance(b, TensorLike),
lambda: 'Expected either argument a or b to be a Tensor"', lambda: 'Expected either argument a or b to be a Tensor"',
) )

View File

@ -15,7 +15,6 @@ import torch._logging
from torch._guards import Source from torch._guards import Source
from torch._ops import OpOverload from torch._ops import OpOverload
from torch._prims_common import ( from torch._prims_common import (
check,
elementwise_dtypes, elementwise_dtypes,
ELEMENTWISE_TYPE_PROMOTION_KIND, ELEMENTWISE_TYPE_PROMOTION_KIND,
is_boolean_dtype, is_boolean_dtype,
@ -1495,7 +1494,7 @@ class FakeTensorMode(TorchDispatchMode):
) = FakeTensor._find_common_device(func, args, kwargs) ) = FakeTensor._find_common_device(func, args, kwargs)
if isinstance(e, FakeTensor): if isinstance(e, FakeTensor):
check( torch._check(
e.device == common_device, e.device == common_device,
lambda: f"FakeTensor is wrapped to wrong device, found {e.device}, expected {common_device}", lambda: f"FakeTensor is wrapped to wrong device, found {e.device}, expected {common_device}",
) )