[nn] lstm : no batch dim support (#71056)

Summary:
Reference: https://github.com/pytorch/pytorch/issues/60585

TODO:
* [x] Update docs

Pull Request resolved: https://github.com/pytorch/pytorch/pull/71056

Reviewed By: samdow

Differential Revision: D33638643

Pulled By: jbschlosser

fbshipit-source-id: c0949829de8a8e6e7b2873f459a8d7da597a3be3
(cherry picked from commit f94d5849f6)
This commit is contained in:
kshitij12345 2022-01-24 07:08:32 -08:00 committed by PyTorch MergeBot
parent 99d9883a22
commit b372be4211
3 changed files with 151 additions and 65 deletions

View File

@ -363,8 +363,11 @@ class TestModule(TestCase):
grad_output = default_output.clone().detach_().normal_() grad_output = default_output.clone().detach_().normal_()
default_output.backward(grad_output, retain_graph=True) default_output.backward(grad_output, retain_graph=True)
else: else:
grad_output = tuple(o.clone().detach_().normal_() for o in default_output) grad_output = tuple(self._traverse_obj(o, lambda o: o.clone().detach_().normal_())
for o, g_o in zip(default_output, grad_output): for o in default_output)
flattened_default_output, _ = torch.utils._pytree.tree_flatten(default_output)
flattened_grad_output, _ = torch.utils._pytree.tree_flatten(grad_output)
for o, g_o in zip(flattened_default_output, flattened_grad_output):
o.backward(g_o, retain_graph=True) o.backward(g_o, retain_graph=True)
default_input_args_grad, default_input_kwargs_grad = deepcopy(self._get_grads((input_args, input_kwargs))) default_input_args_grad, default_input_kwargs_grad = deepcopy(self._get_grads((input_args, input_kwargs)))
@ -388,7 +391,9 @@ class TestModule(TestCase):
if isinstance(out, torch.Tensor): if isinstance(out, torch.Tensor):
out.backward(g_out_copy, retain_graph=True) out.backward(g_out_copy, retain_graph=True)
else: else:
for o, g_o in zip(out, g_out_copy): flattened_out, _ = torch.utils._pytree.tree_flatten(out)
flattened_g_out_copy, _ = torch.utils._pytree.tree_flatten(g_out_copy)
for o, g_o in zip(flattened_out, flattened_g_out_copy):
o.backward(g_o, retain_graph=True) o.backward(g_o, retain_graph=True)
input_args_grad, input_kwargs_grad = self._get_grads((in_args, in_kwargs)) input_args_grad, input_kwargs_grad = self._get_grads((in_args, in_kwargs))
@ -447,7 +452,9 @@ class TestModule(TestCase):
new_kwargs = {name: obj for (name, _), obj in zip(kwarg_tensors, kwarg_args)} new_kwargs = {name: obj for (name, _), obj in zip(kwarg_tensors, kwarg_args)}
with freeze_rng_state(): with freeze_rng_state():
return m(*new_input_args, **new_kwargs, **other_kwargs) output = m(*new_input_args, **new_kwargs, **other_kwargs)
output_flattened, _ = torch.utils._pytree.tree_flatten(output)
return output_flattened
self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol)) self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
@ -531,7 +538,9 @@ class TestModule(TestCase):
if isinstance(cpu_outputs, torch.Tensor): if isinstance(cpu_outputs, torch.Tensor):
check_backward(cpu_outputs, gpu_outputs) check_backward(cpu_outputs, gpu_outputs)
else: else:
for cpu_output, gpu_output in zip(cpu_outputs, gpu_outputs): flatten_cpu_outputs, _ = torch.utils._pytree.tree_flatten(cpu_outputs)
flatten_gpu_outputs, _ = torch.utils._pytree.tree_flatten(gpu_outputs)
for cpu_output, gpu_output in zip(flatten_cpu_outputs, flatten_gpu_outputs):
check_backward(cpu_output, gpu_output) check_backward(cpu_output, gpu_output)

View File

