pytorch/torch/nn/parallel/replicate.py
Zachary DeVito 627f2823e0 remove _register_* bindings from python (#29499)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29499

This changes how DataParallel and trace module creation works so that
we no longer need to mutate Module class after it has been created.

The only remaining usage of register_* functions are now inside C++
tests.

Test Plan: Imported from OSS

Differential Revision: D18413652

Pulled By: zdevito

fbshipit-source-id: f039e5400cd016632768be4547892f6a69645c20
2019-11-11 13:52:46 -08:00

169 lines
6.1 KiB
Python

import torch
import torch.cuda.comm as comm
from torch.cuda._utils import _get_device_index
from torch.nn import Parameter
def _is_script_module(module):
import torch.jit
return isinstance(module, torch.jit.ScriptModule)
def _is_script_method(module):
import torch.jit
return isinstance(module, torch._C.ScriptMethod)
def _init_script_module():
import torch.jit
return torch.jit.ScriptModule()
def _is_jit_enabled():
import torch.jit
return torch.jit._enabled
# Check if we can safely replicate the module.
# there are two types of module:
# 1. python modules
# 2. ScriptModule
#
# currently a module cannot be replicated properly if the descendants of
# any ScriptModule contains python module (type 1 above)
def _replicatable_module(module, memo=None):
# module.modules() contains module itself as the first element
def descendant_modules(module):
gen = module.modules()
next(gen)
return gen
if not _is_jit_enabled():
return True
if memo is None:
memo = set()
# memorize visited modules
memo.add(module)
if _is_script_module(module):
memo.update(descendant_modules(module))
return all(_is_script_module(descendant) for
descendant in descendant_modules(module))
for child in module.children():
# since any unreplicatable module will cause the check to return
# False early, visited modules here can be safely ignored.
if child in memo:
continue
if not _replicatable_module(child, memo):
return False
return True
def _broadcast_coalesced_reshape(tensors, devices, detach=False):
from ._functions import Broadcast
if detach:
return comm.broadcast_coalesced(tensors, devices)
else:
# Use the autograd function to broadcast if not detach
if len(tensors) > 0:
tensor_copies = Broadcast.apply(devices, *tensors)
return [tensor_copies[i:i + len(tensors)]
for i in range(0, len(tensor_copies), len(tensors))]
else:
return []
def replicate(network, devices, detach=False):
if not _replicatable_module(network):
raise RuntimeError("Cannot replicate network where python modules are "
"childrens of ScriptModule")
devices = list(map(lambda x: _get_device_index(x, True), devices))
num_replicas = len(devices)
params = list(network.parameters())
param_indices = {param: idx for idx, param in enumerate(params)}
param_copies = _broadcast_coalesced_reshape(params, devices, detach)
buffers = list(network.buffers())
buffers_rg = []
buffers_not_rg = []
for buf in buffers:
if buf.requires_grad and not detach:
buffers_rg.append(buf)
else:
buffers_not_rg.append(buf)
buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)}
buffer_indices_not_rg = {buf: idx for idx, buf in enumerate(buffers_not_rg)}
buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach)
buffer_copies_not_rg = _broadcast_coalesced_reshape(buffers_not_rg, devices, detach=True)
modules = list(network.modules())
module_copies = [[] for device in devices]
module_indices = {}
scriptmodule_skip_attr = {"_parameters", "_buffers", "_modules", "forward", "_c"}
for i, module in enumerate(modules):
module_indices[module] = i
for j in range(num_replicas):
if _is_script_module(module):
# we have to initialize ScriptModule properly so that
# it works with pybind11
def init_fn(script_module):
# Don't do anything here, we'll initialize the ScriptModule below
return
replica = torch.jit.RecursiveScriptModule._construct(module._c._replicate_for_data_parallel(), init_fn)
else:
replica = module._replicate_for_data_parallel()
module_copies[j].append(replica)
for i, module in enumerate(modules):
for key, child in module._modules.items():
if child is None:
for j in range(num_replicas):
replica = module_copies[j][i]
replica._modules[key] = None
else:
module_idx = module_indices[child]
for j in range(num_replicas):
replica = module_copies[j][i]
setattr(replica, key, module_copies[j][module_idx])
for key, param in module._parameters.items():
if param is None:
for j in range(num_replicas):
replica = module_copies[j][i]
replica._parameters[key] = None
else:
param_idx = param_indices[param]
for j in range(num_replicas):
replica = module_copies[j][i]
param = param_copies[j][param_idx]
setattr(replica, key, Parameter(param))
# TODO: We need to manually set _parameters with a bare
# non-parameter Tensor, otherwise gradients don't
# accumulate in the original parameters when you call
# backwards() on the DataParallel module.
replica._parameters[key] = param
for key, buf in module._buffers.items():
if buf is None:
for j in range(num_replicas):
replica = module_copies[j][i]
replica._buffers[key] = None
else:
if buf.requires_grad and not detach:
buffer_copies = buffer_copies_rg
buffer_idx = buffer_indices_rg[buf]
else:
buffer_copies = buffer_copies_not_rg
buffer_idx = buffer_indices_not_rg[buf]
for j in range(num_replicas):
replica = module_copies[j][i]
setattr(replica, key, buffer_copies[j][buffer_idx])
return [module_copies[j][0] for j in range(num_replicas)]