mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: This allows custom ops to take string parameters. Pull Request resolved: https://github.com/pytorch/pytorch/pull/13683 Differential Revision: D13017010 Pulled By: soumith fbshipit-source-id: 7c40aca7f57ba3f8812d34bc55828ff362c69bd2
35 lines
963 B
C++
35 lines
963 B
C++
#include <torch/script.h>
|
|
|
|
#include "op.h"
|
|
|
|
#include <cstddef>
|
|
#include <vector>
|
|
#include <string>
|
|
|
|
std::vector<torch::Tensor> custom_op(
|
|
torch::Tensor tensor,
|
|
double scalar,
|
|
int64_t repeat) {
|
|
std::vector<torch::Tensor> output;
|
|
output.reserve(repeat);
|
|
for (int64_t i = 0; i < repeat; ++i) {
|
|
output.push_back(tensor * scalar);
|
|
}
|
|
return output;
|
|
}
|
|
|
|
int64_t custom_op2(std::string s1, std::string s2) {
|
|
return s1.compare(s2);
|
|
}
|
|
|
|
static auto registry =
|
|
torch::jit::RegisterOperators()
|
|
// We parse the schema for the user.
|
|
.op("custom::op", &custom_op)
|
|
.op("custom::op2", &custom_op2)
|
|
// User provided schema. Among other things, allows defaulting values,
|
|
// because we cannot infer default values from the signature. It also
|
|
// gives arguments meaningful names.
|
|
.op("custom::op_with_defaults(Tensor tensor, float scalar = 1, int repeat = 1) -> Tensor[]",
|
|
&custom_op);
|