@ -393,7 +393,7 @@ class RNN(RNNBase):
``output.view(seq_len, batch, num_directions, hidden_size)``. ``output.view(seq_len, batch, num_directions, hidden_size)``.
.. note:: .. note::
`batch_first` argument is ignored for unbatched inputs. ``batch_first`` argument is ignored for unbatched inputs.
.. include:: ../cudnn_rnn_determinism.rst .. include:: ../cudnn_rnn_determinism.rst
@ -568,16 +568,19 @@ class LSTM(RNNBase):
proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0 proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0
Inputs: input, (h_0, c_0) Inputs: input, (h_0, c_0)
* **input**: tensor of shape :math:`(L, N, H_{in})` when ``batch_first=False`` or * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input,
:math:`(L, N, H_{in})` when ``batch_first=False`` or
:math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
the input sequence. The input can also be a packed variable length sequence. the input sequence. The input can also be a packed variable length sequence.
See :func:`torch.nn.utils.rnn.pack_padded_sequence` or See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
:func:`torch.nn.utils.rnn.pack_sequence` for details. :func:`torch.nn.utils.rnn.pack_sequence` for details.
* **h_0**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
initial hidden state for each element in the batch. :math:`(D * \text{num\_layers}, N, H_{out})` containing the
initial hidden state for each element in the input sequence.
Defaults to zeros if (h_0, c_0) is not provided. Defaults to zeros if (h_0, c_0) is not provided.
* **c_0**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{cell})` containing the * **c_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or
initial cell state for each element in the batch. :math:`(D * \text{num\_layers}, N, H_{cell})` containing the
initial cell state for each element in the input sequence.
Defaults to zeros if (h_0, c_0) is not provided. Defaults to zeros if (h_0, c_0) is not provided.
where: where:
@ -593,15 +596,18 @@ class LSTM(RNNBase):
\end{aligned} \end{aligned}
Outputs: output, (h_n, c_n) Outputs: output, (h_n, c_n)
* **output**: tensor of shape :math:`(L, N, D * H_{out})` when ``batch_first=False`` or * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input,
:math:`(L, N, D * H_{out})` when ``batch_first=False`` or
:math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
`(h_t)` from the last layer of the LSTM, for each `t`. If a `(h_t)` from the last layer of the LSTM, for each `t`. If a
:class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
will also be a packed sequence. will also be a packed sequence.
* **h_n**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
final hidden state for each element in the batch. :math:`(D * \text{num\_layers}, N, H_{out})` containing the
* **c_n**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{cell})` containing the final hidden state for each element in the sequence.
final cell state for each element in the batch. * **c_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or
:math:`(D * \text{num\_layers}, N, H_{cell})` containing the
final cell state for each element in the sequence.
Attributes: Attributes:
weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
@ -639,6 +645,9 @@ class LSTM(RNNBase):
Example of splitting the output layers when ``batch_first=False``: Example of splitting the output layers when ``batch_first=False``:
``output.view(seq_len, batch, num_directions, hidden_size)``. ``output.view(seq_len, batch, num_directions, hidden_size)``.
.. note::
``batch_first`` argument is ignored for unbatched inputs.
.. include:: ../cudnn_rnn_determinism.rst .. include:: ../cudnn_rnn_determinism.rst
.. include:: ../cudnn_persistent_rnn.rst .. include:: ../cudnn_persistent_rnn.rst
@ -704,12 +713,17 @@ class LSTM(RNNBase):
def forward(self, input, hx=None): # noqa: F811 def forward(self, input, hx=None): # noqa: F811
orig_input = input orig_input = input
# xxx: isinstance check needs to be in conditional for TorchScript to compile # xxx: isinstance check needs to be in conditional for TorchScript to compile
batch_sizes = None
if isinstance(orig_input, PackedSequence): if isinstance(orig_input, PackedSequence):
input, batch_sizes, sorted_indices, unsorted_indices = input input, batch_sizes, sorted_indices, unsorted_indices = input
max_batch_size = batch_sizes[0] max_batch_size = batch_sizes[0]
max_batch_size = int(max_batch_size) max_batch_size = int(max_batch_size)
else: else:
batch_sizes = None batch_sizes = None
is_batched = input.dim() == 3
batch_dim = 0 if self.batch_first else 1
if not is_batched:
input = input.unsqueeze(batch_dim)
max_batch_size = input.size(0) if self.batch_first else input.size(1) max_batch_size = input.size(0) if self.batch_first else input.size(1)
sorted_indices = None sorted_indices = None
unsorted_indices = None unsorted_indices = None
@ -725,6 +739,19 @@ class LSTM(RNNBase):
dtype=input.dtype, device=input.device) dtype=input.dtype, device=input.device)
hx = (h_zeros, c_zeros) hx = (h_zeros, c_zeros)
else: else:
if batch_sizes is None: # If not PackedSequence input.
if is_batched:
if (hx[0].dim() != 3 or hx[1].dim() != 3):
msg = ("For batched 3-D input, hx and cx should "
f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
raise RuntimeError(msg)
else:
if hx[0].dim() != 2 or hx[1].dim() != 2:
msg = ("For unbatched 2-D input, hx and cx should "
f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
raise RuntimeError(msg)
hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1))
# Each batch of the hidden state should match the input sequence that # Each batch of the hidden state should match the input sequence that
# the user believes he/she is passing in. # the user believes he/she is passing in.
hx = self.permute_hidden(hx, sorted_indices) hx = self.permute_hidden(hx, sorted_indices)
@ -743,6 +770,9 @@ class LSTM(RNNBase):
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices) output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
return output_packed, self.permute_hidden(hidden, unsorted_indices) return output_packed, self.permute_hidden(hidden, unsorted_indices)
else: else:
if not is_batched:
output = output.squeeze(batch_dim)
hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))
return output, self.permute_hidden(hidden, unsorted_indices) return output, self.permute_hidden(hidden, unsorted_indices)
@ -844,7 +874,7 @@ class GRU(RNNBase):
``output.view(seq_len, batch, num_directions, hidden_size)``. ``output.view(seq_len, batch, num_directions, hidden_size)``.
.. note:: .. note::
`batch_first` argument is ignored for unbatched inputs. ``batch_first`` argument is ignored for unbatched inputs.
.. include:: ../cudnn_persistent_rnn.rst .. include:: ../cudnn_persistent_rnn.rst

