mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`. This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool. One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall. ezyang ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151 Reviewed By: ezyang Differential Revision: D8809298 Pulled By: goldsborough fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
366 lines
10 KiB
C++
366 lines
10 KiB
C++
#include <catch.hpp>
|
|
|
|
#include <torch/torch.h>
|
|
#include <torch/utils.h>
|
|
#include <torch/nn/modules/any.h>
|
|
|
|
#include <algorithm>
|
|
#include <string>
|
|
|
|
using namespace torch::nn;
|
|
using namespace torch::detail;
|
|
|
|
using Catch::Contains;
|
|
using Catch::StartsWith;
|
|
|
|
TEST_CASE("any-module") {
|
|
torch::manual_seed(0);
|
|
SECTION("int()") {
|
|
struct M : torch::nn::Module {
|
|
int forward() {
|
|
return 123;
|
|
}
|
|
};
|
|
AnyModule any(M{});
|
|
REQUIRE(any.forward().get<int>() == 123);
|
|
}
|
|
SECTION("int(int)") {
|
|
struct M : torch::nn::Module {
|
|
int forward(int x) {
|
|
return x;
|
|
}
|
|
};
|
|
AnyModule any(M{});
|
|
REQUIRE(any.forward(5).get<int>() == 5);
|
|
}
|
|
SECTION("const char*(const char*)") {
|
|
struct M : torch::nn::Module {
|
|
const char* forward(const char* x) {
|
|
return x;
|
|
}
|
|
};
|
|
AnyModule any(M{});
|
|
REQUIRE(any.forward("hello").get<const char*>() == std::string("hello"));
|
|
}
|
|
|
|
SECTION("string(int, const double)") {
|
|
struct M : torch::nn::Module {
|
|
std::string forward(int x, const double f) {
|
|
return std::to_string(static_cast<int>(x + f));
|
|
}
|
|
};
|
|
AnyModule any(M{});
|
|
int x = 4;
|
|
REQUIRE(any.forward(x, 3.14).get<std::string>() == std::string("7"));
|
|
}
|
|
|
|
SECTION("Tensor(string, const string&, string&&)") {
|
|
struct M : torch::nn::Module {
|
|
torch::Tensor forward(
|
|
std::string a,
|
|
const std::string& b,
|
|
std::string&& c) {
|
|
const auto s = a + b + c;
|
|
return torch::ones({static_cast<int64_t>(s.size())});
|
|
}
|
|
};
|
|
AnyModule any(M{});
|
|
REQUIRE(
|
|
any.forward(std::string("a"), std::string("ab"), std::string("abc"))
|
|
.get<torch::Tensor>()
|
|
.sum()
|
|
.toCInt() == 6);
|
|
}
|
|
SECTION("wrong argument type") {
|
|
struct M : torch::nn::Module {
|
|
int forward(float x) {
|
|
return x;
|
|
}
|
|
};
|
|
AnyModule any(M{});
|
|
REQUIRE_THROWS_WITH(
|
|
any.forward(5.0),
|
|
StartsWith("Expected argument #0 to be of type float, "
|
|
"but received value of type double"));
|
|
}
|
|
SECTION("wrong number of arguments") {
|
|
struct M : torch::nn::Module {
|
|
int forward(int a, int b) {
|
|
return a + b;
|
|
}
|
|
};
|
|
AnyModule any(M{});
|
|
REQUIRE_THROWS_WITH(
|
|
any.forward(),
|
|
Contains("M's forward() method expects 2 arguments, but received 0"));
|
|
REQUIRE_THROWS_WITH(
|
|
any.forward(5),
|
|
Contains("M's forward() method expects 2 arguments, but received 1"));
|
|
REQUIRE_THROWS_WITH(
|
|
any.forward(1, 2, 3),
|
|
Contains("M's forward() method expects 2 arguments, but received 3"));
|
|
}
|
|
SECTION("get()") {
|
|
struct M : torch::nn::Module {
|
|
explicit M(int value_) : torch::nn::Module("M"), value(value_) {}
|
|
int value;
|
|
int forward(float x) {
|
|
return x;
|
|
}
|
|
};
|
|
AnyModule any(M{5});
|
|
|
|
SECTION("good cast") {
|
|
REQUIRE(any.get<M>().value == 5);
|
|
}
|
|
|
|
SECTION("bad cast") {
|
|
struct N : torch::nn::Module {};
|
|
REQUIRE_THROWS_WITH(any.get<N>(), StartsWith("Attempted to cast module"));
|
|
}
|
|
}
|
|
SECTION("ptr()") {
|
|
struct M : torch::nn::Module {
|
|
explicit M(int value_) : torch::nn::Module("M"), value(value_) {}
|
|
int value;
|
|
int forward(float x) {
|
|
return x;
|
|
}
|
|
};
|
|
AnyModule any(M{5});
|
|
|
|
SECTION("base class cast") {
|
|
auto ptr = any.ptr();
|
|
REQUIRE(ptr != nullptr);
|
|
REQUIRE(ptr->name() == "M");
|
|
}
|
|
|
|
SECTION("good downcast") {
|
|
auto ptr = any.ptr<M>();
|
|
REQUIRE(ptr != nullptr);
|
|
REQUIRE(ptr->value == 5);
|
|
}
|
|
|
|
SECTION("bad downcast") {
|
|
struct N : torch::nn::Module {};
|
|
REQUIRE_THROWS_WITH(any.ptr<N>(), StartsWith("Attempted to cast module"));
|
|
}
|
|
}
|
|
SECTION("default state is empty") {
|
|
struct M : torch::nn::Module {
|
|
explicit M(int value_) : value(value_) {}
|
|
int value;
|
|
int forward(float x) {
|
|
return x;
|
|
}
|
|
};
|
|
AnyModule any;
|
|
REQUIRE(any.is_empty());
|
|
any = std::make_shared<M>(5);
|
|
REQUIRE(!any.is_empty());
|
|
REQUIRE(any.get<M>().value == 5);
|
|
}
|
|
SECTION("all methods throw for empty AnyModule") {
|
|
struct M : torch::nn::Module {
|
|
int forward(int x) {
|
|
return x;
|
|
}
|
|
};
|
|
AnyModule any;
|
|
REQUIRE(any.is_empty());
|
|
REQUIRE_THROWS_WITH(
|
|
any.get<M>(), StartsWith("Cannot call get() on an empty AnyModule"));
|
|
REQUIRE_THROWS_WITH(
|
|
any.ptr<M>(), StartsWith("Cannot call ptr() on an empty AnyModule"));
|
|
REQUIRE_THROWS_WITH(
|
|
any.ptr(), StartsWith("Cannot call ptr() on an empty AnyModule"));
|
|
REQUIRE_THROWS_WITH(
|
|
any.type_info(),
|
|
StartsWith("Cannot call type_info() on an empty AnyModule"));
|
|
REQUIRE_THROWS_WITH(
|
|
any.forward<int>(5),
|
|
StartsWith("Cannot call forward() on an empty AnyModule"));
|
|
}
|
|
SECTION("can move assign differentm modules") {
|
|
struct M : torch::nn::Module {
|
|
std::string forward(int x) {
|
|
return std::to_string(x);
|
|
}
|
|
};
|
|
struct N : torch::nn::Module {
|
|
int forward(float x) {
|
|
return 3 + x;
|
|
}
|
|
};
|
|
AnyModule any;
|
|
REQUIRE(any.is_empty());
|
|
any = std::make_shared<M>();
|
|
REQUIRE(!any.is_empty());
|
|
REQUIRE(any.forward(5).get<std::string>() == "5");
|
|
any = std::make_shared<N>();
|
|
REQUIRE(!any.is_empty());
|
|
REQUIRE(any.forward(5.0f).get<int>() == 8);
|
|
}
|
|
SECTION("constructs from ModuleHolder") {
|
|
struct MImpl : torch::nn::Module {
|
|
explicit MImpl(int value_) : torch::nn::Module("M"), value(value_) {}
|
|
int value;
|
|
int forward(float x) {
|
|
return x;
|
|
}
|
|
};
|
|
|
|
struct M : torch::nn::ModuleHolder<MImpl> {
|
|
using torch::nn::ModuleHolder<MImpl>::ModuleHolder;
|
|
using torch::nn::ModuleHolder<MImpl>::get;
|
|
};
|
|
|
|
AnyModule any(M{5});
|
|
REQUIRE(any.get<MImpl>().value == 5);
|
|
REQUIRE(any.get<M>()->value == 5);
|
|
}
|
|
SECTION("converts at::Tensor to torch::Tensor correctly") {
|
|
struct M : torch::nn::Module {
|
|
torch::Tensor forward(torch::Tensor input) {
|
|
return input;
|
|
}
|
|
};
|
|
struct N : torch::nn::Module {
|
|
at::Tensor forward(at::Tensor input) {
|
|
return input;
|
|
}
|
|
};
|
|
{
|
|
// When you get an at::Tensor by performing an operation on a
|
|
// torch::Tensor, the tensor should be converted back to torch::Tensor
|
|
// before being passed to the function (to avoid a type mismatch).
|
|
AnyModule any(M{});
|
|
at::Tensor tensor_that_is_actually_a_variable = torch::ones(5) * 2;
|
|
REQUIRE(
|
|
any.forward(tensor_that_is_actually_a_variable)
|
|
.get<torch::Tensor>()
|
|
.sum()
|
|
.toCFloat() == 10);
|
|
// But tensors that are really tensors should just error.
|
|
REQUIRE_THROWS_WITH(
|
|
any.forward(at::ones(5)),
|
|
StartsWith(
|
|
"Expected argument #0 to be of type torch::autograd::Variable, "
|
|
"but received value of type at::Tensor"));
|
|
}
|
|
{
|
|
// If the function does really accept an `at::Tensor`, this should still
|
|
// work.
|
|
AnyModule any(N{});
|
|
REQUIRE(any.forward(at::ones(5)).get<at::Tensor>().sum().toCFloat() == 5);
|
|
}
|
|
}
|
|
}
|
|
|
|
namespace torch {
|
|
namespace nn {
|
|
struct TestValue {
|
|
template <typename T>
|
|
explicit TestValue(T&& value) : value_(std::forward<T>(value)) {}
|
|
AnyModule::Value operator()() {
|
|
return std::move(value_);
|
|
}
|
|
AnyModule::Value value_;
|
|
};
|
|
template <typename T>
|
|
AnyModule::Value make_value(T&& value) {
|
|
return TestValue(std::forward<T>(value))();
|
|
}
|
|
} // namespace nn
|
|
} // namespace torch
|
|
|
|
TEST_CASE("any-value") {
|
|
torch::manual_seed(0);
|
|
SECTION("gets the correct value for the right type") {
|
|
SECTION("int") {
|
|
auto value = make_value(5);
|
|
// const and non-const types have the same typeid()
|
|
REQUIRE(value.try_get<int>() != nullptr);
|
|
REQUIRE(value.try_get<const int>() != nullptr);
|
|
REQUIRE(value.get<int>() == 5);
|
|
}
|
|
SECTION("const int") {
|
|
auto value = make_value(5);
|
|
REQUIRE(value.try_get<const int>() != nullptr);
|
|
REQUIRE(value.try_get<int>() != nullptr);
|
|
REQUIRE(value.get<const int>() == 5);
|
|
}
|
|
SECTION("const char*") {
|
|
auto value = make_value("hello");
|
|
REQUIRE(value.try_get<const char*>() != nullptr);
|
|
REQUIRE(value.get<const char*>() == std::string("hello"));
|
|
}
|
|
SECTION("std::string") {
|
|
auto value = make_value(std::string("hello"));
|
|
REQUIRE(value.try_get<std::string>() != nullptr);
|
|
REQUIRE(value.get<std::string>() == "hello");
|
|
}
|
|
SECTION("pointers") {
|
|
std::string s("hello");
|
|
std::string* p = &s;
|
|
auto value = make_value(p);
|
|
REQUIRE(value.try_get<std::string*>() != nullptr);
|
|
REQUIRE(*value.get<std::string*>() == "hello");
|
|
}
|
|
SECTION("references") {
|
|
std::string s("hello");
|
|
const std::string& t = s;
|
|
auto value = make_value(t);
|
|
REQUIRE(value.try_get<std::string>() != nullptr);
|
|
REQUIRE(value.get<std::string>() == "hello");
|
|
}
|
|
}
|
|
SECTION("try_get returns nullptr for the wrong type") {
|
|
auto value = make_value(5);
|
|
REQUIRE(value.try_get<int>() != nullptr);
|
|
REQUIRE(value.try_get<float>() == nullptr);
|
|
REQUIRE(value.try_get<long>() == nullptr);
|
|
REQUIRE(value.try_get<std::string>() == nullptr);
|
|
}
|
|
SECTION("get throws for the wrong type") {
|
|
auto value = make_value(5);
|
|
REQUIRE(value.try_get<int>() != nullptr);
|
|
REQUIRE_THROWS_WITH(
|
|
value.get<float>(),
|
|
StartsWith("Attempted to cast Value to float, "
|
|
"but its actual type is int"));
|
|
REQUIRE_THROWS_WITH(
|
|
value.get<long>(),
|
|
StartsWith("Attempted to cast Value to long, "
|
|
"but its actual type is int"));
|
|
}
|
|
SECTION("move is allowed") {
|
|
auto value = make_value(5);
|
|
SECTION("construction") {
|
|
auto copy = make_value(std::move(value));
|
|
REQUIRE(copy.try_get<int>() != nullptr);
|
|
REQUIRE(copy.get<int>() == 5);
|
|
}
|
|
SECTION("assignment") {
|
|
auto copy = make_value(10);
|
|
copy = std::move(value);
|
|
REQUIRE(copy.try_get<int>() != nullptr);
|
|
REQUIRE(copy.get<int>() == 5);
|
|
}
|
|
}
|
|
SECTION("type_info is correct") {
|
|
SECTION("int") {
|
|
auto value = make_value(5);
|
|
REQUIRE(value.type_info().hash_code() == typeid(int).hash_code());
|
|
}
|
|
SECTION("const char") {
|
|
auto value = make_value("hello");
|
|
REQUIRE(value.type_info().hash_code() == typeid(const char*).hash_code());
|
|
}
|
|
SECTION("std::string") {
|
|
auto value = make_value(std::string("hello"));
|
|
REQUIRE(value.type_info().hash_code() == typeid(std::string).hash_code());
|
|
}
|
|
}
|
|
}
|