mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Enable dist autograd tests (#28606)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/28606 Without passing setup_model_parallel=True to dist_init, it the decorator actually takes function object as the value for the flag. Test Plan: Imported from OSS Differential Revision: D18120507 Pulled By: mrshenli fbshipit-source-id: afbaa381647e8f284e28fa9dbdd2a7c411073b3f
This commit is contained in:
parent
70e4548fd7
commit
261a13a84b
|
|
@ -125,7 +125,7 @@ class DistAutogradTest(object):
|
|||
def init_method(self):
|
||||
return INIT_METHOD_TEMPLATE.format(file_name=self.file_name)
|
||||
|
||||
@dist_init
|
||||
@dist_init(setup_model_parallel=True)
|
||||
def test_autograd_context(self):
|
||||
# Verify max possible id.
|
||||
max_auto_increment = 281474976710655
|
||||
|
|
@ -151,7 +151,7 @@ class DistAutogradTest(object):
|
|||
):
|
||||
dist_autograd._retrieve_context(context_id)
|
||||
|
||||
@dist_init
|
||||
@dist_init(setup_model_parallel=True)
|
||||
def test_nested_context(self):
|
||||
with dist_autograd.context() as context_id:
|
||||
# Nested contexts not supported.
|
||||
|
|
@ -295,16 +295,16 @@ class DistAutogradTest(object):
|
|||
with self.assertRaises(RuntimeError):
|
||||
ctx = dist_autograd._current_context()
|
||||
|
||||
@dist_init
|
||||
@dist_init(setup_model_parallel=True)
|
||||
def test_graph_for_builtin_call(self):
|
||||
self._test_graph(torch.add)
|
||||
|
||||
@dist_init
|
||||
@dist_init(setup_model_parallel=True)
|
||||
def test_graph_for_python_call(self):
|
||||
self._test_graph(my_py_add)
|
||||
|
||||
# 3-layer nested calls
|
||||
@dist_init
|
||||
@dist_init(setup_model_parallel=True)
|
||||
def test_graph_for_py_nested_call(self):
|
||||
dst_rank = (self.rank + 1) % self.world_size
|
||||
with dist_autograd.context() as context_id:
|
||||
|
|
@ -358,7 +358,7 @@ class DistAutogradTest(object):
|
|||
dist.barrier()
|
||||
|
||||
# Rank0->Rank1->Rank0
|
||||
@dist_init
|
||||
@dist_init(setup_model_parallel=True)
|
||||
def test_graph_for_py_nested_call_itself(self):
|
||||
dst_rank = (self.rank + 1) % self.world_size
|
||||
with dist_autograd.context() as context_id:
|
||||
|
|
@ -395,7 +395,7 @@ class DistAutogradTest(object):
|
|||
# autograd context before another worker tries to access it.
|
||||
dist.barrier()
|
||||
|
||||
@dist_init
|
||||
@dist_init(setup_model_parallel=True)
|
||||
def test_no_graph_with_tensors_not_require_grad(self):
|
||||
dst_rank = (self.rank + 1) % self.world_size
|
||||
with dist_autograd.context() as context_id:
|
||||
|
|
@ -417,7 +417,7 @@ class DistAutogradTest(object):
|
|||
with self.assertRaises(RuntimeError):
|
||||
ctx = dist_autograd._retrieve_context(ctx_ids[1])
|
||||
|
||||
@dist_init
|
||||
@dist_init(setup_model_parallel=True)
|
||||
def test_rpc_complex_args(self):
|
||||
with dist_autograd.context() as context_id:
|
||||
num_tensors = 10
|
||||
|
|
@ -451,7 +451,7 @@ class DistAutogradTest(object):
|
|||
self.assertEqual(worker_ids[0], dst_rank)
|
||||
|
||||
|
||||
@dist_init
|
||||
@dist_init(setup_model_parallel=True)
|
||||
def test_context_cleanup_many_workers(self):
|
||||
global known_context_ids
|
||||
dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank}
|
||||
|
|
@ -468,7 +468,7 @@ class DistAutogradTest(object):
|
|||
success = _all_contexts_cleaned_up(num_contexts=len(dst_ranks))
|
||||
self.assertTrue(success)
|
||||
|
||||
@dist_init
|
||||
@dist_init(setup_model_parallel=True)
|
||||
def test_worker_ids_recorded(self):
|
||||
dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank}
|
||||
with dist_autograd.context() as context_id:
|
||||
|
|
@ -499,7 +499,7 @@ class DistAutogradTest(object):
|
|||
self.assertEqual(len(worker_ids), len(dst_ranks))
|
||||
self.assertEqual(set(worker_ids), dst_ranks)
|
||||
|
||||
@dist_init
|
||||
@dist_init(setup_model_parallel=True)
|
||||
def test_error_in_context(self):
|
||||
with dist_autograd.context() as context_id:
|
||||
t1 = torch.rand(3, 3, requires_grad=True)
|
||||
|
|
@ -536,7 +536,7 @@ class DistAutogradTest(object):
|
|||
self.assertEqual(ngrads, len(grads))
|
||||
|
||||
|
||||
@dist_init
|
||||
@dist_init(setup_model_parallel=True)
|
||||
def test_backward_simple(self):
|
||||
# Run the same code locally and with dist autograd and verify gradients
|
||||
# are same.
|
||||
|
|
@ -549,7 +549,7 @@ class DistAutogradTest(object):
|
|||
loss = ret.sum()
|
||||
local_grads = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t1, t2)
|
||||
|
||||
@dist_init
|
||||
@dist_init(setup_model_parallel=True)
|
||||
def test_backward_multiple_round_trips(self):
|
||||
local_grads = None
|
||||
t1 = torch.rand((3, 3), requires_grad=True)
|
||||
|
|
@ -571,7 +571,7 @@ class DistAutogradTest(object):
|
|||
|
||||
local_grads = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t1, t2, t3, t4, t5)
|
||||
|
||||
@dist_init
|
||||
@dist_init(setup_model_parallel=True)
|
||||
def test_backward_different_tensor_dims(self):
|
||||
local_grads = None
|
||||
t1 = torch.rand((4, 6), requires_grad=True)
|
||||
|
|
@ -587,7 +587,7 @@ class DistAutogradTest(object):
|
|||
|
||||
local_grads = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t1, t2, t2, t3, t4)
|
||||
|
||||
@dist_init
|
||||
@dist_init(setup_model_parallel=True)
|
||||
def test_backward_unused_tensors(self):
|
||||
local_grads = None
|
||||
t1 = torch.rand((3, 3), requires_grad=True)
|
||||
|
|
@ -601,7 +601,7 @@ class DistAutogradTest(object):
|
|||
loss = val.sum()
|
||||
local_grads = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t1, t2, t3)
|
||||
|
||||
@dist_init
|
||||
@dist_init(setup_model_parallel=True)
|
||||
def test_backward_multiple_output_tensors(self):
|
||||
local_grads = None
|
||||
t = torch.rand((10, 2), requires_grad=True)
|
||||
|
|
@ -633,7 +633,7 @@ class DistAutogradTest(object):
|
|||
dist_autograd.backward([val.sum()])
|
||||
|
||||
|
||||
@dist_init
|
||||
@dist_init(setup_model_parallel=True)
|
||||
def test_backward_unused_send_function(self):
|
||||
# Run the test in a thread which would never finish.
|
||||
t = threading.Thread(target=self._run_test_backward_unused_send_function_in_thread)
|
||||
|
|
@ -644,7 +644,7 @@ class DistAutogradTest(object):
|
|||
# Verify thread is still alive (indicating backward hasn't completed yet).
|
||||
self.assertTrue(t.is_alive())
|
||||
|
||||
@dist_init
|
||||
@dist_init(setup_model_parallel=True)
|
||||
def test_backward_autograd_engine_error(self):
|
||||
with dist_autograd.context() as context_id:
|
||||
t1 = torch.rand((3, 3), requires_grad=True)
|
||||
|
|
@ -666,7 +666,7 @@ class DistAutogradTest(object):
|
|||
# Run backwards, and validate we receive an error.
|
||||
dist_autograd.backward([val.sum()])
|
||||
|
||||
@dist_init
|
||||
@dist_init(setup_model_parallel=True)
|
||||
@unittest.skip("Skipping this test temporarily since ProcessGroupAgent does not report errors on node failures")
|
||||
def test_backward_node_failure(self):
|
||||
with dist_autograd.context() as context_id:
|
||||
|
|
@ -687,7 +687,7 @@ class DistAutogradTest(object):
|
|||
# Kill all other nodes.
|
||||
sys.exit(0)
|
||||
|
||||
@dist_init
|
||||
@dist_init(setup_model_parallel=True)
|
||||
def test_backward_without_context(self):
|
||||
t1 = torch.rand((3, 3), requires_grad=True)
|
||||
t2 = torch.rand((3, 3), requires_grad=True)
|
||||
|
|
@ -697,7 +697,7 @@ class DistAutogradTest(object):
|
|||
args=(t1, t2))
|
||||
dist_autograd.backward([res.sum()])
|
||||
|
||||
@dist_init
|
||||
@dist_init(setup_model_parallel=True)
|
||||
def test_backward_without_rpc(self):
|
||||
dst_rank = self.rank
|
||||
with dist_autograd.context() as context_id:
|
||||
|
|
@ -713,7 +713,7 @@ class DistAutogradTest(object):
|
|||
self.assertEqual(torch.ones(3, 3), grads[t1])
|
||||
self.assertEqual(torch.ones(3, 3), grads[t2])
|
||||
|
||||
@dist_init
|
||||
@dist_init(setup_model_parallel=True)
|
||||
def test_backward_invalid_args(self):
|
||||
with dist_autograd.context() as context_id:
|
||||
|
||||
|
|
@ -735,7 +735,7 @@ class DistAutogradTest(object):
|
|||
t = torch.rand(1, requires_grad=True)
|
||||
dist_autograd.backward([t])
|
||||
|
||||
@dist_init
|
||||
@dist_init(setup_model_parallel=True)
|
||||
def test_backward_multiple_roots(self):
|
||||
local_grads = None
|
||||
t1 = torch.rand((3, 3), requires_grad=True)
|
||||
|
|
|
|||
|
|
@ -52,6 +52,10 @@ def set_termination_signal():
|
|||
|
||||
|
||||
def dist_init(setup_model_parallel=True):
|
||||
assert isinstance(setup_model_parallel, bool), (
|
||||
"setup_model_parallel must be a bool value"
|
||||
)
|
||||
|
||||
def decorator(old_test_method):
|
||||
"""
|
||||
We use this decorator for setting up and tearing down state since
|
||||
|
|
@ -87,7 +91,7 @@ def dist_init(setup_model_parallel=True):
|
|||
num_send_recv_threads=16,
|
||||
)
|
||||
|
||||
old_test_method(self, *arg, **kwargs)
|
||||
ret = old_test_method(self, *arg, **kwargs)
|
||||
|
||||
if setup_model_parallel:
|
||||
# Follower reports done.
|
||||
|
|
@ -119,6 +123,8 @@ def dist_init(setup_model_parallel=True):
|
|||
# Close RPC.
|
||||
rpc.join_rpc()
|
||||
|
||||
return ret
|
||||
|
||||
return new_test_method
|
||||
|
||||
return decorator
|
||||
|
|
|
|||
|
|
@ -932,3 +932,17 @@ class RpcTest(object):
|
|||
|
||||
if TEST_CONFIG.rpc_backend == RpcBackend.PROCESS_GROUP:
|
||||
self.assertEqual(test_func(), "expected result")
|
||||
|
||||
def test_dist_init_decorator(self):
|
||||
@dist_init(setup_model_parallel=False)
|
||||
def test_func(self):
|
||||
return "expected result"
|
||||
|
||||
self.assertEqual(test_func(self), "expected result")
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError, "setup_model_parallel must be a bool value"
|
||||
):
|
||||
@dist_init
|
||||
def test_func(self):
|
||||
return "expected result"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user