diff --git a/tools/linter/adapters/pyfmt_linter.py b/tools/linter/adapters/pyfmt_linter.py index 55ffa429e7f..927325bffeb 100644 --- a/tools/linter/adapters/pyfmt_linter.py +++ b/tools/linter/adapters/pyfmt_linter.py @@ -52,7 +52,6 @@ USE_BLACK_FILELIST = re.compile( # torch/[e-m]*/** # torch/optim/** # torch/[p-z]*/** - "torch/[p-z]*/**", ], ), ) diff --git a/torch/package/_mangling.py b/torch/package/_mangling.py index 09d7901c2d6..08b0560f793 100644 --- a/torch/package/_mangling.py +++ b/torch/package/_mangling.py @@ -2,6 +2,7 @@ """Import mangling. See mangling.md for details. """ + import re diff --git a/torch/package/package_exporter.py b/torch/package/package_exporter.py index 21446c626b9..6118e8ce809 100644 --- a/torch/package/package_exporter.py +++ b/torch/package/package_exporter.py @@ -605,9 +605,9 @@ class PackageExporter: dependencies (bool, optional): If ``True``, we scan the source for dependencies. """ - assert (pickle_protocol == 4) or ( - pickle_protocol == 3 - ), "torch.package only supports pickle protocols 3 and 4" + assert (pickle_protocol == 4) or (pickle_protocol == 3), ( + "torch.package only supports pickle protocols 3 and 4" + ) filename = self._filename(package, resource) # Write the pickle data for `obj` diff --git a/torch/package/package_importer.py b/torch/package/package_importer.py index a97cf475b35..7291227e42a 100644 --- a/torch/package/package_importer.py +++ b/torch/package/package_importer.py @@ -423,7 +423,12 @@ class PackageImporter(Importer): module.__dict__.setdefault(old_name, new_name) return module - return self._make_module(name, cur.source_file, isinstance(cur, _PackageNode), parent) # type: ignore[attr-defined] + return self._make_module( + name, + cur.source_file, # type: ignore[attr-defined] + isinstance(cur, _PackageNode), + parent, + ) def _compile_source(self, fullpath: str, mangled_filename: str): source = self.zip_reader.get_record(fullpath) diff --git a/torch/profiler/__init__.py b/torch/profiler/__init__.py index a90a371130e..153d4560e26 100644 --- a/torch/profiler/__init__.py +++ b/torch/profiler/__init__.py @@ -7,6 +7,7 @@ examine their input shapes and stack traces, study device kernel activity and vi An earlier version of the API in :mod:`torch.autograd` module is considered legacy and will be deprecated. """ + import os from typing import Any from typing_extensions import TypeVarTuple, Unpack diff --git a/torch/profiler/_memory_profiler.py b/torch/profiler/_memory_profiler.py index 7ad917d1e86..d9f3a917c15 100644 --- a/torch/profiler/_memory_profiler.py +++ b/torch/profiler/_memory_profiler.py @@ -239,10 +239,12 @@ class SchemaMatcher: def match_schemas(cls, t: _ExtraFields_TorchOp) -> tuple[FunctionSchema, ...]: signature = tuple( # Tensor - TensorKey.from_tensor(i) if isinstance(i, _TensorMetadata) + TensorKey.from_tensor(i) + if isinstance(i, _TensorMetadata) # # TensorList - else [TensorKey.from_tensor(j) for j in i] if isinstance(i, list) + else [TensorKey.from_tensor(j) for j in i] + if isinstance(i, list) # # Scalar and uncaptured inputs. else i diff --git a/torch/profiler/_utils.py b/torch/profiler/_utils.py index b1160324cb9..5b631ef743c 100644 --- a/torch/profiler/_utils.py +++ b/torch/profiler/_utils.py @@ -124,9 +124,9 @@ class BasicEvaluation: for child_event in curr_event.children: self_time -= child_event.duration_time_ns stack.append(child_event) - assert ( - EventKey(curr_event) not in self.metrics - ), f"Duplicate id: {curr_event.id}, {curr_event.name}" + assert EventKey(curr_event) not in self.metrics, ( + f"Duplicate id: {curr_event.id}, {curr_event.name}" + ) self.metrics[EventKey(curr_event)] = EventMetrics(self_time_ns=self_time) self.metrics[ EventKey(curr_event) @@ -227,8 +227,7 @@ class BasicEvaluation: while ( current_kernel_index < len(cuda_kernel_events) - and (cuda_kernel_events[current_kernel_index].start_ns()) - <= start_time # type: ignore[possibly-undefined] + and (cuda_kernel_events[current_kernel_index].start_ns()) <= start_time # type: ignore[possibly-undefined] ): current_kernel_index += 1 current_queue_depth = spawned_kernel_index - current_kernel_index + 1 @@ -352,11 +351,11 @@ class BasicEvaluation: output += "\n".join( [ - f"""{'-' * 80} + f"""{"-" * 80} Event: {event} Source code location: {source_code_location(event.event)} Percentage idle time: {self.metrics[event].fraction_idle_time * 100:.2f}% -{'-' * 80}""" +{"-" * 80}""" for event in event_list ] ) diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index f7be416cfaa..d88d6c5cad7 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -624,8 +624,7 @@ class profile(_KinetoProfile): ] ) as p: code_to_profile() - print(p.key_averages().table( - sort_by="self_cuda_time_total", row_limit=-1)) + print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) Using the profiler's ``schedule``, ``on_trace_ready`` and ``step`` functions: @@ -635,16 +634,17 @@ class profile(_KinetoProfile): # on different iterations of the training loop; # trace_handler is called every time a new trace becomes available def trace_handler(prof): - print(prof.key_averages().table( - sort_by="self_cuda_time_total", row_limit=-1)) + print( + prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1) + ) # prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json") + with torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ], - # In this example with wait=1, warmup=1, active=2, repeat=1, # profiler will skip the first step/iteration, # start warming up on the second, record @@ -652,20 +652,15 @@ class profile(_KinetoProfile): # after which the trace will become available # and on_trace_ready (when set) is called; # the cycle repeats starting with the next step - - schedule=torch.profiler.schedule( - wait=1, - warmup=1, - active=2, - repeat=1), - on_trace_ready=trace_handler + schedule=torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=1), + on_trace_ready=trace_handler, # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log') # used when outputting for tensorboard - ) as p: - for iter in range(N): - code_iteration_to_profile(iter) - # send a signal to the profiler that the next iteration has started - p.step() + ) as p: + for iter in range(N): + code_iteration_to_profile(iter) + # send a signal to the profiler that the next iteration has started + p.step() The following sample shows how to setup up an Execution Trace Observer (`execution_trace_observer`) diff --git a/torch/quantization/fuser_method_mappings.py b/torch/quantization/fuser_method_mappings.py index cfb13ac9627..5a68fbf0201 100644 --- a/torch/quantization/fuser_method_mappings.py +++ b/torch/quantization/fuser_method_mappings.py @@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the `torch/ao/quantization/fuser_method_mappings.py`, while adding an import statement here. """ + from torch.ao.quantization.fuser_method_mappings import ( _DEFAULT_OP_LIST_TO_FUSER_METHOD, fuse_conv_bn, diff --git a/torch/quantization/fx/_equalize.py b/torch/quantization/fx/_equalize.py index 7acea4f84a2..d6b8611d4a7 100644 --- a/torch/quantization/fx/_equalize.py +++ b/torch/quantization/fx/_equalize.py @@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx._equalize import ( _convert_equalization_ref, _InputEqualizationObserver, diff --git a/torch/quantization/fx/convert.py b/torch/quantization/fx/convert.py index 9d6ac350602..30a661da41e 100644 --- a/torch/quantization/fx/convert.py +++ b/torch/quantization/fx/convert.py @@ -6,4 +6,5 @@ If you are adding a new entry/functionality, please, add it to the appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx.convert import convert diff --git a/torch/quantization/fx/fuse.py b/torch/quantization/fx/fuse.py index 67527080304..22ad750e9f8 100644 --- a/torch/quantization/fx/fuse.py +++ b/torch/quantization/fx/fuse.py @@ -6,4 +6,5 @@ If you are adding a new entry/functionality, please, add it to the appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx.fuse import fuse diff --git a/torch/quantization/fx/fusion_patterns.py b/torch/quantization/fx/fusion_patterns.py index e29337b3f86..982d919655f 100644 --- a/torch/quantization/fx/fusion_patterns.py +++ b/torch/quantization/fx/fusion_patterns.py @@ -6,4 +6,5 @@ If you are adding a new entry/functionality, please, add it to the appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx.fuse_handler import DefaultFuseHandler, FuseHandler diff --git a/torch/quantization/fx/graph_module.py b/torch/quantization/fx/graph_module.py index a71e980a57b..74b63903d74 100644 --- a/torch/quantization/fx/graph_module.py +++ b/torch/quantization/fx/graph_module.py @@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx.graph_module import ( _is_observed_module, _is_observed_standalone_module, diff --git a/torch/quantization/fx/match_utils.py b/torch/quantization/fx/match_utils.py index 8b49f7c645d..8585a21ad44 100644 --- a/torch/quantization/fx/match_utils.py +++ b/torch/quantization/fx/match_utils.py @@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx.match_utils import ( _find_matches, _is_match, diff --git a/torch/quantization/fx/pattern_utils.py b/torch/quantization/fx/pattern_utils.py index 2a83e180fc4..fa601d1eb61 100644 --- a/torch/quantization/fx/pattern_utils.py +++ b/torch/quantization/fx/pattern_utils.py @@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx.pattern_utils import ( _register_fusion_pattern, _register_quant_pattern, diff --git a/torch/quantization/fx/prepare.py b/torch/quantization/fx/prepare.py index ca65dcc04dd..a6007ef242a 100644 --- a/torch/quantization/fx/prepare.py +++ b/torch/quantization/fx/prepare.py @@ -6,4 +6,5 @@ If you are adding a new entry/functionality, please, add it to the appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx.prepare import prepare diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 20d8cc52ee4..89f8d4406e9 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx.quantize_handler import ( BatchNormQuantizeHandler, BinaryOpQuantizeHandler, diff --git a/torch/quantization/fx/quantization_types.py b/torch/quantization/fx/quantization_types.py index a422cdd3142..0820ea05707 100644 --- a/torch/quantization/fx/quantization_types.py +++ b/torch/quantization/fx/quantization_types.py @@ -6,4 +6,5 @@ If you are adding a new entry/functionality, please, add it to the appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.utils import Pattern, QuantizerCls diff --git a/torch/quantization/fx/utils.py b/torch/quantization/fx/utils.py index ef35559884b..e45c82b8fb6 100644 --- a/torch/quantization/fx/utils.py +++ b/torch/quantization/fx/utils.py @@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx.utils import ( all_node_args_have_no_tensors, assert_and_get_unique_device, diff --git a/torch/quantization/observer.py b/torch/quantization/observer.py index 6e6c7c1917c..2163e2717b0 100644 --- a/torch/quantization/observer.py +++ b/torch/quantization/observer.py @@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the `torch/ao/quantization/observer.py`, while adding an import statement here. """ + from torch.ao.quantization.observer import ( _is_activation_post_process, _is_per_channel_script_obs_instance, diff --git a/torch/quantization/qconfig.py b/torch/quantization/qconfig.py index 6bb7e14110c..a02ff7d6f73 100644 --- a/torch/quantization/qconfig.py +++ b/torch/quantization/qconfig.py @@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the `torch/ao/quantization/qconfig.py`, while adding an import statement here. """ + from torch.ao.quantization.qconfig import ( _add_module_to_qconfig_obs_ctr, _assert_valid_qconfig, diff --git a/torch/quantization/quantization_mappings.py b/torch/quantization/quantization_mappings.py index 8b44a980ce8..faa24d391d3 100644 --- a/torch/quantization/quantization_mappings.py +++ b/torch/quantization/quantization_mappings.py @@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the `torch/ao/quantization/quantization_mappings.py`, while adding an import statement here. """ + from torch.ao.quantization.quantization_mappings import ( _get_special_act_post_process, _has_special_act_post_process, diff --git a/torch/signal/windows/windows.py b/torch/signal/windows/windows.py index 7d67de3f838..e68c202f03e 100644 --- a/torch/signal/windows/windows.py +++ b/torch/signal/windows/windows.py @@ -128,9 +128,7 @@ Examples:: >>> # Generates a periodic exponential window and decay factor equal to .5 >>> torch.signal.windows.exponential(10, sym=False,tau=.5) tensor([4.5400e-05, 3.3546e-04, 2.4788e-03, 1.8316e-02, 1.3534e-01, 1.0000e+00, 1.3534e-01, 1.8316e-02, 2.4788e-03, 3.3546e-04]) - """.format( - **window_common_args - ), + """.format(**window_common_args), ) def exponential( M: int, @@ -452,9 +450,7 @@ Examples:: >>> # Generates a periodic Hamming window. >>> torch.signal.windows.hamming(10, sym=False) tensor([0.0800, 0.1679, 0.3979, 0.6821, 0.9121, 1.0000, 0.9121, 0.6821, 0.3979, 0.1679]) -""".format( - **window_common_args - ), +""".format(**window_common_args), ) def hamming( M: int, @@ -508,9 +504,7 @@ Examples:: >>> # Generates a periodic Hann window. >>> torch.signal.windows.hann(10, sym=False) tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955]) -""".format( - **window_common_args - ), +""".format(**window_common_args), ) def hann( M: int, @@ -564,9 +558,7 @@ Examples:: >>> # Generates a periodic Blackman window. >>> torch.signal.windows.blackman(5, sym=False) tensor([-1.4901e-08, 2.0077e-01, 8.4923e-01, 8.4923e-01, 2.0077e-01]) -""".format( - **window_common_args - ), +""".format(**window_common_args), ) def blackman( M: int, @@ -627,9 +619,7 @@ Examples:: >>> # Generates a periodic Bartlett window. >>> torch.signal.windows.bartlett(10, sym=False) tensor([0.0000, 0.2000, 0.4000, 0.6000, 0.8000, 1.0000, 0.8000, 0.6000, 0.4000, 0.2000]) -""".format( - **window_common_args - ), +""".format(**window_common_args), ) def bartlett( M: int, @@ -704,9 +694,7 @@ Examples:: >>> # Generates a periodic general cosine window with 2 coefficients. >>> torch.signal.windows.general_cosine(10, a=[0.5, 1 - 0.5], sym=False) tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955]) -""".format( - **window_common_args - ), +""".format(**window_common_args), ) def general_cosine( M, @@ -799,9 +787,7 @@ Examples:: >>> # Generates a periodic Hann window with the general Hamming window. >>> torch.signal.windows.general_hamming(10, alpha=0.5, sym=False) tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955]) -""".format( - **window_common_args - ), +""".format(**window_common_args), ) def general_hamming( M, @@ -866,9 +852,7 @@ Examples:: >>> # Generates a periodic Nuttall window. >>> torch.signal.windows.general_hamming(5, sym=False) tensor([3.6280e-04, 1.1052e-01, 7.9826e-01, 7.9826e-01, 1.1052e-01]) -""".format( - **window_common_args - ), +""".format(**window_common_args), ) def nuttall( M: int, diff --git a/torch/sparse/__init__.py b/torch/sparse/__init__.py index 39d78e8c26a..31299314a85 100644 --- a/torch/sparse/__init__.py +++ b/torch/sparse/__init__.py @@ -559,7 +559,11 @@ def as_sparse_gradcheck(gradcheck): For example: >>> gradcheck = torch.sparse.as_sparse_gradcheck(torch.autograd.gradcheck) - >>> x = torch.tensor([[0, 1], [2, 3]], dtype=torch.float64).to_sparse_coo().requires_grad_(True) + >>> x = ( + ... torch.tensor([[0, 1], [2, 3]], dtype=torch.float64) + ... .to_sparse_coo() + ... .requires_grad_(True) + ... ) >>> gradcheck(lambda x: x.to_sparse_csr(), x) True """ @@ -667,7 +671,7 @@ def as_sparse_gradcheck(gradcheck): ) else: raise NotImplementedError( - f'conversion of {d["layout"]} strided representation to tensor' + f"conversion of {d['layout']} strided representation to tensor" ) new_args.append(a) return tuple(new_args) diff --git a/torch/sparse/_triton_ops.py b/torch/sparse/_triton_ops.py index a5e802084c2..ea36264d8f8 100644 --- a/torch/sparse/_triton_ops.py +++ b/torch/sparse/_triton_ops.py @@ -296,11 +296,11 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None): for b in range(nbatches): for i, r in enumerate(r_offsets): r0, r1 = divmod(r, N) - acc = accumulators[b, r0:r0 + Ms, r1:r1 + Ns] - for g in range(c_indices[i], c_indices[i+1]): + acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns] + for g in range(c_indices[i], c_indices[i + 1]): p = p_offsets[g] q0, q1 = divmod(q_offsets[g], N) - acc += blocks[p] @ others[b, q0:q0 + Ks, q1:q1 + Ns] + acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns] where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are integer multiples of ``Ms`` and ``Ks``, respectively. @@ -320,11 +320,11 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None): n = (r % N) // Ns r0, r1 = divmod(r, N) c0, c1 = c_indices[m], c_indices[m + 1] - acc = accumulators[b, r0:r0 + Ms, r1:r1 + Ns] + acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns] for i, p in enumerate(range(c0, c1)): q = q_offsets[n * c1 + (SPLIT_N - n) * c0 + i] q0, q1 = divmod(q, N) - acc += blocks[p] @ others[b, q0:q0 + Ks, q1:q1 + Ns] + acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns] where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are integer multiples of ``Ms`` and ``Ks``, respectively. diff --git a/torch/sparse/_triton_ops_meta.py b/torch/sparse/_triton_ops_meta.py index 762874077c7..89245246395 100644 --- a/torch/sparse/_triton_ops_meta.py +++ b/torch/sparse/_triton_ops_meta.py @@ -97,6 +97,7 @@ tune_bsr_dense_addmm to learn how to register a custom set of optimal kernel parameters for addmm-based operations. """ + __all__ = ["get_meta", "tune_bsr_dense_addmm", "tune__int_bsr_dense_addmm"] import inspect @@ -432,9 +433,9 @@ def minimize( def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device): - assert ( - sparsity <= 1.0 and sparsity >= 0.0 - ), "sparsity should be a value between 0 and 1" + assert sparsity <= 1.0 and sparsity >= 0.0, ( + "sparsity should be a value between 0 and 1" + ) assert M % blocksize[0] == 0 assert N % blocksize[1] == 0 shape = (B, M // blocksize[0], N // blocksize[1])[int(B == 0) :] diff --git a/torch/sparse/semi_structured.py b/torch/sparse/semi_structured.py index 721f2551279..b225eaabb32 100644 --- a/torch/sparse/semi_structured.py +++ b/torch/sparse/semi_structured.py @@ -465,14 +465,26 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor): The equivalent PyTorch code to create the same five outputs from the dense tensor can be found below: ``` from torch.sparse import SparseSemiStructuredTensorCUTLASS - from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask + from torch.sparse._semi_structured_conversions import ( + _sparse_semi_structured_tile, + _compute_compressed_swizzled_bitmask, + ) pruned = _sparse_semi_structured_tile(dense) packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned) - packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous()) + packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass( + pruned.t().contiguous() + ) bitmask = _compute_compressed_swizzled_bitmask(pruned) - SparseSemiStructuredTensorCUTLASS(dense.shape, packed_cutlass, meta_cutlass, packed_t_cutlass, meta_t_cutlass, bitmask) + SparseSemiStructuredTensorCUTLASS( + dense.shape, + packed_cutlass, + meta_cutlass, + packed_t_cutlass, + meta_t_cutlass, + bitmask, + ) ``` """ # We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag. @@ -583,14 +595,19 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor): The equivalent PyTorch code to create the same three outputs from the dense tensor can be found below: ``` from torch.sparse import SparseSemiStructuredTensorCUSPARSELT - from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask + from torch.sparse._semi_structured_conversions import ( + _sparse_semi_structured_tile, + _compute_compressed_swizzled_bitmask, + ) pruned = _sparse_semi_structured_tile(dense) packed_cusparselt = torch._cslt_compress(pruned) packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous()) bitmask = _compute_compressed_swizzled_bitmask(pruned) - SparseSemiStructuredTensorCUSPARSELT(dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask) + SparseSemiStructuredTensorCUSPARSELT( + dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask + ) ``` """ ( diff --git a/torch/special/__init__.py b/torch/special/__init__.py index be027caa94c..dbc9314ad20 100644 --- a/torch/special/__init__.py +++ b/torch/special/__init__.py @@ -134,9 +134,7 @@ Example:: >>> torch.special.digamma(a) tensor([-0.5772, -1.9635]) -""".format( - **common_args - ), +""".format(**common_args), ) gammaln = _add_docstr( @@ -162,9 +160,7 @@ Example:: >>> torch.special.gammaln(a) tensor([ 0.5724, 0.0000, -0.1208]) -""".format( - **common_args - ), +""".format(**common_args), ) polygamma = _add_docstr( @@ -200,9 +196,7 @@ Example:: tensor([ 6.4939, 97.4091]) >>> torch.special.polygamma(4, a) tensor([ -24.8863, -771.4742]) -""".format( - **common_args - ), +""".format(**common_args), ) erf = _add_docstr( @@ -226,9 +220,7 @@ Example:: >>> torch.special.erf(torch.tensor([0, -1., 10.])) tensor([ 0.0000, -0.8427, 1.0000]) -""".format( - **common_args - ), +""".format(**common_args), ) erfc = _add_docstr( @@ -253,9 +245,7 @@ Example:: >>> torch.special.erfc(torch.tensor([0, -1., 10.])) tensor([ 1.0000, 1.8427, 0.0000]) -""".format( - **common_args - ), +""".format(**common_args), ) erfcx = _add_docstr( @@ -283,9 +273,7 @@ Example:: >>> torch.special.erfcx(torch.tensor([0, -1., 10.])) tensor([ 1.0000, 5.0090, 0.0561]) -""".format( - **common_args - ), +""".format(**common_args), ) erfinv = _add_docstr( @@ -311,9 +299,7 @@ Example:: >>> torch.special.erfinv(torch.tensor([0, 0.5, -1.])) tensor([ 0.0000, 0.4769, -inf]) -""".format( - **common_args - ), +""".format(**common_args), ) logit = _add_docstr( @@ -351,9 +337,7 @@ Example:: tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516]) >>> torch.special.logit(a, eps=1e-6) tensor([-0.9466, 2.6352, 0.6131, -1.7169, 0.6261]) -""".format( - **common_args - ), +""".format(**common_args), ) logsumexp = _add_docstr( @@ -362,9 +346,7 @@ logsumexp = _add_docstr( logsumexp(input, dim, keepdim=False, *, out=None) Alias for :func:`torch.logsumexp`. -""".format( - **multi_dim_common - ), +""".format(**multi_dim_common), ) expit = _add_docstr( @@ -391,9 +373,7 @@ Example:: tensor([ 0.9213, 1.0887, -0.8858, -1.7683]) >>> torch.special.expit(t) tensor([ 0.7153, 0.7481, 0.2920, 0.1458]) -""".format( - **common_args - ), +""".format(**common_args), ) exp2 = _add_docstr( @@ -418,9 +398,7 @@ Example:: >>> torch.special.exp2(torch.tensor([0, math.log2(2.), 3, 4])) tensor([ 1., 2., 8., 16.]) -""".format( - **common_args - ), +""".format(**common_args), ) expm1 = _add_docstr( @@ -448,9 +426,7 @@ Example:: >>> torch.special.expm1(torch.tensor([0, math.log(2.)])) tensor([ 0., 1.]) -""".format( - **common_args - ), +""".format(**common_args), ) xlog1py = _add_docstr( @@ -495,9 +471,7 @@ Example:: tensor([1.6094, 3.2189, 4.8283]) >>> torch.special.xlog1py(2, y) tensor([2.7726, 2.1972, 1.3863]) -""".format( - **common_args - ), +""".format(**common_args), ) xlogy = _add_docstr( @@ -542,9 +516,7 @@ Example:: tensor([1.3863, 2.7726, 4.1589]) >>> torch.special.xlogy(2, y) tensor([2.1972, 1.3863, 0.0000]) -""".format( - **common_args - ), +""".format(**common_args), ) i0 = _add_docstr( @@ -570,9 +542,7 @@ Example:: >>> torch.i0(torch.arange(5, dtype=torch.float32)) tensor([ 1.0000, 1.2661, 2.2796, 4.8808, 11.3019]) -""".format( - **common_args - ), +""".format(**common_args), ) i0e = _add_docstr( @@ -597,9 +567,7 @@ Example:: >>> torch.special.i0e(torch.arange(5, dtype=torch.float32)) tensor([1.0000, 0.4658, 0.3085, 0.2430, 0.2070]) -""".format( - **common_args - ), +""".format(**common_args), ) i1 = _add_docstr( @@ -624,9 +592,7 @@ Example:: >>> torch.special.i1(torch.arange(5, dtype=torch.float32)) tensor([0.0000, 0.5652, 1.5906, 3.9534, 9.7595]) -""".format( - **common_args - ), +""".format(**common_args), ) i1e = _add_docstr( @@ -652,9 +618,7 @@ Example:: >>> torch.special.i1e(torch.arange(5, dtype=torch.float32)) tensor([0.0000, 0.2079, 0.2153, 0.1968, 0.1788]) -""".format( - **common_args - ), +""".format(**common_args), ) ndtr = _add_docstr( @@ -679,9 +643,7 @@ Example:: >>> torch.special.ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3])) tensor([0.0013, 0.0228, 0.1587, 0.5000, 0.8413, 0.9772, 0.9987]) -""".format( - **common_args - ), +""".format(**common_args), ) ndtri = _add_docstr( @@ -709,9 +671,7 @@ Example:: >>> torch.special.ndtri(torch.tensor([0, 0.25, 0.5, 0.75, 1])) tensor([ -inf, -0.6745, 0.0000, 0.6745, inf]) -""".format( - **common_args - ), +""".format(**common_args), ) log_ndtr = _add_docstr( @@ -736,9 +696,7 @@ Example:: >>> torch.special.log_ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3])) tensor([-6.6077 -3.7832 -1.841 -0.6931 -0.1728 -0.023 -0.0014]) -""".format( - **common_args - ), +""".format(**common_args), ) log1p = _add_docstr( @@ -779,9 +737,7 @@ Example:: tensor([ 0.2252, -0.2948, 1.0267, -1.1566]) >>> torch.special.sinc(t) tensor([ 0.9186, 0.8631, -0.0259, -0.1300]) -""".format( - **common_args - ), +""".format(**common_args), ) round = _add_docstr( @@ -886,9 +842,7 @@ Example:: tensor([1.6449, 0.0823]) >>> torch.special.zeta(2, torch.tensor([1., 2.])) tensor([1.6449, 0.6449]) -""".format( - **common_args - ), +""".format(**common_args), ) multigammaln = _add_docstr( @@ -925,9 +879,7 @@ Example:: >>> torch.special.multigammaln(a, 2) tensor([[0.3928, 0.4007, 0.7586], [1.0311, 0.3901, 0.5049]]) -""".format( - **common_args - ), +""".format(**common_args), ) gammainc = _add_docstr( @@ -976,9 +928,7 @@ Example:: >>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2) tensor([1., 1., 1.]) -""".format( - **common_args - ), +""".format(**common_args), ) gammaincc = _add_docstr( @@ -1026,9 +976,7 @@ Example:: >>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2) tensor([1., 1., 1.]) -""".format( - **common_args - ), +""".format(**common_args), ) airy_ai = _add_docstr( @@ -1045,9 +993,7 @@ Args: Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) bessel_j0 = _add_docstr( @@ -1064,9 +1010,7 @@ Args: Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) bessel_j1 = _add_docstr( @@ -1083,9 +1027,7 @@ Args: Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) bessel_y0 = _add_docstr( @@ -1102,9 +1044,7 @@ Args: Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) bessel_y1 = _add_docstr( @@ -1121,9 +1061,7 @@ Args: Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) chebyshev_polynomial_t = _add_docstr( @@ -1154,9 +1092,7 @@ Args: Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) chebyshev_polynomial_u = _add_docstr( @@ -1188,9 +1124,7 @@ Args: Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) chebyshev_polynomial_v = _add_docstr( @@ -1208,9 +1142,7 @@ Args: Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) chebyshev_polynomial_w = _add_docstr( @@ -1228,9 +1160,7 @@ Args: Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) hermite_polynomial_h = _add_docstr( @@ -1256,9 +1186,7 @@ Args: Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) hermite_polynomial_he = _add_docstr( @@ -1284,9 +1212,7 @@ Args: Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) laguerre_polynomial_l = _add_docstr( @@ -1312,9 +1238,7 @@ Args: Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) legendre_polynomial_p = _add_docstr( @@ -1340,9 +1264,7 @@ Args: Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) modified_bessel_i0 = _add_docstr( @@ -1359,9 +1281,7 @@ Args: Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) modified_bessel_i1 = _add_docstr( @@ -1378,9 +1298,7 @@ Args: Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) modified_bessel_k0 = _add_docstr( @@ -1397,9 +1315,7 @@ Args: Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) modified_bessel_k1 = _add_docstr( @@ -1416,9 +1332,7 @@ Args: Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) scaled_modified_bessel_k0 = _add_docstr( @@ -1435,9 +1349,7 @@ Args: Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) scaled_modified_bessel_k1 = _add_docstr( @@ -1454,9 +1366,7 @@ Args: Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) shifted_chebyshev_polynomial_t = _add_docstr( @@ -1474,9 +1384,7 @@ Args: Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) shifted_chebyshev_polynomial_u = _add_docstr( @@ -1494,9 +1402,7 @@ Args: Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) shifted_chebyshev_polynomial_v = _add_docstr( @@ -1514,9 +1420,7 @@ Args: Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) shifted_chebyshev_polynomial_w = _add_docstr( @@ -1534,9 +1438,7 @@ Args: Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) spherical_bessel_j0 = _add_docstr( @@ -1553,7 +1455,5 @@ Args: Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) diff --git a/torch/testing/_comparison.py b/torch/testing/_comparison.py index 228c04cd312..eff07c413de 100644 --- a/torch/testing/_comparison.py +++ b/torch/testing/_comparison.py @@ -1538,7 +1538,9 @@ def assert_close( >>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = torch.tensor([1.0, 4.0, 5.0]) >>> # The default error message can be overwritten. - >>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!") + >>> torch.testing.assert_close( + ... actual, expected, msg="Argh, the tensors are not close!" + ... ) Traceback (most recent call last): ... AssertionError: Argh, the tensors are not close! diff --git a/torch/testing/_creation.py b/torch/testing/_creation.py index e513b8d8560..23d80d6ceae 100644 --- a/torch/testing/_creation.py +++ b/torch/testing/_creation.py @@ -115,11 +115,11 @@ def make_tensor( >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) >>> from torch.testing import make_tensor >>> # Creates a float tensor with values in [-1, 1) - >>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1) + >>> make_tensor((3,), device="cpu", dtype=torch.float32, low=-1, high=1) >>> # xdoctest: +SKIP tensor([ 0.1205, 0.2282, -0.6380]) >>> # Creates a bool tensor on CUDA - >>> make_tensor((2, 2), device='cuda', dtype=torch.bool) + >>> make_tensor((2, 2), device="cuda", dtype=torch.bool) tensor([[False, False], [False, True]], device='cuda:0') """ diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 01499280da8..528497ba545 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -721,9 +721,9 @@ def filter_desired_device_types(device_type_test_bases, except_for=None, only_fo intersect = set(except_for if except_for else []) & set( only_for if only_for else [] ) - assert ( - not intersect - ), f"device ({intersect}) appeared in both except_for and only_for" + assert not intersect, ( + f"device ({intersect}) appeared in both except_for and only_for" + ) # Replace your privateuse1 backend name with 'privateuse1' if is_privateuse1_backend_available(): @@ -1407,9 +1407,9 @@ class deviceCountAtLeast: self.num_required_devices = num_required_devices def __call__(self, fn): - assert not hasattr( - fn, "num_required_devices" - ), f"deviceCountAtLeast redefinition for {fn.__name__}" + assert not hasattr(fn, "num_required_devices"), ( + f"deviceCountAtLeast redefinition for {fn.__name__}" + ) fn.num_required_devices = self.num_required_devices @wraps(fn) @@ -1474,13 +1474,13 @@ def onlyNativeDeviceTypesAnd(devices=None): # self.precision *2, max(1, self.precision)). class precisionOverride: def __init__(self, d): - assert isinstance( - d, dict - ), "precisionOverride not given a dtype : precision dict!" + assert isinstance(d, dict), ( + "precisionOverride not given a dtype : precision dict!" + ) for dtype in d.keys(): - assert isinstance( - dtype, torch.dtype - ), f"precisionOverride given unknown dtype {dtype}" + assert isinstance(dtype, torch.dtype), ( + f"precisionOverride given unknown dtype {dtype}" + ) self.d = d @@ -1513,12 +1513,12 @@ class toleranceOverride: def __init__(self, d): assert isinstance(d, dict), "toleranceOverride not given a dtype : tol dict!" for dtype, prec in d.items(): - assert isinstance( - dtype, torch.dtype - ), f"toleranceOverride given unknown dtype {dtype}" - assert isinstance( - prec, tol - ), "toleranceOverride not given a dtype : tol dict!" + assert isinstance(dtype, torch.dtype), ( + f"toleranceOverride given unknown dtype {dtype}" + ) + assert isinstance(prec, tol), ( + "toleranceOverride not given a dtype : tol dict!" + ) self.d = d @@ -1546,13 +1546,13 @@ class dtypes: "all dtype variants must be. " f"Received non-list non-tuple dtype {str(arg)}" ) - assert all( - isinstance(dtype, torch.dtype) for dtype in arg - ), f"Unknown dtype in {str(arg)}" + assert all(isinstance(dtype, torch.dtype) for dtype in arg), ( + f"Unknown dtype in {str(arg)}" + ) else: - assert all( - isinstance(arg, torch.dtype) for arg in args - ), f"Unknown dtype in {str(args)}" + assert all(isinstance(arg, torch.dtype) for arg in args), ( + f"Unknown dtype in {str(args)}" + ) self.args = args self.device_type = device_type diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index af1aafd3871..0dbb6ca0ea7 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -253,9 +253,9 @@ def verify_ddp_error_logged(model_DDP, err_substr): if err_substr.find("\nException raised from ") == -1 else err_substr.split("\nException raised from ")[0] ) - assert ( - actual in logging_err - ), f"Did not find expected {actual} in ddp logging data error: {logging_err}" + assert actual in logging_err, ( + f"Did not find expected {actual} in ddp logging data error: {logging_err}" + ) def with_nccl_blocking_wait(func): @@ -294,9 +294,9 @@ def with_nccl_blocking_wait(func): finally: # restore old values. if cached_nccl_async_error_handling is not None: - os.environ[ - "TORCH_NCCL_ASYNC_ERROR_HANDLING" - ] = cached_nccl_async_error_handling + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = ( + cached_nccl_async_error_handling + ) if cached_nccl_blocking_wait is not None: os.environ["TORCH_NCCL_BLOCKING_WAIT"] = cached_nccl_blocking_wait @@ -812,7 +812,7 @@ class MultiProcessTestCase(TestCase): sys.exit(TEST_SKIPS["generic"].exit_code) except Exception: logger.error( - "Caught exception: \n%s exiting " "process %s with exit code: %s", + "Caught exception: \n%s exiting process %s with exit code: %s", traceback.format_exc(), self.rank, MultiProcessTestCase.TEST_ERROR_EXIT_CODE, @@ -1689,9 +1689,7 @@ class MultiProcContinousTest(TestCase): cls.processes.append(process) cls.task_queues.append(task_queue) cls.completion_queues.append(completion_queue) - logger.info( - "Started process %s with pid %s", rank, process.pid - ) # noqa: UP031 + logger.info("Started process %s with pid %s", rank, process.pid) # noqa: UP031 @classmethod def setUpClass(cls): diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index a9e24eb90ef..0e50762893d 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -1285,10 +1285,10 @@ class FSDPTest(MultiProcessTestCase): loss = sharded_grad_scaler.scale(loss) if not mixed_precision and not use_pure_fp16: - assert ( - loss.dtype == torch.float32 - ), "loss data type should be float32, as the original \ + assert loss.dtype == torch.float32, ( + "loss data type should be float32, as the original \ parameter data type is float32." + ) else: if use_pure_fp16: self.assertEqual(loss.dtype, torch.float16) @@ -1354,9 +1354,9 @@ class FSDPTest(MultiProcessTestCase): wrapper should provide data parallel semantics. If ``None``, then the callable defaults to the DDP constructor. """ - assert ( - fsdp_init_mode != FSDPInitMode.NO_FSDP - ), "Expects an FSDP init mode that wraps with FSDP" + assert fsdp_init_mode != FSDPInitMode.NO_FSDP, ( + "Expects an FSDP init mode that wraps with FSDP" + ) if init_kwargs is None: init_kwargs = {} lr = 1e-2 diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index 780514e6743..96bab4a084c 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -1268,9 +1268,9 @@ def _get_optim_inputs_including_global_cliquey_kwargs( trivial. That said, we sometimes want to test for all possible configs on an optimizer including all supported flags, so this helper returns all optim inputs. """ - assert all( - x in ["foreach", "fused", "differentiable"] for x in skip - ), "skip must be a subset of ['foreach', 'fused', 'differentiable']" + assert all(x in ["foreach", "fused", "differentiable"] for x in skip), ( + "skip must be a subset of ['foreach', 'fused', 'differentiable']" + ) optim_inputs = optim_info.optim_inputs_func(device) diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index f3a72441f37..4eb6677a035 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -477,7 +477,9 @@ def with_comms( def decorator(func, eager_init: bool = False, backend: Optional[str] = None): @wraps(func) # pyre-ignore[6] def wrapper( - self, *args: tuple[object], **kwargs: dict[str, Any] # type: ignore[misc] + self, + *args: tuple[object], + **kwargs: dict[str, Any], # type: ignore[misc] ) -> None: self.init_pg(eager_init, backend) diff --git a/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py b/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py index 1ac9252d498..61c21be3ca0 100644 --- a/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py +++ b/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py @@ -253,7 +253,11 @@ class Trainer: else: input_batches = batches - with self.hybrid_module.join() if simulate_uneven_inputs else contextlib.nullcontext(): + with ( + self.hybrid_module.join() + if simulate_uneven_inputs + else contextlib.nullcontext() + ): for b in input_batches: with dist_autograd.context() as context_id: output = self.hybrid_module.forward(b) @@ -261,8 +265,7 @@ class Trainer: dist_autograd.backward(context_id, [loss]) grads_dict = dist_autograd.get_gradients(context_id) gLogger.info( - "Loss is %s for mini batch: %s. " - "Grads dict has %s entries: %s", + "Loss is %s for mini batch: %s. Grads dict has %s entries: %s", loss, mini_batch, len(grads_dict), diff --git a/torch/testing/_internal/opinfo/core.py b/torch/testing/_internal/opinfo/core.py index 5cd248792dc..97dee3c7c0f 100644 --- a/torch/testing/_internal/opinfo/core.py +++ b/torch/testing/_internal/opinfo/core.py @@ -162,9 +162,7 @@ class SampleInput: # Allow calling either as SampleInput(input, args=args, kwargs=kwargs), or as # SampleInput(input, *args, **kwargs) but not to mix the two forms if args is not None or kwargs is not None: - assert ( - not var_args and not var_kwargs - ), """ + assert not var_args and not var_kwargs, """ A SampleInput can be constructed "naturally" with *args and **kwargs or by explicitly setting the "args" and "kwargs" parameters, but the two methods of construction cannot be mixed!""" @@ -226,7 +224,7 @@ cannot specify additional metadata in keyword arguments""" f"name={repr(self.name)}", ] - return f'SampleInput({", ".join(a for a in arguments if a is not None)})' + return f"SampleInput({', '.join(a for a in arguments if a is not None)})" def __repr__(self): return self._repr_helper(lambda x: x) @@ -1601,13 +1599,11 @@ class SampleRule(ABC): # returns a string identifier of the rule type @abstractmethod - def type(self) -> str: - ... + def type(self) -> str: ... # returns an appropriate context that handles the xfail, skips, etc. @abstractmethod - def get_context(self, test_case): - ... + def get_context(self, test_case): ... # useful for specifying xfails @@ -1791,8 +1787,10 @@ class ReductionOpInfo(OpInfo): # kwargs to use when calling the op. This is required for operators that # have other required parameters besides the input tensor. generate_args_kwargs: Callable = lambda t, dim=None, keepdim=False: ( - yield (), - {}, + yield ( + (), + {}, + ) ), # Options from the OpInfo base class **kwargs, @@ -2476,9 +2474,9 @@ class BinaryUfuncInfo(OpInfo): self.supports_one_python_scalar = True if self.supports_one_python_scalar: - assert ( - supports_rhs_python_scalar - ), "Can't support lhs and rhs Python scalars but not rhs scalars!" + assert supports_rhs_python_scalar, ( + "Can't support lhs and rhs Python scalars but not rhs scalars!" + ) # The following functions and classes are for testing elementwise unary operators. diff --git a/torch/testing/_internal/opinfo/definitions/_masked.py b/torch/testing/_internal/opinfo/definitions/_masked.py index e05299632d0..c5d08073803 100644 --- a/torch/testing/_internal/opinfo/definitions/_masked.py +++ b/torch/testing/_internal/opinfo/definitions/_masked.py @@ -102,8 +102,9 @@ def sample_inputs_masked_reduction(op_info, device, dtype, requires_grad, **kwar for mask in _generate_masked_op_mask( sample_input.input.shape, device, **kwargs ): - sample_input_args, sample_input_kwargs = sample_input.args, dict( - mask=mask, **sample_input.kwargs + sample_input_args, sample_input_kwargs = ( + sample_input.args, + dict(mask=mask, **sample_input.kwargs), ) yield SampleInput( sample_input.input.detach().requires_grad_(requires_grad), @@ -224,8 +225,9 @@ def sample_inputs_masked_norm(op_info, device, dtype, requires_grad, **kwargs): op_info, device, dtype, requires_grad, **kwargs ): sample_input_args, sample_input_kwargs = ( - ord, - ) + sample_input.args, sample_input.kwargs.copy() + (ord,) + sample_input.args, + sample_input.kwargs.copy(), + ) yield SampleInput( sample_input.input.clone().requires_grad_(requires_grad), args=sample_input_args, @@ -276,8 +278,9 @@ def sample_inputs_masked_std_var(op_info, device, dtype, requires_grad, **kwargs for mask in _generate_masked_op_mask( sample_input.input.shape, device, **kwargs ): - sample_input_args, sample_input_kwargs = sample_input.args, dict( - mask=mask, **sample_input.kwargs + sample_input_args, sample_input_kwargs = ( + sample_input.args, + dict(mask=mask, **sample_input.kwargs), ) yield SampleInput( sample_input.input.detach().requires_grad_(requires_grad), @@ -364,8 +367,9 @@ def sample_inputs_masked_cumops(op_info, device, dtype, requires_grad, **kwargs) ): if type(mask) != torch.Tensor: continue - sample_input_args, sample_input_kwargs = sample_input.args, dict( - mask=mask, **sample_input.kwargs + sample_input_args, sample_input_kwargs = ( + sample_input.args, + dict(mask=mask, **sample_input.kwargs), ) if "keepdim" in sample_input_kwargs: sample_input_kwargs.pop("keepdim") diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py index 4ec4e5b5915..811b45fd1d6 100644 --- a/torch/utils/_config_module.py +++ b/torch/utils/_config_module.py @@ -112,7 +112,7 @@ class _Config(Generic[T]): @staticmethod def string_or_list_of_string_to_list( - val: Optional[Union[str, list[str]]] + val: Optional[Union[str, list[str]]], ) -> Optional[list[str]]: if val is None: return None @@ -135,8 +135,7 @@ if TYPE_CHECKING: env_name_force: Optional[Union[str, list[str]]] = None, value_type: Optional[type] = None, alias: Optional[str] = None, - ) -> T: - ... + ) -> T: ... else: @@ -323,9 +322,9 @@ class _ConfigEntry: # Ensure justknobs and envvars are allowlisted types if self.justknob is not None and self.default is not None: - assert isinstance( - self.default, bool - ), f"justknobs only support booleans, {self.default} is not a boolean" + assert isinstance(self.default, bool), ( + f"justknobs only support booleans, {self.default} is not a boolean" + ) if self.value_type is not None and ( config.env_name_default is not None or config.env_name_force is not None ): @@ -334,7 +333,9 @@ class _ConfigEntry: str, Optional[bool], Optional[str], - ), f"envvar configs only support (optional) booleans or strings, {self.value_type} is neither" + ), ( + f"envvar configs only support (optional) booleans or strings, {self.value_type} is neither" + ) class ConfigModule(ModuleType): diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 24c73061b71..5ddda2c7edb 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -282,9 +282,9 @@ def tree_is_leaf( False >>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple)) True - >>> tree_is_leaf({'a': 1, 'b': 2, 'c': 3}) + >>> tree_is_leaf({"a": 1, "b": 2, "c": 3}) False - >>> tree_is_leaf({'a': 1, 'b': 2, 'c': None}) + >>> tree_is_leaf({"a": 1, "b": 2, "c": None}) False Args: @@ -586,29 +586,28 @@ MapOnlyFn = Callable[[T], Callable[[Any], Any]] # These specializations help with type inference on the lambda passed to this # function @overload -def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: - ... +def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: ... @overload -def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: - ... +def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: ... @overload -def map_only(type_or_types_or_pred: Type3[T, S, U], /) -> MapOnlyFn[Fn3[T, S, U, Any]]: - ... +def map_only( + type_or_types_or_pred: Type3[T, S, U], / +) -> MapOnlyFn[Fn3[T, S, U, Any]]: ... # This specialization is needed for the implementations below that call @overload -def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: - ... +def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: ... @overload -def map_only(type_or_types_or_pred: Callable[[Any], bool], /) -> MapOnlyFn[FnAny[Any]]: - ... +def map_only( + type_or_types_or_pred: Callable[[Any], bool], / +) -> MapOnlyFn[FnAny[Any]]: ... def map_only( @@ -664,8 +663,7 @@ def tree_map_only( func: Fn[T, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -675,8 +673,7 @@ def tree_map_only( func: Fn2[T, S, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -686,8 +683,7 @@ def tree_map_only( func: Fn3[T, S, U, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -697,8 +693,7 @@ def tree_map_only( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -708,8 +703,7 @@ def tree_map_only( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... def tree_map_only( @@ -729,8 +723,7 @@ def tree_map_only_( func: Fn[T, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -740,8 +733,7 @@ def tree_map_only_( func: Fn2[T, S, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -751,8 +743,7 @@ def tree_map_only_( func: Fn3[T, S, U, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -762,8 +753,7 @@ def tree_map_only_( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -773,8 +763,7 @@ def tree_map_only_( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... def tree_map_only_( @@ -812,8 +801,7 @@ def tree_all_only( pred: Fn[T, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... @overload @@ -823,8 +811,7 @@ def tree_all_only( pred: Fn2[T, S, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... @overload @@ -834,8 +821,7 @@ def tree_all_only( pred: Fn3[T, S, U, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... def tree_all_only( @@ -856,8 +842,7 @@ def tree_any_only( pred: Fn[T, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... @overload @@ -867,8 +852,7 @@ def tree_any_only( pred: Fn2[T, S, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... @overload @@ -878,8 +862,7 @@ def tree_any_only( pred: Fn3[T, S, U, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... def tree_any_only( diff --git a/torch/utils/_functools.py b/torch/utils/_functools.py index 40ffd8f80a9..0b555ffc27f 100644 --- a/torch/utils/_functools.py +++ b/torch/utils/_functools.py @@ -12,7 +12,7 @@ _cache_sentinel = object() def cache_method( - f: Callable[Concatenate[_C, _P], _T] + f: Callable[Concatenate[_C, _P], _T], ) -> Callable[Concatenate[_C, _P], _T]: """ Like `@functools.cache` but for methods. diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index 664994e6fe3..84353fbbebf 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -302,14 +302,12 @@ class BaseTorchDispatchMode(TorchDispatchMode): # Subtypes which have __tensor_flatten__ and __tensor_unflatten__. class TensorWithFlatten(Protocol): - def __tensor_flatten__(self) -> tuple[Sequence[str], object]: - ... + def __tensor_flatten__(self) -> tuple[Sequence[str], object]: ... @staticmethod def __tensor_unflatten__( inner_tensors: int, flatten_spec: int, outer_size: int, outer_stride: int - ) -> torch.Tensor: - ... + ) -> torch.Tensor: ... # It would be really nice to be able to say that the return of # is_traceable_wrapper_subclass() is Intersection[torch.Tensor, @@ -318,26 +316,20 @@ class TensorWithFlatten(Protocol): shape: torch._C.Size @overload - def stride(self, dim: None = None) -> tuple[int, ...]: - ... + def stride(self, dim: None = None) -> tuple[int, ...]: ... @overload - def stride(self, dim: int) -> int: - ... + def stride(self, dim: int) -> int: ... @overload - def size(self, dim: None = None) -> tuple[int, ...]: - ... + def size(self, dim: None = None) -> tuple[int, ...]: ... @overload - def size(self, dim: int) -> int: - ... + def size(self, dim: int) -> int: ... - def storage_offset(self) -> int: - ... + def storage_offset(self) -> int: ... - def dim(self) -> int: - ... + def dim(self) -> int: ... @overload def to( @@ -347,8 +339,7 @@ class TensorWithFlatten(Protocol): copy: bool = False, *, memory_format: Optional[torch.memory_format] = None, - ) -> torch.Tensor: - ... + ) -> torch.Tensor: ... @overload def to( @@ -359,8 +350,7 @@ class TensorWithFlatten(Protocol): copy: bool = False, *, memory_format: Optional[torch.memory_format] = None, - ) -> torch.Tensor: - ... + ) -> torch.Tensor: ... @overload def to( @@ -370,8 +360,7 @@ class TensorWithFlatten(Protocol): copy: bool = False, *, memory_format: Optional[torch.memory_format] = None, - ) -> torch.Tensor: - ... + ) -> torch.Tensor: ... def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]: diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 3e7cadc6dc7..02954d33866 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -99,17 +99,13 @@ NO_SERIALIZED_TYPE_NAME_FOUND = "NO_SERIALIZED_TYPE_NAME_FOUND" class KeyEntry(Protocol): - def __hash__(self) -> int: - ... + def __hash__(self) -> int: ... - def __eq__(self, other: object) -> bool: - ... + def __eq__(self, other: object) -> bool: ... - def __str__(self) -> str: - ... + def __str__(self) -> str: ... - def get(self, parent: Any) -> Any: - ... + def get(self, parent: Any) -> Any: ... class EnumEncoder(json.JSONEncoder): @@ -757,7 +753,7 @@ def _tuple_flatten(d: tuple[T, ...]) -> tuple[list[T], Context]: def _tuple_flatten_with_keys( - d: tuple[T, ...] + d: tuple[T, ...], ) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _tuple_flatten(d) return [(SequenceKey(i), v) for i, v in enumerate(values)], context @@ -785,7 +781,7 @@ def _dict_flatten(d: dict[Any, T]) -> tuple[list[T], Context]: def _dict_flatten_with_keys( - d: dict[Any, T] + d: dict[Any, T], ) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _dict_flatten(d) return [(MappingKey(k), v) for k, v in zip(context, values)], context @@ -849,7 +845,7 @@ def _ordereddict_flatten(d: OrderedDict[Any, T]) -> tuple[list[T], Context]: def _ordereddict_flatten_with_keys( - d: OrderedDict[Any, T] + d: OrderedDict[Any, T], ) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _ordereddict_flatten(d) return [(MappingKey(k), v) for k, v in zip(context, values)], context @@ -872,7 +868,7 @@ def _defaultdict_flatten(d: defaultdict[Any, T]) -> tuple[list[T], Context]: def _defaultdict_flatten_with_keys( - d: defaultdict[Any, T] + d: defaultdict[Any, T], ) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _defaultdict_flatten(d) _, dict_context = context @@ -1035,9 +1031,9 @@ def tree_is_leaf( False >>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple)) True - >>> tree_is_leaf({'a': 1, 'b': 2, 'c': 3}) + >>> tree_is_leaf({"a": 1, "b": 2, "c": 3}) False - >>> tree_is_leaf({'a': 1, 'b': 2, 'c': None}) + >>> tree_is_leaf({"a": 1, "b": 2, "c": None}) False """ if is_leaf is not None and is_leaf(tree): @@ -1346,9 +1342,9 @@ def tree_map( See also :func:`tree_map_`. - >>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)}) + >>> tree_map(lambda x: x + 1, {"x": 7, "y": (42, 64)}) {'x': 8, 'y': (43, 65)} - >>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}) + >>> tree_map(lambda x: x is None, {"x": 7, "y": (42, 64), "z": None}) {'x': False, 'y': (False, False), 'z': True} If multiple inputs are given, the structure of the tree is taken from the first input; @@ -1432,29 +1428,28 @@ MapOnlyFn = Callable[[T], Callable[[Any], Any]] # These specializations help with type inference on the lambda passed to this # function @overload -def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: - ... +def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: ... @overload -def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: - ... +def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: ... @overload -def map_only(type_or_types_or_pred: Type3[T, S, U], /) -> MapOnlyFn[Fn3[T, S, U, Any]]: - ... +def map_only( + type_or_types_or_pred: Type3[T, S, U], / +) -> MapOnlyFn[Fn3[T, S, U, Any]]: ... # This specialization is needed for the implementations below that call @overload -def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: - ... +def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: ... @overload -def map_only(type_or_types_or_pred: Callable[[Any], bool], /) -> MapOnlyFn[FnAny[Any]]: - ... +def map_only( + type_or_types_or_pred: Callable[[Any], bool], / +) -> MapOnlyFn[FnAny[Any]]: ... def map_only( @@ -1510,8 +1505,7 @@ def tree_map_only( func: Fn[T, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1521,8 +1515,7 @@ def tree_map_only( func: Fn2[T, S, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1532,8 +1525,7 @@ def tree_map_only( func: Fn3[T, S, U, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1543,8 +1535,7 @@ def tree_map_only( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1554,8 +1545,7 @@ def tree_map_only( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... def tree_map_only( @@ -1575,8 +1565,7 @@ def tree_map_only_( func: Fn[T, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1586,8 +1575,7 @@ def tree_map_only_( func: Fn2[T, S, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1597,8 +1585,7 @@ def tree_map_only_( func: Fn3[T, S, U, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1608,8 +1595,7 @@ def tree_map_only_( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1619,8 +1605,7 @@ def tree_map_only_( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... def tree_map_only_( @@ -1658,8 +1643,7 @@ def tree_all_only( pred: Fn[T, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... @overload @@ -1669,8 +1653,7 @@ def tree_all_only( pred: Fn2[T, S, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... @overload @@ -1680,8 +1663,7 @@ def tree_all_only( pred: Fn3[T, S, U, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... def tree_all_only( @@ -1702,8 +1684,7 @@ def tree_any_only( pred: Fn[T, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... @overload @@ -1713,8 +1694,7 @@ def tree_any_only( pred: Fn2[T, S, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... @overload @@ -1724,8 +1704,7 @@ def tree_any_only( pred: Fn3[T, S, U, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... def tree_any_only( @@ -1862,7 +1841,7 @@ def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec: if json_schema["type"] not in SERIALIZED_TYPE_TO_PYTHON_TYPE: raise NotImplementedError( - f'Deserializing {json_schema["type"]} in pytree is not registered.', + f"Deserializing {json_schema['type']} in pytree is not registered.", ) typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[json_schema["type"]] diff --git a/torch/utils/_strobelight/cli_function_profiler.py b/torch/utils/_strobelight/cli_function_profiler.py index 39e981a78ac..9b94a7b7a48 100644 --- a/torch/utils/_strobelight/cli_function_profiler.py +++ b/torch/utils/_strobelight/cli_function_profiler.py @@ -301,7 +301,7 @@ def strobelight( profiler = StrobelightCLIFunctionProfiler(**kwargs) def strobelight_inner( - work_function: Callable[_P, _R] + work_function: Callable[_P, _R], ) -> Callable[_P, Optional[_R]]: @functools.wraps(work_function) def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]: diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 42c99839d41..2b6c159f5c3 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -98,7 +98,7 @@ def _is_symbols_binary_summation(expr: sympy.Expr) -> bool: def _keep_float( - f: Callable[[Unpack[_Ts]], _T] + f: Callable[[Unpack[_Ts]], _T], ) -> Callable[[Unpack[_Ts]], Union[_T, sympy.Float]]: @functools.wraps(f) def inner(*args: Unpack[_Ts]) -> Union[_T, sympy.Float]: @@ -926,10 +926,12 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc] _eval_is_algebraic = lambda s: _torf(i.is_algebraic for i in s.args) # noqa: E731 _eval_is_antihermitian = lambda s: _torf( # noqa: E731 - i.is_antihermitian for i in s.args # noqa: E731 + i.is_antihermitian + for i in s.args # noqa: E731 ) # noqa: E731 _eval_is_commutative = lambda s: _torf( # noqa: E731 - i.is_commutative for i in s.args # noqa: E731 + i.is_commutative + for i in s.args # noqa: E731 ) # noqa: E731 _eval_is_complex = lambda s: _torf(i.is_complex for i in s.args) # noqa: E731 _eval_is_composite = lambda s: _torf(i.is_composite for i in s.args) # noqa: E731 @@ -943,10 +945,12 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc] _eval_is_negative = lambda s: _torf(i.is_negative for i in s.args) # noqa: E731 _eval_is_noninteger = lambda s: _torf(i.is_noninteger for i in s.args) # noqa: E731 _eval_is_nonnegative = lambda s: _torf( # noqa: E731 - i.is_nonnegative for i in s.args # noqa: E731 + i.is_nonnegative + for i in s.args # noqa: E731 ) # noqa: E731 _eval_is_nonpositive = lambda s: _torf( # noqa: E731 - i.is_nonpositive for i in s.args # noqa: E731 + i.is_nonpositive + for i in s.args # noqa: E731 ) # noqa: E731 _eval_is_nonzero = lambda s: _torf(i.is_nonzero for i in s.args) # noqa: E731 _eval_is_odd = lambda s: _torf(i.is_odd for i in s.args) # noqa: E731 @@ -956,10 +960,12 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc] _eval_is_rational = lambda s: _torf(i.is_rational for i in s.args) # noqa: E731 _eval_is_real = lambda s: _torf(i.is_real for i in s.args) # noqa: E731 _eval_is_extended_real = lambda s: _torf( # noqa: E731 - i.is_extended_real for i in s.args # noqa: E731 + i.is_extended_real + for i in s.args # noqa: E731 ) # noqa: E731 _eval_is_transcendental = lambda s: _torf( # noqa: E731 - i.is_transcendental for i in s.args # noqa: E731 + i.is_transcendental + for i in s.args # noqa: E731 ) # noqa: E731 _eval_is_zero = lambda s: _torf(i.is_zero for i in s.args) # noqa: E731 diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index 1b360337a53..e02e049cc36 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -144,16 +144,14 @@ class ValueRanges(Generic[_T]): self: ValueRanges[sympy.Expr], lower: ExprIn, upper: ExprIn, - ) -> None: - ... + ) -> None: ... @overload def __init__( # type: ignore[misc] self: ValueRanges[SympyBoolean], lower: BoolIn, upper: BoolIn, - ) -> None: - ... + ) -> None: ... def __init__(self, lower: AllIn, upper: AllIn) -> None: lower = simple_sympify(lower) @@ -240,15 +238,13 @@ class ValueRanges(Generic[_T]): def __and__( self: ValueRanges[sympy.Expr], other: ValueRanges[sympy.Expr], - ) -> ValueRanges[sympy.Expr]: - ... + ) -> ValueRanges[sympy.Expr]: ... @overload def __and__( # type: ignore[misc] self: ValueRanges[SympyBoolean], other: ValueRanges[SympyBoolean], - ) -> ValueRanges[SympyBoolean]: - ... + ) -> ValueRanges[SympyBoolean]: ... def __and__(self: AllVR, other: AllVR) -> AllVR: if other in (ValueRanges.unknown(), ValueRanges.unknown_int()): @@ -272,15 +268,13 @@ class ValueRanges(Generic[_T]): def __or__( self: ValueRanges[sympy.Expr], other: ValueRanges[sympy.Expr], - ) -> ValueRanges[sympy.Expr]: - ... + ) -> ValueRanges[sympy.Expr]: ... @overload def __or__( # type: ignore[misc] self: ValueRanges[SympyBoolean], other: ValueRanges[SympyBoolean], - ) -> ValueRanges[SympyBoolean]: - ... + ) -> ValueRanges[SympyBoolean]: ... def __or__(self: AllVR, other: AllVR) -> AllVR: if ValueRanges.unknown() in (self, other): @@ -343,8 +337,7 @@ class ValueRanges(Generic[_T]): @overload @staticmethod - def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: - ... + def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: ... @overload @staticmethod @@ -384,8 +377,7 @@ class ValueRanges(Generic[_T]): x: Union[ExprIn, ExprVR], y: Union[ExprIn, ExprVR], fn: ExprFn2, - ) -> ExprVR: - ... + ) -> ExprVR: ... @overload @staticmethod @@ -393,8 +385,7 @@ class ValueRanges(Generic[_T]): x: Union[BoolIn, BoolVR], y: Union[BoolIn, BoolVR], fn: BoolFn2, - ) -> BoolVR: - ... + ) -> BoolVR: ... @staticmethod def coordinatewise_increasing_map( diff --git a/torch/utils/backend_registration.py b/torch/utils/backend_registration.py index e11a7afc09d..5a83aede8d4 100644 --- a/torch/utils/backend_registration.py +++ b/torch/utils/backend_registration.py @@ -426,9 +426,9 @@ def _get_custom_mod_func(func_name: str): it is marked as private. It is a convenience function for backend implementers to more easily call the hooks into their backend extensions. """ - assert isinstance( - func_name, str - ), f"func_name must be `str`, but got `{type(func_name)}`." + assert isinstance(func_name, str), ( + f"func_name must be `str`, but got `{type(func_name)}`." + ) backend_name = _get_privateuse1_backend_name() custom_device_mod = getattr(torch, backend_name, None) # type: ignore[arg-type] function = getattr(custom_device_mod, func_name, None) # type: ignore[arg-type] diff --git a/torch/utils/data/_utils/collate.py b/torch/utils/data/_utils/collate.py index 68a4da0731c..3b291b1e60a 100644 --- a/torch/utils/data/_utils/collate.py +++ b/torch/utils/data/_utils/collate.py @@ -44,7 +44,7 @@ def default_convert(data): >>> default_convert(np.array([0, 1])) tensor([0, 1]) >>> # Example with NamedTuple - >>> Point = namedtuple('Point', ['x', 'y']) + >>> Point = namedtuple("Point", ["x", "y"]) >>> default_convert(Point(0, 0)) Point(x=0, y=0) >>> default_convert(Point(np.array(0), np.array(0))) @@ -366,13 +366,13 @@ def default_collate(batch): >>> default_collate([0, 1, 2, 3]) tensor([0, 1, 2, 3]) >>> # Example with a batch of `str`s: - >>> default_collate(['a', 'b', 'c']) + >>> default_collate(["a", "b", "c"]) ['a', 'b', 'c'] >>> # Example with `Map` inside the batch: - >>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}]) + >>> default_collate([{"A": 0, "B": 1}, {"A": 100, "B": 100}]) {'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])} >>> # Example with `NamedTuple` inside the batch: - >>> Point = namedtuple('Point', ['x', 'y']) + >>> Point = namedtuple("Point", ["x", "y"]) >>> default_collate([Point(0, 0), Point(1, 1)]) Point(x=tensor([0, 1]), y=tensor([0, 1])) >>> # Example with `Tuple` inside the batch: diff --git a/torch/utils/data/_utils/pin_memory.py b/torch/utils/data/_utils/pin_memory.py index c75756dd5fd..b53c7aef959 100644 --- a/torch/utils/data/_utils/pin_memory.py +++ b/torch/utils/data/_utils/pin_memory.py @@ -69,7 +69,9 @@ def pin_memory(data, device=None): ) return clone else: - return type(data)({k: pin_memory(sample, device) for k, sample in data.items()}) # type: ignore[call-arg] + return type(data)( + {k: pin_memory(sample, device) for k, sample in data.items()} + ) # type: ignore[call-arg] except TypeError: # The mapping type may not support `copy()` / `update(mapping)` # or `__init__(iterable)`. diff --git a/torch/utils/data/_utils/worker.py b/torch/utils/data/_utils/worker.py index a275e2e86b6..97c7243e78e 100644 --- a/torch/utils/data/_utils/worker.py +++ b/torch/utils/data/_utils/worker.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers. +r"""Contains definitions of the methods used by the _BaseDataLoaderIter workers. These **needs** to be in global scope since Py2 doesn't support serializing static methods. diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index dd7a73ea11e..991b4f00eb8 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -5,6 +5,7 @@ To support these two classes, in `./_utils` we define many utility methods and functions to be run in multiprocessing. E.g., the data loading worker loop is in `./_utils/worker.py`. """ + from __future__ import annotations import functools @@ -1208,7 +1209,10 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter): atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w) # .pid can be None only before process is spawned (not the case, so ignore) - _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc] + _utils.signal_handling._set_worker_pids( + id(self), + tuple(w.pid for w in self._workers), # type: ignore[misc] + ) _utils.signal_handling._set_SIGCHLD_handler() self._worker_pids_set = True self._reset(loader, first_iter=True) diff --git a/torch/utils/data/datapipes/_decorator.py b/torch/utils/data/datapipes/_decorator.py index 13e28a19d62..0833f8fdf75 100644 --- a/torch/utils/data/datapipes/_decorator.py +++ b/torch/utils/data/datapipes/_decorator.py @@ -109,8 +109,7 @@ class non_deterministic: # Decorate with a functional argument if not ( - isinstance(args[0], type) - and issubclass(args[0], IterDataPipe) # type: ignore[arg-type] + isinstance(args[0], type) and issubclass(args[0], IterDataPipe) # type: ignore[arg-type] ): raise TypeError( f"Only `IterDataPipe` can be decorated, but {args[0].__name__} is found" diff --git a/torch/utils/data/datapipes/datapipe.py b/torch/utils/data/datapipes/datapipe.py index d3eeee0ebfd..506f642c411 100644 --- a/torch/utils/data/datapipes/datapipe.py +++ b/torch/utils/data/datapipes/datapipe.py @@ -99,7 +99,9 @@ class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta): >>> from torchdata.datapipes.iter import IterableWrapper, Mapper >>> dp = IterableWrapper(range(10)) >>> map_dp_1 = Mapper(dp, lambda x: x + 1) # Using class constructor - >>> map_dp_2 = dp.map(lambda x: x + 1) # Using functional form (recommended) + >>> map_dp_2 = dp.map( + ... lambda x: x + 1 + ... ) # Using functional form (recommended) >>> list(map_dp_1) [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] >>> list(map_dp_2) @@ -114,7 +116,9 @@ class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta): >>> list(it1) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] >>> it1 = iter(source_dp) - >>> it2 = iter(source_dp) # The creation of a new iterator invalidates `it1` + >>> it2 = iter( + ... source_dp + ... ) # The creation of a new iterator invalidates `it1` >>> next(it2) 0 >>> next(it1) # Further usage of `it1` will raise a `RunTimeError` diff --git a/torch/utils/data/datapipes/iter/callable.py b/torch/utils/data/datapipes/iter/callable.py index 718e728c938..41c6bb362af 100644 --- a/torch/utils/data/datapipes/iter/callable.py +++ b/torch/utils/data/datapipes/iter/callable.py @@ -55,7 +55,8 @@ class MapperIterDataPipe(IterDataPipe[_T_co]): >>> def add_one(x): ... return x + 1 >>> dp = IterableWrapper(range(10)) - >>> map_dp_1 = dp.map(add_one) # Invocation via functional form is preferred + >>> # Invocation via functional form is preferred + ... map_dp_1 = dp.map(add_one) >>> list(map_dp_1) [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] >>> # We discourage the usage of `lambda` functions as they are not serializable with `pickle` @@ -202,7 +203,7 @@ class CollatorIterDataPipe(MapperIterDataPipe): >>> class MyIterDataPipe(torch.utils.data.IterDataPipe): ... def __init__(self, start, end): ... super(MyIterDataPipe).__init__() - ... assert end > start, "this example code only works with end >= start" + ... assert end > start, "this example only works with end >= start" ... self.start = start ... self.end = end ... @@ -211,13 +212,11 @@ class CollatorIterDataPipe(MapperIterDataPipe): ... ... def __len__(self): ... return self.end - self.start - ... >>> ds = MyIterDataPipe(start=3, end=7) >>> print(list(ds)) [3, 4, 5, 6] >>> def collate_fn(batch): ... return torch.tensor(batch, dtype=torch.float) - ... >>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn) >>> print(list(collated_ds)) [tensor(3.), tensor(4.), tensor(5.), tensor(6.)] diff --git a/torch/utils/data/datapipes/iter/combinatorics.py b/torch/utils/data/datapipes/iter/combinatorics.py index 4c602ce4eed..f92edd6b7b3 100644 --- a/torch/utils/data/datapipes/iter/combinatorics.py +++ b/torch/utils/data/datapipes/iter/combinatorics.py @@ -38,15 +38,17 @@ class SamplerIterDataPipe(IterDataPipe[_T_co]): sampler_args: Optional[tuple] = None, sampler_kwargs: Optional[dict] = None, ) -> None: - assert isinstance( - datapipe, Sized - ), "Sampler class requires input datapipe implemented `__len__`" + assert isinstance(datapipe, Sized), ( + "Sampler class requires input datapipe implemented `__len__`" + ) super().__init__() self.datapipe = datapipe self.sampler_args = () if sampler_args is None else sampler_args self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs # https://github.com/python/mypy/pull/9629 will solve - self.sampler = sampler(*self.sampler_args, data_source=self.datapipe, **self.sampler_kwargs) # type: ignore[misc] + self.sampler = sampler( + *self.sampler_args, data_source=self.datapipe, **self.sampler_kwargs + ) # type: ignore[misc] def __iter__(self) -> Iterator[_T_co]: return iter(self.sampler) diff --git a/torch/utils/data/datapipes/iter/combining.py b/torch/utils/data/datapipes/iter/combining.py index deaca079c68..8c6abc50621 100644 --- a/torch/utils/data/datapipes/iter/combining.py +++ b/torch/utils/data/datapipes/iter/combining.py @@ -116,16 +116,13 @@ class _ContainerTemplate(ABC): r"""Abstract class for container ``DataPipes``. The followings are three required methods.""" @abstractmethod - def get_next_element_by_instance(self, instance_id: int): - ... + def get_next_element_by_instance(self, instance_id: int): ... @abstractmethod - def is_every_instance_exhausted(self) -> bool: - ... + def is_every_instance_exhausted(self) -> bool: ... @abstractmethod - def reset(self) -> None: - ... + def reset(self) -> None: ... @abstractmethod def get_length_by_instance(self, instance_id: int): @@ -403,7 +400,9 @@ class DemultiplexerIterDataPipe(IterDataPipe): >>> # It can also filter out any element that gets `None` from the `classifier_fn` >>> def odd_or_even_no_zero(n): ... return n % 2 if n != 0 else None - >>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True) + >>> dp1, dp2 = source_dp.demux( + ... num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True + ... ) >>> list(dp1) [2, 4] >>> list(dp2) @@ -428,7 +427,9 @@ class DemultiplexerIterDataPipe(IterDataPipe): # When num_instances == 1, demux can be replaced by filter, # but keep it as Demultiplexer for the sake of consistency # like throwing Error when classification result is out of o range - container = _DemultiplexerIterDataPipe(datapipe, num_instances, classifier_fn, drop_none, buffer_size) # type: ignore[abstract] + container = _DemultiplexerIterDataPipe( + datapipe, num_instances, classifier_fn, drop_none, buffer_size + ) # type: ignore[abstract] return [_ChildDataPipe(container, i) for i in range(num_instances)] @@ -602,16 +603,18 @@ class MultiplexerIterDataPipe(IterDataPipe): Example: >>> # xdoctest: +REQUIRES(module:torchdata) >>> from torchdata.datapipes.iter import IterableWrapper - >>> dp1, dp2, dp3 = IterableWrapper(range(3)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25)) + >>> dp1, dp2, dp3 = ( + ... IterableWrapper(range(3)), + ... IterableWrapper(range(10, 15)), + ... IterableWrapper(range(20, 25)), + ... ) >>> list(dp1.mux(dp2, dp3)) [0, 10, 20, 1, 11, 21, 2, 12, 22] """ def __init__(self, *datapipes): self.datapipes = datapipes - self.buffer: list = ( - [] - ) # Store values to be yielded only when every iterator provides one + self.buffer: list = [] # Store values to be yielded only when every iterator provides one def __iter__(self): iterators = [iter(x) for x in self.datapipes] @@ -670,7 +673,11 @@ class ZipperIterDataPipe(IterDataPipe[tuple[_T_co]]): Example: >>> # xdoctest: +REQUIRES(module:torchdata) >>> from torchdata.datapipes.iter import IterableWrapper - >>> dp1, dp2, dp3 = IterableWrapper(range(5)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25)) + >>> dp1, dp2, dp3 = ( + ... IterableWrapper(range(5)), + ... IterableWrapper(range(10, 15)), + ... IterableWrapper(range(20, 25)), + ... ) >>> list(dp1.zip(dp2, dp3)) [(0, 10, 20), (1, 11, 21), (2, 12, 22), (3, 13, 23), (4, 14, 24)] """ diff --git a/torch/utils/data/datapipes/iter/fileopener.py b/torch/utils/data/datapipes/iter/fileopener.py index 2542c89773b..3025b809e12 100644 --- a/torch/utils/data/datapipes/iter/fileopener.py +++ b/torch/utils/data/datapipes/iter/fileopener.py @@ -33,8 +33,12 @@ class FileOpenerIterDataPipe(IterDataPipe[tuple[str, IOBase]]): Example: >>> # xdoctest: +SKIP - >>> from torchdata.datapipes.iter import FileLister, FileOpener, StreamReader - >>> dp = FileLister(root=".").filter(lambda fname: fname.endswith('.txt')) + >>> from torchdata.datapipes.iter import ( + ... FileLister, + ... FileOpener, + ... StreamReader, + ... ) + >>> dp = FileLister(root=".").filter(lambda fname: fname.endswith(".txt")) >>> dp = FileOpener(dp) >>> dp = StreamReader(dp) >>> list(dp) diff --git a/torch/utils/data/datapipes/iter/grouping.py b/torch/utils/data/datapipes/iter/grouping.py index 08d124fdc60..055d9c28b09 100644 --- a/torch/utils/data/datapipes/iter/grouping.py +++ b/torch/utils/data/datapipes/iter/grouping.py @@ -182,7 +182,9 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]): >>> from torchdata.datapipes.iter import IterableWrapper >>> def group_fn(file): ... return os.path.basename(file).split(".")[0] - >>> source_dp = IterableWrapper(["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"]) + >>> source_dp = IterableWrapper( + ... ["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"] + ... ) >>> dp0 = source_dp.groupby(group_key_fn=group_fn) >>> list(dp0) [['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']] @@ -191,7 +193,12 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]): >>> list(dp1) [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']] >>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size` - >>> dp2 = source_dp.groupby(group_key_fn=group_fn, buffer_size=3, group_size=3, guaranteed_group_size=2) + >>> dp2 = source_dp.groupby( + ... group_key_fn=group_fn, + ... buffer_size=3, + ... group_size=3, + ... guaranteed_group_size=2, + ... ) >>> list(dp2) [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']] """ diff --git a/torch/utils/data/datapipes/map/utils.py b/torch/utils/data/datapipes/map/utils.py index 02865e8064f..e1290df3237 100644 --- a/torch/utils/data/datapipes/map/utils.py +++ b/torch/utils/data/datapipes/map/utils.py @@ -31,8 +31,8 @@ class SequenceWrapperMapDataPipe(MapDataPipe[_T]): >>> dp = SequenceWrapper(range(10)) >>> list(dp) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] - >>> dp = SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400}) - >>> dp['a'] + >>> dp = SequenceWrapper({"a": 100, "b": 200, "c": 300, "d": 400}) + >>> dp["a"] 100 """ diff --git a/torch/utils/data/datapipes/utils/decoder.py b/torch/utils/data/datapipes/utils/decoder.py index ee5bee8f152..9db7309bdc5 100644 --- a/torch/utils/data/datapipes/utils/decoder.py +++ b/torch/utils/data/datapipes/utils/decoder.py @@ -45,8 +45,8 @@ def basichandlers(extension: str, data): Example: >>> import pickle - >>> data = pickle.dumps('some data') - >>> new_data = basichandlers('pickle', data) + >>> data = pickle.dumps("some data") + >>> new_data = basichandlers("pickle", data) >>> new_data some data @@ -169,9 +169,9 @@ class ImageHandler: """ def __init__(self, imagespec): - assert imagespec in list( - imagespecs.keys() - ), f"unknown image specification: {imagespec}" + assert imagespec in list(imagespecs.keys()), ( + f"unknown image specification: {imagespec}" + ) self.imagespec = imagespec.lower() def __call__(self, extension, data): @@ -205,18 +205,18 @@ class ImageHandler: return img elif atype == "numpy": result = np.asarray(img) - assert ( - result.dtype == np.uint8 - ), f"numpy image array should be type uint8, but got {result.dtype}" + assert result.dtype == np.uint8, ( + f"numpy image array should be type uint8, but got {result.dtype}" + ) if etype == "uint8": return result else: return result.astype("f") / 255.0 elif atype == "torch": result = np.asarray(img) - assert ( - result.dtype == np.uint8 - ), f"numpy image array should be type uint8, but got {result.dtype}" + assert result.dtype == np.uint8, ( + f"numpy image array should be type uint8, but got {result.dtype}" + ) if etype == "uint8": result = np.array(result.transpose(2, 0, 1)) diff --git a/torch/utils/data/dataset.py b/torch/utils/data/dataset.py index d0234c553ce..e8164e015a6 100644 --- a/torch/utils/data/dataset.py +++ b/torch/utils/data/dataset.py @@ -96,7 +96,7 @@ class IterableDataset(Dataset[_T_co], Iterable[_T_co]): >>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() - ... assert end > start, "this example code only works with end >= start" + ... assert end > start, "this example only works with end >= start" ... self.start = start ... self.end = end ... @@ -138,7 +138,7 @@ class IterableDataset(Dataset[_T_co], Iterable[_T_co]): >>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() - ... assert end > start, "this example code only works with end >= start" + ... assert end > start, "this example only works with end >= start" ... self.start = start ... self.end = end ... @@ -198,9 +198,9 @@ class TensorDataset(Dataset[tuple[Tensor, ...]]): tensors: tuple[Tensor, ...] def __init__(self, *tensors: Tensor) -> None: - assert all( - tensors[0].size(0) == tensor.size(0) for tensor in tensors - ), "Size mismatch between tensors" + assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), ( + "Size mismatch between tensors" + ) self.tensors = tensors def __getitem__(self, index): @@ -222,7 +222,7 @@ class StackDataset(Dataset[_T_stack]): >>> tuple_stack = StackDataset(images, texts) >>> tuple_stack[0] == (images[0], texts[0]) >>> dict_stack = StackDataset(image=images, text=texts) - >>> dict_stack[0] == {'image': images[0], 'text': texts[0]} + >>> dict_stack[0] == {"image": images[0], "text": texts[0]} Args: *args (Dataset): Datasets for stacking returned as tuple. @@ -323,9 +323,9 @@ class ConcatDataset(Dataset[_T_co]): self.datasets = list(datasets) assert len(self.datasets) > 0, "datasets should not be an empty iterable" # type: ignore[arg-type] for d in self.datasets: - assert not isinstance( - d, IterableDataset - ), "ConcatDataset does not support IterableDataset" + assert not isinstance(d, IterableDataset), ( + "ConcatDataset does not support IterableDataset" + ) self.cumulative_sizes = self.cumsum(self.datasets) def __len__(self): @@ -371,17 +371,17 @@ class ChainDataset(IterableDataset): def __iter__(self): for d in self.datasets: - assert isinstance( - d, IterableDataset - ), "ChainDataset only supports IterableDataset" + assert isinstance(d, IterableDataset), ( + "ChainDataset only supports IterableDataset" + ) yield from d def __len__(self): total = 0 for d in self.datasets: - assert isinstance( - d, IterableDataset - ), "ChainDataset only supports IterableDataset" + assert isinstance(d, IterableDataset), ( + "ChainDataset only supports IterableDataset" + ) total += len(d) # type: ignore[arg-type] return total diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py index c92bdbb00e1..6c2e6dcaf2f 100644 --- a/torch/utils/data/sampler.py +++ b/torch/utils/data/sampler.py @@ -236,9 +236,17 @@ class WeightedRandomSampler(Sampler[int]): Example: >>> # xdoctest: +IGNORE_WANT("non-deterministic") - >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) + >>> list( + ... WeightedRandomSampler( + ... [0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True + ... ) + ... ) [4, 4, 1, 4, 5] - >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) + >>> list( + ... WeightedRandomSampler( + ... [0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False + ... ) + ... ) [0, 1, 4, 3, 2] """ @@ -298,9 +306,15 @@ class BatchSampler(Sampler[list[int]]): its size would be less than ``batch_size`` Example: - >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) + >>> list( + ... BatchSampler( + ... SequentialSampler(range(10)), batch_size=3, drop_last=False + ... ) + ... ) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] - >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) + >>> list( + ... BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True) + ... ) [[0, 1, 2], [3, 4, 5], [6, 7, 8]] """ diff --git a/torch/utils/module_tracker.py b/torch/utils/module_tracker.py index 8ac97f2e2e8..4c7dec04815 100644 --- a/torch/utils/module_tracker.py +++ b/torch/utils/module_tracker.py @@ -49,6 +49,7 @@ class ModuleTracker: def my_linear(m1, m2, bias): print(f"Current modules: {tracker.parents}") return torch.mm(m1, m2.t()) + bias + torch.nn.functional.linear = my_linear mod(torch.rand(2, 2)) diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py index 23e3a25c90f..9a4ade5e71e 100644 --- a/torch/xpu/__init__.py +++ b/torch/xpu/__init__.py @@ -6,6 +6,7 @@ Intel GPU optimization. This package is lazily initialized, so you can always import it, and use :func:`is_available()` to determine if your system supports XPU. """ + import threading import traceback from functools import lru_cache @@ -292,6 +293,7 @@ class StreamContext: ``None``. .. note:: Streams are per-device. """ + cur_stream: Optional["torch.xpu.Stream"] def __init__(self, stream: Optional["torch.xpu.Stream"]): @@ -438,7 +440,7 @@ def get_gencode_flags() -> str: arch_list = get_arch_list() if len(arch_list) == 0: return "" - return f'-device {",".join(arch for arch in arch_list)}' + return f"-device {','.join(arch for arch in arch_list)}" def _get_generator(device: torch.device) -> torch._C.Generator: