[BE][PYFMT] migrate PYFMT for torch/[p-z]*/ to ruff format (#144552)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144552
Approved by: https://github.com/ezyang
This commit is contained in:
Xuehai Pan 2025-08-06 20:57:29 +00:00 committed by PyTorch MergeBot
parent fd606a3a91
commit 5cedc5a0ff
65 changed files with 446 additions and 522 deletions

View File

@ -52,7 +52,6 @@ USE_BLACK_FILELIST = re.compile(
# torch/[e-m]*/** # torch/[e-m]*/**
# torch/optim/** # torch/optim/**
# torch/[p-z]*/** # torch/[p-z]*/**
"torch/[p-z]*/**",
], ],
), ),
) )

View File

@ -2,6 +2,7 @@
"""Import mangling. """Import mangling.
See mangling.md for details. See mangling.md for details.
""" """
import re import re

View File

@ -605,9 +605,9 @@ class PackageExporter:
dependencies (bool, optional): If ``True``, we scan the source for dependencies. dependencies (bool, optional): If ``True``, we scan the source for dependencies.
""" """
assert (pickle_protocol == 4) or ( assert (pickle_protocol == 4) or (pickle_protocol == 3), (
pickle_protocol == 3 "torch.package only supports pickle protocols 3 and 4"
), "torch.package only supports pickle protocols 3 and 4" )
filename = self._filename(package, resource) filename = self._filename(package, resource)
# Write the pickle data for `obj` # Write the pickle data for `obj`

View File

@ -423,7 +423,12 @@ class PackageImporter(Importer):
module.__dict__.setdefault(old_name, new_name) module.__dict__.setdefault(old_name, new_name)
return module return module
return self._make_module(name, cur.source_file, isinstance(cur, _PackageNode), parent) # type: ignore[attr-defined] return self._make_module(
name,
cur.source_file, # type: ignore[attr-defined]
isinstance(cur, _PackageNode),
parent,
)
def _compile_source(self, fullpath: str, mangled_filename: str): def _compile_source(self, fullpath: str, mangled_filename: str):
source = self.zip_reader.get_record(fullpath) source = self.zip_reader.get_record(fullpath)

View File

@ -7,6 +7,7 @@ examine their input shapes and stack traces, study device kernel activity and vi
An earlier version of the API in :mod:`torch.autograd` module is considered legacy and will be deprecated. An earlier version of the API in :mod:`torch.autograd` module is considered legacy and will be deprecated.
""" """
import os import os
from typing import Any from typing import Any
from typing_extensions import TypeVarTuple, Unpack from typing_extensions import TypeVarTuple, Unpack

View File

@ -239,10 +239,12 @@ class SchemaMatcher:
def match_schemas(cls, t: _ExtraFields_TorchOp) -> tuple[FunctionSchema, ...]: def match_schemas(cls, t: _ExtraFields_TorchOp) -> tuple[FunctionSchema, ...]:
signature = tuple( signature = tuple(
# Tensor # Tensor
TensorKey.from_tensor(i) if isinstance(i, _TensorMetadata) TensorKey.from_tensor(i)
if isinstance(i, _TensorMetadata)
# #
# TensorList # TensorList
else [TensorKey.from_tensor(j) for j in i] if isinstance(i, list) else [TensorKey.from_tensor(j) for j in i]
if isinstance(i, list)
# #
# Scalar and uncaptured inputs. # Scalar and uncaptured inputs.
else i else i

View File

@ -124,9 +124,9 @@ class BasicEvaluation:
for child_event in curr_event.children: for child_event in curr_event.children:
self_time -= child_event.duration_time_ns self_time -= child_event.duration_time_ns
stack.append(child_event) stack.append(child_event)
assert ( assert EventKey(curr_event) not in self.metrics, (
EventKey(curr_event) not in self.metrics f"Duplicate id: {curr_event.id}, {curr_event.name}"
), f"Duplicate id: {curr_event.id}, {curr_event.name}" )
self.metrics[EventKey(curr_event)] = EventMetrics(self_time_ns=self_time) self.metrics[EventKey(curr_event)] = EventMetrics(self_time_ns=self_time)
self.metrics[ self.metrics[
EventKey(curr_event) EventKey(curr_event)
@ -227,8 +227,7 @@ class BasicEvaluation:
while ( while (
current_kernel_index < len(cuda_kernel_events) current_kernel_index < len(cuda_kernel_events)
and (cuda_kernel_events[current_kernel_index].start_ns()) and (cuda_kernel_events[current_kernel_index].start_ns()) <= start_time # type: ignore[possibly-undefined]
<= start_time # type: ignore[possibly-undefined]
): ):
current_kernel_index += 1 current_kernel_index += 1
current_queue_depth = spawned_kernel_index - current_kernel_index + 1 current_queue_depth = spawned_kernel_index - current_kernel_index + 1
@ -352,11 +351,11 @@ class BasicEvaluation:
output += "\n".join( output += "\n".join(
[ [
f"""{'-' * 80} f"""{"-" * 80}
Event: {event} Event: {event}
Source code location: {source_code_location(event.event)} Source code location: {source_code_location(event.event)}
Percentage idle time: {self.metrics[event].fraction_idle_time * 100:.2f}% Percentage idle time: {self.metrics[event].fraction_idle_time * 100:.2f}%
{'-' * 80}""" {"-" * 80}"""
for event in event_list for event in event_list
] ]
) )

View File

@ -624,8 +624,7 @@ class profile(_KinetoProfile):
] ]
) as p: ) as p:
code_to_profile() code_to_profile()
print(p.key_averages().table( print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
sort_by="self_cuda_time_total", row_limit=-1))
Using the profiler's ``schedule``, ``on_trace_ready`` and ``step`` functions: Using the profiler's ``schedule``, ``on_trace_ready`` and ``step`` functions:
@ -635,16 +634,17 @@ class profile(_KinetoProfile):
# on different iterations of the training loop; # on different iterations of the training loop;
# trace_handler is called every time a new trace becomes available # trace_handler is called every time a new trace becomes available
def trace_handler(prof): def trace_handler(prof):
print(prof.key_averages().table( print(
sort_by="self_cuda_time_total", row_limit=-1)) prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)
)
# prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json") # prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json")
with torch.profiler.profile( with torch.profiler.profile(
activities=[ activities=[
torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA, torch.profiler.ProfilerActivity.CUDA,
], ],
# In this example with wait=1, warmup=1, active=2, repeat=1, # In this example with wait=1, warmup=1, active=2, repeat=1,
# profiler will skip the first step/iteration, # profiler will skip the first step/iteration,
# start warming up on the second, record # start warming up on the second, record
@ -652,20 +652,15 @@ class profile(_KinetoProfile):
# after which the trace will become available # after which the trace will become available
# and on_trace_ready (when set) is called; # and on_trace_ready (when set) is called;
# the cycle repeats starting with the next step # the cycle repeats starting with the next step
schedule=torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=1),
schedule=torch.profiler.schedule( on_trace_ready=trace_handler,
wait=1,
warmup=1,
active=2,
repeat=1),
on_trace_ready=trace_handler
# on_trace_ready=torch.profiler.tensorboard_trace_handler('./log') # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log')
# used when outputting for tensorboard # used when outputting for tensorboard
) as p: ) as p:
for iter in range(N): for iter in range(N):
code_iteration_to_profile(iter) code_iteration_to_profile(iter)
# send a signal to the profiler that the next iteration has started # send a signal to the profiler that the next iteration has started
p.step() p.step()
The following sample shows how to setup up an Execution Trace Observer (`execution_trace_observer`) The following sample shows how to setup up an Execution Trace Observer (`execution_trace_observer`)

View File

@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
`torch/ao/quantization/fuser_method_mappings.py`, while adding an import statement `torch/ao/quantization/fuser_method_mappings.py`, while adding an import statement
here. here.
""" """
from torch.ao.quantization.fuser_method_mappings import ( from torch.ao.quantization.fuser_method_mappings import (
_DEFAULT_OP_LIST_TO_FUSER_METHOD, _DEFAULT_OP_LIST_TO_FUSER_METHOD,
fuse_conv_bn, fuse_conv_bn,

View File

@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
here. here.
""" """
from torch.ao.quantization.fx._equalize import ( from torch.ao.quantization.fx._equalize import (
_convert_equalization_ref, _convert_equalization_ref,
_InputEqualizationObserver, _InputEqualizationObserver,

View File

@ -6,4 +6,5 @@ If you are adding a new entry/functionality, please, add it to the
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
here. here.
""" """
from torch.ao.quantization.fx.convert import convert from torch.ao.quantization.fx.convert import convert

View File

@ -6,4 +6,5 @@ If you are adding a new entry/functionality, please, add it to the
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
here. here.
""" """
from torch.ao.quantization.fx.fuse import fuse from torch.ao.quantization.fx.fuse import fuse

View File

@ -6,4 +6,5 @@ If you are adding a new entry/functionality, please, add it to the
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
here. here.
""" """
from torch.ao.quantization.fx.fuse_handler import DefaultFuseHandler, FuseHandler from torch.ao.quantization.fx.fuse_handler import DefaultFuseHandler, FuseHandler

View File

@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
here. here.
""" """
from torch.ao.quantization.fx.graph_module import ( from torch.ao.quantization.fx.graph_module import (
_is_observed_module, _is_observed_module,
_is_observed_standalone_module, _is_observed_standalone_module,

View File

@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
here. here.
""" """
from torch.ao.quantization.fx.match_utils import ( from torch.ao.quantization.fx.match_utils import (
_find_matches, _find_matches,
_is_match, _is_match,

View File

@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
here. here.
""" """
from torch.ao.quantization.fx.pattern_utils import ( from torch.ao.quantization.fx.pattern_utils import (
_register_fusion_pattern, _register_fusion_pattern,
_register_quant_pattern, _register_quant_pattern,

View File

@ -6,4 +6,5 @@ If you are adding a new entry/functionality, please, add it to the
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
here. here.
""" """
from torch.ao.quantization.fx.prepare import prepare from torch.ao.quantization.fx.prepare import prepare

View File

@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
here. here.
""" """
from torch.ao.quantization.fx.quantize_handler import ( from torch.ao.quantization.fx.quantize_handler import (
BatchNormQuantizeHandler, BatchNormQuantizeHandler,
BinaryOpQuantizeHandler, BinaryOpQuantizeHandler,

View File

@ -6,4 +6,5 @@ If you are adding a new entry/functionality, please, add it to the
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
here. here.
""" """
from torch.ao.quantization.utils import Pattern, QuantizerCls from torch.ao.quantization.utils import Pattern, QuantizerCls

View File

@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
here. here.
""" """
from torch.ao.quantization.fx.utils import ( from torch.ao.quantization.fx.utils import (
all_node_args_have_no_tensors, all_node_args_have_no_tensors,
assert_and_get_unique_device, assert_and_get_unique_device,

View File

@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
`torch/ao/quantization/observer.py`, while adding an import statement `torch/ao/quantization/observer.py`, while adding an import statement
here. here.
""" """
from torch.ao.quantization.observer import ( from torch.ao.quantization.observer import (
_is_activation_post_process, _is_activation_post_process,
_is_per_channel_script_obs_instance, _is_per_channel_script_obs_instance,

View File

@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
`torch/ao/quantization/qconfig.py`, while adding an import statement `torch/ao/quantization/qconfig.py`, while adding an import statement
here. here.
""" """
from torch.ao.quantization.qconfig import ( from torch.ao.quantization.qconfig import (
_add_module_to_qconfig_obs_ctr, _add_module_to_qconfig_obs_ctr,
_assert_valid_qconfig, _assert_valid_qconfig,

View File

@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
`torch/ao/quantization/quantization_mappings.py`, while adding an import statement `torch/ao/quantization/quantization_mappings.py`, while adding an import statement
here. here.
""" """
from torch.ao.quantization.quantization_mappings import ( from torch.ao.quantization.quantization_mappings import (
_get_special_act_post_process, _get_special_act_post_process,
_has_special_act_post_process, _has_special_act_post_process,

View File

@ -128,9 +128,7 @@ Examples::
>>> # Generates a periodic exponential window and decay factor equal to .5 >>> # Generates a periodic exponential window and decay factor equal to .5
>>> torch.signal.windows.exponential(10, sym=False,tau=.5) >>> torch.signal.windows.exponential(10, sym=False,tau=.5)
tensor([4.5400e-05, 3.3546e-04, 2.4788e-03, 1.8316e-02, 1.3534e-01, 1.0000e+00, 1.3534e-01, 1.8316e-02, 2.4788e-03, 3.3546e-04]) tensor([4.5400e-05, 3.3546e-04, 2.4788e-03, 1.8316e-02, 1.3534e-01, 1.0000e+00, 1.3534e-01, 1.8316e-02, 2.4788e-03, 3.3546e-04])
""".format( """.format(**window_common_args),
**window_common_args
),
) )
def exponential( def exponential(
M: int, M: int,
@ -452,9 +450,7 @@ Examples::
>>> # Generates a periodic Hamming window. >>> # Generates a periodic Hamming window.
>>> torch.signal.windows.hamming(10, sym=False) >>> torch.signal.windows.hamming(10, sym=False)
tensor([0.0800, 0.1679, 0.3979, 0.6821, 0.9121, 1.0000, 0.9121, 0.6821, 0.3979, 0.1679]) tensor([0.0800, 0.1679, 0.3979, 0.6821, 0.9121, 1.0000, 0.9121, 0.6821, 0.3979, 0.1679])
""".format( """.format(**window_common_args),
**window_common_args
),
) )
def hamming( def hamming(
M: int, M: int,
@ -508,9 +504,7 @@ Examples::
>>> # Generates a periodic Hann window. >>> # Generates a periodic Hann window.
>>> torch.signal.windows.hann(10, sym=False) >>> torch.signal.windows.hann(10, sym=False)
tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955]) tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955])
""".format( """.format(**window_common_args),
**window_common_args
),
) )
def hann( def hann(
M: int, M: int,
@ -564,9 +558,7 @@ Examples::
>>> # Generates a periodic Blackman window. >>> # Generates a periodic Blackman window.
>>> torch.signal.windows.blackman(5, sym=False) >>> torch.signal.windows.blackman(5, sym=False)
tensor([-1.4901e-08, 2.0077e-01, 8.4923e-01, 8.4923e-01, 2.0077e-01]) tensor([-1.4901e-08, 2.0077e-01, 8.4923e-01, 8.4923e-01, 2.0077e-01])
""".format( """.format(**window_common_args),
**window_common_args
),
) )
def blackman( def blackman(
M: int, M: int,
@ -627,9 +619,7 @@ Examples::
>>> # Generates a periodic Bartlett window. >>> # Generates a periodic Bartlett window.
>>> torch.signal.windows.bartlett(10, sym=False) >>> torch.signal.windows.bartlett(10, sym=False)
tensor([0.0000, 0.2000, 0.4000, 0.6000, 0.8000, 1.0000, 0.8000, 0.6000, 0.4000, 0.2000]) tensor([0.0000, 0.2000, 0.4000, 0.6000, 0.8000, 1.0000, 0.8000, 0.6000, 0.4000, 0.2000])
""".format( """.format(**window_common_args),
**window_common_args
),
) )
def bartlett( def bartlett(
M: int, M: int,
@ -704,9 +694,7 @@ Examples::
>>> # Generates a periodic general cosine window with 2 coefficients. >>> # Generates a periodic general cosine window with 2 coefficients.
>>> torch.signal.windows.general_cosine(10, a=[0.5, 1 - 0.5], sym=False) >>> torch.signal.windows.general_cosine(10, a=[0.5, 1 - 0.5], sym=False)
tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955]) tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955])
""".format( """.format(**window_common_args),
**window_common_args
),
) )
def general_cosine( def general_cosine(
M, M,
@ -799,9 +787,7 @@ Examples::
>>> # Generates a periodic Hann window with the general Hamming window. >>> # Generates a periodic Hann window with the general Hamming window.
>>> torch.signal.windows.general_hamming(10, alpha=0.5, sym=False) >>> torch.signal.windows.general_hamming(10, alpha=0.5, sym=False)
tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955]) tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955])
""".format( """.format(**window_common_args),
**window_common_args
),
) )
def general_hamming( def general_hamming(
M, M,
@ -866,9 +852,7 @@ Examples::
>>> # Generates a periodic Nuttall window. >>> # Generates a periodic Nuttall window.
>>> torch.signal.windows.general_hamming(5, sym=False) >>> torch.signal.windows.general_hamming(5, sym=False)
tensor([3.6280e-04, 1.1052e-01, 7.9826e-01, 7.9826e-01, 1.1052e-01]) tensor([3.6280e-04, 1.1052e-01, 7.9826e-01, 7.9826e-01, 1.1052e-01])
""".format( """.format(**window_common_args),
**window_common_args
),
) )
def nuttall( def nuttall(
M: int, M: int,

View File

@ -559,7 +559,11 @@ def as_sparse_gradcheck(gradcheck):
For example: For example:
>>> gradcheck = torch.sparse.as_sparse_gradcheck(torch.autograd.gradcheck) >>> gradcheck = torch.sparse.as_sparse_gradcheck(torch.autograd.gradcheck)
>>> x = torch.tensor([[0, 1], [2, 3]], dtype=torch.float64).to_sparse_coo().requires_grad_(True) >>> x = (
... torch.tensor([[0, 1], [2, 3]], dtype=torch.float64)
... .to_sparse_coo()
... .requires_grad_(True)
... )
>>> gradcheck(lambda x: x.to_sparse_csr(), x) >>> gradcheck(lambda x: x.to_sparse_csr(), x)
True True
""" """
@ -667,7 +671,7 @@ def as_sparse_gradcheck(gradcheck):
) )
else: else:
raise NotImplementedError( raise NotImplementedError(
f'conversion of {d["layout"]} strided representation to tensor' f"conversion of {d['layout']} strided representation to tensor"
) )
new_args.append(a) new_args.append(a)
return tuple(new_args) return tuple(new_args)

