pytorch/test/test_transformers.py
Joel Benjamin Schlosser f15c5bf133 Initial private SDP interface and naive composite impl (#81956)
Adds an initial private API version of the SDP interface.

Signature:
```
_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None,
    float dropout_p=0.0, bool need_attn_weights=True, bool is_causal=False) -> (Tensor, Tensor)
```

Returns a tuple of `(output, attn_weights)`.

Note the following:
* `need_attn_weights`: flag indicating that attention weights should be computed. This is useful to toggle off for flash attention as it does not materialize the weights by default, making it more expensive to return them.
* Boolean attention mask support only; `True` values within `attn_mask` indicate that the element should take part in attention (notably, this is reverse of MHA, which uses `True` to mask *out* values). Mask is optional.
* `is_causal`: Temporary flag indicating whether to use a causal attention weighting. If this is set to `True`, it takes precedent over any value passed in for `attn_mask`. Longer term, the `is_causal` flagging can be subsumed into the `attn_mask` arg via tensor subclassing (see e.g. [CausalTensor](https://github.com/facebookresearch/xformers/blob/sparse_cleanup/xformers/sparse/causal_tensor.py) in xFormers).
* Testing is currently done via reference with the existing Python impl of `F._scaled_dot_product_attention`.
* This PR does not yet drop-in the new SDP anywhere. A future PR can hook it up in BT or MHA.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81956
Approved by: https://github.com/drisspg, https://github.com/erichan1
2022-07-27 15:41:45 +00:00

714 lines
33 KiB
Python

# Owner(s): ["module: nn"]
import contextlib
import torch
import torch.nn as nn
import torch.nn.functional as F
import unittest
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import (
TEST_FAIRSEQ, run_tests, parametrize, instantiate_parametrized_tests, freeze_rng_state)
from torch.testing._internal.common_cuda import TEST_CUDA
if TEST_FAIRSEQ:
import fairseq.models.transformer as fairseq_transformer
@contextlib.contextmanager
def set_default_dtype(dtype):
saved_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
try:
yield
finally:
torch.set_default_dtype(saved_dtype)
class TestTransformers(NNTestCase):
_do_cuda_memory_leak_check = True
_do_cuda_non_default_stream = True
device_list = ['cpu'] # TODO: is there a way to do parametrize for this?
if TEST_CUDA:
device_list.append('cuda')
@unittest.skip("4D mask not supported yet - activate when 4D mask supported")
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable") # TODO: make this work for both cuda and cpu
def test_self_attn_TxT_attn_mask(self):
embed_dim = 16
num_heads = 4
batch_size = 10
tgt_len = 16
query = torch.rand(batch_size, tgt_len, embed_dim, device="cuda") # [N, T, D]
attn_mask = torch.randint(0, 2, (tgt_len, tgt_len)).cuda().float() # [T, T]
attn_mask = attn_mask.masked_fill(attn_mask == 0, float('-inf')).masked_fill(attn_mask == 1, float(0.0))
attn_mask_4d = attn_mask.expand(batch_size, num_heads, tgt_len, tgt_len)
mta_model = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True).cuda()
mta_model.eval()
# Generate 3D results
with torch.inference_mode():
output_mask_4d = mta_model(query, query, query, attn_mask=attn_mask_4d)[0]
output_mask_4d = output_mask_4d.transpose(0, 1) # [N, T, D]
output_mask_TxT = mta_model(query, query, query, attn_mask=attn_mask)[0]
output_mask_TxT = output_mask_TxT.transpose(0, 1) # [N, T, D]
self.assertEqual(output_mask_4d, output_mask_TxT)
@parametrize("device", device_list)
def test_transformerencoderlayer_src_mask(self, device):
batch_size = 2
seqlen = 4
d_model = 8
nhead = 8
dim_feedforward = 32
model = torch.nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
batch_first=True).to(device)
src = torch.rand(batch_size, seqlen, d_model).to(device) # bs, seqlen, d_model
src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device)
model(src, src_mask=src_mask)
model.eval()
with torch.no_grad():
model(src, src_mask=src_mask)
@parametrize("use_torchscript", [True, False])
@parametrize("with_no_grad", [True, False])
@parametrize("training", [True, False])
def test_transformerencoder_fastpath_torchscript(self, use_torchscript, with_no_grad, training):
"""
Test TransformerEncoder does not crash
"""
model = torch.nn.TransformerEncoder(
torch.nn.TransformerEncoderLayer(d_model=2, nhead=2, dim_feedforward=8, batch_first=True),
num_layers=2,
enable_nested_tensor=True
)
if training:
model = model.train()
else:
model = model.eval()
if use_torchscript:
model = torch.jit.script(model)
x = torch.Tensor([[[1, 2], [3, 4]]]).to(torch.float)
mask = torch.Tensor([[0, 1]]).to(torch.bool)
if with_no_grad:
cm = torch.no_grad()
else:
cm = contextlib.nullcontext()
with cm:
model(x, src_key_padding_mask=mask)
@parametrize("with_no_grad", [True, False])
@parametrize("training", [True, False])
@parametrize("enable_nested_tensor", [False])
@parametrize("device", device_list)
def test_transformerencoder_square_input(self, with_no_grad, training, enable_nested_tensor, device):
"""
Test for edge cases when input of shape (batch size, sequence length, embedding dimension) has
batch size == sequence length
"""
model = torch.nn.TransformerEncoder(
torch.nn.TransformerEncoderLayer(d_model=4, nhead=2, dim_feedforward=16, dropout=0.0, batch_first=True),
num_layers=2,
enable_nested_tensor=enable_nested_tensor
).to(device)
with torch.no_grad():
# set constant weights of the model
for idx, p in enumerate(model.parameters()):
x = p.data
sz = x.view(-1).size(0)
shape = x.shape
x = torch.cos(torch.arange(0, sz).float().view(shape))
p.data.copy_(x)
if training:
model = model.train()
else:
model = model.eval()
x = torch.arange(0, 16).reshape(2, 2, 4).to(torch.float).to(device)
src_mask = torch.Tensor([[0, 1], [0, 0]]).to(torch.bool).to(device)
if with_no_grad:
cm = torch.no_grad()
else:
cm = contextlib.nullcontext()
with cm:
result = model(x, mask=src_mask)
ref_output = torch.Tensor([[[2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351],
[2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351]],
[[2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689],
[2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689]]]
).to(device)
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
@parametrize("batch_first", [True, False])
@parametrize("training", [True, False])
@parametrize("enable_nested_tensor", [True, False])
@parametrize("device", device_list)
def test_transformerencoder(self, batch_first, training, enable_nested_tensor, device):
def get_a_test_layer(activation, batch_first=False):
d_model = 4
nhead = 2
dim_feedforward = 16
dropout = 0.0
layer = nn.TransformerEncoderLayer(
d_model,
nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
batch_first=batch_first,
).to(device)
with torch.no_grad():
# set constant weights of the model
for idx, p in enumerate(layer.parameters()):
x = p.data
sz = x.view(-1).size(0)
shape = x.shape
x = torch.cos(torch.arange(0, sz).float().view(shape))
p.data.copy_(x)
return layer
# this is a deterministic test for TransformerEncoder
activation = F.relu
def _test(batch_first, training, enable_nested_tensor):
def perm_fn(x):
return x.transpose(1, 0) if batch_first else x
encoder_layer = get_a_test_layer(activation=activation,
batch_first=batch_first)
model = nn.TransformerEncoder(encoder_layer, 1).to(device)
if not training:
model = model.eval()
# deterministic input
encoder_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891],
[0.5387, 0.1655, 0.3565, 0.0471]],
[[0.8335, 0.2799, 0.5031, 0.2947],
[0.1402, 0.0318, 0.7636, 0.1346]],
[[0.6333, 0.9344, 0.1376, 0.9938],
[0.8924, 0.2872, 0.6692, 0.2944]],
[[0.9897, 0.6915, 0.3154, 0.1733],
[0.8645, 0.3513, 0.3064, 0.0767]],
[[0.8117, 0.2366, 0.4838, 0.7881],
[0.3718, 0.4945, 0.9511, 0.0864]]]
)).to(device)
result = model(encoder_input)
ref_output = perm_fn(torch.tensor([[[2.428589, 0.020835, -0.602055, -0.085249],
[2.427987, 0.021213, -0.602496, -0.084103]],
[[2.424689, 0.019155, -0.604793, -0.085672],
[2.413863, 0.022211, -0.612486, -0.072490]],
[[2.433774, 0.021598, -0.598343, -0.087548],
[2.425104, 0.019748, -0.604515, -0.084839]],
[[2.436185, 0.022682, -0.596625, -0.087261],
[2.433556, 0.021891, -0.598509, -0.086832]],
[[2.416246, 0.017512, -0.610712, -0.082961],
[2.422901, 0.024187, -0.606178, -0.074929]]]
)).to(device)
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
# all 0 src_mask
src_mask = torch.zeros([5, 5]).to(device) == 1
result = model(encoder_input, mask=src_mask)
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
# all 0
mask = torch.zeros([2, 5]).to(device) == 1
result = model(encoder_input, src_key_padding_mask=mask)
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
mask[0, 1] = 1
mask[1, 3] = 1
mask[1, 4] = 1
# If mask is not left aligned
# We disable nested tensor
model.enable_nested_tensor = enable_nested_tensor
result = model(encoder_input, src_key_padding_mask=mask)
ref_output = perm_fn(torch.tensor([[[2.429026, 0.020793, -0.601741, -0.085642],
[2.428811, 0.021445, -0.601912, -0.084252]],
[[2.425009, 0.019155, -0.604566, -0.085899],
[2.415408, 0.02249, -0.611415, -0.073]],
[[2.434199, 0.021682, -0.598039, -0.087699],
[2.42598, 0.019941, -0.603896, -0.085091]],
[[2.436457, 0.022736, -0.59643, -0.08736],
[2.434021, 0.022093, -0.598179, -0.08679]],
[[2.416531, 0.017498, -0.610513, -0.083181],
[2.4242, 0.024653, -0.605266, -0.074959]]]
)).to(device)
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
# test case 2, multiple layers no norm
model = nn.TransformerEncoder(encoder_layer, 2, enable_nested_tensor=enable_nested_tensor).to(device)
if not training:
model = model.eval()
result = model(encoder_input, src_key_padding_mask=mask)
ref_output = perm_fn(torch.tensor([[[2.419051, 0.017446, -0.608738, -0.085003],
[2.419102, 0.017452, -0.608703, -0.085026]],
[[2.419043, 0.017445, -0.608744, -0.084999],
[2.419052, 0.017446, -0.608738, -0.085004]],
[[2.419067, 0.017448, -0.608727, -0.085010],
[2.419098, 0.017452, -0.608706, -0.085024]],
[[2.419072, 0.017449, -0.608724, -0.085012],
[2.419119, 0.017455, -0.608691, -0.085034]],
[[2.419019, 0.017442, -0.608761, -0.084989],
[2.419075, 0.017449, -0.608722, -0.085014]]]
)).to(device)
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
model = nn.TransformerEncoder(encoder_layer, 6, enable_nested_tensor=enable_nested_tensor).to(device)
if not training:
model = model.eval()
result = model(encoder_input, src_key_padding_mask=mask)
ref_output = perm_fn(torch.tensor([[[2.419101, 0.017453, -0.608703, -0.085025],
[2.419101, 0.017453, -0.608704, -0.085025]],
[[2.419101, 0.017453, -0.608703, -0.085025],
[2.419101, 0.017453, -0.608704, -0.085025]],
[[2.419101, 0.017453, -0.608703, -0.085025],
[2.419101, 0.017453, -0.608704, -0.085025]],
[[2.419101, 0.017453, -0.608703, -0.085025],
[2.419101, 0.017453, -0.608704, -0.085025]],
[[2.419101, 0.017453, -0.608703, -0.085025],
[2.419101, 0.017453, -0.608704, -0.085025]]]
)).to(device)
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
# test case 3, multiple layers with norm
# d_model = 4
norm = nn.LayerNorm(4)
model = nn.TransformerEncoder(encoder_layer, 2, norm=norm, enable_nested_tensor=enable_nested_tensor).to(device)
if not training:
model = model.eval()
result = model(encoder_input, src_key_padding_mask=mask)
ref_output = perm_fn(torch.tensor([[[1.695949, -0.357635, -0.893077, -0.445238],
[1.695955, -0.357639, -0.893050, -0.445266]],
[[1.695948, -0.357634, -0.893082, -0.445233],
[1.695950, -0.357635, -0.893077, -0.445238]],
[[1.695951, -0.357636, -0.893069, -0.445246],
[1.695955, -0.357639, -0.893052, -0.445264]],
[[1.695952, -0.357636, -0.893066, -0.445249],
[1.695957, -0.357641, -0.893041, -0.445276]],
[[1.695946, -0.357632, -0.893095, -0.445220],
[1.695952, -0.357637, -0.893065, -0.445251]]]
)).to(device)
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
model = nn.TransformerEncoder(encoder_layer, 6, norm=norm, enable_nested_tensor=enable_nested_tensor).to(device)
if not training:
model = model.eval()
result = model(encoder_input, src_key_padding_mask=mask)
ref_output = perm_fn(torch.tensor([[[1.695955, -0.357639, -0.893051, -0.445265],
[1.695955, -0.357639, -0.893051, -0.445265]],
[[1.695955, -0.357639, -0.893051, -0.445265],
[1.695955, -0.357639, -0.893051, -0.445265]],
[[1.695955, -0.357639, -0.893051, -0.445265],
[1.695955, -0.357639, -0.893051, -0.445265]],
[[1.695955, -0.357639, -0.893051, -0.445265],
[1.695955, -0.357639, -0.893051, -0.445265]],
[[1.695955, -0.357639, -0.893051, -0.445265],
[1.695955, -0.357639, -0.893051, -0.445265]]]
)).to(device)
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
# TODO: remove set default dtype to double by making ref_output more precise.
# Added because this test was copied from test_nn.py, which has default
# dtype double. If default dtype is float, tests will say tensors not close because
# ref output precision too low
with set_default_dtype(torch.double):
if training:
cm = contextlib.nullcontext()
else:
cm = torch.no_grad() # transformer fast path requires no grad
with cm:
_test(batch_first, training, enable_nested_tensor)
@unittest.skipIf(not TEST_FAIRSEQ, "Fairseq not found")
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
def test_decoder_only_layer(self):
DEFAULT_PADDING_IDX = 0
class FairseqDecoder(torch.nn.Module):
def __init__(
self,
embed_dim,
attention_heads,
ffn_embed_dim,
num_layers,
embedding_layer, # torch.nn.Embedding. Must have a padding_idx field
dropout=0,
normalize_before=False,
torch_encoder=None, # torch encoder that you can map weights from
activation="relu",
):
super().__init__()
cfg = fairseq_transformer.TransformerConfig()
cfg.decoder.embed_dim = embed_dim
cfg.decoder.output_dim = embed_dim
cfg.decoder.attention_heads = attention_heads
cfg.decoder.ffn_embed_dim = ffn_embed_dim
cfg.dropout = dropout
cfg.decoder.normalize_before = normalize_before
cfg.decoder.layers = num_layers
# make embedding behavior same as other encoders
cfg.no_token_positional_embeddings = True
cfg.no_scale_embedding = True
cfg.activation_fn = activation
dictionary = {} # TODO: verify what this is
self.decoder = fairseq_transformer.TransformerDecoder(
cfg,
dictionary,
embedding_layer,
no_encoder_attn=True,
output_projection=None,
)
if torch_encoder is not None:
self.decoder = torch_to_fairseq(torch_encoder, self.decoder)
self.decoder = self.decoder.eval().cuda().half()
def forward(
self,
tokens,
src_lengths=None,
with_triangle_mask=False,
incremental_state=None,
):
return self.decoder(
prev_output_tokens=tokens,
encoder_out=None,
incremental_state=incremental_state,
features_only=True,
full_context_alignment=not with_triangle_mask,
alignment_layer=None,
alignment_heads=None,
src_lengths=src_lengths,
return_all_hiddens=False,
)[0]
class BetterDecoder(torch.nn.Module):
"""
Only incremental decoder for now
"""
def __init__(self, transformer, embedding, pad_idx):
super().__init__()
self.transformer = transformer
self.embedding = embedding
self.padding_idx = pad_idx
def forward(
self,
x,
src_mask=None,
include_padding_mask=True,
incr_key_lst=None,
incr_value_lst=None,
is_incremental_decoding=False,
):
padding_mask = None
if not x.is_nested and include_padding_mask:
padding_mask = x.eq(self.padding_idx)
if(is_incremental_decoding):
x = x[:, -1:] # only take the last token
x = self.embedding(x)
one_encoder_layer = self.transformer.layers[0]
self_attn = one_encoder_layer.self_attn
embed_dim = self_attn.embed_dim
num_heads = self_attn.num_heads
use_gelu = (
one_encoder_layer.activation_relu_or_gelu == 2
) # see torch/nn/modules/activation attention impl. 1 == relu, 2 == gelu
assert (
one_encoder_layer.activation_relu_or_gelu != 0
) # 0 == not relu or gelu
norm_first = one_encoder_layer.norm_first
# TODO: make this a bit less janky. but for now we initialize with an empty tensor.
if(not is_incremental_decoding):
assert len(incr_key_lst) == 0 or incr_key_lst[0] is None
assert len(incr_value_lst) == 0 or incr_value_lst[0] is None
while len(incr_key_lst) <= len(self.transformer.layers):
if(is_incremental_decoding):
incr_key_lst.append(torch.Tensor([]).cuda().half())
incr_value_lst.append(torch.Tensor([]).cuda().half())
else:
incr_key_lst.append(None)
incr_value_lst.append(None)
for i, layer in enumerate(self.transformer.layers):
incr_key = incr_key_lst[i]
incr_value = incr_value_lst[i]
x, incr_key, incr_value = torch._transformer_decoder_only_layer_fwd(
src=x,
embed_dim=embed_dim,
num_heads=num_heads,
qkv_weight=layer.self_attn.in_proj_weight,
qkv_bias=layer.self_attn.in_proj_bias,
proj_weight=layer.self_attn.out_proj.weight,
proj_bias=layer.self_attn.out_proj.bias,
use_gelu=use_gelu,
norm_first=norm_first,
# TODO: layer_norm_eps hardcoded to be same as nn.TransformerEncoder default.
# fix by pulling from self_attn.norm1
eps=1e-5,
norm_weight_1=layer.norm1.weight,
norm_bias_1=layer.norm1.bias,
norm_weight_2=layer.norm2.weight,
norm_bias_2=layer.norm2.bias,
ffn_weight_1=layer.linear1.weight,
ffn_bias_1=layer.linear1.bias,
ffn_weight_2=layer.linear2.weight,
ffn_bias_2=layer.linear2.bias,
mask=src_mask,
incr_key=incr_key, # altered in place
incr_value=incr_value,
)
# not in-place
if(not is_incremental_decoding):
incr_key = None
incr_value = None
incr_key_lst[i] = incr_key
incr_value_lst[i] = incr_value
return x, incr_key_lst, incr_value_lst
def torch_to_fairseq(torch_encoder, fairseq_encoder):
for src_layer, dst_layer in zip(torch_encoder.layers, fairseq_encoder.layers):
w_q, w_k, w_v = src_layer.self_attn.in_proj_weight.chunk(3, dim=0)
b_q, b_k, b_v = src_layer.self_attn.in_proj_bias.chunk(3, dim=0)
dst_layer.self_attn.q_proj.weight = torch.nn.Parameter(w_q)
dst_layer.self_attn.q_proj.bias = torch.nn.Parameter(b_q)
dst_layer.self_attn.k_proj.weight = torch.nn.Parameter(w_k)
dst_layer.self_attn.k_proj.bias = torch.nn.Parameter(b_k)
dst_layer.self_attn.v_proj.weight = torch.nn.Parameter(w_v)
dst_layer.self_attn.v_proj.bias = torch.nn.Parameter(b_v)
dst_layer.self_attn.out_proj.weight = src_layer.self_attn.out_proj.weight
dst_layer.self_attn.out_proj.bias = src_layer.self_attn.out_proj.bias
dst_layer.fc1.weight = src_layer.linear1.weight
dst_layer.fc1.bias = src_layer.linear1.bias
# fairseq may use fusedlayernorm from nvidia apex - diff properties
dst_layer.self_attn_layer_norm.load_state_dict(src_layer.norm1.state_dict())
dst_layer.fc2.weight = src_layer.linear2.weight
dst_layer.fc2.bias = src_layer.linear2.bias
dst_layer.final_layer_norm.load_state_dict(src_layer.norm2.state_dict())
return fairseq_encoder
def set_weights_deterministic(model):
for idx, p in enumerate(model.parameters()):
x = p.data
sz = x.view(-1).size(0)
shape = x.shape
x = torch.cos(torch.arange(0, sz).float().view(shape))
p.data.copy_(x)
D = 4 # d_model
H = 2 # nhead
FD = 16 # dim_feedforward
V = 100 # vocab size
L = 2 # num layers
embedding_layer = torch.nn.Embedding(V, D, DEFAULT_PADDING_IDX)
layer = torch.nn.TransformerEncoderLayer(
d_model=D,
nhead=H,
dim_feedforward=FD,
batch_first=True,
activation="gelu",
)
transformer = torch.nn.TransformerEncoder(
layer,
num_layers=L,
).eval().cuda().half()
set_weights_deterministic(embedding_layer)
set_weights_deterministic(transformer)
better_decoder = (
BetterDecoder(transformer, embedding_layer, DEFAULT_PADDING_IDX)
.eval()
.cuda()
.half()
)
fairseq_decoder = (
FairseqDecoder(
D,
H,
FD,
L,
embedding_layer,
dropout=0,
normalize_before=False,
torch_encoder=transformer,
activation="gelu",
)
.eval()
.cuda()
.half()
)
tokens = torch.Tensor([
[5, 6, 7, 8],
[9, 10, 11, 12]
]).to(torch.int).cuda()
lengths_tensor = torch.Tensor([2, 2]).to(torch.int).cuda()
# bs = 2, seqlen = 4
bs, seqlen = tokens.shape
upper_triangle = torch.zeros(seqlen, seqlen)
upper_triangle.fill_(-100000000)
upper_triangle = torch.triu(upper_triangle, 1)
upper_triangle = upper_triangle.cuda().half()
upper_triangle_expanded = upper_triangle.unsqueeze(0).unsqueeze(0)
upper_triangle_expanded = upper_triangle_expanded.expand(
bs, H, -1, -1
)
# test forced decoding
with torch.no_grad():
result, _, _ = better_decoder(
tokens,
src_mask=upper_triangle_expanded,
include_padding_mask=False,
incr_key_lst=[],
incr_value_lst=[],
is_incremental_decoding=False,
)
ref_output = fairseq_decoder(tokens, lengths_tensor, with_triangle_mask=True)
self.assertEqual(result.shape, ref_output.shape)
torch.testing.assert_close(result, ref_output, atol=1e-3, rtol=1e-2)
# test incremental decoding
bs, seqlen = tokens.shape
incr_state = {}
ref_outputs = [fairseq_decoder(
tokens[:, :i],
src_lengths=None,
with_triangle_mask=False,
incremental_state=incr_state,
) for i in range(1, seqlen + 1)]
ref_output = torch.stack(ref_outputs)
incr_key_lst = []
incr_value_lst = []
results = []
for i in range(1, seqlen + 1):
res, incr_key_lst, incr_value_lst = better_decoder(
tokens[:, :i],
src_mask=None,
include_padding_mask=False,
incr_key_lst=incr_key_lst,
incr_value_lst=incr_value_lst,
is_incremental_decoding=True,
)
results.append(res)
result = torch.stack(results)
self.assertEqual(result.shape, ref_output.shape)
torch.testing.assert_close(result, ref_output, atol=1e-3, rtol=1e-2)
@parametrize("attn_mask_dim,is_causal",
[(None, False), (2, False), (2, True), (3, False), (3, True)],
name_fn=lambda dim, is_causal: (f"{dim}D_{'causal_' if is_causal else ''}attn_mask"
if dim is not None else "no_attn_mask"))
@parametrize("dropout_p", [0.0, 0.2, 0.5])
@parametrize("device", device_list)
def test_scaled_dot_product_attention(self, device, attn_mask_dim, is_causal, dropout_p):
# TODO: Support cross-device / dtype testing properly when instantiate_device_type_tests() is used.
dtypes = [torch.double, torch.float]
for dtype in dtypes:
# This test compares python and C++ implementations of SDP.
N, L, S, E = 5, 4, 3, 6
query = torch.randn(N, L, E, device=device, dtype=dtype)
key = torch.randn(N, S, E, device=device, dtype=dtype)
value = torch.randn(N, S, E, device=device, dtype=dtype)
attn_mask = None
if attn_mask_dim is not None:
assert attn_mask_dim in [2, 3]
mask_size = (L, S) if attn_mask_dim == 2 else (N, L, S)
attn_mask = (torch.ones(mask_size, device=device, dtype=torch.bool).tril() if is_causal
else torch.randint(0, 2, size=mask_size, device=device, dtype=torch.bool))
with freeze_rng_state():
# Python impl only supports float mask.
attn_mask_float = attn_mask
if attn_mask_float is not None:
attn_mask_float = torch.zeros_like(attn_mask, dtype=query.dtype)
attn_mask_float.masked_fill_(attn_mask.logical_not(), float("-inf"))
expected = F._scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask_float, dropout_p=dropout_p)
need_attn_weights: bool = True
with freeze_rng_state():
if is_causal:
# NB: Don't pass attn_mask here
actual = torch.ops.aten._scaled_dot_product_attention(
query, key, value, None, dropout_p, need_attn_weights, is_causal)
# Error case: both explicit attn_mask and is_causal are set
with self.assertRaisesRegex(RuntimeError,
"Explicit attn_mask should not be set when is_causal=True"):
torch.ops.aten._scaled_dot_product_attention(
query, key, value, attn_mask, dropout_p, need_attn_weights, is_causal)
else:
actual = torch.ops.aten._scaled_dot_product_attention(
query, key, value, attn_mask, dropout_p, need_attn_weights, is_causal)
# freeze_rng_state() doesn't seem to work outside of CPU, so dropout makes the results incomparable.
# TODO: Do this skipping in a nicer way once the granular test skipping logic lands.
if dropout_p == 0.0 or device == 'cpu':
self.assertEqual(actual, expected)
# TODO: Replace this with instantiate_device_type_tests() to take advantage of test framework support for
# cross device / dtype testing.
instantiate_parametrized_tests(TestTransformers)
if __name__ == '__main__':
run_tests()