pytorch/torch/nn/utils/_stateless.py
Emilio Castillo cd813f16bf Add functional api for nn.Module (#61447)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/58839

After discussing with albanD he proposed this simple design.

Let's iterate over the idea here :).

Thanks.

The main point that this PR does is to use reparametrization to be reverted at the end of the functional call.
This allows us to have the original model with its status unchanged, also in this scenario the module is created without parameters so this will hard error if not all parameters are specified when the forward pass is done.

``` python
import torch
import torch.nn.utils._stateless

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(1, 1)

    def forward(self, x):
        return self.l1(x)

mod = MyModule()
print('weight before', mod.l1.weight)
x = torch.rand((1, 1))
parameters = {"l1.weight": torch.nn.Parameter(torch.tensor([[1.0]])),
              "l1.bias": torch.nn.Parameter(torch.tensor([0.0]))}
res = torch.nn.utils._stateless.functional_call(mod, parameters, x)
print('Functional call input ', x, ' and result ', res)
print('weight after', mod.l1.weight)
```
Output
```
weight before Parameter containing:
tensor([[-0.4419]], requires_grad=True)

Functional call input tensor([[0.3531]]) and result tensor([[0.3531]], grad_fn=<AddmmBackward>)

weight after Parameter containing:
tensor([[-0.4419]], requires_grad=True)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61447

Reviewed By: soulitzer

Differential Revision: D31082765

Pulled By: albanD

fbshipit-source-id: ba814d0f9162fb39c59989ca9a8efe160405ba76
2021-09-21 12:39:43 -07:00

46 lines
1.4 KiB
Python

import contextlib
import torch
@contextlib.contextmanager
def reparametrize_module(module, parameters_and_buffers):
# Parametrization does not support to change submodules directly
for name, tensor in parameters_and_buffers.items():
_apply_func_submodules(
torch.nn.utils.parametrize.register_parametrization,
module, name.split("."), (_ReparametrizedTensor(tensor),))
yield
for name in parameters_and_buffers:
_apply_func_submodules(
torch.nn.utils.parametrize.remove_parametrizations,
module, name.split("."), (False,))
class _ReparametrizedTensor(torch.nn.Module):
def __init__(self, tensor):
super().__init__()
self._tensor = tensor
def forward(self, original):
return self._tensor
def _apply_func_submodules(func, module, path, args):
if len(path) == 1:
func(module, path[0], *args)
else:
_apply_func_submodules(func, getattr(module, path[0]), path[1:], args)
def functional_call(module, parameters_and_buffers, args, kwargs=None):
# TODO allow kwargs such as unsafe and others for parametrization
if kwargs is None:
kwargs = {}
with reparametrize_module(module, parameters_and_buffers):
if isinstance(args, tuple):
out = module(*args, **kwargs)
else:
out = module(args, **kwargs)
return out