mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Make PyTorch argparser understand complex (#129580)
It understands float and int, so why not `complex`. Test plan: `python -c "import torch;print(torch.rand(3, dtype=complex))"` Fixes https://github.com/pytorch/pytorch/issues/126837 Pull Request resolved: https://github.com/pytorch/pytorch/pull/129580 Approved by: https://github.com/albanD
This commit is contained in:
parent
dfd55d1714
commit
2bc6f329b2
|
|
@ -20,7 +20,8 @@ inline bool THPDtype_Check(PyObject* obj) {
|
|||
|
||||
inline bool THPPythonScalarType_Check(PyObject* obj) {
|
||||
return obj == (PyObject*)(&PyFloat_Type) ||
|
||||
obj == (PyObject*)(&PyBool_Type) || obj == (PyObject*)(&PyLong_Type);
|
||||
obj == (PyObject*)(&PyComplex_Type) || obj == (PyObject*)(&PyBool_Type) ||
|
||||
obj == (PyObject*)(&PyLong_Type);
|
||||
}
|
||||
|
||||
TORCH_API PyObject* THPDtype_New(
|
||||
|
|
|
|||
|
|
@ -755,6 +755,9 @@ inline at::ScalarType toScalarType(PyObject* obj) {
|
|||
if (obj == (PyObject*)&PyLong_Type) {
|
||||
return at::ScalarType::Long;
|
||||
}
|
||||
if (obj == (PyObject*)&PyComplex_Type) {
|
||||
return at::ScalarType::ComplexDouble;
|
||||
}
|
||||
return reinterpret_cast<THPDtype*>(obj)->scalar_type;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user