mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
This is an automatic change generated by the following script:
```
#!/usr/bin/env python3
from subprocess import check_output, check_call
import os
def get_compiled_files_list():
import json
with open("build/compile_commands.json") as f:
data = json.load(f)
files = [os.path.relpath(node['file']) for node in data]
for idx, fname in enumerate(files):
if fname.startswith('build/') and fname.endswith('.DEFAULT.cpp'):
files[idx] = fname[len('build/'):-len('.DEFAULT.cpp')]
return files
def run_clang_tidy(fname):
check_call(["python3", "tools/clang_tidy.py", "-c", "build", "-x", fname,"-s"])
changes = check_output(["git", "ls-files", "-m"])
if len(changes) == 0:
return
check_call(["git", "commit","--all", "-m", f"NOLINT stubs for {fname}"])
def main():
git_files = check_output(["git", "ls-files"]).decode("ascii").split("\n")
compiled_files = get_compiled_files_list()
for idx, fname in enumerate(git_files):
if fname not in compiled_files:
continue
if fname.startswith("caffe2/contrib/aten/"):
continue
print(f"[{idx}/{len(git_files)}] Processing {fname}")
run_clang_tidy(fname)
if __name__ == "__main__":
main()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56892
Reviewed By: H-Huang
Differential Revision: D27991944
Pulled By: malfet
fbshipit-source-id: 5415e1eb2c1b34319a4f03024bfaa087007d7179
141 lines
4.3 KiB
C++
141 lines
4.3 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <test/cpp/jit/test_custom_class_registrations.h>
|
|
#include <torch/csrc/jit/passes/freeze_module.h>
|
|
#include <torch/custom_class.h>
|
|
#include <torch/script.h>
|
|
|
|
#include <iostream>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(CustomClassTest, TorchbindIValueAPI) {
|
|
script::Module m("m");
|
|
|
|
// test make_custom_class API
|
|
auto custom_class_obj = make_custom_class<MyStackClass<std::string>>(
|
|
std::vector<std::string>{"foo", "bar"});
|
|
m.define(R"(
|
|
def forward(self, s : __torch__.torch.classes._TorchScriptTesting._StackString):
|
|
return s.pop(), s
|
|
)");
|
|
|
|
auto test_with_obj = [&m](IValue obj, std::string expected) {
|
|
auto res = m.run_method("forward", obj);
|
|
auto tup = res.toTuple();
|
|
AT_ASSERT(tup->elements().size() == 2);
|
|
auto str = tup->elements()[0].toStringRef();
|
|
auto other_obj =
|
|
tup->elements()[1].toCustomClass<MyStackClass<std::string>>();
|
|
AT_ASSERT(str == expected);
|
|
auto ref_obj = obj.toCustomClass<MyStackClass<std::string>>();
|
|
AT_ASSERT(other_obj.get() == ref_obj.get());
|
|
};
|
|
|
|
test_with_obj(custom_class_obj, "bar");
|
|
|
|
// test IValue() API
|
|
auto my_new_stack = c10::make_intrusive<MyStackClass<std::string>>(
|
|
std::vector<std::string>{"baz", "boo"});
|
|
auto new_stack_ivalue = c10::IValue(my_new_stack);
|
|
|
|
test_with_obj(new_stack_ivalue, "boo");
|
|
}
|
|
|
|
class TorchBindTestClass : public torch::jit::CustomClassHolder {
|
|
public:
|
|
std::string get() {
|
|
return "Hello, I am your test custom class";
|
|
}
|
|
};
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
|
constexpr char class_doc_string[] = R"(
|
|
I am docstring for TorchBindTestClass
|
|
Args:
|
|
What is an argument? Oh never mind, I don't take any.
|
|
|
|
Return:
|
|
How would I know? I am just a holder of some meaningless test methods.
|
|
)";
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
|
constexpr char method_doc_string[] =
|
|
"I am docstring for TorchBindTestClass get_with_docstring method";
|
|
|
|
namespace {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
static auto reg =
|
|
torch::class_<TorchBindTestClass>(
|
|
"_TorchBindTest",
|
|
"_TorchBindTestClass",
|
|
class_doc_string)
|
|
.def("get", &TorchBindTestClass::get)
|
|
.def("get_with_docstring", &TorchBindTestClass::get, method_doc_string);
|
|
|
|
} // namespace
|
|
|
|
// Tests DocString is properly propagated when defining CustomClasses.
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(CustomClassTest, TestDocString) {
|
|
auto class_type = getCustomClass(
|
|
"__torch__.torch.classes._TorchBindTest._TorchBindTestClass");
|
|
AT_ASSERT(class_type);
|
|
AT_ASSERT(class_type->doc_string() == class_doc_string);
|
|
|
|
AT_ASSERT(class_type->getMethod("get").doc_string().empty());
|
|
AT_ASSERT(
|
|
class_type->getMethod("get_with_docstring").doc_string() ==
|
|
method_doc_string);
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(CustomClassTest, Serialization) {
|
|
script::Module m("m");
|
|
|
|
// test make_custom_class API
|
|
auto custom_class_obj = make_custom_class<MyStackClass<std::string>>(
|
|
std::vector<std::string>{"foo", "bar"});
|
|
m.register_attribute(
|
|
"s",
|
|
custom_class_obj.type(),
|
|
custom_class_obj,
|
|
// NOLINTNEXTLINE(bugprone-argument-comment)
|
|
/*is_parameter=*/false);
|
|
m.define(R"(
|
|
def forward(self):
|
|
return self.s.return_a_tuple()
|
|
)");
|
|
|
|
auto test_with_obj = [](script::Module& mod) {
|
|
auto res = mod.run_method("forward");
|
|
auto tup = res.toTuple();
|
|
AT_ASSERT(tup->elements().size() == 2);
|
|
auto i = tup->elements()[1].toInt();
|
|
AT_ASSERT(i == 123);
|
|
};
|
|
|
|
auto frozen_m = torch::jit::freeze_module(m.clone());
|
|
|
|
test_with_obj(m);
|
|
test_with_obj(frozen_m);
|
|
|
|
std::ostringstream oss;
|
|
m.save(oss);
|
|
std::istringstream iss(oss.str());
|
|
caffe2::serialize::IStreamAdapter adapter{&iss};
|
|
auto loaded_module = torch::jit::load(iss, torch::kCPU);
|
|
|
|
std::ostringstream oss_frozen;
|
|
frozen_m.save(oss_frozen);
|
|
std::istringstream iss_frozen(oss_frozen.str());
|
|
caffe2::serialize::IStreamAdapter adapter_frozen{&iss_frozen};
|
|
auto loaded_frozen_module = torch::jit::load(iss_frozen, torch::kCPU);
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|