mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[jit] Fix dict type serialization (#32569)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32569 If the dict's contained types cannot be inferred from its contents (for example, `Dict[str, Tensor]` vs. `Dict[str, Optional[Tensor]]`), we must explicitly annotate the type. Also this removes some special handling that omits annotations on empty containers that have the default type. It makes the code more complex for not too much value, and was wrong for dicts anyway. Test Plan: Imported from OSS Differential Revision: D19551016 Pulled By: suo fbshipit-source-id: c529b112e72c10f509a6bc0f5876644caa1be967
This commit is contained in:
parent
3ada2e0d64
commit
8fd3eaed25
|
|
@ -34,6 +34,23 @@ class TestScriptPy3(JitTestCase):
|
|||
self.assertAlmostEqual(out, out_script)
|
||||
self.assertEqual(captured, captured_script)
|
||||
|
||||
def test_optional_dict_construct(self):
|
||||
class M(torch.nn.Module):
|
||||
def use(self, buffer: Dict[str, Optional[torch.Tensor]]):
|
||||
return buffer["prev_key"]
|
||||
|
||||
def forward(self, x):
|
||||
prev_key = torch.rand(2, 3)
|
||||
next_key = torch.rand(2, 3)
|
||||
saved_state: Dict[str, Optional[torch.Tensor]] = {
|
||||
"prev_key": prev_key,
|
||||
"next_key": next_key,
|
||||
}
|
||||
|
||||
return self.use(saved_state)
|
||||
|
||||
self.checkModule(M(), (torch.rand(2, 2),))
|
||||
|
||||
def test_kwarg_support(self):
|
||||
with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "variable number of arguments"):
|
||||
class M(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -906,32 +906,38 @@ struct PythonPrintImpl {
|
|||
case prim::ListConstruct: {
|
||||
ListTypePtr list_type = node->output()->type()->expect<ListType>();
|
||||
TypePtr elem_type = list_type->getElementType();
|
||||
if (!elem_type->isSubtypeOf(TensorType::get())) {
|
||||
// when the list is empty and is not a list of tensors,
|
||||
// we need to annotate it, otherwise it won't be possible
|
||||
// to infer the type on import
|
||||
if (node->inputs().size() == 0) {
|
||||
stmt << "annotate(" << node->output()->type()->python_str()
|
||||
<< ", [])";
|
||||
} else if (!elementTypeCanBeInferredFromMembers(elem_type)) {
|
||||
stmt << "annotate(" << node->output()->type()->python_str() << ",";
|
||||
printValueList(stmt, node->inputs(), "[", "]");
|
||||
stmt << ")";
|
||||
} else {
|
||||
printValueList(stmt, node->inputs(), "[", "]");
|
||||
}
|
||||
// Empty lists must be annotated with their type so the compiler knows
|
||||
// what type is supposed to be inside them
|
||||
if (node->inputs().size() == 0) {
|
||||
stmt << "annotate(" << node->output()->type()->python_str()
|
||||
<< ", [])";
|
||||
// If we can't infer the type based on what's inside, explicitly
|
||||
// annotate it to disambiguate.
|
||||
// This happens for List[Tensor] vs. List[Optional[Tensor]]
|
||||
} else if (!elementTypeCanBeInferredFromMembers(elem_type)) {
|
||||
stmt << "annotate(" << node->output()->type()->python_str() << ", ";
|
||||
printValueList(stmt, node->inputs(), "[", "]");
|
||||
stmt << ")";
|
||||
// Otherwise just print a list
|
||||
} else {
|
||||
printValueList(stmt, node->inputs(), "[", "]");
|
||||
}
|
||||
} break;
|
||||
case prim::DictConstruct: {
|
||||
auto dict_type = node->output()->type()->expect<DictType>();
|
||||
bool is_default_type =
|
||||
dict_type->getKeyType()->isSubtypeOf(StringType::get()) &&
|
||||
dict_type->getKeyType()->isSubtypeOf(TensorType::get());
|
||||
if (node->inputs().size() == 0 && !is_default_type) {
|
||||
stmt << "annotate(" << node->output()->type()->python_str()
|
||||
<< ", {})";
|
||||
// There are cases where we must annotate the dict with an explicit type
|
||||
// to help the compiler out:
|
||||
// - the dict is empty
|
||||
// - the dict has potentially ambiguous element types
|
||||
// (e.g. Tensor vs. Optional[Tensor])
|
||||
if (
|
||||
node->inputs().size() == 0 ||
|
||||
!elementTypeCanBeInferredFromMembers(dict_type->getKeyType()) ||
|
||||
!elementTypeCanBeInferredFromMembers(dict_type->getValueType())) {
|
||||
stmt << "annotate(" << node->output()->type()->python_str() << ", ";
|
||||
printDict(stmt, node->inputs());
|
||||
stmt << ")";
|
||||
// Otherwise just print a dict
|
||||
} else {
|
||||
printDict(stmt, node->inputs());
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user