mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Remove prim::Constant op (#32804)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32804 Constants are interpreter primitives so the op was not actually used. This cleans up some of the logic around it. This also fixes constant prop such that failures to look up an op do not silently stop constant propagation. Instead, only errors inside the op implementation itself will do this. Test Plan: Imported from OSS Differential Revision: D19673156 Pulled By: zdevito fbshipit-source-id: 7beee59a6a67a6c2f8261d86bd505280fefa999e
This commit is contained in:
parent
c59e35b147
commit
83c347ff4a
|
|
@ -3943,7 +3943,7 @@ graph(%Ra, %Rb):
|
||||||
|
|
||||||
bar = Bar()
|
bar = Bar()
|
||||||
ops = torch.jit.export_opnames(bar)
|
ops = torch.jit.export_opnames(bar)
|
||||||
expected = ['aten::add.Tensor', 'aten::mul.Scalar', 'prim::Constant']
|
expected = ['aten::add.Tensor', 'aten::mul.Scalar']
|
||||||
self.assertEqual(ops, expected)
|
self.assertEqual(ops, expected)
|
||||||
|
|
||||||
def test_pytorch_jit_env_off(self):
|
def test_pytorch_jit_env_off(self):
|
||||||
|
|
|
||||||
|
|
@ -401,11 +401,11 @@ class TestScriptPy3(JitTestCase):
|
||||||
|
|
||||||
scripted_M_mod = torch.jit.script(M())
|
scripted_M_mod = torch.jit.script(M())
|
||||||
self.assertEqual(torch.jit.export_opnames(scripted_M_mod),
|
self.assertEqual(torch.jit.export_opnames(scripted_M_mod),
|
||||||
['aten::mul.Scalar', 'aten::mul.Tensor', 'aten::reciprocal', 'prim::Constant'])
|
['aten::mul.Scalar', 'aten::mul.Tensor', 'aten::reciprocal'])
|
||||||
|
|
||||||
scripted_M_mod.sub = torch.jit.script(FooMod())
|
scripted_M_mod.sub = torch.jit.script(FooMod())
|
||||||
self.assertEqual(torch.jit.export_opnames(scripted_M_mod),
|
self.assertEqual(torch.jit.export_opnames(scripted_M_mod),
|
||||||
['aten::add.Tensor', 'aten::mul.Scalar', 'prim::Constant'])
|
['aten::add.Tensor', 'aten::mul.Scalar'])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
||||||
|
|
@ -140,117 +140,63 @@ c10::optional<Value*> tryInsertConstant(
|
||||||
return g.insertNode(n)->output();
|
return g.insertNode(n)->output();
|
||||||
}
|
}
|
||||||
|
|
||||||
RegisterOperators reg({
|
|
||||||
Operator(
|
|
||||||
prim::Constant,
|
|
||||||
[](const Node* node) -> Operation {
|
|
||||||
TypePtr type = node->output()->type();
|
|
||||||
if (type->isSubtypeOf(TensorType::get())) {
|
|
||||||
auto t = node->t(attr::value);
|
|
||||||
return [t](Stack& stack) {
|
|
||||||
push(stack, t);
|
|
||||||
return 0;
|
|
||||||
};
|
|
||||||
} else if (type->isSubtypeOf(BoolType::get())) {
|
|
||||||
bool b = node->i(attr::value);
|
|
||||||
return [b](Stack& stack) {
|
|
||||||
push(stack, b);
|
|
||||||
return 0;
|
|
||||||
};
|
|
||||||
} else if (
|
|
||||||
type->isSubtypeOf(NumberType::get()) &&
|
|
||||||
node->kindOf(attr::value) == AttributeKind::i) {
|
|
||||||
auto i = node->i(attr::value);
|
|
||||||
return [i](Stack& stack) {
|
|
||||||
push(stack, i);
|
|
||||||
return 0;
|
|
||||||
};
|
|
||||||
} else if (
|
|
||||||
type->isSubtypeOf(NumberType::get()) &&
|
|
||||||
node->kindOf(attr::value) == AttributeKind::f) {
|
|
||||||
auto f = node->f(attr::value);
|
|
||||||
return [f](Stack& stack) {
|
|
||||||
push(stack, f);
|
|
||||||
return 0;
|
|
||||||
};
|
|
||||||
} else if (
|
|
||||||
type->cast<TupleType>() &&
|
|
||||||
node->kindOf(attr::value) == AttributeKind::ival) {
|
|
||||||
const auto& tup = node->ival(attr::value);
|
|
||||||
TORCH_INTERNAL_ASSERT(tup.isTuple());
|
|
||||||
return [tup](Stack& stack) {
|
|
||||||
push(stack, tup);
|
|
||||||
return 0;
|
|
||||||
};
|
|
||||||
} else if (type->isSubtypeOf(ListType::ofInts())) {
|
|
||||||
const auto& is = node->is(attr::value);
|
|
||||||
return [is](Stack& stack) {
|
|
||||||
push(stack, is);
|
|
||||||
return 0;
|
|
||||||
};
|
|
||||||
} else if (type->isSubtypeOf(ListType::ofFloats())) {
|
|
||||||
const auto& fs = node->fs(attr::value);
|
|
||||||
return [fs](Stack& stack) {
|
|
||||||
push(stack, fs);
|
|
||||||
return 0;
|
|
||||||
};
|
|
||||||
} else if (type->isSubtypeOf(ListType::ofBools())) {
|
|
||||||
const auto bs = fmap<bool>(node->is(attr::value));
|
|
||||||
return [bs](Stack& stack) {
|
|
||||||
push(stack, bs);
|
|
||||||
return 0;
|
|
||||||
};
|
|
||||||
} else if (type->isSubtypeOf(ListType::ofTensors())) {
|
|
||||||
const auto& ts = node->ts(attr::value);
|
|
||||||
return [ts](Stack& stack) {
|
|
||||||
push(stack, ts);
|
|
||||||
return 0;
|
|
||||||
};
|
|
||||||
} else if (type->isSubtypeOf(ListType::ofStrings())) {
|
|
||||||
const auto& ss = node->ss(attr::value);
|
|
||||||
auto vals = c10::impl::GenericList(StringType::get());
|
|
||||||
for (const auto& str : ss) {
|
|
||||||
vals.push_back(str);
|
|
||||||
}
|
|
||||||
return [vals](Stack& stack) {
|
|
||||||
push(stack, vals);
|
|
||||||
return 0;
|
|
||||||
};
|
|
||||||
} else if (type == StringType::get()) {
|
|
||||||
const auto& s = node->s(attr::value);
|
|
||||||
return [s](Stack& stack) {
|
|
||||||
push(stack, s);
|
|
||||||
return 0;
|
|
||||||
};
|
|
||||||
} else if (type == DeviceObjType::get()) {
|
|
||||||
auto d = c10::Device(node->s(attr::value));
|
|
||||||
return [d](Stack& stack) {
|
|
||||||
push(stack, d);
|
|
||||||
return 0;
|
|
||||||
};
|
|
||||||
} else if (node->mustBeNone()) {
|
|
||||||
return [](Stack& stack) {
|
|
||||||
push(stack, IValue());
|
|
||||||
return 0;
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
std::stringstream ss;
|
|
||||||
ss << "constant literal not supported for: " << type->str();
|
|
||||||
throw std::runtime_error(ss.str());
|
|
||||||
}
|
|
||||||
},
|
|
||||||
aliasAnalysisInternalSpecialCase()),
|
|
||||||
});
|
|
||||||
|
|
||||||
c10::optional<IValue> toIValue(const Value* v) {
|
c10::optional<IValue> toIValue(const Value* v) {
|
||||||
if (v->node()->kind() != prim::Constant || v->type()->cast<FunctionType>()) {
|
if (v->node()->kind() != prim::Constant || v->type()->cast<FunctionType>()) {
|
||||||
return c10::nullopt;
|
return c10::nullopt;
|
||||||
}
|
}
|
||||||
// use implementation of prim::Constant to compute the output IValue
|
const Node* node = v->node();
|
||||||
auto op = v->node()->getOperation();
|
const TypePtr& type = v->type();
|
||||||
Stack stack;
|
if (type->isSubtypeOf(TensorType::get())) {
|
||||||
op(stack);
|
return node->t(attr::value);
|
||||||
return stack.back();
|
} else if (type->isSubtypeOf(BoolType::get())) {
|
||||||
|
return (bool) node->i(attr::value);
|
||||||
|
} else if (
|
||||||
|
type->isSubtypeOf(NumberType::get()) &&
|
||||||
|
node->kindOf(attr::value) == AttributeKind::i) {
|
||||||
|
return node->i(attr::value);
|
||||||
|
} else if (
|
||||||
|
type->isSubtypeOf(NumberType::get()) &&
|
||||||
|
node->kindOf(attr::value) == AttributeKind::f) {
|
||||||
|
return node->f(attr::value);
|
||||||
|
} else if (
|
||||||
|
type->cast<TupleType>() &&
|
||||||
|
node->kindOf(attr::value) == AttributeKind::ival) {
|
||||||
|
const auto& tup = node->ival(attr::value);
|
||||||
|
TORCH_INTERNAL_ASSERT(tup.isTuple());
|
||||||
|
return tup;
|
||||||
|
} else if (type->isSubtypeOf(ListType::ofInts())) {
|
||||||
|
const auto& is = node->is(attr::value);
|
||||||
|
return is;
|
||||||
|
} else if (type->isSubtypeOf(ListType::ofFloats())) {
|
||||||
|
const auto& fs = node->fs(attr::value);
|
||||||
|
return fs;
|
||||||
|
} else if (type->isSubtypeOf(ListType::ofBools())) {
|
||||||
|
const auto bs = fmap<bool>(node->is(attr::value));
|
||||||
|
return bs;
|
||||||
|
} else if (type->isSubtypeOf(ListType::ofTensors())) {
|
||||||
|
const auto& ts = node->ts(attr::value);
|
||||||
|
return ts;
|
||||||
|
} else if (type->isSubtypeOf(ListType::ofStrings())) {
|
||||||
|
const auto& ss = node->ss(attr::value);
|
||||||
|
auto vals = c10::impl::GenericList(StringType::get());
|
||||||
|
for (const auto& str : ss) {
|
||||||
|
vals.push_back(str);
|
||||||
|
}
|
||||||
|
return vals;
|
||||||
|
} else if (type == StringType::get()) {
|
||||||
|
const auto& s = node->s(attr::value);
|
||||||
|
return s;
|
||||||
|
} else if (type == DeviceObjType::get()) {
|
||||||
|
auto d = c10::Device(node->s(attr::value));
|
||||||
|
return d;
|
||||||
|
} else if (node->mustBeNone()) {
|
||||||
|
return IValue();
|
||||||
|
} else {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "constant literal not supported for: " << type->str();
|
||||||
|
throw std::runtime_error(ss.str());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
||||||
|
|
@ -24,8 +24,8 @@ c10::optional<Stack> runNodeIfInputsAreConstant(const Node* node) {
|
||||||
return c10::nullopt;
|
return c10::nullopt;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
auto op = node->getOperation();
|
||||||
try {
|
try {
|
||||||
auto op = node->getOperation();
|
|
||||||
op(stack);
|
op(stack);
|
||||||
TORCH_INTERNAL_ASSERT(stack.size() == node->outputs().size());
|
TORCH_INTERNAL_ASSERT(stack.size() == node->outputs().size());
|
||||||
} catch (...) {
|
} catch (...) {
|
||||||
|
|
@ -75,39 +75,35 @@ struct ConstantPropagator {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<IValue> runNode(Node* n) {
|
c10::optional<std::vector<IValue>> runNode(Node* n) {
|
||||||
auto op = n->getOperation();
|
auto op = n->getOperation();
|
||||||
Stack stack;
|
Stack stack;
|
||||||
for (auto input : n->inputs()) {
|
for (auto input : n->inputs()) {
|
||||||
stack.push_back(*toIValue(input));
|
stack.push_back(*toIValue(input));
|
||||||
}
|
}
|
||||||
op(stack);
|
try {
|
||||||
auto var_outputs = fmap(stack, [&](IValue v) -> IValue {
|
op(stack);
|
||||||
|
} catch (...) {
|
||||||
|
return c10::nullopt;
|
||||||
|
}
|
||||||
|
for (const IValue& v : stack) {
|
||||||
if (v.isTensor()) {
|
if (v.isTensor()) {
|
||||||
auto t = std::move(v).toTensor();
|
at::Tensor t = v.toTensor();
|
||||||
if (t.defined()) {
|
if (t.defined() && t.requires_grad()) {
|
||||||
if (t.requires_grad()) {
|
// requires grad tensors cannot be constants
|
||||||
// error gets caught within propagateNode()
|
return c10::nullopt;
|
||||||
throw c10::Error("Can't insert requires grad as constant", "");
|
|
||||||
}
|
|
||||||
return IValue(t);
|
|
||||||
} else {
|
|
||||||
return t;
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
return v;
|
|
||||||
}
|
}
|
||||||
});
|
}
|
||||||
return var_outputs;
|
return stack;
|
||||||
}
|
}
|
||||||
|
|
||||||
void propagateNode(Node* n) {
|
void propagateNode(Node* n) {
|
||||||
std::vector<IValue> outputs;
|
std::vector<IValue> outputs;
|
||||||
try {
|
if (auto outputs_opt = runNode(n)) {
|
||||||
outputs = runNode(n);
|
outputs = std::move(outputs_opt.value());
|
||||||
} catch (...) {
|
} else {
|
||||||
// Catch exceptions. This op may not be run,
|
// The op failed to run, so we cannot continue constant-prop for it.
|
||||||
// so catch the error here & leave the op in the graph
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto graph = n->owningGraph();
|
auto graph = n->owningGraph();
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user