mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: I noticed that `Sequential::clone()` does not work. This is because `Sequential` does not use `reset()` which is normally where modules have to initialize and register its submodules. Further, this is because of the way `Sequential` allows its modules to be passed in the constructor, which doesn't work with `reset()` (since it does "late" initialization). I've added some better error messages inside `Cloneable::clone()` which makes this kind of mistake clearer for other users, and tests for `Sequential::clone()`. I also had to give `AnyModule` a deep `clone()` method. ebetica ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/9372 Differential Revision: D8865189 Pulled By: goldsborough fbshipit-source-id: b81586e0d3157cd3c4265b19ac8dd87c5d8dcf94
30 lines
703 B
C++
30 lines
703 B
C++
#pragma once
|
|
|
|
#include <torch/nn/cloneable.h>
|
|
|
|
#include <string>
|
|
#include <utility>
|
|
|
|
namespace torch {
|
|
namespace test {
|
|
|
|
// Lets you use a container without making a new class,
|
|
// for experimental implementations
|
|
class SimpleContainer : public nn::Cloneable<SimpleContainer> {
|
|
public:
|
|
void reset() override {}
|
|
|
|
template <typename ModuleHolder>
|
|
ModuleHolder add(
|
|
ModuleHolder module_holder,
|
|
std::string name = std::string()) {
|
|
return Module::register_module(std::move(name), module_holder);
|
|
}
|
|
};
|
|
|
|
inline bool pointer_equal(torch::Tensor first, torch::Tensor second) {
|
|
return first.data().data<float>() == second.data().data<float>();
|
|
}
|
|
} // namespace test
|
|
} // namespace torch
|