mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE]: Update ruff to 0.4.5 (#126979)
Update ruff to 0.4.5 and addresses some false negatives that have been found in the newer version. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126979 Approved by: https://github.com/ezyang
This commit is contained in:
parent
4a09117d16
commit
3cb16ebf08
|
|
@ -2128,7 +2128,7 @@ init_command = [
|
|||
'python3',
|
||||
'tools/linter/adapters/pip_init.py',
|
||||
'--dry-run={{DRYRUN}}',
|
||||
'ruff==0.4.4',
|
||||
'ruff==0.4.5',
|
||||
]
|
||||
is_formatter = true
|
||||
|
||||
|
|
|
|||
|
|
@ -151,6 +151,9 @@ select = [
|
|||
"test/torch_np/numpy_tests/**" = [
|
||||
"F821",
|
||||
]
|
||||
"test/dynamo/test_debug_utils.py" = [
|
||||
"UP037",
|
||||
]
|
||||
"test/jit/**" = [
|
||||
"PLR0133", # tests require this for JIT
|
||||
"PYI",
|
||||
|
|
@ -163,6 +166,9 @@ select = [
|
|||
"RUF015",
|
||||
"UP", # We don't want to modify the jit test as they test specify syntax
|
||||
]
|
||||
"test/inductor/test_torchinductor.py" = [
|
||||
"UP037",
|
||||
]
|
||||
# autogenerated #TODO figure out why file level noqa is ignored
|
||||
"torch/_inductor/fx_passes/serialized_patterns/**" = ["F401", "F501"]
|
||||
"torch/onnx/**" = [
|
||||
|
|
|
|||
|
|
@ -2162,8 +2162,8 @@ class numpy_operator_wrapper:
|
|||
def defake(x):
|
||||
if not isinstance(x, FakeTensor):
|
||||
return x
|
||||
size: "torch._prims_common.ShapeType"
|
||||
stride: "torch._prims_common.StrideType"
|
||||
size: torch._prims_common.ShapeType
|
||||
stride: torch._prims_common.StrideType
|
||||
if x._has_symbolic_sizes_strides:
|
||||
size = []
|
||||
for s in x.size():
|
||||
|
|
@ -2204,7 +2204,7 @@ def build_checkpoint_variable(**options):
|
|||
|
||||
# TODO - This is a temporary situation where we have two versions of
|
||||
# checkpointing implementation. We will converge on one and remove the other.
|
||||
activation_checkpoint_op: "torch._ops.HigherOrderOperator" = (
|
||||
activation_checkpoint_op: torch._ops.HigherOrderOperator = (
|
||||
higher_order_ops.tag_activation_checkpoint
|
||||
)
|
||||
if torch._functorch.config.functionalize_rng_ops:
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ def _check_input_constraints_for_graph(
|
|||
# NOTE: export already guarantees that the same symbol is used in metadata
|
||||
# for all InputDims related by equality constraints, so we can just unify
|
||||
# symbols with given input dimension values to check equality constraints.
|
||||
unification_map: "Dict[sympy.Symbol, Any]" = {}
|
||||
unification_map: Dict[sympy.Symbol, Any] = {}
|
||||
for (key_path, arg), node in zip(flat_args_with_path, input_placeholders):
|
||||
node_val = node.meta.get("val")
|
||||
if isinstance(node_val, FakeTensor):
|
||||
|
|
|
|||
|
|
@ -3412,7 +3412,7 @@ class CppKernelProxy(CppKernel):
|
|||
class OuterLoopFusedKernel(CppKernel):
|
||||
def __init__(self, kernel_group):
|
||||
super().__init__(kernel_group.args, kernel_group.ws.num_threads)
|
||||
self.inner: List["LoopLevel"] = []
|
||||
self.inner: List[LoopLevel] = []
|
||||
|
||||
def decide_parallel_depth(self, max_parallel_depth, threads) -> int:
|
||||
kernels_parallel_depth = []
|
||||
|
|
|
|||
|
|
@ -430,7 +430,7 @@ class WrapperCodeGen(CodeGen):
|
|||
# If the generated source code is exactly the same, reuse the
|
||||
# pre-existing kernel for it
|
||||
self.src_to_kernel: Dict[str, str] = {}
|
||||
self.kernel_numel_expr: Set[Tuple[str, "GraphLowering"]] = set()
|
||||
self.kernel_numel_expr: Set[Tuple[str, GraphLowering]] = set()
|
||||
self.lines: List[Union[MemoryPlanningLine, LineContext]] = []
|
||||
self.declare = ""
|
||||
self.declare_maybe_reference = ""
|
||||
|
|
|
|||
|
|
@ -45,8 +45,8 @@ def raise_comms(
|
|||
which is the beginning of the forwards pass. We'll have to either do a special pass for FSDP,
|
||||
or we'll want to redo this pass with memory considerations so we handle the FSDP case in a general way.
|
||||
"""
|
||||
new_order_reversed: List["scheduler.BaseSchedulerNode"] = []
|
||||
cur_comms: List["scheduler.BaseSchedulerNode"] = []
|
||||
new_order_reversed: List[scheduler.BaseSchedulerNode] = []
|
||||
cur_comms: List[scheduler.BaseSchedulerNode] = []
|
||||
for snode in reversed(snodes):
|
||||
if is_collective(snode.node):
|
||||
cur_comms.append(snode)
|
||||
|
|
|
|||
|
|
@ -394,7 +394,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
self.aot_mode = aot_mode
|
||||
self.graph_id = graph_id
|
||||
self.post_grad_graph_id = next(_post_grad_graph_counter)
|
||||
self.scheduler: "torch._inductor.scheduler.Scheduler" = None # type: ignore[assignment]
|
||||
self.scheduler: torch._inductor.scheduler.Scheduler = None # type: ignore[assignment]
|
||||
self.nodes_prefer_channels_last = (
|
||||
self.find_nodes_prefer_channels_last() if self.layout_opt else set()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2488,7 +2488,7 @@ class Scheduler:
|
|||
if len(possible_fusions) == 0:
|
||||
return possible_fusions
|
||||
possible_fusions_group_by_priority: Dict[
|
||||
int, List[Tuple["BaseSchedulerNode", "BaseSchedulerNode"]]
|
||||
int, List[Tuple[BaseSchedulerNode, BaseSchedulerNode]]
|
||||
] = {}
|
||||
|
||||
for node1, node2 in possible_fusions:
|
||||
|
|
|
|||
|
|
@ -215,7 +215,7 @@ else:
|
|||
|
||||
# private field to pre-generate DeviceMesh's hash
|
||||
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
|
||||
self._parent_mesh: Optional["DeviceMesh"] = None
|
||||
self._parent_mesh: Optional[DeviceMesh] = None
|
||||
self._thread_id = threading.get_ident()
|
||||
|
||||
# Skip process group initialization if xla device or init backend is False
|
||||
|
|
|
|||
|
|
@ -224,7 +224,7 @@ class ExportedProgram:
|
|||
|
||||
self._graph_signature: ExportGraphSignature = graph_signature
|
||||
self._state_dict: Dict[str, Any] = state_dict
|
||||
self._range_constraints: "Dict[sympy.Symbol, ValueRanges]" = range_constraints
|
||||
self._range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints
|
||||
assert module_call_graph is not None
|
||||
self._module_call_graph: List[ModuleCallEntry] = module_call_graph
|
||||
self._example_inputs = example_inputs
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
from collections import OrderedDict
|
||||
import logging
|
||||
|
||||
|
|
@ -8,8 +8,6 @@ from torch.fx._compatibility import compatibility
|
|||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.node import Node
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import sympy # noqa: F401
|
||||
|
||||
__all__ = ["Partition", "split_module"]
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
|
@ -172,9 +170,11 @@ def split_module(
|
|||
base_mod_attrs[node.target] = attr_val # type: ignore[index]
|
||||
return base_mod_env, base_mod_attrs
|
||||
|
||||
import sympy
|
||||
|
||||
partitions: Dict[str, Partition] = {}
|
||||
orig_nodes: Dict[str, Node] = {}
|
||||
symbol_to_node: Dict["sympy.Symbol", Node] = {}
|
||||
symbol_to_node: Dict[sympy.Symbol, Node] = {}
|
||||
|
||||
def record_cross_partition_use(
|
||||
def_node: Node, use_node: Optional[Node]
|
||||
|
|
@ -237,8 +237,6 @@ def split_module(
|
|||
active_grad = None
|
||||
active_autocasts = set()
|
||||
|
||||
import sympy # noqa: F811
|
||||
|
||||
for node in m.graph.nodes:
|
||||
if node.op in ["placeholder", "get_attr", "output"]:
|
||||
if (
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ class Library:
|
|||
self.ns = ns
|
||||
self._op_defs: Set[str] = set()
|
||||
self._op_impls: Set[str] = set()
|
||||
self._registration_handles: List["torch._library.utils.RegistrationHandle"] = []
|
||||
self._registration_handles: List[torch._library.utils.RegistrationHandle] = []
|
||||
self.kind = kind
|
||||
self.dispatch_key = dispatch_key
|
||||
# Use a finalizer to setup the "destructor" instead of __del__.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user