mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Fixes https://github.com/pytorch/pytorch/issues/58839 After discussing with albanD he proposed this simple design. Let's iterate over the idea here :). Thanks. The main point that this PR does is to use reparametrization to be reverted at the end of the functional call. This allows us to have the original model with its status unchanged, also in this scenario the module is created without parameters so this will hard error if not all parameters are specified when the forward pass is done. ``` python import torch import torch.nn.utils._stateless class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.l1 = torch.nn.Linear(1, 1) def forward(self, x): return self.l1(x) mod = MyModule() print('weight before', mod.l1.weight) x = torch.rand((1, 1)) parameters = {"l1.weight": torch.nn.Parameter(torch.tensor([[1.0]])), "l1.bias": torch.nn.Parameter(torch.tensor([0.0]))} res = torch.nn.utils._stateless.functional_call(mod, parameters, x) print('Functional call input ', x, ' and result ', res) print('weight after', mod.l1.weight) ``` Output ``` weight before Parameter containing: tensor([[-0.4419]], requires_grad=True) Functional call input tensor([[0.3531]]) and result tensor([[0.3531]], grad_fn=<AddmmBackward>) weight after Parameter containing: tensor([[-0.4419]], requires_grad=True) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/61447 Reviewed By: soulitzer Differential Revision: D31082765 Pulled By: albanD fbshipit-source-id: ba814d0f9162fb39c59989ca9a8efe160405ba76
46 lines
1.4 KiB
Python
46 lines
1.4 KiB
Python
import contextlib
|
|
|
|
import torch
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def reparametrize_module(module, parameters_and_buffers):
|
|
# Parametrization does not support to change submodules directly
|
|
for name, tensor in parameters_and_buffers.items():
|
|
_apply_func_submodules(
|
|
torch.nn.utils.parametrize.register_parametrization,
|
|
module, name.split("."), (_ReparametrizedTensor(tensor),))
|
|
yield
|
|
for name in parameters_and_buffers:
|
|
_apply_func_submodules(
|
|
torch.nn.utils.parametrize.remove_parametrizations,
|
|
module, name.split("."), (False,))
|
|
|
|
|
|
class _ReparametrizedTensor(torch.nn.Module):
|
|
def __init__(self, tensor):
|
|
super().__init__()
|
|
self._tensor = tensor
|
|
|
|
def forward(self, original):
|
|
return self._tensor
|
|
|
|
|
|
def _apply_func_submodules(func, module, path, args):
|
|
if len(path) == 1:
|
|
func(module, path[0], *args)
|
|
else:
|
|
_apply_func_submodules(func, getattr(module, path[0]), path[1:], args)
|
|
|
|
|
|
def functional_call(module, parameters_and_buffers, args, kwargs=None):
|
|
# TODO allow kwargs such as unsafe and others for parametrization
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
with reparametrize_module(module, parameters_and_buffers):
|
|
if isinstance(args, tuple):
|
|
out = module(*args, **kwargs)
|
|
else:
|
|
out = module(args, **kwargs)
|
|
return out
|