mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
63 lines
2.4 KiB
Python
63 lines
2.4 KiB
Python
import torch.cuda.comm as comm
|
|
|
|
|
|
def replicate(network, devices):
|
|
from ._functions import Broadcast
|
|
|
|
devices = tuple(devices)
|
|
num_replicas = len(devices)
|
|
|
|
params = list(network.parameters())
|
|
param_indices = {param: idx for idx, param in enumerate(params)}
|
|
param_copies = Broadcast(devices)(*params)
|
|
if len(params) > 0:
|
|
param_copies = [param_copies[i:i + len(params)]
|
|
for i in range(0, len(param_copies), len(params))]
|
|
|
|
buffers = list(network._all_buffers())
|
|
buffer_indices = {buf: idx for idx, buf in enumerate(buffers)}
|
|
buffer_copies = comm.broadcast_coalesced(buffers, devices)
|
|
|
|
modules = list(network.modules())
|
|
module_copies = [[] for device in devices]
|
|
module_indices = {}
|
|
|
|
for i, module in enumerate(modules):
|
|
module_indices[module] = i
|
|
for j in range(num_replicas):
|
|
replica = module.__new__(type(module))
|
|
replica.__dict__ = module.__dict__.copy()
|
|
replica._parameters = replica._parameters.copy()
|
|
replica._buffers = replica._buffers.copy()
|
|
replica._modules = replica._modules.copy()
|
|
module_copies[j].append(replica)
|
|
|
|
for i, module in enumerate(modules):
|
|
for key, child in module._modules.items():
|
|
module_idx = module_indices[child]
|
|
for j in range(num_replicas):
|
|
replica = module_copies[j][i]
|
|
replica._modules[key] = module_copies[j][module_idx]
|
|
for key, param in module._parameters.items():
|
|
if param is None:
|
|
for j in range(num_replicas):
|
|
replica = module_copies[j][i]
|
|
replica._parameters[key] = None
|
|
else:
|
|
param_idx = param_indices[param]
|
|
for j in range(num_replicas):
|
|
replica = module_copies[j][i]
|
|
replica._parameters[key] = param_copies[j][param_idx]
|
|
for key, buf in module._buffers.items():
|
|
if buf is None:
|
|
for j in range(num_replicas):
|
|
replica = module_copies[j][i]
|
|
replica._buffers[key] = None
|
|
else:
|
|
buffer_idx = buffer_indices[buf]
|
|
for j in range(num_replicas):
|
|
replica = module_copies[j][i]
|
|
replica._buffers[key] = buffer_copies[j][buffer_idx]
|
|
|
|
return [module_copies[j][0] for j in range(num_replicas)]
|