View File

@ -296,11 +296,11 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None):
for b in range(nbatches): for b in range(nbatches):
for i, r in enumerate(r_offsets): for i, r in enumerate(r_offsets):
r0, r1 = divmod(r, N) r0, r1 = divmod(r, N)
acc = accumulators[b, r0:r0 + Ms, r1:r1 + Ns] acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns]
for g in range(c_indices[i], c_indices[i+1]): for g in range(c_indices[i], c_indices[i + 1]):
p = p_offsets[g] p = p_offsets[g]
q0, q1 = divmod(q_offsets[g], N) q0, q1 = divmod(q_offsets[g], N)
acc += blocks[p] @ others[b, q0:q0 + Ks, q1:q1 + Ns] acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns]
where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are
integer multiples of ``Ms`` and ``Ks``, respectively. integer multiples of ``Ms`` and ``Ks``, respectively.
@ -320,11 +320,11 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None):
n = (r % N) // Ns n = (r % N) // Ns
r0, r1 = divmod(r, N) r0, r1 = divmod(r, N)
c0, c1 = c_indices[m], c_indices[m + 1] c0, c1 = c_indices[m], c_indices[m + 1]
acc = accumulators[b, r0:r0 + Ms, r1:r1 + Ns] acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns]
for i, p in enumerate(range(c0, c1)): for i, p in enumerate(range(c0, c1)):
q = q_offsets[n * c1 + (SPLIT_N - n) * c0 + i] q = q_offsets[n * c1 + (SPLIT_N - n) * c0 + i]
q0, q1 = divmod(q, N) q0, q1 = divmod(q, N)
acc += blocks[p] @ others[b, q0:q0 + Ks, q1:q1 + Ns] acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns]
where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are
integer multiples of ``Ms`` and ``Ks``, respectively. integer multiples of ``Ms`` and ``Ks``, respectively.

View File

