mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Options to address the "undocumented python objects": 1. Reference the functions in the .rst via the torch.nn.modules namespace. Note that this changes the generated doc filenames / locations for most of these functions! 2. [Not an option] Monkeypatch `__module__` for these objects (broke several tests in CI due to `inspect.findsource` failing after this change) 3. Update the .rst files to also document the torch.nn.modules forms of these functions, duplicating docs. #### [this is the docs page added](https://docs-preview.pytorch.org/pytorch/pytorch/158491/nn.aliases.html) This PR takes option 3 by adding an rst page nn.aliases that documents the aliases in nested namespaces, removing all the torch.nn.modules.* entries from the coverage skiplist except - NLLLoss2d (deprecated) - Container (deprecated) - CrossMapLRN2d (what is this?) - NonDynamicallyQuantizableLinear This mostly required adding docstrings to `forward`, `extra_repr` and `reset_parameters`. Since forward arguments are already part of the module docstrings I just added a very basic docstring. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158491 Approved by: https://github.com/janeyx99
128 lines
3.9 KiB
Python
128 lines
3.9 KiB
Python
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
|
|
from .module import Module
|
|
|
|
|
|
__all__ = ["PixelShuffle", "PixelUnshuffle"]
|
|
|
|
|
|
class PixelShuffle(Module):
|
|
r"""Rearrange elements in a tensor according to an upscaling factor.
|
|
|
|
Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)`
|
|
to a tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is an upscale factor.
|
|
|
|
This is useful for implementing efficient sub-pixel convolution
|
|
with a stride of :math:`1/r`.
|
|
|
|
See the paper:
|
|
`Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_
|
|
by Shi et al. (2016) for more details.
|
|
|
|
Args:
|
|
upscale_factor (int): factor to increase spatial resolution by
|
|
|
|
Shape:
|
|
- Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions
|
|
- Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where
|
|
|
|
.. math::
|
|
C_{out} = C_{in} \div \text{upscale\_factor}^2
|
|
|
|
.. math::
|
|
H_{out} = H_{in} \times \text{upscale\_factor}
|
|
|
|
.. math::
|
|
W_{out} = W_{in} \times \text{upscale\_factor}
|
|
|
|
Examples::
|
|
|
|
>>> pixel_shuffle = nn.PixelShuffle(3)
|
|
>>> input = torch.randn(1, 9, 4, 4)
|
|
>>> output = pixel_shuffle(input)
|
|
>>> print(output.size())
|
|
torch.Size([1, 1, 12, 12])
|
|
|
|
.. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network:
|
|
https://arxiv.org/abs/1609.05158
|
|
"""
|
|
|
|
__constants__ = ["upscale_factor"]
|
|
upscale_factor: int
|
|
|
|
def __init__(self, upscale_factor: int) -> None:
|
|
super().__init__()
|
|
self.upscale_factor = upscale_factor
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
"""
|
|
Runs the forward pass.
|
|
"""
|
|
return F.pixel_shuffle(input, self.upscale_factor)
|
|
|
|
def extra_repr(self) -> str:
|
|
"""
|
|
Return the extra representation of the module.
|
|
"""
|
|
return f"upscale_factor={self.upscale_factor}"
|
|
|
|
|
|
class PixelUnshuffle(Module):
|
|
r"""Reverse the PixelShuffle operation.
|
|
|
|
Reverses the :class:`~torch.nn.PixelShuffle` operation by rearranging elements
|
|
in a tensor of shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape
|
|
:math:`(*, C \times r^2, H, W)`, where r is a downscale factor.
|
|
|
|
See the paper:
|
|
`Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_
|
|
by Shi et al. (2016) for more details.
|
|
|
|
Args:
|
|
downscale_factor (int): factor to decrease spatial resolution by
|
|
|
|
Shape:
|
|
- Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions
|
|
- Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where
|
|
|
|
.. math::
|
|
C_{out} = C_{in} \times \text{downscale\_factor}^2
|
|
|
|
.. math::
|
|
H_{out} = H_{in} \div \text{downscale\_factor}
|
|
|
|
.. math::
|
|
W_{out} = W_{in} \div \text{downscale\_factor}
|
|
|
|
Examples::
|
|
|
|
>>> pixel_unshuffle = nn.PixelUnshuffle(3)
|
|
>>> input = torch.randn(1, 1, 12, 12)
|
|
>>> output = pixel_unshuffle(input)
|
|
>>> print(output.size())
|
|
torch.Size([1, 9, 4, 4])
|
|
|
|
.. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network:
|
|
https://arxiv.org/abs/1609.05158
|
|
"""
|
|
|
|
__constants__ = ["downscale_factor"]
|
|
downscale_factor: int
|
|
|
|
def __init__(self, downscale_factor: int) -> None:
|
|
super().__init__()
|
|
self.downscale_factor = downscale_factor
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
"""
|
|
Runs the forward pass.
|
|
"""
|
|
return F.pixel_unshuffle(input, self.downscale_factor)
|
|
|
|
def extra_repr(self) -> str:
|
|
"""
|
|
Return the extra representation of the module.
|
|
"""
|
|
return f"downscale_factor={self.downscale_factor}"
|