[jit] Fix dict type serialization (#32569)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32569

If the dict's contained types cannot be inferred from its contents (for
example, `Dict[str, Tensor]` vs. `Dict[str, Optional[Tensor]]`), we must
explicitly annotate the type.

Also this removes some special handling that omits annotations on empty
containers that have the default type. It makes the code more complex
for not too much value, and was wrong for dicts anyway.

Test Plan: Imported from OSS

Differential Revision: D19551016

Pulled By: suo

fbshipit-source-id: c529b112e72c10f509a6bc0f5876644caa1be967
This commit is contained in:
Michael Suo 2020-01-24 03:17:47 -08:00 committed by Facebook Github Bot
parent 3ada2e0d64
commit 8fd3eaed25
2 changed files with 43 additions and 20 deletions

View File

@ -34,6 +34,23 @@ class TestScriptPy3(JitTestCase):
self.assertAlmostEqual(out, out_script)
self.assertEqual(captured, captured_script)
def test_optional_dict_construct(self):
class M(torch.nn.Module):
def use(self, buffer: Dict[str, Optional[torch.Tensor]]):
return buffer["prev_key"]
def forward(self, x):
prev_key = torch.rand(2, 3)
next_key = torch.rand(2, 3)
saved_state: Dict[str, Optional[torch.Tensor]] = {
"prev_key": prev_key,
"next_key": next_key,
}
return self.use(saved_state)
self.checkModule(M(), (torch.rand(2, 2),))
def test_kwarg_support(self):
with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "variable number of arguments"):
class M(torch.nn.Module):

View File

@ -906,32 +906,38 @@ struct PythonPrintImpl {
case prim::ListConstruct: {
ListTypePtr list_type = node->output()->type()->expect<ListType>();
TypePtr elem_type = list_type->getElementType();
if (!elem_type->isSubtypeOf(TensorType::get())) {
// when the list is empty and is not a list of tensors,
// we need to annotate it, otherwise it won't be possible
// to infer the type on import
if (node->inputs().size() == 0) {
stmt << "annotate(" << node->output()->type()->python_str()
<< ", [])";
} else if (!elementTypeCanBeInferredFromMembers(elem_type)) {
stmt << "annotate(" << node->output()->type()->python_str() << ",";
printValueList(stmt, node->inputs(), "[", "]");
stmt << ")";
} else {
printValueList(stmt, node->inputs(), "[", "]");
}
// Empty lists must be annotated with their type so the compiler knows
// what type is supposed to be inside them
if (node->inputs().size() == 0) {
stmt << "annotate(" << node->output()->type()->python_str()
<< ", [])";
// If we can't infer the type based on what's inside, explicitly
// annotate it to disambiguate.
// This happens for List[Tensor] vs. List[Optional[Tensor]]
} else if (!elementTypeCanBeInferredFromMembers(elem_type)) {
stmt << "annotate(" << node->output()->type()->python_str() << ", ";
printValueList(stmt, node->inputs(), "[", "]");
stmt << ")";
// Otherwise just print a list
} else {
printValueList(stmt, node->inputs(), "[", "]");
}
} break;
case prim::DictConstruct: {
auto dict_type = node->output()->type()->expect<DictType>();
bool is_default_type =
dict_type->getKeyType()->isSubtypeOf(StringType::get()) &&
dict_type->getKeyType()->isSubtypeOf(TensorType::get());
if (node->inputs().size() == 0 && !is_default_type) {
stmt << "annotate(" << node->output()->type()->python_str()
<< ", {})";
// There are cases where we must annotate the dict with an explicit type
// to help the compiler out:
// - the dict is empty
// - the dict has potentially ambiguous element types
// (e.g. Tensor vs. Optional[Tensor])
if (
node->inputs().size() == 0 ||
!elementTypeCanBeInferredFromMembers(dict_type->getKeyType()) ||
!elementTypeCanBeInferredFromMembers(dict_type->getValueType())) {
stmt << "annotate(" << node->output()->type()->python_str() << ", ";
printDict(stmt, node->inputs());
stmt << ")";
// Otherwise just print a dict
} else {
printDict(stmt, node->inputs());
}