@ -97,6 +97,7 @@ tune_bsr_dense_addmm to learn how to register a custom set of optimal
kernel parameters for addmm-based operations. kernel parameters for addmm-based operations.
""" """
__all__ = ["get_meta", "tune_bsr_dense_addmm", "tune__int_bsr_dense_addmm"] __all__ = ["get_meta", "tune_bsr_dense_addmm", "tune__int_bsr_dense_addmm"]
import inspect import inspect
@ -432,9 +433,9 @@ def minimize(
def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device): def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device):
assert ( assert sparsity <= 1.0 and sparsity >= 0.0, (
sparsity <= 1.0 and sparsity >= 0.0 "sparsity should be a value between 0 and 1"
), "sparsity should be a value between 0 and 1" )
assert M % blocksize[0] == 0 assert M % blocksize[0] == 0
assert N % blocksize[1] == 0 assert N % blocksize[1] == 0
shape = (B, M // blocksize[0], N // blocksize[1])[int(B == 0) :] shape = (B, M // blocksize[0], N // blocksize[1])[int(B == 0) :]

View File

@ -465,14 +465,26 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
The equivalent PyTorch code to create the same five outputs from the dense tensor can be found below: The equivalent PyTorch code to create the same five outputs from the dense tensor can be found below:
``` ```
from torch.sparse import SparseSemiStructuredTensorCUTLASS from torch.sparse import SparseSemiStructuredTensorCUTLASS
from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask from torch.sparse._semi_structured_conversions import (
_sparse_semi_structured_tile,
_compute_compressed_swizzled_bitmask,
)
pruned = _sparse_semi_structured_tile(dense) pruned = _sparse_semi_structured_tile(dense)
packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned) packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned)
packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous()) packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(
pruned.t().contiguous()
)
bitmask = _compute_compressed_swizzled_bitmask(pruned) bitmask = _compute_compressed_swizzled_bitmask(pruned)
SparseSemiStructuredTensorCUTLASS(dense.shape, packed_cutlass, meta_cutlass, packed_t_cutlass, meta_t_cutlass, bitmask) SparseSemiStructuredTensorCUTLASS(
dense.shape,
packed_cutlass,
meta_cutlass,
packed_t_cutlass,
meta_t_cutlass,
bitmask,
)
``` ```
""" """
# We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag. # We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag.
@ -583,14 +595,19 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
The equivalent PyTorch code to create the same three outputs from the dense tensor can be found below: The equivalent PyTorch code to create the same three outputs from the dense tensor can be found below:
``` ```
from torch.sparse import SparseSemiStructuredTensorCUSPARSELT from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask from torch.sparse._semi_structured_conversions import (
_sparse_semi_structured_tile,
_compute_compressed_swizzled_bitmask,
)
pruned = _sparse_semi_structured_tile(dense) pruned = _sparse_semi_structured_tile(dense)
packed_cusparselt = torch._cslt_compress(pruned) packed_cusparselt = torch._cslt_compress(pruned)
packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous()) packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous())
bitmask = _compute_compressed_swizzled_bitmask(pruned) bitmask = _compute_compressed_swizzled_bitmask(pruned)
SparseSemiStructuredTensorCUSPARSELT(dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask) SparseSemiStructuredTensorCUSPARSELT(
dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask
)
``` ```
""" """
( (

View File

@ -134,9 +134,7 @@ Example::
>>> torch.special.digamma(a) >>> torch.special.digamma(a)
tensor([-0.5772, -1.9635]) tensor([-0.5772, -1.9635])
""".format( """.format(**common_args),
**common_args
),
) )
gammaln = _add_docstr( gammaln = _add_docstr(
@ -162,9 +160,7 @@ Example::
>>> torch.special.gammaln(a) >>> torch.special.gammaln(a)
tensor([ 0.5724, 0.0000, -0.1208]) tensor([ 0.5724, 0.0000, -0.1208])
""".format( """.format(**common_args),
**common_args
),
) )
polygamma = _add_docstr( polygamma = _add_docstr(
@ -200,9 +196,7 @@ Example::
tensor([ 6.4939, 97.4091]) tensor([ 6.4939, 97.4091])
>>> torch.special.polygamma(4, a) >>> torch.special.polygamma(4, a)
tensor([ -24.8863, -771.4742]) tensor([ -24.8863, -771.4742])
""".format( """.format(**common_args),
**common_args
),
) )
erf = _add_docstr( erf = _add_docstr(
@ -226,9 +220,7 @@ Example::
>>> torch.special.erf(torch.tensor([0, -1., 10.])) >>> torch.special.erf(torch.tensor([0, -1., 10.]))
tensor([ 0.0000, -0.8427, 1.0000]) tensor([ 0.0000, -0.8427, 1.0000])
""".format( """.format(**common_args),
**common_args
),
) )
erfc = _add_docstr( erfc = _add_docstr(
@ -253,9 +245,7 @@ Example::
>>> torch.special.erfc(torch.tensor([0, -1., 10.])) >>> torch.special.erfc(torch.tensor([0, -1., 10.]))
tensor([ 1.0000, 1.8427, 0.0000]) tensor([ 1.0000, 1.8427, 0.0000])
""".format( """.format(**common_args),
**common_args
),
) )
erfcx = _add_docstr( erfcx = _add_docstr(
@ -283,9 +273,7 @@ Example::
>>> torch.special.erfcx(torch.tensor([0, -1., 10.])) >>> torch.special.erfcx(torch.tensor([0, -1., 10.]))
tensor([ 1.0000, 5.0090, 0.0561]) tensor([ 1.0000, 5.0090, 0.0561])
""".format( """.format(**common_args),
**common_args
),
) )
erfinv = _add_docstr( erfinv = _add_docstr(
@ -311,9 +299,7 @@ Example::
>>> torch.special.erfinv(torch.tensor([0, 0.5, -1.])) >>> torch.special.erfinv(torch.tensor([0, 0.5, -1.]))
tensor([ 0.0000, 0.4769, -inf]) tensor([ 0.0000, 0.4769, -inf])
""".format( """.format(**common_args),
**common_args
),
) )
logit = _add_docstr( logit = _add_docstr(
@ -351,9 +337,7 @@ Example::
tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516]) tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516])
>>> torch.special.logit(a, eps=1e-6) >>> torch.special.logit(a, eps=1e-6)
tensor([-0.9466, 2.6352, 0.6131, -1.7169, 0.6261]) tensor([-0.9466, 2.6352, 0.6131, -1.7169, 0.6261])
""".format( """.format(**common_args),
**common_args
),
) )
logsumexp = _add_docstr( logsumexp = _add_docstr(
@ -362,9 +346,7 @@ logsumexp = _add_docstr(
logsumexp(input, dim, keepdim=False, *, out=None) logsumexp(input, dim, keepdim=False, *, out=None)
Alias for :func:`torch.logsumexp`. Alias for :func:`torch.logsumexp`.
""".format( """.format(**multi_dim_common),
**multi_dim_common
),
) )
expit = _add_docstr( expit = _add_docstr(
@ -391,9 +373,7 @@ Example::
tensor([ 0.9213, 1.0887, -0.8858, -1.7683]) tensor([ 0.9213, 1.0887, -0.8858, -1.7683])
>>> torch.special.expit(t) >>> torch.special.expit(t)
tensor([ 0.7153, 0.7481, 0.2920, 0.1458]) tensor([ 0.7153, 0.7481, 0.2920, 0.1458])
""".format( """.format(**common_args),
**common_args
),
) )
exp2 = _add_docstr( exp2 = _add_docstr(
@ -418,9 +398,7 @@ Example::
>>> torch.special.exp2(torch.tensor([0, math.log2(2.), 3, 4])) >>> torch.special.exp2(torch.tensor([0, math.log2(2.), 3, 4]))
tensor([ 1., 2., 8., 16.]) tensor([ 1., 2., 8., 16.])
""".format( """.format(**common_args),
**common_args
),
) )
expm1 = _add_docstr( expm1 = _add_docstr(
@ -448,9 +426,7 @@ Example::
>>> torch.special.expm1(torch.tensor([0, math.log(2.)])) >>> torch.special.expm1(torch.tensor([0, math.log(2.)]))
tensor([ 0., 1.]) tensor([ 0., 1.])
""".format( """.format(**common_args),
**common_args
),
) )
xlog1py = _add_docstr( xlog1py = _add_docstr(
@ -495,9 +471,7 @@ Example::
tensor([1.6094, 3.2189, 4.8283]) tensor([1.6094, 3.2189, 4.8283])
>>> torch.special.xlog1py(2, y) >>> torch.special.xlog1py(2, y)
tensor([2.7726, 2.1972, 1.3863]) tensor([2.7726, 2.1972, 1.3863])
""".format( """.format(**common_args),
**common_args
),
) )
xlogy = _add_docstr( xlogy = _add_docstr(
@ -542,9 +516,7 @@ Example::
tensor([1.3863, 2.7726, 4.1589]) tensor([1.3863, 2.7726, 4.1589])
>>> torch.special.xlogy(2, y) >>> torch.special.xlogy(2, y)
tensor([2.1972, 1.3863, 0.0000]) tensor([2.1972, 1.3863, 0.0000])
""".format( """.format(**common_args),
**common_args
),
) )
i0 = _add_docstr( i0 = _add_docstr(
@ -570,9 +542,7 @@ Example::
>>> torch.i0(torch.arange(5, dtype=torch.float32)) >>> torch.i0(torch.arange(5, dtype=torch.float32))
tensor([ 1.0000, 1.2661, 2.2796, 4.8808, 11.3019]) tensor([ 1.0000, 1.2661, 2.2796, 4.8808, 11.3019])
""".format( """.format(**common_args),
**common_args
),
) )
i0e = _add_docstr( i0e = _add_docstr(
@ -597,9 +567,7 @@ Example::
>>> torch.special.i0e(torch.arange(5, dtype=torch.float32)) >>> torch.special.i0e(torch.arange(5, dtype=torch.float32))
tensor([1.0000, 0.4658, 0.3085, 0.2430, 0.2070]) tensor([1.0000, 0.4658, 0.3085, 0.2430, 0.2070])
""".format( """.format(**common_args),
**common_args
),
) )
i1 = _add_docstr( i1 = _add_docstr(
@ -624,9 +592,7 @@ Example::
>>> torch.special.i1(torch.arange(5, dtype=torch.float32)) >>> torch.special.i1(torch.arange(5, dtype=torch.float32))
tensor([0.0000, 0.5652, 1.5906, 3.9534, 9.7595]) tensor([0.0000, 0.5652, 1.5906, 3.9534, 9.7595])
""".format( """.format(**common_args),
**common_args
),
) )
i1e = _add_docstr( i1e = _add_docstr(
@ -652,9 +618,7 @@ Example::
>>> torch.special.i1e(torch.arange(5, dtype=torch.float32)) >>> torch.special.i1e(torch.arange(5, dtype=torch.float32))
tensor([0.0000, 0.2079, 0.2153, 0.1968, 0.1788]) tensor([0.0000, 0.2079, 0.2153, 0.1968, 0.1788])
""".format( """.format(**common_args),
**common_args
),
) )
ndtr = _add_docstr( ndtr = _add_docstr(
@ -679,9 +643,7 @@ Example::
>>> torch.special.ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3])) >>> torch.special.ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3]))
tensor([0.0013, 0.0228, 0.1587, 0.5000, 0.8413, 0.9772, 0.9987]) tensor([0.0013, 0.0228, 0.1587, 0.5000, 0.8413, 0.9772, 0.9987])
""".format( """.format(**common_args),
**common_args
),
) )
ndtri = _add_docstr( ndtri = _add_docstr(
@ -709,9 +671,7 @@ Example::
>>> torch.special.ndtri(torch.tensor([0, 0.25, 0.5, 0.75, 1])) >>> torch.special.ndtri(torch.tensor([0, 0.25, 0.5, 0.75, 1]))
tensor([ -inf, -0.6745, 0.0000, 0.6745, inf]) tensor([ -inf, -0.6745, 0.0000, 0.6745, inf])
""".format( """.format(**common_args),
**common_args
),
) )
log_ndtr = _add_docstr( log_ndtr = _add_docstr(
@ -736,9 +696,7 @@ Example::
>>> torch.special.log_ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3])) >>> torch.special.log_ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3]))
tensor([-6.6077 -3.7832 -1.841 -0.6931 -0.1728 -0.023 -0.0014]) tensor([-6.6077 -3.7832 -1.841 -0.6931 -0.1728 -0.023 -0.0014])
""".format( """.format(**common_args),
**common_args
),
) )
log1p = _add_docstr( log1p = _add_docstr(
@ -779,9 +737,7 @@ Example::
tensor([ 0.2252, -0.2948, 1.0267, -1.1566]) tensor([ 0.2252, -0.2948, 1.0267, -1.1566])
>>> torch.special.sinc(t) >>> torch.special.sinc(t)
tensor([ 0.9186, 0.8631, -0.0259, -0.1300]) tensor([ 0.9186, 0.8631, -0.0259, -0.1300])
""".format( """.format(**common_args),
**common_args
),
) )
round = _add_docstr( round = _add_docstr(
@ -886,9 +842,7 @@ Example::
tensor([1.6449, 0.0823]) tensor([1.6449, 0.0823])
>>> torch.special.zeta(2, torch.tensor([1., 2.])) >>> torch.special.zeta(2, torch.tensor([1., 2.]))
tensor([1.6449, 0.6449]) tensor([1.6449, 0.6449])
""".format( """.format(**common_args),
**common_args
),
) )
multigammaln = _add_docstr( multigammaln = _add_docstr(
@ -925,9 +879,7 @@ Example::
>>> torch.special.multigammaln(a, 2) >>> torch.special.multigammaln(a, 2)
tensor([[0.3928, 0.4007, 0.7586], tensor([[0.3928, 0.4007, 0.7586],
[1.0311, 0.3901, 0.5049]]) [1.0311, 0.3901, 0.5049]])
""".format( """.format(**common_args),
**common_args
),
) )
gammainc = _add_docstr( gammainc = _add_docstr(
@ -976,9 +928,7 @@ Example::
>>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2) >>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2)
tensor([1., 1., 1.]) tensor([1., 1., 1.])
""".format( """.format(**common_args),
**common_args
),
) )
gammaincc = _add_docstr( gammaincc = _add_docstr(
@ -1026,9 +976,7 @@ Example::
>>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2) >>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2)
tensor([1., 1., 1.]) tensor([1., 1., 1.])
""".format( """.format(**common_args),
**common_args
),
) )
airy_ai = _add_docstr( airy_ai = _add_docstr(
@ -1045,9 +993,7 @@ Args:
Keyword args: Keyword args:
{out} {out}
""".format( """.format(**common_args),
**common_args
),
) )
bessel_j0 = _add_docstr( bessel_j0 = _add_docstr(
@ -1064,9 +1010,7 @@ Args:
Keyword args: Keyword args:
{out} {out}
""".format( """.format(**common_args),
**common_args
),
) )
bessel_j1 = _add_docstr( bessel_j1 = _add_docstr(
@ -1083,9 +1027,7 @@ Args:
Keyword args: Keyword args:
{out} {out}
""".format( """.format(**common_args),
**common_args
),
) )
bessel_y0 = _add_docstr( bessel_y0 = _add_docstr(
@ -1102,9 +1044,7 @@ Args:
Keyword args: Keyword args:
{out} {out}
""".format( """.format(**common_args),
**common_args
),
) )
bessel_y1 = _add_docstr( bessel_y1 = _add_docstr(
@ -1121,9 +1061,7 @@ Args:
Keyword args: Keyword args:
{out} {out}
""".format( """.format(**common_args),
**common_args
),
) )
chebyshev_polynomial_t = _add_docstr( chebyshev_polynomial_t = _add_docstr(
@ -1154,9 +1092,7 @@ Args:
Keyword args: Keyword args:
{out} {out}
""".format( """.format(**common_args),
**common_args
),
) )
chebyshev_polynomial_u = _add_docstr( chebyshev_polynomial_u = _add_docstr(
@ -1188,9 +1124,7 @@ Args:
Keyword args: Keyword args:
{out} {out}
""".format( """.format(**common_args),
**common_args
),
) )
chebyshev_polynomial_v = _add_docstr( chebyshev_polynomial_v = _add_docstr(
@ -1208,9 +1142,7 @@ Args:
Keyword args: Keyword args:
{out} {out}
""".format( """.format(**common_args),
**common_args
),
) )
chebyshev_polynomial_w = _add_docstr( chebyshev_polynomial_w = _add_docstr(
@ -1228,9 +1160,7 @@ Args:
Keyword args: Keyword args:
{out} {out}
""".format( """.format(**common_args),
**common_args
),
) )
hermite_polynomial_h = _add_docstr( hermite_polynomial_h = _add_docstr(
@ -1256,9 +1186,7 @@ Args:
Keyword args: Keyword args:
{out} {out}
""".format( """.format(**common_args),
**common_args
),
) )
hermite_polynomial_he = _add_docstr( hermite_polynomial_he = _add_docstr(
@ -1284,9 +1212,7 @@ Args:
Keyword args: Keyword args:
{out} {out}
""".format( """.format(**common_args),
**common_args
),
) )
laguerre_polynomial_l = _add_docstr( laguerre_polynomial_l = _add_docstr(
@ -1312,9 +1238,7 @@ Args:
Keyword args: Keyword args:
{out} {out}
""".format( """.format(**common_args),
**common_args
),
) )
legendre_polynomial_p = _add_docstr( legendre_polynomial_p = _add_docstr(
@ -1340,9 +1264,7 @@ Args:
Keyword args: Keyword args:
{out} {out}
""".format( """.format(**common_args),
**common_args
),
) )
modified_bessel_i0 = _add_docstr( modified_bessel_i0 = _add_docstr(
@ -1359,9 +1281,7 @@ Args:
Keyword args: Keyword args:
{out} {out}
""".format( """.format(**common_args),
**common_args
),
) )
modified_bessel_i1 = _add_docstr( modified_bessel_i1 = _add_docstr(
@ -1378,9 +1298,7 @@ Args:
Keyword args: Keyword args:
{out} {out}
""".format( """.format(**common_args),
**common_args
),
) )
modified_bessel_k0 = _add_docstr( modified_bessel_k0 = _add_docstr(
@ -1397,9 +1315,7 @@ Args:
Keyword args: Keyword args:
{out} {out}
""".format( """.format(**common_args),
**common_args
),
) )
modified_bessel_k1 = _add_docstr( modified_bessel_k1 = _add_docstr(
@ -1416,9 +1332,7 @@ Args:
Keyword args: Keyword args:
{out} {out}
""".format( """.format(**common_args),
**common_args
),
) )
scaled_modified_bessel_k0 = _add_docstr( scaled_modified_bessel_k0 = _add_docstr(
@ -1435,9 +1349,7 @@ Args:
Keyword args: Keyword args:
{out} {out}
""".format( """.format(**common_args),
**common_args
),
) )
scaled_modified_bessel_k1 = _add_docstr( scaled_modified_bessel_k1 = _add_docstr(
@ -1454,9 +1366,7 @@ Args:
Keyword args: Keyword args:
{out} {out}
""".format( """.format(**common_args),
**common_args
),
) )
shifted_chebyshev_polynomial_t = _add_docstr( shifted_chebyshev_polynomial_t = _add_docstr(
@ -1474,9 +1384,7 @@ Args:
Keyword args: Keyword args:
{out} {out}
""".format( """.format(**common_args),
**common_args
),
) )
shifted_chebyshev_polynomial_u = _add_docstr( shifted_chebyshev_polynomial_u = _add_docstr(
@ -1494,9 +1402,7 @@ Args:
Keyword args: Keyword args:
{out} {out}
""".format( """.format(**common_args),
**common_args
),
) )
shifted_chebyshev_polynomial_v = _add_docstr( shifted_chebyshev_polynomial_v = _add_docstr(
@ -1514,9 +1420,7 @@ Args:
Keyword args: Keyword args:
{out} {out}
""".format( """.format(**common_args),
**common_args
),
) )
shifted_chebyshev_polynomial_w = _add_docstr( shifted_chebyshev_polynomial_w = _add_docstr(
@ -1534,9 +1438,7 @@ Args:
Keyword args: Keyword args:
{out} {out}
""".format( """.format(**common_args),
**common_args
),
) )
spherical_bessel_j0 = _add_docstr( spherical_bessel_j0 = _add_docstr(
@ -1553,7 +1455,5 @@ Args:
Keyword args: Keyword args:
{out} {out}
""".format( """.format(**common_args),
**common_args
),
) )

View File

@ -1538,7 +1538,9 @@ def assert_close(
>>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> expected = torch.tensor([1.0, 2.0, 3.0])
>>> actual = torch.tensor([1.0, 4.0, 5.0]) >>> actual = torch.tensor([1.0, 4.0, 5.0])
>>> # The default error message can be overwritten. >>> # The default error message can be overwritten.
>>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!") >>> torch.testing.assert_close(
... actual, expected, msg="Argh, the tensors are not close!"
... )
Traceback (most recent call last): Traceback (most recent call last):
... ...
AssertionError: Argh, the tensors are not close! AssertionError: Argh, the tensors are not close!

