mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
fix missing type check in dictionary literal
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/31375 Test Plan: Imported from OSS Differential Revision: D19145440 Pulled By: zdevito fbshipit-source-id: 69909089586149ef766b4858d3420864a81b2493
This commit is contained in:
parent
348d42114e
commit
457286a383
|
|
@ -16354,6 +16354,12 @@ a")
|
|||
self.checkScript(test_dict_tensor_key, (dict_a, inp1))
|
||||
self.checkScript(test_dict_tensor_key, (dict_a, inp2))
|
||||
|
||||
def test_dict_types(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "single type"):
|
||||
@torch.jit.script
|
||||
def foo():
|
||||
new_item = {'score': [1.0], 'ys': [1, 2, 3]}
|
||||
|
||||
def test_get_set_state_with_tensors(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -2809,6 +2809,7 @@ struct to_ir {
|
|||
auto value_trees = dl.value_inputs().tree()->trees();
|
||||
AT_ASSERT(key_trees.size() == value_trees.size());
|
||||
std::vector<Value*> keys, values;
|
||||
|
||||
for (size_t i = 0; i < key_trees.size(); ++i) {
|
||||
keys.push_back(emitExpr(Expr(key_trees[i])));
|
||||
values.push_back(emitExpr(Expr(value_trees[i])));
|
||||
|
|
@ -2821,15 +2822,34 @@ struct to_ir {
|
|||
auto dict_type = type_hint->expect<DictType>();
|
||||
key_type = dict_type->getKeyType();
|
||||
value_type = dict_type->getValueType();
|
||||
} else if (!keys.empty()) {
|
||||
key_type = keys.at(0)->type();
|
||||
value_type = values.at(0)->type();
|
||||
} else {
|
||||
} else if (keys.empty()) {
|
||||
key_type = StringType::get();
|
||||
value_type = TensorType::get();
|
||||
} else {
|
||||
key_type = keys.at(0)->type();
|
||||
value_type = values.at(0)->type();
|
||||
}
|
||||
AT_ASSERT(key_type != nullptr && value_type != nullptr);
|
||||
|
||||
auto checkTypeOfValues = [](const TypePtr& type,
|
||||
const char* what,
|
||||
const std::vector<Value*>& values,
|
||||
TreeList trees) {
|
||||
for (size_t i = 0, N = values.size(); i < N; ++i) {
|
||||
std::stringstream ss;
|
||||
if (!values[i]->type()->isSubtypeOfExt(type, &ss)) {
|
||||
throw ErrorReport(trees[i])
|
||||
<< "Dict " << what
|
||||
<< " must contain only a single type, expected: "
|
||||
<< type->python_str() << " but found "
|
||||
<< values[i]->type()->python_str() << " instead.\n"
|
||||
<< ss.str();
|
||||
}
|
||||
}
|
||||
};
|
||||
checkTypeOfValues(key_type, "keys", keys, key_trees);
|
||||
checkTypeOfValues(value_type, "values", values, value_trees);
|
||||
|
||||
return graph
|
||||
->insertNode(graph->createDict(key_type, value_type, keys, values))
|
||||
->output();
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user