diff --git a/test/test_jit_autocast.py b/test/test_jit_autocast.py index 85a95c25403..62e735fec23 100644 --- a/test/test_jit_autocast.py +++ b/test/test_jit_autocast.py @@ -10,7 +10,7 @@ from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_utils import run_tests from torch.testing import FileCheck -TEST_BFLOAT16 = torch.cuda.is_bf16_supported() +TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported() class TestAutocast(JitTestCase): def setUp(self):