[BE][PYFMT] migrate PYFMT for torch/[p-z]*/ to ruff format (#144552)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144552
Approved by: https://github.com/ezyang
This commit is contained in:
Xuehai Pan 2025-08-06 20:57:29 +00:00 committed by PyTorch MergeBot
parent fd606a3a91
commit 5cedc5a0ff
65 changed files with 446 additions and 522 deletions

View File

@ -52,7 +52,6 @@ USE_BLACK_FILELIST = re.compile(
# torch/[e-m]*/**
# torch/optim/**
# torch/[p-z]*/**
"torch/[p-z]*/**",
],
),
)

View File

@ -2,6 +2,7 @@
"""Import mangling.
See mangling.md for details.
"""
import re

View File

@ -605,9 +605,9 @@ class PackageExporter:
dependencies (bool, optional): If ``True``, we scan the source for dependencies.
"""
assert (pickle_protocol == 4) or (
pickle_protocol == 3
), "torch.package only supports pickle protocols 3 and 4"
assert (pickle_protocol == 4) or (pickle_protocol == 3), (
"torch.package only supports pickle protocols 3 and 4"
)
filename = self._filename(package, resource)
# Write the pickle data for `obj`

View File

@ -423,7 +423,12 @@ class PackageImporter(Importer):
module.__dict__.setdefault(old_name, new_name)
return module
return self._make_module(name, cur.source_file, isinstance(cur, _PackageNode), parent) # type: ignore[attr-defined]
return self._make_module(
name,
cur.source_file, # type: ignore[attr-defined]
isinstance(cur, _PackageNode),
parent,
)
def _compile_source(self, fullpath: str, mangled_filename: str):
source = self.zip_reader.get_record(fullpath)

View File

@ -7,6 +7,7 @@ examine their input shapes and stack traces, study device kernel activity and vi
An earlier version of the API in :mod:`torch.autograd` module is considered legacy and will be deprecated.
"""
import os
from typing import Any
from typing_extensions import TypeVarTuple, Unpack

View File

@ -239,10 +239,12 @@ class SchemaMatcher:
def match_schemas(cls, t: _ExtraFields_TorchOp) -> tuple[FunctionSchema, ...]:
signature = tuple(
# Tensor
TensorKey.from_tensor(i) if isinstance(i, _TensorMetadata)
TensorKey.from_tensor(i)
if isinstance(i, _TensorMetadata)
#
# TensorList
else [TensorKey.from_tensor(j) for j in i] if isinstance(i, list)
else [TensorKey.from_tensor(j) for j in i]
if isinstance(i, list)
#
# Scalar and uncaptured inputs.
else i

View File

@ -124,9 +124,9 @@ class BasicEvaluation:
for child_event in curr_event.children:
self_time -= child_event.duration_time_ns
stack.append(child_event)
assert (
EventKey(curr_event) not in self.metrics
), f"Duplicate id: {curr_event.id}, {curr_event.name}"
assert EventKey(curr_event) not in self.metrics, (
f"Duplicate id: {curr_event.id}, {curr_event.name}"
)
self.metrics[EventKey(curr_event)] = EventMetrics(self_time_ns=self_time)
self.metrics[
EventKey(curr_event)
@ -227,8 +227,7 @@ class BasicEvaluation:
while (
current_kernel_index < len(cuda_kernel_events)
and (cuda_kernel_events[current_kernel_index].start_ns())
<= start_time # type: ignore[possibly-undefined]
and (cuda_kernel_events[current_kernel_index].start_ns()) <= start_time # type: ignore[possibly-undefined]
):
current_kernel_index += 1
current_queue_depth = spawned_kernel_index - current_kernel_index + 1
@ -352,11 +351,11 @@ class BasicEvaluation:
output += "\n".join(
[
f"""{'-' * 80}
f"""{"-" * 80}
Event: {event}
Source code location: {source_code_location(event.event)}
Percentage idle time: {self.metrics[event].fraction_idle_time * 100:.2f}%
{'-' * 80}"""
{"-" * 80}"""
for event in event_list
]
)

View File

@ -624,8 +624,7 @@ class profile(_KinetoProfile):
]
) as p:
code_to_profile()
print(p.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))
print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
Using the profiler's ``schedule``, ``on_trace_ready`` and ``step`` functions:
@ -635,16 +634,17 @@ class profile(_KinetoProfile):
# on different iterations of the training loop;
# trace_handler is called every time a new trace becomes available
def trace_handler(prof):
print(prof.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))
print(
prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)
)
# prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json")
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
# In this example with wait=1, warmup=1, active=2, repeat=1,
# profiler will skip the first step/iteration,
# start warming up on the second, record
@ -652,20 +652,15 @@ class profile(_KinetoProfile):
# after which the trace will become available
# and on_trace_ready (when set) is called;
# the cycle repeats starting with the next step
schedule=torch.profiler.schedule(
wait=1,
warmup=1,
active=2,
repeat=1),
on_trace_ready=trace_handler
schedule=torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=1),
on_trace_ready=trace_handler,
# on_trace_ready=torch.profiler.tensorboard_trace_handler('./log')
# used when outputting for tensorboard
) as p:
for iter in range(N):
code_iteration_to_profile(iter)
# send a signal to the profiler that the next iteration has started
p.step()
) as p:
for iter in range(N):
code_iteration_to_profile(iter)
# send a signal to the profiler that the next iteration has started
p.step()
The following sample shows how to setup up an Execution Trace Observer (`execution_trace_observer`)

View File

@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
`torch/ao/quantization/fuser_method_mappings.py`, while adding an import statement
here.
"""
from torch.ao.quantization.fuser_method_mappings import (
_DEFAULT_OP_LIST_TO_FUSER_METHOD,
fuse_conv_bn,

View File

@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
here.
"""
from torch.ao.quantization.fx._equalize import (
_convert_equalization_ref,
_InputEqualizationObserver,

View File

@ -6,4 +6,5 @@ If you are adding a new entry/functionality, please, add it to the
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
here.
"""
from torch.ao.quantization.fx.convert import convert

View File

@ -6,4 +6,5 @@ If you are adding a new entry/functionality, please, add it to the
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
here.
"""
from torch.ao.quantization.fx.fuse import fuse

View File

@ -6,4 +6,5 @@ If you are adding a new entry/functionality, please, add it to the
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
here.
"""
from torch.ao.quantization.fx.fuse_handler import DefaultFuseHandler, FuseHandler

View File

@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
here.
"""
from torch.ao.quantization.fx.graph_module import (
_is_observed_module,
_is_observed_standalone_module,

View File

@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
here.
"""
from torch.ao.quantization.fx.match_utils import (
_find_matches,
_is_match,

View File

@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
here.
"""
from torch.ao.quantization.fx.pattern_utils import (
_register_fusion_pattern,
_register_quant_pattern,

View File

@ -6,4 +6,5 @@ If you are adding a new entry/functionality, please, add it to the
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
here.
"""
from torch.ao.quantization.fx.prepare import prepare

View File

@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
here.
"""
from torch.ao.quantization.fx.quantize_handler import (
BatchNormQuantizeHandler,
BinaryOpQuantizeHandler,

View File

@ -6,4 +6,5 @@ If you are adding a new entry/functionality, please, add it to the
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
here.
"""
from torch.ao.quantization.utils import Pattern, QuantizerCls

View File

@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
here.
"""
from torch.ao.quantization.fx.utils import (
all_node_args_have_no_tensors,
assert_and_get_unique_device,

View File

@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
`torch/ao/quantization/observer.py`, while adding an import statement
here.
"""
from torch.ao.quantization.observer import (
_is_activation_post_process,
_is_per_channel_script_obs_instance,

View File

@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
`torch/ao/quantization/qconfig.py`, while adding an import statement
here.
"""
from torch.ao.quantization.qconfig import (
_add_module_to_qconfig_obs_ctr,
_assert_valid_qconfig,

View File

@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
`torch/ao/quantization/quantization_mappings.py`, while adding an import statement
here.
"""
from torch.ao.quantization.quantization_mappings import (
_get_special_act_post_process,
_has_special_act_post_process,

View File

@ -128,9 +128,7 @@ Examples::
>>> # Generates a periodic exponential window and decay factor equal to .5
>>> torch.signal.windows.exponential(10, sym=False,tau=.5)
tensor([4.5400e-05, 3.3546e-04, 2.4788e-03, 1.8316e-02, 1.3534e-01, 1.0000e+00, 1.3534e-01, 1.8316e-02, 2.4788e-03, 3.3546e-04])
""".format(
**window_common_args
),
""".format(**window_common_args),
)
def exponential(
M: int,
@ -452,9 +450,7 @@ Examples::
>>> # Generates a periodic Hamming window.
>>> torch.signal.windows.hamming(10, sym=False)
tensor([0.0800, 0.1679, 0.3979, 0.6821, 0.9121, 1.0000, 0.9121, 0.6821, 0.3979, 0.1679])
""".format(
**window_common_args
),
""".format(**window_common_args),
)
def hamming(
M: int,
@ -508,9 +504,7 @@ Examples::
>>> # Generates a periodic Hann window.
>>> torch.signal.windows.hann(10, sym=False)
tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955])
""".format(
**window_common_args
),
""".format(**window_common_args),
)
def hann(
M: int,
@ -564,9 +558,7 @@ Examples::
>>> # Generates a periodic Blackman window.
>>> torch.signal.windows.blackman(5, sym=False)
tensor([-1.4901e-08, 2.0077e-01, 8.4923e-01, 8.4923e-01, 2.0077e-01])
""".format(
**window_common_args
),
""".format(**window_common_args),
)
def blackman(
M: int,
@ -627,9 +619,7 @@ Examples::
>>> # Generates a periodic Bartlett window.
>>> torch.signal.windows.bartlett(10, sym=False)
tensor([0.0000, 0.2000, 0.4000, 0.6000, 0.8000, 1.0000, 0.8000, 0.6000, 0.4000, 0.2000])
""".format(
**window_common_args
),
""".format(**window_common_args),
)
def bartlett(
M: int,
@ -704,9 +694,7 @@ Examples::
>>> # Generates a periodic general cosine window with 2 coefficients.
>>> torch.signal.windows.general_cosine(10, a=[0.5, 1 - 0.5], sym=False)
tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955])
""".format(
**window_common_args
),
""".format(**window_common_args),
)
def general_cosine(
M,
@ -799,9 +787,7 @@ Examples::
>>> # Generates a periodic Hann window with the general Hamming window.
>>> torch.signal.windows.general_hamming(10, alpha=0.5, sym=False)
tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955])
""".format(
**window_common_args
),
""".format(**window_common_args),
)
def general_hamming(
M,
@ -866,9 +852,7 @@ Examples::
>>> # Generates a periodic Nuttall window.
>>> torch.signal.windows.general_hamming(5, sym=False)
tensor([3.6280e-04, 1.1052e-01, 7.9826e-01, 7.9826e-01, 1.1052e-01])
""".format(
**window_common_args
),
""".format(**window_common_args),
)
def nuttall(
M: int,

View File

@ -559,7 +559,11 @@ def as_sparse_gradcheck(gradcheck):
For example:
>>> gradcheck = torch.sparse.as_sparse_gradcheck(torch.autograd.gradcheck)
>>> x = torch.tensor([[0, 1], [2, 3]], dtype=torch.float64).to_sparse_coo().requires_grad_(True)
>>> x = (
... torch.tensor([[0, 1], [2, 3]], dtype=torch.float64)
... .to_sparse_coo()
... .requires_grad_(True)
... )
>>> gradcheck(lambda x: x.to_sparse_csr(), x)
True
"""
@ -667,7 +671,7 @@ def as_sparse_gradcheck(gradcheck):
)
else:
raise NotImplementedError(
f'conversion of {d["layout"]} strided representation to tensor'
f"conversion of {d['layout']} strided representation to tensor"
)
new_args.append(a)
return tuple(new_args)

View File

@ -296,11 +296,11 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None):
for b in range(nbatches):
for i, r in enumerate(r_offsets):
r0, r1 = divmod(r, N)
acc = accumulators[b, r0:r0 + Ms, r1:r1 + Ns]
for g in range(c_indices[i], c_indices[i+1]):
acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns]
for g in range(c_indices[i], c_indices[i + 1]):
p = p_offsets[g]
q0, q1 = divmod(q_offsets[g], N)
acc += blocks[p] @ others[b, q0:q0 + Ks, q1:q1 + Ns]
acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns]
where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are
integer multiples of ``Ms`` and ``Ks``, respectively.
@ -320,11 +320,11 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None):
n = (r % N) // Ns
r0, r1 = divmod(r, N)
c0, c1 = c_indices[m], c_indices[m + 1]
acc = accumulators[b, r0:r0 + Ms, r1:r1 + Ns]
acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns]
for i, p in enumerate(range(c0, c1)):
q = q_offsets[n * c1 + (SPLIT_N - n) * c0 + i]
q0, q1 = divmod(q, N)
acc += blocks[p] @ others[b, q0:q0 + Ks, q1:q1 + Ns]
acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns]
where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are
integer multiples of ``Ms`` and ``Ks``, respectively.

