pytorch/torch/nn/modules/linear.py
MilesCranmer 30292d994f Add an identity module (#19249)
Summary:
This is a simple yet useful addition to the torch.nn modules: an identity module. This is a first draft - please let me know what you think and I will edit my PR.

 There is no identity module - nn.Sequential() can be used, however it is argument sensitive so can't be used interchangably with any other module. This adds nn.Identity(...) which can be swapped with any module because it has dummy arguments. It's also more understandable than seeing an empty Sequential inside a model.

See discussion on #9160. The current solution is to use nn.Sequential(). However this won't work as follows:

```python
batch_norm = nn.BatchNorm2d
if dont_use_batch_norm:
    batch_norm = Identity
```

Then in your network, you have:

```python
nn.Sequential(
    ...
    batch_norm(N, momentum=0.05),
    ...
)
```

If you try to simply set `Identity = nn.Sequential`, this will fail since `nn.Sequential` expects modules as arguments. Of course there are many ways to get around this, including:

- Conditionally adding modules to an existing Sequential module
- Not using Sequential but writing the usual `forward` function with an if statement
- ...

**However, I think that an identity module is the most pythonic strategy,** assuming you want to use nn.Sequential.

Using the very simple class (this isn't the same as the one in my commit):

```python
class Identity(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
    def forward(self, x):
        return x
```

we can get around using nn.Sequential, and `batch_norm(N, momentum=0.05)` will work. There are of course other situations this would be useful.

Thank you.
Best,
Miles
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19249

Differential Revision: D15012969

Pulled By: ezyang

fbshipit-source-id: 9f47e252137a1679e306fd4c169dca832eb82c0c
2019-04-19 10:12:18 -07:00

170 lines
6.0 KiB
Python

import math
import torch
from torch.nn.parameter import Parameter
from .. import functional as F
from .. import init
from .module import Module
from ..._jit_internal import weak_module, weak_script_method
@weak_module
class Identity(Module):
r"""A placeholder identity operator that is argument-insensitive.
Args:
args: any argument (unused)
kwargs: any keyword argument (unused)
Examples::
>>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 20])
"""
def __init__(self, *args, **kwargs):
super(Identity, self).__init__()
@weak_script_method
def forward(self, input):
return input
@weak_module
class Linear(Module):
r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
Args:
in_features: size of each input sample
out_features: size of each output sample
bias: If set to ``False``, the layer will not learn an additive bias.
Default: ``True``
Shape:
- Input: :math:`(N, *, H_{in})` where :math:`*` means any number of
additional dimensions and :math:`H_{in} = \text{in\_features}`
- Output: :math:`(N, *, H_{out})` where all but the last dimension
are the same shape as the input and :math:`H_{out} = \text{out\_features}`.
Attributes:
weight: the learnable weights of the module of shape
:math:`(\text{out\_features}, \text{in\_features})`. The values are
initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
:math:`k = \frac{1}{\text{in\_features}}`
bias: the learnable bias of the module of shape :math:`(\text{out\_features})`.
If :attr:`bias` is ``True``, the values are initialized from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{1}{\text{in\_features}}`
Examples::
>>> m = nn.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
__constants__ = ['bias']
def __init__(self, in_features, out_features, bias=True):
super(Linear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
@weak_script_method
def forward(self, input):
return F.linear(input, self.weight, self.bias)
def extra_repr(self):
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None
)
@weak_module
class Bilinear(Module):
r"""Applies a bilinear transformation to the incoming data:
:math:`y = x_1 A x_2 + b`
Args:
in1_features: size of each first input sample
in2_features: size of each second input sample
out_features: size of each output sample
bias: If set to False, the layer will not learn an additive bias.
Default: ``True``
Shape:
- Input1: :math:`(N, *, H_{in1})` where :math:`H_{in1}=\text{in1\_features}` and
:math:`*` means any number of additional dimensions. All but the last dimension
of the inputs should be the same.
- Input2: :math:`(N, *, H_{in2})` where :math:`H_{in2}=\text{in2\_features}`.
- Output: :math:`(N, *, H_{out})` where :math:`H_{out}=\text{out\_features}`
and all but the last dimension are the same shape as the input.
Attributes:
weight: the learnable weights of the module of shape
:math:`(\text{out\_features}, \text{in1\_features}, \text{in2\_features})`.
The values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
:math:`k = \frac{1}{\text{in1\_features}}`
bias: the learnable bias of the module of shape :math:`(\text{out\_features})`.
If :attr:`bias` is ``True``, the values are initialized from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
:math:`k = \frac{1}{\text{in1\_features}}`
Examples::
>>> m = nn.Bilinear(20, 30, 40)
>>> input1 = torch.randn(128, 20)
>>> input2 = torch.randn(128, 30)
>>> output = m(input1, input2)
>>> print(output.size())
torch.Size([128, 40])
"""
__constants__ = ['in1_features', 'in2_features', 'out_features', 'bias']
def __init__(self, in1_features, in2_features, out_features, bias=True):
super(Bilinear, self).__init__()
self.in1_features = in1_features
self.in2_features = in2_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(out_features, in1_features, in2_features))
if bias:
self.bias = Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
bound = 1 / math.sqrt(self.weight.size(1))
init.uniform_(self.weight, -bound, bound)
if self.bias is not None:
init.uniform_(self.bias, -bound, bound)
@weak_script_method
def forward(self, input1, input2):
return F.bilinear(input1, input2, self.weight, self.bias)
def extra_repr(self):
return 'in1_features={}, in2_features={}, out_features={}, bias={}'.format(
self.in1_features, self.in2_features, self.out_features, self.bias is not None
)
# TODO: PartialLinear - maybe in sparse?