From 2f3b0befedaf66fb2b54bf3fbd815fbd274932eb Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Fri, 26 Apr 2024 14:34:52 +0000 Subject: [PATCH] [BE]: Apply ruff FURB 118. (#124743) Replaces various lambdas with operator.itemgetter which is more efficient (as it's a builtin function). Particularly useful for when lambdas are used as 'key' functions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124743 Approved by: https://github.com/albanD, https://github.com/malfet --- .github/scripts/get_workflow_job_id.py | 3 +- torch/_export/serde/serialize.py | 10 +++---- torch/_functorch/benchmark_utils.py | 3 +- torch/_functorch/partitioners.py | 6 ++-- .../_functorch/top_operators_github_usage.py | 4 ++- torch/_inductor/pattern_matcher.py | 2 +- torch/_inductor/scheduler.py | 2 +- torch/_refs/__init__.py | 2 +- torch/_refs/linalg/__init__.py | 4 ++- torch/ao/quantization/fx/_equalize.py | 3 +- torch/cuda/_memory_viz.py | 3 +- torch/distributed/_shard/sharding_spec/api.py | 3 +- torch/distributed/_tools/memory_tracker.py | 3 +- torch/distributed/checkpoint/filesystem.py | 5 ++-- torch/distributed/checkpoint/planner.py | 3 +- torch/export/unflatten.py | 2 +- .../experimental/accelerator_partitioner.py | 2 +- .../unification/multipledispatch/conflict.py | 3 +- torch/profiler/_utils.py | 3 +- torch/testing/_internal/common_modules.py | 29 ++++++++++--------- .../_internal/distributed/rpc/rpc_test.py | 5 ++-- .../utils/benchmark/examples/op_benchmark.py | 3 +- .../benchmark/examples/sparse/op_benchmark.py | 3 +- torch/utils/benchmark/utils/compare.py | 3 +- .../operator_versions/gen_mobile_upgraders.py | 3 +- 25 files changed, 66 insertions(+), 46 deletions(-) diff --git a/.github/scripts/get_workflow_job_id.py b/.github/scripts/get_workflow_job_id.py index 75bc7e01617..28f337a3071 100644 --- a/.github/scripts/get_workflow_job_id.py +++ b/.github/scripts/get_workflow_job_id.py @@ -4,6 +4,7 @@ import argparse import json +import operator import os import re import sys @@ -126,7 +127,7 @@ def find_job_id_name(args: Any) -> Tuple[str, str]: # Sort the jobs list by start time, in descending order. We want to get the most # recently scheduled job on the runner. - jobs.sort(key=lambda job: job["started_at"], reverse=True) + jobs.sort(key=operator.itemgetter("started_at"), reverse=True) for job in jobs: if job["runner_name"] == args.runner_name: diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 31483e68f0b..77f0cf3d3c9 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -2630,12 +2630,12 @@ def _canonicalize_graph( n.metadata.clear() # Stage 4: Aggregate values. - sorted_tensor_values = dict(sorted(graph.tensor_values.items(), key=lambda x: x[0])) + sorted_tensor_values = dict(sorted(graph.tensor_values.items(), key=operator.itemgetter(0))) sorted_sym_int_values = dict( - sorted(graph.sym_int_values.items(), key=lambda x: x[0]) + sorted(graph.sym_int_values.items(), key=operator.itemgetter(0)) ) sorted_sym_bool_values = dict( - sorted(graph.sym_bool_values.items(), key=lambda x: x[0]) + sorted(graph.sym_bool_values.items(), key=operator.itemgetter(0)) ) # Stage 5: Recurse in subgraphs. @@ -2683,8 +2683,8 @@ def canonicalize(ep: ExportedProgram) -> ExportedProgram: """ ep = copy.deepcopy(ep) - opset_version = dict(sorted(ep.opset_version.items(), key=lambda x: x[0])) - range_constraints = dict(sorted(ep.range_constraints.items(), key=lambda x: x[0])) + opset_version = dict(sorted(ep.opset_version.items(), key=operator.itemgetter(0))) + range_constraints = dict(sorted(ep.range_constraints.items(), key=operator.itemgetter(0))) module_call_graph = sorted(ep.graph_module.module_call_graph, key=lambda x: x.fqn) signature = ep.graph_module.signature graph = ep.graph_module.graph diff --git a/torch/_functorch/benchmark_utils.py b/torch/_functorch/benchmark_utils.py index af606f20f94..e0bcae4c836 100644 --- a/torch/_functorch/benchmark_utils.py +++ b/torch/_functorch/benchmark_utils.py @@ -2,6 +2,7 @@ import contextlib import json +import operator import os import time @@ -94,7 +95,7 @@ def get_sorted_gpu_events(events): if not is_gpu_compute_event(event): continue sorted_gpu_events.append(event) - return sorted(sorted_gpu_events, key=lambda x: x["ts"]) + return sorted(sorted_gpu_events, key=operator.itemgetter("ts")) def get_duration(sorted_gpu_events): diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 873441a971a..78d580d0259 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -407,7 +407,7 @@ def _count_ops(graph): for node in graph.nodes: if node.op == "call_function": cnt[node.target.__name__] += 1 - print(sorted(cnt.items(), key=lambda x: x[1], reverse=True)) + print(sorted(cnt.items(), key=operator.itemgetter(1), reverse=True)) @functools.lru_cache(None) @@ -432,7 +432,7 @@ def sort_depths(args, depth_map): arg_depths = { arg: depth_map[arg] for arg in args if isinstance(arg, torch.fx.node.Node) } - return sorted(arg_depths.items(), key=lambda x: x[1], reverse=True) + return sorted(arg_depths.items(), key=operator.itemgetter(1), reverse=True) def reordering_to_mimic_autograd_engine(gm): @@ -1315,7 +1315,7 @@ def min_cut_rematerialization_partition( ) print( "Count of Ops Rematerialized: ", - sorted(counts.items(), key=lambda x: x[1], reverse=True), + sorted(counts.items(), key=operator.itemgetter(1), reverse=True), ) return fw_module, bw_module diff --git a/torch/_functorch/top_operators_github_usage.py b/torch/_functorch/top_operators_github_usage.py index ab5c984bada..ce74f7aadfb 100644 --- a/torch/_functorch/top_operators_github_usage.py +++ b/torch/_functorch/top_operators_github_usage.py @@ -4,6 +4,8 @@ From https://docs.google.com/spreadsheets/d/12R3nCOLskxPYjjiNkdqy4OdQ65eQp_htebXGODsjSeA/edit#gid=0 Try to keep this list in sync with that. """ +import operator + top_torch = [ ("t", 6837449), ("tensor", 585786), @@ -618,7 +620,7 @@ def get_nn_functional_top_list(): top_nn_functional_[functional_name] += count top_nn_functional_ = list(top_nn_functional_.items()) - top_nn_functional_.sort(key=lambda x: x[1], reverse=True) + top_nn_functional_.sort(key=operator.itemgetter(1), reverse=True) return top_nn_functional_ diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index 1e0e9a5a87f..f1caf01eace 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -894,7 +894,7 @@ class ReplacementPatternEntry(PatternEntry): for n in output_nodes if isinstance(n, torch.fx.Node) ] - last_node = min(indices, key=lambda tup: tup[0])[1] + last_node = min(indices, key=operator.itemgetter(0))[1] def percolate_tags(node, recompute_tag, input_stops): queue = [node] diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 4c896137fc0..d1548b73e5c 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2290,7 +2290,7 @@ class Scheduler: ) # return the possible fusions with highest priority possible_fusions_with_highest_priority = min( - possible_fusions_group_by_priority.items(), key=lambda item: item[0] + possible_fusions_group_by_priority.items(), key=operator.itemgetter(0) )[1] assert len(possible_fusions_with_highest_priority) > 0 return possible_fusions_with_highest_priority diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 1f277ec9321..b9da3b67a98 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -3597,7 +3597,7 @@ def repeat(a: Tensor, *repeat_shape) -> Tensor: # derive permute order by sorting urtensor strides enumerated_stride = list(enumerate(urtensor_stride)) - enumerated_stride.sort(key=lambda item: item[1], reverse=True) + enumerated_stride.sort(key=operator.itemgetter(1), reverse=True) permute_order, sorted_stride = zip(*enumerated_stride) # add new and expand dimensions according to urtensor diff --git a/torch/_refs/linalg/__init__.py b/torch/_refs/linalg/__init__.py index b948e1eccc0..a1b59e94d27 100644 --- a/torch/_refs/linalg/__init__.py +++ b/torch/_refs/linalg/__init__.py @@ -63,6 +63,8 @@ def _check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_nam ) +import operator + # Utilities should come BEFORE this import from torch._decomp import register_decomposition from torch._decomp.decompositions import pw_cast_for_opmath @@ -165,7 +167,7 @@ def _backshift_permutation(dim0, dim1, ndim): def _inverse_permutation(perm): # Given a permutation, returns its inverse. It's equivalent to argsort on an array - return [i for i, j in sorted(enumerate(perm), key=lambda i_j: i_j[1])] + return [i for i, j in sorted(enumerate(perm), key=operator.itemgetter(1))] # CompositeImplicitAutograd diff --git a/torch/ao/quantization/fx/_equalize.py b/torch/ao/quantization/fx/_equalize.py index 55bcb52576b..b0965b9a705 100644 --- a/torch/ao/quantization/fx/_equalize.py +++ b/torch/ao/quantization/fx/_equalize.py @@ -19,6 +19,7 @@ from .utils import ( maybe_get_next_module, node_arg_is_weight, ) +import operator CUSTOM_MODULE_SUPP_LIST: List[Any] = [] @@ -810,7 +811,7 @@ def get_equalization_qconfig_dict( # Sort the layer_sqnr_dictionary values and get the layers with the lowest # SQNR values (aka highest quantization errors) - layer_sqnr_sorted = sorted(layer_sqnr_dict.items(), key=lambda item: item[1]) + layer_sqnr_sorted = sorted(layer_sqnr_dict.items(), key=operator.itemgetter(1)) layers_to_equalize = layer_sqnr_sorted[:num_layers_to_equalize] # Constructs an equalization_qconfig_dict that specifies to only equalize diff --git a/torch/cuda/_memory_viz.py b/torch/cuda/_memory_viz.py index d3838f34102..587d7e9c7c5 100644 --- a/torch/cuda/_memory_viz.py +++ b/torch/cuda/_memory_viz.py @@ -9,6 +9,7 @@ from typing import Any from itertools import groupby import base64 import warnings +import operator cache = lru_cache(None) @@ -492,7 +493,7 @@ def _profile_to_snapshot(profile): # create the final snapshot state blocks_at_end = [(to_device(tensor_key.device), event['addr'], event['size'], event['frames']) for (tensor_key, version), event in kv_to_elem.items()] - for device, blocks in groupby(sorted(blocks_at_end), key=lambda x: x[0]): + for device, blocks in groupby(sorted(blocks_at_end), key=operator.itemgetter(0)): seg = snapshot['segments'][device] # type: ignore[index] last_addr = seg['address'] for _, addr, size, frames in blocks: diff --git a/torch/distributed/_shard/sharding_spec/api.py b/torch/distributed/_shard/sharding_spec/api.py index bcfacbf0354..1824b66a819 100644 --- a/torch/distributed/_shard/sharding_spec/api.py +++ b/torch/distributed/_shard/sharding_spec/api.py @@ -15,6 +15,7 @@ from torch.distributed._shard.metadata import ShardMetadata import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta from torch.distributed._shard.op_registry_utils import _decorator_func +import operator if TYPE_CHECKING: # Only include ShardedTensor when do type checking, exclude it @@ -214,7 +215,7 @@ def _infer_sharding_spec_from_shards_metadata(shards_metadata): if chunk_sharding_dim is not None: # Ensure we infer the correct placement order from offsets placements = [ - x for _, x in sorted(zip(chunk_offset_list, placements), key=lambda e: e[0]) + x for _, x in sorted(zip(chunk_offset_list, placements), key=operator.itemgetter(0)) ] from .chunk_sharding_spec import ChunkShardingSpec diff --git a/torch/distributed/_tools/memory_tracker.py b/torch/distributed/_tools/memory_tracker.py index 96a3fa497c0..fdc60acdf82 100644 --- a/torch/distributed/_tools/memory_tracker.py +++ b/torch/distributed/_tools/memory_tracker.py @@ -17,6 +17,7 @@ import torch import torch.nn as nn from torch.utils.hooks import RemovableHandle from torch.utils._python_dispatch import TorchDispatchMode +import operator BYTES_PER_MB = 1024 * 1024.0 @@ -148,7 +149,7 @@ class MemoryTracker: print("------------------------------------------------") print(f"The number of cuda retries are: {self._num_cuda_retries}") print(f"Top {top} ops that generates memory are:") - for k, v in sorted(op_diff.items(), key=lambda item: item[1], reverse=True)[ + for k, v in sorted(op_diff.items(), key=operator.itemgetter(1), reverse=True)[ :top ]: print(f"{k}: {v}MB") diff --git a/torch/distributed/checkpoint/filesystem.py b/torch/distributed/checkpoint/filesystem.py index 67688432923..9b6345862ce 100644 --- a/torch/distributed/checkpoint/filesystem.py +++ b/torch/distributed/checkpoint/filesystem.py @@ -1,6 +1,7 @@ import collections import dataclasses import io +import operator import os import pickle import queue @@ -177,7 +178,7 @@ class _OverlappingCpuLoader(_TensorLoader): if self.started: return self.started = True - self.items.sort(key=lambda x: x[0]) + self.items.sort(key=operator.itemgetter(0)) self._refill() def values(self) -> Iterator[Tuple[torch.Tensor, object]]: @@ -218,7 +219,7 @@ def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[Writ for wi in tensor_w: # TODO replace with headq - idx = min(enumerate(bucket_sizes), key=lambda x: x[1])[0] + idx = min(enumerate(bucket_sizes), key=operator.itemgetter(1))[0] buckets[idx].append(wi) bucket_sizes[idx] += _item_size(wi) diff --git a/torch/distributed/checkpoint/planner.py b/torch/distributed/checkpoint/planner.py index 7e29bc336c5..1492f09bf2a 100644 --- a/torch/distributed/checkpoint/planner.py +++ b/torch/distributed/checkpoint/planner.py @@ -1,5 +1,6 @@ import abc import io +import operator from dataclasses import dataclass from enum import auto, Enum from functools import reduce @@ -67,7 +68,7 @@ class WriteItem: if self.tensor_data is None: return None - numels = reduce(lambda x, y: x * y, self.tensor_data.size, 1) + numels = reduce(operator.mul, self.tensor_data.size, 1) dtype_size = torch._utils._element_size(self.tensor_data.properties.dtype) return numels * dtype_size diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index f8e220b00dc..ee3376204f3 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -840,7 +840,7 @@ def _reorder_submodules( _reorder_submodules(child, fqn_order, prefix=fqn + ".") delattr(parent, name) children.append((fqn_order[fqn], name, child)) - children.sort(key=lambda x: x[0]) + children.sort(key=operator.itemgetter(0)) for _, name, child in children: parent.register_module(name, child) diff --git a/torch/fx/experimental/accelerator_partitioner.py b/torch/fx/experimental/accelerator_partitioner.py index 7bb91692b34..fc28f112323 100644 --- a/torch/fx/experimental/accelerator_partitioner.py +++ b/torch/fx/experimental/accelerator_partitioner.py @@ -259,7 +259,7 @@ def get_device_to_partitions_mapping( # Find devices for all the partitions without a device found_device = True for partition in no_device_partitions: - device_to_left_mem_bytes = dict(sorted(device_to_left_mem_bytes.items(), key=lambda item: item[1])) + device_to_left_mem_bytes = dict(sorted(device_to_left_mem_bytes.items(), key=operator.itemgetter(1))) found_device = find_device_for(partition) if not found_device: break diff --git a/torch/fx/experimental/unification/multipledispatch/conflict.py b/torch/fx/experimental/unification/multipledispatch/conflict.py index 71db96dd476..6c247bd9811 100644 --- a/torch/fx/experimental/unification/multipledispatch/conflict.py +++ b/torch/fx/experimental/unification/multipledispatch/conflict.py @@ -1,5 +1,6 @@ from .utils import _toposort, groupby from .variadic import isvariadic +import operator __all__ = ["AmbiguityWarning", "supercedes", "consistent", "ambiguous", "ambiguities", "super_signature", "edge", "ordering"] @@ -111,7 +112,7 @@ def ordering(signatures): """ signatures = list(map(tuple, signatures)) edges = [(a, b) for a in signatures for b in signatures if edge(a, b)] - edges = groupby(lambda x: x[0], edges) + edges = groupby(operator.itemgetter(0), edges) for s in signatures: if s not in edges: edges[s] = [] diff --git a/torch/profiler/_utils.py b/torch/profiler/_utils.py index caacfb83036..1ad1293e3e9 100644 --- a/torch/profiler/_utils.py +++ b/torch/profiler/_utils.py @@ -1,4 +1,5 @@ import functools +import operator import re from collections import deque from dataclasses import dataclass @@ -316,7 +317,7 @@ class BasicEvaluation: event for _, event in sorted( zip(heuristic_score_list, event_list), - key=lambda x: x[0], + key=operator.itemgetter(0), reverse=True, ) ] diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index e111b20c086..ffd0e6f95a8 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -28,6 +28,7 @@ from torch.testing._internal.common_utils import ( skipIfTorchDynamo) from types import ModuleType from typing import List, Tuple, Type, Set, Dict +import operator # List of all namespaces containing modules to test. MODULE_NAMESPACES: List[ModuleType] = [ @@ -3374,7 +3375,7 @@ module_db: List[ModuleInfo] = [ unittest.expectedFailure, 'TestModule', 'test_memory_format', - active_if=lambda p: p['training'], + active_if=operator.itemgetter('training'), ),) ), ModuleInfo(torch.nn.AdaptiveAvgPool3d, @@ -3413,7 +3414,7 @@ module_db: List[ModuleInfo] = [ unittest.expectedFailure, 'TestModule', 'test_memory_format', - active_if=lambda p: p['training'], + active_if=operator.itemgetter('training'), device_type='cuda', ), # error: input types 'tensor' and 'tensor<15x10xf16>' are not broadcast compatible @@ -3440,13 +3441,13 @@ module_db: List[ModuleInfo] = [ DecorateInfo( unittest.expectedFailure, 'TestEagerFusionModuleInfo', 'test_aot_autograd_symbolic_module_exhaustive', - active_if=lambda p: p['training'] + active_if=operator.itemgetter('training') ), # torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default DecorateInfo( unittest.expectedFailure, 'TestEagerFusionModuleInfo', 'test_aot_autograd_module_exhaustive', - active_if=lambda p: p['training'] + active_if=operator.itemgetter('training') )) ), ModuleInfo(torch.nn.BatchNorm2d, @@ -3461,13 +3462,13 @@ module_db: List[ModuleInfo] = [ DecorateInfo( unittest.expectedFailure, 'TestEagerFusionModuleInfo', 'test_aot_autograd_symbolic_module_exhaustive', - active_if=lambda p: p['training'] + active_if=operator.itemgetter('training') ), # torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default DecorateInfo( unittest.expectedFailure, 'TestEagerFusionModuleInfo', 'test_aot_autograd_module_exhaustive', - active_if=lambda p: p['training'] + active_if=operator.itemgetter('training') ),) ), ModuleInfo(torch.nn.BatchNorm3d, @@ -3481,13 +3482,13 @@ module_db: List[ModuleInfo] = [ DecorateInfo( unittest.expectedFailure, 'TestEagerFusionModuleInfo', 'test_aot_autograd_symbolic_module_exhaustive', - active_if=lambda p: p['training'] + active_if=operator.itemgetter('training') ), # torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default DecorateInfo( unittest.expectedFailure, 'TestEagerFusionModuleInfo', 'test_aot_autograd_module_exhaustive', - active_if=lambda p: p['training'] + active_if=operator.itemgetter('training') ),) ), ModuleInfo(torch.nn.CELU, @@ -3870,7 +3871,7 @@ module_db: List[ModuleInfo] = [ unittest.expectedFailure, 'TestModule', 'test_memory_format', - active_if=lambda p: p['training'], + active_if=operator.itemgetter('training'), device_type='mps', ),) ), @@ -4070,7 +4071,7 @@ module_db: List[ModuleInfo] = [ unittest.expectedFailure, 'TestModule', 'test_memory_format', - active_if=lambda p: p['training'], + active_if=operator.itemgetter('training'), device_type='mps', ),), supports_gradgrad=False), @@ -4193,7 +4194,7 @@ module_db: List[ModuleInfo] = [ unittest.expectedFailure, 'TestModule', 'test_memory_format', - active_if=lambda p: p['training'], + active_if=operator.itemgetter('training'), device_type='mps', ),) ), @@ -4235,7 +4236,7 @@ module_db: List[ModuleInfo] = [ unittest.expectedFailure, 'TestModule', 'test_memory_format', - active_if=lambda p: p['training'], + active_if=operator.itemgetter('training'), device_type='mps', ),) ), @@ -4298,7 +4299,7 @@ module_db: List[ModuleInfo] = [ unittest.expectedFailure, 'TestModule', 'test_memory_format', - active_if=lambda p: p['training'], + active_if=operator.itemgetter('training'), device_type='mps', ),) ), @@ -4311,7 +4312,7 @@ module_db: List[ModuleInfo] = [ unittest.expectedFailure, 'TestModule', 'test_memory_format', - active_if=lambda p: p['training'], + active_if=operator.itemgetter('training'), device_type='mps', ),) ), diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index ee98aaa161f..5d2a67cd473 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -55,6 +55,7 @@ from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( from torch.testing._internal.common_utils import TemporaryFileName from torch.autograd.profiler_legacy import profile as _profile +import operator def foo_add(): @@ -6309,7 +6310,7 @@ class TensorPipeAgentCudaRpcTest(RpcAgentTestFixture, RpcTestCommon): @skip_if_lt_x_gpu(1) def test_cuda_future_can_extract_list_with_cuda_tensor(self): self._test_cuda_future_extraction( - wrapper=lambda t: [t], unwrapper=lambda v: v[0], sparse_tensor=False + wrapper=lambda t: [t], unwrapper=operator.itemgetter(0), sparse_tensor=False ) @skip_if_lt_x_gpu(1) @@ -6484,7 +6485,7 @@ class TensorPipeAgentCudaRpcTest(RpcAgentTestFixture, RpcTestCommon): @skip_if_lt_x_gpu(1) def test_cuda_future_can_extract_list_with_cuda_sparse_tensor(self): self._test_cuda_future_extraction( - wrapper=lambda t: [t], unwrapper=lambda v: v[0], sparse_tensor=True + wrapper=lambda t: [t], unwrapper=operator.itemgetter(0), sparse_tensor=True ) @skip_if_lt_x_gpu(1) diff --git a/torch/utils/benchmark/examples/op_benchmark.py b/torch/utils/benchmark/examples/op_benchmark.py index b7536b9ec26..e2f0861d20a 100644 --- a/torch/utils/benchmark/examples/op_benchmark.py +++ b/torch/utils/benchmark/examples/op_benchmark.py @@ -9,6 +9,7 @@ import torch from torch.utils.benchmark import Timer from torch.utils.benchmark.op_fuzzers.binary import BinaryOpFuzzer from torch.utils.benchmark.op_fuzzers.unary import UnaryOpFuzzer +import operator _MEASURE_TIME = 1.0 @@ -75,7 +76,7 @@ def run(n, stmt, fuzzer_cls): order_len = max(order_len, len(order)) steps_len = max(steps_len, len(steps)) - parsed_results.sort(key=lambda x: x[2]) + parsed_results.sort(key=operator.itemgetter(2)) print(f"stmt: {stmt}") print(f" diff faster{'':>17}{' ' * name_len} ", end="") diff --git a/torch/utils/benchmark/examples/sparse/op_benchmark.py b/torch/utils/benchmark/examples/sparse/op_benchmark.py index d7e97d33cc1..f998f6d5db4 100644 --- a/torch/utils/benchmark/examples/sparse/op_benchmark.py +++ b/torch/utils/benchmark/examples/sparse/op_benchmark.py @@ -9,6 +9,7 @@ import torch from torch.utils.benchmark import Timer from torch.utils.benchmark.op_fuzzers.sparse_unary import UnaryOpSparseFuzzer from torch.utils.benchmark.op_fuzzers.sparse_binary import BinaryOpSparseFuzzer +import operator _MEASURE_TIME = 1.0 @@ -70,7 +71,7 @@ def run(n, stmt, fuzzer_cls): sparse_dim_len = max(sparse_dim_len, len(sparse_dim)) is_coalesced_len = max(is_coalesced_len, len(is_coalesced)) - parsed_results.sort(key=lambda x: x[2]) + parsed_results.sort(key=operator.itemgetter(2)) print(f"stmt: {stmt}") print(f" diff faster{'':>17}{' ' * name_len} ", end="") diff --git a/torch/utils/benchmark/utils/compare.py b/torch/utils/benchmark/utils/compare.py index 9c7863e6a74..337b742ca06 100644 --- a/torch/utils/benchmark/utils/compare.py +++ b/torch/utils/benchmark/utils/compare.py @@ -6,6 +6,7 @@ from typing import DefaultDict, List, Optional, Tuple from torch.utils.benchmark.utils import common from torch import tensor as _tensor +import operator __all__ = ["Colorize", "Compare"] @@ -167,7 +168,7 @@ class Table: ) self.row_keys = common.ordered_unique([self.row_fn(i) for i in results]) - self.row_keys.sort(key=lambda args: args[:2]) # preserve stmt order + self.row_keys.sort(key=operator.itemgetter(slice(2))) # preserve stmt order self.column_keys = common.ordered_unique([self.col_fn(i) for i in results]) self.rows, self.columns = self.populate_rows_and_columns() diff --git a/torchgen/operator_versions/gen_mobile_upgraders.py b/torchgen/operator_versions/gen_mobile_upgraders.py index dab15685804..29070761c55 100644 --- a/torchgen/operator_versions/gen_mobile_upgraders.py +++ b/torchgen/operator_versions/gen_mobile_upgraders.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import os from enum import Enum +from operator import itemgetter from pathlib import Path from typing import Any, Dict, List @@ -263,7 +264,7 @@ def construct_version_maps( upgrader_bytecode_function_to_index_map: Dict[str, Any] ) -> str: version_map = torch._C._get_operator_version_map() - sorted_version_map_ = sorted(version_map.items(), key=lambda item: item[0]) # type: ignore[no-any-return] + sorted_version_map_ = sorted(version_map.items(), key=itemgetter(0)) # type: ignore[no-any-return] sorted_version_map = dict(sorted_version_map_) operator_list_in_version_map_part = []