Remove prim::Constant op (#32804)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32804

Constants are interpreter primitives so the op was not actually used.
This cleans up some of the logic around it.

This also fixes constant prop such that failures to look up an op
do not silently stop constant propagation. Instead, only errors
inside the op implementation itself will do this.

Test Plan: Imported from OSS

Differential Revision: D19673156

Pulled By: zdevito

fbshipit-source-id: 7beee59a6a67a6c2f8261d86bd505280fefa999e
This commit is contained in:
Zachary DeVito 2020-02-18 15:02:36 -08:00 committed by Facebook Github Bot
parent c59e35b147
commit 83c347ff4a
4 changed files with 74 additions and 132 deletions

View File

@ -3943,7 +3943,7 @@ graph(%Ra, %Rb):
bar = Bar()
ops = torch.jit.export_opnames(bar)
expected = ['aten::add.Tensor', 'aten::mul.Scalar', 'prim::Constant']
expected = ['aten::add.Tensor', 'aten::mul.Scalar']
self.assertEqual(ops, expected)
def test_pytorch_jit_env_off(self):

View File

@ -401,11 +401,11 @@ class TestScriptPy3(JitTestCase):
scripted_M_mod = torch.jit.script(M())
self.assertEqual(torch.jit.export_opnames(scripted_M_mod),
['aten::mul.Scalar', 'aten::mul.Tensor', 'aten::reciprocal', 'prim::Constant'])
['aten::mul.Scalar', 'aten::mul.Tensor', 'aten::reciprocal'])
scripted_M_mod.sub = torch.jit.script(FooMod())
self.assertEqual(torch.jit.export_opnames(scripted_M_mod),
['aten::add.Tensor', 'aten::mul.Scalar', 'prim::Constant'])
['aten::add.Tensor', 'aten::mul.Scalar'])
if __name__ == '__main__':

View File

