# Owner(s): ["module: dynamo"] import copy import functools import os import random import unittest from unittest.mock import patch import numpy as np import torch import torch._dynamo from torch._dynamo.optimizations.distributed import DDPOptimizer import torch._dynamo.test_case import torch.distributed as dist from contextlib import contextmanager from torch import nn from torch._dynamo import config from torch._dynamo.utils import same from torch._dynamo.testing import collect_results from torch._inductor.utils import has_triton from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.nn.parallel import DistributedDataParallel as DDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.testing._internal.common_distributed import ( MultiProcessTestCase, import_transformers_or_skip, skip_if_lt_x_gpu, requires_nccl ) import torch._dynamo.logging def reset_rng_state(): torch.manual_seed(1337) random.seed(1337) np.random.seed(1337) def init_weights(m): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) m.bias.data.fill_(0.01) class ToyModel(nn.Module): def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5): super().__init__() self.net = nn.Sequential( *[nn.Linear(in_feat, hidden_feat), nn.ReLU()] + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()] + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()] + [nn.Linear(hidden_feat, out_feat), nn.ReLU()] ) def forward(self, inputs): return self.net(inputs) def get_model(device, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5): m = ToyModel(in_feat=in_feat, hidden_feat=hidden_feat, out_feat=out_feat).to(device) m.apply(init_weights) inputs = torch.rand(bsz, in_feat).to(device) outputs = m(inputs) return m, inputs, outputs def get_custom_model(device): class MyCustomLinear(torch.nn.Module): def __init__(self): super(MyCustomLinear, self).__init__() self.weight = nn.Parameter(torch.randn(512, 512)) def forward(self, x): return torch.mm(x, self.weight.t()) class MyLinear(torch.nn.Module): def __init__(self): super(MyLinear, self).__init__() self.linear = torch.nn.Linear(512, 512) def forward(self, x): return self.linear(x) class MyModule(torch.nn.Module): def __init__(self): super(MyModule, self).__init__() mods = [ (MyLinear(), torch.nn.ReLU()), # sandwitch the custom in the middle so it comes before and after (MyCustomLinear(), torch.nn.ReLU()), (MyLinear(), torch.nn.ReLU()), ] self.seq = torch.nn.Sequential(*[x for items in mods for x in items]) def forward(self, x): return self.seq(x) m = MyModule().to(device) m.apply(init_weights) inputs = torch.rand((512, 512)).to(device) correct_outputs = m(inputs) return m, inputs, correct_outputs def get_hf_bert(rank): # Note: use @import_transformers_or_skip on your test case if you use this # in a multiprocessing test try: from transformers import BertConfig, AutoModelForMaskedLM except ImportError as e: raise unittest.SkipTest("Unable to import transformers") from e batch_size, max_length, config, device = 4, 512, BertConfig(), f"cuda:{rank}" model = AutoModelForMaskedLM.from_config(config).to(device) input_ids = torch.randint(0, config.vocab_size, (batch_size, max_length)).to(device) decoder_ids = torch.randint(0, config.vocab_size, (batch_size, max_length)).to(device) inputs = {'input_ids': input_ids, 'labels': decoder_ids} model.train() return model, inputs class CheckSplitsCompiler: def __init__(self): self.compiler_called = 0 def compile_fn(self, gm, example_inputs): self.compiler_called += 1 return gm @contextmanager def _per_rank_init(rank, world_size): # To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase, # Just manually implement the most important part of the dynamo behavior to reset/clear. torch.cuda.set_device(rank) os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '6789' dist.init_process_group("nccl", rank=rank, world_size=world_size) torch._dynamo.reset() torch._dynamo.utils.counters.clear() yield torch._dynamo.reset() torch._dynamo.utils.counters.clear() dist.destroy_process_group() # This simulates DDP, but it doesn't actually do any process communication; # it just has enough properties so that the dynamo distributed optimization is # able to optimize. Feel free to simulate more properties as necessary. The # other important thing is patching _active_ddp_module, which is what actually # triggers DDP optimization class FakeDDP(nn.Module): def __init__(self, module): super().__init__() self.module = module bucket_cap_mb = 25 self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024) @contextmanager def _inside_ddp_forward(self): DDP._active_ddp_module = self try: yield except Exception: raise finally: DDP._active_ddp_module = None def forward(self, *inputs, **kwargs): with self._inside_ddp_forward(): return self.module.forward(*inputs, **kwargs) def run_hf_bert_ddp(self, model, inputs, backend): reset_rng_state() correct_outputs = model(**inputs) correct_loss = correct_outputs.loss correct_loss.backward() reset_rng_state() opt_model = torch._dynamo.optimize(backend)(model) opt_outputs = opt_model(**inputs) opt_loss = opt_outputs.loss opt_loss.backward() inputs_flat = [inputs[k] for k in inputs] correct_results = collect_results(model, correct_outputs.logits, correct_loss, inputs_flat) opt_results = collect_results(opt_model, opt_outputs.logits, opt_loss, inputs_flat) self.assertTrue(same(correct_results, opt_results)) class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @patch.object(config, "optimize_ddp", True) @patch.object(torch._inductor.config, "fallback_random", True) def test_hf_bert_ddp_inductor(self): model, inputs = get_hf_bert(0) model = FakeDDP(model) run_hf_bert_ddp(self, model, inputs, "inductor") @patch.object(config, "optimize_ddp", True) def test_hf_bert_ddp_aot_eager(self): model, inputs = get_hf_bert(0) model = FakeDDP(model) run_hf_bert_ddp(self, model, inputs, "aot_eager") # Are these tests failing? Check and see if TestFakeDistributedSingleProc has a # single process version; if it's just a problem in the Dynamo distributed # optimizer, you should be able to repro it single process! @requires_nccl() class TestDistributedMultiProc(MultiProcessTestCase): def setUp(self): super(TestDistributedMultiProc, self).setUp() self._spawn_processes() def tearDown(self): super(TestDistributedMultiProc, self).tearDown() try: os.remove(self.file_name) except OSError: pass @property def world_size(self) -> int: return torch.cuda.device_count() @classmethod def _run(cls, rank: int, test_name: str, file_name: str, parent_pipe) -> None: # Don't enable DDP + ReplicatedTensor, as that breaks Dynamo+DDP # TODO(whc) why is ReplicatedTensor defaulted=True in MultiProcessTestCase, and should we support it? # from torch.nn.parallel._replicated_tensor_ddp_utils import _set_ddp_with_replicated_tensor # _set_ddp_with_replicated_tensor(True) # The rest is copypasta from MultiProcessTestCase._run self = cls(test_name) self.rank = rank self.file_name = file_name self.run_test(test_name, parent_pipe) @skip_if_lt_x_gpu(2) @patch.object(config, "optimize_ddp", False) def test_ddp_baseline_aot_eager_multiprocess(self): with _per_rank_init(self.rank, self.world_size): self.assertFalse(config.optimize_ddp) m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") m = DDP(m, device_ids=[self.rank]) m = torch._dynamo.optimize("aot_eager")(m) outputs = m(inputs) self.assertTrue(same(correct_outputs, outputs)) @skip_if_lt_x_gpu(2) @import_transformers_or_skip() @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @patch.object(config, "optimize_ddp", True) @patch.object(torch._inductor.config, "fallback_random", True) def test_hf_bert_ddp_inductor(self): with _per_rank_init(self.rank, self.world_size): model, inputs = get_hf_bert(self.rank) model = DDP(model) run_hf_bert_ddp(self, model, inputs, "inductor") @skip_if_lt_x_gpu(2) @import_transformers_or_skip() @patch.object(config, "optimize_ddp", True) def test_hf_bert_ddp_aot_eager(self): with _per_rank_init(self.rank, self.world_size): model, inputs = get_hf_bert(self.rank) model = DDP(model) run_hf_bert_ddp(self, model, inputs, "aot_eager") @skip_if_lt_x_gpu(1) def test_fsdp_aot_eager(self): with _per_rank_init(self.rank, self.world_size): # Test with basic FSDP wrapping (outer wrap around whole model) m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") fsdp_m = FSDP(m, use_orig_params=True) fsdp_m = torch._dynamo.optimize("aot_eager")(fsdp_m) outputs = fsdp_m(inputs) self.assertTrue(same(correct_outputs, outputs)) # Test with recursive wrapping, nested FSDP around each Linear m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") fsdp_m = FSDP( m, auto_wrap_policy=functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear, ) ), use_orig_params=True ) fsdp_m = torch._dynamo.optimize("aot_eager")(fsdp_m) outputs = fsdp_m(inputs) self.assertTrue(same(correct_outputs, outputs)) @skip_if_lt_x_gpu(1) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") def test_fsdp_inductor(self): with _per_rank_init(self.rank, self.world_size): # Test with basic FSDP wrapping (outer wrap around whole model) m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") fsdp_m = FSDP(m, use_orig_params=True) fsdp_m = torch._dynamo.optimize("inductor")(fsdp_m) outputs = fsdp_m(inputs) self.assertTrue(same(correct_outputs, outputs)) # Test with recursive wrapping, nested FSDP around each Linear m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") fsdp_m = FSDP( m, auto_wrap_policy=functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear, ) ), use_orig_params=True ) fsdp_m = torch._dynamo.optimize("inductor")(fsdp_m) outputs = fsdp_m(inputs) self.assertTrue(same(correct_outputs, outputs)) @import_transformers_or_skip() @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") # TODO(whc) Investigate why cudagraphs breaks inductor+fsdp for hf_bert @patch.object(torch._inductor.config.triton, "cudagraphs", False) @patch.object(torch._inductor.config, "fallback_random", True) def test_hf_bert_fsdp(self): from transformers.models.bert.modeling_bert import BertLayer def apply_fsdp(model, wrap_policy): model = FSDP( copy.deepcopy(model), auto_wrap_policy=wrap_policy, use_orig_params=True ) return model with _per_rank_init(self.rank, self.world_size): for (wrap_policy, test_instance) in ( ( None, "FSDP without recursive wrapping" ), ( functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=(BertLayer, ) ), "FSDP with recursive wrapping BertLayer instances" ) ): print(f"Running hf_bert test for {test_instance}") model, inputs = get_hf_bert(self.rank) reset_rng_state() eager_model = apply_fsdp(model, wrap_policy) correct_outputs = eager_model(**inputs) correct_loss = correct_outputs.loss correct_loss.backward() reset_rng_state() opt_model = apply_fsdp(model, wrap_policy) opt_model = torch._dynamo.optimize("inductor")(opt_model) opt_outputs = opt_model(**inputs) opt_loss = opt_outputs.loss opt_loss.backward() inputs_flat = [inputs[k] for k in inputs] correct_results = collect_results(eager_model, correct_outputs.logits, correct_loss, inputs_flat) opt_results = collect_results(opt_model, opt_outputs.logits, opt_loss, inputs_flat) self.assertTrue(same(correct_results, opt_results)) @requires_nccl() class TestDistributed(torch._dynamo.test_case.TestCase): """ Test harness initializes dist process group """ @classmethod def setUpClass(cls): super().setUpClass() # _exit_stack is set up in TestCase cls._exit_stack.enter_context( patch.dict( os.environ, { "MASTER_ADDR": "localhost", "MASTER_PORT": "12355", }, ) ) cls.rank = 0 cls.device = f"cuda:{cls.rank}" cls.device_ids = None if "cuda" in cls.device else [cls.rank] dist.init_process_group("nccl", rank=cls.rank, world_size=1) @classmethod def tearDownClass(cls): dist.destroy_process_group() super().tearDownClass() def get_model(self, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5): m = ToyModel(in_feat=in_feat, hidden_feat=hidden_feat, out_feat=out_feat).to(self.device) m.apply(init_weights) inputs = torch.rand(bsz, in_feat).to(self.device) outputs = m(inputs) return m, inputs, outputs @patch.object(config, "optimize_ddp", False) def test_ddp_baseline_aot_eager(self): from torch.nn.parallel import DistributedDataParallel as DDP m, inputs, correct_outputs = self.get_model() ddp_m = DDP(m, device_ids=self.device_ids) ddp_m = torch._dynamo.optimize("aot_eager")(ddp_m) outputs = ddp_m(inputs) self.assertTrue(same(correct_outputs, outputs)) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @patch.object(config, "optimize_ddp", False) def test_ddp_baseline_inductor(self): from torch.nn.parallel import DistributedDataParallel as DDP m, inputs, correct_outputs = self.get_model() ddp_m = DDP(m, device_ids=self.device_ids) ddp_m = torch._dynamo.optimize("inductor")(ddp_m) outputs = ddp_m(inputs) self.assertTrue(same(correct_outputs, outputs)) @patch.object(config, "optimize_ddp", True) def test_graph_split(self): """ Just ensures that the appropriate number of splits happen (based on bucket size and model parameters) - verifies the number of times the user-provided compiler is called by the DDPOptimizer which is doing the graph splitting """ m, inputs, correct_outputs = self.get_model() ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) check_splits_compiler = CheckSplitsCompiler() @torch._dynamo.optimize(check_splits_compiler.compile_fn) def opt_fn(inputs): return ddp_m(inputs) opt_outputs = opt_fn(inputs) self.assertTrue(same(correct_outputs, opt_outputs)) self.assertEqual(check_splits_compiler.compiler_called, 3) @patch.object(config, "optimize_ddp", True) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") def test_graph_split_inductor(self): """ Same as above, but using inductor backend. We observed issues with inductor/fx interface in the past. """ m, inputs, correct_outputs = self.get_model() ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) @torch._dynamo.optimize("inductor") def opt_fn(inputs): return ddp_m(inputs) opt_outputs = opt_fn(inputs) self.assertTrue(same(correct_outputs, opt_outputs)) @patch.object(config, "optimize_ddp", True) def test_no_split(self): """ Ensures the DDPOptimizer returns a correct, compiled module without introducing graph splits. (Based on model parmeters fitting in the bucket) """ # DDP will always do a 'first bucket' with a really small size; so only a tiny model will escape this m, inputs, correct_outputs = self.get_model(hidden_feat=5) ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=250) check_splits_compiler = CheckSplitsCompiler() @torch._dynamo.optimize(check_splits_compiler.compile_fn) def opt_fn(inputs): return ddp_m(inputs) opt_outputs = opt_fn(inputs) self.assertTrue(same(correct_outputs, opt_outputs)) self.assertEqual(check_splits_compiler.compiler_called, 1) @patch.object(config, "optimize_ddp", True) def test_aot_autograd(self): """ Explicitly check AotAutograd family of compilers work, since they require example inputs propagated between graph splits. """ m, inputs, correct_outputs = self.get_model() ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) @torch._dynamo.optimize("aot_eager") def opt_fn(inputs): return ddp_m(inputs) opt_outputs = opt_fn(inputs) opt_outputs.sum().backward() self.assertTrue(same(correct_outputs, opt_outputs)) @patch.object(config, "optimize_ddp", True) def test_custom_layer(self): """ Just ensures that the appropriate number of splits happen (based on bucket size and model parameters) - verifies the number of times the user-provided compiler is called by the DDPOptimizer which is doing the graph splitting """ m, inputs, correct_outputs = get_custom_model(self.device) ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=1) check_splits_compiler = CheckSplitsCompiler() @torch._dynamo.optimize(check_splits_compiler.compile_fn) def opt_fn(inputs): return ddp_m(inputs) opt_outputs = opt_fn(inputs) self.assertTrue(same(correct_outputs, opt_outputs)) self.assertEqual(check_splits_compiler.compiler_called, 3) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") def test_empty_graph_inductor(self): def fn(): get_world_size = torch.distributed.distributed_c10d.get_world_size() return (get_world_size,) opt_fn = torch._dynamo.optimize("inductor")(fn) res = None try: res = opt_fn()[0] except Exception: pass self.assertEqual(res, 1) @patch.object(config, "optimize_ddp", False) def test_ignored_parameters(self): """ Verifies ddp graph-split logic ignores parameters marked to ignore on DDP module. Hooks up graph-split optimizer manually so it can peek at internal state. """ m, inputs, correct_outputs = get_custom_model(self.device) parameters_to_ignore = ["seq.2.weight", "seq.4.linear.bias"] DDP._set_params_and_buffers_to_ignore_for_model(m, parameters_to_ignore) ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) parameter_ids_to_ignore = [ id(ddp_m.module.get_parameter(p)) for p in ddp_m.parameters_to_ignore ] check_splits_compiler = CheckSplitsCompiler() ddp_optimizer = DDPOptimizer( bucket_bytes_cap=ddp_m.bucket_bytes_cap, backend_compile_fn=check_splits_compiler.compile_fn ) @torch._dynamo.optimize(ddp_optimizer.compile_fn) def opt_fn(inputs): return ddp_m(inputs) opt_outputs = opt_fn(inputs) self.assertTrue(same(correct_outputs, opt_outputs)) self.assertEqual(check_splits_compiler.compiler_called, 2) for b in ddp_optimizer.buckets: for p_id in b.param_ids: self.assertFalse(p_id in parameter_ids_to_ignore) def test_fsdp_orig_params_assert(self): # Test with basic FSDP wrapping (outer wrap around whole model) m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") fsdp_m = FSDP(m, use_orig_params=False) fsdp_m = torch._dynamo.optimize()(fsdp_m) self.assertRaisesRegex(AssertionError, "Dynamo only supports FSDP with use_orig_params=True", fsdp_m, inputs) if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()