View File

@ -115,11 +115,11 @@ def make_tensor(
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> from torch.testing import make_tensor >>> from torch.testing import make_tensor
>>> # Creates a float tensor with values in [-1, 1) >>> # Creates a float tensor with values in [-1, 1)
>>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1) >>> make_tensor((3,), device="cpu", dtype=torch.float32, low=-1, high=1)
>>> # xdoctest: +SKIP >>> # xdoctest: +SKIP
tensor([ 0.1205, 0.2282, -0.6380]) tensor([ 0.1205, 0.2282, -0.6380])
>>> # Creates a bool tensor on CUDA >>> # Creates a bool tensor on CUDA
>>> make_tensor((2, 2), device='cuda', dtype=torch.bool) >>> make_tensor((2, 2), device="cuda", dtype=torch.bool)
tensor([[False, False], tensor([[False, False],
[False, True]], device='cuda:0') [False, True]], device='cuda:0')
""" """

View File

@ -721,9 +721,9 @@ def filter_desired_device_types(device_type_test_bases, except_for=None, only_fo
intersect = set(except_for if except_for else []) & set( intersect = set(except_for if except_for else []) & set(
only_for if only_for else [] only_for if only_for else []
) )
assert ( assert not intersect, (
not intersect f"device ({intersect}) appeared in both except_for and only_for"
), f"device ({intersect}) appeared in both except_for and only_for" )
# Replace your privateuse1 backend name with 'privateuse1' # Replace your privateuse1 backend name with 'privateuse1'
if is_privateuse1_backend_available(): if is_privateuse1_backend_available():
@ -1407,9 +1407,9 @@ class deviceCountAtLeast:
self.num_required_devices = num_required_devices self.num_required_devices = num_required_devices
def __call__(self, fn): def __call__(self, fn):
assert not hasattr( assert not hasattr(fn, "num_required_devices"), (
fn, "num_required_devices" f"deviceCountAtLeast redefinition for {fn.__name__}"
), f"deviceCountAtLeast redefinition for {fn.__name__}" )
fn.num_required_devices = self.num_required_devices fn.num_required_devices = self.num_required_devices
@wraps(fn) @wraps(fn)
@ -1474,13 +1474,13 @@ def onlyNativeDeviceTypesAnd(devices=None):
# self.precision *2, max(1, self.precision)). # self.precision *2, max(1, self.precision)).
class precisionOverride: class precisionOverride:
def __init__(self, d): def __init__(self, d):
assert isinstance( assert isinstance(d, dict), (
d, dict "precisionOverride not given a dtype : precision dict!"
), "precisionOverride not given a dtype : precision dict!" )
for dtype in d.keys(): for dtype in d.keys():
assert isinstance( assert isinstance(dtype, torch.dtype), (
dtype, torch.dtype f"precisionOverride given unknown dtype {dtype}"
), f"precisionOverride given unknown dtype {dtype}" )
self.d = d self.d = d
@ -1513,12 +1513,12 @@ class toleranceOverride:
def __init__(self, d): def __init__(self, d):
assert isinstance(d, dict), "toleranceOverride not given a dtype : tol dict!" assert isinstance(d, dict), "toleranceOverride not given a dtype : tol dict!"
for dtype, prec in d.items(): for dtype, prec in d.items():
assert isinstance( assert isinstance(dtype, torch.dtype), (
dtype, torch.dtype f"toleranceOverride given unknown dtype {dtype}"
), f"toleranceOverride given unknown dtype {dtype}" )
assert isinstance( assert isinstance(prec, tol), (
prec, tol "toleranceOverride not given a dtype : tol dict!"
), "toleranceOverride not given a dtype : tol dict!" )
self.d = d self.d = d
@ -1546,13 +1546,13 @@ class dtypes:
"all dtype variants must be. " "all dtype variants must be. "
f"Received non-list non-tuple dtype {str(arg)}" f"Received non-list non-tuple dtype {str(arg)}"
) )
assert all( assert all(isinstance(dtype, torch.dtype) for dtype in arg), (
isinstance(dtype, torch.dtype) for dtype in arg f"Unknown dtype in {str(arg)}"
), f"Unknown dtype in {str(arg)}" )
else: else:
assert all( assert all(isinstance(arg, torch.dtype) for arg in args), (
isinstance(arg, torch.dtype) for arg in args f"Unknown dtype in {str(args)}"
), f"Unknown dtype in {str(args)}" )
self.args = args self.args = args
self.device_type = device_type self.device_type = device_type

View File

