From ee83c646bb30d5f11b64013b54174768b733214b Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Wed, 21 Jun 2023 00:46:14 +0000 Subject: [PATCH] Replace `_prims_common.check` with `torch._check*` (#103240) This relands most of the changes from #102219 which were backed out by #103128. However, instead of removing `_prims_common.check`, it adds a warning and a comment mentioning that it will be removed in the future and `torch._check*` should be used instead. As mentioned in https://github.com/pytorch/pytorch/pull/103128#pullrequestreview-1466414415, `_prims_common.check` cannot yet be removed because of some internal usage Part of #72948 Pull Request resolved: https://github.com/pytorch/pytorch/pull/103240 Approved by: https://github.com/albanD --- test/test_prims.py | 4 + torch/__init__.py | 2 +- torch/_decomp/decompositions.py | 84 +++--- torch/_inductor/cudagraph_trees.py | 7 +- torch/_meta_registrations.py | 406 +++++++++++++------------- torch/_prims/__init__.py | 26 +- torch/_prims_common/__init__.py | 41 ++- torch/_prims_common/wrappers.py | 7 +- torch/_refs/__init__.py | 219 +++++++------- torch/_refs/_conversions.py | 6 +- torch/_refs/fft.py | 32 +- torch/_refs/linalg/__init__.py | 25 +- torch/_refs/nn/functional/__init__.py | 61 ++-- torch/_refs/special/__init__.py | 2 +- torch/_subclasses/fake_tensor.py | 3 +- 15 files changed, 462 insertions(+), 463 deletions(-) diff --git a/test/test_prims.py b/test/test_prims.py index da1f5a101c3..ad1e63f58a1 100644 --- a/test/test_prims.py +++ b/test/test_prims.py @@ -1161,6 +1161,10 @@ $1 = torch._ops.prims.sin.default($0)""") def test_mul_complex(self): prims.mul(torch.randn(2), 1 + 1j) + def test_check_deprecation_warning(self): + with self.assertWarnsRegex(DeprecationWarning, 'will be removed in the future'): + torch._prims_common.check(True, lambda: 'message') + instantiate_device_type_tests(TestPrims, globals()) diff --git a/torch/__init__.py b/torch/__init__.py index 148336c0424..b4c45f72089 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -936,7 +936,7 @@ def is_warn_always_enabled(): # These error checking functions must be kept consistent with their C++ # equivalents. Their C++ equivalents are mentioned where applicable. -def _check_with(error_type, cond, message): +def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callable[[], str]): if not isinstance(cond, (builtins.bool, torch.SymBool)): raise TypeError(f'cond must be a bool, but got {type(cond)}') diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 0c307c86994..8715b1fe314 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -149,7 +149,7 @@ def fill_scalar(self, value): @register_decomposition([aten.fill.Tensor]) def fill_tensor(self, value: Tensor): - utils.check( + torch._check( value.dim() == 0, lambda: f"fill only supports 0-dimension value tensor but got tensor with {value.dim()} dimensions", ) @@ -785,14 +785,14 @@ def im2col( padding: List[int], stride: List[int], ) -> Tensor: - utils.check(len(kernel_size) == 2, lambda: "im2col(): only 2D kernel supported") - utils.check(len(dilation) == 2, lambda: "im2col(): only 2D dilation supported") - utils.check(len(padding) == 2, lambda: "im2col(): only 2D padding supported") - utils.check(len(stride) == 2, lambda: "im2col(): only 2D stride supported") + torch._check(len(kernel_size) == 2, lambda: "im2col(): only 2D kernel supported") + torch._check(len(dilation) == 2, lambda: "im2col(): only 2D dilation supported") + torch._check(len(padding) == 2, lambda: "im2col(): only 2D padding supported") + torch._check(len(stride) == 2, lambda: "im2col(): only 2D stride supported") def check_positive(param, param_name, strict=True): cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param) - utils.check( + torch._check( cond, lambda: "{param_name} should be greater {'than' zero, but got {param}" ) @@ -803,7 +803,7 @@ def im2col( shape = input.shape ndim = len(shape) - utils.check( + torch._check( ndim in (3, 4) and all(d != 0 for d in shape[-3:]), lambda: "Expected 3D or 4D (batch mode) tensor for input with possible 0 batch size " f"and non-zero dimensions, but got: {tuple(shape)}", @@ -814,7 +814,7 @@ def im2col( shape[-2:], padding, dilation, kernel_size, stride ) ) - utils.check( + torch._check( all(c > 0 for c in output_size), lambda: f"Given an input with spacial size {tuple(shape[-2:])}, " f"kernel_size={kernel_size}, dilation={dilation}, " @@ -869,15 +869,15 @@ def col2im( padding: List[int], stride: List[int], ) -> Tensor: - utils.check(len(output_size) == 2, lambda: "only 2D output_size supported") - utils.check(len(kernel_size) == 2, lambda: "only 2D kernel supported") - utils.check(len(dilation) == 2, lambda: "only 2D dilation supported") - utils.check(len(padding) == 2, lambda: "only 2D padding supported") - utils.check(len(stride) == 2, lambda: "only 2D stride supported") + torch._check(len(output_size) == 2, lambda: "only 2D output_size supported") + torch._check(len(kernel_size) == 2, lambda: "only 2D kernel supported") + torch._check(len(dilation) == 2, lambda: "only 2D dilation supported") + torch._check(len(padding) == 2, lambda: "only 2D padding supported") + torch._check(len(stride) == 2, lambda: "only 2D stride supported") def check_positive(param, param_name, strict=True): cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param) - utils.check( + torch._check( cond, lambda: "{param_name} should be greater than zero, but got {param}" ) @@ -889,13 +889,13 @@ def col2im( shape = input.shape ndim = len(shape) - utils.check( + torch._check( ndim in (2, 3) and all(d != 0 for d in shape[-2:]), lambda: "Expected 2D or 3D (batch mode) tensor for input with possible 0 batch size " f"and non-zero dimensions, but got: {tuple(shape)}", ) prod_kernel_size = kernel_size[0] * kernel_size[1] - utils.check( + torch._check( shape[-2] % prod_kernel_size == 0, lambda: "Expected size of input's first non-batch dimension to be divisible by the " f"product of kernel_size, but got input.shape[-2] = {shape[-2]} and " @@ -908,13 +908,13 @@ def col2im( ) ] L = col[0] * col[1] - utils.check( + torch._check( shape[-1] == L, lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, " f"dilation={dilation}, padding={padding}, stride={stride}, " f"expected input.size(-1) to be {L} but got {shape[-1]}.", ) - utils.check( + torch._check( L > 0, lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, " f"dilation={dilation}, padding={padding}, stride={stride}, " @@ -961,7 +961,7 @@ def col2im( def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float): # According to the CUDA kernel implementation we should have this test; # but it seems to fail tests! - # utils.check(mask.dtype == torch.bool, lambda: f"Mask should be Bool Scalar Type {mask.dtype}") + # torch._check(mask.dtype == torch.bool, lambda: f"Mask should be Bool Scalar Type {mask.dtype}") # Mimicking CUDA kernel's behavior for output stride: output follow input's memory format # This different from TensorIterator's behavior @@ -1221,21 +1221,21 @@ def native_group_norm_backward( ) utils.check_same_shape(input, grad_output, allow_cpu_scalar_tensors=False) utils.check_same_shape(mean, rstd, allow_cpu_scalar_tensors=False) - utils.check( + torch._check( input.numel() == N * C * HxW, lambda: f"Expect input to have { N * C * HxW} elements", ) - utils.check( + torch._check( mean.shape == (N, group), lambda: f"Expect mean to have shape ({N}, {group}, but got {mean.shape}", ) - utils.check( + torch._check( gamma is None or gamma.numel() == C, lambda: f"Expect gamma to have {C} elements but got {gamma.numel() if gamma is not None else -1}", ) cpg, _rem = divmod(C, group) - utils.check( + torch._check( _rem == 0, lambda: f"Expect number of channels {C} to be evenly-divisible by number of groups {group}", ) @@ -1834,12 +1834,12 @@ def adaptive_avg_pool2d(input: Tensor, output_size: Tuple[int, int]): device = input.device shape = input.shape ndim = len(shape) - utils.check( + torch._check( ndim in (3, 4), lambda: f"adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got {ndim}", ) for d in input.shape[-2:]: - utils.check( + torch._check( d != 0, lambda: "adaptive_avg_pool2d(): Expected input to have non-zero size for " f"non-batch dimensions, but input has shape {tuple(shape)}.", @@ -1966,13 +1966,13 @@ def _index_add( alpha: NumberType = 1, ): dim = utils.canonicalize_dims(x.ndim, dim) - utils.check( + torch._check( index.ndim <= 1, lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", ) if alpha != 1: python_type = utils.dtype_to_type(x.dtype) - utils.check( + torch._check( python_type == bool or utils.is_weakly_lesser_type(type(alpha), python_type), lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!", @@ -2005,7 +2005,7 @@ def _index_copy( x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike, *, inplace: bool ): dim = utils.canonicalize_dims(x.ndim, dim) - utils.check( + torch._check( index.ndim <= 1, lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", ) @@ -2060,19 +2060,19 @@ def uniform_(self, low=0, high=1, generator=None): def upsample_compute_output_size(input_size, output_size, scale_factors): spatial_dimensions = len(input_size) - 2 if output_size is not None: - utils.check( + torch._check( scale_factors is None, lambda: "Must specify exactly one of output_size and scale_factors", ) - utils.check(len(output_size) == spatial_dimensions, lambda: "") + torch._check(len(output_size) == spatial_dimensions, lambda: "") return output_size if scale_factors is not None: # NB: this isn't necessary lol - utils.check( + torch._check( output_size is None, lambda: "Must specify exactly one of output_size and scale_factors", ) - utils.check(len(scale_factors) == spatial_dimensions, lambda: "") + torch._check(len(scale_factors) == spatial_dimensions, lambda: "") output_size = [] for i, s in enumerate(scale_factors): if int(s) == s: @@ -2080,7 +2080,7 @@ def upsample_compute_output_size(input_size, output_size, scale_factors): else: output_size.append(sym_int(input_size[i + 2] * s)) return output_size - utils.check( + torch._check( False, lambda: "Must specify exactly one of output_size and scale_factors" ) @@ -2969,11 +2969,11 @@ def grid_sampler_2d( padding_mode: int = 0, align_corners: bool = False, ) -> Tensor: - utils.check( + torch._check( interpolation_mode in (0, 1, 2), lambda: f"Invalid interpolation mode {interpolation_mode}", ) - utils.check( + torch._check( padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}" ) @@ -3110,11 +3110,11 @@ def grid_sampler_2d( @out_wrapper() @pw_cast_for_opmath def mv(self, vec): - utils.check( + torch._check( self.dim() == 2 and vec.dim() == 1, lambda: f"matrix @ vector expected, got {self.dim()}, {vec.dim()}", ) - utils.check( + torch._check( self.size(1) == vec.size(0), lambda: f"size mismatch, got input ({self.size(0)}x{self.size(1)}), vec ({vec.size(0)})", ) @@ -3134,11 +3134,11 @@ def dot(self, other): elif other.is_conj(): return torch.vdot(other.conj(), self) - utils.check( + torch._check( self.dim() == 1 and other.dim() == 1, lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors", ) - utils.check( + torch._check( self.dtype == other.dtype, lambda: f"dot : expected both vectors to have same dtype, but found {self.dtype} and {other.dtype}", ) @@ -3149,7 +3149,7 @@ def dot(self, other): f"same number of elements, but got {self.numel()} and {other.numel()} elements respectively" ) - utils.check(self.numel() == other.numel(), numel_error) + torch._check(self.numel() == other.numel(), numel_error) return (self * other).sum() @@ -3296,7 +3296,7 @@ def matmul(tensor1, tensor2): return tensor1_expanded.bmm(tensor2_expanded).view(output_shape) else: - utils.check(False, lambda: "both arguments to matmul need to be at least 1D") + torch._check(False, lambda: "both arguments to matmul need to be at least 1D") @register_decomposition(aten.upsample_bicubic2d.default) @@ -3373,7 +3373,7 @@ def upsample_bicubic2d_vec( align_corners: bool, scale_factors: Optional[Tuple[float, float]] = None, ) -> Tensor: - utils.check( + torch._check( bool(output_size) + bool(scale_factors) == 1, lambda: "Must specify exactly one of output_size and scale_factors.", ) diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index e74fe631020..fccdad3bf32 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -72,7 +72,6 @@ from torch._inductor.compile_fx import ( remove_unaligned_input_idxs, static_input, ) -from torch._prims_common import check from torch.multiprocessing.reductions import StorageWeakRef from torch.storage import UntypedStorage from torch.utils import _pytree as pytree @@ -1071,7 +1070,7 @@ class CUDAGraphNode: self.output_storage_alias.append(UnaliasedStorage) continue - check( + torch._check( o.is_cuda, lambda: f"Expected all cuda outputs in cuda graph recording. Non cuda output from {self.stack_traces[i]}", ), @@ -1447,7 +1446,7 @@ class CUDAGraphNode: for idx in self.cudagraph_managed_idxs: inputs[idx] = None - check( + torch._check( self._check_liveness( self.expected_dead_indices_after_graph, self.path_weakrefs ), @@ -1522,7 +1521,7 @@ def check_memory_pool(device, pool_id, live_storages_ptrs: List[StorageWeakRefWr addr += block["size"] - check( + torch._check( len(unique_storages) == 0, lambda: f"These storage data ptrs are not allocated in pool {pool_id} but should be {unique_storages}", ) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index e3c4f1b3a2c..c3460456232 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -13,7 +13,6 @@ from torch._decomp import ( from torch._ops import OpOverload from torch._prims import _elementwise_meta, ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND from torch._prims_common import ( - check, corresponding_complex_dtype, corresponding_real_dtype, elementwise_dtypes, @@ -63,7 +62,7 @@ def toRealValueType(dtype): def check_inplace_broadcast(self_shape, *args_shape): broadcasted_shape = tuple(_broadcast_shapes(self_shape, *args_shape)) - check( + torch._check( broadcasted_shape == self_shape, lambda: f"output with shape {self_shape} doesn't match the broadcast shape {broadcasted_shape}", ) @@ -73,15 +72,14 @@ def check_inplace_broadcast(self_shape, *args_shape): @out_wrapper() def meta_take(self, index): # Type and device checks - check( + torch._check( index.dtype == torch.long, lambda: f"take(): Expected a long tensor for index, but got {index.dtype}", ) # Index checks - check( + torch._check_index( not (self.numel() == 0 and index.numel() != 0), lambda: "take(): tried to take from an empty tensor", - IndexError, ) return self.new_empty(index.shape) @@ -91,11 +89,11 @@ def meta_take(self, index): def linalg_cross(self, other, *, dim=-1): x_d = self.ndim y_d = other.ndim - check( + torch._check( x_d == y_d, lambda: "linalg.cross: inputs must have the same number of dimensions.", ) - check( + torch._check( self.size(dim) == 3 and other.size(dim) == 3, lambda: ( f"linalg.cross: inputs dimension {dim} must have length 3. " @@ -334,7 +332,7 @@ def linearSolveCheckInputs( A: Tensor, name: str, ): - check( + torch._check( self.device == A.device, lambda: ( f"Expected b and A to be on the same device, but found b on " @@ -342,7 +340,7 @@ def linearSolveCheckInputs( ), ) - check( + torch._check( self.dtype == A.dtype, lambda: ( f"Expected b and A to have the same dtype, but found b of type " @@ -350,7 +348,7 @@ def linearSolveCheckInputs( ), ) - check( + torch._check( A.size(-1) == A.size(-2), lambda: ( f"A must be batches of square matrices, " @@ -358,7 +356,7 @@ def linearSolveCheckInputs( ), ) - check( + torch._check( A.size(-1) == self.size(-2), lambda: ( f"Incompatible matrix sizes for {name}: each A " @@ -373,12 +371,12 @@ def checkFloatingOrComplex( t: Tensor, f_name: str, allow_low_precision_dtypes: bool = True ): dtype = t.dtype - check( + torch._check( t.is_floating_point() or t.is_complex(), lambda: f"{f_name}: Expected a floating point or complex tensor as input. Got {dtype}", ) if not allow_low_precision_dtypes: - check( + torch._check( dtype in (torch.float, torch.double, torch.cfloat, torch.cdouble), lambda: f"{f_name}: Low precision dtypes not supported. Got {dtype}", ) @@ -386,7 +384,7 @@ def checkFloatingOrComplex( # From aten/src/ATen/native/LinearAlgebraUtils.h def checkIsMatrix(A: Tensor, f_name: str, arg_name: str = "A"): - check( + torch._check( A.dim() >= 2, lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.", ) @@ -400,7 +398,7 @@ def checkInputsSolver( ): squareCheckInputs(A, f_name) checkIsMatrix(B, f_name) - check( + torch._check( A.size(-2) == B.size(-2) if left else A.size(-1) == B.size(-1), lambda: ( f"{f_name}: Incompatible shapes of A and B for the equation " @@ -413,7 +411,7 @@ def checkInputsSolver( def checkSameDevice( fn_name: str, result: Tensor, input: Tensor, result_name: str = "result" ): - check( + torch._check( result.device == input.device, lambda: ( f"{fn_name}: Expected {result_name} and input tensors to be on the same device, but got " @@ -424,7 +422,7 @@ def checkSameDevice( def checkUplo(UPLO: str): UPLO_uppercase = UPLO.upper() - check( + torch._check( len(UPLO) == 1 and (UPLO_uppercase == "U" or UPLO_uppercase == "L"), lambda: f"Expected UPLO argument to be 'L' or 'U', but got {UPLO}", ) @@ -477,20 +475,20 @@ def linalg_cholesky_ex(A: Tensor, upper: bool = False, check_errors: bool = Fals ) @out_wrapper() def linalg_householder_product(input: Tensor, tau: Tensor) -> Tensor: - check( + torch._check( input.ndim >= 2, lambda: "torch.linalg.householder_product: input must have at least 2 dimensions.", ) - check( + torch._check( input.size(-2) >= input.size(-1), lambda: "torch.linalg.householder_product: input.shape[-2] must be greater than or equal to input.shape[-1]", ) - check( + torch._check( input.size(-1) >= tau.size(-1), lambda: "torch.linalg.householder_product: input.shape[-1] must be greater than or equal to tau.shape[-1]", ) - check( + torch._check( input.ndim - tau.ndim == 1, lambda: ( f"torch.linalg.householder_product: Expected tau to have one dimension less than input, " @@ -500,7 +498,7 @@ def linalg_householder_product(input: Tensor, tau: Tensor) -> Tensor: if input.ndim > 2: expected_batch_tau_shape = input.shape[:-2] actual_batch_tau_shape = tau.shape[:-1] - check( + torch._check( actual_batch_tau_shape == expected_batch_tau_shape, lambda: ( f"torch.linalg.householder_product: Expected batch dimensions of tau to be " @@ -508,7 +506,7 @@ def linalg_householder_product(input: Tensor, tau: Tensor) -> Tensor: ), ) - check( + torch._check( tau.dtype == input.dtype, lambda: ( f"torch.linalg.householder_product: tau dtype {tau.dtype}" @@ -567,7 +565,7 @@ def linalg_ldl_solve_meta( squareCheckInputs(LD, "torch.linalg.ldl_solve") checkFloatingOrComplex(LD, "torch.linalg.ldl_solve") linearSolveCheckInputs(B, LD, "torch.linalg.ldl_solve") - check( + torch._check( B.ndim >= 2, lambda: ( f"torch.linalg.ldl_solve: Expected B to have at least 2 dimensions, " @@ -575,18 +573,18 @@ def linalg_ldl_solve_meta( ), ) expected_pivots_shape = LD.shape[:-1] - check( + torch._check( expected_pivots_shape == pivots.shape, lambda: ( f"torch.linalg.ldl_solve: Expected LD.shape[:-1] and pivots.shape to be the same, " f"but got pivots with shape {pivots.shape} instead" ), ) - check( + torch._check( utils.is_integer_dtype(pivots.dtype), lambda: f"torch.linalg.ldl_solve: Expected pivots to be integers. Got {pivots.dtype}", ) - check( + torch._check( LD.dtype == B.dtype, lambda: f"torch.linalg.ldl_solve: LD dtype {LD.dtype} does not match b dtype {B.dtype}", ) @@ -602,7 +600,7 @@ def linalg_ldl_solve_meta( @register_meta([aten.linalg_lu.default, aten.linalg_lu.out]) @out_wrapper("P", "L", "U") def linalg_lu_meta(A: Tensor, *, pivot: bool = True) -> Tuple[Tensor, Tensor, Tensor]: - check( + torch._check( A.ndim >= 2, lambda: f"linalg.lu: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead", ) @@ -632,7 +630,7 @@ def linalg_lu_meta(A: Tensor, *, pivot: bool = True) -> Tuple[Tensor, Tensor, Te def linalg_lu_factor_ex_meta( A: Tensor, *, pivot: bool = True, check_errors: bool = False ) -> Tuple[Tensor, Tensor, Tensor]: - check( + torch._check( A.ndim >= 2, lambda: f"torch.lu_factor: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead", ) @@ -672,14 +670,14 @@ def linalg_lu_solve_meta( ) -> Tensor: # dtype checkFloatingOrComplex(LU, "torch.linalg.lu_solve") - check( + torch._check( LU.dtype == B.dtype, lambda: ( f"linalg.lu_solve: Expected LU and B to have the same dtype, " f"but found LU of type {LU.dtype} and B of type {B.dtype} instead" ), ) - check( + torch._check( pivots.dtype == torch.int, lambda: "linalg.lu_solve: pivots should be a Tensor of scalar type torch.int32", ) @@ -687,13 +685,13 @@ def linalg_lu_solve_meta( # matrix shapes squareCheckInputs(LU, "torch.linalg.lu_solve") checkInputsSolver(LU, B, left, "linalg.lu_solve") - check( + torch._check( LU.size(-1) == pivots.size(-1), lambda: "linalg.lu_solve: Number of pivots per batch should be same as the dimension of the matrix", ) # batches - check( + torch._check( LU.shape[:-1] == pivots.shape, lambda: ( f"linalg.lu_solve: Expected LU.shape[:-1] and pivots.shape to be the same, " @@ -770,7 +768,7 @@ def _parse_qr_mode(mode: str) -> Tuple[bool, bool]: compute_q = False reduced = True # this is actually irrelevant in this mode else: - check( + torch._check( False, lambda: ( f"qr received unrecognized mode '{mode}' " @@ -1043,11 +1041,11 @@ def meta_pad2d_backward(grad_output, self, padding): output_h = input_h + pad_t + pad_b output_w = input_w + pad_l + pad_r - check( + torch._check( output_w == grad_output.shape[dim_w], lambda: f"gradOutput width unexpected. Expected: {output_w}, Got: {grad_output.shape[dim_w]}", ) - check( + torch._check( output_h == grad_output.shape[dim_h], lambda: f"gradOutput height unexpected. Expected: {output_h}, Got: {grad_output.shape[dim_h]}", ) @@ -1057,7 +1055,7 @@ def meta_pad2d_backward(grad_output, self, padding): @register_meta(aten.reflection_pad2d.default) def meta_pad2d(self, padding): valid_dims = self.size(1) != 0 and self.size(2) != 0 - check( + torch._check( (self.ndim == 3 and valid_dims) or (self.ndim == 4 and valid_dims and self.size(3) != 0), lambda: f"3D or 4D (batch mode) tensor expected for input, but got: {self}", @@ -1086,9 +1084,9 @@ def meta_baddbmm(self, batch1, batch2, *, beta=1, alpha=1): dim2 = batch1.size(1) dim3 = batch2.size(2) self = self.expand((dim1, dim2, dim3)) - check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor") - check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor") - check( + torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor") + torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor") + torch._check( self.dtype == batch1.dtype == batch2.dtype, lambda: f"Input dtypes must be the same, got: input: {self.dtype}, batch1: {batch1.dtype}, batch2: {batch2.dtype}", ) @@ -1096,7 +1094,7 @@ def meta_baddbmm(self, batch1, batch2, *, beta=1, alpha=1): batch2_sizes = batch2.shape bs = batch1_sizes[0] contraction_size = batch1_sizes[2] - check( + torch._check( batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size, lambda: ( f"Expected size for first two dimensions of batch2 tensor to be: " @@ -1140,7 +1138,7 @@ def meta__fused_moving_avg_obs_fq_helper( per_row_fake_quant=False, symmetric_quant=False, ): - check( + torch._check( ch_axis < self.dim(), lambda: "Error in fused_moving_avg_obs_fake_quant_cpu: ch_axis must be < self.dim()", ) @@ -1149,7 +1147,7 @@ def meta__fused_moving_avg_obs_fq_helper( def dot_check(self, other): - check( + torch._check( self.dim() == 1 and other.dim() == 1, lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors", ) @@ -1163,11 +1161,11 @@ def meta_dot(self, tensor): @register_meta([aten.mm.default]) def meta_mm(a, b): - check(a.dim() == 2, lambda: "a must be 2D") - check(b.dim() == 2, lambda: "b must be 2D") + torch._check(a.dim() == 2, lambda: "a must be 2D") + torch._check(b.dim() == 2, lambda: "b must be 2D") N, M1 = a.shape M2, P = b.shape - check( + torch._check( M1 == M2, lambda: f"a and b must have same reduction dim, but got [{N}, {M1}] X [{M2}, {P}].", ) @@ -1389,7 +1387,7 @@ if torch._C._has_mkldnn: # from check_dim_size() in aten/src/ATen/TensorUtils.cpp. def check_dim_size(tensor, dim, dim_size, size): - check( + torch._check( tensor.dim() == dim and tensor.shape[dim_size] == size, lambda: f"Expected a tensor of dimension {dim} and tensor.size[{dim_size}] == {size}, " + f"but got : dimension {tensor.dim()} and tensor.size[{dim_size}] = {tensor.shape[dim_size]}", @@ -1407,7 +1405,7 @@ def meta_avg_pool2d( divisor_override=None, ): def unpack(name, val): - check( + torch._check( len(val) in [1, 2], lambda: f"avg_pool2d: {name} must either be a single int, or a tuple of two ints", ) @@ -1416,7 +1414,7 @@ def meta_avg_pool2d( return H, W kH, kW = unpack("kernel_size", kernel_size) - check( + torch._check( len(stride) in [0, 1, 2], lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints", ) @@ -1429,7 +1427,7 @@ def meta_avg_pool2d( padH, padW = unpack("padding", padding) - check( + torch._check( divisor_override is None or divisor_override != 0, lambda: "divisor must be not zero", ) @@ -1530,26 +1528,26 @@ def meta_avg_pool2d_backward( divisor_override, ): # From aten/src/ATen/native/AveragePool2d.cpp structured kernel meta func. - check( + torch._check( len(kernel_size) == 1 or len(kernel_size) == 2, lambda: "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints", ) kH = kernel_size[0] kW = kH if len(kernel_size) == 1 else kernel_size[1] - check( + torch._check( len(stride) == 0 or len(stride) == 1 or len(stride) == 2, lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints", ) dH = kH if len(stride) == 0 else stride[0] dW = kW if len(stride) == 0 else dH if len(stride) == 1 else stride[1] - check( + torch._check( len(padding) == 1 or len(padding) == 2, lambda: "avg_pool2d: padding must either be a single int, or a tuple of two ints", ) padH = padding[0] padW = padH if len(padding) == 1 else padding[1] - check( + torch._check( divisor_override is None or divisor_override != 0, lambda: "divisor must be not zero", ) @@ -1602,7 +1600,7 @@ def meta_avg_pool3d( count_include_pad=True, divisor_override=None, ): - check( + torch._check( len(kernel_size) in (1, 3), lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints", ) @@ -1610,7 +1608,7 @@ def meta_avg_pool3d( kH = kT if len(kernel_size) == 1 else kernel_size[1] kW = kT if len(kernel_size) == 1 else kernel_size[2] - check( + torch._check( not stride or len(stride) in (1, 3), lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints", ) @@ -1618,7 +1616,7 @@ def meta_avg_pool3d( dH = kH if not stride else (dT if len(stride) == 1 else stride[1]) dW = kW if not stride else (dT if len(stride) == 1 else stride[2]) - check( + torch._check( len(padding) in (1, 3), lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints", ) @@ -1626,12 +1624,12 @@ def meta_avg_pool3d( padH = padT if len(padding) == 1 else padding[1] padW = padT if len(padding) == 1 else padding[2] - check( + torch._check( input.ndim in (4, 5), lambda: "non-empty 4D or 5D (batch mode) tensor expected for input", ) - check( + torch._check( not divisor_override or divisor_override != 0, lambda: "divisor must be not zero", ) @@ -1689,7 +1687,7 @@ def meta_avg_pool3d_backward( count_include_pad, divisor_override, ): - check( + torch._check( len(kernel_size) in (1, 3), lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints", ) @@ -1697,7 +1695,7 @@ def meta_avg_pool3d_backward( kH = kT if len(kernel_size) == 1 else kernel_size[1] kW = kT if len(kernel_size) == 1 else kernel_size[2] - check( + torch._check( not stride or len(stride) in (1, 3), lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints", ) @@ -1705,7 +1703,7 @@ def meta_avg_pool3d_backward( dH = kH if not stride else (dT if len(stride) == 1 else stride[1]) dW = kW if not stride else (dT if len(stride) == 1 else stride[2]) - check( + torch._check( len(padding) in (1, 3), lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints", ) @@ -1713,12 +1711,12 @@ def meta_avg_pool3d_backward( padH = padT if len(padding) == 1 else padding[1] padW = padT if len(padding) == 1 else padding[2] - check( + torch._check( input.ndim in (4, 5), lambda: "non-empty 4D or 5D (batch mode) tensor expected for input", ) - check( + torch._check( not divisor_override or divisor_override != 0, lambda: "divisor must be not zero", ) @@ -1759,7 +1757,7 @@ def meta_avg_pool3d_backward( @register_meta(aten._adaptive_avg_pool2d.default) def meta_adaptive_avg_pool2d(self, output_size): - check( + torch._check( self.ndim == 3 or self.ndim == 4, lambda: f"Expected 3D or 4D tensor, but got {self.shape}", ) @@ -1777,7 +1775,7 @@ def meta_adaptive_avg_pool2d(self, output_size): @register_meta(aten._adaptive_avg_pool3d.default) def meta_adaptive_avg_pool3d(self, output_size): - check( + torch._check( self.ndim == 4 or self.ndim == 5, lambda: f"Expected 4D or 5D tensor, but got {self.shape}", ) @@ -1788,16 +1786,16 @@ def meta_adaptive_avg_pool3d(self, output_size): def meta__adaptive_avg_pool2d_backward(grad_out, self): ndim = grad_out.ndim for i in range(1, ndim): - check( + torch._check( grad_out.size(i) > 0, lambda: f"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero \ size for non-batch dimensions, {grad_out.shape} with dimension {i} being empty", ) - check( + torch._check( ndim == 3 or ndim == 4, lambda: f"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got {self.shape}", ) - check( + torch._check( self.dtype == grad_out.dtype, lambda: f"expected dtype {self.dtype} for `grad_output` but got dtype {grad_out.dtype}", ) @@ -1852,30 +1850,28 @@ def nonzero_static(self, *, size: int, fill_value: int = -1): @register_meta([aten.index.Tensor, aten._unsafe_index.Tensor]) def meta_index_Tensor(self, indices): - check(indices, lambda: "at least one index must be provided") + torch._check(bool(indices), lambda: "at least one index must be provided") # aten::index is the internal advanced indexing implementation # checkIndexTensorTypes and expandTensors result: List[Optional[Tensor]] = [] for i, index in enumerate(indices): if index is not None: - check( + torch._check( index.dtype in [torch.long, torch.int, torch.int8, torch.bool], lambda: "tensors used as indices must be long, int, byte or bool tensors", ) if index.dtype in [torch.int8, torch.bool]: nonzero = index.nonzero() k = len(result) - check( + torch._check_index( k + index.ndim <= self.ndim, lambda: f"too many indices for tensor of dimension {self.ndim}", - IndexError, ) for j in range(index.ndim): - check( + torch._check_index( index.shape[j] == self.shape[k + j], lambda: f"The shape of the mask {index.shape} at index {i} " f"does not match the shape of the indexed tensor {self.shape} at index {k + j}", - IndexError, ) result.append(nonzero.select(1, j)) else: @@ -1883,7 +1879,7 @@ def meta_index_Tensor(self, indices): else: result.append(index) indices = result - check( + torch._check( len(indices) <= self.ndim, lambda: f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})", ) @@ -1988,20 +1984,20 @@ def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1): dim1 = batch1.size(1) dim2 = batch2.size(2) self = self.expand((dim1, dim2)) - check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor") - check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor") - check( + torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor") + torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor") + torch._check( batch1.size(0) == batch2.size(0), lambda: f"batch1 and batch2 must have same number of batches, got {batch1.size(0)} and {batch2.size(0)}", ) - check( + torch._check( batch1.size(2) == batch2.size(1), lambda: ( f"Incompatible matrix sizes for bmm ({batch1.size(1)}x{batch1.size(2)} " f"and {batch2.size(1)}x{batch2.size(2)})" ), ) - check( + torch._check( self.size(0) == dim1 and self.size(1) == dim2, lambda: "self tensor does not match matmul output shape", ) @@ -2015,7 +2011,7 @@ def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1): ] ) def meta__foreach_unaop_(self): - check( + torch._check( isinstance(self, List), lambda: f"Expect List[Tensor] but got {type(self)}", ) @@ -2029,7 +2025,7 @@ def meta__foreach_unaop_(self): ] ) def meta__foreach_unaop(self): - check( + torch._check( isinstance(self, List), lambda: f"Expect List[Tensor] but got {type(self)}", ) @@ -2037,14 +2033,14 @@ def meta__foreach_unaop(self): def _check_foreach_binop_tensor_lists(self, other): - check( + torch._check( isinstance(self, List) and isinstance(other, List), lambda: ( "The first two arguments of must be List[Tensor], " f"but got {type(self)} and {type(other)}." ), ) - check( + torch._check( len(self) > 0 and len(self) == len(other), lambda: ( "self and other must be non-empty and match in length, " @@ -2100,7 +2096,7 @@ def meta__foreach_binop_list(self, other): ] ) def meta__foreach_binop__scalar(self, scalar=1): - check( + torch._check( isinstance(self, List), lambda: f"The first argument of must be List[Tensor], but got {type(self)}.", ) @@ -2115,7 +2111,7 @@ def meta__foreach_binop__scalar(self, scalar=1): ] ) def meta__foreach_binop_scalar(self, scalar=1): - check( + torch._check( isinstance(self, List), lambda: f"The first argument of must be List[Tensor], but got {type(self)}.", ) @@ -2129,15 +2125,15 @@ def meta__foreach_binop_scalar(self, scalar=1): ] ) def meta__foreach_addcop__scalar(self, tensor1, tensor2, scalar=1): - check( + torch._check( all(isinstance(l, List) for l in [self, tensor1, tensor2]), lambda: ( "All arguments of _foreach_addc*_ must be List[Tensor], " f"but got {type(self)}, {type(tensor1)}, and {type(tensor2)}" ), ) - check(len(self) > 0, lambda: "input tensor list must not be empty.") - check( + torch._check(len(self) > 0, lambda: "input tensor list must not be empty.") + torch._check( len(self) == len(tensor1) and len(self) == len(tensor2), lambda: "All input tensor lists must have the same length", ) @@ -2150,15 +2146,15 @@ def meta__foreach_addcop__scalar(self, tensor1, tensor2, scalar=1): ] ) def meta__foreach_addcop_scalar(self, tensor1, tensor2, scalar=1): - check( + torch._check( all(isinstance(l, List) for l in [self, tensor1, tensor2]), lambda: ( "All arguments must be List[Tensor], " f"but got {type(self)}, {type(tensor1)}, and {type(tensor2)}" ), ) - check(len(self) > 0, lambda: "input tensor list must not be empty.") - check( + torch._check(len(self) > 0, lambda: "input tensor list must not be empty.") + torch._check( len(self) == len(tensor1) and len(self) == len(tensor2), lambda: "All input tensor lists must have the same length", ) @@ -2168,7 +2164,7 @@ def meta__foreach_addcop_scalar(self, tensor1, tensor2, scalar=1): @register_meta([aten._foreach_pow.ScalarAndTensor]) def meta__foreach_pow_scalar_and_tensor(self, exponent): - check( + torch._check( isinstance(exponent, List), lambda: f"exponent must be a tensor list but got {type(exponent)}", ) @@ -2177,7 +2173,7 @@ def meta__foreach_pow_scalar_and_tensor(self, exponent): @register_meta([aten._foreach_addcdiv_.Tensor, aten._foreach_addcmul_.Tensor]) def meta__foreach_addcop_tensor(self, tensor1, tensor2, scalars): - check( + torch._check( all(isinstance(l, List) for l in [self, tensor1, tensor2]) and isinstance(scalars, torch.Tensor), lambda: ( @@ -2185,8 +2181,8 @@ def meta__foreach_addcop_tensor(self, tensor1, tensor2, scalars): f"but got: {type(self)}, {type(tensor1)}, {type(tensor2)}, and {type(scalars)}" ), ) - check(len(self) > 0, lambda: "input tensor list must not be empty.") - check( + torch._check(len(self) > 0, lambda: "input tensor list must not be empty.") + torch._check( len(self) == len(tensor1) and len(self) == len(tensor2), lambda: "All input tensor lists must have the same length", ) @@ -2212,7 +2208,7 @@ def meta__fused_adam_( found_inf=None, ): for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]: - check( + torch._check( isinstance(l, List), lambda: f"exponent must be a tensor list but got {type(l)}", ) @@ -2238,7 +2234,7 @@ def meta__fused_adam( found_inf=None, ): for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]: - check( + torch._check( isinstance(l, List), lambda: f"exponent must be a tensor list but got {type(l)}", ) @@ -2258,17 +2254,17 @@ def meta__fused_adam( @register_meta([aten._int_mm]) @out_wrapper() def meta__int_mm(a, b): - check(a.dim() == 2, lambda: "a must be a 2D tensor") - check(b.dim() == 2, lambda: "b must be a 2D tensor") - check( + torch._check(a.dim() == 2, lambda: "a must be a 2D tensor") + torch._check(b.dim() == 2, lambda: "b must be a 2D tensor") + torch._check( a.dtype is torch.int8, lambda: f"expected self to be int8, got {a.dtype}", ) - check( + torch._check( b.dtype is torch.int8, lambda: f"expected mat2 to be int8, got {b.dtype}", ) - check( + torch._check( a.size(1) == b.size(0), lambda: ( f"Incompatible matrix sizes for _int_mm ({a.size(0)}x{a.size(1)} " @@ -2280,28 +2276,28 @@ def meta__int_mm(a, b): @register_meta(aten._cdist_forward.default) def meta_cdist_forward(x1, x2, p, compute_mode): - check( + torch._check( x1.dim() >= 2, lambda: f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D", ) - check( + torch._check( x2.dim() >= 2, lambda: f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D", ) - check( + torch._check( x1.size(-1) == x2.size(-1), lambda: f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}", ) - check( + torch._check( utils.is_float_dtype(x1.dtype), lambda: "cdist only supports floating-point dtypes, X1 got: {x1.dtype}", ) - check( + torch._check( utils.is_float_dtype(x2.dtype), lambda: "cdist only supports floating-point dtypes, X2 got: {x2.dtype}", ) - check(p >= 0, lambda: "cdist only supports non-negative p values") - check( + torch._check(p >= 0, lambda: "cdist only supports non-negative p values") + torch._check( compute_mode in (None, 1, 2), lambda: f"possible modes: None, 1, 2, but was: {compute_mode}", ) @@ -2326,22 +2322,22 @@ def meta_embedding_bag( include_last_offset=False, padding_idx=-1, ): - check( + torch._check( indices.dtype in (torch.long, torch.int), lambda: f"expected indices to be long or int, got {indices.dtype}", ) - check( + torch._check( offsets.dtype in (torch.long, torch.int), lambda: f"expected offsets to be long or int, got {offsets.dtype}", ) - check( + torch._check( utils.is_float_dtype(weight.dtype), lambda: f"expected weight to be floating point type, got {weight.dtype}", ) num_bags = offsets.size(0) if include_last_offset: - check( + torch._check( num_bags >= 1, lambda: "include_last_offset: numBags should be at least 1", ) @@ -2351,19 +2347,19 @@ def meta_embedding_bag( MODE_SUM, MODE_MEAN, MODE_MAX = range(3) if per_sample_weights is not None: - check( + torch._check( mode == MODE_SUM, lambda: "embedding_bag: per_sample_weights only supported with mode='sum'", ) - check( + torch._check( per_sample_weights.dtype == weight.dtype, lambda: f"expected weight ({weight.dtype}) and per_sample_weights ({per_sample_weights.dtype}) to have same dtype", ) - check( + torch._check( per_sample_weights.ndim == 1, lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D", ) - check( + torch._check( per_sample_weights.numel() == indices.numel(), lambda: ( f"expected per_sample_weights.numel() ({per_sample_weights.numel()} " @@ -2408,7 +2404,7 @@ def meta_embedding_bag( numBags = offsets.shape[0] if mode == MODE_MAX: if include_last_offset: - check( + torch._check( numBags >= 1, lambda: "include_last_offset: numBags should be at least 1", ) @@ -2477,7 +2473,7 @@ def meta_logical_not_(self): @register_meta(aten.repeat.default) def meta_repeat(self, repeats): - check( + torch._check( len(repeats) >= self.dim(), lambda: "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor", ) @@ -2534,17 +2530,17 @@ def meta_round(self, **kwargs): def shift_dtype_check(fn_name, self, val): - check( + torch._check( utils.is_integer_dtype(self.dtype), lambda: f"{fn_name}: Expected input tensor to have an integral dtype. Got {self.dtype}", ) if isinstance(val, torch.Tensor): - check( + torch._check( utils.is_integer_dtype(val.dtype), lambda: f"{fn_name}: Expected shift value to have an integral dtype. Got {val.dtype}", ) else: - check( + torch._check( isinstance(val, IntLike), lambda: f"{fn_name}: Expected shift value to be an int. Got {val}", ) @@ -2620,8 +2616,8 @@ def meta_alias(self): def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None): - check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor") - check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor") + torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor") + torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor") batch1_sizes = batch1.size() batch2_sizes = batch2.size() @@ -2632,7 +2628,7 @@ def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None): res_cols = batch2_sizes[2] output_size = (bs, res_rows, res_cols) - check( + torch._check( batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size, lambda: f"Expected size for first two dimensions of batch2 tensor to be: [{bs}" f", {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}].", @@ -2643,8 +2639,8 @@ def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None): output = batch2.new_empty(output_size) if not is_bmm and self_baddbmm is not None: - check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor") - check( + torch._check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor") + torch._check( self_baddbmm.size() == output_size, lambda: f"Expected an input tensor shape with shape {output_size} but got shape: {self_baddbmm.size()}", ) @@ -2689,9 +2685,9 @@ def pooling_output_shape_pad_lr( def pooling_output_shape(inputSize, kernelSize, pad, stride, dilation, ceil_mode): - check(stride != 0, lambda: "stride should not be zero") - check(pad >= 0, lambda: f"pad must be non-negative, but got pad: {pad}") - check( + torch._check(stride != 0, lambda: "stride should not be zero") + torch._check(pad >= 0, lambda: f"pad must be non-negative, but got pad: {pad}") + torch._check( pad <= kernelSize // 2, lambda: f"pad should be at most half of kernel size, but got pad={pad} and kernel_size={kernelSize}", ) @@ -2720,15 +2716,15 @@ def pool2d_shape_check( ndim = input.dim() nOutputPlane = nInputPlane - check( + torch._check( kW > 0 and kH > 0, lambda: "kernel size should be greater than zero, but got kH: {kH}, kW: {kW}", ) - check( + torch._check( dW > 0 and dH > 0, lambda: "stride should be greater than zero, but got dH: {dH}, dW: {dW}", ) - check( + torch._check( dilationH > 0 and dilationW > 0, lambda: "dilation should be greater than zero, but got dilationH: {dilationH}, dilationW: {dilationW}", ) @@ -2736,25 +2732,25 @@ def pool2d_shape_check( valid_dims = input.size(1) != 0 and input.size(2) != 0 if memory_format == torch.channels_last: - check( + torch._check( ndim == 4 and valid_dims and input.size(3) != 0, lambda: "Expected 4D (batch mode) tensor expected for input with channels_last layout" " with optional 0 dim batch size for input, but got: {input.size()}", ) else: - check( + torch._check( (ndim == 3 and input.size(0) != 0 and valid_dims) or (ndim == 4 and valid_dims and input.size(3) != 0), lambda: f"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got: {input.size()}", ) - check( + torch._check( kW // 2 >= padW and kH // 2 >= padH, lambda: "pad should be smaller than or equal to half of kernel size, but got " f"padW = {padW}, padH = {padH}, kW = {kW}, kH = {kH}", ) - check( + torch._check( outputWidth >= 1 and outputHeight >= 1, lambda: f"Given input size: ({nInputPlane}x{inputHeight}x{inputWidth}). " f"Calculated output size: ({nOutputPlane}x{outputHeight}x{outputWidth}). " @@ -2788,21 +2784,21 @@ def pool3d_shape_check( ): ndim = input.ndim - check( + torch._check( kT > 0 and kW > 0 and kH > 0, lambda: ( f"kernel size should be greater than zero, but got " f"kT: {kT}, kH: {kH}, kW: {kW}" ), ) - check( + torch._check( dT > 0 and dW > 0 and dH > 0, lambda: ( f"stride should be greater than zero, but got " f"dT: {dT}, dH: {dH}, dW: {dW}" ), ) - check( + torch._check( dilationT > 0 and dilationW > 0 and dilationH > 0, lambda: ( f"dilation should be greater than zero, but got " @@ -2810,7 +2806,7 @@ def pool3d_shape_check( ), ) - check( + torch._check( ndim in (4, 5), lambda: f"{fn_name}: Expected 4D or 5D tensor for input, but got: {input.shape}", ) @@ -2819,7 +2815,7 @@ def pool3d_shape_check( if ndim == 5 and i == 0: # size of batch-dim can be 0. continue - check( + torch._check( input.size(i) > 0, lambda: ( f"{fn_name}: Expected input's non-batch dimensions to have positive length," @@ -2829,7 +2825,7 @@ def pool3d_shape_check( ) if check_input_size: # AveragePool3d - check( + torch._check( itime >= kT and iheight >= kH and iwidth >= kW, lambda: ( f"input image (T: {itime} H: {iheight} W: {iwidth}) smaller than " @@ -2837,7 +2833,7 @@ def pool3d_shape_check( ), ) - check( + torch._check( kT / 2 >= pT and kW / 2 >= pW and kH / 2 >= pH, lambda: ( f"pad should be smaller than or equal to half of kernel size, but got " @@ -2845,7 +2841,7 @@ def pool3d_shape_check( ), ) - check( + torch._check( otime >= 1 and owidth >= 1 and oheight >= 1, lambda: ( f"Given input size: ({nslices}x{itime}x{iheight}x{iwidth}). " @@ -2914,7 +2910,7 @@ def max_pool2d_checks_and_compute_shape( ): # Reference: aten/src/ATen/native/DilatedMaxPool2d.cpp def unpack(name, val): - check( + torch._check( len(val) in [1, 2], lambda: f"max_pool2d: {name} must either be a single int, or a tuple of two ints", ) @@ -2924,7 +2920,7 @@ def max_pool2d_checks_and_compute_shape( kH, kW = unpack("kernel_size", kernel_size) - check( + torch._check( len(stride) in [0, 1, 2], lambda: "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints", ) @@ -2941,17 +2937,17 @@ def max_pool2d_checks_and_compute_shape( memory_format = utils.suggest_memory_format(input) if memory_format == torch.channels_last: - check( + torch._check( input.dim() == 4, lambda: "non-empty 4D (batch mode) tensor expected for input with channels_last layout", ) elif memory_format == torch.contiguous_format: - check( + torch._check( input.dim() in [3, 4], lambda: "non-empty 3D or 4D (batch mode) tensor expected for input", ) else: - check( + torch._check( False, lambda: "Unsupport memory format. Supports only ChannelsLast, Contiguous", ) @@ -2999,7 +2995,7 @@ def meta_max_pool2d_with_indices_backward( self, kernel_size, stride, padding, dilation, ceil_mode ) - check( + torch._check( self.dtype == grad_output.dtype, lambda: f"Expected dtype {self.dtype} for `gradOutput` but got dtype {grad_output.dtype}", ) @@ -3093,7 +3089,7 @@ def zeros_like( memory_format=None, ): if layout == torch.sparse_coo: - check( + torch._check( memory_format is None, lambda: "memory format option is only supported by strided tensors", ) @@ -3131,20 +3127,18 @@ def zeros_like( @register_meta(aten.select.int) def meta_select(self, dim, index): ndim = self.dim() - check( + torch._check_index( ndim != 0, lambda: "select() cannot be applied to a 0-dim tensor.", - IndexError, ) dim = dim if dim >= 0 else dim + ndim size = self.size(dim) - check( + torch._check_index( not (-index > size or index >= size), lambda: f"select(): index {index} out of range for tensor of size " f"{self.size()} at dimension {dim}", - IndexError, ) index = index if index >= 0 else index + size @@ -3190,13 +3184,13 @@ def ensure_nonempty_size(t, dim): def gather_shape_check(self, dim, index): self_dims = max(self.dim(), 1) index_dims = max(index.dim(), 1) - check( + torch._check( self_dims == index_dims, lambda: "Index tensor must have the same number of dimensions as input tensor", ) for i in range(self_dims): if i != dim: - check( + torch._check( ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i), lambda: f"Size does not match at dimension {i} expected index {index.shape}" + f" to be smaller than self {self.shape} apart from dimension {dim}", @@ -3208,7 +3202,7 @@ def meta_gather(self, dim, index, sparse_grad=False): wrapped_dim = maybe_wrap_dim(dim, self.dim()) is_index_empty = index.numel() == 0 if not is_index_empty: - check( + torch._check( index.dtype == torch.long, lambda: f"gather(): Expected dtype int64 for index, but got {index.dtype}", ) @@ -3229,7 +3223,7 @@ def get_operator_enum(reduce_, use_new_options=False): return "REDUCE_MAXIMUM" elif reduce_ == "amin": return "REDUCE_MINIMUM" - check( + torch._check( False, lambda: "reduce argument must be either sum, prod, mean, amax or amin.", ) @@ -3239,20 +3233,20 @@ def get_operator_enum(reduce_, use_new_options=False): return "REDUCE_ADD" elif reduce_ == "multiply": return "REDUCE_MULTIPLY" - check(False, lambda: "reduce argument must be either add or multiply.") + torch._check(False, lambda: "reduce argument must be either add or multiply.") return # From aten/src/ATen/native/ScatterGatherChecks.h def scatter_gather_dtype_check(method_name, self, index, src_opt=None): if index.numel() != 0: - check( + torch._check( index.dtype == torch.long, lambda: f"{method_name}(): Expected dtype int64 for index", ) if src_opt is not None: - check( + torch._check( self.dtype == src_opt.dtype, lambda: f"{method_name}(): Expected self.dtype to be equal to src.dtype", ) @@ -3266,7 +3260,7 @@ def ensure_nonempty_dim(dim): def scatter_shape_check(self, dim, index, src_opt=None): if index.numel() == 0: return - check( + torch._check( ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()), lambda: "Index tensor must have the same number of dimensions as self tensor", ) @@ -3292,17 +3286,17 @@ def scatter_shape_check(self, dim, index, src_opt=None): break if src_opt is not None: - check( + torch._check( ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()), lambda: "Index tensor must have the same number of dimensions as self tensor", ) - check( + torch._check( not is_wrong_shape, lambda: f"Expected index {index.shape} to be smaller than self {self.shape}" + f" apart from dimension {dim} and to be smaller than src {src_opt.shape}", ) else: - check( + torch._check( not is_wrong_shape, lambda: f"Expected index {index.shape} to be smaller than self {self.shape}" + f" apart from dimension {dim}", @@ -3588,7 +3582,7 @@ def meta_scatter_reduce__two(self, dim, index, src, reduce, include_self=True): @register_meta([aten.multinomial.default, aten.multinomial.out]) @out_wrapper() def meta_multinomial(input, num_samples, replacement=False, *, generator=None): - check( + torch._check( 0 < input.dim() <= 2, lambda: f"The probabilty distributions dimensions must be 1 or 2, but got {input.dim()}", ) @@ -3607,17 +3601,17 @@ def multiply_integers(vs): def upsample_common_check(input_size, output_size, num_spatial_dims): - check( + torch._check( len(output_size) == num_spatial_dims, lambda: f"It is expected output_size equals to {num_spatial_dims}, but got size {len(output_size)}", ) expected_input_dims = num_spatial_dims + 2 # N, C, ... - check( + torch._check( len(input_size) == expected_input_dims, lambda: f"It is expected input_size equals to {expected_input_dims}, but got size {len(input_size)}", ) - check( + torch._check( all(s > 0 for s in input_size[2:]) and all(s > 0 for s in output_size), lambda: f"Input and output sizes should be greater than 0, but got " f"input size {input_size} and output size {output_size}", @@ -3629,7 +3623,7 @@ def upsample_common_check(input_size, output_size, num_spatial_dims): @register_meta(aten.upsample_nearest1d.default) def upsample_nearest1d(input, output_size, scales=None): - check( + torch._check( input.numel() != 0 or multiply_integers(input.size()[1:]), lambda: f"Non-empty 3D data tensor expected but got a tensor with sizes {input.size()}", ) @@ -3643,7 +3637,7 @@ def upsample_nearest1d(input, output_size, scales=None): @register_meta(aten.upsample_nearest2d.default) def upsample_nearest2d(input, output_size, scales_h=None, scales_w=None): - check( + torch._check( input.numel() != 0 or multiply_integers(input.size()[1:]), lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}", ) @@ -3676,12 +3670,12 @@ def upsample_nearest2d_backward( full_output_size = upsample_common_check( input_size, output_size, num_spatial_dims=2 ) - check( + torch._check( grad_output.ndim == 4, lambda: f"Expected grad_output to be a tensor of dimension 4 but got: dimension {grad_output.ndim}", ) for i in range(4): - check( + torch._check( grad_output.size(i) == full_output_size[i], lambda: ( f"Expected grad_output to have the same shape as output;" @@ -3697,7 +3691,7 @@ def upsample_nearest2d_backward( @register_meta(aten.upsample_nearest3d.default) def upsample_nearest3d(input, output_size, scales_d=None, scales_h=None, scales_w=None): - check( + torch._check( input.numel() != 0 or multiply_integers(input.size()[1:]), lambda: f"Non-empty 5D data tensor expected but got a tensor with sizes {input.size()}", ) @@ -3739,29 +3733,29 @@ def meta_sort(self, stable=None, dim=-1, descending=False, values=None, indices= def rnn_cell_checkSizes( input_gates, hidden_gates, input_bias, hidden_bias, factor, prev_hidden ): - check(input_gates.ndim == 2, lambda: f"{input_gates.ndim} != 2") - check( + torch._check(input_gates.ndim == 2, lambda: f"{input_gates.ndim} != 2") + torch._check( input_gates.shape == hidden_gates.shape, lambda: f"{input_gates.shape} != {hidden_gates.shape}", ) gates_size = input_gates.size(1) if input_bias is not None: - check(input_bias.ndim == 1, lambda: f"{input_bias.ndim} != 1") - check( + torch._check(input_bias.ndim == 1, lambda: f"{input_bias.ndim} != 1") + torch._check( input_bias.numel() == gates_size, lambda: f"{input_bias.numel()} != {gates_size}", ) - check( + torch._check( input_bias.shape == hidden_bias.shape, lambda: f"{input_bias.shape} != {hidden_bias.shape}", ) - check(prev_hidden.ndim == 2, lambda: f"{prev_hidden.ndim} != 2") + torch._check(prev_hidden.ndim == 2, lambda: f"{prev_hidden.ndim} != 2") expected_prev_hidden_numel = input_gates.size(0) * gates_size // factor - check( + torch._check( prev_hidden.numel() == expected_prev_hidden_numel, lambda: f"{prev_hidden.numel()} != {input_gates.size(0)} * {gates_size} // {factor} (aka {expected_prev_hidden_numel})", ) - check( + torch._check( all( x.device == input_gates.device for x in [hidden_gates, input_bias, hidden_bias, prev_hidden] @@ -3879,16 +3873,14 @@ def mkldnn_rnn_layer( def zero_numel_check_dims(self, dim, fn_name): if self.ndim == 0: - check( + torch._check_index( dim == 0 or dim == -1, lambda: f"{fn_name}: Expected reduction dim -1 or 0 for scalar but got {dim}", - IndexError, ) else: - check( + torch._check_index( self.size(dim) != 0, lambda: f"{fn_name}: Expected reduction dim {dim} to have non-zero size.", - IndexError, ) @@ -3898,7 +3890,7 @@ def check_argmax_argmin(name, self, dim): dim = maybe_wrap_dim(dim, self.dim()) zero_numel_check_dims(self, dim, name) else: - check( + torch._check( self.numel() != 0, lambda: f"{name}: Expected reduction dim to be specified for input.numel() == 0.", ) @@ -3923,12 +3915,12 @@ def scalar_tensor(s, dtype=None, layout=None, device=None, pin_memory=None): def topk_meta(self, k, dim=-1, largest=True, sorted=True): # From aten/src/ATen/native/Sorting.cpp dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True) - check( + torch._check( k >= 0 and k <= (self.size(dim) if self.dim() > 0 else 1), lambda: "selected index k out of range", ) sliceSize = 1 if self.dim() == 0 else self.size(dim) - check(k >= 0 and k <= sliceSize, lambda: "k not in range for dimension") + torch._check(k >= 0 and k <= sliceSize, lambda: "k not in range for dimension") topKSize = list(self.shape) if len(topKSize) > 0: @@ -3942,16 +3934,16 @@ legacy_contiguous_memory_format = torch.contiguous_format # From aten/src/ATen/native/cuda/RNN.cu def checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace): defined_grad = grad_hy if grad_hy is not None else grad_cy - check(defined_grad.dim() == 2, lambda: "") + torch._check(defined_grad.dim() == 2, lambda: "") exp_size = defined_grad.size() if grad_hy is not None: - check(grad_hy.size() == exp_size, lambda: "") + torch._check(grad_hy.size() == exp_size, lambda: "") if grad_cy is not None: - check(grad_cy.size() == exp_size, lambda: "") - check(cx.size() == exp_size, lambda: "") - check(cy.size() == exp_size, lambda: "") - check(workspace.dim() == 2, lambda: "") - check(workspace.numel() == exp_size[0] * exp_size[1] * 4, lambda: "") + torch._check(grad_cy.size() == exp_size, lambda: "") + torch._check(cx.size() == exp_size, lambda: "") + torch._check(cy.size() == exp_size, lambda: "") + torch._check(workspace.dim() == 2, lambda: "") + torch._check(workspace.numel() == exp_size[0] * exp_size[1] * 4, lambda: "") # From aten/src/ATen/native/cuda/RNN.cu @@ -4048,7 +4040,7 @@ def meta_upsample_bilinear2d_aa( full_output_size = upsample_common_check( input.size(), output_size, num_spatial_dims=2 ) - check( + torch._check( input.numel() != 0 or all(size > 0 for size in input.size()[1:]), lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}", ) @@ -4060,13 +4052,17 @@ def meta_upsample_bilinear2d_aa( # From aten/src/ATen/native/cuda/AmpKernels.cu @register_meta(aten._amp_foreach_non_finite_check_and_unscale_.default) def _amp_foreach_non_finite_check_and_unscale_(self, found_inf, inv_scale): - check(found_inf.numel() == 1, lambda: "found_inf must be a 1-element tensor.") - check(inv_scale.numel() == 1, lambda: "inv_scale must be a 1-element tensor.") - check( + torch._check( + found_inf.numel() == 1, lambda: "found_inf must be a 1-element tensor." + ) + torch._check( + inv_scale.numel() == 1, lambda: "inv_scale must be a 1-element tensor." + ) + torch._check( found_inf.dtype.is_floating_point, lambda: "found_inf must be a float tensor.", ) - check( + torch._check( inv_scale.dtype.is_floating_point, lambda: "inv_scale must be a float tensor.", ) diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index 29cb75cbb0b..65a408086af 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -16,7 +16,6 @@ from torch._prims.debug_prims import register_debug_prims from torch._prims.nvfuser_prims import register_nvprims from torch._prims.rng_prims import register_rng_prims from torch._prims_common import ( - check, Dim, DimsSequenceType, DimsType, @@ -422,7 +421,7 @@ def _elementwise_meta( def _complex_only_elementwise_meta(*args, **kwargs): - utils.check( + torch._check( utils.is_complex_dtype(args[0].dtype), lambda: "Only complex dtype is supported" ) return _elementwise_meta(*args, **kwargs) @@ -581,7 +580,7 @@ bitwise_not = _make_elementwise_unary_prim( def _cbrt_aten(a: torch.Tensor) -> Tensor: - utils.check( + torch._check( not a.is_complex(), lambda: "cbrt: Complex inputs not supported. Consider calling torch.pow(a, 1.0/3.0)", ) @@ -1293,10 +1292,9 @@ def _validate_collapse_args(a: Tensor, start: int, end: int) -> None: # Verifies end is strictly greater than start # (Collapse requires a non-empty interval) - utils.check( + torch._check_value( end >= start, lambda: f"Attempting to collapse but end, {end}, is less than start, {start}!", - ValueError, ) @@ -1823,7 +1821,7 @@ def _as_strided_scatter_meta( utils.validate_strides(stride) required_size = utils.compute_required_storage_length(size, stride, storage_offset) - utils.check( + torch._check( input.numel() >= required_size, lambda: ( f"as_strided_scatter: sizes {size}, strides {stride}, storage offset {storage_offset} " @@ -1832,7 +1830,7 @@ def _as_strided_scatter_meta( f"for storage of size {input.numel() * input.element_size()}" ), ) - utils.check( + torch._check( utils.is_same_shape(src.shape, size), lambda: f"expected src to have a size equal to the slice of self. src size = {src.shape}, slice size = {size}", ) @@ -2432,11 +2430,11 @@ def _iota_meta( device: torch.device, requires_grad: bool, ) -> TensorLikeType: - utils.check( + torch._check( utils.is_integer_dtype(dtype), lambda: "prims.iota only supports integer dtypes", ) - utils.check(step != 0, lambda: "step must be nonzero") + torch._check(step != 0, lambda: "step must be nonzero") return torch.empty( length, dtype=dtype, @@ -2532,7 +2530,7 @@ def _empty_permuted_meta( ) -> TensorLikeType: p_strides = utils.make_contiguous_strides_for([shape[l] for l in physical_layout]) dim = len(shape) - utils.check( + torch._check( len(physical_layout) == dim, lambda: ( "Number of dimensions in the tensor input does not match the " @@ -2543,7 +2541,7 @@ def _empty_permuted_meta( strides = [0] * len(shape) seen_dims = set() for p, l in enumerate(physical_layout): - utils.check( + torch._check( 0 <= l < dim, lambda: ( f"Dimension out of range (expected to be between 0 and {dim - 1}, but got " @@ -2551,7 +2549,7 @@ def _empty_permuted_meta( "not currently supported; file an issue if you want it." ), ) - utils.check(l not in seen_dims, lambda: "Duplicate dim not allowed") + torch._check(l not in seen_dims, lambda: "Duplicate dim not allowed") strides[l] = p_strides[p] seen_dims.add(l) return TensorMeta( @@ -2779,12 +2777,12 @@ def _normal_meta( device: torch.device, requires_grad: bool, ) -> TensorLikeType: - utils.check( + torch._check( std >= 0.0, lambda: f"expected non-negative standard deviation, but got std={std}", ) - utils.check( + torch._check( utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype), lambda: f"expected a floating-point or complex dtype, but got dtype={dtype}", ) diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index b0c86645243..080719754b3 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -7,6 +7,7 @@ from functools import reduce, cmp_to_key import operator import sympy import weakref +import warnings import torch from torch import sym_float, sym_int, sym_max @@ -268,7 +269,7 @@ _memory_formats = { def validate_memory_format(memory_format: torch.memory_format): - check( + torch._check( memory_format in _memory_formats, lambda: f"Received unknown memory format {memory_format}!", ) @@ -286,7 +287,7 @@ def is_contiguous_for_memory_format( # type: ignore[return] if memory_format == torch.channels_last_3d: return is_channels_last_contiguous_3d(a) - check( + torch._check( False, lambda: f"is_contiguous received unsupported memory format {memory_format}", ) @@ -795,13 +796,13 @@ def infer_size(shape: ShapeType, numel: int) -> Tuple[int, ...]: newsize = 1 for i, d in enumerate(shape): if d == -1: - check(dim is None, lambda: "only one dimension can be inferred") + torch._check(dim is None, lambda: "only one dimension can be inferred") dim = i elif d >= 0: newsize *= d else: - check(False, lambda: f"invalid shape dimension {d}") - check( + torch._check(False, lambda: f"invalid shape dimension {d}") + torch._check( numel == newsize or (dim is not None and newsize > 0 and numel % newsize == 0), lambda: f"shape '{list(shape)}' is invalid for input of size {numel}", ) @@ -809,7 +810,7 @@ def infer_size(shape: ShapeType, numel: int) -> Tuple[int, ...]: # Convert to list to produce a compatible error message with core # PyTorch, which prints sequences in square brackets. shape = list(shape) - check( + torch._check( newsize != 0, lambda: (f"cannot reshape tensor of 0 elements into shape {shape} because the " f"unspecified dimension size -1 can be any value and is ambiguous"), @@ -954,18 +955,18 @@ def check_fp_or_complex( Checks whether the input is floating point or complex. If allow_low_precision_dtypes is True, it allows having float16, bfloat16, and complex32 """ - check( + torch._check( is_float_dtype(dtype) or is_complex_dtype(dtype), lambda: f"{fn_name}: Expected a floating point or complex tensor as input. Got {dtype}", ) - check( + torch._check( allow_low_precision_dtypes or not is_low_precision_dtype(dtype), lambda: f"{fn_name}: Half precision dtypes not supported. Got {dtype}", ) def check_is_matrix(A: TensorLikeType, f_name: str, arg_name: str = "A"): - check( + torch._check( len(A.shape) >= 2, lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.", ) @@ -1060,11 +1061,11 @@ def get_higher_dtype( def check_pin_memory(pin_memory: bool): - check(not pin_memory, lambda: "PrimTorch does not support pinned memory", NotImplementedError) + torch._check_not_implemented(not pin_memory, lambda: "PrimTorch does not support pinned memory") def check_layout(layout: torch.layout): - check(layout == torch.strided, lambda: f"PrimTorch doesn't support layout={layout}", NotImplementedError) + torch._check_not_implemented(layout == torch.strided, lambda: f"PrimTorch doesn't support layout={layout}") # TODO: maybe unify with can_cast_to? @@ -1485,7 +1486,7 @@ def make_contiguous_strides_for( def make_channels_last_1d_strides_for(shape: ShapeType) -> Tuple[int, ...]: - check( + torch._check( len(shape) == 3, lambda: "Only tensors of rank 3 can use the channels_last_1d memory format", ) @@ -1503,7 +1504,7 @@ def make_channels_last_1d_strides_for(shape: ShapeType) -> Tuple[int, ...]: def make_channels_last_2d_strides_for(shape: ShapeType) -> Tuple[int, ...]: # TODO: maybe inform the user of channels_last_3d if rank of the tensor is 5? - check( + torch._check( len(shape) == 4, lambda: "Only tensors of rank 4 can use the channels_last memory format", ) @@ -1520,7 +1521,7 @@ def make_channels_last_2d_strides_for(shape: ShapeType) -> Tuple[int, ...]: def make_channels_last_3d_strides_for(shape: ShapeType) -> Tuple[int, ...]: - check( + torch._check( len(shape) == 5, lambda: "Only tensors of rank 5 can use the channels_last_3d memory format", ) @@ -1654,6 +1655,9 @@ def check_in_bounds_for_storage( raise ValueError(msg) +# NOTE: This function should ideally be removed, but some Meta internal models +# packaged with `torch.package` are using it, so it will have to be removed +# at some point in the future when those models no longer use this function. def check( b: bool, s: Callable[[], str], exc_type: Type[Exception] = RuntimeError ) -> None: @@ -1662,9 +1666,14 @@ def check( Error message is a callable producing a string (to avoid wasting time string formatting in non-error case, and also to make it easier for torchdynamo to trace.) + + .. note:: This function is planned for removal in the future. Please use + `torch._check*` functions instead. """ - if not b: - raise exc_type(s()) + warnings.warn(DeprecationWarning(( + "'torch._prims_common.check' will be removed in the future. Please use " + "'torch._check*' functions instead"))) + torch._check_with(exc_type, b, s) # This combines is_channels_last_strides_2d and is_channels_last_strides_3d in diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py index 5f10a2a5170..041fb764c8b 100644 --- a/torch/_prims_common/wrappers.py +++ b/torch/_prims_common/wrappers.py @@ -176,13 +176,13 @@ def _safe_copy_out( # Checks safe cast if exact_dtype: - utils.check( + torch._check( copy_from.dtype == copy_to.dtype, lambda: f"Expected out tensor to have dtype {copy_from.dtype} " f"but got {copy_to.dtype} instead", ) else: - utils.check( + torch._check( utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype), lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, " "but this can't be cast because it is not safe!", @@ -255,10 +255,9 @@ def out_wrapper(*out_names: str, exact_dtype: bool = False): _safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type] else: assert isinstance(out, Tuple) # type: ignore[arg-type] - utils.check( + torch._check_type( len(out) == len(result), lambda: f"expected tuple of {len(result)} elements but got {len(out)}", - TypeError, ) for r, o in zip(result, out): # These two operations are done in-place diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index f1778eb6059..a1e66cbdaa9 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -16,7 +16,6 @@ import torch._prims as prims import torch._prims_common as utils from torch import sym_float, sym_int from torch._prims_common import ( - check, DeviceLikeType, Dim, DimsSequenceType, @@ -626,7 +625,7 @@ def frac(x: TensorLikeType) -> TensorLikeType: # imag does not use _make_elementwise_unary_reference because it does not support out def imag(a: TensorLikeType) -> TensorLikeType: assert isinstance(a, TensorLike) - utils.check( + torch._check( utils.is_complex_dtype(a.dtype), lambda: "imag only supports complex tensors." ) return prims.imag(a) @@ -654,7 +653,7 @@ def isinf(a: TensorLikeType) -> TensorLikeType: @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) def isposinf(a: TensorLikeType) -> TensorLikeType: - utils.check( + torch._check( not utils.is_complex_dtype(a.dtype), lambda: f"Complex dtype is not supported for isposinf, got dtype {a.dtype}", ) @@ -665,7 +664,7 @@ def isposinf(a: TensorLikeType) -> TensorLikeType: @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) def isneginf(a: TensorLikeType) -> TensorLikeType: - utils.check( + torch._check( not utils.is_complex_dtype(a.dtype), lambda: f"Complex dtype is not supported for isneginf, got dtype {a.dtype}", ) @@ -788,7 +787,7 @@ def nan_to_num( def _neg_meta(a: TensorLikeType): - check( + torch._check( a.dtype is not torch.bool, lambda: ( "Negation, the `-` operator, on a bool tensor is not supported. " @@ -935,23 +934,20 @@ def _make_elementwise_binary_reference( a: Union[Tensor, NumberType], b: Union[Tensor, NumberType], ) -> Tensor: - check( + torch._check_value( supports_lhs_python_scalar or not isinstance(a, Number), lambda: f"{name}: Received a lhs Python scalar to an elementwise binary " "operation that does not accept lhs scalars!", - ValueError, ) - check( + torch._check_value( supports_rhs_python_scalar or not isinstance(b, Number), lambda: f"{name}: Received a rhs Python scalar to an elementwise binary " "operation that does not accept rhs scalars!", - ValueError, ) - check( + torch._check_value( supports_two_python_scalars or not (isinstance(a, Number) and isinstance(b, Number)), lambda: f"{name}: Receive two Number inputs to an elementwise binary operation!", - ValueError, ) a, b = _maybe_broadcast(a, b) return prim(a, b) @@ -1230,7 +1226,7 @@ def floor_divide( elif utils.is_integer_dtype(dtype): return _floor_divide_integer(a, b) else: - check(False, lambda: f"{dtype} not supported for floor_divide") + torch._check(False, lambda: f"{dtype} not supported for floor_divide") def _floor_divide_integer(a: Tensor, b: Tensor) -> Tensor: @@ -1374,20 +1370,19 @@ def _check_close_args( rtol: float, atol: float, ) -> None: - check( + torch._check_value( a.dtype == b.dtype, lambda: "{0}: Attempting to compare tensors of different dtypes {1} and {2}!".format( name, a.dtype, b.dtype ), - ValueError, ) - check( + torch._check( rtol >= 0, lambda: "{0}: rtol must be greater than or equal to zero, but got {1}!".format( name, rtol ), ) - check( + torch._check( atol >= 0, lambda: "{0}: atol must be greater than or equal to zero, but got {1}!".format( name, atol @@ -1678,7 +1673,7 @@ def true_divide(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, ) def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]): - utils.check( + torch._check( isinstance(a, TensorLike) or isinstance(b, TensorLike), lambda: 'Expected either argument a or b to be a Tensor"', ) @@ -1736,12 +1731,11 @@ def addcdiv( if value is not None: dtype = self.dtype # no scalars allowed, see add python_type = utils.dtype_to_type(dtype) - check( + torch._check_value( utils.is_weakly_lesser_type(type(value), python_type), lambda: "value argument of type {0} cannot be safely cast to type {1}!".format( type(value), python_type ), - exc_type=ValueError, ) return self + value * tensor1 / tensor2 @@ -1766,12 +1760,11 @@ def addcmul( if value is not None: dtype = self.dtype # no scalars allowed, see add python_type = utils.dtype_to_type(dtype) - check( + torch._check_value( utils.is_weakly_lesser_type(type(value), python_type), lambda: "value argument of type {0} cannot be safely cast to type {1}!".format( type(value), python_type ), - exc_type=ValueError, ) return self + value * tensor1 * tensor2 @@ -1851,7 +1844,7 @@ def where( raise NotImplementedError utils.check_same_device(pred, a, b, allow_cpu_scalar_tensors=True) - check( + torch._check( pred.dtype is torch.bool, lambda: f"expected predicate to be bool, got {pred.dtype}", ) @@ -2229,7 +2222,7 @@ def sum_to_size( *shape, ) -> Tensor: shape = utils.extract_shape_from_varargs(shape, validate=False) - utils.check( + torch._check( utils.is_expandable_to(shape, a.shape), lambda: f'sum_to_size: size "{shape}" is not expandable to size "{a.shape}"', ) @@ -2402,7 +2395,7 @@ def mean( if dtype is None: dtype = a.dtype # can't use out wrapper because of this argument - check( + torch._check( out is None or out.dtype == dtype, lambda: f"Expected out tensor to have dtype {dtype}, but got {out.dtype} instead", ) @@ -2415,7 +2408,7 @@ def mean( out=None, output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE, ) - check( + torch._check( utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype), lambda: ( f"mean(): could not infer output dtype. " @@ -2491,22 +2484,22 @@ def addr( beta: NumberType = 1, alpha: NumberType = 1, ) -> TensorLikeType: - check( + torch._check( vec1.ndim == 1, lambda: f"addr: Expected 1-D argument vec1, but got {vec1.ndim}-D", ) - check( + torch._check( vec2.ndim == 1, lambda: f"addr: Expected 1-D argument vec2, but got {vec2.ndim}-D", ) self = self.expand(vec1.shape[0], vec2.shape[0]) if utils.is_boolean_dtype(self.dtype): # Integers are accepted for booleans - check( + torch._check( is_weakly_lesser_type(type(beta), int), lambda: f"expected bool/int beta but got {type(beta)}", ) - check( + torch._check( is_weakly_lesser_type(type(alpha), int), lambda: f"expected bool/int alpha but got {type(beta)}", ) @@ -2518,11 +2511,11 @@ def addr( torch.outer(vec1, vec2) if alpha else torch.full_like(self, False), ) else: - check( + torch._check( is_weakly_lesser_type(type(beta), dtype_to_type(self.dtype)), lambda: f"cannot safely convert {type(beta)} to {self.dtype}", ) - check( + torch._check( is_weakly_lesser_type(type(alpha), dtype_to_type(self.dtype)), lambda: f"cannot safely convert {type(alpha)} to {self.dtype}", ) @@ -2712,7 +2705,7 @@ def conj(input: TensorLikeType) -> TensorLikeType: def constant_pad_nd( input: TensorLikeType, pad: List[int], value: NumberType = 0 ) -> TensorLikeType: - check( + torch._check( len(pad) % 2 == 0, lambda: f"Length of pad must be even but instead it equals {len(pad)}", ) @@ -2723,7 +2716,7 @@ def constant_pad_nd( l_pad = len(pad) // 2 l_diff = l_inp - l_pad - check( + torch._check( l_inp >= l_pad, lambda: "Length of pad should be no more than twice the number of " f"dimensions of the input. Pad length is {len(pad)} while the input has " @@ -2748,7 +2741,7 @@ def constant_pad_nd( for i in range(l_pad): pad_idx = len(pad) - ((i + 1) * 2) new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1] - check( + torch._check( new_dim > 0, lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding " f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, " @@ -2787,7 +2780,7 @@ def constant_pad_nd( def contiguous( a: Tensor, *, memory_format: torch.memory_format = torch.contiguous_format ) -> Tensor: - check( + torch._check( memory_format != torch.preserve_format, lambda: "preserve memory format is unsupported by the contiguous operator", ) @@ -2800,7 +2793,7 @@ def contiguous( @out_wrapper() def dstack(tensors: TensorSequenceType) -> TensorLikeType: - check(len(tensors) > 0, lambda: "dstack expects a non-empty TensorList") + torch._check(len(tensors) > 0, lambda: "dstack expects a non-empty TensorList") aligned_tensors = atleast_3d(*tensors) return cat(aligned_tensors, 2) @@ -2813,7 +2806,7 @@ def expand(a: Tensor, *shape) -> Tensor: if len(shape) == 1 and isinstance(shape[0], Sequence): shape = tuple(shape[0]) - check( + torch._check( len(shape) >= len(a.shape), lambda: "expand: the requested shape has too few dimensions!", ) @@ -2823,7 +2816,7 @@ def expand(a: Tensor, *shape) -> Tensor: for idx, x in enumerate(a.shape): offset_idx = idx + offset requested_length = shape[offset_idx] - check( + torch._check( requested_length == x or x == 1 or requested_length == -1, lambda: f"expand: attempting to expand a dimension of length {x}!", ) @@ -2917,13 +2910,13 @@ def narrow( # Supports Tensor overload that was added for XLA: # https://github.com/pytorch/pytorch/issues/31558 if isinstance(start, TensorLike): - check( + torch._check( start.dim() == 0 and utils.is_integer_dtype(start.dtype), lambda: "start must be an 0-dim integral Tensor.", ) start = start.item() # type: ignore[assignment] - check(a.dim() > 0, lambda: "narrow() cannot be applied to a 0-dim tensor.") - check(length >= 0, lambda: "narrow(): length must be non-negative.") + torch._check(a.dim() > 0, lambda: "narrow() cannot be applied to a 0-dim tensor.") + torch._check(length >= 0, lambda: "narrow(): length must be non-negative.") dim = utils.canonicalize_dim(a.ndim, dim) dim_length = a.size(dim) # Start being the end is usually invalid since it's out of bounds. So it's @@ -2934,7 +2927,7 @@ def narrow( # Note: a dimension isn't being canonicalized here, this reuses # canonicalize_dim because the semantics are similar. start = utils.canonicalize_dim(dim_length, start) # type: ignore[arg-type] - check( + torch._check( start <= dim_length - length, # type: ignore[arg-type] lambda: f"start ({start}) + length ({length}) exceeds dimension size ({dim_length}).", ) @@ -2993,11 +2986,11 @@ def native_group_norm( num_groups: int, eps: float, ) -> Tuple[Tensor, Tensor, Tensor]: - utils.check( + torch._check( input.ndim >= 2, lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}", ) - utils.check( + torch._check( num_channels % num_groups == 0, lambda: "Expected number of channels in input to be divisible by num_groups, " + f"but got input of shape {input.shape} and num_groups = {num_groups}", @@ -3044,7 +3037,7 @@ def native_layer_norm( eps: float, ) -> Tuple[Tensor, Tensor, Tensor]: normalized_ndim = len(normalized_shape) - utils.check( + torch._check( normalized_ndim >= 1, lambda: "Expected normalized_shape to be at least 1-dimensional, i.e., " + "containing at least one element, but got normalized_shape = " @@ -3053,7 +3046,7 @@ def native_layer_norm( # torch.Size([1, 2, 3]) == [1, 2, 3] evaluates to False # while torch.Size([1, 2, 3]) == (1, 2, 3) is True # therefore we use tuple(normalized_shape) - utils.check( + torch._check( weight is None or weight.shape == tuple(normalized_shape), lambda: "Expected weight to be of same shape as normalized_shape, but got " + "weight of shape " @@ -3061,7 +3054,7 @@ def native_layer_norm( + " and normalized_shape = " + str(normalized_shape), ) - utils.check( + torch._check( bias is None or bias.shape == tuple(normalized_shape), lambda: "Expected bias to be of same shape as normalized_shape, but got " + "bias of shape " @@ -3069,7 +3062,7 @@ def native_layer_norm( + " and normalized_shape = " + str(normalized_shape), ) - utils.check( + torch._check( input.ndim >= normalized_ndim and input.shape[(input.ndim - normalized_ndim) :] == tuple(normalized_shape), lambda: "Given normalized_shape=" @@ -3123,12 +3116,12 @@ def _get_unfold_shape_stride( max_size = 1 if a_ndim == 0 else a_shape[dim] last_stride = 1 if a_ndim == 0 else a_stride[dim] - utils.check( + torch._check( size <= max_size, lambda: f"Maximum size for tensor at dimension {dim} is {max_size} but size is {size}", ) - utils.check( + torch._check( step > 0, lambda: f"Step is {step} but must be > 0", ) @@ -3146,7 +3139,7 @@ def _get_unfold_shape_stride( @register_decomposition(aten.repeat) def repeat(a: Tensor, *repeat_shape) -> Tensor: repeat_shape = utils.extract_shape_from_varargs(repeat_shape, validate=False) - utils.check( + torch._check( len(repeat_shape) >= len(a.shape), lambda: "repeat: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor", ) @@ -3452,7 +3445,7 @@ def softmax( # CompositeImplicitAutograd - don't register decomp @out_wrapper() def hstack(tensors: TensorSequenceType) -> TensorLikeType: - check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList") + torch._check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList") aligned_tensors = atleast_1d(*tensors) if aligned_tensors[0].ndim == 1: return cat(aligned_tensors, 0) @@ -3462,7 +3455,7 @@ def hstack(tensors: TensorSequenceType) -> TensorLikeType: # CompositeImplicitAutograd - don't register decomp @out_wrapper() def vstack(tensors: TensorSequenceType) -> TensorLikeType: - check(len(tensors) > 0, lambda: "vstack expects a non-empty TensorList") + torch._check(len(tensors) > 0, lambda: "vstack expects a non-empty TensorList") aligned_tensors = atleast_2d(*tensors) return cat(aligned_tensors, 0) @@ -3470,17 +3463,16 @@ def vstack(tensors: TensorSequenceType) -> TensorLikeType: # CompositeImplicitAutograd - don't register decomp def unflatten(a: TensorLikeType, dim: int, sizes: ShapeType) -> TensorLikeType: dim = utils.canonicalize_dim(a.ndim, dim) - utils.check(len(sizes) != 0, lambda: "unflatten: sizes must be non-empty") + torch._check(len(sizes) != 0, lambda: "unflatten: sizes must be non-empty") return a.view(tuple(a.shape[:dim]) + tuple(sizes) + tuple(a.shape[dim + 1 :])) @register_decomposition(aten.unbind) def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType: dim = utils.canonicalize_dim(t.ndim, dim) - check( + torch._check_index( len(t.shape) > 0, lambda: "Dimension specified as 0 but tensor has no dimensions", - IndexError, ) if t.shape[dim] == 0: return tuple() @@ -3499,7 +3491,7 @@ def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike): def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike): dim = utils.canonicalize_dims(x.ndim, dim) - utils.check( + torch._check( index.ndim <= 1, lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", ) @@ -3532,12 +3524,12 @@ def _index_fill( *, inplace: bool, ): - utils.check( + torch._check( index.ndim <= 1, lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", ) if isinstance(value, TensorLike): - utils.check( + torch._check( value.ndim == 0, lambda: "Only supports 0-dimensional value tensor. " # type: ignore[union-attr] f"Got a tensor with {value.ndim} dimensions.", @@ -3589,7 +3581,7 @@ def index_add( @out_wrapper() def index_select(x: TensorLike, dim: int, index: TensorLike): dim = utils.canonicalize_dims(x.ndim, dim) - utils.check( + torch._check( index.ndim <= 1, lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", ) @@ -3713,7 +3705,7 @@ def tensor_split( def hsplit( a: TensorLikeType, indices_or_sections: DimsType ) -> Tuple[TensorLikeType, ...]: - check( + torch._check( a.ndim >= 1, lambda: ( "torch.hsplit requires a tensor with at least 1 dimension, but got a tensor with " @@ -3724,7 +3716,7 @@ def hsplit( dim = 0 if a.ndim == 1 else 1 if isinstance(indices_or_sections, IntLike): split_size = indices_or_sections - check( + torch._check( (split_size != 0 and a.shape[dim] % split_size == 0), lambda: ( "torch.hsplit attempted to split along dimension " @@ -3738,14 +3730,13 @@ def hsplit( ) return tensor_split(a, split_size, dim) - check( + torch._check_type( isinstance(indices_or_sections, (list, tuple)), lambda: ( "hsplit(): received an invalid combination of arguments. " "Expected indices_or_sections to be of type int, list of ints or tuple of ints " f"but got type {type(indices_or_sections)}" ), - exc_type=TypeError, ) split_sizes = indices_or_sections @@ -3756,7 +3747,7 @@ def hsplit( def vsplit( a: TensorLikeType, indices_or_sections: DimsType ) -> Tuple[TensorLikeType, ...]: - check( + torch._check( a.ndim >= 2, lambda: ( "torch.vsplit requires a tensor with at least 2 dimension, but got a tensor with " @@ -3766,7 +3757,7 @@ def vsplit( ) if isinstance(indices_or_sections, IntLike): split_size = indices_or_sections - check( + torch._check( (split_size != 0 and a.shape[0] % split_size == 0), lambda: ( f"torch.vsplit attempted to split along dimension 0" @@ -3779,14 +3770,13 @@ def vsplit( ) return tensor_split(a, split_size, 0) - check( + torch._check_type( isinstance(indices_or_sections, (list, tuple)), lambda: ( "vsplit(): received an invalid combination of arguments. " "Expected indices_or_sections to be of type int, list of ints or tuple of ints " f"but got type {type(indices_or_sections)}" ), - exc_type=TypeError, ) split_sizes = indices_or_sections @@ -3800,7 +3790,7 @@ def diag( offset: int = 0, ) -> TensorLikeType: ndim = self.dim() - utils.check( + torch._check( ndim in (1, 2), lambda: f"diag(): Supports 1D or 2D tensors. Got {ndim}D" ) if ndim == 1: @@ -3820,7 +3810,7 @@ def diagonal_scatter( ) -> TensorLikeType: out = utils.clone_preserve_strides(input) diag = out.diagonal(offset, dim1, dim2) - check( + torch._check( diag.shape == src.shape, lambda: "expected src to have a size equal to the diagonal of the input." f"Got {src.shape} for a diagonal of shape {diag.shape}", @@ -3843,7 +3833,7 @@ def diagonal( dim1 = utils.canonicalize_dim(idx=dim1, rank=num_dims) dim2 = utils.canonicalize_dim(idx=dim2, rank=num_dims) - check( + torch._check( dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}" ) @@ -3896,7 +3886,7 @@ def diag_embed( dim1 = utils.canonicalize_dim(rank=rank, idx=dim1) dim2 = utils.canonicalize_dim(rank=rank, idx=dim2) - check( + torch._check( dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}" ) @@ -3967,7 +3957,7 @@ def t(a: TensorLikeType): # CompositeImplicitAutograd - don't register decomp def T(a: TensorLikeType) -> TensorLikeType: # n != 2 && n != 0 is deprecated in regular PyTorch. - check( + torch._check( a.ndim in (0, 2), lambda: ( "The use of `x.T` on tensors of dimension other than 0 or 2 " @@ -4102,7 +4092,7 @@ def empty( pin_memory: bool = False, memory_format: torch.memory_format = torch.contiguous_format, ) -> TensorLikeType: - check( + torch._check( memory_format != torch.preserve_format, lambda: "torch.empty: the Preserve memory format is not supported", ) @@ -4114,7 +4104,7 @@ def empty( elif memory_format == torch.channels_last_3d: strides = utils.make_channels_last_3d_strides_for(shape) else: # memory_format == torch.channels_last - check( + torch._check( memory_format == torch.channels_last, lambda: f"torch.empty: received an unknown memory format {memory_format}!", ) @@ -4398,8 +4388,8 @@ def arange( if end is None: end = start start = 0 - utils.check(step != 0, lambda: "step must be nonzero") - utils.check( + torch._check(step != 0, lambda: "step must be nonzero") + torch._check( (step > 0 and end >= start) or (step < 0 and end <= start), lambda: "upper bound and lower bound inconsistent with step sign", ) @@ -4407,11 +4397,11 @@ def arange( def is_finite(x): return not isinstance(x, FloatWithoutSymFloat) or math.isfinite(x) - utils.check( + torch._check( is_finite(start) and is_finite(end), lambda: f"unsupported range: {start} -> {end}", ) - utils.check( + torch._check( is_finite(step), lambda: f"step must be finite but got {step}", ) @@ -4514,7 +4504,7 @@ def linspace( if dtype is None: dtype = default_complex_dtype else: - check( + torch._check( utils.is_complex_dtype(dtype), lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}", ) @@ -4523,13 +4513,12 @@ def linspace( assert isinstance(dtype, torch.dtype) # steps does not participate in the computation of the dtype - check( + torch._check_type( isinstance(steps, IntLike), lambda: "steps must be int, not float", - exc_type=TypeError, ) assert isinstance(steps, IntLike) # for mypy - check(steps >= 0, lambda: "number of steps must be non-negative") + torch._check(steps >= 0, lambda: "number of steps must be non-negative") factory_kwargs = { "layout": layout, @@ -4631,19 +4620,19 @@ def meshgrid( assert len(tensors) == 1 tensors = tuple(tensors[0]) - check( + torch._check( py_all(isinstance(a, TensorLike) for a in tensors), lambda: "meshgrid expects its inputs to be tensors", ) - check(len(tensors) > 0, lambda: "meshgrid expects a non-empty TensorList") + torch._check(len(tensors) > 0, lambda: "meshgrid expects a non-empty TensorList") for i in range(len(tensors) - 1): - check( + torch._check( tensors[i].dtype == tensors[i + 1].dtype, # type: ignore[union-attr] lambda: "meshgrid expects all tensors to have the same dtype", ) - check( + torch._check( tensors[i].device == tensors[i + 1].device, # type: ignore[union-attr] lambda: "meshgrid expects all tensors to have the same device", ) @@ -4654,7 +4643,7 @@ def meshgrid( if swap_first_and_second_tensors: tensors = (tensors[1], tensors[0], *tensors[2:]) else: - check( + torch._check( indexing == "ij", lambda: ( 'torch.meshgrid: indexing must be one of "xy" or "ij", ' @@ -4665,7 +4654,7 @@ def meshgrid( result_shape: List[int] = [] for t in tensors: assert isinstance(t, TensorLike) # mypy - check( + torch._check( t.ndim == 0 or t.ndim == 1, lambda: f"torch.meshgrid: Expected 0D or 1D tensor in the tensor list but got: {t}", ) @@ -4701,7 +4690,7 @@ def movedim( # Converts to list to produce a compatible error message with core PyTorch, # which prints sequences in square brackets. - utils.check( + torch._check( len(source) == len(destination), # type: ignore[arg-type] lambda: ( "movedim: Invalid source or destination dims: source " # type: ignore[arg-type] @@ -4718,11 +4707,11 @@ def movedim( dss = set(ds) # See above on why this converts to list in error messages. - utils.check( + torch._check( len(ss) == len(sss), lambda: f"movedim: repeated dim in `source` ({list(source)})", # type: ignore[arg-type] ) - utils.check( + torch._check( len(ds) == len(dss), lambda: f"movedim: repeated dim in `destination` ({list(destination)})", # type: ignore[arg-type] ) @@ -4795,8 +4784,8 @@ def eye( if m is None: m = n - check(n >= 0, lambda: f"n must be greater or equal to 0, got {n}") - check(m >= 0, lambda: f"m must be greater or equal to 0, got {m}") + torch._check(n >= 0, lambda: f"n must be greater or equal to 0, got {n}") + torch._check(m >= 0, lambda: f"m must be greater or equal to 0, got {m}") range_n = torch.arange(n, dtype=torch.int64, device=device, requires_grad=False) range_m = torch.arange(m, dtype=torch.int64, device=device, requires_grad=False) @@ -4994,13 +4983,13 @@ def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLi # NOTE: Could not use value = item(value) as it resulted in # RuntimeError: Cannot cast FakeTensor(cpu) to number value_ndim = value.ndim - check( + torch._check( value_ndim == 0, lambda: f"only supports a 0-dimensional value tensor, but got tensor with {value_ndim} dimension", ) # `masked_fill` allows cpu scalar to be moved to cuda and xpu but not otherwise. is_cpu_scalar = a.device.type in ["cuda", "xpu"] and value.device.type == "cpu" - check( + torch._check( is_cpu_scalar or value.device == a.device, lambda: "Expected `value` to be on same device as `a`", ) @@ -5011,7 +5000,7 @@ def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLi # We allow casting `value` to lower type for other case # Eg. float -> int. # Ref: https://github.com/pytorch/pytorch/issues/79195 - check( + torch._check( utils.is_weakly_lesser_type(value_type, python_type), lambda: f"could not convert to type {python_type} without overflow", ) @@ -5101,7 +5090,7 @@ def norm( @register_decomposition(aten.trace) def trace(self: TensorLikeType) -> TensorLikeType: - utils.check( + torch._check( self.ndim == 2, lambda: "expected a matrix, but got tensor with dim {self.ndim}" ) return torch.sum(torch.diag(self, 0)) @@ -5125,7 +5114,7 @@ rpow = _make_r_binary_op(pow) @register_decomposition(aten.triu) @out_wrapper() def triu(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType: - utils.check( + torch._check( a.ndim >= 2, lambda: "triu: input tensor must have at least 2 dimensions" ) h, w = a.shape[-2:] @@ -5142,7 +5131,7 @@ def triu(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType: @register_decomposition(aten.tril) @out_wrapper() def tril(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType: - utils.check( + torch._check( a.ndim >= 2, lambda: "tril: input tensor must have at least 2 dimensions" ) h, w = a.shape[-2:] @@ -5187,9 +5176,9 @@ def _trilu_checks( layout: torch.layout, pin_memory: bool, ): - check(row >= 0, lambda: f"row must be non-negative, got {row}") - check(col >= 0, lambda: f"col must be non-negative, got {col}") - check( + torch._check(row >= 0, lambda: f"row must be non-negative, got {row}") + torch._check(col >= 0, lambda: f"col must be non-negative, got {col}") + torch._check( dtype in (torch.int32, torch.int64), lambda: f"\"{name}\" not implemented for '{dtype}'", ) @@ -5306,7 +5295,7 @@ def bucketize( out_int32: bool = False, right: bool = False, ): - utils.check( + torch._check( boundaries.dim() == 1, lambda: f"boundaries tensor must be 1 dimension but got dim({boundaries.dim()})", ) @@ -5364,14 +5353,14 @@ def bucketize( ) def cauchy(self, median=0, sigma=1, generator=None): assert generator is None - utils.check( + torch._check( not utils.is_complex_dtype(self.dtype) and not utils.is_integer_dtype(self.dtype) and not utils.is_boolean_dtype(self.dtype), lambda: f"Cauchy distribution is a continuous probability distribution. \ dtype must be a floating point but you specified {self.dtype}", ) - utils.check( + torch._check( sigma > 0.0, lambda: f"cauchy_ expects sigma > 0.0, but found sigma={sigma}", ) @@ -5386,14 +5375,14 @@ def cauchy(self, median=0, sigma=1, generator=None): ) def exponential(self, rate=1, generator=None): assert generator is None - utils.check( + torch._check( not utils.is_complex_dtype(self.dtype) and not utils.is_integer_dtype(self.dtype) and not utils.is_boolean_dtype(self.dtype), lambda: f"Exponential distribution is a continuous probability distribution. \ dtype must be a floating point but you specified {self.dtype}", ) - utils.check( + torch._check( rate > 0.0, lambda: f"exponential_ expects lambda > 0.0, but found lambda={rate}", ) @@ -5409,12 +5398,12 @@ def exponential(self, rate=1, generator=None): def geometric(self, p, generator=None): assert generator is None # TODO: fix inductor rand_like for integer, bool dtypes - utils.check( + torch._check( not utils.is_complex_dtype(self.dtype) and not utils.is_boolean_dtype(self.dtype), lambda: f"geometric not implemented for {self.dtype}", ) - utils.check( + torch._check( 0 < p and p < 1, lambda: f"geometric_ expects p to be in (0, 1), but got p={p}", ) @@ -5429,13 +5418,13 @@ def geometric(self, p, generator=None): ) def log_normal(self, mean=1, std=2, generator=None): assert generator is None - utils.check( + torch._check( not utils.is_complex_dtype(self.dtype) and not utils.is_integer_dtype(self.dtype) and not utils.is_boolean_dtype(self.dtype), lambda: f"log_normal not implemented for {self.dtype}", ) - utils.check( + torch._check( 0 < std, lambda: f"log_normal_ expects std > 0.0, but found std={std}", ) @@ -5451,7 +5440,7 @@ def log_normal(self, mean=1, std=2, generator=None): ) def normal(self, mean=0, std=1, generator=None): assert generator is None - utils.check(std >= 0, lambda: f"normal expects std >= 0.0, but found std {std}") + torch._check(std >= 0, lambda: f"normal expects std >= 0.0, but found std {std}") normal_samples = prims.normal( self.shape, mean=0.0, @@ -5465,7 +5454,7 @@ def normal(self, mean=0, std=1, generator=None): @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) def rad2deg(self: TensorLikeType): - utils.check( + torch._check( not utils.is_complex_dtype(self.dtype), lambda: "rad2deg is not supported for complex tensors.", ) @@ -5475,7 +5464,7 @@ def rad2deg(self: TensorLikeType): @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) def deg2rad(self: TensorLikeType): - utils.check( + torch._check( not utils.is_complex_dtype(self.dtype), lambda: "deg2rad is not supported for complex tensors.", ) diff --git a/torch/_refs/_conversions.py b/torch/_refs/_conversions.py index d0607d89702..fa1ca242825 100644 --- a/torch/_refs/_conversions.py +++ b/torch/_refs/_conversions.py @@ -4,7 +4,7 @@ import torch._prims_common as utils # Utilities should come BEFORE this import from torch._decomp import register_decomposition -from torch._prims_common import check, TensorLikeType +from torch._prims_common import TensorLikeType from torch._prims_common.wrappers import out_wrapper from torch._refs import _broadcast_shapes @@ -79,14 +79,14 @@ short = _make_conversion_method("short", torch.short) @out_wrapper(exact_dtype=True) def complex(real: TensorLikeType, imag: TensorLikeType) -> TensorLikeType: allowed_dtypes = (torch.float32, torch.float64, torch.float16) - check( + torch._check( real.dtype in allowed_dtypes and imag.dtype in allowed_dtypes, lambda: ( f"Expected both inputs to be Half, Float or Double tensors but got " f"{real.dtype} and {imag.dtype}" ), ) - check( + torch._check( real.dtype == imag.dtype, lambda: ( f"Expected object of scalar type {real.dtype} but got " diff --git a/torch/_refs/fft.py b/torch/_refs/fft.py index 54a98c273e8..4b22d8670f0 100644 --- a/torch/_refs/fft.py +++ b/torch/_refs/fft.py @@ -6,7 +6,7 @@ import torch import torch._prims as prims import torch._prims_common as utils from torch._decomp import register_decomposition -from torch._prims_common import check, DimsType, ShapeType, TensorLikeType +from torch._prims_common import DimsType, ShapeType, TensorLikeType from torch._prims_common.wrappers import _maybe_convert_to_dtype, out_wrapper __all__ = [ @@ -43,7 +43,7 @@ def _apply_norm( x: TensorLikeType, norm: NormType, signal_numel: int, forward: bool ) -> TensorLikeType: """Apply normalization to the un-normalized FFT result""" - check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}") + torch._check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}") if norm == "ortho": return x * (1 / math.sqrt(signal_numel)) @@ -116,7 +116,9 @@ def _fft_c2r( input = _maybe_promote_tensor_fft(input, require_complex=True) dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),) last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1) - check(last_dim_size >= 1, lambda: f"Invalid number of data points ({n}) specified") + torch._check( + last_dim_size >= 1, lambda: f"Invalid number of data points ({n}) specified" + ) if n is not None: input = _resize_fft_input(input, dims=dims, sizes=(last_dim_size // 2 + 1,)) @@ -138,7 +140,7 @@ def _fft_r2c( onesided: bool, ) -> TensorLikeType: """Common code for performing any real to complex FFT (rfft or ihfft)""" - check( + torch._check( not input.dtype.is_complex, lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}", ) @@ -162,7 +164,7 @@ def _fft_c2c( forward: bool, ) -> TensorLikeType: """Common code for performing any complex to complex FFT (fft or ifft)""" - check( + torch._check( input.dtype.is_complex, lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}", ) @@ -265,20 +267,20 @@ def _canonicalize_fft_shape_and_dim_args( ret_dims = utils.canonicalize_dims(input_dim, dim, wrap_scalar=False) # Check dims are unique - check(len(set(dim)) == len(dim), lambda: "FFT dims must be unique") + torch._check(len(set(dim)) == len(dim), lambda: "FFT dims must be unique") if shape is not None: if not isinstance(shape, Sequence): shape = (shape,) # Has shape, might have dim - check( + torch._check( dim is None or len(dim) == len(shape), lambda: "When given, dim and shape arguments must have the same length", ) transform_ndim = len(shape) - check( + torch._check( transform_ndim <= input_dim, lambda: f"Got shape with {transform_ndim} values but input tensor " f"only has {input_dim} dimensions.", @@ -301,7 +303,7 @@ def _canonicalize_fft_shape_and_dim_args( ret_shape = tuple(input_sizes[d] for d in ret_dims) for n in ret_shape: - check(n > 0, lambda: f"Invalid number of data points ({n}) specified") + torch._check(n > 0, lambda: f"Invalid number of data points ({n}) specified") return _ShapeAndDims(shape=ret_shape, dims=ret_dims) @@ -323,7 +325,7 @@ def _fftn_c2c( forward: bool, ) -> TensorLikeType: """Common code for n-dimensional complex to complex FFTs (fftn or ifftn)""" - check( + torch._check( input.dtype.is_complex, lambda: f"{function_name} expects a complex input tensor, " f"but got {input.dtype}", @@ -367,7 +369,7 @@ def rfftn( dim: Optional[DimsType] = None, norm: NormType = None, ) -> TensorLikeType: - check( + torch._check( not input.dtype.is_complex, lambda: f"rfftn expects a real-valued input tensor, but got {input.dtype}", ) @@ -386,12 +388,12 @@ def ihfftn( dim: Optional[DimsType] = None, norm: NormType = None, ) -> TensorLikeType: - check( + torch._check( not input.dtype.is_complex, lambda: f"ihfftn expects a real-valued input tensor, but got {input.dtype}", ) shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim) - check(len(shape) > 0, lambda: "ihfftn must transform at least one axis") + torch._check(len(shape) > 0, lambda: "ihfftn must transform at least one axis") input = _maybe_promote_tensor_fft(input, require_complex=False) input = _resize_fft_input(input, dim, shape) @@ -421,14 +423,14 @@ def _canonicalize_fft_c2r_shape_and_dim_args( """Canonicalize shape and dim arguments for n-dimensional c2r transforms, as well as calculating the last_dim_size which is shape[dim[-1]] for the output""" (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim) - check(len(shape) > 0, lambda: f"{fname} must transform at least one axis") + torch._check(len(shape) > 0, lambda: f"{fname} must transform at least one axis") if s is None or s[-1] == -1: last_dim_size = 2 * (input.shape[dim[-1]] - 1) else: last_dim_size = shape[-1] - check( + torch._check( last_dim_size >= 1, lambda: f"Invalid number of data points ({last_dim_size}) specified", ) diff --git a/torch/_refs/linalg/__init__.py b/torch/_refs/linalg/__init__.py index f22926c26bc..97fb1d9d57f 100644 --- a/torch/_refs/linalg/__init__.py +++ b/torch/_refs/linalg/__init__.py @@ -11,7 +11,6 @@ import torch._refs as refs import torch._refs.linalg as linalg from torch import Tensor from torch._prims_common import ( - check, check_fp_or_complex, check_is_matrix, Dim, @@ -29,11 +28,11 @@ def _check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_nam Checks related to the dtype kwarg in `linalg.*norm` functions """ if dtype is not None: - check( + torch._check( utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype), lambda: f"{fn_name}: dtype should be floating point or complex. Got {dtype}", ) - check( + torch._check( utils.is_complex_dtype(dtype) == utils.is_complex_dtype(x_dtype), lambda: "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}".format( fn_name=fn_name, @@ -41,7 +40,7 @@ def _check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_nam dtype=dtype, ), ) - check( + torch._check( utils.get_higher_dtype(dtype, x_dtype) == dtype, lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible " "without narrowing to the specified dtype ({dtype})", @@ -79,7 +78,7 @@ def vector_norm( dim = [dim] # type: ignore[assignment] if x.numel() == 0 and (ord < 0.0 or ord == float("inf")): - check( + torch._check( dim is not None and len(dim) != 0, lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor " "because the operation does not have an identity", @@ -87,7 +86,7 @@ def vector_norm( shape = x.shape assert dim is not None # mypy does not seem to be able to see through check? for d in dim: - check( + torch._check( shape[d] != 0, lambda: f"linalg.vector_norm cannot compute the {ord} norm on the " f"dimension {d} because this dimension is empty and the " @@ -147,8 +146,10 @@ def matrix_norm( dim = utils.canonicalize_dims(A.ndim, dim) if isinstance(dim, Dim): dim = (dim,) # type: ignore[assignment] - check(len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}") - check( + torch._check( + len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}" + ) + torch._check( dim[0] != dim[1], lambda: "linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})", ) @@ -157,7 +158,7 @@ def matrix_norm( if isinstance(ord, str): # ord - check( + torch._check( ord in ("fro", "nuc"), lambda: "linalg.matrix_norm: Order {ord} not supported.", ) @@ -180,7 +181,7 @@ def matrix_norm( else: # ord abs_ord = abs(ord) - check( + torch._check( abs_ord in (2, 1, float("inf")), lambda: "linalg.matrix_norm: Order {ord} not supported.", ) @@ -224,12 +225,12 @@ def norm( if dim is not None: if isinstance(dim, Dim): dim = (dim,) # type: ignore[assignment] - check( + torch._check( len(dim) in (1, 2), lambda: "linalg.norm: If dim is specified, it must be of length 1 or 2. Got {dim}", ) elif ord is not None: - check( + torch._check( A.ndim in (1, 2), lambda: "linalg.norm: If dim is not specified but ord is, the input must be 1D or 2D. Got {A.ndim}D", ) diff --git a/torch/_refs/nn/functional/__init__.py b/torch/_refs/nn/functional/__init__.py index be82d0ab781..eaa6618379f 100644 --- a/torch/_refs/nn/functional/__init__.py +++ b/torch/_refs/nn/functional/__init__.py @@ -8,7 +8,6 @@ import torch._prims_common as utils import torch._refs as refs from torch._decomp import register_decomposition from torch._prims_common import ( - check, ELEMENTWISE_TYPE_PROMOTION_KIND, NumberType, ShapeType, @@ -98,7 +97,7 @@ def alpha_dropout( if not training: return self - utils.check( + torch._check( p <= 1 and p >= 0, lambda: f"dropout probability has to be between 0 and 1, but got, {p}", ) @@ -134,7 +133,7 @@ def _inplace_wrapper(fn): @wraps(fn) def _fn(a, *args, inplace=False, **kwargs): if inplace: - check( + torch._check( "out" not in kwargs, lambda: "Cannot set inplace=True and pass out= at the same time", ) @@ -193,7 +192,7 @@ def dropout( if not training: return a - utils.check( + torch._check( p <= 1 and p >= 0, lambda: f"dropout probability has to be between 0 and 1, but got, {p}", ) @@ -232,15 +231,15 @@ def elu( # nb. This should be factored out into a can_cast aux function python_type = utils.dtype_to_type(a.dtype) - check( + torch._check( utils.is_weakly_lesser_type(type(input_scale), python_type), lambda: f"input_scale argument of type {type(input_scale)} cannot be safely cast to type {python_type}!", ) - check( + torch._check( utils.is_weakly_lesser_type(type(scale), python_type), lambda: f"scale argument of type {type(scale)} cannot be safely cast to type {python_type}!", ) - check( + torch._check( utils.is_weakly_lesser_type(type(alpha), python_type), lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!", ) @@ -276,14 +275,14 @@ def group_norm( """ Reference implementation of :func:`torch.nn.functional.group_norm`. """ - utils.check( + torch._check( input.ndim >= 2, lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}", ) batch_size = input.shape[0] num_channels = input.shape[1] - utils.check( + torch._check( num_channels % num_groups == 0, lambda: "Expected number of channels in input to be divisible by num_groups, " + f"but got input of shape {input.shape} and num_groups = {num_groups}", @@ -394,7 +393,7 @@ def softmax( # deprecated. For PrimTorch, it's fine to drop support for deprecated # behavior because it requires explicit opt in. This error is to inform # users how to update their calls. - check(dim is not None, lambda: "implicit dim not supported, use dim=X") + torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X") return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] @@ -409,7 +408,7 @@ def softmin( # deprecated. For PrimTorch, it's fine to drop support for deprecated # behavior because it requires explicit opt in. This error is to inform # users how to update their calls. - check(dim is not None, lambda: "implicit dim not supported, use dim=X") + torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X") return torch.softmax(a=-a, dim=dim, dtype=dtype) # type: ignore[call-overload] @@ -469,7 +468,7 @@ def softshrink(a: TensorLikeType, lambd: float = 0.5): # softshrink(x) = x - lambd if x > lambd # = x + lambd if x < -lambd # = 0 otherwise - check( + torch._check( lambd >= 0, lambda: f"lambda must be greater or equal to 0, but found to be {lambd}", ) @@ -596,7 +595,7 @@ def log_softmax( # deprecated. For PrimTorch, it's fine to drop support for deprecated # behavior because it requires explicit opt in. This error is to inform # users how to update their calls. - check(dim is not None, lambda: "implicit dim not supported, use dim=X") + torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X") return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] @@ -668,12 +667,12 @@ def _nll_loss_nd( reduction: str, ignore_index: int, ) -> TensorLikeType: - utils.check( + torch._check( input.ndim > 0 and input.ndim <= 3, lambda: f"Expected input dimension to be either [1, 2, 3] but received {input.ndim}.", ) - utils.check( + torch._check( (input.ndim == 1) or (input.shape[0] == target.shape[0]), lambda: f"Expected input batch size {input.shape[0]} to match target batch size {target.shape[0]}.", ) @@ -693,7 +692,7 @@ def _nll_loss_nd( (flat_target >= 0), (flat_target < num_classes) ) class_check = torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask)) - utils.check( + torch._check( isinstance(target, FakeTensor) or bool(class_check.item()), lambda: "A target class is out-of-bounds and not the ignore index.", ) @@ -758,7 +757,7 @@ def nll_loss( """ Reference implementation of torch.nn.functional.nll_loss """ - utils.check( + torch._check( input.ndim > 0, lambda: f"Expected input tensor to have 1 or more dimensions (got {input.ndim})", ) @@ -796,9 +795,13 @@ def nll_loss( # For ndim > 3, we reshape the input and target to 3-D case. # Input (N batch-size, C classes, k-dimensions) # Target (N batch-size, k-dimensions) - utils.check( + torch._check( input.ndim > 0 and target.ndim > 0 and target.shape[1:] == input.shape[2:], - lambda: f"Expected target shape {out_size} but got {target.shape}", + lambda: ( + "Expected input and target to both have ndim > 0 and " + "target.shape[1:] == input.shape[2:], but got " + f"target.shape {target.shape} and input.shape {input.shape}" + ), ) batch_size = input.shape[0] @@ -837,7 +840,7 @@ def huber_loss( if type(reduction) is int: reduction = _reduction_int_to_str(reduction) _check_reduction_value(reduction) # type: ignore[arg-type] - check( + torch._check( delta > 0, lambda: "huber_loss does not support non-positive values for delta.", ) @@ -938,7 +941,7 @@ def _triplet_margin_with_distance_loss( a_dim = anchor.ndim p_dim = positive.ndim n_dim = negative.ndim - check( + torch._check( a_dim == p_dim and p_dim == n_dim, lambda: ( f"The anchor, positive, and negative tensors are expected to have " @@ -1075,25 +1078,25 @@ def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType: """ Reference implementation of torch.nn.functional.prelu """ - check( + torch._check( isinstance(a, TensorLike), lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}", ) - check( + torch._check( isinstance(weight, TensorLike), lambda: f"prelu: Expected `weight` to be tensor, but got: {type(weight)}", ) if weight.numel() != 1: - check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.") + torch._check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.") channel_size = a.shape[1] if a.ndim >= 2 else 1 - check( + torch._check( weight.numel() == channel_size, lambda: f"Mismatch of parameter numbers and input channel size. Found parameter numbers =" f" {weight.numel()} and channel size = {channel_size}.", ) - check( + torch._check( weight.ndim == 0 or weight.ndim == 1, lambda: f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: " f"ndim = {weight.ndim}", @@ -1132,7 +1135,7 @@ def relu6(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: ) def glu(a: TensorLikeType, dim: int = -1) -> TensorLikeType: dim = utils.canonicalize_dims(a.ndim, dim) - check( + torch._check( a.shape[dim] % 2 == 0, lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}", ) @@ -1160,8 +1163,8 @@ def pairwise_distance( type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) def pdist(a: TensorLikeType, p: float = 2) -> TensorLikeType: - check(a.ndim == 2, lambda: f"pdist only supports 2D tensors, got: {a.ndim}D") - check(p >= 0, lambda: "pdist only supports non-negative p values") + torch._check(a.ndim == 2, lambda: f"pdist only supports 2D tensors, got: {a.ndim}D") + torch._check(p >= 0, lambda: "pdist only supports non-negative p values") # For p == 2 we can use an efficient implementation, but other values of p # require creating a much bigger tensor for an intermediate step if p == 2: diff --git a/torch/_refs/special/__init__.py b/torch/_refs/special/__init__.py index 4369265a90f..048de83506d 100644 --- a/torch/_refs/special/__init__.py +++ b/torch/_refs/special/__init__.py @@ -148,7 +148,7 @@ def logit(self: TensorLikeType, eps: Optional[float] = None) -> TensorLikeType: type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, ) def xlog1py(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]): - utils.check( + torch._check( isinstance(a, TensorLike) or isinstance(b, TensorLike), lambda: 'Expected either argument a or b to be a Tensor"', ) diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 3ba091a86f7..fe1dd93b33d 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -15,7 +15,6 @@ import torch._logging from torch._guards import Source from torch._ops import OpOverload from torch._prims_common import ( - check, elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND, is_boolean_dtype, @@ -1495,7 +1494,7 @@ class FakeTensorMode(TorchDispatchMode): ) = FakeTensor._find_common_device(func, args, kwargs) if isinstance(e, FakeTensor): - check( + torch._check( e.device == common_device, lambda: f"FakeTensor is wrapped to wrong device, found {e.device}, expected {common_device}", )