@ -140,117 +140,63 @@ c10::optional<Value*> tryInsertConstant(
return g.insertNode(n)->output();
}
RegisterOperators reg({
Operator(
prim::Constant,
[](const Node* node) -> Operation {
TypePtr type = node->output()->type();
if (type->isSubtypeOf(TensorType::get())) {
auto t = node->t(attr::value);
return [t](Stack& stack) {
push(stack, t);
return 0;
};
} else if (type->isSubtypeOf(BoolType::get())) {
bool b = node->i(attr::value);
return [b](Stack& stack) {
push(stack, b);
return 0;
};
} else if (
type->isSubtypeOf(NumberType::get()) &&
node->kindOf(attr::value) == AttributeKind::i) {
auto i = node->i(attr::value);
return [i](Stack& stack) {
push(stack, i);
return 0;
};
} else if (
type->isSubtypeOf(NumberType::get()) &&
node->kindOf(attr::value) == AttributeKind::f) {
auto f = node->f(attr::value);
return [f](Stack& stack) {
push(stack, f);
return 0;
};
} else if (
type->cast<TupleType>() &&
node->kindOf(attr::value) == AttributeKind::ival) {
const auto& tup = node->ival(attr::value);
TORCH_INTERNAL_ASSERT(tup.isTuple());
return [tup](Stack& stack) {
push(stack, tup);
return 0;
};
} else if (type->isSubtypeOf(ListType::ofInts())) {
const auto& is = node->is(attr::value);
return [is](Stack& stack) {
push(stack, is);
return 0;
};
} else if (type->isSubtypeOf(ListType::ofFloats())) {
const auto& fs = node->fs(attr::value);
return [fs](Stack& stack) {
push(stack, fs);
return 0;
};
} else if (type->isSubtypeOf(ListType::ofBools())) {
const auto bs = fmap<bool>(node->is(attr::value));
return [bs](Stack& stack) {
push(stack, bs);
return 0;
};
} else if (type->isSubtypeOf(ListType::ofTensors())) {
const auto& ts = node->ts(attr::value);
return [ts](Stack& stack) {
push(stack, ts);
return 0;
};
} else if (type->isSubtypeOf(ListType::ofStrings())) {
const auto& ss = node->ss(attr::value);
auto vals = c10::impl::GenericList(StringType::get());
for (const auto& str : ss) {
vals.push_back(str);
}
return [vals](Stack& stack) {
push(stack, vals);
return 0;
};
} else if (type == StringType::get()) {
const auto& s = node->s(attr::value);
return [s](Stack& stack) {
push(stack, s);
return 0;
};
} else if (type == DeviceObjType::get()) {
auto d = c10::Device(node->s(attr::value));
return [d](Stack& stack) {
push(stack, d);
return 0;
};
} else if (node->mustBeNone()) {
return [](Stack& stack) {
push(stack, IValue());
return 0;
};
} else {
std::stringstream ss;
ss << "constant literal not supported for: " << type->str();
throw std::runtime_error(ss.str());
}
},
aliasAnalysisInternalSpecialCase()),
});
c10::optional<IValue> toIValue(const Value* v) {
if (v->node()->kind() != prim::Constant || v->type()->cast<FunctionType>()) {
return c10::nullopt;
}
// use implementation of prim::Constant to compute the output IValue
auto op = v->node()->getOperation();
Stack stack;
op(stack);
return stack.back();
const Node* node = v->node();
const TypePtr& type = v->type();
if (type->isSubtypeOf(TensorType::get())) {
return node->t(attr::value);
} else if (type->isSubtypeOf(BoolType::get())) {
return (bool) node->i(attr::value);
} else if (
type->isSubtypeOf(NumberType::get()) &&
node->kindOf(attr::value) == AttributeKind::i) {
return node->i(attr::value);
} else if (
type->isSubtypeOf(NumberType::get()) &&
node->kindOf(attr::value) == AttributeKind::f) {
return node->f(attr::value);
} else if (
type->cast<TupleType>() &&
node->kindOf(attr::value) == AttributeKind::ival) {
const auto& tup = node->ival(attr::value);
TORCH_INTERNAL_ASSERT(tup.isTuple());
return tup;
} else if (type->isSubtypeOf(ListType::ofInts())) {
const auto& is = node->is(attr::value);
return is;
} else if (type->isSubtypeOf(ListType::ofFloats())) {
const auto& fs = node->fs(attr::value);
return fs;
} else if (type->isSubtypeOf(ListType::ofBools())) {
const auto bs = fmap<bool>(node->is(attr::value));
return bs;
} else if (type->isSubtypeOf(ListType::ofTensors())) {
const auto& ts = node->ts(attr::value);
return ts;
} else if (type->isSubtypeOf(ListType::ofStrings())) {
const auto& ss = node->ss(attr::value);
auto vals = c10::impl::GenericList(StringType::get());
for (const auto& str : ss) {
vals.push_back(str);
}
return vals;
} else if (type == StringType::get()) {
const auto& s = node->s(attr::value);
return s;
} else if (type == DeviceObjType::get()) {
auto d = c10::Device(node->s(attr::value));
return d;
} else if (node->mustBeNone()) {
return IValue();
} else {
std::stringstream ss;
ss << "constant literal not supported for: " << type->str();
throw std::runtime_error(ss.str());
}
}
} // namespace jit
} // namespace torch

View File

@ -24,8 +24,8 @@ c10::optional<Stack> runNodeIfInputsAreConstant(const Node* node) {
return c10::nullopt;
}
}
auto op = node->getOperation();
try {
auto op = node->getOperation();
op(stack);
TORCH_INTERNAL_ASSERT(stack.size() == node->outputs().size());
} catch (...) {
@ -75,39 +75,35 @@ struct ConstantPropagator {
}
}
std::vector<IValue> runNode(Node* n) {
c10::optional<std::vector<IValue>> runNode(Node* n) {
auto op = n->getOperation();
Stack stack;
for (auto input : n->inputs()) {
stack.push_back(*toIValue(input));
}
op(stack);
auto var_outputs = fmap(stack, [&](IValue v) -> IValue {
try {
op(stack);
} catch (...) {
return c10::nullopt;
}
for (const IValue& v : stack) {
if (v.isTensor()) {
auto t = std::move(v).toTensor();
if (t.defined()) {
if (t.requires_grad()) {
// error gets caught within propagateNode()
throw c10::Error("Can't insert requires grad as constant", "");
}
return IValue(t);
} else {
return t;
at::Tensor t = v.toTensor();
if (t.defined() && t.requires_grad()) {
// requires grad tensors cannot be constants
return c10::nullopt;
}
} else {
return v;
}
});
return var_outputs;
}
return stack;
}
void propagateNode(Node* n) {
std::vector<IValue> outputs;
try {
outputs = runNode(n);
} catch (...) {
// Catch exceptions. This op may not be run,
// so catch the error here & leave the op in the graph
if (auto outputs_opt = runNode(n)) {
outputs = std::move(outputs_opt.value());
} else {
// The op failed to run, so we cannot continue constant-prop for it.
return;
}
auto graph = n->owningGraph();