pytorch/test/custom_operator/test_custom_ops.cpp
Zachary DeVito 796363147f Implement more of of the nn.Module API (#28828)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28828

This updates torch::script::Module to more closely match the behavior
of nn.Module. In particular, it implements the (optionally recurisive)
iterators that retrieve submodules, parameters, and buffers and makes
their names match the python versions.

This also removes the individual accessors for Parameter, Module, Buffer, etc.
and replaces them with a single `attr` function which is equivalent to
writing `a.foo` in Python (`setattr` emulates `a.foo = v`).
As we build out the user-facing API for TorchScript values this will end
up matching how an  attribute is accessed on general objects.

This PR preservers the python bindings for script::Module by emulating the
old API at the binding level. A followup will clean up the usage to more
directly match the C++ API.

Test Plan: Imported from OSS

Differential Revision: D18197611

Pulled By: zdevito

fbshipit-source-id: 7ee4dcbb258605d1c988314b05d938423f1ccee5
2019-11-06 22:58:25 -08:00

149 lines
4.2 KiB
C++

#include <torch/script.h>
#include <torch/cuda.h>
#include "op.h"
#include <memory>
#include <string>
#include <vector>
#include <iostream>
namespace helpers {
template <typename Predicate>
void check_all_parameters(
const torch::jit::script::Module& module,
Predicate predicate) {
for (at::Tensor parameter : module.parameters()) {
AT_ASSERT(predicate(parameter));
}
}
} // namespace helpers
void get_operator_from_registry_and_execute() {
auto& ops = torch::jit::getAllOperatorsFor(
torch::jit::Symbol::fromQualString("custom::op"));
AT_ASSERT(ops.size() == 1);
auto& op = ops.front();
AT_ASSERT(op->schema().name() == "custom::op");
torch::jit::Stack stack;
torch::jit::push(stack, torch::ones(5), 2.0, 3);
op->getOperation()(stack);
std::vector<torch::Tensor> output;
torch::jit::pop(stack, output);
const auto manual = custom_op(torch::ones(5), 2.0, 3);
AT_ASSERT(output.size() == 3);
for (size_t i = 0; i < output.size(); ++i) {
AT_ASSERT(output[i].allclose(torch::ones(5) * 2));
AT_ASSERT(output[i].allclose(manual[i]));
}
}
void load_serialized_module_with_custom_op_and_execute(
const std::string& path_to_exported_script_module) {
torch::jit::script::Module module =
torch::jit::load(path_to_exported_script_module);
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones(5));
auto output = module.forward(inputs).toTensor();
AT_ASSERT(output.allclose(torch::ones(5) + 1));
}
void test_argument_checking_for_serialized_modules(
const std::string& path_to_exported_script_module) {
torch::jit::script::Module module =
torch::jit::load(path_to_exported_script_module);
try {
module.forward({torch::jit::IValue(1), torch::jit::IValue(2)});
AT_ASSERT(false);
} catch (const c10::Error& error) {
AT_ASSERT(
std::string(error.what_without_backtrace())
.find("Expected at most 2 argument(s) for operator 'forward', "
"but received 3 argument(s)") == 0);
}
try {
module.forward({torch::jit::IValue(5)});
AT_ASSERT(false);
} catch (const c10::Error& error) {
AT_ASSERT(
std::string(error.what_without_backtrace())
.find("forward() Expected a value of type 'Tensor' "
"for argument 'input' but instead found type 'int'") == 0);
}
try {
module.forward({});
AT_ASSERT(false);
} catch (const c10::Error& error) {
AT_ASSERT(
std::string(error.what_without_backtrace())
.find("forward() is missing value for argument 'input'") == 0);
}
}
void test_move_to_device(const std::string& path_to_exported_script_module) {
torch::jit::script::Module module =
torch::jit::load(path_to_exported_script_module);
helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
return tensor.device().is_cpu();
});
module.to(torch::kCUDA);
helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
return tensor.device().is_cuda();
});
module.to(torch::kCPU);
helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
return tensor.device().is_cpu();
});
}
void test_move_to_dtype(const std::string& path_to_exported_script_module) {
torch::jit::script::Module module =
torch::jit::load(path_to_exported_script_module);
module.to(torch::kInt);
helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
return tensor.dtype() == torch::kInt;
});
module.to(torch::kDouble);
helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
return tensor.dtype() == torch::kDouble;
});
}
int main(int argc, const char* argv[]) {
if (argc != 2) {
std::cerr << "usage: test_custom_ops <path-to-exported-script-module>\n";
return -1;
}
const std::string path_to_exported_script_module = argv[1];
get_operator_from_registry_and_execute();
load_serialized_module_with_custom_op_and_execute(
path_to_exported_script_module);
test_argument_checking_for_serialized_modules(path_to_exported_script_module);
test_move_to_dtype(path_to_exported_script_module);
if (torch::cuda::device_count() > 0) {
test_move_to_device(path_to_exported_script_module);
}
std::cout << "ok\n";
}