mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Moving _run_autocast_outofplace to basic class named TestAutocast to reduce redundance (#134460)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134460 Approved by: https://github.com/EikanWang, https://github.com/ezyang
This commit is contained in:
parent
c2ff9fe042
commit
80a6d60829
|
|
@ -1,10 +1,12 @@
|
|||
# Owner(s): ["module: unknown"]
|
||||
|
||||
import collections
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.autocast_test_lists import AutocastCPUTestLists
|
||||
from torch.testing._internal.autocast_test_lists import (
|
||||
AutocastCPUTestLists,
|
||||
TestAutocast,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
run_tests,
|
||||
|
|
@ -14,7 +16,7 @@ from torch.testing._internal.common_utils import (
|
|||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
|
||||
class TestAutocastCPU(TestCase):
|
||||
class TestAutocastCPU(TestAutocast):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.autocast_lists = AutocastCPUTestLists(torch.device("cpu"))
|
||||
|
|
@ -23,100 +25,6 @@ class TestAutocastCPU(TestCase):
|
|||
del self.autocast_lists
|
||||
super().tearDown()
|
||||
|
||||
def _run_autocast_outofplace(
|
||||
self,
|
||||
op,
|
||||
args,
|
||||
run_as_type,
|
||||
out_type=None,
|
||||
module=torch,
|
||||
add_kwargs=None,
|
||||
amp_dtype=torch.bfloat16,
|
||||
):
|
||||
# helper to cast args
|
||||
def cast(val, to_type):
|
||||
if isinstance(val, torch.Tensor):
|
||||
return val.to(to_type) if val.is_floating_point() else val
|
||||
elif isinstance(val, collections.abc.Iterable):
|
||||
return type(val)(cast(v, to_type) for v in val)
|
||||
else:
|
||||
return val
|
||||
|
||||
if add_kwargs is None:
|
||||
add_kwargs = {}
|
||||
|
||||
self.assertFalse(torch.is_autocast_enabled(device_type="cpu"))
|
||||
with torch.amp.autocast(device_type="cpu", dtype=amp_dtype):
|
||||
self.assertTrue(torch.is_autocast_enabled(device_type="cpu"))
|
||||
out_type = out_type if out_type is not None else run_as_type
|
||||
output = output_method = None
|
||||
|
||||
# Try module.* variant, if requested:
|
||||
if module is not None and hasattr(module, op):
|
||||
output = getattr(module, op)(*args, **add_kwargs)
|
||||
if isinstance(output, torch.Tensor):
|
||||
self.assertTrue(
|
||||
out_type == output.dtype,
|
||||
f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}",
|
||||
)
|
||||
# Try Tensor.* variant:
|
||||
if hasattr(torch.Tensor, op):
|
||||
output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
|
||||
if isinstance(output_method, torch.Tensor):
|
||||
self.assertTrue(
|
||||
out_type == output_method.dtype,
|
||||
f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}",
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
(output is not None) or (output_method is not None),
|
||||
f"{op} not found as an attribute on either Tensor or the requested module {module}",
|
||||
)
|
||||
|
||||
# Accounts for ops that return Tensors, iterables, and other non-Tensors.
|
||||
# For example, lstm_cell returns a tuple and equal returns bool.
|
||||
def compare(first, second):
|
||||
if isinstance(first, torch.Tensor):
|
||||
return torch.equal(first, second)
|
||||
elif isinstance(first, collections.abc.Iterable):
|
||||
return all(compare(f, s) for f, s in zip(first, second))
|
||||
else:
|
||||
return first == second
|
||||
|
||||
# If both torch.* and Tensor.* variants were found, check outputs are identical
|
||||
if (output is not None) and (output_method is not None):
|
||||
self.assertTrue(type(output) == type(output_method))
|
||||
comparison = compare(output, output_method)
|
||||
self.assertTrue(
|
||||
comparison, f"torch.{op} result did not match Tensor.{op} result"
|
||||
)
|
||||
|
||||
# Compare numerics to Python-side "autocasting" that (we expect) does the same thing
|
||||
# as the C++-side autocasting, and should be bitwise accurate.
|
||||
output_to_compare = output if output is not None else output_method
|
||||
with torch.amp.autocast(device_type="cpu", enabled=False):
|
||||
self.assertFalse(torch.is_autocast_enabled(device_type="cpu"))
|
||||
|
||||
if module is not None and hasattr(module, op):
|
||||
control = getattr(module, op)(
|
||||
*cast(args, run_as_type), **add_kwargs
|
||||
)
|
||||
else:
|
||||
control = getattr(args[0].to(run_as_type), op)(
|
||||
*cast(args[1:], run_as_type), **add_kwargs
|
||||
)
|
||||
self.assertTrue(type(output_to_compare) == type(control))
|
||||
comparison = compare(output_to_compare, control)
|
||||
self.assertTrue(comparison, f"torch.{op} result did not match control")
|
||||
self.assertTrue(torch.is_autocast_enabled(device_type="cpu"))
|
||||
self.assertFalse(torch.is_autocast_enabled(device_type="cpu"))
|
||||
|
||||
def args_maybe_kwargs(self, op_with_args):
|
||||
if len(op_with_args) == 2:
|
||||
return op_with_args[0], op_with_args[1], {}
|
||||
else:
|
||||
return op_with_args[0], op_with_args[1], op_with_args[2]
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
def test_autocast_torch_expect_builtin_promote(self):
|
||||
for (
|
||||
|
|
@ -125,9 +33,16 @@ class TestAutocastCPU(TestCase):
|
|||
args2,
|
||||
out_type,
|
||||
) in self.autocast_lists.torch_expect_builtin_promote:
|
||||
self._run_autocast_outofplace(op, args1, torch.float32, out_type=out_type)
|
||||
self._run_autocast_outofplace(
|
||||
op, args2, torch.float32, out_type=out_type, amp_dtype=torch.float16
|
||||
op, args1, torch.float32, device="cpu", out_type=out_type
|
||||
)
|
||||
self._run_autocast_outofplace(
|
||||
op,
|
||||
args2,
|
||||
torch.float32,
|
||||
device="cpu",
|
||||
out_type=out_type,
|
||||
amp_dtype=torch.float16,
|
||||
)
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
|
|
@ -139,12 +54,13 @@ class TestAutocastCPU(TestCase):
|
|||
out_type,
|
||||
) in self.autocast_lists.methods_expect_builtin_promote:
|
||||
self._run_autocast_outofplace(
|
||||
op, args1, torch.float32, module=None, out_type=out_type
|
||||
op, args1, torch.float32, device="cpu", module=None, out_type=out_type
|
||||
)
|
||||
self._run_autocast_outofplace(
|
||||
op,
|
||||
args2,
|
||||
torch.float32,
|
||||
device="cpu",
|
||||
module=None,
|
||||
out_type=out_type,
|
||||
amp_dtype=torch.float16,
|
||||
|
|
@ -155,12 +71,13 @@ class TestAutocastCPU(TestCase):
|
|||
for op_with_args in self.autocast_lists.torch_16:
|
||||
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
|
||||
self._run_autocast_outofplace(
|
||||
op, args, torch.bfloat16, add_kwargs=maybe_kwargs
|
||||
op, args, torch.bfloat16, device="cpu", add_kwargs=maybe_kwargs
|
||||
)
|
||||
self._run_autocast_outofplace(
|
||||
op,
|
||||
args,
|
||||
torch.float16,
|
||||
device="cpu",
|
||||
add_kwargs=maybe_kwargs,
|
||||
amp_dtype=torch.float16,
|
||||
)
|
||||
|
|
@ -170,12 +87,18 @@ class TestAutocastCPU(TestCase):
|
|||
for op_with_args in self.autocast_lists.nn_16:
|
||||
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
|
||||
self._run_autocast_outofplace(
|
||||
op, args, torch.bfloat16, module=torch._C._nn, add_kwargs=maybe_kwargs
|
||||
op,
|
||||
args,
|
||||
torch.bfloat16,
|
||||
device="cpu",
|
||||
module=torch._C._nn,
|
||||
add_kwargs=maybe_kwargs,
|
||||
)
|
||||
self._run_autocast_outofplace(
|
||||
op,
|
||||
args,
|
||||
torch.float16,
|
||||
device="cpu",
|
||||
module=torch._C._nn,
|
||||
add_kwargs=maybe_kwargs,
|
||||
amp_dtype=torch.float16,
|
||||
|
|
@ -186,12 +109,13 @@ class TestAutocastCPU(TestCase):
|
|||
for op_with_args in self.autocast_lists.torch_fp32:
|
||||
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
|
||||
self._run_autocast_outofplace(
|
||||
op, args, torch.float32, add_kwargs=maybe_kwargs
|
||||
op, args, torch.float32, device="cpu", add_kwargs=maybe_kwargs
|
||||
)
|
||||
self._run_autocast_outofplace(
|
||||
op,
|
||||
args,
|
||||
torch.float32,
|
||||
device="cpu",
|
||||
add_kwargs=maybe_kwargs,
|
||||
amp_dtype=torch.float16,
|
||||
)
|
||||
|
|
@ -201,12 +125,18 @@ class TestAutocastCPU(TestCase):
|
|||
for op_with_args in self.autocast_lists.nn_fp32:
|
||||
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
|
||||
self._run_autocast_outofplace(
|
||||
op, args, torch.float32, module=torch._C._nn, add_kwargs=maybe_kwargs
|
||||
op,
|
||||
args,
|
||||
torch.float32,
|
||||
device="cpu",
|
||||
module=torch._C._nn,
|
||||
add_kwargs=maybe_kwargs,
|
||||
)
|
||||
self._run_autocast_outofplace(
|
||||
op,
|
||||
args,
|
||||
torch.float32,
|
||||
device="cpu",
|
||||
module=torch._C._nn,
|
||||
add_kwargs=maybe_kwargs,
|
||||
amp_dtype=torch.float16,
|
||||
|
|
@ -215,9 +145,9 @@ class TestAutocastCPU(TestCase):
|
|||
@skipIfTorchDynamo()
|
||||
def test_autocast_torch_need_autocast_promote(self):
|
||||
for op, args1, args2 in self.autocast_lists.torch_need_autocast_promote:
|
||||
self._run_autocast_outofplace(op, args1, torch.float32)
|
||||
self._run_autocast_outofplace(op, args1, torch.float32, device="cpu")
|
||||
self._run_autocast_outofplace(
|
||||
op, args2, torch.float32, amp_dtype=torch.float16
|
||||
op, args2, torch.float32, device="cpu", amp_dtype=torch.float16
|
||||
)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
|
||||
|
|
|
|||
1094
test/test_cuda.py
1094
test/test_cuda.py
File diff suppressed because it is too large
Load Diff
107
test/test_xpu.py
107
test/test_xpu.py
|
|
@ -1,6 +1,5 @@
|
|||
# Owner(s): ["module: intel"]
|
||||
|
||||
import collections
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
|
|
@ -8,7 +7,7 @@ import unittest
|
|||
|
||||
import torch
|
||||
import torch.xpu._gpu_trace as gpu_trace
|
||||
from torch.testing._internal.autocast_test_lists import AutocastTestLists
|
||||
from torch.testing._internal.autocast_test_lists import AutocastTestLists, TestAutocast
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
onlyXPU,
|
||||
|
|
@ -371,7 +370,7 @@ print(torch.xpu.device_count())
|
|||
instantiate_device_type_tests(TestXpu, globals(), only_for="xpu", allow_xpu=True)
|
||||
|
||||
|
||||
class TestXpuAutocast(TestCase):
|
||||
class TestXpuAutocast(TestAutocast):
|
||||
# These operators are not implemented on XPU backend and we can NOT fall back
|
||||
# them to CPU. So we have to skip them at this moment.
|
||||
# TODO: remove these operators from skip list when they are implemented on XPU backend.
|
||||
|
|
@ -385,89 +384,6 @@ class TestXpuAutocast(TestCase):
|
|||
del self.autocast_lists
|
||||
super().tearDown()
|
||||
|
||||
def _run_autocast_outofplace(
|
||||
self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None
|
||||
):
|
||||
# helper to cast args
|
||||
def cast(val, to_type):
|
||||
if isinstance(val, torch.Tensor):
|
||||
return val.to(to_type) if val.is_floating_point() else val
|
||||
elif isinstance(val, collections.abc.Iterable):
|
||||
return type(val)(cast(v, to_type) for v in val)
|
||||
else:
|
||||
return val
|
||||
|
||||
if add_kwargs is None:
|
||||
add_kwargs = {}
|
||||
fast_dtype = torch.bfloat16 if run_as_type == torch.bfloat16 else torch.float16
|
||||
self.assertFalse(torch.is_autocast_enabled("xpu"))
|
||||
with torch.amp.autocast("xpu", dtype=fast_dtype):
|
||||
self.assertTrue(torch.is_autocast_enabled("xpu"))
|
||||
|
||||
out_type = out_type if out_type is not None else run_as_type
|
||||
output = output_method = None
|
||||
|
||||
# Try module.* variant, if requested:
|
||||
if module is not None and hasattr(module, op):
|
||||
output = getattr(module, op)(*args, **add_kwargs)
|
||||
if isinstance(output, torch.Tensor):
|
||||
self.assertTrue(
|
||||
out_type == output.dtype,
|
||||
f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}",
|
||||
)
|
||||
|
||||
# Try Tensor.* variant:
|
||||
if hasattr(torch.Tensor, op):
|
||||
output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
|
||||
if isinstance(output_method, torch.Tensor):
|
||||
self.assertTrue(
|
||||
out_type == output_method.dtype,
|
||||
f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}",
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
(output is not None) or (output_method is not None),
|
||||
f"{op} not found as an attribute on either Tensor or the requested module {module}",
|
||||
)
|
||||
|
||||
# Accounts for ops that return Tensors, iterables, and other non-Tensors.
|
||||
# For example, lstm_cell returns a tuple and equal returns bool.
|
||||
def compare(first, second):
|
||||
if isinstance(first, torch.Tensor):
|
||||
return torch.equal(first, second)
|
||||
elif isinstance(first, collections.abc.Iterable):
|
||||
return all(compare(f, s) for f, s in zip(first, second))
|
||||
else:
|
||||
return first == second
|
||||
|
||||
# If both torch.* and Tensor.* variants were found, check outputs are identical
|
||||
if (output is not None) and (output_method is not None):
|
||||
self.assertTrue(type(output) == type(output_method))
|
||||
comparison = compare(output, output_method)
|
||||
self.assertTrue(
|
||||
comparison, f"torch.{op} result did not match Tensor.{op} result"
|
||||
)
|
||||
|
||||
# Compare numerics to Python-side "autocasting" that (we expect) does the same thing
|
||||
# as the C++-side autocasting, and should be bitwise accurate.
|
||||
output_to_compare = output if output is not None else output_method
|
||||
with torch.amp.autocast("xpu", enabled=False):
|
||||
self.assertFalse(torch.is_autocast_enabled("xpu"))
|
||||
|
||||
if module is not None and hasattr(module, op):
|
||||
control = getattr(module, op)(
|
||||
*cast(args, run_as_type), **add_kwargs
|
||||
)
|
||||
else:
|
||||
control = getattr(args[0].to(run_as_type), op)(
|
||||
*cast(args[1:], run_as_type), **add_kwargs
|
||||
)
|
||||
self.assertTrue(type(output_to_compare) == type(control))
|
||||
comparison = compare(output_to_compare, control)
|
||||
self.assertTrue(comparison, f"torch.{op} result did not match control")
|
||||
self.assertTrue(torch.is_autocast_enabled("xpu"))
|
||||
self.assertFalse(torch.is_autocast_enabled("xpu"))
|
||||
|
||||
def test_autocast_torch_fp16(self):
|
||||
for op_with_args in self.autocast_lists.torch_fp16:
|
||||
skip_test = False
|
||||
|
|
@ -477,7 +393,9 @@ class TestXpuAutocast(TestCase):
|
|||
if len(op_with_args) == 3:
|
||||
skip_test = True # skip cudnn op
|
||||
if not skip_test:
|
||||
self._run_autocast_outofplace(op, args, torch.float16)
|
||||
self._run_autocast_outofplace(
|
||||
op, args, torch.float16, device="xpu", amp_dtype=torch.float16
|
||||
)
|
||||
|
||||
def test_autocast_torch_bf16(self):
|
||||
for op_with_args in self.autocast_lists.torch_fp16:
|
||||
|
|
@ -488,15 +406,24 @@ class TestXpuAutocast(TestCase):
|
|||
if len(op_with_args) == 3:
|
||||
skip_test = True # skip cudnn op
|
||||
if not skip_test:
|
||||
self._run_autocast_outofplace(op, args, torch.bfloat16)
|
||||
self._run_autocast_outofplace(op, args, torch.bfloat16, device="xpu")
|
||||
|
||||
def test_autocast_torch_need_autocast_promote(self):
|
||||
for op, args in self.autocast_lists.torch_need_autocast_promote:
|
||||
self._run_autocast_outofplace(op, args, torch.float32)
|
||||
self._run_autocast_outofplace(
|
||||
op, args, torch.float32, device="xpu", amp_dtype=torch.float16
|
||||
)
|
||||
|
||||
def test_autocast_torch_expect_builtin_promote(self):
|
||||
for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote:
|
||||
self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type)
|
||||
self._run_autocast_outofplace(
|
||||
op,
|
||||
args,
|
||||
torch.float32,
|
||||
device="xpu",
|
||||
out_type=out_type,
|
||||
amp_dtype=torch.float16,
|
||||
)
|
||||
|
||||
def test_autocast_checkpointing(self):
|
||||
model = torch.nn.Sequential(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import collections
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import TEST_WITH_ROCM
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
|
||||
class AutocastTestLists:
|
||||
|
|
@ -234,6 +237,7 @@ class AutocastTestLists:
|
|||
torch.rand((n, n), device=dev, dtype=torch.float32)), torch._C._nn),
|
||||
]
|
||||
|
||||
|
||||
class AutocastCPUTestLists:
|
||||
# Supplies ops and arguments for test_autocast_* in test/test_cpu.py
|
||||
def __init__(self, dev):
|
||||
|
|
@ -368,3 +372,103 @@ class AutocastCPUTestLists:
|
|||
("cat", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)),
|
||||
("stack", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)),
|
||||
]
|
||||
|
||||
|
||||
class TestAutocast(TestCase):
|
||||
def args_maybe_kwargs(self, op_with_args):
|
||||
if len(op_with_args) == 2:
|
||||
return op_with_args[0], op_with_args[1], {}
|
||||
else:
|
||||
return op_with_args[0], op_with_args[1], op_with_args[2]
|
||||
|
||||
def _run_autocast_outofplace(
|
||||
self,
|
||||
op,
|
||||
args,
|
||||
run_as_type,
|
||||
device,
|
||||
out_type=None,
|
||||
module=torch,
|
||||
add_kwargs=None,
|
||||
amp_dtype=torch.bfloat16,
|
||||
):
|
||||
# helper to cast args
|
||||
def cast(val, to_type):
|
||||
if isinstance(val, torch.Tensor):
|
||||
return val.to(to_type) if val.is_floating_point() else val
|
||||
elif isinstance(val, collections.abc.Iterable):
|
||||
return type(val)(cast(v, to_type) for v in val)
|
||||
else:
|
||||
return val
|
||||
|
||||
if add_kwargs is None:
|
||||
add_kwargs = {}
|
||||
|
||||
self.assertFalse(torch.is_autocast_enabled(device_type=device))
|
||||
with torch.amp.autocast(device_type=device, dtype=amp_dtype):
|
||||
self.assertTrue(torch.is_autocast_enabled(device_type=device))
|
||||
|
||||
out_type = out_type if out_type is not None else run_as_type
|
||||
output = output_method = None
|
||||
|
||||
# Try module.* variant, if requested:
|
||||
if module is not None and hasattr(module, op):
|
||||
output = getattr(module, op)(*args, **add_kwargs)
|
||||
if isinstance(output, torch.Tensor):
|
||||
self.assertTrue(
|
||||
out_type == output.dtype,
|
||||
f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}",
|
||||
)
|
||||
# Try Tensor.* variant:
|
||||
if hasattr(torch.Tensor, op):
|
||||
output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
|
||||
if isinstance(output_method, torch.Tensor):
|
||||
self.assertTrue(
|
||||
out_type == output_method.dtype,
|
||||
f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}",
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
(output is not None) or (output_method is not None),
|
||||
f"{op} not found as an attribute on either Tensor or the requested module {module}",
|
||||
)
|
||||
|
||||
# Accounts for ops that return Tensors, iterables, and other non-Tensors.
|
||||
# For example, lstm_cell returns a tuple and equal returns bool.
|
||||
def compare(first, second):
|
||||
if isinstance(first, torch.Tensor):
|
||||
return torch.equal(first, second)
|
||||
elif isinstance(first, collections.abc.Iterable):
|
||||
return all(compare(f, s) for f, s in zip(first, second))
|
||||
else:
|
||||
return first == second
|
||||
|
||||
# If both torch.* and Tensor.* variants were found, check outputs are identical
|
||||
if (output is not None) and (output_method is not None):
|
||||
self.assertTrue(type(output) == type(output_method))
|
||||
comparison = compare(output, output_method)
|
||||
self.assertTrue(
|
||||
comparison, f"torch.{op} result did not match Tensor.{op} result"
|
||||
)
|
||||
|
||||
# Compare numerics to Python-side "autocasting" that (we expect) does the same thing
|
||||
# as the C++-side autocasting, and should be bitwise accurate.
|
||||
output_to_compare = output if output is not None else output_method
|
||||
with torch.amp.autocast(device_type=device, enabled=False):
|
||||
self.assertFalse(
|
||||
torch.is_autocast_enabled(device_type=device)
|
||||
)
|
||||
|
||||
if module is not None and hasattr(module, op):
|
||||
control = getattr(module, op)(
|
||||
*cast(args, run_as_type), **add_kwargs
|
||||
)
|
||||
else:
|
||||
control = getattr(args[0].to(run_as_type), op)(
|
||||
*cast(args[1:], run_as_type), **add_kwargs
|
||||
)
|
||||
self.assertTrue(type(output_to_compare) == type(control))
|
||||
comparison = compare(output_to_compare, control)
|
||||
self.assertTrue(comparison, f"torch.{op} result did not match control")
|
||||
self.assertTrue(torch.is_autocast_enabled(device_type=device))
|
||||
self.assertFalse(torch.is_autocast_enabled(device_type=device))
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user