View File

@ -97,6 +97,7 @@ tune_bsr_dense_addmm to learn how to register a custom set of optimal
kernel parameters for addmm-based operations.
"""
__all__ = ["get_meta", "tune_bsr_dense_addmm", "tune__int_bsr_dense_addmm"]
import inspect
@ -432,9 +433,9 @@ def minimize(
def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device):
assert (
sparsity <= 1.0 and sparsity >= 0.0
), "sparsity should be a value between 0 and 1"
assert sparsity <= 1.0 and sparsity >= 0.0, (
"sparsity should be a value between 0 and 1"
)
assert M % blocksize[0] == 0
assert N % blocksize[1] == 0
shape = (B, M // blocksize[0], N // blocksize[1])[int(B == 0) :]

View File

@ -465,14 +465,26 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
The equivalent PyTorch code to create the same five outputs from the dense tensor can be found below:
```
from torch.sparse import SparseSemiStructuredTensorCUTLASS
from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask
from torch.sparse._semi_structured_conversions import (
_sparse_semi_structured_tile,
_compute_compressed_swizzled_bitmask,
)
pruned = _sparse_semi_structured_tile(dense)
packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned)
packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous())
packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(
pruned.t().contiguous()
)
bitmask = _compute_compressed_swizzled_bitmask(pruned)
SparseSemiStructuredTensorCUTLASS(dense.shape, packed_cutlass, meta_cutlass, packed_t_cutlass, meta_t_cutlass, bitmask)
SparseSemiStructuredTensorCUTLASS(
dense.shape,
packed_cutlass,
meta_cutlass,
packed_t_cutlass,
meta_t_cutlass,
bitmask,
)
```
"""
# We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag.
@ -583,14 +595,19 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
The equivalent PyTorch code to create the same three outputs from the dense tensor can be found below:
```
from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask
from torch.sparse._semi_structured_conversions import (
_sparse_semi_structured_tile,
_compute_compressed_swizzled_bitmask,
)
pruned = _sparse_semi_structured_tile(dense)
packed_cusparselt = torch._cslt_compress(pruned)
packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous())
bitmask = _compute_compressed_swizzled_bitmask(pruned)
SparseSemiStructuredTensorCUSPARSELT(dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask)
SparseSemiStructuredTensorCUSPARSELT(
dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask
)
```
"""
(

View File

@ -134,9 +134,7 @@ Example::
>>> torch.special.digamma(a)
tensor([-0.5772, -1.9635])
""".format(
**common_args
),
""".format(**common_args),
)
gammaln = _add_docstr(
@ -162,9 +160,7 @@ Example::
>>> torch.special.gammaln(a)
tensor([ 0.5724, 0.0000, -0.1208])
""".format(
**common_args
),
""".format(**common_args),
)
polygamma = _add_docstr(
@ -200,9 +196,7 @@ Example::
tensor([ 6.4939, 97.4091])
>>> torch.special.polygamma(4, a)
tensor([ -24.8863, -771.4742])
""".format(
**common_args
),
""".format(**common_args),
)
erf = _add_docstr(
@ -226,9 +220,7 @@ Example::
>>> torch.special.erf(torch.tensor([0, -1., 10.]))
tensor([ 0.0000, -0.8427, 1.0000])
""".format(
**common_args
),
""".format(**common_args),
)
erfc = _add_docstr(
@ -253,9 +245,7 @@ Example::
>>> torch.special.erfc(torch.tensor([0, -1., 10.]))
tensor([ 1.0000, 1.8427, 0.0000])
""".format(
**common_args
),
""".format(**common_args),
)
erfcx = _add_docstr(
@ -283,9 +273,7 @@ Example::
>>> torch.special.erfcx(torch.tensor([0, -1., 10.]))
tensor([ 1.0000, 5.0090, 0.0561])
""".format(
**common_args
),
""".format(**common_args),
)
erfinv = _add_docstr(
@ -311,9 +299,7 @@ Example::
>>> torch.special.erfinv(torch.tensor([0, 0.5, -1.]))
tensor([ 0.0000, 0.4769, -inf])
""".format(
**common_args
),
""".format(**common_args),
)
logit = _add_docstr(
@ -351,9 +337,7 @@ Example::
tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516])
>>> torch.special.logit(a, eps=1e-6)
tensor([-0.9466, 2.6352, 0.6131, -1.7169, 0.6261])
""".format(
**common_args
),
""".format(**common_args),
)
logsumexp = _add_docstr(
@ -362,9 +346,7 @@ logsumexp = _add_docstr(
logsumexp(input, dim, keepdim=False, *, out=None)
Alias for :func:`torch.logsumexp`.
""".format(
**multi_dim_common
),
""".format(**multi_dim_common),
)
expit = _add_docstr(
@ -391,9 +373,7 @@ Example::
tensor([ 0.9213, 1.0887, -0.8858, -1.7683])
>>> torch.special.expit(t)
tensor([ 0.7153, 0.7481, 0.2920, 0.1458])
""".format(
**common_args
),
""".format(**common_args),
)
exp2 = _add_docstr(
@ -418,9 +398,7 @@ Example::
>>> torch.special.exp2(torch.tensor([0, math.log2(2.), 3, 4]))
tensor([ 1., 2., 8., 16.])
""".format(
**common_args
),
""".format(**common_args),
)
expm1 = _add_docstr(
@ -448,9 +426,7 @@ Example::
>>> torch.special.expm1(torch.tensor([0, math.log(2.)]))
tensor([ 0., 1.])
""".format(
**common_args
),
""".format(**common_args),
)
xlog1py = _add_docstr(
@ -495,9 +471,7 @@ Example::
tensor([1.6094, 3.2189, 4.8283])
>>> torch.special.xlog1py(2, y)
tensor([2.7726, 2.1972, 1.3863])
""".format(
**common_args
),
""".format(**common_args),
)
xlogy = _add_docstr(
@ -542,9 +516,7 @@ Example::
tensor([1.3863, 2.7726, 4.1589])
>>> torch.special.xlogy(2, y)
tensor([2.1972, 1.3863, 0.0000])
""".format(
**common_args
),
""".format(**common_args),
)
i0 = _add_docstr(
@ -570,9 +542,7 @@ Example::
>>> torch.i0(torch.arange(5, dtype=torch.float32))
tensor([ 1.0000, 1.2661, 2.2796, 4.8808, 11.3019])
""".format(
**common_args
),
""".format(**common_args),
)
i0e = _add_docstr(
@ -597,9 +567,7 @@ Example::
>>> torch.special.i0e(torch.arange(5, dtype=torch.float32))
tensor([1.0000, 0.4658, 0.3085, 0.2430, 0.2070])
""".format(
**common_args
),
""".format(**common_args),
)
i1 = _add_docstr(
@ -624,9 +592,7 @@ Example::
>>> torch.special.i1(torch.arange(5, dtype=torch.float32))
tensor([0.0000, 0.5652, 1.5906, 3.9534, 9.7595])
""".format(
**common_args
),
""".format(**common_args),
)
i1e = _add_docstr(
@ -652,9 +618,7 @@ Example::
>>> torch.special.i1e(torch.arange(5, dtype=torch.float32))
tensor([0.0000, 0.2079, 0.2153, 0.1968, 0.1788])
""".format(
**common_args
),
""".format(**common_args),
)
ndtr = _add_docstr(
@ -679,9 +643,7 @@ Example::
>>> torch.special.ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3]))
tensor([0.0013, 0.0228, 0.1587, 0.5000, 0.8413, 0.9772, 0.9987])
""".format(
**common_args
),
""".format(**common_args),
)
ndtri = _add_docstr(
@ -709,9 +671,7 @@ Example::
>>> torch.special.ndtri(torch.tensor([0, 0.25, 0.5, 0.75, 1]))
tensor([ -inf, -0.6745, 0.0000, 0.6745, inf])
""".format(
**common_args
),
""".format(**common_args),
)
log_ndtr = _add_docstr(
@ -736,9 +696,7 @@ Example::
>>> torch.special.log_ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3]))
tensor([-6.6077 -3.7832 -1.841 -0.6931 -0.1728 -0.023 -0.0014])
""".format(
**common_args
),
""".format(**common_args),
)
log1p = _add_docstr(
@ -779,9 +737,7 @@ Example::
tensor([ 0.2252, -0.2948, 1.0267, -1.1566])
>>> torch.special.sinc(t)
tensor([ 0.9186, 0.8631, -0.0259, -0.1300])
""".format(
**common_args
),
""".format(**common_args),
)
round = _add_docstr(
@ -886,9 +842,7 @@ Example::
tensor([1.6449, 0.0823])
>>> torch.special.zeta(2, torch.tensor([1., 2.]))
tensor([1.6449, 0.6449])
""".format(
**common_args
),
""".format(**common_args),
)
multigammaln = _add_docstr(
@ -925,9 +879,7 @@ Example::
>>> torch.special.multigammaln(a, 2)
tensor([[0.3928, 0.4007, 0.7586],
[1.0311, 0.3901, 0.5049]])
""".format(
**common_args
),
""".format(**common_args),
)
gammainc = _add_docstr(
@ -976,9 +928,7 @@ Example::
>>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2)
tensor([1., 1., 1.])
""".format(
**common_args
),
""".format(**common_args),
)
gammaincc = _add_docstr(
@ -1026,9 +976,7 @@ Example::
>>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2)
tensor([1., 1., 1.])
""".format(
**common_args
),
""".format(**common_args),
)
airy_ai = _add_docstr(
@ -1045,9 +993,7 @@ Args:
Keyword args:
{out}
""".format(
**common_args
),
""".format(**common_args),
)
bessel_j0 = _add_docstr(
@ -1064,9 +1010,7 @@ Args:
Keyword args:
{out}
""".format(
**common_args
),
""".format(**common_args),
)
bessel_j1 = _add_docstr(
@ -1083,9 +1027,7 @@ Args:
Keyword args:
{out}
""".format(
**common_args
),
""".format(**common_args),
)
bessel_y0 = _add_docstr(
@ -1102,9 +1044,7 @@ Args:
Keyword args:
{out}
""".format(
**common_args
),
""".format(**common_args),
)
bessel_y1 = _add_docstr(
@ -1121,9 +1061,7 @@ Args:
Keyword args:
{out}
""".format(
**common_args
),
""".format(**common_args),
)
chebyshev_polynomial_t = _add_docstr(
@ -1154,9 +1092,7 @@ Args:
Keyword args:
{out}
""".format(
**common_args
),
""".format(**common_args),
)
chebyshev_polynomial_u = _add_docstr(
@ -1188,9 +1124,7 @@ Args:
Keyword args:
{out}
""".format(
**common_args
),
""".format(**common_args),
)
chebyshev_polynomial_v = _add_docstr(
@ -1208,9 +1142,7 @@ Args:
Keyword args:
{out}
""".format(
**common_args
),
""".format(**common_args),
)
chebyshev_polynomial_w = _add_docstr(
@ -1228,9 +1160,7 @@ Args:
Keyword args:
{out}
""".format(
**common_args
),
""".format(**common_args),
)
hermite_polynomial_h = _add_docstr(
@ -1256,9 +1186,7 @@ Args:
Keyword args:
{out}
""".format(
**common_args
),
""".format(**common_args),
)
hermite_polynomial_he = _add_docstr(
@ -1284,9 +1212,7 @@ Args:
Keyword args:
{out}
""".format(
**common_args
),
""".format(**common_args),
)
laguerre_polynomial_l = _add_docstr(
@ -1312,9 +1238,7 @@ Args:
Keyword args:
{out}
""".format(
**common_args
),
""".format(**common_args),
)
legendre_polynomial_p = _add_docstr(
@ -1340,9 +1264,7 @@ Args:
Keyword args:
{out}
""".format(
**common_args
),
""".format(**common_args),
)
modified_bessel_i0 = _add_docstr(
@ -1359,9 +1281,7 @@ Args:
Keyword args:
{out}
""".format(
**common_args
),
""".format(**common_args),
)
modified_bessel_i1 = _add_docstr(
@ -1378,9 +1298,7 @@ Args:
Keyword args:
{out}
""".format(
**common_args
),
""".format(**common_args),
)
modified_bessel_k0 = _add_docstr(
@ -1397,9 +1315,7 @@ Args:
Keyword args:
{out}
""".format(
**common_args
),
""".format(**common_args),
)
modified_bessel_k1 = _add_docstr(
@ -1416,9 +1332,7 @@ Args:
Keyword args:
{out}
""".format(
**common_args
),
""".format(**common_args),
)
scaled_modified_bessel_k0 = _add_docstr(
@ -1435,9 +1349,7 @@ Args:
Keyword args:
{out}
""".format(
**common_args
),
""".format(**common_args),
)
scaled_modified_bessel_k1 = _add_docstr(
@ -1454,9 +1366,7 @@ Args:
Keyword args:
{out}
""".format(
**common_args
),
""".format(**common_args),
)
shifted_chebyshev_polynomial_t = _add_docstr(
@ -1474,9 +1384,7 @@ Args:
Keyword args:
{out}
""".format(
**common_args
),
""".format(**common_args),
)
shifted_chebyshev_polynomial_u = _add_docstr(
@ -1494,9 +1402,7 @@ Args:
Keyword args:
{out}
""".format(
**common_args
),
""".format(**common_args),
)
shifted_chebyshev_polynomial_v = _add_docstr(
@ -1514,9 +1420,7 @@ Args:
Keyword args:
{out}
""".format(
**common_args
),
""".format(**common_args),
)
shifted_chebyshev_polynomial_w = _add_docstr(
@ -1534,9 +1438,7 @@ Args:
Keyword args:
{out}
""".format(
**common_args
),
""".format(**common_args),
)
spherical_bessel_j0 = _add_docstr(
@ -1553,7 +1455,5 @@ Args:
Keyword args:
{out}
""".format(
**common_args
),
""".format(**common_args),
)

View File

@ -1538,7 +1538,9 @@ def assert_close(
>>> expected = torch.tensor([1.0, 2.0, 3.0])
>>> actual = torch.tensor([1.0, 4.0, 5.0])
>>> # The default error message can be overwritten.
>>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!")
>>> torch.testing.assert_close(
... actual, expected, msg="Argh, the tensors are not close!"
... )
Traceback (most recent call last):
...
AssertionError: Argh, the tensors are not close!

View File

@ -115,11 +115,11 @@ def make_tensor(
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> from torch.testing import make_tensor
>>> # Creates a float tensor with values in [-1, 1)
>>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1)
>>> make_tensor((3,), device="cpu", dtype=torch.float32, low=-1, high=1)
>>> # xdoctest: +SKIP
tensor([ 0.1205, 0.2282, -0.6380])
>>> # Creates a bool tensor on CUDA
>>> make_tensor((2, 2), device='cuda', dtype=torch.bool)
>>> make_tensor((2, 2), device="cuda", dtype=torch.bool)
tensor([[False, False],
[False, True]], device='cuda:0')
"""

