Correctly mark unannotated NamedTuple field to be inferred TensorType (#46969)

Summary:
If there is no annotation given, we want to show users that the type is inferred

Pull Request resolved: https://github.com/pytorch/pytorch/pull/46969

Test Plan:
Added a new test case that throws an error with the expected error message

Fixes https://github.com/pytorch/pytorch/issues/46326

Reviewed By: ZolotukhinM

Differential Revision: D24614450

Pulled By: gmagogsfm

fbshipit-source-id: dec555a53bfaa9cdefd3b21b5142f5e522847504
This commit is contained in:
tmanlaibaatar 2020-10-29 11:58:10 -07:00 committed by Facebook GitHub Bot
parent 1e275bc1a6
commit fee585b5a3
2 changed files with 19 additions and 2 deletions

View File

@ -82,7 +82,7 @@ from copy import deepcopy
from itertools import product
import itertools
from textwrap import dedent
from typing import List, Dict, Optional, Tuple, Union
from typing import List, Dict, NamedTuple, Optional, Tuple, Union
import inspect
import math
import functools
@ -13796,6 +13796,23 @@ dedent """
out = test_non_primitive_types(_MyNamedTuple(value=torch.tensor(5.0)))
self.assertEqual(out, torch.tensor(6.0))
def test_namedtuple_type_inference(self):
_AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('value', int)])
_UnannotatedNamedTuple = namedtuple('_NamedTupleUnAnnotated', ['value'])
def test_check_named_tuple_value():
named_tuple = _AnnotatedNamedTuple(1)
return named_tuple.value
self.checkScript(test_check_named_tuple_value, ())
def test_error():
return _UnannotatedNamedTuple(1)
with self.assertRaisesRegex(RuntimeError, r"Expected a value of type \'Tensor \(inferred\)\' "
r"for argument \'value\' but instead found type \'int\'."):
torch.jit.script(test_error)
def test_isinstance_dynamic(self):
@torch.jit.script
def foo(a):

View File

@ -839,7 +839,7 @@ def _get_named_tuple_properties(obj):
the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], fake_range())
annotations.append(the_type)
else:
annotations.append(torch._C.TensorType.get())
annotations.append(torch._C.TensorType.getInferred())
return type(obj).__name__, fields, annotations