mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: This is a simple yet useful addition to the torch.nn modules: an identity module. This is a first draft - please let me know what you think and I will edit my PR. There is no identity module - nn.Sequential() can be used, however it is argument sensitive so can't be used interchangably with any other module. This adds nn.Identity(...) which can be swapped with any module because it has dummy arguments. It's also more understandable than seeing an empty Sequential inside a model. See discussion on #9160. The current solution is to use nn.Sequential(). However this won't work as follows: ```python batch_norm = nn.BatchNorm2d if dont_use_batch_norm: batch_norm = Identity ``` Then in your network, you have: ```python nn.Sequential( ... batch_norm(N, momentum=0.05), ... ) ``` If you try to simply set `Identity = nn.Sequential`, this will fail since `nn.Sequential` expects modules as arguments. Of course there are many ways to get around this, including: - Conditionally adding modules to an existing Sequential module - Not using Sequential but writing the usual `forward` function with an if statement - ... **However, I think that an identity module is the most pythonic strategy,** assuming you want to use nn.Sequential. Using the very simple class (this isn't the same as the one in my commit): ```python class Identity(nn.Module): def __init__(self, *args, **kwargs): super().__init__() def forward(self, x): return x ``` we can get around using nn.Sequential, and `batch_norm(N, momentum=0.05)` will work. There are of course other situations this would be useful. Thank you. Best, Miles Pull Request resolved: https://github.com/pytorch/pytorch/pull/19249 Differential Revision: D15012969 Pulled By: ezyang fbshipit-source-id: 9f47e252137a1679e306fd4c169dca832eb82c0c
170 lines
6.0 KiB
Python
170 lines
6.0 KiB
Python
import math
|
|
|
|
import torch
|
|
from torch.nn.parameter import Parameter
|
|
from .. import functional as F
|
|
from .. import init
|
|
from .module import Module
|
|
from ..._jit_internal import weak_module, weak_script_method
|
|
|
|
|
|
@weak_module
|
|
class Identity(Module):
|
|
r"""A placeholder identity operator that is argument-insensitive.
|
|
|
|
Args:
|
|
args: any argument (unused)
|
|
kwargs: any keyword argument (unused)
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
|
|
>>> input = torch.randn(128, 20)
|
|
>>> output = m(input)
|
|
>>> print(output.size())
|
|
torch.Size([128, 20])
|
|
|
|
"""
|
|
def __init__(self, *args, **kwargs):
|
|
super(Identity, self).__init__()
|
|
|
|
@weak_script_method
|
|
def forward(self, input):
|
|
return input
|
|
|
|
|
|
@weak_module
|
|
class Linear(Module):
|
|
r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
|
|
|
|
Args:
|
|
in_features: size of each input sample
|
|
out_features: size of each output sample
|
|
bias: If set to ``False``, the layer will not learn an additive bias.
|
|
Default: ``True``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *, H_{in})` where :math:`*` means any number of
|
|
additional dimensions and :math:`H_{in} = \text{in\_features}`
|
|
- Output: :math:`(N, *, H_{out})` where all but the last dimension
|
|
are the same shape as the input and :math:`H_{out} = \text{out\_features}`.
|
|
|
|
Attributes:
|
|
weight: the learnable weights of the module of shape
|
|
:math:`(\text{out\_features}, \text{in\_features})`. The values are
|
|
initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
|
|
:math:`k = \frac{1}{\text{in\_features}}`
|
|
bias: the learnable bias of the module of shape :math:`(\text{out\_features})`.
|
|
If :attr:`bias` is ``True``, the values are initialized from
|
|
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
|
|
:math:`k = \frac{1}{\text{in\_features}}`
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Linear(20, 30)
|
|
>>> input = torch.randn(128, 20)
|
|
>>> output = m(input)
|
|
>>> print(output.size())
|
|
torch.Size([128, 30])
|
|
"""
|
|
__constants__ = ['bias']
|
|
|
|
def __init__(self, in_features, out_features, bias=True):
|
|
super(Linear, self).__init__()
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
self.weight = Parameter(torch.Tensor(out_features, in_features))
|
|
if bias:
|
|
self.bias = Parameter(torch.Tensor(out_features))
|
|
else:
|
|
self.register_parameter('bias', None)
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
|
if self.bias is not None:
|
|
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
|
|
bound = 1 / math.sqrt(fan_in)
|
|
init.uniform_(self.bias, -bound, bound)
|
|
|
|
@weak_script_method
|
|
def forward(self, input):
|
|
return F.linear(input, self.weight, self.bias)
|
|
|
|
def extra_repr(self):
|
|
return 'in_features={}, out_features={}, bias={}'.format(
|
|
self.in_features, self.out_features, self.bias is not None
|
|
)
|
|
|
|
|
|
@weak_module
|
|
class Bilinear(Module):
|
|
r"""Applies a bilinear transformation to the incoming data:
|
|
:math:`y = x_1 A x_2 + b`
|
|
|
|
Args:
|
|
in1_features: size of each first input sample
|
|
in2_features: size of each second input sample
|
|
out_features: size of each output sample
|
|
bias: If set to False, the layer will not learn an additive bias.
|
|
Default: ``True``
|
|
|
|
Shape:
|
|
- Input1: :math:`(N, *, H_{in1})` where :math:`H_{in1}=\text{in1\_features}` and
|
|
:math:`*` means any number of additional dimensions. All but the last dimension
|
|
of the inputs should be the same.
|
|
- Input2: :math:`(N, *, H_{in2})` where :math:`H_{in2}=\text{in2\_features}`.
|
|
- Output: :math:`(N, *, H_{out})` where :math:`H_{out}=\text{out\_features}`
|
|
and all but the last dimension are the same shape as the input.
|
|
|
|
Attributes:
|
|
weight: the learnable weights of the module of shape
|
|
:math:`(\text{out\_features}, \text{in1\_features}, \text{in2\_features})`.
|
|
The values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
|
|
:math:`k = \frac{1}{\text{in1\_features}}`
|
|
bias: the learnable bias of the module of shape :math:`(\text{out\_features})`.
|
|
If :attr:`bias` is ``True``, the values are initialized from
|
|
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
|
|
:math:`k = \frac{1}{\text{in1\_features}}`
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Bilinear(20, 30, 40)
|
|
>>> input1 = torch.randn(128, 20)
|
|
>>> input2 = torch.randn(128, 30)
|
|
>>> output = m(input1, input2)
|
|
>>> print(output.size())
|
|
torch.Size([128, 40])
|
|
"""
|
|
__constants__ = ['in1_features', 'in2_features', 'out_features', 'bias']
|
|
|
|
def __init__(self, in1_features, in2_features, out_features, bias=True):
|
|
super(Bilinear, self).__init__()
|
|
self.in1_features = in1_features
|
|
self.in2_features = in2_features
|
|
self.out_features = out_features
|
|
self.weight = Parameter(torch.Tensor(out_features, in1_features, in2_features))
|
|
|
|
if bias:
|
|
self.bias = Parameter(torch.Tensor(out_features))
|
|
else:
|
|
self.register_parameter('bias', None)
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
bound = 1 / math.sqrt(self.weight.size(1))
|
|
init.uniform_(self.weight, -bound, bound)
|
|
if self.bias is not None:
|
|
init.uniform_(self.bias, -bound, bound)
|
|
|
|
@weak_script_method
|
|
def forward(self, input1, input2):
|
|
return F.bilinear(input1, input2, self.weight, self.bias)
|
|
|
|
def extra_repr(self):
|
|
return 'in1_features={}, in2_features={}, out_features={}, bias={}'.format(
|
|
self.in1_features, self.in2_features, self.out_features, self.bias is not None
|
|
)
|
|
|
|
# TODO: PartialLinear - maybe in sparse?
|