pytorch/test/cpp
Will Feng e23a9dc140 [C++ API] RNN / GRU / LSTM layer refactoring (#34322)
Summary:
This PR refactors RNN / GRU / LSTM layers in C++ API to exactly match the implementation in Python API.

**BC-breaking changes:**
- Instead of returning `RNNOutput`, RNN / GRU forward method now returns `std::tuple<Tensor, Tensor>`, and LSTM forward method now returns `std::tuple<Tensor, std::tuple<Tensor, Tensor>>`, matching Python API.
- RNN / LSTM / GRU forward method now accepts the same inputs (input tensor and optionally hidden state), matching Python API.
- RNN / LSTM / GRU now has `forward_with_packed_input` method which accepts `PackedSequence` as input and optionally hidden state, matching the `forward(PackedSequence, ...)` variant in Python API.
- In `RNNOptions`
    - `tanh()` / `relu()` / `activation` are removed. Instead, `nonlinearity` is added which takes either `torch::kTanh` or `torch::kReLU`
    - `layers` -> `num_layers`
    - `with_bias` -> `bias`
- In `LSTMOptions`
    - `layers` -> `num_layers`
    - `with_bias` -> `bias`
- In `GRUOptions`
    - `layers` -> `num_layers`
    - `with_bias` -> `bias`

The majority of the changes in this PR focused on refactoring the implementations in `torch/csrc/api/src/nn/modules/rnn.cpp` to match the Python API. RNN tests are then changed to reflected the revised API design.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34322

Differential Revision: D20311699

Pulled By: yf225

fbshipit-source-id: e2b60fc7bac64367a8434647d74c08568a7b28f7
2020-03-14 12:09:04 -07:00
..
api [C++ API] RNN / GRU / LSTM layer refactoring (#34322) 2020-03-14 12:09:04 -07:00
common Trim libshm deps, move tempfile.h to c10 (#17019) 2019-02-13 19:38:35 -08:00
dist_autograd [Dist Autograd] Functional API for Dist Autograd and Dist Optimizer (#33711) 2020-02-26 19:08:28 -08:00
jit Move torchbind out of jit namespace (#34745) 2020-03-13 23:03:14 -07:00
rpc [pytorch-rpc] WireSerializer should check has_storage() (#34626) 2020-03-12 11:35:21 -07:00
tensorexpr [TensorExpr] Add IR Printer. (#33220) 2020-02-21 13:10:26 -08:00
__init__.py Add train() / eval() / is_training() to C++ ScriptModule API (#16044) 2019-02-01 13:07:38 -08:00