mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix incorrect type comparison (#145449)
Summary: This change was incorrectly made as part of #145166 Differential Revision: D68536221 Pull Request resolved: https://github.com/pytorch/pytorch/pull/145449 Approved by: https://github.com/bobrenjc93
This commit is contained in:
parent
09ae69a364
commit
c4523999a1
|
|
@ -9,6 +9,7 @@ import pickle
|
|||
import sys
|
||||
import sympy
|
||||
import tempfile
|
||||
import typing
|
||||
import unittest
|
||||
from types import BuiltinFunctionType
|
||||
from typing import Callable, NamedTuple, Optional, Union, List
|
||||
|
|
@ -1547,25 +1548,25 @@ class {test_classname}(torch.nn.Module):
|
|||
(Optional[list[int]], list[int]),
|
||||
] + [
|
||||
# pre-PEP585 signatures
|
||||
(list[int], int),
|
||||
(list[int], create_type_hint([int, int])),
|
||||
(list[int], create_type_hint((int, int))),
|
||||
(list[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])),
|
||||
(typing.List[int], int),
|
||||
(typing.List[int], create_type_hint([int, int])),
|
||||
(typing.List[int], create_type_hint((int, int))),
|
||||
(typing.List[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])),
|
||||
(
|
||||
list[torch.Tensor],
|
||||
typing.List[torch.Tensor],
|
||||
create_type_hint([torch.nn.Parameter, torch.nn.Parameter]),
|
||||
),
|
||||
(list[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])),
|
||||
(list[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])),
|
||||
(list[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))),
|
||||
(typing.List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])),
|
||||
(typing.List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])),
|
||||
(typing.List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))),
|
||||
(
|
||||
list[torch.Tensor],
|
||||
typing.List[torch.Tensor],
|
||||
create_type_hint((torch.nn.Parameter, torch.nn.Parameter)),
|
||||
),
|
||||
(list[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))),
|
||||
(list[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))),
|
||||
(Optional[list[torch.Tensor]], list[torch.Tensor]),
|
||||
(Optional[list[int]], list[int]),
|
||||
(typing.List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))),
|
||||
(typing.List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))),
|
||||
(Optional[typing.List[torch.Tensor]], typing.List[torch.Tensor]),
|
||||
(Optional[typing.List[int]], typing.List[int]),
|
||||
]
|
||||
|
||||
for sig_type, arg_type in should_be_equal:
|
||||
|
|
|
|||
|
|
@ -281,10 +281,6 @@ def create_type_hint(x):
|
|||
return x
|
||||
|
||||
|
||||
_LIST_TYPES = (list, typing.List) # noqa: UP006
|
||||
_TUPLE_TYPES = (tuple, typing.Tuple) # noqa: UP006
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def type_matches(signature_type: Any, argument_type: Any):
|
||||
sig_origin_type = getattr(signature_type, "__origin__", signature_type)
|
||||
|
|
@ -298,24 +294,23 @@ def type_matches(signature_type: Any, argument_type: Any):
|
|||
sig_contained = signature_type.__args__
|
||||
return any(type_matches(c, argument_type) for c in sig_contained)
|
||||
|
||||
if signature_type is typing.List[int] and argument_type is int: # noqa: UP006
|
||||
# int can be promoted to List[int]
|
||||
return True
|
||||
|
||||
if getattr(signature_type, "__origin__", None) in _LIST_TYPES:
|
||||
if getattr(signature_type, "__origin__", None) is list:
|
||||
sig_el_type = signature_type.__args__[0]
|
||||
if sig_el_type is argument_type:
|
||||
|
||||
# int can be promoted to list[int]
|
||||
if argument_type is int and sig_el_type is int:
|
||||
return True
|
||||
|
||||
if not inspect.isclass(sig_el_type):
|
||||
warnings.warn(
|
||||
f"Does not support nested parametric types, got {signature_type}. Please file a bug."
|
||||
)
|
||||
return False
|
||||
if getattr(argument_type, "__origin__", None) in _LIST_TYPES:
|
||||
if getattr(argument_type, "__origin__", None) is list:
|
||||
return issubclass(argument_type.__args__[0], sig_el_type)
|
||||
|
||||
def is_homogeneous_tuple(t):
|
||||
if getattr(t, "__origin__", None) not in _TUPLE_TYPES:
|
||||
if getattr(t, "__origin__", None) is not tuple:
|
||||
return False
|
||||
contained = t.__args__
|
||||
if t.__args__ == ((),): # Tuple[()].__args__ == ((),) for some reason
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user