Horace He 2022-05-17 05:40:46 +00:00 committed by PyTorch MergeBot
parent e517fc8b28
commit 0e351c7df9
3 changed files with 61 additions and 22 deletions

View File

@ -73,6 +73,7 @@ class TestStatelessFunctionalAPI(TestCase):
self._run_call_with_mock_module(traced_module)
@unittest.skipIf(not TEST_MULTIGPU, 'multi-GPU not supported')
@unittest.skip("This doesn't work right now")
def test_functional_call_with_data_parallel(self):
module = MockModule()
module.cuda()
@ -155,6 +156,23 @@ class TestStatelessFunctionalAPI(TestCase):
self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters()))
self.assertEqual(orig_sn_weight, module.l1.weight)
def test_setattr(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer('foo', torch.zeros(()))
def forward(self, x):
self.foo = self.foo + 1
return x + self.foo
a = {'foo': torch.zeros(())}
mod = Foo()
stateless.functional_call(mod, a, torch.ones(()))
self.assertEqual(mod.foo, torch.zeros(()))
self.assertEqual(a['foo'], torch.ones(()))
class TestStatelessDeprecation(TestCase):
def test_private_stateless_warns(self):
script = """

View File

@ -9,7 +9,7 @@ warnings.warn("The `torch.nn.utils._stateless` code is deprecated now that "
# Import * wouldn't work as most things are private and thus wouldn't be imported
# here.
from torch.nn.utils.stateless import functional_call # noqa: F401
from torch.nn.utils.stateless import _apply_func_submodules, _remove_swap, _swap_parameters, _change_class # noqa: F401
from torch.nn.utils.stateless import _apply_func_submodules, _change_class # noqa: F401
# This one used to look public but should actually be private. This was fixed when making the module
# public and is kept here for BC
from torch.nn.utils.stateless import _reparametrize_module as reparametrize_module # noqa: F401

View File

@ -8,43 +8,51 @@ __all__ = ["functional_call"]
# We avoid typing module here because module attributes are declared as Union[Parameter, Tensor] by default
# and using other types causes mypy errors
def _change_class(module) -> None:
def _change_class(module, params_and_buffers) -> None:
cls = module.__class__
func_params : Dict[str, Tensor] = module._functional_parameters
attr_to_path : Dict[str, str] = module._attr_to_path
def _getattribute(self, name: str) -> Any:
if name in func_params:
return func_params[name]
if name in attr_to_path:
return params_and_buffers[attr_to_path[name]]
return cls.__getattribute__(self, name)
def _setattr(self, name: str, value: Any) -> None:
if name in attr_to_path:
params_and_buffers[attr_to_path[name]] = value
else:
return cls.__setattr__(self, name, value)
param_cls = type(
f"StatelessReplacer{cls.__name__}",
(cls,),
{
"__getattribute__": _getattribute,
"__setattr__": _setattr,
},
)
module.__class__ = param_cls
module._orig_class = cls
def _swap_parameters(module, tensor_name: str, tensor: Tensor) -> None:
# Changes the module class to get a new __getattr__ dunder method
# that looks for the reparametrized tensor
if hasattr(module, "_functional_parameters"):
module._functional_parameters[tensor_name] = tensor
else:
module._functional_parameters = {}
module._functional_parameters[tensor_name] = tensor
_change_class(module)
def _create_swap_params(params_and_buffers):
def _swap_parameters(module, tensor_name: str, full_path: str, tensor: Tensor) -> None:
# Changes the module class to get a new __getattr__ dunder method
# that looks for the reparametrized tensor
if hasattr(module, "_attr_to_path"):
module._attr_to_path[tensor_name] = full_path
else:
module._attr_to_path = {}
module._attr_to_path[tensor_name] = full_path
_change_class(module, params_and_buffers)
return _swap_parameters
def _remove_swap(module, name: str) -> None:
def _remove_swap(module, name: str, full_path: str) -> None:
if hasattr(module, "_orig_class"):
module.__class__ = module._orig_class
delattr(module, "_orig_class")
delattr(module, "_functional_parameters")
delattr(module, "_attr_to_path")
@contextlib.contextmanager
@ -54,25 +62,26 @@ def _reparametrize_module(
) -> Iterator[None]:
for name, tensor in parameters_and_buffers.items():
_apply_func_submodules(
_swap_parameters,
module, name.split("."), (tensor,))
_create_swap_params(parameters_and_buffers),
module, name.split("."), name, (tensor,))
yield
for name in parameters_and_buffers:
_apply_func_submodules(
_remove_swap,
module, name.split("."), ())
module, name.split("."), name, ())
def _apply_func_submodules(
func: Callable[..., None],
module: 'torch.nn.Module',
path: List[str],
full_path: str,
args: Tuple,
):
if len(path) == 1:
func(module, path[0], *args)
func(module, path[0], full_path, *args)
else:
_apply_func_submodules(func, getattr(module, path[0]), path[1:], args)
_apply_func_submodules(func, getattr(module, path[0]), path[1:], full_path, args)
def functional_call(
@ -90,6 +99,18 @@ def functional_call(
If you want to apply the parametrization function to the value passed
please set the key as ``{submodule_name}.parametrizations.{parameter_name}.original``.
.. note:: If the module performs in-place operations on parameters/buffers, these will be reflected
in the `parameters_and_buffers` input.
Example::
>>> a = {'foo': torch.zeros(())}
>>> mod = Foo() # does self.foo = self.foo + 1
>>> print(mod.foo) # tensor(0.)
>>> functional_call(mod, a, torch.ones(()))
>>> print(mod.foo) # tensor(0.)
>>> print(a['foo']) # tensor(1.)
Args:
module (torch.nn.Module): the module to call
parameters_and_buffers (dict of str and Tensor): the parameters that will be used in