mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Operations on `Variable`s (or `torch::Tensor`) usually return `at::Tensor`. This is usually fine, but the `AnyModule` used in the implementation of `torch::Sequential` is very picky about types, and does not understand implicit conversions like this. This means that `sequential.forward(at_tensor_that_is_actually_a_variable)` will fail unless you wrap `at_tensor_that_is_actually_a_variable` with `torch::Tensor`. This PR adds a special case to `AnyModule` that will convert an `at::Tensor` to `torch::Tensor` when the tensor is really a variable, and else just pass the `at::Tensor`. This is a nice little usability improvement for the often-used `Sequential` class. ebetica ezyang Closes https://github.com/pytorch/pytorch/pull/8968 Reviewed By: ezyang Differential Revision: D8670407 Pulled By: goldsborough fbshipit-source-id: 3635ed6ed28238f3900ce4a876d07f1b11713831
373 lines
11 KiB
C++
373 lines
11 KiB
C++
#include <catch.hpp>
|
|
|
|
#include <torch/torch.h>
|
|
#include <torch/utils.h>
|
|
#include <torch/nn/modules/any.h>
|
|
|
|
#include <algorithm>
|
|
#include <string>
|
|
|
|
using namespace torch::nn;
|
|
using namespace torch::detail;
|
|
|
|
using Catch::Contains;
|
|
using Catch::StartsWith;
|
|
|
|
TEST_CASE("any-module") {
|
|
torch::manual_seed(0);
|
|
SECTION("int()") {
|
|
struct M : torch::nn::Module {
|
|
int forward() {
|
|
return 123;
|
|
}
|
|
};
|
|
AnyModule any(M{});
|
|
REQUIRE(any.forward().get<int>() == 123);
|
|
}
|
|
SECTION("int(int)") {
|
|
struct M : torch::nn::Module {
|
|
int forward(int x) {
|
|
return x;
|
|
}
|
|
};
|
|
AnyModule any(M{});
|
|
REQUIRE(any.forward(5).get<int>() == 5);
|
|
}
|
|
SECTION("const char*(const char*)") {
|
|
struct M : torch::nn::Module {
|
|
const char* forward(const char* x) {
|
|
return x;
|
|
}
|
|
};
|
|
AnyModule any(M{});
|
|
REQUIRE(any.forward("hello").get<const char*>() == std::string("hello"));
|
|
}
|
|
|
|
SECTION("string(int, const double)") {
|
|
struct M : torch::nn::Module {
|
|
std::string forward(int x, const double f) {
|
|
return std::to_string(static_cast<int>(x + f));
|
|
}
|
|
};
|
|
AnyModule any(M{});
|
|
int x = 4;
|
|
REQUIRE(any.forward(x, 3.14).get<std::string>() == std::string("7"));
|
|
}
|
|
|
|
SECTION("Tensor(string, const string&, string&&)") {
|
|
struct M : torch::nn::Module {
|
|
torch::Tensor forward(
|
|
std::string a,
|
|
const std::string& b,
|
|
std::string&& c) {
|
|
const auto s = a + b + c;
|
|
return torch::ones({static_cast<int64_t>(s.size())});
|
|
}
|
|
};
|
|
AnyModule any(M{});
|
|
REQUIRE(
|
|
any.forward(std::string("a"), std::string("ab"), std::string("abc"))
|
|
.get<torch::Tensor>()
|
|
.sum()
|
|
.toCInt() == 6);
|
|
}
|
|
SECTION("wrong argument type") {
|
|
struct M : torch::nn::Module {
|
|
int forward(float x) {
|
|
return x;
|
|
}
|
|
};
|
|
AnyModule any(M{});
|
|
REQUIRE_THROWS_WITH(
|
|
any.forward(5.0),
|
|
StartsWith("Expected argument #0 to be of type float, "
|
|
"but received value of type double"));
|
|
}
|
|
SECTION("wrong number of arguments") {
|
|
struct M : torch::nn::Module {
|
|
int forward(int a, int b) {
|
|
return a + b;
|
|
}
|
|
};
|
|
AnyModule any(M{});
|
|
REQUIRE_THROWS_WITH(
|
|
any.forward(),
|
|
Contains("M's forward() method expects 2 arguments, but received 0"));
|
|
REQUIRE_THROWS_WITH(
|
|
any.forward(5),
|
|
Contains("M's forward() method expects 2 arguments, but received 1"));
|
|
REQUIRE_THROWS_WITH(
|
|
any.forward(1, 2, 3),
|
|
Contains("M's forward() method expects 2 arguments, but received 3"));
|
|
}
|
|
SECTION("get()") {
|
|
struct M : torch::nn::Module {
|
|
explicit M(int value_) : torch::nn::Module("M"), value(value_) {}
|
|
int value;
|
|
int forward(float x) {
|
|
return x;
|
|
}
|
|
};
|
|
AnyModule any(M{5});
|
|
|
|
SECTION("good cast") {
|
|
REQUIRE(any.get<M>().value == 5);
|
|
}
|
|
|
|
SECTION("bad cast") {
|
|
struct N : torch::nn::Module {};
|
|
REQUIRE_THROWS_WITH(any.get<N>(), StartsWith("Attempted to cast module"));
|
|
}
|
|
}
|
|
SECTION("ptr()") {
|
|
struct M : torch::nn::Module {
|
|
explicit M(int value_) : torch::nn::Module("M"), value(value_) {}
|
|
int value;
|
|
int forward(float x) {
|
|
return x;
|
|
}
|
|
};
|
|
AnyModule any(M{5});
|
|
|
|
SECTION("base class cast") {
|
|
auto ptr = any.ptr();
|
|
REQUIRE(ptr != nullptr);
|
|
REQUIRE(ptr->name() == "M");
|
|
}
|
|
|
|
SECTION("good downcast") {
|
|
auto ptr = any.ptr<M>();
|
|
REQUIRE(ptr != nullptr);
|
|
REQUIRE(ptr->value == 5);
|
|
}
|
|
|
|
SECTION("bad downcast") {
|
|
struct N : torch::nn::Module {};
|
|
REQUIRE_THROWS_WITH(any.ptr<N>(), StartsWith("Attempted to cast module"));
|
|
}
|
|
}
|
|
SECTION("default state is empty") {
|
|
struct M : torch::nn::Module {
|
|
explicit M(int value_) : value(value_) {}
|
|
int value;
|
|
int forward(float x) {
|
|
return x;
|
|
}
|
|
};
|
|
AnyModule any;
|
|
REQUIRE(any.is_empty());
|
|
any = std::make_shared<M>(5);
|
|
REQUIRE(!any.is_empty());
|
|
REQUIRE(any.get<M>().value == 5);
|
|
}
|
|
SECTION("all methods throw for empty AnyModule") {
|
|
struct M : torch::nn::Module {
|
|
int forward(int x) {
|
|
return x;
|
|
}
|
|
};
|
|
AnyModule any;
|
|
REQUIRE(any.is_empty());
|
|
REQUIRE_THROWS_WITH(
|
|
any.get<M>(), StartsWith("Cannot call get() on an empty AnyModule"));
|
|
REQUIRE_THROWS_WITH(
|
|
any.ptr<M>(), StartsWith("Cannot call ptr() on an empty AnyModule"));
|
|
REQUIRE_THROWS_WITH(
|
|
any.ptr(), StartsWith("Cannot call ptr() on an empty AnyModule"));
|
|
REQUIRE_THROWS_WITH(
|
|
any.type_info(),
|
|
StartsWith("Cannot call type_info() on an empty AnyModule"));
|
|
REQUIRE_THROWS_WITH(
|
|
any.forward<int>(5),
|
|
StartsWith("Cannot call forward() on an empty AnyModule"));
|
|
}
|
|
SECTION("can move assign differentm modules") {
|
|
struct M : torch::nn::Module {
|
|
std::string forward(int x) {
|
|
return std::to_string(x);
|
|
}
|
|
};
|
|
struct N : torch::nn::Module {
|
|
int forward(float x) {
|
|
return 3 + x;
|
|
}
|
|
};
|
|
AnyModule any;
|
|
REQUIRE(any.is_empty());
|
|
any = std::make_shared<M>();
|
|
REQUIRE(!any.is_empty());
|
|
REQUIRE(any.forward(5).get<std::string>() == "5");
|
|
any = std::make_shared<N>();
|
|
REQUIRE(!any.is_empty());
|
|
REQUIRE(any.forward(5.0f).get<int>() == 8);
|
|
}
|
|
SECTION("has reference semantics") {
|
|
Sequential first(Linear(2, 3), Linear(4, 4), Linear(4, 5));
|
|
Sequential second(first);
|
|
|
|
REQUIRE(first.size() == second.size());
|
|
REQUIRE(std::equal(first.begin(), first.end(), second.begin()));
|
|
}
|
|
SECTION("constructs from ModuleHolder") {
|
|
struct MImpl : torch::nn::Module {
|
|
explicit MImpl(int value_) : torch::nn::Module("M"), value(value_) {}
|
|
int value;
|
|
int forward(float x) {
|
|
return x;
|
|
}
|
|
};
|
|
|
|
struct M : torch::nn::ModuleHolder<MImpl> {
|
|
using torch::nn::ModuleHolder<MImpl>::ModuleHolder;
|
|
using torch::nn::ModuleHolder<MImpl>::get;
|
|
};
|
|
|
|
AnyModule any(M{5});
|
|
REQUIRE(any.get<MImpl>().value == 5);
|
|
REQUIRE(any.get<M>()->value == 5);
|
|
}
|
|
SECTION("converts at::Tensor to torch::Tensor correctly") {
|
|
struct M : torch::nn::Module {
|
|
torch::Tensor forward(torch::Tensor input) {
|
|
return input;
|
|
}
|
|
};
|
|
struct N : torch::nn::Module {
|
|
at::Tensor forward(at::Tensor input) {
|
|
return input;
|
|
}
|
|
};
|
|
{
|
|
// When you get an at::Tensor by performing an operation on a
|
|
// torch::Tensor, the tensor should be converted back to torch::Tensor
|
|
// before being passed to the function (to avoid a type mismatch).
|
|
AnyModule any(M{});
|
|
at::Tensor tensor_that_is_actually_a_variable = torch::ones(5) * 2;
|
|
REQUIRE(
|
|
any.forward(tensor_that_is_actually_a_variable)
|
|
.get<torch::Tensor>()
|
|
.sum()
|
|
.toCFloat() == 10);
|
|
// But tensors that are really tensors should just error.
|
|
REQUIRE_THROWS_WITH(
|
|
any.forward(at::ones(5)),
|
|
StartsWith(
|
|
"Expected argument #0 to be of type torch::autograd::Variable, "
|
|
"but received value of type at::Tensor"));
|
|
}
|
|
{
|
|
// If the function does really accept an `at::Tensor`, this should still
|
|
// work.
|
|
AnyModule any(N{});
|
|
REQUIRE(any.forward(at::ones(5)).get<at::Tensor>().sum().toCFloat() == 5);
|
|
}
|
|
}
|
|
}
|
|
|
|
namespace torch {
|
|
namespace nn {
|
|
struct TestValue {
|
|
template <typename T>
|
|
explicit TestValue(T&& value) : value_(std::forward<T>(value)) {}
|
|
AnyModule::Value operator()() {
|
|
return std::move(value_);
|
|
}
|
|
AnyModule::Value value_;
|
|
};
|
|
template <typename T>
|
|
AnyModule::Value make_value(T&& value) {
|
|
return TestValue(std::forward<T>(value))();
|
|
}
|
|
} // namespace nn
|
|
} // namespace torch
|
|
|
|
TEST_CASE("any-value") {
|
|
torch::manual_seed(0);
|
|
SECTION("gets the correct value for the right type") {
|
|
SECTION("int") {
|
|
auto value = make_value(5);
|
|
// const and non-const types have the same typeid()
|
|
REQUIRE(value.try_get<int>() != nullptr);
|
|
REQUIRE(value.try_get<const int>() != nullptr);
|
|
REQUIRE(value.get<int>() == 5);
|
|
}
|
|
SECTION("const int") {
|
|
auto value = make_value(5);
|
|
REQUIRE(value.try_get<const int>() != nullptr);
|
|
REQUIRE(value.try_get<int>() != nullptr);
|
|
REQUIRE(value.get<const int>() == 5);
|
|
}
|
|
SECTION("const char*") {
|
|
auto value = make_value("hello");
|
|
REQUIRE(value.try_get<const char*>() != nullptr);
|
|
REQUIRE(value.get<const char*>() == std::string("hello"));
|
|
}
|
|
SECTION("std::string") {
|
|
auto value = make_value(std::string("hello"));
|
|
REQUIRE(value.try_get<std::string>() != nullptr);
|
|
REQUIRE(value.get<std::string>() == "hello");
|
|
}
|
|
SECTION("pointers") {
|
|
std::string s("hello");
|
|
std::string* p = &s;
|
|
auto value = make_value(p);
|
|
REQUIRE(value.try_get<std::string*>() != nullptr);
|
|
REQUIRE(*value.get<std::string*>() == "hello");
|
|
}
|
|
SECTION("references") {
|
|
std::string s("hello");
|
|
const std::string& t = s;
|
|
auto value = make_value(t);
|
|
REQUIRE(value.try_get<std::string>() != nullptr);
|
|
REQUIRE(value.get<std::string>() == "hello");
|
|
}
|
|
}
|
|
SECTION("try_get returns nullptr for the wrong type") {
|
|
auto value = make_value(5);
|
|
REQUIRE(value.try_get<int>() != nullptr);
|
|
REQUIRE(value.try_get<float>() == nullptr);
|
|
REQUIRE(value.try_get<long>() == nullptr);
|
|
REQUIRE(value.try_get<std::string>() == nullptr);
|
|
}
|
|
SECTION("get throws for the wrong type") {
|
|
auto value = make_value(5);
|
|
REQUIRE(value.try_get<int>() != nullptr);
|
|
REQUIRE_THROWS_WITH(
|
|
value.get<float>(),
|
|
StartsWith("Attempted to cast Value to float, "
|
|
"but its actual type is int"));
|
|
REQUIRE_THROWS_WITH(
|
|
value.get<long>(),
|
|
StartsWith("Attempted to cast Value to long, "
|
|
"but its actual type is int"));
|
|
}
|
|
SECTION("move is allowed") {
|
|
auto value = make_value(5);
|
|
SECTION("construction") {
|
|
auto copy = make_value(std::move(value));
|
|
REQUIRE(copy.try_get<int>() != nullptr);
|
|
REQUIRE(copy.get<int>() == 5);
|
|
}
|
|
SECTION("assignment") {
|
|
auto copy = make_value(10);
|
|
copy = std::move(value);
|
|
REQUIRE(copy.try_get<int>() != nullptr);
|
|
REQUIRE(copy.get<int>() == 5);
|
|
}
|
|
}
|
|
SECTION("type_info is correct") {
|
|
SECTION("int") {
|
|
auto value = make_value(5);
|
|
REQUIRE(value.type_info().hash_code() == typeid(int).hash_code());
|
|
}
|
|
SECTION("const char") {
|
|
auto value = make_value("hello");
|
|
REQUIRE(value.type_info().hash_code() == typeid(const char*).hash_code());
|
|
}
|
|
SECTION("std::string") {
|
|
auto value = make_value(std::string("hello"));
|
|
REQUIRE(value.type_info().hash_code() == typeid(std::string).hash_code());
|
|
}
|
|
}
|
|
}
|