pytorch/test/custom_operator/op.cpp
Thomas Viehmann 4c3b76c402 Add std::string to the getTypePtr for JIT inference of custom op types (#13683)
Summary:
This allows custom ops to take string parameters.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13683

Differential Revision: D13017010

Pulled By: soumith

fbshipit-source-id: 7c40aca7f57ba3f8812d34bc55828ff362c69bd2
2018-11-10 12:58:53 -08:00

35 lines
963 B
C++

#include <torch/script.h>
#include "op.h"
#include <cstddef>
#include <vector>
#include <string>
std::vector<torch::Tensor> custom_op(
torch::Tensor tensor,
double scalar,
int64_t repeat) {
std::vector<torch::Tensor> output;
output.reserve(repeat);
for (int64_t i = 0; i < repeat; ++i) {
output.push_back(tensor * scalar);
}
return output;
}
int64_t custom_op2(std::string s1, std::string s2) {
return s1.compare(s2);
}
static auto registry =
torch::jit::RegisterOperators()
// We parse the schema for the user.
.op("custom::op", &custom_op)
.op("custom::op2", &custom_op2)
// User provided schema. Among other things, allows defaulting values,
// because we cannot infer default values from the signature. It also
// gives arguments meaningful names.
.op("custom::op_with_defaults(Tensor tensor, float scalar = 1, int repeat = 1) -> Tensor[]",
&custom_op);