[BE]: Update ruff to 0.4.5 (#126979)

Update ruff to 0.4.5 and addresses some false negatives that have been found in the newer version.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126979
Approved by: https://github.com/ezyang
This commit is contained in:
Aaron Gokaslan 2024-05-24 18:38:33 +00:00 committed by PyTorch MergeBot
parent 4a09117d16
commit 3cb16ebf08
13 changed files with 24 additions and 20 deletions

View File

@ -2128,7 +2128,7 @@ init_command = [
'python3',
'tools/linter/adapters/pip_init.py',
'--dry-run={{DRYRUN}}',
'ruff==0.4.4',
'ruff==0.4.5',
]
is_formatter = true

View File

@ -151,6 +151,9 @@ select = [
"test/torch_np/numpy_tests/**" = [
"F821",
]
"test/dynamo/test_debug_utils.py" = [
"UP037",
]
"test/jit/**" = [
"PLR0133", # tests require this for JIT
"PYI",
@ -163,6 +166,9 @@ select = [
"RUF015",
"UP", # We don't want to modify the jit test as they test specify syntax
]
"test/inductor/test_torchinductor.py" = [
"UP037",
]
# autogenerated #TODO figure out why file level noqa is ignored
"torch/_inductor/fx_passes/serialized_patterns/**" = ["F401", "F501"]
"torch/onnx/**" = [

View File

@ -2162,8 +2162,8 @@ class numpy_operator_wrapper:
def defake(x):
if not isinstance(x, FakeTensor):
return x
size: "torch._prims_common.ShapeType"
stride: "torch._prims_common.StrideType"
size: torch._prims_common.ShapeType
stride: torch._prims_common.StrideType
if x._has_symbolic_sizes_strides:
size = []
for s in x.size():
@ -2204,7 +2204,7 @@ def build_checkpoint_variable(**options):
# TODO - This is a temporary situation where we have two versions of
# checkpointing implementation. We will converge on one and remove the other.
activation_checkpoint_op: "torch._ops.HigherOrderOperator" = (
activation_checkpoint_op: torch._ops.HigherOrderOperator = (
higher_order_ops.tag_activation_checkpoint
)
if torch._functorch.config.functionalize_rng_ops:

View File

@ -74,7 +74,7 @@ def _check_input_constraints_for_graph(
# NOTE: export already guarantees that the same symbol is used in metadata
# for all InputDims related by equality constraints, so we can just unify
# symbols with given input dimension values to check equality constraints.
unification_map: "Dict[sympy.Symbol, Any]" = {}
unification_map: Dict[sympy.Symbol, Any] = {}
for (key_path, arg), node in zip(flat_args_with_path, input_placeholders):
node_val = node.meta.get("val")
if isinstance(node_val, FakeTensor):

View File

@ -3412,7 +3412,7 @@ class CppKernelProxy(CppKernel):
class OuterLoopFusedKernel(CppKernel):
def __init__(self, kernel_group):
super().__init__(kernel_group.args, kernel_group.ws.num_threads)
self.inner: List["LoopLevel"] = []
self.inner: List[LoopLevel] = []
def decide_parallel_depth(self, max_parallel_depth, threads) -> int:
kernels_parallel_depth = []

View File

@ -430,7 +430,7 @@ class WrapperCodeGen(CodeGen):
# If the generated source code is exactly the same, reuse the
# pre-existing kernel for it
self.src_to_kernel: Dict[str, str] = {}
self.kernel_numel_expr: Set[Tuple[str, "GraphLowering"]] = set()
self.kernel_numel_expr: Set[Tuple[str, GraphLowering]] = set()
self.lines: List[Union[MemoryPlanningLine, LineContext]] = []
self.declare = ""
self.declare_maybe_reference = ""

View File

@ -45,8 +45,8 @@ def raise_comms(
which is the beginning of the forwards pass. We'll have to either do a special pass for FSDP,
or we'll want to redo this pass with memory considerations so we handle the FSDP case in a general way.
"""
new_order_reversed: List["scheduler.BaseSchedulerNode"] = []
cur_comms: List["scheduler.BaseSchedulerNode"] = []
new_order_reversed: List[scheduler.BaseSchedulerNode] = []
cur_comms: List[scheduler.BaseSchedulerNode] = []
for snode in reversed(snodes):
if is_collective(snode.node):
cur_comms.append(snode)

View File

@ -394,7 +394,7 @@ class GraphLowering(torch.fx.Interpreter):
self.aot_mode = aot_mode
self.graph_id = graph_id
self.post_grad_graph_id = next(_post_grad_graph_counter)
self.scheduler: "torch._inductor.scheduler.Scheduler" = None # type: ignore[assignment]
self.scheduler: torch._inductor.scheduler.Scheduler = None # type: ignore[assignment]
self.nodes_prefer_channels_last = (
self.find_nodes_prefer_channels_last() if self.layout_opt else set()
)

View File

@ -2488,7 +2488,7 @@ class Scheduler:
if len(possible_fusions) == 0:
return possible_fusions
possible_fusions_group_by_priority: Dict[
int, List[Tuple["BaseSchedulerNode", "BaseSchedulerNode"]]
int, List[Tuple[BaseSchedulerNode, BaseSchedulerNode]]
] = {}
for node1, node2 in possible_fusions:

View File

@ -215,7 +215,7 @@ else:
# private field to pre-generate DeviceMesh's hash
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
self._parent_mesh: Optional["DeviceMesh"] = None
self._parent_mesh: Optional[DeviceMesh] = None
self._thread_id = threading.get_ident()
# Skip process group initialization if xla device or init backend is False

View File

@ -224,7 +224,7 @@ class ExportedProgram:
self._graph_signature: ExportGraphSignature = graph_signature
self._state_dict: Dict[str, Any] = state_dict
self._range_constraints: "Dict[sympy.Symbol, ValueRanges]" = range_constraints
self._range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints
assert module_call_graph is not None
self._module_call_graph: List[ModuleCallEntry] = module_call_graph
self._example_inputs = example_inputs

View File

@ -1,5 +1,5 @@
import inspect
from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING
from typing import Any, Callable, Dict, List, Optional, Set
from collections import OrderedDict
import logging
@ -8,8 +8,6 @@ from torch.fx._compatibility import compatibility
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node
if TYPE_CHECKING:
import sympy # noqa: F401
__all__ = ["Partition", "split_module"]
_LOGGER = logging.getLogger(__name__)
@ -172,9 +170,11 @@ def split_module(
base_mod_attrs[node.target] = attr_val # type: ignore[index]
return base_mod_env, base_mod_attrs
import sympy
partitions: Dict[str, Partition] = {}
orig_nodes: Dict[str, Node] = {}
symbol_to_node: Dict["sympy.Symbol", Node] = {}
symbol_to_node: Dict[sympy.Symbol, Node] = {}
def record_cross_partition_use(
def_node: Node, use_node: Optional[Node]
@ -237,8 +237,6 @@ def split_module(
active_grad = None
active_autocasts = set()
import sympy # noqa: F811
for node in m.graph.nodes:
if node.op in ["placeholder", "get_attr", "output"]:
if (

View File

@ -71,7 +71,7 @@ class Library:
self.ns = ns
self._op_defs: Set[str] = set()
self._op_impls: Set[str] = set()
self._registration_handles: List["torch._library.utils.RegistrationHandle"] = []
self._registration_handles: List[torch._library.utils.RegistrationHandle] = []
self.kind = kind
self.dispatch_key = dispatch_key
# Use a finalizer to setup the "destructor" instead of __del__.