mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Replace _prims_common.check with torch._check* (#103240)
This relands most of the changes from #102219 which were backed out by #103128. However, instead of removing `_prims_common.check`, it adds a warning and a comment mentioning that it will be removed in the future and `torch._check*` should be used instead. As mentioned in https://github.com/pytorch/pytorch/pull/103128#pullrequestreview-1466414415, `_prims_common.check` cannot yet be removed because of some internal usage Part of #72948 Pull Request resolved: https://github.com/pytorch/pytorch/pull/103240 Approved by: https://github.com/albanD
This commit is contained in:
parent
f3c3d12efb
commit
ee83c646bb
|
|
@ -1161,6 +1161,10 @@ $1 = torch._ops.prims.sin.default($0)""")
|
|||
def test_mul_complex(self):
|
||||
prims.mul(torch.randn(2), 1 + 1j)
|
||||
|
||||
def test_check_deprecation_warning(self):
|
||||
with self.assertWarnsRegex(DeprecationWarning, 'will be removed in the future'):
|
||||
torch._prims_common.check(True, lambda: 'message')
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestPrims, globals())
|
||||
|
||||
|
|
|
|||
|
|
@ -936,7 +936,7 @@ def is_warn_always_enabled():
|
|||
# These error checking functions must be kept consistent with their C++
|
||||
# equivalents. Their C++ equivalents are mentioned where applicable.
|
||||
|
||||
def _check_with(error_type, cond, message):
|
||||
def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callable[[], str]):
|
||||
if not isinstance(cond, (builtins.bool, torch.SymBool)):
|
||||
raise TypeError(f'cond must be a bool, but got {type(cond)}')
|
||||
|
||||
|
|
|
|||
|
|
@ -149,7 +149,7 @@ def fill_scalar(self, value):
|
|||
|
||||
@register_decomposition([aten.fill.Tensor])
|
||||
def fill_tensor(self, value: Tensor):
|
||||
utils.check(
|
||||
torch._check(
|
||||
value.dim() == 0,
|
||||
lambda: f"fill only supports 0-dimension value tensor but got tensor with {value.dim()} dimensions",
|
||||
)
|
||||
|
|
@ -785,14 +785,14 @@ def im2col(
|
|||
padding: List[int],
|
||||
stride: List[int],
|
||||
) -> Tensor:
|
||||
utils.check(len(kernel_size) == 2, lambda: "im2col(): only 2D kernel supported")
|
||||
utils.check(len(dilation) == 2, lambda: "im2col(): only 2D dilation supported")
|
||||
utils.check(len(padding) == 2, lambda: "im2col(): only 2D padding supported")
|
||||
utils.check(len(stride) == 2, lambda: "im2col(): only 2D stride supported")
|
||||
torch._check(len(kernel_size) == 2, lambda: "im2col(): only 2D kernel supported")
|
||||
torch._check(len(dilation) == 2, lambda: "im2col(): only 2D dilation supported")
|
||||
torch._check(len(padding) == 2, lambda: "im2col(): only 2D padding supported")
|
||||
torch._check(len(stride) == 2, lambda: "im2col(): only 2D stride supported")
|
||||
|
||||
def check_positive(param, param_name, strict=True):
|
||||
cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
|
||||
utils.check(
|
||||
torch._check(
|
||||
cond, lambda: "{param_name} should be greater {'than' zero, but got {param}"
|
||||
)
|
||||
|
||||
|
|
@ -803,7 +803,7 @@ def im2col(
|
|||
|
||||
shape = input.shape
|
||||
ndim = len(shape)
|
||||
utils.check(
|
||||
torch._check(
|
||||
ndim in (3, 4) and all(d != 0 for d in shape[-3:]),
|
||||
lambda: "Expected 3D or 4D (batch mode) tensor for input with possible 0 batch size "
|
||||
f"and non-zero dimensions, but got: {tuple(shape)}",
|
||||
|
|
@ -814,7 +814,7 @@ def im2col(
|
|||
shape[-2:], padding, dilation, kernel_size, stride
|
||||
)
|
||||
)
|
||||
utils.check(
|
||||
torch._check(
|
||||
all(c > 0 for c in output_size),
|
||||
lambda: f"Given an input with spacial size {tuple(shape[-2:])}, "
|
||||
f"kernel_size={kernel_size}, dilation={dilation}, "
|
||||
|
|
@ -869,15 +869,15 @@ def col2im(
|
|||
padding: List[int],
|
||||
stride: List[int],
|
||||
) -> Tensor:
|
||||
utils.check(len(output_size) == 2, lambda: "only 2D output_size supported")
|
||||
utils.check(len(kernel_size) == 2, lambda: "only 2D kernel supported")
|
||||
utils.check(len(dilation) == 2, lambda: "only 2D dilation supported")
|
||||
utils.check(len(padding) == 2, lambda: "only 2D padding supported")
|
||||
utils.check(len(stride) == 2, lambda: "only 2D stride supported")
|
||||
torch._check(len(output_size) == 2, lambda: "only 2D output_size supported")
|
||||
torch._check(len(kernel_size) == 2, lambda: "only 2D kernel supported")
|
||||
torch._check(len(dilation) == 2, lambda: "only 2D dilation supported")
|
||||
torch._check(len(padding) == 2, lambda: "only 2D padding supported")
|
||||
torch._check(len(stride) == 2, lambda: "only 2D stride supported")
|
||||
|
||||
def check_positive(param, param_name, strict=True):
|
||||
cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
|
||||
utils.check(
|
||||
torch._check(
|
||||
cond, lambda: "{param_name} should be greater than zero, but got {param}"
|
||||
)
|
||||
|
||||
|
|
@ -889,13 +889,13 @@ def col2im(
|
|||
|
||||
shape = input.shape
|
||||
ndim = len(shape)
|
||||
utils.check(
|
||||
torch._check(
|
||||
ndim in (2, 3) and all(d != 0 for d in shape[-2:]),
|
||||
lambda: "Expected 2D or 3D (batch mode) tensor for input with possible 0 batch size "
|
||||
f"and non-zero dimensions, but got: {tuple(shape)}",
|
||||
)
|
||||
prod_kernel_size = kernel_size[0] * kernel_size[1]
|
||||
utils.check(
|
||||
torch._check(
|
||||
shape[-2] % prod_kernel_size == 0,
|
||||
lambda: "Expected size of input's first non-batch dimension to be divisible by the "
|
||||
f"product of kernel_size, but got input.shape[-2] = {shape[-2]} and "
|
||||
|
|
@ -908,13 +908,13 @@ def col2im(
|
|||
)
|
||||
]
|
||||
L = col[0] * col[1]
|
||||
utils.check(
|
||||
torch._check(
|
||||
shape[-1] == L,
|
||||
lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, "
|
||||
f"dilation={dilation}, padding={padding}, stride={stride}, "
|
||||
f"expected input.size(-1) to be {L} but got {shape[-1]}.",
|
||||
)
|
||||
utils.check(
|
||||
torch._check(
|
||||
L > 0,
|
||||
lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, "
|
||||
f"dilation={dilation}, padding={padding}, stride={stride}, "
|
||||
|
|
@ -961,7 +961,7 @@ def col2im(
|
|||
def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float):
|
||||
# According to the CUDA kernel implementation we should have this test;
|
||||
# but it seems to fail tests!
|
||||
# utils.check(mask.dtype == torch.bool, lambda: f"Mask should be Bool Scalar Type {mask.dtype}")
|
||||
# torch._check(mask.dtype == torch.bool, lambda: f"Mask should be Bool Scalar Type {mask.dtype}")
|
||||
|
||||
# Mimicking CUDA kernel's behavior for output stride: output follow input's memory format
|
||||
# This different from TensorIterator's behavior
|
||||
|
|
@ -1221,21 +1221,21 @@ def native_group_norm_backward(
|
|||
)
|
||||
utils.check_same_shape(input, grad_output, allow_cpu_scalar_tensors=False)
|
||||
utils.check_same_shape(mean, rstd, allow_cpu_scalar_tensors=False)
|
||||
utils.check(
|
||||
torch._check(
|
||||
input.numel() == N * C * HxW,
|
||||
lambda: f"Expect input to have { N * C * HxW} elements",
|
||||
)
|
||||
utils.check(
|
||||
torch._check(
|
||||
mean.shape == (N, group),
|
||||
lambda: f"Expect mean to have shape ({N}, {group}, but got {mean.shape}",
|
||||
)
|
||||
utils.check(
|
||||
torch._check(
|
||||
gamma is None or gamma.numel() == C,
|
||||
lambda: f"Expect gamma to have {C} elements but got {gamma.numel() if gamma is not None else -1}",
|
||||
)
|
||||
|
||||
cpg, _rem = divmod(C, group)
|
||||
utils.check(
|
||||
torch._check(
|
||||
_rem == 0,
|
||||
lambda: f"Expect number of channels {C} to be evenly-divisible by number of groups {group}",
|
||||
)
|
||||
|
|
@ -1834,12 +1834,12 @@ def adaptive_avg_pool2d(input: Tensor, output_size: Tuple[int, int]):
|
|||
device = input.device
|
||||
shape = input.shape
|
||||
ndim = len(shape)
|
||||
utils.check(
|
||||
torch._check(
|
||||
ndim in (3, 4),
|
||||
lambda: f"adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got {ndim}",
|
||||
)
|
||||
for d in input.shape[-2:]:
|
||||
utils.check(
|
||||
torch._check(
|
||||
d != 0,
|
||||
lambda: "adaptive_avg_pool2d(): Expected input to have non-zero size for "
|
||||
f"non-batch dimensions, but input has shape {tuple(shape)}.",
|
||||
|
|
@ -1966,13 +1966,13 @@ def _index_add(
|
|||
alpha: NumberType = 1,
|
||||
):
|
||||
dim = utils.canonicalize_dims(x.ndim, dim)
|
||||
utils.check(
|
||||
torch._check(
|
||||
index.ndim <= 1,
|
||||
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
|
||||
)
|
||||
if alpha != 1:
|
||||
python_type = utils.dtype_to_type(x.dtype)
|
||||
utils.check(
|
||||
torch._check(
|
||||
python_type == bool
|
||||
or utils.is_weakly_lesser_type(type(alpha), python_type),
|
||||
lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!",
|
||||
|
|
@ -2005,7 +2005,7 @@ def _index_copy(
|
|||
x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike, *, inplace: bool
|
||||
):
|
||||
dim = utils.canonicalize_dims(x.ndim, dim)
|
||||
utils.check(
|
||||
torch._check(
|
||||
index.ndim <= 1,
|
||||
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
|
||||
)
|
||||
|
|
@ -2060,19 +2060,19 @@ def uniform_(self, low=0, high=1, generator=None):
|
|||
def upsample_compute_output_size(input_size, output_size, scale_factors):
|
||||
spatial_dimensions = len(input_size) - 2
|
||||
if output_size is not None:
|
||||
utils.check(
|
||||
torch._check(
|
||||
scale_factors is None,
|
||||
lambda: "Must specify exactly one of output_size and scale_factors",
|
||||
)
|
||||
utils.check(len(output_size) == spatial_dimensions, lambda: "")
|
||||
torch._check(len(output_size) == spatial_dimensions, lambda: "")
|
||||
return output_size
|
||||
if scale_factors is not None:
|
||||
# NB: this isn't necessary lol
|
||||
utils.check(
|
||||
torch._check(
|
||||
output_size is None,
|
||||
lambda: "Must specify exactly one of output_size and scale_factors",
|
||||
)
|
||||
utils.check(len(scale_factors) == spatial_dimensions, lambda: "")
|
||||
torch._check(len(scale_factors) == spatial_dimensions, lambda: "")
|
||||
output_size = []
|
||||
for i, s in enumerate(scale_factors):
|
||||
if int(s) == s:
|
||||
|
|
@ -2080,7 +2080,7 @@ def upsample_compute_output_size(input_size, output_size, scale_factors):
|
|||
else:
|
||||
output_size.append(sym_int(input_size[i + 2] * s))
|
||||
return output_size
|
||||
utils.check(
|
||||
torch._check(
|
||||
False, lambda: "Must specify exactly one of output_size and scale_factors"
|
||||
)
|
||||
|
||||
|
|
@ -2969,11 +2969,11 @@ def grid_sampler_2d(
|
|||
padding_mode: int = 0,
|
||||
align_corners: bool = False,
|
||||
) -> Tensor:
|
||||
utils.check(
|
||||
torch._check(
|
||||
interpolation_mode in (0, 1, 2),
|
||||
lambda: f"Invalid interpolation mode {interpolation_mode}",
|
||||
)
|
||||
utils.check(
|
||||
torch._check(
|
||||
padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}"
|
||||
)
|
||||
|
||||
|
|
@ -3110,11 +3110,11 @@ def grid_sampler_2d(
|
|||
@out_wrapper()
|
||||
@pw_cast_for_opmath
|
||||
def mv(self, vec):
|
||||
utils.check(
|
||||
torch._check(
|
||||
self.dim() == 2 and vec.dim() == 1,
|
||||
lambda: f"matrix @ vector expected, got {self.dim()}, {vec.dim()}",
|
||||
)
|
||||
utils.check(
|
||||
torch._check(
|
||||
self.size(1) == vec.size(0),
|
||||
lambda: f"size mismatch, got input ({self.size(0)}x{self.size(1)}), vec ({vec.size(0)})",
|
||||
)
|
||||
|
|
@ -3134,11 +3134,11 @@ def dot(self, other):
|
|||
elif other.is_conj():
|
||||
return torch.vdot(other.conj(), self)
|
||||
|
||||
utils.check(
|
||||
torch._check(
|
||||
self.dim() == 1 and other.dim() == 1,
|
||||
lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors",
|
||||
)
|
||||
utils.check(
|
||||
torch._check(
|
||||
self.dtype == other.dtype,
|
||||
lambda: f"dot : expected both vectors to have same dtype, but found {self.dtype} and {other.dtype}",
|
||||
)
|
||||
|
|
@ -3149,7 +3149,7 @@ def dot(self, other):
|
|||
f"same number of elements, but got {self.numel()} and {other.numel()} elements respectively"
|
||||
)
|
||||
|
||||
utils.check(self.numel() == other.numel(), numel_error)
|
||||
torch._check(self.numel() == other.numel(), numel_error)
|
||||
|
||||
return (self * other).sum()
|
||||
|
||||
|
|
@ -3296,7 +3296,7 @@ def matmul(tensor1, tensor2):
|
|||
|
||||
return tensor1_expanded.bmm(tensor2_expanded).view(output_shape)
|
||||
else:
|
||||
utils.check(False, lambda: "both arguments to matmul need to be at least 1D")
|
||||
torch._check(False, lambda: "both arguments to matmul need to be at least 1D")
|
||||
|
||||
|
||||
@register_decomposition(aten.upsample_bicubic2d.default)
|
||||
|
|
@ -3373,7 +3373,7 @@ def upsample_bicubic2d_vec(
|
|||
align_corners: bool,
|
||||
scale_factors: Optional[Tuple[float, float]] = None,
|
||||
) -> Tensor:
|
||||
utils.check(
|
||||
torch._check(
|
||||
bool(output_size) + bool(scale_factors) == 1,
|
||||
lambda: "Must specify exactly one of output_size and scale_factors.",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -72,7 +72,6 @@ from torch._inductor.compile_fx import (
|
|||
remove_unaligned_input_idxs,
|
||||
static_input,
|
||||
)
|
||||
from torch._prims_common import check
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
from torch.storage import UntypedStorage
|
||||
from torch.utils import _pytree as pytree
|
||||
|
|
@ -1071,7 +1070,7 @@ class CUDAGraphNode:
|
|||
self.output_storage_alias.append(UnaliasedStorage)
|
||||
continue
|
||||
|
||||
check(
|
||||
torch._check(
|
||||
o.is_cuda,
|
||||
lambda: f"Expected all cuda outputs in cuda graph recording. Non cuda output from {self.stack_traces[i]}",
|
||||
),
|
||||
|
|
@ -1447,7 +1446,7 @@ class CUDAGraphNode:
|
|||
for idx in self.cudagraph_managed_idxs:
|
||||
inputs[idx] = None
|
||||
|
||||
check(
|
||||
torch._check(
|
||||
self._check_liveness(
|
||||
self.expected_dead_indices_after_graph, self.path_weakrefs
|
||||
),
|
||||
|
|
@ -1522,7 +1521,7 @@ def check_memory_pool(device, pool_id, live_storages_ptrs: List[StorageWeakRefWr
|
|||
|
||||
addr += block["size"]
|
||||
|
||||
check(
|
||||
torch._check(
|
||||
len(unique_storages) == 0,
|
||||
lambda: f"These storage data ptrs are not allocated in pool {pool_id} but should be {unique_storages}",
|
||||
)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -16,7 +16,6 @@ from torch._prims.debug_prims import register_debug_prims
|
|||
from torch._prims.nvfuser_prims import register_nvprims
|
||||
from torch._prims.rng_prims import register_rng_prims
|
||||
from torch._prims_common import (
|
||||
check,
|
||||
Dim,
|
||||
DimsSequenceType,
|
||||
DimsType,
|
||||
|
|
@ -422,7 +421,7 @@ def _elementwise_meta(
|
|||
|
||||
|
||||
def _complex_only_elementwise_meta(*args, **kwargs):
|
||||
utils.check(
|
||||
torch._check(
|
||||
utils.is_complex_dtype(args[0].dtype), lambda: "Only complex dtype is supported"
|
||||
)
|
||||
return _elementwise_meta(*args, **kwargs)
|
||||
|
|
@ -581,7 +580,7 @@ bitwise_not = _make_elementwise_unary_prim(
|
|||
|
||||
|
||||
def _cbrt_aten(a: torch.Tensor) -> Tensor:
|
||||
utils.check(
|
||||
torch._check(
|
||||
not a.is_complex(),
|
||||
lambda: "cbrt: Complex inputs not supported. Consider calling torch.pow(a, 1.0/3.0)",
|
||||
)
|
||||
|
|
@ -1293,10 +1292,9 @@ def _validate_collapse_args(a: Tensor, start: int, end: int) -> None:
|
|||
|
||||
# Verifies end is strictly greater than start
|
||||
# (Collapse requires a non-empty interval)
|
||||
utils.check(
|
||||
torch._check_value(
|
||||
end >= start,
|
||||
lambda: f"Attempting to collapse but end, {end}, is less than start, {start}!",
|
||||
ValueError,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1823,7 +1821,7 @@ def _as_strided_scatter_meta(
|
|||
utils.validate_strides(stride)
|
||||
|
||||
required_size = utils.compute_required_storage_length(size, stride, storage_offset)
|
||||
utils.check(
|
||||
torch._check(
|
||||
input.numel() >= required_size,
|
||||
lambda: (
|
||||
f"as_strided_scatter: sizes {size}, strides {stride}, storage offset {storage_offset} "
|
||||
|
|
@ -1832,7 +1830,7 @@ def _as_strided_scatter_meta(
|
|||
f"for storage of size {input.numel() * input.element_size()}"
|
||||
),
|
||||
)
|
||||
utils.check(
|
||||
torch._check(
|
||||
utils.is_same_shape(src.shape, size),
|
||||
lambda: f"expected src to have a size equal to the slice of self. src size = {src.shape}, slice size = {size}",
|
||||
)
|
||||
|
|
@ -2432,11 +2430,11 @@ def _iota_meta(
|
|||
device: torch.device,
|
||||
requires_grad: bool,
|
||||
) -> TensorLikeType:
|
||||
utils.check(
|
||||
torch._check(
|
||||
utils.is_integer_dtype(dtype),
|
||||
lambda: "prims.iota only supports integer dtypes",
|
||||
)
|
||||
utils.check(step != 0, lambda: "step must be nonzero")
|
||||
torch._check(step != 0, lambda: "step must be nonzero")
|
||||
return torch.empty(
|
||||
length,
|
||||
dtype=dtype,
|
||||
|
|
@ -2532,7 +2530,7 @@ def _empty_permuted_meta(
|
|||
) -> TensorLikeType:
|
||||
p_strides = utils.make_contiguous_strides_for([shape[l] for l in physical_layout])
|
||||
dim = len(shape)
|
||||
utils.check(
|
||||
torch._check(
|
||||
len(physical_layout) == dim,
|
||||
lambda: (
|
||||
"Number of dimensions in the tensor input does not match the "
|
||||
|
|
@ -2543,7 +2541,7 @@ def _empty_permuted_meta(
|
|||
strides = [0] * len(shape)
|
||||
seen_dims = set()
|
||||
for p, l in enumerate(physical_layout):
|
||||
utils.check(
|
||||
torch._check(
|
||||
0 <= l < dim,
|
||||
lambda: (
|
||||
f"Dimension out of range (expected to be between 0 and {dim - 1}, but got "
|
||||
|
|
@ -2551,7 +2549,7 @@ def _empty_permuted_meta(
|
|||
"not currently supported; file an issue if you want it."
|
||||
),
|
||||
)
|
||||
utils.check(l not in seen_dims, lambda: "Duplicate dim not allowed")
|
||||
torch._check(l not in seen_dims, lambda: "Duplicate dim not allowed")
|
||||
strides[l] = p_strides[p]
|
||||
seen_dims.add(l)
|
||||
return TensorMeta(
|
||||
|
|
@ -2779,12 +2777,12 @@ def _normal_meta(
|
|||
device: torch.device,
|
||||
requires_grad: bool,
|
||||
) -> TensorLikeType:
|
||||
utils.check(
|
||||
torch._check(
|
||||
std >= 0.0,
|
||||
lambda: f"expected non-negative standard deviation, but got std={std}",
|
||||
)
|
||||
|
||||
utils.check(
|
||||
torch._check(
|
||||
utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
|
||||
lambda: f"expected a floating-point or complex dtype, but got dtype={dtype}",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from functools import reduce, cmp_to_key
|
|||
import operator
|
||||
import sympy
|
||||
import weakref
|
||||
import warnings
|
||||
import torch
|
||||
from torch import sym_float, sym_int, sym_max
|
||||
|
||||
|
|
@ -268,7 +269,7 @@ _memory_formats = {
|
|||
|
||||
|
||||
def validate_memory_format(memory_format: torch.memory_format):
|
||||
check(
|
||||
torch._check(
|
||||
memory_format in _memory_formats,
|
||||
lambda: f"Received unknown memory format {memory_format}!",
|
||||
)
|
||||
|
|
@ -286,7 +287,7 @@ def is_contiguous_for_memory_format( # type: ignore[return]
|
|||
if memory_format == torch.channels_last_3d:
|
||||
return is_channels_last_contiguous_3d(a)
|
||||
|
||||
check(
|
||||
torch._check(
|
||||
False,
|
||||
lambda: f"is_contiguous received unsupported memory format {memory_format}",
|
||||
)
|
||||
|
|
@ -795,13 +796,13 @@ def infer_size(shape: ShapeType, numel: int) -> Tuple[int, ...]:
|
|||
newsize = 1
|
||||
for i, d in enumerate(shape):
|
||||
if d == -1:
|
||||
check(dim is None, lambda: "only one dimension can be inferred")
|
||||
torch._check(dim is None, lambda: "only one dimension can be inferred")
|
||||
dim = i
|
||||
elif d >= 0:
|
||||
newsize *= d
|
||||
else:
|
||||
check(False, lambda: f"invalid shape dimension {d}")
|
||||
check(
|
||||
torch._check(False, lambda: f"invalid shape dimension {d}")
|
||||
torch._check(
|
||||
numel == newsize or (dim is not None and newsize > 0 and numel % newsize == 0),
|
||||
lambda: f"shape '{list(shape)}' is invalid for input of size {numel}",
|
||||
)
|
||||
|
|
@ -809,7 +810,7 @@ def infer_size(shape: ShapeType, numel: int) -> Tuple[int, ...]:
|
|||
# Convert to list to produce a compatible error message with core
|
||||
# PyTorch, which prints sequences in square brackets.
|
||||
shape = list(shape)
|
||||
check(
|
||||
torch._check(
|
||||
newsize != 0,
|
||||
lambda: (f"cannot reshape tensor of 0 elements into shape {shape} because the "
|
||||
f"unspecified dimension size -1 can be any value and is ambiguous"),
|
||||
|
|
@ -954,18 +955,18 @@ def check_fp_or_complex(
|
|||
Checks whether the input is floating point or complex.
|
||||
If allow_low_precision_dtypes is True, it allows having float16, bfloat16, and complex32
|
||||
"""
|
||||
check(
|
||||
torch._check(
|
||||
is_float_dtype(dtype) or is_complex_dtype(dtype),
|
||||
lambda: f"{fn_name}: Expected a floating point or complex tensor as input. Got {dtype}",
|
||||
)
|
||||
check(
|
||||
torch._check(
|
||||
allow_low_precision_dtypes or not is_low_precision_dtype(dtype),
|
||||
lambda: f"{fn_name}: Half precision dtypes not supported. Got {dtype}",
|
||||
)
|
||||
|
||||
|
||||
def check_is_matrix(A: TensorLikeType, f_name: str, arg_name: str = "A"):
|
||||
check(
|
||||
torch._check(
|
||||
len(A.shape) >= 2,
|
||||
lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.",
|
||||
)
|
||||
|
|
@ -1060,11 +1061,11 @@ def get_higher_dtype(
|
|||
|
||||
|
||||
def check_pin_memory(pin_memory: bool):
|
||||
check(not pin_memory, lambda: "PrimTorch does not support pinned memory", NotImplementedError)
|
||||
torch._check_not_implemented(not pin_memory, lambda: "PrimTorch does not support pinned memory")
|
||||
|
||||
|
||||
def check_layout(layout: torch.layout):
|
||||
check(layout == torch.strided, lambda: f"PrimTorch doesn't support layout={layout}", NotImplementedError)
|
||||
torch._check_not_implemented(layout == torch.strided, lambda: f"PrimTorch doesn't support layout={layout}")
|
||||
|
||||
|
||||
# TODO: maybe unify with can_cast_to?
|
||||
|
|
@ -1485,7 +1486,7 @@ def make_contiguous_strides_for(
|
|||
|
||||
|
||||
def make_channels_last_1d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
|
||||
check(
|
||||
torch._check(
|
||||
len(shape) == 3,
|
||||
lambda: "Only tensors of rank 3 can use the channels_last_1d memory format",
|
||||
)
|
||||
|
|
@ -1503,7 +1504,7 @@ def make_channels_last_1d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
|
|||
|
||||
def make_channels_last_2d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
|
||||
# TODO: maybe inform the user of channels_last_3d if rank of the tensor is 5?
|
||||
check(
|
||||
torch._check(
|
||||
len(shape) == 4,
|
||||
lambda: "Only tensors of rank 4 can use the channels_last memory format",
|
||||
)
|
||||
|
|
@ -1520,7 +1521,7 @@ def make_channels_last_2d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
|
|||
|
||||
|
||||
def make_channels_last_3d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
|
||||
check(
|
||||
torch._check(
|
||||
len(shape) == 5,
|
||||
lambda: "Only tensors of rank 5 can use the channels_last_3d memory format",
|
||||
)
|
||||
|
|
@ -1654,6 +1655,9 @@ def check_in_bounds_for_storage(
|
|||
raise ValueError(msg)
|
||||
|
||||
|
||||
# NOTE: This function should ideally be removed, but some Meta internal models
|
||||
# packaged with `torch.package` are using it, so it will have to be removed
|
||||
# at some point in the future when those models no longer use this function.
|
||||
def check(
|
||||
b: bool, s: Callable[[], str], exc_type: Type[Exception] = RuntimeError
|
||||
) -> None:
|
||||
|
|
@ -1662,9 +1666,14 @@ def check(
|
|||
Error message is a callable producing a string (to avoid wasting time
|
||||
string formatting in non-error case, and also to make it easier for torchdynamo
|
||||
to trace.)
|
||||
|
||||
.. note:: This function is planned for removal in the future. Please use
|
||||
`torch._check*` functions instead.
|
||||
"""
|
||||
if not b:
|
||||
raise exc_type(s())
|
||||
warnings.warn(DeprecationWarning((
|
||||
"'torch._prims_common.check' will be removed in the future. Please use "
|
||||
"'torch._check*' functions instead")))
|
||||
torch._check_with(exc_type, b, s)
|
||||
|
||||
|
||||
# This combines is_channels_last_strides_2d and is_channels_last_strides_3d in
|
||||
|
|
|
|||
|
|
@ -176,13 +176,13 @@ def _safe_copy_out(
|
|||
|
||||
# Checks safe cast
|
||||
if exact_dtype:
|
||||
utils.check(
|
||||
torch._check(
|
||||
copy_from.dtype == copy_to.dtype,
|
||||
lambda: f"Expected out tensor to have dtype {copy_from.dtype} "
|
||||
f"but got {copy_to.dtype} instead",
|
||||
)
|
||||
else:
|
||||
utils.check(
|
||||
torch._check(
|
||||
utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype),
|
||||
lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, "
|
||||
"but this can't be cast because it is not safe!",
|
||||
|
|
@ -255,10 +255,9 @@ def out_wrapper(*out_names: str, exact_dtype: bool = False):
|
|||
_safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type]
|
||||
else:
|
||||
assert isinstance(out, Tuple) # type: ignore[arg-type]
|
||||
utils.check(
|
||||
torch._check_type(
|
||||
len(out) == len(result),
|
||||
lambda: f"expected tuple of {len(result)} elements but got {len(out)}",
|
||||
TypeError,
|
||||
)
|
||||
for r, o in zip(result, out):
|
||||
# These two operations are done in-place
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ import torch._prims as prims
|
|||
import torch._prims_common as utils
|
||||
from torch import sym_float, sym_int
|
||||
from torch._prims_common import (
|
||||
check,
|
||||
DeviceLikeType,
|
||||
Dim,
|
||||
DimsSequenceType,
|
||||
|
|
@ -626,7 +625,7 @@ def frac(x: TensorLikeType) -> TensorLikeType:
|
|||
# imag does not use _make_elementwise_unary_reference because it does not support out
|
||||
def imag(a: TensorLikeType) -> TensorLikeType:
|
||||
assert isinstance(a, TensorLike)
|
||||
utils.check(
|
||||
torch._check(
|
||||
utils.is_complex_dtype(a.dtype), lambda: "imag only supports complex tensors."
|
||||
)
|
||||
return prims.imag(a)
|
||||
|
|
@ -654,7 +653,7 @@ def isinf(a: TensorLikeType) -> TensorLikeType:
|
|||
|
||||
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
|
||||
def isposinf(a: TensorLikeType) -> TensorLikeType:
|
||||
utils.check(
|
||||
torch._check(
|
||||
not utils.is_complex_dtype(a.dtype),
|
||||
lambda: f"Complex dtype is not supported for isposinf, got dtype {a.dtype}",
|
||||
)
|
||||
|
|
@ -665,7 +664,7 @@ def isposinf(a: TensorLikeType) -> TensorLikeType:
|
|||
|
||||
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
|
||||
def isneginf(a: TensorLikeType) -> TensorLikeType:
|
||||
utils.check(
|
||||
torch._check(
|
||||
not utils.is_complex_dtype(a.dtype),
|
||||
lambda: f"Complex dtype is not supported for isneginf, got dtype {a.dtype}",
|
||||
)
|
||||
|
|
@ -788,7 +787,7 @@ def nan_to_num(
|
|||
|
||||
|
||||
def _neg_meta(a: TensorLikeType):
|
||||
check(
|
||||
torch._check(
|
||||
a.dtype is not torch.bool,
|
||||
lambda: (
|
||||
"Negation, the `-` operator, on a bool tensor is not supported. "
|
||||
|
|
@ -935,23 +934,20 @@ def _make_elementwise_binary_reference(
|
|||
a: Union[Tensor, NumberType],
|
||||
b: Union[Tensor, NumberType],
|
||||
) -> Tensor:
|
||||
check(
|
||||
torch._check_value(
|
||||
supports_lhs_python_scalar or not isinstance(a, Number),
|
||||
lambda: f"{name}: Received a lhs Python scalar to an elementwise binary "
|
||||
"operation that does not accept lhs scalars!",
|
||||
ValueError,
|
||||
)
|
||||
check(
|
||||
torch._check_value(
|
||||
supports_rhs_python_scalar or not isinstance(b, Number),
|
||||
lambda: f"{name}: Received a rhs Python scalar to an elementwise binary "
|
||||
"operation that does not accept rhs scalars!",
|
||||
ValueError,
|
||||
)
|
||||
check(
|
||||
torch._check_value(
|
||||
supports_two_python_scalars
|
||||
or not (isinstance(a, Number) and isinstance(b, Number)),
|
||||
lambda: f"{name}: Receive two Number inputs to an elementwise binary operation!",
|
||||
ValueError,
|
||||
)
|
||||
a, b = _maybe_broadcast(a, b)
|
||||
return prim(a, b)
|
||||
|
|
@ -1230,7 +1226,7 @@ def floor_divide(
|
|||
elif utils.is_integer_dtype(dtype):
|
||||
return _floor_divide_integer(a, b)
|
||||
else:
|
||||
check(False, lambda: f"{dtype} not supported for floor_divide")
|
||||
torch._check(False, lambda: f"{dtype} not supported for floor_divide")
|
||||
|
||||
|
||||
def _floor_divide_integer(a: Tensor, b: Tensor) -> Tensor:
|
||||
|
|
@ -1374,20 +1370,19 @@ def _check_close_args(
|
|||
rtol: float,
|
||||
atol: float,
|
||||
) -> None:
|
||||
check(
|
||||
torch._check_value(
|
||||
a.dtype == b.dtype,
|
||||
lambda: "{0}: Attempting to compare tensors of different dtypes {1} and {2}!".format(
|
||||
name, a.dtype, b.dtype
|
||||
),
|
||||
ValueError,
|
||||
)
|
||||
check(
|
||||
torch._check(
|
||||
rtol >= 0,
|
||||
lambda: "{0}: rtol must be greater than or equal to zero, but got {1}!".format(
|
||||
name, rtol
|
||||
),
|
||||
)
|
||||
check(
|
||||
torch._check(
|
||||
atol >= 0,
|
||||
lambda: "{0}: atol must be greater than or equal to zero, but got {1}!".format(
|
||||
name, atol
|
||||
|
|
@ -1678,7 +1673,7 @@ def true_divide(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
|
|||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
||||
)
|
||||
def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]):
|
||||
utils.check(
|
||||
torch._check(
|
||||
isinstance(a, TensorLike) or isinstance(b, TensorLike),
|
||||
lambda: 'Expected either argument a or b to be a Tensor"',
|
||||
)
|
||||
|
|
@ -1736,12 +1731,11 @@ def addcdiv(
|
|||
if value is not None:
|
||||
dtype = self.dtype # no scalars allowed, see add
|
||||
python_type = utils.dtype_to_type(dtype)
|
||||
check(
|
||||
torch._check_value(
|
||||
utils.is_weakly_lesser_type(type(value), python_type),
|
||||
lambda: "value argument of type {0} cannot be safely cast to type {1}!".format(
|
||||
type(value), python_type
|
||||
),
|
||||
exc_type=ValueError,
|
||||
)
|
||||
|
||||
return self + value * tensor1 / tensor2
|
||||
|
|
@ -1766,12 +1760,11 @@ def addcmul(
|
|||
if value is not None:
|
||||
dtype = self.dtype # no scalars allowed, see add
|
||||
python_type = utils.dtype_to_type(dtype)
|
||||
check(
|
||||
torch._check_value(
|
||||
utils.is_weakly_lesser_type(type(value), python_type),
|
||||
lambda: "value argument of type {0} cannot be safely cast to type {1}!".format(
|
||||
type(value), python_type
|
||||
),
|
||||
exc_type=ValueError,
|
||||
)
|
||||
|
||||
return self + value * tensor1 * tensor2
|
||||
|
|
@ -1851,7 +1844,7 @@ def where(
|
|||
raise NotImplementedError
|
||||
|
||||
utils.check_same_device(pred, a, b, allow_cpu_scalar_tensors=True)
|
||||
check(
|
||||
torch._check(
|
||||
pred.dtype is torch.bool,
|
||||
lambda: f"expected predicate to be bool, got {pred.dtype}",
|
||||
)
|
||||
|
|
@ -2229,7 +2222,7 @@ def sum_to_size(
|
|||
*shape,
|
||||
) -> Tensor:
|
||||
shape = utils.extract_shape_from_varargs(shape, validate=False)
|
||||
utils.check(
|
||||
torch._check(
|
||||
utils.is_expandable_to(shape, a.shape),
|
||||
lambda: f'sum_to_size: size "{shape}" is not expandable to size "{a.shape}"',
|
||||
)
|
||||
|
|
@ -2402,7 +2395,7 @@ def mean(
|
|||
if dtype is None:
|
||||
dtype = a.dtype
|
||||
# can't use out wrapper because of this argument
|
||||
check(
|
||||
torch._check(
|
||||
out is None or out.dtype == dtype,
|
||||
lambda: f"Expected out tensor to have dtype {dtype}, but got {out.dtype} instead",
|
||||
)
|
||||
|
|
@ -2415,7 +2408,7 @@ def mean(
|
|||
out=None,
|
||||
output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE,
|
||||
)
|
||||
check(
|
||||
torch._check(
|
||||
utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
|
||||
lambda: (
|
||||
f"mean(): could not infer output dtype. "
|
||||
|
|
@ -2491,22 +2484,22 @@ def addr(
|
|||
beta: NumberType = 1,
|
||||
alpha: NumberType = 1,
|
||||
) -> TensorLikeType:
|
||||
check(
|
||||
torch._check(
|
||||
vec1.ndim == 1,
|
||||
lambda: f"addr: Expected 1-D argument vec1, but got {vec1.ndim}-D",
|
||||
)
|
||||
check(
|
||||
torch._check(
|
||||
vec2.ndim == 1,
|
||||
lambda: f"addr: Expected 1-D argument vec2, but got {vec2.ndim}-D",
|
||||
)
|
||||
self = self.expand(vec1.shape[0], vec2.shape[0])
|
||||
if utils.is_boolean_dtype(self.dtype):
|
||||
# Integers are accepted for booleans
|
||||
check(
|
||||
torch._check(
|
||||
is_weakly_lesser_type(type(beta), int),
|
||||
lambda: f"expected bool/int beta but got {type(beta)}",
|
||||
)
|
||||
check(
|
||||
torch._check(
|
||||
is_weakly_lesser_type(type(alpha), int),
|
||||
lambda: f"expected bool/int alpha but got {type(beta)}",
|
||||
)
|
||||
|
|
@ -2518,11 +2511,11 @@ def addr(
|
|||
torch.outer(vec1, vec2) if alpha else torch.full_like(self, False),
|
||||
)
|
||||
else:
|
||||
check(
|
||||
torch._check(
|
||||
is_weakly_lesser_type(type(beta), dtype_to_type(self.dtype)),
|
||||
lambda: f"cannot safely convert {type(beta)} to {self.dtype}",
|
||||
)
|
||||
check(
|
||||
torch._check(
|
||||
is_weakly_lesser_type(type(alpha), dtype_to_type(self.dtype)),
|
||||
lambda: f"cannot safely convert {type(alpha)} to {self.dtype}",
|
||||
)
|
||||
|
|
@ -2712,7 +2705,7 @@ def conj(input: TensorLikeType) -> TensorLikeType:
|
|||
def constant_pad_nd(
|
||||
input: TensorLikeType, pad: List[int], value: NumberType = 0
|
||||
) -> TensorLikeType:
|
||||
check(
|
||||
torch._check(
|
||||
len(pad) % 2 == 0,
|
||||
lambda: f"Length of pad must be even but instead it equals {len(pad)}",
|
||||
)
|
||||
|
|
@ -2723,7 +2716,7 @@ def constant_pad_nd(
|
|||
l_pad = len(pad) // 2
|
||||
l_diff = l_inp - l_pad
|
||||
|
||||
check(
|
||||
torch._check(
|
||||
l_inp >= l_pad,
|
||||
lambda: "Length of pad should be no more than twice the number of "
|
||||
f"dimensions of the input. Pad length is {len(pad)} while the input has "
|
||||
|
|
@ -2748,7 +2741,7 @@ def constant_pad_nd(
|
|||
for i in range(l_pad):
|
||||
pad_idx = len(pad) - ((i + 1) * 2)
|
||||
new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1]
|
||||
check(
|
||||
torch._check(
|
||||
new_dim > 0,
|
||||
lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding "
|
||||
f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, "
|
||||
|
|
@ -2787,7 +2780,7 @@ def constant_pad_nd(
|
|||
def contiguous(
|
||||
a: Tensor, *, memory_format: torch.memory_format = torch.contiguous_format
|
||||
) -> Tensor:
|
||||
check(
|
||||
torch._check(
|
||||
memory_format != torch.preserve_format,
|
||||
lambda: "preserve memory format is unsupported by the contiguous operator",
|
||||
)
|
||||
|
|
@ -2800,7 +2793,7 @@ def contiguous(
|
|||
|
||||
@out_wrapper()
|
||||
def dstack(tensors: TensorSequenceType) -> TensorLikeType:
|
||||
check(len(tensors) > 0, lambda: "dstack expects a non-empty TensorList")
|
||||
torch._check(len(tensors) > 0, lambda: "dstack expects a non-empty TensorList")
|
||||
aligned_tensors = atleast_3d(*tensors)
|
||||
return cat(aligned_tensors, 2)
|
||||
|
||||
|
|
@ -2813,7 +2806,7 @@ def expand(a: Tensor, *shape) -> Tensor:
|
|||
if len(shape) == 1 and isinstance(shape[0], Sequence):
|
||||
shape = tuple(shape[0])
|
||||
|
||||
check(
|
||||
torch._check(
|
||||
len(shape) >= len(a.shape),
|
||||
lambda: "expand: the requested shape has too few dimensions!",
|
||||
)
|
||||
|
|
@ -2823,7 +2816,7 @@ def expand(a: Tensor, *shape) -> Tensor:
|
|||
for idx, x in enumerate(a.shape):
|
||||
offset_idx = idx + offset
|
||||
requested_length = shape[offset_idx]
|
||||
check(
|
||||
torch._check(
|
||||
requested_length == x or x == 1 or requested_length == -1,
|
||||
lambda: f"expand: attempting to expand a dimension of length {x}!",
|
||||
)
|
||||
|
|
@ -2917,13 +2910,13 @@ def narrow(
|
|||
# Supports Tensor overload that was added for XLA:
|
||||
# https://github.com/pytorch/pytorch/issues/31558
|
||||
if isinstance(start, TensorLike):
|
||||
check(
|
||||
torch._check(
|
||||
start.dim() == 0 and utils.is_integer_dtype(start.dtype),
|
||||
lambda: "start must be an 0-dim integral Tensor.",
|
||||
)
|
||||
start = start.item() # type: ignore[assignment]
|
||||
check(a.dim() > 0, lambda: "narrow() cannot be applied to a 0-dim tensor.")
|
||||
check(length >= 0, lambda: "narrow(): length must be non-negative.")
|
||||
torch._check(a.dim() > 0, lambda: "narrow() cannot be applied to a 0-dim tensor.")
|
||||
torch._check(length >= 0, lambda: "narrow(): length must be non-negative.")
|
||||
dim = utils.canonicalize_dim(a.ndim, dim)
|
||||
dim_length = a.size(dim)
|
||||
# Start being the end is usually invalid since it's out of bounds. So it's
|
||||
|
|
@ -2934,7 +2927,7 @@ def narrow(
|
|||
# Note: a dimension isn't being canonicalized here, this reuses
|
||||
# canonicalize_dim because the semantics are similar.
|
||||
start = utils.canonicalize_dim(dim_length, start) # type: ignore[arg-type]
|
||||
check(
|
||||
torch._check(
|
||||
start <= dim_length - length, # type: ignore[arg-type]
|
||||
lambda: f"start ({start}) + length ({length}) exceeds dimension size ({dim_length}).",
|
||||
)
|
||||
|
|
@ -2993,11 +2986,11 @@ def native_group_norm(
|
|||
num_groups: int,
|
||||
eps: float,
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
utils.check(
|
||||
torch._check(
|
||||
input.ndim >= 2,
|
||||
lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}",
|
||||
)
|
||||
utils.check(
|
||||
torch._check(
|
||||
num_channels % num_groups == 0,
|
||||
lambda: "Expected number of channels in input to be divisible by num_groups, "
|
||||
+ f"but got input of shape {input.shape} and num_groups = {num_groups}",
|
||||
|
|
@ -3044,7 +3037,7 @@ def native_layer_norm(
|
|||
eps: float,
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
normalized_ndim = len(normalized_shape)
|
||||
utils.check(
|
||||
torch._check(
|
||||
normalized_ndim >= 1,
|
||||
lambda: "Expected normalized_shape to be at least 1-dimensional, i.e., "
|
||||
+ "containing at least one element, but got normalized_shape = "
|
||||
|
|
@ -3053,7 +3046,7 @@ def native_layer_norm(
|
|||
# torch.Size([1, 2, 3]) == [1, 2, 3] evaluates to False
|
||||
# while torch.Size([1, 2, 3]) == (1, 2, 3) is True
|
||||
# therefore we use tuple(normalized_shape)
|
||||
utils.check(
|
||||
torch._check(
|
||||
weight is None or weight.shape == tuple(normalized_shape),
|
||||
lambda: "Expected weight to be of same shape as normalized_shape, but got "
|
||||
+ "weight of shape "
|
||||
|
|
@ -3061,7 +3054,7 @@ def native_layer_norm(
|
|||
+ " and normalized_shape = "
|
||||
+ str(normalized_shape),
|
||||
)
|
||||
utils.check(
|
||||
torch._check(
|
||||
bias is None or bias.shape == tuple(normalized_shape),
|
||||
lambda: "Expected bias to be of same shape as normalized_shape, but got "
|
||||
+ "bias of shape "
|
||||
|
|
@ -3069,7 +3062,7 @@ def native_layer_norm(
|
|||
+ " and normalized_shape = "
|
||||
+ str(normalized_shape),
|
||||
)
|
||||
utils.check(
|
||||
torch._check(
|
||||
input.ndim >= normalized_ndim
|
||||
and input.shape[(input.ndim - normalized_ndim) :] == tuple(normalized_shape),
|
||||
lambda: "Given normalized_shape="
|
||||
|
|
@ -3123,12 +3116,12 @@ def _get_unfold_shape_stride(
|
|||
max_size = 1 if a_ndim == 0 else a_shape[dim]
|
||||
last_stride = 1 if a_ndim == 0 else a_stride[dim]
|
||||
|
||||
utils.check(
|
||||
torch._check(
|
||||
size <= max_size,
|
||||
lambda: f"Maximum size for tensor at dimension {dim} is {max_size} but size is {size}",
|
||||
)
|
||||
|
||||
utils.check(
|
||||
torch._check(
|
||||
step > 0,
|
||||
lambda: f"Step is {step} but must be > 0",
|
||||
)
|
||||
|
|
@ -3146,7 +3139,7 @@ def _get_unfold_shape_stride(
|
|||
@register_decomposition(aten.repeat)
|
||||
def repeat(a: Tensor, *repeat_shape) -> Tensor:
|
||||
repeat_shape = utils.extract_shape_from_varargs(repeat_shape, validate=False)
|
||||
utils.check(
|
||||
torch._check(
|
||||
len(repeat_shape) >= len(a.shape),
|
||||
lambda: "repeat: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
|
||||
)
|
||||
|
|
@ -3452,7 +3445,7 @@ def softmax(
|
|||
# CompositeImplicitAutograd - don't register decomp
|
||||
@out_wrapper()
|
||||
def hstack(tensors: TensorSequenceType) -> TensorLikeType:
|
||||
check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList")
|
||||
torch._check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList")
|
||||
aligned_tensors = atleast_1d(*tensors)
|
||||
if aligned_tensors[0].ndim == 1:
|
||||
return cat(aligned_tensors, 0)
|
||||
|
|
@ -3462,7 +3455,7 @@ def hstack(tensors: TensorSequenceType) -> TensorLikeType:
|
|||
# CompositeImplicitAutograd - don't register decomp
|
||||
@out_wrapper()
|
||||
def vstack(tensors: TensorSequenceType) -> TensorLikeType:
|
||||
check(len(tensors) > 0, lambda: "vstack expects a non-empty TensorList")
|
||||
torch._check(len(tensors) > 0, lambda: "vstack expects a non-empty TensorList")
|
||||
aligned_tensors = atleast_2d(*tensors)
|
||||
return cat(aligned_tensors, 0)
|
||||
|
||||
|
|
@ -3470,17 +3463,16 @@ def vstack(tensors: TensorSequenceType) -> TensorLikeType:
|
|||
# CompositeImplicitAutograd - don't register decomp
|
||||
def unflatten(a: TensorLikeType, dim: int, sizes: ShapeType) -> TensorLikeType:
|
||||
dim = utils.canonicalize_dim(a.ndim, dim)
|
||||
utils.check(len(sizes) != 0, lambda: "unflatten: sizes must be non-empty")
|
||||
torch._check(len(sizes) != 0, lambda: "unflatten: sizes must be non-empty")
|
||||
return a.view(tuple(a.shape[:dim]) + tuple(sizes) + tuple(a.shape[dim + 1 :]))
|
||||
|
||||
|
||||
@register_decomposition(aten.unbind)
|
||||
def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType:
|
||||
dim = utils.canonicalize_dim(t.ndim, dim)
|
||||
check(
|
||||
torch._check_index(
|
||||
len(t.shape) > 0,
|
||||
lambda: "Dimension specified as 0 but tensor has no dimensions",
|
||||
IndexError,
|
||||
)
|
||||
if t.shape[dim] == 0:
|
||||
return tuple()
|
||||
|
|
@ -3499,7 +3491,7 @@ def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
|
|||
|
||||
def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
|
||||
dim = utils.canonicalize_dims(x.ndim, dim)
|
||||
utils.check(
|
||||
torch._check(
|
||||
index.ndim <= 1,
|
||||
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
|
||||
)
|
||||
|
|
@ -3532,12 +3524,12 @@ def _index_fill(
|
|||
*,
|
||||
inplace: bool,
|
||||
):
|
||||
utils.check(
|
||||
torch._check(
|
||||
index.ndim <= 1,
|
||||
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
|
||||
)
|
||||
if isinstance(value, TensorLike):
|
||||
utils.check(
|
||||
torch._check(
|
||||
value.ndim == 0,
|
||||
lambda: "Only supports 0-dimensional value tensor. " # type: ignore[union-attr]
|
||||
f"Got a tensor with {value.ndim} dimensions.",
|
||||
|
|
@ -3589,7 +3581,7 @@ def index_add(
|
|||
@out_wrapper()
|
||||
def index_select(x: TensorLike, dim: int, index: TensorLike):
|
||||
dim = utils.canonicalize_dims(x.ndim, dim)
|
||||
utils.check(
|
||||
torch._check(
|
||||
index.ndim <= 1,
|
||||
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
|
||||
)
|
||||
|
|
@ -3713,7 +3705,7 @@ def tensor_split(
|
|||
def hsplit(
|
||||
a: TensorLikeType, indices_or_sections: DimsType
|
||||
) -> Tuple[TensorLikeType, ...]:
|
||||
check(
|
||||
torch._check(
|
||||
a.ndim >= 1,
|
||||
lambda: (
|
||||
"torch.hsplit requires a tensor with at least 1 dimension, but got a tensor with "
|
||||
|
|
@ -3724,7 +3716,7 @@ def hsplit(
|
|||
dim = 0 if a.ndim == 1 else 1
|
||||
if isinstance(indices_or_sections, IntLike):
|
||||
split_size = indices_or_sections
|
||||
check(
|
||||
torch._check(
|
||||
(split_size != 0 and a.shape[dim] % split_size == 0),
|
||||
lambda: (
|
||||
"torch.hsplit attempted to split along dimension "
|
||||
|
|
@ -3738,14 +3730,13 @@ def hsplit(
|
|||
)
|
||||
return tensor_split(a, split_size, dim)
|
||||
|
||||
check(
|
||||
torch._check_type(
|
||||
isinstance(indices_or_sections, (list, tuple)),
|
||||
lambda: (
|
||||
"hsplit(): received an invalid combination of arguments. "
|
||||
"Expected indices_or_sections to be of type int, list of ints or tuple of ints "
|
||||
f"but got type {type(indices_or_sections)}"
|
||||
),
|
||||
exc_type=TypeError,
|
||||
)
|
||||
|
||||
split_sizes = indices_or_sections
|
||||
|
|
@ -3756,7 +3747,7 @@ def hsplit(
|
|||
def vsplit(
|
||||
a: TensorLikeType, indices_or_sections: DimsType
|
||||
) -> Tuple[TensorLikeType, ...]:
|
||||
check(
|
||||
torch._check(
|
||||
a.ndim >= 2,
|
||||
lambda: (
|
||||
"torch.vsplit requires a tensor with at least 2 dimension, but got a tensor with "
|
||||
|
|
@ -3766,7 +3757,7 @@ def vsplit(
|
|||
)
|
||||
if isinstance(indices_or_sections, IntLike):
|
||||
split_size = indices_or_sections
|
||||
check(
|
||||
torch._check(
|
||||
(split_size != 0 and a.shape[0] % split_size == 0),
|
||||
lambda: (
|
||||
f"torch.vsplit attempted to split along dimension 0"
|
||||
|
|
@ -3779,14 +3770,13 @@ def vsplit(
|
|||
)
|
||||
return tensor_split(a, split_size, 0)
|
||||
|
||||
check(
|
||||
torch._check_type(
|
||||
isinstance(indices_or_sections, (list, tuple)),
|
||||
lambda: (
|
||||
"vsplit(): received an invalid combination of arguments. "
|
||||
"Expected indices_or_sections to be of type int, list of ints or tuple of ints "
|
||||
f"but got type {type(indices_or_sections)}"
|
||||
),
|
||||
exc_type=TypeError,
|
||||
)
|
||||
|
||||
split_sizes = indices_or_sections
|
||||
|
|
@ -3800,7 +3790,7 @@ def diag(
|
|||
offset: int = 0,
|
||||
) -> TensorLikeType:
|
||||
ndim = self.dim()
|
||||
utils.check(
|
||||
torch._check(
|
||||
ndim in (1, 2), lambda: f"diag(): Supports 1D or 2D tensors. Got {ndim}D"
|
||||
)
|
||||
if ndim == 1:
|
||||
|
|
@ -3820,7 +3810,7 @@ def diagonal_scatter(
|
|||
) -> TensorLikeType:
|
||||
out = utils.clone_preserve_strides(input)
|
||||
diag = out.diagonal(offset, dim1, dim2)
|
||||
check(
|
||||
torch._check(
|
||||
diag.shape == src.shape,
|
||||
lambda: "expected src to have a size equal to the diagonal of the input."
|
||||
f"Got {src.shape} for a diagonal of shape {diag.shape}",
|
||||
|
|
@ -3843,7 +3833,7 @@ def diagonal(
|
|||
dim1 = utils.canonicalize_dim(idx=dim1, rank=num_dims)
|
||||
dim2 = utils.canonicalize_dim(idx=dim2, rank=num_dims)
|
||||
|
||||
check(
|
||||
torch._check(
|
||||
dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
|
||||
)
|
||||
|
||||
|
|
@ -3896,7 +3886,7 @@ def diag_embed(
|
|||
dim1 = utils.canonicalize_dim(rank=rank, idx=dim1)
|
||||
dim2 = utils.canonicalize_dim(rank=rank, idx=dim2)
|
||||
|
||||
check(
|
||||
torch._check(
|
||||
dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
|
||||
)
|
||||
|
||||
|
|
@ -3967,7 +3957,7 @@ def t(a: TensorLikeType):
|
|||
# CompositeImplicitAutograd - don't register decomp
|
||||
def T(a: TensorLikeType) -> TensorLikeType:
|
||||
# n != 2 && n != 0 is deprecated in regular PyTorch.
|
||||
check(
|
||||
torch._check(
|
||||
a.ndim in (0, 2),
|
||||
lambda: (
|
||||
"The use of `x.T` on tensors of dimension other than 0 or 2 "
|
||||
|
|
@ -4102,7 +4092,7 @@ def empty(
|
|||
pin_memory: bool = False,
|
||||
memory_format: torch.memory_format = torch.contiguous_format,
|
||||
) -> TensorLikeType:
|
||||
check(
|
||||
torch._check(
|
||||
memory_format != torch.preserve_format,
|
||||
lambda: "torch.empty: the Preserve memory format is not supported",
|
||||
)
|
||||
|
|
@ -4114,7 +4104,7 @@ def empty(
|
|||
elif memory_format == torch.channels_last_3d:
|
||||
strides = utils.make_channels_last_3d_strides_for(shape)
|
||||
else: # memory_format == torch.channels_last
|
||||
check(
|
||||
torch._check(
|
||||
memory_format == torch.channels_last,
|
||||
lambda: f"torch.empty: received an unknown memory format {memory_format}!",
|
||||
)
|
||||
|
|
@ -4398,8 +4388,8 @@ def arange(
|
|||
if end is None:
|
||||
end = start
|
||||
start = 0
|
||||
utils.check(step != 0, lambda: "step must be nonzero")
|
||||
utils.check(
|
||||
torch._check(step != 0, lambda: "step must be nonzero")
|
||||
torch._check(
|
||||
(step > 0 and end >= start) or (step < 0 and end <= start),
|
||||
lambda: "upper bound and lower bound inconsistent with step sign",
|
||||
)
|
||||
|
|
@ -4407,11 +4397,11 @@ def arange(
|
|||
def is_finite(x):
|
||||
return not isinstance(x, FloatWithoutSymFloat) or math.isfinite(x)
|
||||
|
||||
utils.check(
|
||||
torch._check(
|
||||
is_finite(start) and is_finite(end),
|
||||
lambda: f"unsupported range: {start} -> {end}",
|
||||
)
|
||||
utils.check(
|
||||
torch._check(
|
||||
is_finite(step),
|
||||
lambda: f"step must be finite but got {step}",
|
||||
)
|
||||
|
|
@ -4514,7 +4504,7 @@ def linspace(
|
|||
if dtype is None:
|
||||
dtype = default_complex_dtype
|
||||
else:
|
||||
check(
|
||||
torch._check(
|
||||
utils.is_complex_dtype(dtype),
|
||||
lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}",
|
||||
)
|
||||
|
|
@ -4523,13 +4513,12 @@ def linspace(
|
|||
assert isinstance(dtype, torch.dtype)
|
||||
|
||||
# steps does not participate in the computation of the dtype
|
||||
check(
|
||||
torch._check_type(
|
||||
isinstance(steps, IntLike),
|
||||
lambda: "steps must be int, not float",
|
||||
exc_type=TypeError,
|
||||
)
|
||||
assert isinstance(steps, IntLike) # for mypy
|
||||
check(steps >= 0, lambda: "number of steps must be non-negative")
|
||||
torch._check(steps >= 0, lambda: "number of steps must be non-negative")
|
||||
|
||||
factory_kwargs = {
|
||||
"layout": layout,
|
||||
|
|
@ -4631,19 +4620,19 @@ def meshgrid(
|
|||
assert len(tensors) == 1
|
||||
tensors = tuple(tensors[0])
|
||||
|
||||
check(
|
||||
torch._check(
|
||||
py_all(isinstance(a, TensorLike) for a in tensors),
|
||||
lambda: "meshgrid expects its inputs to be tensors",
|
||||
)
|
||||
|
||||
check(len(tensors) > 0, lambda: "meshgrid expects a non-empty TensorList")
|
||||
torch._check(len(tensors) > 0, lambda: "meshgrid expects a non-empty TensorList")
|
||||
|
||||
for i in range(len(tensors) - 1):
|
||||
check(
|
||||
torch._check(
|
||||
tensors[i].dtype == tensors[i + 1].dtype, # type: ignore[union-attr]
|
||||
lambda: "meshgrid expects all tensors to have the same dtype",
|
||||
)
|
||||
check(
|
||||
torch._check(
|
||||
tensors[i].device == tensors[i + 1].device, # type: ignore[union-attr]
|
||||
lambda: "meshgrid expects all tensors to have the same device",
|
||||
)
|
||||
|
|
@ -4654,7 +4643,7 @@ def meshgrid(
|
|||
if swap_first_and_second_tensors:
|
||||
tensors = (tensors[1], tensors[0], *tensors[2:])
|
||||
else:
|
||||
check(
|
||||
torch._check(
|
||||
indexing == "ij",
|
||||
lambda: (
|
||||
'torch.meshgrid: indexing must be one of "xy" or "ij", '
|
||||
|
|
@ -4665,7 +4654,7 @@ def meshgrid(
|
|||
result_shape: List[int] = []
|
||||
for t in tensors:
|
||||
assert isinstance(t, TensorLike) # mypy
|
||||
check(
|
||||
torch._check(
|
||||
t.ndim == 0 or t.ndim == 1,
|
||||
lambda: f"torch.meshgrid: Expected 0D or 1D tensor in the tensor list but got: {t}",
|
||||
)
|
||||
|
|
@ -4701,7 +4690,7 @@ def movedim(
|
|||
|
||||
# Converts to list to produce a compatible error message with core PyTorch,
|
||||
# which prints sequences in square brackets.
|
||||
utils.check(
|
||||
torch._check(
|
||||
len(source) == len(destination), # type: ignore[arg-type]
|
||||
lambda: (
|
||||
"movedim: Invalid source or destination dims: source " # type: ignore[arg-type]
|
||||
|
|
@ -4718,11 +4707,11 @@ def movedim(
|
|||
dss = set(ds)
|
||||
|
||||
# See above on why this converts to list in error messages.
|
||||
utils.check(
|
||||
torch._check(
|
||||
len(ss) == len(sss),
|
||||
lambda: f"movedim: repeated dim in `source` ({list(source)})", # type: ignore[arg-type]
|
||||
)
|
||||
utils.check(
|
||||
torch._check(
|
||||
len(ds) == len(dss),
|
||||
lambda: f"movedim: repeated dim in `destination` ({list(destination)})", # type: ignore[arg-type]
|
||||
)
|
||||
|
|
@ -4795,8 +4784,8 @@ def eye(
|
|||
if m is None:
|
||||
m = n
|
||||
|
||||
check(n >= 0, lambda: f"n must be greater or equal to 0, got {n}")
|
||||
check(m >= 0, lambda: f"m must be greater or equal to 0, got {m}")
|
||||
torch._check(n >= 0, lambda: f"n must be greater or equal to 0, got {n}")
|
||||
torch._check(m >= 0, lambda: f"m must be greater or equal to 0, got {m}")
|
||||
|
||||
range_n = torch.arange(n, dtype=torch.int64, device=device, requires_grad=False)
|
||||
range_m = torch.arange(m, dtype=torch.int64, device=device, requires_grad=False)
|
||||
|
|
@ -4994,13 +4983,13 @@ def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLi
|
|||
# NOTE: Could not use value = item(value) as it resulted in
|
||||
# RuntimeError: Cannot cast FakeTensor(cpu) to number
|
||||
value_ndim = value.ndim
|
||||
check(
|
||||
torch._check(
|
||||
value_ndim == 0,
|
||||
lambda: f"only supports a 0-dimensional value tensor, but got tensor with {value_ndim} dimension",
|
||||
)
|
||||
# `masked_fill` allows cpu scalar to be moved to cuda and xpu but not otherwise.
|
||||
is_cpu_scalar = a.device.type in ["cuda", "xpu"] and value.device.type == "cpu"
|
||||
check(
|
||||
torch._check(
|
||||
is_cpu_scalar or value.device == a.device,
|
||||
lambda: "Expected `value` to be on same device as `a`",
|
||||
)
|
||||
|
|
@ -5011,7 +5000,7 @@ def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLi
|
|||
# We allow casting `value` to lower type for other case
|
||||
# Eg. float -> int.
|
||||
# Ref: https://github.com/pytorch/pytorch/issues/79195
|
||||
check(
|
||||
torch._check(
|
||||
utils.is_weakly_lesser_type(value_type, python_type),
|
||||
lambda: f"could not convert to type {python_type} without overflow",
|
||||
)
|
||||
|
|
@ -5101,7 +5090,7 @@ def norm(
|
|||
|
||||
@register_decomposition(aten.trace)
|
||||
def trace(self: TensorLikeType) -> TensorLikeType:
|
||||
utils.check(
|
||||
torch._check(
|
||||
self.ndim == 2, lambda: "expected a matrix, but got tensor with dim {self.ndim}"
|
||||
)
|
||||
return torch.sum(torch.diag(self, 0))
|
||||
|
|
@ -5125,7 +5114,7 @@ rpow = _make_r_binary_op(pow)
|
|||
@register_decomposition(aten.triu)
|
||||
@out_wrapper()
|
||||
def triu(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
|
||||
utils.check(
|
||||
torch._check(
|
||||
a.ndim >= 2, lambda: "triu: input tensor must have at least 2 dimensions"
|
||||
)
|
||||
h, w = a.shape[-2:]
|
||||
|
|
@ -5142,7 +5131,7 @@ def triu(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
|
|||
@register_decomposition(aten.tril)
|
||||
@out_wrapper()
|
||||
def tril(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
|
||||
utils.check(
|
||||
torch._check(
|
||||
a.ndim >= 2, lambda: "tril: input tensor must have at least 2 dimensions"
|
||||
)
|
||||
h, w = a.shape[-2:]
|
||||
|
|
@ -5187,9 +5176,9 @@ def _trilu_checks(
|
|||
layout: torch.layout,
|
||||
pin_memory: bool,
|
||||
):
|
||||
check(row >= 0, lambda: f"row must be non-negative, got {row}")
|
||||
check(col >= 0, lambda: f"col must be non-negative, got {col}")
|
||||
check(
|
||||
torch._check(row >= 0, lambda: f"row must be non-negative, got {row}")
|
||||
torch._check(col >= 0, lambda: f"col must be non-negative, got {col}")
|
||||
torch._check(
|
||||
dtype in (torch.int32, torch.int64),
|
||||
lambda: f"\"{name}\" not implemented for '{dtype}'",
|
||||
)
|
||||
|
|
@ -5306,7 +5295,7 @@ def bucketize(
|
|||
out_int32: bool = False,
|
||||
right: bool = False,
|
||||
):
|
||||
utils.check(
|
||||
torch._check(
|
||||
boundaries.dim() == 1,
|
||||
lambda: f"boundaries tensor must be 1 dimension but got dim({boundaries.dim()})",
|
||||
)
|
||||
|
|
@ -5364,14 +5353,14 @@ def bucketize(
|
|||
)
|
||||
def cauchy(self, median=0, sigma=1, generator=None):
|
||||
assert generator is None
|
||||
utils.check(
|
||||
torch._check(
|
||||
not utils.is_complex_dtype(self.dtype)
|
||||
and not utils.is_integer_dtype(self.dtype)
|
||||
and not utils.is_boolean_dtype(self.dtype),
|
||||
lambda: f"Cauchy distribution is a continuous probability distribution. \
|
||||
dtype must be a floating point but you specified {self.dtype}",
|
||||
)
|
||||
utils.check(
|
||||
torch._check(
|
||||
sigma > 0.0,
|
||||
lambda: f"cauchy_ expects sigma > 0.0, but found sigma={sigma}",
|
||||
)
|
||||
|
|
@ -5386,14 +5375,14 @@ def cauchy(self, median=0, sigma=1, generator=None):
|
|||
)
|
||||
def exponential(self, rate=1, generator=None):
|
||||
assert generator is None
|
||||
utils.check(
|
||||
torch._check(
|
||||
not utils.is_complex_dtype(self.dtype)
|
||||
and not utils.is_integer_dtype(self.dtype)
|
||||
and not utils.is_boolean_dtype(self.dtype),
|
||||
lambda: f"Exponential distribution is a continuous probability distribution. \
|
||||
dtype must be a floating point but you specified {self.dtype}",
|
||||
)
|
||||
utils.check(
|
||||
torch._check(
|
||||
rate > 0.0,
|
||||
lambda: f"exponential_ expects lambda > 0.0, but found lambda={rate}",
|
||||
)
|
||||
|
|
@ -5409,12 +5398,12 @@ def exponential(self, rate=1, generator=None):
|
|||
def geometric(self, p, generator=None):
|
||||
assert generator is None
|
||||
# TODO: fix inductor rand_like for integer, bool dtypes
|
||||
utils.check(
|
||||
torch._check(
|
||||
not utils.is_complex_dtype(self.dtype)
|
||||
and not utils.is_boolean_dtype(self.dtype),
|
||||
lambda: f"geometric not implemented for {self.dtype}",
|
||||
)
|
||||
utils.check(
|
||||
torch._check(
|
||||
0 < p and p < 1,
|
||||
lambda: f"geometric_ expects p to be in (0, 1), but got p={p}",
|
||||
)
|
||||
|
|
@ -5429,13 +5418,13 @@ def geometric(self, p, generator=None):
|
|||
)
|
||||
def log_normal(self, mean=1, std=2, generator=None):
|
||||
assert generator is None
|
||||
utils.check(
|
||||
torch._check(
|
||||
not utils.is_complex_dtype(self.dtype)
|
||||
and not utils.is_integer_dtype(self.dtype)
|
||||
and not utils.is_boolean_dtype(self.dtype),
|
||||
lambda: f"log_normal not implemented for {self.dtype}",
|
||||
)
|
||||
utils.check(
|
||||
torch._check(
|
||||
0 < std,
|
||||
lambda: f"log_normal_ expects std > 0.0, but found std={std}",
|
||||
)
|
||||
|
|
@ -5451,7 +5440,7 @@ def log_normal(self, mean=1, std=2, generator=None):
|
|||
)
|
||||
def normal(self, mean=0, std=1, generator=None):
|
||||
assert generator is None
|
||||
utils.check(std >= 0, lambda: f"normal expects std >= 0.0, but found std {std}")
|
||||
torch._check(std >= 0, lambda: f"normal expects std >= 0.0, but found std {std}")
|
||||
normal_samples = prims.normal(
|
||||
self.shape,
|
||||
mean=0.0,
|
||||
|
|
@ -5465,7 +5454,7 @@ def normal(self, mean=0, std=1, generator=None):
|
|||
|
||||
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
|
||||
def rad2deg(self: TensorLikeType):
|
||||
utils.check(
|
||||
torch._check(
|
||||
not utils.is_complex_dtype(self.dtype),
|
||||
lambda: "rad2deg is not supported for complex tensors.",
|
||||
)
|
||||
|
|
@ -5475,7 +5464,7 @@ def rad2deg(self: TensorLikeType):
|
|||
|
||||
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
|
||||
def deg2rad(self: TensorLikeType):
|
||||
utils.check(
|
||||
torch._check(
|
||||
not utils.is_complex_dtype(self.dtype),
|
||||
lambda: "deg2rad is not supported for complex tensors.",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import torch._prims_common as utils
|
|||
# Utilities should come BEFORE this import
|
||||
from torch._decomp import register_decomposition
|
||||
|
||||
from torch._prims_common import check, TensorLikeType
|
||||
from torch._prims_common import TensorLikeType
|
||||
from torch._prims_common.wrappers import out_wrapper
|
||||
from torch._refs import _broadcast_shapes
|
||||
|
||||
|
|
@ -79,14 +79,14 @@ short = _make_conversion_method("short", torch.short)
|
|||
@out_wrapper(exact_dtype=True)
|
||||
def complex(real: TensorLikeType, imag: TensorLikeType) -> TensorLikeType:
|
||||
allowed_dtypes = (torch.float32, torch.float64, torch.float16)
|
||||
check(
|
||||
torch._check(
|
||||
real.dtype in allowed_dtypes and imag.dtype in allowed_dtypes,
|
||||
lambda: (
|
||||
f"Expected both inputs to be Half, Float or Double tensors but got "
|
||||
f"{real.dtype} and {imag.dtype}"
|
||||
),
|
||||
)
|
||||
check(
|
||||
torch._check(
|
||||
real.dtype == imag.dtype,
|
||||
lambda: (
|
||||
f"Expected object of scalar type {real.dtype} but got "
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import torch
|
|||
import torch._prims as prims
|
||||
import torch._prims_common as utils
|
||||
from torch._decomp import register_decomposition
|
||||
from torch._prims_common import check, DimsType, ShapeType, TensorLikeType
|
||||
from torch._prims_common import DimsType, ShapeType, TensorLikeType
|
||||
from torch._prims_common.wrappers import _maybe_convert_to_dtype, out_wrapper
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -43,7 +43,7 @@ def _apply_norm(
|
|||
x: TensorLikeType, norm: NormType, signal_numel: int, forward: bool
|
||||
) -> TensorLikeType:
|
||||
"""Apply normalization to the un-normalized FFT result"""
|
||||
check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}")
|
||||
torch._check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}")
|
||||
|
||||
if norm == "ortho":
|
||||
return x * (1 / math.sqrt(signal_numel))
|
||||
|
|
@ -116,7 +116,9 @@ def _fft_c2r(
|
|||
input = _maybe_promote_tensor_fft(input, require_complex=True)
|
||||
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
|
||||
last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1)
|
||||
check(last_dim_size >= 1, lambda: f"Invalid number of data points ({n}) specified")
|
||||
torch._check(
|
||||
last_dim_size >= 1, lambda: f"Invalid number of data points ({n}) specified"
|
||||
)
|
||||
|
||||
if n is not None:
|
||||
input = _resize_fft_input(input, dims=dims, sizes=(last_dim_size // 2 + 1,))
|
||||
|
|
@ -138,7 +140,7 @@ def _fft_r2c(
|
|||
onesided: bool,
|
||||
) -> TensorLikeType:
|
||||
"""Common code for performing any real to complex FFT (rfft or ihfft)"""
|
||||
check(
|
||||
torch._check(
|
||||
not input.dtype.is_complex,
|
||||
lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}",
|
||||
)
|
||||
|
|
@ -162,7 +164,7 @@ def _fft_c2c(
|
|||
forward: bool,
|
||||
) -> TensorLikeType:
|
||||
"""Common code for performing any complex to complex FFT (fft or ifft)"""
|
||||
check(
|
||||
torch._check(
|
||||
input.dtype.is_complex,
|
||||
lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}",
|
||||
)
|
||||
|
|
@ -265,20 +267,20 @@ def _canonicalize_fft_shape_and_dim_args(
|
|||
ret_dims = utils.canonicalize_dims(input_dim, dim, wrap_scalar=False)
|
||||
|
||||
# Check dims are unique
|
||||
check(len(set(dim)) == len(dim), lambda: "FFT dims must be unique")
|
||||
torch._check(len(set(dim)) == len(dim), lambda: "FFT dims must be unique")
|
||||
|
||||
if shape is not None:
|
||||
if not isinstance(shape, Sequence):
|
||||
shape = (shape,)
|
||||
|
||||
# Has shape, might have dim
|
||||
check(
|
||||
torch._check(
|
||||
dim is None or len(dim) == len(shape),
|
||||
lambda: "When given, dim and shape arguments must have the same length",
|
||||
)
|
||||
transform_ndim = len(shape)
|
||||
|
||||
check(
|
||||
torch._check(
|
||||
transform_ndim <= input_dim,
|
||||
lambda: f"Got shape with {transform_ndim} values but input tensor "
|
||||
f"only has {input_dim} dimensions.",
|
||||
|
|
@ -301,7 +303,7 @@ def _canonicalize_fft_shape_and_dim_args(
|
|||
ret_shape = tuple(input_sizes[d] for d in ret_dims)
|
||||
|
||||
for n in ret_shape:
|
||||
check(n > 0, lambda: f"Invalid number of data points ({n}) specified")
|
||||
torch._check(n > 0, lambda: f"Invalid number of data points ({n}) specified")
|
||||
|
||||
return _ShapeAndDims(shape=ret_shape, dims=ret_dims)
|
||||
|
||||
|
|
@ -323,7 +325,7 @@ def _fftn_c2c(
|
|||
forward: bool,
|
||||
) -> TensorLikeType:
|
||||
"""Common code for n-dimensional complex to complex FFTs (fftn or ifftn)"""
|
||||
check(
|
||||
torch._check(
|
||||
input.dtype.is_complex,
|
||||
lambda: f"{function_name} expects a complex input tensor, "
|
||||
f"but got {input.dtype}",
|
||||
|
|
@ -367,7 +369,7 @@ def rfftn(
|
|||
dim: Optional[DimsType] = None,
|
||||
norm: NormType = None,
|
||||
) -> TensorLikeType:
|
||||
check(
|
||||
torch._check(
|
||||
not input.dtype.is_complex,
|
||||
lambda: f"rfftn expects a real-valued input tensor, but got {input.dtype}",
|
||||
)
|
||||
|
|
@ -386,12 +388,12 @@ def ihfftn(
|
|||
dim: Optional[DimsType] = None,
|
||||
norm: NormType = None,
|
||||
) -> TensorLikeType:
|
||||
check(
|
||||
torch._check(
|
||||
not input.dtype.is_complex,
|
||||
lambda: f"ihfftn expects a real-valued input tensor, but got {input.dtype}",
|
||||
)
|
||||
shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim)
|
||||
check(len(shape) > 0, lambda: "ihfftn must transform at least one axis")
|
||||
torch._check(len(shape) > 0, lambda: "ihfftn must transform at least one axis")
|
||||
input = _maybe_promote_tensor_fft(input, require_complex=False)
|
||||
input = _resize_fft_input(input, dim, shape)
|
||||
|
||||
|
|
@ -421,14 +423,14 @@ def _canonicalize_fft_c2r_shape_and_dim_args(
|
|||
"""Canonicalize shape and dim arguments for n-dimensional c2r transforms,
|
||||
as well as calculating the last_dim_size which is shape[dim[-1]] for the output"""
|
||||
(shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
|
||||
check(len(shape) > 0, lambda: f"{fname} must transform at least one axis")
|
||||
torch._check(len(shape) > 0, lambda: f"{fname} must transform at least one axis")
|
||||
|
||||
if s is None or s[-1] == -1:
|
||||
last_dim_size = 2 * (input.shape[dim[-1]] - 1)
|
||||
else:
|
||||
last_dim_size = shape[-1]
|
||||
|
||||
check(
|
||||
torch._check(
|
||||
last_dim_size >= 1,
|
||||
lambda: f"Invalid number of data points ({last_dim_size}) specified",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ import torch._refs as refs
|
|||
import torch._refs.linalg as linalg
|
||||
from torch import Tensor
|
||||
from torch._prims_common import (
|
||||
check,
|
||||
check_fp_or_complex,
|
||||
check_is_matrix,
|
||||
Dim,
|
||||
|
|
@ -29,11 +28,11 @@ def _check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_nam
|
|||
Checks related to the dtype kwarg in `linalg.*norm` functions
|
||||
"""
|
||||
if dtype is not None:
|
||||
check(
|
||||
torch._check(
|
||||
utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
|
||||
lambda: f"{fn_name}: dtype should be floating point or complex. Got {dtype}",
|
||||
)
|
||||
check(
|
||||
torch._check(
|
||||
utils.is_complex_dtype(dtype) == utils.is_complex_dtype(x_dtype),
|
||||
lambda: "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}".format(
|
||||
fn_name=fn_name,
|
||||
|
|
@ -41,7 +40,7 @@ def _check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_nam
|
|||
dtype=dtype,
|
||||
),
|
||||
)
|
||||
check(
|
||||
torch._check(
|
||||
utils.get_higher_dtype(dtype, x_dtype) == dtype,
|
||||
lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible "
|
||||
"without narrowing to the specified dtype ({dtype})",
|
||||
|
|
@ -79,7 +78,7 @@ def vector_norm(
|
|||
dim = [dim] # type: ignore[assignment]
|
||||
|
||||
if x.numel() == 0 and (ord < 0.0 or ord == float("inf")):
|
||||
check(
|
||||
torch._check(
|
||||
dim is not None and len(dim) != 0,
|
||||
lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor "
|
||||
"because the operation does not have an identity",
|
||||
|
|
@ -87,7 +86,7 @@ def vector_norm(
|
|||
shape = x.shape
|
||||
assert dim is not None # mypy does not seem to be able to see through check?
|
||||
for d in dim:
|
||||
check(
|
||||
torch._check(
|
||||
shape[d] != 0,
|
||||
lambda: f"linalg.vector_norm cannot compute the {ord} norm on the "
|
||||
f"dimension {d} because this dimension is empty and the "
|
||||
|
|
@ -147,8 +146,10 @@ def matrix_norm(
|
|||
dim = utils.canonicalize_dims(A.ndim, dim)
|
||||
if isinstance(dim, Dim):
|
||||
dim = (dim,) # type: ignore[assignment]
|
||||
check(len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}")
|
||||
check(
|
||||
torch._check(
|
||||
len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}"
|
||||
)
|
||||
torch._check(
|
||||
dim[0] != dim[1],
|
||||
lambda: "linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})",
|
||||
)
|
||||
|
|
@ -157,7 +158,7 @@ def matrix_norm(
|
|||
|
||||
if isinstance(ord, str):
|
||||
# ord
|
||||
check(
|
||||
torch._check(
|
||||
ord in ("fro", "nuc"),
|
||||
lambda: "linalg.matrix_norm: Order {ord} not supported.",
|
||||
)
|
||||
|
|
@ -180,7 +181,7 @@ def matrix_norm(
|
|||
else:
|
||||
# ord
|
||||
abs_ord = abs(ord)
|
||||
check(
|
||||
torch._check(
|
||||
abs_ord in (2, 1, float("inf")),
|
||||
lambda: "linalg.matrix_norm: Order {ord} not supported.",
|
||||
)
|
||||
|
|
@ -224,12 +225,12 @@ def norm(
|
|||
if dim is not None:
|
||||
if isinstance(dim, Dim):
|
||||
dim = (dim,) # type: ignore[assignment]
|
||||
check(
|
||||
torch._check(
|
||||
len(dim) in (1, 2),
|
||||
lambda: "linalg.norm: If dim is specified, it must be of length 1 or 2. Got {dim}",
|
||||
)
|
||||
elif ord is not None:
|
||||
check(
|
||||
torch._check(
|
||||
A.ndim in (1, 2),
|
||||
lambda: "linalg.norm: If dim is not specified but ord is, the input must be 1D or 2D. Got {A.ndim}D",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import torch._prims_common as utils
|
|||
import torch._refs as refs
|
||||
from torch._decomp import register_decomposition
|
||||
from torch._prims_common import (
|
||||
check,
|
||||
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
||||
NumberType,
|
||||
ShapeType,
|
||||
|
|
@ -98,7 +97,7 @@ def alpha_dropout(
|
|||
if not training:
|
||||
return self
|
||||
|
||||
utils.check(
|
||||
torch._check(
|
||||
p <= 1 and p >= 0,
|
||||
lambda: f"dropout probability has to be between 0 and 1, but got, {p}",
|
||||
)
|
||||
|
|
@ -134,7 +133,7 @@ def _inplace_wrapper(fn):
|
|||
@wraps(fn)
|
||||
def _fn(a, *args, inplace=False, **kwargs):
|
||||
if inplace:
|
||||
check(
|
||||
torch._check(
|
||||
"out" not in kwargs,
|
||||
lambda: "Cannot set inplace=True and pass out= at the same time",
|
||||
)
|
||||
|
|
@ -193,7 +192,7 @@ def dropout(
|
|||
if not training:
|
||||
return a
|
||||
|
||||
utils.check(
|
||||
torch._check(
|
||||
p <= 1 and p >= 0,
|
||||
lambda: f"dropout probability has to be between 0 and 1, but got, {p}",
|
||||
)
|
||||
|
|
@ -232,15 +231,15 @@ def elu(
|
|||
|
||||
# nb. This should be factored out into a can_cast aux function
|
||||
python_type = utils.dtype_to_type(a.dtype)
|
||||
check(
|
||||
torch._check(
|
||||
utils.is_weakly_lesser_type(type(input_scale), python_type),
|
||||
lambda: f"input_scale argument of type {type(input_scale)} cannot be safely cast to type {python_type}!",
|
||||
)
|
||||
check(
|
||||
torch._check(
|
||||
utils.is_weakly_lesser_type(type(scale), python_type),
|
||||
lambda: f"scale argument of type {type(scale)} cannot be safely cast to type {python_type}!",
|
||||
)
|
||||
check(
|
||||
torch._check(
|
||||
utils.is_weakly_lesser_type(type(alpha), python_type),
|
||||
lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!",
|
||||
)
|
||||
|
|
@ -276,14 +275,14 @@ def group_norm(
|
|||
"""
|
||||
Reference implementation of :func:`torch.nn.functional.group_norm`.
|
||||
"""
|
||||
utils.check(
|
||||
torch._check(
|
||||
input.ndim >= 2,
|
||||
lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}",
|
||||
)
|
||||
|
||||
batch_size = input.shape[0]
|
||||
num_channels = input.shape[1]
|
||||
utils.check(
|
||||
torch._check(
|
||||
num_channels % num_groups == 0,
|
||||
lambda: "Expected number of channels in input to be divisible by num_groups, "
|
||||
+ f"but got input of shape {input.shape} and num_groups = {num_groups}",
|
||||
|
|
@ -394,7 +393,7 @@ def softmax(
|
|||
# deprecated. For PrimTorch, it's fine to drop support for deprecated
|
||||
# behavior because it requires explicit opt in. This error is to inform
|
||||
# users how to update their calls.
|
||||
check(dim is not None, lambda: "implicit dim not supported, use dim=X")
|
||||
torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X")
|
||||
return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload]
|
||||
|
||||
|
||||
|
|
@ -409,7 +408,7 @@ def softmin(
|
|||
# deprecated. For PrimTorch, it's fine to drop support for deprecated
|
||||
# behavior because it requires explicit opt in. This error is to inform
|
||||
# users how to update their calls.
|
||||
check(dim is not None, lambda: "implicit dim not supported, use dim=X")
|
||||
torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X")
|
||||
return torch.softmax(a=-a, dim=dim, dtype=dtype) # type: ignore[call-overload]
|
||||
|
||||
|
||||
|
|
@ -469,7 +468,7 @@ def softshrink(a: TensorLikeType, lambd: float = 0.5):
|
|||
# softshrink(x) = x - lambd if x > lambd
|
||||
# = x + lambd if x < -lambd
|
||||
# = 0 otherwise
|
||||
check(
|
||||
torch._check(
|
||||
lambd >= 0,
|
||||
lambda: f"lambda must be greater or equal to 0, but found to be {lambd}",
|
||||
)
|
||||
|
|
@ -596,7 +595,7 @@ def log_softmax(
|
|||
# deprecated. For PrimTorch, it's fine to drop support for deprecated
|
||||
# behavior because it requires explicit opt in. This error is to inform
|
||||
# users how to update their calls.
|
||||
check(dim is not None, lambda: "implicit dim not supported, use dim=X")
|
||||
torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X")
|
||||
return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload]
|
||||
|
||||
|
||||
|
|
@ -668,12 +667,12 @@ def _nll_loss_nd(
|
|||
reduction: str,
|
||||
ignore_index: int,
|
||||
) -> TensorLikeType:
|
||||
utils.check(
|
||||
torch._check(
|
||||
input.ndim > 0 and input.ndim <= 3,
|
||||
lambda: f"Expected input dimension to be either [1, 2, 3] but received {input.ndim}.",
|
||||
)
|
||||
|
||||
utils.check(
|
||||
torch._check(
|
||||
(input.ndim == 1) or (input.shape[0] == target.shape[0]),
|
||||
lambda: f"Expected input batch size {input.shape[0]} to match target batch size {target.shape[0]}.",
|
||||
)
|
||||
|
|
@ -693,7 +692,7 @@ def _nll_loss_nd(
|
|||
(flat_target >= 0), (flat_target < num_classes)
|
||||
)
|
||||
class_check = torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask))
|
||||
utils.check(
|
||||
torch._check(
|
||||
isinstance(target, FakeTensor) or bool(class_check.item()),
|
||||
lambda: "A target class is out-of-bounds and not the ignore index.",
|
||||
)
|
||||
|
|
@ -758,7 +757,7 @@ def nll_loss(
|
|||
"""
|
||||
Reference implementation of torch.nn.functional.nll_loss
|
||||
"""
|
||||
utils.check(
|
||||
torch._check(
|
||||
input.ndim > 0,
|
||||
lambda: f"Expected input tensor to have 1 or more dimensions (got {input.ndim})",
|
||||
)
|
||||
|
|
@ -796,9 +795,13 @@ def nll_loss(
|
|||
# For ndim > 3, we reshape the input and target to 3-D case.
|
||||
# Input (N batch-size, C classes, k-dimensions)
|
||||
# Target (N batch-size, k-dimensions)
|
||||
utils.check(
|
||||
torch._check(
|
||||
input.ndim > 0 and target.ndim > 0 and target.shape[1:] == input.shape[2:],
|
||||
lambda: f"Expected target shape {out_size} but got {target.shape}",
|
||||
lambda: (
|
||||
"Expected input and target to both have ndim > 0 and "
|
||||
"target.shape[1:] == input.shape[2:], but got "
|
||||
f"target.shape {target.shape} and input.shape {input.shape}"
|
||||
),
|
||||
)
|
||||
|
||||
batch_size = input.shape[0]
|
||||
|
|
@ -837,7 +840,7 @@ def huber_loss(
|
|||
if type(reduction) is int:
|
||||
reduction = _reduction_int_to_str(reduction)
|
||||
_check_reduction_value(reduction) # type: ignore[arg-type]
|
||||
check(
|
||||
torch._check(
|
||||
delta > 0,
|
||||
lambda: "huber_loss does not support non-positive values for delta.",
|
||||
)
|
||||
|
|
@ -938,7 +941,7 @@ def _triplet_margin_with_distance_loss(
|
|||
a_dim = anchor.ndim
|
||||
p_dim = positive.ndim
|
||||
n_dim = negative.ndim
|
||||
check(
|
||||
torch._check(
|
||||
a_dim == p_dim and p_dim == n_dim,
|
||||
lambda: (
|
||||
f"The anchor, positive, and negative tensors are expected to have "
|
||||
|
|
@ -1075,25 +1078,25 @@ def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType:
|
|||
"""
|
||||
Reference implementation of torch.nn.functional.prelu
|
||||
"""
|
||||
check(
|
||||
torch._check(
|
||||
isinstance(a, TensorLike),
|
||||
lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}",
|
||||
)
|
||||
check(
|
||||
torch._check(
|
||||
isinstance(weight, TensorLike),
|
||||
lambda: f"prelu: Expected `weight` to be tensor, but got: {type(weight)}",
|
||||
)
|
||||
|
||||
if weight.numel() != 1:
|
||||
check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.")
|
||||
torch._check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.")
|
||||
channel_size = a.shape[1] if a.ndim >= 2 else 1
|
||||
check(
|
||||
torch._check(
|
||||
weight.numel() == channel_size,
|
||||
lambda: f"Mismatch of parameter numbers and input channel size. Found parameter numbers ="
|
||||
f" {weight.numel()} and channel size = {channel_size}.",
|
||||
)
|
||||
|
||||
check(
|
||||
torch._check(
|
||||
weight.ndim == 0 or weight.ndim == 1,
|
||||
lambda: f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: "
|
||||
f"ndim = {weight.ndim}",
|
||||
|
|
@ -1132,7 +1135,7 @@ def relu6(a: TensorLikeType, inplace: bool = False) -> TensorLikeType:
|
|||
)
|
||||
def glu(a: TensorLikeType, dim: int = -1) -> TensorLikeType:
|
||||
dim = utils.canonicalize_dims(a.ndim, dim)
|
||||
check(
|
||||
torch._check(
|
||||
a.shape[dim] % 2 == 0,
|
||||
lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}",
|
||||
)
|
||||
|
|
@ -1160,8 +1163,8 @@ def pairwise_distance(
|
|||
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
)
|
||||
def pdist(a: TensorLikeType, p: float = 2) -> TensorLikeType:
|
||||
check(a.ndim == 2, lambda: f"pdist only supports 2D tensors, got: {a.ndim}D")
|
||||
check(p >= 0, lambda: "pdist only supports non-negative p values")
|
||||
torch._check(a.ndim == 2, lambda: f"pdist only supports 2D tensors, got: {a.ndim}D")
|
||||
torch._check(p >= 0, lambda: "pdist only supports non-negative p values")
|
||||
# For p == 2 we can use an efficient implementation, but other values of p
|
||||
# require creating a much bigger tensor for an intermediate step
|
||||
if p == 2:
|
||||
|
|
|
|||
|
|
@ -148,7 +148,7 @@ def logit(self: TensorLikeType, eps: Optional[float] = None) -> TensorLikeType:
|
|||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
||||
)
|
||||
def xlog1py(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]):
|
||||
utils.check(
|
||||
torch._check(
|
||||
isinstance(a, TensorLike) or isinstance(b, TensorLike),
|
||||
lambda: 'Expected either argument a or b to be a Tensor"',
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ import torch._logging
|
|||
from torch._guards import Source
|
||||
from torch._ops import OpOverload
|
||||
from torch._prims_common import (
|
||||
check,
|
||||
elementwise_dtypes,
|
||||
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
||||
is_boolean_dtype,
|
||||
|
|
@ -1495,7 +1494,7 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
) = FakeTensor._find_common_device(func, args, kwargs)
|
||||
|
||||
if isinstance(e, FakeTensor):
|
||||
check(
|
||||
torch._check(
|
||||
e.device == common_device,
|
||||
lambda: f"FakeTensor is wrapped to wrong device, found {e.device}, expected {common_device}",
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user