mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Impl for ParameterList (#41259)
Summary:
This is a new PR for https://github.com/pytorch/pytorch/issues/40850, https://github.com/pytorch/pytorch/issues/40987 and https://github.com/pytorch/pytorch/issues/41206(I unintentionally closed), as I have some issues for rebates for that one. Very sorry about that. And I have fixed the tests failed in that PR.
This diff contains the implementation of C++ API for ParameterList from https://github.com/pytorch/pytorch/issues/25883.
Refer to the Python API: bc9e8af218/torch/nn/modules/container.py (L376)
Not sure about some naming difference between C++ API and Python API, like `append`, should it be called `push_back`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41259
Test Plan: Add unit tests in this diff
Differential Revision: D22495780
Pulled By: glaringlee
fbshipit-source-id: 79ea3592db640f35477d445ecdaeafbdad814bec
This commit is contained in:
parent
fa153184c8
commit
98df9781a7
|
|
@ -16,6 +16,7 @@ set(TORCH_API_TEST_SOURCES
|
|||
${TORCH_API_TEST_DIR}/modulelist.cpp
|
||||
${TORCH_API_TEST_DIR}/modules.cpp
|
||||
${TORCH_API_TEST_DIR}/parameterdict.cpp
|
||||
${TORCH_API_TEST_DIR}/parameterlist.cpp
|
||||
${TORCH_API_TEST_DIR}/namespace.cpp
|
||||
${TORCH_API_TEST_DIR}/nn_utils.cpp
|
||||
${TORCH_API_TEST_DIR}/optim.cpp
|
||||
|
|
|
|||
162
test/cpp/api/parameterlist.cpp
Normal file
162
test/cpp/api/parameterlist.cpp
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <test/cpp/api/support.h>
|
||||
|
||||
using namespace torch::nn;
|
||||
using namespace torch::test;
|
||||
|
||||
struct ParameterListTest : torch::test::SeedingFixture {};
|
||||
|
||||
TEST_F(ParameterListTest, ConstructsFromSharedPointer) {
|
||||
torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
|
||||
torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
|
||||
torch::Tensor tc = torch::randn({1, 2});
|
||||
ASSERT_TRUE(ta.requires_grad());
|
||||
ASSERT_FALSE(tb.requires_grad());
|
||||
ParameterList list(ta, tb, tc);
|
||||
ASSERT_EQ(list->size(), 3);
|
||||
}
|
||||
|
||||
TEST_F(ParameterListTest, isEmpty) {
|
||||
torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
|
||||
ParameterList list;
|
||||
ASSERT_TRUE(list->is_empty());
|
||||
list->append(ta);
|
||||
ASSERT_FALSE(list->is_empty());
|
||||
ASSERT_EQ(list->size(), 1);
|
||||
}
|
||||
|
||||
TEST_F(ParameterListTest, PushBackAddsAnElement) {
|
||||
ParameterList list;
|
||||
torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
|
||||
torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
|
||||
torch::Tensor tc = torch::randn({1, 2});
|
||||
torch::Tensor td = torch::randn({1, 2, 3});
|
||||
ASSERT_EQ(list->size(), 0);
|
||||
ASSERT_TRUE(list->is_empty());
|
||||
list->append(ta);
|
||||
ASSERT_EQ(list->size(), 1);
|
||||
list->append(tb);
|
||||
ASSERT_EQ(list->size(), 2);
|
||||
list->append(tc);
|
||||
ASSERT_EQ(list->size(), 3);
|
||||
list->append(td);
|
||||
ASSERT_EQ(list->size(), 4);
|
||||
}
|
||||
TEST_F(ParameterListTest, ForEachLoop) {
|
||||
torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
|
||||
torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
|
||||
torch::Tensor tc = torch::randn({1, 2});
|
||||
torch::Tensor td = torch::randn({1, 2, 3});
|
||||
ParameterList list(ta, tb, tc, td);
|
||||
std::vector<torch::Tensor> params = {ta, tb, tc, td};
|
||||
ASSERT_EQ(list->size(), 4);
|
||||
int idx = 0;
|
||||
for (const auto& pair : *list) {
|
||||
ASSERT_TRUE(
|
||||
torch::all(torch::eq(pair.value(), params[idx++])).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ParameterListTest, AccessWithAt) {
|
||||
torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
|
||||
torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
|
||||
torch::Tensor tc = torch::randn({1, 2});
|
||||
torch::Tensor td = torch::randn({1, 2, 3});
|
||||
std::vector<torch::Tensor> params = {ta, tb, tc, td};
|
||||
|
||||
ParameterList list;
|
||||
for (auto& param : params) {
|
||||
list->append(param);
|
||||
}
|
||||
ASSERT_EQ(list->size(), 4);
|
||||
|
||||
// returns the correct module for a given index
|
||||
for (size_t i = 0; i < params.size(); ++i) {
|
||||
ASSERT_TRUE(torch::all(torch::eq(list->at(i), params[i])).item<bool>());
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < params.size(); ++i) {
|
||||
ASSERT_TRUE(torch::all(torch::eq(list[i], params[i])).item<bool>());
|
||||
}
|
||||
|
||||
// throws for a bad index
|
||||
ASSERT_THROWS_WITH(list->at(params.size() + 100), "Index out of range");
|
||||
ASSERT_THROWS_WITH(list->at(params.size() + 1), "Index out of range");
|
||||
ASSERT_THROWS_WITH(list[params.size() + 1], "Index out of range");
|
||||
}
|
||||
|
||||
TEST_F(ParameterListTest, ExtendPushesParametersFromOtherParameterList) {
|
||||
torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
|
||||
torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
|
||||
torch::Tensor tc = torch::randn({1, 2});
|
||||
torch::Tensor td = torch::randn({1, 2, 3});
|
||||
torch::Tensor te = torch::randn({1, 2});
|
||||
torch::Tensor tf = torch::randn({1, 2, 3});
|
||||
ParameterList a(ta, tb);
|
||||
ParameterList b(tc, td);
|
||||
a->extend(*b);
|
||||
|
||||
ASSERT_EQ(a->size(), 4);
|
||||
ASSERT_TRUE(torch::all(torch::eq(a[0], ta)).item<bool>());
|
||||
ASSERT_TRUE(torch::all(torch::eq(a[1], tb)).item<bool>());
|
||||
ASSERT_TRUE(torch::all(torch::eq(a[2], tc)).item<bool>());
|
||||
ASSERT_TRUE(torch::all(torch::eq(a[3], td)).item<bool>());
|
||||
|
||||
ASSERT_EQ(b->size(), 2);
|
||||
ASSERT_TRUE(torch::all(torch::eq(b[0], tc)).item<bool>());
|
||||
ASSERT_TRUE(torch::all(torch::eq(b[1], td)).item<bool>());
|
||||
|
||||
std::vector<torch::Tensor> c = {te, tf};
|
||||
b->extend(c);
|
||||
|
||||
ASSERT_EQ(b->size(), 4);
|
||||
ASSERT_TRUE(torch::all(torch::eq(b[0], tc)).item<bool>());
|
||||
ASSERT_TRUE(torch::all(torch::eq(b[1], td)).item<bool>());
|
||||
ASSERT_TRUE(torch::all(torch::eq(b[2], te)).item<bool>());
|
||||
ASSERT_TRUE(torch::all(torch::eq(b[3], tf)).item<bool>());
|
||||
}
|
||||
|
||||
TEST_F(ParameterListTest, PrettyPrintParameterList) {
|
||||
torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
|
||||
torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
|
||||
torch::Tensor tc = torch::randn({1, 2});
|
||||
ParameterList list(ta, tb, tc);
|
||||
ASSERT_EQ(
|
||||
c10::str(list),
|
||||
"torch::nn::ParameterList(\n"
|
||||
"(0): Parameter containing: [Float of size [1, 2]]\n"
|
||||
"(1): Parameter containing: [Float of size [1, 2]]\n"
|
||||
"(2): Parameter containing: [Float of size [1, 2]]\n"
|
||||
")");
|
||||
}
|
||||
|
||||
TEST_F(ParameterListTest, IncrementAdd) {
|
||||
torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
|
||||
torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
|
||||
torch::Tensor tc = torch::randn({1, 2});
|
||||
torch::Tensor td = torch::randn({1, 2, 3});
|
||||
torch::Tensor te = torch::randn({1, 2});
|
||||
torch::Tensor tf = torch::randn({1, 2, 3});
|
||||
ParameterList listA(ta, tb, tc);
|
||||
ParameterList listB(td, te, tf);
|
||||
std::vector<torch::Tensor> tensors{ta, tb, tc, td, te, tf};
|
||||
int idx = 0;
|
||||
*listA += *listB;
|
||||
ASSERT_TRUE(torch::all(torch::eq(listA[0], ta)).item<bool>());
|
||||
ASSERT_TRUE(torch::all(torch::eq(listA[1], tb)).item<bool>());
|
||||
ASSERT_TRUE(torch::all(torch::eq(listA[2], tc)).item<bool>());
|
||||
ASSERT_TRUE(torch::all(torch::eq(listA[3], td)).item<bool>());
|
||||
ASSERT_TRUE(torch::all(torch::eq(listA[4], te)).item<bool>());
|
||||
ASSERT_TRUE(torch::all(torch::eq(listA[5], tf)).item<bool>());
|
||||
for (const auto& P : listA->named_parameters(false))
|
||||
ASSERT_TRUE(torch::all(torch::eq(P.value(), tensors[idx++])).item<bool>());
|
||||
|
||||
ASSERT_EQ(idx, 6);
|
||||
}
|
||||
|
|
@ -10,6 +10,7 @@
|
|||
#include <torch/nn/modules/container/named_any.h>
|
||||
#include <torch/nn/modules/container/sequential.h>
|
||||
#include <torch/nn/modules/container/parameterdict.h>
|
||||
#include <torch/nn/modules/container/parameterlist.h>
|
||||
|
||||
// Layers
|
||||
#include <torch/nn/modules/adaptive.h>
|
||||
|
|
|
|||
|
|
@ -0,0 +1,169 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/nn/cloneable.h>
|
||||
#include <torch/nn/module.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace nn {
|
||||
class ParameterListImpl : public Cloneable<ParameterListImpl> {
|
||||
public:
|
||||
using Iterator = typename std::vector<
|
||||
OrderedDict<std::string, torch::Tensor>::Item>::iterator;
|
||||
using ConstIterator = typename std::vector<
|
||||
OrderedDict<std::string, torch::Tensor>::Item>::const_iterator;
|
||||
|
||||
ParameterListImpl() = default;
|
||||
|
||||
/// Constructs the `ParameterList` from a variadic list of ParameterList.
|
||||
template <typename... Tensors>
|
||||
explicit ParameterListImpl(Tensors&&... params) {
|
||||
parameters_.reserve(sizeof...(Tensors));
|
||||
push_back_var(std::forward<Tensors>(params)...);
|
||||
}
|
||||
|
||||
template <typename... Tensors>
|
||||
explicit ParameterListImpl(const Tensors&... params) {
|
||||
parameters_.reserve(sizeof...(Tensors));
|
||||
push_back_var(std::forward<Tensors>(params)...);
|
||||
}
|
||||
|
||||
/// `reset()` is empty for `ParameterList`, since it does not have parameters
|
||||
/// of its own.
|
||||
void reset() override {}
|
||||
|
||||
/// Pretty prints the `ParameterList` module into the given `stream`.
|
||||
void pretty_print(std::ostream& stream) const override {
|
||||
stream << "torch::nn::ParameterList(" << std::endl;
|
||||
for (const auto& pair : parameters_) {
|
||||
stream << "(" << pair.key() << ")"
|
||||
<< ": Parameter containing: [" << pair.value().scalar_type()
|
||||
<< " of size " << pair.value().sizes() << "]";
|
||||
;
|
||||
stream << std::endl;
|
||||
}
|
||||
stream << ")";
|
||||
}
|
||||
|
||||
/// push the a given parameter at the end of the list
|
||||
void append(torch::Tensor&& param) {
|
||||
bool requires_grad = param.requires_grad();
|
||||
register_parameter(
|
||||
c10::to_string(parameters_.size()), std::move(param), requires_grad);
|
||||
}
|
||||
|
||||
/// push the a given parameter at the end of the list
|
||||
void append(const torch::Tensor& param) {
|
||||
bool requires_grad = param.requires_grad();
|
||||
register_parameter(
|
||||
c10::to_string(parameters_.size()), std::move(param), requires_grad);
|
||||
}
|
||||
|
||||
/// push the a given parameter at the end of the list
|
||||
/// And the key of the pair will be discarded, only the value
|
||||
/// will be added into the `ParameterList`
|
||||
void append(const OrderedDict<std::string, torch::Tensor>::Item& pair) {
|
||||
register_parameter(
|
||||
c10::to_string(parameters_.size()),
|
||||
pair.value(),
|
||||
pair.value().requires_grad());
|
||||
}
|
||||
|
||||
/// extend parameters from a container to the end of the list
|
||||
template <typename Container>
|
||||
void extend(const Container& container) {
|
||||
for (const auto& param : container) {
|
||||
append(param);
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an iterator to the start of the ParameterList
|
||||
/// the iterator returned will be type of `OrderedDict<std::string,
|
||||
/// torch::Tensor>::Item`
|
||||
Iterator begin() {
|
||||
return parameters_.begin();
|
||||
}
|
||||
|
||||
/// Returns a const iterator to the start of the ParameterList
|
||||
/// the iterator returned will be type of `OrderedDict<std::string,
|
||||
/// torch::Tensor>::Item`
|
||||
ConstIterator begin() const {
|
||||
return parameters_.begin();
|
||||
}
|
||||
|
||||
/// Returns an iterator to the end of the ParameterList
|
||||
/// the iterator returned will be type of `OrderedDict<std::string,
|
||||
/// torch::Tensor>::Item`
|
||||
Iterator end() {
|
||||
return parameters_.end();
|
||||
}
|
||||
|
||||
/// Returns a const iterator to the end of the ParameterList
|
||||
/// the iterator returned will be type of `OrderedDict<std::string,
|
||||
/// torch::Tensor>::Item`
|
||||
ConstIterator end() const {
|
||||
return parameters_.end();
|
||||
}
|
||||
|
||||
/// Returns the value associated with the given `key`. Throws an exception if
|
||||
/// no such key is stored in the `ParameterList`. Check contains(key) before
|
||||
/// for a non-throwing way of access
|
||||
at::Tensor& at(size_t idx) {
|
||||
TORCH_CHECK(idx < size(), "Index out of range");
|
||||
return parameters_[c10::to_string(idx)];
|
||||
}
|
||||
|
||||
/// Returns the value associated with the given `key`. Throws an exception if
|
||||
/// no such key is stored in the `ParameterList`. Check contains(key) before
|
||||
/// for a non-throwing way of access
|
||||
const at::Tensor& at(size_t idx) const {
|
||||
TORCH_CHECK(idx < size(), "Index out of range");
|
||||
return parameters_[c10::to_string(idx)];
|
||||
}
|
||||
|
||||
/// Returns the value associated with the given `key`. Throws an exception if
|
||||
/// no such key is stored in the `ParameterList`. Check contains(key) before
|
||||
/// for a non-throwing way of access
|
||||
at::Tensor& operator[](size_t idx) {
|
||||
return at(idx);
|
||||
}
|
||||
|
||||
/// Returns the value associated with the given `key`. Throws an exception if
|
||||
/// no such key is stored in the `ParameterList`. Check contains(key) before
|
||||
/// for a non-throwing way of access
|
||||
const at::Tensor& operator[](size_t idx) const {
|
||||
return at(idx);
|
||||
}
|
||||
|
||||
/// Return the size of the ParameterList
|
||||
size_t size() const noexcept {
|
||||
return parameters_.size();
|
||||
}
|
||||
/// True if the ParameterList is empty
|
||||
bool is_empty() const noexcept {
|
||||
return parameters_.is_empty();
|
||||
}
|
||||
|
||||
/// Overload the +=, so that two ParameterList could be incrementally added
|
||||
template <typename Container>
|
||||
Container& operator+=(const Container& other) {
|
||||
extend(other);
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename Head, typename... Tail>
|
||||
void push_back_var(Head&& head, Tail&&... tail) {
|
||||
append(std::forward<Head>(head));
|
||||
// Recursively calls this method, until the parameter pack only thas this
|
||||
// entry left. Then calls `push_back()` a final time (above).
|
||||
push_back_var(std::forward<Tail>(tail)...);
|
||||
}
|
||||
|
||||
/// The base case, when the list of modules is empty.
|
||||
void push_back_var() {}
|
||||
};
|
||||
TORCH_MODULE(ParameterList);
|
||||
} // namespace nn
|
||||
} // namespace torch
|
||||
Loading…
Reference in New Issue
Block a user