Make PyTorch argparser understand complex (#129580)

It understands float and int, so why not `complex`.

Test plan: `python -c "import torch;print(torch.rand(3, dtype=complex))"`

Fixes https://github.com/pytorch/pytorch/issues/126837

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129580
Approved by: https://github.com/albanD
This commit is contained in:
Nikita Shulga 2024-06-29 01:21:12 +00:00 committed by PyTorch MergeBot
parent dfd55d1714
commit 2bc6f329b2
2 changed files with 5 additions and 1 deletions

View File

@ -20,7 +20,8 @@ inline bool THPDtype_Check(PyObject* obj) {
inline bool THPPythonScalarType_Check(PyObject* obj) {
return obj == (PyObject*)(&PyFloat_Type) ||
obj == (PyObject*)(&PyBool_Type) || obj == (PyObject*)(&PyLong_Type);
obj == (PyObject*)(&PyComplex_Type) || obj == (PyObject*)(&PyBool_Type) ||
obj == (PyObject*)(&PyLong_Type);
}
TORCH_API PyObject* THPDtype_New(

View File

@ -755,6 +755,9 @@ inline at::ScalarType toScalarType(PyObject* obj) {
if (obj == (PyObject*)&PyLong_Type) {
return at::ScalarType::Long;
}
if (obj == (PyObject*)&PyComplex_Type) {
return at::ScalarType::ComplexDouble;
}
return reinterpret_cast<THPDtype*>(obj)->scalar_type;
}