mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[C++ API] Remove deprecated torch::nn::BatchNorm / FeatureDropout / modules_ordered_dict and torch::nn::init::Nonlinearity / FanMode (#34508)
Summary:
This PR is BC-breaking in the following way:
- The deprecated `torch::nn::BatchNorm` is removed in favor of `torch::nn::BatchNorm{1,2,3}d`
- The deprecated `torch::nn::FeatureDropout` is removed in favor of `torch::nn::Dropout{2,3}d`
- The deprecated `torch::nn::modules_ordered_dict` is removed. User should do `Sequential sequential({{"m1", MyModule(1)}, {"m2", MyModule(2)}})` instead.
- The deprecated `torch::nn::init::Nonlinearity` is removed, in favor of the following enums:
- `torch::kLinear`
- `torch::kConv1D`
- `torch::kConv2D`
- `torch::kConv3D`
- `torch::kConvTranspose1D`
- `torch::kConvTranspose2D`
- `torch::kConvTranspose3D`
- `torch::kSigmoid`
- `torch::kTanh`
- `torch::kReLU`
- `torch::kLeakyReLU`
- The deprecated `torch::nn::init::FanMode` is removed, in favor of the following enums:
- `torch::kFanIn`
- `torch::kFanOut`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34508
Differential Revision: D20351601
Pulled By: yf225
fbshipit-source-id: cca0cd112f29a31bb023e348ca8f82780e42bea3
This commit is contained in:
parent
e95657b87e
commit
a54416d208
|
|
@ -612,7 +612,6 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||||
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/rnn.cpp
|
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/rnn.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/upsampling.cpp
|
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/upsampling.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/container/functional.cpp
|
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/container/functional.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/container/named_any.cpp
|
|
||||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/activation.cpp
|
${TORCH_SRC_DIR}/csrc/api/src/nn/options/activation.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/adaptive.cpp
|
${TORCH_SRC_DIR}/csrc/api/src/nn/options/adaptive.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/batchnorm.cpp
|
${TORCH_SRC_DIR}/csrc/api/src/nn/options/batchnorm.cpp
|
||||||
|
|
|
||||||
|
|
@ -111,19 +111,19 @@ TEST(InitTest, CanInitializeTensorThatRequiresGrad) {
|
||||||
|
|
||||||
TEST(InitTest, CalculateGainWithTanh) {
|
TEST(InitTest, CalculateGainWithTanh) {
|
||||||
double gain =
|
double gain =
|
||||||
torch::nn::init::calculate_gain(torch::nn::init::Nonlinearity::Tanh);
|
torch::nn::init::calculate_gain(torch::kTanh);
|
||||||
ASSERT_DOUBLE_EQ(gain, 5.0 / 3.0);
|
ASSERT_DOUBLE_EQ(gain, 5.0 / 3.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(InitTest, CalculateGainWithRelu) {
|
TEST(InitTest, CalculateGainWithRelu) {
|
||||||
double gain =
|
double gain =
|
||||||
torch::nn::init::calculate_gain(torch::nn::init::Nonlinearity::ReLU);
|
torch::nn::init::calculate_gain(torch::kReLU);
|
||||||
ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0));
|
ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(InitTest, CalculateGainWithLeakyRelu) {
|
TEST(InitTest, CalculateGainWithLeakyRelu) {
|
||||||
double gain =
|
double gain =
|
||||||
torch::nn::init::calculate_gain(torch::nn::init::Nonlinearity::LeakyReLU);
|
torch::nn::init::calculate_gain(torch::kLeakyReLU);
|
||||||
ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0 / (1 + pow(0.01, 2))));
|
ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0 / (1 + pow(0.01, 2))));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -131,41 +131,3 @@ TEST(InitTest, CanInitializeCnnWithOrthogonal) {
|
||||||
torch::nn::Conv2d conv_layer(torch::nn::Conv2dOptions(3, 2, 3).stride(2));
|
torch::nn::Conv2d conv_layer(torch::nn::Conv2dOptions(3, 2, 3).stride(2));
|
||||||
torch::nn::init::orthogonal_(conv_layer->named_parameters()["weight"]);
|
torch::nn::init::orthogonal_(conv_layer->named_parameters()["weight"]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(func_name, enum_name, enum_torch_kname) \
|
|
||||||
{ \
|
|
||||||
std::stringstream buffer; \
|
|
||||||
CerrRedirect cerr_redirect(buffer.rdbuf()); \
|
|
||||||
std::cerr << torch::nn::init::func_name(torch::nn::init::Nonlinearity::enum_name) << std::endl; \
|
|
||||||
ASSERT_EQ(count_substr_occurrences(buffer.str(), enum_torch_kname), 1); \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define FANMODE_ENUM_LEGACY_WARNING_CHECK(func_name, enum_name, enum_torch_kname) \
|
|
||||||
{ \
|
|
||||||
std::stringstream buffer; \
|
|
||||||
CerrRedirect cerr_redirect(buffer.rdbuf()); \
|
|
||||||
std::cerr << torch::nn::init::func_name(torch::randn({4, 5}), 0, torch::nn::init::FanMode::enum_name) << std::endl; \
|
|
||||||
ASSERT_EQ(count_substr_occurrences(buffer.str(), enum_torch_kname), 1); \
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(InitTest, NonlinearityLegacyEnum) {
|
|
||||||
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, Linear, "torch::kLinear")
|
|
||||||
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, Conv1D, "torch::kConv1D")
|
|
||||||
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, Conv2D, "torch::kConv2D")
|
|
||||||
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, Conv3D, "torch::kConv3D")
|
|
||||||
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, ConvTranspose1D, "torch::kConvTranspose1D")
|
|
||||||
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, ConvTranspose2D, "torch::kConvTranspose2D")
|
|
||||||
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, ConvTranspose3D, "torch::kConvTranspose3D")
|
|
||||||
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, Sigmoid, "torch::kSigmoid")
|
|
||||||
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, Tanh, "torch::kTanh")
|
|
||||||
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, ReLU, "torch::kReLU")
|
|
||||||
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, LeakyReLU, "torch::kLeakyReLU")
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(InitTest, FanModeLegacyEnum) {
|
|
||||||
FANMODE_ENUM_LEGACY_WARNING_CHECK(kaiming_normal_, FanIn, "torch::kFanIn")
|
|
||||||
FANMODE_ENUM_LEGACY_WARNING_CHECK(kaiming_normal_, FanOut, "torch::kFanOut")
|
|
||||||
|
|
||||||
FANMODE_ENUM_LEGACY_WARNING_CHECK(kaiming_uniform_, FanIn, "torch::kFanIn")
|
|
||||||
FANMODE_ENUM_LEGACY_WARNING_CHECK(kaiming_uniform_, FanOut, "torch::kFanOut")
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -284,10 +284,10 @@ TEST_F(IntegrationTest, MNISTBatchNorm_CUDA) {
|
||||||
torch::manual_seed(0);
|
torch::manual_seed(0);
|
||||||
auto model = std::make_shared<SimpleContainer>();
|
auto model = std::make_shared<SimpleContainer>();
|
||||||
auto conv1 = model->add(Conv2d(1, 10, 5), "conv1");
|
auto conv1 = model->add(Conv2d(1, 10, 5), "conv1");
|
||||||
auto batchnorm2d = model->add(BatchNorm(10), "batchnorm2d");
|
auto batchnorm2d = model->add(BatchNorm2d(10), "batchnorm2d");
|
||||||
auto conv2 = model->add(Conv2d(10, 20, 5), "conv2");
|
auto conv2 = model->add(Conv2d(10, 20, 5), "conv2");
|
||||||
auto linear1 = model->add(Linear(320, 50), "linear1");
|
auto linear1 = model->add(Linear(320, 50), "linear1");
|
||||||
auto batchnorm1 = model->add(BatchNorm(50), "batchnorm1");
|
auto batchnorm1 = model->add(BatchNorm1d(50), "batchnorm1");
|
||||||
auto linear2 = model->add(Linear(50, 10), "linear2");
|
auto linear2 = model->add(Linear(50, 10), "linear2");
|
||||||
|
|
||||||
auto forward = [&](torch::Tensor x) {
|
auto forward = [&](torch::Tensor x) {
|
||||||
|
|
|
||||||
|
|
@ -158,7 +158,7 @@ TEST_F(ModuleListTest, SanityCheckForHoldingStandardModules) {
|
||||||
Linear(10, 3),
|
Linear(10, 3),
|
||||||
Conv2d(1, 2, 3),
|
Conv2d(1, 2, 3),
|
||||||
Dropout(0.5),
|
Dropout(0.5),
|
||||||
BatchNorm(5),
|
BatchNorm2d(5),
|
||||||
Embedding(4, 10),
|
Embedding(4, 10),
|
||||||
LSTM(4, 5));
|
LSTM(4, 5));
|
||||||
}
|
}
|
||||||
|
|
@ -210,7 +210,7 @@ TEST_F(ModuleListTest, HasReferenceSemantics) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ModuleListTest, IsCloneable) {
|
TEST_F(ModuleListTest, IsCloneable) {
|
||||||
ModuleList list(Linear(3, 4), Functional(torch::relu), BatchNorm(3));
|
ModuleList list(Linear(3, 4), Functional(torch::relu), BatchNorm1d(3));
|
||||||
ModuleList clone = std::dynamic_pointer_cast<ModuleListImpl>(list->clone());
|
ModuleList clone = std::dynamic_pointer_cast<ModuleListImpl>(list->clone());
|
||||||
ASSERT_EQ(list->size(), clone->size());
|
ASSERT_EQ(list->size(), clone->size());
|
||||||
|
|
||||||
|
|
@ -255,7 +255,7 @@ TEST_F(ModuleListTest, NestingIsPossible) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ModuleListTest, CloneToDevice_CUDA) {
|
TEST_F(ModuleListTest, CloneToDevice_CUDA) {
|
||||||
ModuleList list(Linear(3, 4), Functional(torch::relu), BatchNorm(3));
|
ModuleList list(Linear(3, 4), Functional(torch::relu), BatchNorm1d(3));
|
||||||
torch::Device device(torch::kCUDA, 0);
|
torch::Device device(torch::kCUDA, 0);
|
||||||
ModuleList clone =
|
ModuleList clone =
|
||||||
std::dynamic_pointer_cast<ModuleListImpl>(list->clone(device));
|
std::dynamic_pointer_cast<ModuleListImpl>(list->clone(device));
|
||||||
|
|
@ -272,7 +272,7 @@ TEST_F(ModuleListTest, PrettyPrintModuleList) {
|
||||||
Linear(10, 3),
|
Linear(10, 3),
|
||||||
Conv2d(1, 2, 3),
|
Conv2d(1, 2, 3),
|
||||||
Dropout(0.5),
|
Dropout(0.5),
|
||||||
BatchNorm(5),
|
BatchNorm2d(5),
|
||||||
Embedding(4, 10),
|
Embedding(4, 10),
|
||||||
LSTM(4, 5));
|
LSTM(4, 5));
|
||||||
ASSERT_EQ(
|
ASSERT_EQ(
|
||||||
|
|
@ -281,7 +281,7 @@ TEST_F(ModuleListTest, PrettyPrintModuleList) {
|
||||||
" (0): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
|
" (0): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
|
||||||
" (1): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
|
" (1): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
|
||||||
" (2): torch::nn::Dropout(p=0.5, inplace=false)\n"
|
" (2): torch::nn::Dropout(p=0.5, inplace=false)\n"
|
||||||
" (3): torch::nn::BatchNorm(num_features=5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
|
" (3): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
|
||||||
" (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
|
" (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
|
||||||
" (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
|
" (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
|
||||||
")");
|
")");
|
||||||
|
|
@ -290,7 +290,7 @@ TEST_F(ModuleListTest, PrettyPrintModuleList) {
|
||||||
TEST_F(ModuleListTest, RangeBasedForLoop) {
|
TEST_F(ModuleListTest, RangeBasedForLoop) {
|
||||||
torch::nn::ModuleList mlist(
|
torch::nn::ModuleList mlist(
|
||||||
torch::nn::Linear(3, 4),
|
torch::nn::Linear(3, 4),
|
||||||
torch::nn::BatchNorm(4),
|
torch::nn::BatchNorm1d(4),
|
||||||
torch::nn::Dropout(0.5)
|
torch::nn::Dropout(0.5)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1350,37 +1350,6 @@ TEST_F(ModulesTest, Dropout3d) {
|
||||||
ASSERT_EQ(y.sum().item<float>(), 100);
|
ASSERT_EQ(y.sum().item<float>(), 100);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ModulesTest, FeatureDropout) {
|
|
||||||
FeatureDropout dropout(0.5);
|
|
||||||
torch::Tensor x = torch::ones({10, 10}, torch::requires_grad());
|
|
||||||
torch::Tensor y = dropout(x);
|
|
||||||
|
|
||||||
y.backward(torch::ones_like(y));
|
|
||||||
ASSERT_EQ(y.ndimension(), 2);
|
|
||||||
ASSERT_EQ(y.size(0), 10);
|
|
||||||
ASSERT_EQ(y.size(1), 10);
|
|
||||||
ASSERT_LT(y.sum().item<float>(), 130); // Probably
|
|
||||||
ASSERT_GT(y.sum().item<float>(), 70); // Probably
|
|
||||||
|
|
||||||
dropout->eval();
|
|
||||||
y = dropout(x);
|
|
||||||
ASSERT_EQ(y.sum().item<float>(), 100);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ModulesTest, FeatureDropoutLegacyWarning) {
|
|
||||||
std::stringstream buffer;
|
|
||||||
torch::test::CerrRedirect cerr_redirect(buffer.rdbuf());
|
|
||||||
|
|
||||||
FeatureDropout bn(0.5);
|
|
||||||
|
|
||||||
ASSERT_EQ(
|
|
||||||
count_substr_occurrences(
|
|
||||||
buffer.str(),
|
|
||||||
"torch::nn::FeatureDropout module is deprecated"
|
|
||||||
),
|
|
||||||
1);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ModulesTest, Parameters) {
|
TEST_F(ModulesTest, Parameters) {
|
||||||
auto model = std::make_shared<NestedModel>();
|
auto model = std::make_shared<NestedModel>();
|
||||||
auto parameters = model->named_parameters();
|
auto parameters = model->named_parameters();
|
||||||
|
|
@ -1431,74 +1400,6 @@ TEST_F(ModulesTest, FunctionalArgumentBinding) {
|
||||||
ASSERT_EQ(functional(torch::ones({})).item<float>(), 0);
|
ASSERT_EQ(functional(torch::ones({})).item<float>(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ModulesTest, BatchNormStateful) {
|
|
||||||
BatchNorm bn(5);
|
|
||||||
|
|
||||||
// Is stateful by default.
|
|
||||||
ASSERT_TRUE(bn->options.track_running_stats());
|
|
||||||
|
|
||||||
ASSERT_TRUE(bn->running_mean.defined());
|
|
||||||
ASSERT_EQ(bn->running_mean.dim(), 1);
|
|
||||||
ASSERT_EQ(bn->running_mean.size(0), 5);
|
|
||||||
|
|
||||||
ASSERT_TRUE(bn->running_var.defined());
|
|
||||||
ASSERT_EQ(bn->running_var.dim(), 1);
|
|
||||||
ASSERT_EQ(bn->running_var.size(0), 5);
|
|
||||||
|
|
||||||
// Is affine by default.
|
|
||||||
ASSERT_TRUE(bn->options.affine());
|
|
||||||
|
|
||||||
ASSERT_TRUE(bn->weight.defined());
|
|
||||||
ASSERT_EQ(bn->weight.dim(), 1);
|
|
||||||
ASSERT_EQ(bn->weight.size(0), 5);
|
|
||||||
|
|
||||||
ASSERT_TRUE(bn->bias.defined());
|
|
||||||
ASSERT_EQ(bn->bias.dim(), 1);
|
|
||||||
ASSERT_EQ(bn->bias.size(0), 5);
|
|
||||||
}
|
|
||||||
TEST_F(ModulesTest, BatchNormStateless) {
|
|
||||||
BatchNorm bn(BatchNormOptions(5).track_running_stats(false).affine(false));
|
|
||||||
|
|
||||||
ASSERT_FALSE(bn->running_mean.defined());
|
|
||||||
ASSERT_FALSE(bn->running_var.defined());
|
|
||||||
ASSERT_FALSE(bn->weight.defined());
|
|
||||||
ASSERT_FALSE(bn->bias.defined());
|
|
||||||
|
|
||||||
ASSERT_THROWS_WITH(
|
|
||||||
bn(torch::ones({2, 5})),
|
|
||||||
"Calling BatchNorm::forward is only permitted "
|
|
||||||
"when the 'track_running_stats' option is true (was false). "
|
|
||||||
"Use BatchNorm::pure_forward instead.");
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ModulesTest, BatchNormPureForward) {
|
|
||||||
BatchNorm bn(BatchNormOptions(5).affine(false));
|
|
||||||
bn->eval();
|
|
||||||
|
|
||||||
// Want to make sure we use the supplied values in `pure_forward` even if
|
|
||||||
// we are stateful.
|
|
||||||
auto input = torch::randn({2, 5});
|
|
||||||
auto mean = torch::randn(5);
|
|
||||||
auto variance = torch::rand(5);
|
|
||||||
auto output = bn->pure_forward(input, mean, variance);
|
|
||||||
auto expected = (input - mean) / torch::sqrt(variance + bn->options.eps());
|
|
||||||
ASSERT_TRUE(output.allclose(expected));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ModulesTest, BatchNormLegacyWarning) {
|
|
||||||
std::stringstream buffer;
|
|
||||||
torch::test::CerrRedirect cerr_redirect(buffer.rdbuf());
|
|
||||||
|
|
||||||
BatchNorm bn(5);
|
|
||||||
|
|
||||||
ASSERT_EQ(
|
|
||||||
count_substr_occurrences(
|
|
||||||
buffer.str(),
|
|
||||||
"torch::nn::BatchNorm module is deprecated"
|
|
||||||
),
|
|
||||||
1);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ModulesTest, BatchNorm1dStateful) {
|
TEST_F(ModulesTest, BatchNorm1dStateful) {
|
||||||
BatchNorm1d bn(5);
|
BatchNorm1d bn(5);
|
||||||
|
|
||||||
|
|
@ -4087,24 +3988,10 @@ TEST_F(ModulesTest, PrettyPrintDropout3d) {
|
||||||
ASSERT_EQ(c10::str(Dropout3d(Dropout3dOptions().p(0.42).inplace(true))), "torch::nn::Dropout3d(p=0.42, inplace=true)");
|
ASSERT_EQ(c10::str(Dropout3d(Dropout3dOptions().p(0.42).inplace(true))), "torch::nn::Dropout3d(p=0.42, inplace=true)");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ModulesTest, PrettyPrintFeatureDropout) {
|
|
||||||
ASSERT_EQ(c10::str(FeatureDropout()), "torch::nn::FeatureDropout(p=0.5, inplace=false)");
|
|
||||||
ASSERT_EQ(c10::str(FeatureDropout(0.42)), "torch::nn::FeatureDropout(p=0.42, inplace=false)");
|
|
||||||
ASSERT_EQ(c10::str(FeatureDropout(FeatureDropoutOptions().p(0.42).inplace(true))), "torch::nn::FeatureDropout(p=0.42, inplace=true)");
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ModulesTest, PrettyPrintFunctional) {
|
TEST_F(ModulesTest, PrettyPrintFunctional) {
|
||||||
ASSERT_EQ(c10::str(Functional(torch::relu)), "torch::nn::Functional()");
|
ASSERT_EQ(c10::str(Functional(torch::relu)), "torch::nn::Functional()");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ModulesTest, PrettyPrintBatchNorm) {
|
|
||||||
ASSERT_EQ(
|
|
||||||
c10::str(BatchNorm(
|
|
||||||
BatchNormOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(
|
|
||||||
true))),
|
|
||||||
"torch::nn::BatchNorm(num_features=4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ModulesTest, PrettyPrintBatchNorm1d) {
|
TEST_F(ModulesTest, PrettyPrintBatchNorm1d) {
|
||||||
ASSERT_EQ(
|
ASSERT_EQ(
|
||||||
c10::str(BatchNorm1d(
|
c10::str(BatchNorm1d(
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ using namespace torch::test;
|
||||||
struct SequentialTest : torch::test::SeedingFixture {};
|
struct SequentialTest : torch::test::SeedingFixture {};
|
||||||
|
|
||||||
TEST_F(SequentialTest, CanContainThings) {
|
TEST_F(SequentialTest, CanContainThings) {
|
||||||
Sequential sequential(Linear(3, 4), ReLU(), BatchNorm(3));
|
Sequential sequential(Linear(3, 4), ReLU(), BatchNorm1d(3));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SequentialTest, ConstructsFromSharedPointer) {
|
TEST_F(SequentialTest, ConstructsFromSharedPointer) {
|
||||||
|
|
@ -94,22 +94,6 @@ TEST_F(SequentialTest, ConstructsFromModuleHolder) {
|
||||||
ASSERT_EQ(sequential->size(), 3);
|
ASSERT_EQ(sequential->size(), 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SequentialTest, LegacyBuilderForOrderedDictOfNamedModules) {
|
|
||||||
std::stringstream buffer;
|
|
||||||
CerrRedirect cerr_redirect(buffer.rdbuf());
|
|
||||||
|
|
||||||
Sequential sequential_named(modules_ordered_dict({
|
|
||||||
{"m1", Linear(3, 4)},
|
|
||||||
{"m2", ReLU()},
|
|
||||||
{"m3", BatchNorm(3)}
|
|
||||||
}));
|
|
||||||
ASSERT_EQ(sequential_named->size(), 3);
|
|
||||||
|
|
||||||
ASSERT_EQ(
|
|
||||||
count_substr_occurrences(buffer.str(), "`torch::nn::modules_ordered_dict` is deprecated"),
|
|
||||||
1);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(SequentialTest, PushBackAddsAnElement) {
|
TEST_F(SequentialTest, PushBackAddsAnElement) {
|
||||||
struct M : torch::nn::Module {
|
struct M : torch::nn::Module {
|
||||||
explicit M(int value_) : value(value_) {}
|
explicit M(int value_) : value(value_) {}
|
||||||
|
|
@ -291,7 +275,7 @@ TEST_F(SequentialTest, SanityCheckForHoldingStandardModules) {
|
||||||
Linear(10, 3),
|
Linear(10, 3),
|
||||||
Conv2d(1, 2, 3),
|
Conv2d(1, 2, 3),
|
||||||
Dropout(0.5),
|
Dropout(0.5),
|
||||||
BatchNorm(5),
|
BatchNorm2d(5),
|
||||||
Embedding(4, 10),
|
Embedding(4, 10),
|
||||||
LSTM(4, 5));
|
LSTM(4, 5));
|
||||||
}
|
}
|
||||||
|
|
@ -358,7 +342,7 @@ TEST_F(SequentialTest, HasReferenceSemantics) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SequentialTest, IsCloneable) {
|
TEST_F(SequentialTest, IsCloneable) {
|
||||||
Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm(3));
|
Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm1d(3));
|
||||||
Sequential clone =
|
Sequential clone =
|
||||||
std::dynamic_pointer_cast<SequentialImpl>(sequential->clone());
|
std::dynamic_pointer_cast<SequentialImpl>(sequential->clone());
|
||||||
ASSERT_EQ(sequential->size(), clone->size());
|
ASSERT_EQ(sequential->size(), clone->size());
|
||||||
|
|
@ -398,7 +382,7 @@ TEST_F(SequentialTest, RegistersElementsAsSubmodules) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SequentialTest, CloneToDevice_CUDA) {
|
TEST_F(SequentialTest, CloneToDevice_CUDA) {
|
||||||
Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm(3));
|
Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm1d(3));
|
||||||
torch::Device device(torch::kCUDA, 0);
|
torch::Device device(torch::kCUDA, 0);
|
||||||
Sequential clone =
|
Sequential clone =
|
||||||
std::dynamic_pointer_cast<SequentialImpl>(sequential->clone(device));
|
std::dynamic_pointer_cast<SequentialImpl>(sequential->clone(device));
|
||||||
|
|
@ -415,7 +399,7 @@ TEST_F(SequentialTest, PrettyPrintSequential) {
|
||||||
Linear(10, 3),
|
Linear(10, 3),
|
||||||
Conv2d(1, 2, 3),
|
Conv2d(1, 2, 3),
|
||||||
Dropout(0.5),
|
Dropout(0.5),
|
||||||
BatchNorm(5),
|
BatchNorm2d(5),
|
||||||
Embedding(4, 10),
|
Embedding(4, 10),
|
||||||
LSTM(4, 5));
|
LSTM(4, 5));
|
||||||
ASSERT_EQ(
|
ASSERT_EQ(
|
||||||
|
|
@ -424,7 +408,7 @@ TEST_F(SequentialTest, PrettyPrintSequential) {
|
||||||
" (0): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
|
" (0): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
|
||||||
" (1): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
|
" (1): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
|
||||||
" (2): torch::nn::Dropout(p=0.5, inplace=false)\n"
|
" (2): torch::nn::Dropout(p=0.5, inplace=false)\n"
|
||||||
" (3): torch::nn::BatchNorm(num_features=5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
|
" (3): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
|
||||||
" (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
|
" (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
|
||||||
" (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
|
" (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
|
||||||
")");
|
")");
|
||||||
|
|
@ -433,7 +417,7 @@ TEST_F(SequentialTest, PrettyPrintSequential) {
|
||||||
{"linear", Linear(10, 3)},
|
{"linear", Linear(10, 3)},
|
||||||
{"conv2d", Conv2d(1, 2, 3)},
|
{"conv2d", Conv2d(1, 2, 3)},
|
||||||
{"dropout", Dropout(0.5)},
|
{"dropout", Dropout(0.5)},
|
||||||
{"batchnorm", BatchNorm(5)},
|
{"batchnorm2d", BatchNorm2d(5)},
|
||||||
{"embedding", Embedding(4, 10)},
|
{"embedding", Embedding(4, 10)},
|
||||||
{"lstm", LSTM(4, 5)}
|
{"lstm", LSTM(4, 5)}
|
||||||
});
|
});
|
||||||
|
|
@ -443,7 +427,7 @@ TEST_F(SequentialTest, PrettyPrintSequential) {
|
||||||
" (linear): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
|
" (linear): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
|
||||||
" (conv2d): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
|
" (conv2d): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
|
||||||
" (dropout): torch::nn::Dropout(p=0.5, inplace=false)\n"
|
" (dropout): torch::nn::Dropout(p=0.5, inplace=false)\n"
|
||||||
" (batchnorm): torch::nn::BatchNorm(num_features=5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
|
" (batchnorm2d): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
|
||||||
" (embedding): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
|
" (embedding): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
|
||||||
" (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
|
" (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
|
||||||
")");
|
")");
|
||||||
|
|
|
||||||
|
|
@ -254,7 +254,6 @@ torch_cpp_srcs = [
|
||||||
"torch/csrc/api/src/nn/modules/rnn.cpp",
|
"torch/csrc/api/src/nn/modules/rnn.cpp",
|
||||||
"torch/csrc/api/src/nn/modules/upsampling.cpp",
|
"torch/csrc/api/src/nn/modules/upsampling.cpp",
|
||||||
"torch/csrc/api/src/nn/modules/container/functional.cpp",
|
"torch/csrc/api/src/nn/modules/container/functional.cpp",
|
||||||
"torch/csrc/api/src/nn/modules/container/named_any.cpp",
|
|
||||||
"torch/csrc/api/src/nn/options/activation.cpp",
|
"torch/csrc/api/src/nn/options/activation.cpp",
|
||||||
"torch/csrc/api/src/nn/options/adaptive.cpp",
|
"torch/csrc/api/src/nn/options/adaptive.cpp",
|
||||||
"torch/csrc/api/src/nn/options/batchnorm.cpp",
|
"torch/csrc/api/src/nn/options/batchnorm.cpp",
|
||||||
|
|
|
||||||
|
|
@ -8,24 +8,6 @@ namespace torch {
|
||||||
namespace nn {
|
namespace nn {
|
||||||
namespace init {
|
namespace init {
|
||||||
|
|
||||||
// This enum class is deprecated and will be removed in 1.5
|
|
||||||
enum class Nonlinearity {
|
|
||||||
Linear,
|
|
||||||
Conv1D,
|
|
||||||
Conv2D,
|
|
||||||
Conv3D,
|
|
||||||
ConvTranspose1D,
|
|
||||||
ConvTranspose2D,
|
|
||||||
ConvTranspose3D,
|
|
||||||
Sigmoid,
|
|
||||||
Tanh,
|
|
||||||
ReLU,
|
|
||||||
LeakyReLU
|
|
||||||
};
|
|
||||||
|
|
||||||
// This enum class is deprecated and will be removed in 1.5
|
|
||||||
enum class FanMode { FanIn, FanOut };
|
|
||||||
|
|
||||||
using NonlinearityType = c10::variant<
|
using NonlinearityType = c10::variant<
|
||||||
enumtype::kLinear,
|
enumtype::kLinear,
|
||||||
enumtype::kConv1D,
|
enumtype::kConv1D,
|
||||||
|
|
@ -37,18 +19,12 @@ using NonlinearityType = c10::variant<
|
||||||
enumtype::kSigmoid,
|
enumtype::kSigmoid,
|
||||||
enumtype::kTanh,
|
enumtype::kTanh,
|
||||||
enumtype::kReLU,
|
enumtype::kReLU,
|
||||||
enumtype::kLeakyReLU,
|
enumtype::kLeakyReLU
|
||||||
|
|
||||||
// Support for this enum class is deprecated and will be removed in 1.5.
|
|
||||||
Nonlinearity
|
|
||||||
>;
|
>;
|
||||||
|
|
||||||
using FanModeType = c10::variant<
|
using FanModeType = c10::variant<
|
||||||
enumtype::kFanIn,
|
enumtype::kFanIn,
|
||||||
enumtype::kFanOut,
|
enumtype::kFanOut
|
||||||
|
|
||||||
// Support for this enum class is deprecated and will be removed in 1.5.
|
|
||||||
FanMode
|
|
||||||
>;
|
>;
|
||||||
|
|
||||||
} // namespace init
|
} // namespace init
|
||||||
|
|
|
||||||
|
|
@ -12,73 +12,6 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace nn {
|
namespace nn {
|
||||||
|
|
||||||
/// Applies [Batch Normalization](https://arxiv.org/abs/1502.03167) to an input.
|
|
||||||
///
|
|
||||||
/// Refer to the documentation for
|
|
||||||
/// [`BatchNorm1d`](https://pytorch.org/docs/stable/nn.html#torch.nn.BatchNorm1d)
|
|
||||||
/// in PyTorch to learn more about the exact semantics of this module, __but see
|
|
||||||
/// the note below regarding differences between the Python and C++ API__.
|
|
||||||
///
|
|
||||||
/// \rst
|
|
||||||
/// .. attention::
|
|
||||||
/// In the Python API, there are separate implementations for 1-D, 2-D and 3-D
|
|
||||||
/// BatchNorm. In C++, there is only one `BatchNorm` module, which works for
|
|
||||||
/// any of these dimensions.
|
|
||||||
/// \endrst
|
|
||||||
class TORCH_API BatchNormImpl : public torch::nn::Cloneable<BatchNormImpl> {
|
|
||||||
public:
|
|
||||||
explicit BatchNormImpl(int64_t num_features)
|
|
||||||
: BatchNormImpl(BatchNormOptions(num_features)) {}
|
|
||||||
explicit BatchNormImpl(const BatchNormOptions& options_);
|
|
||||||
|
|
||||||
void reset() override;
|
|
||||||
|
|
||||||
/// Pretty prints the `BatchNorm` module into the given `stream`.
|
|
||||||
void pretty_print(std::ostream& stream) const override;
|
|
||||||
|
|
||||||
/// Applies batch normalization on the `input` using the stored mean and
|
|
||||||
/// variance.
|
|
||||||
///
|
|
||||||
/// The module must be constructed with `track_running_stats = true` when calling this
|
|
||||||
/// method, as the module will otherwise not store running statistics. If you
|
|
||||||
/// want to supply the mean and variance yourself, use `pure_forward`.
|
|
||||||
Tensor forward(const Tensor& input);
|
|
||||||
|
|
||||||
/// Applies batch normalization on the `input` using the given `mean` and
|
|
||||||
/// `variance` statistics.
|
|
||||||
Tensor pure_forward(
|
|
||||||
const Tensor& input,
|
|
||||||
const Tensor& mean,
|
|
||||||
const Tensor& variance);
|
|
||||||
|
|
||||||
/// The options with which this module was constructed.
|
|
||||||
BatchNormOptions options;
|
|
||||||
|
|
||||||
/// The learned weight.
|
|
||||||
/// Only defined if the `affine` option was `true` upon construction.
|
|
||||||
Tensor weight;
|
|
||||||
|
|
||||||
/// The learned bias.
|
|
||||||
/// Only defined if the `affine` option was `true` upon construction.
|
|
||||||
Tensor bias;
|
|
||||||
|
|
||||||
/// The running mean.
|
|
||||||
/// Only defined if the `track_running_stats` option was `true` upon construction.
|
|
||||||
Tensor running_mean;
|
|
||||||
|
|
||||||
/// The running variance.
|
|
||||||
/// Only defined if the `track_running_stats` option was `true` upon construction.
|
|
||||||
Tensor running_var;
|
|
||||||
};
|
|
||||||
|
|
||||||
/// A `ModuleHolder` subclass for `BatchNormImpl`.
|
|
||||||
/// See the documentation for `BatchNormImpl` class to learn what methods it
|
|
||||||
/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
|
|
||||||
/// module storage semantics.
|
|
||||||
TORCH_MODULE(BatchNorm);
|
|
||||||
|
|
||||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
|
|
||||||
/// Base class for all (dimension-specialized) batchnorm and instancenorm modules.
|
/// Base class for all (dimension-specialized) batchnorm and instancenorm modules.
|
||||||
template <size_t D, typename Derived, typename DerivedOptions>
|
template <size_t D, typename Derived, typename DerivedOptions>
|
||||||
class NormImplBase : public torch::nn::Cloneable<Derived> {
|
class NormImplBase : public torch::nn::Cloneable<Derived> {
|
||||||
|
|
@ -191,6 +124,8 @@ class BatchNormImplBase : public NormImplBase<D, Derived, BatchNormOptions> {
|
||||||
void pretty_print(std::ostream& stream) const override;
|
void pretty_print(std::ostream& stream) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
/// Applies the BatchNorm1d function.
|
/// Applies the BatchNorm1d function.
|
||||||
/// See https://pytorch.org/docs/master/nn.html#torch.nn.BatchNorm1d to learn
|
/// See https://pytorch.org/docs/master/nn.html#torch.nn.BatchNorm1d to learn
|
||||||
/// about the exact behavior of this module.
|
/// about the exact behavior of this module.
|
||||||
|
|
@ -217,6 +152,8 @@ class TORCH_API BatchNorm1dImpl : public BatchNormImplBase<1, BatchNorm1dImpl> {
|
||||||
/// module storage semantics.
|
/// module storage semantics.
|
||||||
TORCH_MODULE(BatchNorm1d);
|
TORCH_MODULE(BatchNorm1d);
|
||||||
|
|
||||||
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
/// Applies the BatchNorm2d function.
|
/// Applies the BatchNorm2d function.
|
||||||
/// See https://pytorch.org/docs/master/nn.html#torch.nn.BatchNorm2d to learn
|
/// See https://pytorch.org/docs/master/nn.html#torch.nn.BatchNorm2d to learn
|
||||||
/// about the exact behavior of this module.
|
/// about the exact behavior of this module.
|
||||||
|
|
@ -243,6 +180,8 @@ class TORCH_API BatchNorm2dImpl : public BatchNormImplBase<2, BatchNorm2dImpl> {
|
||||||
/// module storage semantics.
|
/// module storage semantics.
|
||||||
TORCH_MODULE(BatchNorm2d);
|
TORCH_MODULE(BatchNorm2d);
|
||||||
|
|
||||||
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
/// Applies the BatchNorm3d function.
|
/// Applies the BatchNorm3d function.
|
||||||
/// See https://pytorch.org/docs/master/nn.html#torch.nn.BatchNorm3d to learn
|
/// See https://pytorch.org/docs/master/nn.html#torch.nn.BatchNorm3d to learn
|
||||||
/// about the exact behavior of this module.
|
/// about the exact behavior of this module.
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ namespace nn {
|
||||||
/// Sequential sequential(
|
/// Sequential sequential(
|
||||||
/// Linear(3, 4),
|
/// Linear(3, 4),
|
||||||
/// Functional(torch::relu),
|
/// Functional(torch::relu),
|
||||||
/// BatchNorm(3),
|
/// BatchNorm1d(3),
|
||||||
/// Functional(torch::elu, /*alpha=*/1));
|
/// Functional(torch::elu, /*alpha=*/1));
|
||||||
/// \endrst
|
/// \endrst
|
||||||
///
|
///
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ namespace nn {
|
||||||
///
|
///
|
||||||
/// torch::nn::ModuleList mlist(
|
/// torch::nn::ModuleList mlist(
|
||||||
/// torch::nn::Linear(3, 4),
|
/// torch::nn::Linear(3, 4),
|
||||||
/// torch::nn::BatchNorm(4),
|
/// torch::nn::BatchNorm1d(4),
|
||||||
/// torch::nn::Dropout(0.5)
|
/// torch::nn::Dropout(0.5)
|
||||||
/// );
|
/// );
|
||||||
///
|
///
|
||||||
|
|
@ -39,7 +39,7 @@ namespace nn {
|
||||||
///
|
///
|
||||||
/// torch::nn::ModuleList mlist(
|
/// torch::nn::ModuleList mlist(
|
||||||
/// torch::nn::Linear(3, 4),
|
/// torch::nn::Linear(3, 4),
|
||||||
/// torch::nn::BatchNorm(4),
|
/// torch::nn::BatchNorm1d(4),
|
||||||
/// torch::nn::Dropout(0.5)
|
/// torch::nn::Dropout(0.5)
|
||||||
/// );
|
/// );
|
||||||
///
|
///
|
||||||
|
|
|
||||||
|
|
@ -91,15 +91,5 @@ class NamedAnyModule {
|
||||||
AnyModule module_;
|
AnyModule module_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifndef DOXYGEN_SHOULD_SKIP_THIS
|
|
||||||
|
|
||||||
C10_DEPRECATED_MESSAGE("`torch::nn::modules_ordered_dict` is deprecated. " \
|
|
||||||
"To construct a `Sequential` with named submodules, " \
|
|
||||||
"you can do `Sequential sequential({{\"m1\", MyModule(1)}, {\"m2\", MyModule(2)}})`")
|
|
||||||
TORCH_API torch::OrderedDict<std::string, AnyModule> modules_ordered_dict(
|
|
||||||
std::initializer_list<NamedAnyModule> named_modules);
|
|
||||||
|
|
||||||
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
|
|
||||||
|
|
||||||
} // namespace nn
|
} // namespace nn
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ namespace nn {
|
||||||
///
|
///
|
||||||
/// torch::nn::Sequential seq(
|
/// torch::nn::Sequential seq(
|
||||||
/// torch::nn::Linear(3, 4),
|
/// torch::nn::Linear(3, 4),
|
||||||
/// torch::nn::BatchNorm(4),
|
/// torch::nn::BatchNorm1d(4),
|
||||||
/// torch::nn::Dropout(0.5)
|
/// torch::nn::Dropout(0.5)
|
||||||
/// );
|
/// );
|
||||||
///
|
///
|
||||||
|
|
@ -69,7 +69,7 @@ namespace nn {
|
||||||
///
|
///
|
||||||
/// torch::nn::Sequential seq(
|
/// torch::nn::Sequential seq(
|
||||||
/// torch::nn::Linear(3, 4),
|
/// torch::nn::Linear(3, 4),
|
||||||
/// torch::nn::BatchNorm(4),
|
/// torch::nn::BatchNorm1d(4),
|
||||||
/// torch::nn::Dropout(0.5)
|
/// torch::nn::Dropout(0.5)
|
||||||
/// );
|
/// );
|
||||||
///
|
///
|
||||||
|
|
|
||||||
|
|
@ -128,34 +128,6 @@ public:
|
||||||
/// module storage semantics.
|
/// module storage semantics.
|
||||||
TORCH_MODULE(Dropout3d);
|
TORCH_MODULE(Dropout3d);
|
||||||
|
|
||||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FeatureDropout ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
|
|
||||||
/// Applies spatial [Dropout](https://arxiv.org/abs/1207.0580) to inputs with
|
|
||||||
/// 2-D or 3-D features.
|
|
||||||
///
|
|
||||||
/// The equivalent in Python is
|
|
||||||
/// [Dropout2d](https://pytorch.org/docs/stable/nn.html#torch.nn.Dropout2d) for
|
|
||||||
/// 2-D features and
|
|
||||||
/// [Dropout3d](https://pytorch.org/docs/stable/nn.html#torch.nn.Dropout3d) for
|
|
||||||
/// 3-D features. This `FeatureDropout` module can instead deal with both 2-D
|
|
||||||
/// and 3-D features.
|
|
||||||
class TORCH_API FeatureDropoutImpl
|
|
||||||
: public detail::_DropoutNd<FeatureDropoutImpl> {
|
|
||||||
public:
|
|
||||||
FeatureDropoutImpl(double p);
|
|
||||||
|
|
||||||
explicit FeatureDropoutImpl(const FeatureDropoutOptions& options_ = {});
|
|
||||||
|
|
||||||
/// During training, applies a noise mask to the input tensor.
|
|
||||||
/// During evaluation, applies an identity function.
|
|
||||||
Tensor forward(const Tensor& input);
|
|
||||||
|
|
||||||
/// Pretty prints the `FeatureDropout` module into the given `stream`.
|
|
||||||
void pretty_print(std::ostream& stream) const override;
|
|
||||||
};
|
|
||||||
|
|
||||||
TORCH_MODULE(FeatureDropout);
|
|
||||||
|
|
||||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AlphaDropout ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AlphaDropout ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
/// Applies Alpha Dropout over the input.
|
/// Applies Alpha Dropout over the input.
|
||||||
|
|
|
||||||
|
|
@ -39,9 +39,6 @@ using Dropout2dOptions = DropoutOptions;
|
||||||
/// ```
|
/// ```
|
||||||
using Dropout3dOptions = DropoutOptions;
|
using Dropout3dOptions = DropoutOptions;
|
||||||
|
|
||||||
/// Options for `FeatureDropout` module.
|
|
||||||
using FeatureDropoutOptions = DropoutOptions;
|
|
||||||
|
|
||||||
/// Options for `AlphaDropout` module.
|
/// Options for `AlphaDropout` module.
|
||||||
///
|
///
|
||||||
/// Example:
|
/// Example:
|
||||||
|
|
|
||||||
|
|
@ -35,57 +35,6 @@ struct Fan {
|
||||||
int64_t out;
|
int64_t out;
|
||||||
};
|
};
|
||||||
|
|
||||||
#define COMPUTE_NONLINEARITY_ENUM(name) /* NOLINT(cppcoreguidelines-macro-usage) */ \
|
|
||||||
case Nonlinearity::name: \
|
|
||||||
TORCH_WARN( \
|
|
||||||
"The enum value `torch::nn::init::Nonlinearity::", #name, "` is deprecated and will be removed in 1.5. ", \
|
|
||||||
"Please use `torch::k", #name, "` instead."); \
|
|
||||||
return torch::k##name;
|
|
||||||
|
|
||||||
#define COMPUTE_FANMODE_ENUM(name) /* NOLINT(cppcoreguidelines-macro-usage) */ \
|
|
||||||
case FanMode::name: \
|
|
||||||
TORCH_WARN( \
|
|
||||||
"The enum value `torch::nn::init::FanMode::", #name, "` is deprecated and will be removed in 1.5. ", \
|
|
||||||
"Please use `torch::k", #name, "` instead."); \
|
|
||||||
return torch::k##name;
|
|
||||||
|
|
||||||
NonlinearityType _compute_nonlinearity_type(Nonlinearity nonlinearity) {
|
|
||||||
switch (nonlinearity) {
|
|
||||||
COMPUTE_NONLINEARITY_ENUM(Linear)
|
|
||||||
COMPUTE_NONLINEARITY_ENUM(Conv1D)
|
|
||||||
COMPUTE_NONLINEARITY_ENUM(Conv2D)
|
|
||||||
COMPUTE_NONLINEARITY_ENUM(Conv3D)
|
|
||||||
COMPUTE_NONLINEARITY_ENUM(ConvTranspose1D)
|
|
||||||
COMPUTE_NONLINEARITY_ENUM(ConvTranspose2D)
|
|
||||||
COMPUTE_NONLINEARITY_ENUM(ConvTranspose3D)
|
|
||||||
COMPUTE_NONLINEARITY_ENUM(Sigmoid)
|
|
||||||
COMPUTE_NONLINEARITY_ENUM(Tanh)
|
|
||||||
COMPUTE_NONLINEARITY_ENUM(ReLU)
|
|
||||||
COMPUTE_NONLINEARITY_ENUM(LeakyReLU)
|
|
||||||
default:
|
|
||||||
TORCH_INTERNAL_ASSERT(
|
|
||||||
false,
|
|
||||||
"The enum class `torch::nn::init::Nonlinearity` is deprecated, ",
|
|
||||||
"please don't add any new enum to it. ",
|
|
||||||
"Instead, add the new enum to `torch/csrc/api/include/torch/enum.h` ",
|
|
||||||
"and use `torch::kEnumName` to reference it.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
FanModeType _compute_fanmode_type(FanMode fanmode) {
|
|
||||||
switch (fanmode) {
|
|
||||||
COMPUTE_FANMODE_ENUM(FanIn);
|
|
||||||
COMPUTE_FANMODE_ENUM(FanOut);
|
|
||||||
default:
|
|
||||||
TORCH_INTERNAL_ASSERT(
|
|
||||||
false,
|
|
||||||
"The enum class `torch::nn::init::Nonlinearity` is deprecated, ",
|
|
||||||
"please don't add any new enum to it. ",
|
|
||||||
"Instead, add the new enum to `torch/csrc/api/include/torch/enum.h` ",
|
|
||||||
"and use `torch::kEnumName` to reference it.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
double calculate_kaiming_std(
|
double calculate_kaiming_std(
|
||||||
Tensor tensor,
|
Tensor tensor,
|
||||||
double a,
|
double a,
|
||||||
|
|
@ -96,10 +45,6 @@ double calculate_kaiming_std(
|
||||||
const auto gain = calculate_gain(nonlinearity, a);
|
const auto gain = calculate_gain(nonlinearity, a);
|
||||||
double std = 0.0;
|
double std = 0.0;
|
||||||
|
|
||||||
// Support for `torch::nn::init::FanMode` is deprecated and will be removed in 1.5.
|
|
||||||
if (c10::get_if<FanMode>(&mode)) {
|
|
||||||
mode = _compute_fanmode_type(c10::get<FanMode>(mode));
|
|
||||||
}
|
|
||||||
if (c10::get_if<enumtype::kFanIn>(&mode)) {
|
if (c10::get_if<enumtype::kFanIn>(&mode)) {
|
||||||
std = gain / std::sqrt(fan.in);
|
std = gain / std::sqrt(fan.in);
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -110,10 +55,6 @@ double calculate_kaiming_std(
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
double calculate_gain(NonlinearityType nonlinearity, double param) {
|
double calculate_gain(NonlinearityType nonlinearity, double param) {
|
||||||
// Support for `torch::nn::init::Nonlinearity` is deprecated and will be removed in 1.5.
|
|
||||||
if (c10::get_if<Nonlinearity>(&nonlinearity)) {
|
|
||||||
nonlinearity = _compute_nonlinearity_type(c10::get<Nonlinearity>(nonlinearity));
|
|
||||||
}
|
|
||||||
if (c10::get_if<enumtype::kTanh>(&nonlinearity)) {
|
if (c10::get_if<enumtype::kTanh>(&nonlinearity)) {
|
||||||
return 5.0 / 3.0; // NOLINT
|
return 5.0 / 3.0; // NOLINT
|
||||||
} else if (c10::get_if<enumtype::kReLU>(&nonlinearity)) {
|
} else if (c10::get_if<enumtype::kReLU>(&nonlinearity)) {
|
||||||
|
|
|
||||||
|
|
@ -16,69 +16,6 @@ namespace F = torch::nn::functional;
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace nn {
|
namespace nn {
|
||||||
|
|
||||||
BatchNormImpl::BatchNormImpl(const BatchNormOptions& options_) : options(options_) { // NOLINT(modernize-pass-by-value)
|
|
||||||
TORCH_WARN("torch::nn::BatchNorm module is deprecated and will be removed in 1.5. "
|
|
||||||
"Use BatchNorm{1,2,3}d instead.");
|
|
||||||
reset();
|
|
||||||
}
|
|
||||||
|
|
||||||
void BatchNormImpl::reset() {
|
|
||||||
if (options.affine()) {
|
|
||||||
weight = register_parameter(
|
|
||||||
"weight", torch::empty({options.num_features()}).uniform_());
|
|
||||||
bias = register_parameter("bias", torch::zeros({options.num_features()}));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (options.track_running_stats()) {
|
|
||||||
running_mean =
|
|
||||||
register_buffer("running_mean", torch::zeros({options.num_features()}));
|
|
||||||
running_var =
|
|
||||||
register_buffer("running_var", torch::ones({options.num_features()}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void BatchNormImpl::pretty_print(std::ostream& stream) const {
|
|
||||||
stream << std::boolalpha
|
|
||||||
<< "torch::nn::BatchNorm(num_features=" << options.num_features()
|
|
||||||
<< ", eps=" << options.eps() << ", momentum=" << options.momentum().value()
|
|
||||||
<< ", affine=" << options.affine() << ", track_running_stats=" << options.track_running_stats()
|
|
||||||
<< ")";
|
|
||||||
}
|
|
||||||
|
|
||||||
Tensor BatchNormImpl::forward(const Tensor& input) {
|
|
||||||
TORCH_CHECK(
|
|
||||||
options.track_running_stats(),
|
|
||||||
"Calling BatchNorm::forward is only permitted when "
|
|
||||||
"the 'track_running_stats' option is true (was false). "
|
|
||||||
"Use BatchNorm::pure_forward instead.");
|
|
||||||
return pure_forward(input, running_mean, running_var);
|
|
||||||
}
|
|
||||||
|
|
||||||
Tensor BatchNormImpl::pure_forward(
|
|
||||||
const Tensor& input,
|
|
||||||
const Tensor& mean,
|
|
||||||
const Tensor& variance) {
|
|
||||||
if (is_training()) {
|
|
||||||
const auto num_channels = input.dim() > 1 ? input.size(1) : 1;
|
|
||||||
TORCH_CHECK(
|
|
||||||
input.numel() / num_channels > 1,
|
|
||||||
"BatchNorm expected more than 1 value per channel when training!");
|
|
||||||
}
|
|
||||||
|
|
||||||
return torch::batch_norm(
|
|
||||||
input,
|
|
||||||
weight,
|
|
||||||
bias,
|
|
||||||
mean,
|
|
||||||
variance,
|
|
||||||
is_training(),
|
|
||||||
options.momentum().value(),
|
|
||||||
options.eps(),
|
|
||||||
torch::cuda::cudnn_is_available());
|
|
||||||
}
|
|
||||||
|
|
||||||
// ===========================================================================
|
|
||||||
|
|
||||||
template <size_t D, typename Derived>
|
template <size_t D, typename Derived>
|
||||||
void BatchNormImplBase<D, Derived>::pretty_print(std::ostream& stream) const {
|
void BatchNormImplBase<D, Derived>::pretty_print(std::ostream& stream) const {
|
||||||
stream << std::boolalpha
|
stream << std::boolalpha
|
||||||
|
|
|
||||||
|
|
@ -1,20 +0,0 @@
|
||||||
#include <torch/nn/modules/container/named_any.h>
|
|
||||||
|
|
||||||
namespace torch {
|
|
||||||
namespace nn {
|
|
||||||
|
|
||||||
torch::OrderedDict<std::string, AnyModule> modules_ordered_dict(
|
|
||||||
std::initializer_list<NamedAnyModule> named_modules) {
|
|
||||||
TORCH_WARN(
|
|
||||||
"`torch::nn::modules_ordered_dict` is deprecated. "
|
|
||||||
"To construct a `Sequential` with named submodules, "
|
|
||||||
"you can do `Sequential sequential({{\"m1\", MyModule(1)}, {\"m2\", MyModule(2)}})`");
|
|
||||||
torch::OrderedDict<std::string, AnyModule> dict;
|
|
||||||
for (auto named_module : named_modules) {
|
|
||||||
dict.insert(named_module.name(), std::move(named_module.module()));
|
|
||||||
}
|
|
||||||
return dict;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace nn
|
|
||||||
} // namespace torch
|
|
||||||
|
|
@ -53,30 +53,6 @@ void Dropout3dImpl::pretty_print(std::ostream& stream) const {
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|
||||||
FeatureDropoutImpl::FeatureDropoutImpl(double p)
|
|
||||||
: detail::_DropoutNd<FeatureDropoutImpl>(p) {
|
|
||||||
TORCH_WARN("torch::nn::FeatureDropout module is deprecated."
|
|
||||||
"Use Dropout{2,3}d instead.");
|
|
||||||
}
|
|
||||||
|
|
||||||
FeatureDropoutImpl::FeatureDropoutImpl(const FeatureDropoutOptions& options_)
|
|
||||||
: detail::_DropoutNd<FeatureDropoutImpl>(options_) {
|
|
||||||
TORCH_WARN("torch::nn::FeatureDropout module is deprecated."
|
|
||||||
"Use Dropout{2,3}d instead.");
|
|
||||||
}
|
|
||||||
|
|
||||||
Tensor FeatureDropoutImpl::forward(const Tensor& input) {
|
|
||||||
return torch::feature_dropout(input, options.p(), is_training());
|
|
||||||
}
|
|
||||||
|
|
||||||
void FeatureDropoutImpl::pretty_print(std::ostream& stream) const {
|
|
||||||
stream << std::boolalpha
|
|
||||||
<< "torch::nn::FeatureDropout(p=" << options.p()
|
|
||||||
<< ", inplace=" << options.inplace() << ")";
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================================
|
|
||||||
|
|
||||||
Tensor AlphaDropoutImpl::forward(const Tensor& input) {
|
Tensor AlphaDropoutImpl::forward(const Tensor& input) {
|
||||||
return F::detail::alpha_dropout(input, options.p(), is_training(), /*inplace=*/false);
|
return F::detail::alpha_dropout(input, options.p(), is_training(), /*inplace=*/false);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user