Dict mutability (#16884)

Summary:
Adds `aten::_set_item` for `dict[key]` calls
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16884

Differential Revision: D14000488

Pulled By: driazati

fbshipit-source-id: ea1b46e0a736d095053effb4bc52753f696617b2
This commit is contained in:
David Riazati 2019-02-21 16:09:43 -08:00 committed by Facebook Github Bot
parent 3a47d56946
commit ac00a0cd47
4 changed files with 36 additions and 9 deletions

View File

@ -41,7 +41,7 @@ from common_methods_invocations import create_input, unpack_variables, \
exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL
from copy import deepcopy
import random
from typing import List, Optional
from typing import List, Dict, Optional
from torch.jit.frontend import NotSupportedError
from torch.jit import BatchTensor
@ -10328,6 +10328,16 @@ a")
self.checkScript(list_of_dicts, ())
def test_dict_mutability(self):
@torch.jit.script
def fn():
# type: () -> Dict[str, int]
a = torch.jit.annotate(Dict[str, int], {})
a['ok'] = 10
return a
self.assertEqual(fn(), {'ok': 10})
def dict_to_python(self):
def python_lookup(my_dict, keys):
# type: (Dict[str, int], List[str]) -> List[int]

View File

@ -167,8 +167,9 @@ struct SchemaParser {
auto key_type = parseType().first;
L.expect(',');
auto value_type = parseType().first;
alias_info = parseAliasAnnotation();
L.expect(')');
alias_info = parseAliasAnnotation();
value = DictType::create(key_type, value_type);
} else {
auto value_alias = parseBaseType();

View File

@ -801,9 +801,10 @@ struct PythonPrintPass {
} break;
case prim::DictConstruct: {
auto dict_type = node->output()->type()->expect<DictType>();
if (node->inputs().size() == 0 &&
!dict_type->getKeyType()->isSubtypeOf(StringType::get()) &&
!dict_type->getValueType()->isSubtypeOf(TensorType::get())) {
bool is_default_type =
dict_type->getKeyType()->isSubtypeOf(StringType::get()) &&
dict_type->getKeyType()->isSubtypeOf(TensorType::get());
if (node->inputs().size() == 0 && !is_default_type) {
stmt << "annotate(" << node->output()->type()->python_str()
<< ", {})";
} else {

View File

@ -1179,6 +1179,15 @@ int listSetItem<Shared<BoolList>, bool>(Stack& stack) {
return 0;
}
int dictSetItem(Stack& stack) {
auto value = pop(stack);
auto idx = pop(stack);
auto& dict = pop(stack).toGenericDict()->elements();
dict[idx] = value;
push(stack, dict);
return 0;
}
int dictLen(Stack& stack) {
auto dict = pop(stack).toGenericDictRef();
push(stack, int64_t(dict.size()));
@ -1316,6 +1325,8 @@ RegisterOperators reg2({
// NOTE: this must be after the other list specializations so that operator
// resolution doesn't pick this up first
CREATE_MUTABLE_LIST_OPS("t", GenericList),
#undef CREATE_IMMUTABLE_LIST_OPS
#undef CREATE_MUTABLE_LIST_OPS
#define CREATE_LIST_OPS(decl_type, c_type) \
Operator("aten::len(" decl_type "[] a) -> int", listLen<Shared<c_type>>), \
@ -1506,13 +1517,17 @@ RegisterOperators reg2({
#define CREATE_DICT_OPS(key_type) \
Operator("aten::len(Dict(" key_type ", t) self) -> int", dictLen), \
Operator( \
"aten::keys(Dict(" key_type ", t) self) -> " key_type "[]", \
"aten::keys(Dict(" key_type ", t) self) -> " key_type "[](*)", \
dictKeys), \
Operator("aten::values(Dict(" key_type ", t) self) -> t[]", dictValues), \
Operator("aten::values(Dict(" key_type ", t) self) -> t[](*)", dictValues),\
Operator( \
"prim::DictIndex(Dict(" key_type ", t) self, " key_type \
" key) -> t", \
dictIndex)
" key) -> t(*)", \
dictIndex), \
Operator( \
"aten::_set_item(Dict(" key_type ", t)(a!) l, " key_type \
" idx, t v) -> ()", \
dictSetItem)
CREATE_DICT_OPS("str"),
CREATE_DICT_OPS("int"),