Enable xdoctest runner in CI for real this time (#83816)

Builds on #83317 and enables running the doctests. Just need to figure out what is causing the failures.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83816
Approved by: https://github.com/ezyang, https://github.com/malfet
This commit is contained in:
joncrall 2022-12-29 05:32:42 +00:00 committed by PyTorch MergeBot
parent fb4fc0dabe
commit ad782ff7df
90 changed files with 456 additions and 262 deletions

View File

@ -179,9 +179,9 @@ pytest-rerunfailures
#Pinned versions: #Pinned versions:
#test that import: #test that import:
xdoctest==1.0.2 xdoctest==1.1.0
#Description: runs doctests in pytest #Description: runs doctests in pytest
#Pinned versions: 1.0.2 #Pinned versions: 1.1.0
#test that import: #test that import:
pygments==2.12.0 pygments==2.12.0

View File

@ -19,4 +19,4 @@ pytest-shard==0.1.2
scipy==1.9.0 scipy==1.9.0
sympy==1.11.1 sympy==1.11.1
unittest-xml-reporting<=3.2.0,>=2.0.0 unittest-xml-reporting<=3.2.0,>=2.0.0
xdoctest==1.0.2 xdoctest==1.1.0

View File

@ -4,9 +4,12 @@ This script simply runs the torch doctests via the xdoctest runner.
This must be run from the root of the torch repo, as it needs the path to the This must be run from the root of the torch repo, as it needs the path to the
torch source code. torch source code.
"
#xdoctest -m torch --style=google list This script is provided as a developer convenience. On the CI the doctests are
invoked in 'run_test.py'
"
# To simply list tests
# xdoctest -m torch --style=google list
# Reference: https://stackoverflow.com/questions/59895/bash-script-dir # Reference: https://stackoverflow.com/questions/59895/bash-script-dir
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
@ -16,14 +19,10 @@ echo "TORCH_MODPATH = $TORCH_MODPATH"
if [[ ! -d "$TORCH_MODPATH" ]] ; then if [[ ! -d "$TORCH_MODPATH" ]] ; then
echo "Could not find the path to the torch module" echo "Could not find the path to the torch module"
else else
# Next version of xdoctest will support environment variables that overlo
export XDOCTEST_GLOBAL_EXEC="from torch import nn\nimport torch.nn.functional as F\nimport torch" export XDOCTEST_GLOBAL_EXEC="from torch import nn\nimport torch.nn.functional as F\nimport torch"
export XDOCTEST_OPTIONS="+IGNORE_WHITESPACE" export XDOCTEST_OPTIONS="+IGNORE_WHITESPACE"
# Note: google wont catch numpy style docstrings (a few exist) but it also wont fail # Note: google wont catch numpy style docstrings (a few exist) but it also wont fail
# on things not intended to be doctests. # on things not intended to be doctests.
export XDOCTEST_STYLE="google" export XDOCTEST_STYLE="google"
xdoctest "$TORCH_MODPATH" --style="$XDOCTEST_STYLE" --global-exec "$XDOCTEST_GLOBAL_EXEC" --options="$XDOCTEST_OPTIONS" xdoctest torch "$TORCH_MODPATH" --style="$XDOCTEST_STYLE" --global-exec "$XDOCTEST_GLOBAL_EXEC" --options="$XDOCTEST_OPTIONS"
fi fi

View File

@ -659,10 +659,9 @@ def run_doctests(test_module, test_directory, options):
import pathlib import pathlib
pkgpath = pathlib.Path(torch.__file__).parent pkgpath = pathlib.Path(torch.__file__).parent
#
enabled = { enabled = {
# TODO: expose these options to the user # TODO: expose these options to the user
# Temporary disable all feature-conditional tests # For now disable all feature-conditional tests
# 'lapack': 'auto', # 'lapack': 'auto',
# 'cuda': 'auto', # 'cuda': 'auto',
# 'cuda1': 'auto', # 'cuda1': 'auto',
@ -671,6 +670,9 @@ def run_doctests(test_module, test_directory, options):
'cuda': 0, 'cuda': 0,
'cuda1': 0, 'cuda1': 0,
'qengine': 0, 'qengine': 0,
'autograd_profiler': 0,
'cpp_ext': 0,
'monitor': 0,
} }
# Resolve "auto" based on a test to determine if the feature is available. # Resolve "auto" based on a test to determine if the feature is available.
@ -707,13 +709,34 @@ def run_doctests(test_module, test_directory, options):
if enabled['qengine']: if enabled['qengine']:
os.environ['TORCH_DOCTEST_QENGINE'] = '1' os.environ['TORCH_DOCTEST_QENGINE'] = '1'
if enabled['autograd_profiler']:
os.environ['TORCH_DOCTEST_AUTOGRAD_PROFILER'] = '1'
if enabled['cpp_ext']:
os.environ['TORCH_DOCTEST_CPP_EXT'] = '1'
if enabled['monitor']:
os.environ['TORCH_DOCTEST_MONITOR'] = '1'
if 0:
# TODO: could try to enable some of these
os.environ['TORCH_DOCTEST_QUANTIZED_DYNAMIC'] = '1'
os.environ['TORCH_DOCTEST_ANOMOLY'] = '1'
os.environ['TORCH_DOCTEST_AUTOGRAD'] = '1'
os.environ['TORCH_DOCTEST_HUB'] = '1'
os.environ['TORCH_DOCTEST_DATALOADER'] = '1'
os.environ['TORCH_DOCTEST_ONNX'] = '1'
os.environ['TORCH_DOCTEST_FUTURES'] = '1'
pkgpath = os.path.dirname(torch.__file__) pkgpath = os.path.dirname(torch.__file__)
xdoctest_config = { xdoctest_config = {
'global_exec': r'\n'.join([ 'global_exec': r'\n'.join([
'from torch import nn', 'from torch import nn',
'import torch.nn.functional as F', 'import torch.nn.functional as F',
'import torch', 'import torch',
]), ]),
'analysis': 'static', # set to "auto" to test doctests in compiled modules
'style': 'google', 'style': 'google',
'options': '+IGNORE_WHITESPACE', 'options': '+IGNORE_WHITESPACE',
} }
@ -1016,7 +1039,7 @@ def parse_args():
) )
parser.add_argument( parser.add_argument(
"--xdoctest-command", "--xdoctest-command",
default='list', default='all',
help=( help=(
"Control the specific doctest action. " "Control the specific doctest action. "
"Use 'list' to simply parse doctests and check syntax. " "Use 'list' to simply parse doctests and check syntax. "

View File

@ -427,7 +427,7 @@ def is_tensor(obj):
obj (Object): Object to test obj (Object): Object to test
Example:: Example::
>>> x=torch.tensor([1,2,3]) >>> x = torch.tensor([1, 2, 3])
>>> torch.is_tensor(x) >>> torch.is_tensor(x)
True True
@ -627,10 +627,10 @@ def use_deterministic_algorithms(mode, *, warn_only=False):
Example:: Example::
>>> # xdoctest: +SKIP
>>> torch.use_deterministic_algorithms(True) >>> torch.use_deterministic_algorithms(True)
# Forward mode nondeterministic error # Forward mode nondeterministic error
>>> # xdoctest: +SKIP
>>> torch.randn(10, device='cuda').kthvalue(0) >>> torch.randn(10, device='cuda').kthvalue(0)
... ...
RuntimeError: kthvalue CUDA does not have a deterministic implementation... RuntimeError: kthvalue CUDA does not have a deterministic implementation...

View File

@ -251,6 +251,7 @@ def vjp(func: Callable, *primals, has_aux: bool = False):
Case 2: Using ``vjp`` inside ``torch.no_grad`` context manager: Case 2: Using ``vjp`` inside ``torch.no_grad`` context manager:
>>> # xdoctest: +SKIP(failing)
>>> with torch.no_grad(): >>> with torch.no_grad():
>>> vjp(f)(x) >>> vjp(f)(x)
@ -1286,6 +1287,7 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla
Example of using ``grad``: Example of using ``grad``:
>>> # xdoctest: +SKIP
>>> from torch.func import grad >>> from torch.func import grad
>>> x = torch.randn([]) >>> x = torch.randn([])
>>> cos_x = grad(lambda x: torch.sin(x))(x) >>> cos_x = grad(lambda x: torch.sin(x))(x)
@ -1297,6 +1299,7 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla
When composed with ``vmap``, ``grad`` can be used to compute per-sample-gradients: When composed with ``vmap``, ``grad`` can be used to compute per-sample-gradients:
>>> # xdoctest: +SKIP
>>> from torch.func import grad, vmap >>> from torch.func import grad, vmap
>>> batch_size, feature_size = 3, 5 >>> batch_size, feature_size = 3, 5
>>> >>>
@ -1317,6 +1320,7 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla
Example of using ``grad`` with ``has_aux`` and ``argnums``: Example of using ``grad`` with ``has_aux`` and ``argnums``:
>>> # xdoctest: +SKIP
>>> from torch.func import grad >>> from torch.func import grad
>>> def my_loss_func(y, y_pred): >>> def my_loss_func(y, y_pred):
>>> loss_per_sample = (0.5 * y_pred - y) ** 2 >>> loss_per_sample = (0.5 * y_pred - y) ** 2
@ -1327,13 +1331,14 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla
>>> y_true = torch.rand(4) >>> y_true = torch.rand(4)
>>> y_preds = torch.rand(4, requires_grad=True) >>> y_preds = torch.rand(4, requires_grad=True)
>>> out = fn(y_true, y_preds) >>> out = fn(y_true, y_preds)
>>> > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample)) >>> # > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample))
.. note:: .. note::
Using PyTorch ``torch.no_grad`` together with ``grad``. Using PyTorch ``torch.no_grad`` together with ``grad``.
Case 1: Using ``torch.no_grad`` inside a function: Case 1: Using ``torch.no_grad`` inside a function:
>>> # xdoctest: +SKIP
>>> def f(x): >>> def f(x):
>>> with torch.no_grad(): >>> with torch.no_grad():
>>> c = x ** 2 >>> c = x ** 2
@ -1343,6 +1348,7 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla
Case 2: Using ``grad`` inside ``torch.no_grad`` context manager: Case 2: Using ``grad`` inside ``torch.no_grad`` context manager:
>>> # xdoctest: +SKIP
>>> with torch.no_grad(): >>> with torch.no_grad():
>>> grad(f)(x) >>> grad(f)(x)
@ -1433,11 +1439,12 @@ def functionalize(func: Callable, *, remove: str = 'mutations') -> Callable:
Example:: Example::
>>> # xdoctest: +SKIP
>>> import torch >>> import torch
>>> from torch.fx.experimental.proxy_tensor import make_fx >>> from torch.fx.experimental.proxy_tensor import make_fx
>>> from torch.func import functionalize >>> from torch.func import functionalize
>>> >>>
>>> A function that uses mutations and views, but only on intermediate tensors. >>> # A function that uses mutations and views, but only on intermediate tensors.
>>> def f(a): >>> def f(a):
... b = a + 1 ... b = a + 1
... c = b.view(-1) ... c = b.view(-1)
@ -1490,17 +1497,17 @@ def functionalize(func: Callable, *, remove: str = 'mutations') -> Callable:
return view_copy_1 return view_copy_1
>>> A function that mutates its input tensor >>> # A function that mutates its input tensor
>>> def f(a): >>> def f(a):
... b = a.view(-1) ... b = a.view(-1)
... b.add_(1) ... b.add_(1)
... return a ... return a
... ...
>>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt) >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt)
>>> >>> #
>>> All mutations and views have been removed, >>> # All mutations and views have been removed,
>>> but there is an extra copy_ in the graph to correctly apply the mutation to the input >>> # but there is an extra copy_ in the graph to correctly apply the mutation to the input
>>> after the function has completed. >>> # after the function has completed.
>>> print(f_no_mutations_and_views_traced.code) >>> print(f_no_mutations_and_views_traced.code)

View File

@ -69,6 +69,7 @@ def minifier(fail_f: fx.GraphModule, inps, module_fails, dump_state: Callable =
2. Delta Debugging: Tries replacing half of the graph with inputs. If fails, 2. Delta Debugging: Tries replacing half of the graph with inputs. If fails,
tries replacing quarter of the graph, etc. tries replacing quarter of the graph, etc.
>>> # xdoctest: +SKIP(failing)
>>> failing_function = fx.symbolic_trace(f) >>> failing_function = fx.symbolic_trace(f)
>>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps)) >>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps))

View File

@ -122,10 +122,12 @@ def update_names(tensor, names, rename_map, inplace):
For example, For example,
``` ```
>>> # xdoctest: +SKIP
>>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W')) >>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
>>> x.rename('...', 'height', 'width').names >>> x.rename('...', 'height', 'width').names
('N', 'C', 'height', 'width') ('N', 'C', 'height', 'width')
>>> # xdoctest: +SKIP
>>> x.rename('batch', '...', 'width').names >>> x.rename('batch', '...', 'width').names
('batch', 'C', 'H', 'width') ('batch', 'C', 'H', 'width')
@ -136,6 +138,7 @@ def update_names(tensor, names, rename_map, inplace):
For example, For example,
``` ```
>>> # xdoctest: +SKIP
>>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W')) >>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
>>> x.rename(W='width', H='height').names >>> x.rename(W='width', H='height').names
('N', 'C', 'height', 'width') ('N', 'C', 'height', 'width')

View File

@ -1496,6 +1496,7 @@ def compute_required_storage_length(
>>> compute_required_storage_length(t.shape, t.stride(), t.storage_offset()) >>> compute_required_storage_length(t.shape, t.stride(), t.storage_offset())
200 200
>>> # xdoctest: +SKIP(failing)
>>> t2 = torch.empty_strided((1, 2, 3), (5, 7, 11)) >>> t2 = torch.empty_strided((1, 2, 3), (5, 7, 11))
>>> size = compute_required_storage_length(t2.shape, t2.stride(), t2.storage_offset()) >>> size = compute_required_storage_length(t2.shape, t2.stride(), t2.storage_offset())
>>> size == t.storage().size() >>> size == t.storage().size()

View File

@ -215,7 +215,6 @@ def _vector_str(self, indent, summarize, formatter1, formatter2=None):
elements_per_line = max( elements_per_line = max(
1, int(math.floor((PRINT_OPTS.linewidth - indent) / (element_length))) 1, int(math.floor((PRINT_OPTS.linewidth - indent) / (element_length)))
) )
# char_per_line = element_length * elements_per_line # unused
def _val_formatter(val, formatter1=formatter1, formatter2=formatter2): def _val_formatter(val, formatter1=formatter1, formatter2=formatter2):
if formatter2 is not None: if formatter2 is not None:

View File

@ -9,6 +9,7 @@ from torch.utils._pytree import _broadcast_to_and_flatten, tree_flatten, tree_un
in_dims_t = Union[int, Tuple] in_dims_t = Union[int, Tuple]
out_dims_t = Union[int, Tuple[int, ...]] out_dims_t = Union[int, Tuple[int, ...]]
# Checks that all args-to-be-batched have the same batch dim size # Checks that all args-to-be-batched have the same batch dim size
def _validate_and_get_batch_size( def _validate_and_get_batch_size(
flat_in_dims: List[Optional[int]], flat_args: List flat_in_dims: List[Optional[int]], flat_args: List

View File

@ -19,9 +19,9 @@ class LinearReLU(nnqd.Linear):
Examples:: Examples::
>>> # xdoctest: +SKIP
>>> m = nn.intrinsic.quantized.dynamic.LinearReLU(20, 30) >>> m = nn.intrinsic.quantized.dynamic.LinearReLU(20, 30)
>>> input = torch.randn(128, 20) >>> input = torch.randn(128, 20)
>>> # xdoctest: +SKIP
>>> output = m(input) >>> output = m(input)
>>> print(output.size()) >>> print(output.size())
torch.Size([128, 30]) torch.Size([128, 30])

View File

@ -56,6 +56,7 @@ class LinearLeakyReLU(nnq.Linear):
Same as torch.nn.quantized.Linear Same as torch.nn.quantized.Linear
+ negative_slope + negative_slope
Examples:: Examples::
>>> # xdoctest: +SKIP
>>> m = nn.intrinsic.LinearLeakyReLU(20, 30, 0.01) >>> m = nn.intrinsic.LinearLeakyReLU(20, 30, 0.01)
>>> input = torch.randn(128, 20) >>> input = torch.randn(128, 20)
>>> output = m(input) >>> output = m(input)

View File

@ -15,6 +15,7 @@ import warnings
__all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d'] __all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d']
class Conv1d(nnq.Conv1d): class Conv1d(nnq.Conv1d):
r"""A dynamically quantized conv module with floating point tensors as inputs and outputs. r"""A dynamically quantized conv module with floating point tensors as inputs and outputs.
@ -31,9 +32,9 @@ class Conv1d(nnq.Conv1d):
Examples:: Examples::
>>> # xdoctest: +SKIP
>>> m = nn.quantized.dynamic.Conv1d(16, 33, 3, stride=2) >>> m = nn.quantized.dynamic.Conv1d(16, 33, 3, stride=2)
>>> input = torch.randn(20, 16, 100) >>> input = torch.randn(20, 16, 100)
>>> # xdoctest: +SKIP
>>> output = m(input) >>> output = m(input)
""" """
@ -102,6 +103,7 @@ class Conv2d(nnq.Conv2d):
Examples:: Examples::
>>> # xdoctest: +SKIP
>>> # With square kernels and equal stride >>> # With square kernels and equal stride
>>> m = nn.quantized.dynamic.Conv2d(16, 33, 3, stride=2) >>> m = nn.quantized.dynamic.Conv2d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding >>> # non-square kernels and unequal stride and with padding
@ -109,7 +111,6 @@ class Conv2d(nnq.Conv2d):
>>> # non-square kernels and unequal stride and with padding and dilation >>> # non-square kernels and unequal stride and with padding and dilation
>>> m = nn.quantized.dynamic.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) >>> m = nn.quantized.dynamic.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
>>> input = torch.randn(20, 16, 50, 100) >>> input = torch.randn(20, 16, 50, 100)
>>> # xdoctest: +SKIP
>>> output = m(input) >>> output = m(input)
""" """
@ -167,6 +168,7 @@ class Conv3d(nnq.Conv3d):
Examples:: Examples::
>>> # xdoctest: +SKIP
>>> # With square kernels and equal stride >>> # With square kernels and equal stride
>>> m = nn.quantized.dynamic.Conv3d(16, 33, 3, stride=2) >>> m = nn.quantized.dynamic.Conv3d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding >>> # non-square kernels and unequal stride and with padding
@ -174,7 +176,6 @@ class Conv3d(nnq.Conv3d):
>>> # non-square kernels and unequal stride and with padding and dilation >>> # non-square kernels and unequal stride and with padding and dilation
>>> m = nn.quantized.dynamic.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), dilation=(1, 2, 2)) >>> m = nn.quantized.dynamic.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), dilation=(1, 2, 2))
>>> input = torch.randn(20, 16, 56, 56, 56) >>> input = torch.randn(20, 16, 56, 56, 56)
>>> # xdoctest: +SKIP
>>> output = m(input) >>> output = m(input)
""" """
@ -233,8 +234,8 @@ class ConvTranspose1d(nnq.ConvTranspose1d):
Examples:: Examples::
>>> # With square kernels and equal stride
>>> # xdoctest: +SKIP >>> # xdoctest: +SKIP
>>> # With square kernels and equal stride
>>> m = nndq.ConvTranspose1d(16, 33, 3, stride=2) >>> m = nndq.ConvTranspose1d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding >>> # non-square kernels and unequal stride and with padding
>>> m = nndq.ConvTranspose1d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) >>> m = nndq.ConvTranspose1d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
@ -294,11 +295,11 @@ class ConvTranspose2d(nnq.ConvTranspose2d):
Examples:: Examples::
>>> # xdoctest: +SKIP
>>> # With square kernels and equal stride >>> # With square kernels and equal stride
>>> m = nnq.ConvTranspose2d(16, 33, 3, stride=2) >>> m = nnq.ConvTranspose2d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding >>> # non-square kernels and unequal stride and with padding
>>> m = nnq.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) >>> m = nnq.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
>>> # xdoctest: +SKIP
>>> output = m(input) >>> output = m(input)
>>> # exact output size can be also specified as an argument >>> # exact output size can be also specified as an argument
>>> downsample = nnq.Conv2d(16, 16, 3, stride=2, padding=1) >>> downsample = nnq.Conv2d(16, 16, 3, stride=2, padding=1)
@ -355,11 +356,11 @@ class ConvTranspose3d(nnq.ConvTranspose3d):
Examples:: Examples::
>>> # xdoctest: +SKIP
>>> # With cubic kernels and equal stride >>> # With cubic kernels and equal stride
>>> m = nnq.ConvTranspose3d(16, 33, 3, stride=2) >>> m = nnq.ConvTranspose3d(16, 33, 3, stride=2)
>>> # non-cubic kernels and unequal stride and with padding >>> # non-cubic kernels and unequal stride and with padding
>>> m = nnq.ConvTranspose3d(16, 33, (3, 3, 5), stride=(2, 1, 1), padding=(4, 2, 2)) >>> m = nnq.ConvTranspose3d(16, 33, (3, 3, 5), stride=(2, 1, 1), padding=(4, 2, 2))
>>> # xdoctest: +SKIP
>>> output = m(input) >>> output = m(input)
>>> # exact output size can be also specified as an argument >>> # exact output size can be also specified as an argument
>>> downsample = nnq.Conv3d(16, 16, 3, stride=2, padding=1) >>> downsample = nnq.Conv3d(16, 16, 3, stride=2, padding=1)

