mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Split test_transformers.py (#147441)
Split test_transformers.py into test_transformers.py and test_transformers_privateuser1.py. Currently the privateuse1 test cases in test_transformers.py are skipped since they conflict with cuda test cases. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147441 Approved by: https://github.com/drisspg
This commit is contained in:
parent
cf6d1e6824
commit
7ffae2c028
|
|
@ -467,6 +467,7 @@ S390X_TESTLIST = [
|
|||
"test_tensorexpr_pybind",
|
||||
"test_torch",
|
||||
"test_transformers",
|
||||
"test_transformers_privateuse1",
|
||||
"test_type_hints",
|
||||
"test_type_info",
|
||||
"test_type_promotion",
|
||||
|
|
@ -1483,7 +1484,7 @@ CUSTOM_HANDLERS = {
|
|||
"test_autoload_enable": test_autoload_enable,
|
||||
"test_autoload_disable": test_autoload_disable,
|
||||
"test_cpp_extensions_open_device_registration": run_test_with_openreg,
|
||||
"test_transformers": run_test_with_openreg,
|
||||
"test_transformers_privateuse1": run_test_with_openreg,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import contextlib
|
|||
from functools import partial
|
||||
from collections import namedtuple
|
||||
import sys
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -22,7 +21,6 @@ from typing import Optional
|
|||
import torch.utils.cpp_extension
|
||||
from torch.testing._internal.common_nn import NNTestCase
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_FBCODE,
|
||||
TEST_WITH_ROCM,
|
||||
skipIfRocm,
|
||||
skipIfTorchDynamo,
|
||||
|
|
@ -38,7 +36,6 @@ from torch.testing._internal.common_utils import (
|
|||
NOTEST_CPU,
|
||||
IS_WINDOWS,
|
||||
TEST_WITH_TORCHDYNAMO,
|
||||
TEST_XPU,
|
||||
)
|
||||
from torch._dynamo.testing import CompileCounterWithBackend
|
||||
|
||||
|
|
@ -4006,69 +4003,6 @@ class TestAttnBias(NNTestCase):
|
|||
with self.assertRaisesRegex(ValueError, "CausalBias should not be used with causal=True"):
|
||||
scaled_dot_product_attention(query, key, value, attn_mask=attn_bias, is_causal=True, dropout_p=0.0)
|
||||
|
||||
@unittest.skipIf(TEST_XPU, "XPU does not support cppextension currently")
|
||||
@unittest.skipIf(IS_FBCODE, "Ninja is required to load C++ extensions and it's not compatible with Buck ")
|
||||
@unittest.skip("TODO: This test is broken and should be moved into a dedicated process for registering new extensions")
|
||||
class TestSDPAPrivateUse1Only(NNTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
import pytorch_openreg # noqa: F401
|
||||
|
||||
torch.testing._internal.common_utils.remove_cpp_extensions_build_root()
|
||||
cls.module = torch.utils.cpp_extension.load(
|
||||
name="custom_device_extension",
|
||||
sources=[
|
||||
f"{'test/' if not os.getcwd().endswith('test') else ''}cpp_extensions/open_registration_extension.cpp",
|
||||
],
|
||||
extra_include_paths=["cpp_extensions"],
|
||||
extra_cflags=["-g"],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
def test_fused_sdp_choice_privateuseone(self):
|
||||
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
|
||||
make_tensor = partial(torch.rand, device="cpu", dtype=torch.float16)
|
||||
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
|
||||
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
|
||||
q_privateuse1 = q_cpu.to("openreg")
|
||||
k_privateuse1 = k_cpu.to("openreg")
|
||||
v_privateuse1 = v_cpu.to("openreg")
|
||||
assert torch._fused_sdp_choice(q_privateuse1, k_privateuse1, v_privateuse1) == SDPBackend.OVERRIDEABLE.value
|
||||
|
||||
def test_scaled_dot_product_fused_attention_overrideable(self):
|
||||
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
|
||||
make_tensor = partial(torch.rand, device="cpu", dtype=torch.float16)
|
||||
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
|
||||
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
|
||||
q_privateuse1 = q_cpu.to("openreg")
|
||||
k_privateuse1 = k_cpu.to("openreg")
|
||||
v_privateuse1 = v_cpu.to("openreg")
|
||||
torch.nn.functional.scaled_dot_product_attention(
|
||||
q_privateuse1, k_privateuse1, v_privateuse1, attn_mask=None, dropout_p=0.0)
|
||||
|
||||
def test_scaled_dot_product_fused_attention_overrideable_backward(self):
|
||||
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
|
||||
make_tensor = partial(torch.rand, device="cpu", dtype=torch.float16, requires_grad=True)
|
||||
shape = (batch_size, num_heads, seq_len, head_dim)
|
||||
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
|
||||
attn_mask = make_tensor((batch_size, num_heads, seq_len, seq_len))
|
||||
q_privateuse1 = q_cpu.to("openreg")
|
||||
k_privateuse1 = k_cpu.to("openreg")
|
||||
v_privateuse1 = v_cpu.to("openreg")
|
||||
attn_mask_privateuse1 = attn_mask.to("openreg")
|
||||
output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = \
|
||||
torch.ops.aten._scaled_dot_product_fused_attention_overrideable(
|
||||
q_privateuse1, k_privateuse1, v_privateuse1, attn_bias=attn_mask_privateuse1)
|
||||
|
||||
rand_upward = torch.rand(shape, device="cpu", dtype=torch.float16, requires_grad=False)
|
||||
rand_upward_privateuse1 = rand_upward.to("openreg")
|
||||
grad_input_mask = [True, True, True, True]
|
||||
grad_q, grad_k, grad_v, grad_attn_mask = torch.ops.aten._scaled_dot_product_fused_attention_overrideable_backward(
|
||||
rand_upward_privateuse1, q_privateuse1, k_privateuse1, v_privateuse1, attn_mask_privateuse1,
|
||||
grad_input_mask, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p=0.0,
|
||||
is_causal=False, philox_seed=philox_seed, philox_offset=philox_offset)
|
||||
|
||||
if NOTEST_CPU:
|
||||
device_types = ("cuda", )
|
||||
else:
|
||||
|
|
|
|||
124
test/test_transformers_privateuse1.py
Normal file
124
test/test_transformers_privateuse1.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
# Owner(s): ["module: sdpa"]
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
|
||||
import pytorch_openreg # noqa: F401
|
||||
|
||||
import torch
|
||||
import torch.utils.cpp_extension
|
||||
from torch.nn.attention import SDPBackend
|
||||
from torch.testing._internal.common_nn import NNTestCase
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_FBCODE,
|
||||
run_tests,
|
||||
skipIfTorchDynamo,
|
||||
TEST_XPU,
|
||||
)
|
||||
|
||||
|
||||
SdpaShape = namedtuple("Sdpa_Shape", ["batch", "num_heads", "seq_len", "head_dim"])
|
||||
|
||||
|
||||
@unittest.skipIf(TEST_XPU, "XPU does not support cppextension currently")
|
||||
@unittest.skipIf(
|
||||
IS_FBCODE,
|
||||
"Ninja is required to load C++ extensions and it's not compatible with Buck ",
|
||||
)
|
||||
class TestSDPAPrivateUse1Only(NNTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
torch.testing._internal.common_utils.remove_cpp_extensions_build_root()
|
||||
cls.module = torch.utils.cpp_extension.load(
|
||||
name="custom_device_extension",
|
||||
sources=[
|
||||
f"{'test/' if not os.getcwd().endswith('test') else ''}cpp_extensions/open_registration_extension.cpp",
|
||||
],
|
||||
extra_include_paths=["cpp_extensions"],
|
||||
extra_cflags=["-g"],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
def test_fused_sdp_choice_privateuseone(self):
|
||||
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
|
||||
make_tensor = partial(torch.rand, device="cpu", dtype=torch.float16)
|
||||
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
|
||||
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
|
||||
q_privateuse1 = q_cpu.to("openreg")
|
||||
k_privateuse1 = k_cpu.to("openreg")
|
||||
v_privateuse1 = v_cpu.to("openreg")
|
||||
assert (
|
||||
torch._fused_sdp_choice(q_privateuse1, k_privateuse1, v_privateuse1)
|
||||
== SDPBackend.OVERRIDEABLE.value
|
||||
)
|
||||
|
||||
def test_scaled_dot_product_fused_attention_overrideable(self):
|
||||
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
|
||||
make_tensor = partial(torch.rand, device="cpu", dtype=torch.float16)
|
||||
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
|
||||
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
|
||||
q_privateuse1 = q_cpu.to("openreg")
|
||||
k_privateuse1 = k_cpu.to("openreg")
|
||||
v_privateuse1 = v_cpu.to("openreg")
|
||||
torch.nn.functional.scaled_dot_product_attention(
|
||||
q_privateuse1, k_privateuse1, v_privateuse1, attn_mask=None, dropout_p=0.0
|
||||
)
|
||||
|
||||
def test_scaled_dot_product_fused_attention_overrideable_backward(self):
|
||||
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
|
||||
make_tensor = partial(
|
||||
torch.rand, device="cpu", dtype=torch.float16, requires_grad=True
|
||||
)
|
||||
shape = (batch_size, num_heads, seq_len, head_dim)
|
||||
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
|
||||
attn_mask = make_tensor((batch_size, num_heads, seq_len, seq_len))
|
||||
q_privateuse1 = q_cpu.to("openreg")
|
||||
k_privateuse1 = k_cpu.to("openreg")
|
||||
v_privateuse1 = v_cpu.to("openreg")
|
||||
attn_mask_privateuse1 = attn_mask.to("openreg")
|
||||
(
|
||||
output,
|
||||
logsumexp,
|
||||
cum_seq_q,
|
||||
cum_seq_k,
|
||||
max_q,
|
||||
max_k,
|
||||
philox_seed,
|
||||
philox_offset,
|
||||
debug_attn_mask,
|
||||
) = torch.ops.aten._scaled_dot_product_fused_attention_overrideable(
|
||||
q_privateuse1, k_privateuse1, v_privateuse1, attn_bias=attn_mask_privateuse1
|
||||
)
|
||||
|
||||
rand_upward = torch.rand(
|
||||
shape, device="cpu", dtype=torch.float16, requires_grad=False
|
||||
)
|
||||
rand_upward_privateuse1 = rand_upward.to("openreg")
|
||||
grad_input_mask = [True, True, True, True]
|
||||
grad_q, grad_k, grad_v, grad_attn_mask = (
|
||||
torch.ops.aten._scaled_dot_product_fused_attention_overrideable_backward(
|
||||
rand_upward_privateuse1,
|
||||
q_privateuse1,
|
||||
k_privateuse1,
|
||||
v_privateuse1,
|
||||
attn_mask_privateuse1,
|
||||
grad_input_mask,
|
||||
output,
|
||||
logsumexp,
|
||||
cum_seq_q,
|
||||
cum_seq_k,
|
||||
max_q,
|
||||
max_k,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
philox_seed=philox_seed,
|
||||
philox_offset=philox_offset,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
Loading…
Reference in New Issue
Block a user