mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Added setattr to functional_call. (#77137)
Fixes https://github.com/pytorch/pytorch/issues/77133 Pull Request resolved: https://github.com/pytorch/pytorch/pull/77137 Approved by: https://github.com/emcastillo, https://github.com/albanD, https://github.com/jbschlosser
This commit is contained in:
parent
e517fc8b28
commit
0e351c7df9
|
|
@ -73,6 +73,7 @@ class TestStatelessFunctionalAPI(TestCase):
|
|||
self._run_call_with_mock_module(traced_module)
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIGPU, 'multi-GPU not supported')
|
||||
@unittest.skip("This doesn't work right now")
|
||||
def test_functional_call_with_data_parallel(self):
|
||||
module = MockModule()
|
||||
module.cuda()
|
||||
|
|
@ -155,6 +156,23 @@ class TestStatelessFunctionalAPI(TestCase):
|
|||
self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters()))
|
||||
self.assertEqual(orig_sn_weight, module.l1.weight)
|
||||
|
||||
def test_setattr(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer('foo', torch.zeros(()))
|
||||
|
||||
def forward(self, x):
|
||||
self.foo = self.foo + 1
|
||||
return x + self.foo
|
||||
|
||||
a = {'foo': torch.zeros(())}
|
||||
mod = Foo()
|
||||
stateless.functional_call(mod, a, torch.ones(()))
|
||||
self.assertEqual(mod.foo, torch.zeros(()))
|
||||
self.assertEqual(a['foo'], torch.ones(()))
|
||||
|
||||
|
||||
class TestStatelessDeprecation(TestCase):
|
||||
def test_private_stateless_warns(self):
|
||||
script = """
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ warnings.warn("The `torch.nn.utils._stateless` code is deprecated now that "
|
|||
# Import * wouldn't work as most things are private and thus wouldn't be imported
|
||||
# here.
|
||||
from torch.nn.utils.stateless import functional_call # noqa: F401
|
||||
from torch.nn.utils.stateless import _apply_func_submodules, _remove_swap, _swap_parameters, _change_class # noqa: F401
|
||||
from torch.nn.utils.stateless import _apply_func_submodules, _change_class # noqa: F401
|
||||
# This one used to look public but should actually be private. This was fixed when making the module
|
||||
# public and is kept here for BC
|
||||
from torch.nn.utils.stateless import _reparametrize_module as reparametrize_module # noqa: F401
|
||||
|
|
|
|||
|
|
@ -8,43 +8,51 @@ __all__ = ["functional_call"]
|
|||
|
||||
# We avoid typing module here because module attributes are declared as Union[Parameter, Tensor] by default
|
||||
# and using other types causes mypy errors
|
||||
def _change_class(module) -> None:
|
||||
def _change_class(module, params_and_buffers) -> None:
|
||||
cls = module.__class__
|
||||
func_params : Dict[str, Tensor] = module._functional_parameters
|
||||
attr_to_path : Dict[str, str] = module._attr_to_path
|
||||
|
||||
def _getattribute(self, name: str) -> Any:
|
||||
if name in func_params:
|
||||
return func_params[name]
|
||||
if name in attr_to_path:
|
||||
return params_and_buffers[attr_to_path[name]]
|
||||
return cls.__getattribute__(self, name)
|
||||
|
||||
def _setattr(self, name: str, value: Any) -> None:
|
||||
if name in attr_to_path:
|
||||
params_and_buffers[attr_to_path[name]] = value
|
||||
else:
|
||||
return cls.__setattr__(self, name, value)
|
||||
|
||||
param_cls = type(
|
||||
f"StatelessReplacer{cls.__name__}",
|
||||
(cls,),
|
||||
{
|
||||
"__getattribute__": _getattribute,
|
||||
"__setattr__": _setattr,
|
||||
},
|
||||
)
|
||||
|
||||
module.__class__ = param_cls
|
||||
module._orig_class = cls
|
||||
|
||||
|
||||
def _swap_parameters(module, tensor_name: str, tensor: Tensor) -> None:
|
||||
# Changes the module class to get a new __getattr__ dunder method
|
||||
# that looks for the reparametrized tensor
|
||||
if hasattr(module, "_functional_parameters"):
|
||||
module._functional_parameters[tensor_name] = tensor
|
||||
else:
|
||||
module._functional_parameters = {}
|
||||
module._functional_parameters[tensor_name] = tensor
|
||||
_change_class(module)
|
||||
def _create_swap_params(params_and_buffers):
|
||||
def _swap_parameters(module, tensor_name: str, full_path: str, tensor: Tensor) -> None:
|
||||
# Changes the module class to get a new __getattr__ dunder method
|
||||
# that looks for the reparametrized tensor
|
||||
if hasattr(module, "_attr_to_path"):
|
||||
module._attr_to_path[tensor_name] = full_path
|
||||
else:
|
||||
module._attr_to_path = {}
|
||||
module._attr_to_path[tensor_name] = full_path
|
||||
_change_class(module, params_and_buffers)
|
||||
return _swap_parameters
|
||||
|
||||
|
||||
def _remove_swap(module, name: str) -> None:
|
||||
def _remove_swap(module, name: str, full_path: str) -> None:
|
||||
if hasattr(module, "_orig_class"):
|
||||
module.__class__ = module._orig_class
|
||||
delattr(module, "_orig_class")
|
||||
delattr(module, "_functional_parameters")
|
||||
delattr(module, "_attr_to_path")
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
|
@ -54,25 +62,26 @@ def _reparametrize_module(
|
|||
) -> Iterator[None]:
|
||||
for name, tensor in parameters_and_buffers.items():
|
||||
_apply_func_submodules(
|
||||
_swap_parameters,
|
||||
module, name.split("."), (tensor,))
|
||||
_create_swap_params(parameters_and_buffers),
|
||||
module, name.split("."), name, (tensor,))
|
||||
yield
|
||||
for name in parameters_and_buffers:
|
||||
_apply_func_submodules(
|
||||
_remove_swap,
|
||||
module, name.split("."), ())
|
||||
module, name.split("."), name, ())
|
||||
|
||||
|
||||
def _apply_func_submodules(
|
||||
func: Callable[..., None],
|
||||
module: 'torch.nn.Module',
|
||||
path: List[str],
|
||||
full_path: str,
|
||||
args: Tuple,
|
||||
):
|
||||
if len(path) == 1:
|
||||
func(module, path[0], *args)
|
||||
func(module, path[0], full_path, *args)
|
||||
else:
|
||||
_apply_func_submodules(func, getattr(module, path[0]), path[1:], args)
|
||||
_apply_func_submodules(func, getattr(module, path[0]), path[1:], full_path, args)
|
||||
|
||||
|
||||
def functional_call(
|
||||
|
|
@ -90,6 +99,18 @@ def functional_call(
|
|||
If you want to apply the parametrization function to the value passed
|
||||
please set the key as ``{submodule_name}.parametrizations.{parameter_name}.original``.
|
||||
|
||||
.. note:: If the module performs in-place operations on parameters/buffers, these will be reflected
|
||||
in the `parameters_and_buffers` input.
|
||||
|
||||
Example::
|
||||
|
||||
>>> a = {'foo': torch.zeros(())}
|
||||
>>> mod = Foo() # does self.foo = self.foo + 1
|
||||
>>> print(mod.foo) # tensor(0.)
|
||||
>>> functional_call(mod, a, torch.ones(()))
|
||||
>>> print(mod.foo) # tensor(0.)
|
||||
>>> print(a['foo']) # tensor(1.)
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): the module to call
|
||||
parameters_and_buffers (dict of str and Tensor): the parameters that will be used in
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user