Pretty printing of C++ modules (#15326)

Summary:
A long outstanding nicety: pretty printing of C++ modules. E.g.
```
  Sequential sequential(
      Linear(10, 3),
      Conv2d(1, 2, 3),
      Dropout(0.5),
      BatchNorm(5),
      Embedding(4, 10),
      LSTM(4, 5));
std::cout << sequential;
```
prints
```
torch::nn::Sequential(
  (0): torch::nn::Linear(in=10, out=3, with_bias=true)
  (1): torch::nn::Conv2d(input_channels=1, output_channels=2, kernel_size=[3, 3], stride=[1, 1])
  (2): torch::nn::Dropout(rate=0.5)
  (3): torch::nn::BatchNorm(features=5, eps=1e-05, momentum=0.1, affine=true, stateful=true)
  (4): torch::nn::Embedding(count=4, dimension=10)
  (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)
)
```

apaszke ebetica ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15326

Differential Revision: D13518986

Pulled By: goldsborough

fbshipit-source-id: 63bf753672f0e348951de3645208f263581de5fb
This commit is contained in:
Peter Goldsborough 2018-12-19 21:38:00 -08:00 committed by Facebook Github Bot
parent 2ef0f1222a
commit eb5d28ecef
23 changed files with 306 additions and 13 deletions

View File

@ -832,3 +832,23 @@ TEST_F(ModuleTest, ThrowsWhenAttemptingtoGetTopLevelModuleAsSharedPtr) {
ASSERT_NO_THROW(module->modules());
}
}
struct EmptyModule : torch::nn::Module {};
TEST_F(ModuleTest, PrettyPrint) {
struct TestModule : torch::nn::Module {
TestModule(int x, float y) : x_(x), y_(y) {}
void pretty_print(std::ostream& stream) const {
stream << "TestModule(x=" << x_ << ", y=" << y_ << ")";
}
int x_;
float y_;
};
using namespace torch::nn;
ASSERT_EQ(c10::str(EmptyModule{}), "EmptyModule");
ASSERT_EQ(c10::str(TestModule(1, 3.14)), "TestModule(x=1, y=3.14)");
}

View File

