mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Let the developer decide! Taken from voz's https://github.com/pytorch/pytorch/pull/89392 Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/89655 Approved by: https://github.com/albanD
566 lines
21 KiB
Python
566 lines
21 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import copy
|
|
import functools
|
|
import os
|
|
import random
|
|
import unittest
|
|
from unittest.mock import patch
|
|
import numpy as np
|
|
import torch
|
|
import torch._dynamo
|
|
from torch._dynamo.optimizations.distributed import DDPOptimizer
|
|
import torch._dynamo.test_case
|
|
import torch.distributed as dist
|
|
from contextlib import contextmanager
|
|
from torch import nn
|
|
from torch._dynamo import config
|
|
from torch._dynamo.utils import same
|
|
from torch._dynamo.testing import collect_results
|
|
from torch._inductor.utils import has_triton
|
|
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
from torch.testing._internal.common_utils import TestCase
|
|
from torch.testing._internal.common_distributed import (
|
|
MultiProcessTestCase,
|
|
import_transformers_or_skip,
|
|
skip_if_lt_x_gpu,
|
|
requires_nccl
|
|
)
|
|
import torch._dynamo.logging
|
|
|
|
|
|
def reset_rng_state():
|
|
torch.manual_seed(1337)
|
|
random.seed(1337)
|
|
np.random.seed(1337)
|
|
|
|
def init_weights(m):
|
|
if isinstance(m, nn.Linear):
|
|
nn.init.xavier_uniform_(m.weight)
|
|
m.bias.data.fill_(0.01)
|
|
|
|
class ToyModel(nn.Module):
|
|
def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5):
|
|
super().__init__()
|
|
self.net = nn.Sequential(
|
|
*[nn.Linear(in_feat, hidden_feat), nn.ReLU()]
|
|
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
|
|
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
|
|
+ [nn.Linear(hidden_feat, out_feat), nn.ReLU()]
|
|
)
|
|
|
|
def forward(self, inputs):
|
|
return self.net(inputs)
|
|
|
|
def get_model(device, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5):
|
|
m = ToyModel(in_feat=in_feat, hidden_feat=hidden_feat, out_feat=out_feat).to(device)
|
|
m.apply(init_weights)
|
|
inputs = torch.rand(bsz, in_feat).to(device)
|
|
outputs = m(inputs)
|
|
return m, inputs, outputs
|
|
|
|
def get_custom_model(device):
|
|
class MyCustomLinear(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyCustomLinear, self).__init__()
|
|
self.weight = nn.Parameter(torch.randn(512, 512))
|
|
|
|
def forward(self, x):
|
|
return torch.mm(x, self.weight.t())
|
|
|
|
class MyLinear(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyLinear, self).__init__()
|
|
self.linear = torch.nn.Linear(512, 512)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
mods = [
|
|
(MyLinear(), torch.nn.ReLU()),
|
|
# sandwitch the custom in the middle so it comes before and after
|
|
(MyCustomLinear(), torch.nn.ReLU()),
|
|
(MyLinear(), torch.nn.ReLU()),
|
|
]
|
|
self.seq = torch.nn.Sequential(*[x for items in mods for x in items])
|
|
|
|
def forward(self, x):
|
|
return self.seq(x)
|
|
|
|
m = MyModule().to(device)
|
|
m.apply(init_weights)
|
|
inputs = torch.rand((512, 512)).to(device)
|
|
correct_outputs = m(inputs)
|
|
return m, inputs, correct_outputs
|
|
|
|
def get_hf_bert(rank):
|
|
# Note: use @import_transformers_or_skip on your test case if you use this
|
|
# in a multiprocessing test
|
|
try:
|
|
from transformers import BertConfig, AutoModelForMaskedLM
|
|
except ImportError:
|
|
raise unittest.SkipTest("Unable to import transformers")
|
|
|
|
batch_size, max_length, config, device = 4, 512, BertConfig(), f"cuda:{rank}"
|
|
model = AutoModelForMaskedLM.from_config(config).to(device)
|
|
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_length)).to(device)
|
|
decoder_ids = torch.randint(0, config.vocab_size, (batch_size, max_length)).to(device)
|
|
inputs = {'input_ids': input_ids, 'labels': decoder_ids}
|
|
model.train()
|
|
return model, inputs
|
|
|
|
class CheckSplitsCompiler:
|
|
def __init__(self):
|
|
self.compiler_called = 0
|
|
|
|
def compile_fn(self, gm, example_inputs):
|
|
self.compiler_called += 1
|
|
return gm
|
|
|
|
@contextmanager
|
|
def _per_rank_init(rank, world_size):
|
|
# To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase,
|
|
# Just manually implement the most important part of the dynamo behavior to reset/clear.
|
|
torch.cuda.set_device(rank)
|
|
os.environ['MASTER_ADDR'] = 'localhost'
|
|
os.environ['MASTER_PORT'] = '6789'
|
|
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
|
torch._dynamo.reset()
|
|
torch._dynamo.utils.counters.clear()
|
|
yield
|
|
torch._dynamo.reset()
|
|
torch._dynamo.utils.counters.clear()
|
|
dist.destroy_process_group()
|
|
|
|
|
|
# This simulates DDP, but it doesn't actually do any process communication;
|
|
# it just has enough properties so that the dynamo distributed optimization is
|
|
# able to optimize. Feel free to simulate more properties as necessary. The
|
|
# other important thing is patching _active_ddp_module, which is what actually
|
|
# triggers DDP optimization
|
|
class FakeDDP(nn.Module):
|
|
def __init__(self, module):
|
|
super().__init__()
|
|
self.module = module
|
|
bucket_cap_mb = 25
|
|
self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024)
|
|
|
|
@contextmanager
|
|
def _inside_ddp_forward(self):
|
|
DDP._active_ddp_module = self
|
|
try:
|
|
yield
|
|
except Exception:
|
|
raise
|
|
finally:
|
|
DDP._active_ddp_module = None
|
|
|
|
def forward(self, *inputs, **kwargs):
|
|
with self._inside_ddp_forward():
|
|
return self.module.forward(*inputs, **kwargs)
|
|
|
|
def run_hf_bert_ddp(self, model, inputs, backend):
|
|
reset_rng_state()
|
|
correct_outputs = model(**inputs)
|
|
correct_loss = correct_outputs.loss
|
|
correct_loss.backward()
|
|
|
|
reset_rng_state()
|
|
opt_model = torch._dynamo.optimize(backend)(model)
|
|
opt_outputs = opt_model(**inputs)
|
|
opt_loss = opt_outputs.loss
|
|
opt_loss.backward()
|
|
|
|
inputs_flat = [inputs[k] for k in inputs]
|
|
correct_results = collect_results(model, correct_outputs.logits, correct_loss, inputs_flat)
|
|
opt_results = collect_results(opt_model, opt_outputs.logits, opt_loss, inputs_flat)
|
|
self.assertTrue(same(correct_results, opt_results))
|
|
|
|
class TestFakeDistributedSingleProc(TestCase):
|
|
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
|
@patch.object(config, "optimize_ddp", True)
|
|
@patch.object(torch._inductor.config, "fallback_random", True)
|
|
def test_hf_bert_ddp_inductor(self):
|
|
model, inputs = get_hf_bert(0)
|
|
model = FakeDDP(model)
|
|
run_hf_bert_ddp(self, model, inputs, "inductor")
|
|
|
|
@patch.object(config, "optimize_ddp", True)
|
|
def test_hf_bert_ddp_aot_eager(self):
|
|
model, inputs = get_hf_bert(0)
|
|
model = FakeDDP(model)
|
|
run_hf_bert_ddp(self, model, inputs, "aot_eager")
|
|
|
|
|
|
# Are these tests failing? Check and see if TestFakeDistributedSingleProc has a
|
|
# single process version; if it's just a problem in the Dynamo distributed
|
|
# optimizer, you should be able to repro it single process!
|
|
@requires_nccl()
|
|
class TestDistributedMultiProc(MultiProcessTestCase):
|
|
def setUp(self):
|
|
super(TestDistributedMultiProc, self).setUp()
|
|
self._spawn_processes()
|
|
|
|
def tearDown(self):
|
|
super(TestDistributedMultiProc, self).tearDown()
|
|
try:
|
|
os.remove(self.file_name)
|
|
except OSError:
|
|
pass
|
|
|
|
@property
|
|
def world_size(self) -> int:
|
|
return torch.cuda.device_count()
|
|
|
|
@classmethod
|
|
def _run(cls, rank: int, test_name: str, file_name: str, parent_pipe) -> None:
|
|
# Don't enable DDP + ReplicatedTensor, as that breaks Dynamo+DDP
|
|
# TODO(whc) why is ReplicatedTensor defaulted=True in MultiProcessTestCase, and should we support it?
|
|
# from torch.nn.parallel._replicated_tensor_ddp_utils import _set_ddp_with_replicated_tensor
|
|
# _set_ddp_with_replicated_tensor(True)
|
|
|
|
# The rest is copypasta from MultiProcessTestCase._run
|
|
self = cls(test_name)
|
|
self.rank = rank
|
|
self.file_name = file_name
|
|
self.run_test(test_name, parent_pipe)
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
@patch.object(config, "optimize_ddp", False)
|
|
def test_ddp_baseline_aot_eager_multiprocess(self):
|
|
with _per_rank_init(self.rank, self.world_size):
|
|
self.assertFalse(config.optimize_ddp)
|
|
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
|
|
m = DDP(m, device_ids=[self.rank])
|
|
m = torch._dynamo.optimize("aot_eager")(m)
|
|
outputs = m(inputs)
|
|
self.assertTrue(same(correct_outputs, outputs))
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
@import_transformers_or_skip()
|
|
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
|
@patch.object(config, "optimize_ddp", True)
|
|
@patch.object(torch._inductor.config, "fallback_random", True)
|
|
def test_hf_bert_ddp_inductor(self):
|
|
|
|
with _per_rank_init(self.rank, self.world_size):
|
|
model, inputs = get_hf_bert(self.rank)
|
|
model = DDP(model)
|
|
run_hf_bert_ddp(self, model, inputs, "inductor")
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
@import_transformers_or_skip()
|
|
@patch.object(config, "optimize_ddp", True)
|
|
def test_hf_bert_ddp_aot_eager(self):
|
|
with _per_rank_init(self.rank, self.world_size):
|
|
model, inputs = get_hf_bert(self.rank)
|
|
model = DDP(model)
|
|
run_hf_bert_ddp(self, model, inputs, "aot_eager")
|
|
|
|
@skip_if_lt_x_gpu(1)
|
|
def test_fsdp_aot_eager(self):
|
|
with _per_rank_init(self.rank, self.world_size):
|
|
# Test with basic FSDP wrapping (outer wrap around whole model)
|
|
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
|
|
fsdp_m = FSDP(m, use_orig_params=True)
|
|
fsdp_m = torch._dynamo.optimize("aot_eager")(fsdp_m)
|
|
outputs = fsdp_m(inputs)
|
|
self.assertTrue(same(correct_outputs, outputs))
|
|
|
|
# Test with recursive wrapping, nested FSDP around each Linear
|
|
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
|
|
fsdp_m = FSDP(
|
|
m,
|
|
auto_wrap_policy=functools.partial(
|
|
transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear, )
|
|
),
|
|
use_orig_params=True
|
|
)
|
|
fsdp_m = torch._dynamo.optimize("aot_eager")(fsdp_m)
|
|
outputs = fsdp_m(inputs)
|
|
self.assertTrue(same(correct_outputs, outputs))
|
|
|
|
@skip_if_lt_x_gpu(1)
|
|
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
|
def test_fsdp_inductor(self):
|
|
with _per_rank_init(self.rank, self.world_size):
|
|
# Test with basic FSDP wrapping (outer wrap around whole model)
|
|
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
|
|
fsdp_m = FSDP(m, use_orig_params=True)
|
|
fsdp_m = torch._dynamo.optimize("inductor")(fsdp_m)
|
|
outputs = fsdp_m(inputs)
|
|
self.assertTrue(same(correct_outputs, outputs))
|
|
|
|
# Test with recursive wrapping, nested FSDP around each Linear
|
|
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
|
|
fsdp_m = FSDP(
|
|
m,
|
|
auto_wrap_policy=functools.partial(
|
|
transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear, )
|
|
),
|
|
use_orig_params=True
|
|
)
|
|
fsdp_m = torch._dynamo.optimize("inductor")(fsdp_m)
|
|
outputs = fsdp_m(inputs)
|
|
self.assertTrue(same(correct_outputs, outputs))
|
|
|
|
@import_transformers_or_skip()
|
|
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
|
# TODO(whc) Investigate why cudagraphs breaks inductor+fsdp for hf_bert
|
|
@patch.object(torch._inductor.config.triton, "cudagraphs", False)
|
|
@patch.object(torch._inductor.config, "fallback_random", True)
|
|
def test_hf_bert_fsdp(self):
|
|
from transformers.models.bert.modeling_bert import BertLayer
|
|
|
|
def apply_fsdp(model, wrap_policy):
|
|
model = FSDP(
|
|
copy.deepcopy(model),
|
|
auto_wrap_policy=wrap_policy,
|
|
use_orig_params=True
|
|
)
|
|
return model
|
|
|
|
with _per_rank_init(self.rank, self.world_size):
|
|
for (wrap_policy, test_instance) in (
|
|
(
|
|
None,
|
|
"FSDP without recursive wrapping"
|
|
),
|
|
(
|
|
functools.partial(
|
|
transformer_auto_wrap_policy, transformer_layer_cls=(BertLayer, )
|
|
),
|
|
"FSDP with recursive wrapping BertLayer instances"
|
|
)
|
|
):
|
|
print(f"Running hf_bert test for {test_instance}")
|
|
model, inputs = get_hf_bert(self.rank)
|
|
reset_rng_state()
|
|
eager_model = apply_fsdp(model, wrap_policy)
|
|
correct_outputs = eager_model(**inputs)
|
|
correct_loss = correct_outputs.loss
|
|
correct_loss.backward()
|
|
|
|
reset_rng_state()
|
|
opt_model = apply_fsdp(model, wrap_policy)
|
|
|
|
opt_model = torch._dynamo.optimize("inductor")(opt_model)
|
|
opt_outputs = opt_model(**inputs)
|
|
opt_loss = opt_outputs.loss
|
|
opt_loss.backward()
|
|
|
|
inputs_flat = [inputs[k] for k in inputs]
|
|
correct_results = collect_results(eager_model, correct_outputs.logits, correct_loss, inputs_flat)
|
|
opt_results = collect_results(opt_model, opt_outputs.logits, opt_loss, inputs_flat)
|
|
self.assertTrue(same(correct_results, opt_results))
|
|
|
|
|
|
@requires_nccl()
|
|
class TestDistributed(torch._dynamo.test_case.TestCase):
|
|
"""
|
|
Test harness initializes dist process group
|
|
"""
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super().setUpClass()
|
|
# _exit_stack is set up in TestCase
|
|
cls._exit_stack.enter_context(
|
|
patch.dict(
|
|
os.environ,
|
|
{
|
|
"MASTER_ADDR": "localhost",
|
|
"MASTER_PORT": "12355",
|
|
},
|
|
)
|
|
)
|
|
cls.rank = 0
|
|
cls.device = f"cuda:{cls.rank}"
|
|
cls.device_ids = None if "cuda" in cls.device else [cls.rank]
|
|
dist.init_process_group("nccl", rank=cls.rank, world_size=1)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
dist.destroy_process_group()
|
|
super().tearDownClass()
|
|
|
|
def get_model(self, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5):
|
|
m = ToyModel(in_feat=in_feat, hidden_feat=hidden_feat, out_feat=out_feat).to(self.device)
|
|
m.apply(init_weights)
|
|
inputs = torch.rand(bsz, in_feat).to(self.device)
|
|
outputs = m(inputs)
|
|
return m, inputs, outputs
|
|
|
|
@patch.object(config, "optimize_ddp", False)
|
|
def test_ddp_baseline_aot_eager(self):
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
m, inputs, correct_outputs = self.get_model()
|
|
ddp_m = DDP(m, device_ids=self.device_ids)
|
|
ddp_m = torch._dynamo.optimize("aot_eager")(ddp_m)
|
|
outputs = ddp_m(inputs)
|
|
self.assertTrue(same(correct_outputs, outputs))
|
|
|
|
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
|
@patch.object(config, "optimize_ddp", False)
|
|
def test_ddp_baseline_inductor(self):
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
m, inputs, correct_outputs = self.get_model()
|
|
ddp_m = DDP(m, device_ids=self.device_ids)
|
|
ddp_m = torch._dynamo.optimize("inductor")(ddp_m)
|
|
outputs = ddp_m(inputs)
|
|
self.assertTrue(same(correct_outputs, outputs))
|
|
|
|
@patch.object(config, "optimize_ddp", True)
|
|
def test_graph_split(self):
|
|
"""
|
|
Just ensures that the appropriate number of splits happen (based on
|
|
bucket size and model parameters) - verifies the number of times
|
|
the user-provided compiler is called by the DDPOptimizer which is
|
|
doing the graph splitting
|
|
"""
|
|
|
|
m, inputs, correct_outputs = self.get_model()
|
|
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
|
|
|
|
check_splits_compiler = CheckSplitsCompiler()
|
|
|
|
@torch._dynamo.optimize(check_splits_compiler.compile_fn)
|
|
def opt_fn(inputs):
|
|
return ddp_m(inputs)
|
|
|
|
opt_outputs = opt_fn(inputs)
|
|
self.assertTrue(same(correct_outputs, opt_outputs))
|
|
self.assertEqual(check_splits_compiler.compiler_called, 3)
|
|
|
|
@patch.object(config, "optimize_ddp", True)
|
|
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
|
def test_graph_split_inductor(self):
|
|
"""
|
|
Same as above, but using inductor backend.
|
|
We observed issues with inductor/fx interface in the past.
|
|
"""
|
|
m, inputs, correct_outputs = self.get_model()
|
|
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
|
|
|
|
@torch._dynamo.optimize("inductor")
|
|
def opt_fn(inputs):
|
|
return ddp_m(inputs)
|
|
|
|
opt_outputs = opt_fn(inputs)
|
|
self.assertTrue(same(correct_outputs, opt_outputs))
|
|
|
|
@patch.object(config, "optimize_ddp", True)
|
|
def test_no_split(self):
|
|
"""
|
|
Ensures the DDPOptimizer returns a correct, compiled module without
|
|
introducing graph splits. (Based on model parmeters fitting in the bucket)
|
|
"""
|
|
# DDP will always do a 'first bucket' with a really small size; so only a tiny model will escape this
|
|
m, inputs, correct_outputs = self.get_model(hidden_feat=5)
|
|
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=250)
|
|
check_splits_compiler = CheckSplitsCompiler()
|
|
|
|
@torch._dynamo.optimize(check_splits_compiler.compile_fn)
|
|
def opt_fn(inputs):
|
|
return ddp_m(inputs)
|
|
|
|
opt_outputs = opt_fn(inputs)
|
|
self.assertTrue(same(correct_outputs, opt_outputs))
|
|
self.assertEqual(check_splits_compiler.compiler_called, 1)
|
|
|
|
@patch.object(config, "optimize_ddp", True)
|
|
def test_aot_autograd(self):
|
|
"""
|
|
Explicitly check AotAutograd family of compilers work,
|
|
since they require example inputs propagated between graph splits.
|
|
"""
|
|
m, inputs, correct_outputs = self.get_model()
|
|
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
|
|
|
|
@torch._dynamo.optimize("aot_eager")
|
|
def opt_fn(inputs):
|
|
return ddp_m(inputs)
|
|
|
|
opt_outputs = opt_fn(inputs)
|
|
opt_outputs.sum().backward()
|
|
self.assertTrue(same(correct_outputs, opt_outputs))
|
|
|
|
@patch.object(config, "optimize_ddp", True)
|
|
def test_custom_layer(self):
|
|
"""
|
|
Just ensures that the appropriate number of splits happen (based on
|
|
bucket size and model parameters) - verifies the number of times
|
|
the user-provided compiler is called by the DDPOptimizer which is
|
|
doing the graph splitting
|
|
"""
|
|
m, inputs, correct_outputs = get_custom_model(self.device)
|
|
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=1)
|
|
|
|
check_splits_compiler = CheckSplitsCompiler()
|
|
|
|
@torch._dynamo.optimize(check_splits_compiler.compile_fn)
|
|
def opt_fn(inputs):
|
|
return ddp_m(inputs)
|
|
|
|
opt_outputs = opt_fn(inputs)
|
|
self.assertTrue(same(correct_outputs, opt_outputs))
|
|
self.assertEqual(check_splits_compiler.compiler_called, 3)
|
|
|
|
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
|
def test_empty_graph_inductor(self):
|
|
def fn():
|
|
get_world_size = torch.distributed.distributed_c10d.get_world_size()
|
|
return (get_world_size,)
|
|
|
|
opt_fn = torch._dynamo.optimize("inductor")(fn)
|
|
res = None
|
|
try:
|
|
res = opt_fn()[0]
|
|
except Exception:
|
|
pass
|
|
self.assertEqual(res, 1)
|
|
|
|
@patch.object(config, "optimize_ddp", False)
|
|
def test_ignored_parameters(self):
|
|
"""
|
|
Verifies ddp graph-split logic ignores parameters marked to ignore on DDP module.
|
|
Hooks up graph-split optimizer manually so it can peek at internal state.
|
|
"""
|
|
m, inputs, correct_outputs = get_custom_model(self.device)
|
|
parameters_to_ignore = ["seq.2.weight", "seq.4.linear.bias"]
|
|
DDP._set_params_and_buffers_to_ignore_for_model(m, parameters_to_ignore)
|
|
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
|
|
parameter_ids_to_ignore = [
|
|
id(ddp_m.module.get_parameter(p))
|
|
for p in ddp_m.parameters_to_ignore
|
|
]
|
|
|
|
check_splits_compiler = CheckSplitsCompiler()
|
|
ddp_optimizer = DDPOptimizer(
|
|
bucket_bytes_cap=ddp_m.bucket_bytes_cap,
|
|
backend_compile_fn=check_splits_compiler.compile_fn
|
|
)
|
|
|
|
@torch._dynamo.optimize(ddp_optimizer.compile_fn)
|
|
def opt_fn(inputs):
|
|
return ddp_m(inputs)
|
|
|
|
opt_outputs = opt_fn(inputs)
|
|
self.assertTrue(same(correct_outputs, opt_outputs))
|
|
self.assertEqual(check_splits_compiler.compiler_called, 2)
|
|
for b in ddp_optimizer.buckets:
|
|
for p_id in b.param_ids:
|
|
self.assertFalse(p_id in parameter_ids_to_ignore)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|