Fix incorrect type comparison (#145449)

Summary: This change was incorrectly made as part of #145166

Differential Revision: D68536221

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145449
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein 2025-01-26 04:40:26 +00:00 committed by PyTorch MergeBot
parent 09ae69a364
commit c4523999a1
2 changed files with 21 additions and 25 deletions

View File

@ -9,6 +9,7 @@ import pickle
import sys import sys
import sympy import sympy
import tempfile import tempfile
import typing
import unittest import unittest
from types import BuiltinFunctionType from types import BuiltinFunctionType
from typing import Callable, NamedTuple, Optional, Union, List from typing import Callable, NamedTuple, Optional, Union, List
@ -1547,25 +1548,25 @@ class {test_classname}(torch.nn.Module):
(Optional[list[int]], list[int]), (Optional[list[int]], list[int]),
] + [ ] + [
# pre-PEP585 signatures # pre-PEP585 signatures
(list[int], int), (typing.List[int], int),
(list[int], create_type_hint([int, int])), (typing.List[int], create_type_hint([int, int])),
(list[int], create_type_hint((int, int))), (typing.List[int], create_type_hint((int, int))),
(list[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])), (typing.List[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])),
( (
list[torch.Tensor], typing.List[torch.Tensor],
create_type_hint([torch.nn.Parameter, torch.nn.Parameter]), create_type_hint([torch.nn.Parameter, torch.nn.Parameter]),
), ),
(list[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])), (typing.List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])),
(list[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])), (typing.List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])),
(list[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))), (typing.List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))),
( (
list[torch.Tensor], typing.List[torch.Tensor],
create_type_hint((torch.nn.Parameter, torch.nn.Parameter)), create_type_hint((torch.nn.Parameter, torch.nn.Parameter)),
), ),
(list[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))), (typing.List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))),
(list[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))), (typing.List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))),
(Optional[list[torch.Tensor]], list[torch.Tensor]), (Optional[typing.List[torch.Tensor]], typing.List[torch.Tensor]),
(Optional[list[int]], list[int]), (Optional[typing.List[int]], typing.List[int]),
] ]
for sig_type, arg_type in should_be_equal: for sig_type, arg_type in should_be_equal:

View File

@ -281,10 +281,6 @@ def create_type_hint(x):
return x return x
_LIST_TYPES = (list, typing.List) # noqa: UP006
_TUPLE_TYPES = (tuple, typing.Tuple) # noqa: UP006
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
def type_matches(signature_type: Any, argument_type: Any): def type_matches(signature_type: Any, argument_type: Any):
sig_origin_type = getattr(signature_type, "__origin__", signature_type) sig_origin_type = getattr(signature_type, "__origin__", signature_type)
@ -298,24 +294,23 @@ def type_matches(signature_type: Any, argument_type: Any):
sig_contained = signature_type.__args__ sig_contained = signature_type.__args__
return any(type_matches(c, argument_type) for c in sig_contained) return any(type_matches(c, argument_type) for c in sig_contained)
if signature_type is typing.List[int] and argument_type is int: # noqa: UP006 if getattr(signature_type, "__origin__", None) is list:
# int can be promoted to List[int]
return True
if getattr(signature_type, "__origin__", None) in _LIST_TYPES:
sig_el_type = signature_type.__args__[0] sig_el_type = signature_type.__args__[0]
if sig_el_type is argument_type:
# int can be promoted to list[int]
if argument_type is int and sig_el_type is int:
return True return True
if not inspect.isclass(sig_el_type): if not inspect.isclass(sig_el_type):
warnings.warn( warnings.warn(
f"Does not support nested parametric types, got {signature_type}. Please file a bug." f"Does not support nested parametric types, got {signature_type}. Please file a bug."
) )
return False return False
if getattr(argument_type, "__origin__", None) in _LIST_TYPES: if getattr(argument_type, "__origin__", None) is list:
return issubclass(argument_type.__args__[0], sig_el_type) return issubclass(argument_type.__args__[0], sig_el_type)
def is_homogeneous_tuple(t): def is_homogeneous_tuple(t):
if getattr(t, "__origin__", None) not in _TUPLE_TYPES: if getattr(t, "__origin__", None) is not tuple:
return False return False
contained = t.__args__ contained = t.__args__
if t.__args__ == ((),): # Tuple[()].__args__ == ((),) for some reason if t.__args__ == ((),): # Tuple[()].__args__ == ((),) for some reason