mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Adds an initial private API version of the SDP interface.
Signature:
```
_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None,
float dropout_p=0.0, bool need_attn_weights=True, bool is_causal=False) -> (Tensor, Tensor)
```
Returns a tuple of `(output, attn_weights)`.
Note the following:
* `need_attn_weights`: flag indicating that attention weights should be computed. This is useful to toggle off for flash attention as it does not materialize the weights by default, making it more expensive to return them.
* Boolean attention mask support only; `True` values within `attn_mask` indicate that the element should take part in attention (notably, this is reverse of MHA, which uses `True` to mask *out* values). Mask is optional.
* `is_causal`: Temporary flag indicating whether to use a causal attention weighting. If this is set to `True`, it takes precedent over any value passed in for `attn_mask`. Longer term, the `is_causal` flagging can be subsumed into the `attn_mask` arg via tensor subclassing (see e.g. [CausalTensor](https://github.com/facebookresearch/xformers/blob/sparse_cleanup/xformers/sparse/causal_tensor.py) in xFormers).
* Testing is currently done via reference with the existing Python impl of `F._scaled_dot_product_attention`.
* This PR does not yet drop-in the new SDP anywhere. A future PR can hook it up in BT or MHA.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81956
Approved by: https://github.com/drisspg, https://github.com/erichan1
714 lines
33 KiB
Python
714 lines
33 KiB
Python
# Owner(s): ["module: nn"]
|
|
|
|
import contextlib
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import unittest
|
|
|
|
from torch.testing._internal.common_nn import NNTestCase
|
|
from torch.testing._internal.common_utils import (
|
|
TEST_FAIRSEQ, run_tests, parametrize, instantiate_parametrized_tests, freeze_rng_state)
|
|
from torch.testing._internal.common_cuda import TEST_CUDA
|
|
|
|
if TEST_FAIRSEQ:
|
|
import fairseq.models.transformer as fairseq_transformer
|
|
|
|
@contextlib.contextmanager
|
|
def set_default_dtype(dtype):
|
|
saved_dtype = torch.get_default_dtype()
|
|
torch.set_default_dtype(dtype)
|
|
try:
|
|
yield
|
|
finally:
|
|
torch.set_default_dtype(saved_dtype)
|
|
|
|
class TestTransformers(NNTestCase):
|
|
_do_cuda_memory_leak_check = True
|
|
_do_cuda_non_default_stream = True
|
|
|
|
device_list = ['cpu'] # TODO: is there a way to do parametrize for this?
|
|
if TEST_CUDA:
|
|
device_list.append('cuda')
|
|
|
|
@unittest.skip("4D mask not supported yet - activate when 4D mask supported")
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable") # TODO: make this work for both cuda and cpu
|
|
def test_self_attn_TxT_attn_mask(self):
|
|
embed_dim = 16
|
|
num_heads = 4
|
|
batch_size = 10
|
|
tgt_len = 16
|
|
|
|
query = torch.rand(batch_size, tgt_len, embed_dim, device="cuda") # [N, T, D]
|
|
attn_mask = torch.randint(0, 2, (tgt_len, tgt_len)).cuda().float() # [T, T]
|
|
attn_mask = attn_mask.masked_fill(attn_mask == 0, float('-inf')).masked_fill(attn_mask == 1, float(0.0))
|
|
|
|
attn_mask_4d = attn_mask.expand(batch_size, num_heads, tgt_len, tgt_len)
|
|
|
|
mta_model = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True).cuda()
|
|
mta_model.eval()
|
|
|
|
# Generate 3D results
|
|
with torch.inference_mode():
|
|
output_mask_4d = mta_model(query, query, query, attn_mask=attn_mask_4d)[0]
|
|
output_mask_4d = output_mask_4d.transpose(0, 1) # [N, T, D]
|
|
|
|
output_mask_TxT = mta_model(query, query, query, attn_mask=attn_mask)[0]
|
|
output_mask_TxT = output_mask_TxT.transpose(0, 1) # [N, T, D]
|
|
|
|
self.assertEqual(output_mask_4d, output_mask_TxT)
|
|
|
|
@parametrize("device", device_list)
|
|
def test_transformerencoderlayer_src_mask(self, device):
|
|
batch_size = 2
|
|
seqlen = 4
|
|
d_model = 8
|
|
nhead = 8
|
|
dim_feedforward = 32
|
|
|
|
model = torch.nn.TransformerEncoderLayer(
|
|
d_model=d_model,
|
|
nhead=nhead,
|
|
dim_feedforward=dim_feedforward,
|
|
batch_first=True).to(device)
|
|
src = torch.rand(batch_size, seqlen, d_model).to(device) # bs, seqlen, d_model
|
|
src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device)
|
|
|
|
model(src, src_mask=src_mask)
|
|
model.eval()
|
|
with torch.no_grad():
|
|
model(src, src_mask=src_mask)
|
|
|
|
@parametrize("use_torchscript", [True, False])
|
|
@parametrize("with_no_grad", [True, False])
|
|
@parametrize("training", [True, False])
|
|
def test_transformerencoder_fastpath_torchscript(self, use_torchscript, with_no_grad, training):
|
|
"""
|
|
Test TransformerEncoder does not crash
|
|
"""
|
|
model = torch.nn.TransformerEncoder(
|
|
torch.nn.TransformerEncoderLayer(d_model=2, nhead=2, dim_feedforward=8, batch_first=True),
|
|
num_layers=2,
|
|
enable_nested_tensor=True
|
|
)
|
|
|
|
if training:
|
|
model = model.train()
|
|
else:
|
|
model = model.eval()
|
|
|
|
if use_torchscript:
|
|
model = torch.jit.script(model)
|
|
|
|
x = torch.Tensor([[[1, 2], [3, 4]]]).to(torch.float)
|
|
mask = torch.Tensor([[0, 1]]).to(torch.bool)
|
|
|
|
if with_no_grad:
|
|
cm = torch.no_grad()
|
|
else:
|
|
cm = contextlib.nullcontext()
|
|
with cm:
|
|
model(x, src_key_padding_mask=mask)
|
|
|
|
@parametrize("with_no_grad", [True, False])
|
|
@parametrize("training", [True, False])
|
|
@parametrize("enable_nested_tensor", [False])
|
|
@parametrize("device", device_list)
|
|
def test_transformerencoder_square_input(self, with_no_grad, training, enable_nested_tensor, device):
|
|
"""
|
|
Test for edge cases when input of shape (batch size, sequence length, embedding dimension) has
|
|
batch size == sequence length
|
|
"""
|
|
model = torch.nn.TransformerEncoder(
|
|
torch.nn.TransformerEncoderLayer(d_model=4, nhead=2, dim_feedforward=16, dropout=0.0, batch_first=True),
|
|
num_layers=2,
|
|
enable_nested_tensor=enable_nested_tensor
|
|
).to(device)
|
|
|
|
with torch.no_grad():
|
|
# set constant weights of the model
|
|
for idx, p in enumerate(model.parameters()):
|
|
x = p.data
|
|
sz = x.view(-1).size(0)
|
|
shape = x.shape
|
|
x = torch.cos(torch.arange(0, sz).float().view(shape))
|
|
p.data.copy_(x)
|
|
|
|
if training:
|
|
model = model.train()
|
|
else:
|
|
model = model.eval()
|
|
x = torch.arange(0, 16).reshape(2, 2, 4).to(torch.float).to(device)
|
|
src_mask = torch.Tensor([[0, 1], [0, 0]]).to(torch.bool).to(device)
|
|
|
|
if with_no_grad:
|
|
cm = torch.no_grad()
|
|
else:
|
|
cm = contextlib.nullcontext()
|
|
with cm:
|
|
result = model(x, mask=src_mask)
|
|
|
|
ref_output = torch.Tensor([[[2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351],
|
|
[2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351]],
|
|
[[2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689],
|
|
[2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689]]]
|
|
).to(device)
|
|
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
|
|
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
|
|
|
|
@parametrize("batch_first", [True, False])
|
|
@parametrize("training", [True, False])
|
|
@parametrize("enable_nested_tensor", [True, False])
|
|
@parametrize("device", device_list)
|
|
def test_transformerencoder(self, batch_first, training, enable_nested_tensor, device):
|
|
def get_a_test_layer(activation, batch_first=False):
|
|
d_model = 4
|
|
nhead = 2
|
|
dim_feedforward = 16
|
|
dropout = 0.0
|
|
|
|
layer = nn.TransformerEncoderLayer(
|
|
d_model,
|
|
nhead,
|
|
dim_feedforward=dim_feedforward,
|
|
dropout=dropout,
|
|
activation=activation,
|
|
batch_first=batch_first,
|
|
).to(device)
|
|
|
|
with torch.no_grad():
|
|
# set constant weights of the model
|
|
for idx, p in enumerate(layer.parameters()):
|
|
x = p.data
|
|
sz = x.view(-1).size(0)
|
|
shape = x.shape
|
|
x = torch.cos(torch.arange(0, sz).float().view(shape))
|
|
p.data.copy_(x)
|
|
|
|
return layer
|
|
|
|
# this is a deterministic test for TransformerEncoder
|
|
activation = F.relu
|
|
|
|
def _test(batch_first, training, enable_nested_tensor):
|
|
def perm_fn(x):
|
|
return x.transpose(1, 0) if batch_first else x
|
|
|
|
encoder_layer = get_a_test_layer(activation=activation,
|
|
batch_first=batch_first)
|
|
|
|
model = nn.TransformerEncoder(encoder_layer, 1).to(device)
|
|
if not training:
|
|
model = model.eval()
|
|
|
|
# deterministic input
|
|
encoder_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891],
|
|
[0.5387, 0.1655, 0.3565, 0.0471]],
|
|
[[0.8335, 0.2799, 0.5031, 0.2947],
|
|
[0.1402, 0.0318, 0.7636, 0.1346]],
|
|
[[0.6333, 0.9344, 0.1376, 0.9938],
|
|
[0.8924, 0.2872, 0.6692, 0.2944]],
|
|
[[0.9897, 0.6915, 0.3154, 0.1733],
|
|
[0.8645, 0.3513, 0.3064, 0.0767]],
|
|
[[0.8117, 0.2366, 0.4838, 0.7881],
|
|
[0.3718, 0.4945, 0.9511, 0.0864]]]
|
|
)).to(device)
|
|
result = model(encoder_input)
|
|
ref_output = perm_fn(torch.tensor([[[2.428589, 0.020835, -0.602055, -0.085249],
|
|
[2.427987, 0.021213, -0.602496, -0.084103]],
|
|
[[2.424689, 0.019155, -0.604793, -0.085672],
|
|
[2.413863, 0.022211, -0.612486, -0.072490]],
|
|
[[2.433774, 0.021598, -0.598343, -0.087548],
|
|
[2.425104, 0.019748, -0.604515, -0.084839]],
|
|
[[2.436185, 0.022682, -0.596625, -0.087261],
|
|
[2.433556, 0.021891, -0.598509, -0.086832]],
|
|
[[2.416246, 0.017512, -0.610712, -0.082961],
|
|
[2.422901, 0.024187, -0.606178, -0.074929]]]
|
|
)).to(device)
|
|
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
|
|
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
|
|
|
|
# all 0 src_mask
|
|
src_mask = torch.zeros([5, 5]).to(device) == 1
|
|
result = model(encoder_input, mask=src_mask)
|
|
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
|
|
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
|
|
|
|
# all 0
|
|
mask = torch.zeros([2, 5]).to(device) == 1
|
|
result = model(encoder_input, src_key_padding_mask=mask)
|
|
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
|
|
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
|
|
|
|
mask[0, 1] = 1
|
|
mask[1, 3] = 1
|
|
mask[1, 4] = 1
|
|
# If mask is not left aligned
|
|
# We disable nested tensor
|
|
model.enable_nested_tensor = enable_nested_tensor
|
|
result = model(encoder_input, src_key_padding_mask=mask)
|
|
ref_output = perm_fn(torch.tensor([[[2.429026, 0.020793, -0.601741, -0.085642],
|
|
[2.428811, 0.021445, -0.601912, -0.084252]],
|
|
[[2.425009, 0.019155, -0.604566, -0.085899],
|
|
[2.415408, 0.02249, -0.611415, -0.073]],
|
|
[[2.434199, 0.021682, -0.598039, -0.087699],
|
|
[2.42598, 0.019941, -0.603896, -0.085091]],
|
|
[[2.436457, 0.022736, -0.59643, -0.08736],
|
|
[2.434021, 0.022093, -0.598179, -0.08679]],
|
|
[[2.416531, 0.017498, -0.610513, -0.083181],
|
|
[2.4242, 0.024653, -0.605266, -0.074959]]]
|
|
)).to(device)
|
|
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
|
|
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
|
|
|
|
# test case 2, multiple layers no norm
|
|
model = nn.TransformerEncoder(encoder_layer, 2, enable_nested_tensor=enable_nested_tensor).to(device)
|
|
if not training:
|
|
model = model.eval()
|
|
result = model(encoder_input, src_key_padding_mask=mask)
|
|
ref_output = perm_fn(torch.tensor([[[2.419051, 0.017446, -0.608738, -0.085003],
|
|
[2.419102, 0.017452, -0.608703, -0.085026]],
|
|
[[2.419043, 0.017445, -0.608744, -0.084999],
|
|
[2.419052, 0.017446, -0.608738, -0.085004]],
|
|
[[2.419067, 0.017448, -0.608727, -0.085010],
|
|
[2.419098, 0.017452, -0.608706, -0.085024]],
|
|
[[2.419072, 0.017449, -0.608724, -0.085012],
|
|
[2.419119, 0.017455, -0.608691, -0.085034]],
|
|
[[2.419019, 0.017442, -0.608761, -0.084989],
|
|
[2.419075, 0.017449, -0.608722, -0.085014]]]
|
|
)).to(device)
|
|
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
|
|
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
|
|
|
|
model = nn.TransformerEncoder(encoder_layer, 6, enable_nested_tensor=enable_nested_tensor).to(device)
|
|
if not training:
|
|
model = model.eval()
|
|
result = model(encoder_input, src_key_padding_mask=mask)
|
|
ref_output = perm_fn(torch.tensor([[[2.419101, 0.017453, -0.608703, -0.085025],
|
|
[2.419101, 0.017453, -0.608704, -0.085025]],
|
|
[[2.419101, 0.017453, -0.608703, -0.085025],
|
|
[2.419101, 0.017453, -0.608704, -0.085025]],
|
|
[[2.419101, 0.017453, -0.608703, -0.085025],
|
|
[2.419101, 0.017453, -0.608704, -0.085025]],
|
|
[[2.419101, 0.017453, -0.608703, -0.085025],
|
|
[2.419101, 0.017453, -0.608704, -0.085025]],
|
|
[[2.419101, 0.017453, -0.608703, -0.085025],
|
|
[2.419101, 0.017453, -0.608704, -0.085025]]]
|
|
)).to(device)
|
|
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
|
|
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
|
|
|
|
# test case 3, multiple layers with norm
|
|
# d_model = 4
|
|
norm = nn.LayerNorm(4)
|
|
model = nn.TransformerEncoder(encoder_layer, 2, norm=norm, enable_nested_tensor=enable_nested_tensor).to(device)
|
|
if not training:
|
|
model = model.eval()
|
|
result = model(encoder_input, src_key_padding_mask=mask)
|
|
ref_output = perm_fn(torch.tensor([[[1.695949, -0.357635, -0.893077, -0.445238],
|
|
[1.695955, -0.357639, -0.893050, -0.445266]],
|
|
[[1.695948, -0.357634, -0.893082, -0.445233],
|
|
[1.695950, -0.357635, -0.893077, -0.445238]],
|
|
[[1.695951, -0.357636, -0.893069, -0.445246],
|
|
[1.695955, -0.357639, -0.893052, -0.445264]],
|
|
[[1.695952, -0.357636, -0.893066, -0.445249],
|
|
[1.695957, -0.357641, -0.893041, -0.445276]],
|
|
[[1.695946, -0.357632, -0.893095, -0.445220],
|
|
[1.695952, -0.357637, -0.893065, -0.445251]]]
|
|
)).to(device)
|
|
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
|
|
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
|
|
|
|
model = nn.TransformerEncoder(encoder_layer, 6, norm=norm, enable_nested_tensor=enable_nested_tensor).to(device)
|
|
if not training:
|
|
model = model.eval()
|
|
result = model(encoder_input, src_key_padding_mask=mask)
|
|
ref_output = perm_fn(torch.tensor([[[1.695955, -0.357639, -0.893051, -0.445265],
|
|
[1.695955, -0.357639, -0.893051, -0.445265]],
|
|
[[1.695955, -0.357639, -0.893051, -0.445265],
|
|
[1.695955, -0.357639, -0.893051, -0.445265]],
|
|
[[1.695955, -0.357639, -0.893051, -0.445265],
|
|
[1.695955, -0.357639, -0.893051, -0.445265]],
|
|
[[1.695955, -0.357639, -0.893051, -0.445265],
|
|
[1.695955, -0.357639, -0.893051, -0.445265]],
|
|
[[1.695955, -0.357639, -0.893051, -0.445265],
|
|
[1.695955, -0.357639, -0.893051, -0.445265]]]
|
|
)).to(device)
|
|
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
|
|
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
|
|
|
|
# TODO: remove set default dtype to double by making ref_output more precise.
|
|
# Added because this test was copied from test_nn.py, which has default
|
|
# dtype double. If default dtype is float, tests will say tensors not close because
|
|
# ref output precision too low
|
|
with set_default_dtype(torch.double):
|
|
if training:
|
|
cm = contextlib.nullcontext()
|
|
else:
|
|
cm = torch.no_grad() # transformer fast path requires no grad
|
|
with cm:
|
|
_test(batch_first, training, enable_nested_tensor)
|
|
|
|
@unittest.skipIf(not TEST_FAIRSEQ, "Fairseq not found")
|
|
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
|
|
def test_decoder_only_layer(self):
|
|
DEFAULT_PADDING_IDX = 0
|
|
|
|
class FairseqDecoder(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
embed_dim,
|
|
attention_heads,
|
|
ffn_embed_dim,
|
|
num_layers,
|
|
embedding_layer, # torch.nn.Embedding. Must have a padding_idx field
|
|
dropout=0,
|
|
normalize_before=False,
|
|
torch_encoder=None, # torch encoder that you can map weights from
|
|
activation="relu",
|
|
):
|
|
super().__init__()
|
|
|
|
cfg = fairseq_transformer.TransformerConfig()
|
|
cfg.decoder.embed_dim = embed_dim
|
|
cfg.decoder.output_dim = embed_dim
|
|
cfg.decoder.attention_heads = attention_heads
|
|
cfg.decoder.ffn_embed_dim = ffn_embed_dim
|
|
cfg.dropout = dropout
|
|
cfg.decoder.normalize_before = normalize_before
|
|
cfg.decoder.layers = num_layers
|
|
# make embedding behavior same as other encoders
|
|
cfg.no_token_positional_embeddings = True
|
|
cfg.no_scale_embedding = True
|
|
cfg.activation_fn = activation
|
|
|
|
dictionary = {} # TODO: verify what this is
|
|
|
|
self.decoder = fairseq_transformer.TransformerDecoder(
|
|
cfg,
|
|
dictionary,
|
|
embedding_layer,
|
|
no_encoder_attn=True,
|
|
output_projection=None,
|
|
)
|
|
|
|
if torch_encoder is not None:
|
|
self.decoder = torch_to_fairseq(torch_encoder, self.decoder)
|
|
self.decoder = self.decoder.eval().cuda().half()
|
|
|
|
def forward(
|
|
self,
|
|
tokens,
|
|
src_lengths=None,
|
|
with_triangle_mask=False,
|
|
incremental_state=None,
|
|
):
|
|
return self.decoder(
|
|
prev_output_tokens=tokens,
|
|
encoder_out=None,
|
|
incremental_state=incremental_state,
|
|
features_only=True,
|
|
full_context_alignment=not with_triangle_mask,
|
|
alignment_layer=None,
|
|
alignment_heads=None,
|
|
src_lengths=src_lengths,
|
|
return_all_hiddens=False,
|
|
)[0]
|
|
|
|
class BetterDecoder(torch.nn.Module):
|
|
"""
|
|
Only incremental decoder for now
|
|
"""
|
|
|
|
def __init__(self, transformer, embedding, pad_idx):
|
|
super().__init__()
|
|
self.transformer = transformer
|
|
self.embedding = embedding
|
|
self.padding_idx = pad_idx
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
src_mask=None,
|
|
include_padding_mask=True,
|
|
incr_key_lst=None,
|
|
incr_value_lst=None,
|
|
is_incremental_decoding=False,
|
|
):
|
|
padding_mask = None
|
|
if not x.is_nested and include_padding_mask:
|
|
padding_mask = x.eq(self.padding_idx)
|
|
if(is_incremental_decoding):
|
|
x = x[:, -1:] # only take the last token
|
|
x = self.embedding(x)
|
|
|
|
one_encoder_layer = self.transformer.layers[0]
|
|
self_attn = one_encoder_layer.self_attn
|
|
embed_dim = self_attn.embed_dim
|
|
num_heads = self_attn.num_heads
|
|
|
|
use_gelu = (
|
|
one_encoder_layer.activation_relu_or_gelu == 2
|
|
) # see torch/nn/modules/activation attention impl. 1 == relu, 2 == gelu
|
|
assert (
|
|
one_encoder_layer.activation_relu_or_gelu != 0
|
|
) # 0 == not relu or gelu
|
|
|
|
norm_first = one_encoder_layer.norm_first
|
|
|
|
|
|
# TODO: make this a bit less janky. but for now we initialize with an empty tensor.
|
|
if(not is_incremental_decoding):
|
|
assert len(incr_key_lst) == 0 or incr_key_lst[0] is None
|
|
assert len(incr_value_lst) == 0 or incr_value_lst[0] is None
|
|
while len(incr_key_lst) <= len(self.transformer.layers):
|
|
if(is_incremental_decoding):
|
|
incr_key_lst.append(torch.Tensor([]).cuda().half())
|
|
incr_value_lst.append(torch.Tensor([]).cuda().half())
|
|
else:
|
|
incr_key_lst.append(None)
|
|
incr_value_lst.append(None)
|
|
|
|
for i, layer in enumerate(self.transformer.layers):
|
|
incr_key = incr_key_lst[i]
|
|
incr_value = incr_value_lst[i]
|
|
|
|
x, incr_key, incr_value = torch._transformer_decoder_only_layer_fwd(
|
|
src=x,
|
|
embed_dim=embed_dim,
|
|
num_heads=num_heads,
|
|
qkv_weight=layer.self_attn.in_proj_weight,
|
|
qkv_bias=layer.self_attn.in_proj_bias,
|
|
proj_weight=layer.self_attn.out_proj.weight,
|
|
proj_bias=layer.self_attn.out_proj.bias,
|
|
use_gelu=use_gelu,
|
|
norm_first=norm_first,
|
|
# TODO: layer_norm_eps hardcoded to be same as nn.TransformerEncoder default.
|
|
# fix by pulling from self_attn.norm1
|
|
eps=1e-5,
|
|
norm_weight_1=layer.norm1.weight,
|
|
norm_bias_1=layer.norm1.bias,
|
|
norm_weight_2=layer.norm2.weight,
|
|
norm_bias_2=layer.norm2.bias,
|
|
ffn_weight_1=layer.linear1.weight,
|
|
ffn_bias_1=layer.linear1.bias,
|
|
ffn_weight_2=layer.linear2.weight,
|
|
ffn_bias_2=layer.linear2.bias,
|
|
mask=src_mask,
|
|
incr_key=incr_key, # altered in place
|
|
incr_value=incr_value,
|
|
)
|
|
|
|
# not in-place
|
|
if(not is_incremental_decoding):
|
|
incr_key = None
|
|
incr_value = None
|
|
incr_key_lst[i] = incr_key
|
|
incr_value_lst[i] = incr_value
|
|
|
|
return x, incr_key_lst, incr_value_lst
|
|
|
|
def torch_to_fairseq(torch_encoder, fairseq_encoder):
|
|
for src_layer, dst_layer in zip(torch_encoder.layers, fairseq_encoder.layers):
|
|
w_q, w_k, w_v = src_layer.self_attn.in_proj_weight.chunk(3, dim=0)
|
|
b_q, b_k, b_v = src_layer.self_attn.in_proj_bias.chunk(3, dim=0)
|
|
|
|
dst_layer.self_attn.q_proj.weight = torch.nn.Parameter(w_q)
|
|
dst_layer.self_attn.q_proj.bias = torch.nn.Parameter(b_q)
|
|
dst_layer.self_attn.k_proj.weight = torch.nn.Parameter(w_k)
|
|
dst_layer.self_attn.k_proj.bias = torch.nn.Parameter(b_k)
|
|
dst_layer.self_attn.v_proj.weight = torch.nn.Parameter(w_v)
|
|
dst_layer.self_attn.v_proj.bias = torch.nn.Parameter(b_v)
|
|
|
|
dst_layer.self_attn.out_proj.weight = src_layer.self_attn.out_proj.weight
|
|
dst_layer.self_attn.out_proj.bias = src_layer.self_attn.out_proj.bias
|
|
|
|
dst_layer.fc1.weight = src_layer.linear1.weight
|
|
dst_layer.fc1.bias = src_layer.linear1.bias
|
|
|
|
# fairseq may use fusedlayernorm from nvidia apex - diff properties
|
|
dst_layer.self_attn_layer_norm.load_state_dict(src_layer.norm1.state_dict())
|
|
|
|
dst_layer.fc2.weight = src_layer.linear2.weight
|
|
dst_layer.fc2.bias = src_layer.linear2.bias
|
|
|
|
dst_layer.final_layer_norm.load_state_dict(src_layer.norm2.state_dict())
|
|
|
|
return fairseq_encoder
|
|
|
|
def set_weights_deterministic(model):
|
|
for idx, p in enumerate(model.parameters()):
|
|
x = p.data
|
|
sz = x.view(-1).size(0)
|
|
shape = x.shape
|
|
x = torch.cos(torch.arange(0, sz).float().view(shape))
|
|
p.data.copy_(x)
|
|
|
|
D = 4 # d_model
|
|
H = 2 # nhead
|
|
FD = 16 # dim_feedforward
|
|
V = 100 # vocab size
|
|
L = 2 # num layers
|
|
|
|
embedding_layer = torch.nn.Embedding(V, D, DEFAULT_PADDING_IDX)
|
|
layer = torch.nn.TransformerEncoderLayer(
|
|
d_model=D,
|
|
nhead=H,
|
|
dim_feedforward=FD,
|
|
batch_first=True,
|
|
activation="gelu",
|
|
)
|
|
transformer = torch.nn.TransformerEncoder(
|
|
layer,
|
|
num_layers=L,
|
|
).eval().cuda().half()
|
|
|
|
set_weights_deterministic(embedding_layer)
|
|
set_weights_deterministic(transformer)
|
|
|
|
better_decoder = (
|
|
BetterDecoder(transformer, embedding_layer, DEFAULT_PADDING_IDX)
|
|
.eval()
|
|
.cuda()
|
|
.half()
|
|
)
|
|
fairseq_decoder = (
|
|
FairseqDecoder(
|
|
D,
|
|
H,
|
|
FD,
|
|
L,
|
|
embedding_layer,
|
|
dropout=0,
|
|
normalize_before=False,
|
|
torch_encoder=transformer,
|
|
activation="gelu",
|
|
)
|
|
.eval()
|
|
.cuda()
|
|
.half()
|
|
)
|
|
|
|
tokens = torch.Tensor([
|
|
[5, 6, 7, 8],
|
|
[9, 10, 11, 12]
|
|
]).to(torch.int).cuda()
|
|
lengths_tensor = torch.Tensor([2, 2]).to(torch.int).cuda()
|
|
# bs = 2, seqlen = 4
|
|
bs, seqlen = tokens.shape
|
|
|
|
upper_triangle = torch.zeros(seqlen, seqlen)
|
|
upper_triangle.fill_(-100000000)
|
|
upper_triangle = torch.triu(upper_triangle, 1)
|
|
upper_triangle = upper_triangle.cuda().half()
|
|
upper_triangle_expanded = upper_triangle.unsqueeze(0).unsqueeze(0)
|
|
upper_triangle_expanded = upper_triangle_expanded.expand(
|
|
bs, H, -1, -1
|
|
)
|
|
|
|
# test forced decoding
|
|
with torch.no_grad():
|
|
result, _, _ = better_decoder(
|
|
tokens,
|
|
src_mask=upper_triangle_expanded,
|
|
include_padding_mask=False,
|
|
incr_key_lst=[],
|
|
incr_value_lst=[],
|
|
is_incremental_decoding=False,
|
|
)
|
|
ref_output = fairseq_decoder(tokens, lengths_tensor, with_triangle_mask=True)
|
|
|
|
self.assertEqual(result.shape, ref_output.shape)
|
|
torch.testing.assert_close(result, ref_output, atol=1e-3, rtol=1e-2)
|
|
|
|
# test incremental decoding
|
|
bs, seqlen = tokens.shape
|
|
|
|
incr_state = {}
|
|
ref_outputs = [fairseq_decoder(
|
|
tokens[:, :i],
|
|
src_lengths=None,
|
|
with_triangle_mask=False,
|
|
incremental_state=incr_state,
|
|
) for i in range(1, seqlen + 1)]
|
|
ref_output = torch.stack(ref_outputs)
|
|
|
|
incr_key_lst = []
|
|
incr_value_lst = []
|
|
results = []
|
|
for i in range(1, seqlen + 1):
|
|
res, incr_key_lst, incr_value_lst = better_decoder(
|
|
tokens[:, :i],
|
|
src_mask=None,
|
|
include_padding_mask=False,
|
|
incr_key_lst=incr_key_lst,
|
|
incr_value_lst=incr_value_lst,
|
|
is_incremental_decoding=True,
|
|
)
|
|
results.append(res)
|
|
result = torch.stack(results)
|
|
|
|
self.assertEqual(result.shape, ref_output.shape)
|
|
torch.testing.assert_close(result, ref_output, atol=1e-3, rtol=1e-2)
|
|
|
|
@parametrize("attn_mask_dim,is_causal",
|
|
[(None, False), (2, False), (2, True), (3, False), (3, True)],
|
|
name_fn=lambda dim, is_causal: (f"{dim}D_{'causal_' if is_causal else ''}attn_mask"
|
|
if dim is not None else "no_attn_mask"))
|
|
@parametrize("dropout_p", [0.0, 0.2, 0.5])
|
|
@parametrize("device", device_list)
|
|
def test_scaled_dot_product_attention(self, device, attn_mask_dim, is_causal, dropout_p):
|
|
# TODO: Support cross-device / dtype testing properly when instantiate_device_type_tests() is used.
|
|
dtypes = [torch.double, torch.float]
|
|
for dtype in dtypes:
|
|
# This test compares python and C++ implementations of SDP.
|
|
N, L, S, E = 5, 4, 3, 6
|
|
query = torch.randn(N, L, E, device=device, dtype=dtype)
|
|
key = torch.randn(N, S, E, device=device, dtype=dtype)
|
|
value = torch.randn(N, S, E, device=device, dtype=dtype)
|
|
|
|
attn_mask = None
|
|
if attn_mask_dim is not None:
|
|
assert attn_mask_dim in [2, 3]
|
|
mask_size = (L, S) if attn_mask_dim == 2 else (N, L, S)
|
|
attn_mask = (torch.ones(mask_size, device=device, dtype=torch.bool).tril() if is_causal
|
|
else torch.randint(0, 2, size=mask_size, device=device, dtype=torch.bool))
|
|
|
|
with freeze_rng_state():
|
|
# Python impl only supports float mask.
|
|
attn_mask_float = attn_mask
|
|
if attn_mask_float is not None:
|
|
attn_mask_float = torch.zeros_like(attn_mask, dtype=query.dtype)
|
|
attn_mask_float.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
|
expected = F._scaled_dot_product_attention(
|
|
query, key, value, attn_mask=attn_mask_float, dropout_p=dropout_p)
|
|
|
|
need_attn_weights: bool = True
|
|
with freeze_rng_state():
|
|
if is_causal:
|
|
# NB: Don't pass attn_mask here
|
|
actual = torch.ops.aten._scaled_dot_product_attention(
|
|
query, key, value, None, dropout_p, need_attn_weights, is_causal)
|
|
|
|
# Error case: both explicit attn_mask and is_causal are set
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
"Explicit attn_mask should not be set when is_causal=True"):
|
|
torch.ops.aten._scaled_dot_product_attention(
|
|
query, key, value, attn_mask, dropout_p, need_attn_weights, is_causal)
|
|
else:
|
|
actual = torch.ops.aten._scaled_dot_product_attention(
|
|
query, key, value, attn_mask, dropout_p, need_attn_weights, is_causal)
|
|
|
|
# freeze_rng_state() doesn't seem to work outside of CPU, so dropout makes the results incomparable.
|
|
# TODO: Do this skipping in a nicer way once the granular test skipping logic lands.
|
|
if dropout_p == 0.0 or device == 'cpu':
|
|
self.assertEqual(actual, expected)
|
|
|
|
|
|
# TODO: Replace this with instantiate_device_type_tests() to take advantage of test framework support for
|
|
# cross device / dtype testing.
|
|
instantiate_parametrized_tests(TestTransformers)
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|