mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo][numpy] Add unsigned integer dtypes (#125717)
We should support these to whatever extent we can. They corresponding `torch.uint<w>` types are defined, so I don't see an issue with generating the various casting rules and allowing them to trace. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125717 Approved by: https://github.com/lezcano
This commit is contained in:
parent
4ce5322a1f
commit
879d01afcb
|
|
@ -476,13 +476,18 @@ class TestNumPyInterop(TestCase):
|
|||
self.assertTrue(r2.requires_grad)
|
||||
|
||||
@onlyCPU
|
||||
def test_parse_numpy_int(self, device):
|
||||
@skipIfTorchDynamo()
|
||||
def test_parse_numpy_int_overflow(self, device):
|
||||
# assertRaises uses a try-except which dynamo has issues with
|
||||
# Only concrete class can be given where "Type[number[_64Bit]]" is expected
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"(Overflow|an integer is required)",
|
||||
lambda: torch.mean(torch.randn(1, 1), np.uint64(-1)),
|
||||
) # type: ignore[call-overload]
|
||||
|
||||
@onlyCPU
|
||||
def test_parse_numpy_int(self, device):
|
||||
# https://github.com/pytorch/pytorch/issues/29252
|
||||
for nptype in [np.int16, np.int8, np.uint8, np.int32, np.int64]:
|
||||
scalar = 3
|
||||
|
|
|
|||
|
|
@ -8,14 +8,17 @@ import warnings
|
|||
|
||||
# from numpy.core.getlimits import _discovered_machar, _float_ma
|
||||
|
||||
from unittest import skipIf
|
||||
from unittest import expectedFailure as xfail, skipIf
|
||||
|
||||
import numpy
|
||||
|
||||
from pytest import raises as assert_raises
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
subtest,
|
||||
TEST_WITH_TORCHDYNAMO,
|
||||
TestCase,
|
||||
xpassIfTorchDynamo,
|
||||
|
|
@ -109,6 +112,7 @@ class TestFinfo(TestCase):
|
|||
getattr(finfo(dt), attr)
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class TestIinfo(TestCase):
|
||||
def test_basic(self):
|
||||
dts = list(
|
||||
|
|
@ -129,11 +133,19 @@ class TestIinfo(TestCase):
|
|||
with assert_raises((TypeError, ValueError)):
|
||||
iinfo("f4")
|
||||
|
||||
def test_unsigned_max(self):
|
||||
types = np.sctypes["uint"]
|
||||
for T in types:
|
||||
max_calculated = T(0) - T(1)
|
||||
assert_equal(iinfo(T).max, max_calculated)
|
||||
@parametrize(
|
||||
"T",
|
||||
[
|
||||
np.uint8,
|
||||
# xfail: unsupported add (uint[16,32,64])
|
||||
subtest(np.uint16, decorators=[xfail]),
|
||||
subtest(np.uint32, decorators=[xfail]),
|
||||
subtest(np.uint64, decorators=[xfail]),
|
||||
],
|
||||
)
|
||||
def test_unsigned_max(self, T):
|
||||
max_calculated = T(0) - T(1)
|
||||
assert_equal(iinfo(T).max, max_calculated)
|
||||
|
||||
|
||||
class TestRepr(TestCase):
|
||||
|
|
|
|||
|
|
@ -732,13 +732,16 @@ class TestAbs(TestCase):
|
|||
|
||||
@instantiate_parametrized_tests
|
||||
class TestBitShifts(TestCase):
|
||||
@parametrize("type_code", np.typecodes["Integer"] + "B")
|
||||
@parametrize("type_code", np.typecodes["AllInteger"])
|
||||
@parametrize("op", [operator.rshift, operator.lshift])
|
||||
def test_shift_all_bits(self, type_code, op):
|
||||
"""Shifts where the shift amount is the width of the type or wider"""
|
||||
# gh-2449
|
||||
dt = np.dtype(type_code)
|
||||
nbits = dt.itemsize * 8
|
||||
if dt in (np.dtype(np.uint64), np.dtype(np.uint32), np.dtype(np.uint16)):
|
||||
raise SkipTest("NYI: bitshift uint64")
|
||||
|
||||
for val in [5, -5]:
|
||||
for shift in [nbits, nbits + 4]:
|
||||
val_scl = np.array(val).astype(dt)[()]
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from torch.testing._internal.common_utils import (
|
|||
dtype_names = [
|
||||
"bool_",
|
||||
*[f"int{w}" for w in [8, 16, 32, 64]],
|
||||
"uint8",
|
||||
*[f"uint{w}" for w in [8, 16, 32, 64]],
|
||||
*[f"float{w}" for w in [16, 32, 64]],
|
||||
*[f"complex{w}" for w in [64, 128]],
|
||||
]
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import torch
|
||||
|
||||
# These two dicts are autogenerated with autogen/gen_dtypes.py,
|
||||
# using numpy version 1.23.5.
|
||||
# using numpy version 1.24.3.
|
||||
|
||||
_can_cast_dict = {
|
||||
"no": {
|
||||
|
|
@ -14,6 +14,9 @@ _can_cast_dict = {
|
|||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -27,6 +30,9 @@ _can_cast_dict = {
|
|||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -40,6 +46,9 @@ _can_cast_dict = {
|
|||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -53,6 +62,9 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -66,6 +78,9 @@ _can_cast_dict = {
|
|||
torch.complex64: False,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -79,6 +94,57 @@ _can_cast_dict = {
|
|||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: True,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
torch.int64: False,
|
||||
torch.bool: False,
|
||||
},
|
||||
torch.uint16: {
|
||||
torch.float16: False,
|
||||
torch.float32: False,
|
||||
torch.float64: False,
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: True,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
torch.int64: False,
|
||||
torch.bool: False,
|
||||
},
|
||||
torch.uint32: {
|
||||
torch.float16: False,
|
||||
torch.float32: False,
|
||||
torch.float64: False,
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: True,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
torch.int64: False,
|
||||
torch.bool: False,
|
||||
},
|
||||
torch.uint64: {
|
||||
torch.float16: False,
|
||||
torch.float32: False,
|
||||
torch.float64: False,
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: True,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -92,6 +158,9 @@ _can_cast_dict = {
|
|||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: True,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -105,6 +174,9 @@ _can_cast_dict = {
|
|||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: True,
|
||||
torch.int32: False,
|
||||
|
|
@ -118,6 +190,9 @@ _can_cast_dict = {
|
|||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: True,
|
||||
|
|
@ -131,6 +206,9 @@ _can_cast_dict = {
|
|||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -144,6 +222,9 @@ _can_cast_dict = {
|
|||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -159,6 +240,9 @@ _can_cast_dict = {
|
|||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -172,6 +256,9 @@ _can_cast_dict = {
|
|||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -185,6 +272,9 @@ _can_cast_dict = {
|
|||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -198,6 +288,9 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -211,6 +304,9 @@ _can_cast_dict = {
|
|||
torch.complex64: False,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -224,6 +320,57 @@ _can_cast_dict = {
|
|||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: True,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
torch.int64: False,
|
||||
torch.bool: False,
|
||||
},
|
||||
torch.uint16: {
|
||||
torch.float16: False,
|
||||
torch.float32: False,
|
||||
torch.float64: False,
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: True,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
torch.int64: False,
|
||||
torch.bool: False,
|
||||
},
|
||||
torch.uint32: {
|
||||
torch.float16: False,
|
||||
torch.float32: False,
|
||||
torch.float64: False,
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: True,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
torch.int64: False,
|
||||
torch.bool: False,
|
||||
},
|
||||
torch.uint64: {
|
||||
torch.float16: False,
|
||||
torch.float32: False,
|
||||
torch.float64: False,
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: True,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -237,6 +384,9 @@ _can_cast_dict = {
|
|||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: True,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -250,6 +400,9 @@ _can_cast_dict = {
|
|||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: True,
|
||||
torch.int32: False,
|
||||
|
|
@ -263,6 +416,9 @@ _can_cast_dict = {
|
|||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: True,
|
||||
|
|
@ -276,6 +432,9 @@ _can_cast_dict = {
|
|||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -289,6 +448,9 @@ _can_cast_dict = {
|
|||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -304,6 +466,9 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -317,6 +482,9 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -330,6 +498,9 @@ _can_cast_dict = {
|
|||
torch.complex64: False,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -343,6 +514,9 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -356,6 +530,9 @@ _can_cast_dict = {
|
|||
torch.complex64: False,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -369,12 +546,63 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: False,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
torch.int64: True,
|
||||
torch.bool: False,
|
||||
},
|
||||
torch.uint16: {
|
||||
torch.float16: False,
|
||||
torch.float32: True,
|
||||
torch.float64: True,
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: True,
|
||||
torch.int64: True,
|
||||
torch.bool: False,
|
||||
},
|
||||
torch.uint32: {
|
||||
torch.float16: False,
|
||||
torch.float32: False,
|
||||
torch.float64: True,
|
||||
torch.complex64: False,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
torch.int64: True,
|
||||
torch.bool: False,
|
||||
},
|
||||
torch.uint64: {
|
||||
torch.float16: False,
|
||||
torch.float32: False,
|
||||
torch.float64: True,
|
||||
torch.complex64: False,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: True,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
torch.int64: False,
|
||||
torch.bool: False,
|
||||
},
|
||||
torch.int8: {
|
||||
torch.float16: True,
|
||||
torch.float32: True,
|
||||
|
|
@ -382,6 +610,9 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
|
|
@ -395,6 +626,9 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
|
|
@ -408,6 +642,9 @@ _can_cast_dict = {
|
|||
torch.complex64: False,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: True,
|
||||
|
|
@ -421,6 +658,9 @@ _can_cast_dict = {
|
|||
torch.complex64: False,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -434,6 +674,9 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
|
|
@ -449,6 +692,9 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -462,6 +708,9 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -475,6 +724,9 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -488,6 +740,9 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -501,6 +756,9 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
|
|
@ -514,6 +772,57 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
torch.int64: True,
|
||||
torch.bool: False,
|
||||
},
|
||||
torch.uint16: {
|
||||
torch.float16: True,
|
||||
torch.float32: True,
|
||||
torch.float64: True,
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
torch.int64: True,
|
||||
torch.bool: False,
|
||||
},
|
||||
torch.uint32: {
|
||||
torch.float16: True,
|
||||
torch.float32: True,
|
||||
torch.float64: True,
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
torch.int64: True,
|
||||
torch.bool: False,
|
||||
},
|
||||
torch.uint64: {
|
||||
torch.float16: True,
|
||||
torch.float32: True,
|
||||
torch.float64: True,
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
|
|
@ -527,6 +836,9 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
|
|
@ -540,6 +852,9 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
|
|
@ -553,6 +868,9 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
|
|
@ -566,6 +884,9 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
|
|
@ -579,6 +900,9 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
|
|
@ -594,6 +918,9 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
|
|
@ -607,6 +934,9 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
|
|
@ -620,6 +950,9 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
|
|
@ -633,6 +966,9 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
|
|
@ -646,6 +982,9 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
|
|
@ -659,6 +998,57 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
torch.int64: True,
|
||||
torch.bool: True,
|
||||
},
|
||||
torch.uint16: {
|
||||
torch.float16: True,
|
||||
torch.float32: True,
|
||||
torch.float64: True,
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
torch.int64: True,
|
||||
torch.bool: True,
|
||||
},
|
||||
torch.uint32: {
|
||||
torch.float16: True,
|
||||
torch.float32: True,
|
||||
torch.float64: True,
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
torch.int64: True,
|
||||
torch.bool: True,
|
||||
},
|
||||
torch.uint64: {
|
||||
torch.float16: True,
|
||||
torch.float32: True,
|
||||
torch.float64: True,
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
|
|
@ -672,6 +1062,9 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
|
|
@ -685,6 +1078,9 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
|
|
@ -698,6 +1094,9 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
|
|
@ -711,6 +1110,9 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
|
|
@ -724,6 +1126,9 @@ _can_cast_dict = {
|
|||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
|
|
@ -742,6 +1147,9 @@ _result_type_dict = {
|
|||
torch.complex64: torch.complex64,
|
||||
torch.complex128: torch.complex128,
|
||||
torch.uint8: torch.float16,
|
||||
torch.uint16: torch.float32,
|
||||
torch.uint32: torch.float64,
|
||||
torch.uint64: torch.float64,
|
||||
torch.int8: torch.float16,
|
||||
torch.int16: torch.float32,
|
||||
torch.int32: torch.float64,
|
||||
|
|
@ -755,6 +1163,9 @@ _result_type_dict = {
|
|||
torch.complex64: torch.complex64,
|
||||
torch.complex128: torch.complex128,
|
||||
torch.uint8: torch.float32,
|
||||
torch.uint16: torch.float32,
|
||||
torch.uint32: torch.float64,
|
||||
torch.uint64: torch.float64,
|
||||
torch.int8: torch.float32,
|
||||
torch.int16: torch.float32,
|
||||
torch.int32: torch.float64,
|
||||
|
|
@ -768,6 +1179,9 @@ _result_type_dict = {
|
|||
torch.complex64: torch.complex128,
|
||||
torch.complex128: torch.complex128,
|
||||
torch.uint8: torch.float64,
|
||||
torch.uint16: torch.float64,
|
||||
torch.uint32: torch.float64,
|
||||
torch.uint64: torch.float64,
|
||||
torch.int8: torch.float64,
|
||||
torch.int16: torch.float64,
|
||||
torch.int32: torch.float64,
|
||||
|
|
@ -781,6 +1195,9 @@ _result_type_dict = {
|
|||
torch.complex64: torch.complex64,
|
||||
torch.complex128: torch.complex128,
|
||||
torch.uint8: torch.complex64,
|
||||
torch.uint16: torch.complex64,
|
||||
torch.uint32: torch.complex128,
|
||||
torch.uint64: torch.complex128,
|
||||
torch.int8: torch.complex64,
|
||||
torch.int16: torch.complex64,
|
||||
torch.int32: torch.complex128,
|
||||
|
|
@ -794,6 +1211,9 @@ _result_type_dict = {
|
|||
torch.complex64: torch.complex128,
|
||||
torch.complex128: torch.complex128,
|
||||
torch.uint8: torch.complex128,
|
||||
torch.uint16: torch.complex128,
|
||||
torch.uint32: torch.complex128,
|
||||
torch.uint64: torch.complex128,
|
||||
torch.int8: torch.complex128,
|
||||
torch.int16: torch.complex128,
|
||||
torch.int32: torch.complex128,
|
||||
|
|
@ -807,12 +1227,63 @@ _result_type_dict = {
|
|||
torch.complex64: torch.complex64,
|
||||
torch.complex128: torch.complex128,
|
||||
torch.uint8: torch.uint8,
|
||||
torch.uint16: torch.uint16,
|
||||
torch.uint32: torch.uint32,
|
||||
torch.uint64: torch.uint64,
|
||||
torch.int8: torch.int16,
|
||||
torch.int16: torch.int16,
|
||||
torch.int32: torch.int32,
|
||||
torch.int64: torch.int64,
|
||||
torch.bool: torch.uint8,
|
||||
},
|
||||
torch.uint16: {
|
||||
torch.float16: torch.float32,
|
||||
torch.float32: torch.float32,
|
||||
torch.float64: torch.float64,
|
||||
torch.complex64: torch.complex64,
|
||||
torch.complex128: torch.complex128,
|
||||
torch.uint8: torch.uint16,
|
||||
torch.uint16: torch.uint16,
|
||||
torch.uint32: torch.uint32,
|
||||
torch.uint64: torch.uint64,
|
||||
torch.int8: torch.int32,
|
||||
torch.int16: torch.int32,
|
||||
torch.int32: torch.int32,
|
||||
torch.int64: torch.int64,
|
||||
torch.bool: torch.uint16,
|
||||
},
|
||||
torch.uint32: {
|
||||
torch.float16: torch.float64,
|
||||
torch.float32: torch.float64,
|
||||
torch.float64: torch.float64,
|
||||
torch.complex64: torch.complex128,
|
||||
torch.complex128: torch.complex128,
|
||||
torch.uint8: torch.uint32,
|
||||
torch.uint16: torch.uint32,
|
||||
torch.uint32: torch.uint32,
|
||||
torch.uint64: torch.uint64,
|
||||
torch.int8: torch.int64,
|
||||
torch.int16: torch.int64,
|
||||
torch.int32: torch.int64,
|
||||
torch.int64: torch.int64,
|
||||
torch.bool: torch.uint32,
|
||||
},
|
||||
torch.uint64: {
|
||||
torch.float16: torch.float64,
|
||||
torch.float32: torch.float64,
|
||||
torch.float64: torch.float64,
|
||||
torch.complex64: torch.complex128,
|
||||
torch.complex128: torch.complex128,
|
||||
torch.uint8: torch.uint64,
|
||||
torch.uint16: torch.uint64,
|
||||
torch.uint32: torch.uint64,
|
||||
torch.uint64: torch.uint64,
|
||||
torch.int8: torch.float64,
|
||||
torch.int16: torch.float64,
|
||||
torch.int32: torch.float64,
|
||||
torch.int64: torch.float64,
|
||||
torch.bool: torch.uint64,
|
||||
},
|
||||
torch.int8: {
|
||||
torch.float16: torch.float16,
|
||||
torch.float32: torch.float32,
|
||||
|
|
@ -820,6 +1291,9 @@ _result_type_dict = {
|
|||
torch.complex64: torch.complex64,
|
||||
torch.complex128: torch.complex128,
|
||||
torch.uint8: torch.int16,
|
||||
torch.uint16: torch.int32,
|
||||
torch.uint32: torch.int64,
|
||||
torch.uint64: torch.float64,
|
||||
torch.int8: torch.int8,
|
||||
torch.int16: torch.int16,
|
||||
torch.int32: torch.int32,
|
||||
|
|
@ -833,6 +1307,9 @@ _result_type_dict = {
|
|||
torch.complex64: torch.complex64,
|
||||
torch.complex128: torch.complex128,
|
||||
torch.uint8: torch.int16,
|
||||
torch.uint16: torch.int32,
|
||||
torch.uint32: torch.int64,
|
||||
torch.uint64: torch.float64,
|
||||
torch.int8: torch.int16,
|
||||
torch.int16: torch.int16,
|
||||
torch.int32: torch.int32,
|
||||
|
|
@ -846,6 +1323,9 @@ _result_type_dict = {
|
|||
torch.complex64: torch.complex128,
|
||||
torch.complex128: torch.complex128,
|
||||
torch.uint8: torch.int32,
|
||||
torch.uint16: torch.int32,
|
||||
torch.uint32: torch.int64,
|
||||
torch.uint64: torch.float64,
|
||||
torch.int8: torch.int32,
|
||||
torch.int16: torch.int32,
|
||||
torch.int32: torch.int32,
|
||||
|
|
@ -859,6 +1339,9 @@ _result_type_dict = {
|
|||
torch.complex64: torch.complex128,
|
||||
torch.complex128: torch.complex128,
|
||||
torch.uint8: torch.int64,
|
||||
torch.uint16: torch.int64,
|
||||
torch.uint32: torch.int64,
|
||||
torch.uint64: torch.float64,
|
||||
torch.int8: torch.int64,
|
||||
torch.int16: torch.int64,
|
||||
torch.int32: torch.int64,
|
||||
|
|
@ -872,6 +1355,9 @@ _result_type_dict = {
|
|||
torch.complex64: torch.complex64,
|
||||
torch.complex128: torch.complex128,
|
||||
torch.uint8: torch.uint8,
|
||||
torch.uint16: torch.uint16,
|
||||
torch.uint32: torch.uint32,
|
||||
torch.uint64: torch.uint64,
|
||||
torch.int8: torch.int8,
|
||||
torch.int16: torch.int16,
|
||||
torch.int32: torch.int32,
|
||||
|
|
|
|||
|
|
@ -113,6 +113,24 @@ class uint8(unsignedinteger):
|
|||
torch_dtype = torch.uint8
|
||||
|
||||
|
||||
class uint16(unsignedinteger):
|
||||
name = "uint16"
|
||||
typecode = "H"
|
||||
torch_dtype = torch.uint16
|
||||
|
||||
|
||||
class uint32(signedinteger):
|
||||
name = "uint32"
|
||||
typecode = "I"
|
||||
torch_dtype = torch.uint32
|
||||
|
||||
|
||||
class uint64(signedinteger):
|
||||
name = "uint64"
|
||||
typecode = "L"
|
||||
torch_dtype = torch.uint64
|
||||
|
||||
|
||||
# floating point
|
||||
|
||||
|
||||
|
|
@ -160,6 +178,7 @@ _name_aliases = {
|
|||
"byte": int8,
|
||||
"short": int16,
|
||||
"longlong": int64, # XXX: is this correct?
|
||||
"ulonglong": uint64,
|
||||
"ubyte": uint8,
|
||||
"half": float16,
|
||||
"single": float32,
|
||||
|
|
@ -180,7 +199,7 @@ for name, obj in _name_aliases.items():
|
|||
# cf tests/core/test_scalar_methods.py
|
||||
sctypes = {
|
||||
"int": [int8, int16, int32, int64],
|
||||
"uint": [uint8],
|
||||
"uint": [uint8, uint16, uint32, uint64],
|
||||
"float": [float16, float32, float64],
|
||||
"complex": [complex64, complex128],
|
||||
"others": [bool_],
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user