[C++ API] Remove deprecated torch::nn::BatchNorm / FeatureDropout / modules_ordered_dict and torch::nn::init::Nonlinearity / FanMode (#34508)

Summary:
This PR is BC-breaking in the following way:
- The deprecated `torch::nn::BatchNorm` is removed in favor of `torch::nn::BatchNorm{1,2,3}d`
- The deprecated `torch::nn::FeatureDropout` is removed in favor of `torch::nn::Dropout{2,3}d`
- The deprecated `torch::nn::modules_ordered_dict` is removed. User should do `Sequential sequential({{"m1", MyModule(1)}, {"m2", MyModule(2)}})` instead.
- The deprecated `torch::nn::init::Nonlinearity` is removed, in favor of the following enums:
    - `torch::kLinear`
    - `torch::kConv1D`
    - `torch::kConv2D`
    - `torch::kConv3D`
    - `torch::kConvTranspose1D`
    - `torch::kConvTranspose2D`
    - `torch::kConvTranspose3D`
    - `torch::kSigmoid`
    - `torch::kTanh`
    - `torch::kReLU`
    - `torch::kLeakyReLU`
- The deprecated `torch::nn::init::FanMode` is removed, in favor of the following enums:
    - `torch::kFanIn`
    - `torch::kFanOut`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34508

Differential Revision: D20351601

Pulled By: yf225

fbshipit-source-id: cca0cd112f29a31bb023e348ca8f82780e42bea3
This commit is contained in:
Will Feng 2020-03-12 10:07:03 -07:00 committed by Facebook GitHub Bot
parent e95657b87e
commit a54416d208
19 changed files with 32 additions and 493 deletions

View File

@ -612,7 +612,6 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/rnn.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/upsampling.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/container/functional.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/container/named_any.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/activation.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/adaptive.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/batchnorm.cpp

View File

@ -111,19 +111,19 @@ TEST(InitTest, CanInitializeTensorThatRequiresGrad) {
TEST(InitTest, CalculateGainWithTanh) {
double gain =
torch::nn::init::calculate_gain(torch::nn::init::Nonlinearity::Tanh);
torch::nn::init::calculate_gain(torch::kTanh);
ASSERT_DOUBLE_EQ(gain, 5.0 / 3.0);
}
TEST(InitTest, CalculateGainWithRelu) {
double gain =
torch::nn::init::calculate_gain(torch::nn::init::Nonlinearity::ReLU);
torch::nn::init::calculate_gain(torch::kReLU);
ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0));
}
TEST(InitTest, CalculateGainWithLeakyRelu) {
double gain =
torch::nn::init::calculate_gain(torch::nn::init::Nonlinearity::LeakyReLU);
torch::nn::init::calculate_gain(torch::kLeakyReLU);
ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0 / (1 + pow(0.01, 2))));
}
@ -131,41 +131,3 @@ TEST(InitTest, CanInitializeCnnWithOrthogonal) {
torch::nn::Conv2d conv_layer(torch::nn::Conv2dOptions(3, 2, 3).stride(2));
torch::nn::init::orthogonal_(conv_layer->named_parameters()["weight"]);
}
#define NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(func_name, enum_name, enum_torch_kname) \
{ \
std::stringstream buffer; \
CerrRedirect cerr_redirect(buffer.rdbuf()); \
std::cerr << torch::nn::init::func_name(torch::nn::init::Nonlinearity::enum_name) << std::endl; \
ASSERT_EQ(count_substr_occurrences(buffer.str(), enum_torch_kname), 1); \
}
#define FANMODE_ENUM_LEGACY_WARNING_CHECK(func_name, enum_name, enum_torch_kname) \
{ \
std::stringstream buffer; \
CerrRedirect cerr_redirect(buffer.rdbuf()); \
std::cerr << torch::nn::init::func_name(torch::randn({4, 5}), 0, torch::nn::init::FanMode::enum_name) << std::endl; \
ASSERT_EQ(count_substr_occurrences(buffer.str(), enum_torch_kname), 1); \
}
TEST(InitTest, NonlinearityLegacyEnum) {
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, Linear, "torch::kLinear")
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, Conv1D, "torch::kConv1D")
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, Conv2D, "torch::kConv2D")
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, Conv3D, "torch::kConv3D")
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, ConvTranspose1D, "torch::kConvTranspose1D")
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, ConvTranspose2D, "torch::kConvTranspose2D")
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, ConvTranspose3D, "torch::kConvTranspose3D")
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, Sigmoid, "torch::kSigmoid")
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, Tanh, "torch::kTanh")
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, ReLU, "torch::kReLU")
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, LeakyReLU, "torch::kLeakyReLU")
}
TEST(InitTest, FanModeLegacyEnum) {
FANMODE_ENUM_LEGACY_WARNING_CHECK(kaiming_normal_, FanIn, "torch::kFanIn")
FANMODE_ENUM_LEGACY_WARNING_CHECK(kaiming_normal_, FanOut, "torch::kFanOut")
FANMODE_ENUM_LEGACY_WARNING_CHECK(kaiming_uniform_, FanIn, "torch::kFanIn")
FANMODE_ENUM_LEGACY_WARNING_CHECK(kaiming_uniform_, FanOut, "torch::kFanOut")
}

