mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[nn] lstm : no batch dim support (#71056)
Summary:
Reference: https://github.com/pytorch/pytorch/issues/60585
TODO:
* [x] Update docs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71056
Reviewed By: samdow
Differential Revision: D33638643
Pulled By: jbschlosser
fbshipit-source-id: c0949829de8a8e6e7b2873f459a8d7da597a3be3
(cherry picked from commit f94d5849f6)
This commit is contained in:
parent
99d9883a22
commit
b372be4211
|
|
@ -363,8 +363,11 @@ class TestModule(TestCase):
|
||||||
grad_output = default_output.clone().detach_().normal_()
|
grad_output = default_output.clone().detach_().normal_()
|
||||||
default_output.backward(grad_output, retain_graph=True)
|
default_output.backward(grad_output, retain_graph=True)
|
||||||
else:
|
else:
|
||||||
grad_output = tuple(o.clone().detach_().normal_() for o in default_output)
|
grad_output = tuple(self._traverse_obj(o, lambda o: o.clone().detach_().normal_())
|
||||||
for o, g_o in zip(default_output, grad_output):
|
for o in default_output)
|
||||||
|
flattened_default_output, _ = torch.utils._pytree.tree_flatten(default_output)
|
||||||
|
flattened_grad_output, _ = torch.utils._pytree.tree_flatten(grad_output)
|
||||||
|
for o, g_o in zip(flattened_default_output, flattened_grad_output):
|
||||||
o.backward(g_o, retain_graph=True)
|
o.backward(g_o, retain_graph=True)
|
||||||
|
|
||||||
default_input_args_grad, default_input_kwargs_grad = deepcopy(self._get_grads((input_args, input_kwargs)))
|
default_input_args_grad, default_input_kwargs_grad = deepcopy(self._get_grads((input_args, input_kwargs)))
|
||||||
|
|
@ -388,7 +391,9 @@ class TestModule(TestCase):
|
||||||
if isinstance(out, torch.Tensor):
|
if isinstance(out, torch.Tensor):
|
||||||
out.backward(g_out_copy, retain_graph=True)
|
out.backward(g_out_copy, retain_graph=True)
|
||||||
else:
|
else:
|
||||||
for o, g_o in zip(out, g_out_copy):
|
flattened_out, _ = torch.utils._pytree.tree_flatten(out)
|
||||||
|
flattened_g_out_copy, _ = torch.utils._pytree.tree_flatten(g_out_copy)
|
||||||
|
for o, g_o in zip(flattened_out, flattened_g_out_copy):
|
||||||
o.backward(g_o, retain_graph=True)
|
o.backward(g_o, retain_graph=True)
|
||||||
|
|
||||||
input_args_grad, input_kwargs_grad = self._get_grads((in_args, in_kwargs))
|
input_args_grad, input_kwargs_grad = self._get_grads((in_args, in_kwargs))
|
||||||
|
|
@ -447,7 +452,9 @@ class TestModule(TestCase):
|
||||||
new_kwargs = {name: obj for (name, _), obj in zip(kwarg_tensors, kwarg_args)}
|
new_kwargs = {name: obj for (name, _), obj in zip(kwarg_tensors, kwarg_args)}
|
||||||
|
|
||||||
with freeze_rng_state():
|
with freeze_rng_state():
|
||||||
return m(*new_input_args, **new_kwargs, **other_kwargs)
|
output = m(*new_input_args, **new_kwargs, **other_kwargs)
|
||||||
|
output_flattened, _ = torch.utils._pytree.tree_flatten(output)
|
||||||
|
return output_flattened
|
||||||
|
|
||||||
self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
|
self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
|
||||||
|
|
||||||
|
|
@ -531,7 +538,9 @@ class TestModule(TestCase):
|
||||||
if isinstance(cpu_outputs, torch.Tensor):
|
if isinstance(cpu_outputs, torch.Tensor):
|
||||||
check_backward(cpu_outputs, gpu_outputs)
|
check_backward(cpu_outputs, gpu_outputs)
|
||||||
else:
|
else:
|
||||||
for cpu_output, gpu_output in zip(cpu_outputs, gpu_outputs):
|
flatten_cpu_outputs, _ = torch.utils._pytree.tree_flatten(cpu_outputs)
|
||||||
|
flatten_gpu_outputs, _ = torch.utils._pytree.tree_flatten(gpu_outputs)
|
||||||
|
for cpu_output, gpu_output in zip(flatten_cpu_outputs, flatten_gpu_outputs):
|
||||||
check_backward(cpu_output, gpu_output)
|
check_backward(cpu_output, gpu_output)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -393,7 +393,7 @@ class RNN(RNNBase):
|
||||||
``output.view(seq_len, batch, num_directions, hidden_size)``.
|
``output.view(seq_len, batch, num_directions, hidden_size)``.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
`batch_first` argument is ignored for unbatched inputs.
|
``batch_first`` argument is ignored for unbatched inputs.
|
||||||
|
|
||||||
.. include:: ../cudnn_rnn_determinism.rst
|
.. include:: ../cudnn_rnn_determinism.rst
|
||||||
|
|
||||||
|
|
@ -568,16 +568,19 @@ class LSTM(RNNBase):
|
||||||
proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0
|
proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0
|
||||||
|
|
||||||
Inputs: input, (h_0, c_0)
|
Inputs: input, (h_0, c_0)
|
||||||
* **input**: tensor of shape :math:`(L, N, H_{in})` when ``batch_first=False`` or
|
* **input**: tensor of shape :math:`(L, H_{in})` for unbatched input,
|
||||||
|
:math:`(L, N, H_{in})` when ``batch_first=False`` or
|
||||||
:math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
|
:math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
|
||||||
the input sequence. The input can also be a packed variable length sequence.
|
the input sequence. The input can also be a packed variable length sequence.
|
||||||
See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
|
See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
|
||||||
:func:`torch.nn.utils.rnn.pack_sequence` for details.
|
:func:`torch.nn.utils.rnn.pack_sequence` for details.
|
||||||
* **h_0**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the
|
* **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
|
||||||
initial hidden state for each element in the batch.
|
:math:`(D * \text{num\_layers}, N, H_{out})` containing the
|
||||||
|
initial hidden state for each element in the input sequence.
|
||||||
Defaults to zeros if (h_0, c_0) is not provided.
|
Defaults to zeros if (h_0, c_0) is not provided.
|
||||||
* **c_0**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{cell})` containing the
|
* **c_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or
|
||||||
initial cell state for each element in the batch.
|
:math:`(D * \text{num\_layers}, N, H_{cell})` containing the
|
||||||
|
initial cell state for each element in the input sequence.
|
||||||
Defaults to zeros if (h_0, c_0) is not provided.
|
Defaults to zeros if (h_0, c_0) is not provided.
|
||||||
|
|
||||||
where:
|
where:
|
||||||
|
|
@ -593,15 +596,18 @@ class LSTM(RNNBase):
|
||||||
\end{aligned}
|
\end{aligned}
|
||||||
|
|
||||||
Outputs: output, (h_n, c_n)
|
Outputs: output, (h_n, c_n)
|
||||||
* **output**: tensor of shape :math:`(L, N, D * H_{out})` when ``batch_first=False`` or
|
* **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input,
|
||||||
|
:math:`(L, N, D * H_{out})` when ``batch_first=False`` or
|
||||||
:math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
|
:math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
|
||||||
`(h_t)` from the last layer of the LSTM, for each `t`. If a
|
`(h_t)` from the last layer of the LSTM, for each `t`. If a
|
||||||
:class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
|
:class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
|
||||||
will also be a packed sequence.
|
will also be a packed sequence.
|
||||||
* **h_n**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the
|
* **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
|
||||||
final hidden state for each element in the batch.
|
:math:`(D * \text{num\_layers}, N, H_{out})` containing the
|
||||||
* **c_n**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{cell})` containing the
|
final hidden state for each element in the sequence.
|
||||||
final cell state for each element in the batch.
|
* **c_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or
|
||||||
|
:math:`(D * \text{num\_layers}, N, H_{cell})` containing the
|
||||||
|
final cell state for each element in the sequence.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
|
weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
|
||||||
|
|
@ -639,6 +645,9 @@ class LSTM(RNNBase):
|
||||||
Example of splitting the output layers when ``batch_first=False``:
|
Example of splitting the output layers when ``batch_first=False``:
|
||||||
``output.view(seq_len, batch, num_directions, hidden_size)``.
|
``output.view(seq_len, batch, num_directions, hidden_size)``.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
``batch_first`` argument is ignored for unbatched inputs.
|
||||||
|
|
||||||
.. include:: ../cudnn_rnn_determinism.rst
|
.. include:: ../cudnn_rnn_determinism.rst
|
||||||
|
|
||||||
.. include:: ../cudnn_persistent_rnn.rst
|
.. include:: ../cudnn_persistent_rnn.rst
|
||||||
|
|
@ -704,12 +713,17 @@ class LSTM(RNNBase):
|
||||||
def forward(self, input, hx=None): # noqa: F811
|
def forward(self, input, hx=None): # noqa: F811
|
||||||
orig_input = input
|
orig_input = input
|
||||||
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
||||||
|
batch_sizes = None
|
||||||
if isinstance(orig_input, PackedSequence):
|
if isinstance(orig_input, PackedSequence):
|
||||||
input, batch_sizes, sorted_indices, unsorted_indices = input
|
input, batch_sizes, sorted_indices, unsorted_indices = input
|
||||||
max_batch_size = batch_sizes[0]
|
max_batch_size = batch_sizes[0]
|
||||||
max_batch_size = int(max_batch_size)
|
max_batch_size = int(max_batch_size)
|
||||||
else:
|
else:
|
||||||
batch_sizes = None
|
batch_sizes = None
|
||||||
|
is_batched = input.dim() == 3
|
||||||
|
batch_dim = 0 if self.batch_first else 1
|
||||||
|
if not is_batched:
|
||||||
|
input = input.unsqueeze(batch_dim)
|
||||||
max_batch_size = input.size(0) if self.batch_first else input.size(1)
|
max_batch_size = input.size(0) if self.batch_first else input.size(1)
|
||||||
sorted_indices = None
|
sorted_indices = None
|
||||||
unsorted_indices = None
|
unsorted_indices = None
|
||||||
|
|
@ -725,6 +739,19 @@ class LSTM(RNNBase):
|
||||||
dtype=input.dtype, device=input.device)
|
dtype=input.dtype, device=input.device)
|
||||||
hx = (h_zeros, c_zeros)
|
hx = (h_zeros, c_zeros)
|
||||||
else:
|
else:
|
||||||
|
if batch_sizes is None: # If not PackedSequence input.
|
||||||
|
if is_batched:
|
||||||
|
if (hx[0].dim() != 3 or hx[1].dim() != 3):
|
||||||
|
msg = ("For batched 3-D input, hx and cx should "
|
||||||
|
f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
else:
|
||||||
|
if hx[0].dim() != 2 or hx[1].dim() != 2:
|
||||||
|
msg = ("For unbatched 2-D input, hx and cx should "
|
||||||
|
f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1))
|
||||||
|
|
||||||
# Each batch of the hidden state should match the input sequence that
|
# Each batch of the hidden state should match the input sequence that
|
||||||
# the user believes he/she is passing in.
|
# the user believes he/she is passing in.
|
||||||
hx = self.permute_hidden(hx, sorted_indices)
|
hx = self.permute_hidden(hx, sorted_indices)
|
||||||
|
|
@ -743,6 +770,9 @@ class LSTM(RNNBase):
|
||||||
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
||||||
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
||||||
else:
|
else:
|
||||||
|
if not is_batched:
|
||||||
|
output = output.squeeze(batch_dim)
|
||||||
|
hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))
|
||||||
return output, self.permute_hidden(hidden, unsorted_indices)
|
return output, self.permute_hidden(hidden, unsorted_indices)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -844,7 +874,7 @@ class GRU(RNNBase):
|
||||||
``output.view(seq_len, batch, num_directions, hidden_size)``.
|
``output.view(seq_len, batch, num_directions, hidden_size)``.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
`batch_first` argument is ignored for unbatched inputs.
|
``batch_first`` argument is ignored for unbatched inputs.
|
||||||
|
|
||||||
.. include:: ../cudnn_persistent_rnn.rst
|
.. include:: ../cudnn_persistent_rnn.rst
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -379,6 +379,30 @@ def no_batch_dim_reference_rnn_gru(m, p, *args, **kwargs):
|
||||||
return (output[0].squeeze(batch_dim), output[1].squeeze(1))
|
return (output[0].squeeze(batch_dim), output[1].squeeze(1))
|
||||||
|
|
||||||
|
|
||||||
|
def no_batch_dim_reference_lstm(m, p, *args, **kwargs):
|
||||||
|
"""Reference function for LSTM supporting no batch dimensions.
|
||||||
|
|
||||||
|
Unbatched inputs are unsqueezed to form a
|
||||||
|
single batch input before passing them to the module.
|
||||||
|
The output is squeezed to compare with the
|
||||||
|
output of unbatched input to the module.
|
||||||
|
"""
|
||||||
|
if len(args) == 1:
|
||||||
|
inp, = args
|
||||||
|
h = None
|
||||||
|
elif len(args) == 2:
|
||||||
|
inp, h = args
|
||||||
|
h = (h[0].unsqueeze(1), h[1].unsqueeze(1))
|
||||||
|
|
||||||
|
batch_dim = 0 if kwargs['batch_first'] else 1
|
||||||
|
kwargs.pop('batch_first')
|
||||||
|
inp = inp.unsqueeze(batch_dim)
|
||||||
|
single_batch_input_args = (inp, h)
|
||||||
|
with freeze_rng_state():
|
||||||
|
output = m(*single_batch_input_args, **kwargs)
|
||||||
|
return (output[0].squeeze(batch_dim), (output[1][0].squeeze(1), output[1][1].squeeze(1)))
|
||||||
|
|
||||||
|
|
||||||
def no_batch_dim_reference_lstmcell(m, p, *args, **kwargs):
|
def no_batch_dim_reference_lstmcell(m, p, *args, **kwargs):
|
||||||
"""Reference function for LSTMCell supporting no batch dimensions.
|
"""Reference function for LSTMCell supporting no batch dimensions.
|
||||||
|
|
||||||
|
|
@ -901,6 +925,72 @@ def module_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, **
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, **kwargs):
|
||||||
|
# Currently all samples below are for validating the no-batch-dim support.
|
||||||
|
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||||
|
bias = (False, True)
|
||||||
|
batch_first = (False, True)
|
||||||
|
bidirectional = (False, True)
|
||||||
|
proj_sizes = (0, 2)
|
||||||
|
|
||||||
|
samples = []
|
||||||
|
prod_gen = product(bias, batch_first, bidirectional, proj_sizes)
|
||||||
|
|
||||||
|
for args in prod_gen:
|
||||||
|
b, b_f, bidir, proj_size = args
|
||||||
|
hidden_size = 3
|
||||||
|
cons_args = {'input_size': 2, 'hidden_size': hidden_size, 'num_layers': 2, 'proj_size': proj_size,
|
||||||
|
'batch_first': b_f, 'bias': b, 'bidirectional': bidir}
|
||||||
|
cons_args_hidden = {'input_size': 2, 'hidden_size': hidden_size, 'num_layers': 2, 'proj_size': proj_size,
|
||||||
|
'batch_first': b_f, 'bias': b, 'bidirectional': bidir}
|
||||||
|
|
||||||
|
samples.append(
|
||||||
|
ModuleInput(
|
||||||
|
constructor_input=FunctionInput(**cons_args),
|
||||||
|
forward_input=FunctionInput(make_input((2, 2))),
|
||||||
|
reference_fn=partial(no_batch_dim_reference_lstm, batch_first=b_f),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
h_out = proj_size if proj_size > 0 else hidden_size
|
||||||
|
hx = (make_input((4 if bidir else 2, h_out)), make_input((4 if bidir else 2, hidden_size)))
|
||||||
|
samples.append(
|
||||||
|
ModuleInput(
|
||||||
|
constructor_input=FunctionInput(**cons_args_hidden),
|
||||||
|
forward_input=FunctionInput(make_input((3, 2)), hx),
|
||||||
|
reference_fn=partial(no_batch_dim_reference_lstm, batch_first=b_f),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
# All these operators share similar issues on cuDNN and MIOpen
|
||||||
|
rnn_gru_lstm_module_info_decorators = (
|
||||||
|
# RuntimeError: Batching rule not implemented for aten::_cudnn_rnn_backward.
|
||||||
|
# We could not generate a fallback
|
||||||
|
DecorateInfo(
|
||||||
|
unittest.expectedFailure, "TestModule", "test_grad",
|
||||||
|
active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
|
||||||
|
),
|
||||||
|
# NotImplementedError: the derivative for '_cudnn_rnn_backward' is not implemented.
|
||||||
|
# Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API
|
||||||
|
DecorateInfo(
|
||||||
|
unittest.expectedFailure, "TestModule", "test_gradgrad",
|
||||||
|
active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
|
||||||
|
),
|
||||||
|
# CUDNN GRU doesn't accept non-contiguous hx
|
||||||
|
DecorateInfo(
|
||||||
|
unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
|
||||||
|
active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
|
||||||
|
),
|
||||||
|
# MIOPEN GRU doesn't accept non-contiguous hx (this is dispatched to miopen only for float).
|
||||||
|
DecorateInfo(
|
||||||
|
unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
|
||||||
|
active_if=(TEST_CUDNN and TEST_WITH_ROCM), dtypes=(torch.float,), device_type='cuda'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# Database of ModuleInfo entries in alphabetical order.
|
# Database of ModuleInfo entries in alphabetical order.
|
||||||
module_db: List[ModuleInfo] = [
|
module_db: List[ModuleInfo] = [
|
||||||
ModuleInfo(torch.nn.AdaptiveAvgPool2d,
|
ModuleInfo(torch.nn.AdaptiveAvgPool2d,
|
||||||
|
|
@ -1192,55 +1282,12 @@ module_db: List[ModuleInfo] = [
|
||||||
module_inputs_func=module_inputs_torch_nn_Sigmoid),
|
module_inputs_func=module_inputs_torch_nn_Sigmoid),
|
||||||
ModuleInfo(torch.nn.RNN,
|
ModuleInfo(torch.nn.RNN,
|
||||||
module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=True),
|
module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=True),
|
||||||
decorators=(
|
decorators=rnn_gru_lstm_module_info_decorators
|
||||||
# RuntimeError: Batching rule not implemented for aten::_cudnn_rnn_backward.
|
|
||||||
# We could not generate a fallback
|
|
||||||
DecorateInfo(
|
|
||||||
unittest.expectedFailure, "TestModule", "test_grad",
|
|
||||||
active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
|
|
||||||
),
|
|
||||||
# NotImplementedError: the derivative for '_cudnn_rnn_backward' is not implemented.
|
|
||||||
# Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API
|
|
||||||
DecorateInfo(
|
|
||||||
unittest.expectedFailure, "TestModule", "test_gradgrad",
|
|
||||||
active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
|
|
||||||
),
|
|
||||||
# CUDNN RNN doesn't accept non-contiguous hx
|
|
||||||
DecorateInfo(
|
|
||||||
unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
|
|
||||||
active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
|
|
||||||
),
|
|
||||||
# MIOPEN RNN doesn't accept non-contiguous hx (this is dispatched to miopen only for float).
|
|
||||||
DecorateInfo(
|
|
||||||
unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
|
|
||||||
active_if=(TEST_CUDNN and TEST_WITH_ROCM), dtypes=(torch.float,), device_type='cuda'
|
|
||||||
),
|
|
||||||
)
|
|
||||||
),
|
),
|
||||||
ModuleInfo(torch.nn.GRU,
|
ModuleInfo(torch.nn.GRU,
|
||||||
module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=False),
|
module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=False),
|
||||||
decorators=(
|
decorators=rnn_gru_lstm_module_info_decorators),
|
||||||
# RuntimeError: Batching rule not implemented for aten::_cudnn_rnn_backward.
|
ModuleInfo(torch.nn.LSTM,
|
||||||
# We could not generate a fallback
|
module_inputs_func=module_inputs_torch_nn_LSTM,
|
||||||
DecorateInfo(
|
decorators=rnn_gru_lstm_module_info_decorators)
|
||||||
unittest.expectedFailure, "TestModule", "test_grad",
|
|
||||||
active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
|
|
||||||
),
|
|
||||||
# NotImplementedError: the derivative for '_cudnn_rnn_backward' is not implemented.
|
|
||||||
# Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API
|
|
||||||
DecorateInfo(
|
|
||||||
unittest.expectedFailure, "TestModule", "test_gradgrad",
|
|
||||||
active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
|
|
||||||
),
|
|
||||||
# CUDNN GRU doesn't accept non-contiguous hx
|
|
||||||
DecorateInfo(
|
|
||||||
unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
|
|
||||||
active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
|
|
||||||
),
|
|
||||||
# MIOPEN GRU doesn't accept non-contiguous hx (this is dispatched to miopen only for float).
|
|
||||||
DecorateInfo(
|
|
||||||
unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
|
|
||||||
active_if=(TEST_CUDNN and TEST_WITH_ROCM), dtypes=(torch.float,), device_type='cuda'
|
|
||||||
),
|
|
||||||
))
|
|
||||||
]
|
]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user