pytorch/torch/csrc/jit/register_prim_ops.cpp
Elias Ellison 59f8e8ada7 First step at adding exceptions (#12789)
Summary:
This is a first step towards adding exceptions. We need minimal support in order to begin converting the torch library to weak script mode (which is the main goal here).

Some limitations (that are documented in the tests & compiler):
1. Cannot assign exceptions to variables
2. Any name after raise is being treated as a valid Exception
3. No control flow analysis yet. Below a will be undefined:

if True:
     a = 1
else:
     raise Exception("Hi")
return a
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12789

Differential Revision: D12848936

Pulled By: eellison

fbshipit-source-id: 1f60ceef2381040486123ec797e97d65b074862d
2018-10-30 20:25:50 -07:00

942 lines
31 KiB
C++

#include "torch/csrc/autograd/edge.h"
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/generated/variable_factories.h"
#include "torch/csrc/autograd/profiler.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/jit/fusers/interface.h"
#include "torch/csrc/jit/graph_executor.h"
#include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/operator.h"
#include "torch/csrc/jit/custom_operator.h"
#include "torch/csrc/jit/script/jit_exception.h"
#include "torch/csrc/variable_tensor_functions.h"
#include <exception>
#include <iostream>
#include <limits>
#include <memory>
#include <mutex>
#include <ostream>
#include <stdexcept>
#include <string>
#include <typeinfo>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
namespace torch {
namespace jit {
namespace {
Operation noop(Node* n) {
return [](Stack& stack) { return 0; };
}
// using the rules from python_arg_parser FunctionParameter::check
// tensor cannot have grad set, tensor must be 0 dim,
// and if the dest is an int the source must be integral type
void checkImplicitTensorToNum(at::Tensor t, bool toInt) {
if (autograd::as_variable_ref(t).requires_grad()) {
throw std::runtime_error("Cannot input a tensor that requires grad as a scalar argument");
}
if (t.sizes().size() != 0) {
throw std::runtime_error("Cannot input a tensor of dimension other than 0 as a scalar argument");
}
if (toInt && !isIntegralType(autograd::as_variable_ref(t).data().type().scalarType())) {
std::stringstream ss;
ss << "Cannot input a tensor of type " << t.type().scalarType() << " as an integral argument";
throw std::runtime_error(ss.str());
}
}
RegisterOperators reg({
Operator(
prim::MemoryFence,
[](Node* node) {
return [](Stack& stack) {
return 0;
};
}),
Operator(
prim::FusionGroup,
[](Node* node) {
auto handle = getFusionHandle(node);
return [handle](Stack& stack) {
autograd::profiler::RecordFunction record("FusionGroup");
handle->run(stack);
return 0;
};
}),
Operator(
prim::TensorToBool,
[](Node* node) -> Operation {
return [](Stack& stack) {
at::Tensor a;
pop(stack, a);
at::DeviceGuard guard(a);
push(stack, a.item<int64_t>() != 0);
return 0;
};
}),
Operator(
prim::TensorToNum,
[](Node* node) -> Operation {
if(node->output()->type() == IntType::get()) {
return [](Stack& stack) {
at::Tensor a;
pop(stack, a);
at::DeviceGuard guard(a);
push(stack, a.item<int64_t>());
return 0;
};
} else {
return [](Stack& stack) {
at::Tensor a;
pop(stack, a);
at::DeviceGuard guard(a);
push(stack, a.item<double>());
return 0;
};
}
}),
Operator(
prim::ImplicitTensorToNum,
[](Node* node) -> Operation {
if(node->output()->type() == IntType::get()) {
return [](Stack& stack) {
at::Tensor a;
pop(stack, a);
checkImplicitTensorToNum(a, /*to int*/true);
at::DeviceGuard guard(a);
push(stack, a.item<int64_t>());
return 0;
};
} else {
return [](Stack& stack) {
at::Tensor a;
pop(stack, a);
checkImplicitTensorToNum(a, /*to int*/false);
at::DeviceGuard guard(a);
push(stack, a.item<double>());
return 0;
};
}
}),
Operator(
prim::NumToTensor,
[](Node* node) -> Operation {
return [](Stack& stack) {
at::Scalar s;
pop(stack, s);
push(stack, autograd::make_variable(at::scalar_to_tensor(s)));
return 0;
};
}),
Operator(
prim::BoolToTensor,
[](Node* node) -> Operation {
return [](Stack& stack) {
bool b;
pop(stack, b);
push(
stack,
autograd::make_variable(at::scalar_to_tensor(b)));
return 0;
};
}),
Operator(
prim::IntToFloat,
[](Node* node) -> Operation {
return [](Stack& stack) {
int64_t i;
pop(stack, i);
push(stack, (float)i);
return 0;
};
}),
Operator(
prim::FloatToInt,
[](Node* node) -> Operation {
return [](Stack& stack) {
double d;
pop(stack, d);
push(stack, (int64_t)d);
return 0;
};
}),
Operator(
prim::StringToFloat,
[](Node* node) -> Operation {
return [](Stack& stack) {
auto s = pop(stack).toString();
if (s->string() == "inf")
push(stack, std::numeric_limits<double>::infinity());
else if (s->string() == "-inf")
push(stack, -std::numeric_limits<double>::infinity());
else
AT_ERROR(
"Only 'inf' or '-inf' can be cast to a float, but got '",
s->string(),
"'");
return 0;
};
}),
Operator(
prim::TensorDevice,
[](Node* node) -> Operation {
return [](Stack& stack) {
at::Tensor a;
pop(stack, a);
push(stack, std::vector<int64_t>({static_cast<int64_t>(a.device().type()),
a.device().index()}));
return 0;
};
}),
Operator(
prim::TensorDType,
[](Node* node) -> Operation {
return [](Stack& stack) {
at::Tensor a;
pop(stack, a);
push(stack, static_cast<int64_t>(a.scalar_type()));
return 0;
};
}),
Operator(
prim::TensorShape,
[](Node* node) -> Operation {
return [](Stack& stack) {
at::Tensor a;
pop(stack, a);
push(stack, a.sizes());
return 0;
};
}),
Operator(
prim::Undefined,
[](Node* node) {
return [](Stack& stack) {
stack.push_back(at::Tensor());
return 0;
};
}),
Operator(
prim::None,
[](Node* node) {
return [](Stack& stack) {
stack.push_back(IValue());
return 0;
};
}),
Operator(
prim::NoneGenerator,
[](Node* node) {
return [](Stack& stack) {
stack.emplace_back();
return 0;
};
}),
Operator(
prim::Print,
[](Node* node) {
size_t num_inputs = node->inputs().size();
return [num_inputs](Stack& stack) {
bool first = true;
for (const IValue& i : last(stack, num_inputs)) {
if (!first)
std::cout << " ";
first = false;
std::cout << i;
}
drop(stack, num_inputs);
std::cout << std::endl;
return 0;
};
}),
Operator(
prim::RaiseException,
[](Node* node) -> Operation {
return [](Stack& stack) {
throw JITException(pop(stack).toStringRef());
return 0;
};
}),
// Load x, y
// loads values from registers onto the stack, the actual callback does
// nothing since the stack manipulation is already encoded in inst.inputs
// and inst.outputs
Operator(prim::Load, noop),
// x, y = Store
// stores vales from stack into registers, the actual callback does
// nothing since the stack manipulation is already encoded in inst.inputs
// and inst.outputs
Operator(prim::Store, noop),
Operator(
prim::Drop,
[](Node* node) {
auto N = node->inputs().size();
return [=](Stack& stack) {
drop(stack, N);
return 0;
};
}),
Operator(
prim::LoadWorld,
[](Node* node) {
return [](Stack& stack) {
push(stack, World{0});
return 0;
};
}),
Operator(
prim::StoreWorld,
[](Node* node) {
return [](Stack& stack) {
drop(stack, 1);
return 0;
};
}),
Operator(
prim::DummyWorld,
[](Node* node) {
return [](Stack& stack) {
AT_ERROR("Encountered a dummy world during graph execution.");
return 0;
};
}),
Operator(
onnx::Reshape,
[](Node* node) {
return [=](Stack& stack) {
at::Tensor input, shape;
pop(stack, input, shape);
shape = shape.contiguous();
JIT_ASSERT(shape.ndimension() == 1);
at::IntList shape_list(shape.data<int64_t>(), shape.size(0));
push(stack, input.reshape(shape_list));
return 0;
};
}),
Operator(
onnx::Shape,
[](Node* node) {
return [=](Stack& stack) {
auto t = pop(stack).toTensor();
at::IntList sizes = t.sizes();
auto sizes_tensor = torch::empty(
{static_cast<int64_t>(sizes.size())}, at::dtype(at::kLong));
auto accessor = sizes_tensor.accessor<int64_t, 1>();
for (size_t i = 0; i < sizes.size(); ++i) {
accessor[i] = sizes[i];
}
stack.push_back(sizes_tensor);
return 0;
};
}),
Operator(
prim::AnyDefined,
[](Node* node) {
size_t num_inputs = node->inputs().size();
return [=](Stack& stack) {
bool result = false;
for (const IValue& t : last(stack, num_inputs)) {
if (std::move(t).toTensor().defined()) {
result = true;
break;
}
}
drop(stack, num_inputs);
stack.push_back(result);
return 0;
};
}),
Operator(
prim::AutogradAdd,
[](Node* node) {
return [=](Stack& stack) {
at::Tensor a, b;
pop(stack, a, b);
if (!a.defined())
stack.push_back(b);
else if (!b.defined())
stack.push_back(a);
else
stack.push_back(a + b);
return 0;
};
}),
Operator(
prim::TupleUnpack,
[](Node* node) {
size_t num_elems = node->outputs().size();
return [=](Stack& stack) {
auto t = pop(stack).toTuple();
const auto & elems = t->elements();
if (elems.size() != num_elems) {
AT_ERROR("Expected a tuple of ", num_elems, " elements, but got ", elems.size());
}
stack.insert(stack.end(), elems.begin(), elems.end());
return 0;
};
}),
Operator(
prim::TupleSlice,
[](Node* node) {
int64_t beg_ind = node->i(attr::beg);
int64_t end_ind = node->i(attr::end);
return [=](Stack& stack) {
auto t = pop(stack).toTuple();
const auto & elems = t->elements();
std::vector<IValue> output_elems;
for (int64_t i = beg_ind; i < end_ind; ++i) {
output_elems.push_back(elems.at(i));
}
push(stack, Tuple::create(std::move(output_elems)));
return 0;
};
}),
Operator(
prim::TupleIndex,
[](Node* node) {
auto index = node->i(attr::index);
return [=](Stack& stack) {
auto tup = pop(stack).toTuple();
const auto & elems = tup->elements();
// index is normalized to be positive at compile time
stack.push_back(elems.at(index));
return 0;
};
}),
Operator(
prim::TupleConstruct,
[](Node* node) {
size_t num_inputs = node->inputs().size();
return [=](Stack& stack) {
std::vector<IValue> elems {
std::make_move_iterator(stack.end() - num_inputs),
std::make_move_iterator(stack.end())
};
drop(stack, num_inputs);
push(stack, Tuple::create(std::move(elems)));
return 0;
};
}),
Operator(
prim::ConstantChunk,
[](Node* node) {
int64_t chunks = node->i(attr::chunks);
int64_t dim = node->i(attr::dim);
auto outputs_used = fmap(node->outputs(), [](Value *v) { return v->uses().size() > 0; });
return [=](Stack& stack) {
autograd::profiler::RecordFunction record("chunk");
at::Tensor t;
pop(stack, t);
auto result = at::chunk(t, chunks, dim);
stack.insert(stack.end(), std::make_move_iterator(result.begin()),
std::make_move_iterator(result.end()));
// NB: Chunk can sometimes return a smaller number of outputs.
int64_t num_results = result.size();
if (num_results != chunks) {
if (num_results > chunks) {
JIT_ASSERTM(num_results == chunks,
"Expected chunk to return ", chunks, " outputs, but got ", num_results);
}
for (int64_t i = num_results; i < chunks; ++i) {
AT_CHECK(!outputs_used[i],
"Expected chunk to return at least ", chunks, " outputs, but got only ", num_results);
// We know that the output is unused, so it's ok to push anything on the stack.
stack.emplace_back();
}
}
return 0;
};
}),
Operator(
prim::ListUnpack,
[](Node* node) -> Operation {
size_t num_outputs = node->outputs().size();
ListTypePtr lt = node->input()->type()->expect<ListType>();
if (lt->getElementType() == IntType::get()) {
return [=](Stack& stack) {
auto ilist = pop(stack);
const auto & list = ilist.toIntList()->elements();
AT_CHECK(list.size() == num_outputs,
"Expected ", num_outputs, " elements in a list but found ", list.size());
stack.insert(stack.end(), list.begin(), list.end());
return 0;
};
} else if (lt->getElementType() == FloatType::get()) {
return [=](Stack& stack) {
auto ilist = pop(stack);
const auto & list = ilist.toDoubleList()->elements();
AT_CHECK(list.size() == num_outputs,
"Expected ", num_outputs, " elements in a list but found ", list.size());
stack.insert(stack.end(), list.begin(), list.end());
return 0;
};
} else if (lt->getElementType() == DynamicType::get()) {
return [=](Stack& stack) {
auto ilist = pop(stack);
const auto & list = ilist.toTensorList()->elements();
AT_CHECK(list.size() == num_outputs,
"Expected ", num_outputs, " elements in a list but found ", list.size());
stack.insert(stack.end(), list.begin(), list.end());
return 0;
};
} else {
AT_ERROR("Unsupported list type: ", lt->getElementType()->str());
}
}),
Operator(
prim::ListConstruct,
[](Node* node) -> Operation {
size_t num_inputs = node->inputs().size();
ListTypePtr lt = node->output()->type()->expect<ListType>();
if(IntType::get() == lt->getElementType()) {
return [=](Stack& stack) {
auto inputs = peekSlice(stack, 0, num_inputs, num_inputs);
std::vector<int64_t> vals = fmap(inputs, [](const IValue& v) {
return v.toInt();
});
drop(stack, num_inputs);
push(stack, std::move(vals));
return 0;
};
} else if(FloatType::get() == lt->getElementType()) {
return [=](Stack& stack) {
auto inputs = peekSlice(stack, 0, num_inputs, num_inputs);
std::vector<double> vals = fmap(inputs, [](const IValue& v) {
return v.toDouble();
});
drop(stack, num_inputs);
push(stack, std::move(vals));
return 0;
};
} else if (lt->getElementType()->isSubtypeOf(DynamicType::get())) {
return [=](Stack& stack) {
const size_t stack_size = stack.size();
std::vector<at::Tensor> vals;
vals.reserve(num_inputs);
for (size_t i = stack_size - num_inputs; i < stack_size; ++i) {
vals.push_back(std::move(stack[i]).toTensor());
}
drop(stack, num_inputs);
push(stack, std::move(vals));
return 0;
};
} else {
return [=](Stack& stack) {
const size_t stack_size = stack.size();
std::vector<IValue> vals;
vals.reserve(num_inputs);
for (size_t i = stack_size - num_inputs; i < stack_size; ++i) {
vals.push_back(std::move(stack[i]));
}
drop(stack, num_inputs);
push(stack, std::move(vals));
return 0;
};
}
}),
Operator(
prim::fork,
[](Node* node) {
Code code(node->g(attr::Subgraph));
JIT_ASSERT(node->blocks().size() == 0);
JIT_ASSERT(node->hasAttribute(attr::Subgraph));
return [=](Stack& stack) {
InterpreterState(code).run(stack);
push(stack, Future(pop(stack)));
return 0;
};
}),
Operator(
"aten::wait(Future(t) self) -> t",
[](Node* node) {
return [=](Stack& stack) {
push(stack, pop(stack).toFuture()->get());
return 0;
};
}),
});
// define implementations for primitive number ops
#define DEFINE_GENERIC_OP(aten_op, int_op, float_op, int_result, float_result) \
Operator( \
#aten_op "(int a, int b) -> " #int_result, \
[](Node* node) { \
return [=](Stack& stack) { \
int64_t a, b; \
pop(stack, a, b); \
push(stack, int_op); \
return 0; \
}; \
}), \
Operator( \
#aten_op "(float a, float b) -> " #float_result, [](Node* node) { \
return [=](Stack& stack) { \
double a, b; \
pop(stack, a, b); \
push(stack, float_op); \
return 0; \
}; \
}),
#define DEFINE_INT_OP(aten_op, op) \
Operator(#aten_op "(int a, int b) -> int", [](Node* node) { \
return [=](Stack& stack) { \
int64_t a, b; \
pop(stack, a, b); \
push(stack, op); \
return 0; \
}; \
}),
#define DEFINE_BINARY_OP(aten_op, op) \
DEFINE_GENERIC_OP(aten_op, op, op, int, float)
#define DEFINE_COMPARISON_OP(aten_op, op) \
DEFINE_GENERIC_OP(aten_op, op, op, bool, bool)
#define DEFINE_BOOL_OP(aten_op, op) \
Operator(#aten_op "(bool a, bool b) -> bool", [](Node* node) { \
return [=](Stack& stack) { \
bool a, b; \
pop(stack, a, b); \
push(stack, op); \
return 0; \
}; \
}),
// Convert an python index (which may be negative) into an index usable for a
// C++ container
int64_t normalizeIndex(int64_t idx, int64_t list_size) {
if (idx < 0) {
// Handle negative indexing
idx = list_size + idx;
}
return idx;
}
template <typename TList, typename TElement>
Operation listAppend(Node* node) {
return [](Stack& stack) {
TList a;
TElement el;
pop(stack, a, el);
a->elements().push_back(el);
return 0;
};
}
template <typename T>
Operation listSelect(Node* node) {
return [=](Stack& stack) {
T list;
int64_t idx;
pop(stack, list, idx);
const int64_t list_size = list->elements().size();
const int64_t normalized_idx = normalizeIndex(idx, list_size);
if (normalized_idx < 0 || normalized_idx >= list_size) {
throw std::out_of_range("list index out of range");
}
auto element = list->elements()[normalized_idx];
push(stack, std::move(element));
return 0;
};
}
template <typename T>
Operation listLen(Node* node) {
return [=](Stack& stack) {
T a;
pop(stack, a);
const int64_t size = a->elements().size();
push(stack, size);
return 0;
};
}
template <typename T>
Operation listEq(Node* node) {
return [=](Stack& stack) {
T a;
T b;
pop(stack, a, b);
push(stack, a->elements() == b->elements() ? 1 : 0);
return 0;
};
}
// Specialization for at::Tensor, since it doesn't define operator==
template <>
Operation listEq<Shared<TensorList>>(Node* node) {
return [=](Stack& stack) {
Shared<TensorList> a;
Shared<TensorList> b;
pop(stack, a, b);
if (a->elements().size() != b->elements().size()) {
push(stack, 0);
return 0;
}
for (size_t i = 0; i < a->elements().size(); ++i) {
const auto& a_element = a->elements()[i];
const auto& b_element = b->elements()[i];
// This preserves Python's semantics, which uses eq() to compare two
// elements, then passes the result to bool().
// see: https://docs.python.org/3.4/reference/datamodel.html#object.__ge__
const auto cmp_result = a_element.eq(b_element);
if (!cmp_result.is_nonzero()) {
push(stack, 0);
return 0;
}
}
push(stack, 1);
return 0;
};
}
template <class TList, class TElement>
Operation listAdd(Node* node) {
return [=](Stack& stack) {
TList a;
TList b;
pop(stack, a, b);
std::vector<TElement> ret;
const auto total_size = a->elements().size() + b->elements().size();
ret.reserve(total_size);
for (const auto& a_element : a->elements()) {
ret.push_back(a_element);
}
for (const auto& b_element : b->elements()) {
ret.push_back(b_element);
}
push(stack, ret);
return 0;
};
}
template <typename TList, typename TElement>
Operation listSlice(Node* node) {
return [](Stack& stack) {
TList list;
int64_t start;
int64_t end;
int64_t step;
pop(stack, list, start, end, step);
const int64_t list_size = list->elements().size();
// clamp start and end to the bounds of the list
const auto normalized_start =
std::max((int64_t)0, normalizeIndex(start, list_size));
const auto normalized_end =
std::min(list_size, normalizeIndex(end, list_size));
std::vector<TElement> sliced_list;
if (normalized_end <= normalized_start) {
// early exit if the slice is trivially empty
push(stack, sliced_list);
return 0;
}
sliced_list.reserve(normalized_end - normalized_start);
for (auto i = normalized_start; i < normalized_end;) {
sliced_list.push_back(list->elements()[i]);
i += step;
}
push(stack, sliced_list);
return 0;
};
}
RegisterOperators reg2({
#define DEFINE_STRING_OP(op_name, string_op, result) \
Operator( \
#op_name "(str a, str b) ->" #result, \
[](Node* node) { \
return [=](Stack& stack) { \
auto b = pop(stack).toStringRef(); \
auto a = pop(stack).toStringRef(); \
push(stack, string_op); \
return 0; \
}; \
}),
DEFINE_STRING_OP(aten::eq, a == b, bool)
DEFINE_STRING_OP(aten::ne, a != b, bool)
DEFINE_STRING_OP(aten::add, a + b, str)
#define CREATE_LIST_OPS(decl_type, c_type) \
Operator("aten::select(" decl_type "[] a, int b) -> " decl_type, listSelect<Shared<c_type>>), \
Operator("aten::len(" decl_type "[] a) -> int", listLen<Shared<c_type>>), \
Operator("aten::add(" decl_type "[] a, " decl_type "[] b) -> " decl_type "[]", listAdd<Shared<c_type>, c_type::ElemType>), \
Operator( \
"aten::slice(" decl_type "[] l, int start, int end=9223372036854775807, int step=1) -> " decl_type "[]", \
listSlice<Shared<c_type>, c_type::ElemType>), \
Operator( \
"aten::append(World w, " decl_type "[] self, " decl_type " el) -> World", \
listAppend<Shared<c_type>, c_type::ElemType>), \
CREATE_LIST_OPS("int", IntList)
CREATE_LIST_OPS("float", DoubleList)
CREATE_LIST_OPS("Tensor", TensorList)
CREATE_LIST_OPS("t", GenericList)
Operator("aten::eq(int[] a, int[] b) -> int", listEq<Shared<IntList>>),
Operator("aten::eq(float[] a, float[] b) -> int", listEq<Shared<DoubleList>>),
Operator("aten::eq(Tensor[] a, Tensor[] b) -> int", listEq<Shared<TensorList>>),
DEFINE_BINARY_OP(aten::add, a + b)
DEFINE_BINARY_OP(aten::sub, a - b)
DEFINE_BINARY_OP(aten::mul, a * b)
DEFINE_BINARY_OP(aten::pow, static_cast<decltype(a)>(pow(a, b)))
// Pass in two ops for handling int and float separately as % in C++ only works for int
// The modulus calculation is different between C++ and Python (on negative), we preserve
// the python behavior as it's more common and match python syntax, hence the conversion.
DEFINE_GENERIC_OP(aten::remainder, (b + (a % b)) % b, fmod((b + fmod(a, b)), b), int, float)
// TODO: Support python floordiv (//)
// Right now aten::floordiv is only used by loop unrolling
DEFINE_INT_OP(aten::floordiv, a / b)
// NB: This is the python truediv operation
Operator("aten::div(int a, int b) -> float",
[](Node* node) {
return [=](Stack& stack) {
int64_t a, b;
pop(stack, a, b);
push(stack, static_cast<double>(a) / static_cast<double>(b));
return 0;
};
}),
Operator("aten::div(float a, float b) -> float",
[](Node* node) {
return [=](Stack& stack) {
double a, b;
pop(stack, a, b);
push(stack, a / b);
return 0;
};
}),
DEFINE_COMPARISON_OP(aten::ne, a != b)
DEFINE_COMPARISON_OP(aten::eq, a == b)
DEFINE_COMPARISON_OP(aten::lt, a < b)
DEFINE_COMPARISON_OP(aten::gt, a > b)
DEFINE_COMPARISON_OP(aten::le, a <= b)
DEFINE_COMPARISON_OP(aten::ge, a >= b)
DEFINE_BOOL_OP(aten::__and__, a && b)
DEFINE_BOOL_OP(aten::__or__, a || b)
Operator("aten::_construct_empty_int_list() -> int[]",
[](Node* node) -> Operation {
return [=](Stack& stack){
push(stack, std::vector<int64_t>());
return 0;
};
}),
Operator("aten::_construct_empty_float_list() -> float[]",
[](Node* node) -> Operation {
return [=](Stack& stack){
push(stack, std::vector<double>());
return 0;
};
}),
Operator("aten::_construct_empty_tensor_list() -> Tensor[]",
[](Node* node) -> Operation {
return [=](Stack& stack){
push(stack, std::vector<at::Tensor>());
return 0;
};
}),
Operator(
"aten::neg(int self) -> int",
[](Node* node) {
return [=](Stack& stack) {
push(stack, -pop(stack).toInt());
return 0;
};
}),
Operator(
"aten::neg(float self) -> float",
[](Node* node) {
return [=](Stack& stack) {
push(stack, -pop(stack).toDouble());
return 0;
};
}),
Operator(
"aten::__not__(int self) -> int",
[](Node* node) {
return [=](Stack& stack) {
push(stack, !pop(stack).toInt());
return 0;
};
}),
Operator(
"aten::_tensor_to_list(Tensor self) -> int[]",
[](Node* node) {
return [=](Stack& stack) {
at::Tensor t;
pop(stack, t);
std::vector<int64_t> elems;
for(int i = 0; i < t.size(0); i++){
elems.push_back(*t[i].data<int32_t>());
}
push(stack, jit::IntList::create(elems));
return 0;
};
}),
Operator(
"aten::_list_to_tensor(int[] self) -> Tensor",
[](Node* node) {
return [=](Stack& stack) {
std::vector<int64_t> l;
pop(stack, l);
auto t = torch::empty(
{static_cast<int64_t>(l.size())}, at::dtype(at::kInt));
for(size_t i = 0; i < l.size(); i++){
t[i] = l[i];
}
push(stack, t);
return 0;
};
}),
});
at::Tensor leaky_relu(at::Tensor tensor, double scalar) {
return at::leaky_relu(tensor, scalar);
}
at::Tensor cat(std::vector<at::Tensor> tensors) {
return at::cat(tensors);
}
static auto reg3 =
torch::jit::RegisterOperators()
.op("_test::leaky_relu(Tensor self, float v=0.01) -> Tensor", &leaky_relu)
.op("_test::cat(Tensor[] inputs) -> Tensor", &cat);
}}} // torch::jit::anon