Expose is_signed for dtype (#29511)

Summary:
Changelog:
- Expose is_signed for torch.dtype by modifying torch/csrc/Dtype.cpp
- Allow half, bfloat16 and bool to also been "known" by the isSignedType function
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29511

Test Plan:
- Add tests in test/test_torch.py

Closes https://github.com/pytorch/pytorch/issues/29475

Differential Revision: D18439030

Pulled By: albanD

fbshipit-source-id: 4b1f9da70c1c8dfd0a5bc028b6936acd1c64af47
This commit is contained in:
vishwakftw 2019-11-15 11:15:02 -08:00 committed by Facebook Github Bot
parent 23fcc409d5
commit 69e343f2cc
3 changed files with 23 additions and 5 deletions

View File

@ -316,11 +316,11 @@ static inline bool isSignedType(ScalarType t) {
case ScalarType::name: \
return std::numeric_limits<ctype>::is_signed;
switch (t) {
AT_FORALL_SCALAR_TYPES_AND(Half, CASE_SIGNED)
default:
AT_ERROR("Unknown ScalarType");
}
switch (toUnderlying(t)) {
AT_FORALL_SCALAR_TYPES_AND3(Half, Bool, BFloat16, CASE_SIGNED)
default:
AT_ERROR("Unknown ScalarType");
}
#undef CASE_SIGNED
}

View File

@ -2427,6 +2427,14 @@ class _TestTorchMixin(object):
x = torch.Tensor([1, nan, 2])
self.assertEqual(torch.isnan(x), torch.ByteTensor([0, 1, 0]))
def test_dtype_is_signed(self):
for dtype in torch.testing.get_all_dtypes():
self.assertEqual(dtype.is_signed, torch.is_signed(torch.tensor(0, dtype=dtype)))
self.assertFalse(torch.quint8.is_signed)
self.assertTrue(torch.qint8.is_signed)
self.assertTrue(torch.qint32.is_signed)
def test_RNGState(self):
state = torch.get_rng_state()
stateCloned = state.clone()

View File

@ -29,6 +29,15 @@ PyObject *THPDtype_is_floating_point(THPDtype *self, PyObject *noargs)
}
}
PyObject *THPDtype_is_signed(THPDtype *self, PyObject *noargs)
{
if (at::isSignedType(self->scalar_type)) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
}
PyObject *THPDtype_reduce(THPDtype *self, PyObject *noargs)
{
/*
@ -42,6 +51,7 @@ typedef PyObject *(*getter)(PyObject *, void *);
static struct PyGetSetDef THPDtype_properties[] = {
{"is_floating_point", (getter)THPDtype_is_floating_point, nullptr, nullptr, nullptr},
{"is_signed", (getter)THPDtype_is_signed, nullptr, nullptr, nullptr},
{nullptr}
};