View File

@ -7,6 +7,7 @@ __all__ = [
"Linear", "Linear",
] ]
class Linear(nnq.Linear): class Linear(nnq.Linear):
r""" r"""
A dynamic quantized linear module with floating point tensor as inputs and outputs. A dynamic quantized linear module with floating point tensor as inputs and outputs.
@ -25,9 +26,9 @@ class Linear(nnq.Linear):
Examples:: Examples::
>>> # xdoctest: +SKIP
>>> m = nn.quantized.dynamic.Linear(20, 30) >>> m = nn.quantized.dynamic.Linear(20, 30)
>>> input = torch.randn(128, 20) >>> input = torch.randn(128, 20)
>>> # xdoctest: +SKIP
>>> output = m(input) >>> output = m(input)
>>> print(output.size()) >>> print(output.size())
torch.Size([128, 30]) torch.Size([128, 30])

View File

@ -11,13 +11,16 @@ from torch.ao.nn.quantized.modules.utils import _quantize_weight
__all__ = ['pack_weight_bias', 'PackedParameter', 'RNNBase', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell', 'LSTMCell', __all__ = ['pack_weight_bias', 'PackedParameter', 'RNNBase', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell', 'LSTMCell',
'GRUCell', "apply_permutation"] 'GRUCell', "apply_permutation"]
def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
return tensor.index_select(dim, permutation) return tensor.index_select(dim, permutation)
def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
warnings.warn("apply_permutation is deprecated, please use tensor.index_select(dim, permutation) instead") warnings.warn("apply_permutation is deprecated, please use tensor.index_select(dim, permutation) instead")
return _apply_permutation(tensor, permutation, dim) return _apply_permutation(tensor, permutation, dim)
def pack_weight_bias(qweight, bias, dtype): def pack_weight_bias(qweight, bias, dtype):
if dtype == torch.qint8: if dtype == torch.qint8:
@ -39,6 +42,7 @@ def pack_weight_bias(qweight, bias, dtype):
return packed_weight return packed_weight
class PackedParameter(torch.nn.Module): class PackedParameter(torch.nn.Module):
def __init__(self, param): def __init__(self, param):
super(PackedParameter, self).__init__() super(PackedParameter, self).__init__()
@ -54,6 +58,7 @@ class PackedParameter(torch.nn.Module):
super(PackedParameter, self)._load_from_state_dict(state_dict, prefix, local_metadata, False, super(PackedParameter, self)._load_from_state_dict(state_dict, prefix, local_metadata, False,
missing_keys, unexpected_keys, error_msgs) missing_keys, unexpected_keys, error_msgs)
class RNNBase(torch.nn.Module): class RNNBase(torch.nn.Module):
_FLOAT_MODULE = nn.RNNBase _FLOAT_MODULE = nn.RNNBase
@ -347,7 +352,6 @@ class RNNBase(torch.nn.Module):
return qRNNBase return qRNNBase
def _weight_bias(self): def _weight_bias(self):
# Returns a dict of weights and biases # Returns a dict of weights and biases
weight_bias_dict: Dict[str, Dict] = {'weight' : {}, 'bias' : {}} weight_bias_dict: Dict[str, Dict] = {'weight' : {}, 'bias' : {}}
@ -376,6 +380,7 @@ class RNNBase(torch.nn.Module):
def get_bias(self): def get_bias(self):
return self._weight_bias()['bias'] return self._weight_bias()['bias']
class LSTM(RNNBase): class LSTM(RNNBase):
r""" r"""
A dynamic quantized LSTM module with floating point tensor as inputs and outputs. A dynamic quantized LSTM module with floating point tensor as inputs and outputs.
@ -384,6 +389,7 @@ class LSTM(RNNBase):
Examples:: Examples::
>>> # xdoctest: +SKIP
>>> rnn = nn.LSTM(10, 20, 2) >>> rnn = nn.LSTM(10, 20, 2)
>>> input = torch.randn(5, 3, 10) >>> input = torch.randn(5, 3, 10)
>>> h0 = torch.randn(2, 3, 20) >>> h0 = torch.randn(2, 3, 20)
@ -610,6 +616,7 @@ class GRU(RNNBase):
Examples:: Examples::
>>> # xdoctest: +SKIP
>>> rnn = nn.GRU(10, 20, 2) >>> rnn = nn.GRU(10, 20, 2)
>>> input = torch.randn(5, 3, 10) >>> input = torch.randn(5, 3, 10)
>>> h0 = torch.randn(2, 3, 20) >>> h0 = torch.randn(2, 3, 20)
@ -922,6 +929,7 @@ class RNNCellBase(torch.nn.Module):
super(RNNCellBase, self)._load_from_state_dict(state_dict, prefix, local_metadata, False, super(RNNCellBase, self)._load_from_state_dict(state_dict, prefix, local_metadata, False,
missing_keys, unexpected_keys, error_msgs) missing_keys, unexpected_keys, error_msgs)
class RNNCell(RNNCellBase): class RNNCell(RNNCellBase):
r"""An Elman RNN cell with tanh or ReLU non-linearity. r"""An Elman RNN cell with tanh or ReLU non-linearity.
A dynamic quantized RNNCell module with floating point tensor as inputs and outputs. A dynamic quantized RNNCell module with floating point tensor as inputs and outputs.
@ -930,6 +938,7 @@ class RNNCell(RNNCellBase):
Examples:: Examples::
>>> # xdoctest: +SKIP
>>> rnn = nn.RNNCell(10, 20) >>> rnn = nn.RNNCell(10, 20)
>>> input = torch.randn(6, 3, 10) >>> input = torch.randn(6, 3, 10)
>>> hx = torch.randn(3, 20) >>> hx = torch.randn(3, 20)
@ -982,6 +991,7 @@ class LSTMCell(RNNCellBase):
Examples:: Examples::
>>> # xdoctest: +SKIP
>>> rnn = nn.LSTMCell(10, 20) >>> rnn = nn.LSTMCell(10, 20)
>>> input = torch.randn(6, 3, 10) >>> input = torch.randn(6, 3, 10)
>>> hx = torch.randn(3, 20) >>> hx = torch.randn(3, 20)
@ -1014,6 +1024,7 @@ class LSTMCell(RNNCellBase):
def from_float(cls, mod): def from_float(cls, mod):
return super(LSTMCell, cls).from_float(mod) return super(LSTMCell, cls).from_float(mod)
class GRUCell(RNNCellBase): class GRUCell(RNNCellBase):
r"""A gated recurrent unit (GRU) cell r"""A gated recurrent unit (GRU) cell
@ -1023,6 +1034,7 @@ class GRUCell(RNNCellBase):
Examples:: Examples::
>>> # xdoctest: +SKIP
>>> rnn = nn.GRUCell(10, 20) >>> rnn = nn.GRUCell(10, 20)
>>> input = torch.randn(6, 3, 10) >>> input = torch.randn(6, 3, 10)
>>> hx = torch.randn(3, 20) >>> hx = torch.randn(3, 20)

View File

