mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14414 The previous functions were CUDA-centric, and lead to lots of places where we improperly assumed that CUDA is the only game in town (it's not). Best to delete them. What are your alternatives? This diff fix some use sites which may give you some ideas. In particular, the "given a device type, give me the current device for that device type" might be a good function to enshrine for real. Reviewed By: gchanan Differential Revision: D13218540 fbshipit-source-id: 2f42cd6b9bdab4930d25166b8041c9466a1c6e0a
157 lines
4.5 KiB
C++
157 lines
4.5 KiB
C++
#include <torch/script.h>
|
|
#include <torch/cuda.h>
|
|
|
|
#include "op.h"
|
|
|
|
#include <memory>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include <iostream>
|
|
|
|
namespace helpers {
|
|
template <typename Predicate>
|
|
void check_all_parameters(
|
|
const torch::jit::script::Module& module,
|
|
Predicate predicate) {
|
|
for (const auto& parameter : module.get_parameters()) {
|
|
AT_ASSERT(predicate(*parameter->slot()));
|
|
}
|
|
for (const auto& child : module.get_modules()) {
|
|
check_all_parameters(*child->module, predicate);
|
|
}
|
|
}
|
|
} // namespace helpers
|
|
|
|
void get_operator_from_registry_and_execute() {
|
|
auto& ops = torch::jit::getAllOperatorsFor(
|
|
torch::jit::Symbol::fromQualString("custom::op"));
|
|
AT_ASSERT(ops.size() == 1);
|
|
|
|
auto& op = ops.front();
|
|
AT_ASSERT(op->schema().name() == "custom::op");
|
|
|
|
torch::jit::Stack stack;
|
|
torch::jit::push(stack, torch::ones(5), 2.0, 3);
|
|
op->getOperation()(stack);
|
|
std::vector<torch::Tensor> output;
|
|
torch::jit::pop(stack, output);
|
|
|
|
const auto manual = custom_op(torch::ones(5), 2.0, 3);
|
|
|
|
AT_ASSERT(output.size() == 3);
|
|
for (size_t i = 0; i < output.size(); ++i) {
|
|
AT_ASSERT(output[i].allclose(torch::ones(5) * 2));
|
|
AT_ASSERT(output[i].allclose(manual[i]));
|
|
}
|
|
}
|
|
|
|
void load_serialized_module_with_custom_op_and_execute(
|
|
const std::string& path_to_exported_script_module) {
|
|
std::shared_ptr<torch::jit::script::Module> module =
|
|
torch::jit::load(path_to_exported_script_module);
|
|
AT_ASSERT(module != nullptr);
|
|
|
|
std::vector<torch::jit::IValue> inputs;
|
|
inputs.push_back(torch::ones(5));
|
|
auto output = module->forward(inputs).toTensor();
|
|
|
|
AT_ASSERT(output.allclose(torch::ones(5) + 1));
|
|
}
|
|
|
|
void test_argument_checking_for_serialized_modules(
|
|
const std::string& path_to_exported_script_module) {
|
|
std::shared_ptr<torch::jit::script::Module> module =
|
|
torch::jit::load(path_to_exported_script_module);
|
|
AT_ASSERT(module != nullptr);
|
|
|
|
try {
|
|
module->forward({torch::jit::IValue(1), torch::jit::IValue(2)});
|
|
AT_ASSERT(false);
|
|
} catch (const c10::Error& error) {
|
|
AT_ASSERT(
|
|
std::string(error.what_without_backtrace())
|
|
.find("Expected at most 1 argument(s) for operator 'forward', "
|
|
"but received 2 argument(s)") == 0);
|
|
}
|
|
|
|
try {
|
|
module->forward({torch::jit::IValue(5)});
|
|
AT_ASSERT(false);
|
|
} catch (const c10::Error& error) {
|
|
AT_ASSERT(
|
|
std::string(error.what_without_backtrace())
|
|
.find("Expected value of type Tensor for argument 'input' in "
|
|
"position 0, but instead got value of type int") == 0);
|
|
}
|
|
|
|
try {
|
|
module->forward({});
|
|
AT_ASSERT(false);
|
|
} catch (const c10::Error& error) {
|
|
AT_ASSERT(
|
|
std::string(error.what_without_backtrace())
|
|
.find("forward() is missing value for argument 'input'") == 0);
|
|
}
|
|
}
|
|
|
|
void test_move_to_device(const std::string& path_to_exported_script_module) {
|
|
std::shared_ptr<torch::jit::script::Module> module =
|
|
torch::jit::load(path_to_exported_script_module);
|
|
AT_ASSERT(module != nullptr);
|
|
|
|
helpers::check_all_parameters(*module, [](const torch::Tensor& tensor) {
|
|
return tensor.device().is_cpu();
|
|
});
|
|
|
|
module->to(torch::kCUDA);
|
|
|
|
helpers::check_all_parameters(*module, [](const torch::Tensor& tensor) {
|
|
return tensor.device().is_cuda();
|
|
});
|
|
|
|
module->to(torch::kCPU);
|
|
|
|
helpers::check_all_parameters(*module, [](const torch::Tensor& tensor) {
|
|
return tensor.device().is_cpu();
|
|
});
|
|
}
|
|
|
|
void test_move_to_dtype(const std::string& path_to_exported_script_module) {
|
|
std::shared_ptr<torch::jit::script::Module> module =
|
|
torch::jit::load(path_to_exported_script_module);
|
|
AT_ASSERT(module != nullptr);
|
|
|
|
module->to(torch::kInt);
|
|
|
|
helpers::check_all_parameters(*module, [](const torch::Tensor& tensor) {
|
|
return tensor.dtype() == torch::kInt;
|
|
});
|
|
|
|
module->to(torch::kDouble);
|
|
|
|
helpers::check_all_parameters(*module, [](const torch::Tensor& tensor) {
|
|
return tensor.dtype() == torch::kDouble;
|
|
});
|
|
}
|
|
|
|
int main(int argc, const char* argv[]) {
|
|
if (argc != 2) {
|
|
std::cerr << "usage: test_custom_ops <path-to-exported-script-module>\n";
|
|
return -1;
|
|
}
|
|
const std::string path_to_exported_script_module = argv[1];
|
|
|
|
get_operator_from_registry_and_execute();
|
|
load_serialized_module_with_custom_op_and_execute(
|
|
path_to_exported_script_module);
|
|
test_argument_checking_for_serialized_modules(path_to_exported_script_module);
|
|
test_move_to_dtype(path_to_exported_script_module);
|
|
|
|
if (torch::cuda::device_count() > 0) {
|
|
test_move_to_device(path_to_exported_script_module);
|
|
}
|
|
|
|
std::cout << "ok\n";
|
|
}
|