mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/28828 This updates torch::script::Module to more closely match the behavior of nn.Module. In particular, it implements the (optionally recurisive) iterators that retrieve submodules, parameters, and buffers and makes their names match the python versions. This also removes the individual accessors for Parameter, Module, Buffer, etc. and replaces them with a single `attr` function which is equivalent to writing `a.foo` in Python (`setattr` emulates `a.foo = v`). As we build out the user-facing API for TorchScript values this will end up matching how an attribute is accessed on general objects. This PR preservers the python bindings for script::Module by emulating the old API at the binding level. A followup will clean up the usage to more directly match the C++ API. Test Plan: Imported from OSS Differential Revision: D18197611 Pulled By: zdevito fbshipit-source-id: 7ee4dcbb258605d1c988314b05d938423f1ccee5
149 lines
4.2 KiB
C++
149 lines
4.2 KiB
C++
#include <torch/script.h>
|
|
#include <torch/cuda.h>
|
|
|
|
#include "op.h"
|
|
|
|
#include <memory>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include <iostream>
|
|
|
|
namespace helpers {
|
|
template <typename Predicate>
|
|
void check_all_parameters(
|
|
const torch::jit::script::Module& module,
|
|
Predicate predicate) {
|
|
for (at::Tensor parameter : module.parameters()) {
|
|
AT_ASSERT(predicate(parameter));
|
|
}
|
|
}
|
|
} // namespace helpers
|
|
|
|
void get_operator_from_registry_and_execute() {
|
|
auto& ops = torch::jit::getAllOperatorsFor(
|
|
torch::jit::Symbol::fromQualString("custom::op"));
|
|
AT_ASSERT(ops.size() == 1);
|
|
|
|
auto& op = ops.front();
|
|
AT_ASSERT(op->schema().name() == "custom::op");
|
|
|
|
torch::jit::Stack stack;
|
|
torch::jit::push(stack, torch::ones(5), 2.0, 3);
|
|
op->getOperation()(stack);
|
|
std::vector<torch::Tensor> output;
|
|
torch::jit::pop(stack, output);
|
|
|
|
const auto manual = custom_op(torch::ones(5), 2.0, 3);
|
|
|
|
AT_ASSERT(output.size() == 3);
|
|
for (size_t i = 0; i < output.size(); ++i) {
|
|
AT_ASSERT(output[i].allclose(torch::ones(5) * 2));
|
|
AT_ASSERT(output[i].allclose(manual[i]));
|
|
}
|
|
}
|
|
|
|
void load_serialized_module_with_custom_op_and_execute(
|
|
const std::string& path_to_exported_script_module) {
|
|
torch::jit::script::Module module =
|
|
torch::jit::load(path_to_exported_script_module);
|
|
std::vector<torch::jit::IValue> inputs;
|
|
inputs.push_back(torch::ones(5));
|
|
auto output = module.forward(inputs).toTensor();
|
|
|
|
AT_ASSERT(output.allclose(torch::ones(5) + 1));
|
|
}
|
|
|
|
void test_argument_checking_for_serialized_modules(
|
|
const std::string& path_to_exported_script_module) {
|
|
torch::jit::script::Module module =
|
|
torch::jit::load(path_to_exported_script_module);
|
|
|
|
try {
|
|
module.forward({torch::jit::IValue(1), torch::jit::IValue(2)});
|
|
AT_ASSERT(false);
|
|
} catch (const c10::Error& error) {
|
|
AT_ASSERT(
|
|
std::string(error.what_without_backtrace())
|
|
.find("Expected at most 2 argument(s) for operator 'forward', "
|
|
"but received 3 argument(s)") == 0);
|
|
}
|
|
|
|
try {
|
|
module.forward({torch::jit::IValue(5)});
|
|
AT_ASSERT(false);
|
|
} catch (const c10::Error& error) {
|
|
AT_ASSERT(
|
|
std::string(error.what_without_backtrace())
|
|
.find("forward() Expected a value of type 'Tensor' "
|
|
"for argument 'input' but instead found type 'int'") == 0);
|
|
}
|
|
|
|
try {
|
|
module.forward({});
|
|
AT_ASSERT(false);
|
|
} catch (const c10::Error& error) {
|
|
AT_ASSERT(
|
|
std::string(error.what_without_backtrace())
|
|
.find("forward() is missing value for argument 'input'") == 0);
|
|
}
|
|
}
|
|
|
|
void test_move_to_device(const std::string& path_to_exported_script_module) {
|
|
torch::jit::script::Module module =
|
|
torch::jit::load(path_to_exported_script_module);
|
|
|
|
helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
|
|
return tensor.device().is_cpu();
|
|
});
|
|
|
|
module.to(torch::kCUDA);
|
|
|
|
helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
|
|
return tensor.device().is_cuda();
|
|
});
|
|
|
|
module.to(torch::kCPU);
|
|
|
|
helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
|
|
return tensor.device().is_cpu();
|
|
});
|
|
}
|
|
|
|
void test_move_to_dtype(const std::string& path_to_exported_script_module) {
|
|
torch::jit::script::Module module =
|
|
torch::jit::load(path_to_exported_script_module);
|
|
|
|
module.to(torch::kInt);
|
|
|
|
helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
|
|
return tensor.dtype() == torch::kInt;
|
|
});
|
|
|
|
module.to(torch::kDouble);
|
|
|
|
helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
|
|
return tensor.dtype() == torch::kDouble;
|
|
});
|
|
}
|
|
|
|
int main(int argc, const char* argv[]) {
|
|
if (argc != 2) {
|
|
std::cerr << "usage: test_custom_ops <path-to-exported-script-module>\n";
|
|
return -1;
|
|
}
|
|
const std::string path_to_exported_script_module = argv[1];
|
|
|
|
get_operator_from_registry_and_execute();
|
|
load_serialized_module_with_custom_op_and_execute(
|
|
path_to_exported_script_module);
|
|
test_argument_checking_for_serialized_modules(path_to_exported_script_module);
|
|
test_move_to_dtype(path_to_exported_script_module);
|
|
|
|
if (torch::cuda::device_count() > 0) {
|
|
test_move_to_device(path_to_exported_script_module);
|
|
}
|
|
|
|
std::cout << "ok\n";
|
|
}
|