@ -253,9 +253,9 @@ def verify_ddp_error_logged(model_DDP, err_substr):
if err_substr.find("\nException raised from ") == -1 if err_substr.find("\nException raised from ") == -1
else err_substr.split("\nException raised from ")[0] else err_substr.split("\nException raised from ")[0]
) )
assert ( assert actual in logging_err, (
actual in logging_err f"Did not find expected {actual} in ddp logging data error: {logging_err}"
), f"Did not find expected {actual} in ddp logging data error: {logging_err}" )
def with_nccl_blocking_wait(func): def with_nccl_blocking_wait(func):
@ -294,9 +294,9 @@ def with_nccl_blocking_wait(func):
finally: finally:
# restore old values. # restore old values.
if cached_nccl_async_error_handling is not None: if cached_nccl_async_error_handling is not None:
os.environ[ os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = (
"TORCH_NCCL_ASYNC_ERROR_HANDLING" cached_nccl_async_error_handling
] = cached_nccl_async_error_handling )
if cached_nccl_blocking_wait is not None: if cached_nccl_blocking_wait is not None:
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = cached_nccl_blocking_wait os.environ["TORCH_NCCL_BLOCKING_WAIT"] = cached_nccl_blocking_wait
@ -812,7 +812,7 @@ class MultiProcessTestCase(TestCase):
sys.exit(TEST_SKIPS["generic"].exit_code) sys.exit(TEST_SKIPS["generic"].exit_code)
except Exception: except Exception:
logger.error( logger.error(
"Caught exception: \n%s exiting " "process %s with exit code: %s", "Caught exception: \n%s exiting process %s with exit code: %s",
traceback.format_exc(), traceback.format_exc(),
self.rank, self.rank,
MultiProcessTestCase.TEST_ERROR_EXIT_CODE, MultiProcessTestCase.TEST_ERROR_EXIT_CODE,
@ -1689,9 +1689,7 @@ class MultiProcContinousTest(TestCase):
cls.processes.append(process) cls.processes.append(process)
cls.task_queues.append(task_queue) cls.task_queues.append(task_queue)
cls.completion_queues.append(completion_queue) cls.completion_queues.append(completion_queue)
logger.info( logger.info("Started process %s with pid %s", rank, process.pid) # noqa: UP031
"Started process %s with pid %s", rank, process.pid
) # noqa: UP031
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):

View File

@ -1285,10 +1285,10 @@ class FSDPTest(MultiProcessTestCase):
loss = sharded_grad_scaler.scale(loss) loss = sharded_grad_scaler.scale(loss)
if not mixed_precision and not use_pure_fp16: if not mixed_precision and not use_pure_fp16:
assert ( assert loss.dtype == torch.float32, (
loss.dtype == torch.float32 "loss data type should be float32, as the original \
), "loss data type should be float32, as the original \
parameter data type is float32." parameter data type is float32."
)
else: else:
if use_pure_fp16: if use_pure_fp16:
self.assertEqual(loss.dtype, torch.float16) self.assertEqual(loss.dtype, torch.float16)
@ -1354,9 +1354,9 @@ class FSDPTest(MultiProcessTestCase):
wrapper should provide data parallel semantics. If ``None``, wrapper should provide data parallel semantics. If ``None``,
then the callable defaults to the DDP constructor. then the callable defaults to the DDP constructor.
""" """
assert ( assert fsdp_init_mode != FSDPInitMode.NO_FSDP, (
fsdp_init_mode != FSDPInitMode.NO_FSDP "Expects an FSDP init mode that wraps with FSDP"
), "Expects an FSDP init mode that wraps with FSDP" )
if init_kwargs is None: if init_kwargs is None:
init_kwargs = {} init_kwargs = {}
lr = 1e-2 lr = 1e-2

View File

@ -1268,9 +1268,9 @@ def _get_optim_inputs_including_global_cliquey_kwargs(
trivial. That said, we sometimes want to test for all possible configs on an trivial. That said, we sometimes want to test for all possible configs on an
optimizer including all supported flags, so this helper returns all optim inputs. optimizer including all supported flags, so this helper returns all optim inputs.
""" """
assert all( assert all(x in ["foreach", "fused", "differentiable"] for x in skip), (
x in ["foreach", "fused", "differentiable"] for x in skip "skip must be a subset of ['foreach', 'fused', 'differentiable']"
), "skip must be a subset of ['foreach', 'fused', 'differentiable']" )
optim_inputs = optim_info.optim_inputs_func(device) optim_inputs = optim_info.optim_inputs_func(device)

View File

@ -477,7 +477,9 @@ def with_comms(
def decorator(func, eager_init: bool = False, backend: Optional[str] = None): def decorator(func, eager_init: bool = False, backend: Optional[str] = None):
@wraps(func) # pyre-ignore[6] @wraps(func) # pyre-ignore[6]
def wrapper( def wrapper(
self, *args: tuple[object], **kwargs: dict[str, Any] # type: ignore[misc] self,
*args: tuple[object],
**kwargs: dict[str, Any], # type: ignore[misc]
) -> None: ) -> None:
self.init_pg(eager_init, backend) self.init_pg(eager_init, backend)

View File

@ -253,7 +253,11 @@ class Trainer:
else: else:
input_batches = batches input_batches = batches
with self.hybrid_module.join() if simulate_uneven_inputs else contextlib.nullcontext(): with (
self.hybrid_module.join()
if simulate_uneven_inputs
else contextlib.nullcontext()
):
for b in input_batches: for b in input_batches:
with dist_autograd.context() as context_id: with dist_autograd.context() as context_id:
output = self.hybrid_module.forward(b) output = self.hybrid_module.forward(b)
@ -261,8 +265,7 @@ class Trainer:
dist_autograd.backward(context_id, [loss]) dist_autograd.backward(context_id, [loss])
grads_dict = dist_autograd.get_gradients(context_id) grads_dict = dist_autograd.get_gradients(context_id)
gLogger.info( gLogger.info(
"Loss is %s for mini batch: %s. " "Loss is %s for mini batch: %s. Grads dict has %s entries: %s",
"Grads dict has %s entries: %s",
loss, loss,
mini_batch, mini_batch,
len(grads_dict), len(grads_dict),

View File

@ -162,9 +162,7 @@ class SampleInput:
# Allow calling either as SampleInput(input, args=args, kwargs=kwargs), or as # Allow calling either as SampleInput(input, args=args, kwargs=kwargs), or as
# SampleInput(input, *args, **kwargs) but not to mix the two forms # SampleInput(input, *args, **kwargs) but not to mix the two forms
if args is not None or kwargs is not None: if args is not None or kwargs is not None:
assert ( assert not var_args and not var_kwargs, """
not var_args and not var_kwargs
), """
A SampleInput can be constructed "naturally" with *args and **kwargs or by A SampleInput can be constructed "naturally" with *args and **kwargs or by
explicitly setting the "args" and "kwargs" parameters, but the two explicitly setting the "args" and "kwargs" parameters, but the two
methods of construction cannot be mixed!""" methods of construction cannot be mixed!"""
@ -226,7 +224,7 @@ cannot specify additional metadata in keyword arguments"""
f"name={repr(self.name)}", f"name={repr(self.name)}",
] ]
return f'SampleInput({", ".join(a for a in arguments if a is not None)})' return f"SampleInput({', '.join(a for a in arguments if a is not None)})"
def __repr__(self): def __repr__(self):
return self._repr_helper(lambda x: x) return self._repr_helper(lambda x: x)
@ -1601,13 +1599,11 @@ class SampleRule(ABC):
# returns a string identifier of the rule type # returns a string identifier of the rule type
@abstractmethod @abstractmethod
def type(self) -> str: def type(self) -> str: ...
...
# returns an appropriate context that handles the xfail, skips, etc. # returns an appropriate context that handles the xfail, skips, etc.
@abstractmethod @abstractmethod
def get_context(self, test_case): def get_context(self, test_case): ...
...
# useful for specifying xfails # useful for specifying xfails
@ -1791,8 +1787,10 @@ class ReductionOpInfo(OpInfo):
# kwargs to use when calling the op. This is required for operators that # kwargs to use when calling the op. This is required for operators that
# have other required parameters besides the input tensor. # have other required parameters besides the input tensor.
generate_args_kwargs: Callable = lambda t, dim=None, keepdim=False: ( generate_args_kwargs: Callable = lambda t, dim=None, keepdim=False: (
yield (), yield (
{}, (),
{},
)
), ),
# Options from the OpInfo base class # Options from the OpInfo base class
**kwargs, **kwargs,
@ -2476,9 +2474,9 @@ class BinaryUfuncInfo(OpInfo):
self.supports_one_python_scalar = True self.supports_one_python_scalar = True
if self.supports_one_python_scalar: if self.supports_one_python_scalar:
assert ( assert supports_rhs_python_scalar, (
supports_rhs_python_scalar "Can't support lhs and rhs Python scalars but not rhs scalars!"
), "Can't support lhs and rhs Python scalars but not rhs scalars!" )
# The following functions and classes are for testing elementwise unary operators. # The following functions and classes are for testing elementwise unary operators.

View File

@ -102,8 +102,9 @@ def sample_inputs_masked_reduction(op_info, device, dtype, requires_grad, **kwar
for mask in _generate_masked_op_mask( for mask in _generate_masked_op_mask(
sample_input.input.shape, device, **kwargs sample_input.input.shape, device, **kwargs
): ):
sample_input_args, sample_input_kwargs = sample_input.args, dict( sample_input_args, sample_input_kwargs = (
mask=mask, **sample_input.kwargs sample_input.args,
dict(mask=mask, **sample_input.kwargs),
) )
yield SampleInput( yield SampleInput(
sample_input.input.detach().requires_grad_(requires_grad), sample_input.input.detach().requires_grad_(requires_grad),
@ -224,8 +225,9 @@ def sample_inputs_masked_norm(op_info, device, dtype, requires_grad, **kwargs):
op_info, device, dtype, requires_grad, **kwargs op_info, device, dtype, requires_grad, **kwargs
): ):
sample_input_args, sample_input_kwargs = ( sample_input_args, sample_input_kwargs = (
ord, (ord,) + sample_input.args,
) + sample_input.args, sample_input.kwargs.copy() sample_input.kwargs.copy(),
)
yield SampleInput( yield SampleInput(
sample_input.input.clone().requires_grad_(requires_grad), sample_input.input.clone().requires_grad_(requires_grad),
args=sample_input_args, args=sample_input_args,
@ -276,8 +278,9 @@ def sample_inputs_masked_std_var(op_info, device, dtype, requires_grad, **kwargs
for mask in _generate_masked_op_mask( for mask in _generate_masked_op_mask(
sample_input.input.shape, device, **kwargs sample_input.input.shape, device, **kwargs
): ):
sample_input_args, sample_input_kwargs = sample_input.args, dict( sample_input_args, sample_input_kwargs = (
mask=mask, **sample_input.kwargs sample_input.args,
dict(mask=mask, **sample_input.kwargs),
) )
yield SampleInput( yield SampleInput(
sample_input.input.detach().requires_grad_(requires_grad), sample_input.input.detach().requires_grad_(requires_grad),
@ -364,8 +367,9 @@ def sample_inputs_masked_cumops(op_info, device, dtype, requires_grad, **kwargs)
): ):
if type(mask) != torch.Tensor: if type(mask) != torch.Tensor:
continue continue
sample_input_args, sample_input_kwargs = sample_input.args, dict( sample_input_args, sample_input_kwargs = (
mask=mask, **sample_input.kwargs sample_input.args,
dict(mask=mask, **sample_input.kwargs),
) )
if "keepdim" in sample_input_kwargs: if "keepdim" in sample_input_kwargs:
sample_input_kwargs.pop("keepdim") sample_input_kwargs.pop("keepdim")

View File

@ -112,7 +112,7 @@ class _Config(Generic[T]):
@staticmethod @staticmethod
def string_or_list_of_string_to_list( def string_or_list_of_string_to_list(
val: Optional[Union[str, list[str]]] val: Optional[Union[str, list[str]]],
) -> Optional[list[str]]: ) -> Optional[list[str]]:
if val is None: if val is None:
return None return None
@ -135,8 +135,7 @@ if TYPE_CHECKING:
env_name_force: Optional[Union[str, list[str]]] = None, env_name_force: Optional[Union[str, list[str]]] = None,
value_type: Optional[type] = None, value_type: Optional[type] = None,
alias: Optional[str] = None, alias: Optional[str] = None,
) -> T: ) -> T: ...
...
else: else:
@ -323,9 +322,9 @@ class _ConfigEntry:
# Ensure justknobs and envvars are allowlisted types # Ensure justknobs and envvars are allowlisted types
if self.justknob is not None and self.default is not None: if self.justknob is not None and self.default is not None:
assert isinstance( assert isinstance(self.default, bool), (
self.default, bool f"justknobs only support booleans, {self.default} is not a boolean"
), f"justknobs only support booleans, {self.default} is not a boolean" )
if self.value_type is not None and ( if self.value_type is not None and (
config.env_name_default is not None or config.env_name_force is not None config.env_name_default is not None or config.env_name_force is not None
): ):
@ -334,7 +333,9 @@ class _ConfigEntry:
str, str,
Optional[bool], Optional[bool],
Optional[str], Optional[str],
), f"envvar configs only support (optional) booleans or strings, {self.value_type} is neither" ), (
f"envvar configs only support (optional) booleans or strings, {self.value_type} is neither"
)
class ConfigModule(ModuleType): class ConfigModule(ModuleType):

View File

@ -282,9 +282,9 @@ def tree_is_leaf(
False False
>>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple)) >>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple))
True True
>>> tree_is_leaf({'a': 1, 'b': 2, 'c': 3}) >>> tree_is_leaf({"a": 1, "b": 2, "c": 3})
False False
>>> tree_is_leaf({'a': 1, 'b': 2, 'c': None}) >>> tree_is_leaf({"a": 1, "b": 2, "c": None})
False False
Args: Args:
@ -586,29 +586,28 @@ MapOnlyFn = Callable[[T], Callable[[Any], Any]]
# These specializations help with type inference on the lambda passed to this # These specializations help with type inference on the lambda passed to this
# function # function
@overload @overload
def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: ...
...
@overload @overload
def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: ...
...
@overload @overload
def map_only(type_or_types_or_pred: Type3[T, S, U], /) -> MapOnlyFn[Fn3[T, S, U, Any]]: def map_only(
... type_or_types_or_pred: Type3[T, S, U], /
) -> MapOnlyFn[Fn3[T, S, U, Any]]: ...
# This specialization is needed for the implementations below that call # This specialization is needed for the implementations below that call
@overload @overload
def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: ...
...
@overload @overload
def map_only(type_or_types_or_pred: Callable[[Any], bool], /) -> MapOnlyFn[FnAny[Any]]: def map_only(
... type_or_types_or_pred: Callable[[Any], bool], /
) -> MapOnlyFn[FnAny[Any]]: ...
def map_only( def map_only(
@ -664,8 +663,7 @@ def tree_map_only(
func: Fn[T, Any], func: Fn[T, Any],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree: ) -> PyTree: ...
...
@overload @overload
@ -675,8 +673,7 @@ def tree_map_only(
func: Fn2[T, S, Any], func: Fn2[T, S, Any],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree: ) -> PyTree: ...
...
@overload @overload
@ -686,8 +683,7 @@ def tree_map_only(
func: Fn3[T, S, U, Any], func: Fn3[T, S, U, Any],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree: ) -> PyTree: ...
...
@overload @overload
@ -697,8 +693,7 @@ def tree_map_only(
func: FnAny[Any], func: FnAny[Any],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree: ) -> PyTree: ...
...
@overload @overload
@ -708,8 +703,7 @@ def tree_map_only(
func: FnAny[Any], func: FnAny[Any],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree: ) -> PyTree: ...
...
def tree_map_only( def tree_map_only(
@ -729,8 +723,7 @@ def tree_map_only_(
func: Fn[T, Any], func: Fn[T, Any],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree: ) -> PyTree: ...
...
@overload @overload
@ -740,8 +733,7 @@ def tree_map_only_(
func: Fn2[T, S, Any], func: Fn2[T, S, Any],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree: ) -> PyTree: ...
...
@overload @overload
@ -751,8 +743,7 @@ def tree_map_only_(
func: Fn3[T, S, U, Any], func: Fn3[T, S, U, Any],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree: ) -> PyTree: ...
...
@overload @overload
@ -762,8 +753,7 @@ def tree_map_only_(
func: FnAny[Any], func: FnAny[Any],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree: ) -> PyTree: ...
...
@overload @overload
@ -773,8 +763,7 @@ def tree_map_only_(
func: FnAny[Any], func: FnAny[Any],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree: ) -> PyTree: ...
...
def tree_map_only_( def tree_map_only_(
@ -812,8 +801,7 @@ def tree_all_only(
pred: Fn[T, bool], pred: Fn[T, bool],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool: ) -> bool: ...
...
@overload @overload
@ -823,8 +811,7 @@ def tree_all_only(
pred: Fn2[T, S, bool], pred: Fn2[T, S, bool],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool: ) -> bool: ...
...
@overload @overload
@ -834,8 +821,7 @@ def tree_all_only(
pred: Fn3[T, S, U, bool], pred: Fn3[T, S, U, bool],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool: ) -> bool: ...
...
def tree_all_only( def tree_all_only(
@ -856,8 +842,7 @@ def tree_any_only(
pred: Fn[T, bool], pred: Fn[T, bool],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool: ) -> bool: ...
...
@overload @overload
@ -867,8 +852,7 @@ def tree_any_only(
pred: Fn2[T, S, bool], pred: Fn2[T, S, bool],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool: ) -> bool: ...
...
@overload @overload
@ -878,8 +862,7 @@ def tree_any_only(
pred: Fn3[T, S, U, bool], pred: Fn3[T, S, U, bool],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool: ) -> bool: ...
...
def tree_any_only( def tree_any_only(

View File

@ -12,7 +12,7 @@ _cache_sentinel = object()
def cache_method( def cache_method(
f: Callable[Concatenate[_C, _P], _T] f: Callable[Concatenate[_C, _P], _T],
) -> Callable[Concatenate[_C, _P], _T]: ) -> Callable[Concatenate[_C, _P], _T]:
""" """
Like `@functools.cache` but for methods. Like `@functools.cache` but for methods.

View File

@ -302,14 +302,12 @@ class BaseTorchDispatchMode(TorchDispatchMode):
# Subtypes which have __tensor_flatten__ and __tensor_unflatten__. # Subtypes which have __tensor_flatten__ and __tensor_unflatten__.
class TensorWithFlatten(Protocol): class TensorWithFlatten(Protocol):
def __tensor_flatten__(self) -> tuple[Sequence[str], object]: def __tensor_flatten__(self) -> tuple[Sequence[str], object]: ...
...
@staticmethod @staticmethod
def __tensor_unflatten__( def __tensor_unflatten__(
inner_tensors: int, flatten_spec: int, outer_size: int, outer_stride: int inner_tensors: int, flatten_spec: int, outer_size: int, outer_stride: int
) -> torch.Tensor: ) -> torch.Tensor: ...
...
# It would be really nice to be able to say that the return of # It would be really nice to be able to say that the return of
# is_traceable_wrapper_subclass() is Intersection[torch.Tensor, # is_traceable_wrapper_subclass() is Intersection[torch.Tensor,
@ -318,26 +316,20 @@ class TensorWithFlatten(Protocol):
shape: torch._C.Size shape: torch._C.Size
@overload @overload
def stride(self, dim: None = None) -> tuple[int, ...]: def stride(self, dim: None = None) -> tuple[int, ...]: ...
...
@overload @overload
def stride(self, dim: int) -> int: def stride(self, dim: int) -> int: ...
...
@overload @overload
def size(self, dim: None = None) -> tuple[int, ...]: def size(self, dim: None = None) -> tuple[int, ...]: ...
...
@overload @overload
def size(self, dim: int) -> int: def size(self, dim: int) -> int: ...
...
def storage_offset(self) -> int: def storage_offset(self) -> int: ...
...
def dim(self) -> int: def dim(self) -> int: ...
...
@overload @overload
def to( def to(
@ -347,8 +339,7 @@ class TensorWithFlatten(Protocol):
copy: bool = False, copy: bool = False,
*, *,
memory_format: Optional[torch.memory_format] = None, memory_format: Optional[torch.memory_format] = None,
) -> torch.Tensor: ) -> torch.Tensor: ...
...
@overload @overload
def to( def to(
@ -359,8 +350,7 @@ class TensorWithFlatten(Protocol):
copy: bool = False, copy: bool = False,
*, *,
memory_format: Optional[torch.memory_format] = None, memory_format: Optional[torch.memory_format] = None,
) -> torch.Tensor: ) -> torch.Tensor: ...
...
@overload @overload
def to( def to(
@ -370,8 +360,7 @@ class TensorWithFlatten(Protocol):
copy: bool = False, copy: bool = False,
*, *,
memory_format: Optional[torch.memory_format] = None, memory_format: Optional[torch.memory_format] = None,
) -> torch.Tensor: ) -> torch.Tensor: ...
...
def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]: def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]:

View File

@ -99,17 +99,13 @@ NO_SERIALIZED_TYPE_NAME_FOUND = "NO_SERIALIZED_TYPE_NAME_FOUND"
class KeyEntry(Protocol): class KeyEntry(Protocol):
def __hash__(self) -> int: def __hash__(self) -> int: ...
...
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool: ...
...
def __str__(self) -> str: def __str__(self) -> str: ...
...
def get(self, parent: Any) -> Any: def get(self, parent: Any) -> Any: ...
...
class EnumEncoder(json.JSONEncoder): class EnumEncoder(json.JSONEncoder):
@ -757,7 +753,7 @@ def _tuple_flatten(d: tuple[T, ...]) -> tuple[list[T], Context]:
def _tuple_flatten_with_keys( def _tuple_flatten_with_keys(
d: tuple[T, ...] d: tuple[T, ...],
) -> tuple[list[tuple[KeyEntry, T]], Context]: ) -> tuple[list[tuple[KeyEntry, T]], Context]:
values, context = _tuple_flatten(d) values, context = _tuple_flatten(d)
return [(SequenceKey(i), v) for i, v in enumerate(values)], context return [(SequenceKey(i), v) for i, v in enumerate(values)], context
@ -785,7 +781,7 @@ def _dict_flatten(d: dict[Any, T]) -> tuple[list[T], Context]:
def _dict_flatten_with_keys( def _dict_flatten_with_keys(
d: dict[Any, T] d: dict[Any, T],
) -> tuple[list[tuple[KeyEntry, T]], Context]: ) -> tuple[list[tuple[KeyEntry, T]], Context]:
values, context = _dict_flatten(d) values, context = _dict_flatten(d)
return [(MappingKey(k), v) for k, v in zip(context, values)], context return [(MappingKey(k), v) for k, v in zip(context, values)], context
@ -849,7 +845,7 @@ def _ordereddict_flatten(d: OrderedDict[Any, T]) -> tuple[list[T], Context]:
def _ordereddict_flatten_with_keys( def _ordereddict_flatten_with_keys(
d: OrderedDict[Any, T] d: OrderedDict[Any, T],
) -> tuple[list[tuple[KeyEntry, T]], Context]: ) -> tuple[list[tuple[KeyEntry, T]], Context]:
values, context = _ordereddict_flatten(d) values, context = _ordereddict_flatten(d)
return [(MappingKey(k), v) for k, v in zip(context, values)], context return [(MappingKey(k), v) for k, v in zip(context, values)], context
@ -872,7 +868,7 @@ def _defaultdict_flatten(d: defaultdict[Any, T]) -> tuple[list[T], Context]:
def _defaultdict_flatten_with_keys( def _defaultdict_flatten_with_keys(
d: defaultdict[Any, T] d: defaultdict[Any, T],
) -> tuple[list[tuple[KeyEntry, T]], Context]: ) -> tuple[list[tuple[KeyEntry, T]], Context]:
values, context = _defaultdict_flatten(d) values, context = _defaultdict_flatten(d)
_, dict_context = context _, dict_context = context
@ -1035,9 +1031,9 @@ def tree_is_leaf(
False False
>>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple)) >>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple))
True True
>>> tree_is_leaf({'a': 1, 'b': 2, 'c': 3}) >>> tree_is_leaf({"a": 1, "b": 2, "c": 3})
False False
>>> tree_is_leaf({'a': 1, 'b': 2, 'c': None}) >>> tree_is_leaf({"a": 1, "b": 2, "c": None})
False False
""" """
if is_leaf is not None and is_leaf(tree): if is_leaf is not None and is_leaf(tree):
@ -1346,9 +1342,9 @@ def tree_map(
See also :func:`tree_map_`. See also :func:`tree_map_`.
>>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)}) >>> tree_map(lambda x: x + 1, {"x": 7, "y": (42, 64)})
{'x': 8, 'y': (43, 65)} {'x': 8, 'y': (43, 65)}
>>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}) >>> tree_map(lambda x: x is None, {"x": 7, "y": (42, 64), "z": None})
{'x': False, 'y': (False, False), 'z': True} {'x': False, 'y': (False, False), 'z': True}
If multiple inputs are given, the structure of the tree is taken from the first input; If multiple inputs are given, the structure of the tree is taken from the first input;
@ -1432,29 +1428,28 @@ MapOnlyFn = Callable[[T], Callable[[Any], Any]]
# These specializations help with type inference on the lambda passed to this # These specializations help with type inference on the lambda passed to this
# function # function
@overload @overload
def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: ...
...
@overload @overload
def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: ...
...
@overload @overload
def map_only(type_or_types_or_pred: Type3[T, S, U], /) -> MapOnlyFn[Fn3[T, S, U, Any]]: def map_only(
... type_or_types_or_pred: Type3[T, S, U], /
) -> MapOnlyFn[Fn3[T, S, U, Any]]: ...
# This specialization is needed for the implementations below that call # This specialization is needed for the implementations below that call
@overload @overload
def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: ...
...
@overload @overload
def map_only(type_or_types_or_pred: Callable[[Any], bool], /) -> MapOnlyFn[FnAny[Any]]: def map_only(
... type_or_types_or_pred: Callable[[Any], bool], /
) -> MapOnlyFn[FnAny[Any]]: ...
def map_only( def map_only(
@ -1510,8 +1505,7 @@ def tree_map_only(
func: Fn[T, Any], func: Fn[T, Any],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree: ) -> PyTree: ...
...
@overload @overload
@ -1521,8 +1515,7 @@ def tree_map_only(
func: Fn2[T, S, Any], func: Fn2[T, S, Any],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree: ) -> PyTree: ...
...
@overload @overload
@ -1532,8 +1525,7 @@ def tree_map_only(
func: Fn3[T, S, U, Any], func: Fn3[T, S, U, Any],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree: ) -> PyTree: ...
...
@overload @overload
@ -1543,8 +1535,7 @@ def tree_map_only(
func: FnAny[Any], func: FnAny[Any],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree: ) -> PyTree: ...
...
@overload @overload
@ -1554,8 +1545,7 @@ def tree_map_only(
func: FnAny[Any], func: FnAny[Any],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree: ) -> PyTree: ...
...
def tree_map_only( def tree_map_only(
@ -1575,8 +1565,7 @@ def tree_map_only_(
func: Fn[T, Any], func: Fn[T, Any],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree: ) -> PyTree: ...
...
@overload @overload
@ -1586,8 +1575,7 @@ def tree_map_only_(
func: Fn2[T, S, Any], func: Fn2[T, S, Any],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree: ) -> PyTree: ...
...
@overload @overload
@ -1597,8 +1585,7 @@ def tree_map_only_(
func: Fn3[T, S, U, Any], func: Fn3[T, S, U, Any],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree: ) -> PyTree: ...
...
@overload @overload
@ -1608,8 +1595,7 @@ def tree_map_only_(
func: FnAny[Any], func: FnAny[Any],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree: ) -> PyTree: ...
...
@overload @overload
@ -1619,8 +1605,7 @@ def tree_map_only_(
func: FnAny[Any], func: FnAny[Any],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree: ) -> PyTree: ...
...
def tree_map_only_( def tree_map_only_(
@ -1658,8 +1643,7 @@ def tree_all_only(
pred: Fn[T, bool], pred: Fn[T, bool],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool: ) -> bool: ...
...
@overload @overload
@ -1669,8 +1653,7 @@ def tree_all_only(
pred: Fn2[T, S, bool], pred: Fn2[T, S, bool],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool: ) -> bool: ...
...
@overload @overload
@ -1680,8 +1663,7 @@ def tree_all_only(
pred: Fn3[T, S, U, bool], pred: Fn3[T, S, U, bool],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool: ) -> bool: ...
...
def tree_all_only( def tree_all_only(
@ -1702,8 +1684,7 @@ def tree_any_only(
pred: Fn[T, bool], pred: Fn[T, bool],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool: ) -> bool: ...
...
@overload @overload
@ -1713,8 +1694,7 @@ def tree_any_only(
pred: Fn2[T, S, bool], pred: Fn2[T, S, bool],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool: ) -> bool: ...
...
@overload @overload
@ -1724,8 +1704,7 @@ def tree_any_only(
pred: Fn3[T, S, U, bool], pred: Fn3[T, S, U, bool],
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool: ) -> bool: ...
...
def tree_any_only( def tree_any_only(
@ -1862,7 +1841,7 @@ def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec:
if json_schema["type"] not in SERIALIZED_TYPE_TO_PYTHON_TYPE: if json_schema["type"] not in SERIALIZED_TYPE_TO_PYTHON_TYPE:
raise NotImplementedError( raise NotImplementedError(
f'Deserializing {json_schema["type"]} in pytree is not registered.', f"Deserializing {json_schema['type']} in pytree is not registered.",
) )
typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[json_schema["type"]] typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[json_schema["type"]]

View File

@ -301,7 +301,7 @@ def strobelight(
profiler = StrobelightCLIFunctionProfiler(**kwargs) profiler = StrobelightCLIFunctionProfiler(**kwargs)
def strobelight_inner( def strobelight_inner(
work_function: Callable[_P, _R] work_function: Callable[_P, _R],
) -> Callable[_P, Optional[_R]]: ) -> Callable[_P, Optional[_R]]:
@functools.wraps(work_function) @functools.wraps(work_function)
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]: def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]:

View File

@ -98,7 +98,7 @@ def _is_symbols_binary_summation(expr: sympy.Expr) -> bool:
def _keep_float( def _keep_float(
f: Callable[[Unpack[_Ts]], _T] f: Callable[[Unpack[_Ts]], _T],
) -> Callable[[Unpack[_Ts]], Union[_T, sympy.Float]]: ) -> Callable[[Unpack[_Ts]], Union[_T, sympy.Float]]:
@functools.wraps(f) @functools.wraps(f)
def inner(*args: Unpack[_Ts]) -> Union[_T, sympy.Float]: def inner(*args: Unpack[_Ts]) -> Union[_T, sympy.Float]:
@ -926,10 +926,12 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
_eval_is_algebraic = lambda s: _torf(i.is_algebraic for i in s.args) # noqa: E731 _eval_is_algebraic = lambda s: _torf(i.is_algebraic for i in s.args) # noqa: E731
_eval_is_antihermitian = lambda s: _torf( # noqa: E731 _eval_is_antihermitian = lambda s: _torf( # noqa: E731
i.is_antihermitian for i in s.args # noqa: E731 i.is_antihermitian
for i in s.args # noqa: E731
) # noqa: E731 ) # noqa: E731
_eval_is_commutative = lambda s: _torf( # noqa: E731 _eval_is_commutative = lambda s: _torf( # noqa: E731
i.is_commutative for i in s.args # noqa: E731 i.is_commutative
for i in s.args # noqa: E731
) # noqa: E731 ) # noqa: E731
_eval_is_complex = lambda s: _torf(i.is_complex for i in s.args) # noqa: E731 _eval_is_complex = lambda s: _torf(i.is_complex for i in s.args) # noqa: E731
_eval_is_composite = lambda s: _torf(i.is_composite for i in s.args) # noqa: E731 _eval_is_composite = lambda s: _torf(i.is_composite for i in s.args) # noqa: E731
@ -943,10 +945,12 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
_eval_is_negative = lambda s: _torf(i.is_negative for i in s.args) # noqa: E731 _eval_is_negative = lambda s: _torf(i.is_negative for i in s.args) # noqa: E731
_eval_is_noninteger = lambda s: _torf(i.is_noninteger for i in s.args) # noqa: E731 _eval_is_noninteger = lambda s: _torf(i.is_noninteger for i in s.args) # noqa: E731
_eval_is_nonnegative = lambda s: _torf( # noqa: E731 _eval_is_nonnegative = lambda s: _torf( # noqa: E731
i.is_nonnegative for i in s.args # noqa: E731 i.is_nonnegative
for i in s.args # noqa: E731
) # noqa: E731 ) # noqa: E731
_eval_is_nonpositive = lambda s: _torf( # noqa: E731 _eval_is_nonpositive = lambda s: _torf( # noqa: E731
i.is_nonpositive for i in s.args # noqa: E731 i.is_nonpositive
for i in s.args # noqa: E731
) # noqa: E731 ) # noqa: E731
_eval_is_nonzero = lambda s: _torf(i.is_nonzero for i in s.args) # noqa: E731 _eval_is_nonzero = lambda s: _torf(i.is_nonzero for i in s.args) # noqa: E731
_eval_is_odd = lambda s: _torf(i.is_odd for i in s.args) # noqa: E731 _eval_is_odd = lambda s: _torf(i.is_odd for i in s.args) # noqa: E731
@ -956,10 +960,12 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
_eval_is_rational = lambda s: _torf(i.is_rational for i in s.args) # noqa: E731 _eval_is_rational = lambda s: _torf(i.is_rational for i in s.args) # noqa: E731
_eval_is_real = lambda s: _torf(i.is_real for i in s.args) # noqa: E731 _eval_is_real = lambda s: _torf(i.is_real for i in s.args) # noqa: E731
_eval_is_extended_real = lambda s: _torf( # noqa: E731 _eval_is_extended_real = lambda s: _torf( # noqa: E731
i.is_extended_real for i in s.args # noqa: E731 i.is_extended_real
for i in s.args # noqa: E731
) # noqa: E731 ) # noqa: E731
_eval_is_transcendental = lambda s: _torf( # noqa: E731 _eval_is_transcendental = lambda s: _torf( # noqa: E731
i.is_transcendental for i in s.args # noqa: E731 i.is_transcendental
for i in s.args # noqa: E731
) # noqa: E731 ) # noqa: E731
_eval_is_zero = lambda s: _torf(i.is_zero for i in s.args) # noqa: E731 _eval_is_zero = lambda s: _torf(i.is_zero for i in s.args) # noqa: E731

View File

@ -144,16 +144,14 @@ class ValueRanges(Generic[_T]):
self: ValueRanges[sympy.Expr], self: ValueRanges[sympy.Expr],
lower: ExprIn, lower: ExprIn,
upper: ExprIn, upper: ExprIn,
) -> None: ) -> None: ...
...
@overload @overload
def __init__( # type: ignore[misc] def __init__( # type: ignore[misc]
self: ValueRanges[SympyBoolean], self: ValueRanges[SympyBoolean],
lower: BoolIn, lower: BoolIn,
upper: BoolIn, upper: BoolIn,
) -> None: ) -> None: ...
...
def __init__(self, lower: AllIn, upper: AllIn) -> None: def __init__(self, lower: AllIn, upper: AllIn) -> None:
lower = simple_sympify(lower) lower = simple_sympify(lower)
@ -240,15 +238,13 @@ class ValueRanges(Generic[_T]):
def __and__( def __and__(
self: ValueRanges[sympy.Expr], self: ValueRanges[sympy.Expr],
other: ValueRanges[sympy.Expr], other: ValueRanges[sympy.Expr],
) -> ValueRanges[sympy.Expr]: ) -> ValueRanges[sympy.Expr]: ...
...
@overload @overload
def __and__( # type: ignore[misc] def __and__( # type: ignore[misc]
self: ValueRanges[SympyBoolean], self: ValueRanges[SympyBoolean],
other: ValueRanges[SympyBoolean], other: ValueRanges[SympyBoolean],
) -> ValueRanges[SympyBoolean]: ) -> ValueRanges[SympyBoolean]: ...
...
def __and__(self: AllVR, other: AllVR) -> AllVR: def __and__(self: AllVR, other: AllVR) -> AllVR:
if other in (ValueRanges.unknown(), ValueRanges.unknown_int()): if other in (ValueRanges.unknown(), ValueRanges.unknown_int()):
@ -272,15 +268,13 @@ class ValueRanges(Generic[_T]):
def __or__( def __or__(
self: ValueRanges[sympy.Expr], self: ValueRanges[sympy.Expr],
other: ValueRanges[sympy.Expr], other: ValueRanges[sympy.Expr],
) -> ValueRanges[sympy.Expr]: ) -> ValueRanges[sympy.Expr]: ...
...
@overload @overload
def __or__( # type: ignore[misc] def __or__( # type: ignore[misc]
self: ValueRanges[SympyBoolean], self: ValueRanges[SympyBoolean],
other: ValueRanges[SympyBoolean], other: ValueRanges[SympyBoolean],
) -> ValueRanges[SympyBoolean]: ) -> ValueRanges[SympyBoolean]: ...
...
def __or__(self: AllVR, other: AllVR) -> AllVR: def __or__(self: AllVR, other: AllVR) -> AllVR:
if ValueRanges.unknown() in (self, other): if ValueRanges.unknown() in (self, other):
@ -343,8 +337,7 @@ class ValueRanges(Generic[_T]):
@overload @overload
@staticmethod @staticmethod
def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: ...
...
@overload @overload
@staticmethod @staticmethod
@ -384,8 +377,7 @@ class ValueRanges(Generic[_T]):
x: Union[ExprIn, ExprVR], x: Union[ExprIn, ExprVR],
y: Union[ExprIn, ExprVR], y: Union[ExprIn, ExprVR],
fn: ExprFn2, fn: ExprFn2,
) -> ExprVR: ) -> ExprVR: ...
...
@overload @overload
@staticmethod @staticmethod
@ -393,8 +385,7 @@ class ValueRanges(Generic[_T]):
x: Union[BoolIn, BoolVR], x: Union[BoolIn, BoolVR],
y: Union[BoolIn, BoolVR], y: Union[BoolIn, BoolVR],
fn: BoolFn2, fn: BoolFn2,
) -> BoolVR: ) -> BoolVR: ...
...
@staticmethod @staticmethod
def coordinatewise_increasing_map( def coordinatewise_increasing_map(

View File

@ -426,9 +426,9 @@ def _get_custom_mod_func(func_name: str):
it is marked as private. It is a convenience function for backend implementers to it is marked as private. It is a convenience function for backend implementers to
more easily call the hooks into their backend extensions. more easily call the hooks into their backend extensions.
""" """
assert isinstance( assert isinstance(func_name, str), (
func_name, str f"func_name must be `str`, but got `{type(func_name)}`."
), f"func_name must be `str`, but got `{type(func_name)}`." )
backend_name = _get_privateuse1_backend_name() backend_name = _get_privateuse1_backend_name()
custom_device_mod = getattr(torch, backend_name, None) # type: ignore[arg-type] custom_device_mod = getattr(torch, backend_name, None) # type: ignore[arg-type]
function = getattr(custom_device_mod, func_name, None) # type: ignore[arg-type] function = getattr(custom_device_mod, func_name, None) # type: ignore[arg-type]

View File

@ -44,7 +44,7 @@ def default_convert(data):
>>> default_convert(np.array([0, 1])) >>> default_convert(np.array([0, 1]))
tensor([0, 1]) tensor([0, 1])
>>> # Example with NamedTuple >>> # Example with NamedTuple
>>> Point = namedtuple('Point', ['x', 'y']) >>> Point = namedtuple("Point", ["x", "y"])
>>> default_convert(Point(0, 0)) >>> default_convert(Point(0, 0))
Point(x=0, y=0) Point(x=0, y=0)
>>> default_convert(Point(np.array(0), np.array(0))) >>> default_convert(Point(np.array(0), np.array(0)))
@ -366,13 +366,13 @@ def default_collate(batch):
>>> default_collate([0, 1, 2, 3]) >>> default_collate([0, 1, 2, 3])
tensor([0, 1, 2, 3]) tensor([0, 1, 2, 3])
>>> # Example with a batch of `str`s: >>> # Example with a batch of `str`s:
>>> default_collate(['a', 'b', 'c']) >>> default_collate(["a", "b", "c"])
['a', 'b', 'c'] ['a', 'b', 'c']
>>> # Example with `Map` inside the batch: >>> # Example with `Map` inside the batch:
>>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}]) >>> default_collate([{"A": 0, "B": 1}, {"A": 100, "B": 100}])
{'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])} {'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])}
>>> # Example with `NamedTuple` inside the batch: >>> # Example with `NamedTuple` inside the batch:
>>> Point = namedtuple('Point', ['x', 'y']) >>> Point = namedtuple("Point", ["x", "y"])
>>> default_collate([Point(0, 0), Point(1, 1)]) >>> default_collate([Point(0, 0), Point(1, 1)])
Point(x=tensor([0, 1]), y=tensor([0, 1])) Point(x=tensor([0, 1]), y=tensor([0, 1]))
>>> # Example with `Tuple` inside the batch: >>> # Example with `Tuple` inside the batch:

View File

@ -69,7 +69,9 @@ def pin_memory(data, device=None):
) )
return clone return clone
else: else:
return type(data)({k: pin_memory(sample, device) for k, sample in data.items()}) # type: ignore[call-arg] return type(data)(
{k: pin_memory(sample, device) for k, sample in data.items()}
) # type: ignore[call-arg]
except TypeError: except TypeError:
# The mapping type may not support `copy()` / `update(mapping)` # The mapping type may not support `copy()` / `update(mapping)`
# or `__init__(iterable)`. # or `__init__(iterable)`.

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers. r"""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
These **needs** to be in global scope since Py2 doesn't support serializing These **needs** to be in global scope since Py2 doesn't support serializing
static methods. static methods.

View File

@ -5,6 +5,7 @@ To support these two classes, in `./_utils` we define many utility methods and
functions to be run in multiprocessing. E.g., the data loading worker loop is functions to be run in multiprocessing. E.g., the data loading worker loop is
in `./_utils/worker.py`. in `./_utils/worker.py`.
""" """
from __future__ import annotations from __future__ import annotations
import functools import functools
@ -1208,7 +1209,10 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w) atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
# .pid can be None only before process is spawned (not the case, so ignore) # .pid can be None only before process is spawned (not the case, so ignore)
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc] _utils.signal_handling._set_worker_pids(
id(self),
tuple(w.pid for w in self._workers), # type: ignore[misc]
)
_utils.signal_handling._set_SIGCHLD_handler() _utils.signal_handling._set_SIGCHLD_handler()
self._worker_pids_set = True self._worker_pids_set = True
self._reset(loader, first_iter=True) self._reset(loader, first_iter=True)

View File

@ -109,8 +109,7 @@ class non_deterministic:
# Decorate with a functional argument # Decorate with a functional argument
if not ( if not (
isinstance(args[0], type) isinstance(args[0], type) and issubclass(args[0], IterDataPipe) # type: ignore[arg-type]
and issubclass(args[0], IterDataPipe) # type: ignore[arg-type]
): ):
raise TypeError( raise TypeError(
f"Only `IterDataPipe` can be decorated, but {args[0].__name__} is found" f"Only `IterDataPipe` can be decorated, but {args[0].__name__} is found"

View File

@ -99,7 +99,9 @@ class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
>>> from torchdata.datapipes.iter import IterableWrapper, Mapper >>> from torchdata.datapipes.iter import IterableWrapper, Mapper
>>> dp = IterableWrapper(range(10)) >>> dp = IterableWrapper(range(10))
>>> map_dp_1 = Mapper(dp, lambda x: x + 1) # Using class constructor >>> map_dp_1 = Mapper(dp, lambda x: x + 1) # Using class constructor
>>> map_dp_2 = dp.map(lambda x: x + 1) # Using functional form (recommended) >>> map_dp_2 = dp.map(
... lambda x: x + 1
... ) # Using functional form (recommended)
>>> list(map_dp_1) >>> list(map_dp_1)
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10] [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
>>> list(map_dp_2) >>> list(map_dp_2)
@ -114,7 +116,9 @@ class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
>>> list(it1) >>> list(it1)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
>>> it1 = iter(source_dp) >>> it1 = iter(source_dp)
>>> it2 = iter(source_dp) # The creation of a new iterator invalidates `it1` >>> it2 = iter(
... source_dp
... ) # The creation of a new iterator invalidates `it1`
>>> next(it2) >>> next(it2)
0 0
>>> next(it1) # Further usage of `it1` will raise a `RunTimeError` >>> next(it1) # Further usage of `it1` will raise a `RunTimeError`

View File

@ -55,7 +55,8 @@ class MapperIterDataPipe(IterDataPipe[_T_co]):
>>> def add_one(x): >>> def add_one(x):
... return x + 1 ... return x + 1
>>> dp = IterableWrapper(range(10)) >>> dp = IterableWrapper(range(10))
>>> map_dp_1 = dp.map(add_one) # Invocation via functional form is preferred >>> # Invocation via functional form is preferred
... map_dp_1 = dp.map(add_one)
>>> list(map_dp_1) >>> list(map_dp_1)
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10] [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
>>> # We discourage the usage of `lambda` functions as they are not serializable with `pickle` >>> # We discourage the usage of `lambda` functions as they are not serializable with `pickle`
@ -202,7 +203,7 @@ class CollatorIterDataPipe(MapperIterDataPipe):
>>> class MyIterDataPipe(torch.utils.data.IterDataPipe): >>> class MyIterDataPipe(torch.utils.data.IterDataPipe):
... def __init__(self, start, end): ... def __init__(self, start, end):
... super(MyIterDataPipe).__init__() ... super(MyIterDataPipe).__init__()
... assert end > start, "this example code only works with end >= start" ... assert end > start, "this example only works with end >= start"
... self.start = start ... self.start = start
... self.end = end ... self.end = end
... ...
@ -211,13 +212,11 @@ class CollatorIterDataPipe(MapperIterDataPipe):
... ...
... def __len__(self): ... def __len__(self):
... return self.end - self.start ... return self.end - self.start
...
>>> ds = MyIterDataPipe(start=3, end=7) >>> ds = MyIterDataPipe(start=3, end=7)
>>> print(list(ds)) >>> print(list(ds))
[3, 4, 5, 6] [3, 4, 5, 6]
>>> def collate_fn(batch): >>> def collate_fn(batch):
... return torch.tensor(batch, dtype=torch.float) ... return torch.tensor(batch, dtype=torch.float)
...
>>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn) >>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn)
>>> print(list(collated_ds)) >>> print(list(collated_ds))
[tensor(3.), tensor(4.), tensor(5.), tensor(6.)] [tensor(3.), tensor(4.), tensor(5.), tensor(6.)]

View File

@ -38,15 +38,17 @@ class SamplerIterDataPipe(IterDataPipe[_T_co]):
sampler_args: Optional[tuple] = None, sampler_args: Optional[tuple] = None,
sampler_kwargs: Optional[dict] = None, sampler_kwargs: Optional[dict] = None,
) -> None: ) -> None:
assert isinstance( assert isinstance(datapipe, Sized), (
datapipe, Sized "Sampler class requires input datapipe implemented `__len__`"
), "Sampler class requires input datapipe implemented `__len__`" )
super().__init__() super().__init__()
self.datapipe = datapipe self.datapipe = datapipe
self.sampler_args = () if sampler_args is None else sampler_args self.sampler_args = () if sampler_args is None else sampler_args
self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs
# https://github.com/python/mypy/pull/9629 will solve # https://github.com/python/mypy/pull/9629 will solve
self.sampler = sampler(*self.sampler_args, data_source=self.datapipe, **self.sampler_kwargs) # type: ignore[misc] self.sampler = sampler(
*self.sampler_args, data_source=self.datapipe, **self.sampler_kwargs
) # type: ignore[misc]
def __iter__(self) -> Iterator[_T_co]: def __iter__(self) -> Iterator[_T_co]:
return iter(self.sampler) return iter(self.sampler)