@ -325,3 +325,85 @@ TEST_F(ModulesTest, Linear2_CUDA) {
ASSERT_EQ(model->weight.grad().numel(), 2 * 5);
}
TEST_F(ModulesTest, PrettyPrintLinear) {
ASSERT_EQ(
c10::str(Linear(3, 4)), "torch::nn::Linear(in=3, out=4, with_bias=true)");
}
TEST_F(ModulesTest, PrettyPrintConv) {
ASSERT_EQ(
c10::str(Conv1d(3, 4, 5)),
"torch::nn::Conv1d(input_channels=3, output_channels=4, kernel_size=5, stride=1)");
ASSERT_EQ(
c10::str(Conv2d(3, 4, 5)),
"torch::nn::Conv2d(input_channels=3, output_channels=4, kernel_size=[5, 5], stride=[1, 1])");
ASSERT_EQ(
c10::str(Conv2d(Conv2dOptions(3, 4, 5).stride(2))),
"torch::nn::Conv2d(input_channels=3, output_channels=4, kernel_size=[5, 5], stride=[2, 2])");
const auto options = Conv2dOptions(3, 4, torch::IntList{5, 6}).stride({1, 2});
ASSERT_EQ(
c10::str(Conv2d(options)),
"torch::nn::Conv2d(input_channels=3, output_channels=4, kernel_size=[5, 6], stride=[1, 2])");
}
TEST_F(ModulesTest, PrettyPrintDropout) {
ASSERT_EQ(c10::str(Dropout(0.5)), "torch::nn::Dropout(rate=0.5)");
ASSERT_EQ(
c10::str(FeatureDropout(0.5)), "torch::nn::FeatureDropout(rate=0.5)");
}
TEST_F(ModulesTest, PrettyPrintFunctional) {
ASSERT_EQ(c10::str(Functional(torch::relu)), "torch::nn::Functional()");
}
TEST_F(ModulesTest, PrettyPrintBatchNorm) {
ASSERT_EQ(
c10::str(BatchNorm(
BatchNormOptions(4).eps(0.5).momentum(0.1).affine(false).stateful(
true))),
"torch::nn::BatchNorm(features=4, eps=0.5, momentum=0.1, affine=false, stateful=true)");
}
TEST_F(ModulesTest, PrettyPrintEmbedding) {
ASSERT_EQ(
c10::str(Embedding(10, 2)),
"torch::nn::Embedding(count=10, dimension=2)");
}
TEST_F(ModulesTest, PrettyPrintNestedModel) {
struct InnerTestModule : torch::nn::Module {
InnerTestModule()
: torch::nn::Module("InnerTestModule"),
fc(register_module("fc", torch::nn::Linear(3, 4))),
table(register_module("table", torch::nn::Embedding(10, 2))) {}
torch::nn::Linear fc;
torch::nn::Embedding table;
};
struct TestModule : torch::nn::Module {
TestModule()
: torch::nn::Module("TestModule"),
fc(register_module("fc", torch::nn::Linear(4, 5))),
table(register_module("table", torch::nn::Embedding(10, 2))),
inner(register_module("inner", std::make_shared<InnerTestModule>())) {
}
torch::nn::Linear fc;
torch::nn::Embedding table;
std::shared_ptr<InnerTestModule> inner;
};
ASSERT_EQ(
c10::str(TestModule{}),
"TestModule(\n"
" (fc): torch::nn::Linear(in=4, out=5, with_bias=true)\n"
" (table): torch::nn::Embedding(count=10, dimension=2)\n"
" (inner): InnerTestModule(\n"
" (fc): torch::nn::Linear(in=3, out=4, with_bias=true)\n"
" (table): torch::nn::Embedding(count=10, dimension=2)\n"
" )\n"
")");
}

View File

@ -227,3 +227,15 @@ TEST_F(RNNTest, EndToEndRNNTanh_CUDA) {
ASSERT_TRUE(test_RNN_xor<RNN>(
[](int s) { return RNN(RNNOptions(s, s).tanh().layers(2)); }, true));
}
TEST_F(RNNTest, PrettyPrintRNNs) {
ASSERT_EQ(
c10::str(LSTM(LSTMOptions(128, 64).layers(3).dropout(0.2))),
"torch::nn::LSTM(input_size=128, hidden_size=64, layers=3, dropout=0.2)");
ASSERT_EQ(
c10::str(GRU(GRUOptions(128, 64).layers(3).dropout(0.5))),
"torch::nn::GRU(input_size=128, hidden_size=64, layers=3, dropout=0.5)");
ASSERT_EQ(
c10::str(RNN(RNNOptions(128, 64).layers(3).dropout(0.2).tanh())),
"torch::nn::RNN(input_size=128, hidden_size=64, layers=3, dropout=0.2, activation=tanh)");
}

View File

@ -323,3 +323,23 @@ TEST_F(SequentialTest, CloneToDevice_CUDA) {
ASSERT_EQ(b.device(), device);
}
}
TEST_F(SequentialTest, PrettyPrintSequential) {
Sequential sequential(
Linear(10, 3),
Conv2d(1, 2, 3),
Dropout(0.5),
BatchNorm(5),
Embedding(4, 10),
LSTM(4, 5));
ASSERT_EQ(
c10::str(sequential),
"torch::nn::Sequential(\n"
" (0): torch::nn::Linear(in=10, out=3, with_bias=true)\n"
" (1): torch::nn::Conv2d(input_channels=1, output_channels=2, kernel_size=[3, 3], stride=[1, 1])\n"
" (2): torch::nn::Dropout(rate=0.5)\n"
" (3): torch::nn::BatchNorm(features=5, eps=1e-05, momentum=0.1, affine=true, stateful=true)\n"
" (4): torch::nn::Embedding(count=4, dimension=10)\n"
" (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
")");
}

