mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[nn] lstm : no batch dim support (#71056)
Summary:
Reference: https://github.com/pytorch/pytorch/issues/60585
TODO:
* [x] Update docs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71056
Reviewed By: samdow
Differential Revision: D33638643
Pulled By: jbschlosser
fbshipit-source-id: c0949829de8a8e6e7b2873f459a8d7da597a3be3
(cherry picked from commit f94d5849f6)
This commit is contained in:
parent
99d9883a22
commit
b372be4211
|
|
@ -363,8 +363,11 @@ class TestModule(TestCase):
|
|||
grad_output = default_output.clone().detach_().normal_()
|
||||
default_output.backward(grad_output, retain_graph=True)
|
||||
else:
|
||||
grad_output = tuple(o.clone().detach_().normal_() for o in default_output)
|
||||
for o, g_o in zip(default_output, grad_output):
|
||||
grad_output = tuple(self._traverse_obj(o, lambda o: o.clone().detach_().normal_())
|
||||
for o in default_output)
|
||||
flattened_default_output, _ = torch.utils._pytree.tree_flatten(default_output)
|
||||
flattened_grad_output, _ = torch.utils._pytree.tree_flatten(grad_output)
|
||||
for o, g_o in zip(flattened_default_output, flattened_grad_output):
|
||||
o.backward(g_o, retain_graph=True)
|
||||
|
||||
default_input_args_grad, default_input_kwargs_grad = deepcopy(self._get_grads((input_args, input_kwargs)))
|
||||
|
|
@ -388,7 +391,9 @@ class TestModule(TestCase):
|
|||
if isinstance(out, torch.Tensor):
|
||||
out.backward(g_out_copy, retain_graph=True)
|
||||
else:
|
||||
for o, g_o in zip(out, g_out_copy):
|
||||
flattened_out, _ = torch.utils._pytree.tree_flatten(out)
|
||||
flattened_g_out_copy, _ = torch.utils._pytree.tree_flatten(g_out_copy)
|
||||
for o, g_o in zip(flattened_out, flattened_g_out_copy):
|
||||
o.backward(g_o, retain_graph=True)
|
||||
|
||||
input_args_grad, input_kwargs_grad = self._get_grads((in_args, in_kwargs))
|
||||
|
|
@ -447,7 +452,9 @@ class TestModule(TestCase):
|
|||
new_kwargs = {name: obj for (name, _), obj in zip(kwarg_tensors, kwarg_args)}
|
||||
|
||||
with freeze_rng_state():
|
||||
return m(*new_input_args, **new_kwargs, **other_kwargs)
|
||||
output = m(*new_input_args, **new_kwargs, **other_kwargs)
|
||||
output_flattened, _ = torch.utils._pytree.tree_flatten(output)
|
||||
return output_flattened
|
||||
|
||||
self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
|
||||
|
||||
|
|
@ -531,7 +538,9 @@ class TestModule(TestCase):
|
|||
if isinstance(cpu_outputs, torch.Tensor):
|
||||
check_backward(cpu_outputs, gpu_outputs)
|
||||
else:
|
||||
for cpu_output, gpu_output in zip(cpu_outputs, gpu_outputs):
|
||||
flatten_cpu_outputs, _ = torch.utils._pytree.tree_flatten(cpu_outputs)
|
||||
flatten_gpu_outputs, _ = torch.utils._pytree.tree_flatten(gpu_outputs)
|
||||
for cpu_output, gpu_output in zip(flatten_cpu_outputs, flatten_gpu_outputs):
|
||||
check_backward(cpu_output, gpu_output)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -393,7 +393,7 @@ class RNN(RNNBase):
|
|||
``output.view(seq_len, batch, num_directions, hidden_size)``.
|
||||
|
||||
.. note::
|
||||
`batch_first` argument is ignored for unbatched inputs.
|
||||
``batch_first`` argument is ignored for unbatched inputs.
|
||||
|
||||
.. include:: ../cudnn_rnn_determinism.rst
|
||||
|
||||
|
|
@ -568,16 +568,19 @@ class LSTM(RNNBase):
|
|||
proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0
|
||||
|
||||
Inputs: input, (h_0, c_0)
|
||||
* **input**: tensor of shape :math:`(L, N, H_{in})` when ``batch_first=False`` or
|
||||
* **input**: tensor of shape :math:`(L, H_{in})` for unbatched input,
|
||||
:math:`(L, N, H_{in})` when ``batch_first=False`` or
|
||||
:math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
|
||||
the input sequence. The input can also be a packed variable length sequence.
|
||||
See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
|
||||
:func:`torch.nn.utils.rnn.pack_sequence` for details.
|
||||
* **h_0**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the
|
||||
initial hidden state for each element in the batch.
|
||||
* **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
|
||||
:math:`(D * \text{num\_layers}, N, H_{out})` containing the
|
||||
initial hidden state for each element in the input sequence.
|
||||
Defaults to zeros if (h_0, c_0) is not provided.
|
||||
* **c_0**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{cell})` containing the
|
||||
initial cell state for each element in the batch.
|
||||
* **c_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or
|
||||
:math:`(D * \text{num\_layers}, N, H_{cell})` containing the
|
||||
initial cell state for each element in the input sequence.
|
||||
Defaults to zeros if (h_0, c_0) is not provided.
|
||||
|
||||
where:
|
||||
|
|
@ -593,15 +596,18 @@ class LSTM(RNNBase):
|
|||
\end{aligned}
|
||||
|
||||
Outputs: output, (h_n, c_n)
|
||||
* **output**: tensor of shape :math:`(L, N, D * H_{out})` when ``batch_first=False`` or
|
||||
* **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input,
|
||||
:math:`(L, N, D * H_{out})` when ``batch_first=False`` or
|
||||
:math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
|
||||
`(h_t)` from the last layer of the LSTM, for each `t`. If a
|
||||
:class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
|
||||
will also be a packed sequence.
|
||||
* **h_n**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the
|
||||
final hidden state for each element in the batch.
|
||||
* **c_n**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{cell})` containing the
|
||||
final cell state for each element in the batch.
|
||||
* **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
|
||||
:math:`(D * \text{num\_layers}, N, H_{out})` containing the
|
||||
final hidden state for each element in the sequence.
|
||||
* **c_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or
|
||||
:math:`(D * \text{num\_layers}, N, H_{cell})` containing the
|
||||
final cell state for each element in the sequence.
|
||||
|
||||
Attributes:
|
||||
weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
|
||||
|
|
@ -639,6 +645,9 @@ class LSTM(RNNBase):
|
|||
Example of splitting the output layers when ``batch_first=False``:
|
||||
``output.view(seq_len, batch, num_directions, hidden_size)``.
|
||||
|
||||
.. note::
|
||||
``batch_first`` argument is ignored for unbatched inputs.
|
||||
|
||||
.. include:: ../cudnn_rnn_determinism.rst
|
||||
|
||||
.. include:: ../cudnn_persistent_rnn.rst
|
||||
|
|
@ -704,12 +713,17 @@ class LSTM(RNNBase):
|
|||
def forward(self, input, hx=None): # noqa: F811
|
||||
orig_input = input
|
||||
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
||||
batch_sizes = None
|
||||
if isinstance(orig_input, PackedSequence):
|
||||
input, batch_sizes, sorted_indices, unsorted_indices = input
|
||||
max_batch_size = batch_sizes[0]
|
||||
max_batch_size = int(max_batch_size)
|
||||
else:
|
||||
batch_sizes = None
|
||||
is_batched = input.dim() == 3
|
||||
batch_dim = 0 if self.batch_first else 1
|
||||
if not is_batched:
|
||||
input = input.unsqueeze(batch_dim)
|
||||
max_batch_size = input.size(0) if self.batch_first else input.size(1)
|
||||
sorted_indices = None
|
||||
unsorted_indices = None
|
||||
|
|
@ -725,6 +739,19 @@ class LSTM(RNNBase):
|
|||
dtype=input.dtype, device=input.device)
|
||||
hx = (h_zeros, c_zeros)
|
||||
else:
|
||||
if batch_sizes is None: # If not PackedSequence input.
|
||||
if is_batched:
|
||||
if (hx[0].dim() != 3 or hx[1].dim() != 3):
|
||||
msg = ("For batched 3-D input, hx and cx should "
|
||||
f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
|
||||
raise RuntimeError(msg)
|
||||
else:
|
||||
if hx[0].dim() != 2 or hx[1].dim() != 2:
|
||||
msg = ("For unbatched 2-D input, hx and cx should "
|
||||
f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
|
||||
raise RuntimeError(msg)
|
||||
hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1))
|
||||
|
||||
# Each batch of the hidden state should match the input sequence that
|
||||
# the user believes he/she is passing in.
|
||||
hx = self.permute_hidden(hx, sorted_indices)
|
||||
|
|
@ -743,6 +770,9 @@ class LSTM(RNNBase):
|
|||
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
||||
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
||||
else:
|
||||
if not is_batched:
|
||||
output = output.squeeze(batch_dim)
|
||||
hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))
|
||||
return output, self.permute_hidden(hidden, unsorted_indices)
|
||||
|
||||
|
||||
|
|
@ -844,7 +874,7 @@ class GRU(RNNBase):
|
|||
``output.view(seq_len, batch, num_directions, hidden_size)``.
|
||||
|
||||
.. note::
|
||||
`batch_first` argument is ignored for unbatched inputs.
|
||||
``batch_first`` argument is ignored for unbatched inputs.
|
||||
|
||||
.. include:: ../cudnn_persistent_rnn.rst
|
||||
|
||||
|
|
|
|||
|
|
@ -379,6 +379,30 @@ def no_batch_dim_reference_rnn_gru(m, p, *args, **kwargs):
|
|||
return (output[0].squeeze(batch_dim), output[1].squeeze(1))
|
||||
|
||||
|
||||
def no_batch_dim_reference_lstm(m, p, *args, **kwargs):
|
||||
"""Reference function for LSTM supporting no batch dimensions.
|
||||
|
||||
Unbatched inputs are unsqueezed to form a
|
||||
single batch input before passing them to the module.
|
||||
The output is squeezed to compare with the
|
||||
output of unbatched input to the module.
|
||||
"""
|
||||
if len(args) == 1:
|
||||
inp, = args
|
||||
h = None
|
||||
elif len(args) == 2:
|
||||
inp, h = args
|
||||
h = (h[0].unsqueeze(1), h[1].unsqueeze(1))
|
||||
|
||||
batch_dim = 0 if kwargs['batch_first'] else 1
|
||||
kwargs.pop('batch_first')
|
||||
inp = inp.unsqueeze(batch_dim)
|
||||
single_batch_input_args = (inp, h)
|
||||
with freeze_rng_state():
|
||||
output = m(*single_batch_input_args, **kwargs)
|
||||
return (output[0].squeeze(batch_dim), (output[1][0].squeeze(1), output[1][1].squeeze(1)))
|
||||
|
||||
|
||||
def no_batch_dim_reference_lstmcell(m, p, *args, **kwargs):
|
||||
"""Reference function for LSTMCell supporting no batch dimensions.
|
||||
|
||||
|
|
@ -901,6 +925,72 @@ def module_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, **
|
|||
return samples
|
||||
|
||||
|
||||
def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, **kwargs):
|
||||
# Currently all samples below are for validating the no-batch-dim support.
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
bias = (False, True)
|
||||
batch_first = (False, True)
|
||||
bidirectional = (False, True)
|
||||
proj_sizes = (0, 2)
|
||||
|
||||
samples = []
|
||||
prod_gen = product(bias, batch_first, bidirectional, proj_sizes)
|
||||
|
||||
for args in prod_gen:
|
||||
b, b_f, bidir, proj_size = args
|
||||
hidden_size = 3
|
||||
cons_args = {'input_size': 2, 'hidden_size': hidden_size, 'num_layers': 2, 'proj_size': proj_size,
|
||||
'batch_first': b_f, 'bias': b, 'bidirectional': bidir}
|
||||
cons_args_hidden = {'input_size': 2, 'hidden_size': hidden_size, 'num_layers': 2, 'proj_size': proj_size,
|
||||
'batch_first': b_f, 'bias': b, 'bidirectional': bidir}
|
||||
|
||||
samples.append(
|
||||
ModuleInput(
|
||||
constructor_input=FunctionInput(**cons_args),
|
||||
forward_input=FunctionInput(make_input((2, 2))),
|
||||
reference_fn=partial(no_batch_dim_reference_lstm, batch_first=b_f),
|
||||
)
|
||||
)
|
||||
|
||||
h_out = proj_size if proj_size > 0 else hidden_size
|
||||
hx = (make_input((4 if bidir else 2, h_out)), make_input((4 if bidir else 2, hidden_size)))
|
||||
samples.append(
|
||||
ModuleInput(
|
||||
constructor_input=FunctionInput(**cons_args_hidden),
|
||||
forward_input=FunctionInput(make_input((3, 2)), hx),
|
||||
reference_fn=partial(no_batch_dim_reference_lstm, batch_first=b_f),
|
||||
)
|
||||
)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
# All these operators share similar issues on cuDNN and MIOpen
|
||||
rnn_gru_lstm_module_info_decorators = (
|
||||
# RuntimeError: Batching rule not implemented for aten::_cudnn_rnn_backward.
|
||||
# We could not generate a fallback
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure, "TestModule", "test_grad",
|
||||
active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
|
||||
),
|
||||
# NotImplementedError: the derivative for '_cudnn_rnn_backward' is not implemented.
|
||||
# Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure, "TestModule", "test_gradgrad",
|
||||
active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
|
||||
),
|
||||
# CUDNN GRU doesn't accept non-contiguous hx
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
|
||||
active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
|
||||
),
|
||||
# MIOPEN GRU doesn't accept non-contiguous hx (this is dispatched to miopen only for float).
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
|
||||
active_if=(TEST_CUDNN and TEST_WITH_ROCM), dtypes=(torch.float,), device_type='cuda'
|
||||
),
|
||||
)
|
||||
|
||||
# Database of ModuleInfo entries in alphabetical order.
|
||||
module_db: List[ModuleInfo] = [
|
||||
ModuleInfo(torch.nn.AdaptiveAvgPool2d,
|
||||
|
|
@ -1192,55 +1282,12 @@ module_db: List[ModuleInfo] = [
|
|||
module_inputs_func=module_inputs_torch_nn_Sigmoid),
|
||||
ModuleInfo(torch.nn.RNN,
|
||||
module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=True),
|
||||
decorators=(
|
||||
# RuntimeError: Batching rule not implemented for aten::_cudnn_rnn_backward.
|
||||
# We could not generate a fallback
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure, "TestModule", "test_grad",
|
||||
active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
|
||||
),
|
||||
# NotImplementedError: the derivative for '_cudnn_rnn_backward' is not implemented.
|
||||
# Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure, "TestModule", "test_gradgrad",
|
||||
active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
|
||||
),
|
||||
# CUDNN RNN doesn't accept non-contiguous hx
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
|
||||
active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
|
||||
),
|
||||
# MIOPEN RNN doesn't accept non-contiguous hx (this is dispatched to miopen only for float).
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
|
||||
active_if=(TEST_CUDNN and TEST_WITH_ROCM), dtypes=(torch.float,), device_type='cuda'
|
||||
),
|
||||
)
|
||||
decorators=rnn_gru_lstm_module_info_decorators
|
||||
),
|
||||
ModuleInfo(torch.nn.GRU,
|
||||
module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=False),
|
||||
decorators=(
|
||||
# RuntimeError: Batching rule not implemented for aten::_cudnn_rnn_backward.
|
||||
# We could not generate a fallback
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure, "TestModule", "test_grad",
|
||||
active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
|
||||
),
|
||||
# NotImplementedError: the derivative for '_cudnn_rnn_backward' is not implemented.
|
||||
# Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure, "TestModule", "test_gradgrad",
|
||||
active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
|
||||
),
|
||||
# CUDNN GRU doesn't accept non-contiguous hx
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
|
||||
active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
|
||||
),
|
||||
# MIOPEN GRU doesn't accept non-contiguous hx (this is dispatched to miopen only for float).
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
|
||||
active_if=(TEST_CUDNN and TEST_WITH_ROCM), dtypes=(torch.float,), device_type='cuda'
|
||||
),
|
||||
))
|
||||
decorators=rnn_gru_lstm_module_info_decorators),
|
||||
ModuleInfo(torch.nn.LSTM,
|
||||
module_inputs_func=module_inputs_torch_nn_LSTM,
|
||||
decorators=rnn_gru_lstm_module_info_decorators)
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user