Enable dist autograd tests (#28606)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28606

Without passing setup_model_parallel=True to dist_init, it the
decorator actually takes function object as the value for the
flag.

Test Plan: Imported from OSS

Differential Revision: D18120507

Pulled By: mrshenli

fbshipit-source-id: afbaa381647e8f284e28fa9dbdd2a7c411073b3f
This commit is contained in:
Shen Li 2019-10-24 15:28:50 -07:00 committed by Facebook Github Bot
parent 70e4548fd7
commit 261a13a84b
3 changed files with 44 additions and 24 deletions

View File

@ -125,7 +125,7 @@ class DistAutogradTest(object):
def init_method(self):
return INIT_METHOD_TEMPLATE.format(file_name=self.file_name)
@dist_init
@dist_init(setup_model_parallel=True)
def test_autograd_context(self):
# Verify max possible id.
max_auto_increment = 281474976710655
@ -151,7 +151,7 @@ class DistAutogradTest(object):
):
dist_autograd._retrieve_context(context_id)
@dist_init
@dist_init(setup_model_parallel=True)
def test_nested_context(self):
with dist_autograd.context() as context_id:
# Nested contexts not supported.
@ -295,16 +295,16 @@ class DistAutogradTest(object):
with self.assertRaises(RuntimeError):
ctx = dist_autograd._current_context()
@dist_init
@dist_init(setup_model_parallel=True)
def test_graph_for_builtin_call(self):
self._test_graph(torch.add)
@dist_init
@dist_init(setup_model_parallel=True)
def test_graph_for_python_call(self):
self._test_graph(my_py_add)
# 3-layer nested calls
@dist_init
@dist_init(setup_model_parallel=True)
def test_graph_for_py_nested_call(self):
dst_rank = (self.rank + 1) % self.world_size
with dist_autograd.context() as context_id:
@ -358,7 +358,7 @@ class DistAutogradTest(object):
dist.barrier()
# Rank0->Rank1->Rank0
@dist_init
@dist_init(setup_model_parallel=True)
def test_graph_for_py_nested_call_itself(self):
dst_rank = (self.rank + 1) % self.world_size
with dist_autograd.context() as context_id:
@ -395,7 +395,7 @@ class DistAutogradTest(object):
# autograd context before another worker tries to access it.
dist.barrier()
@dist_init
@dist_init(setup_model_parallel=True)
def test_no_graph_with_tensors_not_require_grad(self):
dst_rank = (self.rank + 1) % self.world_size
with dist_autograd.context() as context_id:
@ -417,7 +417,7 @@ class DistAutogradTest(object):
with self.assertRaises(RuntimeError):
ctx = dist_autograd._retrieve_context(ctx_ids[1])
@dist_init
@dist_init(setup_model_parallel=True)
def test_rpc_complex_args(self):
with dist_autograd.context() as context_id:
num_tensors = 10
@ -451,7 +451,7 @@ class DistAutogradTest(object):
self.assertEqual(worker_ids[0], dst_rank)
@dist_init
@dist_init(setup_model_parallel=True)
def test_context_cleanup_many_workers(self):
global known_context_ids
dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank}
@ -468,7 +468,7 @@ class DistAutogradTest(object):
success = _all_contexts_cleaned_up(num_contexts=len(dst_ranks))
self.assertTrue(success)
@dist_init
@dist_init(setup_model_parallel=True)
def test_worker_ids_recorded(self):
dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank}
with dist_autograd.context() as context_id:
@ -499,7 +499,7 @@ class DistAutogradTest(object):
self.assertEqual(len(worker_ids), len(dst_ranks))
self.assertEqual(set(worker_ids), dst_ranks)
@dist_init
@dist_init(setup_model_parallel=True)
def test_error_in_context(self):
with dist_autograd.context() as context_id:
t1 = torch.rand(3, 3, requires_grad=True)
@ -536,7 +536,7 @@ class DistAutogradTest(object):
self.assertEqual(ngrads, len(grads))
@dist_init
@dist_init(setup_model_parallel=True)
def test_backward_simple(self):
# Run the same code locally and with dist autograd and verify gradients
# are same.
@ -549,7 +549,7 @@ class DistAutogradTest(object):
loss = ret.sum()
local_grads = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t1, t2)
@dist_init
@dist_init(setup_model_parallel=True)
def test_backward_multiple_round_trips(self):
local_grads = None
t1 = torch.rand((3, 3), requires_grad=True)
@ -571,7 +571,7 @@ class DistAutogradTest(object):
local_grads = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t1, t2, t3, t4, t5)
@dist_init
@dist_init(setup_model_parallel=True)
def test_backward_different_tensor_dims(self):
local_grads = None
t1 = torch.rand((4, 6), requires_grad=True)
@ -587,7 +587,7 @@ class DistAutogradTest(object):
local_grads = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t1, t2, t2, t3, t4)
@dist_init
@dist_init(setup_model_parallel=True)
def test_backward_unused_tensors(self):
local_grads = None
t1 = torch.rand((3, 3), requires_grad=True)
@ -601,7 +601,7 @@ class DistAutogradTest(object):
loss = val.sum()
local_grads = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t1, t2, t3)
@dist_init
@dist_init(setup_model_parallel=True)
def test_backward_multiple_output_tensors(self):
local_grads = None
t = torch.rand((10, 2), requires_grad=True)
@ -633,7 +633,7 @@ class DistAutogradTest(object):
dist_autograd.backward([val.sum()])
@dist_init
@dist_init(setup_model_parallel=True)
def test_backward_unused_send_function(self):
# Run the test in a thread which would never finish.
t = threading.Thread(target=self._run_test_backward_unused_send_function_in_thread)
@ -644,7 +644,7 @@ class DistAutogradTest(object):
# Verify thread is still alive (indicating backward hasn't completed yet).
self.assertTrue(t.is_alive())
@dist_init
@dist_init(setup_model_parallel=True)
def test_backward_autograd_engine_error(self):
with dist_autograd.context() as context_id:
t1 = torch.rand((3, 3), requires_grad=True)
@ -666,7 +666,7 @@ class DistAutogradTest(object):
# Run backwards, and validate we receive an error.
dist_autograd.backward([val.sum()])
@dist_init
@dist_init(setup_model_parallel=True)
@unittest.skip("Skipping this test temporarily since ProcessGroupAgent does not report errors on node failures")
def test_backward_node_failure(self):
with dist_autograd.context() as context_id:
@ -687,7 +687,7 @@ class DistAutogradTest(object):
# Kill all other nodes.
sys.exit(0)
@dist_init
@dist_init(setup_model_parallel=True)
def test_backward_without_context(self):
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
@ -697,7 +697,7 @@ class DistAutogradTest(object):
args=(t1, t2))
dist_autograd.backward([res.sum()])
@dist_init
@dist_init(setup_model_parallel=True)
def test_backward_without_rpc(self):
dst_rank = self.rank
with dist_autograd.context() as context_id:
@ -713,7 +713,7 @@ class DistAutogradTest(object):
self.assertEqual(torch.ones(3, 3), grads[t1])
self.assertEqual(torch.ones(3, 3), grads[t2])
@dist_init
@dist_init(setup_model_parallel=True)
def test_backward_invalid_args(self):
with dist_autograd.context() as context_id:
@ -735,7 +735,7 @@ class DistAutogradTest(object):
t = torch.rand(1, requires_grad=True)
dist_autograd.backward([t])
@dist_init
@dist_init(setup_model_parallel=True)
def test_backward_multiple_roots(self):
local_grads = None
t1 = torch.rand((3, 3), requires_grad=True)

View File

@ -52,6 +52,10 @@ def set_termination_signal():
def dist_init(setup_model_parallel=True):
assert isinstance(setup_model_parallel, bool), (
"setup_model_parallel must be a bool value"
)
def decorator(old_test_method):
"""
We use this decorator for setting up and tearing down state since
@ -87,7 +91,7 @@ def dist_init(setup_model_parallel=True):
num_send_recv_threads=16,
)
old_test_method(self, *arg, **kwargs)
ret = old_test_method(self, *arg, **kwargs)
if setup_model_parallel:
# Follower reports done.
@ -119,6 +123,8 @@ def dist_init(setup_model_parallel=True):
# Close RPC.
rpc.join_rpc()
return ret
return new_test_method
return decorator

View File

@ -932,3 +932,17 @@ class RpcTest(object):
if TEST_CONFIG.rpc_backend == RpcBackend.PROCESS_GROUP:
self.assertEqual(test_func(), "expected result")
def test_dist_init_decorator(self):
@dist_init(setup_model_parallel=False)
def test_func(self):
return "expected result"
self.assertEqual(test_func(self), "expected result")
with self.assertRaisesRegex(
AssertionError, "setup_model_parallel must be a bool value"
):
@dist_init
def test_func(self):
return "expected result"