View File

@ -379,6 +379,30 @@ def no_batch_dim_reference_rnn_gru(m, p, *args, **kwargs):
return (output[0].squeeze(batch_dim), output[1].squeeze(1)) return (output[0].squeeze(batch_dim), output[1].squeeze(1))
def no_batch_dim_reference_lstm(m, p, *args, **kwargs):
"""Reference function for LSTM supporting no batch dimensions.
Unbatched inputs are unsqueezed to form a
single batch input before passing them to the module.
The output is squeezed to compare with the
output of unbatched input to the module.
"""
if len(args) == 1:
inp, = args
h = None
elif len(args) == 2:
inp, h = args
h = (h[0].unsqueeze(1), h[1].unsqueeze(1))
batch_dim = 0 if kwargs['batch_first'] else 1
kwargs.pop('batch_first')
inp = inp.unsqueeze(batch_dim)
single_batch_input_args = (inp, h)
with freeze_rng_state():
output = m(*single_batch_input_args, **kwargs)
return (output[0].squeeze(batch_dim), (output[1][0].squeeze(1), output[1][1].squeeze(1)))
def no_batch_dim_reference_lstmcell(m, p, *args, **kwargs): def no_batch_dim_reference_lstmcell(m, p, *args, **kwargs):
"""Reference function for LSTMCell supporting no batch dimensions. """Reference function for LSTMCell supporting no batch dimensions.
@ -901,6 +925,72 @@ def module_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, **
return samples return samples
def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, **kwargs):
# Currently all samples below are for validating the no-batch-dim support.
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
bias = (False, True)
batch_first = (False, True)
bidirectional = (False, True)
proj_sizes = (0, 2)
samples = []
prod_gen = product(bias, batch_first, bidirectional, proj_sizes)
for args in prod_gen:
b, b_f, bidir, proj_size = args
hidden_size = 3
cons_args = {'input_size': 2, 'hidden_size': hidden_size, 'num_layers': 2, 'proj_size': proj_size,
'batch_first': b_f, 'bias': b, 'bidirectional': bidir}
cons_args_hidden = {'input_size': 2, 'hidden_size': hidden_size, 'num_layers': 2, 'proj_size': proj_size,
'batch_first': b_f, 'bias': b, 'bidirectional': bidir}
samples.append(
ModuleInput(
constructor_input=FunctionInput(**cons_args),
forward_input=FunctionInput(make_input((2, 2))),
reference_fn=partial(no_batch_dim_reference_lstm, batch_first=b_f),
)
)
h_out = proj_size if proj_size > 0 else hidden_size
hx = (make_input((4 if bidir else 2, h_out)), make_input((4 if bidir else 2, hidden_size)))
samples.append(
ModuleInput(
constructor_input=FunctionInput(**cons_args_hidden),
forward_input=FunctionInput(make_input((3, 2)), hx),
reference_fn=partial(no_batch_dim_reference_lstm, batch_first=b_f),
)
)
return samples
# All these operators share similar issues on cuDNN and MIOpen
rnn_gru_lstm_module_info_decorators = (
# RuntimeError: Batching rule not implemented for aten::_cudnn_rnn_backward.
# We could not generate a fallback
DecorateInfo(
unittest.expectedFailure, "TestModule", "test_grad",
active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
),
# NotImplementedError: the derivative for '_cudnn_rnn_backward' is not implemented.
# Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API
DecorateInfo(
unittest.expectedFailure, "TestModule", "test_gradgrad",
active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
),
# CUDNN GRU doesn't accept non-contiguous hx
DecorateInfo(
unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
),
# MIOPEN GRU doesn't accept non-contiguous hx (this is dispatched to miopen only for float).
DecorateInfo(
unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
active_if=(TEST_CUDNN and TEST_WITH_ROCM), dtypes=(torch.float,), device_type='cuda'
),
)
# Database of ModuleInfo entries in alphabetical order. # Database of ModuleInfo entries in alphabetical order.
module_db: List[ModuleInfo] = [ module_db: List[ModuleInfo] = [
ModuleInfo(torch.nn.AdaptiveAvgPool2d, ModuleInfo(torch.nn.AdaptiveAvgPool2d,
@ -1192,55 +1282,12 @@ module_db: List[ModuleInfo] = [
module_inputs_func=module_inputs_torch_nn_Sigmoid), module_inputs_func=module_inputs_torch_nn_Sigmoid),
ModuleInfo(torch.nn.RNN, ModuleInfo(torch.nn.RNN,
module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=True), module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=True),
decorators=( decorators=rnn_gru_lstm_module_info_decorators
# RuntimeError: Batching rule not implemented for aten::_cudnn_rnn_backward.
# We could not generate a fallback
DecorateInfo(
unittest.expectedFailure, "TestModule", "test_grad",
active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
),
# NotImplementedError: the derivative for '_cudnn_rnn_backward' is not implemented.
# Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API
DecorateInfo(
unittest.expectedFailure, "TestModule", "test_gradgrad",
active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
),
# CUDNN RNN doesn't accept non-contiguous hx
DecorateInfo(
unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
),
# MIOPEN RNN doesn't accept non-contiguous hx (this is dispatched to miopen only for float).
DecorateInfo(
unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
active_if=(TEST_CUDNN and TEST_WITH_ROCM), dtypes=(torch.float,), device_type='cuda'
),
)
), ),
ModuleInfo(torch.nn.GRU, ModuleInfo(torch.nn.GRU,
module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=False), module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=False),
decorators=( decorators=rnn_gru_lstm_module_info_decorators),
# RuntimeError: Batching rule not implemented for aten::_cudnn_rnn_backward. ModuleInfo(torch.nn.LSTM,
# We could not generate a fallback module_inputs_func=module_inputs_torch_nn_LSTM,
DecorateInfo( decorators=rnn_gru_lstm_module_info_decorators)
unittest.expectedFailure, "TestModule", "test_grad",
active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
),
# NotImplementedError: the derivative for '_cudnn_rnn_backward' is not implemented.
# Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API
DecorateInfo(
unittest.expectedFailure, "TestModule", "test_gradgrad",
active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
),
# CUDNN GRU doesn't accept non-contiguous hx
DecorateInfo(
unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
),
# MIOPEN GRU doesn't accept non-contiguous hx (this is dispatched to miopen only for float).
DecorateInfo(
unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
active_if=(TEST_CUDNN and TEST_WITH_ROCM), dtypes=(torch.float,), device_type='cuda'
),
))
] ]