pytorch/test/cpp/api/util.h
Peter Goldsborough ae44a6b5e3 Fix Sequential::clone() (#9372)
Summary:
I noticed that `Sequential::clone()` does not work. This is because `Sequential` does not use `reset()` which is normally where modules have to initialize and register its submodules. Further, this is because of the way `Sequential` allows its modules to be passed in the constructor, which doesn't work with `reset()` (since it does "late" initialization).

I've added some better error messages inside `Cloneable::clone()` which makes this kind of mistake clearer for other users, and tests for `Sequential::clone()`.

I also had to give `AnyModule` a deep `clone()` method.

ebetica ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9372

Differential Revision: D8865189

Pulled By: goldsborough

fbshipit-source-id: b81586e0d3157cd3c4265b19ac8dd87c5d8dcf94
2018-07-16 21:53:42 -07:00

30 lines
703 B
C++

#pragma once
#include <torch/nn/cloneable.h>
#include <string>
#include <utility>
namespace torch {
namespace test {
// Lets you use a container without making a new class,
// for experimental implementations
class SimpleContainer : public nn::Cloneable<SimpleContainer> {
public:
void reset() override {}
template <typename ModuleHolder>
ModuleHolder add(
ModuleHolder module_holder,
std::string name = std::string()) {
return Module::register_module(std::move(name), module_holder);
}
};
inline bool pointer_equal(torch::Tensor first, torch::Tensor second) {
return first.data().data<float>() == second.data().data<float>();
}
} // namespace test
} // namespace torch