mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Intel GPU] add tf32 support for matmul on XPU (#144240)
Support xpu tf32 matmul using torch.bachend.mkldnn.allow_tf32, we will discuss in future if we need a new api to control matmul only ~~Support xpu tf32 matmul using torch.set_float32_matmul_precision. For conv, check https://github.com/pytorch/pytorch/pull/137570 We decide not following torch.backends.cuda.matmul.allow_tf32 because this API actually calls setAllowTF32CuBLAS to set matmul_precison to high. We also avoid other related tf32 changes (i.e. in inductor) by not introducing new API.~~ Pull Request resolved: https://github.com/pytorch/pytorch/pull/144240 Approved by: https://github.com/EikanWang
This commit is contained in:
parent
ff039d39ec
commit
9f98e37eb4
|
|
@ -194,7 +194,12 @@ sycl::event matmul(
|
|||
pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
|
||||
if (m1_dt == dnnl::memory::data_type::f32) {
|
||||
pattr.set_fpmath_mode(dnnl::fpmath_mode::strict);
|
||||
bool allow_tf32 = at::globalContext().allowTF32OneDNN();
|
||||
if (allow_tf32) {
|
||||
pattr.set_fpmath_mode(dnnl::fpmath_mode::tf32);
|
||||
} else {
|
||||
pattr.set_fpmath_mode(dnnl::fpmath_mode::strict);
|
||||
}
|
||||
}
|
||||
|
||||
// STEP3: create primitive
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
# Owner(s): ["module: intel"]
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import itertools
|
||||
import math
|
||||
import random
|
||||
|
|
@ -23,6 +26,102 @@ from torch.testing._internal.common_utils import (
|
|||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def tf32_off():
|
||||
enabled = torch.backends.mkldnn.enabled
|
||||
deterministic = torch.backends.mkldnn.deterministic
|
||||
with torch.backends.mkldnn.flags(
|
||||
enabled=enabled, deterministic=deterministic, allow_tf32=False
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def tf32_on(self, tf32_precision=1e-5):
|
||||
enabled = torch.backends.mkldnn.enabled
|
||||
deterministic = torch.backends.mkldnn.deterministic
|
||||
old_precision = self.precision
|
||||
try:
|
||||
self.precision = tf32_precision
|
||||
with torch.backends.mkldnn.flags(
|
||||
enabled=enabled, deterministic=deterministic, allow_tf32=True
|
||||
):
|
||||
yield
|
||||
finally:
|
||||
self.precision = old_precision
|
||||
|
||||
|
||||
# This is a wrapper that wraps a test to run this test twice, one with
|
||||
# allow_tf32=True, another with allow_tf32=False. When running with
|
||||
# allow_tf32=True, it will use reduced precision as specified by the
|
||||
# argument. For example:
|
||||
# @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
# @tf32_on_and_off(0.005)
|
||||
# def test_matmul(self, device, dtype):
|
||||
# a = ...; b = ...;
|
||||
# c = torch.matmul(a, b)
|
||||
# self.assertEqual(c, expected)
|
||||
# In the above example, when testing torch.float32 , the matmul will be running at
|
||||
# TF32 mode and TF32 mode off, and on TF32 mode, the assertEqual will use reduced
|
||||
# precision to check values.
|
||||
#
|
||||
# This decorator can be used for function with or without device/dtype, such as
|
||||
# @tf32_on_and_off(0.005)
|
||||
# def test_my_op(self)
|
||||
# @tf32_on_and_off(0.005)
|
||||
# def test_my_op(self, device)
|
||||
# @tf32_on_and_off(0.005)
|
||||
# def test_my_op(self, device, dtype)
|
||||
# @tf32_on_and_off(0.005)
|
||||
# def test_my_op(self, dtype)
|
||||
def tf32_on_and_off(tf32_precision=1e-5):
|
||||
def with_tf32_disabled(self, function_call):
|
||||
with tf32_off():
|
||||
function_call()
|
||||
|
||||
def with_tf32_enabled(self, function_call):
|
||||
with tf32_on(self, tf32_precision):
|
||||
function_call()
|
||||
|
||||
def wrapper(f):
|
||||
params = inspect.signature(f).parameters
|
||||
arg_names = tuple(params.keys())
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrapped(*args, **kwargs):
|
||||
kwargs.update(zip(arg_names, args))
|
||||
cond = True
|
||||
if "device" in kwargs:
|
||||
cond = cond and (torch.device(kwargs["device"]).type == "xpu")
|
||||
if "dtype" in kwargs:
|
||||
cond = cond and (
|
||||
kwargs["dtype"] in {torch.float32}
|
||||
) # TODO: add complex64
|
||||
if cond:
|
||||
with_tf32_disabled(kwargs["self"], lambda: f(**kwargs))
|
||||
with_tf32_enabled(kwargs["self"], lambda: f(**kwargs))
|
||||
else:
|
||||
f(**kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# This is a wrapper that wraps a test to run it with TF32 turned off.
|
||||
# This wrapper is designed to be used when a test uses matmul or convolutions
|
||||
# but the purpose of that test is not testing matmul or convolutions.
|
||||
# Disabling TF32 will enforce torch.float tensors to be always computed
|
||||
# at full precision.
|
||||
def with_tf32_off(f):
|
||||
@functools.wraps(f)
|
||||
def wrapped(*args, **kwargs):
|
||||
with tf32_off():
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
class TestBasicGEMM(TestCase):
|
||||
def _test_addmm_addmv(
|
||||
self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False, activation=None
|
||||
|
|
@ -133,11 +232,13 @@ class TestBasicGEMM(TestCase):
|
|||
|
||||
@precisionOverride({torch.float: 1e-4, torch.double: 1e-6, torch.half: 1e-1})
|
||||
@dtypes(torch.float32, torch.half, torch.double)
|
||||
@tf32_on_and_off(0.05)
|
||||
def test_addmm(self, device, dtype):
|
||||
self._test_addmm_impl(torch.addmm, None, device, dtype)
|
||||
|
||||
@precisionOverride({torch.bfloat16: 1e-0, torch.half: 1e-3, torch.float: 1e-4})
|
||||
@dtypes(torch.bfloat16, torch.half, torch.float, torch.double)
|
||||
@tf32_on_and_off(0.005)
|
||||
def test_addmv(self, device, dtype):
|
||||
# have to use torch.randn(...).to(bfloat16) instead of
|
||||
# torch.randn(..., dtype=bfloat16). randn does not support
|
||||
|
|
@ -185,6 +286,7 @@ class TestBasicGEMM(TestCase):
|
|||
torch.float32,
|
||||
torch.float64,
|
||||
)
|
||||
@tf32_on_and_off(0.05)
|
||||
def test_mm(self, device, dtype):
|
||||
def _test_mm(n, m, p, dtype, genf):
|
||||
# helper function
|
||||
|
|
@ -287,6 +389,7 @@ class TestBasicGEMM(TestCase):
|
|||
|
||||
@precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
|
||||
@dtypes(torch.float32, torch.bfloat16, torch.half, torch.float64)
|
||||
@tf32_on_and_off(0.05)
|
||||
def test_bmm(self, device, dtype):
|
||||
batch_sizes = [1, 10]
|
||||
M, N, O = 23, 15, 12
|
||||
|
|
@ -403,6 +506,7 @@ class TestBasicGEMM(TestCase):
|
|||
|
||||
@precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
|
||||
@dtypes(torch.float64, torch.float32, torch.bfloat16, torch.half)
|
||||
@tf32_on_and_off(0.005)
|
||||
def test_addbmm(self, device, dtype):
|
||||
num_batches = 2
|
||||
M, N, O = 16, 17, 18
|
||||
|
|
@ -506,6 +610,7 @@ class TestBasicGEMM(TestCase):
|
|||
|
||||
@precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5, torch.float64: 1e-6})
|
||||
@dtypes(torch.float64, torch.float32, torch.bfloat16, torch.half)
|
||||
@tf32_on_and_off(0.01)
|
||||
def test_baddbmm(self, device, dtype):
|
||||
num_batches = 10
|
||||
M, N, O = 12, 8, 50
|
||||
|
|
@ -568,6 +673,7 @@ class TestBasicGEMM(TestCase):
|
|||
for b1, b2, ref, out_tensor in generate_tensor():
|
||||
self._test_addbmm_baddbmm("baddbmm", b1, b2, ref, out_tensor)
|
||||
|
||||
@tf32_on_and_off(0.05)
|
||||
def test_tensordot(self, device):
|
||||
a = torch.arange(60.0, device=device).reshape(3, 4, 5)
|
||||
b = torch.arange(24.0, device=device).reshape(4, 3, 2)
|
||||
|
|
@ -604,6 +710,7 @@ class TestBasicGEMM(TestCase):
|
|||
|
||||
@dtypes(torch.float, torch.double)
|
||||
@precisionOverride({torch.float32: 1e-4})
|
||||
@tf32_on_and_off(0.005)
|
||||
def test_1_sized_with_0_strided(self, device, dtype):
|
||||
a = make_tensor((8, 1, 64), dtype=dtype, device=device)
|
||||
a_strided = torch.as_strided(a, size=[8, 1, 64], stride=[64, 0, 1])
|
||||
|
|
@ -646,6 +753,7 @@ class TestBasicGEMM(TestCase):
|
|||
dims_small = [ds] + dims_small
|
||||
return (dims_small, dims_large, dims_full)
|
||||
|
||||
@tf32_on_and_off(0.005)
|
||||
def test_broadcast_fused_matmul(self, device):
|
||||
fns = ["baddbmm", "addbmm", "addmm", "addmv", "addr"]
|
||||
|
||||
|
|
@ -692,6 +800,7 @@ class TestBasicGEMM(TestCase):
|
|||
self.assertEqual(r0, r1)
|
||||
|
||||
@dtypes(torch.float32, torch.float64)
|
||||
@tf32_on_and_off(0.005)
|
||||
def test_strided_mm_bmm(self, device, dtype):
|
||||
# Tests strided view case with stride smaller than corresponding dimension size
|
||||
x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype, device=device)
|
||||
|
|
@ -706,6 +815,7 @@ class TestBasicGEMM(TestCase):
|
|||
torch_fn = lambda x: torch.mm(x, x) # noqa: E731
|
||||
self.compare_with_numpy(torch_fn, np_fn, sx[0])
|
||||
|
||||
@tf32_on_and_off(0.005)
|
||||
def test_mm_empty_inputs_mixed_dtype_errors(self, device):
|
||||
a = torch.randint(0, 10, [1, 10], dtype=torch.int16, device=device)
|
||||
b = torch.randn(10, 20, dtype=torch.float32, device=device)
|
||||
|
|
@ -714,6 +824,7 @@ class TestBasicGEMM(TestCase):
|
|||
):
|
||||
torch.mm(a, b)
|
||||
|
||||
@tf32_on_and_off(0.005)
|
||||
def test_matmul_45724(self, device):
|
||||
# https://github.com/pytorch/pytorch/issues/45724
|
||||
a = torch.rand(65537, 22, 64, device=device, dtype=torch.half)
|
||||
|
|
@ -731,6 +842,7 @@ class TestBasicGEMM(TestCase):
|
|||
torch.float32,
|
||||
torch.float64,
|
||||
)
|
||||
@tf32_on_and_off(0.005)
|
||||
def test_baddbmm_input_dtypes_compatibility(self, device, dtype):
|
||||
batch1 = torch.rand((1, 2, 2), dtype=torch.float32, device=device)
|
||||
batch2 = torch.rand((1, 2, 2), dtype=torch.float32, device=device)
|
||||
|
|
@ -745,6 +857,7 @@ class TestBasicGEMM(TestCase):
|
|||
self.assertEqual(out, y_ref)
|
||||
|
||||
@dtypes(torch.float)
|
||||
@tf32_on_and_off(0.005)
|
||||
def test_baddbmm_nan_input_with_zero_beta(self, device, dtype):
|
||||
for shape in [[3, 2, 2], [2, 20, 20]]:
|
||||
mat1, mat2 = (
|
||||
|
|
@ -767,6 +880,7 @@ class TestBasicGEMM(TestCase):
|
|||
|
||||
@precisionOverride({torch.double: 1e-6})
|
||||
@dtypes(torch.float, torch.double)
|
||||
@tf32_on_and_off(0.005)
|
||||
def test_addmm_sizes(self, device, dtype):
|
||||
for m in [0, 1, 25]:
|
||||
for n in [0, 1, 10]:
|
||||
|
|
@ -798,6 +912,7 @@ class TestBasicGEMM(TestCase):
|
|||
}
|
||||
)
|
||||
@dtypes(torch.double, torch.float32, torch.bfloat16, torch.half)
|
||||
@tf32_on_and_off(0.05)
|
||||
def test_addmm_gelu(self, device, dtype):
|
||||
self._test_addmm_impl(torch._addmm_activation, "gelu", device, dtype)
|
||||
|
||||
|
|
@ -812,10 +927,12 @@ class TestBasicGEMM(TestCase):
|
|||
}
|
||||
)
|
||||
@dtypes(torch.double, torch.float32, torch.bfloat16, torch.half)
|
||||
@tf32_on_and_off(0.05)
|
||||
def test_addmm_relu(self, device, dtype):
|
||||
self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype)
|
||||
|
||||
@dtypes(torch.float, torch.bfloat16, torch.half, torch.double)
|
||||
@dtypes(torch.float, torch.bfloat16, torch.half)
|
||||
@tf32_on_and_off(0.005)
|
||||
def test_addmv_rowmajor_colmajor_incx_incy_lda(self, device, dtype):
|
||||
# tests (o, s)*(s). o is output size, s is summed size.
|
||||
o = 5
|
||||
|
|
@ -859,6 +976,7 @@ class TestBasicGEMM(TestCase):
|
|||
}
|
||||
)
|
||||
@dtypes(torch.double, torch.bfloat16, torch.half, torch.float32)
|
||||
@tf32_on_and_off(0.005)
|
||||
def test_corner_cases_of_cublasltmatmul(self, device, dtype):
|
||||
# common case
|
||||
M = torch.randn(128, device=device).to(dtype)
|
||||
|
|
@ -998,6 +1116,7 @@ class TestBasicGEMM(TestCase):
|
|||
torch.tensor(0.0, device=device), fn(torch.dot, (0,), (0,), test_out=True)
|
||||
)
|
||||
|
||||
@tf32_on_and_off(0.005)
|
||||
def test_large_bmm_backward(self, device):
|
||||
A = torch.randn([1024, 2, 1024], device=device).mT.contiguous().mT
|
||||
B = torch.randn([1, 1024, 65536], device=device, requires_grad=True)
|
||||
|
|
@ -1006,6 +1125,7 @@ class TestBasicGEMM(TestCase):
|
|||
# Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM
|
||||
(A @ B).backward(G)
|
||||
|
||||
@tf32_on_and_off(0.005)
|
||||
def test_large_bmm_mm_backward(self, device):
|
||||
A = torch.randn([1024, 2, 1024], device=device).mT.contiguous().mT
|
||||
B = torch.randn([1024, 65536], device=device, requires_grad=True)
|
||||
|
|
@ -1104,6 +1224,7 @@ class TestBasicGEMM(TestCase):
|
|||
self.check_single_matmul(x, y)
|
||||
|
||||
@dtypes(torch.float)
|
||||
@tf32_on_and_off(0.005)
|
||||
def test_matmul_out_kernel_errors_with_autograd(self, device, dtype):
|
||||
a = torch.empty(
|
||||
(256, 512), device=device, dtype=dtype, requires_grad=True
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user