mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
using new device-agnostic api instead of old api like torch.cpu or torch.cuda (#134448)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134448 Approved by: https://github.com/guangyey, https://github.com/shink, https://github.com/albanD
This commit is contained in:
parent
0c7856973b
commit
f7467c3b95
|
|
@ -45,9 +45,9 @@ class TestAutocastCPU(TestCase):
|
||||||
if add_kwargs is None:
|
if add_kwargs is None:
|
||||||
add_kwargs = {}
|
add_kwargs = {}
|
||||||
|
|
||||||
self.assertFalse(torch.is_autocast_cpu_enabled())
|
self.assertFalse(torch.is_autocast_enabled(device_type="cpu"))
|
||||||
with torch.cpu.amp.autocast(dtype=amp_dtype):
|
with torch.amp.autocast(device_type="cpu", dtype=amp_dtype):
|
||||||
self.assertTrue(torch.is_autocast_cpu_enabled())
|
self.assertTrue(torch.is_autocast_enabled(device_type="cpu"))
|
||||||
out_type = out_type if out_type is not None else run_as_type
|
out_type = out_type if out_type is not None else run_as_type
|
||||||
output = output_method = None
|
output = output_method = None
|
||||||
|
|
||||||
|
|
@ -94,8 +94,8 @@ class TestAutocastCPU(TestCase):
|
||||||
# Compare numerics to Python-side "autocasting" that (we expect) does the same thing
|
# Compare numerics to Python-side "autocasting" that (we expect) does the same thing
|
||||||
# as the C++-side autocasting, and should be bitwise accurate.
|
# as the C++-side autocasting, and should be bitwise accurate.
|
||||||
output_to_compare = output if output is not None else output_method
|
output_to_compare = output if output is not None else output_method
|
||||||
with torch.cpu.amp.autocast(enabled=False):
|
with torch.amp.autocast(device_type="cpu", enabled=False):
|
||||||
self.assertFalse(torch.is_autocast_cpu_enabled())
|
self.assertFalse(torch.is_autocast_enabled(device_type="cpu"))
|
||||||
|
|
||||||
if module is not None and hasattr(module, op):
|
if module is not None and hasattr(module, op):
|
||||||
control = getattr(module, op)(
|
control = getattr(module, op)(
|
||||||
|
|
@ -108,8 +108,8 @@ class TestAutocastCPU(TestCase):
|
||||||
self.assertTrue(type(output_to_compare) == type(control))
|
self.assertTrue(type(output_to_compare) == type(control))
|
||||||
comparison = compare(output_to_compare, control)
|
comparison = compare(output_to_compare, control)
|
||||||
self.assertTrue(comparison, f"torch.{op} result did not match control")
|
self.assertTrue(comparison, f"torch.{op} result did not match control")
|
||||||
self.assertTrue(torch.is_autocast_cpu_enabled())
|
self.assertTrue(torch.is_autocast_enabled(device_type="cpu"))
|
||||||
self.assertFalse(torch.is_autocast_cpu_enabled())
|
self.assertFalse(torch.is_autocast_enabled(device_type="cpu"))
|
||||||
|
|
||||||
def args_maybe_kwargs(self, op_with_args):
|
def args_maybe_kwargs(self, op_with_args):
|
||||||
if len(op_with_args) == 2:
|
if len(op_with_args) == 2:
|
||||||
|
|
@ -237,7 +237,7 @@ class TestAutocastCPU(TestCase):
|
||||||
m(x, (hx, cx))
|
m(x, (hx, cx))
|
||||||
|
|
||||||
# Should be able to run the below case with autocast
|
# Should be able to run the below case with autocast
|
||||||
with torch.cpu.amp.autocast():
|
with torch.amp.autocast(device_type="cpu"):
|
||||||
m(x, (hx, cx))
|
m(x, (hx, cx))
|
||||||
|
|
||||||
def test_autocast_disabled_with_fp32_dtype(self):
|
def test_autocast_disabled_with_fp32_dtype(self):
|
||||||
|
|
@ -249,7 +249,7 @@ class TestAutocastCPU(TestCase):
|
||||||
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
|
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
|
||||||
with torch.amp.autocast(device_type="cpu"):
|
with torch.amp.autocast(device_type="cpu"):
|
||||||
generic_autocast_output = getattr(torch, op)(*args, **maybe_kwargs)
|
generic_autocast_output = getattr(torch, op)(*args, **maybe_kwargs)
|
||||||
with torch.cpu.amp.autocast():
|
with torch.amp.autocast(device_type="cpu"):
|
||||||
cpu_autocast_output = getattr(torch, op)(*args, **maybe_kwargs)
|
cpu_autocast_output = getattr(torch, op)(*args, **maybe_kwargs)
|
||||||
self.assertEqual(generic_autocast_output, cpu_autocast_output)
|
self.assertEqual(generic_autocast_output, cpu_autocast_output)
|
||||||
|
|
||||||
|
|
@ -346,8 +346,8 @@ class TestAutocastGPU(TestCase):
|
||||||
|
|
||||||
class TestTorchAutocast(TestCase):
|
class TestTorchAutocast(TestCase):
|
||||||
def test_autocast_fast_dtype(self):
|
def test_autocast_fast_dtype(self):
|
||||||
gpu_fast_dtype = torch.get_autocast_gpu_dtype()
|
gpu_fast_dtype = torch.get_autocast_dtype(device_type="cuda")
|
||||||
cpu_fast_dtype = torch.get_autocast_cpu_dtype()
|
cpu_fast_dtype = torch.get_autocast_dtype(device_type="cpu")
|
||||||
self.assertEqual(gpu_fast_dtype, torch.half)
|
self.assertEqual(gpu_fast_dtype, torch.half)
|
||||||
self.assertEqual(cpu_fast_dtype, torch.bfloat16)
|
self.assertEqual(cpu_fast_dtype, torch.bfloat16)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user