pytorch/test/dist_optimizer_test.py
Alisson Gusatti Azzolini b0cf43b2dd Simple distributed optimizer (#29304)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29304

Implements a simple python distributed optimizer that takes rrefs to parameters that will be optimized.
It keeps instances of optimizers remotely and calling step on distributed optimizer will call step on each of the remote optimizers in parallel.
ghstack-source-id: 93564364

Test Plan: unit tests.

Differential Revision: D18354586

fbshipit-source-id: 85d4c8bfec4aa38d2863cda704d024692511cff5
2019-11-11 12:02:24 -08:00

190 lines
6.0 KiB
Python

from __future__ import absolute_import, division, print_function, unicode_literals
import unittest
from dist_utils import INIT_METHOD_TEMPLATE, dist_init
from torch import optim
from torch.distributed.optim import DistributedOptimizer
import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
import threading
class MyModule:
lock = threading.Lock()
def __init__(self):
# avoid race where 2 modules could be initialized
# concurrently thus changing the order random numbers are drawn.
with MyModule.lock:
torch.manual_seed(0)
self.w = torch.rand((3, 3), requires_grad=True)
def forward(self, t1):
return torch.mm(self.w, t1)
def get_w(self):
return self.w
class FailingOptimizer(optim.Optimizer):
def __init__(self, params):
super(FailingOptimizer, self).__init__(params, {})
def step(self, closure=None):
raise ValueError('Error running optimizer.')
def _call_method(method, obj_rref, *args, **kwargs):
return method(obj_rref.local_value(), *args, **kwargs)
def remote_method(method, obj_rref, *args, **kwargs):
"""
Call rpc.remote on a method in a remote object.
Args:
method: the method (for example, Class.method)
obj_rref (RRef): remote reference to the object
args: positional arguments to pass to the method
kwargs: keyword arguments to pass to the method
Returns a RRef to the remote method call result.
"""
return rpc.remote(
obj_rref.owner(),
_call_method,
args=[method, obj_rref] + list(args),
kwargs=kwargs
)
def rpc_async_method(method, obj_rref, *args, **kwargs):
"""
Call rpc.rpc_async on a method in a remote object.
Args:
method: the method (for example, Class.method)
obj_rref (RRef): remote reference to the object
args: positional arguments to pass to the method
kwargs: keyword arguments to pass to the method
Returns a Future to the method call result.
"""
return rpc.rpc_async(
obj_rref.owner(),
_call_method,
args=[method, obj_rref] + list(args),
kwargs=kwargs
)
@unittest.skipIf(
not torch._six.PY3, "Pytorch distributed optim does not support python2"
)
class DistOptimizerTest(object):
@property
def world_size(self):
return 4
@property
def init_method(self):
return INIT_METHOD_TEMPLATE.format(
file_name=self.file_name, rank=self.rank, world_size=self.world_size
)
@dist_init()
def test_dist_optim_exception(self):
# distributed version
owner1 = 'worker%d' % ((self.rank + 1) % self.world_size)
owner2 = 'worker%d' % ((self.rank + 2) % self.world_size)
remote_module1 = rpc.remote(owner1, MyModule)
remote_module2 = rpc.remote(owner2, MyModule)
remote_param1 = remote_method(MyModule.get_w, remote_module1)
remote_param2 = remote_method(MyModule.get_w, remote_module2)
dist_optim = DistributedOptimizer(
FailingOptimizer,
[remote_param1, remote_param2],
)
with dist_autograd.context():
torch.manual_seed(0)
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
output1 = rpc_async_method(MyModule.forward, remote_module1, t2)
output2 = rpc_async_method(
MyModule.forward, remote_module2, output1.wait())
loss = torch.add(output2.wait(), t1).sum()
dist_autograd.backward([loss])
with self.assertRaisesRegex(Exception, "Error running optimizer"):
dist_optim.step()
@dist_init()
def test_dist_optim(self):
# local version
module1 = MyModule()
module2 = MyModule()
params = [module1.get_w(), module2.get_w()]
local_optim = optim.SGD(params, lr=0.05)
old_w1 = module1.w.clone().detach()
old_w2 = module2.w.clone().detach()
torch.manual_seed(0)
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
output1 = module1.forward(t2)
output2 = module2.forward(output1)
loss = torch.add(output2, t1).sum()
loss.backward()
local_optim.step()
# distributed version
owner1 = 'worker%d' % ((self.rank + 1) % self.world_size)
owner2 = 'worker%d' % ((self.rank + 2) % self.world_size)
remote_module1 = rpc.remote(owner1, MyModule)
remote_module2 = rpc.remote(owner2, MyModule)
remote_param1 = remote_method(MyModule.get_w, remote_module1)
remote_param2 = remote_method(MyModule.get_w, remote_module2)
old_w1_remote = remote_param1.to_here()
# sanity check: local and remote initial weights should match
self.assertEqual(old_w1, remote_param1.to_here())
self.assertEqual(old_w2, remote_param2.to_here())
dist_optim = DistributedOptimizer(
optim.SGD,
[remote_param1, remote_param2],
lr=0.05,
)
with dist_autograd.context():
torch.manual_seed(0)
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
output1 = rpc_async_method(MyModule.forward, remote_module1, t2)
output2 = rpc_async_method(
MyModule.forward, remote_module2, output1.wait())
loss = torch.add(output2.wait(), t1)
dist_autograd.backward([loss.sum()])
dist_optim.step()
new_w1 = rpc_async_method(MyModule.get_w, remote_module1).wait()
new_w2 = rpc_async_method(MyModule.get_w, remote_module2).wait()
# ensure optimizer changed weights
self.assertNotEqual(old_w1, new_w1)
self.assertNotEqual(old_w2, new_w2)
# ensure local equals remote
self.assertEqual(new_w1, module1.get_w())
self.assertEqual(new_w2, module2.get_w())