View File

@ -284,10 +284,10 @@ TEST_F(IntegrationTest, MNISTBatchNorm_CUDA) {
torch::manual_seed(0);
auto model = std::make_shared<SimpleContainer>();
auto conv1 = model->add(Conv2d(1, 10, 5), "conv1");
auto batchnorm2d = model->add(BatchNorm(10), "batchnorm2d");
auto batchnorm2d = model->add(BatchNorm2d(10), "batchnorm2d");
auto conv2 = model->add(Conv2d(10, 20, 5), "conv2");
auto linear1 = model->add(Linear(320, 50), "linear1");
auto batchnorm1 = model->add(BatchNorm(50), "batchnorm1");
auto batchnorm1 = model->add(BatchNorm1d(50), "batchnorm1");
auto linear2 = model->add(Linear(50, 10), "linear2");
auto forward = [&](torch::Tensor x) {

View File

@ -158,7 +158,7 @@ TEST_F(ModuleListTest, SanityCheckForHoldingStandardModules) {
Linear(10, 3),
Conv2d(1, 2, 3),
Dropout(0.5),
BatchNorm(5),
BatchNorm2d(5),
Embedding(4, 10),
LSTM(4, 5));
}
@ -210,7 +210,7 @@ TEST_F(ModuleListTest, HasReferenceSemantics) {
}
TEST_F(ModuleListTest, IsCloneable) {
ModuleList list(Linear(3, 4), Functional(torch::relu), BatchNorm(3));
ModuleList list(Linear(3, 4), Functional(torch::relu), BatchNorm1d(3));
ModuleList clone = std::dynamic_pointer_cast<ModuleListImpl>(list->clone());
ASSERT_EQ(list->size(), clone->size());
@ -255,7 +255,7 @@ TEST_F(ModuleListTest, NestingIsPossible) {
}
TEST_F(ModuleListTest, CloneToDevice_CUDA) {
ModuleList list(Linear(3, 4), Functional(torch::relu), BatchNorm(3));
ModuleList list(Linear(3, 4), Functional(torch::relu), BatchNorm1d(3));
torch::Device device(torch::kCUDA, 0);
ModuleList clone =
std::dynamic_pointer_cast<ModuleListImpl>(list->clone(device));
@ -272,7 +272,7 @@ TEST_F(ModuleListTest, PrettyPrintModuleList) {
Linear(10, 3),
Conv2d(1, 2, 3),
Dropout(0.5),
BatchNorm(5),
BatchNorm2d(5),
Embedding(4, 10),
LSTM(4, 5));
ASSERT_EQ(
@ -281,7 +281,7 @@ TEST_F(ModuleListTest, PrettyPrintModuleList) {
" (0): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
" (1): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
" (2): torch::nn::Dropout(p=0.5, inplace=false)\n"
" (3): torch::nn::BatchNorm(num_features=5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
" (3): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
" (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
" (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
")");
@ -290,7 +290,7 @@ TEST_F(ModuleListTest, PrettyPrintModuleList) {
TEST_F(ModuleListTest, RangeBasedForLoop) {
torch::nn::ModuleList mlist(
torch::nn::Linear(3, 4),
torch::nn::BatchNorm(4),
torch::nn::BatchNorm1d(4),
torch::nn::Dropout(0.5)
);

View File

@ -1350,37 +1350,6 @@ TEST_F(ModulesTest, Dropout3d) {
ASSERT_EQ(y.sum().item<float>(), 100);
}
TEST_F(ModulesTest, FeatureDropout) {
FeatureDropout dropout(0.5);
torch::Tensor x = torch::ones({10, 10}, torch::requires_grad());
torch::Tensor y = dropout(x);
y.backward(torch::ones_like(y));
ASSERT_EQ(y.ndimension(), 2);
ASSERT_EQ(y.size(0), 10);
ASSERT_EQ(y.size(1), 10);
ASSERT_LT(y.sum().item<float>(), 130); // Probably
ASSERT_GT(y.sum().item<float>(), 70); // Probably
dropout->eval();
y = dropout(x);
ASSERT_EQ(y.sum().item<float>(), 100);
}
TEST_F(ModulesTest, FeatureDropoutLegacyWarning) {
std::stringstream buffer;
torch::test::CerrRedirect cerr_redirect(buffer.rdbuf());
FeatureDropout bn(0.5);
ASSERT_EQ(
count_substr_occurrences(
buffer.str(),
"torch::nn::FeatureDropout module is deprecated"
),
1);
}
TEST_F(ModulesTest, Parameters) {
auto model = std::make_shared<NestedModel>();
auto parameters = model->named_parameters();
@ -1431,74 +1400,6 @@ TEST_F(ModulesTest, FunctionalArgumentBinding) {
ASSERT_EQ(functional(torch::ones({})).item<float>(), 0);
}
TEST_F(ModulesTest, BatchNormStateful) {
BatchNorm bn(5);
// Is stateful by default.
ASSERT_TRUE(bn->options.track_running_stats());
ASSERT_TRUE(bn->running_mean.defined());
ASSERT_EQ(bn->running_mean.dim(), 1);
ASSERT_EQ(bn->running_mean.size(0), 5);
ASSERT_TRUE(bn->running_var.defined());
ASSERT_EQ(bn->running_var.dim(), 1);
ASSERT_EQ(bn->running_var.size(0), 5);
// Is affine by default.
ASSERT_TRUE(bn->options.affine());
ASSERT_TRUE(bn->weight.defined());
ASSERT_EQ(bn->weight.dim(), 1);
ASSERT_EQ(bn->weight.size(0), 5);
ASSERT_TRUE(bn->bias.defined());
ASSERT_EQ(bn->bias.dim(), 1);
ASSERT_EQ(bn->bias.size(0), 5);
}
TEST_F(ModulesTest, BatchNormStateless) {
BatchNorm bn(BatchNormOptions(5).track_running_stats(false).affine(false));
ASSERT_FALSE(bn->running_mean.defined());
ASSERT_FALSE(bn->running_var.defined());
ASSERT_FALSE(bn->weight.defined());
ASSERT_FALSE(bn->bias.defined());
ASSERT_THROWS_WITH(
bn(torch::ones({2, 5})),
"Calling BatchNorm::forward is only permitted "
"when the 'track_running_stats' option is true (was false). "
"Use BatchNorm::pure_forward instead.");
}
TEST_F(ModulesTest, BatchNormPureForward) {
BatchNorm bn(BatchNormOptions(5).affine(false));
bn->eval();
// Want to make sure we use the supplied values in `pure_forward` even if
// we are stateful.
auto input = torch::randn({2, 5});
auto mean = torch::randn(5);
auto variance = torch::rand(5);
auto output = bn->pure_forward(input, mean, variance);
auto expected = (input - mean) / torch::sqrt(variance + bn->options.eps());
ASSERT_TRUE(output.allclose(expected));
}
TEST_F(ModulesTest, BatchNormLegacyWarning) {
std::stringstream buffer;
torch::test::CerrRedirect cerr_redirect(buffer.rdbuf());
BatchNorm bn(5);
ASSERT_EQ(
count_substr_occurrences(
buffer.str(),
"torch::nn::BatchNorm module is deprecated"
),
1);
}
TEST_F(ModulesTest, BatchNorm1dStateful) {
BatchNorm1d bn(5);
@ -4087,24 +3988,10 @@ TEST_F(ModulesTest, PrettyPrintDropout3d) {
ASSERT_EQ(c10::str(Dropout3d(Dropout3dOptions().p(0.42).inplace(true))), "torch::nn::Dropout3d(p=0.42, inplace=true)");
}
TEST_F(ModulesTest, PrettyPrintFeatureDropout) {
ASSERT_EQ(c10::str(FeatureDropout()), "torch::nn::FeatureDropout(p=0.5, inplace=false)");
ASSERT_EQ(c10::str(FeatureDropout(0.42)), "torch::nn::FeatureDropout(p=0.42, inplace=false)");
ASSERT_EQ(c10::str(FeatureDropout(FeatureDropoutOptions().p(0.42).inplace(true))), "torch::nn::FeatureDropout(p=0.42, inplace=true)");
}
TEST_F(ModulesTest, PrettyPrintFunctional) {
ASSERT_EQ(c10::str(Functional(torch::relu)), "torch::nn::Functional()");
}
TEST_F(ModulesTest, PrettyPrintBatchNorm) {
ASSERT_EQ(
c10::str(BatchNorm(
BatchNormOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(
true))),
"torch::nn::BatchNorm(num_features=4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
}
TEST_F(ModulesTest, PrettyPrintBatchNorm1d) {
ASSERT_EQ(
c10::str(BatchNorm1d(

View File

@ -14,7 +14,7 @@ using namespace torch::test;
struct SequentialTest : torch::test::SeedingFixture {};
TEST_F(SequentialTest, CanContainThings) {
Sequential sequential(Linear(3, 4), ReLU(), BatchNorm(3));
Sequential sequential(Linear(3, 4), ReLU(), BatchNorm1d(3));
}
TEST_F(SequentialTest, ConstructsFromSharedPointer) {
@ -94,22 +94,6 @@ TEST_F(SequentialTest, ConstructsFromModuleHolder) {
ASSERT_EQ(sequential->size(), 3);
}
TEST_F(SequentialTest, LegacyBuilderForOrderedDictOfNamedModules) {
std::stringstream buffer;
CerrRedirect cerr_redirect(buffer.rdbuf());
Sequential sequential_named(modules_ordered_dict({
{"m1", Linear(3, 4)},
{"m2", ReLU()},
{"m3", BatchNorm(3)}
}));
ASSERT_EQ(sequential_named->size(), 3);
ASSERT_EQ(
count_substr_occurrences(buffer.str(), "`torch::nn::modules_ordered_dict` is deprecated"),
1);
}
TEST_F(SequentialTest, PushBackAddsAnElement) {
struct M : torch::nn::Module {
explicit M(int value_) : value(value_) {}
@ -291,7 +275,7 @@ TEST_F(SequentialTest, SanityCheckForHoldingStandardModules) {
Linear(10, 3),
Conv2d(1, 2, 3),
Dropout(0.5),
BatchNorm(5),
BatchNorm2d(5),
Embedding(4, 10),
LSTM(4, 5));
}
@ -358,7 +342,7 @@ TEST_F(SequentialTest, HasReferenceSemantics) {
}
TEST_F(SequentialTest, IsCloneable) {
Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm(3));
Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm1d(3));
Sequential clone =
std::dynamic_pointer_cast<SequentialImpl>(sequential->clone());
ASSERT_EQ(sequential->size(), clone->size());
@ -398,7 +382,7 @@ TEST_F(SequentialTest, RegistersElementsAsSubmodules) {
}
TEST_F(SequentialTest, CloneToDevice_CUDA) {
Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm(3));
Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm1d(3));
torch::Device device(torch::kCUDA, 0);
Sequential clone =
std::dynamic_pointer_cast<SequentialImpl>(sequential->clone(device));
@ -415,7 +399,7 @@ TEST_F(SequentialTest, PrettyPrintSequential) {
Linear(10, 3),
Conv2d(1, 2, 3),
Dropout(0.5),
BatchNorm(5),
BatchNorm2d(5),
Embedding(4, 10),
LSTM(4, 5));
ASSERT_EQ(
@ -424,7 +408,7 @@ TEST_F(SequentialTest, PrettyPrintSequential) {
" (0): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
" (1): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
" (2): torch::nn::Dropout(p=0.5, inplace=false)\n"
" (3): torch::nn::BatchNorm(num_features=5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
" (3): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
" (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
" (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
")");
@ -433,7 +417,7 @@ TEST_F(SequentialTest, PrettyPrintSequential) {
{"linear", Linear(10, 3)},
{"conv2d", Conv2d(1, 2, 3)},
{"dropout", Dropout(0.5)},
{"batchnorm", BatchNorm(5)},
{"batchnorm2d", BatchNorm2d(5)},
{"embedding", Embedding(4, 10)},
{"lstm", LSTM(4, 5)}
});
@ -443,7 +427,7 @@ TEST_F(SequentialTest, PrettyPrintSequential) {
" (linear): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
" (conv2d): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
" (dropout): torch::nn::Dropout(p=0.5, inplace=false)\n"
" (batchnorm): torch::nn::BatchNorm(num_features=5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
" (batchnorm2d): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
" (embedding): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
" (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
")");

View File

@ -254,7 +254,6 @@ torch_cpp_srcs = [
"torch/csrc/api/src/nn/modules/rnn.cpp",
"torch/csrc/api/src/nn/modules/upsampling.cpp",
"torch/csrc/api/src/nn/modules/container/functional.cpp",
"torch/csrc/api/src/nn/modules/container/named_any.cpp",
"torch/csrc/api/src/nn/options/activation.cpp",
"torch/csrc/api/src/nn/options/adaptive.cpp",
"torch/csrc/api/src/nn/options/batchnorm.cpp",

View File

@ -8,24 +8,6 @@ namespace torch {
namespace nn {
namespace init {
// This enum class is deprecated and will be removed in 1.5
enum class Nonlinearity {
Linear,
Conv1D,
Conv2D,
Conv3D,
ConvTranspose1D,
ConvTranspose2D,
ConvTranspose3D,
Sigmoid,
Tanh,
ReLU,
LeakyReLU
};
// This enum class is deprecated and will be removed in 1.5
enum class FanMode { FanIn, FanOut };
using NonlinearityType = c10::variant<
enumtype::kLinear,
enumtype::kConv1D,
@ -37,18 +19,12 @@ using NonlinearityType = c10::variant<
enumtype::kSigmoid,
enumtype::kTanh,
enumtype::kReLU,
enumtype::kLeakyReLU,
// Support for this enum class is deprecated and will be removed in 1.5.
Nonlinearity
enumtype::kLeakyReLU
>;
using FanModeType = c10::variant<
enumtype::kFanIn,
enumtype::kFanOut,
// Support for this enum class is deprecated and will be removed in 1.5.
FanMode
enumtype::kFanOut
>;
} // namespace init

View File

@ -12,73 +12,6 @@
namespace torch {
namespace nn {
/// Applies [Batch Normalization](https://arxiv.org/abs/1502.03167) to an input.
///
/// Refer to the documentation for
/// [`BatchNorm1d`](https://pytorch.org/docs/stable/nn.html#torch.nn.BatchNorm1d)
/// in PyTorch to learn more about the exact semantics of this module, __but see
/// the note below regarding differences between the Python and C++ API__.
///
/// \rst
/// .. attention::
/// In the Python API, there are separate implementations for 1-D, 2-D and 3-D
/// BatchNorm. In C++, there is only one `BatchNorm` module, which works for
/// any of these dimensions.
/// \endrst
class TORCH_API BatchNormImpl : public torch::nn::Cloneable<BatchNormImpl> {
public:
explicit BatchNormImpl(int64_t num_features)
: BatchNormImpl(BatchNormOptions(num_features)) {}
explicit BatchNormImpl(const BatchNormOptions& options_);
void reset() override;
/// Pretty prints the `BatchNorm` module into the given `stream`.
void pretty_print(std::ostream& stream) const override;
/// Applies batch normalization on the `input` using the stored mean and
/// variance.
///
/// The module must be constructed with `track_running_stats = true` when calling this
/// method, as the module will otherwise not store running statistics. If you
/// want to supply the mean and variance yourself, use `pure_forward`.
Tensor forward(const Tensor& input);
/// Applies batch normalization on the `input` using the given `mean` and
/// `variance` statistics.
Tensor pure_forward(
const Tensor& input,
const Tensor& mean,
const Tensor& variance);
/// The options with which this module was constructed.
BatchNormOptions options;
/// The learned weight.
/// Only defined if the `affine` option was `true` upon construction.
Tensor weight;
/// The learned bias.
/// Only defined if the `affine` option was `true` upon construction.
Tensor bias;
/// The running mean.
/// Only defined if the `track_running_stats` option was `true` upon construction.
Tensor running_mean;
/// The running variance.
/// Only defined if the `track_running_stats` option was `true` upon construction.
Tensor running_var;
};
/// A `ModuleHolder` subclass for `BatchNormImpl`.
/// See the documentation for `BatchNormImpl` class to learn what methods it
/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
/// module storage semantics.
TORCH_MODULE(BatchNorm);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Base class for all (dimension-specialized) batchnorm and instancenorm modules.
template <size_t D, typename Derived, typename DerivedOptions>
class NormImplBase : public torch::nn::Cloneable<Derived> {
@ -191,6 +124,8 @@ class BatchNormImplBase : public NormImplBase<D, Derived, BatchNormOptions> {
void pretty_print(std::ostream& stream) const override;
};
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Applies the BatchNorm1d function.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.BatchNorm1d to learn
/// about the exact behavior of this module.
@ -217,6 +152,8 @@ class TORCH_API BatchNorm1dImpl : public BatchNormImplBase<1, BatchNorm1dImpl> {
/// module storage semantics.
TORCH_MODULE(BatchNorm1d);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Applies the BatchNorm2d function.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.BatchNorm2d to learn
/// about the exact behavior of this module.
@ -243,6 +180,8 @@ class TORCH_API BatchNorm2dImpl : public BatchNormImplBase<2, BatchNorm2dImpl> {
/// module storage semantics.
TORCH_MODULE(BatchNorm2d);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Applies the BatchNorm3d function.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.BatchNorm3d to learn
/// about the exact behavior of this module.

View File

@ -24,7 +24,7 @@ namespace nn {
/// Sequential sequential(
/// Linear(3, 4),
/// Functional(torch::relu),
/// BatchNorm(3),
/// BatchNorm1d(3),
/// Functional(torch::elu, /*alpha=*/1));
/// \endrst
///

View File

@ -15,7 +15,7 @@ namespace nn {
///
/// torch::nn::ModuleList mlist(
/// torch::nn::Linear(3, 4),
/// torch::nn::BatchNorm(4),
/// torch::nn::BatchNorm1d(4),
/// torch::nn::Dropout(0.5)
/// );
///
@ -39,7 +39,7 @@ namespace nn {
///
/// torch::nn::ModuleList mlist(
/// torch::nn::Linear(3, 4),
/// torch::nn::BatchNorm(4),
/// torch::nn::BatchNorm1d(4),
/// torch::nn::Dropout(0.5)
/// );
///

View File

@ -91,15 +91,5 @@ class NamedAnyModule {
AnyModule module_;
};
#ifndef DOXYGEN_SHOULD_SKIP_THIS
C10_DEPRECATED_MESSAGE("`torch::nn::modules_ordered_dict` is deprecated. " \
"To construct a `Sequential` with named submodules, " \
"you can do `Sequential sequential({{\"m1\", MyModule(1)}, {\"m2\", MyModule(2)}})`")
TORCH_API torch::OrderedDict<std::string, AnyModule> modules_ordered_dict(
std::initializer_list<NamedAnyModule> named_modules);
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
} // namespace nn
} // namespace torch

View File

@ -34,7 +34,7 @@ namespace nn {
///
/// torch::nn::Sequential seq(
/// torch::nn::Linear(3, 4),
/// torch::nn::BatchNorm(4),
/// torch::nn::BatchNorm1d(4),
/// torch::nn::Dropout(0.5)
/// );
///
@ -69,7 +69,7 @@ namespace nn {
///
/// torch::nn::Sequential seq(
/// torch::nn::Linear(3, 4),
/// torch::nn::BatchNorm(4),
/// torch::nn::BatchNorm1d(4),
/// torch::nn::Dropout(0.5)
/// );
///

View File

@ -128,34 +128,6 @@ public:
/// module storage semantics.
TORCH_MODULE(Dropout3d);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FeatureDropout ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Applies spatial [Dropout](https://arxiv.org/abs/1207.0580) to inputs with
/// 2-D or 3-D features.
///
/// The equivalent in Python is
/// [Dropout2d](https://pytorch.org/docs/stable/nn.html#torch.nn.Dropout2d) for
/// 2-D features and
/// [Dropout3d](https://pytorch.org/docs/stable/nn.html#torch.nn.Dropout3d) for
/// 3-D features. This `FeatureDropout` module can instead deal with both 2-D
/// and 3-D features.
class TORCH_API FeatureDropoutImpl
: public detail::_DropoutNd<FeatureDropoutImpl> {
public:
FeatureDropoutImpl(double p);
explicit FeatureDropoutImpl(const FeatureDropoutOptions& options_ = {});
/// During training, applies a noise mask to the input tensor.
/// During evaluation, applies an identity function.
Tensor forward(const Tensor& input);
/// Pretty prints the `FeatureDropout` module into the given `stream`.
void pretty_print(std::ostream& stream) const override;
};
TORCH_MODULE(FeatureDropout);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AlphaDropout ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Applies Alpha Dropout over the input.

View File

@ -39,9 +39,6 @@ using Dropout2dOptions = DropoutOptions;
/// ```
using Dropout3dOptions = DropoutOptions;
/// Options for `FeatureDropout` module.
using FeatureDropoutOptions = DropoutOptions;
/// Options for `AlphaDropout` module.
///
/// Example:

View File

@ -35,57 +35,6 @@ struct Fan {
int64_t out;
};
#define COMPUTE_NONLINEARITY_ENUM(name) /* NOLINT(cppcoreguidelines-macro-usage) */ \
case Nonlinearity::name: \
TORCH_WARN( \
"The enum value `torch::nn::init::Nonlinearity::", #name, "` is deprecated and will be removed in 1.5. ", \
"Please use `torch::k", #name, "` instead."); \
return torch::k##name;
#define COMPUTE_FANMODE_ENUM(name) /* NOLINT(cppcoreguidelines-macro-usage) */ \
case FanMode::name: \
TORCH_WARN( \
"The enum value `torch::nn::init::FanMode::", #name, "` is deprecated and will be removed in 1.5. ", \
"Please use `torch::k", #name, "` instead."); \
return torch::k##name;
NonlinearityType _compute_nonlinearity_type(Nonlinearity nonlinearity) {
switch (nonlinearity) {
COMPUTE_NONLINEARITY_ENUM(Linear)
COMPUTE_NONLINEARITY_ENUM(Conv1D)
COMPUTE_NONLINEARITY_ENUM(Conv2D)
COMPUTE_NONLINEARITY_ENUM(Conv3D)
COMPUTE_NONLINEARITY_ENUM(ConvTranspose1D)
COMPUTE_NONLINEARITY_ENUM(ConvTranspose2D)
COMPUTE_NONLINEARITY_ENUM(ConvTranspose3D)
COMPUTE_NONLINEARITY_ENUM(Sigmoid)
COMPUTE_NONLINEARITY_ENUM(Tanh)
COMPUTE_NONLINEARITY_ENUM(ReLU)
COMPUTE_NONLINEARITY_ENUM(LeakyReLU)
default:
TORCH_INTERNAL_ASSERT(
false,
"The enum class `torch::nn::init::Nonlinearity` is deprecated, ",
"please don't add any new enum to it. ",
"Instead, add the new enum to `torch/csrc/api/include/torch/enum.h` ",
"and use `torch::kEnumName` to reference it.")
}
}
FanModeType _compute_fanmode_type(FanMode fanmode) {
switch (fanmode) {
COMPUTE_FANMODE_ENUM(FanIn);
COMPUTE_FANMODE_ENUM(FanOut);
default:
TORCH_INTERNAL_ASSERT(
false,
"The enum class `torch::nn::init::Nonlinearity` is deprecated, ",
"please don't add any new enum to it. ",
"Instead, add the new enum to `torch/csrc/api/include/torch/enum.h` ",
"and use `torch::kEnumName` to reference it.")
}
}
double calculate_kaiming_std(
Tensor tensor,
double a,
@ -96,10 +45,6 @@ double calculate_kaiming_std(
const auto gain = calculate_gain(nonlinearity, a);
double std = 0.0;
// Support for `torch::nn::init::FanMode` is deprecated and will be removed in 1.5.
if (c10::get_if<FanMode>(&mode)) {
mode = _compute_fanmode_type(c10::get<FanMode>(mode));
}
if (c10::get_if<enumtype::kFanIn>(&mode)) {
std = gain / std::sqrt(fan.in);
} else {
@ -110,10 +55,6 @@ double calculate_kaiming_std(
} // namespace
double calculate_gain(NonlinearityType nonlinearity, double param) {
// Support for `torch::nn::init::Nonlinearity` is deprecated and will be removed in 1.5.
if (c10::get_if<Nonlinearity>(&nonlinearity)) {
nonlinearity = _compute_nonlinearity_type(c10::get<Nonlinearity>(nonlinearity));
}
if (c10::get_if<enumtype::kTanh>(&nonlinearity)) {
return 5.0 / 3.0; // NOLINT
} else if (c10::get_if<enumtype::kReLU>(&nonlinearity)) {

View File

@ -16,69 +16,6 @@ namespace F = torch::nn::functional;
namespace torch {
namespace nn {
BatchNormImpl::BatchNormImpl(const BatchNormOptions& options_) : options(options_) { // NOLINT(modernize-pass-by-value)
TORCH_WARN("torch::nn::BatchNorm module is deprecated and will be removed in 1.5. "
"Use BatchNorm{1,2,3}d instead.");
reset();
}
void BatchNormImpl::reset() {
if (options.affine()) {
weight = register_parameter(
"weight", torch::empty({options.num_features()}).uniform_());
bias = register_parameter("bias", torch::zeros({options.num_features()}));
}
if (options.track_running_stats()) {
running_mean =
register_buffer("running_mean", torch::zeros({options.num_features()}));
running_var =
register_buffer("running_var", torch::ones({options.num_features()}));
}
}
void BatchNormImpl::pretty_print(std::ostream& stream) const {
stream << std::boolalpha
<< "torch::nn::BatchNorm(num_features=" << options.num_features()
<< ", eps=" << options.eps() << ", momentum=" << options.momentum().value()
<< ", affine=" << options.affine() << ", track_running_stats=" << options.track_running_stats()
<< ")";
}
Tensor BatchNormImpl::forward(const Tensor& input) {
TORCH_CHECK(
options.track_running_stats(),
"Calling BatchNorm::forward is only permitted when "
"the 'track_running_stats' option is true (was false). "
"Use BatchNorm::pure_forward instead.");
return pure_forward(input, running_mean, running_var);
}
Tensor BatchNormImpl::pure_forward(
const Tensor& input,
const Tensor& mean,
const Tensor& variance) {
if (is_training()) {
const auto num_channels = input.dim() > 1 ? input.size(1) : 1;
TORCH_CHECK(
input.numel() / num_channels > 1,
"BatchNorm expected more than 1 value per channel when training!");
}
return torch::batch_norm(
input,
weight,
bias,
mean,
variance,
is_training(),
options.momentum().value(),
options.eps(),
torch::cuda::cudnn_is_available());
}
// ===========================================================================
template <size_t D, typename Derived>
void BatchNormImplBase<D, Derived>::pretty_print(std::ostream& stream) const {
stream << std::boolalpha

View File

@ -1,20 +0,0 @@
#include <torch/nn/modules/container/named_any.h>
namespace torch {
namespace nn {
torch::OrderedDict<std::string, AnyModule> modules_ordered_dict(
std::initializer_list<NamedAnyModule> named_modules) {
TORCH_WARN(
"`torch::nn::modules_ordered_dict` is deprecated. "
"To construct a `Sequential` with named submodules, "
"you can do `Sequential sequential({{\"m1\", MyModule(1)}, {\"m2\", MyModule(2)}})`");
torch::OrderedDict<std::string, AnyModule> dict;
for (auto named_module : named_modules) {
dict.insert(named_module.name(), std::move(named_module.module()));
}
return dict;
}
} // namespace nn
} // namespace torch

View File

@ -53,30 +53,6 @@ void Dropout3dImpl::pretty_print(std::ostream& stream) const {
// ============================================================================
FeatureDropoutImpl::FeatureDropoutImpl(double p)
: detail::_DropoutNd<FeatureDropoutImpl>(p) {
TORCH_WARN("torch::nn::FeatureDropout module is deprecated."
"Use Dropout{2,3}d instead.");
}
FeatureDropoutImpl::FeatureDropoutImpl(const FeatureDropoutOptions& options_)
: detail::_DropoutNd<FeatureDropoutImpl>(options_) {
TORCH_WARN("torch::nn::FeatureDropout module is deprecated."
"Use Dropout{2,3}d instead.");
}
Tensor FeatureDropoutImpl::forward(const Tensor& input) {
return torch::feature_dropout(input, options.p(), is_training());
}
void FeatureDropoutImpl::pretty_print(std::ostream& stream) const {
stream << std::boolalpha
<< "torch::nn::FeatureDropout(p=" << options.p()
<< ", inplace=" << options.inplace() << ")";
}
// ============================================================================
Tensor AlphaDropoutImpl::forward(const Tensor& input) {
return F::detail::alpha_dropout(input, options.p(), is_training(), /*inplace=*/false);
}