View File

@ -116,16 +116,13 @@ class _ContainerTemplate(ABC):
r"""Abstract class for container ``DataPipes``. The followings are three required methods.""" r"""Abstract class for container ``DataPipes``. The followings are three required methods."""
@abstractmethod @abstractmethod
def get_next_element_by_instance(self, instance_id: int): def get_next_element_by_instance(self, instance_id: int): ...
...
@abstractmethod @abstractmethod
def is_every_instance_exhausted(self) -> bool: def is_every_instance_exhausted(self) -> bool: ...
...
@abstractmethod @abstractmethod
def reset(self) -> None: def reset(self) -> None: ...
...
@abstractmethod @abstractmethod
def get_length_by_instance(self, instance_id: int): def get_length_by_instance(self, instance_id: int):
@ -403,7 +400,9 @@ class DemultiplexerIterDataPipe(IterDataPipe):
>>> # It can also filter out any element that gets `None` from the `classifier_fn` >>> # It can also filter out any element that gets `None` from the `classifier_fn`
>>> def odd_or_even_no_zero(n): >>> def odd_or_even_no_zero(n):
... return n % 2 if n != 0 else None ... return n % 2 if n != 0 else None
>>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True) >>> dp1, dp2 = source_dp.demux(
... num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True
... )
>>> list(dp1) >>> list(dp1)
[2, 4] [2, 4]
>>> list(dp2) >>> list(dp2)
@ -428,7 +427,9 @@ class DemultiplexerIterDataPipe(IterDataPipe):
# When num_instances == 1, demux can be replaced by filter, # When num_instances == 1, demux can be replaced by filter,
# but keep it as Demultiplexer for the sake of consistency # but keep it as Demultiplexer for the sake of consistency
# like throwing Error when classification result is out of o range # like throwing Error when classification result is out of o range
container = _DemultiplexerIterDataPipe(datapipe, num_instances, classifier_fn, drop_none, buffer_size) # type: ignore[abstract] container = _DemultiplexerIterDataPipe(
datapipe, num_instances, classifier_fn, drop_none, buffer_size
) # type: ignore[abstract]
return [_ChildDataPipe(container, i) for i in range(num_instances)] return [_ChildDataPipe(container, i) for i in range(num_instances)]
@ -602,16 +603,18 @@ class MultiplexerIterDataPipe(IterDataPipe):
Example: Example:
>>> # xdoctest: +REQUIRES(module:torchdata) >>> # xdoctest: +REQUIRES(module:torchdata)
>>> from torchdata.datapipes.iter import IterableWrapper >>> from torchdata.datapipes.iter import IterableWrapper
>>> dp1, dp2, dp3 = IterableWrapper(range(3)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25)) >>> dp1, dp2, dp3 = (
... IterableWrapper(range(3)),
... IterableWrapper(range(10, 15)),
... IterableWrapper(range(20, 25)),
... )
>>> list(dp1.mux(dp2, dp3)) >>> list(dp1.mux(dp2, dp3))
[0, 10, 20, 1, 11, 21, 2, 12, 22] [0, 10, 20, 1, 11, 21, 2, 12, 22]
""" """
def __init__(self, *datapipes): def __init__(self, *datapipes):
self.datapipes = datapipes self.datapipes = datapipes
self.buffer: list = ( self.buffer: list = [] # Store values to be yielded only when every iterator provides one
[]
) # Store values to be yielded only when every iterator provides one
def __iter__(self): def __iter__(self):
iterators = [iter(x) for x in self.datapipes] iterators = [iter(x) for x in self.datapipes]
@ -670,7 +673,11 @@ class ZipperIterDataPipe(IterDataPipe[tuple[_T_co]]):
Example: Example:
>>> # xdoctest: +REQUIRES(module:torchdata) >>> # xdoctest: +REQUIRES(module:torchdata)
>>> from torchdata.datapipes.iter import IterableWrapper >>> from torchdata.datapipes.iter import IterableWrapper
>>> dp1, dp2, dp3 = IterableWrapper(range(5)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25)) >>> dp1, dp2, dp3 = (
... IterableWrapper(range(5)),
... IterableWrapper(range(10, 15)),
... IterableWrapper(range(20, 25)),
... )
>>> list(dp1.zip(dp2, dp3)) >>> list(dp1.zip(dp2, dp3))
[(0, 10, 20), (1, 11, 21), (2, 12, 22), (3, 13, 23), (4, 14, 24)] [(0, 10, 20), (1, 11, 21), (2, 12, 22), (3, 13, 23), (4, 14, 24)]
""" """

