fix pybind issue for get_autocast_cpu_dtype and get_autocast_gpu_dtype (#66396)

Summary:
There has an issue when calling **torch.get_autocast_cpu_dtype** and **torch.get_autocast_gpu_dtype**:
```
>>> torch.get_autocast_gpu_dtype()==torch.half
False
>>> torch.get_autocast_cpu_dtype()==torch.bfloat16
False
```
but the expected results  should be :
```
>>> torch.get_autocast_gpu_dtype()==torch.half
True
>>> torch.get_autocast_cpu_dtype()==torch.bfloat16
True
```

This PR is about fixing this issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/66396

Reviewed By: ejguan

Differential Revision: D31541727

Pulled By: albanD

fbshipit-source-id: 1a0fe070a82590ef2926a517bf48046c2633d168
This commit is contained in:
XiaobingSuper 2021-10-11 08:32:15 -07:00 committed by Facebook GitHub Bot
parent 1b40daac74
commit 822c0850cb
2 changed files with 14 additions and 2 deletions

View File

@ -119,5 +119,13 @@ class TestAutocastCPU(TestCase):
for op, args in self.autocast_lists.torch_need_autocast_promote:
self._run_autocast_outofplace(op, args, torch.float32)
class TestTorchAutocast(TestCase):
def test_autocast_fast_dtype(self):
gpu_fast_dtype = torch.get_autocast_gpu_dtype()
cpu_fast_dtype = torch.get_autocast_cpu_dtype()
self.assertEqual(gpu_fast_dtype, torch.half)
self.assertEqual(cpu_fast_dtype, torch.bfloat16)
if __name__ == '__main__':
run_tests()

View File

@ -390,14 +390,18 @@ static const char* scalarTypeName(const at::ScalarType type) {
static PyObject * get_autocast_gpu_dtype(PyObject* _unused, PyObject *arg){
HANDLE_TH_ERRORS
at::ScalarType current_dtype = at::autocast::get_autocast_gpu_dtype();
return THPDtype_New(current_dtype, scalarTypeName(current_dtype));
auto dtype = (PyObject*)torch::getTHPDtype(current_dtype);
Py_INCREF(dtype);
return dtype;
END_HANDLE_TH_ERRORS
}
static PyObject * get_autocast_cpu_dtype(PyObject* _unused, PyObject *arg){
HANDLE_TH_ERRORS
at::ScalarType current_dtype = at::autocast::get_autocast_cpu_dtype();
return THPDtype_New(current_dtype, scalarTypeName(current_dtype));
auto dtype = (PyObject*)torch::getTHPDtype(current_dtype);
Py_INCREF(dtype);
return dtype;
END_HANDLE_TH_ERRORS
}