mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Dict mutability (#16884)
Summary: Adds `aten::_set_item` for `dict[key]` calls Pull Request resolved: https://github.com/pytorch/pytorch/pull/16884 Differential Revision: D14000488 Pulled By: driazati fbshipit-source-id: ea1b46e0a736d095053effb4bc52753f696617b2
This commit is contained in:
parent
3a47d56946
commit
ac00a0cd47
|
|
@ -41,7 +41,7 @@ from common_methods_invocations import create_input, unpack_variables, \
|
|||
exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL
|
||||
from copy import deepcopy
|
||||
import random
|
||||
from typing import List, Optional
|
||||
from typing import List, Dict, Optional
|
||||
from torch.jit.frontend import NotSupportedError
|
||||
from torch.jit import BatchTensor
|
||||
|
||||
|
|
@ -10328,6 +10328,16 @@ a")
|
|||
|
||||
self.checkScript(list_of_dicts, ())
|
||||
|
||||
def test_dict_mutability(self):
|
||||
@torch.jit.script
|
||||
def fn():
|
||||
# type: () -> Dict[str, int]
|
||||
a = torch.jit.annotate(Dict[str, int], {})
|
||||
a['ok'] = 10
|
||||
return a
|
||||
|
||||
self.assertEqual(fn(), {'ok': 10})
|
||||
|
||||
def dict_to_python(self):
|
||||
def python_lookup(my_dict, keys):
|
||||
# type: (Dict[str, int], List[str]) -> List[int]
|
||||
|
|
|
|||
|
|
@ -167,8 +167,9 @@ struct SchemaParser {
|
|||
auto key_type = parseType().first;
|
||||
L.expect(',');
|
||||
auto value_type = parseType().first;
|
||||
alias_info = parseAliasAnnotation();
|
||||
L.expect(')');
|
||||
alias_info = parseAliasAnnotation();
|
||||
|
||||
value = DictType::create(key_type, value_type);
|
||||
} else {
|
||||
auto value_alias = parseBaseType();
|
||||
|
|
|
|||
|
|
@ -801,9 +801,10 @@ struct PythonPrintPass {
|
|||
} break;
|
||||
case prim::DictConstruct: {
|
||||
auto dict_type = node->output()->type()->expect<DictType>();
|
||||
if (node->inputs().size() == 0 &&
|
||||
!dict_type->getKeyType()->isSubtypeOf(StringType::get()) &&
|
||||
!dict_type->getValueType()->isSubtypeOf(TensorType::get())) {
|
||||
bool is_default_type =
|
||||
dict_type->getKeyType()->isSubtypeOf(StringType::get()) &&
|
||||
dict_type->getKeyType()->isSubtypeOf(TensorType::get());
|
||||
if (node->inputs().size() == 0 && !is_default_type) {
|
||||
stmt << "annotate(" << node->output()->type()->python_str()
|
||||
<< ", {})";
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -1179,6 +1179,15 @@ int listSetItem<Shared<BoolList>, bool>(Stack& stack) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
int dictSetItem(Stack& stack) {
|
||||
auto value = pop(stack);
|
||||
auto idx = pop(stack);
|
||||
auto& dict = pop(stack).toGenericDict()->elements();
|
||||
dict[idx] = value;
|
||||
push(stack, dict);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int dictLen(Stack& stack) {
|
||||
auto dict = pop(stack).toGenericDictRef();
|
||||
push(stack, int64_t(dict.size()));
|
||||
|
|
@ -1316,6 +1325,8 @@ RegisterOperators reg2({
|
|||
// NOTE: this must be after the other list specializations so that operator
|
||||
// resolution doesn't pick this up first
|
||||
CREATE_MUTABLE_LIST_OPS("t", GenericList),
|
||||
#undef CREATE_IMMUTABLE_LIST_OPS
|
||||
#undef CREATE_MUTABLE_LIST_OPS
|
||||
|
||||
#define CREATE_LIST_OPS(decl_type, c_type) \
|
||||
Operator("aten::len(" decl_type "[] a) -> int", listLen<Shared<c_type>>), \
|
||||
|
|
@ -1506,13 +1517,17 @@ RegisterOperators reg2({
|
|||
#define CREATE_DICT_OPS(key_type) \
|
||||
Operator("aten::len(Dict(" key_type ", t) self) -> int", dictLen), \
|
||||
Operator( \
|
||||
"aten::keys(Dict(" key_type ", t) self) -> " key_type "[]", \
|
||||
"aten::keys(Dict(" key_type ", t) self) -> " key_type "[](*)", \
|
||||
dictKeys), \
|
||||
Operator("aten::values(Dict(" key_type ", t) self) -> t[]", dictValues), \
|
||||
Operator("aten::values(Dict(" key_type ", t) self) -> t[](*)", dictValues),\
|
||||
Operator( \
|
||||
"prim::DictIndex(Dict(" key_type ", t) self, " key_type \
|
||||
" key) -> t", \
|
||||
dictIndex)
|
||||
" key) -> t(*)", \
|
||||
dictIndex), \
|
||||
Operator( \
|
||||
"aten::_set_item(Dict(" key_type ", t)(a!) l, " key_type \
|
||||
" idx, t v) -> ()", \
|
||||
dictSetItem)
|
||||
|
||||
CREATE_DICT_OPS("str"),
|
||||
CREATE_DICT_OPS("int"),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user