mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE][PYFMT] migrate PYFMT for torch/[p-z]*/ to ruff format (#144552)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144552 Approved by: https://github.com/ezyang
This commit is contained in:
parent
fd606a3a91
commit
5cedc5a0ff
|
|
@ -52,7 +52,6 @@ USE_BLACK_FILELIST = re.compile(
|
||||||
# torch/[e-m]*/**
|
# torch/[e-m]*/**
|
||||||
# torch/optim/**
|
# torch/optim/**
|
||||||
# torch/[p-z]*/**
|
# torch/[p-z]*/**
|
||||||
"torch/[p-z]*/**",
|
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
"""Import mangling.
|
"""Import mangling.
|
||||||
See mangling.md for details.
|
See mangling.md for details.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -605,9 +605,9 @@ class PackageExporter:
|
||||||
dependencies (bool, optional): If ``True``, we scan the source for dependencies.
|
dependencies (bool, optional): If ``True``, we scan the source for dependencies.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert (pickle_protocol == 4) or (
|
assert (pickle_protocol == 4) or (pickle_protocol == 3), (
|
||||||
pickle_protocol == 3
|
"torch.package only supports pickle protocols 3 and 4"
|
||||||
), "torch.package only supports pickle protocols 3 and 4"
|
)
|
||||||
|
|
||||||
filename = self._filename(package, resource)
|
filename = self._filename(package, resource)
|
||||||
# Write the pickle data for `obj`
|
# Write the pickle data for `obj`
|
||||||
|
|
|
||||||
|
|
@ -423,7 +423,12 @@ class PackageImporter(Importer):
|
||||||
module.__dict__.setdefault(old_name, new_name)
|
module.__dict__.setdefault(old_name, new_name)
|
||||||
|
|
||||||
return module
|
return module
|
||||||
return self._make_module(name, cur.source_file, isinstance(cur, _PackageNode), parent) # type: ignore[attr-defined]
|
return self._make_module(
|
||||||
|
name,
|
||||||
|
cur.source_file, # type: ignore[attr-defined]
|
||||||
|
isinstance(cur, _PackageNode),
|
||||||
|
parent,
|
||||||
|
)
|
||||||
|
|
||||||
def _compile_source(self, fullpath: str, mangled_filename: str):
|
def _compile_source(self, fullpath: str, mangled_filename: str):
|
||||||
source = self.zip_reader.get_record(fullpath)
|
source = self.zip_reader.get_record(fullpath)
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ examine their input shapes and stack traces, study device kernel activity and vi
|
||||||
An earlier version of the API in :mod:`torch.autograd` module is considered legacy and will be deprecated.
|
An earlier version of the API in :mod:`torch.autograd` module is considered legacy and will be deprecated.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing_extensions import TypeVarTuple, Unpack
|
from typing_extensions import TypeVarTuple, Unpack
|
||||||
|
|
|
||||||
|
|
@ -239,10 +239,12 @@ class SchemaMatcher:
|
||||||
def match_schemas(cls, t: _ExtraFields_TorchOp) -> tuple[FunctionSchema, ...]:
|
def match_schemas(cls, t: _ExtraFields_TorchOp) -> tuple[FunctionSchema, ...]:
|
||||||
signature = tuple(
|
signature = tuple(
|
||||||
# Tensor
|
# Tensor
|
||||||
TensorKey.from_tensor(i) if isinstance(i, _TensorMetadata)
|
TensorKey.from_tensor(i)
|
||||||
|
if isinstance(i, _TensorMetadata)
|
||||||
#
|
#
|
||||||
# TensorList
|
# TensorList
|
||||||
else [TensorKey.from_tensor(j) for j in i] if isinstance(i, list)
|
else [TensorKey.from_tensor(j) for j in i]
|
||||||
|
if isinstance(i, list)
|
||||||
#
|
#
|
||||||
# Scalar and uncaptured inputs.
|
# Scalar and uncaptured inputs.
|
||||||
else i
|
else i
|
||||||
|
|
|
||||||
|
|
@ -124,9 +124,9 @@ class BasicEvaluation:
|
||||||
for child_event in curr_event.children:
|
for child_event in curr_event.children:
|
||||||
self_time -= child_event.duration_time_ns
|
self_time -= child_event.duration_time_ns
|
||||||
stack.append(child_event)
|
stack.append(child_event)
|
||||||
assert (
|
assert EventKey(curr_event) not in self.metrics, (
|
||||||
EventKey(curr_event) not in self.metrics
|
f"Duplicate id: {curr_event.id}, {curr_event.name}"
|
||||||
), f"Duplicate id: {curr_event.id}, {curr_event.name}"
|
)
|
||||||
self.metrics[EventKey(curr_event)] = EventMetrics(self_time_ns=self_time)
|
self.metrics[EventKey(curr_event)] = EventMetrics(self_time_ns=self_time)
|
||||||
self.metrics[
|
self.metrics[
|
||||||
EventKey(curr_event)
|
EventKey(curr_event)
|
||||||
|
|
@ -227,8 +227,7 @@ class BasicEvaluation:
|
||||||
|
|
||||||
while (
|
while (
|
||||||
current_kernel_index < len(cuda_kernel_events)
|
current_kernel_index < len(cuda_kernel_events)
|
||||||
and (cuda_kernel_events[current_kernel_index].start_ns())
|
and (cuda_kernel_events[current_kernel_index].start_ns()) <= start_time # type: ignore[possibly-undefined]
|
||||||
<= start_time # type: ignore[possibly-undefined]
|
|
||||||
):
|
):
|
||||||
current_kernel_index += 1
|
current_kernel_index += 1
|
||||||
current_queue_depth = spawned_kernel_index - current_kernel_index + 1
|
current_queue_depth = spawned_kernel_index - current_kernel_index + 1
|
||||||
|
|
@ -352,11 +351,11 @@ class BasicEvaluation:
|
||||||
|
|
||||||
output += "\n".join(
|
output += "\n".join(
|
||||||
[
|
[
|
||||||
f"""{'-' * 80}
|
f"""{"-" * 80}
|
||||||
Event: {event}
|
Event: {event}
|
||||||
Source code location: {source_code_location(event.event)}
|
Source code location: {source_code_location(event.event)}
|
||||||
Percentage idle time: {self.metrics[event].fraction_idle_time * 100:.2f}%
|
Percentage idle time: {self.metrics[event].fraction_idle_time * 100:.2f}%
|
||||||
{'-' * 80}"""
|
{"-" * 80}"""
|
||||||
for event in event_list
|
for event in event_list
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -624,8 +624,7 @@ class profile(_KinetoProfile):
|
||||||
]
|
]
|
||||||
) as p:
|
) as p:
|
||||||
code_to_profile()
|
code_to_profile()
|
||||||
print(p.key_averages().table(
|
print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
|
||||||
sort_by="self_cuda_time_total", row_limit=-1))
|
|
||||||
|
|
||||||
Using the profiler's ``schedule``, ``on_trace_ready`` and ``step`` functions:
|
Using the profiler's ``schedule``, ``on_trace_ready`` and ``step`` functions:
|
||||||
|
|
||||||
|
|
@ -635,16 +634,17 @@ class profile(_KinetoProfile):
|
||||||
# on different iterations of the training loop;
|
# on different iterations of the training loop;
|
||||||
# trace_handler is called every time a new trace becomes available
|
# trace_handler is called every time a new trace becomes available
|
||||||
def trace_handler(prof):
|
def trace_handler(prof):
|
||||||
print(prof.key_averages().table(
|
print(
|
||||||
sort_by="self_cuda_time_total", row_limit=-1))
|
prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)
|
||||||
|
)
|
||||||
# prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json")
|
# prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json")
|
||||||
|
|
||||||
|
|
||||||
with torch.profiler.profile(
|
with torch.profiler.profile(
|
||||||
activities=[
|
activities=[
|
||||||
torch.profiler.ProfilerActivity.CPU,
|
torch.profiler.ProfilerActivity.CPU,
|
||||||
torch.profiler.ProfilerActivity.CUDA,
|
torch.profiler.ProfilerActivity.CUDA,
|
||||||
],
|
],
|
||||||
|
|
||||||
# In this example with wait=1, warmup=1, active=2, repeat=1,
|
# In this example with wait=1, warmup=1, active=2, repeat=1,
|
||||||
# profiler will skip the first step/iteration,
|
# profiler will skip the first step/iteration,
|
||||||
# start warming up on the second, record
|
# start warming up on the second, record
|
||||||
|
|
@ -652,20 +652,15 @@ class profile(_KinetoProfile):
|
||||||
# after which the trace will become available
|
# after which the trace will become available
|
||||||
# and on_trace_ready (when set) is called;
|
# and on_trace_ready (when set) is called;
|
||||||
# the cycle repeats starting with the next step
|
# the cycle repeats starting with the next step
|
||||||
|
schedule=torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=1),
|
||||||
schedule=torch.profiler.schedule(
|
on_trace_ready=trace_handler,
|
||||||
wait=1,
|
|
||||||
warmup=1,
|
|
||||||
active=2,
|
|
||||||
repeat=1),
|
|
||||||
on_trace_ready=trace_handler
|
|
||||||
# on_trace_ready=torch.profiler.tensorboard_trace_handler('./log')
|
# on_trace_ready=torch.profiler.tensorboard_trace_handler('./log')
|
||||||
# used when outputting for tensorboard
|
# used when outputting for tensorboard
|
||||||
) as p:
|
) as p:
|
||||||
for iter in range(N):
|
for iter in range(N):
|
||||||
code_iteration_to_profile(iter)
|
code_iteration_to_profile(iter)
|
||||||
# send a signal to the profiler that the next iteration has started
|
# send a signal to the profiler that the next iteration has started
|
||||||
p.step()
|
p.step()
|
||||||
|
|
||||||
The following sample shows how to setup up an Execution Trace Observer (`execution_trace_observer`)
|
The following sample shows how to setup up an Execution Trace Observer (`execution_trace_observer`)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
|
||||||
`torch/ao/quantization/fuser_method_mappings.py`, while adding an import statement
|
`torch/ao/quantization/fuser_method_mappings.py`, while adding an import statement
|
||||||
here.
|
here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from torch.ao.quantization.fuser_method_mappings import (
|
from torch.ao.quantization.fuser_method_mappings import (
|
||||||
_DEFAULT_OP_LIST_TO_FUSER_METHOD,
|
_DEFAULT_OP_LIST_TO_FUSER_METHOD,
|
||||||
fuse_conv_bn,
|
fuse_conv_bn,
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
|
||||||
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
||||||
here.
|
here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from torch.ao.quantization.fx._equalize import (
|
from torch.ao.quantization.fx._equalize import (
|
||||||
_convert_equalization_ref,
|
_convert_equalization_ref,
|
||||||
_InputEqualizationObserver,
|
_InputEqualizationObserver,
|
||||||
|
|
|
||||||
|
|
@ -6,4 +6,5 @@ If you are adding a new entry/functionality, please, add it to the
|
||||||
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
||||||
here.
|
here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from torch.ao.quantization.fx.convert import convert
|
from torch.ao.quantization.fx.convert import convert
|
||||||
|
|
|
||||||
|
|
@ -6,4 +6,5 @@ If you are adding a new entry/functionality, please, add it to the
|
||||||
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
||||||
here.
|
here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from torch.ao.quantization.fx.fuse import fuse
|
from torch.ao.quantization.fx.fuse import fuse
|
||||||
|
|
|
||||||
|
|
@ -6,4 +6,5 @@ If you are adding a new entry/functionality, please, add it to the
|
||||||
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
||||||
here.
|
here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from torch.ao.quantization.fx.fuse_handler import DefaultFuseHandler, FuseHandler
|
from torch.ao.quantization.fx.fuse_handler import DefaultFuseHandler, FuseHandler
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
|
||||||
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
||||||
here.
|
here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from torch.ao.quantization.fx.graph_module import (
|
from torch.ao.quantization.fx.graph_module import (
|
||||||
_is_observed_module,
|
_is_observed_module,
|
||||||
_is_observed_standalone_module,
|
_is_observed_standalone_module,
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
|
||||||
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
||||||
here.
|
here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from torch.ao.quantization.fx.match_utils import (
|
from torch.ao.quantization.fx.match_utils import (
|
||||||
_find_matches,
|
_find_matches,
|
||||||
_is_match,
|
_is_match,
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
|
||||||
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
||||||
here.
|
here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from torch.ao.quantization.fx.pattern_utils import (
|
from torch.ao.quantization.fx.pattern_utils import (
|
||||||
_register_fusion_pattern,
|
_register_fusion_pattern,
|
||||||
_register_quant_pattern,
|
_register_quant_pattern,
|
||||||
|
|
|
||||||
|
|
@ -6,4 +6,5 @@ If you are adding a new entry/functionality, please, add it to the
|
||||||
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
||||||
here.
|
here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from torch.ao.quantization.fx.prepare import prepare
|
from torch.ao.quantization.fx.prepare import prepare
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
|
||||||
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
||||||
here.
|
here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from torch.ao.quantization.fx.quantize_handler import (
|
from torch.ao.quantization.fx.quantize_handler import (
|
||||||
BatchNormQuantizeHandler,
|
BatchNormQuantizeHandler,
|
||||||
BinaryOpQuantizeHandler,
|
BinaryOpQuantizeHandler,
|
||||||
|
|
|
||||||
|
|
@ -6,4 +6,5 @@ If you are adding a new entry/functionality, please, add it to the
|
||||||
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
||||||
here.
|
here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from torch.ao.quantization.utils import Pattern, QuantizerCls
|
from torch.ao.quantization.utils import Pattern, QuantizerCls
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
|
||||||
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
||||||
here.
|
here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from torch.ao.quantization.fx.utils import (
|
from torch.ao.quantization.fx.utils import (
|
||||||
all_node_args_have_no_tensors,
|
all_node_args_have_no_tensors,
|
||||||
assert_and_get_unique_device,
|
assert_and_get_unique_device,
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
|
||||||
`torch/ao/quantization/observer.py`, while adding an import statement
|
`torch/ao/quantization/observer.py`, while adding an import statement
|
||||||
here.
|
here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from torch.ao.quantization.observer import (
|
from torch.ao.quantization.observer import (
|
||||||
_is_activation_post_process,
|
_is_activation_post_process,
|
||||||
_is_per_channel_script_obs_instance,
|
_is_per_channel_script_obs_instance,
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
|
||||||
`torch/ao/quantization/qconfig.py`, while adding an import statement
|
`torch/ao/quantization/qconfig.py`, while adding an import statement
|
||||||
here.
|
here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from torch.ao.quantization.qconfig import (
|
from torch.ao.quantization.qconfig import (
|
||||||
_add_module_to_qconfig_obs_ctr,
|
_add_module_to_qconfig_obs_ctr,
|
||||||
_assert_valid_qconfig,
|
_assert_valid_qconfig,
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
|
||||||
`torch/ao/quantization/quantization_mappings.py`, while adding an import statement
|
`torch/ao/quantization/quantization_mappings.py`, while adding an import statement
|
||||||
here.
|
here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from torch.ao.quantization.quantization_mappings import (
|
from torch.ao.quantization.quantization_mappings import (
|
||||||
_get_special_act_post_process,
|
_get_special_act_post_process,
|
||||||
_has_special_act_post_process,
|
_has_special_act_post_process,
|
||||||
|
|
|
||||||
|
|
@ -128,9 +128,7 @@ Examples::
|
||||||
>>> # Generates a periodic exponential window and decay factor equal to .5
|
>>> # Generates a periodic exponential window and decay factor equal to .5
|
||||||
>>> torch.signal.windows.exponential(10, sym=False,tau=.5)
|
>>> torch.signal.windows.exponential(10, sym=False,tau=.5)
|
||||||
tensor([4.5400e-05, 3.3546e-04, 2.4788e-03, 1.8316e-02, 1.3534e-01, 1.0000e+00, 1.3534e-01, 1.8316e-02, 2.4788e-03, 3.3546e-04])
|
tensor([4.5400e-05, 3.3546e-04, 2.4788e-03, 1.8316e-02, 1.3534e-01, 1.0000e+00, 1.3534e-01, 1.8316e-02, 2.4788e-03, 3.3546e-04])
|
||||||
""".format(
|
""".format(**window_common_args),
|
||||||
**window_common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
def exponential(
|
def exponential(
|
||||||
M: int,
|
M: int,
|
||||||
|
|
@ -452,9 +450,7 @@ Examples::
|
||||||
>>> # Generates a periodic Hamming window.
|
>>> # Generates a periodic Hamming window.
|
||||||
>>> torch.signal.windows.hamming(10, sym=False)
|
>>> torch.signal.windows.hamming(10, sym=False)
|
||||||
tensor([0.0800, 0.1679, 0.3979, 0.6821, 0.9121, 1.0000, 0.9121, 0.6821, 0.3979, 0.1679])
|
tensor([0.0800, 0.1679, 0.3979, 0.6821, 0.9121, 1.0000, 0.9121, 0.6821, 0.3979, 0.1679])
|
||||||
""".format(
|
""".format(**window_common_args),
|
||||||
**window_common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
def hamming(
|
def hamming(
|
||||||
M: int,
|
M: int,
|
||||||
|
|
@ -508,9 +504,7 @@ Examples::
|
||||||
>>> # Generates a periodic Hann window.
|
>>> # Generates a periodic Hann window.
|
||||||
>>> torch.signal.windows.hann(10, sym=False)
|
>>> torch.signal.windows.hann(10, sym=False)
|
||||||
tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955])
|
tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955])
|
||||||
""".format(
|
""".format(**window_common_args),
|
||||||
**window_common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
def hann(
|
def hann(
|
||||||
M: int,
|
M: int,
|
||||||
|
|
@ -564,9 +558,7 @@ Examples::
|
||||||
>>> # Generates a periodic Blackman window.
|
>>> # Generates a periodic Blackman window.
|
||||||
>>> torch.signal.windows.blackman(5, sym=False)
|
>>> torch.signal.windows.blackman(5, sym=False)
|
||||||
tensor([-1.4901e-08, 2.0077e-01, 8.4923e-01, 8.4923e-01, 2.0077e-01])
|
tensor([-1.4901e-08, 2.0077e-01, 8.4923e-01, 8.4923e-01, 2.0077e-01])
|
||||||
""".format(
|
""".format(**window_common_args),
|
||||||
**window_common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
def blackman(
|
def blackman(
|
||||||
M: int,
|
M: int,
|
||||||
|
|
@ -627,9 +619,7 @@ Examples::
|
||||||
>>> # Generates a periodic Bartlett window.
|
>>> # Generates a periodic Bartlett window.
|
||||||
>>> torch.signal.windows.bartlett(10, sym=False)
|
>>> torch.signal.windows.bartlett(10, sym=False)
|
||||||
tensor([0.0000, 0.2000, 0.4000, 0.6000, 0.8000, 1.0000, 0.8000, 0.6000, 0.4000, 0.2000])
|
tensor([0.0000, 0.2000, 0.4000, 0.6000, 0.8000, 1.0000, 0.8000, 0.6000, 0.4000, 0.2000])
|
||||||
""".format(
|
""".format(**window_common_args),
|
||||||
**window_common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
def bartlett(
|
def bartlett(
|
||||||
M: int,
|
M: int,
|
||||||
|
|
@ -704,9 +694,7 @@ Examples::
|
||||||
>>> # Generates a periodic general cosine window with 2 coefficients.
|
>>> # Generates a periodic general cosine window with 2 coefficients.
|
||||||
>>> torch.signal.windows.general_cosine(10, a=[0.5, 1 - 0.5], sym=False)
|
>>> torch.signal.windows.general_cosine(10, a=[0.5, 1 - 0.5], sym=False)
|
||||||
tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955])
|
tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955])
|
||||||
""".format(
|
""".format(**window_common_args),
|
||||||
**window_common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
def general_cosine(
|
def general_cosine(
|
||||||
M,
|
M,
|
||||||
|
|
@ -799,9 +787,7 @@ Examples::
|
||||||
>>> # Generates a periodic Hann window with the general Hamming window.
|
>>> # Generates a periodic Hann window with the general Hamming window.
|
||||||
>>> torch.signal.windows.general_hamming(10, alpha=0.5, sym=False)
|
>>> torch.signal.windows.general_hamming(10, alpha=0.5, sym=False)
|
||||||
tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955])
|
tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955])
|
||||||
""".format(
|
""".format(**window_common_args),
|
||||||
**window_common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
def general_hamming(
|
def general_hamming(
|
||||||
M,
|
M,
|
||||||
|
|
@ -866,9 +852,7 @@ Examples::
|
||||||
>>> # Generates a periodic Nuttall window.
|
>>> # Generates a periodic Nuttall window.
|
||||||
>>> torch.signal.windows.general_hamming(5, sym=False)
|
>>> torch.signal.windows.general_hamming(5, sym=False)
|
||||||
tensor([3.6280e-04, 1.1052e-01, 7.9826e-01, 7.9826e-01, 1.1052e-01])
|
tensor([3.6280e-04, 1.1052e-01, 7.9826e-01, 7.9826e-01, 1.1052e-01])
|
||||||
""".format(
|
""".format(**window_common_args),
|
||||||
**window_common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
def nuttall(
|
def nuttall(
|
||||||
M: int,
|
M: int,
|
||||||
|
|
|
||||||
|
|
@ -559,7 +559,11 @@ def as_sparse_gradcheck(gradcheck):
|
||||||
For example:
|
For example:
|
||||||
|
|
||||||
>>> gradcheck = torch.sparse.as_sparse_gradcheck(torch.autograd.gradcheck)
|
>>> gradcheck = torch.sparse.as_sparse_gradcheck(torch.autograd.gradcheck)
|
||||||
>>> x = torch.tensor([[0, 1], [2, 3]], dtype=torch.float64).to_sparse_coo().requires_grad_(True)
|
>>> x = (
|
||||||
|
... torch.tensor([[0, 1], [2, 3]], dtype=torch.float64)
|
||||||
|
... .to_sparse_coo()
|
||||||
|
... .requires_grad_(True)
|
||||||
|
... )
|
||||||
>>> gradcheck(lambda x: x.to_sparse_csr(), x)
|
>>> gradcheck(lambda x: x.to_sparse_csr(), x)
|
||||||
True
|
True
|
||||||
"""
|
"""
|
||||||
|
|
@ -667,7 +671,7 @@ def as_sparse_gradcheck(gradcheck):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f'conversion of {d["layout"]} strided representation to tensor'
|
f"conversion of {d['layout']} strided representation to tensor"
|
||||||
)
|
)
|
||||||
new_args.append(a)
|
new_args.append(a)
|
||||||
return tuple(new_args)
|
return tuple(new_args)
|
||||||
|
|
|
||||||
|
|
@ -296,11 +296,11 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None):
|
||||||
for b in range(nbatches):
|
for b in range(nbatches):
|
||||||
for i, r in enumerate(r_offsets):
|
for i, r in enumerate(r_offsets):
|
||||||
r0, r1 = divmod(r, N)
|
r0, r1 = divmod(r, N)
|
||||||
acc = accumulators[b, r0:r0 + Ms, r1:r1 + Ns]
|
acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns]
|
||||||
for g in range(c_indices[i], c_indices[i+1]):
|
for g in range(c_indices[i], c_indices[i + 1]):
|
||||||
p = p_offsets[g]
|
p = p_offsets[g]
|
||||||
q0, q1 = divmod(q_offsets[g], N)
|
q0, q1 = divmod(q_offsets[g], N)
|
||||||
acc += blocks[p] @ others[b, q0:q0 + Ks, q1:q1 + Ns]
|
acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns]
|
||||||
|
|
||||||
where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are
|
where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are
|
||||||
integer multiples of ``Ms`` and ``Ks``, respectively.
|
integer multiples of ``Ms`` and ``Ks``, respectively.
|
||||||
|
|
@ -320,11 +320,11 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None):
|
||||||
n = (r % N) // Ns
|
n = (r % N) // Ns
|
||||||
r0, r1 = divmod(r, N)
|
r0, r1 = divmod(r, N)
|
||||||
c0, c1 = c_indices[m], c_indices[m + 1]
|
c0, c1 = c_indices[m], c_indices[m + 1]
|
||||||
acc = accumulators[b, r0:r0 + Ms, r1:r1 + Ns]
|
acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns]
|
||||||
for i, p in enumerate(range(c0, c1)):
|
for i, p in enumerate(range(c0, c1)):
|
||||||
q = q_offsets[n * c1 + (SPLIT_N - n) * c0 + i]
|
q = q_offsets[n * c1 + (SPLIT_N - n) * c0 + i]
|
||||||
q0, q1 = divmod(q, N)
|
q0, q1 = divmod(q, N)
|
||||||
acc += blocks[p] @ others[b, q0:q0 + Ks, q1:q1 + Ns]
|
acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns]
|
||||||
|
|
||||||
where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are
|
where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are
|
||||||
integer multiples of ``Ms`` and ``Ks``, respectively.
|
integer multiples of ``Ms`` and ``Ks``, respectively.
|
||||||
|
|
|
||||||
|
|
@ -97,6 +97,7 @@ tune_bsr_dense_addmm to learn how to register a custom set of optimal
|
||||||
kernel parameters for addmm-based operations.
|
kernel parameters for addmm-based operations.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__all__ = ["get_meta", "tune_bsr_dense_addmm", "tune__int_bsr_dense_addmm"]
|
__all__ = ["get_meta", "tune_bsr_dense_addmm", "tune__int_bsr_dense_addmm"]
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
|
@ -432,9 +433,9 @@ def minimize(
|
||||||
|
|
||||||
|
|
||||||
def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device):
|
def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device):
|
||||||
assert (
|
assert sparsity <= 1.0 and sparsity >= 0.0, (
|
||||||
sparsity <= 1.0 and sparsity >= 0.0
|
"sparsity should be a value between 0 and 1"
|
||||||
), "sparsity should be a value between 0 and 1"
|
)
|
||||||
assert M % blocksize[0] == 0
|
assert M % blocksize[0] == 0
|
||||||
assert N % blocksize[1] == 0
|
assert N % blocksize[1] == 0
|
||||||
shape = (B, M // blocksize[0], N // blocksize[1])[int(B == 0) :]
|
shape = (B, M // blocksize[0], N // blocksize[1])[int(B == 0) :]
|
||||||
|
|
|
||||||
|
|
@ -465,14 +465,26 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||||
The equivalent PyTorch code to create the same five outputs from the dense tensor can be found below:
|
The equivalent PyTorch code to create the same five outputs from the dense tensor can be found below:
|
||||||
```
|
```
|
||||||
from torch.sparse import SparseSemiStructuredTensorCUTLASS
|
from torch.sparse import SparseSemiStructuredTensorCUTLASS
|
||||||
from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask
|
from torch.sparse._semi_structured_conversions import (
|
||||||
|
_sparse_semi_structured_tile,
|
||||||
|
_compute_compressed_swizzled_bitmask,
|
||||||
|
)
|
||||||
|
|
||||||
pruned = _sparse_semi_structured_tile(dense)
|
pruned = _sparse_semi_structured_tile(dense)
|
||||||
packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned)
|
packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned)
|
||||||
packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous())
|
packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(
|
||||||
|
pruned.t().contiguous()
|
||||||
|
)
|
||||||
bitmask = _compute_compressed_swizzled_bitmask(pruned)
|
bitmask = _compute_compressed_swizzled_bitmask(pruned)
|
||||||
|
|
||||||
SparseSemiStructuredTensorCUTLASS(dense.shape, packed_cutlass, meta_cutlass, packed_t_cutlass, meta_t_cutlass, bitmask)
|
SparseSemiStructuredTensorCUTLASS(
|
||||||
|
dense.shape,
|
||||||
|
packed_cutlass,
|
||||||
|
meta_cutlass,
|
||||||
|
packed_t_cutlass,
|
||||||
|
meta_t_cutlass,
|
||||||
|
bitmask,
|
||||||
|
)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
# We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag.
|
# We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag.
|
||||||
|
|
@ -583,14 +595,19 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
||||||
The equivalent PyTorch code to create the same three outputs from the dense tensor can be found below:
|
The equivalent PyTorch code to create the same three outputs from the dense tensor can be found below:
|
||||||
```
|
```
|
||||||
from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
|
from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
|
||||||
from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask
|
from torch.sparse._semi_structured_conversions import (
|
||||||
|
_sparse_semi_structured_tile,
|
||||||
|
_compute_compressed_swizzled_bitmask,
|
||||||
|
)
|
||||||
|
|
||||||
pruned = _sparse_semi_structured_tile(dense)
|
pruned = _sparse_semi_structured_tile(dense)
|
||||||
packed_cusparselt = torch._cslt_compress(pruned)
|
packed_cusparselt = torch._cslt_compress(pruned)
|
||||||
packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous())
|
packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous())
|
||||||
bitmask = _compute_compressed_swizzled_bitmask(pruned)
|
bitmask = _compute_compressed_swizzled_bitmask(pruned)
|
||||||
|
|
||||||
SparseSemiStructuredTensorCUSPARSELT(dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask)
|
SparseSemiStructuredTensorCUSPARSELT(
|
||||||
|
dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask
|
||||||
|
)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
(
|
(
|
||||||
|
|
|
||||||
|
|
@ -134,9 +134,7 @@ Example::
|
||||||
>>> torch.special.digamma(a)
|
>>> torch.special.digamma(a)
|
||||||
tensor([-0.5772, -1.9635])
|
tensor([-0.5772, -1.9635])
|
||||||
|
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
gammaln = _add_docstr(
|
gammaln = _add_docstr(
|
||||||
|
|
@ -162,9 +160,7 @@ Example::
|
||||||
>>> torch.special.gammaln(a)
|
>>> torch.special.gammaln(a)
|
||||||
tensor([ 0.5724, 0.0000, -0.1208])
|
tensor([ 0.5724, 0.0000, -0.1208])
|
||||||
|
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
polygamma = _add_docstr(
|
polygamma = _add_docstr(
|
||||||
|
|
@ -200,9 +196,7 @@ Example::
|
||||||
tensor([ 6.4939, 97.4091])
|
tensor([ 6.4939, 97.4091])
|
||||||
>>> torch.special.polygamma(4, a)
|
>>> torch.special.polygamma(4, a)
|
||||||
tensor([ -24.8863, -771.4742])
|
tensor([ -24.8863, -771.4742])
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
erf = _add_docstr(
|
erf = _add_docstr(
|
||||||
|
|
@ -226,9 +220,7 @@ Example::
|
||||||
|
|
||||||
>>> torch.special.erf(torch.tensor([0, -1., 10.]))
|
>>> torch.special.erf(torch.tensor([0, -1., 10.]))
|
||||||
tensor([ 0.0000, -0.8427, 1.0000])
|
tensor([ 0.0000, -0.8427, 1.0000])
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
erfc = _add_docstr(
|
erfc = _add_docstr(
|
||||||
|
|
@ -253,9 +245,7 @@ Example::
|
||||||
|
|
||||||
>>> torch.special.erfc(torch.tensor([0, -1., 10.]))
|
>>> torch.special.erfc(torch.tensor([0, -1., 10.]))
|
||||||
tensor([ 1.0000, 1.8427, 0.0000])
|
tensor([ 1.0000, 1.8427, 0.0000])
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
erfcx = _add_docstr(
|
erfcx = _add_docstr(
|
||||||
|
|
@ -283,9 +273,7 @@ Example::
|
||||||
|
|
||||||
>>> torch.special.erfcx(torch.tensor([0, -1., 10.]))
|
>>> torch.special.erfcx(torch.tensor([0, -1., 10.]))
|
||||||
tensor([ 1.0000, 5.0090, 0.0561])
|
tensor([ 1.0000, 5.0090, 0.0561])
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
erfinv = _add_docstr(
|
erfinv = _add_docstr(
|
||||||
|
|
@ -311,9 +299,7 @@ Example::
|
||||||
|
|
||||||
>>> torch.special.erfinv(torch.tensor([0, 0.5, -1.]))
|
>>> torch.special.erfinv(torch.tensor([0, 0.5, -1.]))
|
||||||
tensor([ 0.0000, 0.4769, -inf])
|
tensor([ 0.0000, 0.4769, -inf])
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logit = _add_docstr(
|
logit = _add_docstr(
|
||||||
|
|
@ -351,9 +337,7 @@ Example::
|
||||||
tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516])
|
tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516])
|
||||||
>>> torch.special.logit(a, eps=1e-6)
|
>>> torch.special.logit(a, eps=1e-6)
|
||||||
tensor([-0.9466, 2.6352, 0.6131, -1.7169, 0.6261])
|
tensor([-0.9466, 2.6352, 0.6131, -1.7169, 0.6261])
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logsumexp = _add_docstr(
|
logsumexp = _add_docstr(
|
||||||
|
|
@ -362,9 +346,7 @@ logsumexp = _add_docstr(
|
||||||
logsumexp(input, dim, keepdim=False, *, out=None)
|
logsumexp(input, dim, keepdim=False, *, out=None)
|
||||||
|
|
||||||
Alias for :func:`torch.logsumexp`.
|
Alias for :func:`torch.logsumexp`.
|
||||||
""".format(
|
""".format(**multi_dim_common),
|
||||||
**multi_dim_common
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
expit = _add_docstr(
|
expit = _add_docstr(
|
||||||
|
|
@ -391,9 +373,7 @@ Example::
|
||||||
tensor([ 0.9213, 1.0887, -0.8858, -1.7683])
|
tensor([ 0.9213, 1.0887, -0.8858, -1.7683])
|
||||||
>>> torch.special.expit(t)
|
>>> torch.special.expit(t)
|
||||||
tensor([ 0.7153, 0.7481, 0.2920, 0.1458])
|
tensor([ 0.7153, 0.7481, 0.2920, 0.1458])
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
exp2 = _add_docstr(
|
exp2 = _add_docstr(
|
||||||
|
|
@ -418,9 +398,7 @@ Example::
|
||||||
|
|
||||||
>>> torch.special.exp2(torch.tensor([0, math.log2(2.), 3, 4]))
|
>>> torch.special.exp2(torch.tensor([0, math.log2(2.), 3, 4]))
|
||||||
tensor([ 1., 2., 8., 16.])
|
tensor([ 1., 2., 8., 16.])
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
expm1 = _add_docstr(
|
expm1 = _add_docstr(
|
||||||
|
|
@ -448,9 +426,7 @@ Example::
|
||||||
|
|
||||||
>>> torch.special.expm1(torch.tensor([0, math.log(2.)]))
|
>>> torch.special.expm1(torch.tensor([0, math.log(2.)]))
|
||||||
tensor([ 0., 1.])
|
tensor([ 0., 1.])
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
xlog1py = _add_docstr(
|
xlog1py = _add_docstr(
|
||||||
|
|
@ -495,9 +471,7 @@ Example::
|
||||||
tensor([1.6094, 3.2189, 4.8283])
|
tensor([1.6094, 3.2189, 4.8283])
|
||||||
>>> torch.special.xlog1py(2, y)
|
>>> torch.special.xlog1py(2, y)
|
||||||
tensor([2.7726, 2.1972, 1.3863])
|
tensor([2.7726, 2.1972, 1.3863])
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
xlogy = _add_docstr(
|
xlogy = _add_docstr(
|
||||||
|
|
@ -542,9 +516,7 @@ Example::
|
||||||
tensor([1.3863, 2.7726, 4.1589])
|
tensor([1.3863, 2.7726, 4.1589])
|
||||||
>>> torch.special.xlogy(2, y)
|
>>> torch.special.xlogy(2, y)
|
||||||
tensor([2.1972, 1.3863, 0.0000])
|
tensor([2.1972, 1.3863, 0.0000])
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
i0 = _add_docstr(
|
i0 = _add_docstr(
|
||||||
|
|
@ -570,9 +542,7 @@ Example::
|
||||||
>>> torch.i0(torch.arange(5, dtype=torch.float32))
|
>>> torch.i0(torch.arange(5, dtype=torch.float32))
|
||||||
tensor([ 1.0000, 1.2661, 2.2796, 4.8808, 11.3019])
|
tensor([ 1.0000, 1.2661, 2.2796, 4.8808, 11.3019])
|
||||||
|
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
i0e = _add_docstr(
|
i0e = _add_docstr(
|
||||||
|
|
@ -597,9 +567,7 @@ Example::
|
||||||
|
|
||||||
>>> torch.special.i0e(torch.arange(5, dtype=torch.float32))
|
>>> torch.special.i0e(torch.arange(5, dtype=torch.float32))
|
||||||
tensor([1.0000, 0.4658, 0.3085, 0.2430, 0.2070])
|
tensor([1.0000, 0.4658, 0.3085, 0.2430, 0.2070])
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
i1 = _add_docstr(
|
i1 = _add_docstr(
|
||||||
|
|
@ -624,9 +592,7 @@ Example::
|
||||||
|
|
||||||
>>> torch.special.i1(torch.arange(5, dtype=torch.float32))
|
>>> torch.special.i1(torch.arange(5, dtype=torch.float32))
|
||||||
tensor([0.0000, 0.5652, 1.5906, 3.9534, 9.7595])
|
tensor([0.0000, 0.5652, 1.5906, 3.9534, 9.7595])
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
i1e = _add_docstr(
|
i1e = _add_docstr(
|
||||||
|
|
@ -652,9 +618,7 @@ Example::
|
||||||
|
|
||||||
>>> torch.special.i1e(torch.arange(5, dtype=torch.float32))
|
>>> torch.special.i1e(torch.arange(5, dtype=torch.float32))
|
||||||
tensor([0.0000, 0.2079, 0.2153, 0.1968, 0.1788])
|
tensor([0.0000, 0.2079, 0.2153, 0.1968, 0.1788])
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ndtr = _add_docstr(
|
ndtr = _add_docstr(
|
||||||
|
|
@ -679,9 +643,7 @@ Example::
|
||||||
|
|
||||||
>>> torch.special.ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3]))
|
>>> torch.special.ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3]))
|
||||||
tensor([0.0013, 0.0228, 0.1587, 0.5000, 0.8413, 0.9772, 0.9987])
|
tensor([0.0013, 0.0228, 0.1587, 0.5000, 0.8413, 0.9772, 0.9987])
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ndtri = _add_docstr(
|
ndtri = _add_docstr(
|
||||||
|
|
@ -709,9 +671,7 @@ Example::
|
||||||
|
|
||||||
>>> torch.special.ndtri(torch.tensor([0, 0.25, 0.5, 0.75, 1]))
|
>>> torch.special.ndtri(torch.tensor([0, 0.25, 0.5, 0.75, 1]))
|
||||||
tensor([ -inf, -0.6745, 0.0000, 0.6745, inf])
|
tensor([ -inf, -0.6745, 0.0000, 0.6745, inf])
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
log_ndtr = _add_docstr(
|
log_ndtr = _add_docstr(
|
||||||
|
|
@ -736,9 +696,7 @@ Example::
|
||||||
|
|
||||||
>>> torch.special.log_ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3]))
|
>>> torch.special.log_ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3]))
|
||||||
tensor([-6.6077 -3.7832 -1.841 -0.6931 -0.1728 -0.023 -0.0014])
|
tensor([-6.6077 -3.7832 -1.841 -0.6931 -0.1728 -0.023 -0.0014])
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
log1p = _add_docstr(
|
log1p = _add_docstr(
|
||||||
|
|
@ -779,9 +737,7 @@ Example::
|
||||||
tensor([ 0.2252, -0.2948, 1.0267, -1.1566])
|
tensor([ 0.2252, -0.2948, 1.0267, -1.1566])
|
||||||
>>> torch.special.sinc(t)
|
>>> torch.special.sinc(t)
|
||||||
tensor([ 0.9186, 0.8631, -0.0259, -0.1300])
|
tensor([ 0.9186, 0.8631, -0.0259, -0.1300])
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
round = _add_docstr(
|
round = _add_docstr(
|
||||||
|
|
@ -886,9 +842,7 @@ Example::
|
||||||
tensor([1.6449, 0.0823])
|
tensor([1.6449, 0.0823])
|
||||||
>>> torch.special.zeta(2, torch.tensor([1., 2.]))
|
>>> torch.special.zeta(2, torch.tensor([1., 2.]))
|
||||||
tensor([1.6449, 0.6449])
|
tensor([1.6449, 0.6449])
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
multigammaln = _add_docstr(
|
multigammaln = _add_docstr(
|
||||||
|
|
@ -925,9 +879,7 @@ Example::
|
||||||
>>> torch.special.multigammaln(a, 2)
|
>>> torch.special.multigammaln(a, 2)
|
||||||
tensor([[0.3928, 0.4007, 0.7586],
|
tensor([[0.3928, 0.4007, 0.7586],
|
||||||
[1.0311, 0.3901, 0.5049]])
|
[1.0311, 0.3901, 0.5049]])
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
gammainc = _add_docstr(
|
gammainc = _add_docstr(
|
||||||
|
|
@ -976,9 +928,7 @@ Example::
|
||||||
>>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2)
|
>>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2)
|
||||||
tensor([1., 1., 1.])
|
tensor([1., 1., 1.])
|
||||||
|
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
gammaincc = _add_docstr(
|
gammaincc = _add_docstr(
|
||||||
|
|
@ -1026,9 +976,7 @@ Example::
|
||||||
>>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2)
|
>>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2)
|
||||||
tensor([1., 1., 1.])
|
tensor([1., 1., 1.])
|
||||||
|
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
airy_ai = _add_docstr(
|
airy_ai = _add_docstr(
|
||||||
|
|
@ -1045,9 +993,7 @@ Args:
|
||||||
|
|
||||||
Keyword args:
|
Keyword args:
|
||||||
{out}
|
{out}
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
bessel_j0 = _add_docstr(
|
bessel_j0 = _add_docstr(
|
||||||
|
|
@ -1064,9 +1010,7 @@ Args:
|
||||||
|
|
||||||
Keyword args:
|
Keyword args:
|
||||||
{out}
|
{out}
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
bessel_j1 = _add_docstr(
|
bessel_j1 = _add_docstr(
|
||||||
|
|
@ -1083,9 +1027,7 @@ Args:
|
||||||
|
|
||||||
Keyword args:
|
Keyword args:
|
||||||
{out}
|
{out}
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
bessel_y0 = _add_docstr(
|
bessel_y0 = _add_docstr(
|
||||||
|
|
@ -1102,9 +1044,7 @@ Args:
|
||||||
|
|
||||||
Keyword args:
|
Keyword args:
|
||||||
{out}
|
{out}
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
bessel_y1 = _add_docstr(
|
bessel_y1 = _add_docstr(
|
||||||
|
|
@ -1121,9 +1061,7 @@ Args:
|
||||||
|
|
||||||
Keyword args:
|
Keyword args:
|
||||||
{out}
|
{out}
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
chebyshev_polynomial_t = _add_docstr(
|
chebyshev_polynomial_t = _add_docstr(
|
||||||
|
|
@ -1154,9 +1092,7 @@ Args:
|
||||||
|
|
||||||
Keyword args:
|
Keyword args:
|
||||||
{out}
|
{out}
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
chebyshev_polynomial_u = _add_docstr(
|
chebyshev_polynomial_u = _add_docstr(
|
||||||
|
|
@ -1188,9 +1124,7 @@ Args:
|
||||||
|
|
||||||
Keyword args:
|
Keyword args:
|
||||||
{out}
|
{out}
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
chebyshev_polynomial_v = _add_docstr(
|
chebyshev_polynomial_v = _add_docstr(
|
||||||
|
|
@ -1208,9 +1142,7 @@ Args:
|
||||||
|
|
||||||
Keyword args:
|
Keyword args:
|
||||||
{out}
|
{out}
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
chebyshev_polynomial_w = _add_docstr(
|
chebyshev_polynomial_w = _add_docstr(
|
||||||
|
|
@ -1228,9 +1160,7 @@ Args:
|
||||||
|
|
||||||
Keyword args:
|
Keyword args:
|
||||||
{out}
|
{out}
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hermite_polynomial_h = _add_docstr(
|
hermite_polynomial_h = _add_docstr(
|
||||||
|
|
@ -1256,9 +1186,7 @@ Args:
|
||||||
|
|
||||||
Keyword args:
|
Keyword args:
|
||||||
{out}
|
{out}
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hermite_polynomial_he = _add_docstr(
|
hermite_polynomial_he = _add_docstr(
|
||||||
|
|
@ -1284,9 +1212,7 @@ Args:
|
||||||
|
|
||||||
Keyword args:
|
Keyword args:
|
||||||
{out}
|
{out}
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
laguerre_polynomial_l = _add_docstr(
|
laguerre_polynomial_l = _add_docstr(
|
||||||
|
|
@ -1312,9 +1238,7 @@ Args:
|
||||||
|
|
||||||
Keyword args:
|
Keyword args:
|
||||||
{out}
|
{out}
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
legendre_polynomial_p = _add_docstr(
|
legendre_polynomial_p = _add_docstr(
|
||||||
|
|
@ -1340,9 +1264,7 @@ Args:
|
||||||
|
|
||||||
Keyword args:
|
Keyword args:
|
||||||
{out}
|
{out}
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
modified_bessel_i0 = _add_docstr(
|
modified_bessel_i0 = _add_docstr(
|
||||||
|
|
@ -1359,9 +1281,7 @@ Args:
|
||||||
|
|
||||||
Keyword args:
|
Keyword args:
|
||||||
{out}
|
{out}
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
modified_bessel_i1 = _add_docstr(
|
modified_bessel_i1 = _add_docstr(
|
||||||
|
|
@ -1378,9 +1298,7 @@ Args:
|
||||||
|
|
||||||
Keyword args:
|
Keyword args:
|
||||||
{out}
|
{out}
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
modified_bessel_k0 = _add_docstr(
|
modified_bessel_k0 = _add_docstr(
|
||||||
|
|
@ -1397,9 +1315,7 @@ Args:
|
||||||
|
|
||||||
Keyword args:
|
Keyword args:
|
||||||
{out}
|
{out}
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
modified_bessel_k1 = _add_docstr(
|
modified_bessel_k1 = _add_docstr(
|
||||||
|
|
@ -1416,9 +1332,7 @@ Args:
|
||||||
|
|
||||||
Keyword args:
|
Keyword args:
|
||||||
{out}
|
{out}
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
scaled_modified_bessel_k0 = _add_docstr(
|
scaled_modified_bessel_k0 = _add_docstr(
|
||||||
|
|
@ -1435,9 +1349,7 @@ Args:
|
||||||
|
|
||||||
Keyword args:
|
Keyword args:
|
||||||
{out}
|
{out}
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
scaled_modified_bessel_k1 = _add_docstr(
|
scaled_modified_bessel_k1 = _add_docstr(
|
||||||
|
|
@ -1454,9 +1366,7 @@ Args:
|
||||||
|
|
||||||
Keyword args:
|
Keyword args:
|
||||||
{out}
|
{out}
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
shifted_chebyshev_polynomial_t = _add_docstr(
|
shifted_chebyshev_polynomial_t = _add_docstr(
|
||||||
|
|
@ -1474,9 +1384,7 @@ Args:
|
||||||
|
|
||||||
Keyword args:
|
Keyword args:
|
||||||
{out}
|
{out}
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
shifted_chebyshev_polynomial_u = _add_docstr(
|
shifted_chebyshev_polynomial_u = _add_docstr(
|
||||||
|
|
@ -1494,9 +1402,7 @@ Args:
|
||||||
|
|
||||||
Keyword args:
|
Keyword args:
|
||||||
{out}
|
{out}
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
shifted_chebyshev_polynomial_v = _add_docstr(
|
shifted_chebyshev_polynomial_v = _add_docstr(
|
||||||
|
|
@ -1514,9 +1420,7 @@ Args:
|
||||||
|
|
||||||
Keyword args:
|
Keyword args:
|
||||||
{out}
|
{out}
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
shifted_chebyshev_polynomial_w = _add_docstr(
|
shifted_chebyshev_polynomial_w = _add_docstr(
|
||||||
|
|
@ -1534,9 +1438,7 @@ Args:
|
||||||
|
|
||||||
Keyword args:
|
Keyword args:
|
||||||
{out}
|
{out}
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
spherical_bessel_j0 = _add_docstr(
|
spherical_bessel_j0 = _add_docstr(
|
||||||
|
|
@ -1553,7 +1455,5 @@ Args:
|
||||||
|
|
||||||
Keyword args:
|
Keyword args:
|
||||||
{out}
|
{out}
|
||||||
""".format(
|
""".format(**common_args),
|
||||||
**common_args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1538,7 +1538,9 @@ def assert_close(
|
||||||
>>> expected = torch.tensor([1.0, 2.0, 3.0])
|
>>> expected = torch.tensor([1.0, 2.0, 3.0])
|
||||||
>>> actual = torch.tensor([1.0, 4.0, 5.0])
|
>>> actual = torch.tensor([1.0, 4.0, 5.0])
|
||||||
>>> # The default error message can be overwritten.
|
>>> # The default error message can be overwritten.
|
||||||
>>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!")
|
>>> torch.testing.assert_close(
|
||||||
|
... actual, expected, msg="Argh, the tensors are not close!"
|
||||||
|
... )
|
||||||
Traceback (most recent call last):
|
Traceback (most recent call last):
|
||||||
...
|
...
|
||||||
AssertionError: Argh, the tensors are not close!
|
AssertionError: Argh, the tensors are not close!
|
||||||
|
|
|
||||||
|
|
@ -115,11 +115,11 @@ def make_tensor(
|
||||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
||||||
>>> from torch.testing import make_tensor
|
>>> from torch.testing import make_tensor
|
||||||
>>> # Creates a float tensor with values in [-1, 1)
|
>>> # Creates a float tensor with values in [-1, 1)
|
||||||
>>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1)
|
>>> make_tensor((3,), device="cpu", dtype=torch.float32, low=-1, high=1)
|
||||||
>>> # xdoctest: +SKIP
|
>>> # xdoctest: +SKIP
|
||||||
tensor([ 0.1205, 0.2282, -0.6380])
|
tensor([ 0.1205, 0.2282, -0.6380])
|
||||||
>>> # Creates a bool tensor on CUDA
|
>>> # Creates a bool tensor on CUDA
|
||||||
>>> make_tensor((2, 2), device='cuda', dtype=torch.bool)
|
>>> make_tensor((2, 2), device="cuda", dtype=torch.bool)
|
||||||
tensor([[False, False],
|
tensor([[False, False],
|
||||||
[False, True]], device='cuda:0')
|
[False, True]], device='cuda:0')
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -721,9 +721,9 @@ def filter_desired_device_types(device_type_test_bases, except_for=None, only_fo
|
||||||
intersect = set(except_for if except_for else []) & set(
|
intersect = set(except_for if except_for else []) & set(
|
||||||
only_for if only_for else []
|
only_for if only_for else []
|
||||||
)
|
)
|
||||||
assert (
|
assert not intersect, (
|
||||||
not intersect
|
f"device ({intersect}) appeared in both except_for and only_for"
|
||||||
), f"device ({intersect}) appeared in both except_for and only_for"
|
)
|
||||||
|
|
||||||
# Replace your privateuse1 backend name with 'privateuse1'
|
# Replace your privateuse1 backend name with 'privateuse1'
|
||||||
if is_privateuse1_backend_available():
|
if is_privateuse1_backend_available():
|
||||||
|
|
@ -1407,9 +1407,9 @@ class deviceCountAtLeast:
|
||||||
self.num_required_devices = num_required_devices
|
self.num_required_devices = num_required_devices
|
||||||
|
|
||||||
def __call__(self, fn):
|
def __call__(self, fn):
|
||||||
assert not hasattr(
|
assert not hasattr(fn, "num_required_devices"), (
|
||||||
fn, "num_required_devices"
|
f"deviceCountAtLeast redefinition for {fn.__name__}"
|
||||||
), f"deviceCountAtLeast redefinition for {fn.__name__}"
|
)
|
||||||
fn.num_required_devices = self.num_required_devices
|
fn.num_required_devices = self.num_required_devices
|
||||||
|
|
||||||
@wraps(fn)
|
@wraps(fn)
|
||||||
|
|
@ -1474,13 +1474,13 @@ def onlyNativeDeviceTypesAnd(devices=None):
|
||||||
# self.precision *2, max(1, self.precision)).
|
# self.precision *2, max(1, self.precision)).
|
||||||
class precisionOverride:
|
class precisionOverride:
|
||||||
def __init__(self, d):
|
def __init__(self, d):
|
||||||
assert isinstance(
|
assert isinstance(d, dict), (
|
||||||
d, dict
|
"precisionOverride not given a dtype : precision dict!"
|
||||||
), "precisionOverride not given a dtype : precision dict!"
|
)
|
||||||
for dtype in d.keys():
|
for dtype in d.keys():
|
||||||
assert isinstance(
|
assert isinstance(dtype, torch.dtype), (
|
||||||
dtype, torch.dtype
|
f"precisionOverride given unknown dtype {dtype}"
|
||||||
), f"precisionOverride given unknown dtype {dtype}"
|
)
|
||||||
|
|
||||||
self.d = d
|
self.d = d
|
||||||
|
|
||||||
|
|
@ -1513,12 +1513,12 @@ class toleranceOverride:
|
||||||
def __init__(self, d):
|
def __init__(self, d):
|
||||||
assert isinstance(d, dict), "toleranceOverride not given a dtype : tol dict!"
|
assert isinstance(d, dict), "toleranceOverride not given a dtype : tol dict!"
|
||||||
for dtype, prec in d.items():
|
for dtype, prec in d.items():
|
||||||
assert isinstance(
|
assert isinstance(dtype, torch.dtype), (
|
||||||
dtype, torch.dtype
|
f"toleranceOverride given unknown dtype {dtype}"
|
||||||
), f"toleranceOverride given unknown dtype {dtype}"
|
)
|
||||||
assert isinstance(
|
assert isinstance(prec, tol), (
|
||||||
prec, tol
|
"toleranceOverride not given a dtype : tol dict!"
|
||||||
), "toleranceOverride not given a dtype : tol dict!"
|
)
|
||||||
|
|
||||||
self.d = d
|
self.d = d
|
||||||
|
|
||||||
|
|
@ -1546,13 +1546,13 @@ class dtypes:
|
||||||
"all dtype variants must be. "
|
"all dtype variants must be. "
|
||||||
f"Received non-list non-tuple dtype {str(arg)}"
|
f"Received non-list non-tuple dtype {str(arg)}"
|
||||||
)
|
)
|
||||||
assert all(
|
assert all(isinstance(dtype, torch.dtype) for dtype in arg), (
|
||||||
isinstance(dtype, torch.dtype) for dtype in arg
|
f"Unknown dtype in {str(arg)}"
|
||||||
), f"Unknown dtype in {str(arg)}"
|
)
|
||||||
else:
|
else:
|
||||||
assert all(
|
assert all(isinstance(arg, torch.dtype) for arg in args), (
|
||||||
isinstance(arg, torch.dtype) for arg in args
|
f"Unknown dtype in {str(args)}"
|
||||||
), f"Unknown dtype in {str(args)}"
|
)
|
||||||
|
|
||||||
self.args = args
|
self.args = args
|
||||||
self.device_type = device_type
|
self.device_type = device_type
|
||||||
|
|
|
||||||
|
|
@ -253,9 +253,9 @@ def verify_ddp_error_logged(model_DDP, err_substr):
|
||||||
if err_substr.find("\nException raised from ") == -1
|
if err_substr.find("\nException raised from ") == -1
|
||||||
else err_substr.split("\nException raised from ")[0]
|
else err_substr.split("\nException raised from ")[0]
|
||||||
)
|
)
|
||||||
assert (
|
assert actual in logging_err, (
|
||||||
actual in logging_err
|
f"Did not find expected {actual} in ddp logging data error: {logging_err}"
|
||||||
), f"Did not find expected {actual} in ddp logging data error: {logging_err}"
|
)
|
||||||
|
|
||||||
|
|
||||||
def with_nccl_blocking_wait(func):
|
def with_nccl_blocking_wait(func):
|
||||||
|
|
@ -294,9 +294,9 @@ def with_nccl_blocking_wait(func):
|
||||||
finally:
|
finally:
|
||||||
# restore old values.
|
# restore old values.
|
||||||
if cached_nccl_async_error_handling is not None:
|
if cached_nccl_async_error_handling is not None:
|
||||||
os.environ[
|
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = (
|
||||||
"TORCH_NCCL_ASYNC_ERROR_HANDLING"
|
cached_nccl_async_error_handling
|
||||||
] = cached_nccl_async_error_handling
|
)
|
||||||
|
|
||||||
if cached_nccl_blocking_wait is not None:
|
if cached_nccl_blocking_wait is not None:
|
||||||
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = cached_nccl_blocking_wait
|
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = cached_nccl_blocking_wait
|
||||||
|
|
@ -812,7 +812,7 @@ class MultiProcessTestCase(TestCase):
|
||||||
sys.exit(TEST_SKIPS["generic"].exit_code)
|
sys.exit(TEST_SKIPS["generic"].exit_code)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Caught exception: \n%s exiting " "process %s with exit code: %s",
|
"Caught exception: \n%s exiting process %s with exit code: %s",
|
||||||
traceback.format_exc(),
|
traceback.format_exc(),
|
||||||
self.rank,
|
self.rank,
|
||||||
MultiProcessTestCase.TEST_ERROR_EXIT_CODE,
|
MultiProcessTestCase.TEST_ERROR_EXIT_CODE,
|
||||||
|
|
@ -1689,9 +1689,7 @@ class MultiProcContinousTest(TestCase):
|
||||||
cls.processes.append(process)
|
cls.processes.append(process)
|
||||||
cls.task_queues.append(task_queue)
|
cls.task_queues.append(task_queue)
|
||||||
cls.completion_queues.append(completion_queue)
|
cls.completion_queues.append(completion_queue)
|
||||||
logger.info(
|
logger.info("Started process %s with pid %s", rank, process.pid) # noqa: UP031
|
||||||
"Started process %s with pid %s", rank, process.pid
|
|
||||||
) # noqa: UP031
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
|
|
|
||||||
|
|
@ -1285,10 +1285,10 @@ class FSDPTest(MultiProcessTestCase):
|
||||||
loss = sharded_grad_scaler.scale(loss)
|
loss = sharded_grad_scaler.scale(loss)
|
||||||
|
|
||||||
if not mixed_precision and not use_pure_fp16:
|
if not mixed_precision and not use_pure_fp16:
|
||||||
assert (
|
assert loss.dtype == torch.float32, (
|
||||||
loss.dtype == torch.float32
|
"loss data type should be float32, as the original \
|
||||||
), "loss data type should be float32, as the original \
|
|
||||||
parameter data type is float32."
|
parameter data type is float32."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if use_pure_fp16:
|
if use_pure_fp16:
|
||||||
self.assertEqual(loss.dtype, torch.float16)
|
self.assertEqual(loss.dtype, torch.float16)
|
||||||
|
|
@ -1354,9 +1354,9 @@ class FSDPTest(MultiProcessTestCase):
|
||||||
wrapper should provide data parallel semantics. If ``None``,
|
wrapper should provide data parallel semantics. If ``None``,
|
||||||
then the callable defaults to the DDP constructor.
|
then the callable defaults to the DDP constructor.
|
||||||
"""
|
"""
|
||||||
assert (
|
assert fsdp_init_mode != FSDPInitMode.NO_FSDP, (
|
||||||
fsdp_init_mode != FSDPInitMode.NO_FSDP
|
"Expects an FSDP init mode that wraps with FSDP"
|
||||||
), "Expects an FSDP init mode that wraps with FSDP"
|
)
|
||||||
if init_kwargs is None:
|
if init_kwargs is None:
|
||||||
init_kwargs = {}
|
init_kwargs = {}
|
||||||
lr = 1e-2
|
lr = 1e-2
|
||||||
|
|
|
||||||
|
|
@ -1268,9 +1268,9 @@ def _get_optim_inputs_including_global_cliquey_kwargs(
|
||||||
trivial. That said, we sometimes want to test for all possible configs on an
|
trivial. That said, we sometimes want to test for all possible configs on an
|
||||||
optimizer including all supported flags, so this helper returns all optim inputs.
|
optimizer including all supported flags, so this helper returns all optim inputs.
|
||||||
"""
|
"""
|
||||||
assert all(
|
assert all(x in ["foreach", "fused", "differentiable"] for x in skip), (
|
||||||
x in ["foreach", "fused", "differentiable"] for x in skip
|
"skip must be a subset of ['foreach', 'fused', 'differentiable']"
|
||||||
), "skip must be a subset of ['foreach', 'fused', 'differentiable']"
|
)
|
||||||
|
|
||||||
optim_inputs = optim_info.optim_inputs_func(device)
|
optim_inputs = optim_info.optim_inputs_func(device)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -477,7 +477,9 @@ def with_comms(
|
||||||
def decorator(func, eager_init: bool = False, backend: Optional[str] = None):
|
def decorator(func, eager_init: bool = False, backend: Optional[str] = None):
|
||||||
@wraps(func) # pyre-ignore[6]
|
@wraps(func) # pyre-ignore[6]
|
||||||
def wrapper(
|
def wrapper(
|
||||||
self, *args: tuple[object], **kwargs: dict[str, Any] # type: ignore[misc]
|
self,
|
||||||
|
*args: tuple[object],
|
||||||
|
**kwargs: dict[str, Any], # type: ignore[misc]
|
||||||
) -> None:
|
) -> None:
|
||||||
self.init_pg(eager_init, backend)
|
self.init_pg(eager_init, backend)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -253,7 +253,11 @@ class Trainer:
|
||||||
else:
|
else:
|
||||||
input_batches = batches
|
input_batches = batches
|
||||||
|
|
||||||
with self.hybrid_module.join() if simulate_uneven_inputs else contextlib.nullcontext():
|
with (
|
||||||
|
self.hybrid_module.join()
|
||||||
|
if simulate_uneven_inputs
|
||||||
|
else contextlib.nullcontext()
|
||||||
|
):
|
||||||
for b in input_batches:
|
for b in input_batches:
|
||||||
with dist_autograd.context() as context_id:
|
with dist_autograd.context() as context_id:
|
||||||
output = self.hybrid_module.forward(b)
|
output = self.hybrid_module.forward(b)
|
||||||
|
|
@ -261,8 +265,7 @@ class Trainer:
|
||||||
dist_autograd.backward(context_id, [loss])
|
dist_autograd.backward(context_id, [loss])
|
||||||
grads_dict = dist_autograd.get_gradients(context_id)
|
grads_dict = dist_autograd.get_gradients(context_id)
|
||||||
gLogger.info(
|
gLogger.info(
|
||||||
"Loss is %s for mini batch: %s. "
|
"Loss is %s for mini batch: %s. Grads dict has %s entries: %s",
|
||||||
"Grads dict has %s entries: %s",
|
|
||||||
loss,
|
loss,
|
||||||
mini_batch,
|
mini_batch,
|
||||||
len(grads_dict),
|
len(grads_dict),
|
||||||
|
|
|
||||||
|
|
@ -162,9 +162,7 @@ class SampleInput:
|
||||||
# Allow calling either as SampleInput(input, args=args, kwargs=kwargs), or as
|
# Allow calling either as SampleInput(input, args=args, kwargs=kwargs), or as
|
||||||
# SampleInput(input, *args, **kwargs) but not to mix the two forms
|
# SampleInput(input, *args, **kwargs) but not to mix the two forms
|
||||||
if args is not None or kwargs is not None:
|
if args is not None or kwargs is not None:
|
||||||
assert (
|
assert not var_args and not var_kwargs, """
|
||||||
not var_args and not var_kwargs
|
|
||||||
), """
|
|
||||||
A SampleInput can be constructed "naturally" with *args and **kwargs or by
|
A SampleInput can be constructed "naturally" with *args and **kwargs or by
|
||||||
explicitly setting the "args" and "kwargs" parameters, but the two
|
explicitly setting the "args" and "kwargs" parameters, but the two
|
||||||
methods of construction cannot be mixed!"""
|
methods of construction cannot be mixed!"""
|
||||||
|
|
@ -226,7 +224,7 @@ cannot specify additional metadata in keyword arguments"""
|
||||||
f"name={repr(self.name)}",
|
f"name={repr(self.name)}",
|
||||||
]
|
]
|
||||||
|
|
||||||
return f'SampleInput({", ".join(a for a in arguments if a is not None)})'
|
return f"SampleInput({', '.join(a for a in arguments if a is not None)})"
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return self._repr_helper(lambda x: x)
|
return self._repr_helper(lambda x: x)
|
||||||
|
|
@ -1601,13 +1599,11 @@ class SampleRule(ABC):
|
||||||
|
|
||||||
# returns a string identifier of the rule type
|
# returns a string identifier of the rule type
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def type(self) -> str:
|
def type(self) -> str: ...
|
||||||
...
|
|
||||||
|
|
||||||
# returns an appropriate context that handles the xfail, skips, etc.
|
# returns an appropriate context that handles the xfail, skips, etc.
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_context(self, test_case):
|
def get_context(self, test_case): ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
# useful for specifying xfails
|
# useful for specifying xfails
|
||||||
|
|
@ -1791,8 +1787,10 @@ class ReductionOpInfo(OpInfo):
|
||||||
# kwargs to use when calling the op. This is required for operators that
|
# kwargs to use when calling the op. This is required for operators that
|
||||||
# have other required parameters besides the input tensor.
|
# have other required parameters besides the input tensor.
|
||||||
generate_args_kwargs: Callable = lambda t, dim=None, keepdim=False: (
|
generate_args_kwargs: Callable = lambda t, dim=None, keepdim=False: (
|
||||||
yield (),
|
yield (
|
||||||
{},
|
(),
|
||||||
|
{},
|
||||||
|
)
|
||||||
),
|
),
|
||||||
# Options from the OpInfo base class
|
# Options from the OpInfo base class
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|
@ -2476,9 +2474,9 @@ class BinaryUfuncInfo(OpInfo):
|
||||||
self.supports_one_python_scalar = True
|
self.supports_one_python_scalar = True
|
||||||
|
|
||||||
if self.supports_one_python_scalar:
|
if self.supports_one_python_scalar:
|
||||||
assert (
|
assert supports_rhs_python_scalar, (
|
||||||
supports_rhs_python_scalar
|
"Can't support lhs and rhs Python scalars but not rhs scalars!"
|
||||||
), "Can't support lhs and rhs Python scalars but not rhs scalars!"
|
)
|
||||||
|
|
||||||
|
|
||||||
# The following functions and classes are for testing elementwise unary operators.
|
# The following functions and classes are for testing elementwise unary operators.
|
||||||
|
|
|
||||||
|
|
@ -102,8 +102,9 @@ def sample_inputs_masked_reduction(op_info, device, dtype, requires_grad, **kwar
|
||||||
for mask in _generate_masked_op_mask(
|
for mask in _generate_masked_op_mask(
|
||||||
sample_input.input.shape, device, **kwargs
|
sample_input.input.shape, device, **kwargs
|
||||||
):
|
):
|
||||||
sample_input_args, sample_input_kwargs = sample_input.args, dict(
|
sample_input_args, sample_input_kwargs = (
|
||||||
mask=mask, **sample_input.kwargs
|
sample_input.args,
|
||||||
|
dict(mask=mask, **sample_input.kwargs),
|
||||||
)
|
)
|
||||||
yield SampleInput(
|
yield SampleInput(
|
||||||
sample_input.input.detach().requires_grad_(requires_grad),
|
sample_input.input.detach().requires_grad_(requires_grad),
|
||||||
|
|
@ -224,8 +225,9 @@ def sample_inputs_masked_norm(op_info, device, dtype, requires_grad, **kwargs):
|
||||||
op_info, device, dtype, requires_grad, **kwargs
|
op_info, device, dtype, requires_grad, **kwargs
|
||||||
):
|
):
|
||||||
sample_input_args, sample_input_kwargs = (
|
sample_input_args, sample_input_kwargs = (
|
||||||
ord,
|
(ord,) + sample_input.args,
|
||||||
) + sample_input.args, sample_input.kwargs.copy()
|
sample_input.kwargs.copy(),
|
||||||
|
)
|
||||||
yield SampleInput(
|
yield SampleInput(
|
||||||
sample_input.input.clone().requires_grad_(requires_grad),
|
sample_input.input.clone().requires_grad_(requires_grad),
|
||||||
args=sample_input_args,
|
args=sample_input_args,
|
||||||
|
|
@ -276,8 +278,9 @@ def sample_inputs_masked_std_var(op_info, device, dtype, requires_grad, **kwargs
|
||||||
for mask in _generate_masked_op_mask(
|
for mask in _generate_masked_op_mask(
|
||||||
sample_input.input.shape, device, **kwargs
|
sample_input.input.shape, device, **kwargs
|
||||||
):
|
):
|
||||||
sample_input_args, sample_input_kwargs = sample_input.args, dict(
|
sample_input_args, sample_input_kwargs = (
|
||||||
mask=mask, **sample_input.kwargs
|
sample_input.args,
|
||||||
|
dict(mask=mask, **sample_input.kwargs),
|
||||||
)
|
)
|
||||||
yield SampleInput(
|
yield SampleInput(
|
||||||
sample_input.input.detach().requires_grad_(requires_grad),
|
sample_input.input.detach().requires_grad_(requires_grad),
|
||||||
|
|
@ -364,8 +367,9 @@ def sample_inputs_masked_cumops(op_info, device, dtype, requires_grad, **kwargs)
|
||||||
):
|
):
|
||||||
if type(mask) != torch.Tensor:
|
if type(mask) != torch.Tensor:
|
||||||
continue
|
continue
|
||||||
sample_input_args, sample_input_kwargs = sample_input.args, dict(
|
sample_input_args, sample_input_kwargs = (
|
||||||
mask=mask, **sample_input.kwargs
|
sample_input.args,
|
||||||
|
dict(mask=mask, **sample_input.kwargs),
|
||||||
)
|
)
|
||||||
if "keepdim" in sample_input_kwargs:
|
if "keepdim" in sample_input_kwargs:
|
||||||
sample_input_kwargs.pop("keepdim")
|
sample_input_kwargs.pop("keepdim")
|
||||||
|
|
|
||||||
|
|
@ -112,7 +112,7 @@ class _Config(Generic[T]):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def string_or_list_of_string_to_list(
|
def string_or_list_of_string_to_list(
|
||||||
val: Optional[Union[str, list[str]]]
|
val: Optional[Union[str, list[str]]],
|
||||||
) -> Optional[list[str]]:
|
) -> Optional[list[str]]:
|
||||||
if val is None:
|
if val is None:
|
||||||
return None
|
return None
|
||||||
|
|
@ -135,8 +135,7 @@ if TYPE_CHECKING:
|
||||||
env_name_force: Optional[Union[str, list[str]]] = None,
|
env_name_force: Optional[Union[str, list[str]]] = None,
|
||||||
value_type: Optional[type] = None,
|
value_type: Optional[type] = None,
|
||||||
alias: Optional[str] = None,
|
alias: Optional[str] = None,
|
||||||
) -> T:
|
) -> T: ...
|
||||||
...
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
|
|
@ -323,9 +322,9 @@ class _ConfigEntry:
|
||||||
|
|
||||||
# Ensure justknobs and envvars are allowlisted types
|
# Ensure justknobs and envvars are allowlisted types
|
||||||
if self.justknob is not None and self.default is not None:
|
if self.justknob is not None and self.default is not None:
|
||||||
assert isinstance(
|
assert isinstance(self.default, bool), (
|
||||||
self.default, bool
|
f"justknobs only support booleans, {self.default} is not a boolean"
|
||||||
), f"justknobs only support booleans, {self.default} is not a boolean"
|
)
|
||||||
if self.value_type is not None and (
|
if self.value_type is not None and (
|
||||||
config.env_name_default is not None or config.env_name_force is not None
|
config.env_name_default is not None or config.env_name_force is not None
|
||||||
):
|
):
|
||||||
|
|
@ -334,7 +333,9 @@ class _ConfigEntry:
|
||||||
str,
|
str,
|
||||||
Optional[bool],
|
Optional[bool],
|
||||||
Optional[str],
|
Optional[str],
|
||||||
), f"envvar configs only support (optional) booleans or strings, {self.value_type} is neither"
|
), (
|
||||||
|
f"envvar configs only support (optional) booleans or strings, {self.value_type} is neither"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConfigModule(ModuleType):
|
class ConfigModule(ModuleType):
|
||||||
|
|
|
||||||
|
|
@ -282,9 +282,9 @@ def tree_is_leaf(
|
||||||
False
|
False
|
||||||
>>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple))
|
>>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple))
|
||||||
True
|
True
|
||||||
>>> tree_is_leaf({'a': 1, 'b': 2, 'c': 3})
|
>>> tree_is_leaf({"a": 1, "b": 2, "c": 3})
|
||||||
False
|
False
|
||||||
>>> tree_is_leaf({'a': 1, 'b': 2, 'c': None})
|
>>> tree_is_leaf({"a": 1, "b": 2, "c": None})
|
||||||
False
|
False
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -586,29 +586,28 @@ MapOnlyFn = Callable[[T], Callable[[Any], Any]]
|
||||||
# These specializations help with type inference on the lambda passed to this
|
# These specializations help with type inference on the lambda passed to this
|
||||||
# function
|
# function
|
||||||
@overload
|
@overload
|
||||||
def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]:
|
def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]:
|
def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def map_only(type_or_types_or_pred: Type3[T, S, U], /) -> MapOnlyFn[Fn3[T, S, U, Any]]:
|
def map_only(
|
||||||
...
|
type_or_types_or_pred: Type3[T, S, U], /
|
||||||
|
) -> MapOnlyFn[Fn3[T, S, U, Any]]: ...
|
||||||
|
|
||||||
|
|
||||||
# This specialization is needed for the implementations below that call
|
# This specialization is needed for the implementations below that call
|
||||||
@overload
|
@overload
|
||||||
def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]:
|
def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def map_only(type_or_types_or_pred: Callable[[Any], bool], /) -> MapOnlyFn[FnAny[Any]]:
|
def map_only(
|
||||||
...
|
type_or_types_or_pred: Callable[[Any], bool], /
|
||||||
|
) -> MapOnlyFn[FnAny[Any]]: ...
|
||||||
|
|
||||||
|
|
||||||
def map_only(
|
def map_only(
|
||||||
|
|
@ -664,8 +663,7 @@ def tree_map_only(
|
||||||
func: Fn[T, Any],
|
func: Fn[T, Any],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> PyTree:
|
) -> PyTree: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -675,8 +673,7 @@ def tree_map_only(
|
||||||
func: Fn2[T, S, Any],
|
func: Fn2[T, S, Any],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> PyTree:
|
) -> PyTree: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -686,8 +683,7 @@ def tree_map_only(
|
||||||
func: Fn3[T, S, U, Any],
|
func: Fn3[T, S, U, Any],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> PyTree:
|
) -> PyTree: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -697,8 +693,7 @@ def tree_map_only(
|
||||||
func: FnAny[Any],
|
func: FnAny[Any],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> PyTree:
|
) -> PyTree: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -708,8 +703,7 @@ def tree_map_only(
|
||||||
func: FnAny[Any],
|
func: FnAny[Any],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> PyTree:
|
) -> PyTree: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def tree_map_only(
|
def tree_map_only(
|
||||||
|
|
@ -729,8 +723,7 @@ def tree_map_only_(
|
||||||
func: Fn[T, Any],
|
func: Fn[T, Any],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> PyTree:
|
) -> PyTree: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -740,8 +733,7 @@ def tree_map_only_(
|
||||||
func: Fn2[T, S, Any],
|
func: Fn2[T, S, Any],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> PyTree:
|
) -> PyTree: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -751,8 +743,7 @@ def tree_map_only_(
|
||||||
func: Fn3[T, S, U, Any],
|
func: Fn3[T, S, U, Any],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> PyTree:
|
) -> PyTree: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -762,8 +753,7 @@ def tree_map_only_(
|
||||||
func: FnAny[Any],
|
func: FnAny[Any],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> PyTree:
|
) -> PyTree: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -773,8 +763,7 @@ def tree_map_only_(
|
||||||
func: FnAny[Any],
|
func: FnAny[Any],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> PyTree:
|
) -> PyTree: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def tree_map_only_(
|
def tree_map_only_(
|
||||||
|
|
@ -812,8 +801,7 @@ def tree_all_only(
|
||||||
pred: Fn[T, bool],
|
pred: Fn[T, bool],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> bool:
|
) -> bool: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -823,8 +811,7 @@ def tree_all_only(
|
||||||
pred: Fn2[T, S, bool],
|
pred: Fn2[T, S, bool],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> bool:
|
) -> bool: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -834,8 +821,7 @@ def tree_all_only(
|
||||||
pred: Fn3[T, S, U, bool],
|
pred: Fn3[T, S, U, bool],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> bool:
|
) -> bool: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def tree_all_only(
|
def tree_all_only(
|
||||||
|
|
@ -856,8 +842,7 @@ def tree_any_only(
|
||||||
pred: Fn[T, bool],
|
pred: Fn[T, bool],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> bool:
|
) -> bool: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -867,8 +852,7 @@ def tree_any_only(
|
||||||
pred: Fn2[T, S, bool],
|
pred: Fn2[T, S, bool],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> bool:
|
) -> bool: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -878,8 +862,7 @@ def tree_any_only(
|
||||||
pred: Fn3[T, S, U, bool],
|
pred: Fn3[T, S, U, bool],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> bool:
|
) -> bool: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def tree_any_only(
|
def tree_any_only(
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ _cache_sentinel = object()
|
||||||
|
|
||||||
|
|
||||||
def cache_method(
|
def cache_method(
|
||||||
f: Callable[Concatenate[_C, _P], _T]
|
f: Callable[Concatenate[_C, _P], _T],
|
||||||
) -> Callable[Concatenate[_C, _P], _T]:
|
) -> Callable[Concatenate[_C, _P], _T]:
|
||||||
"""
|
"""
|
||||||
Like `@functools.cache` but for methods.
|
Like `@functools.cache` but for methods.
|
||||||
|
|
|
||||||
|
|
@ -302,14 +302,12 @@ class BaseTorchDispatchMode(TorchDispatchMode):
|
||||||
|
|
||||||
# Subtypes which have __tensor_flatten__ and __tensor_unflatten__.
|
# Subtypes which have __tensor_flatten__ and __tensor_unflatten__.
|
||||||
class TensorWithFlatten(Protocol):
|
class TensorWithFlatten(Protocol):
|
||||||
def __tensor_flatten__(self) -> tuple[Sequence[str], object]:
|
def __tensor_flatten__(self) -> tuple[Sequence[str], object]: ...
|
||||||
...
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __tensor_unflatten__(
|
def __tensor_unflatten__(
|
||||||
inner_tensors: int, flatten_spec: int, outer_size: int, outer_stride: int
|
inner_tensors: int, flatten_spec: int, outer_size: int, outer_stride: int
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor: ...
|
||||||
...
|
|
||||||
|
|
||||||
# It would be really nice to be able to say that the return of
|
# It would be really nice to be able to say that the return of
|
||||||
# is_traceable_wrapper_subclass() is Intersection[torch.Tensor,
|
# is_traceable_wrapper_subclass() is Intersection[torch.Tensor,
|
||||||
|
|
@ -318,26 +316,20 @@ class TensorWithFlatten(Protocol):
|
||||||
shape: torch._C.Size
|
shape: torch._C.Size
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def stride(self, dim: None = None) -> tuple[int, ...]:
|
def stride(self, dim: None = None) -> tuple[int, ...]: ...
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def stride(self, dim: int) -> int:
|
def stride(self, dim: int) -> int: ...
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def size(self, dim: None = None) -> tuple[int, ...]:
|
def size(self, dim: None = None) -> tuple[int, ...]: ...
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def size(self, dim: int) -> int:
|
def size(self, dim: int) -> int: ...
|
||||||
...
|
|
||||||
|
|
||||||
def storage_offset(self) -> int:
|
def storage_offset(self) -> int: ...
|
||||||
...
|
|
||||||
|
|
||||||
def dim(self) -> int:
|
def dim(self) -> int: ...
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def to(
|
def to(
|
||||||
|
|
@ -347,8 +339,7 @@ class TensorWithFlatten(Protocol):
|
||||||
copy: bool = False,
|
copy: bool = False,
|
||||||
*,
|
*,
|
||||||
memory_format: Optional[torch.memory_format] = None,
|
memory_format: Optional[torch.memory_format] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor: ...
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def to(
|
def to(
|
||||||
|
|
@ -359,8 +350,7 @@ class TensorWithFlatten(Protocol):
|
||||||
copy: bool = False,
|
copy: bool = False,
|
||||||
*,
|
*,
|
||||||
memory_format: Optional[torch.memory_format] = None,
|
memory_format: Optional[torch.memory_format] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor: ...
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def to(
|
def to(
|
||||||
|
|
@ -370,8 +360,7 @@ class TensorWithFlatten(Protocol):
|
||||||
copy: bool = False,
|
copy: bool = False,
|
||||||
*,
|
*,
|
||||||
memory_format: Optional[torch.memory_format] = None,
|
memory_format: Optional[torch.memory_format] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]:
|
def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]:
|
||||||
|
|
|
||||||
|
|
@ -99,17 +99,13 @@ NO_SERIALIZED_TYPE_NAME_FOUND = "NO_SERIALIZED_TYPE_NAME_FOUND"
|
||||||
|
|
||||||
|
|
||||||
class KeyEntry(Protocol):
|
class KeyEntry(Protocol):
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int: ...
|
||||||
...
|
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
def __eq__(self, other: object) -> bool: ...
|
||||||
...
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str: ...
|
||||||
...
|
|
||||||
|
|
||||||
def get(self, parent: Any) -> Any:
|
def get(self, parent: Any) -> Any: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class EnumEncoder(json.JSONEncoder):
|
class EnumEncoder(json.JSONEncoder):
|
||||||
|
|
@ -757,7 +753,7 @@ def _tuple_flatten(d: tuple[T, ...]) -> tuple[list[T], Context]:
|
||||||
|
|
||||||
|
|
||||||
def _tuple_flatten_with_keys(
|
def _tuple_flatten_with_keys(
|
||||||
d: tuple[T, ...]
|
d: tuple[T, ...],
|
||||||
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
||||||
values, context = _tuple_flatten(d)
|
values, context = _tuple_flatten(d)
|
||||||
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
|
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
|
||||||
|
|
@ -785,7 +781,7 @@ def _dict_flatten(d: dict[Any, T]) -> tuple[list[T], Context]:
|
||||||
|
|
||||||
|
|
||||||
def _dict_flatten_with_keys(
|
def _dict_flatten_with_keys(
|
||||||
d: dict[Any, T]
|
d: dict[Any, T],
|
||||||
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
||||||
values, context = _dict_flatten(d)
|
values, context = _dict_flatten(d)
|
||||||
return [(MappingKey(k), v) for k, v in zip(context, values)], context
|
return [(MappingKey(k), v) for k, v in zip(context, values)], context
|
||||||
|
|
@ -849,7 +845,7 @@ def _ordereddict_flatten(d: OrderedDict[Any, T]) -> tuple[list[T], Context]:
|
||||||
|
|
||||||
|
|
||||||
def _ordereddict_flatten_with_keys(
|
def _ordereddict_flatten_with_keys(
|
||||||
d: OrderedDict[Any, T]
|
d: OrderedDict[Any, T],
|
||||||
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
||||||
values, context = _ordereddict_flatten(d)
|
values, context = _ordereddict_flatten(d)
|
||||||
return [(MappingKey(k), v) for k, v in zip(context, values)], context
|
return [(MappingKey(k), v) for k, v in zip(context, values)], context
|
||||||
|
|
@ -872,7 +868,7 @@ def _defaultdict_flatten(d: defaultdict[Any, T]) -> tuple[list[T], Context]:
|
||||||
|
|
||||||
|
|
||||||
def _defaultdict_flatten_with_keys(
|
def _defaultdict_flatten_with_keys(
|
||||||
d: defaultdict[Any, T]
|
d: defaultdict[Any, T],
|
||||||
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
||||||
values, context = _defaultdict_flatten(d)
|
values, context = _defaultdict_flatten(d)
|
||||||
_, dict_context = context
|
_, dict_context = context
|
||||||
|
|
@ -1035,9 +1031,9 @@ def tree_is_leaf(
|
||||||
False
|
False
|
||||||
>>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple))
|
>>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple))
|
||||||
True
|
True
|
||||||
>>> tree_is_leaf({'a': 1, 'b': 2, 'c': 3})
|
>>> tree_is_leaf({"a": 1, "b": 2, "c": 3})
|
||||||
False
|
False
|
||||||
>>> tree_is_leaf({'a': 1, 'b': 2, 'c': None})
|
>>> tree_is_leaf({"a": 1, "b": 2, "c": None})
|
||||||
False
|
False
|
||||||
"""
|
"""
|
||||||
if is_leaf is not None and is_leaf(tree):
|
if is_leaf is not None and is_leaf(tree):
|
||||||
|
|
@ -1346,9 +1342,9 @@ def tree_map(
|
||||||
|
|
||||||
See also :func:`tree_map_`.
|
See also :func:`tree_map_`.
|
||||||
|
|
||||||
>>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)})
|
>>> tree_map(lambda x: x + 1, {"x": 7, "y": (42, 64)})
|
||||||
{'x': 8, 'y': (43, 65)}
|
{'x': 8, 'y': (43, 65)}
|
||||||
>>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None})
|
>>> tree_map(lambda x: x is None, {"x": 7, "y": (42, 64), "z": None})
|
||||||
{'x': False, 'y': (False, False), 'z': True}
|
{'x': False, 'y': (False, False), 'z': True}
|
||||||
|
|
||||||
If multiple inputs are given, the structure of the tree is taken from the first input;
|
If multiple inputs are given, the structure of the tree is taken from the first input;
|
||||||
|
|
@ -1432,29 +1428,28 @@ MapOnlyFn = Callable[[T], Callable[[Any], Any]]
|
||||||
# These specializations help with type inference on the lambda passed to this
|
# These specializations help with type inference on the lambda passed to this
|
||||||
# function
|
# function
|
||||||
@overload
|
@overload
|
||||||
def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]:
|
def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]:
|
def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def map_only(type_or_types_or_pred: Type3[T, S, U], /) -> MapOnlyFn[Fn3[T, S, U, Any]]:
|
def map_only(
|
||||||
...
|
type_or_types_or_pred: Type3[T, S, U], /
|
||||||
|
) -> MapOnlyFn[Fn3[T, S, U, Any]]: ...
|
||||||
|
|
||||||
|
|
||||||
# This specialization is needed for the implementations below that call
|
# This specialization is needed for the implementations below that call
|
||||||
@overload
|
@overload
|
||||||
def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]:
|
def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def map_only(type_or_types_or_pred: Callable[[Any], bool], /) -> MapOnlyFn[FnAny[Any]]:
|
def map_only(
|
||||||
...
|
type_or_types_or_pred: Callable[[Any], bool], /
|
||||||
|
) -> MapOnlyFn[FnAny[Any]]: ...
|
||||||
|
|
||||||
|
|
||||||
def map_only(
|
def map_only(
|
||||||
|
|
@ -1510,8 +1505,7 @@ def tree_map_only(
|
||||||
func: Fn[T, Any],
|
func: Fn[T, Any],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> PyTree:
|
) -> PyTree: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -1521,8 +1515,7 @@ def tree_map_only(
|
||||||
func: Fn2[T, S, Any],
|
func: Fn2[T, S, Any],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> PyTree:
|
) -> PyTree: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -1532,8 +1525,7 @@ def tree_map_only(
|
||||||
func: Fn3[T, S, U, Any],
|
func: Fn3[T, S, U, Any],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> PyTree:
|
) -> PyTree: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -1543,8 +1535,7 @@ def tree_map_only(
|
||||||
func: FnAny[Any],
|
func: FnAny[Any],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> PyTree:
|
) -> PyTree: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -1554,8 +1545,7 @@ def tree_map_only(
|
||||||
func: FnAny[Any],
|
func: FnAny[Any],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> PyTree:
|
) -> PyTree: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def tree_map_only(
|
def tree_map_only(
|
||||||
|
|
@ -1575,8 +1565,7 @@ def tree_map_only_(
|
||||||
func: Fn[T, Any],
|
func: Fn[T, Any],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> PyTree:
|
) -> PyTree: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -1586,8 +1575,7 @@ def tree_map_only_(
|
||||||
func: Fn2[T, S, Any],
|
func: Fn2[T, S, Any],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> PyTree:
|
) -> PyTree: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -1597,8 +1585,7 @@ def tree_map_only_(
|
||||||
func: Fn3[T, S, U, Any],
|
func: Fn3[T, S, U, Any],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> PyTree:
|
) -> PyTree: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -1608,8 +1595,7 @@ def tree_map_only_(
|
||||||
func: FnAny[Any],
|
func: FnAny[Any],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> PyTree:
|
) -> PyTree: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -1619,8 +1605,7 @@ def tree_map_only_(
|
||||||
func: FnAny[Any],
|
func: FnAny[Any],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> PyTree:
|
) -> PyTree: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def tree_map_only_(
|
def tree_map_only_(
|
||||||
|
|
@ -1658,8 +1643,7 @@ def tree_all_only(
|
||||||
pred: Fn[T, bool],
|
pred: Fn[T, bool],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> bool:
|
) -> bool: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -1669,8 +1653,7 @@ def tree_all_only(
|
||||||
pred: Fn2[T, S, bool],
|
pred: Fn2[T, S, bool],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> bool:
|
) -> bool: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -1680,8 +1663,7 @@ def tree_all_only(
|
||||||
pred: Fn3[T, S, U, bool],
|
pred: Fn3[T, S, U, bool],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> bool:
|
) -> bool: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def tree_all_only(
|
def tree_all_only(
|
||||||
|
|
@ -1702,8 +1684,7 @@ def tree_any_only(
|
||||||
pred: Fn[T, bool],
|
pred: Fn[T, bool],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> bool:
|
) -> bool: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -1713,8 +1694,7 @@ def tree_any_only(
|
||||||
pred: Fn2[T, S, bool],
|
pred: Fn2[T, S, bool],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> bool:
|
) -> bool: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -1724,8 +1704,7 @@ def tree_any_only(
|
||||||
pred: Fn3[T, S, U, bool],
|
pred: Fn3[T, S, U, bool],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> bool:
|
) -> bool: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def tree_any_only(
|
def tree_any_only(
|
||||||
|
|
@ -1862,7 +1841,7 @@ def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec:
|
||||||
|
|
||||||
if json_schema["type"] not in SERIALIZED_TYPE_TO_PYTHON_TYPE:
|
if json_schema["type"] not in SERIALIZED_TYPE_TO_PYTHON_TYPE:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f'Deserializing {json_schema["type"]} in pytree is not registered.',
|
f"Deserializing {json_schema['type']} in pytree is not registered.",
|
||||||
)
|
)
|
||||||
|
|
||||||
typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[json_schema["type"]]
|
typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[json_schema["type"]]
|
||||||
|
|
|
||||||
|
|
@ -301,7 +301,7 @@ def strobelight(
|
||||||
profiler = StrobelightCLIFunctionProfiler(**kwargs)
|
profiler = StrobelightCLIFunctionProfiler(**kwargs)
|
||||||
|
|
||||||
def strobelight_inner(
|
def strobelight_inner(
|
||||||
work_function: Callable[_P, _R]
|
work_function: Callable[_P, _R],
|
||||||
) -> Callable[_P, Optional[_R]]:
|
) -> Callable[_P, Optional[_R]]:
|
||||||
@functools.wraps(work_function)
|
@functools.wraps(work_function)
|
||||||
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]:
|
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]:
|
||||||
|
|
|
||||||
|
|
@ -98,7 +98,7 @@ def _is_symbols_binary_summation(expr: sympy.Expr) -> bool:
|
||||||
|
|
||||||
|
|
||||||
def _keep_float(
|
def _keep_float(
|
||||||
f: Callable[[Unpack[_Ts]], _T]
|
f: Callable[[Unpack[_Ts]], _T],
|
||||||
) -> Callable[[Unpack[_Ts]], Union[_T, sympy.Float]]:
|
) -> Callable[[Unpack[_Ts]], Union[_T, sympy.Float]]:
|
||||||
@functools.wraps(f)
|
@functools.wraps(f)
|
||||||
def inner(*args: Unpack[_Ts]) -> Union[_T, sympy.Float]:
|
def inner(*args: Unpack[_Ts]) -> Union[_T, sympy.Float]:
|
||||||
|
|
@ -926,10 +926,12 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
|
||||||
|
|
||||||
_eval_is_algebraic = lambda s: _torf(i.is_algebraic for i in s.args) # noqa: E731
|
_eval_is_algebraic = lambda s: _torf(i.is_algebraic for i in s.args) # noqa: E731
|
||||||
_eval_is_antihermitian = lambda s: _torf( # noqa: E731
|
_eval_is_antihermitian = lambda s: _torf( # noqa: E731
|
||||||
i.is_antihermitian for i in s.args # noqa: E731
|
i.is_antihermitian
|
||||||
|
for i in s.args # noqa: E731
|
||||||
) # noqa: E731
|
) # noqa: E731
|
||||||
_eval_is_commutative = lambda s: _torf( # noqa: E731
|
_eval_is_commutative = lambda s: _torf( # noqa: E731
|
||||||
i.is_commutative for i in s.args # noqa: E731
|
i.is_commutative
|
||||||
|
for i in s.args # noqa: E731
|
||||||
) # noqa: E731
|
) # noqa: E731
|
||||||
_eval_is_complex = lambda s: _torf(i.is_complex for i in s.args) # noqa: E731
|
_eval_is_complex = lambda s: _torf(i.is_complex for i in s.args) # noqa: E731
|
||||||
_eval_is_composite = lambda s: _torf(i.is_composite for i in s.args) # noqa: E731
|
_eval_is_composite = lambda s: _torf(i.is_composite for i in s.args) # noqa: E731
|
||||||
|
|
@ -943,10 +945,12 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
|
||||||
_eval_is_negative = lambda s: _torf(i.is_negative for i in s.args) # noqa: E731
|
_eval_is_negative = lambda s: _torf(i.is_negative for i in s.args) # noqa: E731
|
||||||
_eval_is_noninteger = lambda s: _torf(i.is_noninteger for i in s.args) # noqa: E731
|
_eval_is_noninteger = lambda s: _torf(i.is_noninteger for i in s.args) # noqa: E731
|
||||||
_eval_is_nonnegative = lambda s: _torf( # noqa: E731
|
_eval_is_nonnegative = lambda s: _torf( # noqa: E731
|
||||||
i.is_nonnegative for i in s.args # noqa: E731
|
i.is_nonnegative
|
||||||
|
for i in s.args # noqa: E731
|
||||||
) # noqa: E731
|
) # noqa: E731
|
||||||
_eval_is_nonpositive = lambda s: _torf( # noqa: E731
|
_eval_is_nonpositive = lambda s: _torf( # noqa: E731
|
||||||
i.is_nonpositive for i in s.args # noqa: E731
|
i.is_nonpositive
|
||||||
|
for i in s.args # noqa: E731
|
||||||
) # noqa: E731
|
) # noqa: E731
|
||||||
_eval_is_nonzero = lambda s: _torf(i.is_nonzero for i in s.args) # noqa: E731
|
_eval_is_nonzero = lambda s: _torf(i.is_nonzero for i in s.args) # noqa: E731
|
||||||
_eval_is_odd = lambda s: _torf(i.is_odd for i in s.args) # noqa: E731
|
_eval_is_odd = lambda s: _torf(i.is_odd for i in s.args) # noqa: E731
|
||||||
|
|
@ -956,10 +960,12 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
|
||||||
_eval_is_rational = lambda s: _torf(i.is_rational for i in s.args) # noqa: E731
|
_eval_is_rational = lambda s: _torf(i.is_rational for i in s.args) # noqa: E731
|
||||||
_eval_is_real = lambda s: _torf(i.is_real for i in s.args) # noqa: E731
|
_eval_is_real = lambda s: _torf(i.is_real for i in s.args) # noqa: E731
|
||||||
_eval_is_extended_real = lambda s: _torf( # noqa: E731
|
_eval_is_extended_real = lambda s: _torf( # noqa: E731
|
||||||
i.is_extended_real for i in s.args # noqa: E731
|
i.is_extended_real
|
||||||
|
for i in s.args # noqa: E731
|
||||||
) # noqa: E731
|
) # noqa: E731
|
||||||
_eval_is_transcendental = lambda s: _torf( # noqa: E731
|
_eval_is_transcendental = lambda s: _torf( # noqa: E731
|
||||||
i.is_transcendental for i in s.args # noqa: E731
|
i.is_transcendental
|
||||||
|
for i in s.args # noqa: E731
|
||||||
) # noqa: E731
|
) # noqa: E731
|
||||||
_eval_is_zero = lambda s: _torf(i.is_zero for i in s.args) # noqa: E731
|
_eval_is_zero = lambda s: _torf(i.is_zero for i in s.args) # noqa: E731
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -144,16 +144,14 @@ class ValueRanges(Generic[_T]):
|
||||||
self: ValueRanges[sympy.Expr],
|
self: ValueRanges[sympy.Expr],
|
||||||
lower: ExprIn,
|
lower: ExprIn,
|
||||||
upper: ExprIn,
|
upper: ExprIn,
|
||||||
) -> None:
|
) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __init__( # type: ignore[misc]
|
def __init__( # type: ignore[misc]
|
||||||
self: ValueRanges[SympyBoolean],
|
self: ValueRanges[SympyBoolean],
|
||||||
lower: BoolIn,
|
lower: BoolIn,
|
||||||
upper: BoolIn,
|
upper: BoolIn,
|
||||||
) -> None:
|
) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def __init__(self, lower: AllIn, upper: AllIn) -> None:
|
def __init__(self, lower: AllIn, upper: AllIn) -> None:
|
||||||
lower = simple_sympify(lower)
|
lower = simple_sympify(lower)
|
||||||
|
|
@ -240,15 +238,13 @@ class ValueRanges(Generic[_T]):
|
||||||
def __and__(
|
def __and__(
|
||||||
self: ValueRanges[sympy.Expr],
|
self: ValueRanges[sympy.Expr],
|
||||||
other: ValueRanges[sympy.Expr],
|
other: ValueRanges[sympy.Expr],
|
||||||
) -> ValueRanges[sympy.Expr]:
|
) -> ValueRanges[sympy.Expr]: ...
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __and__( # type: ignore[misc]
|
def __and__( # type: ignore[misc]
|
||||||
self: ValueRanges[SympyBoolean],
|
self: ValueRanges[SympyBoolean],
|
||||||
other: ValueRanges[SympyBoolean],
|
other: ValueRanges[SympyBoolean],
|
||||||
) -> ValueRanges[SympyBoolean]:
|
) -> ValueRanges[SympyBoolean]: ...
|
||||||
...
|
|
||||||
|
|
||||||
def __and__(self: AllVR, other: AllVR) -> AllVR:
|
def __and__(self: AllVR, other: AllVR) -> AllVR:
|
||||||
if other in (ValueRanges.unknown(), ValueRanges.unknown_int()):
|
if other in (ValueRanges.unknown(), ValueRanges.unknown_int()):
|
||||||
|
|
@ -272,15 +268,13 @@ class ValueRanges(Generic[_T]):
|
||||||
def __or__(
|
def __or__(
|
||||||
self: ValueRanges[sympy.Expr],
|
self: ValueRanges[sympy.Expr],
|
||||||
other: ValueRanges[sympy.Expr],
|
other: ValueRanges[sympy.Expr],
|
||||||
) -> ValueRanges[sympy.Expr]:
|
) -> ValueRanges[sympy.Expr]: ...
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __or__( # type: ignore[misc]
|
def __or__( # type: ignore[misc]
|
||||||
self: ValueRanges[SympyBoolean],
|
self: ValueRanges[SympyBoolean],
|
||||||
other: ValueRanges[SympyBoolean],
|
other: ValueRanges[SympyBoolean],
|
||||||
) -> ValueRanges[SympyBoolean]:
|
) -> ValueRanges[SympyBoolean]: ...
|
||||||
...
|
|
||||||
|
|
||||||
def __or__(self: AllVR, other: AllVR) -> AllVR:
|
def __or__(self: AllVR, other: AllVR) -> AllVR:
|
||||||
if ValueRanges.unknown() in (self, other):
|
if ValueRanges.unknown() in (self, other):
|
||||||
|
|
@ -343,8 +337,7 @@ class ValueRanges(Generic[_T]):
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
|
def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: ...
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -384,8 +377,7 @@ class ValueRanges(Generic[_T]):
|
||||||
x: Union[ExprIn, ExprVR],
|
x: Union[ExprIn, ExprVR],
|
||||||
y: Union[ExprIn, ExprVR],
|
y: Union[ExprIn, ExprVR],
|
||||||
fn: ExprFn2,
|
fn: ExprFn2,
|
||||||
) -> ExprVR:
|
) -> ExprVR: ...
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -393,8 +385,7 @@ class ValueRanges(Generic[_T]):
|
||||||
x: Union[BoolIn, BoolVR],
|
x: Union[BoolIn, BoolVR],
|
||||||
y: Union[BoolIn, BoolVR],
|
y: Union[BoolIn, BoolVR],
|
||||||
fn: BoolFn2,
|
fn: BoolFn2,
|
||||||
) -> BoolVR:
|
) -> BoolVR: ...
|
||||||
...
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def coordinatewise_increasing_map(
|
def coordinatewise_increasing_map(
|
||||||
|
|
|
||||||
|
|
@ -426,9 +426,9 @@ def _get_custom_mod_func(func_name: str):
|
||||||
it is marked as private. It is a convenience function for backend implementers to
|
it is marked as private. It is a convenience function for backend implementers to
|
||||||
more easily call the hooks into their backend extensions.
|
more easily call the hooks into their backend extensions.
|
||||||
"""
|
"""
|
||||||
assert isinstance(
|
assert isinstance(func_name, str), (
|
||||||
func_name, str
|
f"func_name must be `str`, but got `{type(func_name)}`."
|
||||||
), f"func_name must be `str`, but got `{type(func_name)}`."
|
)
|
||||||
backend_name = _get_privateuse1_backend_name()
|
backend_name = _get_privateuse1_backend_name()
|
||||||
custom_device_mod = getattr(torch, backend_name, None) # type: ignore[arg-type]
|
custom_device_mod = getattr(torch, backend_name, None) # type: ignore[arg-type]
|
||||||
function = getattr(custom_device_mod, func_name, None) # type: ignore[arg-type]
|
function = getattr(custom_device_mod, func_name, None) # type: ignore[arg-type]
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ def default_convert(data):
|
||||||
>>> default_convert(np.array([0, 1]))
|
>>> default_convert(np.array([0, 1]))
|
||||||
tensor([0, 1])
|
tensor([0, 1])
|
||||||
>>> # Example with NamedTuple
|
>>> # Example with NamedTuple
|
||||||
>>> Point = namedtuple('Point', ['x', 'y'])
|
>>> Point = namedtuple("Point", ["x", "y"])
|
||||||
>>> default_convert(Point(0, 0))
|
>>> default_convert(Point(0, 0))
|
||||||
Point(x=0, y=0)
|
Point(x=0, y=0)
|
||||||
>>> default_convert(Point(np.array(0), np.array(0)))
|
>>> default_convert(Point(np.array(0), np.array(0)))
|
||||||
|
|
@ -366,13 +366,13 @@ def default_collate(batch):
|
||||||
>>> default_collate([0, 1, 2, 3])
|
>>> default_collate([0, 1, 2, 3])
|
||||||
tensor([0, 1, 2, 3])
|
tensor([0, 1, 2, 3])
|
||||||
>>> # Example with a batch of `str`s:
|
>>> # Example with a batch of `str`s:
|
||||||
>>> default_collate(['a', 'b', 'c'])
|
>>> default_collate(["a", "b", "c"])
|
||||||
['a', 'b', 'c']
|
['a', 'b', 'c']
|
||||||
>>> # Example with `Map` inside the batch:
|
>>> # Example with `Map` inside the batch:
|
||||||
>>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}])
|
>>> default_collate([{"A": 0, "B": 1}, {"A": 100, "B": 100}])
|
||||||
{'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])}
|
{'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])}
|
||||||
>>> # Example with `NamedTuple` inside the batch:
|
>>> # Example with `NamedTuple` inside the batch:
|
||||||
>>> Point = namedtuple('Point', ['x', 'y'])
|
>>> Point = namedtuple("Point", ["x", "y"])
|
||||||
>>> default_collate([Point(0, 0), Point(1, 1)])
|
>>> default_collate([Point(0, 0), Point(1, 1)])
|
||||||
Point(x=tensor([0, 1]), y=tensor([0, 1]))
|
Point(x=tensor([0, 1]), y=tensor([0, 1]))
|
||||||
>>> # Example with `Tuple` inside the batch:
|
>>> # Example with `Tuple` inside the batch:
|
||||||
|
|
|
||||||
|
|
@ -69,7 +69,9 @@ def pin_memory(data, device=None):
|
||||||
)
|
)
|
||||||
return clone
|
return clone
|
||||||
else:
|
else:
|
||||||
return type(data)({k: pin_memory(sample, device) for k, sample in data.items()}) # type: ignore[call-arg]
|
return type(data)(
|
||||||
|
{k: pin_memory(sample, device) for k, sample in data.items()}
|
||||||
|
) # type: ignore[call-arg]
|
||||||
except TypeError:
|
except TypeError:
|
||||||
# The mapping type may not support `copy()` / `update(mapping)`
|
# The mapping type may not support `copy()` / `update(mapping)`
|
||||||
# or `__init__(iterable)`.
|
# or `__init__(iterable)`.
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
|
r"""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
|
||||||
|
|
||||||
These **needs** to be in global scope since Py2 doesn't support serializing
|
These **needs** to be in global scope since Py2 doesn't support serializing
|
||||||
static methods.
|
static methods.
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ To support these two classes, in `./_utils` we define many utility methods and
|
||||||
functions to be run in multiprocessing. E.g., the data loading worker loop is
|
functions to be run in multiprocessing. E.g., the data loading worker loop is
|
||||||
in `./_utils/worker.py`.
|
in `./_utils/worker.py`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
|
@ -1208,7 +1209,10 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
||||||
atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
|
atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
|
||||||
|
|
||||||
# .pid can be None only before process is spawned (not the case, so ignore)
|
# .pid can be None only before process is spawned (not the case, so ignore)
|
||||||
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc]
|
_utils.signal_handling._set_worker_pids(
|
||||||
|
id(self),
|
||||||
|
tuple(w.pid for w in self._workers), # type: ignore[misc]
|
||||||
|
)
|
||||||
_utils.signal_handling._set_SIGCHLD_handler()
|
_utils.signal_handling._set_SIGCHLD_handler()
|
||||||
self._worker_pids_set = True
|
self._worker_pids_set = True
|
||||||
self._reset(loader, first_iter=True)
|
self._reset(loader, first_iter=True)
|
||||||
|
|
|
||||||
|
|
@ -109,8 +109,7 @@ class non_deterministic:
|
||||||
|
|
||||||
# Decorate with a functional argument
|
# Decorate with a functional argument
|
||||||
if not (
|
if not (
|
||||||
isinstance(args[0], type)
|
isinstance(args[0], type) and issubclass(args[0], IterDataPipe) # type: ignore[arg-type]
|
||||||
and issubclass(args[0], IterDataPipe) # type: ignore[arg-type]
|
|
||||||
):
|
):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Only `IterDataPipe` can be decorated, but {args[0].__name__} is found"
|
f"Only `IterDataPipe` can be decorated, but {args[0].__name__} is found"
|
||||||
|
|
|
||||||
|
|
@ -99,7 +99,9 @@ class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
|
||||||
>>> from torchdata.datapipes.iter import IterableWrapper, Mapper
|
>>> from torchdata.datapipes.iter import IterableWrapper, Mapper
|
||||||
>>> dp = IterableWrapper(range(10))
|
>>> dp = IterableWrapper(range(10))
|
||||||
>>> map_dp_1 = Mapper(dp, lambda x: x + 1) # Using class constructor
|
>>> map_dp_1 = Mapper(dp, lambda x: x + 1) # Using class constructor
|
||||||
>>> map_dp_2 = dp.map(lambda x: x + 1) # Using functional form (recommended)
|
>>> map_dp_2 = dp.map(
|
||||||
|
... lambda x: x + 1
|
||||||
|
... ) # Using functional form (recommended)
|
||||||
>>> list(map_dp_1)
|
>>> list(map_dp_1)
|
||||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||||
>>> list(map_dp_2)
|
>>> list(map_dp_2)
|
||||||
|
|
@ -114,7 +116,9 @@ class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
|
||||||
>>> list(it1)
|
>>> list(it1)
|
||||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||||
>>> it1 = iter(source_dp)
|
>>> it1 = iter(source_dp)
|
||||||
>>> it2 = iter(source_dp) # The creation of a new iterator invalidates `it1`
|
>>> it2 = iter(
|
||||||
|
... source_dp
|
||||||
|
... ) # The creation of a new iterator invalidates `it1`
|
||||||
>>> next(it2)
|
>>> next(it2)
|
||||||
0
|
0
|
||||||
>>> next(it1) # Further usage of `it1` will raise a `RunTimeError`
|
>>> next(it1) # Further usage of `it1` will raise a `RunTimeError`
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,8 @@ class MapperIterDataPipe(IterDataPipe[_T_co]):
|
||||||
>>> def add_one(x):
|
>>> def add_one(x):
|
||||||
... return x + 1
|
... return x + 1
|
||||||
>>> dp = IterableWrapper(range(10))
|
>>> dp = IterableWrapper(range(10))
|
||||||
>>> map_dp_1 = dp.map(add_one) # Invocation via functional form is preferred
|
>>> # Invocation via functional form is preferred
|
||||||
|
... map_dp_1 = dp.map(add_one)
|
||||||
>>> list(map_dp_1)
|
>>> list(map_dp_1)
|
||||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||||
>>> # We discourage the usage of `lambda` functions as they are not serializable with `pickle`
|
>>> # We discourage the usage of `lambda` functions as they are not serializable with `pickle`
|
||||||
|
|
@ -202,7 +203,7 @@ class CollatorIterDataPipe(MapperIterDataPipe):
|
||||||
>>> class MyIterDataPipe(torch.utils.data.IterDataPipe):
|
>>> class MyIterDataPipe(torch.utils.data.IterDataPipe):
|
||||||
... def __init__(self, start, end):
|
... def __init__(self, start, end):
|
||||||
... super(MyIterDataPipe).__init__()
|
... super(MyIterDataPipe).__init__()
|
||||||
... assert end > start, "this example code only works with end >= start"
|
... assert end > start, "this example only works with end >= start"
|
||||||
... self.start = start
|
... self.start = start
|
||||||
... self.end = end
|
... self.end = end
|
||||||
...
|
...
|
||||||
|
|
@ -211,13 +212,11 @@ class CollatorIterDataPipe(MapperIterDataPipe):
|
||||||
...
|
...
|
||||||
... def __len__(self):
|
... def __len__(self):
|
||||||
... return self.end - self.start
|
... return self.end - self.start
|
||||||
...
|
|
||||||
>>> ds = MyIterDataPipe(start=3, end=7)
|
>>> ds = MyIterDataPipe(start=3, end=7)
|
||||||
>>> print(list(ds))
|
>>> print(list(ds))
|
||||||
[3, 4, 5, 6]
|
[3, 4, 5, 6]
|
||||||
>>> def collate_fn(batch):
|
>>> def collate_fn(batch):
|
||||||
... return torch.tensor(batch, dtype=torch.float)
|
... return torch.tensor(batch, dtype=torch.float)
|
||||||
...
|
|
||||||
>>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn)
|
>>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn)
|
||||||
>>> print(list(collated_ds))
|
>>> print(list(collated_ds))
|
||||||
[tensor(3.), tensor(4.), tensor(5.), tensor(6.)]
|
[tensor(3.), tensor(4.), tensor(5.), tensor(6.)]
|
||||||
|
|
|
||||||
|
|
@ -38,15 +38,17 @@ class SamplerIterDataPipe(IterDataPipe[_T_co]):
|
||||||
sampler_args: Optional[tuple] = None,
|
sampler_args: Optional[tuple] = None,
|
||||||
sampler_kwargs: Optional[dict] = None,
|
sampler_kwargs: Optional[dict] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert isinstance(
|
assert isinstance(datapipe, Sized), (
|
||||||
datapipe, Sized
|
"Sampler class requires input datapipe implemented `__len__`"
|
||||||
), "Sampler class requires input datapipe implemented `__len__`"
|
)
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.datapipe = datapipe
|
self.datapipe = datapipe
|
||||||
self.sampler_args = () if sampler_args is None else sampler_args
|
self.sampler_args = () if sampler_args is None else sampler_args
|
||||||
self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs
|
self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs
|
||||||
# https://github.com/python/mypy/pull/9629 will solve
|
# https://github.com/python/mypy/pull/9629 will solve
|
||||||
self.sampler = sampler(*self.sampler_args, data_source=self.datapipe, **self.sampler_kwargs) # type: ignore[misc]
|
self.sampler = sampler(
|
||||||
|
*self.sampler_args, data_source=self.datapipe, **self.sampler_kwargs
|
||||||
|
) # type: ignore[misc]
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[_T_co]:
|
def __iter__(self) -> Iterator[_T_co]:
|
||||||
return iter(self.sampler)
|
return iter(self.sampler)
|
||||||
|
|
|
||||||
|
|
@ -116,16 +116,13 @@ class _ContainerTemplate(ABC):
|
||||||
r"""Abstract class for container ``DataPipes``. The followings are three required methods."""
|
r"""Abstract class for container ``DataPipes``. The followings are three required methods."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_next_element_by_instance(self, instance_id: int):
|
def get_next_element_by_instance(self, instance_id: int): ...
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def is_every_instance_exhausted(self) -> bool:
|
def is_every_instance_exhausted(self) -> bool: ...
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def reset(self) -> None:
|
def reset(self) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_length_by_instance(self, instance_id: int):
|
def get_length_by_instance(self, instance_id: int):
|
||||||
|
|
@ -403,7 +400,9 @@ class DemultiplexerIterDataPipe(IterDataPipe):
|
||||||
>>> # It can also filter out any element that gets `None` from the `classifier_fn`
|
>>> # It can also filter out any element that gets `None` from the `classifier_fn`
|
||||||
>>> def odd_or_even_no_zero(n):
|
>>> def odd_or_even_no_zero(n):
|
||||||
... return n % 2 if n != 0 else None
|
... return n % 2 if n != 0 else None
|
||||||
>>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True)
|
>>> dp1, dp2 = source_dp.demux(
|
||||||
|
... num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True
|
||||||
|
... )
|
||||||
>>> list(dp1)
|
>>> list(dp1)
|
||||||
[2, 4]
|
[2, 4]
|
||||||
>>> list(dp2)
|
>>> list(dp2)
|
||||||
|
|
@ -428,7 +427,9 @@ class DemultiplexerIterDataPipe(IterDataPipe):
|
||||||
# When num_instances == 1, demux can be replaced by filter,
|
# When num_instances == 1, demux can be replaced by filter,
|
||||||
# but keep it as Demultiplexer for the sake of consistency
|
# but keep it as Demultiplexer for the sake of consistency
|
||||||
# like throwing Error when classification result is out of o range
|
# like throwing Error when classification result is out of o range
|
||||||
container = _DemultiplexerIterDataPipe(datapipe, num_instances, classifier_fn, drop_none, buffer_size) # type: ignore[abstract]
|
container = _DemultiplexerIterDataPipe(
|
||||||
|
datapipe, num_instances, classifier_fn, drop_none, buffer_size
|
||||||
|
) # type: ignore[abstract]
|
||||||
return [_ChildDataPipe(container, i) for i in range(num_instances)]
|
return [_ChildDataPipe(container, i) for i in range(num_instances)]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -602,16 +603,18 @@ class MultiplexerIterDataPipe(IterDataPipe):
|
||||||
Example:
|
Example:
|
||||||
>>> # xdoctest: +REQUIRES(module:torchdata)
|
>>> # xdoctest: +REQUIRES(module:torchdata)
|
||||||
>>> from torchdata.datapipes.iter import IterableWrapper
|
>>> from torchdata.datapipes.iter import IterableWrapper
|
||||||
>>> dp1, dp2, dp3 = IterableWrapper(range(3)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25))
|
>>> dp1, dp2, dp3 = (
|
||||||
|
... IterableWrapper(range(3)),
|
||||||
|
... IterableWrapper(range(10, 15)),
|
||||||
|
... IterableWrapper(range(20, 25)),
|
||||||
|
... )
|
||||||
>>> list(dp1.mux(dp2, dp3))
|
>>> list(dp1.mux(dp2, dp3))
|
||||||
[0, 10, 20, 1, 11, 21, 2, 12, 22]
|
[0, 10, 20, 1, 11, 21, 2, 12, 22]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *datapipes):
|
def __init__(self, *datapipes):
|
||||||
self.datapipes = datapipes
|
self.datapipes = datapipes
|
||||||
self.buffer: list = (
|
self.buffer: list = [] # Store values to be yielded only when every iterator provides one
|
||||||
[]
|
|
||||||
) # Store values to be yielded only when every iterator provides one
|
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
iterators = [iter(x) for x in self.datapipes]
|
iterators = [iter(x) for x in self.datapipes]
|
||||||
|
|
@ -670,7 +673,11 @@ class ZipperIterDataPipe(IterDataPipe[tuple[_T_co]]):
|
||||||
Example:
|
Example:
|
||||||
>>> # xdoctest: +REQUIRES(module:torchdata)
|
>>> # xdoctest: +REQUIRES(module:torchdata)
|
||||||
>>> from torchdata.datapipes.iter import IterableWrapper
|
>>> from torchdata.datapipes.iter import IterableWrapper
|
||||||
>>> dp1, dp2, dp3 = IterableWrapper(range(5)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25))
|
>>> dp1, dp2, dp3 = (
|
||||||
|
... IterableWrapper(range(5)),
|
||||||
|
... IterableWrapper(range(10, 15)),
|
||||||
|
... IterableWrapper(range(20, 25)),
|
||||||
|
... )
|
||||||
>>> list(dp1.zip(dp2, dp3))
|
>>> list(dp1.zip(dp2, dp3))
|
||||||
[(0, 10, 20), (1, 11, 21), (2, 12, 22), (3, 13, 23), (4, 14, 24)]
|
[(0, 10, 20), (1, 11, 21), (2, 12, 22), (3, 13, 23), (4, 14, 24)]
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -33,8 +33,12 @@ class FileOpenerIterDataPipe(IterDataPipe[tuple[str, IOBase]]):
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> # xdoctest: +SKIP
|
>>> # xdoctest: +SKIP
|
||||||
>>> from torchdata.datapipes.iter import FileLister, FileOpener, StreamReader
|
>>> from torchdata.datapipes.iter import (
|
||||||
>>> dp = FileLister(root=".").filter(lambda fname: fname.endswith('.txt'))
|
... FileLister,
|
||||||
|
... FileOpener,
|
||||||
|
... StreamReader,
|
||||||
|
... )
|
||||||
|
>>> dp = FileLister(root=".").filter(lambda fname: fname.endswith(".txt"))
|
||||||
>>> dp = FileOpener(dp)
|
>>> dp = FileOpener(dp)
|
||||||
>>> dp = StreamReader(dp)
|
>>> dp = StreamReader(dp)
|
||||||
>>> list(dp)
|
>>> list(dp)
|
||||||
|
|
|
||||||
|
|
@ -182,7 +182,9 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]):
|
||||||
>>> from torchdata.datapipes.iter import IterableWrapper
|
>>> from torchdata.datapipes.iter import IterableWrapper
|
||||||
>>> def group_fn(file):
|
>>> def group_fn(file):
|
||||||
... return os.path.basename(file).split(".")[0]
|
... return os.path.basename(file).split(".")[0]
|
||||||
>>> source_dp = IterableWrapper(["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"])
|
>>> source_dp = IterableWrapper(
|
||||||
|
... ["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"]
|
||||||
|
... )
|
||||||
>>> dp0 = source_dp.groupby(group_key_fn=group_fn)
|
>>> dp0 = source_dp.groupby(group_key_fn=group_fn)
|
||||||
>>> list(dp0)
|
>>> list(dp0)
|
||||||
[['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']]
|
[['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']]
|
||||||
|
|
@ -191,7 +193,12 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]):
|
||||||
>>> list(dp1)
|
>>> list(dp1)
|
||||||
[['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
|
[['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
|
||||||
>>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size`
|
>>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size`
|
||||||
>>> dp2 = source_dp.groupby(group_key_fn=group_fn, buffer_size=3, group_size=3, guaranteed_group_size=2)
|
>>> dp2 = source_dp.groupby(
|
||||||
|
... group_key_fn=group_fn,
|
||||||
|
... buffer_size=3,
|
||||||
|
... group_size=3,
|
||||||
|
... guaranteed_group_size=2,
|
||||||
|
... )
|
||||||
>>> list(dp2)
|
>>> list(dp2)
|
||||||
[['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
|
[['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -31,8 +31,8 @@ class SequenceWrapperMapDataPipe(MapDataPipe[_T]):
|
||||||
>>> dp = SequenceWrapper(range(10))
|
>>> dp = SequenceWrapper(range(10))
|
||||||
>>> list(dp)
|
>>> list(dp)
|
||||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||||
>>> dp = SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400})
|
>>> dp = SequenceWrapper({"a": 100, "b": 200, "c": 300, "d": 400})
|
||||||
>>> dp['a']
|
>>> dp["a"]
|
||||||
100
|
100
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -45,8 +45,8 @@ def basichandlers(extension: str, data):
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> import pickle
|
>>> import pickle
|
||||||
>>> data = pickle.dumps('some data')
|
>>> data = pickle.dumps("some data")
|
||||||
>>> new_data = basichandlers('pickle', data)
|
>>> new_data = basichandlers("pickle", data)
|
||||||
>>> new_data
|
>>> new_data
|
||||||
some data
|
some data
|
||||||
|
|
||||||
|
|
@ -169,9 +169,9 @@ class ImageHandler:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, imagespec):
|
def __init__(self, imagespec):
|
||||||
assert imagespec in list(
|
assert imagespec in list(imagespecs.keys()), (
|
||||||
imagespecs.keys()
|
f"unknown image specification: {imagespec}"
|
||||||
), f"unknown image specification: {imagespec}"
|
)
|
||||||
self.imagespec = imagespec.lower()
|
self.imagespec = imagespec.lower()
|
||||||
|
|
||||||
def __call__(self, extension, data):
|
def __call__(self, extension, data):
|
||||||
|
|
@ -205,18 +205,18 @@ class ImageHandler:
|
||||||
return img
|
return img
|
||||||
elif atype == "numpy":
|
elif atype == "numpy":
|
||||||
result = np.asarray(img)
|
result = np.asarray(img)
|
||||||
assert (
|
assert result.dtype == np.uint8, (
|
||||||
result.dtype == np.uint8
|
f"numpy image array should be type uint8, but got {result.dtype}"
|
||||||
), f"numpy image array should be type uint8, but got {result.dtype}"
|
)
|
||||||
if etype == "uint8":
|
if etype == "uint8":
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
return result.astype("f") / 255.0
|
return result.astype("f") / 255.0
|
||||||
elif atype == "torch":
|
elif atype == "torch":
|
||||||
result = np.asarray(img)
|
result = np.asarray(img)
|
||||||
assert (
|
assert result.dtype == np.uint8, (
|
||||||
result.dtype == np.uint8
|
f"numpy image array should be type uint8, but got {result.dtype}"
|
||||||
), f"numpy image array should be type uint8, but got {result.dtype}"
|
)
|
||||||
|
|
||||||
if etype == "uint8":
|
if etype == "uint8":
|
||||||
result = np.array(result.transpose(2, 0, 1))
|
result = np.array(result.transpose(2, 0, 1))
|
||||||
|
|
|
||||||
|
|
@ -96,7 +96,7 @@ class IterableDataset(Dataset[_T_co], Iterable[_T_co]):
|
||||||
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
|
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
|
||||||
... def __init__(self, start, end):
|
... def __init__(self, start, end):
|
||||||
... super(MyIterableDataset).__init__()
|
... super(MyIterableDataset).__init__()
|
||||||
... assert end > start, "this example code only works with end >= start"
|
... assert end > start, "this example only works with end >= start"
|
||||||
... self.start = start
|
... self.start = start
|
||||||
... self.end = end
|
... self.end = end
|
||||||
...
|
...
|
||||||
|
|
@ -138,7 +138,7 @@ class IterableDataset(Dataset[_T_co], Iterable[_T_co]):
|
||||||
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
|
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
|
||||||
... def __init__(self, start, end):
|
... def __init__(self, start, end):
|
||||||
... super(MyIterableDataset).__init__()
|
... super(MyIterableDataset).__init__()
|
||||||
... assert end > start, "this example code only works with end >= start"
|
... assert end > start, "this example only works with end >= start"
|
||||||
... self.start = start
|
... self.start = start
|
||||||
... self.end = end
|
... self.end = end
|
||||||
...
|
...
|
||||||
|
|
@ -198,9 +198,9 @@ class TensorDataset(Dataset[tuple[Tensor, ...]]):
|
||||||
tensors: tuple[Tensor, ...]
|
tensors: tuple[Tensor, ...]
|
||||||
|
|
||||||
def __init__(self, *tensors: Tensor) -> None:
|
def __init__(self, *tensors: Tensor) -> None:
|
||||||
assert all(
|
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), (
|
||||||
tensors[0].size(0) == tensor.size(0) for tensor in tensors
|
"Size mismatch between tensors"
|
||||||
), "Size mismatch between tensors"
|
)
|
||||||
self.tensors = tensors
|
self.tensors = tensors
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
|
|
@ -222,7 +222,7 @@ class StackDataset(Dataset[_T_stack]):
|
||||||
>>> tuple_stack = StackDataset(images, texts)
|
>>> tuple_stack = StackDataset(images, texts)
|
||||||
>>> tuple_stack[0] == (images[0], texts[0])
|
>>> tuple_stack[0] == (images[0], texts[0])
|
||||||
>>> dict_stack = StackDataset(image=images, text=texts)
|
>>> dict_stack = StackDataset(image=images, text=texts)
|
||||||
>>> dict_stack[0] == {'image': images[0], 'text': texts[0]}
|
>>> dict_stack[0] == {"image": images[0], "text": texts[0]}
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
*args (Dataset): Datasets for stacking returned as tuple.
|
*args (Dataset): Datasets for stacking returned as tuple.
|
||||||
|
|
@ -323,9 +323,9 @@ class ConcatDataset(Dataset[_T_co]):
|
||||||
self.datasets = list(datasets)
|
self.datasets = list(datasets)
|
||||||
assert len(self.datasets) > 0, "datasets should not be an empty iterable" # type: ignore[arg-type]
|
assert len(self.datasets) > 0, "datasets should not be an empty iterable" # type: ignore[arg-type]
|
||||||
for d in self.datasets:
|
for d in self.datasets:
|
||||||
assert not isinstance(
|
assert not isinstance(d, IterableDataset), (
|
||||||
d, IterableDataset
|
"ConcatDataset does not support IterableDataset"
|
||||||
), "ConcatDataset does not support IterableDataset"
|
)
|
||||||
self.cumulative_sizes = self.cumsum(self.datasets)
|
self.cumulative_sizes = self.cumsum(self.datasets)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
|
@ -371,17 +371,17 @@ class ChainDataset(IterableDataset):
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
for d in self.datasets:
|
for d in self.datasets:
|
||||||
assert isinstance(
|
assert isinstance(d, IterableDataset), (
|
||||||
d, IterableDataset
|
"ChainDataset only supports IterableDataset"
|
||||||
), "ChainDataset only supports IterableDataset"
|
)
|
||||||
yield from d
|
yield from d
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
total = 0
|
total = 0
|
||||||
for d in self.datasets:
|
for d in self.datasets:
|
||||||
assert isinstance(
|
assert isinstance(d, IterableDataset), (
|
||||||
d, IterableDataset
|
"ChainDataset only supports IterableDataset"
|
||||||
), "ChainDataset only supports IterableDataset"
|
)
|
||||||
total += len(d) # type: ignore[arg-type]
|
total += len(d) # type: ignore[arg-type]
|
||||||
return total
|
return total
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -236,9 +236,17 @@ class WeightedRandomSampler(Sampler[int]):
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||||
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
|
>>> list(
|
||||||
|
... WeightedRandomSampler(
|
||||||
|
... [0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True
|
||||||
|
... )
|
||||||
|
... )
|
||||||
[4, 4, 1, 4, 5]
|
[4, 4, 1, 4, 5]
|
||||||
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
|
>>> list(
|
||||||
|
... WeightedRandomSampler(
|
||||||
|
... [0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False
|
||||||
|
... )
|
||||||
|
... )
|
||||||
[0, 1, 4, 3, 2]
|
[0, 1, 4, 3, 2]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -298,9 +306,15 @@ class BatchSampler(Sampler[list[int]]):
|
||||||
its size would be less than ``batch_size``
|
its size would be less than ``batch_size``
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
|
>>> list(
|
||||||
|
... BatchSampler(
|
||||||
|
... SequentialSampler(range(10)), batch_size=3, drop_last=False
|
||||||
|
... )
|
||||||
|
... )
|
||||||
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
|
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
|
||||||
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
|
>>> list(
|
||||||
|
... BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)
|
||||||
|
... )
|
||||||
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
|
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,7 @@ class ModuleTracker:
|
||||||
def my_linear(m1, m2, bias):
|
def my_linear(m1, m2, bias):
|
||||||
print(f"Current modules: {tracker.parents}")
|
print(f"Current modules: {tracker.parents}")
|
||||||
return torch.mm(m1, m2.t()) + bias
|
return torch.mm(m1, m2.t()) + bias
|
||||||
|
|
||||||
torch.nn.functional.linear = my_linear
|
torch.nn.functional.linear = my_linear
|
||||||
|
|
||||||
mod(torch.rand(2, 2))
|
mod(torch.rand(2, 2))
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ Intel GPU optimization.
|
||||||
This package is lazily initialized, so you can always import it, and use
|
This package is lazily initialized, so you can always import it, and use
|
||||||
:func:`is_available()` to determine if your system supports XPU.
|
:func:`is_available()` to determine if your system supports XPU.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
import traceback
|
import traceback
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
@ -292,6 +293,7 @@ class StreamContext:
|
||||||
``None``.
|
``None``.
|
||||||
.. note:: Streams are per-device.
|
.. note:: Streams are per-device.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cur_stream: Optional["torch.xpu.Stream"]
|
cur_stream: Optional["torch.xpu.Stream"]
|
||||||
|
|
||||||
def __init__(self, stream: Optional["torch.xpu.Stream"]):
|
def __init__(self, stream: Optional["torch.xpu.Stream"]):
|
||||||
|
|
@ -438,7 +440,7 @@ def get_gencode_flags() -> str:
|
||||||
arch_list = get_arch_list()
|
arch_list = get_arch_list()
|
||||||
if len(arch_list) == 0:
|
if len(arch_list) == 0:
|
||||||
return ""
|
return ""
|
||||||
return f'-device {",".join(arch for arch in arch_list)}'
|
return f"-device {','.join(arch for arch in arch_list)}"
|
||||||
|
|
||||||
|
|
||||||
def _get_generator(device: torch.device) -> torch._C.Generator:
|
def _get_generator(device: torch.device) -> torch._C.Generator:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user