From 3cb16ebf080d317d06c390df04354632847e6261 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Fri, 24 May 2024 18:38:33 +0000 Subject: [PATCH] [BE]: Update ruff to 0.4.5 (#126979) Update ruff to 0.4.5 and addresses some false negatives that have been found in the newer version. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126979 Approved by: https://github.com/ezyang --- .lintrunner.toml | 2 +- pyproject.toml | 6 ++++++ torch/_dynamo/utils.py | 6 +++--- torch/_export/utils.py | 2 +- torch/_inductor/codegen/cpp.py | 2 +- torch/_inductor/codegen/wrapper.py | 2 +- torch/_inductor/comms.py | 4 ++-- torch/_inductor/graph.py | 2 +- torch/_inductor/scheduler.py | 2 +- torch/distributed/device_mesh.py | 2 +- torch/export/exported_program.py | 2 +- torch/fx/passes/split_module.py | 10 ++++------ torch/library.py | 2 +- 13 files changed, 24 insertions(+), 20 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 2f2f0db1776..6c102a3011d 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -2128,7 +2128,7 @@ init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', - 'ruff==0.4.4', + 'ruff==0.4.5', ] is_formatter = true diff --git a/pyproject.toml b/pyproject.toml index 07f07508209..bd6c6a6ae1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -151,6 +151,9 @@ select = [ "test/torch_np/numpy_tests/**" = [ "F821", ] +"test/dynamo/test_debug_utils.py" = [ + "UP037", +] "test/jit/**" = [ "PLR0133", # tests require this for JIT "PYI", @@ -163,6 +166,9 @@ select = [ "RUF015", "UP", # We don't want to modify the jit test as they test specify syntax ] +"test/inductor/test_torchinductor.py" = [ + "UP037", +] # autogenerated #TODO figure out why file level noqa is ignored "torch/_inductor/fx_passes/serialized_patterns/**" = ["F401", "F501"] "torch/onnx/**" = [ diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index fcfbde1a6a7..56340a3a1fe 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2162,8 +2162,8 @@ class numpy_operator_wrapper: def defake(x): if not isinstance(x, FakeTensor): return x - size: "torch._prims_common.ShapeType" - stride: "torch._prims_common.StrideType" + size: torch._prims_common.ShapeType + stride: torch._prims_common.StrideType if x._has_symbolic_sizes_strides: size = [] for s in x.size(): @@ -2204,7 +2204,7 @@ def build_checkpoint_variable(**options): # TODO - This is a temporary situation where we have two versions of # checkpointing implementation. We will converge on one and remove the other. - activation_checkpoint_op: "torch._ops.HigherOrderOperator" = ( + activation_checkpoint_op: torch._ops.HigherOrderOperator = ( higher_order_ops.tag_activation_checkpoint ) if torch._functorch.config.functionalize_rng_ops: diff --git a/torch/_export/utils.py b/torch/_export/utils.py index 19fc4e9bdc4..72616f6ddfb 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -74,7 +74,7 @@ def _check_input_constraints_for_graph( # NOTE: export already guarantees that the same symbol is used in metadata # for all InputDims related by equality constraints, so we can just unify # symbols with given input dimension values to check equality constraints. - unification_map: "Dict[sympy.Symbol, Any]" = {} + unification_map: Dict[sympy.Symbol, Any] = {} for (key_path, arg), node in zip(flat_args_with_path, input_placeholders): node_val = node.meta.get("val") if isinstance(node_val, FakeTensor): diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 52f92bd0bec..077363e40bf 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -3412,7 +3412,7 @@ class CppKernelProxy(CppKernel): class OuterLoopFusedKernel(CppKernel): def __init__(self, kernel_group): super().__init__(kernel_group.args, kernel_group.ws.num_threads) - self.inner: List["LoopLevel"] = [] + self.inner: List[LoopLevel] = [] def decide_parallel_depth(self, max_parallel_depth, threads) -> int: kernels_parallel_depth = [] diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 0bab978630d..6bc584a184d 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -430,7 +430,7 @@ class WrapperCodeGen(CodeGen): # If the generated source code is exactly the same, reuse the # pre-existing kernel for it self.src_to_kernel: Dict[str, str] = {} - self.kernel_numel_expr: Set[Tuple[str, "GraphLowering"]] = set() + self.kernel_numel_expr: Set[Tuple[str, GraphLowering]] = set() self.lines: List[Union[MemoryPlanningLine, LineContext]] = [] self.declare = "" self.declare_maybe_reference = "" diff --git a/torch/_inductor/comms.py b/torch/_inductor/comms.py index 4247f9a2cff..a1fe0e1cdce 100644 --- a/torch/_inductor/comms.py +++ b/torch/_inductor/comms.py @@ -45,8 +45,8 @@ def raise_comms( which is the beginning of the forwards pass. We'll have to either do a special pass for FSDP, or we'll want to redo this pass with memory considerations so we handle the FSDP case in a general way. """ - new_order_reversed: List["scheduler.BaseSchedulerNode"] = [] - cur_comms: List["scheduler.BaseSchedulerNode"] = [] + new_order_reversed: List[scheduler.BaseSchedulerNode] = [] + cur_comms: List[scheduler.BaseSchedulerNode] = [] for snode in reversed(snodes): if is_collective(snode.node): cur_comms.append(snode) diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index eb397920d65..ad4d09b7e6d 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -394,7 +394,7 @@ class GraphLowering(torch.fx.Interpreter): self.aot_mode = aot_mode self.graph_id = graph_id self.post_grad_graph_id = next(_post_grad_graph_counter) - self.scheduler: "torch._inductor.scheduler.Scheduler" = None # type: ignore[assignment] + self.scheduler: torch._inductor.scheduler.Scheduler = None # type: ignore[assignment] self.nodes_prefer_channels_last = ( self.find_nodes_prefer_channels_last() if self.layout_opt else set() ) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 59a696ecd21..650ad2e4e98 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2488,7 +2488,7 @@ class Scheduler: if len(possible_fusions) == 0: return possible_fusions possible_fusions_group_by_priority: Dict[ - int, List[Tuple["BaseSchedulerNode", "BaseSchedulerNode"]] + int, List[Tuple[BaseSchedulerNode, BaseSchedulerNode]] ] = {} for node1, node2 in possible_fusions: diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index c0981a549c6..57b8fa1cf56 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -215,7 +215,7 @@ else: # private field to pre-generate DeviceMesh's hash self._flatten_mesh_list = tuple(self.mesh.flatten().tolist()) - self._parent_mesh: Optional["DeviceMesh"] = None + self._parent_mesh: Optional[DeviceMesh] = None self._thread_id = threading.get_ident() # Skip process group initialization if xla device or init backend is False diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index ffb3467055b..cc6a9e65dd3 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -224,7 +224,7 @@ class ExportedProgram: self._graph_signature: ExportGraphSignature = graph_signature self._state_dict: Dict[str, Any] = state_dict - self._range_constraints: "Dict[sympy.Symbol, ValueRanges]" = range_constraints + self._range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints assert module_call_graph is not None self._module_call_graph: List[ModuleCallEntry] = module_call_graph self._example_inputs = example_inputs diff --git a/torch/fx/passes/split_module.py b/torch/fx/passes/split_module.py index daf86ec52dc..977741cfe62 100644 --- a/torch/fx/passes/split_module.py +++ b/torch/fx/passes/split_module.py @@ -1,5 +1,5 @@ import inspect -from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING +from typing import Any, Callable, Dict, List, Optional, Set from collections import OrderedDict import logging @@ -8,8 +8,6 @@ from torch.fx._compatibility import compatibility from torch.fx.graph_module import GraphModule from torch.fx.node import Node -if TYPE_CHECKING: - import sympy # noqa: F401 __all__ = ["Partition", "split_module"] _LOGGER = logging.getLogger(__name__) @@ -172,9 +170,11 @@ def split_module( base_mod_attrs[node.target] = attr_val # type: ignore[index] return base_mod_env, base_mod_attrs + import sympy + partitions: Dict[str, Partition] = {} orig_nodes: Dict[str, Node] = {} - symbol_to_node: Dict["sympy.Symbol", Node] = {} + symbol_to_node: Dict[sympy.Symbol, Node] = {} def record_cross_partition_use( def_node: Node, use_node: Optional[Node] @@ -237,8 +237,6 @@ def split_module( active_grad = None active_autocasts = set() - import sympy # noqa: F811 - for node in m.graph.nodes: if node.op in ["placeholder", "get_attr", "output"]: if ( diff --git a/torch/library.py b/torch/library.py index 8ebdcb969d2..4d99f4f8284 100644 --- a/torch/library.py +++ b/torch/library.py @@ -71,7 +71,7 @@ class Library: self.ns = ns self._op_defs: Set[str] = set() self._op_impls: Set[str] = set() - self._registration_handles: List["torch._library.utils.RegistrationHandle"] = [] + self._registration_handles: List[torch._library.utils.RegistrationHandle] = [] self.kind = kind self.dispatch_key = dispatch_key # Use a finalizer to setup the "destructor" instead of __del__.