mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE][PYFMT] migrate PYFMT for torch/[p-z]*/ to ruff format (#144552)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144552 Approved by: https://github.com/ezyang
This commit is contained in:
parent
fd606a3a91
commit
5cedc5a0ff
|
|
@ -52,7 +52,6 @@ USE_BLACK_FILELIST = re.compile(
|
|||
# torch/[e-m]*/**
|
||||
# torch/optim/**
|
||||
# torch/[p-z]*/**
|
||||
"torch/[p-z]*/**",
|
||||
],
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
"""Import mangling.
|
||||
See mangling.md for details.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -605,9 +605,9 @@ class PackageExporter:
|
|||
dependencies (bool, optional): If ``True``, we scan the source for dependencies.
|
||||
"""
|
||||
|
||||
assert (pickle_protocol == 4) or (
|
||||
pickle_protocol == 3
|
||||
), "torch.package only supports pickle protocols 3 and 4"
|
||||
assert (pickle_protocol == 4) or (pickle_protocol == 3), (
|
||||
"torch.package only supports pickle protocols 3 and 4"
|
||||
)
|
||||
|
||||
filename = self._filename(package, resource)
|
||||
# Write the pickle data for `obj`
|
||||
|
|
|
|||
|
|
@ -423,7 +423,12 @@ class PackageImporter(Importer):
|
|||
module.__dict__.setdefault(old_name, new_name)
|
||||
|
||||
return module
|
||||
return self._make_module(name, cur.source_file, isinstance(cur, _PackageNode), parent) # type: ignore[attr-defined]
|
||||
return self._make_module(
|
||||
name,
|
||||
cur.source_file, # type: ignore[attr-defined]
|
||||
isinstance(cur, _PackageNode),
|
||||
parent,
|
||||
)
|
||||
|
||||
def _compile_source(self, fullpath: str, mangled_filename: str):
|
||||
source = self.zip_reader.get_record(fullpath)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ examine their input shapes and stack traces, study device kernel activity and vi
|
|||
An earlier version of the API in :mod:`torch.autograd` module is considered legacy and will be deprecated.
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
from typing_extensions import TypeVarTuple, Unpack
|
||||
|
|
|
|||
|
|
@ -239,10 +239,12 @@ class SchemaMatcher:
|
|||
def match_schemas(cls, t: _ExtraFields_TorchOp) -> tuple[FunctionSchema, ...]:
|
||||
signature = tuple(
|
||||
# Tensor
|
||||
TensorKey.from_tensor(i) if isinstance(i, _TensorMetadata)
|
||||
TensorKey.from_tensor(i)
|
||||
if isinstance(i, _TensorMetadata)
|
||||
#
|
||||
# TensorList
|
||||
else [TensorKey.from_tensor(j) for j in i] if isinstance(i, list)
|
||||
else [TensorKey.from_tensor(j) for j in i]
|
||||
if isinstance(i, list)
|
||||
#
|
||||
# Scalar and uncaptured inputs.
|
||||
else i
|
||||
|
|
|
|||
|
|
@ -124,9 +124,9 @@ class BasicEvaluation:
|
|||
for child_event in curr_event.children:
|
||||
self_time -= child_event.duration_time_ns
|
||||
stack.append(child_event)
|
||||
assert (
|
||||
EventKey(curr_event) not in self.metrics
|
||||
), f"Duplicate id: {curr_event.id}, {curr_event.name}"
|
||||
assert EventKey(curr_event) not in self.metrics, (
|
||||
f"Duplicate id: {curr_event.id}, {curr_event.name}"
|
||||
)
|
||||
self.metrics[EventKey(curr_event)] = EventMetrics(self_time_ns=self_time)
|
||||
self.metrics[
|
||||
EventKey(curr_event)
|
||||
|
|
@ -227,8 +227,7 @@ class BasicEvaluation:
|
|||
|
||||
while (
|
||||
current_kernel_index < len(cuda_kernel_events)
|
||||
and (cuda_kernel_events[current_kernel_index].start_ns())
|
||||
<= start_time # type: ignore[possibly-undefined]
|
||||
and (cuda_kernel_events[current_kernel_index].start_ns()) <= start_time # type: ignore[possibly-undefined]
|
||||
):
|
||||
current_kernel_index += 1
|
||||
current_queue_depth = spawned_kernel_index - current_kernel_index + 1
|
||||
|
|
@ -352,11 +351,11 @@ class BasicEvaluation:
|
|||
|
||||
output += "\n".join(
|
||||
[
|
||||
f"""{'-' * 80}
|
||||
f"""{"-" * 80}
|
||||
Event: {event}
|
||||
Source code location: {source_code_location(event.event)}
|
||||
Percentage idle time: {self.metrics[event].fraction_idle_time * 100:.2f}%
|
||||
{'-' * 80}"""
|
||||
{"-" * 80}"""
|
||||
for event in event_list
|
||||
]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -624,8 +624,7 @@ class profile(_KinetoProfile):
|
|||
]
|
||||
) as p:
|
||||
code_to_profile()
|
||||
print(p.key_averages().table(
|
||||
sort_by="self_cuda_time_total", row_limit=-1))
|
||||
print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
|
||||
|
||||
Using the profiler's ``schedule``, ``on_trace_ready`` and ``step`` functions:
|
||||
|
||||
|
|
@ -635,16 +634,17 @@ class profile(_KinetoProfile):
|
|||
# on different iterations of the training loop;
|
||||
# trace_handler is called every time a new trace becomes available
|
||||
def trace_handler(prof):
|
||||
print(prof.key_averages().table(
|
||||
sort_by="self_cuda_time_total", row_limit=-1))
|
||||
print(
|
||||
prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)
|
||||
)
|
||||
# prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json")
|
||||
|
||||
|
||||
with torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
|
||||
# In this example with wait=1, warmup=1, active=2, repeat=1,
|
||||
# profiler will skip the first step/iteration,
|
||||
# start warming up on the second, record
|
||||
|
|
@ -652,20 +652,15 @@ class profile(_KinetoProfile):
|
|||
# after which the trace will become available
|
||||
# and on_trace_ready (when set) is called;
|
||||
# the cycle repeats starting with the next step
|
||||
|
||||
schedule=torch.profiler.schedule(
|
||||
wait=1,
|
||||
warmup=1,
|
||||
active=2,
|
||||
repeat=1),
|
||||
on_trace_ready=trace_handler
|
||||
schedule=torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=1),
|
||||
on_trace_ready=trace_handler,
|
||||
# on_trace_ready=torch.profiler.tensorboard_trace_handler('./log')
|
||||
# used when outputting for tensorboard
|
||||
) as p:
|
||||
for iter in range(N):
|
||||
code_iteration_to_profile(iter)
|
||||
# send a signal to the profiler that the next iteration has started
|
||||
p.step()
|
||||
) as p:
|
||||
for iter in range(N):
|
||||
code_iteration_to_profile(iter)
|
||||
# send a signal to the profiler that the next iteration has started
|
||||
p.step()
|
||||
|
||||
The following sample shows how to setup up an Execution Trace Observer (`execution_trace_observer`)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
|
|||
`torch/ao/quantization/fuser_method_mappings.py`, while adding an import statement
|
||||
here.
|
||||
"""
|
||||
|
||||
from torch.ao.quantization.fuser_method_mappings import (
|
||||
_DEFAULT_OP_LIST_TO_FUSER_METHOD,
|
||||
fuse_conv_bn,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
|
|||
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
||||
here.
|
||||
"""
|
||||
|
||||
from torch.ao.quantization.fx._equalize import (
|
||||
_convert_equalization_ref,
|
||||
_InputEqualizationObserver,
|
||||
|
|
|
|||
|
|
@ -6,4 +6,5 @@ If you are adding a new entry/functionality, please, add it to the
|
|||
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
||||
here.
|
||||
"""
|
||||
|
||||
from torch.ao.quantization.fx.convert import convert
|
||||
|
|
|
|||
|
|
@ -6,4 +6,5 @@ If you are adding a new entry/functionality, please, add it to the
|
|||
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
||||
here.
|
||||
"""
|
||||
|
||||
from torch.ao.quantization.fx.fuse import fuse
|
||||
|
|
|
|||
|
|
@ -6,4 +6,5 @@ If you are adding a new entry/functionality, please, add it to the
|
|||
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
||||
here.
|
||||
"""
|
||||
|
||||
from torch.ao.quantization.fx.fuse_handler import DefaultFuseHandler, FuseHandler
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
|
|||
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
||||
here.
|
||||
"""
|
||||
|
||||
from torch.ao.quantization.fx.graph_module import (
|
||||
_is_observed_module,
|
||||
_is_observed_standalone_module,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
|
|||
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
||||
here.
|
||||
"""
|
||||
|
||||
from torch.ao.quantization.fx.match_utils import (
|
||||
_find_matches,
|
||||
_is_match,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
|
|||
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
||||
here.
|
||||
"""
|
||||
|
||||
from torch.ao.quantization.fx.pattern_utils import (
|
||||
_register_fusion_pattern,
|
||||
_register_quant_pattern,
|
||||
|
|
|
|||
|
|
@ -6,4 +6,5 @@ If you are adding a new entry/functionality, please, add it to the
|
|||
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
||||
here.
|
||||
"""
|
||||
|
||||
from torch.ao.quantization.fx.prepare import prepare
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
|
|||
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
||||
here.
|
||||
"""
|
||||
|
||||
from torch.ao.quantization.fx.quantize_handler import (
|
||||
BatchNormQuantizeHandler,
|
||||
BinaryOpQuantizeHandler,
|
||||
|
|
|
|||
|
|
@ -6,4 +6,5 @@ If you are adding a new entry/functionality, please, add it to the
|
|||
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
||||
here.
|
||||
"""
|
||||
|
||||
from torch.ao.quantization.utils import Pattern, QuantizerCls
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
|
|||
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
||||
here.
|
||||
"""
|
||||
|
||||
from torch.ao.quantization.fx.utils import (
|
||||
all_node_args_have_no_tensors,
|
||||
assert_and_get_unique_device,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
|
|||
`torch/ao/quantization/observer.py`, while adding an import statement
|
||||
here.
|
||||
"""
|
||||
|
||||
from torch.ao.quantization.observer import (
|
||||
_is_activation_post_process,
|
||||
_is_per_channel_script_obs_instance,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
|
|||
`torch/ao/quantization/qconfig.py`, while adding an import statement
|
||||
here.
|
||||
"""
|
||||
|
||||
from torch.ao.quantization.qconfig import (
|
||||
_add_module_to_qconfig_obs_ctr,
|
||||
_assert_valid_qconfig,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
|
|||
`torch/ao/quantization/quantization_mappings.py`, while adding an import statement
|
||||
here.
|
||||
"""
|
||||
|
||||
from torch.ao.quantization.quantization_mappings import (
|
||||
_get_special_act_post_process,
|
||||
_has_special_act_post_process,
|
||||
|
|
|
|||
|
|
@ -128,9 +128,7 @@ Examples::
|
|||
>>> # Generates a periodic exponential window and decay factor equal to .5
|
||||
>>> torch.signal.windows.exponential(10, sym=False,tau=.5)
|
||||
tensor([4.5400e-05, 3.3546e-04, 2.4788e-03, 1.8316e-02, 1.3534e-01, 1.0000e+00, 1.3534e-01, 1.8316e-02, 2.4788e-03, 3.3546e-04])
|
||||
""".format(
|
||||
**window_common_args
|
||||
),
|
||||
""".format(**window_common_args),
|
||||
)
|
||||
def exponential(
|
||||
M: int,
|
||||
|
|
@ -452,9 +450,7 @@ Examples::
|
|||
>>> # Generates a periodic Hamming window.
|
||||
>>> torch.signal.windows.hamming(10, sym=False)
|
||||
tensor([0.0800, 0.1679, 0.3979, 0.6821, 0.9121, 1.0000, 0.9121, 0.6821, 0.3979, 0.1679])
|
||||
""".format(
|
||||
**window_common_args
|
||||
),
|
||||
""".format(**window_common_args),
|
||||
)
|
||||
def hamming(
|
||||
M: int,
|
||||
|
|
@ -508,9 +504,7 @@ Examples::
|
|||
>>> # Generates a periodic Hann window.
|
||||
>>> torch.signal.windows.hann(10, sym=False)
|
||||
tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955])
|
||||
""".format(
|
||||
**window_common_args
|
||||
),
|
||||
""".format(**window_common_args),
|
||||
)
|
||||
def hann(
|
||||
M: int,
|
||||
|
|
@ -564,9 +558,7 @@ Examples::
|
|||
>>> # Generates a periodic Blackman window.
|
||||
>>> torch.signal.windows.blackman(5, sym=False)
|
||||
tensor([-1.4901e-08, 2.0077e-01, 8.4923e-01, 8.4923e-01, 2.0077e-01])
|
||||
""".format(
|
||||
**window_common_args
|
||||
),
|
||||
""".format(**window_common_args),
|
||||
)
|
||||
def blackman(
|
||||
M: int,
|
||||
|
|
@ -627,9 +619,7 @@ Examples::
|
|||
>>> # Generates a periodic Bartlett window.
|
||||
>>> torch.signal.windows.bartlett(10, sym=False)
|
||||
tensor([0.0000, 0.2000, 0.4000, 0.6000, 0.8000, 1.0000, 0.8000, 0.6000, 0.4000, 0.2000])
|
||||
""".format(
|
||||
**window_common_args
|
||||
),
|
||||
""".format(**window_common_args),
|
||||
)
|
||||
def bartlett(
|
||||
M: int,
|
||||
|
|
@ -704,9 +694,7 @@ Examples::
|
|||
>>> # Generates a periodic general cosine window with 2 coefficients.
|
||||
>>> torch.signal.windows.general_cosine(10, a=[0.5, 1 - 0.5], sym=False)
|
||||
tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955])
|
||||
""".format(
|
||||
**window_common_args
|
||||
),
|
||||
""".format(**window_common_args),
|
||||
)
|
||||
def general_cosine(
|
||||
M,
|
||||
|
|
@ -799,9 +787,7 @@ Examples::
|
|||
>>> # Generates a periodic Hann window with the general Hamming window.
|
||||
>>> torch.signal.windows.general_hamming(10, alpha=0.5, sym=False)
|
||||
tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955])
|
||||
""".format(
|
||||
**window_common_args
|
||||
),
|
||||
""".format(**window_common_args),
|
||||
)
|
||||
def general_hamming(
|
||||
M,
|
||||
|
|
@ -866,9 +852,7 @@ Examples::
|
|||
>>> # Generates a periodic Nuttall window.
|
||||
>>> torch.signal.windows.general_hamming(5, sym=False)
|
||||
tensor([3.6280e-04, 1.1052e-01, 7.9826e-01, 7.9826e-01, 1.1052e-01])
|
||||
""".format(
|
||||
**window_common_args
|
||||
),
|
||||
""".format(**window_common_args),
|
||||
)
|
||||
def nuttall(
|
||||
M: int,
|
||||
|
|
|
|||
|
|
@ -559,7 +559,11 @@ def as_sparse_gradcheck(gradcheck):
|
|||
For example:
|
||||
|
||||
>>> gradcheck = torch.sparse.as_sparse_gradcheck(torch.autograd.gradcheck)
|
||||
>>> x = torch.tensor([[0, 1], [2, 3]], dtype=torch.float64).to_sparse_coo().requires_grad_(True)
|
||||
>>> x = (
|
||||
... torch.tensor([[0, 1], [2, 3]], dtype=torch.float64)
|
||||
... .to_sparse_coo()
|
||||
... .requires_grad_(True)
|
||||
... )
|
||||
>>> gradcheck(lambda x: x.to_sparse_csr(), x)
|
||||
True
|
||||
"""
|
||||
|
|
@ -667,7 +671,7 @@ def as_sparse_gradcheck(gradcheck):
|
|||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'conversion of {d["layout"]} strided representation to tensor'
|
||||
f"conversion of {d['layout']} strided representation to tensor"
|
||||
)
|
||||
new_args.append(a)
|
||||
return tuple(new_args)
|
||||
|
|
|
|||
|
|
@ -296,11 +296,11 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None):
|
|||
for b in range(nbatches):
|
||||
for i, r in enumerate(r_offsets):
|
||||
r0, r1 = divmod(r, N)
|
||||
acc = accumulators[b, r0:r0 + Ms, r1:r1 + Ns]
|
||||
for g in range(c_indices[i], c_indices[i+1]):
|
||||
acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns]
|
||||
for g in range(c_indices[i], c_indices[i + 1]):
|
||||
p = p_offsets[g]
|
||||
q0, q1 = divmod(q_offsets[g], N)
|
||||
acc += blocks[p] @ others[b, q0:q0 + Ks, q1:q1 + Ns]
|
||||
acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns]
|
||||
|
||||
where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are
|
||||
integer multiples of ``Ms`` and ``Ks``, respectively.
|
||||
|
|
@ -320,11 +320,11 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None):
|
|||
n = (r % N) // Ns
|
||||
r0, r1 = divmod(r, N)
|
||||
c0, c1 = c_indices[m], c_indices[m + 1]
|
||||
acc = accumulators[b, r0:r0 + Ms, r1:r1 + Ns]
|
||||
acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns]
|
||||
for i, p in enumerate(range(c0, c1)):
|
||||
q = q_offsets[n * c1 + (SPLIT_N - n) * c0 + i]
|
||||
q0, q1 = divmod(q, N)
|
||||
acc += blocks[p] @ others[b, q0:q0 + Ks, q1:q1 + Ns]
|
||||
acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns]
|
||||
|
||||
where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are
|
||||
integer multiples of ``Ms`` and ``Ks``, respectively.
|
||||
|
|
|
|||
|
|
@ -97,6 +97,7 @@ tune_bsr_dense_addmm to learn how to register a custom set of optimal
|
|||
kernel parameters for addmm-based operations.
|
||||
|
||||
"""
|
||||
|
||||
__all__ = ["get_meta", "tune_bsr_dense_addmm", "tune__int_bsr_dense_addmm"]
|
||||
|
||||
import inspect
|
||||
|
|
@ -432,9 +433,9 @@ def minimize(
|
|||
|
||||
|
||||
def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device):
|
||||
assert (
|
||||
sparsity <= 1.0 and sparsity >= 0.0
|
||||
), "sparsity should be a value between 0 and 1"
|
||||
assert sparsity <= 1.0 and sparsity >= 0.0, (
|
||||
"sparsity should be a value between 0 and 1"
|
||||
)
|
||||
assert M % blocksize[0] == 0
|
||||
assert N % blocksize[1] == 0
|
||||
shape = (B, M // blocksize[0], N // blocksize[1])[int(B == 0) :]
|
||||
|
|
|
|||
|
|
@ -465,14 +465,26 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
|||
The equivalent PyTorch code to create the same five outputs from the dense tensor can be found below:
|
||||
```
|
||||
from torch.sparse import SparseSemiStructuredTensorCUTLASS
|
||||
from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask
|
||||
from torch.sparse._semi_structured_conversions import (
|
||||
_sparse_semi_structured_tile,
|
||||
_compute_compressed_swizzled_bitmask,
|
||||
)
|
||||
|
||||
pruned = _sparse_semi_structured_tile(dense)
|
||||
packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned)
|
||||
packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous())
|
||||
packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(
|
||||
pruned.t().contiguous()
|
||||
)
|
||||
bitmask = _compute_compressed_swizzled_bitmask(pruned)
|
||||
|
||||
SparseSemiStructuredTensorCUTLASS(dense.shape, packed_cutlass, meta_cutlass, packed_t_cutlass, meta_t_cutlass, bitmask)
|
||||
SparseSemiStructuredTensorCUTLASS(
|
||||
dense.shape,
|
||||
packed_cutlass,
|
||||
meta_cutlass,
|
||||
packed_t_cutlass,
|
||||
meta_t_cutlass,
|
||||
bitmask,
|
||||
)
|
||||
```
|
||||
"""
|
||||
# We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag.
|
||||
|
|
@ -583,14 +595,19 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
|||
The equivalent PyTorch code to create the same three outputs from the dense tensor can be found below:
|
||||
```
|
||||
from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
|
||||
from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask
|
||||
from torch.sparse._semi_structured_conversions import (
|
||||
_sparse_semi_structured_tile,
|
||||
_compute_compressed_swizzled_bitmask,
|
||||
)
|
||||
|
||||
pruned = _sparse_semi_structured_tile(dense)
|
||||
packed_cusparselt = torch._cslt_compress(pruned)
|
||||
packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous())
|
||||
bitmask = _compute_compressed_swizzled_bitmask(pruned)
|
||||
|
||||
SparseSemiStructuredTensorCUSPARSELT(dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask)
|
||||
SparseSemiStructuredTensorCUSPARSELT(
|
||||
dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask
|
||||
)
|
||||
```
|
||||
"""
|
||||
(
|
||||
|
|
|
|||
|
|
@ -134,9 +134,7 @@ Example::
|
|||
>>> torch.special.digamma(a)
|
||||
tensor([-0.5772, -1.9635])
|
||||
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
gammaln = _add_docstr(
|
||||
|
|
@ -162,9 +160,7 @@ Example::
|
|||
>>> torch.special.gammaln(a)
|
||||
tensor([ 0.5724, 0.0000, -0.1208])
|
||||
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
polygamma = _add_docstr(
|
||||
|
|
@ -200,9 +196,7 @@ Example::
|
|||
tensor([ 6.4939, 97.4091])
|
||||
>>> torch.special.polygamma(4, a)
|
||||
tensor([ -24.8863, -771.4742])
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
erf = _add_docstr(
|
||||
|
|
@ -226,9 +220,7 @@ Example::
|
|||
|
||||
>>> torch.special.erf(torch.tensor([0, -1., 10.]))
|
||||
tensor([ 0.0000, -0.8427, 1.0000])
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
erfc = _add_docstr(
|
||||
|
|
@ -253,9 +245,7 @@ Example::
|
|||
|
||||
>>> torch.special.erfc(torch.tensor([0, -1., 10.]))
|
||||
tensor([ 1.0000, 1.8427, 0.0000])
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
erfcx = _add_docstr(
|
||||
|
|
@ -283,9 +273,7 @@ Example::
|
|||
|
||||
>>> torch.special.erfcx(torch.tensor([0, -1., 10.]))
|
||||
tensor([ 1.0000, 5.0090, 0.0561])
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
erfinv = _add_docstr(
|
||||
|
|
@ -311,9 +299,7 @@ Example::
|
|||
|
||||
>>> torch.special.erfinv(torch.tensor([0, 0.5, -1.]))
|
||||
tensor([ 0.0000, 0.4769, -inf])
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
logit = _add_docstr(
|
||||
|
|
@ -351,9 +337,7 @@ Example::
|
|||
tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516])
|
||||
>>> torch.special.logit(a, eps=1e-6)
|
||||
tensor([-0.9466, 2.6352, 0.6131, -1.7169, 0.6261])
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
logsumexp = _add_docstr(
|
||||
|
|
@ -362,9 +346,7 @@ logsumexp = _add_docstr(
|
|||
logsumexp(input, dim, keepdim=False, *, out=None)
|
||||
|
||||
Alias for :func:`torch.logsumexp`.
|
||||
""".format(
|
||||
**multi_dim_common
|
||||
),
|
||||
""".format(**multi_dim_common),
|
||||
)
|
||||
|
||||
expit = _add_docstr(
|
||||
|
|
@ -391,9 +373,7 @@ Example::
|
|||
tensor([ 0.9213, 1.0887, -0.8858, -1.7683])
|
||||
>>> torch.special.expit(t)
|
||||
tensor([ 0.7153, 0.7481, 0.2920, 0.1458])
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
exp2 = _add_docstr(
|
||||
|
|
@ -418,9 +398,7 @@ Example::
|
|||
|
||||
>>> torch.special.exp2(torch.tensor([0, math.log2(2.), 3, 4]))
|
||||
tensor([ 1., 2., 8., 16.])
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
expm1 = _add_docstr(
|
||||
|
|
@ -448,9 +426,7 @@ Example::
|
|||
|
||||
>>> torch.special.expm1(torch.tensor([0, math.log(2.)]))
|
||||
tensor([ 0., 1.])
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
xlog1py = _add_docstr(
|
||||
|
|
@ -495,9 +471,7 @@ Example::
|
|||
tensor([1.6094, 3.2189, 4.8283])
|
||||
>>> torch.special.xlog1py(2, y)
|
||||
tensor([2.7726, 2.1972, 1.3863])
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
xlogy = _add_docstr(
|
||||
|
|
@ -542,9 +516,7 @@ Example::
|
|||
tensor([1.3863, 2.7726, 4.1589])
|
||||
>>> torch.special.xlogy(2, y)
|
||||
tensor([2.1972, 1.3863, 0.0000])
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
i0 = _add_docstr(
|
||||
|
|
@ -570,9 +542,7 @@ Example::
|
|||
>>> torch.i0(torch.arange(5, dtype=torch.float32))
|
||||
tensor([ 1.0000, 1.2661, 2.2796, 4.8808, 11.3019])
|
||||
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
i0e = _add_docstr(
|
||||
|
|
@ -597,9 +567,7 @@ Example::
|
|||
|
||||
>>> torch.special.i0e(torch.arange(5, dtype=torch.float32))
|
||||
tensor([1.0000, 0.4658, 0.3085, 0.2430, 0.2070])
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
i1 = _add_docstr(
|
||||
|
|
@ -624,9 +592,7 @@ Example::
|
|||
|
||||
>>> torch.special.i1(torch.arange(5, dtype=torch.float32))
|
||||
tensor([0.0000, 0.5652, 1.5906, 3.9534, 9.7595])
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
i1e = _add_docstr(
|
||||
|
|
@ -652,9 +618,7 @@ Example::
|
|||
|
||||
>>> torch.special.i1e(torch.arange(5, dtype=torch.float32))
|
||||
tensor([0.0000, 0.2079, 0.2153, 0.1968, 0.1788])
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
ndtr = _add_docstr(
|
||||
|
|
@ -679,9 +643,7 @@ Example::
|
|||
|
||||
>>> torch.special.ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3]))
|
||||
tensor([0.0013, 0.0228, 0.1587, 0.5000, 0.8413, 0.9772, 0.9987])
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
ndtri = _add_docstr(
|
||||
|
|
@ -709,9 +671,7 @@ Example::
|
|||
|
||||
>>> torch.special.ndtri(torch.tensor([0, 0.25, 0.5, 0.75, 1]))
|
||||
tensor([ -inf, -0.6745, 0.0000, 0.6745, inf])
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
log_ndtr = _add_docstr(
|
||||
|
|
@ -736,9 +696,7 @@ Example::
|
|||
|
||||
>>> torch.special.log_ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3]))
|
||||
tensor([-6.6077 -3.7832 -1.841 -0.6931 -0.1728 -0.023 -0.0014])
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
log1p = _add_docstr(
|
||||
|
|
@ -779,9 +737,7 @@ Example::
|
|||
tensor([ 0.2252, -0.2948, 1.0267, -1.1566])
|
||||
>>> torch.special.sinc(t)
|
||||
tensor([ 0.9186, 0.8631, -0.0259, -0.1300])
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
round = _add_docstr(
|
||||
|
|
@ -886,9 +842,7 @@ Example::
|
|||
tensor([1.6449, 0.0823])
|
||||
>>> torch.special.zeta(2, torch.tensor([1., 2.]))
|
||||
tensor([1.6449, 0.6449])
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
multigammaln = _add_docstr(
|
||||
|
|
@ -925,9 +879,7 @@ Example::
|
|||
>>> torch.special.multigammaln(a, 2)
|
||||
tensor([[0.3928, 0.4007, 0.7586],
|
||||
[1.0311, 0.3901, 0.5049]])
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
gammainc = _add_docstr(
|
||||
|
|
@ -976,9 +928,7 @@ Example::
|
|||
>>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2)
|
||||
tensor([1., 1., 1.])
|
||||
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
gammaincc = _add_docstr(
|
||||
|
|
@ -1026,9 +976,7 @@ Example::
|
|||
>>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2)
|
||||
tensor([1., 1., 1.])
|
||||
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
airy_ai = _add_docstr(
|
||||
|
|
@ -1045,9 +993,7 @@ Args:
|
|||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
bessel_j0 = _add_docstr(
|
||||
|
|
@ -1064,9 +1010,7 @@ Args:
|
|||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
bessel_j1 = _add_docstr(
|
||||
|
|
@ -1083,9 +1027,7 @@ Args:
|
|||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
bessel_y0 = _add_docstr(
|
||||
|
|
@ -1102,9 +1044,7 @@ Args:
|
|||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
bessel_y1 = _add_docstr(
|
||||
|
|
@ -1121,9 +1061,7 @@ Args:
|
|||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
chebyshev_polynomial_t = _add_docstr(
|
||||
|
|
@ -1154,9 +1092,7 @@ Args:
|
|||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
chebyshev_polynomial_u = _add_docstr(
|
||||
|
|
@ -1188,9 +1124,7 @@ Args:
|
|||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
chebyshev_polynomial_v = _add_docstr(
|
||||
|
|
@ -1208,9 +1142,7 @@ Args:
|
|||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
chebyshev_polynomial_w = _add_docstr(
|
||||
|
|
@ -1228,9 +1160,7 @@ Args:
|
|||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
hermite_polynomial_h = _add_docstr(
|
||||
|
|
@ -1256,9 +1186,7 @@ Args:
|
|||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
hermite_polynomial_he = _add_docstr(
|
||||
|
|
@ -1284,9 +1212,7 @@ Args:
|
|||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
laguerre_polynomial_l = _add_docstr(
|
||||
|
|
@ -1312,9 +1238,7 @@ Args:
|
|||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
legendre_polynomial_p = _add_docstr(
|
||||
|
|
@ -1340,9 +1264,7 @@ Args:
|
|||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
modified_bessel_i0 = _add_docstr(
|
||||
|
|
@ -1359,9 +1281,7 @@ Args:
|
|||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
modified_bessel_i1 = _add_docstr(
|
||||
|
|
@ -1378,9 +1298,7 @@ Args:
|
|||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
modified_bessel_k0 = _add_docstr(
|
||||
|
|
@ -1397,9 +1315,7 @@ Args:
|
|||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
modified_bessel_k1 = _add_docstr(
|
||||
|
|
@ -1416,9 +1332,7 @@ Args:
|
|||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
scaled_modified_bessel_k0 = _add_docstr(
|
||||
|
|
@ -1435,9 +1349,7 @@ Args:
|
|||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
scaled_modified_bessel_k1 = _add_docstr(
|
||||
|
|
@ -1454,9 +1366,7 @@ Args:
|
|||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
shifted_chebyshev_polynomial_t = _add_docstr(
|
||||
|
|
@ -1474,9 +1384,7 @@ Args:
|
|||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
shifted_chebyshev_polynomial_u = _add_docstr(
|
||||
|
|
@ -1494,9 +1402,7 @@ Args:
|
|||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
shifted_chebyshev_polynomial_v = _add_docstr(
|
||||
|
|
@ -1514,9 +1420,7 @@ Args:
|
|||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
shifted_chebyshev_polynomial_w = _add_docstr(
|
||||
|
|
@ -1534,9 +1438,7 @@ Args:
|
|||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
spherical_bessel_j0 = _add_docstr(
|
||||
|
|
@ -1553,7 +1455,5 @@ Args:
|
|||
|
||||
Keyword args:
|
||||
{out}
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1538,7 +1538,9 @@ def assert_close(
|
|||
>>> expected = torch.tensor([1.0, 2.0, 3.0])
|
||||
>>> actual = torch.tensor([1.0, 4.0, 5.0])
|
||||
>>> # The default error message can be overwritten.
|
||||
>>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!")
|
||||
>>> torch.testing.assert_close(
|
||||
... actual, expected, msg="Argh, the tensors are not close!"
|
||||
... )
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AssertionError: Argh, the tensors are not close!
|
||||
|
|
|
|||
|
|
@ -115,11 +115,11 @@ def make_tensor(
|
|||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
||||
>>> from torch.testing import make_tensor
|
||||
>>> # Creates a float tensor with values in [-1, 1)
|
||||
>>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1)
|
||||
>>> make_tensor((3,), device="cpu", dtype=torch.float32, low=-1, high=1)
|
||||
>>> # xdoctest: +SKIP
|
||||
tensor([ 0.1205, 0.2282, -0.6380])
|
||||
>>> # Creates a bool tensor on CUDA
|
||||
>>> make_tensor((2, 2), device='cuda', dtype=torch.bool)
|
||||
>>> make_tensor((2, 2), device="cuda", dtype=torch.bool)
|
||||
tensor([[False, False],
|
||||
[False, True]], device='cuda:0')
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -721,9 +721,9 @@ def filter_desired_device_types(device_type_test_bases, except_for=None, only_fo
|
|||
intersect = set(except_for if except_for else []) & set(
|
||||
only_for if only_for else []
|
||||
)
|
||||
assert (
|
||||
not intersect
|
||||
), f"device ({intersect}) appeared in both except_for and only_for"
|
||||
assert not intersect, (
|
||||
f"device ({intersect}) appeared in both except_for and only_for"
|
||||
)
|
||||
|
||||
# Replace your privateuse1 backend name with 'privateuse1'
|
||||
if is_privateuse1_backend_available():
|
||||
|
|
@ -1407,9 +1407,9 @@ class deviceCountAtLeast:
|
|||
self.num_required_devices = num_required_devices
|
||||
|
||||
def __call__(self, fn):
|
||||
assert not hasattr(
|
||||
fn, "num_required_devices"
|
||||
), f"deviceCountAtLeast redefinition for {fn.__name__}"
|
||||
assert not hasattr(fn, "num_required_devices"), (
|
||||
f"deviceCountAtLeast redefinition for {fn.__name__}"
|
||||
)
|
||||
fn.num_required_devices = self.num_required_devices
|
||||
|
||||
@wraps(fn)
|
||||
|
|
@ -1474,13 +1474,13 @@ def onlyNativeDeviceTypesAnd(devices=None):
|
|||
# self.precision *2, max(1, self.precision)).
|
||||
class precisionOverride:
|
||||
def __init__(self, d):
|
||||
assert isinstance(
|
||||
d, dict
|
||||
), "precisionOverride not given a dtype : precision dict!"
|
||||
assert isinstance(d, dict), (
|
||||
"precisionOverride not given a dtype : precision dict!"
|
||||
)
|
||||
for dtype in d.keys():
|
||||
assert isinstance(
|
||||
dtype, torch.dtype
|
||||
), f"precisionOverride given unknown dtype {dtype}"
|
||||
assert isinstance(dtype, torch.dtype), (
|
||||
f"precisionOverride given unknown dtype {dtype}"
|
||||
)
|
||||
|
||||
self.d = d
|
||||
|
||||
|
|
@ -1513,12 +1513,12 @@ class toleranceOverride:
|
|||
def __init__(self, d):
|
||||
assert isinstance(d, dict), "toleranceOverride not given a dtype : tol dict!"
|
||||
for dtype, prec in d.items():
|
||||
assert isinstance(
|
||||
dtype, torch.dtype
|
||||
), f"toleranceOverride given unknown dtype {dtype}"
|
||||
assert isinstance(
|
||||
prec, tol
|
||||
), "toleranceOverride not given a dtype : tol dict!"
|
||||
assert isinstance(dtype, torch.dtype), (
|
||||
f"toleranceOverride given unknown dtype {dtype}"
|
||||
)
|
||||
assert isinstance(prec, tol), (
|
||||
"toleranceOverride not given a dtype : tol dict!"
|
||||
)
|
||||
|
||||
self.d = d
|
||||
|
||||
|
|
@ -1546,13 +1546,13 @@ class dtypes:
|
|||
"all dtype variants must be. "
|
||||
f"Received non-list non-tuple dtype {str(arg)}"
|
||||
)
|
||||
assert all(
|
||||
isinstance(dtype, torch.dtype) for dtype in arg
|
||||
), f"Unknown dtype in {str(arg)}"
|
||||
assert all(isinstance(dtype, torch.dtype) for dtype in arg), (
|
||||
f"Unknown dtype in {str(arg)}"
|
||||
)
|
||||
else:
|
||||
assert all(
|
||||
isinstance(arg, torch.dtype) for arg in args
|
||||
), f"Unknown dtype in {str(args)}"
|
||||
assert all(isinstance(arg, torch.dtype) for arg in args), (
|
||||
f"Unknown dtype in {str(args)}"
|
||||
)
|
||||
|
||||
self.args = args
|
||||
self.device_type = device_type
|
||||
|
|
|
|||
|
|
@ -253,9 +253,9 @@ def verify_ddp_error_logged(model_DDP, err_substr):
|
|||
if err_substr.find("\nException raised from ") == -1
|
||||
else err_substr.split("\nException raised from ")[0]
|
||||
)
|
||||
assert (
|
||||
actual in logging_err
|
||||
), f"Did not find expected {actual} in ddp logging data error: {logging_err}"
|
||||
assert actual in logging_err, (
|
||||
f"Did not find expected {actual} in ddp logging data error: {logging_err}"
|
||||
)
|
||||
|
||||
|
||||
def with_nccl_blocking_wait(func):
|
||||
|
|
@ -294,9 +294,9 @@ def with_nccl_blocking_wait(func):
|
|||
finally:
|
||||
# restore old values.
|
||||
if cached_nccl_async_error_handling is not None:
|
||||
os.environ[
|
||||
"TORCH_NCCL_ASYNC_ERROR_HANDLING"
|
||||
] = cached_nccl_async_error_handling
|
||||
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = (
|
||||
cached_nccl_async_error_handling
|
||||
)
|
||||
|
||||
if cached_nccl_blocking_wait is not None:
|
||||
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = cached_nccl_blocking_wait
|
||||
|
|
@ -812,7 +812,7 @@ class MultiProcessTestCase(TestCase):
|
|||
sys.exit(TEST_SKIPS["generic"].exit_code)
|
||||
except Exception:
|
||||
logger.error(
|
||||
"Caught exception: \n%s exiting " "process %s with exit code: %s",
|
||||
"Caught exception: \n%s exiting process %s with exit code: %s",
|
||||
traceback.format_exc(),
|
||||
self.rank,
|
||||
MultiProcessTestCase.TEST_ERROR_EXIT_CODE,
|
||||
|
|
@ -1689,9 +1689,7 @@ class MultiProcContinousTest(TestCase):
|
|||
cls.processes.append(process)
|
||||
cls.task_queues.append(task_queue)
|
||||
cls.completion_queues.append(completion_queue)
|
||||
logger.info(
|
||||
"Started process %s with pid %s", rank, process.pid
|
||||
) # noqa: UP031
|
||||
logger.info("Started process %s with pid %s", rank, process.pid) # noqa: UP031
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
|
|
|||
|
|
@ -1285,10 +1285,10 @@ class FSDPTest(MultiProcessTestCase):
|
|||
loss = sharded_grad_scaler.scale(loss)
|
||||
|
||||
if not mixed_precision and not use_pure_fp16:
|
||||
assert (
|
||||
loss.dtype == torch.float32
|
||||
), "loss data type should be float32, as the original \
|
||||
assert loss.dtype == torch.float32, (
|
||||
"loss data type should be float32, as the original \
|
||||
parameter data type is float32."
|
||||
)
|
||||
else:
|
||||
if use_pure_fp16:
|
||||
self.assertEqual(loss.dtype, torch.float16)
|
||||
|
|
@ -1354,9 +1354,9 @@ class FSDPTest(MultiProcessTestCase):
|
|||
wrapper should provide data parallel semantics. If ``None``,
|
||||
then the callable defaults to the DDP constructor.
|
||||
"""
|
||||
assert (
|
||||
fsdp_init_mode != FSDPInitMode.NO_FSDP
|
||||
), "Expects an FSDP init mode that wraps with FSDP"
|
||||
assert fsdp_init_mode != FSDPInitMode.NO_FSDP, (
|
||||
"Expects an FSDP init mode that wraps with FSDP"
|
||||
)
|
||||
if init_kwargs is None:
|
||||
init_kwargs = {}
|
||||
lr = 1e-2
|
||||
|
|
|
|||
|
|
@ -1268,9 +1268,9 @@ def _get_optim_inputs_including_global_cliquey_kwargs(
|
|||
trivial. That said, we sometimes want to test for all possible configs on an
|
||||
optimizer including all supported flags, so this helper returns all optim inputs.
|
||||
"""
|
||||
assert all(
|
||||
x in ["foreach", "fused", "differentiable"] for x in skip
|
||||
), "skip must be a subset of ['foreach', 'fused', 'differentiable']"
|
||||
assert all(x in ["foreach", "fused", "differentiable"] for x in skip), (
|
||||
"skip must be a subset of ['foreach', 'fused', 'differentiable']"
|
||||
)
|
||||
|
||||
optim_inputs = optim_info.optim_inputs_func(device)
|
||||
|
||||
|
|
|
|||
|
|
@ -477,7 +477,9 @@ def with_comms(
|
|||
def decorator(func, eager_init: bool = False, backend: Optional[str] = None):
|
||||
@wraps(func) # pyre-ignore[6]
|
||||
def wrapper(
|
||||
self, *args: tuple[object], **kwargs: dict[str, Any] # type: ignore[misc]
|
||||
self,
|
||||
*args: tuple[object],
|
||||
**kwargs: dict[str, Any], # type: ignore[misc]
|
||||
) -> None:
|
||||
self.init_pg(eager_init, backend)
|
||||
|
||||
|
|
|
|||
|
|
@ -253,7 +253,11 @@ class Trainer:
|
|||
else:
|
||||
input_batches = batches
|
||||
|
||||
with self.hybrid_module.join() if simulate_uneven_inputs else contextlib.nullcontext():
|
||||
with (
|
||||
self.hybrid_module.join()
|
||||
if simulate_uneven_inputs
|
||||
else contextlib.nullcontext()
|
||||
):
|
||||
for b in input_batches:
|
||||
with dist_autograd.context() as context_id:
|
||||
output = self.hybrid_module.forward(b)
|
||||
|
|
@ -261,8 +265,7 @@ class Trainer:
|
|||
dist_autograd.backward(context_id, [loss])
|
||||
grads_dict = dist_autograd.get_gradients(context_id)
|
||||
gLogger.info(
|
||||
"Loss is %s for mini batch: %s. "
|
||||
"Grads dict has %s entries: %s",
|
||||
"Loss is %s for mini batch: %s. Grads dict has %s entries: %s",
|
||||
loss,
|
||||
mini_batch,
|
||||
len(grads_dict),
|
||||
|
|
|
|||
|
|
@ -162,9 +162,7 @@ class SampleInput:
|
|||
# Allow calling either as SampleInput(input, args=args, kwargs=kwargs), or as
|
||||
# SampleInput(input, *args, **kwargs) but not to mix the two forms
|
||||
if args is not None or kwargs is not None:
|
||||
assert (
|
||||
not var_args and not var_kwargs
|
||||
), """
|
||||
assert not var_args and not var_kwargs, """
|
||||
A SampleInput can be constructed "naturally" with *args and **kwargs or by
|
||||
explicitly setting the "args" and "kwargs" parameters, but the two
|
||||
methods of construction cannot be mixed!"""
|
||||
|
|
@ -226,7 +224,7 @@ cannot specify additional metadata in keyword arguments"""
|
|||
f"name={repr(self.name)}",
|
||||
]
|
||||
|
||||
return f'SampleInput({", ".join(a for a in arguments if a is not None)})'
|
||||
return f"SampleInput({', '.join(a for a in arguments if a is not None)})"
|
||||
|
||||
def __repr__(self):
|
||||
return self._repr_helper(lambda x: x)
|
||||
|
|
@ -1601,13 +1599,11 @@ class SampleRule(ABC):
|
|||
|
||||
# returns a string identifier of the rule type
|
||||
@abstractmethod
|
||||
def type(self) -> str:
|
||||
...
|
||||
def type(self) -> str: ...
|
||||
|
||||
# returns an appropriate context that handles the xfail, skips, etc.
|
||||
@abstractmethod
|
||||
def get_context(self, test_case):
|
||||
...
|
||||
def get_context(self, test_case): ...
|
||||
|
||||
|
||||
# useful for specifying xfails
|
||||
|
|
@ -1791,8 +1787,10 @@ class ReductionOpInfo(OpInfo):
|
|||
# kwargs to use when calling the op. This is required for operators that
|
||||
# have other required parameters besides the input tensor.
|
||||
generate_args_kwargs: Callable = lambda t, dim=None, keepdim=False: (
|
||||
yield (),
|
||||
{},
|
||||
yield (
|
||||
(),
|
||||
{},
|
||||
)
|
||||
),
|
||||
# Options from the OpInfo base class
|
||||
**kwargs,
|
||||
|
|
@ -2476,9 +2474,9 @@ class BinaryUfuncInfo(OpInfo):
|
|||
self.supports_one_python_scalar = True
|
||||
|
||||
if self.supports_one_python_scalar:
|
||||
assert (
|
||||
supports_rhs_python_scalar
|
||||
), "Can't support lhs and rhs Python scalars but not rhs scalars!"
|
||||
assert supports_rhs_python_scalar, (
|
||||
"Can't support lhs and rhs Python scalars but not rhs scalars!"
|
||||
)
|
||||
|
||||
|
||||
# The following functions and classes are for testing elementwise unary operators.
|
||||
|
|
|
|||
|
|
@ -102,8 +102,9 @@ def sample_inputs_masked_reduction(op_info, device, dtype, requires_grad, **kwar
|
|||
for mask in _generate_masked_op_mask(
|
||||
sample_input.input.shape, device, **kwargs
|
||||
):
|
||||
sample_input_args, sample_input_kwargs = sample_input.args, dict(
|
||||
mask=mask, **sample_input.kwargs
|
||||
sample_input_args, sample_input_kwargs = (
|
||||
sample_input.args,
|
||||
dict(mask=mask, **sample_input.kwargs),
|
||||
)
|
||||
yield SampleInput(
|
||||
sample_input.input.detach().requires_grad_(requires_grad),
|
||||
|
|
@ -224,8 +225,9 @@ def sample_inputs_masked_norm(op_info, device, dtype, requires_grad, **kwargs):
|
|||
op_info, device, dtype, requires_grad, **kwargs
|
||||
):
|
||||
sample_input_args, sample_input_kwargs = (
|
||||
ord,
|
||||
) + sample_input.args, sample_input.kwargs.copy()
|
||||
(ord,) + sample_input.args,
|
||||
sample_input.kwargs.copy(),
|
||||
)
|
||||
yield SampleInput(
|
||||
sample_input.input.clone().requires_grad_(requires_grad),
|
||||
args=sample_input_args,
|
||||
|
|
@ -276,8 +278,9 @@ def sample_inputs_masked_std_var(op_info, device, dtype, requires_grad, **kwargs
|
|||
for mask in _generate_masked_op_mask(
|
||||
sample_input.input.shape, device, **kwargs
|
||||
):
|
||||
sample_input_args, sample_input_kwargs = sample_input.args, dict(
|
||||
mask=mask, **sample_input.kwargs
|
||||
sample_input_args, sample_input_kwargs = (
|
||||
sample_input.args,
|
||||
dict(mask=mask, **sample_input.kwargs),
|
||||
)
|
||||
yield SampleInput(
|
||||
sample_input.input.detach().requires_grad_(requires_grad),
|
||||
|
|
@ -364,8 +367,9 @@ def sample_inputs_masked_cumops(op_info, device, dtype, requires_grad, **kwargs)
|
|||
):
|
||||
if type(mask) != torch.Tensor:
|
||||
continue
|
||||
sample_input_args, sample_input_kwargs = sample_input.args, dict(
|
||||
mask=mask, **sample_input.kwargs
|
||||
sample_input_args, sample_input_kwargs = (
|
||||
sample_input.args,
|
||||
dict(mask=mask, **sample_input.kwargs),
|
||||
)
|
||||
if "keepdim" in sample_input_kwargs:
|
||||
sample_input_kwargs.pop("keepdim")
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ class _Config(Generic[T]):
|
|||
|
||||
@staticmethod
|
||||
def string_or_list_of_string_to_list(
|
||||
val: Optional[Union[str, list[str]]]
|
||||
val: Optional[Union[str, list[str]]],
|
||||
) -> Optional[list[str]]:
|
||||
if val is None:
|
||||
return None
|
||||
|
|
@ -135,8 +135,7 @@ if TYPE_CHECKING:
|
|||
env_name_force: Optional[Union[str, list[str]]] = None,
|
||||
value_type: Optional[type] = None,
|
||||
alias: Optional[str] = None,
|
||||
) -> T:
|
||||
...
|
||||
) -> T: ...
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -323,9 +322,9 @@ class _ConfigEntry:
|
|||
|
||||
# Ensure justknobs and envvars are allowlisted types
|
||||
if self.justknob is not None and self.default is not None:
|
||||
assert isinstance(
|
||||
self.default, bool
|
||||
), f"justknobs only support booleans, {self.default} is not a boolean"
|
||||
assert isinstance(self.default, bool), (
|
||||
f"justknobs only support booleans, {self.default} is not a boolean"
|
||||
)
|
||||
if self.value_type is not None and (
|
||||
config.env_name_default is not None or config.env_name_force is not None
|
||||
):
|
||||
|
|
@ -334,7 +333,9 @@ class _ConfigEntry:
|
|||
str,
|
||||
Optional[bool],
|
||||
Optional[str],
|
||||
), f"envvar configs only support (optional) booleans or strings, {self.value_type} is neither"
|
||||
), (
|
||||
f"envvar configs only support (optional) booleans or strings, {self.value_type} is neither"
|
||||
)
|
||||
|
||||
|
||||
class ConfigModule(ModuleType):
|
||||
|
|
|
|||
|
|
@ -282,9 +282,9 @@ def tree_is_leaf(
|
|||
False
|
||||
>>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple))
|
||||
True
|
||||
>>> tree_is_leaf({'a': 1, 'b': 2, 'c': 3})
|
||||
>>> tree_is_leaf({"a": 1, "b": 2, "c": 3})
|
||||
False
|
||||
>>> tree_is_leaf({'a': 1, 'b': 2, 'c': None})
|
||||
>>> tree_is_leaf({"a": 1, "b": 2, "c": None})
|
||||
False
|
||||
|
||||
Args:
|
||||
|
|
@ -586,29 +586,28 @@ MapOnlyFn = Callable[[T], Callable[[Any], Any]]
|
|||
# These specializations help with type inference on the lambda passed to this
|
||||
# function
|
||||
@overload
|
||||
def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]:
|
||||
...
|
||||
def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]:
|
||||
...
|
||||
def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def map_only(type_or_types_or_pred: Type3[T, S, U], /) -> MapOnlyFn[Fn3[T, S, U, Any]]:
|
||||
...
|
||||
def map_only(
|
||||
type_or_types_or_pred: Type3[T, S, U], /
|
||||
) -> MapOnlyFn[Fn3[T, S, U, Any]]: ...
|
||||
|
||||
|
||||
# This specialization is needed for the implementations below that call
|
||||
@overload
|
||||
def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]:
|
||||
...
|
||||
def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def map_only(type_or_types_or_pred: Callable[[Any], bool], /) -> MapOnlyFn[FnAny[Any]]:
|
||||
...
|
||||
def map_only(
|
||||
type_or_types_or_pred: Callable[[Any], bool], /
|
||||
) -> MapOnlyFn[FnAny[Any]]: ...
|
||||
|
||||
|
||||
def map_only(
|
||||
|
|
@ -664,8 +663,7 @@ def tree_map_only(
|
|||
func: Fn[T, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
...
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
|
@ -675,8 +673,7 @@ def tree_map_only(
|
|||
func: Fn2[T, S, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
...
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
|
@ -686,8 +683,7 @@ def tree_map_only(
|
|||
func: Fn3[T, S, U, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
...
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
|
@ -697,8 +693,7 @@ def tree_map_only(
|
|||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
...
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
|
@ -708,8 +703,7 @@ def tree_map_only(
|
|||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
...
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
def tree_map_only(
|
||||
|
|
@ -729,8 +723,7 @@ def tree_map_only_(
|
|||
func: Fn[T, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
...
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
|
@ -740,8 +733,7 @@ def tree_map_only_(
|
|||
func: Fn2[T, S, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
...
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
|
@ -751,8 +743,7 @@ def tree_map_only_(
|
|||
func: Fn3[T, S, U, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
...
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
|
@ -762,8 +753,7 @@ def tree_map_only_(
|
|||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
...
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
|
@ -773,8 +763,7 @@ def tree_map_only_(
|
|||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
...
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
def tree_map_only_(
|
||||
|
|
@ -812,8 +801,7 @@ def tree_all_only(
|
|||
pred: Fn[T, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
...
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
|
@ -823,8 +811,7 @@ def tree_all_only(
|
|||
pred: Fn2[T, S, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
...
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
|
@ -834,8 +821,7 @@ def tree_all_only(
|
|||
pred: Fn3[T, S, U, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
...
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
def tree_all_only(
|
||||
|
|
@ -856,8 +842,7 @@ def tree_any_only(
|
|||
pred: Fn[T, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
...
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
|
@ -867,8 +852,7 @@ def tree_any_only(
|
|||
pred: Fn2[T, S, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
...
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
|
@ -878,8 +862,7 @@ def tree_any_only(
|
|||
pred: Fn3[T, S, U, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
...
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
def tree_any_only(
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ _cache_sentinel = object()
|
|||
|
||||
|
||||
def cache_method(
|
||||
f: Callable[Concatenate[_C, _P], _T]
|
||||
f: Callable[Concatenate[_C, _P], _T],
|
||||
) -> Callable[Concatenate[_C, _P], _T]:
|
||||
"""
|
||||
Like `@functools.cache` but for methods.
|
||||
|
|
|
|||
|
|
@ -302,14 +302,12 @@ class BaseTorchDispatchMode(TorchDispatchMode):
|
|||
|
||||
# Subtypes which have __tensor_flatten__ and __tensor_unflatten__.
|
||||
class TensorWithFlatten(Protocol):
|
||||
def __tensor_flatten__(self) -> tuple[Sequence[str], object]:
|
||||
...
|
||||
def __tensor_flatten__(self) -> tuple[Sequence[str], object]: ...
|
||||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(
|
||||
inner_tensors: int, flatten_spec: int, outer_size: int, outer_stride: int
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
) -> torch.Tensor: ...
|
||||
|
||||
# It would be really nice to be able to say that the return of
|
||||
# is_traceable_wrapper_subclass() is Intersection[torch.Tensor,
|
||||
|
|
@ -318,26 +316,20 @@ class TensorWithFlatten(Protocol):
|
|||
shape: torch._C.Size
|
||||
|
||||
@overload
|
||||
def stride(self, dim: None = None) -> tuple[int, ...]:
|
||||
...
|
||||
def stride(self, dim: None = None) -> tuple[int, ...]: ...
|
||||
|
||||
@overload
|
||||
def stride(self, dim: int) -> int:
|
||||
...
|
||||
def stride(self, dim: int) -> int: ...
|
||||
|
||||
@overload
|
||||
def size(self, dim: None = None) -> tuple[int, ...]:
|
||||
...
|
||||
def size(self, dim: None = None) -> tuple[int, ...]: ...
|
||||
|
||||
@overload
|
||||
def size(self, dim: int) -> int:
|
||||
...
|
||||
def size(self, dim: int) -> int: ...
|
||||
|
||||
def storage_offset(self) -> int:
|
||||
...
|
||||
def storage_offset(self) -> int: ...
|
||||
|
||||
def dim(self) -> int:
|
||||
...
|
||||
def dim(self) -> int: ...
|
||||
|
||||
@overload
|
||||
def to(
|
||||
|
|
@ -347,8 +339,7 @@ class TensorWithFlatten(Protocol):
|
|||
copy: bool = False,
|
||||
*,
|
||||
memory_format: Optional[torch.memory_format] = None,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
) -> torch.Tensor: ...
|
||||
|
||||
@overload
|
||||
def to(
|
||||
|
|
@ -359,8 +350,7 @@ class TensorWithFlatten(Protocol):
|
|||
copy: bool = False,
|
||||
*,
|
||||
memory_format: Optional[torch.memory_format] = None,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
) -> torch.Tensor: ...
|
||||
|
||||
@overload
|
||||
def to(
|
||||
|
|
@ -370,8 +360,7 @@ class TensorWithFlatten(Protocol):
|
|||
copy: bool = False,
|
||||
*,
|
||||
memory_format: Optional[torch.memory_format] = None,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
) -> torch.Tensor: ...
|
||||
|
||||
|
||||
def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]:
|
||||
|
|
|
|||
|
|
@ -99,17 +99,13 @@ NO_SERIALIZED_TYPE_NAME_FOUND = "NO_SERIALIZED_TYPE_NAME_FOUND"
|
|||
|
||||
|
||||
class KeyEntry(Protocol):
|
||||
def __hash__(self) -> int:
|
||||
...
|
||||
def __hash__(self) -> int: ...
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
...
|
||||
def __eq__(self, other: object) -> bool: ...
|
||||
|
||||
def __str__(self) -> str:
|
||||
...
|
||||
def __str__(self) -> str: ...
|
||||
|
||||
def get(self, parent: Any) -> Any:
|
||||
...
|
||||
def get(self, parent: Any) -> Any: ...
|
||||
|
||||
|
||||
class EnumEncoder(json.JSONEncoder):
|
||||
|
|
@ -757,7 +753,7 @@ def _tuple_flatten(d: tuple[T, ...]) -> tuple[list[T], Context]:
|
|||
|
||||
|
||||
def _tuple_flatten_with_keys(
|
||||
d: tuple[T, ...]
|
||||
d: tuple[T, ...],
|
||||
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
||||
values, context = _tuple_flatten(d)
|
||||
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
|
||||
|
|
@ -785,7 +781,7 @@ def _dict_flatten(d: dict[Any, T]) -> tuple[list[T], Context]:
|
|||
|
||||
|
||||
def _dict_flatten_with_keys(
|
||||
d: dict[Any, T]
|
||||
d: dict[Any, T],
|
||||
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
||||
values, context = _dict_flatten(d)
|
||||
return [(MappingKey(k), v) for k, v in zip(context, values)], context
|
||||
|
|
@ -849,7 +845,7 @@ def _ordereddict_flatten(d: OrderedDict[Any, T]) -> tuple[list[T], Context]:
|
|||
|
||||
|
||||
def _ordereddict_flatten_with_keys(
|
||||
d: OrderedDict[Any, T]
|
||||
d: OrderedDict[Any, T],
|
||||
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
||||
values, context = _ordereddict_flatten(d)
|
||||
return [(MappingKey(k), v) for k, v in zip(context, values)], context
|
||||
|
|
@ -872,7 +868,7 @@ def _defaultdict_flatten(d: defaultdict[Any, T]) -> tuple[list[T], Context]:
|
|||
|
||||
|
||||
def _defaultdict_flatten_with_keys(
|
||||
d: defaultdict[Any, T]
|
||||
d: defaultdict[Any, T],
|
||||
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
||||
values, context = _defaultdict_flatten(d)
|
||||
_, dict_context = context
|
||||
|
|
@ -1035,9 +1031,9 @@ def tree_is_leaf(
|
|||
False
|
||||
>>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple))
|
||||
True
|
||||
>>> tree_is_leaf({'a': 1, 'b': 2, 'c': 3})
|
||||
>>> tree_is_leaf({"a": 1, "b": 2, "c": 3})
|
||||
False
|
||||
>>> tree_is_leaf({'a': 1, 'b': 2, 'c': None})
|
||||
>>> tree_is_leaf({"a": 1, "b": 2, "c": None})
|
||||
False
|
||||
"""
|
||||
if is_leaf is not None and is_leaf(tree):
|
||||
|
|
@ -1346,9 +1342,9 @@ def tree_map(
|
|||
|
||||
See also :func:`tree_map_`.
|
||||
|
||||
>>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)})
|
||||
>>> tree_map(lambda x: x + 1, {"x": 7, "y": (42, 64)})
|
||||
{'x': 8, 'y': (43, 65)}
|
||||
>>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None})
|
||||
>>> tree_map(lambda x: x is None, {"x": 7, "y": (42, 64), "z": None})
|
||||
{'x': False, 'y': (False, False), 'z': True}
|
||||
|
||||
If multiple inputs are given, the structure of the tree is taken from the first input;
|
||||
|
|
@ -1432,29 +1428,28 @@ MapOnlyFn = Callable[[T], Callable[[Any], Any]]
|
|||
# These specializations help with type inference on the lambda passed to this
|
||||
# function
|
||||
@overload
|
||||
def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]:
|
||||
...
|
||||
def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]:
|
||||
...
|
||||
def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def map_only(type_or_types_or_pred: Type3[T, S, U], /) -> MapOnlyFn[Fn3[T, S, U, Any]]:
|
||||
...
|
||||
def map_only(
|
||||
type_or_types_or_pred: Type3[T, S, U], /
|
||||
) -> MapOnlyFn[Fn3[T, S, U, Any]]: ...
|
||||
|
||||
|
||||
# This specialization is needed for the implementations below that call
|
||||
@overload
|
||||
def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]:
|
||||
...
|
||||
def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def map_only(type_or_types_or_pred: Callable[[Any], bool], /) -> MapOnlyFn[FnAny[Any]]:
|
||||
...
|
||||
def map_only(
|
||||
type_or_types_or_pred: Callable[[Any], bool], /
|
||||
) -> MapOnlyFn[FnAny[Any]]: ...
|
||||
|
||||
|
||||
def map_only(
|
||||
|
|
@ -1510,8 +1505,7 @@ def tree_map_only(
|
|||
func: Fn[T, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
...
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
|
@ -1521,8 +1515,7 @@ def tree_map_only(
|
|||
func: Fn2[T, S, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
...
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
|
@ -1532,8 +1525,7 @@ def tree_map_only(
|
|||
func: Fn3[T, S, U, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
...
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
|
@ -1543,8 +1535,7 @@ def tree_map_only(
|
|||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
...
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
|
@ -1554,8 +1545,7 @@ def tree_map_only(
|
|||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
...
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
def tree_map_only(
|
||||
|
|
@ -1575,8 +1565,7 @@ def tree_map_only_(
|
|||
func: Fn[T, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
...
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
|
@ -1586,8 +1575,7 @@ def tree_map_only_(
|
|||
func: Fn2[T, S, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
...
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
|
@ -1597,8 +1585,7 @@ def tree_map_only_(
|
|||
func: Fn3[T, S, U, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
...
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
|
@ -1608,8 +1595,7 @@ def tree_map_only_(
|
|||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
...
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
|
@ -1619,8 +1605,7 @@ def tree_map_only_(
|
|||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
...
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
def tree_map_only_(
|
||||
|
|
@ -1658,8 +1643,7 @@ def tree_all_only(
|
|||
pred: Fn[T, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
...
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
|
@ -1669,8 +1653,7 @@ def tree_all_only(
|
|||
pred: Fn2[T, S, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
...
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
|
@ -1680,8 +1663,7 @@ def tree_all_only(
|
|||
pred: Fn3[T, S, U, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
...
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
def tree_all_only(
|
||||
|
|
@ -1702,8 +1684,7 @@ def tree_any_only(
|
|||
pred: Fn[T, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
...
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
|
@ -1713,8 +1694,7 @@ def tree_any_only(
|
|||
pred: Fn2[T, S, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
...
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
|
@ -1724,8 +1704,7 @@ def tree_any_only(
|
|||
pred: Fn3[T, S, U, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
...
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
def tree_any_only(
|
||||
|
|
@ -1862,7 +1841,7 @@ def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec:
|
|||
|
||||
if json_schema["type"] not in SERIALIZED_TYPE_TO_PYTHON_TYPE:
|
||||
raise NotImplementedError(
|
||||
f'Deserializing {json_schema["type"]} in pytree is not registered.',
|
||||
f"Deserializing {json_schema['type']} in pytree is not registered.",
|
||||
)
|
||||
|
||||
typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[json_schema["type"]]
|
||||
|
|
|
|||
|
|
@ -301,7 +301,7 @@ def strobelight(
|
|||
profiler = StrobelightCLIFunctionProfiler(**kwargs)
|
||||
|
||||
def strobelight_inner(
|
||||
work_function: Callable[_P, _R]
|
||||
work_function: Callable[_P, _R],
|
||||
) -> Callable[_P, Optional[_R]]:
|
||||
@functools.wraps(work_function)
|
||||
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]:
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ def _is_symbols_binary_summation(expr: sympy.Expr) -> bool:
|
|||
|
||||
|
||||
def _keep_float(
|
||||
f: Callable[[Unpack[_Ts]], _T]
|
||||
f: Callable[[Unpack[_Ts]], _T],
|
||||
) -> Callable[[Unpack[_Ts]], Union[_T, sympy.Float]]:
|
||||
@functools.wraps(f)
|
||||
def inner(*args: Unpack[_Ts]) -> Union[_T, sympy.Float]:
|
||||
|
|
@ -926,10 +926,12 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
|
|||
|
||||
_eval_is_algebraic = lambda s: _torf(i.is_algebraic for i in s.args) # noqa: E731
|
||||
_eval_is_antihermitian = lambda s: _torf( # noqa: E731
|
||||
i.is_antihermitian for i in s.args # noqa: E731
|
||||
i.is_antihermitian
|
||||
for i in s.args # noqa: E731
|
||||
) # noqa: E731
|
||||
_eval_is_commutative = lambda s: _torf( # noqa: E731
|
||||
i.is_commutative for i in s.args # noqa: E731
|
||||
i.is_commutative
|
||||
for i in s.args # noqa: E731
|
||||
) # noqa: E731
|
||||
_eval_is_complex = lambda s: _torf(i.is_complex for i in s.args) # noqa: E731
|
||||
_eval_is_composite = lambda s: _torf(i.is_composite for i in s.args) # noqa: E731
|
||||
|
|
@ -943,10 +945,12 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
|
|||
_eval_is_negative = lambda s: _torf(i.is_negative for i in s.args) # noqa: E731
|
||||
_eval_is_noninteger = lambda s: _torf(i.is_noninteger for i in s.args) # noqa: E731
|
||||
_eval_is_nonnegative = lambda s: _torf( # noqa: E731
|
||||
i.is_nonnegative for i in s.args # noqa: E731
|
||||
i.is_nonnegative
|
||||
for i in s.args # noqa: E731
|
||||
) # noqa: E731
|
||||
_eval_is_nonpositive = lambda s: _torf( # noqa: E731
|
||||
i.is_nonpositive for i in s.args # noqa: E731
|
||||
i.is_nonpositive
|
||||
for i in s.args # noqa: E731
|
||||
) # noqa: E731
|
||||
_eval_is_nonzero = lambda s: _torf(i.is_nonzero for i in s.args) # noqa: E731
|
||||
_eval_is_odd = lambda s: _torf(i.is_odd for i in s.args) # noqa: E731
|
||||
|
|
@ -956,10 +960,12 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
|
|||
_eval_is_rational = lambda s: _torf(i.is_rational for i in s.args) # noqa: E731
|
||||
_eval_is_real = lambda s: _torf(i.is_real for i in s.args) # noqa: E731
|
||||
_eval_is_extended_real = lambda s: _torf( # noqa: E731
|
||||
i.is_extended_real for i in s.args # noqa: E731
|
||||
i.is_extended_real
|
||||
for i in s.args # noqa: E731
|
||||
) # noqa: E731
|
||||
_eval_is_transcendental = lambda s: _torf( # noqa: E731
|
||||
i.is_transcendental for i in s.args # noqa: E731
|
||||
i.is_transcendental
|
||||
for i in s.args # noqa: E731
|
||||
) # noqa: E731
|
||||
_eval_is_zero = lambda s: _torf(i.is_zero for i in s.args) # noqa: E731
|
||||
|
||||
|
|
|
|||
|
|
@ -144,16 +144,14 @@ class ValueRanges(Generic[_T]):
|
|||
self: ValueRanges[sympy.Expr],
|
||||
lower: ExprIn,
|
||||
upper: ExprIn,
|
||||
) -> None:
|
||||
...
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
def __init__( # type: ignore[misc]
|
||||
self: ValueRanges[SympyBoolean],
|
||||
lower: BoolIn,
|
||||
upper: BoolIn,
|
||||
) -> None:
|
||||
...
|
||||
) -> None: ...
|
||||
|
||||
def __init__(self, lower: AllIn, upper: AllIn) -> None:
|
||||
lower = simple_sympify(lower)
|
||||
|
|
@ -240,15 +238,13 @@ class ValueRanges(Generic[_T]):
|
|||
def __and__(
|
||||
self: ValueRanges[sympy.Expr],
|
||||
other: ValueRanges[sympy.Expr],
|
||||
) -> ValueRanges[sympy.Expr]:
|
||||
...
|
||||
) -> ValueRanges[sympy.Expr]: ...
|
||||
|
||||
@overload
|
||||
def __and__( # type: ignore[misc]
|
||||
self: ValueRanges[SympyBoolean],
|
||||
other: ValueRanges[SympyBoolean],
|
||||
) -> ValueRanges[SympyBoolean]:
|
||||
...
|
||||
) -> ValueRanges[SympyBoolean]: ...
|
||||
|
||||
def __and__(self: AllVR, other: AllVR) -> AllVR:
|
||||
if other in (ValueRanges.unknown(), ValueRanges.unknown_int()):
|
||||
|
|
@ -272,15 +268,13 @@ class ValueRanges(Generic[_T]):
|
|||
def __or__(
|
||||
self: ValueRanges[sympy.Expr],
|
||||
other: ValueRanges[sympy.Expr],
|
||||
) -> ValueRanges[sympy.Expr]:
|
||||
...
|
||||
) -> ValueRanges[sympy.Expr]: ...
|
||||
|
||||
@overload
|
||||
def __or__( # type: ignore[misc]
|
||||
self: ValueRanges[SympyBoolean],
|
||||
other: ValueRanges[SympyBoolean],
|
||||
) -> ValueRanges[SympyBoolean]:
|
||||
...
|
||||
) -> ValueRanges[SympyBoolean]: ...
|
||||
|
||||
def __or__(self: AllVR, other: AllVR) -> AllVR:
|
||||
if ValueRanges.unknown() in (self, other):
|
||||
|
|
@ -343,8 +337,7 @@ class ValueRanges(Generic[_T]):
|
|||
|
||||
@overload
|
||||
@staticmethod
|
||||
def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
|
||||
...
|
||||
def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: ...
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
|
|
@ -384,8 +377,7 @@ class ValueRanges(Generic[_T]):
|
|||
x: Union[ExprIn, ExprVR],
|
||||
y: Union[ExprIn, ExprVR],
|
||||
fn: ExprFn2,
|
||||
) -> ExprVR:
|
||||
...
|
||||
) -> ExprVR: ...
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
|
|
@ -393,8 +385,7 @@ class ValueRanges(Generic[_T]):
|
|||
x: Union[BoolIn, BoolVR],
|
||||
y: Union[BoolIn, BoolVR],
|
||||
fn: BoolFn2,
|
||||
) -> BoolVR:
|
||||
...
|
||||
) -> BoolVR: ...
|
||||
|
||||
@staticmethod
|
||||
def coordinatewise_increasing_map(
|
||||
|
|
|
|||
|
|
@ -426,9 +426,9 @@ def _get_custom_mod_func(func_name: str):
|
|||
it is marked as private. It is a convenience function for backend implementers to
|
||||
more easily call the hooks into their backend extensions.
|
||||
"""
|
||||
assert isinstance(
|
||||
func_name, str
|
||||
), f"func_name must be `str`, but got `{type(func_name)}`."
|
||||
assert isinstance(func_name, str), (
|
||||
f"func_name must be `str`, but got `{type(func_name)}`."
|
||||
)
|
||||
backend_name = _get_privateuse1_backend_name()
|
||||
custom_device_mod = getattr(torch, backend_name, None) # type: ignore[arg-type]
|
||||
function = getattr(custom_device_mod, func_name, None) # type: ignore[arg-type]
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ def default_convert(data):
|
|||
>>> default_convert(np.array([0, 1]))
|
||||
tensor([0, 1])
|
||||
>>> # Example with NamedTuple
|
||||
>>> Point = namedtuple('Point', ['x', 'y'])
|
||||
>>> Point = namedtuple("Point", ["x", "y"])
|
||||
>>> default_convert(Point(0, 0))
|
||||
Point(x=0, y=0)
|
||||
>>> default_convert(Point(np.array(0), np.array(0)))
|
||||
|
|
@ -366,13 +366,13 @@ def default_collate(batch):
|
|||
>>> default_collate([0, 1, 2, 3])
|
||||
tensor([0, 1, 2, 3])
|
||||
>>> # Example with a batch of `str`s:
|
||||
>>> default_collate(['a', 'b', 'c'])
|
||||
>>> default_collate(["a", "b", "c"])
|
||||
['a', 'b', 'c']
|
||||
>>> # Example with `Map` inside the batch:
|
||||
>>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}])
|
||||
>>> default_collate([{"A": 0, "B": 1}, {"A": 100, "B": 100}])
|
||||
{'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])}
|
||||
>>> # Example with `NamedTuple` inside the batch:
|
||||
>>> Point = namedtuple('Point', ['x', 'y'])
|
||||
>>> Point = namedtuple("Point", ["x", "y"])
|
||||
>>> default_collate([Point(0, 0), Point(1, 1)])
|
||||
Point(x=tensor([0, 1]), y=tensor([0, 1]))
|
||||
>>> # Example with `Tuple` inside the batch:
|
||||
|
|
|
|||
|
|
@ -69,7 +69,9 @@ def pin_memory(data, device=None):
|
|||
)
|
||||
return clone
|
||||
else:
|
||||
return type(data)({k: pin_memory(sample, device) for k, sample in data.items()}) # type: ignore[call-arg]
|
||||
return type(data)(
|
||||
{k: pin_memory(sample, device) for k, sample in data.items()}
|
||||
) # type: ignore[call-arg]
|
||||
except TypeError:
|
||||
# The mapping type may not support `copy()` / `update(mapping)`
|
||||
# or `__init__(iterable)`.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
|
||||
r"""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
|
||||
|
||||
These **needs** to be in global scope since Py2 doesn't support serializing
|
||||
static methods.
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ To support these two classes, in `./_utils` we define many utility methods and
|
|||
functions to be run in multiprocessing. E.g., the data loading worker loop is
|
||||
in `./_utils/worker.py`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
|
|
@ -1208,7 +1209,10 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
|||
atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
|
||||
|
||||
# .pid can be None only before process is spawned (not the case, so ignore)
|
||||
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc]
|
||||
_utils.signal_handling._set_worker_pids(
|
||||
id(self),
|
||||
tuple(w.pid for w in self._workers), # type: ignore[misc]
|
||||
)
|
||||
_utils.signal_handling._set_SIGCHLD_handler()
|
||||
self._worker_pids_set = True
|
||||
self._reset(loader, first_iter=True)
|
||||
|
|
|
|||
|
|
@ -109,8 +109,7 @@ class non_deterministic:
|
|||
|
||||
# Decorate with a functional argument
|
||||
if not (
|
||||
isinstance(args[0], type)
|
||||
and issubclass(args[0], IterDataPipe) # type: ignore[arg-type]
|
||||
isinstance(args[0], type) and issubclass(args[0], IterDataPipe) # type: ignore[arg-type]
|
||||
):
|
||||
raise TypeError(
|
||||
f"Only `IterDataPipe` can be decorated, but {args[0].__name__} is found"
|
||||
|
|
|
|||
|
|
@ -99,7 +99,9 @@ class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
|
|||
>>> from torchdata.datapipes.iter import IterableWrapper, Mapper
|
||||
>>> dp = IterableWrapper(range(10))
|
||||
>>> map_dp_1 = Mapper(dp, lambda x: x + 1) # Using class constructor
|
||||
>>> map_dp_2 = dp.map(lambda x: x + 1) # Using functional form (recommended)
|
||||
>>> map_dp_2 = dp.map(
|
||||
... lambda x: x + 1
|
||||
... ) # Using functional form (recommended)
|
||||
>>> list(map_dp_1)
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
>>> list(map_dp_2)
|
||||
|
|
@ -114,7 +116,9 @@ class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
|
|||
>>> list(it1)
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
>>> it1 = iter(source_dp)
|
||||
>>> it2 = iter(source_dp) # The creation of a new iterator invalidates `it1`
|
||||
>>> it2 = iter(
|
||||
... source_dp
|
||||
... ) # The creation of a new iterator invalidates `it1`
|
||||
>>> next(it2)
|
||||
0
|
||||
>>> next(it1) # Further usage of `it1` will raise a `RunTimeError`
|
||||
|
|
|
|||
|
|
@ -55,7 +55,8 @@ class MapperIterDataPipe(IterDataPipe[_T_co]):
|
|||
>>> def add_one(x):
|
||||
... return x + 1
|
||||
>>> dp = IterableWrapper(range(10))
|
||||
>>> map_dp_1 = dp.map(add_one) # Invocation via functional form is preferred
|
||||
>>> # Invocation via functional form is preferred
|
||||
... map_dp_1 = dp.map(add_one)
|
||||
>>> list(map_dp_1)
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
>>> # We discourage the usage of `lambda` functions as they are not serializable with `pickle`
|
||||
|
|
@ -202,7 +203,7 @@ class CollatorIterDataPipe(MapperIterDataPipe):
|
|||
>>> class MyIterDataPipe(torch.utils.data.IterDataPipe):
|
||||
... def __init__(self, start, end):
|
||||
... super(MyIterDataPipe).__init__()
|
||||
... assert end > start, "this example code only works with end >= start"
|
||||
... assert end > start, "this example only works with end >= start"
|
||||
... self.start = start
|
||||
... self.end = end
|
||||
...
|
||||
|
|
@ -211,13 +212,11 @@ class CollatorIterDataPipe(MapperIterDataPipe):
|
|||
...
|
||||
... def __len__(self):
|
||||
... return self.end - self.start
|
||||
...
|
||||
>>> ds = MyIterDataPipe(start=3, end=7)
|
||||
>>> print(list(ds))
|
||||
[3, 4, 5, 6]
|
||||
>>> def collate_fn(batch):
|
||||
... return torch.tensor(batch, dtype=torch.float)
|
||||
...
|
||||
>>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn)
|
||||
>>> print(list(collated_ds))
|
||||
[tensor(3.), tensor(4.), tensor(5.), tensor(6.)]
|
||||
|
|
|
|||
|
|
@ -38,15 +38,17 @@ class SamplerIterDataPipe(IterDataPipe[_T_co]):
|
|||
sampler_args: Optional[tuple] = None,
|
||||
sampler_kwargs: Optional[dict] = None,
|
||||
) -> None:
|
||||
assert isinstance(
|
||||
datapipe, Sized
|
||||
), "Sampler class requires input datapipe implemented `__len__`"
|
||||
assert isinstance(datapipe, Sized), (
|
||||
"Sampler class requires input datapipe implemented `__len__`"
|
||||
)
|
||||
super().__init__()
|
||||
self.datapipe = datapipe
|
||||
self.sampler_args = () if sampler_args is None else sampler_args
|
||||
self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs
|
||||
# https://github.com/python/mypy/pull/9629 will solve
|
||||
self.sampler = sampler(*self.sampler_args, data_source=self.datapipe, **self.sampler_kwargs) # type: ignore[misc]
|
||||
self.sampler = sampler(
|
||||
*self.sampler_args, data_source=self.datapipe, **self.sampler_kwargs
|
||||
) # type: ignore[misc]
|
||||
|
||||
def __iter__(self) -> Iterator[_T_co]:
|
||||
return iter(self.sampler)
|
||||
|
|
|
|||
|
|
@ -116,16 +116,13 @@ class _ContainerTemplate(ABC):
|
|||
r"""Abstract class for container ``DataPipes``. The followings are three required methods."""
|
||||
|
||||
@abstractmethod
|
||||
def get_next_element_by_instance(self, instance_id: int):
|
||||
...
|
||||
def get_next_element_by_instance(self, instance_id: int): ...
|
||||
|
||||
@abstractmethod
|
||||
def is_every_instance_exhausted(self) -> bool:
|
||||
...
|
||||
def is_every_instance_exhausted(self) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
...
|
||||
def reset(self) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_length_by_instance(self, instance_id: int):
|
||||
|
|
@ -403,7 +400,9 @@ class DemultiplexerIterDataPipe(IterDataPipe):
|
|||
>>> # It can also filter out any element that gets `None` from the `classifier_fn`
|
||||
>>> def odd_or_even_no_zero(n):
|
||||
... return n % 2 if n != 0 else None
|
||||
>>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True)
|
||||
>>> dp1, dp2 = source_dp.demux(
|
||||
... num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True
|
||||
... )
|
||||
>>> list(dp1)
|
||||
[2, 4]
|
||||
>>> list(dp2)
|
||||
|
|
@ -428,7 +427,9 @@ class DemultiplexerIterDataPipe(IterDataPipe):
|
|||
# When num_instances == 1, demux can be replaced by filter,
|
||||
# but keep it as Demultiplexer for the sake of consistency
|
||||
# like throwing Error when classification result is out of o range
|
||||
container = _DemultiplexerIterDataPipe(datapipe, num_instances, classifier_fn, drop_none, buffer_size) # type: ignore[abstract]
|
||||
container = _DemultiplexerIterDataPipe(
|
||||
datapipe, num_instances, classifier_fn, drop_none, buffer_size
|
||||
) # type: ignore[abstract]
|
||||
return [_ChildDataPipe(container, i) for i in range(num_instances)]
|
||||
|
||||
|
||||
|
|
@ -602,16 +603,18 @@ class MultiplexerIterDataPipe(IterDataPipe):
|
|||
Example:
|
||||
>>> # xdoctest: +REQUIRES(module:torchdata)
|
||||
>>> from torchdata.datapipes.iter import IterableWrapper
|
||||
>>> dp1, dp2, dp3 = IterableWrapper(range(3)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25))
|
||||
>>> dp1, dp2, dp3 = (
|
||||
... IterableWrapper(range(3)),
|
||||
... IterableWrapper(range(10, 15)),
|
||||
... IterableWrapper(range(20, 25)),
|
||||
... )
|
||||
>>> list(dp1.mux(dp2, dp3))
|
||||
[0, 10, 20, 1, 11, 21, 2, 12, 22]
|
||||
"""
|
||||
|
||||
def __init__(self, *datapipes):
|
||||
self.datapipes = datapipes
|
||||
self.buffer: list = (
|
||||
[]
|
||||
) # Store values to be yielded only when every iterator provides one
|
||||
self.buffer: list = [] # Store values to be yielded only when every iterator provides one
|
||||
|
||||
def __iter__(self):
|
||||
iterators = [iter(x) for x in self.datapipes]
|
||||
|
|
@ -670,7 +673,11 @@ class ZipperIterDataPipe(IterDataPipe[tuple[_T_co]]):
|
|||
Example:
|
||||
>>> # xdoctest: +REQUIRES(module:torchdata)
|
||||
>>> from torchdata.datapipes.iter import IterableWrapper
|
||||
>>> dp1, dp2, dp3 = IterableWrapper(range(5)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25))
|
||||
>>> dp1, dp2, dp3 = (
|
||||
... IterableWrapper(range(5)),
|
||||
... IterableWrapper(range(10, 15)),
|
||||
... IterableWrapper(range(20, 25)),
|
||||
... )
|
||||
>>> list(dp1.zip(dp2, dp3))
|
||||
[(0, 10, 20), (1, 11, 21), (2, 12, 22), (3, 13, 23), (4, 14, 24)]
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -33,8 +33,12 @@ class FileOpenerIterDataPipe(IterDataPipe[tuple[str, IOBase]]):
|
|||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> from torchdata.datapipes.iter import FileLister, FileOpener, StreamReader
|
||||
>>> dp = FileLister(root=".").filter(lambda fname: fname.endswith('.txt'))
|
||||
>>> from torchdata.datapipes.iter import (
|
||||
... FileLister,
|
||||
... FileOpener,
|
||||
... StreamReader,
|
||||
... )
|
||||
>>> dp = FileLister(root=".").filter(lambda fname: fname.endswith(".txt"))
|
||||
>>> dp = FileOpener(dp)
|
||||
>>> dp = StreamReader(dp)
|
||||
>>> list(dp)
|
||||
|
|
|
|||
|
|
@ -182,7 +182,9 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]):
|
|||
>>> from torchdata.datapipes.iter import IterableWrapper
|
||||
>>> def group_fn(file):
|
||||
... return os.path.basename(file).split(".")[0]
|
||||
>>> source_dp = IterableWrapper(["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"])
|
||||
>>> source_dp = IterableWrapper(
|
||||
... ["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"]
|
||||
... )
|
||||
>>> dp0 = source_dp.groupby(group_key_fn=group_fn)
|
||||
>>> list(dp0)
|
||||
[['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']]
|
||||
|
|
@ -191,7 +193,12 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]):
|
|||
>>> list(dp1)
|
||||
[['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
|
||||
>>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size`
|
||||
>>> dp2 = source_dp.groupby(group_key_fn=group_fn, buffer_size=3, group_size=3, guaranteed_group_size=2)
|
||||
>>> dp2 = source_dp.groupby(
|
||||
... group_key_fn=group_fn,
|
||||
... buffer_size=3,
|
||||
... group_size=3,
|
||||
... guaranteed_group_size=2,
|
||||
... )
|
||||
>>> list(dp2)
|
||||
[['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -31,8 +31,8 @@ class SequenceWrapperMapDataPipe(MapDataPipe[_T]):
|
|||
>>> dp = SequenceWrapper(range(10))
|
||||
>>> list(dp)
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
>>> dp = SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400})
|
||||
>>> dp['a']
|
||||
>>> dp = SequenceWrapper({"a": 100, "b": 200, "c": 300, "d": 400})
|
||||
>>> dp["a"]
|
||||
100
|
||||
"""
|
||||
|
||||
|
|
|
|||
|
|
@ -45,8 +45,8 @@ def basichandlers(extension: str, data):
|
|||
|
||||
Example:
|
||||
>>> import pickle
|
||||
>>> data = pickle.dumps('some data')
|
||||
>>> new_data = basichandlers('pickle', data)
|
||||
>>> data = pickle.dumps("some data")
|
||||
>>> new_data = basichandlers("pickle", data)
|
||||
>>> new_data
|
||||
some data
|
||||
|
||||
|
|
@ -169,9 +169,9 @@ class ImageHandler:
|
|||
"""
|
||||
|
||||
def __init__(self, imagespec):
|
||||
assert imagespec in list(
|
||||
imagespecs.keys()
|
||||
), f"unknown image specification: {imagespec}"
|
||||
assert imagespec in list(imagespecs.keys()), (
|
||||
f"unknown image specification: {imagespec}"
|
||||
)
|
||||
self.imagespec = imagespec.lower()
|
||||
|
||||
def __call__(self, extension, data):
|
||||
|
|
@ -205,18 +205,18 @@ class ImageHandler:
|
|||
return img
|
||||
elif atype == "numpy":
|
||||
result = np.asarray(img)
|
||||
assert (
|
||||
result.dtype == np.uint8
|
||||
), f"numpy image array should be type uint8, but got {result.dtype}"
|
||||
assert result.dtype == np.uint8, (
|
||||
f"numpy image array should be type uint8, but got {result.dtype}"
|
||||
)
|
||||
if etype == "uint8":
|
||||
return result
|
||||
else:
|
||||
return result.astype("f") / 255.0
|
||||
elif atype == "torch":
|
||||
result = np.asarray(img)
|
||||
assert (
|
||||
result.dtype == np.uint8
|
||||
), f"numpy image array should be type uint8, but got {result.dtype}"
|
||||
assert result.dtype == np.uint8, (
|
||||
f"numpy image array should be type uint8, but got {result.dtype}"
|
||||
)
|
||||
|
||||
if etype == "uint8":
|
||||
result = np.array(result.transpose(2, 0, 1))
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ class IterableDataset(Dataset[_T_co], Iterable[_T_co]):
|
|||
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
|
||||
... def __init__(self, start, end):
|
||||
... super(MyIterableDataset).__init__()
|
||||
... assert end > start, "this example code only works with end >= start"
|
||||
... assert end > start, "this example only works with end >= start"
|
||||
... self.start = start
|
||||
... self.end = end
|
||||
...
|
||||
|
|
@ -138,7 +138,7 @@ class IterableDataset(Dataset[_T_co], Iterable[_T_co]):
|
|||
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
|
||||
... def __init__(self, start, end):
|
||||
... super(MyIterableDataset).__init__()
|
||||
... assert end > start, "this example code only works with end >= start"
|
||||
... assert end > start, "this example only works with end >= start"
|
||||
... self.start = start
|
||||
... self.end = end
|
||||
...
|
||||
|
|
@ -198,9 +198,9 @@ class TensorDataset(Dataset[tuple[Tensor, ...]]):
|
|||
tensors: tuple[Tensor, ...]
|
||||
|
||||
def __init__(self, *tensors: Tensor) -> None:
|
||||
assert all(
|
||||
tensors[0].size(0) == tensor.size(0) for tensor in tensors
|
||||
), "Size mismatch between tensors"
|
||||
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), (
|
||||
"Size mismatch between tensors"
|
||||
)
|
||||
self.tensors = tensors
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
|
@ -222,7 +222,7 @@ class StackDataset(Dataset[_T_stack]):
|
|||
>>> tuple_stack = StackDataset(images, texts)
|
||||
>>> tuple_stack[0] == (images[0], texts[0])
|
||||
>>> dict_stack = StackDataset(image=images, text=texts)
|
||||
>>> dict_stack[0] == {'image': images[0], 'text': texts[0]}
|
||||
>>> dict_stack[0] == {"image": images[0], "text": texts[0]}
|
||||
|
||||
Args:
|
||||
*args (Dataset): Datasets for stacking returned as tuple.
|
||||
|
|
@ -323,9 +323,9 @@ class ConcatDataset(Dataset[_T_co]):
|
|||
self.datasets = list(datasets)
|
||||
assert len(self.datasets) > 0, "datasets should not be an empty iterable" # type: ignore[arg-type]
|
||||
for d in self.datasets:
|
||||
assert not isinstance(
|
||||
d, IterableDataset
|
||||
), "ConcatDataset does not support IterableDataset"
|
||||
assert not isinstance(d, IterableDataset), (
|
||||
"ConcatDataset does not support IterableDataset"
|
||||
)
|
||||
self.cumulative_sizes = self.cumsum(self.datasets)
|
||||
|
||||
def __len__(self):
|
||||
|
|
@ -371,17 +371,17 @@ class ChainDataset(IterableDataset):
|
|||
|
||||
def __iter__(self):
|
||||
for d in self.datasets:
|
||||
assert isinstance(
|
||||
d, IterableDataset
|
||||
), "ChainDataset only supports IterableDataset"
|
||||
assert isinstance(d, IterableDataset), (
|
||||
"ChainDataset only supports IterableDataset"
|
||||
)
|
||||
yield from d
|
||||
|
||||
def __len__(self):
|
||||
total = 0
|
||||
for d in self.datasets:
|
||||
assert isinstance(
|
||||
d, IterableDataset
|
||||
), "ChainDataset only supports IterableDataset"
|
||||
assert isinstance(d, IterableDataset), (
|
||||
"ChainDataset only supports IterableDataset"
|
||||
)
|
||||
total += len(d) # type: ignore[arg-type]
|
||||
return total
|
||||
|
||||
|
|
|
|||
|
|
@ -236,9 +236,17 @@ class WeightedRandomSampler(Sampler[int]):
|
|||
|
||||
Example:
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
|
||||
>>> list(
|
||||
... WeightedRandomSampler(
|
||||
... [0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True
|
||||
... )
|
||||
... )
|
||||
[4, 4, 1, 4, 5]
|
||||
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
|
||||
>>> list(
|
||||
... WeightedRandomSampler(
|
||||
... [0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False
|
||||
... )
|
||||
... )
|
||||
[0, 1, 4, 3, 2]
|
||||
"""
|
||||
|
||||
|
|
@ -298,9 +306,15 @@ class BatchSampler(Sampler[list[int]]):
|
|||
its size would be less than ``batch_size``
|
||||
|
||||
Example:
|
||||
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
|
||||
>>> list(
|
||||
... BatchSampler(
|
||||
... SequentialSampler(range(10)), batch_size=3, drop_last=False
|
||||
... )
|
||||
... )
|
||||
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
|
||||
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
|
||||
>>> list(
|
||||
... BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)
|
||||
... )
|
||||
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
|
||||
"""
|
||||
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ class ModuleTracker:
|
|||
def my_linear(m1, m2, bias):
|
||||
print(f"Current modules: {tracker.parents}")
|
||||
return torch.mm(m1, m2.t()) + bias
|
||||
|
||||
torch.nn.functional.linear = my_linear
|
||||
|
||||
mod(torch.rand(2, 2))
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ Intel GPU optimization.
|
|||
This package is lazily initialized, so you can always import it, and use
|
||||
:func:`is_available()` to determine if your system supports XPU.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import traceback
|
||||
from functools import lru_cache
|
||||
|
|
@ -292,6 +293,7 @@ class StreamContext:
|
|||
``None``.
|
||||
.. note:: Streams are per-device.
|
||||
"""
|
||||
|
||||
cur_stream: Optional["torch.xpu.Stream"]
|
||||
|
||||
def __init__(self, stream: Optional["torch.xpu.Stream"]):
|
||||
|
|
@ -438,7 +440,7 @@ def get_gencode_flags() -> str:
|
|||
arch_list = get_arch_list()
|
||||
if len(arch_list) == 0:
|
||||
return ""
|
||||
return f'-device {",".join(arch for arch in arch_list)}'
|
||||
return f"-device {','.join(arch for arch in arch_list)}"
|
||||
|
||||
|
||||
def _get_generator(device: torch.device) -> torch._C.Generator:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user