View File

@ -24,18 +24,17 @@ class ExpandingArray {
/// the length is checked against the `ExpandingArray`'s extent parameter `D`
/// at runtime.
/*implicit*/ ExpandingArray(std::initializer_list<T> list)
: ExpandingArray(std::vector<T>(list)) {}
: ExpandingArray(at::ArrayRef<T>(list)) {}
/// Constructs an `ExpandingArray` from a `vector`. The extent of the
/// length is checked against the `ExpandingArray`'s extent parameter `D` at
/// runtime.
/*implicit*/ ExpandingArray(const std::vector<T>& values) {
/// Constructs an `ExpandingArray` from an `initializer_list`. The extent of
/// the length is checked against the `ExpandingArray`'s extent parameter `D`
/// at runtime.
/*implicit*/ ExpandingArray(at::ArrayRef<T> values) {
// clang-format off
AT_CHECK(
values.size() == D,
"Expected ",
D,
" values, but instead got ",
values.size());
"Expected ", D, " values, but instead got ", values.size());
// clang-format on
std::copy(values.begin(), values.end(), values_.begin());
}
@ -84,4 +83,13 @@ class ExpandingArray {
std::array<T, D> values_;
};
template <size_t D, typename T>
std::ostream& operator<<(
std::ostream& stream,
const ExpandingArray<D, T>& expanding_array) {
if (expanding_array.size() == 1) {
return stream << expanding_array->at(0);
}
return stream << static_cast<at::ArrayRef<T>>(expanding_array);
}
} // namespace torch

View File

