mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
C++ API parity: LogSigmoid
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27060 Test Plan: Imported from OSS Differential Revision: D17682404 Pulled By: pbelevich fbshipit-source-id: d60d64cd4caf1f56a2e05c516f91321d46ec9624
This commit is contained in:
parent
17c672e704
commit
2cc1e69cc9
|
|
@ -347,3 +347,16 @@ TEST_F(FunctionalTest, LeakyReLU) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(FunctionalTest, LogSigmoid) {
|
||||
const auto size = 3;
|
||||
LogSigmoid model;
|
||||
auto x = torch::linspace(-10.0, 10.0, size * size * size);
|
||||
x.resize_({size, size, size});
|
||||
auto y = F::logsigmoid(x);
|
||||
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_EQ(y.sizes(), torch::IntArrayRef({size, size, size}));
|
||||
auto y_exp = torch::log(torch::ones_like(x)/(torch::ones_like(x) + torch::exp(torch::neg(x))));
|
||||
ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1099,6 +1099,23 @@ TEST_F(ModulesTest, LeakyReLU) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, LogSigmoid) {
|
||||
const auto size = 3;
|
||||
LogSigmoid model;
|
||||
auto x = torch::linspace(-10.0, 10.0, size * size * size);
|
||||
x.resize_({size, size, size}).set_requires_grad(true);
|
||||
auto y = model(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
s.backward();
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_EQ(y.sizes(), torch::IntArrayRef({size, size, size}));
|
||||
auto y_exp = torch::log(torch::ones_like(x)/(torch::ones_like(x) + torch::exp(torch::neg(x))));
|
||||
ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7));
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, PrettyPrintIdentity) {
|
||||
ASSERT_EQ(c10::str(Identity()), "torch::nn::Identity()");
|
||||
}
|
||||
|
|
@ -1349,3 +1366,7 @@ TEST_F(ModulesTest, PrettyPrintLeakyReLU) {
|
|||
LeakyReLUOptions().negative_slope(0.42).inplace(true))),
|
||||
"torch::nn::LeakyReLU(negative_slope=0.42, inplace=true)");
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, PrettyPrintLogSigmoid) {
|
||||
ASSERT_EQ(c10::str(LogSigmoid()), "torch::nn::LogSigmoid()");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -36,6 +36,10 @@ inline Tensor leaky_relu(Tensor& input, const LeakyReLUOptions& options) {
|
|||
}
|
||||
}
|
||||
|
||||
inline Tensor logsigmoid(const Tensor& input) {
|
||||
return torch::log_sigmoid(input);
|
||||
}
|
||||
|
||||
} // namespace functional
|
||||
} // namespace nn
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -101,5 +101,24 @@ class TORCH_API LeakyReLUImpl : public torch::nn::Cloneable<LeakyReLUImpl> {
|
|||
|
||||
TORCH_MODULE(LeakyReLU);
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LogSigmoid ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
/// Applies the LogSigmoid function element-wise.
|
||||
/// See https://pytorch.org/docs/master/nn.html#torch.nn.LogSigmoid to learn
|
||||
/// about the exact behavior of this module.
|
||||
class TORCH_API LogSigmoidImpl : public torch::nn::Cloneable<LogSigmoidImpl> {
|
||||
public:
|
||||
LogSigmoidImpl() {}
|
||||
|
||||
Tensor forward(const Tensor& input);
|
||||
|
||||
void reset() override;
|
||||
|
||||
/// Pretty prints the `LogSigmoid` module into the given `stream`.
|
||||
void pretty_print(std::ostream& stream) const override;
|
||||
};
|
||||
|
||||
TORCH_MODULE(LogSigmoid);
|
||||
|
||||
} // namespace nn
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -84,5 +84,17 @@ void LeakyReLUImpl::pretty_print(std::ostream& stream) const {
|
|||
stream << ")";
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
|
||||
Tensor LogSigmoidImpl::forward(const Tensor& input) {
|
||||
return F::logsigmoid(input);
|
||||
}
|
||||
|
||||
void LogSigmoidImpl::reset() {}
|
||||
|
||||
void LogSigmoidImpl::pretty_print(std::ostream& stream) const {
|
||||
stream << "torch::nn::LogSigmoid()";
|
||||
}
|
||||
|
||||
} // namespace nn
|
||||
} // namespace torch
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user