@ -164,6 +164,7 @@ def conv1d(input, weight, bias,
Examples:: Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
>>> from torch.ao.nn.quantized import functional as qF >>> from torch.ao.nn.quantized import functional as qF
>>> filters = torch.randn(33, 16, 3, dtype=torch.float) >>> filters = torch.randn(33, 16, 3, dtype=torch.float)
>>> inputs = torch.randn(20, 16, 50, dtype=torch.float) >>> inputs = torch.randn(20, 16, 50, dtype=torch.float)
@ -223,6 +224,7 @@ def conv2d(input, weight, bias,
Examples:: Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
>>> from torch.ao.nn.quantized import functional as qF >>> from torch.ao.nn.quantized import functional as qF
>>> filters = torch.randn(8, 4, 3, 3, dtype=torch.float) >>> filters = torch.randn(8, 4, 3, 3, dtype=torch.float)
>>> inputs = torch.randn(1, 4, 5, 5, dtype=torch.float) >>> inputs = torch.randn(1, 4, 5, 5, dtype=torch.float)
@ -283,6 +285,7 @@ def conv3d(input, weight, bias, stride=1, padding=0, dilation=1, groups=1,
Examples:: Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
>>> from torch.ao.nn.quantized import functional as qF >>> from torch.ao.nn.quantized import functional as qF
>>> filters = torch.randn(8, 4, 3, 3, 3, dtype=torch.float) >>> filters = torch.randn(8, 4, 3, 3, 3, dtype=torch.float)
>>> inputs = torch.randn(1, 4, 5, 5, 5, dtype=torch.float) >>> inputs = torch.randn(1, 4, 5, 5, 5, dtype=torch.float)

View File

@ -293,6 +293,7 @@ class Conv1d(_ConvNd):
Examples:: Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
>>> m = nn.quantized.Conv1d(16, 33, 3, stride=2) >>> m = nn.quantized.Conv1d(16, 33, 3, stride=2)
>>> input = torch.randn(20, 16, 100) >>> input = torch.randn(20, 16, 100)
>>> # quantize input to quint8 >>> # quantize input to quint8
@ -400,6 +401,7 @@ class Conv2d(_ConvNd):
Examples:: Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
>>> # With square kernels and equal stride >>> # With square kernels and equal stride
>>> m = nn.quantized.Conv2d(16, 33, 3, stride=2) >>> m = nn.quantized.Conv2d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding >>> # non-square kernels and unequal stride and with padding
@ -498,6 +500,7 @@ class Conv3d(_ConvNd):
Examples:: Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
>>> # With square kernels and equal stride >>> # With square kernels and equal stride
>>> m = nn.quantized.Conv3d(16, 33, 3, stride=2) >>> m = nn.quantized.Conv3d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding >>> # non-square kernels and unequal stride and with padding

View File

@ -115,6 +115,7 @@ class Linear(WeightedQuantizedModule):
Examples:: Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
>>> m = nn.quantized.Linear(20, 30) >>> m = nn.quantized.Linear(20, 30)
>>> input = torch.randn(128, 20) >>> input = torch.randn(128, 20)
>>> # xdoctest: +SKIP >>> # xdoctest: +SKIP

View File

@ -90,10 +90,10 @@ class BaseDataScheduler(object):
Example: Example:
>>> def get_schedule_param(self): >>> def get_schedule_param(self):
... new_param = {} ... new_param = {}
... for name in self.sparsifier.data_groups.keys(): ... for name in self.sparsifier.data_groups.keys():
... new_param[name] = self.sparsifier.data_groups[name][self.schedule_param] * 0.5 ... new_param[name] = self.sparsifier.data_groups[name][self.schedule_param] * 0.5
... return new_param ... return new_param
When the step() function is called, the value in self.sparsifier.data_groups[name][self.schedule_param] When the step() function is called, the value in self.sparsifier.data_groups[name][self.schedule_param]
would be halved would be halved

View File

@ -88,6 +88,7 @@ class DTypeConfig:
Example usage:: Example usage::
>>> # xdoctest: +SKIP(failing)
>>> dtype_config1 = DTypeConfig( >>> dtype_config1 = DTypeConfig(
... input_dtype=torch.quint8, ... input_dtype=torch.quint8,
... output_dtype=torch.quint8, ... output_dtype=torch.quint8,

View File

@ -77,6 +77,7 @@ def _fuse_linear_bn_leaky_relu(is_qat, linear, bn, leaky_relu):
bn: BatchNorm1d instance that needs to be fused with the linear layer bn: BatchNorm1d instance that needs to be fused with the linear layer
leaky_relu: LeakyReLU instance that needs to be fused with the linear layer leaky_relu: LeakyReLU instance that needs to be fused with the linear layer
Examples:: Examples::
>>> # xdoctest: +SKIP(failing)
>>> m1 = nn.Linear(20, 10) >>> m1 = nn.Linear(20, 10)
>>> b1 = nn.BatchNorm1d(10) >>> b1 = nn.BatchNorm1d(10)
>>> lr = nn.LeakyReLU(0.01) >>> lr = nn.LeakyReLU(0.01)

View File

@ -5,6 +5,7 @@ from typing import Any
__all__ = ["detect_anomaly", "set_detect_anomaly"] __all__ = ["detect_anomaly", "set_detect_anomaly"]
class detect_anomaly(object): class detect_anomaly(object):
r"""Context-manager that enable anomaly detection for the autograd engine. r"""Context-manager that enable anomaly detection for the autograd engine.
@ -22,6 +23,7 @@ class detect_anomaly(object):
Example: Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ANOMOLY)
>>> import torch >>> import torch
>>> from torch import autograd >>> from torch import autograd
>>> class MyFunc(autograd.Function): >>> class MyFunc(autograd.Function):

View File

@ -11,6 +11,7 @@ __all__ = ["UnpackedDualTensor", "enter_dual_level", "exit_dual_level", "make_du
# Global variable used to make the python API simpler to use # Global variable used to make the python API simpler to use
_current_level = -1 _current_level = -1
def enter_dual_level(): def enter_dual_level():
r"""Function that can be used to enter a new forward grad level. r"""Function that can be used to enter a new forward grad level.
This level can be used to make and unpack dual Tensors to compute This level can be used to make and unpack dual Tensors to compute
@ -27,6 +28,7 @@ def enter_dual_level():
_current_level = new_level _current_level = new_level
return new_level return new_level
def exit_dual_level(*, level=None): def exit_dual_level(*, level=None):
r"""Function that can be used to exit a forward grad level. r"""Function that can be used to exit a forward grad level.
This function deletes all the gradients associated with this This function deletes all the gradients associated with this
@ -44,6 +46,7 @@ def exit_dual_level(*, level=None):
torch._C._exit_dual_level(level=level) torch._C._exit_dual_level(level=level)
_current_level = level - 1 _current_level = level - 1
def make_dual(tensor, tangent, *, level=None): def make_dual(tensor, tangent, *, level=None):
r"""Associates a tensor value with a forward gradient, the tangent, to create a r"""Associates a tensor value with a forward gradient, the tangent, to create a
"dual tensor", which is used to compute forward AD gradients. "dual tensor", which is used to compute forward AD gradients.
@ -60,9 +63,9 @@ def make_dual(tensor, tangent, *, level=None):
>>> # xdoctest: +SKIP("Undefined variables") >>> # xdoctest: +SKIP("Undefined variables")
>>> with dual_level(): >>> with dual_level():
... inp = make_dual(x, v) ... inp = make_dual(x, v)
... out = f(inp) ... out = f(inp)
... y, jvp = unpack_dual(out) ... y, jvp = unpack_dual(out)
Please see the `forward-mode AD tutorial <https://pytorch.org/tutorials/intermediate/forward_ad_usage.html>`__ Please see the `forward-mode AD tutorial <https://pytorch.org/tutorials/intermediate/forward_ad_usage.html>`__
for detailed steps on how to use this API. for detailed steps on how to use this API.
@ -104,11 +107,13 @@ def make_dual(tensor, tangent, *, level=None):
_UnpackedDualTensor = namedtuple('_UnpackedDualTensor', ['primal', 'tangent']) _UnpackedDualTensor = namedtuple('_UnpackedDualTensor', ['primal', 'tangent'])
class UnpackedDualTensor(_UnpackedDualTensor): class UnpackedDualTensor(_UnpackedDualTensor):
r"""Namedtuple returned by :func:`unpack_dual` containing the primal and tangent components of the dual tensor. r"""Namedtuple returned by :func:`unpack_dual` containing the primal and tangent components of the dual tensor.
See :func:`unpack_dual` for more details.""" See :func:`unpack_dual` for more details."""
pass pass
def unpack_dual(tensor, *, level=None): def unpack_dual(tensor, *, level=None):
r"""Unpacks a "dual tensor" to get both its Tensor value and its forward AD gradient. r"""Unpacks a "dual tensor" to get both its Tensor value and its forward AD gradient.
The result is a namedtuple ``(primal, tangent)`` where ``primal`` is a view of The result is a namedtuple ``(primal, tangent)`` where ``primal`` is a view of
@ -121,10 +126,10 @@ def unpack_dual(tensor, *, level=None):
>>> # xdoctest: +SKIP("Undefined variables") >>> # xdoctest: +SKIP("Undefined variables")
>>> with dual_level(): >>> with dual_level():
... inp = make_dual(x, x_t) ... inp = make_dual(x, x_t)
... out = f(inp) ... out = f(inp)
... y, jvp = unpack_dual(out) ... y, jvp = unpack_dual(out)
... jvp = unpack_dual(out).tangent ... jvp = unpack_dual(out).tangent
Please see the `forward-mode AD tutorial <https://pytorch.org/tutorials/intermediate/forward_ad_usage.html>`__ Please see the `forward-mode AD tutorial <https://pytorch.org/tutorials/intermediate/forward_ad_usage.html>`__
for detailed steps on how to use this API. for detailed steps on how to use this API.
@ -139,6 +144,7 @@ def unpack_dual(tensor, *, level=None):
return UnpackedDualTensor(primal, dual) return UnpackedDualTensor(primal, dual)
class dual_level(_DecoratorContextManager): class dual_level(_DecoratorContextManager):
r"""Context-manager that enables forward AD. All forward AD computation must r"""Context-manager that enables forward AD. All forward AD computation must
be performed in a ``dual_level`` context. be performed in a ``dual_level`` context.
@ -159,10 +165,10 @@ class dual_level(_DecoratorContextManager):
>>> x = torch.tensor([1]) >>> x = torch.tensor([1])
>>> x_t = torch.tensor([1]) >>> x_t = torch.tensor([1])
>>> with dual_level(): >>> with dual_level():
... inp = make_dual(x, x_t) ... inp = make_dual(x, x_t)
... # Do computations with inp ... # Do computations with inp
... out = your_fn(inp) ... out = your_fn(inp)
... _, grad = unpack_dual(out) ... _, grad = unpack_dual(out)
>>> grad is None >>> grad is None
False False
>>> # After exiting the level, the grad is deleted >>> # After exiting the level, the grad is deleted

View File

@ -48,6 +48,7 @@ class FunctionCtx(object):
See :ref:`extending-autograd` for more details on how to use this method. See :ref:`extending-autograd` for more details on how to use this method.
Example:: Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class Func(Function): >>> class Func(Function):
>>> @staticmethod >>> @staticmethod
>>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int): >>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
@ -139,6 +140,7 @@ class FunctionCtx(object):
modification. modification.
Examples:: Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class Inplace(Function): >>> class Inplace(Function):
>>> @staticmethod >>> @staticmethod
>>> def forward(ctx, x): >>> def forward(ctx, x):
@ -210,6 +212,7 @@ class FunctionCtx(object):
prior to calling the :func:`backward` and :func:`jvp` methods. prior to calling the :func:`backward` and :func:`jvp` methods.
Example:: Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class SimpleFunc(Function): >>> class SimpleFunc(Function):
>>> @staticmethod >>> @staticmethod
>>> def forward(ctx, x): >>> def forward(ctx, x):
@ -382,6 +385,7 @@ class Function(_SingleLevelFunction):
Examples:: Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class Exp(Function): >>> class Exp(Function):
>>> @staticmethod >>> @staticmethod
>>> def forward(ctx, i): >>> def forward(ctx, i):

View File

@ -7,6 +7,7 @@ __all__ = ["vjp", "jvp", "jacobian", "hessian", "hvp", "vhp"]
# Utility functions # Utility functions
def _as_tuple_nocheck(x): def _as_tuple_nocheck(x):
if isinstance(x, tuple): if isinstance(x, tuple):
return x return x
@ -15,6 +16,7 @@ def _as_tuple_nocheck(x):
else: else:
return x, return x,
def _as_tuple(inp, arg_name=None, fn_name=None): def _as_tuple(inp, arg_name=None, fn_name=None):
# Ensures that inp is a tuple of Tensors # Ensures that inp is a tuple of Tensors
# Returns whether or not the original inp was a tuple and the tupled version of the input # Returns whether or not the original inp was a tuple and the tupled version of the input
@ -37,6 +39,7 @@ def _as_tuple(inp, arg_name=None, fn_name=None):
return is_inp_tuple, inp return is_inp_tuple, inp
def _tuple_postprocess(res, to_unpack): def _tuple_postprocess(res, to_unpack):
# Unpacks a potentially nested tuple of Tensors # Unpacks a potentially nested tuple of Tensors
# to_unpack should be a single boolean or a tuple of two booleans. # to_unpack should be a single boolean or a tuple of two booleans.
@ -54,6 +57,7 @@ def _tuple_postprocess(res, to_unpack):
res = res[0] res = res[0]
return res return res
def _grad_preprocess(inputs, create_graph, need_graph): def _grad_preprocess(inputs, create_graph, need_graph):
# Preprocess the inputs to make sure they require gradient # Preprocess the inputs to make sure they require gradient
# inputs is a tuple of Tensors to preprocess # inputs is a tuple of Tensors to preprocess
@ -88,6 +92,7 @@ def _grad_postprocess(inputs, create_graph):
else: else:
return tuple(_grad_postprocess(inp, create_graph) for inp in inputs) return tuple(_grad_postprocess(inp, create_graph) for inp in inputs)
def _validate_v(v, other, is_other_tuple): def _validate_v(v, other, is_other_tuple):
# This assumes that other is the correct shape, and v should match # This assumes that other is the correct shape, and v should match
# Both are assumed to be tuples of Tensors # Both are assumed to be tuples of Tensors
@ -138,6 +143,7 @@ def _check_requires_grad(inputs, input_type, strict):
" The outputs must be computed in a differentiable manner from the input" " The outputs must be computed in a differentiable manner from the input"
" when running in strict mode.".format(i)) " when running in strict mode.".format(i))
def _autograd_grad(outputs, inputs, grad_outputs=None, create_graph=False, retain_graph=None, is_grads_batched=False): def _autograd_grad(outputs, inputs, grad_outputs=None, create_graph=False, retain_graph=None, is_grads_batched=False):
# Version of autograd.grad that accepts `None` in outputs and do not compute gradients for them. # Version of autograd.grad that accepts `None` in outputs and do not compute gradients for them.
# This has the extra constraint that inputs has to be a tuple # This has the extra constraint that inputs has to be a tuple
@ -162,6 +168,7 @@ def _autograd_grad(outputs, inputs, grad_outputs=None, create_graph=False, retai
create_graph=create_graph, retain_graph=retain_graph, create_graph=create_graph, retain_graph=retain_graph,
is_grads_batched=is_grads_batched) is_grads_batched=is_grads_batched)
def _fill_in_zeros(grads, refs, strict, create_graph, stage): def _fill_in_zeros(grads, refs, strict, create_graph, stage):
# Used to detect None in the grads and depending on the flags, either replace them # Used to detect None in the grads and depending on the flags, either replace them
# with Tensors full of 0s of the appropriate size based on the refs or raise an error. # with Tensors full of 0s of the appropriate size based on the refs or raise an error.
@ -204,6 +211,7 @@ def _fill_in_zeros(grads, refs, strict, create_graph, stage):
return res return res
# Public API # Public API
def vjp(func, inputs, v=None, create_graph=False, strict=False): def vjp(func, inputs, v=None, create_graph=False, strict=False):
@ -238,8 +246,9 @@ def vjp(func, inputs, v=None, create_graph=False, strict=False):
Example: Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> def exp_reducer(x): >>> def exp_reducer(x):
... return x.exp().sum(dim=1) ... return x.exp().sum(dim=1)
>>> inputs = torch.rand(4, 4) >>> inputs = torch.rand(4, 4)
>>> v = torch.ones(4) >>> v = torch.ones(4)
>>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> # xdoctest: +IGNORE_WANT("non-deterministic")
@ -258,7 +267,7 @@ def vjp(func, inputs, v=None, create_graph=False, strict=False):
[1.3225, 1.6652, 1.7753, 2.0152]], grad_fn=<MulBackward0>)) [1.3225, 1.6652, 1.7753, 2.0152]], grad_fn=<MulBackward0>))
>>> def adder(x, y): >>> def adder(x, y):
... return 2 * x + 3 * y ... return 2 * x + 3 * y
>>> inputs = (torch.rand(2), torch.rand(2)) >>> inputs = (torch.rand(2), torch.rand(2))
>>> v = torch.ones(2) >>> v = torch.ones(2)
>>> vjp(adder, inputs, v) >>> vjp(adder, inputs, v)
@ -335,8 +344,9 @@ def jvp(func, inputs, v=None, create_graph=False, strict=False):
Example: Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> def exp_reducer(x): >>> def exp_reducer(x):
... return x.exp().sum(dim=1) ... return x.exp().sum(dim=1)
>>> inputs = torch.rand(4, 4) >>> inputs = torch.rand(4, 4)
>>> v = torch.ones(4, 4) >>> v = torch.ones(4, 4)
>>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> # xdoctest: +IGNORE_WANT("non-deterministic")
@ -349,7 +359,7 @@ def jvp(func, inputs, v=None, create_graph=False, strict=False):
tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=<SqueezeBackward1>)) tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=<SqueezeBackward1>))
>>> def adder(x, y): >>> def adder(x, y):
... return 2 * x + 3 * y ... return 2 * x + 3 * y
>>> inputs = (torch.rand(2), torch.rand(2)) >>> inputs = (torch.rand(2), torch.rand(2))
>>> v = (torch.ones(2), torch.ones(2)) >>> v = (torch.ones(2), torch.ones(2))
>>> jvp(adder, inputs, v) >>> jvp(adder, inputs, v)
@ -536,8 +546,9 @@ def jacobian(func, inputs, create_graph=False, strict=False, vectorize=False, st
Example: Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> def exp_reducer(x): >>> def exp_reducer(x):
... return x.exp().sum(dim=1) ... return x.exp().sum(dim=1)
>>> inputs = torch.rand(2, 2) >>> inputs = torch.rand(2, 2)
>>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> jacobian(exp_reducer, inputs) >>> jacobian(exp_reducer, inputs)
@ -553,7 +564,7 @@ def jacobian(func, inputs, create_graph=False, strict=False, vectorize=False, st
[2.4369, 2.3799]]], grad_fn=<ViewBackward>) [2.4369, 2.3799]]], grad_fn=<ViewBackward>)
>>> def exp_adder(x, y): >>> def exp_adder(x, y):
... return 2 * x.exp() + 3 * y ... return 2 * x.exp() + 3 * y
>>> inputs = (torch.rand(2), torch.rand(2)) >>> inputs = (torch.rand(2), torch.rand(2))
>>> jacobian(exp_adder, inputs) >>> jacobian(exp_adder, inputs)
(tensor([[2.8052, 0.0000], (tensor([[2.8052, 0.0000],
@ -698,6 +709,7 @@ def jacobian(func, inputs, create_graph=False, strict=False, vectorize=False, st
return _tuple_postprocess(jacobian, (is_outputs_tuple, is_inputs_tuple)) return _tuple_postprocess(jacobian, (is_outputs_tuple, is_inputs_tuple))
def hessian(func, inputs, create_graph=False, strict=False, vectorize=False, outer_jacobian_strategy="reverse-mode"): def hessian(func, inputs, create_graph=False, strict=False, vectorize=False, outer_jacobian_strategy="reverse-mode"):
r"""Function that computes the Hessian of a given scalar function. r"""Function that computes the Hessian of a given scalar function.
@ -746,8 +758,9 @@ def hessian(func, inputs, create_graph=False, strict=False, vectorize=False, out
Example: Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> def pow_reducer(x): >>> def pow_reducer(x):
... return x.pow(3).sum() ... return x.pow(3).sum()
>>> inputs = torch.rand(2, 2) >>> inputs = torch.rand(2, 2)
>>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> hessian(pow_reducer, inputs) >>> hessian(pow_reducer, inputs)
@ -772,7 +785,7 @@ def hessian(func, inputs, create_graph=False, strict=False, vectorize=False, out
>>> def pow_adder_reducer(x, y): >>> def pow_adder_reducer(x, y):
... return (2 * x.pow(2) + 3 * y.pow(2)).sum() ... return (2 * x.pow(2) + 3 * y.pow(2)).sum()
>>> inputs = (torch.rand(2), torch.rand(2)) >>> inputs = (torch.rand(2), torch.rand(2))
>>> hessian(pow_adder_reducer, inputs) >>> hessian(pow_adder_reducer, inputs)
((tensor([[4., 0.], ((tensor([[4., 0.],
@ -849,8 +862,9 @@ def vhp(func, inputs, v=None, create_graph=False, strict=False):
Example: Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> def pow_reducer(x): >>> def pow_reducer(x):
... return x.pow(3).sum() ... return x.pow(3).sum()
>>> inputs = torch.rand(2, 2) >>> inputs = torch.rand(2, 2)
>>> v = torch.ones(2, 2) >>> v = torch.ones(2, 2)
>>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> # xdoctest: +IGNORE_WANT("non-deterministic")
@ -863,7 +877,7 @@ def vhp(func, inputs, v=None, create_graph=False, strict=False):
tensor([[1.0689, 1.2431], tensor([[1.0689, 1.2431],
[3.0989, 4.4456]], grad_fn=<MulBackward0>)) [3.0989, 4.4456]], grad_fn=<MulBackward0>))
>>> def pow_adder_reducer(x, y): >>> def pow_adder_reducer(x, y):
... return (2 * x.pow(2) + 3 * y.pow(2)).sum() ... return (2 * x.pow(2) + 3 * y.pow(2)).sum()
>>> inputs = (torch.rand(2), torch.rand(2)) >>> inputs = (torch.rand(2), torch.rand(2))
>>> v = (torch.zeros(2), torch.ones(2)) >>> v = (torch.zeros(2), torch.ones(2))
>>> vhp(pow_adder_reducer, inputs, v) >>> vhp(pow_adder_reducer, inputs, v)
@ -939,8 +953,9 @@ def hvp(func, inputs, v=None, create_graph=False, strict=False):
Example: Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> def pow_reducer(x): >>> def pow_reducer(x):
... return x.pow(3).sum() ... return x.pow(3).sum()
>>> inputs = torch.rand(2, 2) >>> inputs = torch.rand(2, 2)
>>> v = torch.ones(2, 2) >>> v = torch.ones(2, 2)
>>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> # xdoctest: +IGNORE_WANT("non-deterministic")
@ -956,7 +971,7 @@ def hvp(func, inputs, v=None, create_graph=False, strict=False):
>>> def pow_adder_reducer(x, y): >>> def pow_adder_reducer(x, y):
... return (2 * x.pow(2) + 3 * y.pow(2)).sum() ... return (2 * x.pow(2) + 3 * y.pow(2)).sum()
>>> inputs = (torch.rand(2), torch.rand(2)) >>> inputs = (torch.rand(2), torch.rand(2))
>>> v = (torch.zeros(2), torch.ones(2)) >>> v = (torch.zeros(2), torch.ones(2))
>>> hvp(pow_adder_reducer, inputs, v) >>> hvp(pow_adder_reducer, inputs, v)

View File

@ -120,7 +120,7 @@ class no_grad(_DecoratorContextManager):
>>> # xdoctest: +SKIP >>> # xdoctest: +SKIP
>>> x = torch.tensor([1.], requires_grad=True) >>> x = torch.tensor([1.], requires_grad=True)
>>> with torch.no_grad(): >>> with torch.no_grad():
... y = x * 2 ... y = x * 2
>>> y.requires_grad >>> y.requires_grad
False False
>>> @torch.no_grad() >>> @torch.no_grad()
@ -166,8 +166,8 @@ class enable_grad(_DecoratorContextManager):
>>> # xdoctest: +SKIP >>> # xdoctest: +SKIP
>>> x = torch.tensor([1.], requires_grad=True) >>> x = torch.tensor([1.], requires_grad=True)
>>> with torch.no_grad(): >>> with torch.no_grad():
... with torch.enable_grad(): ... with torch.enable_grad():
... y = x * 2 ... y = x * 2
>>> y.requires_grad >>> y.requires_grad
True True
>>> y.backward() >>> y.backward()
@ -217,7 +217,7 @@ class set_grad_enabled(_DecoratorContextManager):
>>> x = torch.tensor([1.], requires_grad=True) >>> x = torch.tensor([1.], requires_grad=True)
>>> is_train = False >>> is_train = False
>>> with torch.set_grad_enabled(is_train): >>> with torch.set_grad_enabled(is_train):
... y = x * 2 ... y = x * 2
>>> y.requires_grad >>> y.requires_grad
False False
>>> _ = torch.set_grad_enabled(True) >>> _ = torch.set_grad_enabled(True)
@ -270,10 +270,11 @@ class inference_mode(_DecoratorContextManager):
mode (bool): Flag whether to enable or disable inference mode mode (bool): Flag whether to enable or disable inference mode
Example:: Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> import torch >>> import torch
>>> x = torch.ones(1, 2, 3, requires_grad=True) >>> x = torch.ones(1, 2, 3, requires_grad=True)
>>> with torch.inference_mode(): >>> with torch.inference_mode():
... y = x * x ... y = x * x
>>> y.requires_grad >>> y.requires_grad
False False
>>> # xdoctest: +SKIP("want string isnt quite right") >>> # xdoctest: +SKIP("want string isnt quite right")
@ -283,7 +284,7 @@ class inference_mode(_DecoratorContextManager):
RuntimeError: Inference tensors do not track version counter. RuntimeError: Inference tensors do not track version counter.
>>> @torch.inference_mode() >>> @torch.inference_mode()
... def func(x): ... def func(x):
... return x * x ... return x * x
>>> out = func(x) >>> out = func(x)
>>> out.requires_grad >>> out.requires_grad
False False

View File

@ -48,6 +48,7 @@ class saved_tensors_hooks():
Example:: Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> def pack_hook(x): >>> def pack_hook(x):
... print("Packing", x) ... print("Packing", x)
... return x ... return x
@ -107,6 +108,7 @@ class save_on_cpu(saved_tensors_hooks):
Example:: Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> a = torch.randn(5, requires_grad=True, device="cuda") >>> a = torch.randn(5, requires_grad=True, device="cuda")
>>> b = torch.randn(5, requires_grad=True, device="cuda") >>> b = torch.randn(5, requires_grad=True, device="cuda")
>>> c = torch.randn(5, requires_grad=True, device="cuda") >>> c = torch.randn(5, requires_grad=True, device="cuda")
@ -160,6 +162,7 @@ def disable_saved_tensors_hooks(error_message):
Example:: Example::
>>> # xdoctest: +SKIP(failing)
>>> message = "saved tensors default hooks are disabled" >>> message = "saved tensors default hooks are disabled"
>>> with torch.autograd.graph.disable_saved_tensors_hooks(message): >>> with torch.autograd.graph.disable_saved_tensors_hooks(message):
... # Raises RuntimeError: saved tensors default hooks are disabled ... # Raises RuntimeError: saved tensors default hooks are disabled

View File

@ -121,6 +121,7 @@ class profile(object):
Example: Example:
>>> # xdoctest: +SKIP >>> # xdoctest: +SKIP
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
>>> x = torch.randn((1, 1), requires_grad=True) >>> x = torch.randn((1, 1), requires_grad=True)
>>> with torch.autograd.profiler.profile() as prof: >>> with torch.autograd.profiler.profile() as prof:
>>> for _ in range(100): # any normal python code, really! >>> for _ in range(100): # any normal python code, really!
@ -453,6 +454,7 @@ class record_function(_ContextDecorator):
non-distributed cases. non-distributed cases.
Example: Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
>>> x = torch.randn((1, 1), requires_grad=True) >>> x = torch.randn((1, 1), requires_grad=True)
>>> with torch.autograd.profiler.profile() as prof: >>> with torch.autograd.profiler.profile() as prof:
... y = x ** 2 ... y = x ** 2
@ -578,6 +580,7 @@ class emit_itt(object):
Example: Example:
>>> # xdoctest: +SKIP("Undefined variables") >>> # xdoctest: +SKIP("Undefined variables")
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
>>> with torch.autograd.profiler.emit_itt(): >>> with torch.autograd.profiler.emit_itt():
... model(x) ... model(x)
@ -646,8 +649,9 @@ class emit_nvtx(object):
Example: Example:
>>> # xdoctest: +SKIP("undefined variables") >>> # xdoctest: +SKIP("undefined variables")
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
>>> with torch.cuda.profiler.profile(): >>> with torch.cuda.profiler.profile():
... model(x) # Warmup CUDA memory allocator and profiler ... model(x) # Warmup CUDA memory allocator and profiler
... with torch.autograd.profiler.emit_nvtx(): ... with torch.autograd.profiler.emit_nvtx():
... model(x) ... model(x)

View File

@ -6,6 +6,7 @@ import re
__all__ : List[str] = [] __all__ : List[str] = []
class _CodeParser: class _CodeParser:
def __init__(self, code_string: str): def __init__(self, code_string: str):
optional_ws = r"\s*" optional_ws = r"\s*"
@ -37,6 +38,7 @@ class _CodeParser:
self.function_params = result["function_params"] self.function_params = result["function_params"]
self.function_body = result["function_body"] self.function_body = result["function_body"]
class _JittedFunction: class _JittedFunction:
def __init__(self, code_string: str, return_by_ref: bool, num_outputs: int, **kwargs): def __init__(self, code_string: str, return_by_ref: bool, num_outputs: int, **kwargs):
self.code_string = code_string self.code_string = code_string
@ -135,6 +137,7 @@ def _create_jit_fn(code_string: str, **kwargs) -> Callable:
return _JittedFunction(code_string, return_by_ref=False, num_outputs=1, **kwargs) return _JittedFunction(code_string, return_by_ref=False, num_outputs=1, **kwargs)
def _create_multi_output_jit_fn(code_string: str, num_outputs: int, **kwargs) -> Callable: def _create_multi_output_jit_fn(code_string: str, num_outputs: int, **kwargs) -> Callable:
""" """
Create a jiterator-generated cuda kernel for an elementwise op that supports returning one or more outputs. Create a jiterator-generated cuda kernel for an elementwise op that supports returning one or more outputs.

View File

@ -825,6 +825,7 @@ class DistributedDataParallel(Module):
Example:: Example::
Below is an example of a noop hook that returns the same tensor. Below is an example of a noop hook that returns the same tensor.
>>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
>>> def noop(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]: >>> def noop(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
>>> fut = torch.futures.Future() >>> fut = torch.futures.Future()
>>> fut.set_result(bucket.buffer()) >>> fut.set_result(bucket.buffer())
@ -837,6 +838,7 @@ class DistributedDataParallel(Module):
Below is an example of a Parallel SGD algorithm where gradients are encoded before Below is an example of a Parallel SGD algorithm where gradients are encoded before
allreduce, and then decoded after allreduce. allreduce, and then decoded after allreduce.
>>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
>>> def encode_and_decode(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]: >>> def encode_and_decode(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
>>> encoded_tensor = encode(bucket.buffer()) # encode gradients >>> encoded_tensor = encode(bucket.buffer()) # encode gradients
>>> fut = torch.distributed.all_reduce(encoded_tensor).get_future() >>> fut = torch.distributed.all_reduce(encoded_tensor).get_future()

View File

@ -195,6 +195,7 @@ def checkpoint(module: nn.Module, *, use_reentrant: bool = True) -> nn.Module:
autograd. autograd.
Example:: Example::
>>> # xdoctest: +SKIP
>>> import torch.nn as nn >>> import torch.nn as nn
>>> >>>
>>> class MyModel(nn.Module): >>> class MyModel(nn.Module):

View File

@ -41,6 +41,7 @@ def contract(state_cls: Type[_State] = _State):
``func.state(module)``. ``func.state(module)``.
Example:: Example::
>>> # xdoctest: +SKIP
>>> import torch.nn as nn >>> import torch.nn as nn
>>> >>>
>>> class MyModel(nn.Module): >>> class MyModel(nn.Module):

View File

@ -18,6 +18,7 @@ def replicate(
module (torch.nn.Module): module to replicate module (torch.nn.Module): module to replicate
Example:: Example::
>>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
>>> module = nn.Linear(3, 3) >>> module = nn.Linear(3, 3)
>>> replicate(module) >>> replicate(module)
""" """

View File

@ -427,6 +427,7 @@ def custom_sharded_op_impl(func):
parameters, the function provided will be invoked for that operator. parameters, the function provided will be invoked for that operator.
Example:: Example::
>>> # xdoctest: +SKIP
>>> @custom_sharded_op_impl(torch.nn.functional.linear) >>> @custom_sharded_op_impl(torch.nn.functional.linear)
>>> def my_custom_sharded_linear(types, args, kwargs, process_group): >>> def my_custom_sharded_linear(types, args, kwargs, process_group):
>>> ... >>> ...

View File

@ -805,9 +805,9 @@ class ShardedTensor(ShardedTensorBase):
tensor stored in the current rank. tensor stored in the current rank.
Examples: Examples:
>>> # xdoctest: +SKIP
>>> # All tensors below are of torch.int64 type. >>> # All tensors below are of torch.int64 type.
>>> # We have 2 process groups, 2 ranks. >>> # We have 2 process groups, 2 ranks.
>>> # xdoctest: +SKIP
>>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
>>> local_tensor = torch.unsqueeze(torch.cat([tensor, tensor + 2])) >>> local_tensor = torch.unsqueeze(torch.cat([tensor, tensor + 2]))
>>> local_tensor >>> local_tensor
@ -955,8 +955,8 @@ class ShardedTensor(ShardedTensorBase):
A :class:`ShardedTensor` object whose local shards are resharded. A :class:`ShardedTensor` object whose local shards are resharded.
Examples: Examples:
>>> # We have 2 process groups, 2 ranks.
>>> # xdoctest: +SKIP >>> # xdoctest: +SKIP
>>> # We have 2 process groups, 2 ranks.
>>> tensor = torch.arange(4, dtype=torch.int64) + 1 + 2 * rank >>> tensor = torch.arange(4, dtype=torch.int64) + 1 + 2 * rank
>>> tensor = torch.stack([tensor, tensor]) >>> tensor = torch.stack([tensor, tensor])
>>> tensor >>> tensor

View File

@ -36,6 +36,7 @@ class ShardingPlan(object):
Suppose we want to shard a module with two linear layers and then run it with DDP, we also Suppose we want to shard a module with two linear layers and then run it with DDP, we also
want to convert the output of the second linear layer back to DDP, we can do it as follows: want to convert the output of the second linear layer back to DDP, we can do it as follows:
>>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
>>> class MyModule(nn.Module): >>> class MyModule(nn.Module):
>>> def __init__(self): >>> def __init__(self):
>>> super().__init__() >>> super().__init__()

View File

@ -54,6 +54,7 @@ class MemoryTracker:
Example usage: Example usage:
>>> # xdoctest: +SKIP(failing)
>>> net.cuda() >>> net.cuda()
>>> input = input.cuda() >>> input = input.cuda()

View File

@ -25,6 +25,7 @@ if is_available():
DistAutogradContext, DistAutogradContext,
) )
class context(object): class context(object):
''' '''
Context object to wrap forward and backward passes when using Context object to wrap forward and backward passes when using
@ -35,13 +36,13 @@ class context(object):
autograd pass. autograd pass.
Example:: Example::
>>> import torch.distributed.autograd as dist_autograd
>>> # xdoctest: +SKIP >>> # xdoctest: +SKIP
>>> import torch.distributed.autograd as dist_autograd
>>> with dist_autograd.context() as context_id: >>> with dist_autograd.context() as context_id:
>>> t1 = torch.rand((3, 3), requires_grad=True) >>> t1 = torch.rand((3, 3), requires_grad=True)
>>> t2 = torch.rand((3, 3), requires_grad=True) >>> t2 = torch.rand((3, 3), requires_grad=True)
>>> loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum() >>> loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum()
>>> dist_autograd.backward(context_id, [loss]) >>> dist_autograd.backward(context_id, [loss])
''' '''
def __enter__(self): def __enter__(self):
self.autograd_context = _new_context() self.autograd_context = _new_context()

View File

@ -202,6 +202,7 @@ def load_sharded_optimizer_state_dict(
""" """
Loads a state_dict to be used in conjuntion with FSDP sharded optimizer state. Loads a state_dict to be used in conjuntion with FSDP sharded optimizer state.
This is the current recommended way to checkpoint is FSDP This is the current recommended way to checkpoint is FSDP
>>> # xdoctest: +SKIP
>>> import torch.distributed.checkpoint as dist_cp >>> import torch.distributed.checkpoint as dist_cp
>>> import spmd.checkpoint as sp_cp >>> import spmd.checkpoint as sp_cp
>>> # Save >>> # Save
@ -224,7 +225,7 @@ def load_sharded_optimizer_state_dict(
>>> with FSDP.state_dict_type(model_tp, StateDictType.SHARDED_STATE_DICT): >>> with FSDP.state_dict_type(model_tp, StateDictType.SHARDED_STATE_DICT):
>>> model_state_dict = model_tp.state_dict() >>> model_state_dict = model_tp.state_dict()
>>> checkpoint = { >>> checkpoint = {
>>> "model" = model_state_dict >>> "model": model_state_dict
>>> } >>> }
>>> dist_cp.load_state_dict( >>> dist_cp.load_state_dict(
>>> state_dict=checkpoint, >>> state_dict=checkpoint,
@ -237,13 +238,13 @@ def load_sharded_optimizer_state_dict(
>>> model_state_dict, >>> model_state_dict,
>>> optimizer_key="optimizer", >>> optimizer_key="optimizer",
>>> storage_reader=dist_cp.FileSystemReader("checkpoint"), >>> storage_reader=dist_cp.FileSystemReader("checkpoint"),
>>> ) >>> )
>>> >>>
>>> flattened_osd = FSDP.flatten_sharded_optim_state_dict( >>> flattened_osd = FSDP.flatten_sharded_optim_state_dict(
>>> optim_state["optimizer"], model, optim >>> optim_state["optimizer"], model, optim
>>> ) >>> )
>>> >>>
>>> optim.load_state_dict(flattened_osd) >>> optim.load_state_dict(flattened_osd)
""" """
metadata = storage_reader.read_metadata() metadata = storage_reader.read_metadata()

View File

@ -1940,6 +1940,7 @@ def _tensor_to_object(tensor, tensor_size):
buf = tensor.numpy().tobytes()[:tensor_size] buf = tensor.numpy().tobytes()[:tensor_size]
return _unpickler(io.BytesIO(buf)).load() return _unpickler(io.BytesIO(buf)).load()
def _check_for_nccl_backend(group): def _check_for_nccl_backend(group):
pg = group or _get_default_group() pg = group or _get_default_group()
# Gate PG wrapper check on Gloo availability. # Gate PG wrapper check on Gloo availability.
@ -1954,6 +1955,7 @@ def _check_for_nccl_backend(group):
pg.name() == Backend.NCCL pg.name() == Backend.NCCL
) )
@exception_handler @exception_handler
def all_gather_object(object_list, obj, group=None): def all_gather_object(object_list, obj, group=None):
""" """
@ -3060,7 +3062,7 @@ def all_to_all_single(
>>> scatter_list = list(input.chunk(world_size)) >>> scatter_list = list(input.chunk(world_size))
>>> gather_list = list(output.chunk(world_size)) >>> gather_list = list(output.chunk(world_size))
>>> for i in range(world_size): >>> for i in range(world_size):
>>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i) >>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i)
>>> # Another example with uneven split >>> # Another example with uneven split
>>> input >>> input
@ -3179,7 +3181,7 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False
>>> scatter_list = input >>> scatter_list = input
>>> gather_list = output >>> gather_list = output
>>> for i in range(world_size): >>> for i in range(world_size):
>>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i) >>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src=i)
>>> input >>> input
tensor([0, 1, 2, 3, 4, 5]) # Rank 0 tensor([0, 1, 2, 3, 4, 5]) # Rank 0

View File

@ -323,11 +323,10 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
Example:: Example::
>>> # xdoctest: +SKIP
>>> import torch.nn as nn >>> import torch.nn as nn
>>> from torch.distributed.optim import ZeroRedundancyOptimizer >>> from torch.distributed.optim import ZeroRedundancyOptimizer
>>> from torch.nn.parallel import DistributedDataParallel as DDP >>> from torch.nn.parallel import DistributedDataParallel as DDP
>>> # xdoctest: +SKIP
>>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)]) >>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
>>> ddp = DDP(model, device_ids=[rank]) >>> ddp = DDP(model, device_ids=[rank])
>>> opt = ZeroRedundancyOptimizer( >>> opt = ZeroRedundancyOptimizer(

View File

@ -30,10 +30,12 @@ def _prepare_input_validate(
func (Callable): Same input function with validation logic added. func (Callable): Same input function with validation logic added.
Example:: Example::
>>> # xdoctest: +SKIP(failing)
>>> @_prepare_input_validate >>> @_prepare_input_validate
>>> def make_input_shard_1d(args, kwargs): >>> def make_input_shard_1d(args, kwargs):
>>> ... >>> ...
>>> >>>
>>> # xdoctest: +SKIP(failing)
>>> input = torch.rand(...) >>> input = torch.rand(...)
>>> dtensor = make_input_shard_1d(input, device_mesh, 1) >>> dtensor = make_input_shard_1d(input, device_mesh, 1)
>>> # This will call '_prepare_input_validate' first >>> # This will call '_prepare_input_validate' first
@ -71,14 +73,18 @@ def _prepare_output_validate(
Inject common validation logics for _prepare_output funcs via this Inject common validation logics for _prepare_output funcs via this
decorator, including verifying that output needs to be a DTensor decorator, including verifying that output needs to be a DTensor
and only 1D Device Mesh is passed in. and only 1D Device Mesh is passed in.
Example:: Example::
>>> # xdoctest: +SKIP(failing)
>>> @_prepare_output_validate >>> @_prepare_output_validate
>>> def make_output_shard_1d(args, kwargs): >>> def make_output_shard_1d(args, kwargs):
>>> ... >>> ...
>>> >>>
>>> # xdoctest: +SKIP(failing)
>>> dt = distribute(tensor, device_mesh, [Shard(0)]) >>> dt = distribute(tensor, device_mesh, [Shard(0)])
>>> make_output_shard_1d(dt, device_mesh, 1) >>> make_output_shard_1d(dt, device_mesh, 1)
>>> # This will call '_prepare_output_validate' first >>> # This will call '_prepare_output_validate' first
Args: Args:
_prepare_output_func (Callable): The func we want to inject the _prepare_output_func (Callable): The func we want to inject the
validation into. validation into.

View File

@ -61,7 +61,7 @@ def parallelize_module( # type: ignore[return]
Example:: Example::
>>> # xdoctest: +SKIP("distributed") >>> # xdoctest: +SKIP("distributed")
>>> from from torch.distributed._tensor.parallel import parallelize_module, PairwiseParallel >>> from torch.distributed._tensor.parallel import parallelize_module, PairwiseParallel
>>> >>>
>>> # Define the module. >>> # Define the module.
>>> m = Model(...) >>> m = Model(...)

View File

@ -8,6 +8,7 @@ from torch.distributions.utils import broadcast_all, lazy_property
__all__ = ['VonMises'] __all__ = ['VonMises']
def _eval_poly(y, coef): def _eval_poly(y, coef):
coef = list(coef) coef = list(coef)
result = coef.pop() result = coef.pop()
@ -77,7 +78,7 @@ class VonMises(Distribution):
Example:: Example::
>>> # xdoctest: +IGNORE_WANT("non-deterinistic") >>> # xdoctest: +IGNORE_WANT("non-deterinistic")
>>> m = VonMises(torch.tensor([1.0]), torch.tensor([1.0])) >>> m = VonMises(torch.tensor([1.0]), torch.tensor([1.0]))
>>> m.sample() # von Mises distributed with loc=1 and concentration=1 >>> m.sample() # von Mises distributed with loc=1 and concentration=1
tensor([1.9777]) tensor([1.9777])
:param torch.Tensor loc: an angle in radians. :param torch.Tensor loc: an angle in radians.

View File

@ -159,7 +159,7 @@ def split(
Example:: Example::
>>> a = torch.arange(10).reshape(5,2) >>> a = torch.arange(10).reshape(5, 2)
>>> a >>> a
tensor([[0, 1], tensor([[0, 1],
[2, 3], [2, 3],
@ -172,7 +172,7 @@ def split(
tensor([[4, 5], tensor([[4, 5],
[6, 7]]), [6, 7]]),
tensor([[8, 9]])) tensor([[8, 9]]))
>>> torch.split(a, [1,4]) >>> torch.split(a, [1, 4])
(tensor([[0, 1]]), (tensor([[0, 1]]),
tensor([[2, 3], tensor([[2, 3],
[4, 5], [4, 5],
@ -267,18 +267,18 @@ def einsum(*args: Any) -> Tensor:
Examples:: Examples::
>>> # trace
>>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> # trace
>>> torch.einsum('ii', torch.randn(4, 4)) >>> torch.einsum('ii', torch.randn(4, 4))
tensor(-1.2104) tensor(-1.2104)
>>> # diagonal
>>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> # diagonal
>>> torch.einsum('ii->i', torch.randn(4, 4)) >>> torch.einsum('ii->i', torch.randn(4, 4))
tensor([-0.1034, 0.7952, -0.2433, 0.4545]) tensor([-0.1034, 0.7952, -0.2433, 0.4545])
>>> # outer product
>>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> # outer product
>>> x = torch.randn(5) >>> x = torch.randn(5)
>>> y = torch.randn(4) >>> y = torch.randn(4)
>>> torch.einsum('i,j->ij', x, y) >>> torch.einsum('i,j->ij', x, y)
@ -288,10 +288,10 @@ def einsum(*args: Any) -> Tensor:
[ 0.1713, -0.4291, -0.5802, 0.7350], [ 0.1713, -0.4291, -0.5802, 0.7350],
[ 0.5704, -1.4290, -1.9323, 2.4480]]) [ 0.5704, -1.4290, -1.9323, 2.4480]])
>>> # batch matrix multiplication
>>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> As = torch.randn(3,2,5) >>> # batch matrix multiplication
>>> Bs = torch.randn(3,5,4) >>> As = torch.randn(3, 2, 5)
>>> Bs = torch.randn(3, 5, 4)
>>> torch.einsum('bij,bjk->bik', As, Bs) >>> torch.einsum('bij,bjk->bik', As, Bs)
tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], tensor([[[-1.0564, -1.5904, 3.2023, 3.1271],
[-1.6706, -0.8097, -0.8025, -2.1183]], [-1.6706, -0.8097, -0.8025, -2.1183]],
@ -302,8 +302,8 @@ def einsum(*args: Any) -> Tensor:
[[ 2.8153, 1.8787, -4.3839, -1.2112], [[ 2.8153, 1.8787, -4.3839, -1.2112],
[ 0.3728, -2.1131, 0.0921, 0.8305]]]) [ 0.3728, -2.1131, 0.0921, 0.8305]]])
>>> # with sublist format and ellipsis
>>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> # with sublist format and ellipsis
>>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2]) >>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2])
tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], tensor([[[-1.0564, -1.5904, 3.2023, 3.1271],
[-1.6706, -0.8097, -0.8025, -2.1183]], [-1.6706, -0.8097, -0.8025, -2.1183]],
@ -320,9 +320,9 @@ def einsum(*args: Any) -> Tensor:
torch.Size([2, 3, 5, 4]) torch.Size([2, 3, 5, 4])
>>> # equivalent to torch.nn.functional.bilinear >>> # equivalent to torch.nn.functional.bilinear
>>> A = torch.randn(3,5,4) >>> A = torch.randn(3, 5, 4)
>>> l = torch.randn(2,5) >>> l = torch.randn(2, 5)
>>> r = torch.randn(2,4) >>> r = torch.randn(2, 4)
>>> torch.einsum('bn,anm,bm->ba', l, A, r) >>> torch.einsum('bn,anm,bm->ba', l, A, r)
tensor([[-0.3430, -5.2405, 0.4494], tensor([[-0.3430, -5.2405, 0.4494],
[ 0.3311, 5.5201, -3.0356]]) [ 0.3311, 5.5201, -3.0356]])
@ -1253,7 +1253,7 @@ def atleast_1d(*tensors):
tensor([1.]) tensor([1.])
>>> x = torch.tensor(0.5) >>> x = torch.tensor(0.5)
>>> y = torch.tensor(1.) >>> y = torch.tensor(1.)
>>> torch.atleast_1d((x,y)) >>> torch.atleast_1d((x, y))
(tensor([0.5000]), tensor([1.])) (tensor([0.5000]), tensor([1.]))
""" """
# This wrapper exists to support variadic args. # This wrapper exists to support variadic args.
@ -1282,7 +1282,7 @@ def atleast_2d(*tensors):
tensor(1.) tensor(1.)
>>> torch.atleast_2d(x) >>> torch.atleast_2d(x)
tensor([[1.]]) tensor([[1.]])
>>> x = torch.arange(4).view(2,2) >>> x = torch.arange(4).view(2, 2)
>>> x >>> x
tensor([[0, 1], tensor([[0, 1],
[2, 3]]) [2, 3]])
@ -1291,7 +1291,7 @@ def atleast_2d(*tensors):
[2, 3]]) [2, 3]])
>>> x = torch.tensor(0.5) >>> x = torch.tensor(0.5)
>>> y = torch.tensor(1.) >>> y = torch.tensor(1.)
>>> torch.atleast_2d((x,y)) >>> torch.atleast_2d((x, y))
(tensor([[0.5000]]), tensor([[1.]])) (tensor([[0.5000]]), tensor([[1.]]))
""" """
# This wrapper exists to support variadic args. # This wrapper exists to support variadic args.
@ -1320,7 +1320,7 @@ def atleast_3d(*tensors):
tensor(0.5000) tensor(0.5000)
>>> torch.atleast_3d(x) >>> torch.atleast_3d(x)
tensor([[[0.5000]]]) tensor([[[0.5000]]])
>>> y = torch.arange(4).view(2,2) >>> y = torch.arange(4).view(2, 2)
>>> y >>> y
tensor([[0, 1], tensor([[0, 1],
[2, 3]]) [2, 3]])
@ -1337,7 +1337,7 @@ def atleast_3d(*tensors):
tensor([[[1]]]) tensor([[[1]]])
>>> x = torch.tensor(0.5) >>> x = torch.tensor(0.5)
>>> y = torch.tensor(1.) >>> y = torch.tensor(1.)
>>> torch.atleast_3d((x,y)) >>> torch.atleast_3d((x, y))
(tensor([[[0.5000]]]), tensor([[[1.]]])) (tensor([[[0.5000]]]), tensor([[[1.]]]))
""" """
# This wrapper exists to support variadic args. # This wrapper exists to support variadic args.
@ -1464,15 +1464,15 @@ def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False,
tensor(4.) tensor(4.)
>>> torch.norm(b, float('inf')) >>> torch.norm(b, float('inf'))
tensor(4.) tensor(4.)
>>> c = torch.tensor([[ 1, 2, 3],[-1, 1, 4]] , dtype= torch.float) >>> c = torch.tensor([[ 1, 2, 3], [-1, 1, 4]] , dtype=torch.float)
>>> torch.norm(c, dim=0) >>> torch.norm(c, dim=0)
tensor([1.4142, 2.2361, 5.0000]) tensor([1.4142, 2.2361, 5.0000])
>>> torch.norm(c, dim=1) >>> torch.norm(c, dim=1)
tensor([3.7417, 4.2426]) tensor([3.7417, 4.2426])
>>> torch.norm(c, p=1, dim=1) >>> torch.norm(c, p=1, dim=1)
tensor([6., 6.]) tensor([6., 6.])
>>> d = torch.arange(8, dtype= torch.float).reshape(2,2,2) >>> d = torch.arange(8, dtype=torch.float).reshape(2, 2, 2)
>>> torch.norm(d, dim=(1,2)) >>> torch.norm(d, dim=(1, 2))
tensor([ 3.7417, 11.2250]) tensor([ 3.7417, 11.2250])
>>> torch.norm(d[0, :, :]), torch.norm(d[1, :, :]) >>> torch.norm(d[0, :, :]), torch.norm(d[1, :, :])
(tensor(3.7417), tensor(11.2250)) (tensor(3.7417), tensor(11.2250))
@ -1604,6 +1604,7 @@ def chain_matmul(*matrices, out=None):
Example:: Example::
>>> # xdoctest: +SKIP
>>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> a = torch.randn(3, 4) >>> a = torch.randn(3, 4)
>>> b = torch.randn(4, 5) >>> b = torch.randn(4, 5)
@ -1720,7 +1721,7 @@ def _lu_impl(A, pivot=True, get_infos=False, out=None):
[ 3, 3, 3]], dtype=torch.int32) [ 3, 3, 3]], dtype=torch.int32)
>>> A_LU, pivots, info = torch.lu(A, get_infos=True) >>> A_LU, pivots, info = torch.lu(A, get_infos=True)
>>> if info.nonzero().size(0) == 0: >>> if info.nonzero().size(0) == 0:
... print('LU factorization succeeded for all samples!') ... print('LU factorization succeeded for all samples!')
LU factorization succeeded for all samples! LU factorization succeeded for all samples!
""" """
# If get_infos is True, then we don't need to check for errors and vice versa # If get_infos is True, then we don't need to check for errors and vice versa

View File

@ -144,6 +144,7 @@ class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
on those futures independently. on those futures independently.
Example:: Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
>>> def callback(fut): >>> def callback(fut):
... print(f"RPC return value is {fut.wait()}.") ... print(f"RPC return value is {fut.wait()}.")
>>> fut = torch.futures.Future() >>> fut = torch.futures.Future()
@ -191,8 +192,9 @@ class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
for handling completion/waiting on those futures independently. for handling completion/waiting on those futures independently.
Example:: Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
>>> def callback(fut): >>> def callback(fut):
... print(f"This will run after the future has finished.") ... print("This will run after the future has finished.")
... print(fut.wait()) ... print(fut.wait())
>>> fut = torch.futures.Future() >>> fut = torch.futures.Future()
>>> fut.add_done_callback(callback) >>> fut.add_done_callback(callback)
@ -223,6 +225,7 @@ class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
result (object): the result object of this ``Future``. result (object): the result object of this ``Future``.
Example:: Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
>>> import threading >>> import threading
>>> import time >>> import time
>>> def slow_set_future(fut, value): >>> def slow_set_future(fut, value):
@ -251,6 +254,7 @@ class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
result (BaseException): the exception for this ``Future``. result (BaseException): the exception for this ``Future``.
Example:: Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
>>> fut = torch.futures.Future() >>> fut = torch.futures.Future()
>>> fut.set_exception(ValueError("foo")) >>> fut.set_exception(ValueError("foo"))
>>> fut.wait() >>> fut.wait()
@ -281,6 +285,7 @@ def collect_all(futures: List[Future]) -> Future[List[Future]]:
in Futures. in Futures.
Example:: Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
>>> fut0 = torch.futures.Future() >>> fut0 = torch.futures.Future()
>>> fut1 = torch.futures.Future() >>> fut1 = torch.futures.Future()
>>> fut = torch.futures.collect_all([fut0, fut1]) >>> fut = torch.futures.collect_all([fut0, fut1])

View File

@ -36,10 +36,11 @@ class Dispatcher(object):
return self return self
return _ return _
class VarDispatcher(Dispatcher): class VarDispatcher(Dispatcher):
""" A dispatcher that calls functions with variable names """ A dispatcher that calls functions with variable names
>>> d = VarDispatcher('d')
>>> # xdoctest: +SKIP >>> # xdoctest: +SKIP
>>> d = VarDispatcher('d')
>>> x = var('x') >>> x = var('x')
>>> @d.register('inc', x) >>> @d.register('inc', x)
... def f(x): ... def f(x):
@ -58,8 +59,6 @@ class VarDispatcher(Dispatcher):
return func(**d) return func(**d)
global_namespace = {} # type: ignore[var-annotated] global_namespace = {} # type: ignore[var-annotated]

View File

@ -7,11 +7,11 @@ def unifiable(cls):
This uses the type and __dict__ or __slots__ attributes to define the This uses the type and __dict__ or __slots__ attributes to define the
nature of the term nature of the term
See Also: See Also:
>>> # xdoctest: +SKIP
>>> class A(object): >>> class A(object):
... def __init__(self, a, b): ... def __init__(self, a, b):
... self.a = a ... self.a = a
... self.b = b ... self.b = b
>>> # xdoctest: +SKIP
>>> unifiable(A) >>> unifiable(A)
<class 'unification.more.A'> <class 'unification.more.A'>
>>> x = var('x') >>> x = var('x')
@ -33,13 +33,13 @@ def unifiable(cls):
def reify_object(o, s): def reify_object(o, s):
""" Reify a Python object with a substitution """ Reify a Python object with a substitution
>>> # xdoctest: +SKIP
>>> class Foo(object): >>> class Foo(object):
... def __init__(self, a, b): ... def __init__(self, a, b):
... self.a = a ... self.a = a
... self.b = b ... self.b = b
... def __str__(self): ... def __str__(self):
... return "Foo(%s, %s)"%(str(self.a), str(self.b)) ... return "Foo(%s, %s)"%(str(self.a), str(self.b))
>>> # xdoctest: +SKIP
>>> x = var('x') >>> x = var('x')
>>> f = Foo(1, x) >>> f = Foo(1, x)
>>> print(f) >>> print(f)
@ -88,13 +88,13 @@ def _reify(o, s):
def unify_object(u, v, s): def unify_object(u, v, s):
""" Unify two Python objects """ Unify two Python objects
Unifies their type and ``__dict__`` attributes Unifies their type and ``__dict__`` attributes
>>> # xdoctest: +SKIP
>>> class Foo(object): >>> class Foo(object):
... def __init__(self, a, b): ... def __init__(self, a, b):
... self.a = a ... self.a = a
... self.b = b ... self.b = b
... def __str__(self): ... def __str__(self):
... return "Foo(%s, %s)"%(str(self.a), str(self.b)) ... return "Foo(%s, %s)"%(str(self.a), str(self.b))
>>> # xdoctest: +SKIP
>>> x = var('x') >>> x = var('x')
>>> f = Foo(1, x) >>> f = Foo(1, x)
>>> g = Foo(1, 2) >>> g = Foo(1, 2)
@ -110,6 +110,7 @@ def unify_object(u, v, s):
else: else:
return unify(u.__dict__, v.__dict__, s) return unify(u.__dict__, v.__dict__, s)
@dispatch(slice, slice, dict) @dispatch(slice, slice, dict)
def _unify(u, v, s): def _unify(u, v, s):
""" Unify a Python ``slice`` object """ """ Unify a Python ``slice`` object """

View File

@ -13,35 +13,37 @@ def dispatch(*types, **kwargs):
Collects implementations based on the function name. Ignores namespaces. Collects implementations based on the function name. Ignores namespaces.
If ambiguous type signatures occur a warning is raised when the function is If ambiguous type signatures occur a warning is raised when the function is
defined suggesting the additional method to break the ambiguity. defined suggesting the additional method to break the ambiguity.
Examples
-------- Example:
>>> @dispatch(int) >>> # xdoctest: +SKIP
... def f(x): >>> @dispatch(int)
... return x + 1 ... def f(x):
>>> @dispatch(float) ... return x + 1
... def f(x): >>> @dispatch(float)
... return x - 1 ... def f(x):
>>> f(3) ... return x - 1
4 >>> # xdoctest: +SKIP
>>> f(3.0) >>> f(3)
2.0 4
>>> # Specify an isolated namespace with the namespace keyword argument >>> f(3.0)
>>> my_namespace = {} 2.0
>>> @dispatch(int, namespace=my_namespace) >>> # Specify an isolated namespace with the namespace keyword argument
... def foo(x): >>> my_namespace = {}
... return x + 1 >>> @dispatch(int, namespace=my_namespace)
>>> # Dispatch on instance methods within classes ... def foo(x):
>>> class MyClass(object): ... return x + 1
... @dispatch(list) >>> # Dispatch on instance methods within classes
... def __init__(self, data): >>> class MyClass(object):
... self.data = data ... @dispatch(list)
... @dispatch(int) ... def __init__(self, data):
... def __init__(self, datum): ... self.data = data
... self.data = [datum] ... @dispatch(int)
>>> MyClass([1, 2, 3]).data ... def __init__(self, datum):
[1, 2, 3] ... self.data = [datum]
>>> MyClass(3).data >>> MyClass([1, 2, 3]).data
[3] [1, 2, 3]
>>> MyClass(3).data
[3]
""" """
namespace = kwargs.get('namespace', global_namespace) namespace = kwargs.get('namespace', global_namespace)

View File

@ -121,6 +121,7 @@ class Dispatcher(object):
def register(self, *types, **kwargs): def register(self, *types, **kwargs):
""" register dispatcher with new implementation """ register dispatcher with new implementation
>>> # xdoctest: +SKIP
>>> f = Dispatcher('f') >>> f = Dispatcher('f')
>>> @f.register(int) >>> @f.register(int)
... def inc(x): ... def inc(x):
@ -172,6 +173,7 @@ class Dispatcher(object):
def add(self, signature, func): def add(self, signature, func):
""" Add new types/method pair to dispatcher """ Add new types/method pair to dispatcher
>>> # xdoctest: +SKIP
>>> D = Dispatcher('add') >>> D = Dispatcher('add')
>>> D.add((int, int), lambda x, y: x + y) >>> D.add((int, int), lambda x, y: x + y)
>>> D.add((float, float), lambda x, y: x + y) >>> D.add((float, float), lambda x, y: x + y)

View File

@ -44,6 +44,7 @@ def isvariadic(obj):
Whether or not `obj` is variadic Whether or not `obj` is variadic
Examples Examples
-------- --------
>>> # xdoctest: +SKIP
>>> isvariadic(int) >>> isvariadic(int)
False False
>>> isvariadic(Variadic[int]) >>> isvariadic(Variadic[int])
@ -76,8 +77,8 @@ class Variadic(six.with_metaclass(VariadicSignatureMeta)):
representing a specific variadic signature. representing a specific variadic signature.
Examples Examples
-------- --------
>>> Variadic[int] # any number of int arguments
>>> # xdoctest: +SKIP >>> # xdoctest: +SKIP
>>> Variadic[int] # any number of int arguments
<class 'multipledispatch.variadic.Variadic[int]'> <class 'multipledispatch.variadic.Variadic[int]'>
>>> Variadic[(int, str)] # any number of one of int or str arguments >>> Variadic[(int, str)] # any number of one of int or str arguments
<class 'multipledispatch.variadic.Variadic[(int, str)]'> <class 'multipledispatch.variadic.Variadic[(int, str)]'>

View File

@ -7,6 +7,7 @@ __all__ = ('merge', 'merge_with', 'valmap', 'keymap', 'itemmap',
'valfilter', 'keyfilter', 'itemfilter', 'valfilter', 'keyfilter', 'itemfilter',
'assoc', 'dissoc', 'assoc_in', 'update_in', 'get_in') 'assoc', 'dissoc', 'assoc_in', 'update_in', 'get_in')
def _get_factory(f, kwargs): def _get_factory(f, kwargs):
factory = kwargs.pop('factory', dict) factory = kwargs.pop('factory', dict)
if kwargs: if kwargs:
@ -336,6 +337,7 @@ def get_in(keys, coll, default=None, no_default=False):
raise raise
return default return default
def getter(index): def getter(index):
if isinstance(index, list): if isinstance(index, list):
if len(index) == 1: if len(index) == 1:
@ -348,6 +350,7 @@ def getter(index):
else: else:
return operator.itemgetter(index) return operator.itemgetter(index)
def groupby(key, seq): def groupby(key, seq):
""" Group a collection by a key function """ Group a collection by a key function
@ -383,6 +386,7 @@ def groupby(key, seq):
rv[k] = v.__self__ # type: ignore[var-annotated, attr-defined] rv[k] = v.__self__ # type: ignore[var-annotated, attr-defined]
return rv return rv
def first(seq): def first(seq):
""" The first element in a sequence """ The first element in a sequence

View File

@ -36,8 +36,8 @@ def _toposort(edges):
edges - a dict of the form {a: {b, c}} where b and c depend on a edges - a dict of the form {a: {b, c}} where b and c depend on a
outputs: outputs:
L - an ordered list of nodes that satisfy the dependencies of edges L - an ordered list of nodes that satisfy the dependencies of edges
>>> _toposort({1: (2, 3), 2: (3, )})
>>> # xdoctest: +SKIP >>> # xdoctest: +SKIP
>>> _toposort({1: (2, 3), 2: (3, )})
[1, 2, 3] [1, 2, 3]
Closely follows the wikipedia page [2] Closely follows the wikipedia page [2]
[1] Kahn, Arthur B. (1962), "Topological sorting of large networks", [1] Kahn, Arthur B. (1962), "Topological sorting of large networks",

View File

@ -36,6 +36,7 @@ class Var(object):
def var(): def var():
return lambda *args: Var(*args) return lambda *args: Var(*args)
def vars(): def vars():
return lambda n: [var() for i in range(n)] return lambda n: [var() for i in range(n)]
@ -46,6 +47,7 @@ def isvar(v):
isvar isvar
@dispatch(object) # type: ignore[no-redef] @dispatch(object) # type: ignore[no-redef]
def isvar(o): def isvar(o):
return not not _glv and hashable(o) and o in _glv return not not _glv and hashable(o) and o in _glv
@ -53,23 +55,26 @@ def isvar(o):
@contextmanager @contextmanager
def variables(*variables): def variables(*variables):
""" Context manager for logic variables """
>>> from __future__ import with_statement Context manager for logic variables
>>> with variables(1):
... print(isvar(1)) Example:
True >>> # xdoctest: +SKIP("undefined vars")
>>> print(isvar(1)) >>> from __future__ import with_statement
False >>> with variables(1):
>>> # xdoctest: +SKIP("undefined vars") ... print(isvar(1))
>>> # Normal approach True
>>> from unification import unify >>> print(isvar(1))
>>> x = var('x') False
>>> unify(x, 1) >>> # Normal approach
{~x: 1} >>> from unification import unify
>>> # Context Manager approach >>> x = var('x')
>>> with variables('x'): >>> unify(x, 1)
... print(unify('x', 1)) {~x: 1}
{'x': 1} >>> # Context Manager approach
>>> with variables('x'):
... print(unify('x', 1))
{'x': 1}
""" """
old_global_logic_variables = _global_logic_variables.copy() old_global_logic_variables = _global_logic_variables.copy()
_global_logic_variables.update(set(variables)) _global_logic_variables.update(set(variables))

View File

@ -388,6 +388,7 @@ def list(github, force_reload=False, skip_validation=False, trust_repo=None):
list: The available callables entrypoint list: The available callables entrypoint
Example: Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
>>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True) >>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
""" """
repo_dir = _get_cache_or_reload(github, force_reload, trust_repo, "list", verbose=True, repo_dir = _get_cache_or_reload(github, force_reload, trust_repo, "list", verbose=True,
@ -440,6 +441,7 @@ def help(github, model, force_reload=False, skip_validation=False, trust_repo=No
Default is ``None`` and will eventually change to ``"check"`` in v1.14. Default is ``None`` and will eventually change to ``"check"`` in v1.14.
Example: Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
>>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True)) >>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
""" """
repo_dir = _get_cache_or_reload(github, force_reload, trust_repo, "help", verbose=True, repo_dir = _get_cache_or_reload(github, force_reload, trust_repo, "help", verbose=True,
@ -519,6 +521,7 @@ def load(repo_or_dir, model, *args, source='github', trust_repo=None, force_relo
``*args`` and ``**kwargs``. ``*args`` and ``**kwargs``.
Example: Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
>>> # from a github repo >>> # from a github repo
>>> repo = 'pytorch/vision' >>> repo = 'pytorch/vision'
>>> model = torch.hub.load(repo, 'resnet50', weights='ResNet50_Weights.IMAGENET1K_V1') >>> model = torch.hub.load(repo, 'resnet50', weights='ResNet50_Weights.IMAGENET1K_V1')
@ -586,6 +589,7 @@ def download_url_to_file(url, dst, hash_prefix=None, progress=True):
Default: True Default: True
Example: Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
>>> # xdoctest: +REQUIRES(POSIX) >>> # xdoctest: +REQUIRES(POSIX)
>>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file') >>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')
@ -694,6 +698,7 @@ def load_state_dict_from_url(
file_name (str, optional): name for the downloaded file. Filename from ``url`` will be used if not set. file_name (str, optional): name for the downloaded file. Filename from ``url`` will be used if not set.
Example: Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
>>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
""" """

View File

@ -14,6 +14,7 @@ _impls: Set[str] = set()
# prim is reserved by TorchScript interpreter # prim is reserved by TorchScript interpreter
_reserved_namespaces = ['prim'] _reserved_namespaces = ['prim']
class Library: class Library:
""" """
A class to create libraries that can be used to register new operators or A class to create libraries that can be used to register new operators or
@ -57,6 +58,7 @@ class Library:
name of the operator as inferred from the schema. name of the operator as inferred from the schema.
Example:: Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LIBRARY)
>>> my_lib = Library("foo", "DEF") >>> my_lib = Library("foo", "DEF")
>>> my_lib.define("sum(Tensor self) -> Tensor") >>> my_lib.define("sum(Tensor self) -> Tensor")
''' '''
@ -79,7 +81,7 @@ class Library:
>>> # xdoctest: +SKIP >>> # xdoctest: +SKIP
>>> my_lib = Library("aten", "IMPL") >>> my_lib = Library("aten", "IMPL")
>>> def div_cpu(self, other): >>> def div_cpu(self, other):
>>> return self * (1 / other) >>> return self * (1 / other)
>>> my_lib.impl("div.Tensor", "CPU") >>> my_lib.impl("div.Tensor", "CPU")
''' '''
if not callable(fn): if not callable(fn):
@ -105,7 +107,6 @@ class Library:
"'s behavior for {} dispatch key and {} namespace.". "'s behavior for {} dispatch key and {} namespace.".
format(name.split("::")[-1], dispatch_key, self.ns)) format(name.split("::")[-1], dispatch_key, self.ns))
if dispatch_key == "Meta": if dispatch_key == "Meta":
dispatcher_op_name = name dispatcher_op_name = name
if '::' not in dispatcher_op_name: if '::' not in dispatcher_op_name:
@ -135,6 +136,7 @@ class Library:
_impls.remove(key) _impls.remove(key)
del self.m del self.m
# decorator to register python functions for library ops # decorator to register python functions for library ops
# Note: this decorator API should remain consistent with `Library.impl` API # Note: this decorator API should remain consistent with `Library.impl` API
def impl(lib, name, dispatch_key=""): def impl(lib, name, dispatch_key=""):
@ -143,6 +145,7 @@ def impl(lib, name, dispatch_key=""):
return f return f
return wrap return wrap
def define(lib, schema, alias_analysis=""): def define(lib, schema, alias_analysis=""):
def wrap(f): def wrap(f):
name = lib.define(schema, alias_analysis) name = lib.define(schema, alias_analysis)

View File

@ -8,6 +8,7 @@ if TYPE_CHECKING:
STAT_EVENT = "torch.monitor.Stat" STAT_EVENT = "torch.monitor.Stat"
class TensorboardEventHandler: class TensorboardEventHandler:
""" """
TensorboardEventHandler is an event handler that will write known events to TensorboardEventHandler is an event handler that will write known events to
@ -16,11 +17,13 @@ class TensorboardEventHandler:
This currently only supports ``torch.monitor.Stat`` events which are logged This currently only supports ``torch.monitor.Stat`` events which are logged
as scalars. as scalars.
>>> # xdoctest: +REQUIRES(module:tensorboard) Example:
>>> from torch.utils.tensorboard import SummaryWriter >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_MONITOR)
>>> from torch.monitor import TensorboardEventHandler, register_event_handler >>> # xdoctest: +REQUIRES(module:tensorboard)
>>> writer = SummaryWriter("log_dir") >>> from torch.utils.tensorboard import SummaryWriter
>>> register_event_handler(TensorboardEventHandler(writer)) >>> from torch.monitor import TensorboardEventHandler, register_event_handler
>>> writer = SummaryWriter("log_dir")
>>> register_event_handler(TensorboardEventHandler(writer))
""" """
def __init__(self, writer: "SummaryWriter") -> None: def __init__(self, writer: "SummaryWriter") -> None:
""" """

View File

@ -2167,7 +2167,7 @@ def embedding(
>>> weights = torch.rand(10, 3) >>> weights = torch.rand(10, 3)
>>> weights[0, :].zero_() >>> weights[0, :].zero_()
>>> embedding_matrix = weights >>> embedding_matrix = weights
>>> input = torch.tensor([[0,2,0,5]]) >>> input = torch.tensor([[0, 2, 0, 5]])
>>> F.embedding(input, embedding_matrix, padding_idx=0) >>> F.embedding(input, embedding_matrix, padding_idx=0)
tensor([[[ 0.0000, 0.0000, 0.0000], tensor([[[ 0.0000, 0.0000, 0.0000],
[ 0.5609, 0.5384, 0.8720], [ 0.5609, 0.5384, 0.8720],
@ -2287,8 +2287,8 @@ def embedding_bag(
>>> # an Embedding module containing 10 tensors of size 3 >>> # an Embedding module containing 10 tensors of size 3
>>> embedding_matrix = torch.rand(10, 3) >>> embedding_matrix = torch.rand(10, 3)
>>> # a batch of 2 samples of 4 indices each >>> # a batch of 2 samples of 4 indices each
>>> input = torch.tensor([1,2,4,5,4,3,2,9]) >>> input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9])
>>> offsets = torch.tensor([0,4]) >>> offsets = torch.tensor([0, 4])
>>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> F.embedding_bag(input, embedding_matrix, offsets) >>> F.embedding_bag(input, embedding_matrix, offsets)
tensor([[ 0.3397, 0.3552, 0.5545], tensor([[ 0.3397, 0.3552, 0.5545],
@ -2297,7 +2297,7 @@ def embedding_bag(
>>> # example with padding_idx >>> # example with padding_idx
>>> embedding_matrix = torch.rand(10, 3) >>> embedding_matrix = torch.rand(10, 3)
>>> input = torch.tensor([2, 2, 2, 2, 4, 3, 2, 9]) >>> input = torch.tensor([2, 2, 2, 2, 4, 3, 2, 9])
>>> offsets = torch.tensor([0,4]) >>> offsets = torch.tensor([0, 4])
>>> F.embedding_bag(input, embedding_matrix, offsets, padding_idx=2, mode='sum') >>> F.embedding_bag(input, embedding_matrix, offsets, padding_idx=2, mode='sum')
tensor([[ 0.0000, 0.0000, 0.0000], tensor([[ 0.0000, 0.0000, 0.0000],
[-0.7082, 3.2145, -2.6251]]) [-0.7082, 3.2145, -2.6251]])
@ -2616,7 +2616,7 @@ def ctc_loss(
>>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_() >>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_()
>>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long) >>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
>>> input_lengths = torch.full((16,), 50, dtype=torch.long) >>> input_lengths = torch.full((16,), 50, dtype=torch.long)
>>> target_lengths = torch.randint(10,30,(16,), dtype=torch.long) >>> target_lengths = torch.randint(10, 30, (16,), dtype=torch.long)
>>> loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths) >>> loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths)
>>> loss.backward() >>> loss.backward()
""" """

View File

@ -21,8 +21,8 @@ def conv1d_input(input_size, weight, grad_output, stride=1, padding=0, dilation=
Examples:: Examples::
>>> input = torch.randn(1,1,3, requires_grad=True) >>> input = torch.randn(1, 1, 3, requires_grad=True)
>>> weight = torch.randn(1,1,1, requires_grad=True) >>> weight = torch.randn(1, 1, 1, requires_grad=True)
>>> output = F.conv1d(input, weight) >>> output = F.conv1d(input, weight)
>>> grad_output = torch.randn(output.shape) >>> grad_output = torch.randn(output.shape)
>>> grad_input = torch.autograd.grad(output, input, grad_output) >>> grad_input = torch.autograd.grad(output, input, grad_output)
@ -51,8 +51,8 @@ def conv1d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation
Examples:: Examples::
>>> input = torch.randn(1,1,3, requires_grad=True) >>> input = torch.randn(1, 1, 3, requires_grad=True)
>>> weight = torch.randn(1,1,1, requires_grad=True) >>> weight = torch.randn(1, 1, 1, requires_grad=True)
>>> output = F.conv1d(input, weight) >>> output = F.conv1d(input, weight)
>>> grad_output = torch.randn(output.shape) >>> grad_output = torch.randn(output.shape)
>>> # xdoctest: +SKIP >>> # xdoctest: +SKIP
@ -84,8 +84,8 @@ def conv2d_input(input_size, weight, grad_output, stride=1, padding=0, dilation=
Examples:: Examples::
>>> input = torch.randn(1,1,3,3, requires_grad=True) >>> input = torch.randn(1, 1, 3, 3, requires_grad=True)
>>> weight = torch.randn(1,1,1,2, requires_grad=True) >>> weight = torch.randn(1, 1, 1, 2, requires_grad=True)
>>> output = F.conv2d(input, weight) >>> output = F.conv2d(input, weight)
>>> grad_output = torch.randn(output.shape) >>> grad_output = torch.randn(output.shape)
>>> grad_input = torch.autograd.grad(output, input, grad_output) >>> grad_input = torch.autograd.grad(output, input, grad_output)
@ -114,8 +114,8 @@ def conv2d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation
Examples:: Examples::
>>> input = torch.randn(1,1,3,3, requires_grad=True) >>> input = torch.randn(1, 1, 3, 3, requires_grad=True)
>>> weight = torch.randn(1,1,1,2, requires_grad=True) >>> weight = torch.randn(1, 1, 1, 2, requires_grad=True)
>>> output = F.conv2d(input, weight) >>> output = F.conv2d(input, weight)
>>> grad_output = torch.randn(output.shape) >>> grad_output = torch.randn(output.shape)
>>> # xdoctest: +SKIP >>> # xdoctest: +SKIP

View File

@ -14,6 +14,7 @@ __all__ = ['Threshold', 'ReLU', 'RReLU', 'Hardtanh', 'ReLU6', 'Sigmoid', 'Hardsi
'LogSigmoid', 'Softplus', 'Softshrink', 'MultiheadAttention', 'PReLU', 'Softsign', 'Tanhshrink', 'LogSigmoid', 'Softplus', 'Softshrink', 'MultiheadAttention', 'PReLU', 'Softsign', 'Tanhshrink',
'Softmin', 'Softmax', 'Softmax2d', 'LogSoftmax'] 'Softmin', 'Softmax', 'Softmax2d', 'LogSoftmax']
class Threshold(Module): class Threshold(Module):
r"""Thresholds each element of the input Tensor. r"""Thresholds each element of the input Tensor.
@ -89,7 +90,7 @@ class ReLU(Module):
>>> m = nn.ReLU() >>> m = nn.ReLU()
>>> input = torch.randn(2).unsqueeze(0) >>> input = torch.randn(2).unsqueeze(0)
>>> output = torch.cat((m(input),m(-input))) >>> output = torch.cat((m(input), m(-input)))
""" """
__constants__ = ['inplace'] __constants__ = ['inplace']
inplace: bool inplace: bool

View File

@ -625,6 +625,7 @@ class SyncBatchNorm(_BatchNorm):
Examples:: Examples::
>>> # xdoctest: +SKIP
>>> # With Learnable Parameters >>> # With Learnable Parameters
>>> m = nn.SyncBatchNorm(100) >>> m = nn.SyncBatchNorm(100)
>>> # creating process group (optional) >>> # creating process group (optional)
@ -634,7 +635,6 @@ class SyncBatchNorm(_BatchNorm):
>>> # Note: every rank calls into new_group for every >>> # Note: every rank calls into new_group for every
>>> # process group created, even if that rank is not >>> # process group created, even if that rank is not
>>> # part of the group. >>> # part of the group.
>>> # xdoctest: +SKIP
>>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]] >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
>>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1] >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
>>> # Without Learnable Parameters >>> # Without Learnable Parameters

View File

@ -343,14 +343,14 @@ class GaussianNLLLoss(_Loss):
>>> loss = nn.GaussianNLLLoss() >>> loss = nn.GaussianNLLLoss()
>>> input = torch.randn(5, 2, requires_grad=True) >>> input = torch.randn(5, 2, requires_grad=True)
>>> target = torch.randn(5, 2) >>> target = torch.randn(5, 2)
>>> var = torch.ones(5, 2, requires_grad=True) #heteroscedastic >>> var = torch.ones(5, 2, requires_grad=True) # heteroscedastic
>>> output = loss(input, target, var) >>> output = loss(input, target, var)
>>> output.backward() >>> output.backward()
>>> loss = nn.GaussianNLLLoss() >>> loss = nn.GaussianNLLLoss()
>>> input = torch.randn(5, 2, requires_grad=True) >>> input = torch.randn(5, 2, requires_grad=True)
>>> target = torch.randn(5, 2) >>> target = torch.randn(5, 2)
>>> var = torch.ones(5, 1, requires_grad=True) #homoscedastic >>> var = torch.ones(5, 1, requires_grad=True) # homoscedastic
>>> output = loss(input, target, var) >>> output = loss(input, target, var)
>>> output.backward() >>> output.backward()

View File

@ -2082,8 +2082,8 @@ class Module:
>>> # xdoctest: +SKIP("undefined vars") >>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters(): >>> for name, param in self.named_parameters():
>>> if name in ['bias']: >>> if name in ['bias']:
>>> print(param.size()) >>> print(param.size())
""" """
gen = self._named_members( gen = self._named_members(
@ -2133,8 +2133,8 @@ class Module:
>>> # xdoctest: +SKIP("undefined vars") >>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers(): >>> for name, buf in self.named_buffers():
>>> if name in ['running_var']: >>> if name in ['running_var']:
>>> print(buf.size()) >>> print(buf.size())
""" """
gen = self._named_members( gen = self._named_members(

View File

@ -228,7 +228,7 @@ class MaxPool3d(_MaxPoolNd):
>>> m = nn.MaxPool3d(3, stride=2) >>> m = nn.MaxPool3d(3, stride=2)
>>> # pool of non-square window >>> # pool of non-square window
>>> m = nn.MaxPool3d((3, 2, 2), stride=(2, 1, 2)) >>> m = nn.MaxPool3d((3, 2, 2), stride=(2, 1, 2))
>>> input = torch.randn(20, 16, 50,44, 31) >>> input = torch.randn(20, 16, 50, 44, 31)
>>> output = m(input) >>> output = m(input)
.. _link: .. _link:
@ -524,7 +524,7 @@ class AvgPool1d(_AvgPoolNd):
>>> # pool with window of size=3, stride=2 >>> # pool with window of size=3, stride=2
>>> m = nn.AvgPool1d(3, stride=2) >>> m = nn.AvgPool1d(3, stride=2)
>>> m(torch.tensor([[[1.,2,3,4,5,6,7]]])) >>> m(torch.tensor([[[1., 2, 3, 4, 5, 6, 7]]]))
tensor([[[2., 4., 6.]]]) tensor([[[2., 4., 6.]]])
""" """
@ -688,7 +688,7 @@ class AvgPool3d(_AvgPoolNd):
>>> m = nn.AvgPool3d(3, stride=2) >>> m = nn.AvgPool3d(3, stride=2)
>>> # pool of non-square window >>> # pool of non-square window
>>> m = nn.AvgPool3d((3, 2, 2), stride=(2, 1, 2)) >>> m = nn.AvgPool3d((3, 2, 2), stride=(2, 1, 2))
>>> input = torch.randn(20, 16, 50,44, 31) >>> input = torch.randn(20, 16, 50, 44, 31)
>>> output = m(input) >>> output = m(input)
""" """
__constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad', 'divisor_override'] __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad', 'divisor_override']
@ -1043,7 +1043,7 @@ class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd):
Examples: Examples:
>>> # target output size of 5x7 >>> # target output size of 5x7
>>> m = nn.AdaptiveMaxPool2d((5,7)) >>> m = nn.AdaptiveMaxPool2d((5, 7))
>>> input = torch.randn(1, 64, 8, 9) >>> input = torch.randn(1, 64, 8, 9)
>>> output = m(input) >>> output = m(input)
>>> # target output size of 7x7 (square) >>> # target output size of 7x7 (square)
@ -1086,7 +1086,7 @@ class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd):
Examples: Examples:
>>> # target output size of 5x7x9 >>> # target output size of 5x7x9
>>> m = nn.AdaptiveMaxPool3d((5,7,9)) >>> m = nn.AdaptiveMaxPool3d((5, 7, 9))
>>> input = torch.randn(1, 64, 8, 9, 10) >>> input = torch.randn(1, 64, 8, 9, 10)
>>> output = m(input) >>> output = m(input)
>>> # target output size of 7x7x7 (cube) >>> # target output size of 7x7x7 (cube)
@ -1164,7 +1164,7 @@ class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd):
Examples: Examples:
>>> # target output size of 5x7 >>> # target output size of 5x7
>>> m = nn.AdaptiveAvgPool2d((5,7)) >>> m = nn.AdaptiveAvgPool2d((5, 7))
>>> input = torch.randn(1, 64, 8, 9) >>> input = torch.randn(1, 64, 8, 9)
>>> output = m(input) >>> output = m(input)
>>> # target output size of 7x7 (square) >>> # target output size of 7x7 (square)
@ -1203,7 +1203,7 @@ class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd):
Examples: Examples:
>>> # target output size of 5x7x9 >>> # target output size of 5x7x9
>>> m = nn.AdaptiveAvgPool3d((5,7,9)) >>> m = nn.AdaptiveAvgPool3d((5, 7, 9))
>>> input = torch.randn(1, 64, 8, 9, 10) >>> input = torch.randn(1, 64, 8, 9, 10)
>>> output = m(input) >>> output = m(input)
>>> # target output size of 7x7x7 (cube) >>> # target output size of 7x7x7 (cube)

View File

@ -23,10 +23,12 @@ _rnn_impls = {
def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
return tensor.index_select(dim, permutation) return tensor.index_select(dim, permutation)
def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
warnings.warn("apply_permutation is deprecated, please use tensor.index_select(dim, permutation) instead") warnings.warn("apply_permutation is deprecated, please use tensor.index_select(dim, permutation) instead")
return _apply_permutation(tensor, permutation, dim) return _apply_permutation(tensor, permutation, dim)
class RNNBase(Module): class RNNBase(Module):
__constants__ = ['mode', 'input_size', 'hidden_size', 'num_layers', 'bias', __constants__ = ['mode', 'input_size', 'hidden_size', 'num_layers', 'bias',
'batch_first', 'dropout', 'bidirectional', 'proj_size'] 'batch_first', 'dropout', 'bidirectional', 'proj_size']
@ -1203,9 +1205,9 @@ class LSTMCell(RNNCellBase):
Examples:: Examples::
>>> rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size) >>> rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size)
>>> input = torch.randn(2, 3, 10) # (time_steps, batch, input_size) >>> input = torch.randn(2, 3, 10) # (time_steps, batch, input_size)
>>> hx = torch.randn(3, 20) # (batch, hidden_size) >>> hx = torch.randn(3, 20) # (batch, hidden_size)
>>> cx = torch.randn(3, 20) >>> cx = torch.randn(3, 20)
>>> output = [] >>> output = []
>>> for i in range(input.size()[0]): >>> for i in range(input.size()[0]):

View File

@ -68,7 +68,7 @@ class Embedding(Module):
>>> # an Embedding module containing 10 tensors of size 3 >>> # an Embedding module containing 10 tensors of size 3
>>> embedding = nn.Embedding(10, 3) >>> embedding = nn.Embedding(10, 3)
>>> # a batch of 2 samples of 4 indices each >>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) >>> input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
>>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> embedding(input) >>> embedding(input)
tensor([[[-0.0251, -1.6902, 0.7172], tensor([[[-0.0251, -1.6902, 0.7172],
@ -84,7 +84,7 @@ class Embedding(Module):
>>> # example with padding_idx >>> # example with padding_idx
>>> embedding = nn.Embedding(10, 3, padding_idx=0) >>> embedding = nn.Embedding(10, 3, padding_idx=0)
>>> input = torch.LongTensor([[0,2,0,5]]) >>> input = torch.LongTensor([[0, 2, 0, 5]])
>>> embedding(input) >>> embedding(input)
tensor([[[ 0.0000, 0.0000, 0.0000], tensor([[[ 0.0000, 0.0000, 0.0000],
[ 0.1535, -2.0309, 0.9315], [ 0.1535, -2.0309, 0.9315],
@ -279,8 +279,8 @@ class EmbeddingBag(Module):
>>> # an EmbeddingBag module containing 10 tensors of size 3 >>> # an EmbeddingBag module containing 10 tensors of size 3
>>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum') >>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum')
>>> # a batch of 2 samples of 4 indices each >>> # a batch of 2 samples of 4 indices each
>>> input = torch.tensor([1,2,4,5,4,3,2,9], dtype=torch.long) >>> input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
>>> offsets = torch.tensor([0,4], dtype=torch.long) >>> offsets = torch.tensor([0, 4], dtype=torch.long)
>>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> embedding_sum(input, offsets) >>> embedding_sum(input, offsets)
tensor([[-0.8861, -5.4350, -0.0523], tensor([[-0.8861, -5.4350, -0.0523],
@ -289,7 +289,7 @@ class EmbeddingBag(Module):
>>> # Example with padding_idx >>> # Example with padding_idx
>>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum', padding_idx=2) >>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum', padding_idx=2)
>>> input = torch.tensor([2, 2, 2, 2, 4, 3, 2, 9], dtype=torch.long) >>> input = torch.tensor([2, 2, 2, 2, 4, 3, 2, 9], dtype=torch.long)
>>> offsets = torch.tensor([0,4], dtype=torch.long) >>> offsets = torch.tensor([0, 4], dtype=torch.long)
>>> embedding_sum(input, offsets) >>> embedding_sum(input, offsets)
tensor([[ 0.0000, 0.0000, 0.0000], tensor([[ 0.0000, 0.0000, 0.0000],
[-0.7082, 3.2145, -2.6251]]) [-0.7082, 3.2145, -2.6251]])

View File

@ -4,6 +4,7 @@ from typing import List, Dict, Any
__all__ = ['consume_prefix_in_state_dict_if_present'] __all__ = ['consume_prefix_in_state_dict_if_present']
def _ntuple(n, name="parse"): def _ntuple(n, name="parse"):
def parse(x): def parse(x):
if isinstance(x, collections.abc.Iterable): if isinstance(x, collections.abc.Iterable):

View File

@ -1052,8 +1052,8 @@ class DistributedDataParallel(Module, Joinable):
>>> # xdoctest: +SKIP("undefined variables") >>> # xdoctest: +SKIP("undefined variables")
>>> ddp = torch.nn.parallel.DistributedDataParallel(model, pg) >>> ddp = torch.nn.parallel.DistributedDataParallel(model, pg)
>>> with ddp.no_sync(): >>> with ddp.no_sync():
>>> for input in inputs: >>> for input in inputs:
>>> ddp(input).backward() # no synchronization, accumulate grads >>> ddp(input).backward() # no synchronization, accumulate grads
>>> ddp(another_input).backward() # synchronize grads >>> ddp(another_input).backward() # synchronize grads
.. warning:: .. warning::
@ -1375,6 +1375,7 @@ class DistributedDataParallel(Module, Joinable):
Example:: Example::
>>> # xdoctest: +SKIP("Distributed")
>>> import torch >>> import torch
>>> import torch.distributed as dist >>> import torch.distributed as dist
>>> import os >>> import os
@ -1548,28 +1549,26 @@ class DistributedDataParallel(Module, Joinable):
Example:: Example::
Below is an example of a noop hook that returns the same tensor. Below is an example of a noop hook that returns the same tensor.
>>> # xdoctest: +SKIP('undefined name')
>>> def noop(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]: >>> def noop(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
>>> fut = torch.futures.Future() >>> fut = torch.futures.Future()
>>> fut.set_result(bucket.buffer()) >>> fut.set_result(bucket.buffer())
>>> return fut >>> return fut
>>> # xdoctest: +SKIP('undefined name')
>>> ddp.register_comm_hook(state=None, hook=noop) >>> ddp.register_comm_hook(state=None, hook=noop)
Example:: Example::
Below is an example of a Parallel SGD algorithm where gradients are encoded before Below is an example of a Parallel SGD algorithm where gradients are encoded before
allreduce, and then decoded after allreduce. allreduce, and then decoded after allreduce.
>>> # xdoctest: +SKIP('undefined name')
>>> def encode_and_decode(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]: >>> def encode_and_decode(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
>>> encoded_tensor = encode(bucket.buffer()) # encode gradients >>> encoded_tensor = encode(bucket.buffer()) # encode gradients
>>> fut = torch.distributed.all_reduce(encoded_tensor).get_future() >>> fut = torch.distributed.all_reduce(encoded_tensor).get_future()
>>> # Define the then callback to decode. >>> # Define the then callback to decode.
>>> def decode(fut): >>> def decode(fut):
>>> decoded_tensor = decode(fut.value()[0]) # decode gradients >>> decoded_tensor = decode(fut.value()[0]) # decode gradients
>>> return decoded_tensor >>> return decoded_tensor
>>> return fut.then(decode) >>> return fut.then(decode)
>>> # xdoctest: +SKIP('undefined name')
>>> ddp.register_comm_hook(state=None, hook=encode_and_decode) >>> ddp.register_comm_hook(state=None, hook=encode_and_decode)
""" """
self._check_comm_hook(hook) self._check_comm_hook(hook)

View File

@ -9,6 +9,7 @@ from .expanded_weights_utils import \
THRESHOLD = 32 THRESHOLD = 32
def conv_picker(func, conv1dOpt, conv2dOpt, conv3dOpt): def conv_picker(func, conv1dOpt, conv2dOpt, conv3dOpt):
if func == F.conv1d: if func == F.conv1d:
return conv1dOpt return conv1dOpt
@ -18,6 +19,7 @@ def conv_picker(func, conv1dOpt, conv2dOpt, conv3dOpt):
assert func == F.conv3d assert func == F.conv3d
return conv3dOpt return conv3dOpt
def conv_args_and_kwargs(kwarg_names, expanded_args_and_kwargs): def conv_args_and_kwargs(kwarg_names, expanded_args_and_kwargs):
args = expanded_args_and_kwargs[:len(expanded_args_and_kwargs) - len(kwarg_names)] args = expanded_args_and_kwargs[:len(expanded_args_and_kwargs) - len(kwarg_names)]
kwargs = expanded_args_and_kwargs[len(expanded_args_and_kwargs) - len(kwarg_names):] kwargs = expanded_args_and_kwargs[len(expanded_args_and_kwargs) - len(kwarg_names):]
@ -25,6 +27,7 @@ def conv_args_and_kwargs(kwarg_names, expanded_args_and_kwargs):
return conv_normalizer(*args, **kwargs) return conv_normalizer(*args, **kwargs)
def conv_normalizer(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): def conv_normalizer(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
return (input, weight), {'bias': bias, 'stride': stride, 'padding': padding, 'dilation': dilation, 'groups': groups} return (input, weight), {'bias': bias, 'stride': stride, 'padding': padding, 'dilation': dilation, 'groups': groups}
@ -124,6 +127,7 @@ def conv_backward(func, ctx, grad_output):
set_grad_sample_if_exists(ctx.bias, lambda _: grad_output.reshape(*grad_output.shape[:2], -1).sum(dim=2)) set_grad_sample_if_exists(ctx.bias, lambda _: grad_output.reshape(*grad_output.shape[:2], -1).sum(dim=2))
return tuple(results) return tuple(results)
def conv_unfold_weight_grad_sample(input, grad_output, weight_shape, kernel_size, stride, padding, dilation, groups, func): def conv_unfold_weight_grad_sample(input, grad_output, weight_shape, kernel_size, stride, padding, dilation, groups, func):
n = input.shape[0] n = input.shape[0]
in_channels = input.shape[1] in_channels = input.shape[1]
@ -158,6 +162,7 @@ def conv_unfold_weight_grad_sample(input, grad_output, weight_shape, kernel_size
weight_grad_sample = weight_grad_sample.view(shape) weight_grad_sample = weight_grad_sample.view(shape)
return weight_grad_sample return weight_grad_sample
def conv_group_weight_grad_sample(input, grad_output, weight_shape, stride, padding, dilation, batch_size, func): def conv_group_weight_grad_sample(input, grad_output, weight_shape, stride, padding, dilation, batch_size, func):
I = input.shape[1] I = input.shape[1]
O = grad_output.shape[1] O = grad_output.shape[1]
@ -195,9 +200,9 @@ def unfold3d(
A tensor of shape ``(B, C * np.product(kernel_size), L)``, where L - output spatial dimensions. A tensor of shape ``(B, C * np.product(kernel_size), L)``, where L - output spatial dimensions.
See :class:`torch.nn.Unfold` for more details See :class:`torch.nn.Unfold` for more details
Example: Example:
>>> B, C, D, H, W = 3, 4, 5, 6, 7
>>> tensor = torch.arange(1, B*C*D*H*W + 1.).view(B, C, D, H, W)
>>> # xdoctest: +SKIP >>> # xdoctest: +SKIP
>>> B, C, D, H, W = 3, 4, 5, 6, 7
>>> tensor = torch.arange(1, B * C * D * H * W + 1.).view(B, C, D, H, W)
>>> unfold3d(tensor, kernel_size=2, padding=0, stride=1).shape >>> unfold3d(tensor, kernel_size=2, padding=0, stride=1).shape
torch.Size([3, 32, 120]) torch.Size([3, 32, 120])
""" """

View File

@ -6,6 +6,7 @@ from torch.nn.utils._expanded_weights.expanded_weights_impl import ExpandedWeigh
from torch.utils._pytree import tree_flatten from torch.utils._pytree import tree_flatten
# dependency on `functional_call` means that this can't be exposed in utils # dependency on `functional_call` means that this can't be exposed in utils
# without creating circular dependency # without creating circular dependency
def call_for_per_sample_grads(module, *, batch_size=None, loss_reduction="sum"): def call_for_per_sample_grads(module, *, batch_size=None, loss_reduction="sum"):
@ -28,17 +29,17 @@ def call_for_per_sample_grads(module, *, batch_size=None, loss_reduction="sum"):
running mean across a batch. Must be "mean" or "sum". Default: "sum" running mean across a batch. Must be "mean" or "sum". Default: "sum"
Examples:: Examples::
>>> # xdoctest: +SKIP
>>> model = nn.Linear(4, 3) >>> model = nn.Linear(4, 3)
>>> batched_input = torch.randn(5, 4) # batch size of 5 >>> batched_input = torch.randn(5, 4) # batch size of 5
>>> # xdoctest: +SKIP
>>> res = call_for_per_sample_grads(model)(batched_input).sum() >>> res = call_for_per_sample_grads(model)(batched_input).sum()
>>> res.backward() >>> res.backward()
>>> assert model.weight.shape == (3, 4) >>> assert model.weight.shape == (3, 4)
>>> assert model.weight.grad_sample.shape == (5, 3, 4) >>> assert model.weight.grad_sample.shape == (5, 3, 4)
>>> assert model.weight.grad == None >>> assert model.weight.grad is None
>>> assert model.bias.shape == (3,) >>> assert model.bias.shape == (3,)
>>> assert model.bias.grad_sample.shape == (5, 3) >>> assert model.bias.grad_sample.shape == (5, 3)
>>> assert model.bias.grad == None >>> assert model.bias.grad is None
An example using "mean" loss reduction. The grad_sample fields will be scaled by batch_size from what they would be An example using "mean" loss reduction. The grad_sample fields will be scaled by batch_size from what they would be
if we ran the same code with loss_reduction="sum". This is because the mean at the end will scale all if we ran the same code with loss_reduction="sum". This is because the mean at the end will scale all

View File

@ -28,8 +28,8 @@ def skip_init(module_cls, *args, **kwargs):
Example:: Example::
>>> import torch
>>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> import torch
>>> m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1) >>> m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1)
>>> m.weight >>> m.weight
Parameter containing: Parameter containing:

View File

@ -1,5 +1,6 @@
import torch import torch
def convert_conv2d_weight_memory_format(module, memory_format): def convert_conv2d_weight_memory_format(module, memory_format):
r"""Convert ``memory_format`` of ``nn.Conv2d.weight`` to ``memory_format`` r"""Convert ``memory_format`` of ``nn.Conv2d.weight`` to ``memory_format``
The conversion recursively applies to nested ``nn.Module``, including ``module``. The conversion recursively applies to nested ``nn.Module``, including ``module``.
@ -50,6 +51,7 @@ def convert_conv2d_weight_memory_format(module, memory_format):
The original module with updated ``nn.Conv2d`` The original module with updated ``nn.Conv2d``
Example: Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG) >>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG)
>>> input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float16, device="cuda") >>> input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float16, device="cuda")
>>> model = nn.Sequential( >>> model = nn.Sequential(

View File

@ -1002,9 +1002,9 @@ def ln_structured(module, name, amount, n, dim, importance_scores=None):
module (nn.Module): modified (i.e. pruned) version of the input module module (nn.Module): modified (i.e. pruned) version of the input module
Examples: Examples:
>>> # xdoctest: +SKIP >>> from torch.nn.utils import prune
>>> m = prune.ln_structured( >>> m = prune.ln_structured(
... nn.Conv2d(5, 3, 2), 'weight', amount=0.3, dim=1, n=float('-inf') ... nn.Conv2d(5, 3, 2), 'weight', amount=0.3, dim=1, n=float('-inf')
... ) ... )
""" """
LnStructured.apply( LnStructured.apply(
@ -1055,7 +1055,8 @@ def global_unstructured(parameters, pruning_method, importance_scores=None, **kw
scope of global pruning to unstructured methods. scope of global pruning to unstructured methods.
Examples: Examples:
>>> # xdoctest: +SKIP >>> from torch.nn.utils import prune
>>> from collections import OrderedDict
>>> net = nn.Sequential(OrderedDict([ >>> net = nn.Sequential(OrderedDict([
... ('first', nn.Linear(10, 4)), ... ('first', nn.Linear(10, 4)),
... ('second', nn.Linear(4, 1)), ... ('second', nn.Linear(4, 1)),
@ -1070,7 +1071,7 @@ def global_unstructured(parameters, pruning_method, importance_scores=None, **kw
... amount=10, ... amount=10,
... ) ... )
>>> print(sum(torch.nn.utils.parameters_to_vector(net.buffers()) == 0)) >>> print(sum(torch.nn.utils.parameters_to_vector(net.buffers()) == 0))
tensor(10, dtype=torch.uint8) tensor(10)
""" """
# ensure parameters is a list or generator of tuples # ensure parameters is a list or generator of tuples
@ -1156,7 +1157,7 @@ def custom_from_mask(module, name, mask):
module (nn.Module): modified (i.e. pruned) version of the input module module (nn.Module): modified (i.e. pruned) version of the input module
Examples: Examples:
>>> # xdoctest: +SKIP >>> from torch.nn.utils import prune
>>> m = prune.custom_from_mask( >>> m = prune.custom_from_mask(
... nn.Linear(5, 3), name='bias', mask=torch.tensor([0, 1, 0]) ... nn.Linear(5, 3), name='bias', mask=torch.tensor([0, 1, 0])
... ) ... )
@ -1211,8 +1212,8 @@ def is_pruned(module):
binary answer to whether ``module`` is pruned. binary answer to whether ``module`` is pruned.
Examples: Examples:
>>> from torch.nn.utils import prune
>>> m = nn.Linear(5, 7) >>> m = nn.Linear(5, 7)
>>> # xdoctest: +SKIP
>>> print(prune.is_pruned(m)) >>> print(prune.is_pruned(m))
False False
>>> prune.random_unstructured(m, name='weight', amount=0.2) >>> prune.random_unstructured(m, name='weight', amount=0.2)

View File

@ -20,6 +20,7 @@ PackedSequence_.__annotations__ = {'data': torch.Tensor, 'batch_sizes': torch.Te
'sorted_indices': Optional[torch.Tensor], 'sorted_indices': Optional[torch.Tensor],
'unsorted_indices': Optional[torch.Tensor]} 'unsorted_indices': Optional[torch.Tensor]}
def bind(optional, fn): def bind(optional, fn):
if optional is None: if optional is None:
return None return None
@ -279,7 +280,7 @@ def pad_packed_sequence(
Example: Example:
>>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence >>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
>>> seq = torch.tensor([[1,2,0], [3,0,0], [4,5,6]]) >>> seq = torch.tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]])
>>> lens = [2, 1, 3] >>> lens = [2, 1, 3]
>>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False) >>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False)
>>> packed >>> packed
@ -464,8 +465,8 @@ def pack_sequence(sequences: List[Tensor], enforce_sorted: bool = True) -> Packe
Example: Example:
>>> from torch.nn.utils.rnn import pack_sequence >>> from torch.nn.utils.rnn import pack_sequence
>>> a = torch.tensor([1,2,3]) >>> a = torch.tensor([1, 2, 3])
>>> b = torch.tensor([4,5]) >>> b = torch.tensor([4, 5])
>>> c = torch.tensor([6]) >>> c = torch.tensor([6])
>>> pack_sequence([a, b, c]) >>> pack_sequence([a, b, c])
PackedSequence(data=tensor([1, 4, 6, 2, 5, 3]), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None) PackedSequence(data=tensor([1, 4, 6, 2, 5, 3]), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None)
@ -492,8 +493,8 @@ def unpack_sequence(packed_sequences: PackedSequence) -> List[Tensor]:
Example: Example:
>>> from torch.nn.utils.rnn import pack_sequence, unpack_sequence >>> from torch.nn.utils.rnn import pack_sequence, unpack_sequence
>>> a = torch.tensor([1,2,3]) >>> a = torch.tensor([1, 2, 3])
>>> b = torch.tensor([4,5]) >>> b = torch.tensor([4, 5])
>>> c = torch.tensor([6]) >>> c = torch.tensor([6])
>>> sequences = [a, b, c] >>> sequences = [a, b, c]
>>> print(sequences) >>> print(sequences)

View File

@ -18,7 +18,8 @@ class DiagnosticEngine:
Examples: Examples:
Step 1: Create a set of rules. Step 1: Create a set of rules.
>>> rules = infra.RuleCollection.from_list( >>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
>>> rules = infra.RuleCollection.custom_collection_from_list(
... "CustomRuleCollection", ... "CustomRuleCollection",
... [ ... [
... infra.Rule( ... infra.Rule(
@ -34,6 +35,7 @@ class DiagnosticEngine:
Step 3: Start a new diagnostic context. Step 3: Start a new diagnostic context.
>>> with engine.create_diagnostic_context("torch.onnx.export", version="1.0") as context: >>> with engine.create_diagnostic_context("torch.onnx.export", version="1.0") as context:
... ...
Step 4: Add diagnostics in your code. Step 4: Add diagnostics in your code.
... context.diagnose(rules.rule1, infra.Level.ERROR) ... context.diagnose(rules.rule1, infra.Level.ERROR)

View File

@ -63,6 +63,8 @@ class JitScalarType(enum.IntEnum):
Use ``JitScalarType`` to convert from torch and JIT scalar types to ONNX scalar types. Use ``JitScalarType`` to convert from torch and JIT scalar types to ONNX scalar types.
Examples: Examples:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
>>> # xdoctest: +IGNORE_WANT("win32 has different output")
>>> JitScalarType.from_value(torch.ones(1, 2)).onnx_type() >>> JitScalarType.from_value(torch.ones(1, 2)).onnx_type()
TensorProtoDataType.FLOAT TensorProtoDataType.FLOAT

View File

@ -22,6 +22,7 @@ EPOCH_DEPRECATION_WARNING = (
"https://github.com/pytorch/pytorch/issues/new/choose." "https://github.com/pytorch/pytorch/issues/new/choose."
) )
class LRScheduler(object): class LRScheduler(object):
def __init__(self, optimizer, last_epoch=-1, verbose=False): def __init__(self, optimizer, last_epoch=-1, verbose=False):
@ -196,10 +197,10 @@ class LambdaLR(LRScheduler):
each update. Default: ``False``. each update. Default: ``False``.
Example: Example:
>>> # xdoctest: +SKIP
>>> # Assuming optimizer has two groups. >>> # Assuming optimizer has two groups.
>>> lambda1 = lambda epoch: epoch // 30 >>> lambda1 = lambda epoch: epoch // 30
>>> lambda2 = lambda epoch: 0.95 ** epoch >>> lambda2 = lambda epoch: 0.95 ** epoch
>>> # xdoctest: +SKIP
>>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2]) >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
>>> for epoch in range(100): >>> for epoch in range(100):
>>> train(...) >>> train(...)
@ -282,8 +283,8 @@ class MultiplicativeLR(LRScheduler):
each update. Default: ``False``. each update. Default: ``False``.
Example: Example:
>>> lmbda = lambda epoch: 0.95
>>> # xdoctest: +SKIP >>> # xdoctest: +SKIP
>>> lmbda = lambda epoch: 0.95
>>> scheduler = MultiplicativeLR(optimizer, lr_lambda=lmbda) >>> scheduler = MultiplicativeLR(optimizer, lr_lambda=lmbda)
>>> for epoch in range(100): >>> for epoch in range(100):
>>> train(...) >>> train(...)
@ -365,12 +366,12 @@ class StepLR(LRScheduler):
each update. Default: ``False``. each update. Default: ``False``.
Example: Example:
>>> # xdoctest: +SKIP
>>> # Assuming optimizer uses lr = 0.05 for all groups >>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.05 if epoch < 30 >>> # lr = 0.05 if epoch < 30
>>> # lr = 0.005 if 30 <= epoch < 60 >>> # lr = 0.005 if 30 <= epoch < 60
>>> # lr = 0.0005 if 60 <= epoch < 90 >>> # lr = 0.0005 if 60 <= epoch < 90
>>> # ... >>> # ...
>>> # xdoctest: +SKIP
>>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1) >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
>>> for epoch in range(100): >>> for epoch in range(100):
>>> train(...) >>> train(...)
@ -414,11 +415,11 @@ class MultiStepLR(LRScheduler):
each update. Default: ``False``. each update. Default: ``False``.
Example: Example:
>>> # xdoctest: +SKIP
>>> # Assuming optimizer uses lr = 0.05 for all groups >>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.05 if epoch < 30 >>> # lr = 0.05 if epoch < 30
>>> # lr = 0.005 if 30 <= epoch < 80 >>> # lr = 0.005 if 30 <= epoch < 80
>>> # lr = 0.0005 if epoch >= 80 >>> # lr = 0.0005 if epoch >= 80
>>> # xdoctest: +SKIP
>>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1) >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
>>> for epoch in range(100): >>> for epoch in range(100):
>>> train(...) >>> train(...)
@ -463,13 +464,13 @@ class ConstantLR(LRScheduler):
each update. Default: ``False``. each update. Default: ``False``.
Example: Example:
>>> # xdoctest: +SKIP
>>> # Assuming optimizer uses lr = 0.05 for all groups >>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.025 if epoch == 0 >>> # lr = 0.025 if epoch == 0
>>> # lr = 0.025 if epoch == 1 >>> # lr = 0.025 if epoch == 1
>>> # lr = 0.025 if epoch == 2 >>> # lr = 0.025 if epoch == 2
>>> # lr = 0.025 if epoch == 3 >>> # lr = 0.025 if epoch == 3
>>> # lr = 0.05 if epoch >= 4 >>> # lr = 0.05 if epoch >= 4
>>> # xdoctest: +SKIP
>>> scheduler = ConstantLR(self.opt, factor=0.5, total_iters=4) >>> scheduler = ConstantLR(self.opt, factor=0.5, total_iters=4)
>>> for epoch in range(100): >>> for epoch in range(100):
>>> train(...) >>> train(...)
@ -525,13 +526,13 @@ class LinearLR(LRScheduler):
each update. Default: ``False``. each update. Default: ``False``.
Example: Example:
>>> # xdoctest: +SKIP
>>> # Assuming optimizer uses lr = 0.05 for all groups >>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.025 if epoch == 0 >>> # lr = 0.025 if epoch == 0
>>> # lr = 0.03125 if epoch == 1 >>> # lr = 0.03125 if epoch == 1
>>> # lr = 0.0375 if epoch == 2 >>> # lr = 0.0375 if epoch == 2
>>> # lr = 0.04375 if epoch == 3 >>> # lr = 0.04375 if epoch == 3
>>> # lr = 0.05 if epoch >= 4 >>> # lr = 0.05 if epoch >= 4
>>> # xdoctest: +SKIP
>>> scheduler = LinearLR(self.opt, start_factor=0.5, total_iters=4) >>> scheduler = LinearLR(self.opt, start_factor=0.5, total_iters=4)
>>> for epoch in range(100): >>> for epoch in range(100):
>>> train(...) >>> train(...)
@ -617,13 +618,13 @@ class SequentialLR(LRScheduler):
verbose (bool): Does nothing. verbose (bool): Does nothing.
Example: Example:
>>> # xdoctest: +SKIP
>>> # Assuming optimizer uses lr = 1. for all groups >>> # Assuming optimizer uses lr = 1. for all groups
>>> # lr = 0.1 if epoch == 0 >>> # lr = 0.1 if epoch == 0
>>> # lr = 0.1 if epoch == 1 >>> # lr = 0.1 if epoch == 1
>>> # lr = 0.9 if epoch == 2 >>> # lr = 0.9 if epoch == 2
>>> # lr = 0.81 if epoch == 3 >>> # lr = 0.81 if epoch == 3
>>> # lr = 0.729 if epoch == 4 >>> # lr = 0.729 if epoch == 4
>>> # xdoctest: +SKIP
>>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2) >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)
>>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9)
>>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2]) >>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2])
@ -670,7 +671,6 @@ class SequentialLR(LRScheduler):
self._last_lr = schedulers[0].get_last_lr() self._last_lr = schedulers[0].get_last_lr()
def step(self): def step(self):
self.last_epoch += 1 self.last_epoch += 1
idx = bisect_right(self._milestones, self.last_epoch) idx = bisect_right(self._milestones, self.last_epoch)
@ -726,13 +726,13 @@ class PolynomialLR(LRScheduler):
each update. Default: ``False``. each update. Default: ``False``.
Example: Example:
>>> # xdoctest: +SKIP("undefined vars")
>>> # Assuming optimizer uses lr = 0.001 for all groups >>> # Assuming optimizer uses lr = 0.001 for all groups
>>> # lr = 0.001 if epoch == 0 >>> # lr = 0.001 if epoch == 0
>>> # lr = 0.00075 if epoch == 1 >>> # lr = 0.00075 if epoch == 1
>>> # lr = 0.00050 if epoch == 2 >>> # lr = 0.00050 if epoch == 2
>>> # lr = 0.00025 if epoch == 3 >>> # lr = 0.00025 if epoch == 3
>>> # lr = 0.0 if epoch >= 4 >>> # lr = 0.0 if epoch >= 4
>>> # xdoctest: +SKIP("undefined vars")
>>> scheduler = PolynomialLR(self.opt, total_iters=4, power=1.0) >>> scheduler = PolynomialLR(self.opt, total_iters=4, power=1.0)
>>> for epoch in range(100): >>> for epoch in range(100):
>>> train(...) >>> train(...)
@ -846,13 +846,13 @@ class ChainedScheduler(LRScheduler):
schedulers (list): List of chained schedulers. schedulers (list): List of chained schedulers.
Example: Example:
>>> # xdoctest: +SKIP
>>> # Assuming optimizer uses lr = 1. for all groups >>> # Assuming optimizer uses lr = 1. for all groups
>>> # lr = 0.09 if epoch == 0 >>> # lr = 0.09 if epoch == 0
>>> # lr = 0.081 if epoch == 1 >>> # lr = 0.081 if epoch == 1
>>> # lr = 0.729 if epoch == 2 >>> # lr = 0.729 if epoch == 2
>>> # lr = 0.6561 if epoch == 3 >>> # lr = 0.6561 if epoch == 3
>>> # lr = 0.59049 if epoch >= 4 >>> # lr = 0.59049 if epoch >= 4
>>> # xdoctest: +SKIP
>>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2) >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)
>>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9)
>>> scheduler = ChainedScheduler([scheduler1, scheduler2]) >>> scheduler = ChainedScheduler([scheduler1, scheduler2])
@ -1544,8 +1544,8 @@ class OneCycleLR(LRScheduler):
each update. Default: ``False``. each update. Default: ``False``.
Example: Example:
>>> data_loader = torch.utils.data.DataLoader(...)
>>> # xdoctest: +SKIP >>> # xdoctest: +SKIP
>>> data_loader = torch.utils.data.DataLoader(...)
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10) >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10)
>>> for epoch in range(10): >>> for epoch in range(10):

View File

@ -9,6 +9,7 @@ from torch.optim.lr_scheduler import LRScheduler
__all__ = ['AveragedModel', 'update_bn', 'SWALR'] __all__ = ['AveragedModel', 'update_bn', 'SWALR']
class AveragedModel(Module): class AveragedModel(Module):
r"""Implements averaged model for Stochastic Weight Averaging (SWA). r"""Implements averaged model for Stochastic Weight Averaging (SWA).

View File

@ -49,6 +49,7 @@ __all__ = [
'StorageType', 'StorageType',
] ]
class SourceChangeWarning(Warning): class SourceChangeWarning(Warning):
pass pass
@ -186,10 +187,12 @@ def _cuda_deserialize(obj, location):
else: else:
return obj.cuda(device) return obj.cuda(device)
def _mps_deserialize(obj, location): def _mps_deserialize(obj, location):
if location == 'mps': if location == 'mps':
return obj.mps() return obj.mps()
def _meta_deserialize(obj, location): def _meta_deserialize(obj, location):
if location == 'meta': if location == 'meta':
return torch.UntypedStorage(obj.nbytes(), device='meta') return torch.UntypedStorage(obj.nbytes(), device='meta')
@ -356,6 +359,7 @@ def _check_seekable(f) -> bool:
raise_err_msg(["seek", "tell"], e) raise_err_msg(["seek", "tell"], e)
return False return False
def _check_dill_version(pickle_module) -> None: def _check_dill_version(pickle_module) -> None:
'''Checks if using dill as the pickle module, and if so, checks if it is the correct version. '''Checks if using dill as the pickle module, and if so, checks if it is the correct version.
If dill version is lower than 0.3.1, a ValueError is raised. If dill version is lower than 0.3.1, a ValueError is raised.
@ -375,12 +379,14 @@ def _check_dill_version(pickle_module) -> None:
pickle_module.__version__ pickle_module.__version__
)) ))
def _check_save_filelike(f): def _check_save_filelike(f):
if not isinstance(f, (str, os.PathLike)) and not hasattr(f, 'write'): if not isinstance(f, (str, os.PathLike)) and not hasattr(f, 'write'):
raise AttributeError(( raise AttributeError((
"expected 'f' to be string, path, or a file-like object with " "expected 'f' to be string, path, or a file-like object with "
"a 'write' attribute")) "a 'write' attribute"))
def save( def save(
obj: object, obj: object,
f: FILE_LIKE, f: FILE_LIKE,
@ -420,6 +426,7 @@ def save(
to use the old format, pass the kwarg ``_use_new_zipfile_serialization=False``. to use the old format, pass the kwarg ``_use_new_zipfile_serialization=False``.
Example: Example:
>>> # xdoctest: +SKIP("makes cwd dirty")
>>> # Save to file >>> # Save to file
>>> x = torch.tensor([0, 1, 2, 3, 4]) >>> x = torch.tensor([0, 1, 2, 3, 4])
>>> torch.save(x, 'tensor.pt') >>> torch.save(x, 'tensor.pt')
@ -753,7 +760,7 @@ def load(
# Load all tensors onto GPU 1 # Load all tensors onto GPU 1
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1)) >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
# Map tensors from GPU 1 to GPU 0 # Map tensors from GPU 1 to GPU 0
>>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'}) >>> torch.load('tensors.pt', map_location={'cuda:1': 'cuda:0'})
# Load tensor from io.BytesIO object # Load tensor from io.BytesIO object
>>> with open('tensor.pt', 'rb') as f: >>> with open('tensor.pt', 'rb') as f:
... buffer = io.BytesIO(f.read()) ... buffer = io.BytesIO(f.read())
@ -1087,6 +1094,7 @@ def _get_restore_location(map_location):
return result return result
return restore_location return restore_location
class StorageType(): class StorageType():
def __init__(self, name): def __init__(self, name):
self.dtype = _get_dtype_from_pickle_storage_type(name) self.dtype = _get_dtype_from_pickle_storage_type(name)
@ -1094,6 +1102,7 @@ class StorageType():
def __str__(self): def __str__(self):
return f'StorageType(dtype={self.dtype})' return f'StorageType(dtype={self.dtype})'
def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', **pickle_load_args): def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', **pickle_load_args):
restore_location = _get_restore_location(map_location) restore_location = _get_restore_location(map_location)

View File

@ -90,6 +90,8 @@ def make_tensor(
TypeError: If :attr:`dtype` isn't supported by this function. TypeError: If :attr:`dtype` isn't supported by this function.
Examples: Examples:
>>> # xdoctest: +SKIP
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> from torch.testing import make_tensor >>> from torch.testing import make_tensor
>>> # Creates a float tensor with values in [-1, 1) >>> # Creates a float tensor with values in [-1, 1)
>>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1) >>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1)

View File

@ -54,6 +54,7 @@ def skip_unless_torch_gpu(method: T) -> T:
""" """
Test decorator which skips the test unless there's a GPU available to torch. Test decorator which skips the test unless there's a GPU available to torch.
>>> # xdoctest: +SKIP
>>> @skip_unless_torch_gpu >>> @skip_unless_torch_gpu
>>> def test_some_method(self) -> None: >>> def test_some_method(self) -> None:
>>> ... >>> ...

View File

@ -22,6 +22,7 @@ def rename_privateuse1_backend(backend_name: str) -> None:
Example:: Example::
>>> # xdoctest: +SKIP("failing")
>>> torch.register_privateuse1_backend("foo") >>> torch.register_privateuse1_backend("foo")
# This will work, assuming that you've implemented the right C++ kernels # This will work, assuming that you've implemented the right C++ kernels
# to implement torch.ones. # to implement torch.ones.

View File

@ -912,6 +912,7 @@ def CppExtension(name, sources, *args, **kwargs):
Example: Example:
>>> # xdoctest: +SKIP >>> # xdoctest: +SKIP
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT)
>>> from setuptools import setup >>> from setuptools import setup
>>> from torch.utils.cpp_extension import BuildExtension, CppExtension >>> from torch.utils.cpp_extension import BuildExtension, CppExtension
>>> setup( >>> setup(
@ -959,6 +960,7 @@ def CUDAExtension(name, sources, *args, **kwargs):
Example: Example:
>>> # xdoctest: +SKIP >>> # xdoctest: +SKIP
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT)
>>> from setuptools import setup >>> from setuptools import setup
>>> from torch.utils.cpp_extension import BuildExtension, CUDAExtension >>> from torch.utils.cpp_extension import BuildExtension, CUDAExtension
>>> setup( >>> setup(
@ -1006,14 +1008,12 @@ def CUDAExtension(name, sources, *args, **kwargs):
To workaround the issue, move python binding logic to pure C++ file. To workaround the issue, move python binding logic to pure C++ file.
Example use: Example use:
>>> # xdoctest: +SKIP #include <ATen/ATen.h>
>>> #include <ATen/ATen.h> at::Tensor SigmoidAlphaBlendForwardCuda(....)
>>> at::Tensor SigmoidAlphaBlendForwardCuda(....)
Instead of: Instead of:
>>> # xdoctest: +SKIP #include <torch/extension.h>
>>> #include <torch/extension.h> torch::Tensor SigmoidAlphaBlendForwardCuda(...)
>>> torch::Tensor SigmoidAlphaBlendForwardCuda(...)
Currently open issue for nvcc bug: https://github.com/pytorch/pytorch/issues/69460 Currently open issue for nvcc bug: https://github.com/pytorch/pytorch/issues/69460
Complete workaround code example: https://github.com/facebookresearch/pytorch3d/commit/cb170ac024a949f1f9614ffe6af1c38d972f7d48 Complete workaround code example: https://github.com/facebookresearch/pytorch3d/commit/cb170ac024a949f1f9614ffe6af1c38d972f7d48
@ -1037,6 +1037,7 @@ def CUDAExtension(name, sources, *args, **kwargs):
Example: Example:
>>> # xdoctest: +SKIP >>> # xdoctest: +SKIP
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT)
>>> CUDAExtension( >>> CUDAExtension(
... name='cuda_extension', ... name='cuda_extension',
... sources=['extension.cpp', 'extension_kernel.cu'], ... sources=['extension.cpp', 'extension_kernel.cu'],
@ -1362,6 +1363,7 @@ def load_inline(name,
causes issues. causes issues.
Example: Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT)
>>> from torch.utils.cpp_extension import load_inline >>> from torch.utils.cpp_extension import load_inline
>>> source = """ >>> source = """
at::Tensor sin_add(at::Tensor x, at::Tensor y) { at::Tensor sin_add(at::Tensor x, at::Tensor y) {

View File

@ -33,11 +33,11 @@ def default_convert(data):
data: a single data point to be converted data: a single data point to be converted
Examples: Examples:
>>> # xdoctest: +SKIP
>>> # Example with `int` >>> # Example with `int`
>>> default_convert(0) >>> default_convert(0)
0 0
>>> # Example with NumPy array >>> # Example with NumPy array
>>> # xdoctest: +SKIP
>>> default_convert(np.array([0, 1])) >>> default_convert(np.array([0, 1]))
tensor([0, 1]) tensor([0, 1])
>>> # Example with NamedTuple >>> # Example with NamedTuple
@ -228,6 +228,7 @@ def default_collate(batch):
batch: a single batch to be collated batch: a single batch to be collated
Examples: Examples:
>>> # xdoctest: +SKIP
>>> # Example with a batch of `int`s: >>> # Example with a batch of `int`s:
>>> default_collate([0, 1, 2, 3]) >>> default_collate([0, 1, 2, 3])
tensor([0, 1, 2, 3]) tensor([0, 1, 2, 3])
@ -238,7 +239,6 @@ def default_collate(batch):
>>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}]) >>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}])
{'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])} {'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])}
>>> # Example with `NamedTuple` inside the batch: >>> # Example with `NamedTuple` inside the batch:
>>> # xdoctest: +SKIP
>>> Point = namedtuple('Point', ['x', 'y']) >>> Point = namedtuple('Point', ['x', 'y'])
>>> default_collate([Point(0, 0), Point(1, 1)]) >>> default_collate([Point(0, 0), Point(1, 1)])
Point(x=tensor([0, 1]), y=tensor([0, 1])) Point(x=tensor([0, 1]), y=tensor([0, 1]))

View File

@ -183,7 +183,9 @@ class CollatorIterDataPipe(MapperIterDataPipe):
collate_fn: Customized collate function to collect and combine data or a batch of data. collate_fn: Customized collate function to collect and combine data or a batch of data.
Default function collates to Tensor(s) based on data type. Default function collates to Tensor(s) based on data type.
Example: Convert integer data to float Tensor Example:
>>> # xdoctest: +SKIP
>>> # Convert integer data to float Tensor
>>> class MyIterDataPipe(torch.utils.data.IterDataPipe): >>> class MyIterDataPipe(torch.utils.data.IterDataPipe):
... def __init__(self, start, end): ... def __init__(self, start, end):
... super(MyIterDataPipe).__init__() ... super(MyIterDataPipe).__init__()
@ -203,7 +205,6 @@ class CollatorIterDataPipe(MapperIterDataPipe):
>>> def collate_fn(batch): >>> def collate_fn(batch):
... return torch.tensor(batch, dtype=torch.float) ... return torch.tensor(batch, dtype=torch.float)
... ...
>>> # xdoctest: +SKIP
>>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn) >>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn)
>>> print(list(collated_ds)) >>> print(list(collated_ds))
[tensor(3.), tensor(4.), tensor(5.), tensor(6.)] [tensor(3.), tensor(4.), tensor(5.), tensor(6.)]

View File

@ -227,7 +227,7 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]):
>>> # xdoctest: +SKIP >>> # xdoctest: +SKIP
>>> from torchdata.datapipes.iter import IterableWrapper >>> from torchdata.datapipes.iter import IterableWrapper
>>> def group_fn(file): >>> def group_fn(file):
... return os.path.basename(file).split(".")[0] ... return os.path.basename(file).split(".")[0]
>>> source_dp = IterableWrapper(["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"]) >>> source_dp = IterableWrapper(["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"])
>>> dp0 = source_dp.groupby(group_key_fn=group_fn) >>> dp0 = source_dp.groupby(group_key_fn=group_fn)
>>> list(dp0) >>> list(dp0)

View File

@ -30,6 +30,7 @@ def validate_input_col(fn: Callable, input_col: Optional[Union[int, tuple, list]
keyword-only arguments. keyword-only arguments.
Examples: Examples:
>>> # xdoctest: +SKIP("Failing on some CI machines")
>>> def f(a, b, *, c=1): >>> def f(a, b, *, c=1):
>>> return a + b + c >>> return a + b + c
>>> def f_def(a, b=1, *, c=1): >>> def f_def(a, b=1, *, c=1):
@ -117,6 +118,7 @@ def _is_local_fn(fn):
return "<locals>" in fn_type.__qualname__ return "<locals>" in fn_type.__qualname__
return False return False
def _check_unpickable_fn(fn: Callable): def _check_unpickable_fn(fn: Callable):
""" """
Checks function is pickable or not. If it is a lambda or local function, a UserWarning Checks function is pickable or not. If it is a lambda or local function, a UserWarning

View File

@ -81,6 +81,8 @@ class IterableDataset(Dataset[T_co]):
Example 1: splitting workload across all workers in :meth:`__iter__`:: Example 1: splitting workload across all workers in :meth:`__iter__`::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER)
>>> # xdoctest: +SKIP("Fails on MacOS12")
>>> class MyIterableDataset(torch.utils.data.IterableDataset): >>> class MyIterableDataset(torch.utils.data.IterableDataset):
... def __init__(self, start, end): ... def __init__(self, start, end):
... super(MyIterableDataset).__init__() ... super(MyIterableDataset).__init__()
@ -122,6 +124,7 @@ class IterableDataset(Dataset[T_co]):
Example 2: splitting workload across all workers using :attr:`worker_init_fn`:: Example 2: splitting workload across all workers using :attr:`worker_init_fn`::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER)
>>> class MyIterableDataset(torch.utils.data.IterableDataset): >>> class MyIterableDataset(torch.utils.data.IterableDataset):
... def __init__(self, start, end): ... def __init__(self, start, end):
... super(MyIterableDataset).__init__() ... super(MyIterableDataset).__init__()
@ -313,9 +316,12 @@ def random_split(dataset: Dataset[T], lengths: Sequence[Union[int, float]],
Optionally fix the generator for reproducible results, e.g.: Optionally fix the generator for reproducible results, e.g.:
>>> random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42)) Example:
>>> random_split(range(30), [0.3, 0.3, 0.4], generator=torch.Generator( >>> # xdoctest: +SKIP
... ).manual_seed(42)) >>> generator1 = torch.Generator().manual_seed(42)
>>> generator2 = torch.Generator().manual_seed(42)
>>> random_split(range(10), [3, 7], generator=generator1)
>>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2)
Args: Args:
dataset (Dataset): Dataset to be split dataset (Dataset): Dataset to be split

View File

@ -53,6 +53,7 @@ __all__ = ['InputError', 'openf', 'bcolors', 'GeneratedFileCleaner', 'match_exte
'is_caffe2_gpu_file', 'Trie', 'preprocessor', 'file_specific_replacement', 'file_add_header', 'is_caffe2_gpu_file', 'Trie', 'preprocessor', 'file_specific_replacement', 'file_add_header',
'fix_static_global_kernels', 'extract_arguments', 'str2bool', 'hipify'] 'fix_static_global_kernels', 'extract_arguments', 'str2bool', 'hipify']
class InputError(Exception): class InputError(Exception):
# Exception raised for errors in the input. # Exception raised for errors in the input.
@ -79,6 +80,7 @@ class bcolors:
BOLD = '\033[1m' BOLD = '\033[1m'
UNDERLINE = '\033[4m' UNDERLINE = '\033[4m'
# To the programmer, the output of hipify most likely are intermediates. # To the programmer, the output of hipify most likely are intermediates.
# This class allows users of hipify to ask for a cleanup by running the # This class allows users of hipify to ask for a cleanup by running the
# hipify and compilation in a with instantiating this context manager class # hipify and compilation in a with instantiating this context manager class
@ -119,13 +121,16 @@ class GeneratedFileCleaner:
for d in self.dirs_to_clean[::-1]: for d in self.dirs_to_clean[::-1]:
os.rmdir(d) os.rmdir(d)
def match_extensions(filename: str, extensions: Iterable) -> bool: def match_extensions(filename: str, extensions: Iterable) -> bool:
"""Helper method to see if filename ends with certain extension""" """Helper method to see if filename ends with certain extension"""
return any(filename.endswith(e) for e in extensions) return any(filename.endswith(e) for e in extensions)
def _fnmatch(filepath, patterns): def _fnmatch(filepath, patterns):
return any(fnmatch.fnmatch(filepath, pattern) for pattern in patterns) return any(fnmatch.fnmatch(filepath, pattern) for pattern in patterns)
def matched_files_iter( def matched_files_iter(
root_path: str, root_path: str,
includes: Iterable = (), includes: Iterable = (),
@ -407,10 +412,8 @@ def find_closure_group(input_string, start, group):
find_closure_group returns the positions of group[0] and group[1] as a tuple. find_closure_group returns the positions of group[0] and group[1] as a tuple.
Example: Example:
find_closure_group("(hi)", 0, ["(", ")"]) >>> find_closure_group("(hi)", 0, ["(", ")"])
(0, 3)
Returns:
0, 3
""" """
inside_parenthesis = False inside_parenthesis = False
@ -522,7 +525,7 @@ def get_hip_file_path(rel_filepath, is_pytorch_extension=False):
""" """
# At the moment, some PyTorch source files are HIPified in place. The predicate # At the moment, some PyTorch source files are HIPified in place. The predicate
# is_out_of_place tells us if this is the case or not. # is_out_of_place tells us if this is the case or not.
assert(not os.path.isabs(rel_filepath)) assert not os.path.isabs(rel_filepath)
if not is_pytorch_extension and not is_out_of_place(rel_filepath): if not is_pytorch_extension and not is_out_of_place(rel_filepath):
return rel_filepath return rel_filepath
@ -589,7 +592,7 @@ def get_hip_file_path(rel_filepath, is_pytorch_extension=False):
def is_out_of_place(rel_filepath): def is_out_of_place(rel_filepath):
assert(not os.path.isabs(rel_filepath)) assert not os.path.isabs(rel_filepath)
if rel_filepath.startswith("torch/"): if rel_filepath.startswith("torch/"):
return False return False
if rel_filepath.startswith("tools/autograd/templates/"): if rel_filepath.startswith("tools/autograd/templates/"):
@ -599,7 +602,7 @@ def is_out_of_place(rel_filepath):
# Keep this synchronized with includes/ignores in build_amd.py # Keep this synchronized with includes/ignores in build_amd.py
def is_pytorch_file(rel_filepath): def is_pytorch_file(rel_filepath):
assert(not os.path.isabs(rel_filepath)) assert not os.path.isabs(rel_filepath)
if rel_filepath.startswith("aten/"): if rel_filepath.startswith("aten/"):
if rel_filepath.startswith("aten/src/ATen/core/"): if rel_filepath.startswith("aten/src/ATen/core/"):
return False return False
@ -616,8 +619,9 @@ def is_cusparse_file(rel_filepath):
return "sparse" in rel_filepath.lower() return "sparse" in rel_filepath.lower()
return False return False
def is_caffe2_gpu_file(rel_filepath): def is_caffe2_gpu_file(rel_filepath):
assert(not os.path.isabs(rel_filepath)) assert not os.path.isabs(rel_filepath)
if rel_filepath.startswith("c10/cuda"): if rel_filepath.startswith("c10/cuda"):
return True return True
filename = os.path.basename(rel_filepath) filename = os.path.basename(rel_filepath)
@ -732,6 +736,8 @@ Returns a dict with the following keys:
"skipped" if an identical hipified file already existed or hipified file couldn't be written out "skipped" if an identical hipified file already existed or hipified file couldn't be written out
"ignored" if the source file was a hipified file itself or not meant to be hipified "ignored" if the source file was a hipified file itself or not meant to be hipified
""" """
def preprocessor( def preprocessor(
output_directory: str, output_directory: str,
filepath: str, filepath: str,
@ -885,6 +891,7 @@ def preprocessor(
else: else:
return {"hipified_path": fout_path, "status": "[skipped, already hipified]"} return {"hipified_path": fout_path, "status": "[skipped, already hipified]"}
def file_specific_replacement(filepath, search_string, replace_string, strict=False): def file_specific_replacement(filepath, search_string, replace_string, strict=False):
with openf(filepath, "r+") as f: with openf(filepath, "r+") as f:
contents = f.read() contents = f.read()