[1/N] Apply UP035 rule in tests (#163947)

Apply UP035 `ruff` rule in tests, but some tests for `fx` and `dynamo` are excluded in case the old typing is the test target.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163947
Approved by: https://github.com/ezyang
This commit is contained in:
Yuanyuan Chen 2025-09-29 01:42:01 +00:00 committed by PyTorch MergeBot
parent dc54ce7554
commit a8c528c105
41 changed files with 79 additions and 49 deletions

View File

@ -6,7 +6,8 @@ import itertools
import os import os
import tempfile import tempfile
import unittest import unittest
from typing import Callable, Optional, Union from collections.abc import Callable
from typing import Optional, Union
from unittest.mock import MagicMock from unittest.mock import MagicMock
import torch import torch

View File

@ -3,7 +3,7 @@
import copy import copy
import functools import functools
import unittest import unittest
from typing import Callable from collections.abc import Callable
import torch import torch
import torch.distributed as dist import torch.distributed as dist

View File

@ -3,8 +3,9 @@
import contextlib import contextlib
import functools import functools
import unittest import unittest
from collections.abc import Callable
from copy import deepcopy from copy import deepcopy
from typing import Callable, Optional, Union from typing import Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist

View File

@ -1,7 +1,8 @@
# Owner(s): ["oncall: distributed"] # Owner(s): ["oncall: distributed"]
import unittest import unittest
from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, cast, Union from typing import Any, cast, Union
import torch import torch
from torch import nn, optim from torch import nn, optim

View File

@ -2,8 +2,9 @@
import shutil import shutil
import tempfile import tempfile
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Any, Callable, Optional from typing import Any, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist

View File

@ -3,8 +3,9 @@
import copy import copy
import functools import functools
import sys import sys
from collections.abc import Callable
from itertools import chain from itertools import chain
from typing import Callable, Union from typing import Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist

View File

@ -16,8 +16,9 @@ import tempfile
import time import time
import unittest import unittest
import uuid import uuid
from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Optional from typing import Optional
from unittest import mock from unittest import mock
from unittest.mock import Mock, patch from unittest.mock import Mock, patch

View File

@ -15,8 +15,9 @@ import signal
import sys import sys
import tempfile import tempfile
import time import time
from collections.abc import Callable
from itertools import product from itertools import product
from typing import Callable, Union from typing import Union
from unittest import mock from unittest import mock
import torch import torch

View File

@ -9,8 +9,9 @@
import os import os
import tempfile import tempfile
from base64 import b64encode from base64 import b64encode
from collections.abc import Callable
from datetime import timedelta from datetime import timedelta
from typing import Callable, cast, ClassVar from typing import cast, ClassVar
from unittest import mock, TestCase from unittest import mock, TestCase
from rendezvous_backend_test import RendezvousBackendTestMixin from rendezvous_backend_test import RendezvousBackendTestMixin

View File

@ -14,8 +14,9 @@ import threading
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from base64 import b64encode from base64 import b64encode
from collections.abc import Callable
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Callable, cast, Optional from typing import cast, Optional
from unittest import TestCase from unittest import TestCase
from unittest.mock import call, MagicMock, Mock, patch, PropertyMock from unittest.mock import call, MagicMock, Mock, patch, PropertyMock

View File

@ -7,7 +7,8 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, cast, Optional from collections.abc import Callable
from typing import Any, cast, Optional
from torch.distributed.elastic.rendezvous import RendezvousStateError from torch.distributed.elastic.rendezvous import RendezvousStateError
from torch.distributed.elastic.rendezvous.dynamic_rendezvous import ( from torch.distributed.elastic.rendezvous.dynamic_rendezvous import (

View File

@ -4,7 +4,8 @@ import functools
import itertools import itertools
import sys import sys
import unittest import unittest
from typing import Any, Callable, Optional from collections.abc import Callable
from typing import Any, Optional
from unittest import mock from unittest import mock
import torch import torch

View File

@ -2,9 +2,10 @@
import bisect import bisect
import sys import sys
from collections.abc import Callable
from copy import deepcopy from copy import deepcopy
from enum import auto, Enum from enum import auto, Enum
from typing import Any, Callable, Optional from typing import Any, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn

View File

@ -5,8 +5,9 @@ import itertools
import os import os
import tempfile import tempfile
import unittest import unittest
from collections.abc import Callable
from enum import auto, Enum from enum import auto, Enum
from typing import Callable, Union from typing import Union
import torch import torch
import torch.nn as nn import torch.nn as nn

View File

@ -2,7 +2,8 @@
# Owner(s): ["oncall: distributed"] # Owner(s): ["oncall: distributed"]
# This file is a Schedule zoo for testing torch.distributed.pipelining. # This file is a Schedule zoo for testing torch.distributed.pipelining.
# It includes schedules designed purely for testing purposes # It includes schedules designed purely for testing purposes
from typing import Callable, Optional from collections.abc import Callable
from typing import Optional
from torch.distributed.pipelining.schedules import ( from torch.distributed.pipelining.schedules import (
_Action, _Action,

View File

@ -1,8 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates # Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"] # Owner(s): ["oncall: distributed"]
from collections.abc import Sequence from collections.abc import Callable, Sequence
from typing import Any, Callable, Optional from typing import Any, Optional
from unittest import skip from unittest import skip
import torch import torch

View File

@ -3,8 +3,9 @@
import os import os
import unittest import unittest
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Any, Callable from typing import Any
import numpy as np import numpy as np

View File

@ -2129,7 +2129,8 @@ class GraphModule(torch.nn.Module):
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
@parametrize("dynamic", [True, False]) @parametrize("dynamic", [True, False])
def test_mark_static_with_subclass_desugaring(self, dynamic): def test_mark_static_with_subclass_desugaring(self, dynamic):
from typing import Any, Callable, Optional from collections.abc import Callable
from typing import Any, Optional
from torch._dynamo.decorators import mark_static_address from torch._dynamo.decorators import mark_static_address
from torch._inductor.compile_fx import compile_fx from torch._inductor.compile_fx import compile_fx

View File

@ -10,9 +10,10 @@ import copy
import itertools import itertools
import unittest import unittest
import warnings import warnings
from collections.abc import Callable
from contextlib import ContextDecorator, ExitStack, nullcontext from contextlib import ContextDecorator, ExitStack, nullcontext
from functools import partial, wraps from functools import partial, wraps
from typing import Any, Callable, Optional, Union from typing import Any, Optional, Union
from unittest.mock import patch from unittest.mock import patch
from common_utils import ( from common_utils import (

View File

@ -3,7 +3,7 @@
import inspect import inspect
import random import random
import unittest import unittest
from typing import Callable from collections.abc import Callable
import torch import torch
import torch.fx as fx import torch.fx as fx

View File

@ -9,8 +9,8 @@ import sys
import tempfile import tempfile
import unittest import unittest
import zipfile import zipfile
from collections.abc import Callable
from pathlib import Path from pathlib import Path
from typing import Callable
from parameterized import parameterized_class from parameterized import parameterized_class

View File

@ -8,7 +8,7 @@ from os import environ
from pathlib import Path from pathlib import Path
from random import randint from random import randint
from tempfile import gettempdir from tempfile import gettempdir
from typing import Any, Callable, Sequence from typing import Any, TYPE_CHECKING
from typing_extensions import Self from typing_extensions import Self
from unittest.mock import patch from unittest.mock import patch
@ -20,6 +20,10 @@ from torch.testing._internal.common_utils import (
) )
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
class TestMixin: class TestMixin:
@staticmethod @staticmethod
def abstract_cache_types() -> set[type[icache.Cache]]: def abstract_cache_types() -> set[type[icache.Cache]]:

View File

@ -8,7 +8,7 @@ import os
import platform import platform
import sys import sys
import unittest import unittest
from typing import Callable from collections.abc import Callable
from unittest.mock import patch from unittest.mock import patch
import torch import torch

View File

@ -7,10 +7,11 @@ import string
import unittest import unittest
import warnings import warnings
from collections import namedtuple from collections import namedtuple
from collections.abc import Callable
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from itertools import product from itertools import product
from typing import Callable, Optional, TypeVar, Union from typing import Optional, TypeVar, Union
from unittest import expectedFailure, skip, skipUnless from unittest import expectedFailure, skip, skipUnless
from unittest.mock import patch from unittest.mock import patch

View File

@ -5,7 +5,8 @@ import functools
import sys import sys
import unittest import unittest
from collections import namedtuple from collections import namedtuple
from typing import Callable, Optional, Union from collections.abc import Callable
from typing import Optional, Union
from unittest import expectedFailure from unittest import expectedFailure
from unittest.mock import patch from unittest.mock import patch

View File

@ -6,7 +6,8 @@ Test the FX IR backend.
import itertools import itertools
import operator import operator
import unittest import unittest
from typing import Callable, Optional from collections.abc import Callable
from typing import Optional
import sympy import sympy

View File

@ -10,7 +10,8 @@ import random
import re import re
import tempfile import tempfile
import unittest import unittest
from typing import Callable, Optional from collections.abc import Callable
from typing import Optional
from unittest import mock from unittest import mock
import torch import torch

View File

@ -3,7 +3,8 @@ import copy
import itertools import itertools
import os import os
import unittest import unittest
from typing import Callable, Optional from collections.abc import Callable
from typing import Optional
import torch import torch
import torch._dynamo.config as dynamo_config import torch._dynamo.config as dynamo_config
@ -180,8 +181,7 @@ class TestPatternMatcher(TestCase):
self._test_fused_int_mm_mul_impl(fn2, args, True) self._test_fused_int_mm_mul_impl(fn2, args, True)
def test_duplicate_search(self): def test_duplicate_search(self):
from collections.abc import Iterable from collections.abc import Callable, Iterable
from typing import Callable
import torch import torch
from torch._inductor.pattern_matcher import ( from torch._inductor.pattern_matcher import (

View File

@ -3,7 +3,8 @@ import json
import os import os
import tempfile import tempfile
import unittest import unittest
from typing import Callable, Optional from collections.abc import Callable
from typing import Optional
import torch import torch
import torch._inductor.test_case import torch._inductor.test_case

View File

@ -2,7 +2,7 @@
import contextlib import contextlib
import functools import functools
import unittest.mock import unittest.mock
from typing import Callable from collections.abc import Callable
from unittest.mock import patch from unittest.mock import patch
import torch import torch

View File

@ -19,8 +19,9 @@ import time
import unittest import unittest
import unittest.mock import unittest.mock
import weakref import weakref
from collections.abc import Callable
from pathlib import Path from pathlib import Path
from typing import Callable, TypeVar from typing import TypeVar
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
from unittest.mock import patch from unittest.mock import patch

View File

@ -1,6 +1,7 @@
# Owner(s): ["module: inductor"] # Owner(s): ["module: inductor"]
import importlib import importlib
from typing import Any, Callable, Optional from collections.abc import Callable
from typing import Any, Optional
from unittest import skipIf from unittest import skipIf
import torch import torch

View File

@ -5,7 +5,8 @@ import dataclasses
import importlib import importlib
import math import math
import unittest import unittest
from typing import Any, Callable, Optional, Union from collections.abc import Callable
from typing import Any, Optional, Union
import torch import torch
import torch.utils._pytree as pytree import torch.utils._pytree as pytree

View File

@ -7,8 +7,8 @@ import dataclasses
import io import io
import os import os
import unittest import unittest
from collections.abc import Collection, Iterable, Mapping, Sequence from collections.abc import Callable, Collection, Iterable, Mapping, Sequence
from typing import Any, Callable, Optional, Union from typing import Any, Optional, Union
import numpy as np import numpy as np
import onnxruntime import onnxruntime

View File

@ -12,8 +12,8 @@ import pprint
import sys import sys
import unittest import unittest
import warnings import warnings
from collections.abc import Collection, Iterable, Mapping, Sequence from collections.abc import Callable, Collection, Iterable, Mapping, Sequence
from typing import Any, Callable, Optional, TypeVar from typing import Any, Optional, TypeVar
import error_reproduction import error_reproduction
import numpy as np import numpy as np

View File

@ -39,7 +39,7 @@ from __future__ import annotations
import copy import copy
import dataclasses import dataclasses
import functools import functools
from typing import Any, Callable, Optional, TYPE_CHECKING from typing import Any, Optional, TYPE_CHECKING
from typing_extensions import Self from typing_extensions import Self
import numpy as np import numpy as np
@ -52,7 +52,7 @@ from torch.testing._internal.opinfo import definitions as opinfo_definitions
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Collection from collections.abc import Callable, Collection
# Create a copy of the op_db to modify # Create a copy of the op_db to modify

View File

@ -25,7 +25,7 @@ errors.
from __future__ import annotations from __future__ import annotations
import os import os
from typing import Callable, Optional, TYPE_CHECKING from typing import Optional, TYPE_CHECKING
import error_reproduction import error_reproduction
import numpy as np import numpy as np
@ -44,7 +44,7 @@ from torch.utils import _pytree as pytree
if TYPE_CHECKING: if TYPE_CHECKING:
import unittest import unittest
from collections.abc import Sequence from collections.abc import Callable, Sequence
from torch.testing._internal.opinfo import core as opinfo_core from torch.testing._internal.opinfo import core as opinfo_core

View File

@ -4,8 +4,8 @@ import gc
import itertools as it import itertools as it
import textwrap import textwrap
import unittest import unittest
from collections.abc import Iterator from collections.abc import Callable, Iterator
from typing import Callable, Optional from typing import Optional
import torch import torch
from torch._C._profiler import _EventType, _TensorMetadata from torch._C._profiler import _EventType, _TensorMetadata

View File

@ -7,7 +7,7 @@ import logging
import os import os
import pkgutil import pkgutil
import unittest import unittest
from typing import Callable from collections.abc import Callable
import torch import torch
from torch._utils_internal import get_file_path_2 # @manual from torch._utils_internal import get_file_path_2 # @manual

View File

@ -5,7 +5,7 @@ import itertools
import math import math
import pickle import pickle
import sys import sys
from typing import Callable from collections.abc import Callable
import sympy import sympy

View File

@ -12,7 +12,8 @@ import re
import subprocess import subprocess
import sys import sys
import unittest.mock import unittest.mock
from typing import Any, Callable from typing import Any
from collections.abc import Callable
from collections.abc import Iterator from collections.abc import Iterator
import torch import torch