View File

@ -33,8 +33,12 @@ class FileOpenerIterDataPipe(IterDataPipe[tuple[str, IOBase]]):
Example: Example:
>>> # xdoctest: +SKIP >>> # xdoctest: +SKIP
>>> from torchdata.datapipes.iter import FileLister, FileOpener, StreamReader >>> from torchdata.datapipes.iter import (
>>> dp = FileLister(root=".").filter(lambda fname: fname.endswith('.txt')) ... FileLister,
... FileOpener,
... StreamReader,
... )
>>> dp = FileLister(root=".").filter(lambda fname: fname.endswith(".txt"))
>>> dp = FileOpener(dp) >>> dp = FileOpener(dp)
>>> dp = StreamReader(dp) >>> dp = StreamReader(dp)
>>> list(dp) >>> list(dp)

View File

@ -182,7 +182,9 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]):
>>> from torchdata.datapipes.iter import IterableWrapper >>> from torchdata.datapipes.iter import IterableWrapper
>>> def group_fn(file): >>> def group_fn(file):
... return os.path.basename(file).split(".")[0] ... return os.path.basename(file).split(".")[0]
>>> source_dp = IterableWrapper(["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"]) >>> source_dp = IterableWrapper(
... ["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"]
... )
>>> dp0 = source_dp.groupby(group_key_fn=group_fn) >>> dp0 = source_dp.groupby(group_key_fn=group_fn)
>>> list(dp0) >>> list(dp0)
[['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']] [['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']]
@ -191,7 +193,12 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]):
>>> list(dp1) >>> list(dp1)
[['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']] [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
>>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size` >>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size`
>>> dp2 = source_dp.groupby(group_key_fn=group_fn, buffer_size=3, group_size=3, guaranteed_group_size=2) >>> dp2 = source_dp.groupby(
... group_key_fn=group_fn,
... buffer_size=3,
... group_size=3,
... guaranteed_group_size=2,
... )
>>> list(dp2) >>> list(dp2)
[['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']] [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
""" """

View File

@ -31,8 +31,8 @@ class SequenceWrapperMapDataPipe(MapDataPipe[_T]):
>>> dp = SequenceWrapper(range(10)) >>> dp = SequenceWrapper(range(10))
>>> list(dp) >>> list(dp)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
>>> dp = SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400}) >>> dp = SequenceWrapper({"a": 100, "b": 200, "c": 300, "d": 400})
>>> dp['a'] >>> dp["a"]
100 100
""" """

