mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/30168 Previous implementation of `clone` in `script::Module` copies both the module instance and the class type, after we enabled type sharing https://github.com/pytorch/pytorch/pull/26666 we also need to have a function to clone instance only and share the underlying class type. Test Plan: tbd Imported from OSS Differential Revision: D18631324 fbshipit-source-id: dbadcf19695faee0f755f45093b24618c047b9d1
46 lines
1.3 KiB
C++
46 lines
1.3 KiB
C++
#include <test/cpp/jit/test_base.h>
|
|
#include <test/cpp/jit/test_utils.h>
|
|
#include <torch/torch.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
using namespace torch::jit::script;
|
|
|
|
void testModuleCloneInstance() {
|
|
auto cu = std::make_shared<CompilationUnit>();
|
|
auto cls = ClassType::create("foo.bar", cu, true);
|
|
auto attr_name = "attr";
|
|
cls->addAttribute(attr_name, IntType::get());
|
|
Module m(cu, cls);
|
|
auto v = IValue(2);
|
|
m.register_attribute(attr_name,
|
|
IntType::get(),
|
|
v,
|
|
false);
|
|
|
|
Module m2 = m.clone();
|
|
Module m3 = m.clone_instance();
|
|
// Make sure copy works
|
|
ASSERT_EQ(m2.attr(attr_name).toInt(), 2);
|
|
ASSERT_EQ(m3.attr(attr_name).toInt(), 2);
|
|
|
|
// clone will copy both type and data, therefore we'll have a
|
|
// different type
|
|
ASSERT_NE(m.type(), m2.type());
|
|
// clone_instance only copies data, type is shared
|
|
ASSERT_EQ(m.type(), m3.type());
|
|
|
|
// change value of copied instance
|
|
m3.register_attribute(attr_name,
|
|
IntType::get(),
|
|
IValue(3),
|
|
false);
|
|
// Verify value of original instance doesn't change
|
|
ASSERT_EQ(m2.attr(attr_name).toInt(), 2);
|
|
ASSERT_EQ(m3.attr(attr_name).toInt(), 3);
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|