pytorch/test/distributed/test_dynamo_distributed.py
Will Constable a3f3ec8fac [FSDP+dynamo]: forward treats parameter-views as params (#88781)
Dynamo+AotAutograd needs a way to wrap all tensors (whether
inputs or params/buffers) in FakeTensor wrappers, and
FSDP's mangling of parameters hides them from this wrapping.

This PR unblocks running hf_bert and hf_T5 with FSDP under dynamo, whether using recursive wrapping around transformer layers or only applying FSDP around the whole model.  Perf/memory validation and possibly optimization is the next step.
`python benchmarks/dynamo/distributed.py --torchbench_model hf_Bert --fsdp --dynamo aot_eager`
`python benchmarks/dynamo/distributed.py --torchbench_model hf_Bert --fsdp --dynamo aot_eager --fsdp_wrap`
`python benchmarks/dynamo/distributed.py --torchbench_model hf_T5 --fsdp --dynamo aot_eager`
`python benchmarks/dynamo/distributed.py --torchbench_model hf_T5 --fsdp --dynamo aot_eager --fsdp_wrap`

The problem:
Dynamo (Actually aot_autograd) trips up with FSDP becuase it must
wrap all input tensors in FakeTensor wrappers, and it only knows
to wrap graph inputs or named_(parameters, buffers).  FSDP's
pre_forward hook sets views (which are not nn.param) into the flatparam
as attrs on the module with the same name as the original param, but
they will not show up in named_parameters.

- in use_orig_params mode, FSDP still de-registers
  params during pre-forward hook, then re-registers them
  post-forward
- during forward (between the hooks), the params are setattr'd
  on the module as regular view tensors, not nn.Parameters
- note: use_orig_params is the recommended way to use FSDP,
  and use_orig_params=False is being deprecated.  So i only consider
  use_orig_params=True for this enablement

The solution:
- adding them to named_buffers is not possible because it interferes
  with how FSDP's `_apply` works
- since they are not actual nn.parameters, register_parameter will
  complain about registering them
- simply seting `module._parameters[name] = view` seems to be a viable
  workaround, despite being hacky, and FSDP code does modify _parameters
  directly already.

Note: Manual checkpointing still isn't working with FSDP+dynamo,
so that will have to be addressed in a follow up.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88781
Approved by: https://github.com/ezyang, https://github.com/awgu
2022-11-12 01:17:23 +00:00

512 lines
19 KiB
Python

# Owner(s): ["module: dynamo"]
import copy
import functools
import logging
import os
import random
import unittest
from unittest.mock import patch
import numpy as np
import torch
import torch._dynamo
from torch._dynamo.optimizations.distributed import DDPOptimizer
import torch._dynamo.test_case
import torch.distributed as dist
from contextlib import contextmanager
from torch import nn
from torch._dynamo import config
from torch._dynamo.utils import same
from torch._dynamo.testing import collect_results
from torch._inductor.utils import has_triton
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
import_transformers_or_skip,
skip_if_lt_x_gpu,
requires_nccl
)
import torch._dynamo.logging
def reset_rng_state():
torch.manual_seed(1337)
random.seed(1337)
np.random.seed(1337)
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
class ToyModel(nn.Module):
def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5):
super().__init__()
self.net = nn.Sequential(
*[nn.Linear(in_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, out_feat), nn.ReLU()]
)
def forward(self, inputs):
return self.net(inputs)
def get_model(device, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5):
m = ToyModel(in_feat=in_feat, hidden_feat=hidden_feat, out_feat=out_feat).to(device)
m.apply(init_weights)
inputs = torch.rand(bsz, in_feat).to(device)
outputs = m(inputs)
return m, inputs, outputs
def get_custom_model(device):
class MyCustomLinear(torch.nn.Module):
def __init__(self):
super(MyCustomLinear, self).__init__()
self.weight = nn.Parameter(torch.randn(512, 512))
def forward(self, x):
return torch.mm(x, self.weight.t())
class MyLinear(torch.nn.Module):
def __init__(self):
super(MyLinear, self).__init__()
self.linear = torch.nn.Linear(512, 512)
def forward(self, x):
return self.linear(x)
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
mods = [
(MyLinear(), torch.nn.ReLU()),
# sandwitch the custom in the middle so it comes before and after
(MyCustomLinear(), torch.nn.ReLU()),
(MyLinear(), torch.nn.ReLU()),
]
self.seq = torch.nn.Sequential(*[x for items in mods for x in items])
def forward(self, x):
return self.seq(x)
m = MyModule().to(device)
m.apply(init_weights)
inputs = torch.rand((512, 512)).to(device)
correct_outputs = m(inputs)
return m, inputs, correct_outputs
def get_hf_bert(rank):
# Note: use @import_transformers_or_skip on your test case if you use this
try:
from transformers import BertConfig, AutoModelForMaskedLM
except ImportError:
unittest.skip("Unable to import transformers")
batch_size, max_length, config, device = 4, 512, BertConfig(), f"cuda:{rank}"
model = AutoModelForMaskedLM.from_config(config).to(device)
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_length)).to(device)
decoder_ids = torch.randint(0, config.vocab_size, (batch_size, max_length)).to(device)
inputs = {'input_ids': input_ids, 'labels': decoder_ids}
model.train()
return model, inputs
class CheckSplitsCompiler:
def __init__(self):
self.compiler_called = 0
def compile_fn(self, gm, example_inputs):
self.compiler_called += 1
return gm
@contextmanager
def _per_rank_init(rank, world_size):
# To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase,
# Just manually implement the most important part of the dynamo behavior to reset/clear.
torch.cuda.set_device(rank)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '6789'
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch._dynamo.reset()
torch._dynamo.utils.counters.clear()
yield
torch._dynamo.reset()
torch._dynamo.utils.counters.clear()
dist.destroy_process_group()
@requires_nccl()
class TestDistributedMultiProc(MultiProcessTestCase):
def setUp(self):
super(TestDistributedMultiProc, self).setUp()
self._spawn_processes()
def tearDown(self):
super(TestDistributedMultiProc, self).tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
@property
def world_size(self) -> int:
return torch.cuda.device_count()
@classmethod
def _run(cls, rank: int, test_name: str, file_name: str, parent_pipe) -> None:
# Don't enable DDP + ReplicatedTensor, as that breaks Dynamo+DDP
# TODO(whc) why is ReplicatedTensor defaulted=True in MultiProcessTestCase, and should we support it?
# from torch.nn.parallel._replicated_tensor_ddp_utils import _set_ddp_with_replicated_tensor
# _set_ddp_with_replicated_tensor(True)
# The rest is copypasta from MultiProcessTestCase._run
self = cls(test_name)
self.rank = rank
self.file_name = file_name
self.run_test(test_name, parent_pipe)
@skip_if_lt_x_gpu(2)
@patch.object(config, "optimize_ddp", False)
def test_ddp_baseline_aot_eager_multiprocess(self):
with _per_rank_init(self.rank, self.world_size):
self.assertFalse(config.optimize_ddp)
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
m = DDP(m, device_ids=[self.rank])
m = torch._dynamo.optimize("aot_eager")(m)
outputs = m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@skip_if_lt_x_gpu(2)
@import_transformers_or_skip()
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@patch.object(config, "optimize_ddp", True)
@patch.object(torch._inductor.config, "fallback_random", True)
def test_hf_bert_ddp(self):
with _per_rank_init(self.rank, self.world_size):
model, inputs = get_hf_bert(self.rank)
model = DDP(model)
reset_rng_state()
correct_outputs = model(**inputs)
correct_loss = correct_outputs.loss
correct_loss.backward()
reset_rng_state()
opt_model = torch._dynamo.optimize("inductor")(model)
opt_outputs = opt_model(**inputs)
opt_loss = opt_outputs.loss
opt_loss.backward()
inputs_flat = [inputs[k] for k in inputs]
correct_results = collect_results(model, correct_outputs.logits, correct_loss, inputs_flat)
opt_results = collect_results(opt_model, opt_outputs.logits, opt_loss, inputs_flat)
self.assertTrue(same(correct_results, opt_results))
@skip_if_lt_x_gpu(1)
# TODO(whc) delete aot_eager test, if inductor test lands stably
def test_fsdp_aot_eager(self):
with _per_rank_init(self.rank, self.world_size):
# Test with basic FSDP wrapping (outer wrap around whole model)
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
fsdp_m = FSDP(m, use_orig_params=True)
fsdp_m = torch._dynamo.optimize("aot_eager")(fsdp_m)
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
# Test with recursive wrapping, nested FSDP around each Linear
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
fsdp_m = FSDP(
m,
auto_wrap_policy=functools.partial(
transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear, )
),
use_orig_params=True
)
fsdp_m = torch._dynamo.optimize("aot_eager")(fsdp_m)
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@skip_if_lt_x_gpu(1)
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_fsdp_inductor(self):
with _per_rank_init(self.rank, self.world_size):
# Test with basic FSDP wrapping (outer wrap around whole model)
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
fsdp_m = FSDP(m, use_orig_params=True)
fsdp_m = torch._dynamo.optimize("inductor")(fsdp_m)
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
# Test with recursive wrapping, nested FSDP around each Linear
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
fsdp_m = FSDP(
m,
auto_wrap_policy=functools.partial(
transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear, )
),
use_orig_params=True
)
fsdp_m = torch._dynamo.optimize("inductor")(fsdp_m)
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@import_transformers_or_skip()
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
# TODO(whc) Investigate why cudagraphs breaks inductor+fsdp for hf_bert
@patch.object(torch._inductor.config.triton, "cudagraphs", False)
@patch.object(torch._inductor.config, "fallback_random", True)
def test_hf_bert_fsdp(self):
from transformers.models.bert.modeling_bert import BertLayer
def apply_fsdp(model, wrap_policy):
model = FSDP(
copy.deepcopy(model),
auto_wrap_policy=wrap_policy,
use_orig_params=True
)
return model
with _per_rank_init(self.rank, self.world_size):
for (wrap_policy, test_instance) in (
(
None,
"FSDP without recursive wrapping"
),
(
functools.partial(
transformer_auto_wrap_policy, transformer_layer_cls=(BertLayer, )
),
"FSDP with recursive wrapping BertLayer instances"
)
):
print(f"Running hf_bert test for {test_instance}")
model, inputs = get_hf_bert(self.rank)
reset_rng_state()
eager_model = apply_fsdp(model, wrap_policy)
correct_outputs = eager_model(**inputs)
correct_loss = correct_outputs.loss
correct_loss.backward()
reset_rng_state()
opt_model = apply_fsdp(model, wrap_policy)
opt_model = torch._dynamo.optimize("inductor")(opt_model)
opt_outputs = opt_model(**inputs)
opt_loss = opt_outputs.loss
opt_loss.backward()
inputs_flat = [inputs[k] for k in inputs]
correct_results = collect_results(eager_model, correct_outputs.logits, correct_loss, inputs_flat)
opt_results = collect_results(opt_model, opt_outputs.logits, opt_loss, inputs_flat)
self.assertTrue(same(correct_results, opt_results))
@requires_nccl()
class TestDistributed(torch._dynamo.test_case.TestCase):
"""
Test harness initializes dist process group
"""
@classmethod
def setUpClass(cls):
super().setUpClass()
# _exit_stack is set up in TestCase
cls._exit_stack.enter_context(
patch.dict(
os.environ,
{
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12355",
},
)
)
cls._exit_stack.enter_context(patch.object(config, "log_level", logging.DEBUG))
cls.rank = 0
cls.device = f"cuda:{cls.rank}"
cls.device_ids = None if "cuda" in cls.device else [cls.rank]
dist.init_process_group("nccl", rank=cls.rank, world_size=1)
@classmethod
def tearDownClass(cls):
dist.destroy_process_group()
super().tearDownClass()
def get_model(self, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5):
m = ToyModel(in_feat=in_feat, hidden_feat=hidden_feat, out_feat=out_feat).to(self.device)
m.apply(init_weights)
inputs = torch.rand(bsz, in_feat).to(self.device)
outputs = m(inputs)
return m, inputs, outputs
@patch.object(config, "optimize_ddp", False)
def test_ddp_baseline_aot_eager(self):
from torch.nn.parallel import DistributedDataParallel as DDP
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=self.device_ids)
ddp_m = torch._dynamo.optimize("aot_eager")(ddp_m)
outputs = ddp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@patch.object(config, "optimize_ddp", False)
def test_ddp_baseline_inductor(self):
from torch.nn.parallel import DistributedDataParallel as DDP
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=self.device_ids)
ddp_m = torch._dynamo.optimize("inductor")(ddp_m)
outputs = ddp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@patch.object(config, "optimize_ddp", True)
def test_graph_split(self):
"""
Just ensures that the appropriate number of splits happen (based on
bucket size and model parameters) - verifies the number of times
the user-provided compiler is called by the DDPOptimizer which is
doing the graph splitting
"""
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
check_splits_compiler = CheckSplitsCompiler()
@torch._dynamo.optimize(check_splits_compiler.compile_fn)
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
self.assertEqual(check_splits_compiler.compiler_called, 3)
@patch.object(config, "optimize_ddp", True)
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_graph_split_inductor(self):
"""
Same as above, but using inductor backend.
We observed issues with inductor/fx interface in the past.
"""
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
@torch._dynamo.optimize("inductor")
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
@patch.object(config, "optimize_ddp", True)
def test_no_split(self):
"""
Ensures the DDPOptimizer returns a correct, compiled module without
introducing graph splits. (Based on model parmeters fitting in the bucket)
"""
# DDP will always do a 'first bucket' with a really small size; so only a tiny model will escape this
m, inputs, correct_outputs = self.get_model(hidden_feat=5)
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=250)
check_splits_compiler = CheckSplitsCompiler()
@torch._dynamo.optimize(check_splits_compiler.compile_fn)
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
self.assertEqual(check_splits_compiler.compiler_called, 1)
@patch.object(config, "optimize_ddp", True)
def test_aot_autograd(self):
"""
Explicitly check AotAutograd family of compilers work,
since they require example inputs propagated between graph splits.
"""
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
@torch._dynamo.optimize("aot_eager")
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
opt_outputs.sum().backward()
self.assertTrue(same(correct_outputs, opt_outputs))
@patch.object(config, "optimize_ddp", True)
def test_custom_layer(self):
"""
Just ensures that the appropriate number of splits happen (based on
bucket size and model parameters) - verifies the number of times
the user-provided compiler is called by the DDPOptimizer which is
doing the graph splitting
"""
m, inputs, correct_outputs = get_custom_model(self.device)
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=1)
check_splits_compiler = CheckSplitsCompiler()
@torch._dynamo.optimize(check_splits_compiler.compile_fn)
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
self.assertEqual(check_splits_compiler.compiler_called, 3)
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_empty_graph_inductor(self):
def fn():
get_world_size = torch.distributed.distributed_c10d.get_world_size()
return (get_world_size,)
opt_fn = torch._dynamo.optimize("inductor")(fn)
res = None
try:
res = opt_fn()[0]
except Exception:
pass
self.assertEqual(res, 1)
@patch.object(config, "optimize_ddp", False)
def test_ignored_parameters(self):
"""
Verifies ddp graph-split logic ignores parameters marked to ignore on DDP module.
Hooks up graph-split optimizer manually so it can peek at internal state.
"""
m, inputs, correct_outputs = get_custom_model(self.device)
parameters_to_ignore = ["seq.2.weight", "seq.4.linear.bias"]
DDP._set_params_and_buffers_to_ignore_for_model(m, parameters_to_ignore)
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
parameter_ids_to_ignore = [
id(ddp_m.module.get_parameter(p))
for p in ddp_m.parameters_to_ignore
]
check_splits_compiler = CheckSplitsCompiler()
ddp_optimizer = DDPOptimizer(
bucket_bytes_cap=ddp_m.bucket_bytes_cap,
backend_compile_fn=check_splits_compiler.compile_fn
)
@torch._dynamo.optimize(ddp_optimizer.compile_fn)
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
self.assertEqual(check_splits_compiler.compiler_called, 2)
for b in ddp_optimizer.buckets:
for p_id in b.param_ids:
self.assertFalse(p_id in parameter_ids_to_ignore)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()