@ -8,6 +8,7 @@
#include <ATen/ATen.h>
#include <functional>
#include <iosfwd>
#include <map>
#include <memory>
#include <string>
@ -386,6 +387,15 @@ class TORCH_API Module : public std::enable_shared_from_this<Module> {
/// Deserializes the `Module` from the given `InputArchive`.
virtual void load(serialize::InputArchive& archive);
/// Streams a pretty representation of the `Module` into the given `stream`.
/// By default, this representation will be the name of the module (taken from
/// `name()`), followed by a recursive pretty print of all of the `Module`'s
/// submodules.
///
/// Override this method to change the pretty print. The input
/// `stream` should be returned from the method, to allow easy chaining.
virtual void pretty_print(std::ostream& stream) const;
protected:
/// Registers a parameter with this `Module`.
///
@ -462,6 +472,11 @@ class TORCH_API Module : public std::enable_shared_from_this<Module> {
template <typename Derived>
friend class Cloneable;
/// Pretty prints the given `Module` into the `ostream`.
TORCH_API friend std::ostream& operator<<(
std::ostream& stream,
const nn::Module& module);
// Private methods.
/// Used in the implementation of `Cloneable`.
@ -471,6 +486,11 @@ class TORCH_API Module : public std::enable_shared_from_this<Module> {
template <typename... Ts>
void to_impl(Ts&&... ts);
/// Implements pretty printing the module hierarchy.
void pretty_print_recursive(
std::ostream& stream,
const std::string& indentation) const;
/// Applies the `function` to every submodule recursively, starting at this
/// `Module`'s children (thus not including the module itself).
void apply_to_submodules(

View File

@ -53,6 +53,9 @@ class TORCH_API BatchNormImpl : public torch::nn::Cloneable<BatchNormImpl> {
void reset() override;
/// Pretty prints the `BatchNorm` module into the given `stream`.
void pretty_print(std::ostream& stream) const override;
/// Applies batch normalization on the `input` using the stored mean and
/// variance.
///

View File

@ -88,6 +88,9 @@ class ConvImpl : public torch::nn::Cloneable<Derived> {
void reset() override;
/// Pretty prints the `Conv{1,2,3}d` module into the given `stream`.
void pretty_print(std::ostream& stream) const override;
/// The options with which this `Module` was constructed.
ConvOptions<D> options;

View File

@ -39,9 +39,13 @@ class DropoutImplBase : public torch::nn::Cloneable<Derived> {
class TORCH_API DropoutImpl : public detail::DropoutImplBase<DropoutImpl> {
public:
using detail::DropoutImplBase<DropoutImpl>::DropoutImplBase;
/// During training, applies a noise mask to the input tensor.
/// During evaluation, applies an identity function.
Tensor forward(const Tensor& input);
/// Pretty prints the `Dropout` module into the given `stream`.
void pretty_print(std::ostream& stream) const override;
};
/// Applies spatial [Dropout](https://arxiv.org/abs/1207.0580) to inputs with
@ -53,12 +57,17 @@ class TORCH_API DropoutImpl : public detail::DropoutImplBase<DropoutImpl> {
/// [Dropout3d](https://pytorch.org/docs/stable/nn.html#torch.nn.Dropout3d) for
/// 3-D features. This `FeatureDropout` module can instead deal with both 2-D
/// and 3-D features.
class TORCH_API FeatureDropoutImpl : public detail::DropoutImplBase<FeatureDropoutImpl> {
class TORCH_API FeatureDropoutImpl
: public detail::DropoutImplBase<FeatureDropoutImpl> {
public:
using detail::DropoutImplBase<FeatureDropoutImpl>::DropoutImplBase;
/// During training, applies a noise mask to the input tensor.
/// During evaluation, applies an identity function.
Tensor forward(const Tensor& input);
/// Pretty prints the `FeatureDropout` module into the given `stream`.
void pretty_print(std::ostream& stream) const override;
};
/// A `ModuleHolder` subclass for `DropoutImpl`.

View File

@ -28,6 +28,9 @@ class TORCH_API EmbeddingImpl : public torch::nn::Cloneable<EmbeddingImpl> {
void reset() override;
/// Pretty prints the `Embedding` module into the given `stream`.
void pretty_print(std::ostream& stream) const override;
/// Performs a lookup on the embedding table stored in `weight` using the
/// `indices` supplied and returns the result.
Tensor forward(const Tensor& indices);

View File

@ -79,6 +79,9 @@ class TORCH_API FunctionalImpl : public torch::nn::Cloneable<FunctionalImpl> {
void reset() override;
/// Pretty prints the `Functional` module into the given `stream`.
void pretty_print(std::ostream& stream) const override;
/// Forwards the `input` tensor to the underlying (bound) function object.
Tensor forward(Tensor input);

View File

@ -29,6 +29,9 @@ class TORCH_API LinearImpl : public Cloneable<LinearImpl> {
void reset() override;
/// Pretty prints the `Linear` module into the given `stream`.
void pretty_print(std::ostream& stream) const override;
/// Transforms the `input` tensor by multiplying with the `weight` and
/// optionally adding the `bias`, if `with_bias` is true in the options.
Tensor forward(const Tensor& input);

View File

@ -74,6 +74,9 @@ class RNNImplBase : public torch::nn::Cloneable<Derived> {
void to(torch::Dtype dtype, bool non_blocking = false) override;
void to(torch::Device device, bool non_blocking = false) override;
/// Pretty prints the RNN module into the given `stream`.
void pretty_print(std::ostream& stream) const override;
/// Modifies the internal storage of weights for optimization purposes.
///
/// On CPU, this method should be called if any of the weight or bias vectors
@ -136,7 +139,7 @@ class RNNImplBase : public torch::nn::Cloneable<Derived> {
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
enum class RNNActivation : uint32_t TORCH_API { ReLU, Tanh };
enum class RNNActivation : uint32_t TORCH_API{ReLU, Tanh};
/// Options for RNN modules.
struct TORCH_API RNNOptions {
@ -177,6 +180,9 @@ class TORCH_API RNNImpl : public detail::RNNImplBase<RNNImpl> {
: RNNImpl(RNNOptions(input_size, hidden_size)) {}
explicit RNNImpl(const RNNOptions& options);
/// Pretty prints the `RNN` module into the given `stream`.
void pretty_print(std::ostream& stream) const override;
/// Applies the `RNN` module to an input sequence and input state.
/// The `input` should follow a `(sequence, batch, features)` layout unless
/// `batch_first` is true, in which case the layout should be `(batch,

View File

@ -11,6 +11,7 @@
#include <cstdint>
#include <memory>
#include <ostream>
#include <string>
#include <type_traits>
#include <utility>
@ -116,6 +117,11 @@ class SequentialImpl : public Cloneable<SequentialImpl> {
/// its own.
void reset() override {}
/// Pretty prints the `Sequential` module into the given `stream`.
void pretty_print(std::ostream& stream) const override {
stream << "torch::nn::Sequential";
}
/// Feeds `inputs` to the first module and then chains outputs to inputs,
/// returning the last output.
///

View File

@ -156,6 +156,14 @@ class ModuleHolder : torch::detail::ModuleHolderIndicator {
}
};
/// Pretty prints the given `Module` into the `ostream`.
template <typename ModuleType>
std::ostream& operator<<(
std::ostream& stream,
const nn::ModuleHolder<ModuleType>& module) {
return stream << *module;
}
/// Serializes a `ModuleHolder` into an `OutputArchive`.
template <typename ModuleType>
serialize::OutputArchive& operator<<(

View File

@ -9,6 +9,7 @@
#include <algorithm>
#include <functional>
#include <map>
#include <ostream>
#include <string>
#include <typeinfo>
@ -321,6 +322,26 @@ Tensor& Module::register_buffer(std::string name, Tensor tensor) {
return buffers_.insert(std::move(name), std::move(tensor));
}
void Module::pretty_print(std::ostream& stream) const {
stream << name();
}
void Module::pretty_print_recursive(
std::ostream& stream,
const std::string& indentation) const {
pretty_print(stream);
if (!children_.is_empty()) {
stream << "(\n";
const std::string next_indentation = indentation + " ";
for (const auto& child : children_) {
stream << next_indentation << "(" << child.key() << "): ";
child.value()->pretty_print_recursive(stream, next_indentation);
stream << '\n';
}
stream << indentation << ")";
}
}
void Module::clone_(Module& other, const optional<Device>& device) {}
void Module::apply_to_submodules(
@ -351,6 +372,11 @@ std::shared_ptr<Module> Module::shared_from_this_checked() const {
return std::const_pointer_cast<Module>(ptr);
}
std::ostream& operator<<(std::ostream& stream, const nn::Module& module) {
module.pretty_print_recursive(stream, "");
return stream;
}
serialize::OutputArchive& operator<<(
serialize::OutputArchive& archive,
const std::shared_ptr<nn::Module>& module) {

View File

@ -6,6 +6,7 @@
#include <c10/util/Exception.h>
#include <cstddef>
#include <ostream>
#include <utility>
#include <vector>
@ -32,6 +33,14 @@ void BatchNormImpl::reset() {
}
}
void BatchNormImpl::pretty_print(std::ostream& stream) const {
stream << std::boolalpha
<< "torch::nn::BatchNorm(features=" << options.features_
<< ", eps=" << options.eps_ << ", momentum=" << options.momentum_
<< ", affine=" << options.affine_ << ", stateful=" << options.stateful_
<< ")";
}
Tensor BatchNormImpl::forward(const Tensor& input) {
AT_CHECK(
options.stateful_,

View File

@ -59,6 +59,15 @@ void ConvImpl<D, Derived>::reset() {
}
}
template <size_t D, typename Derived>
void ConvImpl<D, Derived>::pretty_print(std::ostream& stream) const {
stream << "torch::nn::Conv" << D << "d"
<< "(input_channels=" << options.input_channels_
<< ", output_channels=" << options.output_channels_
<< ", kernel_size=" << options.kernel_size_
<< ", stride=" << options.stride_ << ")";
}
Tensor Conv1dImpl::forward(const Tensor& input) {
if (options.transposed_) {
return torch::conv_transpose1d(

View File

@ -5,6 +5,7 @@
#include <c10/util/Exception.h>
#include <cstddef>
#include <ostream>
#include <vector>
namespace torch {
@ -30,8 +31,16 @@ Tensor DropoutImpl::forward(const Tensor& input) {
return torch::dropout(input, options.rate_, this->is_training());
}
void DropoutImpl::pretty_print(std::ostream& stream) const {
stream << "torch::nn::Dropout(rate=" << options.rate_ << ")";
}
Tensor FeatureDropoutImpl::forward(const Tensor& input) {
return torch::feature_dropout(input, options.rate_, this->is_training());
}
void FeatureDropoutImpl::pretty_print(std::ostream& stream) const {
stream << "torch::nn::FeatureDropout(rate=" << options.rate_ << ")";
}
} // namespace nn
} // namespace torch

View File

@ -4,6 +4,7 @@
#include <torch/utils.h>
#include <cstddef>
#include <ostream>
#include <utility>
#include <vector>
@ -13,8 +14,7 @@ namespace nn {
EmbeddingOptions::EmbeddingOptions(int64_t count, int64_t dimension)
: count_(count), dimension_(dimension) {}
EmbeddingImpl::EmbeddingImpl(EmbeddingOptions options)
: options(options) {
EmbeddingImpl::EmbeddingImpl(EmbeddingOptions options) : options(options) {
reset();
}
@ -25,6 +25,11 @@ void EmbeddingImpl::reset() {
weight.normal_(0, 1);
}
void EmbeddingImpl::pretty_print(std::ostream& stream) const {
stream << "torch::nn::Embedding(count=" << options.count_
<< ", dimension=" << options.dimension_ << ")";
}
Tensor EmbeddingImpl::forward(const Tensor& input) {
return torch::embedding(weight, /*indices=*/input);
}

View File

@ -12,6 +12,10 @@ FunctionalImpl::FunctionalImpl(Function function)
void FunctionalImpl::reset() {}
void FunctionalImpl::pretty_print(std::ostream& stream) const {
stream << "torch::nn::Functional()";
}
Tensor FunctionalImpl::forward(Tensor input) {
return function_(std::move(input));
}

View File

@ -28,6 +28,12 @@ void LinearImpl::reset() {
}
}
void LinearImpl::pretty_print(std::ostream& stream) const {
stream << std::boolalpha << "torch::nn::Linear(in=" << options.in_
<< ", out=" << options.out_ << ", with_bias=" << options.with_bias_
<< ")";
}
Tensor LinearImpl::forward(const Tensor& input) {
AT_ASSERT(!options.with_bias_ || bias.defined());
return torch::linear(input, weight, bias);

View File

@ -97,6 +97,16 @@ void RNNImplBase<Derived>::to(torch::Device device, bool non_blocking) {
flatten_parameters();
}
template <typename Derived>
void RNNImplBase<Derived>::pretty_print(std::ostream& stream) const {
const std::string name = this->name();
const std::string name_without_impl = name.substr(0, name.size() - 4);
stream << name_without_impl << "(input_size=" << options.input_size_
<< ", hidden_size=" << options.hidden_size_
<< ", layers=" << options.layers_ << ", dropout=" << options.dropout_
<< ")";
}
template <typename Derived>
void RNNImplBase<Derived>::flatten_parameters() {
// Cache the flattened weight and bias vector.
@ -203,6 +213,15 @@ RNNImpl::RNNImpl(const RNNOptions& options)
static_cast<CuDNNMode>(options.activation_)),
options(options) {}
void RNNImpl::pretty_print(std::ostream& stream) const {
stream << "torch::nn::RNN(input_size=" << options.input_size_
<< ", hidden_size=" << options.hidden_size_
<< ", layers=" << options.layers_ << ", dropout=" << options.dropout_
<< ", activation="
<< (options.activation_ == RNNActivation::Tanh ? "tanh" : "relu")
<< ")";
}
RNNOutput RNNImpl::forward(const Tensor& input, Tensor state) {
switch (options.activation_) {
case RNNActivation::ReLU: