diff --git a/test/test_autocast.py b/test/test_autocast.py index 24f87944990..4c3f031a88c 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -45,9 +45,9 @@ class TestAutocastCPU(TestCase): if add_kwargs is None: add_kwargs = {} - self.assertFalse(torch.is_autocast_cpu_enabled()) - with torch.cpu.amp.autocast(dtype=amp_dtype): - self.assertTrue(torch.is_autocast_cpu_enabled()) + self.assertFalse(torch.is_autocast_enabled(device_type="cpu")) + with torch.amp.autocast(device_type="cpu", dtype=amp_dtype): + self.assertTrue(torch.is_autocast_enabled(device_type="cpu")) out_type = out_type if out_type is not None else run_as_type output = output_method = None @@ -94,8 +94,8 @@ class TestAutocastCPU(TestCase): # Compare numerics to Python-side "autocasting" that (we expect) does the same thing # as the C++-side autocasting, and should be bitwise accurate. output_to_compare = output if output is not None else output_method - with torch.cpu.amp.autocast(enabled=False): - self.assertFalse(torch.is_autocast_cpu_enabled()) + with torch.amp.autocast(device_type="cpu", enabled=False): + self.assertFalse(torch.is_autocast_enabled(device_type="cpu")) if module is not None and hasattr(module, op): control = getattr(module, op)( @@ -108,8 +108,8 @@ class TestAutocastCPU(TestCase): self.assertTrue(type(output_to_compare) == type(control)) comparison = compare(output_to_compare, control) self.assertTrue(comparison, f"torch.{op} result did not match control") - self.assertTrue(torch.is_autocast_cpu_enabled()) - self.assertFalse(torch.is_autocast_cpu_enabled()) + self.assertTrue(torch.is_autocast_enabled(device_type="cpu")) + self.assertFalse(torch.is_autocast_enabled(device_type="cpu")) def args_maybe_kwargs(self, op_with_args): if len(op_with_args) == 2: @@ -237,7 +237,7 @@ class TestAutocastCPU(TestCase): m(x, (hx, cx)) # Should be able to run the below case with autocast - with torch.cpu.amp.autocast(): + with torch.amp.autocast(device_type="cpu"): m(x, (hx, cx)) def test_autocast_disabled_with_fp32_dtype(self): @@ -249,7 +249,7 @@ class TestAutocastCPU(TestCase): op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) with torch.amp.autocast(device_type="cpu"): generic_autocast_output = getattr(torch, op)(*args, **maybe_kwargs) - with torch.cpu.amp.autocast(): + with torch.amp.autocast(device_type="cpu"): cpu_autocast_output = getattr(torch, op)(*args, **maybe_kwargs) self.assertEqual(generic_autocast_output, cpu_autocast_output) @@ -346,8 +346,8 @@ class TestAutocastGPU(TestCase): class TestTorchAutocast(TestCase): def test_autocast_fast_dtype(self): - gpu_fast_dtype = torch.get_autocast_gpu_dtype() - cpu_fast_dtype = torch.get_autocast_cpu_dtype() + gpu_fast_dtype = torch.get_autocast_dtype(device_type="cuda") + cpu_fast_dtype = torch.get_autocast_dtype(device_type="cpu") self.assertEqual(gpu_fast_dtype, torch.half) self.assertEqual(cpu_fast_dtype, torch.bfloat16)