mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32791 When a registered operator has varags (ends with ... in its schema), the interpreter now appends the number of arguments to the top of the stack before invoking the operator. This allows the removal of more uses of Node* in the interpreter. This PR also then cleans up the constructors for Operator to make it more likely someone chooses the correct one. After making these ops: ``` USES NODE: prim::TupleUnpack(...) -> (...) USES NODE: prim::TupleSlice(...) -> (...) USES NODE: prim::TupleConstruct(...) -> (...) USES NODE: prim::ListUnpack(...) -> (...) USES NODE: prim::ListConstruct(...) -> (...) USES NODE: prim::DictConstruct(...) -> (...) USES NODE: prim::Constant() -> (...) USES NODE: prim::isinstance(...) -> (...) USES NODE: prim::CreateObject(...) -> (...) USES NODE: prim::fork(...) -> (...) USES NODE: aten::warn(str message, *, int stacklevel=2) -> () # need stack level information, so ideally in interpreter so it can look at the stack ``` Into interpreter primitives, we can remove all but two constructors for operators: one that is (schema_string, operation), and one that is (symbol, op_creator) for the remaining weird primitives. Test Plan: Imported from OSS Differential Revision: D19673158 Pulled By: zdevito fbshipit-source-id: 95442a001538a6f53c1db4a210f8557ef118de66
257 lines
8.1 KiB
C++
257 lines
8.1 KiB
C++
#include <torch/csrc/jit/constants.h>
|
|
#include <ATen/core/functional.h>
|
|
#include <torch/csrc/autograd/variable.h>
|
|
#include <torch/csrc/jit/custom_operator.h>
|
|
#include <torch/csrc/jit/ir.h>
|
|
#include <torch/csrc/jit/operator.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
namespace {
|
|
c10::OperatorOptions aliasAnalysisInternalSpecialCase() {
|
|
c10::OperatorOptions options;
|
|
options.setAliasAnalysis(AliasAnalysisKind::INTERNAL_SPECIAL_CASE);
|
|
return options;
|
|
}
|
|
} // namespace
|
|
|
|
bool insertableTensor(const at::Tensor& ten) {
|
|
return !ten.requires_grad();
|
|
}
|
|
|
|
bool insertableIValue(const IValue& ivalue) {
|
|
if (ivalue.isInt() || ivalue.isNone() || ivalue.isBool() ||
|
|
ivalue.isDouble() || ivalue.isString() || ivalue.isDevice()) {
|
|
return true;
|
|
}
|
|
if (ivalue.isTensor()) {
|
|
return insertableTensor(ivalue.toTensor());
|
|
}
|
|
if (ivalue.isList() || ivalue.isTuple()) {
|
|
c10::ArrayRef<IValue> elems;
|
|
if (ivalue.isTuple()) {
|
|
elems = ivalue.toTuple()->elements();
|
|
} else {
|
|
elems = ivalue.toListRef();
|
|
}
|
|
return std::all_of(elems.begin(), elems.end(), [](const IValue& tup_elem) {
|
|
return insertableIValue(tup_elem);
|
|
});
|
|
}
|
|
return false;
|
|
}
|
|
|
|
Value* insertConstant(
|
|
Graph& g,
|
|
const IValue& val,
|
|
c10::optional<SourceRange> loc,
|
|
c10::optional<ScopePtr> scope) {
|
|
auto value = tryInsertConstant(g, val, loc, scope);
|
|
if (value) {
|
|
return *value;
|
|
}
|
|
throw constant_not_supported_error(
|
|
"Unsupported value kind: " + val.tagKind());
|
|
}
|
|
|
|
// IValue -> Constant node
|
|
c10::optional<Value*> tryInsertConstant(
|
|
Graph& g,
|
|
const IValue& val,
|
|
c10::optional<SourceRange> loc,
|
|
c10::optional<ScopePtr> scope) {
|
|
Node* n = g.create(prim::Constant);
|
|
if (val.isTensor()) {
|
|
at::Tensor ref = val.toTensor();
|
|
if (!ref.defined()) {
|
|
n->destroy();
|
|
return g.insertNode(g.createNone())->output();
|
|
}
|
|
TORCH_INTERNAL_ASSERT(!ref.requires_grad());
|
|
n->output()->inferTypeFrom(
|
|
ref); // note: before t_ because of std::move(ref)
|
|
n->t_(attr::value, std::move(ref));
|
|
} else if (val.isInt()) {
|
|
n->i_(attr::value, val.toInt());
|
|
n->output()->setType(IntType::get());
|
|
} else if (val.isDouble()) {
|
|
n->f_(attr::value, val.toDouble());
|
|
n->output()->setType(FloatType::get());
|
|
} else if (val.isBool()) {
|
|
n->i_(attr::value, val.toBool());
|
|
n->output()->setType(BoolType::get());
|
|
} else if (val.isBoolList()) {
|
|
auto bool_list = val.toBoolList();
|
|
n->is_(
|
|
attr::value, std::vector<int64_t>(bool_list.begin(), bool_list.end()));
|
|
n->output()->setType(ListType::ofBools());
|
|
} else if (val.isIntList()) {
|
|
n->is_(attr::value, val.toIntVector());
|
|
n->output()->setType(ListType::ofInts());
|
|
} else if (val.isTensorList()) {
|
|
n->ts_(
|
|
attr::value,
|
|
fmap(val.toTensorVector(), [](const at::Tensor& t) {
|
|
AT_ASSERT(!t.requires_grad());
|
|
return t;
|
|
}));
|
|
n->output()->setType(ListType::ofTensors());
|
|
} else if (val.isDoubleList()) {
|
|
auto double_list = val.toDoubleList();
|
|
n->fs_(
|
|
attr::value,
|
|
std::vector<double>(double_list.begin(), double_list.end()));
|
|
n->output()->setType(ListType::ofFloats());
|
|
} else if (val.isString()) {
|
|
n->s_(attr::value, val.toString()->string());
|
|
n->output()->setType(StringType::get());
|
|
} else if (val.type()->isSubtypeOf(ListType::ofStrings())) {
|
|
std::vector<std::string> ss;
|
|
auto generic_list = val.toListRef();
|
|
for (const IValue& ival : generic_list) {
|
|
ss.push_back(ival.toStringRef());
|
|
}
|
|
n->ss_(attr::value, ss);
|
|
n->output()->setType(ListType::create(StringType::get()));
|
|
} else if (val.isDevice()) {
|
|
std::stringstream ss;
|
|
ss << val.toDevice();
|
|
n->s_(attr::value, ss.str());
|
|
n->output()->setType(DeviceObjType::get());
|
|
} else if (val.isNone()) {
|
|
n->output()->setType(NoneType::get());
|
|
} else if (val.isTuple()) {
|
|
if (insertableIValue(val)) {
|
|
n->ival_(attr::value, val);
|
|
n->output()->setType(val.type());
|
|
} else {
|
|
n->destroy();
|
|
return c10::nullopt;
|
|
};
|
|
} else {
|
|
n->destroy();
|
|
return c10::nullopt;
|
|
}
|
|
if (loc)
|
|
n->setSourceRange(*loc);
|
|
if (scope)
|
|
n->setScope(*scope);
|
|
return g.insertNode(n)->output();
|
|
}
|
|
|
|
RegisterOperators reg({
|
|
Operator(
|
|
prim::Constant,
|
|
[](const Node* node) -> Operation {
|
|
TypePtr type = node->output()->type();
|
|
if (type->isSubtypeOf(TensorType::get())) {
|
|
auto t = node->t(attr::value);
|
|
return [t](Stack& stack) {
|
|
push(stack, t);
|
|
return 0;
|
|
};
|
|
} else if (type->isSubtypeOf(BoolType::get())) {
|
|
bool b = node->i(attr::value);
|
|
return [b](Stack& stack) {
|
|
push(stack, b);
|
|
return 0;
|
|
};
|
|
} else if (
|
|
type->isSubtypeOf(NumberType::get()) &&
|
|
node->kindOf(attr::value) == AttributeKind::i) {
|
|
auto i = node->i(attr::value);
|
|
return [i](Stack& stack) {
|
|
push(stack, i);
|
|
return 0;
|
|
};
|
|
} else if (
|
|
type->isSubtypeOf(NumberType::get()) &&
|
|
node->kindOf(attr::value) == AttributeKind::f) {
|
|
auto f = node->f(attr::value);
|
|
return [f](Stack& stack) {
|
|
push(stack, f);
|
|
return 0;
|
|
};
|
|
} else if (
|
|
type->cast<TupleType>() &&
|
|
node->kindOf(attr::value) == AttributeKind::ival) {
|
|
const auto& tup = node->ival(attr::value);
|
|
TORCH_INTERNAL_ASSERT(tup.isTuple());
|
|
return [tup](Stack& stack) {
|
|
push(stack, tup);
|
|
return 0;
|
|
};
|
|
} else if (type->isSubtypeOf(ListType::ofInts())) {
|
|
const auto& is = node->is(attr::value);
|
|
return [is](Stack& stack) {
|
|
push(stack, is);
|
|
return 0;
|
|
};
|
|
} else if (type->isSubtypeOf(ListType::ofFloats())) {
|
|
const auto& fs = node->fs(attr::value);
|
|
return [fs](Stack& stack) {
|
|
push(stack, fs);
|
|
return 0;
|
|
};
|
|
} else if (type->isSubtypeOf(ListType::ofBools())) {
|
|
const auto bs = fmap<bool>(node->is(attr::value));
|
|
return [bs](Stack& stack) {
|
|
push(stack, bs);
|
|
return 0;
|
|
};
|
|
} else if (type->isSubtypeOf(ListType::ofTensors())) {
|
|
const auto& ts = node->ts(attr::value);
|
|
return [ts](Stack& stack) {
|
|
push(stack, ts);
|
|
return 0;
|
|
};
|
|
} else if (type->isSubtypeOf(ListType::ofStrings())) {
|
|
const auto& ss = node->ss(attr::value);
|
|
auto vals = c10::impl::GenericList(StringType::get());
|
|
for (const auto& str : ss) {
|
|
vals.push_back(str);
|
|
}
|
|
return [vals](Stack& stack) {
|
|
push(stack, vals);
|
|
return 0;
|
|
};
|
|
} else if (type == StringType::get()) {
|
|
const auto& s = node->s(attr::value);
|
|
return [s](Stack& stack) {
|
|
push(stack, s);
|
|
return 0;
|
|
};
|
|
} else if (type == DeviceObjType::get()) {
|
|
auto d = c10::Device(node->s(attr::value));
|
|
return [d](Stack& stack) {
|
|
push(stack, d);
|
|
return 0;
|
|
};
|
|
} else if (node->mustBeNone()) {
|
|
return [](Stack& stack) {
|
|
push(stack, IValue());
|
|
return 0;
|
|
};
|
|
} else {
|
|
std::stringstream ss;
|
|
ss << "constant literal not supported for: " << type->str();
|
|
throw std::runtime_error(ss.str());
|
|
}
|
|
},
|
|
aliasAnalysisInternalSpecialCase()),
|
|
});
|
|
|
|
c10::optional<IValue> toIValue(const Value* v) {
|
|
if (v->node()->kind() != prim::Constant || v->type()->cast<FunctionType>()) {
|
|
return c10::nullopt;
|
|
}
|
|
// use implementation of prim::Constant to compute the output IValue
|
|
auto op = v->node()->getOperation();
|
|
Stack stack;
|
|
op(stack);
|
|
return stack.back();
|
|
}
|
|
} // namespace jit
|
|
} // namespace torch
|