View File

@ -45,8 +45,8 @@ def basichandlers(extension: str, data):
Example: Example:
>>> import pickle >>> import pickle
>>> data = pickle.dumps('some data') >>> data = pickle.dumps("some data")
>>> new_data = basichandlers('pickle', data) >>> new_data = basichandlers("pickle", data)
>>> new_data >>> new_data
some data some data
@ -169,9 +169,9 @@ class ImageHandler:
""" """
def __init__(self, imagespec): def __init__(self, imagespec):
assert imagespec in list( assert imagespec in list(imagespecs.keys()), (
imagespecs.keys() f"unknown image specification: {imagespec}"
), f"unknown image specification: {imagespec}" )
self.imagespec = imagespec.lower() self.imagespec = imagespec.lower()
def __call__(self, extension, data): def __call__(self, extension, data):
@ -205,18 +205,18 @@ class ImageHandler:
return img return img
elif atype == "numpy": elif atype == "numpy":
result = np.asarray(img) result = np.asarray(img)
assert ( assert result.dtype == np.uint8, (
result.dtype == np.uint8 f"numpy image array should be type uint8, but got {result.dtype}"
), f"numpy image array should be type uint8, but got {result.dtype}" )
if etype == "uint8": if etype == "uint8":
return result return result
else: else:
return result.astype("f") / 255.0 return result.astype("f") / 255.0
elif atype == "torch": elif atype == "torch":
result = np.asarray(img) result = np.asarray(img)
assert ( assert result.dtype == np.uint8, (
result.dtype == np.uint8 f"numpy image array should be type uint8, but got {result.dtype}"
), f"numpy image array should be type uint8, but got {result.dtype}" )
if etype == "uint8": if etype == "uint8":
result = np.array(result.transpose(2, 0, 1)) result = np.array(result.transpose(2, 0, 1))

View File

@ -96,7 +96,7 @@ class IterableDataset(Dataset[_T_co], Iterable[_T_co]):
>>> class MyIterableDataset(torch.utils.data.IterableDataset): >>> class MyIterableDataset(torch.utils.data.IterableDataset):
... def __init__(self, start, end): ... def __init__(self, start, end):
... super(MyIterableDataset).__init__() ... super(MyIterableDataset).__init__()
... assert end > start, "this example code only works with end >= start" ... assert end > start, "this example only works with end >= start"
... self.start = start ... self.start = start
... self.end = end ... self.end = end
... ...
@ -138,7 +138,7 @@ class IterableDataset(Dataset[_T_co], Iterable[_T_co]):
>>> class MyIterableDataset(torch.utils.data.IterableDataset): >>> class MyIterableDataset(torch.utils.data.IterableDataset):
... def __init__(self, start, end): ... def __init__(self, start, end):
... super(MyIterableDataset).__init__() ... super(MyIterableDataset).__init__()
... assert end > start, "this example code only works with end >= start" ... assert end > start, "this example only works with end >= start"
... self.start = start ... self.start = start
... self.end = end ... self.end = end
... ...
@ -198,9 +198,9 @@ class TensorDataset(Dataset[tuple[Tensor, ...]]):
tensors: tuple[Tensor, ...] tensors: tuple[Tensor, ...]
def __init__(self, *tensors: Tensor) -> None: def __init__(self, *tensors: Tensor) -> None:
assert all( assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), (
tensors[0].size(0) == tensor.size(0) for tensor in tensors "Size mismatch between tensors"
), "Size mismatch between tensors" )
self.tensors = tensors self.tensors = tensors
def __getitem__(self, index): def __getitem__(self, index):
@ -222,7 +222,7 @@ class StackDataset(Dataset[_T_stack]):
>>> tuple_stack = StackDataset(images, texts) >>> tuple_stack = StackDataset(images, texts)
>>> tuple_stack[0] == (images[0], texts[0]) >>> tuple_stack[0] == (images[0], texts[0])
>>> dict_stack = StackDataset(image=images, text=texts) >>> dict_stack = StackDataset(image=images, text=texts)
>>> dict_stack[0] == {'image': images[0], 'text': texts[0]} >>> dict_stack[0] == {"image": images[0], "text": texts[0]}
Args: Args:
*args (Dataset): Datasets for stacking returned as tuple. *args (Dataset): Datasets for stacking returned as tuple.
@ -323,9 +323,9 @@ class ConcatDataset(Dataset[_T_co]):
self.datasets = list(datasets) self.datasets = list(datasets)
assert len(self.datasets) > 0, "datasets should not be an empty iterable" # type: ignore[arg-type] assert len(self.datasets) > 0, "datasets should not be an empty iterable" # type: ignore[arg-type]
for d in self.datasets: for d in self.datasets:
assert not isinstance( assert not isinstance(d, IterableDataset), (
d, IterableDataset "ConcatDataset does not support IterableDataset"
), "ConcatDataset does not support IterableDataset" )
self.cumulative_sizes = self.cumsum(self.datasets) self.cumulative_sizes = self.cumsum(self.datasets)
def __len__(self): def __len__(self):
@ -371,17 +371,17 @@ class ChainDataset(IterableDataset):
def __iter__(self): def __iter__(self):
for d in self.datasets: for d in self.datasets:
assert isinstance( assert isinstance(d, IterableDataset), (
d, IterableDataset "ChainDataset only supports IterableDataset"
), "ChainDataset only supports IterableDataset" )
yield from d yield from d
def __len__(self): def __len__(self):
total = 0 total = 0
for d in self.datasets: for d in self.datasets:
assert isinstance( assert isinstance(d, IterableDataset), (
d, IterableDataset "ChainDataset only supports IterableDataset"
), "ChainDataset only supports IterableDataset" )
total += len(d) # type: ignore[arg-type] total += len(d) # type: ignore[arg-type]
return total return total

View File

@ -236,9 +236,17 @@ class WeightedRandomSampler(Sampler[int]):
Example: Example:
>>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) >>> list(
... WeightedRandomSampler(
... [0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True
... )
... )
[4, 4, 1, 4, 5] [4, 4, 1, 4, 5]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) >>> list(
... WeightedRandomSampler(
... [0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False
... )
... )
[0, 1, 4, 3, 2] [0, 1, 4, 3, 2]
""" """
@ -298,9 +306,15 @@ class BatchSampler(Sampler[list[int]]):
its size would be less than ``batch_size`` its size would be less than ``batch_size``
Example: Example:
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) >>> list(
... BatchSampler(
... SequentialSampler(range(10)), batch_size=3, drop_last=False
... )
... )
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) >>> list(
... BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)
... )
[[0, 1, 2], [3, 4, 5], [6, 7, 8]] [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
""" """

View File

@ -49,6 +49,7 @@ class ModuleTracker:
def my_linear(m1, m2, bias): def my_linear(m1, m2, bias):
print(f"Current modules: {tracker.parents}") print(f"Current modules: {tracker.parents}")
return torch.mm(m1, m2.t()) + bias return torch.mm(m1, m2.t()) + bias
torch.nn.functional.linear = my_linear torch.nn.functional.linear = my_linear
mod(torch.rand(2, 2)) mod(torch.rand(2, 2))

View File

@ -6,6 +6,7 @@ Intel GPU optimization.
This package is lazily initialized, so you can always import it, and use This package is lazily initialized, so you can always import it, and use
:func:`is_available()` to determine if your system supports XPU. :func:`is_available()` to determine if your system supports XPU.
""" """
import threading import threading
import traceback import traceback
from functools import lru_cache from functools import lru_cache
@ -292,6 +293,7 @@ class StreamContext:
``None``. ``None``.
.. note:: Streams are per-device. .. note:: Streams are per-device.
""" """
cur_stream: Optional["torch.xpu.Stream"] cur_stream: Optional["torch.xpu.Stream"]
def __init__(self, stream: Optional["torch.xpu.Stream"]): def __init__(self, stream: Optional["torch.xpu.Stream"]):
@ -438,7 +440,7 @@ def get_gencode_flags() -> str:
arch_list = get_arch_list() arch_list = get_arch_list()
if len(arch_list) == 0: if len(arch_list) == 0:
return "" return ""
return f'-device {",".join(arch for arch in arch_list)}' return f"-device {','.join(arch for arch in arch_list)}"
def _get_generator(device: torch.device) -> torch._C.Generator: def _get_generator(device: torch.device) -> torch._C.Generator: