mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
[C++ API] Remove deprecated torch::nn::BatchNorm / FeatureDropout / modules_ordered_dict and torch::nn::init::Nonlinearity / FanMode (#34508)
Summary:
This PR is BC-breaking in the following way:
- The deprecated `torch::nn::BatchNorm` is removed in favor of `torch::nn::BatchNorm{1,2,3}d`
- The deprecated `torch::nn::FeatureDropout` is removed in favor of `torch::nn::Dropout{2,3}d`
- The deprecated `torch::nn::modules_ordered_dict` is removed. User should do `Sequential sequential({{"m1", MyModule(1)}, {"m2", MyModule(2)}})` instead.
- The deprecated `torch::nn::init::Nonlinearity` is removed, in favor of the following enums:
- `torch::kLinear`
- `torch::kConv1D`
- `torch::kConv2D`
- `torch::kConv3D`
- `torch::kConvTranspose1D`
- `torch::kConvTranspose2D`
- `torch::kConvTranspose3D`
- `torch::kSigmoid`
- `torch::kTanh`
- `torch::kReLU`
- `torch::kLeakyReLU`
- The deprecated `torch::nn::init::FanMode` is removed, in favor of the following enums:
- `torch::kFanIn`
- `torch::kFanOut`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34508
Differential Revision: D20351601
Pulled By: yf225
fbshipit-source-id: cca0cd112f29a31bb023e348ca8f82780e42bea3
This commit is contained in:
parent
e95657b87e
commit
a54416d208
|
|
@ -612,7 +612,6 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
|||
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/rnn.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/upsampling.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/container/functional.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/container/named_any.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/activation.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/adaptive.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/batchnorm.cpp
|
||||
|
|
|
|||
|
|
@ -111,19 +111,19 @@ TEST(InitTest, CanInitializeTensorThatRequiresGrad) {
|
|||
|
||||
TEST(InitTest, CalculateGainWithTanh) {
|
||||
double gain =
|
||||
torch::nn::init::calculate_gain(torch::nn::init::Nonlinearity::Tanh);
|
||||
torch::nn::init::calculate_gain(torch::kTanh);
|
||||
ASSERT_DOUBLE_EQ(gain, 5.0 / 3.0);
|
||||
}
|
||||
|
||||
TEST(InitTest, CalculateGainWithRelu) {
|
||||
double gain =
|
||||
torch::nn::init::calculate_gain(torch::nn::init::Nonlinearity::ReLU);
|
||||
torch::nn::init::calculate_gain(torch::kReLU);
|
||||
ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0));
|
||||
}
|
||||
|
||||
TEST(InitTest, CalculateGainWithLeakyRelu) {
|
||||
double gain =
|
||||
torch::nn::init::calculate_gain(torch::nn::init::Nonlinearity::LeakyReLU);
|
||||
torch::nn::init::calculate_gain(torch::kLeakyReLU);
|
||||
ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0 / (1 + pow(0.01, 2))));
|
||||
}
|
||||
|
||||
|
|
@ -131,41 +131,3 @@ TEST(InitTest, CanInitializeCnnWithOrthogonal) {
|
|||
torch::nn::Conv2d conv_layer(torch::nn::Conv2dOptions(3, 2, 3).stride(2));
|
||||
torch::nn::init::orthogonal_(conv_layer->named_parameters()["weight"]);
|
||||
}
|
||||
|
||||
#define NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(func_name, enum_name, enum_torch_kname) \
|
||||
{ \
|
||||
std::stringstream buffer; \
|
||||
CerrRedirect cerr_redirect(buffer.rdbuf()); \
|
||||
std::cerr << torch::nn::init::func_name(torch::nn::init::Nonlinearity::enum_name) << std::endl; \
|
||||
ASSERT_EQ(count_substr_occurrences(buffer.str(), enum_torch_kname), 1); \
|
||||
}
|
||||
|
||||
#define FANMODE_ENUM_LEGACY_WARNING_CHECK(func_name, enum_name, enum_torch_kname) \
|
||||
{ \
|
||||
std::stringstream buffer; \
|
||||
CerrRedirect cerr_redirect(buffer.rdbuf()); \
|
||||
std::cerr << torch::nn::init::func_name(torch::randn({4, 5}), 0, torch::nn::init::FanMode::enum_name) << std::endl; \
|
||||
ASSERT_EQ(count_substr_occurrences(buffer.str(), enum_torch_kname), 1); \
|
||||
}
|
||||
|
||||
TEST(InitTest, NonlinearityLegacyEnum) {
|
||||
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, Linear, "torch::kLinear")
|
||||
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, Conv1D, "torch::kConv1D")
|
||||
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, Conv2D, "torch::kConv2D")
|
||||
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, Conv3D, "torch::kConv3D")
|
||||
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, ConvTranspose1D, "torch::kConvTranspose1D")
|
||||
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, ConvTranspose2D, "torch::kConvTranspose2D")
|
||||
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, ConvTranspose3D, "torch::kConvTranspose3D")
|
||||
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, Sigmoid, "torch::kSigmoid")
|
||||
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, Tanh, "torch::kTanh")
|
||||
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, ReLU, "torch::kReLU")
|
||||
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, LeakyReLU, "torch::kLeakyReLU")
|
||||
}
|
||||
|
||||
TEST(InitTest, FanModeLegacyEnum) {
|
||||
FANMODE_ENUM_LEGACY_WARNING_CHECK(kaiming_normal_, FanIn, "torch::kFanIn")
|
||||
FANMODE_ENUM_LEGACY_WARNING_CHECK(kaiming_normal_, FanOut, "torch::kFanOut")
|
||||
|
||||
FANMODE_ENUM_LEGACY_WARNING_CHECK(kaiming_uniform_, FanIn, "torch::kFanIn")
|
||||
FANMODE_ENUM_LEGACY_WARNING_CHECK(kaiming_uniform_, FanOut, "torch::kFanOut")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -284,10 +284,10 @@ TEST_F(IntegrationTest, MNISTBatchNorm_CUDA) {
|
|||
torch::manual_seed(0);
|
||||
auto model = std::make_shared<SimpleContainer>();
|
||||
auto conv1 = model->add(Conv2d(1, 10, 5), "conv1");
|
||||
auto batchnorm2d = model->add(BatchNorm(10), "batchnorm2d");
|
||||
auto batchnorm2d = model->add(BatchNorm2d(10), "batchnorm2d");
|
||||
auto conv2 = model->add(Conv2d(10, 20, 5), "conv2");
|
||||
auto linear1 = model->add(Linear(320, 50), "linear1");
|
||||
auto batchnorm1 = model->add(BatchNorm(50), "batchnorm1");
|
||||
auto batchnorm1 = model->add(BatchNorm1d(50), "batchnorm1");
|
||||
auto linear2 = model->add(Linear(50, 10), "linear2");
|
||||
|
||||
auto forward = [&](torch::Tensor x) {
|
||||
|
|
|
|||
|
|
@ -158,7 +158,7 @@ TEST_F(ModuleListTest, SanityCheckForHoldingStandardModules) {
|
|||
Linear(10, 3),
|
||||
Conv2d(1, 2, 3),
|
||||
Dropout(0.5),
|
||||
BatchNorm(5),
|
||||
BatchNorm2d(5),
|
||||
Embedding(4, 10),
|
||||
LSTM(4, 5));
|
||||
}
|
||||
|
|
@ -210,7 +210,7 @@ TEST_F(ModuleListTest, HasReferenceSemantics) {
|
|||
}
|
||||
|
||||
TEST_F(ModuleListTest, IsCloneable) {
|
||||
ModuleList list(Linear(3, 4), Functional(torch::relu), BatchNorm(3));
|
||||
ModuleList list(Linear(3, 4), Functional(torch::relu), BatchNorm1d(3));
|
||||
ModuleList clone = std::dynamic_pointer_cast<ModuleListImpl>(list->clone());
|
||||
ASSERT_EQ(list->size(), clone->size());
|
||||
|
||||
|
|
@ -255,7 +255,7 @@ TEST_F(ModuleListTest, NestingIsPossible) {
|
|||
}
|
||||
|
||||
TEST_F(ModuleListTest, CloneToDevice_CUDA) {
|
||||
ModuleList list(Linear(3, 4), Functional(torch::relu), BatchNorm(3));
|
||||
ModuleList list(Linear(3, 4), Functional(torch::relu), BatchNorm1d(3));
|
||||
torch::Device device(torch::kCUDA, 0);
|
||||
ModuleList clone =
|
||||
std::dynamic_pointer_cast<ModuleListImpl>(list->clone(device));
|
||||
|
|
@ -272,7 +272,7 @@ TEST_F(ModuleListTest, PrettyPrintModuleList) {
|
|||
Linear(10, 3),
|
||||
Conv2d(1, 2, 3),
|
||||
Dropout(0.5),
|
||||
BatchNorm(5),
|
||||
BatchNorm2d(5),
|
||||
Embedding(4, 10),
|
||||
LSTM(4, 5));
|
||||
ASSERT_EQ(
|
||||
|
|
@ -281,7 +281,7 @@ TEST_F(ModuleListTest, PrettyPrintModuleList) {
|
|||
" (0): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
|
||||
" (1): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
|
||||
" (2): torch::nn::Dropout(p=0.5, inplace=false)\n"
|
||||
" (3): torch::nn::BatchNorm(num_features=5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
|
||||
" (3): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
|
||||
" (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
|
||||
" (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
|
||||
")");
|
||||
|
|
@ -290,7 +290,7 @@ TEST_F(ModuleListTest, PrettyPrintModuleList) {
|
|||
TEST_F(ModuleListTest, RangeBasedForLoop) {
|
||||
torch::nn::ModuleList mlist(
|
||||
torch::nn::Linear(3, 4),
|
||||
torch::nn::BatchNorm(4),
|
||||
torch::nn::BatchNorm1d(4),
|
||||
torch::nn::Dropout(0.5)
|
||||
);
|
||||
|
||||
|
|
|
|||
|
|
@ -1350,37 +1350,6 @@ TEST_F(ModulesTest, Dropout3d) {
|
|||
ASSERT_EQ(y.sum().item<float>(), 100);
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, FeatureDropout) {
|
||||
FeatureDropout dropout(0.5);
|
||||
torch::Tensor x = torch::ones({10, 10}, torch::requires_grad());
|
||||
torch::Tensor y = dropout(x);
|
||||
|
||||
y.backward(torch::ones_like(y));
|
||||
ASSERT_EQ(y.ndimension(), 2);
|
||||
ASSERT_EQ(y.size(0), 10);
|
||||
ASSERT_EQ(y.size(1), 10);
|
||||
ASSERT_LT(y.sum().item<float>(), 130); // Probably
|
||||
ASSERT_GT(y.sum().item<float>(), 70); // Probably
|
||||
|
||||
dropout->eval();
|
||||
y = dropout(x);
|
||||
ASSERT_EQ(y.sum().item<float>(), 100);
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, FeatureDropoutLegacyWarning) {
|
||||
std::stringstream buffer;
|
||||
torch::test::CerrRedirect cerr_redirect(buffer.rdbuf());
|
||||
|
||||
FeatureDropout bn(0.5);
|
||||
|
||||
ASSERT_EQ(
|
||||
count_substr_occurrences(
|
||||
buffer.str(),
|
||||
"torch::nn::FeatureDropout module is deprecated"
|
||||
),
|
||||
1);
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, Parameters) {
|
||||
auto model = std::make_shared<NestedModel>();
|
||||
auto parameters = model->named_parameters();
|
||||
|
|
@ -1431,74 +1400,6 @@ TEST_F(ModulesTest, FunctionalArgumentBinding) {
|
|||
ASSERT_EQ(functional(torch::ones({})).item<float>(), 0);
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, BatchNormStateful) {
|
||||
BatchNorm bn(5);
|
||||
|
||||
// Is stateful by default.
|
||||
ASSERT_TRUE(bn->options.track_running_stats());
|
||||
|
||||
ASSERT_TRUE(bn->running_mean.defined());
|
||||
ASSERT_EQ(bn->running_mean.dim(), 1);
|
||||
ASSERT_EQ(bn->running_mean.size(0), 5);
|
||||
|
||||
ASSERT_TRUE(bn->running_var.defined());
|
||||
ASSERT_EQ(bn->running_var.dim(), 1);
|
||||
ASSERT_EQ(bn->running_var.size(0), 5);
|
||||
|
||||
// Is affine by default.
|
||||
ASSERT_TRUE(bn->options.affine());
|
||||
|
||||
ASSERT_TRUE(bn->weight.defined());
|
||||
ASSERT_EQ(bn->weight.dim(), 1);
|
||||
ASSERT_EQ(bn->weight.size(0), 5);
|
||||
|
||||
ASSERT_TRUE(bn->bias.defined());
|
||||
ASSERT_EQ(bn->bias.dim(), 1);
|
||||
ASSERT_EQ(bn->bias.size(0), 5);
|
||||
}
|
||||
TEST_F(ModulesTest, BatchNormStateless) {
|
||||
BatchNorm bn(BatchNormOptions(5).track_running_stats(false).affine(false));
|
||||
|
||||
ASSERT_FALSE(bn->running_mean.defined());
|
||||
ASSERT_FALSE(bn->running_var.defined());
|
||||
ASSERT_FALSE(bn->weight.defined());
|
||||
ASSERT_FALSE(bn->bias.defined());
|
||||
|
||||
ASSERT_THROWS_WITH(
|
||||
bn(torch::ones({2, 5})),
|
||||
"Calling BatchNorm::forward is only permitted "
|
||||
"when the 'track_running_stats' option is true (was false). "
|
||||
"Use BatchNorm::pure_forward instead.");
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, BatchNormPureForward) {
|
||||
BatchNorm bn(BatchNormOptions(5).affine(false));
|
||||
bn->eval();
|
||||
|
||||
// Want to make sure we use the supplied values in `pure_forward` even if
|
||||
// we are stateful.
|
||||
auto input = torch::randn({2, 5});
|
||||
auto mean = torch::randn(5);
|
||||
auto variance = torch::rand(5);
|
||||
auto output = bn->pure_forward(input, mean, variance);
|
||||
auto expected = (input - mean) / torch::sqrt(variance + bn->options.eps());
|
||||
ASSERT_TRUE(output.allclose(expected));
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, BatchNormLegacyWarning) {
|
||||
std::stringstream buffer;
|
||||
torch::test::CerrRedirect cerr_redirect(buffer.rdbuf());
|
||||
|
||||
BatchNorm bn(5);
|
||||
|
||||
ASSERT_EQ(
|
||||
count_substr_occurrences(
|
||||
buffer.str(),
|
||||
"torch::nn::BatchNorm module is deprecated"
|
||||
),
|
||||
1);
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, BatchNorm1dStateful) {
|
||||
BatchNorm1d bn(5);
|
||||
|
||||
|
|
@ -4087,24 +3988,10 @@ TEST_F(ModulesTest, PrettyPrintDropout3d) {
|
|||
ASSERT_EQ(c10::str(Dropout3d(Dropout3dOptions().p(0.42).inplace(true))), "torch::nn::Dropout3d(p=0.42, inplace=true)");
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, PrettyPrintFeatureDropout) {
|
||||
ASSERT_EQ(c10::str(FeatureDropout()), "torch::nn::FeatureDropout(p=0.5, inplace=false)");
|
||||
ASSERT_EQ(c10::str(FeatureDropout(0.42)), "torch::nn::FeatureDropout(p=0.42, inplace=false)");
|
||||
ASSERT_EQ(c10::str(FeatureDropout(FeatureDropoutOptions().p(0.42).inplace(true))), "torch::nn::FeatureDropout(p=0.42, inplace=true)");
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, PrettyPrintFunctional) {
|
||||
ASSERT_EQ(c10::str(Functional(torch::relu)), "torch::nn::Functional()");
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, PrettyPrintBatchNorm) {
|
||||
ASSERT_EQ(
|
||||
c10::str(BatchNorm(
|
||||
BatchNormOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(
|
||||
true))),
|
||||
"torch::nn::BatchNorm(num_features=4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, PrettyPrintBatchNorm1d) {
|
||||
ASSERT_EQ(
|
||||
c10::str(BatchNorm1d(
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ using namespace torch::test;
|
|||
struct SequentialTest : torch::test::SeedingFixture {};
|
||||
|
||||
TEST_F(SequentialTest, CanContainThings) {
|
||||
Sequential sequential(Linear(3, 4), ReLU(), BatchNorm(3));
|
||||
Sequential sequential(Linear(3, 4), ReLU(), BatchNorm1d(3));
|
||||
}
|
||||
|
||||
TEST_F(SequentialTest, ConstructsFromSharedPointer) {
|
||||
|
|
@ -94,22 +94,6 @@ TEST_F(SequentialTest, ConstructsFromModuleHolder) {
|
|||
ASSERT_EQ(sequential->size(), 3);
|
||||
}
|
||||
|
||||
TEST_F(SequentialTest, LegacyBuilderForOrderedDictOfNamedModules) {
|
||||
std::stringstream buffer;
|
||||
CerrRedirect cerr_redirect(buffer.rdbuf());
|
||||
|
||||
Sequential sequential_named(modules_ordered_dict({
|
||||
{"m1", Linear(3, 4)},
|
||||
{"m2", ReLU()},
|
||||
{"m3", BatchNorm(3)}
|
||||
}));
|
||||
ASSERT_EQ(sequential_named->size(), 3);
|
||||
|
||||
ASSERT_EQ(
|
||||
count_substr_occurrences(buffer.str(), "`torch::nn::modules_ordered_dict` is deprecated"),
|
||||
1);
|
||||
}
|
||||
|
||||
TEST_F(SequentialTest, PushBackAddsAnElement) {
|
||||
struct M : torch::nn::Module {
|
||||
explicit M(int value_) : value(value_) {}
|
||||
|
|
@ -291,7 +275,7 @@ TEST_F(SequentialTest, SanityCheckForHoldingStandardModules) {
|
|||
Linear(10, 3),
|
||||
Conv2d(1, 2, 3),
|
||||
Dropout(0.5),
|
||||
BatchNorm(5),
|
||||
BatchNorm2d(5),
|
||||
Embedding(4, 10),
|
||||
LSTM(4, 5));
|
||||
}
|
||||
|
|
@ -358,7 +342,7 @@ TEST_F(SequentialTest, HasReferenceSemantics) {
|
|||
}
|
||||
|
||||
TEST_F(SequentialTest, IsCloneable) {
|
||||
Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm(3));
|
||||
Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm1d(3));
|
||||
Sequential clone =
|
||||
std::dynamic_pointer_cast<SequentialImpl>(sequential->clone());
|
||||
ASSERT_EQ(sequential->size(), clone->size());
|
||||
|
|
@ -398,7 +382,7 @@ TEST_F(SequentialTest, RegistersElementsAsSubmodules) {
|
|||
}
|
||||
|
||||
TEST_F(SequentialTest, CloneToDevice_CUDA) {
|
||||
Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm(3));
|
||||
Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm1d(3));
|
||||
torch::Device device(torch::kCUDA, 0);
|
||||
Sequential clone =
|
||||
std::dynamic_pointer_cast<SequentialImpl>(sequential->clone(device));
|
||||
|
|
@ -415,7 +399,7 @@ TEST_F(SequentialTest, PrettyPrintSequential) {
|
|||
Linear(10, 3),
|
||||
Conv2d(1, 2, 3),
|
||||
Dropout(0.5),
|
||||
BatchNorm(5),
|
||||
BatchNorm2d(5),
|
||||
Embedding(4, 10),
|
||||
LSTM(4, 5));
|
||||
ASSERT_EQ(
|
||||
|
|
@ -424,7 +408,7 @@ TEST_F(SequentialTest, PrettyPrintSequential) {
|
|||
" (0): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
|
||||
" (1): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
|
||||
" (2): torch::nn::Dropout(p=0.5, inplace=false)\n"
|
||||
" (3): torch::nn::BatchNorm(num_features=5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
|
||||
" (3): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
|
||||
" (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
|
||||
" (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
|
||||
")");
|
||||
|
|
@ -433,7 +417,7 @@ TEST_F(SequentialTest, PrettyPrintSequential) {
|
|||
{"linear", Linear(10, 3)},
|
||||
{"conv2d", Conv2d(1, 2, 3)},
|
||||
{"dropout", Dropout(0.5)},
|
||||
{"batchnorm", BatchNorm(5)},
|
||||
{"batchnorm2d", BatchNorm2d(5)},
|
||||
{"embedding", Embedding(4, 10)},
|
||||
{"lstm", LSTM(4, 5)}
|
||||
});
|
||||
|
|
@ -443,7 +427,7 @@ TEST_F(SequentialTest, PrettyPrintSequential) {
|
|||
" (linear): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
|
||||
" (conv2d): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
|
||||
" (dropout): torch::nn::Dropout(p=0.5, inplace=false)\n"
|
||||
" (batchnorm): torch::nn::BatchNorm(num_features=5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
|
||||
" (batchnorm2d): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
|
||||
" (embedding): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
|
||||
" (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
|
||||
")");
|
||||
|
|
|
|||
|
|
@ -254,7 +254,6 @@ torch_cpp_srcs = [
|
|||
"torch/csrc/api/src/nn/modules/rnn.cpp",
|
||||
"torch/csrc/api/src/nn/modules/upsampling.cpp",
|
||||
"torch/csrc/api/src/nn/modules/container/functional.cpp",
|
||||
"torch/csrc/api/src/nn/modules/container/named_any.cpp",
|
||||
"torch/csrc/api/src/nn/options/activation.cpp",
|
||||
"torch/csrc/api/src/nn/options/adaptive.cpp",
|
||||
"torch/csrc/api/src/nn/options/batchnorm.cpp",
|
||||
|
|
|
|||
|
|
@ -8,24 +8,6 @@ namespace torch {
|
|||
namespace nn {
|
||||
namespace init {
|
||||
|
||||
// This enum class is deprecated and will be removed in 1.5
|
||||
enum class Nonlinearity {
|
||||
Linear,
|
||||
Conv1D,
|
||||
Conv2D,
|
||||
Conv3D,
|
||||
ConvTranspose1D,
|
||||
ConvTranspose2D,
|
||||
ConvTranspose3D,
|
||||
Sigmoid,
|
||||
Tanh,
|
||||
ReLU,
|
||||
LeakyReLU
|
||||
};
|
||||
|
||||
// This enum class is deprecated and will be removed in 1.5
|
||||
enum class FanMode { FanIn, FanOut };
|
||||
|
||||
using NonlinearityType = c10::variant<
|
||||
enumtype::kLinear,
|
||||
enumtype::kConv1D,
|
||||
|
|
@ -37,18 +19,12 @@ using NonlinearityType = c10::variant<
|
|||
enumtype::kSigmoid,
|
||||
enumtype::kTanh,
|
||||
enumtype::kReLU,
|
||||
enumtype::kLeakyReLU,
|
||||
|
||||
// Support for this enum class is deprecated and will be removed in 1.5.
|
||||
Nonlinearity
|
||||
enumtype::kLeakyReLU
|
||||
>;
|
||||
|
||||
using FanModeType = c10::variant<
|
||||
enumtype::kFanIn,
|
||||
enumtype::kFanOut,
|
||||
|
||||
// Support for this enum class is deprecated and will be removed in 1.5.
|
||||
FanMode
|
||||
enumtype::kFanOut
|
||||
>;
|
||||
|
||||
} // namespace init
|
||||
|
|
|
|||
|
|
@ -12,73 +12,6 @@
|
|||
namespace torch {
|
||||
namespace nn {
|
||||
|
||||
/// Applies [Batch Normalization](https://arxiv.org/abs/1502.03167) to an input.
|
||||
///
|
||||
/// Refer to the documentation for
|
||||
/// [`BatchNorm1d`](https://pytorch.org/docs/stable/nn.html#torch.nn.BatchNorm1d)
|
||||
/// in PyTorch to learn more about the exact semantics of this module, __but see
|
||||
/// the note below regarding differences between the Python and C++ API__.
|
||||
///
|
||||
/// \rst
|
||||
/// .. attention::
|
||||
/// In the Python API, there are separate implementations for 1-D, 2-D and 3-D
|
||||
/// BatchNorm. In C++, there is only one `BatchNorm` module, which works for
|
||||
/// any of these dimensions.
|
||||
/// \endrst
|
||||
class TORCH_API BatchNormImpl : public torch::nn::Cloneable<BatchNormImpl> {
|
||||
public:
|
||||
explicit BatchNormImpl(int64_t num_features)
|
||||
: BatchNormImpl(BatchNormOptions(num_features)) {}
|
||||
explicit BatchNormImpl(const BatchNormOptions& options_);
|
||||
|
||||
void reset() override;
|
||||
|
||||
/// Pretty prints the `BatchNorm` module into the given `stream`.
|
||||
void pretty_print(std::ostream& stream) const override;
|
||||
|
||||
/// Applies batch normalization on the `input` using the stored mean and
|
||||
/// variance.
|
||||
///
|
||||
/// The module must be constructed with `track_running_stats = true` when calling this
|
||||
/// method, as the module will otherwise not store running statistics. If you
|
||||
/// want to supply the mean and variance yourself, use `pure_forward`.
|
||||
Tensor forward(const Tensor& input);
|
||||
|
||||
/// Applies batch normalization on the `input` using the given `mean` and
|
||||
/// `variance` statistics.
|
||||
Tensor pure_forward(
|
||||
const Tensor& input,
|
||||
const Tensor& mean,
|
||||
const Tensor& variance);
|
||||
|
||||
/// The options with which this module was constructed.
|
||||
BatchNormOptions options;
|
||||
|
||||
/// The learned weight.
|
||||
/// Only defined if the `affine` option was `true` upon construction.
|
||||
Tensor weight;
|
||||
|
||||
/// The learned bias.
|
||||
/// Only defined if the `affine` option was `true` upon construction.
|
||||
Tensor bias;
|
||||
|
||||
/// The running mean.
|
||||
/// Only defined if the `track_running_stats` option was `true` upon construction.
|
||||
Tensor running_mean;
|
||||
|
||||
/// The running variance.
|
||||
/// Only defined if the `track_running_stats` option was `true` upon construction.
|
||||
Tensor running_var;
|
||||
};
|
||||
|
||||
/// A `ModuleHolder` subclass for `BatchNormImpl`.
|
||||
/// See the documentation for `BatchNormImpl` class to learn what methods it
|
||||
/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
|
||||
/// module storage semantics.
|
||||
TORCH_MODULE(BatchNorm);
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
/// Base class for all (dimension-specialized) batchnorm and instancenorm modules.
|
||||
template <size_t D, typename Derived, typename DerivedOptions>
|
||||
class NormImplBase : public torch::nn::Cloneable<Derived> {
|
||||
|
|
@ -191,6 +124,8 @@ class BatchNormImplBase : public NormImplBase<D, Derived, BatchNormOptions> {
|
|||
void pretty_print(std::ostream& stream) const override;
|
||||
};
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
/// Applies the BatchNorm1d function.
|
||||
/// See https://pytorch.org/docs/master/nn.html#torch.nn.BatchNorm1d to learn
|
||||
/// about the exact behavior of this module.
|
||||
|
|
@ -217,6 +152,8 @@ class TORCH_API BatchNorm1dImpl : public BatchNormImplBase<1, BatchNorm1dImpl> {
|
|||
/// module storage semantics.
|
||||
TORCH_MODULE(BatchNorm1d);
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
/// Applies the BatchNorm2d function.
|
||||
/// See https://pytorch.org/docs/master/nn.html#torch.nn.BatchNorm2d to learn
|
||||
/// about the exact behavior of this module.
|
||||
|
|
@ -243,6 +180,8 @@ class TORCH_API BatchNorm2dImpl : public BatchNormImplBase<2, BatchNorm2dImpl> {
|
|||
/// module storage semantics.
|
||||
TORCH_MODULE(BatchNorm2d);
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
/// Applies the BatchNorm3d function.
|
||||
/// See https://pytorch.org/docs/master/nn.html#torch.nn.BatchNorm3d to learn
|
||||
/// about the exact behavior of this module.
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ namespace nn {
|
|||
/// Sequential sequential(
|
||||
/// Linear(3, 4),
|
||||
/// Functional(torch::relu),
|
||||
/// BatchNorm(3),
|
||||
/// BatchNorm1d(3),
|
||||
/// Functional(torch::elu, /*alpha=*/1));
|
||||
/// \endrst
|
||||
///
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ namespace nn {
|
|||
///
|
||||
/// torch::nn::ModuleList mlist(
|
||||
/// torch::nn::Linear(3, 4),
|
||||
/// torch::nn::BatchNorm(4),
|
||||
/// torch::nn::BatchNorm1d(4),
|
||||
/// torch::nn::Dropout(0.5)
|
||||
/// );
|
||||
///
|
||||
|
|
@ -39,7 +39,7 @@ namespace nn {
|
|||
///
|
||||
/// torch::nn::ModuleList mlist(
|
||||
/// torch::nn::Linear(3, 4),
|
||||
/// torch::nn::BatchNorm(4),
|
||||
/// torch::nn::BatchNorm1d(4),
|
||||
/// torch::nn::Dropout(0.5)
|
||||
/// );
|
||||
///
|
||||
|
|
|
|||
|
|
@ -91,15 +91,5 @@ class NamedAnyModule {
|
|||
AnyModule module_;
|
||||
};
|
||||
|
||||
#ifndef DOXYGEN_SHOULD_SKIP_THIS
|
||||
|
||||
C10_DEPRECATED_MESSAGE("`torch::nn::modules_ordered_dict` is deprecated. " \
|
||||
"To construct a `Sequential` with named submodules, " \
|
||||
"you can do `Sequential sequential({{\"m1\", MyModule(1)}, {\"m2\", MyModule(2)}})`")
|
||||
TORCH_API torch::OrderedDict<std::string, AnyModule> modules_ordered_dict(
|
||||
std::initializer_list<NamedAnyModule> named_modules);
|
||||
|
||||
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
|
||||
|
||||
} // namespace nn
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ namespace nn {
|
|||
///
|
||||
/// torch::nn::Sequential seq(
|
||||
/// torch::nn::Linear(3, 4),
|
||||
/// torch::nn::BatchNorm(4),
|
||||
/// torch::nn::BatchNorm1d(4),
|
||||
/// torch::nn::Dropout(0.5)
|
||||
/// );
|
||||
///
|
||||
|
|
@ -69,7 +69,7 @@ namespace nn {
|
|||
///
|
||||
/// torch::nn::Sequential seq(
|
||||
/// torch::nn::Linear(3, 4),
|
||||
/// torch::nn::BatchNorm(4),
|
||||
/// torch::nn::BatchNorm1d(4),
|
||||
/// torch::nn::Dropout(0.5)
|
||||
/// );
|
||||
///
|
||||
|
|
|
|||
|
|
@ -128,34 +128,6 @@ public:
|
|||
/// module storage semantics.
|
||||
TORCH_MODULE(Dropout3d);
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FeatureDropout ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
/// Applies spatial [Dropout](https://arxiv.org/abs/1207.0580) to inputs with
|
||||
/// 2-D or 3-D features.
|
||||
///
|
||||
/// The equivalent in Python is
|
||||
/// [Dropout2d](https://pytorch.org/docs/stable/nn.html#torch.nn.Dropout2d) for
|
||||
/// 2-D features and
|
||||
/// [Dropout3d](https://pytorch.org/docs/stable/nn.html#torch.nn.Dropout3d) for
|
||||
/// 3-D features. This `FeatureDropout` module can instead deal with both 2-D
|
||||
/// and 3-D features.
|
||||
class TORCH_API FeatureDropoutImpl
|
||||
: public detail::_DropoutNd<FeatureDropoutImpl> {
|
||||
public:
|
||||
FeatureDropoutImpl(double p);
|
||||
|
||||
explicit FeatureDropoutImpl(const FeatureDropoutOptions& options_ = {});
|
||||
|
||||
/// During training, applies a noise mask to the input tensor.
|
||||
/// During evaluation, applies an identity function.
|
||||
Tensor forward(const Tensor& input);
|
||||
|
||||
/// Pretty prints the `FeatureDropout` module into the given `stream`.
|
||||
void pretty_print(std::ostream& stream) const override;
|
||||
};
|
||||
|
||||
TORCH_MODULE(FeatureDropout);
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AlphaDropout ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
/// Applies Alpha Dropout over the input.
|
||||
|
|
|
|||
|
|
@ -39,9 +39,6 @@ using Dropout2dOptions = DropoutOptions;
|
|||
/// ```
|
||||
using Dropout3dOptions = DropoutOptions;
|
||||
|
||||
/// Options for `FeatureDropout` module.
|
||||
using FeatureDropoutOptions = DropoutOptions;
|
||||
|
||||
/// Options for `AlphaDropout` module.
|
||||
///
|
||||
/// Example:
|
||||
|
|
|
|||
|
|
@ -35,57 +35,6 @@ struct Fan {
|
|||
int64_t out;
|
||||
};
|
||||
|
||||
#define COMPUTE_NONLINEARITY_ENUM(name) /* NOLINT(cppcoreguidelines-macro-usage) */ \
|
||||
case Nonlinearity::name: \
|
||||
TORCH_WARN( \
|
||||
"The enum value `torch::nn::init::Nonlinearity::", #name, "` is deprecated and will be removed in 1.5. ", \
|
||||
"Please use `torch::k", #name, "` instead."); \
|
||||
return torch::k##name;
|
||||
|
||||
#define COMPUTE_FANMODE_ENUM(name) /* NOLINT(cppcoreguidelines-macro-usage) */ \
|
||||
case FanMode::name: \
|
||||
TORCH_WARN( \
|
||||
"The enum value `torch::nn::init::FanMode::", #name, "` is deprecated and will be removed in 1.5. ", \
|
||||
"Please use `torch::k", #name, "` instead."); \
|
||||
return torch::k##name;
|
||||
|
||||
NonlinearityType _compute_nonlinearity_type(Nonlinearity nonlinearity) {
|
||||
switch (nonlinearity) {
|
||||
COMPUTE_NONLINEARITY_ENUM(Linear)
|
||||
COMPUTE_NONLINEARITY_ENUM(Conv1D)
|
||||
COMPUTE_NONLINEARITY_ENUM(Conv2D)
|
||||
COMPUTE_NONLINEARITY_ENUM(Conv3D)
|
||||
COMPUTE_NONLINEARITY_ENUM(ConvTranspose1D)
|
||||
COMPUTE_NONLINEARITY_ENUM(ConvTranspose2D)
|
||||
COMPUTE_NONLINEARITY_ENUM(ConvTranspose3D)
|
||||
COMPUTE_NONLINEARITY_ENUM(Sigmoid)
|
||||
COMPUTE_NONLINEARITY_ENUM(Tanh)
|
||||
COMPUTE_NONLINEARITY_ENUM(ReLU)
|
||||
COMPUTE_NONLINEARITY_ENUM(LeakyReLU)
|
||||
default:
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false,
|
||||
"The enum class `torch::nn::init::Nonlinearity` is deprecated, ",
|
||||
"please don't add any new enum to it. ",
|
||||
"Instead, add the new enum to `torch/csrc/api/include/torch/enum.h` ",
|
||||
"and use `torch::kEnumName` to reference it.")
|
||||
}
|
||||
}
|
||||
|
||||
FanModeType _compute_fanmode_type(FanMode fanmode) {
|
||||
switch (fanmode) {
|
||||
COMPUTE_FANMODE_ENUM(FanIn);
|
||||
COMPUTE_FANMODE_ENUM(FanOut);
|
||||
default:
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false,
|
||||
"The enum class `torch::nn::init::Nonlinearity` is deprecated, ",
|
||||
"please don't add any new enum to it. ",
|
||||
"Instead, add the new enum to `torch/csrc/api/include/torch/enum.h` ",
|
||||
"and use `torch::kEnumName` to reference it.")
|
||||
}
|
||||
}
|
||||
|
||||
double calculate_kaiming_std(
|
||||
Tensor tensor,
|
||||
double a,
|
||||
|
|
@ -96,10 +45,6 @@ double calculate_kaiming_std(
|
|||
const auto gain = calculate_gain(nonlinearity, a);
|
||||
double std = 0.0;
|
||||
|
||||
// Support for `torch::nn::init::FanMode` is deprecated and will be removed in 1.5.
|
||||
if (c10::get_if<FanMode>(&mode)) {
|
||||
mode = _compute_fanmode_type(c10::get<FanMode>(mode));
|
||||
}
|
||||
if (c10::get_if<enumtype::kFanIn>(&mode)) {
|
||||
std = gain / std::sqrt(fan.in);
|
||||
} else {
|
||||
|
|
@ -110,10 +55,6 @@ double calculate_kaiming_std(
|
|||
} // namespace
|
||||
|
||||
double calculate_gain(NonlinearityType nonlinearity, double param) {
|
||||
// Support for `torch::nn::init::Nonlinearity` is deprecated and will be removed in 1.5.
|
||||
if (c10::get_if<Nonlinearity>(&nonlinearity)) {
|
||||
nonlinearity = _compute_nonlinearity_type(c10::get<Nonlinearity>(nonlinearity));
|
||||
}
|
||||
if (c10::get_if<enumtype::kTanh>(&nonlinearity)) {
|
||||
return 5.0 / 3.0; // NOLINT
|
||||
} else if (c10::get_if<enumtype::kReLU>(&nonlinearity)) {
|
||||
|
|
|
|||
|
|
@ -16,69 +16,6 @@ namespace F = torch::nn::functional;
|
|||
namespace torch {
|
||||
namespace nn {
|
||||
|
||||
BatchNormImpl::BatchNormImpl(const BatchNormOptions& options_) : options(options_) { // NOLINT(modernize-pass-by-value)
|
||||
TORCH_WARN("torch::nn::BatchNorm module is deprecated and will be removed in 1.5. "
|
||||
"Use BatchNorm{1,2,3}d instead.");
|
||||
reset();
|
||||
}
|
||||
|
||||
void BatchNormImpl::reset() {
|
||||
if (options.affine()) {
|
||||
weight = register_parameter(
|
||||
"weight", torch::empty({options.num_features()}).uniform_());
|
||||
bias = register_parameter("bias", torch::zeros({options.num_features()}));
|
||||
}
|
||||
|
||||
if (options.track_running_stats()) {
|
||||
running_mean =
|
||||
register_buffer("running_mean", torch::zeros({options.num_features()}));
|
||||
running_var =
|
||||
register_buffer("running_var", torch::ones({options.num_features()}));
|
||||
}
|
||||
}
|
||||
|
||||
void BatchNormImpl::pretty_print(std::ostream& stream) const {
|
||||
stream << std::boolalpha
|
||||
<< "torch::nn::BatchNorm(num_features=" << options.num_features()
|
||||
<< ", eps=" << options.eps() << ", momentum=" << options.momentum().value()
|
||||
<< ", affine=" << options.affine() << ", track_running_stats=" << options.track_running_stats()
|
||||
<< ")";
|
||||
}
|
||||
|
||||
Tensor BatchNormImpl::forward(const Tensor& input) {
|
||||
TORCH_CHECK(
|
||||
options.track_running_stats(),
|
||||
"Calling BatchNorm::forward is only permitted when "
|
||||
"the 'track_running_stats' option is true (was false). "
|
||||
"Use BatchNorm::pure_forward instead.");
|
||||
return pure_forward(input, running_mean, running_var);
|
||||
}
|
||||
|
||||
Tensor BatchNormImpl::pure_forward(
|
||||
const Tensor& input,
|
||||
const Tensor& mean,
|
||||
const Tensor& variance) {
|
||||
if (is_training()) {
|
||||
const auto num_channels = input.dim() > 1 ? input.size(1) : 1;
|
||||
TORCH_CHECK(
|
||||
input.numel() / num_channels > 1,
|
||||
"BatchNorm expected more than 1 value per channel when training!");
|
||||
}
|
||||
|
||||
return torch::batch_norm(
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
mean,
|
||||
variance,
|
||||
is_training(),
|
||||
options.momentum().value(),
|
||||
options.eps(),
|
||||
torch::cuda::cudnn_is_available());
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
|
||||
template <size_t D, typename Derived>
|
||||
void BatchNormImplBase<D, Derived>::pretty_print(std::ostream& stream) const {
|
||||
stream << std::boolalpha
|
||||
|
|
|
|||
|
|
@ -1,20 +0,0 @@
|
|||
#include <torch/nn/modules/container/named_any.h>
|
||||
|
||||
namespace torch {
|
||||
namespace nn {
|
||||
|
||||
torch::OrderedDict<std::string, AnyModule> modules_ordered_dict(
|
||||
std::initializer_list<NamedAnyModule> named_modules) {
|
||||
TORCH_WARN(
|
||||
"`torch::nn::modules_ordered_dict` is deprecated. "
|
||||
"To construct a `Sequential` with named submodules, "
|
||||
"you can do `Sequential sequential({{\"m1\", MyModule(1)}, {\"m2\", MyModule(2)}})`");
|
||||
torch::OrderedDict<std::string, AnyModule> dict;
|
||||
for (auto named_module : named_modules) {
|
||||
dict.insert(named_module.name(), std::move(named_module.module()));
|
||||
}
|
||||
return dict;
|
||||
}
|
||||
|
||||
} // namespace nn
|
||||
} // namespace torch
|
||||
|
|
@ -53,30 +53,6 @@ void Dropout3dImpl::pretty_print(std::ostream& stream) const {
|
|||
|
||||
// ============================================================================
|
||||
|
||||
FeatureDropoutImpl::FeatureDropoutImpl(double p)
|
||||
: detail::_DropoutNd<FeatureDropoutImpl>(p) {
|
||||
TORCH_WARN("torch::nn::FeatureDropout module is deprecated."
|
||||
"Use Dropout{2,3}d instead.");
|
||||
}
|
||||
|
||||
FeatureDropoutImpl::FeatureDropoutImpl(const FeatureDropoutOptions& options_)
|
||||
: detail::_DropoutNd<FeatureDropoutImpl>(options_) {
|
||||
TORCH_WARN("torch::nn::FeatureDropout module is deprecated."
|
||||
"Use Dropout{2,3}d instead.");
|
||||
}
|
||||
|
||||
Tensor FeatureDropoutImpl::forward(const Tensor& input) {
|
||||
return torch::feature_dropout(input, options.p(), is_training());
|
||||
}
|
||||
|
||||
void FeatureDropoutImpl::pretty_print(std::ostream& stream) const {
|
||||
stream << std::boolalpha
|
||||
<< "torch::nn::FeatureDropout(p=" << options.p()
|
||||
<< ", inplace=" << options.inplace() << ")";
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
|
||||
Tensor AlphaDropoutImpl::forward(const Tensor& input) {
|
||||
return F::detail::alpha_dropout(input, options.p(), is_training(), /*inplace=*/false);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user