mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Correctly mark unannotated NamedTuple field to be inferred TensorType (#46969)
Summary: If there is no annotation given, we want to show users that the type is inferred Pull Request resolved: https://github.com/pytorch/pytorch/pull/46969 Test Plan: Added a new test case that throws an error with the expected error message Fixes https://github.com/pytorch/pytorch/issues/46326 Reviewed By: ZolotukhinM Differential Revision: D24614450 Pulled By: gmagogsfm fbshipit-source-id: dec555a53bfaa9cdefd3b21b5142f5e522847504
This commit is contained in:
parent
1e275bc1a6
commit
fee585b5a3
|
|
@ -82,7 +82,7 @@ from copy import deepcopy
|
|||
from itertools import product
|
||||
import itertools
|
||||
from textwrap import dedent
|
||||
from typing import List, Dict, Optional, Tuple, Union
|
||||
from typing import List, Dict, NamedTuple, Optional, Tuple, Union
|
||||
import inspect
|
||||
import math
|
||||
import functools
|
||||
|
|
@ -13796,6 +13796,23 @@ dedent """
|
|||
out = test_non_primitive_types(_MyNamedTuple(value=torch.tensor(5.0)))
|
||||
self.assertEqual(out, torch.tensor(6.0))
|
||||
|
||||
def test_namedtuple_type_inference(self):
|
||||
_AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('value', int)])
|
||||
_UnannotatedNamedTuple = namedtuple('_NamedTupleUnAnnotated', ['value'])
|
||||
|
||||
def test_check_named_tuple_value():
|
||||
named_tuple = _AnnotatedNamedTuple(1)
|
||||
return named_tuple.value
|
||||
|
||||
self.checkScript(test_check_named_tuple_value, ())
|
||||
|
||||
def test_error():
|
||||
return _UnannotatedNamedTuple(1)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"Expected a value of type \'Tensor \(inferred\)\' "
|
||||
r"for argument \'value\' but instead found type \'int\'."):
|
||||
torch.jit.script(test_error)
|
||||
|
||||
def test_isinstance_dynamic(self):
|
||||
@torch.jit.script
|
||||
def foo(a):
|
||||
|
|
|
|||
|
|
@ -839,7 +839,7 @@ def _get_named_tuple_properties(obj):
|
|||
the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], fake_range())
|
||||
annotations.append(the_type)
|
||||
else:
|
||||
annotations.append(torch._C.TensorType.get())
|
||||
annotations.append(torch._C.TensorType.getInferred())
|
||||
return type(obj).__name__, fields, annotations
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user