View File

@ -721,9 +721,9 @@ def filter_desired_device_types(device_type_test_bases, except_for=None, only_fo
intersect = set(except_for if except_for else []) & set(
only_for if only_for else []
)
assert (
not intersect
), f"device ({intersect}) appeared in both except_for and only_for"
assert not intersect, (
f"device ({intersect}) appeared in both except_for and only_for"
)
# Replace your privateuse1 backend name with 'privateuse1'
if is_privateuse1_backend_available():
@ -1407,9 +1407,9 @@ class deviceCountAtLeast:
self.num_required_devices = num_required_devices
def __call__(self, fn):
assert not hasattr(
fn, "num_required_devices"
), f"deviceCountAtLeast redefinition for {fn.__name__}"
assert not hasattr(fn, "num_required_devices"), (
f"deviceCountAtLeast redefinition for {fn.__name__}"
)
fn.num_required_devices = self.num_required_devices
@wraps(fn)
@ -1474,13 +1474,13 @@ def onlyNativeDeviceTypesAnd(devices=None):
# self.precision *2, max(1, self.precision)).
class precisionOverride:
def __init__(self, d):
assert isinstance(
d, dict
), "precisionOverride not given a dtype : precision dict!"
assert isinstance(d, dict), (
"precisionOverride not given a dtype : precision dict!"
)
for dtype in d.keys():
assert isinstance(
dtype, torch.dtype
), f"precisionOverride given unknown dtype {dtype}"
assert isinstance(dtype, torch.dtype), (
f"precisionOverride given unknown dtype {dtype}"
)
self.d = d
@ -1513,12 +1513,12 @@ class toleranceOverride:
def __init__(self, d):
assert isinstance(d, dict), "toleranceOverride not given a dtype : tol dict!"
for dtype, prec in d.items():
assert isinstance(
dtype, torch.dtype
), f"toleranceOverride given unknown dtype {dtype}"
assert isinstance(
prec, tol
), "toleranceOverride not given a dtype : tol dict!"
assert isinstance(dtype, torch.dtype), (
f"toleranceOverride given unknown dtype {dtype}"
)
assert isinstance(prec, tol), (
"toleranceOverride not given a dtype : tol dict!"
)
self.d = d
@ -1546,13 +1546,13 @@ class dtypes:
"all dtype variants must be. "
f"Received non-list non-tuple dtype {str(arg)}"
)
assert all(
isinstance(dtype, torch.dtype) for dtype in arg
), f"Unknown dtype in {str(arg)}"
assert all(isinstance(dtype, torch.dtype) for dtype in arg), (
f"Unknown dtype in {str(arg)}"
)
else:
assert all(
isinstance(arg, torch.dtype) for arg in args
), f"Unknown dtype in {str(args)}"
assert all(isinstance(arg, torch.dtype) for arg in args), (
f"Unknown dtype in {str(args)}"
)
self.args = args
self.device_type = device_type

