fix missing type check in dictionary literal

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/31375

Test Plan: Imported from OSS

Differential Revision: D19145440

Pulled By: zdevito

fbshipit-source-id: 69909089586149ef766b4858d3420864a81b2493
This commit is contained in:
Zachary DeVito 2019-12-19 16:19:50 -08:00 committed by Facebook Github Bot
parent 348d42114e
commit 457286a383
2 changed files with 30 additions and 4 deletions

View File

@ -16354,6 +16354,12 @@ a")
self.checkScript(test_dict_tensor_key, (dict_a, inp1))
self.checkScript(test_dict_tensor_key, (dict_a, inp2))
def test_dict_types(self):
with self.assertRaisesRegex(RuntimeError, "single type"):
@torch.jit.script
def foo():
new_item = {'score': [1.0], 'ys': [1, 2, 3]}
def test_get_set_state_with_tensors(self):
class M(torch.nn.Module):
def __init__(self):

View File

@ -2809,6 +2809,7 @@ struct to_ir {
auto value_trees = dl.value_inputs().tree()->trees();
AT_ASSERT(key_trees.size() == value_trees.size());
std::vector<Value*> keys, values;
for (size_t i = 0; i < key_trees.size(); ++i) {
keys.push_back(emitExpr(Expr(key_trees[i])));
values.push_back(emitExpr(Expr(value_trees[i])));
@ -2821,15 +2822,34 @@ struct to_ir {
auto dict_type = type_hint->expect<DictType>();
key_type = dict_type->getKeyType();
value_type = dict_type->getValueType();
} else if (!keys.empty()) {
key_type = keys.at(0)->type();
value_type = values.at(0)->type();
} else {
} else if (keys.empty()) {
key_type = StringType::get();
value_type = TensorType::get();
} else {
key_type = keys.at(0)->type();
value_type = values.at(0)->type();
}
AT_ASSERT(key_type != nullptr && value_type != nullptr);
auto checkTypeOfValues = [](const TypePtr& type,
const char* what,
const std::vector<Value*>& values,
TreeList trees) {
for (size_t i = 0, N = values.size(); i < N; ++i) {
std::stringstream ss;
if (!values[i]->type()->isSubtypeOfExt(type, &ss)) {
throw ErrorReport(trees[i])
<< "Dict " << what
<< " must contain only a single type, expected: "
<< type->python_str() << " but found "
<< values[i]->type()->python_str() << " instead.\n"
<< ss.str();
}
}
};
checkTypeOfValues(key_type, "keys", keys, key_trees);
checkTypeOfValues(value_type, "values", values, value_trees);
return graph
->insertNode(graph->createDict(key_type, value_type, keys, values))
->output();