mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: * Deletes all weak script decorators / associated data structures / methods * In order to keep supporting the standard library in script, this enables recursive script on any function defined in `torch.nn` * Most changes in `torch/nn` are the result of `ag -Q "weak" torch/nn/ -l | xargs sed -i '/weak/d'`, only `rnn.py` needed manual editing to use the `ignore` and `export` to continue supporting the overloaded `forward` methods * `Sequential`/`ModuleList` no longer need to be added to constants since they are compiled on demand This should also fix https://github.com/pytorch/pytorch/issues/22212 Pull Request resolved: https://github.com/pytorch/pytorch/pull/22212 Differential Revision: D15988346 Pulled By: driazati fbshipit-source-id: af223e3ad0580be895377312949997a70e988e4f
186 lines
6.7 KiB
Python
186 lines
6.7 KiB
Python
import torch.cuda.comm as comm
|
|
from torch.cuda._utils import _get_device_index
|
|
|
|
|
|
def _is_script_module(module):
|
|
import torch.jit
|
|
return isinstance(module, torch.jit.ScriptModule)
|
|
|
|
|
|
def _is_script_method(module):
|
|
import torch.jit
|
|
return isinstance(module, torch._C.ScriptMethod)
|
|
|
|
|
|
def _init_script_module():
|
|
import torch.jit
|
|
return torch.jit.ScriptModule()
|
|
|
|
|
|
def _is_jit_enabled():
|
|
import torch.jit
|
|
return torch.jit._enabled
|
|
|
|
|
|
# Check if we can safely replicate the module.
|
|
# there are two types of module:
|
|
# 1. python modules
|
|
# 2. ScriptModule
|
|
#
|
|
# currently a module cannot be replicated properly if the descendants of
|
|
# any ScriptModule contains python module (type 1 above)
|
|
def _replicatable_module(module, memo=None):
|
|
|
|
# module.modules() contains module itself as the first element
|
|
def descendant_modules(module):
|
|
gen = module.modules()
|
|
next(gen)
|
|
return gen
|
|
|
|
if not _is_jit_enabled():
|
|
return True
|
|
if memo is None:
|
|
memo = set()
|
|
|
|
# memorize visited modules
|
|
memo.add(module)
|
|
if _is_script_module(module):
|
|
memo.update(descendant_modules(module))
|
|
return all(_is_script_module(descendant) for
|
|
descendant in descendant_modules(module))
|
|
|
|
for child in module.children():
|
|
# since any unreplicatable module will cause the check to return
|
|
# False early, visited modules here can be safely ignored.
|
|
if child in memo:
|
|
continue
|
|
if not _replicatable_module(child, memo):
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def _copy_scriptmodule_methods(modules, module_copies, module_indices):
|
|
for i, module in enumerate(modules):
|
|
if not _is_script_module(module):
|
|
continue
|
|
replica = module_copies[i]
|
|
for method_name in module._c._method_names():
|
|
replica._c.clone_method(module._c, method_name)
|
|
|
|
|
|
def _broadcast_coalesced_reshape(tensors, devices, detach=False):
|
|
from ._functions import Broadcast
|
|
if detach:
|
|
return comm.broadcast_coalesced(tensors, devices)
|
|
else:
|
|
# Use the autograd function to broadcast if not detach
|
|
if len(tensors) > 0:
|
|
tensor_copies = Broadcast.apply(devices, *tensors)
|
|
return [tensor_copies[i:i + len(tensors)]
|
|
for i in range(0, len(tensor_copies), len(tensors))]
|
|
else:
|
|
return []
|
|
|
|
|
|
def replicate(network, devices, detach=False):
|
|
if not _replicatable_module(network):
|
|
raise RuntimeError("Cannot replicate network where python modules are "
|
|
"childrens of ScriptModule")
|
|
|
|
devices = list(map(lambda x: _get_device_index(x, True), devices))
|
|
num_replicas = len(devices)
|
|
|
|
params = list(network.parameters())
|
|
param_indices = {param: idx for idx, param in enumerate(params)}
|
|
param_copies = _broadcast_coalesced_reshape(params, devices, detach)
|
|
|
|
buffers = list(network.buffers())
|
|
buffers_rg = []
|
|
buffers_not_rg = []
|
|
for buf in buffers:
|
|
if buf.requires_grad and not detach:
|
|
buffers_rg.append(buf)
|
|
else:
|
|
buffers_not_rg.append(buf)
|
|
|
|
buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)}
|
|
buffer_indices_not_rg = {buf: idx for idx, buf in enumerate(buffers_not_rg)}
|
|
|
|
buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach)
|
|
buffer_copies_not_rg = _broadcast_coalesced_reshape(buffers_not_rg, devices, detach=True)
|
|
|
|
modules = list(network.modules())
|
|
module_copies = [[] for device in devices]
|
|
module_indices = {}
|
|
scriptmodule_skip_attr = {"_parameters", "_buffers", "_modules", "forward", "_c"}
|
|
|
|
for i, module in enumerate(modules):
|
|
module_indices[module] = i
|
|
for j in range(num_replicas):
|
|
if _is_script_module(module):
|
|
# we have to initialize ScriptModule properly so that
|
|
# it works with pybind11
|
|
replica = _init_script_module()
|
|
|
|
attribute_names = set(entry[0] for entry in module._c._get_attributes())
|
|
|
|
keys = set(module.__dict__.keys()) - scriptmodule_skip_attr - attribute_names
|
|
for key in keys:
|
|
if not _is_script_method(module.__dict__[key]):
|
|
replica.__dict__[key] = module.__dict__[key]
|
|
for name, the_type, value in module._c._get_attributes():
|
|
if name in module._buffers.keys():
|
|
continue
|
|
replica._c._register_attribute(name, the_type, value)
|
|
else:
|
|
replica = module.__new__(type(module))
|
|
replica.__dict__ = module.__dict__.copy()
|
|
replica._parameters = replica._parameters.copy()
|
|
replica._buffers = replica._buffers.copy()
|
|
replica._modules = replica._modules.copy()
|
|
|
|
module_copies[j].append(replica)
|
|
|
|
for i, module in enumerate(modules):
|
|
for key, child in module._modules.items():
|
|
if child is None:
|
|
for j in range(num_replicas):
|
|
replica = module_copies[j][i]
|
|
replica._modules[key] = None
|
|
else:
|
|
module_idx = module_indices[child]
|
|
for j in range(num_replicas):
|
|
replica = module_copies[j][i]
|
|
replica._modules[key] = module_copies[j][module_idx]
|
|
for key, param in module._parameters.items():
|
|
if param is None:
|
|
for j in range(num_replicas):
|
|
replica = module_copies[j][i]
|
|
replica._parameters[key] = None
|
|
else:
|
|
param_idx = param_indices[param]
|
|
for j in range(num_replicas):
|
|
replica = module_copies[j][i]
|
|
replica._parameters[key] = param_copies[j][param_idx]
|
|
for key, buf in module._buffers.items():
|
|
if buf is None:
|
|
for j in range(num_replicas):
|
|
replica = module_copies[j][i]
|
|
replica._buffers[key] = None
|
|
else:
|
|
if buf.requires_grad and not detach:
|
|
buffer_copies = buffer_copies_rg
|
|
buffer_idx = buffer_indices_rg[buf]
|
|
else:
|
|
buffer_copies = buffer_copies_not_rg
|
|
buffer_idx = buffer_indices_not_rg[buf]
|
|
for j in range(num_replicas):
|
|
replica = module_copies[j][i]
|
|
replica._buffers[key] = buffer_copies[j][buffer_idx]
|
|
|
|
for j in range(num_replicas):
|
|
_copy_scriptmodule_methods(modules, module_copies[j], module_indices)
|
|
|
|
return [module_copies[j][0] for j in range(num_replicas)]
|