mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fixes #120488 - The shape for forward pass is clearly stated in the main [transformer class](https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html) - Boolean mask for _key_padding_mask is also explained in the main transformer class. Therefore, add the hyperlink to the transformer class explicitly so the user can refer back to the main class. Also, correct several symbols in the transform doc from normal text style to math style. Pull Request resolved: https://github.com/pytorch/pytorch/pull/120565 Approved by: https://github.com/mikaylagawarecki
957 lines
46 KiB
Python
957 lines
46 KiB
Python
import copy
|
|
from typing import Optional, Any, Union, Callable
|
|
|
|
import torch
|
|
import warnings
|
|
from torch import Tensor
|
|
from .. import functional as F
|
|
from .module import Module
|
|
from .activation import MultiheadAttention
|
|
from .container import ModuleList
|
|
from ..init import xavier_uniform_
|
|
from .dropout import Dropout
|
|
from .linear import Linear
|
|
from .normalization import LayerNorm
|
|
|
|
__all__ = ['Transformer', 'TransformerEncoder', 'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer']
|
|
|
|
def _generate_square_subsequent_mask(
|
|
sz: int,
|
|
device: Optional[torch.device] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
) -> Tensor:
|
|
r"""Generate a square causal mask for the sequence.
|
|
|
|
The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0).
|
|
"""
|
|
if device is None:
|
|
device = torch.device('cpu')
|
|
if dtype is None:
|
|
dtype = torch.float32
|
|
return torch.triu(
|
|
torch.full((sz, sz), float('-inf'), dtype=dtype, device=device),
|
|
diagonal=1,
|
|
)
|
|
|
|
|
|
def _get_seq_len(
|
|
src: Tensor,
|
|
batch_first: bool
|
|
) -> Optional[int]:
|
|
|
|
if src.is_nested:
|
|
return None
|
|
else:
|
|
src_size = src.size()
|
|
if len(src_size) == 2:
|
|
# unbatched: S, E
|
|
return src_size[0]
|
|
else:
|
|
# batched: B, S, E if batch_first else S, B, E
|
|
seq_len_pos = 1 if batch_first else 0
|
|
return src_size[seq_len_pos]
|
|
|
|
|
|
class Transformer(Module):
|
|
r"""A transformer model.
|
|
|
|
User is able to modify the attributes as needed. The architecture
|
|
is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
|
|
Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
|
|
Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
|
|
Processing Systems, pages 6000-6010.
|
|
|
|
Args:
|
|
d_model: the number of expected features in the encoder/decoder inputs (default=512).
|
|
nhead: the number of heads in the multiheadattention models (default=8).
|
|
num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
|
|
num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
|
|
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
|
dropout: the dropout value (default=0.1).
|
|
activation: the activation function of encoder/decoder intermediate layer, can be a string
|
|
("relu" or "gelu") or a unary callable. Default: relu
|
|
custom_encoder: custom encoder (default=None).
|
|
custom_decoder: custom decoder (default=None).
|
|
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
|
|
batch_first: If ``True``, then the input and output tensors are provided
|
|
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
|
norm_first: if ``True``, encoder and decoder layers will perform LayerNorms before
|
|
other attention and feedforward operations, otherwise after. Default: ``False`` (after).
|
|
bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
|
|
bias. Default: ``True``.
|
|
|
|
Examples::
|
|
>>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
|
|
>>> src = torch.rand((10, 32, 512))
|
|
>>> tgt = torch.rand((20, 32, 512))
|
|
>>> out = transformer_model(src, tgt)
|
|
|
|
Note: A full example to apply nn.Transformer module for the word language model is available in
|
|
https://github.com/pytorch/examples/tree/master/word_language_model
|
|
"""
|
|
|
|
def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,
|
|
num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
|
|
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
|
custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None,
|
|
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
|
|
bias: bool = True, device=None, dtype=None) -> None:
|
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
super().__init__()
|
|
torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
|
|
|
|
if custom_encoder is not None:
|
|
self.encoder = custom_encoder
|
|
else:
|
|
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
|
|
activation, layer_norm_eps, batch_first, norm_first,
|
|
bias, **factory_kwargs)
|
|
encoder_norm = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
|
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
|
|
|
if custom_decoder is not None:
|
|
self.decoder = custom_decoder
|
|
else:
|
|
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
|
|
activation, layer_norm_eps, batch_first, norm_first,
|
|
bias, **factory_kwargs)
|
|
decoder_norm = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
|
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
|
|
|
|
self._reset_parameters()
|
|
|
|
self.d_model = d_model
|
|
self.nhead = nhead
|
|
|
|
self.batch_first = batch_first
|
|
|
|
def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,
|
|
memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
|
|
tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None,
|
|
src_is_causal: Optional[bool] = None, tgt_is_causal: Optional[bool] = None,
|
|
memory_is_causal: bool = False) -> Tensor:
|
|
r"""Take in and process masked source/target sequences.
|
|
|
|
Args:
|
|
src: the sequence to the encoder (required).
|
|
tgt: the sequence to the decoder (required).
|
|
src_mask: the additive mask for the src sequence (optional).
|
|
tgt_mask: the additive mask for the tgt sequence (optional).
|
|
memory_mask: the additive mask for the encoder output (optional).
|
|
src_key_padding_mask: the Tensor mask for src keys per batch (optional).
|
|
tgt_key_padding_mask: the Tensor mask for tgt keys per batch (optional).
|
|
memory_key_padding_mask: the Tensor mask for memory keys per batch (optional).
|
|
src_is_causal: If specified, applies a causal mask as ``src_mask``.
|
|
Default: ``None``; try to detect a causal mask.
|
|
Warning:
|
|
``src_is_causal`` provides a hint that ``src_mask`` is
|
|
the causal mask. Providing incorrect hints can result in
|
|
incorrect execution, including forward and backward
|
|
compatibility.
|
|
tgt_is_causal: If specified, applies a causal mask as ``tgt_mask``.
|
|
Default: ``None``; try to detect a causal mask.
|
|
Warning:
|
|
``tgt_is_causal`` provides a hint that ``tgt_mask`` is
|
|
the causal mask. Providing incorrect hints can result in
|
|
incorrect execution, including forward and backward
|
|
compatibility.
|
|
memory_is_causal: If specified, applies a causal mask as
|
|
``memory_mask``.
|
|
Default: ``False``.
|
|
Warning:
|
|
``memory_is_causal`` provides a hint that
|
|
``memory_mask`` is the causal mask. Providing incorrect
|
|
hints can result in incorrect execution, including
|
|
forward and backward compatibility.
|
|
|
|
Shape:
|
|
- src: :math:`(S, E)` for unbatched input, :math:`(S, N, E)` if `batch_first=False` or
|
|
`(N, S, E)` if `batch_first=True`.
|
|
- tgt: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
|
|
`(N, T, E)` if `batch_first=True`.
|
|
- src_mask: :math:`(S, S)` or :math:`(N\cdot\text{num\_heads}, S, S)`.
|
|
- tgt_mask: :math:`(T, T)` or :math:`(N\cdot\text{num\_heads}, T, T)`.
|
|
- memory_mask: :math:`(T, S)`.
|
|
- src_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
|
|
- tgt_key_padding_mask: :math:`(T)` for unbatched input otherwise :math:`(N, T)`.
|
|
- memory_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
|
|
|
|
Note: [src/tgt/memory]_mask ensures that position :math:`i` is allowed to attend the unmasked
|
|
positions. If a BoolTensor is provided, positions with ``True``
|
|
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
|
is provided, it will be added to the attention weight.
|
|
[src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
|
|
the attention. If a BoolTensor is provided, the positions with the
|
|
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
|
|
|
- output: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
|
|
`(N, T, E)` if `batch_first=True`.
|
|
|
|
Note: Due to the multi-head attention architecture in the transformer model,
|
|
the output sequence length of a transformer is same as the input sequence
|
|
(i.e. target) length of the decoder.
|
|
|
|
where :math:`S` is the source sequence length, :math:`T` is the target sequence length, :math:`N` is the
|
|
batch size, :math:`E` is the feature number
|
|
|
|
Examples:
|
|
>>> # xdoctest: +SKIP
|
|
>>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
|
|
"""
|
|
is_batched = src.dim() == 3
|
|
if not self.batch_first and src.size(1) != tgt.size(1) and is_batched:
|
|
raise RuntimeError("the batch number of src and tgt must be equal")
|
|
elif self.batch_first and src.size(0) != tgt.size(0) and is_batched:
|
|
raise RuntimeError("the batch number of src and tgt must be equal")
|
|
|
|
if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model:
|
|
raise RuntimeError("the feature number of src and tgt must be equal to d_model")
|
|
|
|
memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask,
|
|
is_causal=src_is_causal)
|
|
output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
|
|
tgt_key_padding_mask=tgt_key_padding_mask,
|
|
memory_key_padding_mask=memory_key_padding_mask,
|
|
tgt_is_causal=tgt_is_causal, memory_is_causal=memory_is_causal)
|
|
return output
|
|
|
|
@staticmethod
|
|
def generate_square_subsequent_mask(
|
|
sz: int,
|
|
device: Optional[torch.device] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
) -> Tensor:
|
|
r"""Generate a square causal mask for the sequence.
|
|
|
|
The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0).
|
|
"""
|
|
return _generate_square_subsequent_mask(sz, dtype=dtype, device=device)
|
|
|
|
def _reset_parameters(self):
|
|
r"""Initiate parameters in the transformer model."""
|
|
for p in self.parameters():
|
|
if p.dim() > 1:
|
|
xavier_uniform_(p)
|
|
|
|
|
|
class TransformerEncoder(Module):
|
|
r"""TransformerEncoder is a stack of N encoder layers.
|
|
|
|
Users can build the BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
|
|
|
|
Args:
|
|
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
|
|
num_layers: the number of sub-encoder-layers in the encoder (required).
|
|
norm: the layer normalization component (optional).
|
|
enable_nested_tensor: if True, input will automatically convert to nested tensor
|
|
(and convert back on output). This will improve the overall performance of
|
|
TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
|
|
|
|
Examples::
|
|
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
|
>>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
|
|
>>> src = torch.rand(10, 32, 512)
|
|
>>> out = transformer_encoder(src)
|
|
"""
|
|
|
|
__constants__ = ['norm']
|
|
|
|
def __init__(self, encoder_layer, num_layers, norm=None, enable_nested_tensor=True, mask_check=True):
|
|
super().__init__()
|
|
torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
|
|
self.layers = _get_clones(encoder_layer, num_layers)
|
|
self.num_layers = num_layers
|
|
self.norm = norm
|
|
# this attribute saves the value providedat object construction
|
|
self.enable_nested_tensor = enable_nested_tensor
|
|
# this attribute controls whether nested tensors are used
|
|
self.use_nested_tensor = enable_nested_tensor
|
|
self.mask_check = mask_check
|
|
|
|
enc_layer = "encoder_layer"
|
|
why_not_sparsity_fast_path = ''
|
|
if not isinstance(encoder_layer, torch.nn.TransformerEncoderLayer):
|
|
why_not_sparsity_fast_path = f"{enc_layer} was not TransformerEncoderLayer"
|
|
elif encoder_layer.norm_first :
|
|
why_not_sparsity_fast_path = f"{enc_layer}.norm_first was True"
|
|
elif not encoder_layer.self_attn.batch_first:
|
|
why_not_sparsity_fast_path = (f"{enc_layer}.self_attn.batch_first was not True" +
|
|
"(use batch_first for better inference performance)")
|
|
elif not encoder_layer.self_attn._qkv_same_embed_dim:
|
|
why_not_sparsity_fast_path = f"{enc_layer}.self_attn._qkv_same_embed_dim was not True"
|
|
elif encoder_layer.self_attn.in_proj_bias is None:
|
|
why_not_sparsity_fast_path = f"{enc_layer}.self_attn was passed bias=False"
|
|
elif not encoder_layer.activation_relu_or_gelu:
|
|
why_not_sparsity_fast_path = f"{enc_layer}.activation_relu_or_gelu was not True"
|
|
elif not (encoder_layer.norm1.eps == encoder_layer.norm2.eps) :
|
|
why_not_sparsity_fast_path = f"{enc_layer}.norm1.eps was not equal to {enc_layer}.norm2.eps"
|
|
elif encoder_layer.self_attn.num_heads % 2 == 1:
|
|
why_not_sparsity_fast_path = f"{enc_layer}.self_attn.num_heads is odd"
|
|
|
|
if enable_nested_tensor and why_not_sparsity_fast_path:
|
|
warnings.warn(f"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}")
|
|
self.use_nested_tensor = False
|
|
|
|
|
|
def forward(
|
|
self,
|
|
src: Tensor,
|
|
mask: Optional[Tensor] = None,
|
|
src_key_padding_mask: Optional[Tensor] = None,
|
|
is_causal: Optional[bool] = None) -> Tensor:
|
|
r"""Pass the input through the encoder layers in turn.
|
|
|
|
Args:
|
|
src: the sequence to the encoder (required).
|
|
mask: the mask for the src sequence (optional).
|
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
|
is_causal: If specified, applies a causal mask as ``mask``.
|
|
Default: ``None``; try to detect a causal mask.
|
|
Warning:
|
|
``is_causal`` provides a hint that ``mask`` is the
|
|
causal mask. Providing incorrect hints can result in
|
|
incorrect execution, including forward and backward
|
|
compatibility.
|
|
|
|
Shape:
|
|
see the docs in :class:`~torch.nn.Transformer`.
|
|
"""
|
|
src_key_padding_mask = F._canonical_mask(
|
|
mask=src_key_padding_mask,
|
|
mask_name="src_key_padding_mask",
|
|
other_type=F._none_or_dtype(mask),
|
|
other_name="mask",
|
|
target_type=src.dtype
|
|
)
|
|
|
|
mask = F._canonical_mask(
|
|
mask=mask,
|
|
mask_name="mask",
|
|
other_type=None,
|
|
other_name="",
|
|
target_type=src.dtype,
|
|
check_other=False,
|
|
)
|
|
|
|
output = src
|
|
convert_to_nested = False
|
|
first_layer = self.layers[0]
|
|
src_key_padding_mask_for_layers = src_key_padding_mask
|
|
why_not_sparsity_fast_path = ''
|
|
str_first_layer = "self.layers[0]"
|
|
batch_first = first_layer.self_attn.batch_first
|
|
is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
|
|
|
|
if not is_fastpath_enabled:
|
|
why_not_sparsity_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
|
|
elif not hasattr(self, "use_nested_tensor"):
|
|
why_not_sparsity_fast_path = "use_nested_tensor attribute not present"
|
|
elif not self.use_nested_tensor:
|
|
why_not_sparsity_fast_path = "self.use_nested_tensor (set in init) was not True"
|
|
elif first_layer.training:
|
|
why_not_sparsity_fast_path = f"{str_first_layer} was in training mode"
|
|
elif not src.dim() == 3:
|
|
why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
|
|
elif src_key_padding_mask is None:
|
|
why_not_sparsity_fast_path = "src_key_padding_mask was None"
|
|
elif (((not hasattr(self, "mask_check")) or self.mask_check)
|
|
and not torch._nested_tensor_from_mask_left_aligned(src, src_key_padding_mask.logical_not())):
|
|
why_not_sparsity_fast_path = "mask_check enabled, and src and src_key_padding_mask was not left aligned"
|
|
elif output.is_nested:
|
|
why_not_sparsity_fast_path = "NestedTensor input is not supported"
|
|
elif mask is not None:
|
|
why_not_sparsity_fast_path = "src_key_padding_mask and mask were both supplied"
|
|
elif torch.is_autocast_enabled():
|
|
why_not_sparsity_fast_path = "autocast is enabled"
|
|
|
|
if not why_not_sparsity_fast_path:
|
|
tensor_args = (
|
|
src,
|
|
first_layer.self_attn.in_proj_weight,
|
|
first_layer.self_attn.in_proj_bias,
|
|
first_layer.self_attn.out_proj.weight,
|
|
first_layer.self_attn.out_proj.bias,
|
|
first_layer.norm1.weight,
|
|
first_layer.norm1.bias,
|
|
first_layer.norm2.weight,
|
|
first_layer.norm2.bias,
|
|
first_layer.linear1.weight,
|
|
first_layer.linear1.bias,
|
|
first_layer.linear2.weight,
|
|
first_layer.linear2.bias,
|
|
)
|
|
_supported_device_type = ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name]
|
|
if torch.overrides.has_torch_function(tensor_args):
|
|
why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
|
|
elif src.device.type not in _supported_device_type:
|
|
why_not_sparsity_fast_path = f"src device is neither one of {_supported_device_type}"
|
|
elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
|
|
why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
|
|
"input/output projection weights or biases requires_grad")
|
|
|
|
if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None):
|
|
convert_to_nested = True
|
|
output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
|
|
src_key_padding_mask_for_layers = None
|
|
|
|
seq_len = _get_seq_len(src, batch_first)
|
|
is_causal = _detect_is_causal_mask(mask, is_causal, seq_len)
|
|
|
|
for mod in self.layers:
|
|
output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers)
|
|
|
|
if convert_to_nested:
|
|
output = output.to_padded_tensor(0., src.size())
|
|
|
|
if self.norm is not None:
|
|
output = self.norm(output)
|
|
|
|
return output
|
|
|
|
|
|
class TransformerDecoder(Module):
|
|
r"""TransformerDecoder is a stack of N decoder layers.
|
|
|
|
Args:
|
|
decoder_layer: an instance of the TransformerDecoderLayer() class (required).
|
|
num_layers: the number of sub-decoder-layers in the decoder (required).
|
|
norm: the layer normalization component (optional).
|
|
|
|
Examples::
|
|
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
|
|
>>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
|
|
>>> memory = torch.rand(10, 32, 512)
|
|
>>> tgt = torch.rand(20, 32, 512)
|
|
>>> out = transformer_decoder(tgt, memory)
|
|
"""
|
|
|
|
__constants__ = ['norm']
|
|
|
|
def __init__(self, decoder_layer, num_layers, norm=None):
|
|
super().__init__()
|
|
torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
|
|
self.layers = _get_clones(decoder_layer, num_layers)
|
|
self.num_layers = num_layers
|
|
self.norm = norm
|
|
|
|
def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
|
|
memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
|
|
memory_key_padding_mask: Optional[Tensor] = None, tgt_is_causal: Optional[bool] = None,
|
|
memory_is_causal: bool = False) -> Tensor:
|
|
r"""Pass the inputs (and mask) through the decoder layer in turn.
|
|
|
|
Args:
|
|
tgt: the sequence to the decoder (required).
|
|
memory: the sequence from the last layer of the encoder (required).
|
|
tgt_mask: the mask for the tgt sequence (optional).
|
|
memory_mask: the mask for the memory sequence (optional).
|
|
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
|
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
|
tgt_is_causal: If specified, applies a causal mask as ``tgt mask``.
|
|
Default: ``None``; try to detect a causal mask.
|
|
Warning:
|
|
``tgt_is_causal`` provides a hint that ``tgt_mask`` is
|
|
the causal mask. Providing incorrect hints can result in
|
|
incorrect execution, including forward and backward
|
|
compatibility.
|
|
memory_is_causal: If specified, applies a causal mask as
|
|
``memory mask``.
|
|
Default: ``False``.
|
|
Warning:
|
|
``memory_is_causal`` provides a hint that
|
|
``memory_mask`` is the causal mask. Providing incorrect
|
|
hints can result in incorrect execution, including
|
|
forward and backward compatibility.
|
|
|
|
Shape:
|
|
see the docs in :class:`~torch.nn.Transformer`.
|
|
"""
|
|
output = tgt
|
|
|
|
seq_len = _get_seq_len(tgt, self.layers[0].self_attn.batch_first)
|
|
tgt_is_causal = _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len)
|
|
|
|
for mod in self.layers:
|
|
output = mod(output, memory, tgt_mask=tgt_mask,
|
|
memory_mask=memory_mask,
|
|
tgt_key_padding_mask=tgt_key_padding_mask,
|
|
memory_key_padding_mask=memory_key_padding_mask,
|
|
tgt_is_causal=tgt_is_causal,
|
|
memory_is_causal=memory_is_causal)
|
|
|
|
if self.norm is not None:
|
|
output = self.norm(output)
|
|
|
|
return output
|
|
|
|
class TransformerEncoderLayer(Module):
|
|
r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
|
|
|
|
This standard encoder layer is based on the paper "Attention Is All You Need".
|
|
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
|
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
|
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
|
in a different way during application.
|
|
|
|
TransformerEncoderLayer can handle either traditional torch.tensor inputs,
|
|
or Nested Tensor inputs. Derived classes are expected to similarly accept
|
|
both input formats. (Not all combinations of inputs are currently
|
|
supported by TransformerEncoderLayer while Nested Tensor is in prototype
|
|
state.)
|
|
|
|
If you are implementing a custom layer, you may derive it either from
|
|
the Module or TransformerEncoderLayer class. If your custom layer
|
|
supports both torch.Tensors and Nested Tensors inputs, make its
|
|
implementation a derived class of TransformerEncoderLayer. If your custom
|
|
Layer supports only torch.Tensor inputs, derive its implementation from
|
|
Module.
|
|
|
|
Args:
|
|
d_model: the number of expected features in the input (required).
|
|
nhead: the number of heads in the multiheadattention models (required).
|
|
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
|
dropout: the dropout value (default=0.1).
|
|
activation: the activation function of the intermediate layer, can be a string
|
|
("relu" or "gelu") or a unary callable. Default: relu
|
|
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
|
|
batch_first: If ``True``, then the input and output tensors are provided
|
|
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
|
norm_first: if ``True``, layer norm is done prior to attention and feedforward
|
|
operations, respectively. Otherwise it's done after. Default: ``False`` (after).
|
|
bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
|
|
bias. Default: ``True``.
|
|
|
|
Examples::
|
|
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
|
>>> src = torch.rand(10, 32, 512)
|
|
>>> out = encoder_layer(src)
|
|
|
|
Alternatively, when ``batch_first`` is ``True``:
|
|
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
|
|
>>> src = torch.rand(32, 10, 512)
|
|
>>> out = encoder_layer(src)
|
|
|
|
Fast path:
|
|
forward() will use a special optimized implementation described in
|
|
`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`_ if all of the following
|
|
conditions are met:
|
|
|
|
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor
|
|
argument ``requires_grad``
|
|
- training is disabled (using ``.eval()``)
|
|
- batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``)
|
|
- activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu``
|
|
- at most one of ``src_mask`` and ``src_key_padding_mask`` is passed
|
|
- if src is a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_, neither ``src_mask``
|
|
nor ``src_key_padding_mask`` is passed
|
|
- the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case
|
|
unless the caller has manually modified one without modifying the other)
|
|
|
|
If the optimized implementation is in use, a
|
|
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be
|
|
passed for ``src`` to represent padding more efficiently than using a padding
|
|
mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ will be
|
|
returned, and an additional speedup proportional to the fraction of the input that
|
|
is padding can be expected.
|
|
|
|
.. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
|
|
https://arxiv.org/abs/2205.14135
|
|
|
|
"""
|
|
|
|
__constants__ = ['norm_first']
|
|
|
|
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
|
|
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
|
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
|
|
bias: bool = True, device=None, dtype=None) -> None:
|
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
super().__init__()
|
|
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout,
|
|
bias=bias, batch_first=batch_first,
|
|
**factory_kwargs)
|
|
# Implementation of Feedforward model
|
|
self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
|
|
self.dropout = Dropout(dropout)
|
|
self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
|
|
|
|
self.norm_first = norm_first
|
|
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
|
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
|
self.dropout1 = Dropout(dropout)
|
|
self.dropout2 = Dropout(dropout)
|
|
|
|
# Legacy string support for activation function.
|
|
if isinstance(activation, str):
|
|
activation = _get_activation_fn(activation)
|
|
|
|
# We can't test self.activation in forward() in TorchScript,
|
|
# so stash some information about it instead.
|
|
if activation is F.relu or isinstance(activation, torch.nn.ReLU):
|
|
self.activation_relu_or_gelu = 1
|
|
elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
|
|
self.activation_relu_or_gelu = 2
|
|
else:
|
|
self.activation_relu_or_gelu = 0
|
|
self.activation = activation
|
|
|
|
def __setstate__(self, state):
|
|
super().__setstate__(state)
|
|
if not hasattr(self, 'activation'):
|
|
self.activation = F.relu
|
|
|
|
|
|
def forward(
|
|
self,
|
|
src: Tensor,
|
|
src_mask: Optional[Tensor] = None,
|
|
src_key_padding_mask: Optional[Tensor] = None,
|
|
is_causal: bool = False) -> Tensor:
|
|
r"""Pass the input through the encoder layer.
|
|
|
|
Args:
|
|
src: the sequence to the encoder layer (required).
|
|
src_mask: the mask for the src sequence (optional).
|
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
|
is_causal: If specified, applies a causal mask as ``src mask``.
|
|
Default: ``False``.
|
|
Warning:
|
|
``is_causal`` provides a hint that ``src_mask`` is the
|
|
causal mask. Providing incorrect hints can result in
|
|
incorrect execution, including forward and backward
|
|
compatibility.
|
|
|
|
Shape:
|
|
see the docs in :class:`~torch.nn.Transformer`.
|
|
"""
|
|
src_key_padding_mask = F._canonical_mask(
|
|
mask=src_key_padding_mask,
|
|
mask_name="src_key_padding_mask",
|
|
other_type=F._none_or_dtype(src_mask),
|
|
other_name="src_mask",
|
|
target_type=src.dtype
|
|
)
|
|
|
|
src_mask = F._canonical_mask(
|
|
mask=src_mask,
|
|
mask_name="src_mask",
|
|
other_type=None,
|
|
other_name="",
|
|
target_type=src.dtype,
|
|
check_other=False,
|
|
)
|
|
|
|
is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
|
|
|
|
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
|
|
why_not_sparsity_fast_path = ''
|
|
if not is_fastpath_enabled:
|
|
why_not_sparsity_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
|
|
elif not src.dim() == 3:
|
|
why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
|
|
elif self.training:
|
|
why_not_sparsity_fast_path = "training is enabled"
|
|
elif not self.self_attn.batch_first:
|
|
why_not_sparsity_fast_path = "self_attn.batch_first was not True"
|
|
elif self.self_attn.in_proj_bias is None:
|
|
why_not_sparsity_fast_path = "self_attn was passed bias=False"
|
|
elif not self.self_attn._qkv_same_embed_dim:
|
|
why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True"
|
|
elif not self.activation_relu_or_gelu:
|
|
why_not_sparsity_fast_path = "activation_relu_or_gelu was not True"
|
|
elif not (self.norm1.eps == self.norm2.eps):
|
|
why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps"
|
|
elif src.is_nested and (src_key_padding_mask is not None or src_mask is not None):
|
|
why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input"
|
|
elif self.self_attn.num_heads % 2 == 1:
|
|
why_not_sparsity_fast_path = "num_head is odd"
|
|
elif torch.is_autocast_enabled():
|
|
why_not_sparsity_fast_path = "autocast is enabled"
|
|
if not why_not_sparsity_fast_path:
|
|
tensor_args = (
|
|
src,
|
|
self.self_attn.in_proj_weight,
|
|
self.self_attn.in_proj_bias,
|
|
self.self_attn.out_proj.weight,
|
|
self.self_attn.out_proj.bias,
|
|
self.norm1.weight,
|
|
self.norm1.bias,
|
|
self.norm2.weight,
|
|
self.norm2.bias,
|
|
self.linear1.weight,
|
|
self.linear1.bias,
|
|
self.linear2.weight,
|
|
self.linear2.bias,
|
|
)
|
|
|
|
# We have to use list comprehensions below because TorchScript does not support
|
|
# generator expressions.
|
|
_supported_device_type = ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name]
|
|
if torch.overrides.has_torch_function(tensor_args):
|
|
why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
|
|
elif not all((x.device.type in _supported_device_type) for x in tensor_args):
|
|
why_not_sparsity_fast_path = ("some Tensor argument's device is neither one of "
|
|
f"{_supported_device_type}")
|
|
elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
|
|
why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
|
|
"input/output projection weights or biases requires_grad")
|
|
|
|
if not why_not_sparsity_fast_path:
|
|
merged_mask, mask_type = self.self_attn.merge_masks(src_mask, src_key_padding_mask, src)
|
|
return torch._transformer_encoder_layer_fwd(
|
|
src,
|
|
self.self_attn.embed_dim,
|
|
self.self_attn.num_heads,
|
|
self.self_attn.in_proj_weight,
|
|
self.self_attn.in_proj_bias,
|
|
self.self_attn.out_proj.weight,
|
|
self.self_attn.out_proj.bias,
|
|
self.activation_relu_or_gelu == 2,
|
|
self.norm_first,
|
|
self.norm1.eps,
|
|
self.norm1.weight,
|
|
self.norm1.bias,
|
|
self.norm2.weight,
|
|
self.norm2.bias,
|
|
self.linear1.weight,
|
|
self.linear1.bias,
|
|
self.linear2.weight,
|
|
self.linear2.bias,
|
|
merged_mask,
|
|
mask_type,
|
|
)
|
|
|
|
|
|
x = src
|
|
if self.norm_first:
|
|
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal)
|
|
x = x + self._ff_block(self.norm2(x))
|
|
else:
|
|
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal))
|
|
x = self.norm2(x + self._ff_block(x))
|
|
|
|
return x
|
|
|
|
# self-attention block
|
|
def _sa_block(self, x: Tensor,
|
|
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
|
|
x = self.self_attn(x, x, x,
|
|
attn_mask=attn_mask,
|
|
key_padding_mask=key_padding_mask,
|
|
need_weights=False, is_causal=is_causal)[0]
|
|
return self.dropout1(x)
|
|
|
|
# feed forward block
|
|
def _ff_block(self, x: Tensor) -> Tensor:
|
|
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
|
return self.dropout2(x)
|
|
|
|
|
|
class TransformerDecoderLayer(Module):
|
|
r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
|
|
|
|
This standard decoder layer is based on the paper "Attention Is All You Need".
|
|
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
|
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
|
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
|
in a different way during application.
|
|
|
|
Args:
|
|
d_model: the number of expected features in the input (required).
|
|
nhead: the number of heads in the multiheadattention models (required).
|
|
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
|
dropout: the dropout value (default=0.1).
|
|
activation: the activation function of the intermediate layer, can be a string
|
|
("relu" or "gelu") or a unary callable. Default: relu
|
|
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
|
|
batch_first: If ``True``, then the input and output tensors are provided
|
|
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
|
norm_first: if ``True``, layer norm is done prior to self attention, multihead
|
|
attention and feedforward operations, respectively. Otherwise it's done after.
|
|
Default: ``False`` (after).
|
|
bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
|
|
bias. Default: ``True``.
|
|
|
|
Examples::
|
|
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
|
|
>>> memory = torch.rand(10, 32, 512)
|
|
>>> tgt = torch.rand(20, 32, 512)
|
|
>>> out = decoder_layer(tgt, memory)
|
|
|
|
Alternatively, when ``batch_first`` is ``True``:
|
|
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
|
|
>>> memory = torch.rand(32, 10, 512)
|
|
>>> tgt = torch.rand(32, 20, 512)
|
|
>>> out = decoder_layer(tgt, memory)
|
|
"""
|
|
|
|
__constants__ = ['norm_first']
|
|
|
|
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
|
|
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
|
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
|
|
bias: bool = True, device=None, dtype=None) -> None:
|
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
super().__init__()
|
|
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
|
|
bias=bias, **factory_kwargs)
|
|
self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
|
|
bias=bias, **factory_kwargs)
|
|
# Implementation of Feedforward model
|
|
self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
|
|
self.dropout = Dropout(dropout)
|
|
self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
|
|
|
|
self.norm_first = norm_first
|
|
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
|
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
|
self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
|
self.dropout1 = Dropout(dropout)
|
|
self.dropout2 = Dropout(dropout)
|
|
self.dropout3 = Dropout(dropout)
|
|
|
|
# Legacy string support for activation function.
|
|
if isinstance(activation, str):
|
|
self.activation = _get_activation_fn(activation)
|
|
else:
|
|
self.activation = activation
|
|
|
|
def __setstate__(self, state):
|
|
if 'activation' not in state:
|
|
state['activation'] = F.relu
|
|
super().__setstate__(state)
|
|
|
|
def forward(
|
|
self,
|
|
tgt: Tensor,
|
|
memory: Tensor,
|
|
tgt_mask: Optional[Tensor] = None,
|
|
memory_mask: Optional[Tensor] = None,
|
|
tgt_key_padding_mask: Optional[Tensor] = None,
|
|
memory_key_padding_mask: Optional[Tensor] = None,
|
|
tgt_is_causal: bool = False,
|
|
memory_is_causal: bool = False,
|
|
) -> Tensor:
|
|
r"""Pass the inputs (and mask) through the decoder layer.
|
|
|
|
Args:
|
|
tgt: the sequence to the decoder layer (required).
|
|
memory: the sequence from the last layer of the encoder (required).
|
|
tgt_mask: the mask for the tgt sequence (optional).
|
|
memory_mask: the mask for the memory sequence (optional).
|
|
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
|
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
|
tgt_is_causal: If specified, applies a causal mask as ``tgt mask``.
|
|
Default: ``False``.
|
|
Warning:
|
|
``tgt_is_causal`` provides a hint that ``tgt_mask`` is
|
|
the causal mask. Providing incorrect hints can result in
|
|
incorrect execution, including forward and backward
|
|
compatibility.
|
|
memory_is_causal: If specified, applies a causal mask as
|
|
``memory mask``.
|
|
Default: ``False``.
|
|
Warning:
|
|
``memory_is_causal`` provides a hint that
|
|
``memory_mask`` is the causal mask. Providing incorrect
|
|
hints can result in incorrect execution, including
|
|
forward and backward compatibility.
|
|
|
|
Shape:
|
|
see the docs in :class:`~torch.nn.Transformer`.
|
|
"""
|
|
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
|
|
|
|
x = tgt
|
|
if self.norm_first:
|
|
x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal)
|
|
x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, memory_is_causal)
|
|
x = x + self._ff_block(self.norm3(x))
|
|
else:
|
|
x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal))
|
|
x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal))
|
|
x = self.norm3(x + self._ff_block(x))
|
|
|
|
return x
|
|
|
|
# self-attention block
|
|
def _sa_block(self, x: Tensor,
|
|
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
|
|
x = self.self_attn(x, x, x,
|
|
attn_mask=attn_mask,
|
|
key_padding_mask=key_padding_mask,
|
|
is_causal=is_causal,
|
|
need_weights=False)[0]
|
|
return self.dropout1(x)
|
|
|
|
# multihead attention block
|
|
def _mha_block(self, x: Tensor, mem: Tensor,
|
|
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
|
|
x = self.multihead_attn(x, mem, mem,
|
|
attn_mask=attn_mask,
|
|
key_padding_mask=key_padding_mask,
|
|
is_causal=is_causal,
|
|
need_weights=False)[0]
|
|
return self.dropout2(x)
|
|
|
|
# feed forward block
|
|
def _ff_block(self, x: Tensor) -> Tensor:
|
|
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
|
return self.dropout3(x)
|
|
|
|
|
|
def _get_clones(module, N):
|
|
# FIXME: copy.deepcopy() is not defined on nn.module
|
|
return ModuleList([copy.deepcopy(module) for i in range(N)])
|
|
|
|
|
|
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
|
|
if activation == "relu":
|
|
return F.relu
|
|
elif activation == "gelu":
|
|
return F.gelu
|
|
|
|
raise RuntimeError(f"activation should be relu/gelu, not {activation}")
|
|
|
|
|
|
def _detect_is_causal_mask(
|
|
mask: Optional[Tensor],
|
|
is_causal: Optional[bool] = None,
|
|
size: Optional[int] = None,
|
|
) -> bool:
|
|
"""Return whether the given attention mask is causal.
|
|
|
|
Warning:
|
|
If ``is_causal`` is not ``None``, its value will be returned as is. If a
|
|
user supplies an incorrect ``is_causal`` hint,
|
|
|
|
``is_causal=False`` when the mask is in fact a causal attention.mask
|
|
may lead to reduced performance relative to what would be achievable
|
|
with ``is_causal=True``;
|
|
``is_causal=True`` when the mask is in fact not a causal attention.mask
|
|
may lead to incorrect and unpredictable execution - in some scenarios,
|
|
a causal mask may be applied based on the hint, in other execution
|
|
scenarios the specified mask may be used. The choice may not appear
|
|
to be deterministic, in that a number of factors like alignment,
|
|
hardware SKU, etc influence the decision whether to use a mask or
|
|
rely on the hint.
|
|
``size`` if not None, check whether the mask is a causal mask of the provided size
|
|
Otherwise, checks for any causal mask.
|
|
"""
|
|
# Prevent type refinement
|
|
make_causal = (is_causal is True)
|
|
|
|
if is_causal is None and mask is not None:
|
|
sz = size if size is not None else mask.size(-2)
|
|
causal_comparison = _generate_square_subsequent_mask(
|
|
sz, device=mask.device, dtype=mask.dtype)
|
|
|
|
# Do not use `torch.equal` so we handle batched masks by
|
|
# broadcasting the comparison.
|
|
if mask.size() == causal_comparison.size():
|
|
make_causal = bool((mask == causal_comparison).all())
|
|
else:
|
|
make_causal = False
|
|
|
|
return make_causal
|