View File

@ -253,9 +253,9 @@ def verify_ddp_error_logged(model_DDP, err_substr):
if err_substr.find("\nException raised from ") == -1
else err_substr.split("\nException raised from ")[0]
)
assert (
actual in logging_err
), f"Did not find expected {actual} in ddp logging data error: {logging_err}"
assert actual in logging_err, (
f"Did not find expected {actual} in ddp logging data error: {logging_err}"
)
def with_nccl_blocking_wait(func):
@ -294,9 +294,9 @@ def with_nccl_blocking_wait(func):
finally:
# restore old values.
if cached_nccl_async_error_handling is not None:
os.environ[
"TORCH_NCCL_ASYNC_ERROR_HANDLING"
] = cached_nccl_async_error_handling
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = (
cached_nccl_async_error_handling
)
if cached_nccl_blocking_wait is not None:
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = cached_nccl_blocking_wait
@ -812,7 +812,7 @@ class MultiProcessTestCase(TestCase):
sys.exit(TEST_SKIPS["generic"].exit_code)
except Exception:
logger.error(
"Caught exception: \n%s exiting " "process %s with exit code: %s",
"Caught exception: \n%s exiting process %s with exit code: %s",
traceback.format_exc(),
self.rank,
MultiProcessTestCase.TEST_ERROR_EXIT_CODE,
@ -1689,9 +1689,7 @@ class MultiProcContinousTest(TestCase):
cls.processes.append(process)
cls.task_queues.append(task_queue)
cls.completion_queues.append(completion_queue)
logger.info(
"Started process %s with pid %s", rank, process.pid
) # noqa: UP031
logger.info("Started process %s with pid %s", rank, process.pid) # noqa: UP031
@classmethod
def setUpClass(cls):

View File

@ -1285,10 +1285,10 @@ class FSDPTest(MultiProcessTestCase):
loss = sharded_grad_scaler.scale(loss)
if not mixed_precision and not use_pure_fp16:
assert (
loss.dtype == torch.float32
), "loss data type should be float32, as the original \
assert loss.dtype == torch.float32, (
"loss data type should be float32, as the original \
parameter data type is float32."
)
else:
if use_pure_fp16:
self.assertEqual(loss.dtype, torch.float16)
@ -1354,9 +1354,9 @@ class FSDPTest(MultiProcessTestCase):
wrapper should provide data parallel semantics. If ``None``,
then the callable defaults to the DDP constructor.
"""
assert (
fsdp_init_mode != FSDPInitMode.NO_FSDP
), "Expects an FSDP init mode that wraps with FSDP"
assert fsdp_init_mode != FSDPInitMode.NO_FSDP, (
"Expects an FSDP init mode that wraps with FSDP"
)
if init_kwargs is None:
init_kwargs = {}
lr = 1e-2

View File

@ -1268,9 +1268,9 @@ def _get_optim_inputs_including_global_cliquey_kwargs(
trivial. That said, we sometimes want to test for all possible configs on an
optimizer including all supported flags, so this helper returns all optim inputs.
"""
assert all(
x in ["foreach", "fused", "differentiable"] for x in skip
), "skip must be a subset of ['foreach', 'fused', 'differentiable']"
assert all(x in ["foreach", "fused", "differentiable"] for x in skip), (
"skip must be a subset of ['foreach', 'fused', 'differentiable']"
)
optim_inputs = optim_info.optim_inputs_func(device)

View File

