mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Expose is_signed for dtype (#29511)
Summary: Changelog: - Expose is_signed for torch.dtype by modifying torch/csrc/Dtype.cpp - Allow half, bfloat16 and bool to also been "known" by the isSignedType function Pull Request resolved: https://github.com/pytorch/pytorch/pull/29511 Test Plan: - Add tests in test/test_torch.py Closes https://github.com/pytorch/pytorch/issues/29475 Differential Revision: D18439030 Pulled By: albanD fbshipit-source-id: 4b1f9da70c1c8dfd0a5bc028b6936acd1c64af47
This commit is contained in:
parent
23fcc409d5
commit
69e343f2cc
|
|
@ -316,11 +316,11 @@ static inline bool isSignedType(ScalarType t) {
|
|||
case ScalarType::name: \
|
||||
return std::numeric_limits<ctype>::is_signed;
|
||||
|
||||
switch (t) {
|
||||
AT_FORALL_SCALAR_TYPES_AND(Half, CASE_SIGNED)
|
||||
default:
|
||||
AT_ERROR("Unknown ScalarType");
|
||||
}
|
||||
switch (toUnderlying(t)) {
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Half, Bool, BFloat16, CASE_SIGNED)
|
||||
default:
|
||||
AT_ERROR("Unknown ScalarType");
|
||||
}
|
||||
#undef CASE_SIGNED
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2427,6 +2427,14 @@ class _TestTorchMixin(object):
|
|||
x = torch.Tensor([1, nan, 2])
|
||||
self.assertEqual(torch.isnan(x), torch.ByteTensor([0, 1, 0]))
|
||||
|
||||
def test_dtype_is_signed(self):
|
||||
for dtype in torch.testing.get_all_dtypes():
|
||||
self.assertEqual(dtype.is_signed, torch.is_signed(torch.tensor(0, dtype=dtype)))
|
||||
|
||||
self.assertFalse(torch.quint8.is_signed)
|
||||
self.assertTrue(torch.qint8.is_signed)
|
||||
self.assertTrue(torch.qint32.is_signed)
|
||||
|
||||
def test_RNGState(self):
|
||||
state = torch.get_rng_state()
|
||||
stateCloned = state.clone()
|
||||
|
|
|
|||
|
|
@ -29,6 +29,15 @@ PyObject *THPDtype_is_floating_point(THPDtype *self, PyObject *noargs)
|
|||
}
|
||||
}
|
||||
|
||||
PyObject *THPDtype_is_signed(THPDtype *self, PyObject *noargs)
|
||||
{
|
||||
if (at::isSignedType(self->scalar_type)) {
|
||||
Py_RETURN_TRUE;
|
||||
} else {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
}
|
||||
|
||||
PyObject *THPDtype_reduce(THPDtype *self, PyObject *noargs)
|
||||
{
|
||||
/*
|
||||
|
|
@ -42,6 +51,7 @@ typedef PyObject *(*getter)(PyObject *, void *);
|
|||
|
||||
static struct PyGetSetDef THPDtype_properties[] = {
|
||||
{"is_floating_point", (getter)THPDtype_is_floating_point, nullptr, nullptr, nullptr},
|
||||
{"is_signed", (getter)THPDtype_is_signed, nullptr, nullptr, nullptr},
|
||||
{nullptr}
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user