mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add unsigned support to IValue (#160102)
- Moved repeated logic of saving int64/uint64 into a polymorphic container into `THPUtils_unpackInteger` - Added `TestPythonDispatch.test_dispatch_uint64` regression test Fixes https://github.com/pytorch/pytorch/issues/159168 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160102 Approved by: https://github.com/ezyang
This commit is contained in:
parent
e7152ff8a6
commit
d8cb3db533
|
|
@ -97,6 +97,8 @@ c10::TypePtr IValue::TagType<c10::Type>::get(const IValue& v) {
|
|||
return ComplexType::get();
|
||||
case Tag::Int:
|
||||
return IntType::get();
|
||||
case Tag::UInt:
|
||||
return IntType::get();
|
||||
case Tag::SymInt:
|
||||
return c10::SymIntType::get();
|
||||
case Tag::SymFloat:
|
||||
|
|
@ -320,6 +322,8 @@ IValue IValue::equals(const IValue& rhs) const {
|
|||
return rhs.isComplexDouble() && lhs.toComplexDouble() == rhs.toComplexDouble();
|
||||
case Tag::Int:
|
||||
return rhs.isInt() && lhs.toInt() == rhs.toInt();
|
||||
case Tag::UInt:
|
||||
return rhs.isUnsigned() && lhs.toUInt() == rhs.toUInt();
|
||||
case Tag::SymInt:
|
||||
return rhs.isSymInt() && lhs.toSymInt() == rhs.toSymInt();
|
||||
case Tag::SymFloat:
|
||||
|
|
@ -379,6 +383,8 @@ size_t IValue::hash(const IValue& v) {
|
|||
case Tag::Int:
|
||||
return c10::get_hash(v.payload.u.as_int);
|
||||
// NB: these are technically strict aliasing violations
|
||||
case Tag::UInt:
|
||||
return c10::get_hash(v.payload.u.as_int);
|
||||
case Tag::SymInt:
|
||||
return c10::get_hash(v.payload.u.as_int);
|
||||
case Tag::SymFloat:
|
||||
|
|
@ -806,6 +812,8 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
|
|||
return printComplex(out, v);
|
||||
} case IValue::Tag::Int:
|
||||
return out << v.toInt();
|
||||
case IValue::Tag::UInt:
|
||||
return out << v.toUInt();
|
||||
case IValue::Tag::SymInt:
|
||||
return out << v.toSymInt();
|
||||
case IValue::Tag::SymFloat:
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
#include <c10/macros/Export.h>
|
||||
#include <c10/util/MaybeOwned.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
|
@ -160,6 +161,7 @@ struct Capsule {
|
|||
_(Double) \
|
||||
_(ComplexDouble) \
|
||||
_(Int) \
|
||||
_(UInt) \
|
||||
_(SymInt) \
|
||||
_(SymFloat) \
|
||||
_(SymBool) \
|
||||
|
|
@ -653,6 +655,29 @@ struct TORCH_API IValue final {
|
|||
}
|
||||
}
|
||||
|
||||
// Unsigned
|
||||
IValue(uint64_t u) : tag( u <= std::numeric_limits<int64_t>::max() ? Tag::Int : Tag::UInt) {
|
||||
payload.u.as_uint = u;
|
||||
}
|
||||
|
||||
|
||||
// See Note [Meaning of HAS_u]
|
||||
// IValue type model closely follows that of c10::Scalar
|
||||
// Where all integers are upcast to 64-bit representation, and `as_int` is used as default
|
||||
// representation unless value could not be represented as signed int
|
||||
bool isUnsigned() const {
|
||||
return Tag::UInt == tag || (Tag::Int == tag && payload.u.as_int >= 0);
|
||||
}
|
||||
|
||||
uint64_t toUInt() const {
|
||||
if (isUnsigned()) {
|
||||
return payload.u.as_uint;
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(0, "expected unsigned int");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Bool
|
||||
IValue(bool b) : tag(Tag::Bool) {
|
||||
#if defined(__clang__) && defined(__x86_64__)
|
||||
|
|
@ -893,8 +918,14 @@ struct TORCH_API IValue final {
|
|||
} else {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
s.isIntegral(false), "Unknown type in Scalar");
|
||||
tag = Tag::Int;
|
||||
payload.u.as_int = s.toLong();
|
||||
if (s.isUnsigned()) {
|
||||
const auto val = s.toUInt64();
|
||||
payload.u.as_uint = val;
|
||||
tag = val <= std::numeric_limits<int64_t>::max() ? Tag::Int : Tag::UInt;
|
||||
} else {
|
||||
payload.u.as_int = s.toLong();
|
||||
tag = Tag::Int;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -918,6 +949,8 @@ struct TORCH_API IValue final {
|
|||
return toSymFloat();
|
||||
else if (isSymBool())
|
||||
return toSymBool();
|
||||
else if (isUnsigned())
|
||||
return toUInt();
|
||||
TORCH_CHECK(false, "IValue is not a Scalar");
|
||||
}
|
||||
|
||||
|
|
@ -1247,6 +1280,8 @@ struct TORCH_API IValue final {
|
|||
return true;
|
||||
case Tag::Int:
|
||||
return false;
|
||||
case Tag::UInt:
|
||||
return false;
|
||||
case Tag::SymInt:
|
||||
return true;
|
||||
case Tag::SymFloat:
|
||||
|
|
@ -1343,6 +1378,8 @@ struct TORCH_API IValue final {
|
|||
union TriviallyCopyablePayload {
|
||||
TriviallyCopyablePayload() : as_int(0) {}
|
||||
int64_t as_int;
|
||||
// See Note [Meaning of HAS_u]
|
||||
uint64_t as_uint;
|
||||
double as_double;
|
||||
bool as_bool;
|
||||
// Invariant: never nullptr; null state is represented as
|
||||
|
|
|
|||
|
|
@ -2513,6 +2513,19 @@ def forward(self, x_1):
|
|||
with Mode():
|
||||
torch.cond(pred, lambda x: x.sin(), lambda x: x.cos(), (x,))
|
||||
|
||||
def test_dispatch_uint64(self):
|
||||
class DummyMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args, kwargs):
|
||||
self.last_args = args
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Value that could not be intepreted as signed int64
|
||||
uarg = 2**63 + 1
|
||||
with DummyMode() as m:
|
||||
a = torch.full((3, 3), uarg, dtype=torch.uint64)
|
||||
self.assertEqual(m.last_args[1], uarg)
|
||||
self.assertTrue((a == uarg).all().item())
|
||||
|
||||
|
||||
class TestPythonDispatcher(TestCase):
|
||||
def test_basic(self):
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, std::optional<int32_t> N) {
|
|||
if (PyBool_Check(obj.ptr())) {
|
||||
scalar = at::Scalar(THPUtils_unpackBool(obj.ptr()));
|
||||
} else if (THPUtils_checkLong(obj.ptr())) {
|
||||
scalar = at::Scalar(THPUtils_unpackLong(obj.ptr()));
|
||||
scalar = THPUtils_unpackInteger<at::Scalar>(obj.ptr());
|
||||
} else if (PyComplex_Check(obj.ptr())) {
|
||||
scalar = at::Scalar(THPUtils_unpackComplexDouble(obj.ptr()));
|
||||
} else if (THPUtils_checkDouble(obj.ptr())) {
|
||||
|
|
@ -512,7 +512,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, std::optional<int32_t> N) {
|
|||
if (py::isinstance<py::bool_>(obj)) {
|
||||
return py::cast<bool>(obj);
|
||||
} else if (py::isinstance<py::int_>(obj)) {
|
||||
return py::cast<int64_t>(obj);
|
||||
return THPUtils_unpackInteger<IValue>(obj.ptr());
|
||||
} else if (py::isinstance<py::float_>(obj)) {
|
||||
return py::cast<double>(obj);
|
||||
} else if (PyComplex_CheckExact(obj.ptr())) {
|
||||
|
|
@ -598,6 +598,8 @@ py::object toPyObject(IValue ivalue) {
|
|||
return py::cast(*tensor.const_data_ptr<bool>());
|
||||
case at::ScalarType::Long:
|
||||
return py::cast(*tensor.const_data_ptr<int64_t>());
|
||||
case at::ScalarType::UInt64:
|
||||
return py::cast(*tensor.const_data_ptr<uint64_t>());
|
||||
case at::ScalarType::Double:
|
||||
return py::cast(*tensor.const_data_ptr<double>());
|
||||
case at::ScalarType::ComplexDouble:
|
||||
|
|
@ -763,6 +765,8 @@ py::object toPyObject(IValue ivalue) {
|
|||
return py::cast(std::move(ivalue).toSymFloat());
|
||||
} else if (ivalue.isSymBool()) {
|
||||
return py::cast(std::move(ivalue).toSymBool());
|
||||
} else if (ivalue.isUnsigned()) {
|
||||
return py::cast(std::move(ivalue).toUInt());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
|
|
|
|||
|
|
@ -1801,21 +1801,7 @@ at::Tensor PythonArgs::tensor_slow(int i) {
|
|||
if (PyBool_Check(obj)) {
|
||||
scalar = at::Scalar(THPUtils_unpackBool(obj));
|
||||
} else if (THPUtils_checkLong(obj)) {
|
||||
int overflow = -1;
|
||||
long long value = PyLong_AsLongLongAndOverflow(obj, &overflow);
|
||||
if (value == -1 && PyErr_Occurred()) {
|
||||
throw python_error();
|
||||
}
|
||||
if (overflow != 0) {
|
||||
// try unsigned
|
||||
unsigned long long value = PyLong_AsUnsignedLongLong(obj);
|
||||
if (value == static_cast<unsigned long long>(-1) && PyErr_Occurred()) {
|
||||
throw python_error();
|
||||
}
|
||||
scalar = at::Scalar(static_cast<uint64_t>(value));
|
||||
} else {
|
||||
scalar = at::Scalar(static_cast<int64_t>(value));
|
||||
}
|
||||
scalar = THPUtils_unpackInteger<at::Scalar>(obj);
|
||||
} else if (PyComplex_Check(obj)) {
|
||||
scalar = at::Scalar(THPUtils_unpackComplexDouble(obj));
|
||||
} else if (THPUtils_checkDouble(obj)) {
|
||||
|
|
|
|||
|
|
@ -208,3 +208,22 @@ inline c10::DeviceIndex THPUtils_unpackDeviceIndex(PyObject* obj) {
|
|||
}
|
||||
return (c10::DeviceIndex)value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T THPUtils_unpackInteger(PyObject* obj) {
|
||||
int overflow = -1;
|
||||
const auto value = PyLong_AsLongLongAndOverflow(obj, &overflow);
|
||||
if (value == -1 && PyErr_Occurred()) {
|
||||
throw python_error();
|
||||
}
|
||||
if (!overflow) {
|
||||
return static_cast<int64_t>(value);
|
||||
}
|
||||
// try unsigned
|
||||
const auto uvalue = PyLong_AsUnsignedLongLong(obj);
|
||||
if (uvalue == static_cast<std::decay_t<decltype(uvalue)>>(-1) &&
|
||||
PyErr_Occurred()) {
|
||||
throw python_error();
|
||||
}
|
||||
return static_cast<uint64_t>(uvalue);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user