@ -477,7 +477,9 @@ def with_comms(
def decorator(func, eager_init: bool = False, backend: Optional[str] = None):
@wraps(func) # pyre-ignore[6]
def wrapper(
self, *args: tuple[object], **kwargs: dict[str, Any] # type: ignore[misc]
self,
*args: tuple[object],
**kwargs: dict[str, Any], # type: ignore[misc]
) -> None:
self.init_pg(eager_init, backend)

View File

@ -253,7 +253,11 @@ class Trainer:
else:
input_batches = batches
with self.hybrid_module.join() if simulate_uneven_inputs else contextlib.nullcontext():
with (
self.hybrid_module.join()
if simulate_uneven_inputs
else contextlib.nullcontext()
):
for b in input_batches:
with dist_autograd.context() as context_id:
output = self.hybrid_module.forward(b)
@ -261,8 +265,7 @@ class Trainer:
dist_autograd.backward(context_id, [loss])
grads_dict = dist_autograd.get_gradients(context_id)
gLogger.info(
"Loss is %s for mini batch: %s. "
"Grads dict has %s entries: %s",
"Loss is %s for mini batch: %s. Grads dict has %s entries: %s",
loss,
mini_batch,
len(grads_dict),

View File

@ -162,9 +162,7 @@ class SampleInput:
# Allow calling either as SampleInput(input, args=args, kwargs=kwargs), or as
# SampleInput(input, *args, **kwargs) but not to mix the two forms
if args is not None or kwargs is not None:
assert (
not var_args and not var_kwargs
), """
assert not var_args and not var_kwargs, """
A SampleInput can be constructed "naturally" with *args and **kwargs or by
explicitly setting the "args" and "kwargs" parameters, but the two
methods of construction cannot be mixed!"""
@ -226,7 +224,7 @@ cannot specify additional metadata in keyword arguments"""
f"name={repr(self.name)}",
]
return f'SampleInput({", ".join(a for a in arguments if a is not None)})'
return f"SampleInput({', '.join(a for a in arguments if a is not None)})"
def __repr__(self):
return self._repr_helper(lambda x: x)
@ -1601,13 +1599,11 @@ class SampleRule(ABC):
# returns a string identifier of the rule type
@abstractmethod
def type(self) -> str:
...
def type(self) -> str: ...
# returns an appropriate context that handles the xfail, skips, etc.
@abstractmethod
def get_context(self, test_case):
...
def get_context(self, test_case): ...
# useful for specifying xfails
@ -1791,8 +1787,10 @@ class ReductionOpInfo(OpInfo):
# kwargs to use when calling the op. This is required for operators that
# have other required parameters besides the input tensor.
generate_args_kwargs: Callable = lambda t, dim=None, keepdim=False: (
yield (),
{},
yield (
(),
{},
)
),
# Options from the OpInfo base class
**kwargs,
@ -2476,9 +2474,9 @@ class BinaryUfuncInfo(OpInfo):
self.supports_one_python_scalar = True
if self.supports_one_python_scalar:
assert (
supports_rhs_python_scalar
), "Can't support lhs and rhs Python scalars but not rhs scalars!"
assert supports_rhs_python_scalar, (
"Can't support lhs and rhs Python scalars but not rhs scalars!"
)
# The following functions and classes are for testing elementwise unary operators.

View File

@ -102,8 +102,9 @@ def sample_inputs_masked_reduction(op_info, device, dtype, requires_grad, **kwar
for mask in _generate_masked_op_mask(
sample_input.input.shape, device, **kwargs
):
sample_input_args, sample_input_kwargs = sample_input.args, dict(
mask=mask, **sample_input.kwargs
sample_input_args, sample_input_kwargs = (
sample_input.args,
dict(mask=mask, **sample_input.kwargs),
)
yield SampleInput(
sample_input.input.detach().requires_grad_(requires_grad),
@ -224,8 +225,9 @@ def sample_inputs_masked_norm(op_info, device, dtype, requires_grad, **kwargs):
op_info, device, dtype, requires_grad, **kwargs
):
sample_input_args, sample_input_kwargs = (
ord,
) + sample_input.args, sample_input.kwargs.copy()
(ord,) + sample_input.args,
sample_input.kwargs.copy(),
)
yield SampleInput(
sample_input.input.clone().requires_grad_(requires_grad),
args=sample_input_args,
@ -276,8 +278,9 @@ def sample_inputs_masked_std_var(op_info, device, dtype, requires_grad, **kwargs
for mask in _generate_masked_op_mask(
sample_input.input.shape, device, **kwargs
):
sample_input_args, sample_input_kwargs = sample_input.args, dict(
mask=mask, **sample_input.kwargs
sample_input_args, sample_input_kwargs = (
sample_input.args,
dict(mask=mask, **sample_input.kwargs),
)
yield SampleInput(
sample_input.input.detach().requires_grad_(requires_grad),
@ -364,8 +367,9 @@ def sample_inputs_masked_cumops(op_info, device, dtype, requires_grad, **kwargs)
):
if type(mask) != torch.Tensor:
continue
sample_input_args, sample_input_kwargs = sample_input.args, dict(
mask=mask, **sample_input.kwargs
sample_input_args, sample_input_kwargs = (
sample_input.args,
dict(mask=mask, **sample_input.kwargs),
)
if "keepdim" in sample_input_kwargs:
sample_input_kwargs.pop("keepdim")

View File

@ -112,7 +112,7 @@ class _Config(Generic[T]):
@staticmethod
def string_or_list_of_string_to_list(
val: Optional[Union[str, list[str]]]
val: Optional[Union[str, list[str]]],
) -> Optional[list[str]]:
if val is None:
return None
@ -135,8 +135,7 @@ if TYPE_CHECKING:
env_name_force: Optional[Union[str, list[str]]] = None,
value_type: Optional[type] = None,
alias: Optional[str] = None,
) -> T:
...
) -> T: ...
else:
@ -323,9 +322,9 @@ class _ConfigEntry:
# Ensure justknobs and envvars are allowlisted types
if self.justknob is not None and self.default is not None:
assert isinstance(
self.default, bool
), f"justknobs only support booleans, {self.default} is not a boolean"
assert isinstance(self.default, bool), (
f"justknobs only support booleans, {self.default} is not a boolean"
)
if self.value_type is not None and (
config.env_name_default is not None or config.env_name_force is not None
):
@ -334,7 +333,9 @@ class _ConfigEntry:
str,
Optional[bool],
Optional[str],
), f"envvar configs only support (optional) booleans or strings, {self.value_type} is neither"
), (
f"envvar configs only support (optional) booleans or strings, {self.value_type} is neither"
)
class ConfigModule(ModuleType):

View File

@ -282,9 +282,9 @@ def tree_is_leaf(
False
>>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple))
True
>>> tree_is_leaf({'a': 1, 'b': 2, 'c': 3})
>>> tree_is_leaf({"a": 1, "b": 2, "c": 3})
False
>>> tree_is_leaf({'a': 1, 'b': 2, 'c': None})
>>> tree_is_leaf({"a": 1, "b": 2, "c": None})
False
Args:
@ -586,29 +586,28 @@ MapOnlyFn = Callable[[T], Callable[[Any], Any]]
# These specializations help with type inference on the lambda passed to this
# function
@overload
def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]:
...
def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: ...
@overload
def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]:
...
def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: ...
@overload
def map_only(type_or_types_or_pred: Type3[T, S, U], /) -> MapOnlyFn[Fn3[T, S, U, Any]]:
...
def map_only(
type_or_types_or_pred: Type3[T, S, U], /
) -> MapOnlyFn[Fn3[T, S, U, Any]]: ...
# This specialization is needed for the implementations below that call
@overload
def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]:
...
def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: ...
@overload
def map_only(type_or_types_or_pred: Callable[[Any], bool], /) -> MapOnlyFn[FnAny[Any]]:
...
def map_only(
type_or_types_or_pred: Callable[[Any], bool], /
) -> MapOnlyFn[FnAny[Any]]: ...
def map_only(
@ -664,8 +663,7 @@ def tree_map_only(
func: Fn[T, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
...
) -> PyTree: ...
@overload
@ -675,8 +673,7 @@ def tree_map_only(
func: Fn2[T, S, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
...
) -> PyTree: ...
@overload
@ -686,8 +683,7 @@ def tree_map_only(
func: Fn3[T, S, U, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
...
) -> PyTree: ...
@overload
@ -697,8 +693,7 @@ def tree_map_only(
func: FnAny[Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
...
) -> PyTree: ...
@overload
@ -708,8 +703,7 @@ def tree_map_only(
func: FnAny[Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
...
) -> PyTree: ...
def tree_map_only(
@ -729,8 +723,7 @@ def tree_map_only_(
func: Fn[T, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
...
) -> PyTree: ...
@overload
@ -740,8 +733,7 @@ def tree_map_only_(
func: Fn2[T, S, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
...
) -> PyTree: ...
@overload
@ -751,8 +743,7 @@ def tree_map_only_(
func: Fn3[T, S, U, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
...
) -> PyTree: ...
@overload
@ -762,8 +753,7 @@ def tree_map_only_(
func: FnAny[Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
...
) -> PyTree: ...
@overload
@ -773,8 +763,7 @@ def tree_map_only_(
func: FnAny[Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
...
) -> PyTree: ...
def tree_map_only_(
@ -812,8 +801,7 @@ def tree_all_only(
pred: Fn[T, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
...
) -> bool: ...
@overload
@ -823,8 +811,7 @@ def tree_all_only(
pred: Fn2[T, S, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
...
) -> bool: ...
@overload
@ -834,8 +821,7 @@ def tree_all_only(
pred: Fn3[T, S, U, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
...
) -> bool: ...
def tree_all_only(
@ -856,8 +842,7 @@ def tree_any_only(
pred: Fn[T, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
...
) -> bool: ...
@overload
@ -867,8 +852,7 @@ def tree_any_only(
pred: Fn2[T, S, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
...
) -> bool: ...
@overload
@ -878,8 +862,7 @@ def tree_any_only(
pred: Fn3[T, S, U, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
...
) -> bool: ...
def tree_any_only(

View File

@ -12,7 +12,7 @@ _cache_sentinel = object()
def cache_method(
f: Callable[Concatenate[_C, _P], _T]
f: Callable[Concatenate[_C, _P], _T],
) -> Callable[Concatenate[_C, _P], _T]:
"""
Like `@functools.cache` but for methods.

View File

@ -302,14 +302,12 @@ class BaseTorchDispatchMode(TorchDispatchMode):
# Subtypes which have __tensor_flatten__ and __tensor_unflatten__.
class TensorWithFlatten(Protocol):
def __tensor_flatten__(self) -> tuple[Sequence[str], object]:
...
def __tensor_flatten__(self) -> tuple[Sequence[str], object]: ...
@staticmethod
def __tensor_unflatten__(
inner_tensors: int, flatten_spec: int, outer_size: int, outer_stride: int
) -> torch.Tensor:
...
) -> torch.Tensor: ...
# It would be really nice to be able to say that the return of
# is_traceable_wrapper_subclass() is Intersection[torch.Tensor,
@ -318,26 +316,20 @@ class TensorWithFlatten(Protocol):
shape: torch._C.Size
@overload
def stride(self, dim: None = None) -> tuple[int, ...]:
...
def stride(self, dim: None = None) -> tuple[int, ...]: ...
@overload
def stride(self, dim: int) -> int:
...
def stride(self, dim: int) -> int: ...
@overload
def size(self, dim: None = None) -> tuple[int, ...]:
...
def size(self, dim: None = None) -> tuple[int, ...]: ...
@overload
def size(self, dim: int) -> int:
...
def size(self, dim: int) -> int: ...
def storage_offset(self) -> int:
...
def storage_offset(self) -> int: ...
def dim(self) -> int:
...
def dim(self) -> int: ...
@overload
def to(
@ -347,8 +339,7 @@ class TensorWithFlatten(Protocol):
copy: bool = False,
*,
memory_format: Optional[torch.memory_format] = None,
) -> torch.Tensor:
...
) -> torch.Tensor: ...
@overload
def to(
@ -359,8 +350,7 @@ class TensorWithFlatten(Protocol):
copy: bool = False,
*,
memory_format: Optional[torch.memory_format] = None,
) -> torch.Tensor:
...
) -> torch.Tensor: ...
@overload
def to(
@ -370,8 +360,7 @@ class TensorWithFlatten(Protocol):
copy: bool = False,
*,
memory_format: Optional[torch.memory_format] = None,
) -> torch.Tensor:
...
) -> torch.Tensor: ...
def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]:

View File

@ -99,17 +99,13 @@ NO_SERIALIZED_TYPE_NAME_FOUND = "NO_SERIALIZED_TYPE_NAME_FOUND"
class KeyEntry(Protocol):
def __hash__(self) -> int:
...
def __hash__(self) -> int: ...
def __eq__(self, other: object) -> bool:
...
def __eq__(self, other: object) -> bool: ...
def __str__(self) -> str:
...
def __str__(self) -> str: ...
def get(self, parent: Any) -> Any:
...
def get(self, parent: Any) -> Any: ...
class EnumEncoder(json.JSONEncoder):
@ -757,7 +753,7 @@ def _tuple_flatten(d: tuple[T, ...]) -> tuple[list[T], Context]:
def _tuple_flatten_with_keys(
d: tuple[T, ...]
d: tuple[T, ...],
) -> tuple[list[tuple[KeyEntry, T]], Context]:
values, context = _tuple_flatten(d)
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
@ -785,7 +781,7 @@ def _dict_flatten(d: dict[Any, T]) -> tuple[list[T], Context]:
def _dict_flatten_with_keys(
d: dict[Any, T]
d: dict[Any, T],
) -> tuple[list[tuple[KeyEntry, T]], Context]:
values, context = _dict_flatten(d)
return [(MappingKey(k), v) for k, v in zip(context, values)], context
@ -849,7 +845,7 @@ def _ordereddict_flatten(d: OrderedDict[Any, T]) -> tuple[list[T], Context]:
def _ordereddict_flatten_with_keys(
d: OrderedDict[Any, T]
d: OrderedDict[Any, T],
) -> tuple[list[tuple[KeyEntry, T]], Context]:
values, context = _ordereddict_flatten(d)
return [(MappingKey(k), v) for k, v in zip(context, values)], context
@ -872,7 +868,7 @@ def _defaultdict_flatten(d: defaultdict[Any, T]) -> tuple[list[T], Context]:
def _defaultdict_flatten_with_keys(
d: defaultdict[Any, T]
d: defaultdict[Any, T],
) -> tuple[list[tuple[KeyEntry, T]], Context]:
values, context = _defaultdict_flatten(d)
_, dict_context = context
@ -1035,9 +1031,9 @@ def tree_is_leaf(
False
>>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple))
True
>>> tree_is_leaf({'a': 1, 'b': 2, 'c': 3})
>>> tree_is_leaf({"a": 1, "b": 2, "c": 3})
False
>>> tree_is_leaf({'a': 1, 'b': 2, 'c': None})
>>> tree_is_leaf({"a": 1, "b": 2, "c": None})
False
"""
if is_leaf is not None and is_leaf(tree):
@ -1346,9 +1342,9 @@ def tree_map(
See also :func:`tree_map_`.
>>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)})
>>> tree_map(lambda x: x + 1, {"x": 7, "y": (42, 64)})
{'x': 8, 'y': (43, 65)}
>>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None})
>>> tree_map(lambda x: x is None, {"x": 7, "y": (42, 64), "z": None})
{'x': False, 'y': (False, False), 'z': True}
If multiple inputs are given, the structure of the tree is taken from the first input;
@ -1432,29 +1428,28 @@ MapOnlyFn = Callable[[T], Callable[[Any], Any]]
# These specializations help with type inference on the lambda passed to this
# function
@overload
def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]:
...
def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: ...
@overload
def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]:
...
def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: ...
@overload
def map_only(type_or_types_or_pred: Type3[T, S, U], /) -> MapOnlyFn[Fn3[T, S, U, Any]]:
...
def map_only(
type_or_types_or_pred: Type3[T, S, U], /
) -> MapOnlyFn[Fn3[T, S, U, Any]]: ...
# This specialization is needed for the implementations below that call
@overload
def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]:
...
def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: ...
@overload
def map_only(type_or_types_or_pred: Callable[[Any], bool], /) -> MapOnlyFn[FnAny[Any]]:
...
def map_only(
type_or_types_or_pred: Callable[[Any], bool], /
) -> MapOnlyFn[FnAny[Any]]: ...
def map_only(
@ -1510,8 +1505,7 @@ def tree_map_only(
func: Fn[T, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
...
) -> PyTree: ...
@overload
@ -1521,8 +1515,7 @@ def tree_map_only(
func: Fn2[T, S, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
...
) -> PyTree: ...
@overload
@ -1532,8 +1525,7 @@ def tree_map_only(
func: Fn3[T, S, U, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
...
) -> PyTree: ...
@overload
@ -1543,8 +1535,7 @@ def tree_map_only(
func: FnAny[Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
...
) -> PyTree: ...
@overload
@ -1554,8 +1545,7 @@ def tree_map_only(
func: FnAny[Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
...
) -> PyTree: ...
def tree_map_only(
@ -1575,8 +1565,7 @@ def tree_map_only_(
func: Fn[T, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
...
) -> PyTree: ...
@overload
@ -1586,8 +1575,7 @@ def tree_map_only_(
func: Fn2[T, S, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
...
) -> PyTree: ...
@overload
@ -1597,8 +1585,7 @@ def tree_map_only_(
func: Fn3[T, S, U, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
...
) -> PyTree: ...
@overload
@ -1608,8 +1595,7 @@ def tree_map_only_(
func: FnAny[Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
...
) -> PyTree: ...
@overload
@ -1619,8 +1605,7 @@ def tree_map_only_(
func: FnAny[Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
...
) -> PyTree: ...
def tree_map_only_(
@ -1658,8 +1643,7 @@ def tree_all_only(
pred: Fn[T, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
...
) -> bool: ...
@overload
@ -1669,8 +1653,7 @@ def tree_all_only(
pred: Fn2[T, S, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
...
) -> bool: ...
@overload
@ -1680,8 +1663,7 @@ def tree_all_only(
pred: Fn3[T, S, U, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
...
) -> bool: ...
def tree_all_only(
@ -1702,8 +1684,7 @@ def tree_any_only(
pred: Fn[T, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
...
) -> bool: ...
@overload
@ -1713,8 +1694,7 @@ def tree_any_only(
pred: Fn2[T, S, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
...
) -> bool: ...
@overload
@ -1724,8 +1704,7 @@ def tree_any_only(
pred: Fn3[T, S, U, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
...
) -> bool: ...
def tree_any_only(
@ -1862,7 +1841,7 @@ def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec:
if json_schema["type"] not in SERIALIZED_TYPE_TO_PYTHON_TYPE:
raise NotImplementedError(
f'Deserializing {json_schema["type"]} in pytree is not registered.',
f"Deserializing {json_schema['type']} in pytree is not registered.",
)
typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[json_schema["type"]]

View File

@ -301,7 +301,7 @@ def strobelight(
profiler = StrobelightCLIFunctionProfiler(**kwargs)
def strobelight_inner(
work_function: Callable[_P, _R]
work_function: Callable[_P, _R],
) -> Callable[_P, Optional[_R]]:
@functools.wraps(work_function)
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]:

View File

@ -98,7 +98,7 @@ def _is_symbols_binary_summation(expr: sympy.Expr) -> bool:
def _keep_float(
f: Callable[[Unpack[_Ts]], _T]
f: Callable[[Unpack[_Ts]], _T],
) -> Callable[[Unpack[_Ts]], Union[_T, sympy.Float]]:
@functools.wraps(f)
def inner(*args: Unpack[_Ts]) -> Union[_T, sympy.Float]:
@ -926,10 +926,12 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
_eval_is_algebraic = lambda s: _torf(i.is_algebraic for i in s.args) # noqa: E731
_eval_is_antihermitian = lambda s: _torf( # noqa: E731
i.is_antihermitian for i in s.args # noqa: E731
i.is_antihermitian
for i in s.args # noqa: E731
) # noqa: E731
_eval_is_commutative = lambda s: _torf( # noqa: E731
i.is_commutative for i in s.args # noqa: E731
i.is_commutative
for i in s.args # noqa: E731
) # noqa: E731
_eval_is_complex = lambda s: _torf(i.is_complex for i in s.args) # noqa: E731
_eval_is_composite = lambda s: _torf(i.is_composite for i in s.args) # noqa: E731
@ -943,10 +945,12 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
_eval_is_negative = lambda s: _torf(i.is_negative for i in s.args) # noqa: E731
_eval_is_noninteger = lambda s: _torf(i.is_noninteger for i in s.args) # noqa: E731
_eval_is_nonnegative = lambda s: _torf( # noqa: E731
i.is_nonnegative for i in s.args # noqa: E731
i.is_nonnegative
for i in s.args # noqa: E731
) # noqa: E731
_eval_is_nonpositive = lambda s: _torf( # noqa: E731
i.is_nonpositive for i in s.args # noqa: E731
i.is_nonpositive
for i in s.args # noqa: E731
) # noqa: E731
_eval_is_nonzero = lambda s: _torf(i.is_nonzero for i in s.args) # noqa: E731
_eval_is_odd = lambda s: _torf(i.is_odd for i in s.args) # noqa: E731
@ -956,10 +960,12 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
_eval_is_rational = lambda s: _torf(i.is_rational for i in s.args) # noqa: E731
_eval_is_real = lambda s: _torf(i.is_real for i in s.args) # noqa: E731
_eval_is_extended_real = lambda s: _torf( # noqa: E731
i.is_extended_real for i in s.args # noqa: E731
i.is_extended_real
for i in s.args # noqa: E731
) # noqa: E731
_eval_is_transcendental = lambda s: _torf( # noqa: E731
i.is_transcendental for i in s.args # noqa: E731
i.is_transcendental
for i in s.args # noqa: E731
) # noqa: E731
_eval_is_zero = lambda s: _torf(i.is_zero for i in s.args) # noqa: E731

View File

@ -144,16 +144,14 @@ class ValueRanges(Generic[_T]):
self: ValueRanges[sympy.Expr],
lower: ExprIn,
upper: ExprIn,
) -> None:
...
) -> None: ...
@overload
def __init__( # type: ignore[misc]
self: ValueRanges[SympyBoolean],
lower: BoolIn,
upper: BoolIn,
) -> None:
...
) -> None: ...
def __init__(self, lower: AllIn, upper: AllIn) -> None:
lower = simple_sympify(lower)
@ -240,15 +238,13 @@ class ValueRanges(Generic[_T]):
def __and__(
self: ValueRanges[sympy.Expr],
other: ValueRanges[sympy.Expr],
) -> ValueRanges[sympy.Expr]:
...
) -> ValueRanges[sympy.Expr]: ...
@overload
def __and__( # type: ignore[misc]
self: ValueRanges[SympyBoolean],
other: ValueRanges[SympyBoolean],
) -> ValueRanges[SympyBoolean]:
...
) -> ValueRanges[SympyBoolean]: ...
def __and__(self: AllVR, other: AllVR) -> AllVR:
if other in (ValueRanges.unknown(), ValueRanges.unknown_int()):
@ -272,15 +268,13 @@ class ValueRanges(Generic[_T]):
def __or__(
self: ValueRanges[sympy.Expr],
other: ValueRanges[sympy.Expr],
) -> ValueRanges[sympy.Expr]:
...
) -> ValueRanges[sympy.Expr]: ...
@overload
def __or__( # type: ignore[misc]
self: ValueRanges[SympyBoolean],
other: ValueRanges[SympyBoolean],
) -> ValueRanges[SympyBoolean]:
...
) -> ValueRanges[SympyBoolean]: ...
def __or__(self: AllVR, other: AllVR) -> AllVR:
if ValueRanges.unknown() in (self, other):
@ -343,8 +337,7 @@ class ValueRanges(Generic[_T]):
@overload
@staticmethod
def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
...
def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: ...
@overload
@staticmethod
@ -384,8 +377,7 @@ class ValueRanges(Generic[_T]):
x: Union[ExprIn, ExprVR],
y: Union[ExprIn, ExprVR],
fn: ExprFn2,
) -> ExprVR:
...
) -> ExprVR: ...
@overload
@staticmethod
@ -393,8 +385,7 @@ class ValueRanges(Generic[_T]):
x: Union[BoolIn, BoolVR],
y: Union[BoolIn, BoolVR],
fn: BoolFn2,
) -> BoolVR:
...
) -> BoolVR: ...
@staticmethod
def coordinatewise_increasing_map(

View File

@ -426,9 +426,9 @@ def _get_custom_mod_func(func_name: str):
it is marked as private. It is a convenience function for backend implementers to
more easily call the hooks into their backend extensions.
"""
assert isinstance(
func_name, str
), f"func_name must be `str`, but got `{type(func_name)}`."
assert isinstance(func_name, str), (
f"func_name must be `str`, but got `{type(func_name)}`."
)
backend_name = _get_privateuse1_backend_name()
custom_device_mod = getattr(torch, backend_name, None) # type: ignore[arg-type]
function = getattr(custom_device_mod, func_name, None) # type: ignore[arg-type]

View File

@ -44,7 +44,7 @@ def default_convert(data):
>>> default_convert(np.array([0, 1]))
tensor([0, 1])
>>> # Example with NamedTuple
>>> Point = namedtuple('Point', ['x', 'y'])
>>> Point = namedtuple("Point", ["x", "y"])
>>> default_convert(Point(0, 0))
Point(x=0, y=0)
>>> default_convert(Point(np.array(0), np.array(0)))
@ -366,13 +366,13 @@ def default_collate(batch):
>>> default_collate([0, 1, 2, 3])
tensor([0, 1, 2, 3])
>>> # Example with a batch of `str`s:
>>> default_collate(['a', 'b', 'c'])
>>> default_collate(["a", "b", "c"])
['a', 'b', 'c']
>>> # Example with `Map` inside the batch:
>>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}])
>>> default_collate([{"A": 0, "B": 1}, {"A": 100, "B": 100}])
{'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])}
>>> # Example with `NamedTuple` inside the batch:
>>> Point = namedtuple('Point', ['x', 'y'])
>>> Point = namedtuple("Point", ["x", "y"])
>>> default_collate([Point(0, 0), Point(1, 1)])
Point(x=tensor([0, 1]), y=tensor([0, 1]))
>>> # Example with `Tuple` inside the batch:

View File

@ -69,7 +69,9 @@ def pin_memory(data, device=None):
)
return clone
else:
return type(data)({k: pin_memory(sample, device) for k, sample in data.items()}) # type: ignore[call-arg]
return type(data)(
{k: pin_memory(sample, device) for k, sample in data.items()}
) # type: ignore[call-arg]
except TypeError:
# The mapping type may not support `copy()` / `update(mapping)`
# or `__init__(iterable)`.

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
r"""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
These **needs** to be in global scope since Py2 doesn't support serializing
static methods.

View File

@ -5,6 +5,7 @@ To support these two classes, in `./_utils` we define many utility methods and
functions to be run in multiprocessing. E.g., the data loading worker loop is
in `./_utils/worker.py`.
"""
from __future__ import annotations
import functools
@ -1208,7 +1209,10 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
# .pid can be None only before process is spawned (not the case, so ignore)
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc]
_utils.signal_handling._set_worker_pids(
id(self),
tuple(w.pid for w in self._workers), # type: ignore[misc]
)
_utils.signal_handling._set_SIGCHLD_handler()
self._worker_pids_set = True
self._reset(loader, first_iter=True)

View File

@ -109,8 +109,7 @@ class non_deterministic:
# Decorate with a functional argument
if not (
isinstance(args[0], type)
and issubclass(args[0], IterDataPipe) # type: ignore[arg-type]
isinstance(args[0], type) and issubclass(args[0], IterDataPipe) # type: ignore[arg-type]
):
raise TypeError(
f"Only `IterDataPipe` can be decorated, but {args[0].__name__} is found"

View File

@ -99,7 +99,9 @@ class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
>>> from torchdata.datapipes.iter import IterableWrapper, Mapper
>>> dp = IterableWrapper(range(10))
>>> map_dp_1 = Mapper(dp, lambda x: x + 1) # Using class constructor
>>> map_dp_2 = dp.map(lambda x: x + 1) # Using functional form (recommended)
>>> map_dp_2 = dp.map(
... lambda x: x + 1
... ) # Using functional form (recommended)
>>> list(map_dp_1)
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
>>> list(map_dp_2)
@ -114,7 +116,9 @@ class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
>>> list(it1)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
>>> it1 = iter(source_dp)
>>> it2 = iter(source_dp) # The creation of a new iterator invalidates `it1`
>>> it2 = iter(
... source_dp
... ) # The creation of a new iterator invalidates `it1`
>>> next(it2)
0
>>> next(it1) # Further usage of `it1` will raise a `RunTimeError`

View File

@ -55,7 +55,8 @@ class MapperIterDataPipe(IterDataPipe[_T_co]):
>>> def add_one(x):
... return x + 1
>>> dp = IterableWrapper(range(10))
>>> map_dp_1 = dp.map(add_one) # Invocation via functional form is preferred
>>> # Invocation via functional form is preferred
... map_dp_1 = dp.map(add_one)
>>> list(map_dp_1)
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
>>> # We discourage the usage of `lambda` functions as they are not serializable with `pickle`
@ -202,7 +203,7 @@ class CollatorIterDataPipe(MapperIterDataPipe):
>>> class MyIterDataPipe(torch.utils.data.IterDataPipe):
... def __init__(self, start, end):
... super(MyIterDataPipe).__init__()
... assert end > start, "this example code only works with end >= start"
... assert end > start, "this example only works with end >= start"
... self.start = start
... self.end = end
...
@ -211,13 +212,11 @@ class CollatorIterDataPipe(MapperIterDataPipe):
...
... def __len__(self):
... return self.end - self.start
...
>>> ds = MyIterDataPipe(start=3, end=7)
>>> print(list(ds))
[3, 4, 5, 6]
>>> def collate_fn(batch):
... return torch.tensor(batch, dtype=torch.float)
...
>>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn)
>>> print(list(collated_ds))
[tensor(3.), tensor(4.), tensor(5.), tensor(6.)]

View File

@ -38,15 +38,17 @@ class SamplerIterDataPipe(IterDataPipe[_T_co]):
sampler_args: Optional[tuple] = None,
sampler_kwargs: Optional[dict] = None,
) -> None:
assert isinstance(
datapipe, Sized
), "Sampler class requires input datapipe implemented `__len__`"
assert isinstance(datapipe, Sized), (
"Sampler class requires input datapipe implemented `__len__`"
)
super().__init__()
self.datapipe = datapipe
self.sampler_args = () if sampler_args is None else sampler_args
self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs
# https://github.com/python/mypy/pull/9629 will solve
self.sampler = sampler(*self.sampler_args, data_source=self.datapipe, **self.sampler_kwargs) # type: ignore[misc]
self.sampler = sampler(
*self.sampler_args, data_source=self.datapipe, **self.sampler_kwargs
) # type: ignore[misc]
def __iter__(self) -> Iterator[_T_co]:
return iter(self.sampler)

View File

@ -116,16 +116,13 @@ class _ContainerTemplate(ABC):
r"""Abstract class for container ``DataPipes``. The followings are three required methods."""
@abstractmethod
def get_next_element_by_instance(self, instance_id: int):
...
def get_next_element_by_instance(self, instance_id: int): ...
@abstractmethod
def is_every_instance_exhausted(self) -> bool:
...
def is_every_instance_exhausted(self) -> bool: ...
@abstractmethod
def reset(self) -> None:
...
def reset(self) -> None: ...
@abstractmethod
def get_length_by_instance(self, instance_id: int):
@ -403,7 +400,9 @@ class DemultiplexerIterDataPipe(IterDataPipe):
>>> # It can also filter out any element that gets `None` from the `classifier_fn`
>>> def odd_or_even_no_zero(n):
... return n % 2 if n != 0 else None
>>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True)
>>> dp1, dp2 = source_dp.demux(
... num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True
... )
>>> list(dp1)
[2, 4]
>>> list(dp2)
@ -428,7 +427,9 @@ class DemultiplexerIterDataPipe(IterDataPipe):
# When num_instances == 1, demux can be replaced by filter,
# but keep it as Demultiplexer for the sake of consistency
# like throwing Error when classification result is out of o range
container = _DemultiplexerIterDataPipe(datapipe, num_instances, classifier_fn, drop_none, buffer_size) # type: ignore[abstract]
container = _DemultiplexerIterDataPipe(
datapipe, num_instances, classifier_fn, drop_none, buffer_size
) # type: ignore[abstract]
return [_ChildDataPipe(container, i) for i in range(num_instances)]
@ -602,16 +603,18 @@ class MultiplexerIterDataPipe(IterDataPipe):
Example:
>>> # xdoctest: +REQUIRES(module:torchdata)
>>> from torchdata.datapipes.iter import IterableWrapper
>>> dp1, dp2, dp3 = IterableWrapper(range(3)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25))
>>> dp1, dp2, dp3 = (
... IterableWrapper(range(3)),
... IterableWrapper(range(10, 15)),
... IterableWrapper(range(20, 25)),
... )
>>> list(dp1.mux(dp2, dp3))
[0, 10, 20, 1, 11, 21, 2, 12, 22]
"""
def __init__(self, *datapipes):
self.datapipes = datapipes
self.buffer: list = (
[]
) # Store values to be yielded only when every iterator provides one
self.buffer: list = [] # Store values to be yielded only when every iterator provides one
def __iter__(self):
iterators = [iter(x) for x in self.datapipes]
@ -670,7 +673,11 @@ class ZipperIterDataPipe(IterDataPipe[tuple[_T_co]]):
Example:
>>> # xdoctest: +REQUIRES(module:torchdata)
>>> from torchdata.datapipes.iter import IterableWrapper
>>> dp1, dp2, dp3 = IterableWrapper(range(5)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25))
>>> dp1, dp2, dp3 = (
... IterableWrapper(range(5)),
... IterableWrapper(range(10, 15)),
... IterableWrapper(range(20, 25)),
... )
>>> list(dp1.zip(dp2, dp3))
[(0, 10, 20), (1, 11, 21), (2, 12, 22), (3, 13, 23), (4, 14, 24)]
"""

View File

@ -33,8 +33,12 @@ class FileOpenerIterDataPipe(IterDataPipe[tuple[str, IOBase]]):
Example:
>>> # xdoctest: +SKIP
>>> from torchdata.datapipes.iter import FileLister, FileOpener, StreamReader
>>> dp = FileLister(root=".").filter(lambda fname: fname.endswith('.txt'))
>>> from torchdata.datapipes.iter import (
... FileLister,
... FileOpener,
... StreamReader,
... )
>>> dp = FileLister(root=".").filter(lambda fname: fname.endswith(".txt"))
>>> dp = FileOpener(dp)
>>> dp = StreamReader(dp)
>>> list(dp)

View File

@ -182,7 +182,9 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]):
>>> from torchdata.datapipes.iter import IterableWrapper
>>> def group_fn(file):
... return os.path.basename(file).split(".")[0]
>>> source_dp = IterableWrapper(["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"])
>>> source_dp = IterableWrapper(
... ["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"]
... )
>>> dp0 = source_dp.groupby(group_key_fn=group_fn)
>>> list(dp0)
[['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']]
@ -191,7 +193,12 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]):
>>> list(dp1)
[['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
>>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size`
>>> dp2 = source_dp.groupby(group_key_fn=group_fn, buffer_size=3, group_size=3, guaranteed_group_size=2)
>>> dp2 = source_dp.groupby(
... group_key_fn=group_fn,
... buffer_size=3,
... group_size=3,
... guaranteed_group_size=2,
... )
>>> list(dp2)
[['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
"""

View File

@ -31,8 +31,8 @@ class SequenceWrapperMapDataPipe(MapDataPipe[_T]):
>>> dp = SequenceWrapper(range(10))
>>> list(dp)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
>>> dp = SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400})
>>> dp['a']
>>> dp = SequenceWrapper({"a": 100, "b": 200, "c": 300, "d": 400})
>>> dp["a"]
100
"""

View File

@ -45,8 +45,8 @@ def basichandlers(extension: str, data):
Example:
>>> import pickle
>>> data = pickle.dumps('some data')
>>> new_data = basichandlers('pickle', data)
>>> data = pickle.dumps("some data")
>>> new_data = basichandlers("pickle", data)
>>> new_data
some data
@ -169,9 +169,9 @@ class ImageHandler:
"""
def __init__(self, imagespec):
assert imagespec in list(
imagespecs.keys()
), f"unknown image specification: {imagespec}"
assert imagespec in list(imagespecs.keys()), (
f"unknown image specification: {imagespec}"
)
self.imagespec = imagespec.lower()
def __call__(self, extension, data):
@ -205,18 +205,18 @@ class ImageHandler:
return img
elif atype == "numpy":
result = np.asarray(img)
assert (
result.dtype == np.uint8
), f"numpy image array should be type uint8, but got {result.dtype}"
assert result.dtype == np.uint8, (
f"numpy image array should be type uint8, but got {result.dtype}"
)
if etype == "uint8":
return result
else:
return result.astype("f") / 255.0
elif atype == "torch":
result = np.asarray(img)
assert (
result.dtype == np.uint8
), f"numpy image array should be type uint8, but got {result.dtype}"
assert result.dtype == np.uint8, (
f"numpy image array should be type uint8, but got {result.dtype}"
)
if etype == "uint8":
result = np.array(result.transpose(2, 0, 1))

View File

@ -96,7 +96,7 @@ class IterableDataset(Dataset[_T_co], Iterable[_T_co]):
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
... def __init__(self, start, end):
... super(MyIterableDataset).__init__()
... assert end > start, "this example code only works with end >= start"
... assert end > start, "this example only works with end >= start"
... self.start = start
... self.end = end
...
@ -138,7 +138,7 @@ class IterableDataset(Dataset[_T_co], Iterable[_T_co]):
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
... def __init__(self, start, end):
... super(MyIterableDataset).__init__()
... assert end > start, "this example code only works with end >= start"
... assert end > start, "this example only works with end >= start"
... self.start = start
... self.end = end
...
@ -198,9 +198,9 @@ class TensorDataset(Dataset[tuple[Tensor, ...]]):
tensors: tuple[Tensor, ...]
def __init__(self, *tensors: Tensor) -> None:
assert all(
tensors[0].size(0) == tensor.size(0) for tensor in tensors
), "Size mismatch between tensors"
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), (
"Size mismatch between tensors"
)
self.tensors = tensors
def __getitem__(self, index):
@ -222,7 +222,7 @@ class StackDataset(Dataset[_T_stack]):
>>> tuple_stack = StackDataset(images, texts)
>>> tuple_stack[0] == (images[0], texts[0])
>>> dict_stack = StackDataset(image=images, text=texts)
>>> dict_stack[0] == {'image': images[0], 'text': texts[0]}
>>> dict_stack[0] == {"image": images[0], "text": texts[0]}
Args:
*args (Dataset): Datasets for stacking returned as tuple.
@ -323,9 +323,9 @@ class ConcatDataset(Dataset[_T_co]):
self.datasets = list(datasets)
assert len(self.datasets) > 0, "datasets should not be an empty iterable" # type: ignore[arg-type]
for d in self.datasets:
assert not isinstance(
d, IterableDataset
), "ConcatDataset does not support IterableDataset"
assert not isinstance(d, IterableDataset), (
"ConcatDataset does not support IterableDataset"
)
self.cumulative_sizes = self.cumsum(self.datasets)
def __len__(self):
@ -371,17 +371,17 @@ class ChainDataset(IterableDataset):
def __iter__(self):
for d in self.datasets:
assert isinstance(
d, IterableDataset
), "ChainDataset only supports IterableDataset"
assert isinstance(d, IterableDataset), (
"ChainDataset only supports IterableDataset"
)
yield from d
def __len__(self):
total = 0
for d in self.datasets:
assert isinstance(
d, IterableDataset
), "ChainDataset only supports IterableDataset"
assert isinstance(d, IterableDataset), (
"ChainDataset only supports IterableDataset"
)
total += len(d) # type: ignore[arg-type]
return total

View File

@ -236,9 +236,17 @@ class WeightedRandomSampler(Sampler[int]):
Example:
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
>>> list(
... WeightedRandomSampler(
... [0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True
... )
... )
[4, 4, 1, 4, 5]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
>>> list(
... WeightedRandomSampler(
... [0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False
... )
... )
[0, 1, 4, 3, 2]
"""
@ -298,9 +306,15 @@ class BatchSampler(Sampler[list[int]]):
its size would be less than ``batch_size``
Example:
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
>>> list(
... BatchSampler(
... SequentialSampler(range(10)), batch_size=3, drop_last=False
... )
... )
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
>>> list(
... BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)
... )
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
"""

View File

@ -49,6 +49,7 @@ class ModuleTracker:
def my_linear(m1, m2, bias):
print(f"Current modules: {tracker.parents}")
return torch.mm(m1, m2.t()) + bias
torch.nn.functional.linear = my_linear
mod(torch.rand(2, 2))

View File

@ -6,6 +6,7 @@ Intel GPU optimization.
This package is lazily initialized, so you can always import it, and use
:func:`is_available()` to determine if your system supports XPU.
"""
import threading
import traceback
from functools import lru_cache
@ -292,6 +293,7 @@ class StreamContext:
``None``.
.. note:: Streams are per-device.
"""
cur_stream: Optional["torch.xpu.Stream"]
def __init__(self, stream: Optional["torch.xpu.Stream"]):
@ -438,7 +440,7 @@ def get_gencode_flags() -> str:
arch_list = get_arch_list()
if len(arch_list) == 0:
return ""
return f'-device {",".join(arch for arch in arch_list)}'
return f"-device {','.join(arch for arch in arch_list)}"
def _get_generator(device: torch.device) -> torch._C.Generator: