Split test_transformers.py (#147441)

Split test_transformers.py into test_transformers.py and test_transformers_privateuser1.py. Currently the privateuse1 test cases in test_transformers.py are skipped since they conflict with cuda test cases.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147441
Approved by: https://github.com/drisspg
This commit is contained in:
Zhenbin Lin 2025-02-26 11:54:24 +00:00 committed by PyTorch MergeBot
parent cf6d1e6824
commit 7ffae2c028
3 changed files with 126 additions and 67 deletions

View File

@ -467,6 +467,7 @@ S390X_TESTLIST = [
"test_tensorexpr_pybind",
"test_torch",
"test_transformers",
"test_transformers_privateuse1",
"test_type_hints",
"test_type_info",
"test_type_promotion",
@ -1483,7 +1484,7 @@ CUSTOM_HANDLERS = {
"test_autoload_enable": test_autoload_enable,
"test_autoload_disable": test_autoload_disable,
"test_cpp_extensions_open_device_registration": run_test_with_openreg,
"test_transformers": run_test_with_openreg,
"test_transformers_privateuse1": run_test_with_openreg,
}

View File

@ -4,7 +4,6 @@ import contextlib
from functools import partial
from collections import namedtuple
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -22,7 +21,6 @@ from typing import Optional
import torch.utils.cpp_extension
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import (
IS_FBCODE,
TEST_WITH_ROCM,
skipIfRocm,
skipIfTorchDynamo,
@ -38,7 +36,6 @@ from torch.testing._internal.common_utils import (
NOTEST_CPU,
IS_WINDOWS,
TEST_WITH_TORCHDYNAMO,
TEST_XPU,
)
from torch._dynamo.testing import CompileCounterWithBackend
@ -4006,69 +4003,6 @@ class TestAttnBias(NNTestCase):
with self.assertRaisesRegex(ValueError, "CausalBias should not be used with causal=True"):
scaled_dot_product_attention(query, key, value, attn_mask=attn_bias, is_causal=True, dropout_p=0.0)
@unittest.skipIf(TEST_XPU, "XPU does not support cppextension currently")
@unittest.skipIf(IS_FBCODE, "Ninja is required to load C++ extensions and it's not compatible with Buck ")
@unittest.skip("TODO: This test is broken and should be moved into a dedicated process for registering new extensions")
class TestSDPAPrivateUse1Only(NNTestCase):
@classmethod
def setUpClass(cls):
import pytorch_openreg # noqa: F401
torch.testing._internal.common_utils.remove_cpp_extensions_build_root()
cls.module = torch.utils.cpp_extension.load(
name="custom_device_extension",
sources=[
f"{'test/' if not os.getcwd().endswith('test') else ''}cpp_extensions/open_registration_extension.cpp",
],
extra_include_paths=["cpp_extensions"],
extra_cflags=["-g"],
verbose=True,
)
@skipIfTorchDynamo()
def test_fused_sdp_choice_privateuseone(self):
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
make_tensor = partial(torch.rand, device="cpu", dtype=torch.float16)
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
q_privateuse1 = q_cpu.to("openreg")
k_privateuse1 = k_cpu.to("openreg")
v_privateuse1 = v_cpu.to("openreg")
assert torch._fused_sdp_choice(q_privateuse1, k_privateuse1, v_privateuse1) == SDPBackend.OVERRIDEABLE.value
def test_scaled_dot_product_fused_attention_overrideable(self):
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
make_tensor = partial(torch.rand, device="cpu", dtype=torch.float16)
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
q_privateuse1 = q_cpu.to("openreg")
k_privateuse1 = k_cpu.to("openreg")
v_privateuse1 = v_cpu.to("openreg")
torch.nn.functional.scaled_dot_product_attention(
q_privateuse1, k_privateuse1, v_privateuse1, attn_mask=None, dropout_p=0.0)
def test_scaled_dot_product_fused_attention_overrideable_backward(self):
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
make_tensor = partial(torch.rand, device="cpu", dtype=torch.float16, requires_grad=True)
shape = (batch_size, num_heads, seq_len, head_dim)
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
attn_mask = make_tensor((batch_size, num_heads, seq_len, seq_len))
q_privateuse1 = q_cpu.to("openreg")
k_privateuse1 = k_cpu.to("openreg")
v_privateuse1 = v_cpu.to("openreg")
attn_mask_privateuse1 = attn_mask.to("openreg")
output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = \
torch.ops.aten._scaled_dot_product_fused_attention_overrideable(
q_privateuse1, k_privateuse1, v_privateuse1, attn_bias=attn_mask_privateuse1)
rand_upward = torch.rand(shape, device="cpu", dtype=torch.float16, requires_grad=False)
rand_upward_privateuse1 = rand_upward.to("openreg")
grad_input_mask = [True, True, True, True]
grad_q, grad_k, grad_v, grad_attn_mask = torch.ops.aten._scaled_dot_product_fused_attention_overrideable_backward(
rand_upward_privateuse1, q_privateuse1, k_privateuse1, v_privateuse1, attn_mask_privateuse1,
grad_input_mask, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p=0.0,
is_causal=False, philox_seed=philox_seed, philox_offset=philox_offset)
if NOTEST_CPU:
device_types = ("cuda", )
else:

View File

@ -0,0 +1,124 @@
# Owner(s): ["module: sdpa"]
import os
import unittest
from collections import namedtuple
from functools import partial
import pytorch_openreg # noqa: F401
import torch
import torch.utils.cpp_extension
from torch.nn.attention import SDPBackend
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import (
IS_FBCODE,
run_tests,
skipIfTorchDynamo,
TEST_XPU,
)
SdpaShape = namedtuple("Sdpa_Shape", ["batch", "num_heads", "seq_len", "head_dim"])
@unittest.skipIf(TEST_XPU, "XPU does not support cppextension currently")
@unittest.skipIf(
IS_FBCODE,
"Ninja is required to load C++ extensions and it's not compatible with Buck ",
)
class TestSDPAPrivateUse1Only(NNTestCase):
@classmethod
def setUpClass(cls):
torch.testing._internal.common_utils.remove_cpp_extensions_build_root()
cls.module = torch.utils.cpp_extension.load(
name="custom_device_extension",
sources=[
f"{'test/' if not os.getcwd().endswith('test') else ''}cpp_extensions/open_registration_extension.cpp",
],
extra_include_paths=["cpp_extensions"],
extra_cflags=["-g"],
verbose=True,
)
@skipIfTorchDynamo()
def test_fused_sdp_choice_privateuseone(self):
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
make_tensor = partial(torch.rand, device="cpu", dtype=torch.float16)
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
q_privateuse1 = q_cpu.to("openreg")
k_privateuse1 = k_cpu.to("openreg")
v_privateuse1 = v_cpu.to("openreg")
assert (
torch._fused_sdp_choice(q_privateuse1, k_privateuse1, v_privateuse1)
== SDPBackend.OVERRIDEABLE.value
)
def test_scaled_dot_product_fused_attention_overrideable(self):
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
make_tensor = partial(torch.rand, device="cpu", dtype=torch.float16)
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
q_privateuse1 = q_cpu.to("openreg")
k_privateuse1 = k_cpu.to("openreg")
v_privateuse1 = v_cpu.to("openreg")
torch.nn.functional.scaled_dot_product_attention(
q_privateuse1, k_privateuse1, v_privateuse1, attn_mask=None, dropout_p=0.0
)
def test_scaled_dot_product_fused_attention_overrideable_backward(self):
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
make_tensor = partial(
torch.rand, device="cpu", dtype=torch.float16, requires_grad=True
)
shape = (batch_size, num_heads, seq_len, head_dim)
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
attn_mask = make_tensor((batch_size, num_heads, seq_len, seq_len))
q_privateuse1 = q_cpu.to("openreg")
k_privateuse1 = k_cpu.to("openreg")
v_privateuse1 = v_cpu.to("openreg")
attn_mask_privateuse1 = attn_mask.to("openreg")
(
output,
logsumexp,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
philox_seed,
philox_offset,
debug_attn_mask,
) = torch.ops.aten._scaled_dot_product_fused_attention_overrideable(
q_privateuse1, k_privateuse1, v_privateuse1, attn_bias=attn_mask_privateuse1
)
rand_upward = torch.rand(
shape, device="cpu", dtype=torch.float16, requires_grad=False
)
rand_upward_privateuse1 = rand_upward.to("openreg")
grad_input_mask = [True, True, True, True]
grad_q, grad_k, grad_v, grad_attn_mask = (
torch.ops.aten._scaled_dot_product_fused_attention_overrideable_backward(
rand_upward_privateuse1,
q_privateuse1,
k_privateuse1,
v_privateuse1,
attn_mask_privateuse1,
grad_input_mask,
output,
logsumexp,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
dropout_p=0.0,
is_causal=False,
philox_seed=philox_seed,
philox_offset=philox_offset,
)
)
if __name__ == "__main__":
run_tests()