Fix incorrect type comparison (#145449)

Summary: This change was incorrectly made as part of #145166

Differential Revision: D68536221

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145449
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein 2025-01-26 04:40:26 +00:00 committed by PyTorch MergeBot
parent 09ae69a364
commit c4523999a1
2 changed files with 21 additions and 25 deletions

View File

@ -9,6 +9,7 @@ import pickle
import sys
import sympy
import tempfile
import typing
import unittest
from types import BuiltinFunctionType
from typing import Callable, NamedTuple, Optional, Union, List
@ -1547,25 +1548,25 @@ class {test_classname}(torch.nn.Module):
(Optional[list[int]], list[int]),
] + [
# pre-PEP585 signatures
(list[int], int),
(list[int], create_type_hint([int, int])),
(list[int], create_type_hint((int, int))),
(list[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])),
(typing.List[int], int),
(typing.List[int], create_type_hint([int, int])),
(typing.List[int], create_type_hint((int, int))),
(typing.List[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])),
(
list[torch.Tensor],
typing.List[torch.Tensor],
create_type_hint([torch.nn.Parameter, torch.nn.Parameter]),
),
(list[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])),
(list[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])),
(list[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))),
(typing.List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])),
(typing.List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])),
(typing.List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))),
(
list[torch.Tensor],
typing.List[torch.Tensor],
create_type_hint((torch.nn.Parameter, torch.nn.Parameter)),
),
(list[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))),
(list[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))),
(Optional[list[torch.Tensor]], list[torch.Tensor]),
(Optional[list[int]], list[int]),
(typing.List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))),
(typing.List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))),
(Optional[typing.List[torch.Tensor]], typing.List[torch.Tensor]),
(Optional[typing.List[int]], typing.List[int]),
]
for sig_type, arg_type in should_be_equal:

View File

@ -281,10 +281,6 @@ def create_type_hint(x):
return x
_LIST_TYPES = (list, typing.List) # noqa: UP006
_TUPLE_TYPES = (tuple, typing.Tuple) # noqa: UP006
@compatibility(is_backward_compatible=False)
def type_matches(signature_type: Any, argument_type: Any):
sig_origin_type = getattr(signature_type, "__origin__", signature_type)
@ -298,24 +294,23 @@ def type_matches(signature_type: Any, argument_type: Any):
sig_contained = signature_type.__args__
return any(type_matches(c, argument_type) for c in sig_contained)
if signature_type is typing.List[int] and argument_type is int: # noqa: UP006
# int can be promoted to List[int]
return True
if getattr(signature_type, "__origin__", None) in _LIST_TYPES:
if getattr(signature_type, "__origin__", None) is list:
sig_el_type = signature_type.__args__[0]
if sig_el_type is argument_type:
# int can be promoted to list[int]
if argument_type is int and sig_el_type is int:
return True
if not inspect.isclass(sig_el_type):
warnings.warn(
f"Does not support nested parametric types, got {signature_type}. Please file a bug."
)
return False
if getattr(argument_type, "__origin__", None) in _LIST_TYPES:
if getattr(argument_type, "__origin__", None) is list:
return issubclass(argument_type.__args__[0], sig_el_type)
def is_homogeneous_tuple(t):
if getattr(t, "__origin__", None) not in _TUPLE_TYPES:
if getattr(t, "__origin__", None) is not tuple:
return False
contained = t.__args__
if t.__args__ == ((),): # Tuple[()].__args__ == ((),) for some reason