Switch to torch.float16 on XPU AMP mode (#127741)

# Motivation
Previously, the default dtype for AMP on XPU was aligned with the CPU. To align with other GPUs, we intend to change the default dtype for AMP to `torch.float16`. This change aims to save users the effort of converting models from `torch.float16` to `torch.bfloat16`, or vice versa when they want to run the model on different types of GPUs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127741
Approved by: https://github.com/EikanWang, https://github.com/albanD
This commit is contained in:
Yu, Guangye 2024-06-03 15:29:31 +00:00 committed by PyTorch MergeBot
parent 1d0c1087dd
commit 304956e1fb
2 changed files with 131 additions and 1 deletions

View File

@ -68,7 +68,7 @@ thread_local std::array<at::ScalarType, at::COMPILE_TIME_MAX_DEVICE_TYPES>
at::kBFloat16, // XLA / TPU at::kBFloat16, // XLA / TPU
at::ScalarType::Undefined, // Vulkan at::ScalarType::Undefined, // Vulkan
at::ScalarType::Undefined, // Metal at::ScalarType::Undefined, // Metal
at::kBFloat16, // XPU at::kHalf, // XPU
at::ScalarType::Undefined, // MPS at::ScalarType::Undefined, // MPS
at::ScalarType::Undefined, // Meta (tensors with no data) at::ScalarType::Undefined, // Meta (tensors with no data)
at::kBFloat16, // HPU / HABANA at::kBFloat16, // HPU / HABANA

View File

@ -1,11 +1,13 @@
# Owner(s): ["module: intel"] # Owner(s): ["module: intel"]
import collections
import sys import sys
import tempfile import tempfile
import unittest import unittest
import torch import torch
import torch.xpu._gpu_trace as gpu_trace import torch.xpu._gpu_trace as gpu_trace
from torch.testing._internal.autocast_test_lists import AutocastTestLists
from torch.testing._internal.common_device_type import ( from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, instantiate_device_type_tests,
onlyXPU, onlyXPU,
@ -309,6 +311,134 @@ if __name__ == "__main__":
instantiate_device_type_tests(TestXpu, globals(), only_for="xpu") instantiate_device_type_tests(TestXpu, globals(), only_for="xpu")
class TestXpuAutocast(TestCase):
def setUp(self):
super().setUp()
self.autocast_lists = AutocastTestLists(torch.device("xpu"))
def tearDown(self):
del self.autocast_lists
super().tearDown()
def _run_autocast_outofplace(
self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None
):
# helper to cast args
def cast(val, to_type):
if isinstance(val, torch.Tensor):
return val.to(to_type) if val.is_floating_point() else val
elif isinstance(val, collections.abc.Iterable):
return type(val)(cast(v, to_type) for v in val)
else:
return val
if add_kwargs is None:
add_kwargs = {}
fast_dtype = torch.bfloat16 if run_as_type == torch.bfloat16 else torch.float16
self.assertFalse(torch.is_autocast_enabled())
with torch.amp.autocast("xpu", dtype=fast_dtype):
self.assertTrue(torch.is_autocast_enabled())
out_type = out_type if out_type is not None else run_as_type
output = output_method = None
# Try module.* variant, if requested:
if module is not None and hasattr(module, op):
output = getattr(module, op)(*args, **add_kwargs)
if isinstance(output, torch.Tensor):
self.assertTrue(
out_type == output.dtype,
f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}",
)
# Try Tensor.* variant:
if hasattr(torch.Tensor, op):
output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
if isinstance(output_method, torch.Tensor):
self.assertTrue(
out_type == output_method.dtype,
f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}",
)
self.assertTrue(
(output is not None) or (output_method is not None),
f"{op} not found as an attribute on either Tensor or the requested module {module}",
)
# Accounts for ops that return Tensors, iterables, and other non-Tensors.
# For example, lstm_cell returns a tuple and equal returns bool.
def compare(first, second):
if isinstance(first, torch.Tensor):
return torch.equal(first, second)
elif isinstance(first, collections.abc.Iterable):
return all(compare(f, s) for f, s in zip(first, second))
else:
return first == second
# If both torch.* and Tensor.* variants were found, check outputs are identical
if (output is not None) and (output_method is not None):
self.assertTrue(type(output) == type(output_method))
comparison = compare(output, output_method)
self.assertTrue(
comparison, f"torch.{op} result did not match Tensor.{op} result"
)
# Compare numerics to Python-side "autocasting" that (we expect) does the same thing
# as the C++-side autocasting, and should be bitwise accurate.
output_to_compare = output if output is not None else output_method
with torch.amp.autocast("xpu", enabled=False):
self.assertFalse(torch.is_autocast_enabled())
if module is not None and hasattr(module, op):
control = getattr(module, op)(
*cast(args, run_as_type), **add_kwargs
)
else:
control = getattr(args[0].to(run_as_type), op)(
*cast(args[1:], run_as_type), **add_kwargs
)
self.assertTrue(type(output_to_compare) == type(control))
comparison = compare(output_to_compare, control)
self.assertTrue(comparison, f"torch.{op} result did not match control")
self.assertTrue(torch.is_autocast_enabled())
self.assertFalse(torch.is_autocast_enabled())
def test_autocast_torch_fp16(self):
for op_with_args in self.autocast_lists.torch_fp16:
skip_test = False
op, args = op_with_args[0], op_with_args[1]
if len(op_with_args) == 3:
skip_test = True # skip cudnn op
if not skip_test:
self._run_autocast_outofplace(op, args, torch.float16)
def test_autocast_torch_bf16(self):
for op_with_args in self.autocast_lists.torch_fp16:
skip_test = False
op, args = op_with_args[0], op_with_args[1]
if len(op_with_args) == 3:
skip_test = True # skip cudnn op
if not skip_test:
self._run_autocast_outofplace(op, args, torch.bfloat16)
def test_autocast_torch_need_autocast_promote(self):
for op, args in self.autocast_lists.torch_need_autocast_promote:
self._run_autocast_outofplace(op, args, torch.float32)
def test_autocast_torch_expect_builtin_promote(self):
for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote:
self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type)
def test_xpu_autocast_dtype(self):
dtype = torch.get_autocast_dtype("xpu")
self.assertEqual(dtype, torch.float16)
mat0_fp32 = torch.randn((10, 10), dtype=torch.float32, device="xpu")
mat1_fp32 = torch.randn((10, 10), dtype=torch.float32, device="xpu")
with torch.amp.autocast("xpu"):
result = torch.mm(mat0_fp32, mat1_fp32)
self.assertEqual(result.dtype, torch.float16)
class TestXpuTrace(TestCase): class TestXpuTrace(TestCase):
def setUp(self): def setUp(self):
torch._C._activate_gpu_trace() torch._C._activate_gpu_trace()