mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
fix pybind issue for get_autocast_cpu_dtype and get_autocast_gpu_dtype (#66396)
Summary: There has an issue when calling **torch.get_autocast_cpu_dtype** and **torch.get_autocast_gpu_dtype**: ``` >>> torch.get_autocast_gpu_dtype()==torch.half False >>> torch.get_autocast_cpu_dtype()==torch.bfloat16 False ``` but the expected results should be : ``` >>> torch.get_autocast_gpu_dtype()==torch.half True >>> torch.get_autocast_cpu_dtype()==torch.bfloat16 True ``` This PR is about fixing this issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/66396 Reviewed By: ejguan Differential Revision: D31541727 Pulled By: albanD fbshipit-source-id: 1a0fe070a82590ef2926a517bf48046c2633d168
This commit is contained in:
parent
1b40daac74
commit
822c0850cb
|
|
@ -119,5 +119,13 @@ class TestAutocastCPU(TestCase):
|
|||
for op, args in self.autocast_lists.torch_need_autocast_promote:
|
||||
self._run_autocast_outofplace(op, args, torch.float32)
|
||||
|
||||
class TestTorchAutocast(TestCase):
|
||||
def test_autocast_fast_dtype(self):
|
||||
gpu_fast_dtype = torch.get_autocast_gpu_dtype()
|
||||
cpu_fast_dtype = torch.get_autocast_cpu_dtype()
|
||||
self.assertEqual(gpu_fast_dtype, torch.half)
|
||||
self.assertEqual(cpu_fast_dtype, torch.bfloat16)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -390,14 +390,18 @@ static const char* scalarTypeName(const at::ScalarType type) {
|
|||
static PyObject * get_autocast_gpu_dtype(PyObject* _unused, PyObject *arg){
|
||||
HANDLE_TH_ERRORS
|
||||
at::ScalarType current_dtype = at::autocast::get_autocast_gpu_dtype();
|
||||
return THPDtype_New(current_dtype, scalarTypeName(current_dtype));
|
||||
auto dtype = (PyObject*)torch::getTHPDtype(current_dtype);
|
||||
Py_INCREF(dtype);
|
||||
return dtype;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject * get_autocast_cpu_dtype(PyObject* _unused, PyObject *arg){
|
||||
HANDLE_TH_ERRORS
|
||||
at::ScalarType current_dtype = at::autocast::get_autocast_cpu_dtype();
|
||||
return THPDtype_New(current_dtype, scalarTypeName(current_dtype));
|
||||
auto dtype = (PyObject*)torch::getTHPDtype(current_dtype);
|
||||
Py_INCREF(dtype);
|
||||
return dtype;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user