Add unsigned support to IValue (#160102)

- Moved repeated logic of saving int64/uint64 into a polymorphic container into `THPUtils_unpackInteger`
- Added `TestPythonDispatch.test_dispatch_uint64` regression test

Fixes https://github.com/pytorch/pytorch/issues/159168

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160102
Approved by: https://github.com/ezyang
This commit is contained in:
Nikita Shulga 2025-08-10 20:07:40 -04:00 committed by PyTorch MergeBot
parent e7152ff8a6
commit d8cb3db533
6 changed files with 86 additions and 19 deletions

View File

@ -97,6 +97,8 @@ c10::TypePtr IValue::TagType<c10::Type>::get(const IValue& v) {
return ComplexType::get();
case Tag::Int:
return IntType::get();
case Tag::UInt:
return IntType::get();
case Tag::SymInt:
return c10::SymIntType::get();
case Tag::SymFloat:
@ -320,6 +322,8 @@ IValue IValue::equals(const IValue& rhs) const {
return rhs.isComplexDouble() && lhs.toComplexDouble() == rhs.toComplexDouble();
case Tag::Int:
return rhs.isInt() && lhs.toInt() == rhs.toInt();
case Tag::UInt:
return rhs.isUnsigned() && lhs.toUInt() == rhs.toUInt();
case Tag::SymInt:
return rhs.isSymInt() && lhs.toSymInt() == rhs.toSymInt();
case Tag::SymFloat:
@ -379,6 +383,8 @@ size_t IValue::hash(const IValue& v) {
case Tag::Int:
return c10::get_hash(v.payload.u.as_int);
// NB: these are technically strict aliasing violations
case Tag::UInt:
return c10::get_hash(v.payload.u.as_int);
case Tag::SymInt:
return c10::get_hash(v.payload.u.as_int);
case Tag::SymFloat:
@ -806,6 +812,8 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
return printComplex(out, v);
} case IValue::Tag::Int:
return out << v.toInt();
case IValue::Tag::UInt:
return out << v.toUInt();
case IValue::Tag::SymInt:
return out << v.toSymInt();
case IValue::Tag::SymFloat:

View File

@ -12,6 +12,7 @@
#include <c10/macros/Export.h>
#include <c10/util/MaybeOwned.h>
#include <c10/util/intrusive_ptr.h>
#include <limits>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
@ -160,6 +161,7 @@ struct Capsule {
_(Double) \
_(ComplexDouble) \
_(Int) \
_(UInt) \
_(SymInt) \
_(SymFloat) \
_(SymBool) \
@ -653,6 +655,29 @@ struct TORCH_API IValue final {
}
}
// Unsigned
IValue(uint64_t u) : tag( u <= std::numeric_limits<int64_t>::max() ? Tag::Int : Tag::UInt) {
payload.u.as_uint = u;
}
// See Note [Meaning of HAS_u]
// IValue type model closely follows that of c10::Scalar
// Where all integers are upcast to 64-bit representation, and `as_int` is used as default
// representation unless value could not be represented as signed int
bool isUnsigned() const {
return Tag::UInt == tag || (Tag::Int == tag && payload.u.as_int >= 0);
}
uint64_t toUInt() const {
if (isUnsigned()) {
return payload.u.as_uint;
} else {
TORCH_INTERNAL_ASSERT(0, "expected unsigned int");
}
}
// Bool
IValue(bool b) : tag(Tag::Bool) {
#if defined(__clang__) && defined(__x86_64__)
@ -893,8 +918,14 @@ struct TORCH_API IValue final {
} else {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
s.isIntegral(false), "Unknown type in Scalar");
tag = Tag::Int;
if (s.isUnsigned()) {
const auto val = s.toUInt64();
payload.u.as_uint = val;
tag = val <= std::numeric_limits<int64_t>::max() ? Tag::Int : Tag::UInt;
} else {
payload.u.as_int = s.toLong();
tag = Tag::Int;
}
}
}
@ -918,6 +949,8 @@ struct TORCH_API IValue final {
return toSymFloat();
else if (isSymBool())
return toSymBool();
else if (isUnsigned())
return toUInt();
TORCH_CHECK(false, "IValue is not a Scalar");
}
@ -1247,6 +1280,8 @@ struct TORCH_API IValue final {
return true;
case Tag::Int:
return false;
case Tag::UInt:
return false;
case Tag::SymInt:
return true;
case Tag::SymFloat:
@ -1343,6 +1378,8 @@ struct TORCH_API IValue final {
union TriviallyCopyablePayload {
TriviallyCopyablePayload() : as_int(0) {}
int64_t as_int;
// See Note [Meaning of HAS_u]
uint64_t as_uint;
double as_double;
bool as_bool;
// Invariant: never nullptr; null state is represented as

View File

@ -2513,6 +2513,19 @@ def forward(self, x_1):
with Mode():
torch.cond(pred, lambda x: x.sin(), lambda x: x.cos(), (x,))
def test_dispatch_uint64(self):
class DummyMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args, kwargs):
self.last_args = args
return func(*args, **kwargs)
# Value that could not be intepreted as signed int64
uarg = 2**63 + 1
with DummyMode() as m:
a = torch.full((3, 3), uarg, dtype=torch.uint64)
self.assertEqual(m.last_args[1], uarg)
self.assertTrue((a == uarg).all().item())
class TestPythonDispatcher(TestCase):
def test_basic(self):

View File

@ -90,7 +90,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, std::optional<int32_t> N) {
if (PyBool_Check(obj.ptr())) {
scalar = at::Scalar(THPUtils_unpackBool(obj.ptr()));
} else if (THPUtils_checkLong(obj.ptr())) {
scalar = at::Scalar(THPUtils_unpackLong(obj.ptr()));
scalar = THPUtils_unpackInteger<at::Scalar>(obj.ptr());
} else if (PyComplex_Check(obj.ptr())) {
scalar = at::Scalar(THPUtils_unpackComplexDouble(obj.ptr()));
} else if (THPUtils_checkDouble(obj.ptr())) {
@ -512,7 +512,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, std::optional<int32_t> N) {
if (py::isinstance<py::bool_>(obj)) {
return py::cast<bool>(obj);
} else if (py::isinstance<py::int_>(obj)) {
return py::cast<int64_t>(obj);
return THPUtils_unpackInteger<IValue>(obj.ptr());
} else if (py::isinstance<py::float_>(obj)) {
return py::cast<double>(obj);
} else if (PyComplex_CheckExact(obj.ptr())) {
@ -598,6 +598,8 @@ py::object toPyObject(IValue ivalue) {
return py::cast(*tensor.const_data_ptr<bool>());
case at::ScalarType::Long:
return py::cast(*tensor.const_data_ptr<int64_t>());
case at::ScalarType::UInt64:
return py::cast(*tensor.const_data_ptr<uint64_t>());
case at::ScalarType::Double:
return py::cast(*tensor.const_data_ptr<double>());
case at::ScalarType::ComplexDouble:
@ -763,6 +765,8 @@ py::object toPyObject(IValue ivalue) {
return py::cast(std::move(ivalue).toSymFloat());
} else if (ivalue.isSymBool()) {
return py::cast(std::move(ivalue).toSymBool());
} else if (ivalue.isUnsigned()) {
return py::cast(std::move(ivalue).toUInt());
} else {
TORCH_CHECK(
false,

View File

@ -1801,21 +1801,7 @@ at::Tensor PythonArgs::tensor_slow(int i) {
if (PyBool_Check(obj)) {
scalar = at::Scalar(THPUtils_unpackBool(obj));
} else if (THPUtils_checkLong(obj)) {
int overflow = -1;
long long value = PyLong_AsLongLongAndOverflow(obj, &overflow);
if (value == -1 && PyErr_Occurred()) {
throw python_error();
}
if (overflow != 0) {
// try unsigned
unsigned long long value = PyLong_AsUnsignedLongLong(obj);
if (value == static_cast<unsigned long long>(-1) && PyErr_Occurred()) {
throw python_error();
}
scalar = at::Scalar(static_cast<uint64_t>(value));
} else {
scalar = at::Scalar(static_cast<int64_t>(value));
}
scalar = THPUtils_unpackInteger<at::Scalar>(obj);
} else if (PyComplex_Check(obj)) {
scalar = at::Scalar(THPUtils_unpackComplexDouble(obj));
} else if (THPUtils_checkDouble(obj)) {

View File

@ -208,3 +208,22 @@ inline c10::DeviceIndex THPUtils_unpackDeviceIndex(PyObject* obj) {
}
return (c10::DeviceIndex)value;
}
template <typename T>
inline T THPUtils_unpackInteger(PyObject* obj) {
int overflow = -1;
const auto value = PyLong_AsLongLongAndOverflow(obj, &overflow);
if (value == -1 && PyErr_Occurred()) {
throw python_error();
}
if (!overflow) {
return static_cast<int64_t>(value);
}
// try unsigned
const auto uvalue = PyLong_AsUnsignedLongLong(obj);
if (uvalue == static_cast<std::decay_t<decltype(uvalue)>>(-1) &&
PyErr_Occurred()) {
throw python_error();
}
return static_cast<uint64_t>(uvalue);
}