mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix unused Python variables in test/[a-d]* (#134665)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134665 Approved by: https://github.com/albanD
This commit is contained in:
parent
e19f493f02
commit
d25e6e623f
|
|
@ -147,7 +147,6 @@ def _sparse_layer_test_helper(
|
||||||
W_zp = 0
|
W_zp = 0
|
||||||
|
|
||||||
X_fp32 = torch.randn(batch_size, input_channels, dtype=torch.float32)
|
X_fp32 = torch.randn(batch_size, input_channels, dtype=torch.float32)
|
||||||
float_bias = torch.randn(output_channels, dtype=torch.float32)
|
|
||||||
|
|
||||||
# generate a weight which we'll insert into the model
|
# generate a weight which we'll insert into the model
|
||||||
W_fp32 = torch.randn(output_channels, input_channels, dtype=torch.float32)
|
W_fp32 = torch.randn(output_channels, input_channels, dtype=torch.float32)
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,6 @@ class TestQlinearPackedParams(TestCase):
|
||||||
row_block_size = 1
|
row_block_size = 1
|
||||||
col_block_size = 4
|
col_block_size = 4
|
||||||
out_features = weight_fp32.shape[0]
|
out_features = weight_fp32.shape[0]
|
||||||
in_features = weight_fp32.shape[1]
|
|
||||||
|
|
||||||
scales = [2.0, 6.0, 12.0]
|
scales = [2.0, 6.0, 12.0]
|
||||||
zero_points = [
|
zero_points = [
|
||||||
|
|
@ -201,14 +200,11 @@ class TestQlinearPackedParams(TestCase):
|
||||||
row_block_size = 1
|
row_block_size = 1
|
||||||
col_block_size = 4
|
col_block_size = 4
|
||||||
out_features = weight_fp32.shape[0]
|
out_features = weight_fp32.shape[0]
|
||||||
in_features = weight_fp32.shape[1]
|
|
||||||
|
|
||||||
scales = [2.0, 3.0, 7.0]
|
scales = [2.0, 3.0, 7.0]
|
||||||
zero_points = [0 for _ in range(out_features)]
|
zero_points = [0 for _ in range(out_features)]
|
||||||
dtype = torch.qint8
|
dtype = torch.qint8
|
||||||
|
|
||||||
x = torch.rand(size=(1, weight_fp32.shape[1]))
|
|
||||||
|
|
||||||
def make_lin_get_state_weight_bias_and_save():
|
def make_lin_get_state_weight_bias_and_save():
|
||||||
weight = torch.quantize_per_tensor(
|
weight = torch.quantize_per_tensor(
|
||||||
weight_fp32,
|
weight_fp32,
|
||||||
|
|
|
||||||
|
|
@ -86,7 +86,7 @@ class TestBaseSparsifier(TestCase):
|
||||||
sparsifier0.prepare(model0, [{"tensor_fqn": "linear1.weight"}])
|
sparsifier0.prepare(model0, [{"tensor_fqn": "linear1.weight"}])
|
||||||
mask = model0.linear1.parametrizations["weight"][0].mask
|
mask = model0.linear1.parametrizations["weight"][0].mask
|
||||||
mask.data = torch.arange(mask.shape[0] * mask.shape[1]).reshape(mask.shape)
|
mask.data = torch.arange(mask.shape[0] * mask.shape[1]).reshape(mask.shape)
|
||||||
for step in range(step_count):
|
for _ in range(step_count):
|
||||||
sparsifier0.step()
|
sparsifier0.step()
|
||||||
state_dict = sparsifier0.state_dict()
|
state_dict = sparsifier0.state_dict()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -124,7 +124,7 @@ class TestSparsityUtilFunctions(TestCase):
|
||||||
list_of_modules = [m for _, m in model.named_modules()] + [model]
|
list_of_modules = [m for _, m in model.named_modules()] + [model]
|
||||||
for module in list_of_modules:
|
for module in list_of_modules:
|
||||||
module_fqn = module_to_fqn(model, module)
|
module_fqn = module_to_fqn(model, module)
|
||||||
for tensor_name, tensor in module.named_parameters(recurse=False):
|
for tensor_name, _ in module.named_parameters(recurse=False):
|
||||||
tensor_fqn = (
|
tensor_fqn = (
|
||||||
module_fqn + ("." if module_fqn != "" else "") + tensor_name
|
module_fqn + ("." if module_fqn != "" else "") + tensor_name
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -269,7 +269,6 @@ class TestBaseStructuredSparsifier(TestCase):
|
||||||
|
|
||||||
def _test_step_linear_on_device(self, model, device):
|
def _test_step_linear_on_device(self, model, device):
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
x = torch.ones(7, 7, device=device)
|
|
||||||
pruner = SimplePruner(None)
|
pruner = SimplePruner(None)
|
||||||
pruner.prepare(model, None)
|
pruner.prepare(model, None)
|
||||||
pruner.enable_mask_update = True
|
pruner.enable_mask_update = True
|
||||||
|
|
@ -808,7 +807,7 @@ class TestBaseStructuredSparsifier(TestCase):
|
||||||
pruned_model = fx_pruner.prune()
|
pruned_model = fx_pruner.prune()
|
||||||
pruned_model.eval()
|
pruned_model.eval()
|
||||||
out_pruned, lstm_out_pruned = pruned_model(lstm_input)
|
out_pruned, lstm_out_pruned = pruned_model(lstm_input)
|
||||||
r, c = lstm_out_expected.size()
|
_, c = lstm_out_expected.size()
|
||||||
|
|
||||||
# We cannot check that y_expected == y_pruned as usual because
|
# We cannot check that y_expected == y_pruned as usual because
|
||||||
# zeros vs. missing elements yield different numerical results.
|
# zeros vs. missing elements yield different numerical results.
|
||||||
|
|
@ -891,7 +890,7 @@ class TestBaseStructuredSparsifier(TestCase):
|
||||||
pruned_model = fx_pruner.prune()
|
pruned_model = fx_pruner.prune()
|
||||||
pruned_model.eval()
|
pruned_model.eval()
|
||||||
out_pruned, lstm_out_pruned = pruned_model(lstm_input)
|
out_pruned, lstm_out_pruned = pruned_model(lstm_input)
|
||||||
r, c = lstm_out_expected.size()
|
_, c = lstm_out_expected.size()
|
||||||
|
|
||||||
# We cannot check that y_expected == y_pruned as usual because
|
# We cannot check that y_expected == y_pruned as usual because
|
||||||
# zeros vs. missing elements yield different numerical results.
|
# zeros vs. missing elements yield different numerical results.
|
||||||
|
|
|
||||||
|
|
@ -670,7 +670,7 @@ class TestAutogradFunctional(TestCase):
|
||||||
|
|
||||||
x = ctors.randn(3)
|
x = ctors.randn(3)
|
||||||
with warnings.catch_warnings(record=True) as wa:
|
with warnings.catch_warnings(record=True) as wa:
|
||||||
result = api(foo, x, vectorize=True)
|
api(foo, x, vectorize=True)
|
||||||
self.assertEqual(len(wa), 0)
|
self.assertEqual(len(wa), 0)
|
||||||
|
|
||||||
@base_and_logging_tensor
|
@base_and_logging_tensor
|
||||||
|
|
@ -762,7 +762,7 @@ class TestAutogradFunctional(TestCase):
|
||||||
|
|
||||||
inp = ctors.rand(4)
|
inp = ctors.rand(4)
|
||||||
with self.assertRaisesRegex(RuntimeError, "not supported together"):
|
with self.assertRaisesRegex(RuntimeError, "not supported together"):
|
||||||
res = autogradF.jacobian(foo, inp, strict=True, vectorize=True)
|
autogradF.jacobian(foo, inp, strict=True, vectorize=True)
|
||||||
|
|
||||||
@base_and_logging_tensor
|
@base_and_logging_tensor
|
||||||
def test_jacobian_no_grad(self, ctors):
|
def test_jacobian_no_grad(self, ctors):
|
||||||
|
|
@ -1122,7 +1122,7 @@ class TestAutogradFunctional(TestCase):
|
||||||
|
|
||||||
inp = ctors.rand(4)
|
inp = ctors.rand(4)
|
||||||
with self.assertRaisesRegex(RuntimeError, "not supported together"):
|
with self.assertRaisesRegex(RuntimeError, "not supported together"):
|
||||||
res = autogradF.hessian(foo, inp, strict=True, vectorize=True)
|
autogradF.hessian(foo, inp, strict=True, vectorize=True)
|
||||||
|
|
||||||
@base_and_logging_tensor
|
@base_and_logging_tensor
|
||||||
def test_hessian_no_grad(self, ctors):
|
def test_hessian_no_grad(self, ctors):
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ def main():
|
||||||
data = torch.randn(10, 50).cuda()
|
data = torch.randn(10, 50).cuda()
|
||||||
model = Model().cuda()
|
model = Model().cuda()
|
||||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
|
||||||
for i in range(10):
|
for _ in range(10):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss = model(data)
|
loss = model(data)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
|
||||||
|
|
@ -78,9 +78,9 @@ def forward(self, arg0_1):
|
||||||
x = torch.randn(3, device="meta")
|
x = torch.randn(3, device="meta")
|
||||||
self.assertNotIn("my_custom_ops2", sys.modules.keys())
|
self.assertNotIn("my_custom_ops2", sys.modules.keys())
|
||||||
with self.assertRaisesRegex(NotImplementedError, r"'my_custom_ops2'"):
|
with self.assertRaisesRegex(NotImplementedError, r"'my_custom_ops2'"):
|
||||||
y = torch.ops.custom.sin.default(x)
|
torch.ops.custom.sin.default(x)
|
||||||
torch.ops.import_module("my_custom_ops2")
|
torch.ops.import_module("my_custom_ops2")
|
||||||
y = torch.ops.custom.sin.default(x)
|
torch.ops.custom.sin.default(x)
|
||||||
|
|
||||||
def test_calling_custom_op_string(self):
|
def test_calling_custom_op_string(self):
|
||||||
output = ops.custom.op2("abc", "def")
|
output = ops.custom.op2("abc", "def")
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ class _TestClipGradNormBase(FSDPTest):
|
||||||
vector_norm_fn = functools.partial(torch.linalg.vector_norm, ord=norm_type)
|
vector_norm_fn = functools.partial(torch.linalg.vector_norm, ord=norm_type)
|
||||||
dp_mesh = dp_mesh or init_device_mesh("cuda", (self.world_size,))
|
dp_mesh = dp_mesh or init_device_mesh("cuda", (self.world_size,))
|
||||||
torch.manual_seed(42 + dp_mesh.get_local_rank() + 1)
|
torch.manual_seed(42 + dp_mesh.get_local_rank() + 1)
|
||||||
for iter_idx in range(10):
|
for _ in range(10):
|
||||||
ref_optim.zero_grad()
|
ref_optim.zero_grad()
|
||||||
ref_model(inp).sum().backward()
|
ref_model(inp).sum().backward()
|
||||||
optim.zero_grad()
|
optim.zero_grad()
|
||||||
|
|
|
||||||
|
|
@ -250,8 +250,8 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
|
||||||
self.assertEqual(group.size(), self.world_size)
|
self.assertEqual(group.size(), self.world_size)
|
||||||
all_reduce_stream = torch.cuda.Stream()
|
all_reduce_stream = torch.cuda.Stream()
|
||||||
(
|
(
|
||||||
reduce_scatter_input,
|
_,
|
||||||
reduce_scatter_event,
|
_,
|
||||||
post_reduce_event,
|
post_reduce_event,
|
||||||
_,
|
_,
|
||||||
_,
|
_,
|
||||||
|
|
@ -406,7 +406,7 @@ class TestFullyShardCommunication(FSDPTest):
|
||||||
torch.manual_seed(42 + self.rank)
|
torch.manual_seed(42 + self.rank)
|
||||||
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
|
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
|
||||||
|
|
||||||
for iter_idx in range(10):
|
for _ in range(10):
|
||||||
ref_loss = ref_model(inp).sum()
|
ref_loss = ref_model(inp).sum()
|
||||||
ref_loss.backward()
|
ref_loss.backward()
|
||||||
for param in ref_model.parameters():
|
for param in ref_model.parameters():
|
||||||
|
|
@ -501,7 +501,7 @@ class TestFullyShardPrefetch(FSDPTest):
|
||||||
self, reshard_after_forward: Union[bool, int], checkpoint_impl: Optional[str]
|
self, reshard_after_forward: Union[bool, int], checkpoint_impl: Optional[str]
|
||||||
):
|
):
|
||||||
n_layers = 3
|
n_layers = 3
|
||||||
model, optim, inp = self._init_transformer(
|
model, _, inp = self._init_transformer(
|
||||||
n_layers, reshard_after_forward, checkpoint_impl
|
n_layers, reshard_after_forward, checkpoint_impl
|
||||||
)
|
)
|
||||||
events: List[EventType] = []
|
events: List[EventType] = []
|
||||||
|
|
@ -843,7 +843,7 @@ class TestFullyShardPrefetch(FSDPTest):
|
||||||
with patch_unshard(unshard_with_record), patch_post_backward(
|
with patch_unshard(unshard_with_record), patch_post_backward(
|
||||||
post_backward_with_record
|
post_backward_with_record
|
||||||
):
|
):
|
||||||
for iter_idx in range(3):
|
for _ in range(3):
|
||||||
loss = model(inp)
|
loss = model(inp)
|
||||||
expected_events = [
|
expected_events = [
|
||||||
(
|
(
|
||||||
|
|
@ -922,7 +922,7 @@ class TestFullyShardPrefetch(FSDPTest):
|
||||||
with patch_unshard(unshard_with_record), patch_post_backward(
|
with patch_unshard(unshard_with_record), patch_post_backward(
|
||||||
post_backward_with_record
|
post_backward_with_record
|
||||||
):
|
):
|
||||||
for iter_idx in range(3):
|
for _ in range(3):
|
||||||
loss = model(inp)
|
loss = model(inp)
|
||||||
expected_events = [
|
expected_events = [
|
||||||
("unshard", "", TrainingState.FORWARD),
|
("unshard", "", TrainingState.FORWARD),
|
||||||
|
|
|
||||||
|
|
@ -662,7 +662,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
||||||
def __init__(self, n_layers):
|
def __init__(self, n_layers):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = torch.nn.ModuleList()
|
self.layers = torch.nn.ModuleList()
|
||||||
for layer_id in range(n_layers):
|
for _ in range(n_layers):
|
||||||
self.layers.append(TestSubmodule(hidden_dim))
|
self.layers.append(TestSubmodule(hidden_dim))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
@ -684,7 +684,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
||||||
fsdp_config = {}
|
fsdp_config = {}
|
||||||
mesh = init_device_mesh("cuda", (self.world_size,))
|
mesh = init_device_mesh("cuda", (self.world_size,))
|
||||||
model = TestModule(n_layers=3)
|
model = TestModule(n_layers=3)
|
||||||
for layer_id, mod in enumerate(model.layers):
|
for mod in model.layers:
|
||||||
fully_shard(mod, mesh=mesh, reshard_after_forward=True, **fsdp_config)
|
fully_shard(mod, mesh=mesh, reshard_after_forward=True, **fsdp_config)
|
||||||
model = fully_shard(
|
model = fully_shard(
|
||||||
model, mesh=mesh, reshard_after_forward=True, **fsdp_config
|
model, mesh=mesh, reshard_after_forward=True, **fsdp_config
|
||||||
|
|
@ -871,7 +871,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
||||||
else:
|
else:
|
||||||
v.requires_grad_(False)
|
v.requires_grad_(False)
|
||||||
assert requires_grad_param_count == n_layers * len(requires_grad_params)
|
assert requires_grad_param_count == n_layers * len(requires_grad_params)
|
||||||
for layer_id, mod in enumerate(model.layers):
|
for _, mod in enumerate(model.layers):
|
||||||
fully_shard(mod, mesh=mesh, reshard_after_forward=True, **fsdp_config)
|
fully_shard(mod, mesh=mesh, reshard_after_forward=True, **fsdp_config)
|
||||||
model = fully_shard(
|
model = fully_shard(
|
||||||
model, mesh=mesh, reshard_after_forward=True, **fsdp_config
|
model, mesh=mesh, reshard_after_forward=True, **fsdp_config
|
||||||
|
|
@ -1087,7 +1087,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
||||||
setattr(m.encoder, name, new_child)
|
setattr(m.encoder, name, new_child)
|
||||||
m = FSDP(m, sharding_strategy=ShardingStrategy.FULL_SHARD, use_orig_params=True)
|
m = FSDP(m, sharding_strategy=ShardingStrategy.FULL_SHARD, use_orig_params=True)
|
||||||
inp = torch.randn(32, 784, device="cuda")
|
inp = torch.randn(32, 784, device="cuda")
|
||||||
out = m(inp)
|
m(inp)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -241,7 +241,7 @@ class TestFullyShardAllGatherExtensionsMultiProcess(
|
||||||
losses.append(_model(inp).sum())
|
losses.append(_model(inp).sum())
|
||||||
losses[-1].backward()
|
losses[-1].backward()
|
||||||
if _model is ref_model:
|
if _model is ref_model:
|
||||||
for param_name, param in _model.named_parameters():
|
for _, param in _model.named_parameters():
|
||||||
dist.all_reduce(param.grad)
|
dist.all_reduce(param.grad)
|
||||||
param.grad.detach().div_(self.world_size)
|
param.grad.detach().div_(self.world_size)
|
||||||
self.assertEqual(losses[0], losses[1])
|
self.assertEqual(losses[0], losses[1])
|
||||||
|
|
|
||||||
|
|
@ -904,7 +904,7 @@ class TestFullyShardProcessGroupInit(FSDPTestMultiThread):
|
||||||
)
|
)
|
||||||
self.assertEqual(mesh.mesh, ref_mesh.mesh)
|
self.assertEqual(mesh.mesh, ref_mesh.mesh)
|
||||||
self.assertEqual(mesh._coordinate_on_dim, ref_mesh._coordinate_on_dim)
|
self.assertEqual(mesh._coordinate_on_dim, ref_mesh._coordinate_on_dim)
|
||||||
for (tag, ranks, group_name), (ref_tag, ref_ranks, ref_group_name) in zip(
|
for (_, ranks, _), (_, ref_ranks, _) in zip(
|
||||||
mesh._dim_group_infos, ref_mesh._dim_group_infos
|
mesh._dim_group_infos, ref_mesh._dim_group_infos
|
||||||
):
|
):
|
||||||
# Since we manually constructed new subgroups, the test and ref
|
# Since we manually constructed new subgroups, the test and ref
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ class LoggingTests(LoggingTestCase):
|
||||||
env["WORLD_SIZE"] = "1"
|
env["WORLD_SIZE"] = "1"
|
||||||
env["MASTER_PORT"] = "34715"
|
env["MASTER_PORT"] = "34715"
|
||||||
env["MASTER_ADDR"] = "localhost"
|
env["MASTER_ADDR"] = "localhost"
|
||||||
stdout, stderr = self.run_process_no_exception(
|
_, stderr = self.run_process_no_exception(
|
||||||
"""\
|
"""\
|
||||||
import logging
|
import logging
|
||||||
import torch
|
import torch
|
||||||
|
|
|
||||||
|
|
@ -590,7 +590,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||||
|
|
||||||
torch.manual_seed(42 + self.rank)
|
torch.manual_seed(42 + self.rank)
|
||||||
inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda")
|
inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda")
|
||||||
for iter_idx in range(10):
|
for _ in range(10):
|
||||||
losses: List[torch.Tensor] = []
|
losses: List[torch.Tensor] = []
|
||||||
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
||||||
_optim.zero_grad()
|
_optim.zero_grad()
|
||||||
|
|
@ -624,12 +624,12 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||||
# sync point after each iteration
|
# sync point after each iteration
|
||||||
ref_losses: List[torch.Tensor] = []
|
ref_losses: List[torch.Tensor] = []
|
||||||
losses: List[torch.Tensor] = []
|
losses: List[torch.Tensor] = []
|
||||||
for iter_idx in range(10):
|
for _ in range(10):
|
||||||
ref_optim.zero_grad()
|
ref_optim.zero_grad()
|
||||||
ref_losses.append(ref_model(inp).sum())
|
ref_losses.append(ref_model(inp).sum())
|
||||||
ref_losses[-1].backward()
|
ref_losses[-1].backward()
|
||||||
ref_optim.step()
|
ref_optim.step()
|
||||||
for iter_idx in range(10):
|
for _ in range(10):
|
||||||
optim.zero_grad()
|
optim.zero_grad()
|
||||||
losses.append(model(inp).sum())
|
losses.append(model(inp).sum())
|
||||||
losses[-1].backward()
|
losses[-1].backward()
|
||||||
|
|
@ -1185,7 +1185,7 @@ class TestFullyShardNDTraining(FSDPTest):
|
||||||
foreach: bool,
|
foreach: bool,
|
||||||
):
|
):
|
||||||
global_mesh = self.init_global_mesh()
|
global_mesh = self.init_global_mesh()
|
||||||
pp_mesh, dp_mesh, tp_mesh = (
|
_, dp_mesh, tp_mesh = (
|
||||||
global_mesh["pp"],
|
global_mesh["pp"],
|
||||||
global_mesh["dp"],
|
global_mesh["dp"],
|
||||||
global_mesh["tp"],
|
global_mesh["tp"],
|
||||||
|
|
@ -1217,7 +1217,7 @@ class TestFullyShardNDTraining(FSDPTest):
|
||||||
_optim.step()
|
_optim.step()
|
||||||
self.assertEqual(losses[0], losses[1])
|
self.assertEqual(losses[0], losses[1])
|
||||||
|
|
||||||
for n, p in model.named_parameters():
|
for _, p in model.named_parameters():
|
||||||
self.assertIsInstance(p, DTensor)
|
self.assertIsInstance(p, DTensor)
|
||||||
self.assertEqual(p.device_mesh.ndim, 2)
|
self.assertEqual(p.device_mesh.ndim, 2)
|
||||||
self.assertEqual(len(p.placements), 2)
|
self.assertEqual(len(p.placements), 2)
|
||||||
|
|
@ -1288,7 +1288,7 @@ class TestFullyShardHSDP3DTraining(FSDPTest):
|
||||||
_optim.step()
|
_optim.step()
|
||||||
self.assertEqual(losses[0], losses[1])
|
self.assertEqual(losses[0], losses[1])
|
||||||
|
|
||||||
for n, p in model.named_parameters():
|
for _, p in model.named_parameters():
|
||||||
self.assertIsInstance(p, DTensor)
|
self.assertIsInstance(p, DTensor)
|
||||||
self.assertEqual(p.device_mesh.ndim, 3)
|
self.assertEqual(p.device_mesh.ndim, 3)
|
||||||
self.assertEqual(len(p.placements), 3)
|
self.assertEqual(len(p.placements), 3)
|
||||||
|
|
|
||||||
|
|
@ -119,7 +119,6 @@ class TestCheckpoint(TestCase):
|
||||||
# no checkpoint
|
# no checkpoint
|
||||||
with MemoryDelta(x.device) as mem1:
|
with MemoryDelta(x.device) as mem1:
|
||||||
loss1 = net1(x1).sum()
|
loss1 = net1(x1).sum()
|
||||||
graph_size1 = self._get_graph_size(loss1)
|
|
||||||
loss1.backward()
|
loss1.backward()
|
||||||
|
|
||||||
# with checkpoint
|
# with checkpoint
|
||||||
|
|
|
||||||
|
|
@ -244,7 +244,6 @@ class TestFullyShard2DTraining(FSDPTest):
|
||||||
ref_model.parameters(), model.named_parameters()
|
ref_model.parameters(), model.named_parameters()
|
||||||
):
|
):
|
||||||
full_grad = param.grad.full_tensor()
|
full_grad = param.grad.full_tensor()
|
||||||
ref_grad = ref_param.grad
|
|
||||||
self.assertEqual(ref_param.grad, full_grad)
|
self.assertEqual(ref_param.grad, full_grad)
|
||||||
|
|
||||||
ref_optim.step()
|
ref_optim.step()
|
||||||
|
|
@ -285,7 +284,7 @@ class TestFullyShard2DTraining(FSDPTest):
|
||||||
# called, but they will just be no-ops without issuing any kernels.
|
# called, but they will just be no-ops without issuing any kernels.
|
||||||
# We prefer to keep the no-op check at the c10d level, not in FSDP.
|
# We prefer to keep the no-op check at the c10d level, not in FSDP.
|
||||||
inp = torch.randn((4, mlp_dim), device="cuda") # same on all ranks
|
inp = torch.randn((4, mlp_dim), device="cuda") # same on all ranks
|
||||||
for iter_idx in range(10):
|
for _ in range(10):
|
||||||
ref_optim.zero_grad()
|
ref_optim.zero_grad()
|
||||||
optim.zero_grad()
|
optim.zero_grad()
|
||||||
|
|
||||||
|
|
@ -583,9 +582,7 @@ class TestNew2dParallelTraining(DTensorTestBase):
|
||||||
"net1": ColwiseParallel(),
|
"net1": ColwiseParallel(),
|
||||||
"net2": RowwiseParallel(),
|
"net2": RowwiseParallel(),
|
||||||
}
|
}
|
||||||
model_2d = parallelize_module(
|
parallelize_module(SimpleModel().cuda(), mesh_2d["tp"], parallelize_plan)
|
||||||
SimpleModel().cuda(), mesh_2d["tp"], parallelize_plan
|
|
||||||
)
|
|
||||||
|
|
||||||
@with_comms
|
@with_comms
|
||||||
@skip_if_lt_x_gpu(4)
|
@skip_if_lt_x_gpu(4)
|
||||||
|
|
@ -833,7 +830,6 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||||
# Create a model without wrapper
|
# Create a model without wrapper
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
no_wrap_model = simple_model().cuda(self.rank)
|
no_wrap_model = simple_model().cuda(self.rank)
|
||||||
no_wrap_state_dict = no_wrap_model.state_dict()
|
|
||||||
no_wrap_optim = torch.optim.Adam(no_wrap_model.parameters(), lr=0.01)
|
no_wrap_optim = torch.optim.Adam(no_wrap_model.parameters(), lr=0.01)
|
||||||
no_wrap_model(no_wrap_model.get_input().cuda(self.rank)).sum().backward()
|
no_wrap_model(no_wrap_model.get_input().cuda(self.rank)).sum().backward()
|
||||||
no_wrap_optim.step()
|
no_wrap_optim.step()
|
||||||
|
|
@ -890,8 +886,6 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||||
set_optimizer_state_dict(
|
set_optimizer_state_dict(
|
||||||
model_2d, optimizers=optim_2d, optim_state_dict=ref_optim_2d_osd
|
model_2d, optimizers=optim_2d, optim_state_dict=ref_optim_2d_osd
|
||||||
)
|
)
|
||||||
new_optim_2d_osd = get_optimizer_state_dict(model_2d, optimizers=optim_2d)
|
|
||||||
|
|
||||||
ref_optim_2d_osd_states = ref_optim_2d_osd["state"]
|
ref_optim_2d_osd_states = ref_optim_2d_osd["state"]
|
||||||
new_optim_2d_osd_states = optim_2d_osd["state"]
|
new_optim_2d_osd_states = optim_2d_osd["state"]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -119,7 +119,7 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||||
)
|
)
|
||||||
@parametrize("use_new_runtime", [False, True])
|
@parametrize("use_new_runtime", [False, True])
|
||||||
def test_manual_with_data_parallel(self, dp_type, ScheduleClass, use_new_runtime):
|
def test_manual_with_data_parallel(self, dp_type, ScheduleClass, use_new_runtime):
|
||||||
device = torch.device("cuda", self.device)
|
_device_raii = torch.device("cuda", self.device)
|
||||||
torch.cuda.set_device(self.device)
|
torch.cuda.set_device(self.device)
|
||||||
store = torch.distributed.FileStore(self.file_name, self.world_size)
|
store = torch.distributed.FileStore(self.file_name, self.world_size)
|
||||||
torch.distributed.init_process_group(
|
torch.distributed.init_process_group(
|
||||||
|
|
@ -398,7 +398,7 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_3d_with_tp_dp_pp(self, ScheduleClass, MixedPrecisionParam):
|
def test_3d_with_tp_dp_pp(self, ScheduleClass, MixedPrecisionParam):
|
||||||
device = torch.device("cuda", self.device)
|
_device_raii = torch.device("cuda", self.device)
|
||||||
torch.cuda.set_device(self.device)
|
torch.cuda.set_device(self.device)
|
||||||
store = torch.distributed.FileStore(self.file_name, self.world_size)
|
store = torch.distributed.FileStore(self.file_name, self.world_size)
|
||||||
torch.distributed.init_process_group(
|
torch.distributed.init_process_group(
|
||||||
|
|
|
||||||
|
|
@ -329,11 +329,11 @@ class ReplicateTest(MultiProcessInductorTestCase):
|
||||||
code = self._test_bucketing()
|
code = self._test_bucketing()
|
||||||
self.assertEqual(counters["inductor"]["ddp_buckets"], 3)
|
self.assertEqual(counters["inductor"]["ddp_buckets"], 3)
|
||||||
fc = FileCheck()
|
fc = FileCheck()
|
||||||
for i in range(3):
|
for _ in range(3):
|
||||||
fc.check("cpp_fused_").check(
|
fc.check("cpp_fused_").check(
|
||||||
"torch.ops._c10d_functional.all_reduce_coalesced_.default("
|
"torch.ops._c10d_functional.all_reduce_coalesced_.default("
|
||||||
)
|
)
|
||||||
for i in range(3):
|
for _ in range(3):
|
||||||
fc.check("torch.ops._c10d_functional.wait_tensor.default")
|
fc.check("torch.ops._c10d_functional.wait_tensor.default")
|
||||||
|
|
||||||
fc.run(code)
|
fc.run(code)
|
||||||
|
|
@ -342,11 +342,11 @@ class ReplicateTest(MultiProcessInductorTestCase):
|
||||||
code = self._test_bucketing(init_process_group=False, loop=2)
|
code = self._test_bucketing(init_process_group=False, loop=2)
|
||||||
self.assertEqual(counters["inductor"]["ddp_buckets"], 3)
|
self.assertEqual(counters["inductor"]["ddp_buckets"], 3)
|
||||||
fc = FileCheck()
|
fc = FileCheck()
|
||||||
for i in range(3):
|
for _ in range(3):
|
||||||
fc.check("cpp_fused_").check(
|
fc.check("cpp_fused_").check(
|
||||||
"torch.ops._c10d_functional.all_reduce_coalesced_.default("
|
"torch.ops._c10d_functional.all_reduce_coalesced_.default("
|
||||||
)
|
)
|
||||||
for i in range(3):
|
for _ in range(3):
|
||||||
fc.check("torch.ops._c10d_functional.wait_tensor.default")
|
fc.check("torch.ops._c10d_functional.wait_tensor.default")
|
||||||
|
|
||||||
fc.run(code)
|
fc.run(code)
|
||||||
|
|
@ -371,11 +371,11 @@ class ReplicateTest(MultiProcessInductorTestCase):
|
||||||
code = self._test_bucketing()
|
code = self._test_bucketing()
|
||||||
self.assertEqual(counters["inductor"]["ddp_buckets"], 3)
|
self.assertEqual(counters["inductor"]["ddp_buckets"], 3)
|
||||||
fc = FileCheck()
|
fc = FileCheck()
|
||||||
for i in range(3):
|
for _ in range(3):
|
||||||
fc.check("aten.flatten.using_ints(").check("cpp_fused_").check(
|
fc.check("aten.flatten.using_ints(").check("cpp_fused_").check(
|
||||||
"torch.ops._c10d_functional.all_reduce_.default("
|
"torch.ops._c10d_functional.all_reduce_.default("
|
||||||
)
|
)
|
||||||
for i in range(3):
|
for _ in range(3):
|
||||||
fc.check("torch.ops._c10d_functional.wait_tensor.default")
|
fc.check("torch.ops._c10d_functional.wait_tensor.default")
|
||||||
fc.run(code)
|
fc.run(code)
|
||||||
|
|
||||||
|
|
@ -383,11 +383,11 @@ class ReplicateTest(MultiProcessInductorTestCase):
|
||||||
code = self._test_bucketing(init_process_group=False, loop=2)
|
code = self._test_bucketing(init_process_group=False, loop=2)
|
||||||
self.assertEqual(counters["inductor"]["ddp_buckets"], 3)
|
self.assertEqual(counters["inductor"]["ddp_buckets"], 3)
|
||||||
fc = FileCheck()
|
fc = FileCheck()
|
||||||
for i in range(3):
|
for _ in range(3):
|
||||||
fc.check("aten.flatten.using_ints(").check("cpp_fused_").check(
|
fc.check("aten.flatten.using_ints(").check("cpp_fused_").check(
|
||||||
"torch.ops._c10d_functional.all_reduce_.default("
|
"torch.ops._c10d_functional.all_reduce_.default("
|
||||||
)
|
)
|
||||||
for i in range(3):
|
for _ in range(3):
|
||||||
fc.check("torch.ops._c10d_functional.wait_tensor.default")
|
fc.check("torch.ops._c10d_functional.wait_tensor.default")
|
||||||
fc.run(code)
|
fc.run(code)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -129,7 +129,7 @@ class TestShardedTensorBinaryOps(ShardedTensorTestBase):
|
||||||
def test_torch_equal(self):
|
def test_torch_equal(self):
|
||||||
"""Test torch.equal(ShardedTensor, ShardedTensor)"""
|
"""Test torch.equal(ShardedTensor, ShardedTensor)"""
|
||||||
|
|
||||||
spec, alt_spec = self.get_gpu_specs()
|
spec, _ = self.get_gpu_specs()
|
||||||
st1, st2 = self.get_random_tensors(spec, spec, 10, 10)
|
st1, st2 = self.get_random_tensors(spec, spec, 10, 10)
|
||||||
self.assertTrue(torch.equal(st1, st2))
|
self.assertTrue(torch.equal(st1, st2))
|
||||||
|
|
||||||
|
|
@ -145,7 +145,7 @@ class TestShardedTensorBinaryOps(ShardedTensorTestBase):
|
||||||
def test_torch_allclose(self):
|
def test_torch_allclose(self):
|
||||||
"""Test torch.allclose(ShardedTensor, ShardedTensor)"""
|
"""Test torch.allclose(ShardedTensor, ShardedTensor)"""
|
||||||
|
|
||||||
spec, alt_spec = self.get_gpu_specs()
|
spec, _ = self.get_gpu_specs()
|
||||||
|
|
||||||
st1, st2 = self.get_random_tensors(spec, spec, 10, 10)
|
st1, st2 = self.get_random_tensors(spec, spec, 10, 10)
|
||||||
self.assertTrue(torch.allclose(st1, st2))
|
self.assertTrue(torch.allclose(st1, st2))
|
||||||
|
|
|
||||||
|
|
@ -40,8 +40,6 @@ class TestShardedTensorNNInit(ShardedTensorTestBase):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
h, w = 8, 2
|
h, w = 8, 2
|
||||||
expected_h = 2
|
|
||||||
expected_device = torch.device(f"cuda:{self.rank}")
|
|
||||||
a, b = 10, 20
|
a, b = 10, 20
|
||||||
|
|
||||||
seed = 1234
|
seed = 1234
|
||||||
|
|
@ -75,8 +73,6 @@ class TestShardedTensorNNInit(ShardedTensorTestBase):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
h, w = 8, 2
|
h, w = 8, 2
|
||||||
expected_h = 2
|
|
||||||
expected_device = torch.device(f"cuda:{self.rank}")
|
|
||||||
mean, std = 10, 5
|
mean, std = 10, 5
|
||||||
|
|
||||||
seed = 1234
|
seed = 1234
|
||||||
|
|
@ -110,8 +106,6 @@ class TestShardedTensorNNInit(ShardedTensorTestBase):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
h, w = 8, 2
|
h, w = 8, 2
|
||||||
expected_h = 2
|
|
||||||
expected_device = torch.device(f"cuda:{self.rank}")
|
|
||||||
a, mode, nonlinearity = 0, "fan_in", "leaky_relu"
|
a, mode, nonlinearity = 0, "fan_in", "leaky_relu"
|
||||||
|
|
||||||
seed = 1234
|
seed = 1234
|
||||||
|
|
|
||||||
|
|
@ -456,7 +456,7 @@ class TestLocalTensor(ShardedTensorTestBase):
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
NotImplementedError, "Only single local shard is supported."
|
NotImplementedError, "Only single local shard is supported."
|
||||||
):
|
):
|
||||||
local_shard = st.local_tensor()
|
st.local_tensor()
|
||||||
|
|
||||||
|
|
||||||
class TestShardedTensorChunked(ShardedTensorTestBase):
|
class TestShardedTensorChunked(ShardedTensorTestBase):
|
||||||
|
|
@ -981,7 +981,6 @@ class TestShardedTensorChunked(ShardedTensorTestBase):
|
||||||
# Validate remote shards.
|
# Validate remote shards.
|
||||||
remote_shards = st.remote_shards()
|
remote_shards = st.remote_shards()
|
||||||
self.assertEqual(3, len(remote_shards))
|
self.assertEqual(3, len(remote_shards))
|
||||||
owners = {}
|
|
||||||
for rpc_rank, shards in remote_shards.items():
|
for rpc_rank, shards in remote_shards.items():
|
||||||
self.assertEqual(2, len(shards))
|
self.assertEqual(2, len(shards))
|
||||||
for remote_shard in shards:
|
for remote_shard in shards:
|
||||||
|
|
@ -1364,14 +1363,14 @@ class TestShardedTensorChunked(ShardedTensorTestBase):
|
||||||
with self.assertRaisesRegex(RuntimeError, "Local rank at save time was"):
|
with self.assertRaisesRegex(RuntimeError, "Local rank at save time was"):
|
||||||
with load_with_process_group(pg):
|
with load_with_process_group(pg):
|
||||||
# ShardedTensor weights_only is already tested in TestFSDPStateDict.test_torch_save_load
|
# ShardedTensor weights_only is already tested in TestFSDPStateDict.test_torch_save_load
|
||||||
state_dict_deser = torch.load(buffer, weights_only=False)
|
torch.load(buffer, weights_only=False)
|
||||||
else:
|
else:
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError, "Local world size at save time was"
|
RuntimeError, "Local world size at save time was"
|
||||||
):
|
):
|
||||||
with load_with_process_group(pg):
|
with load_with_process_group(pg):
|
||||||
# ShardedTensor weights_only is already tested in TestFSDPStateDict.test_torch_save_load
|
# ShardedTensor weights_only is already tested in TestFSDPStateDict.test_torch_save_load
|
||||||
state_dict_deser = torch.load(buffer, weights_only=False)
|
torch.load(buffer, weights_only=False)
|
||||||
|
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
buffer.seek(0)
|
buffer.seek(0)
|
||||||
|
|
@ -1379,7 +1378,7 @@ class TestShardedTensorChunked(ShardedTensorTestBase):
|
||||||
RuntimeError, "Need to initialize default process group"
|
RuntimeError, "Need to initialize default process group"
|
||||||
):
|
):
|
||||||
# ShardedTensor weights_only is already tested in TestFSDPStateDict.test_torch_save_load
|
# ShardedTensor weights_only is already tested in TestFSDPStateDict.test_torch_save_load
|
||||||
state_dict_deser = torch.load(buffer, weights_only=False)
|
torch.load(buffer, weights_only=False)
|
||||||
rpc.shutdown()
|
rpc.shutdown()
|
||||||
|
|
||||||
@with_comms
|
@with_comms
|
||||||
|
|
@ -1396,8 +1395,8 @@ class TestShardedTensorChunked(ShardedTensorTestBase):
|
||||||
"rank:3/cuda:3",
|
"rank:3/cuda:3",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
st1 = sharded_tensor.empty(spec, 10, 20, init_rrefs=True)
|
sharded_tensor.empty(spec, 10, 20, init_rrefs=True)
|
||||||
st2 = sharded_tensor.empty(spec, 10, 20)
|
sharded_tensor.empty(spec, 10, 20)
|
||||||
|
|
||||||
create_tensors()
|
create_tensors()
|
||||||
self.assertEqual(0, len(sharded_tensor.api._sharded_tensor_map))
|
self.assertEqual(0, len(sharded_tensor.api._sharded_tensor_map))
|
||||||
|
|
@ -2204,7 +2203,6 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase):
|
||||||
else:
|
else:
|
||||||
self.assertEqual(2, len(remote_shards))
|
self.assertEqual(2, len(remote_shards))
|
||||||
|
|
||||||
owners = {}
|
|
||||||
for rpc_rank, shards in remote_shards.items():
|
for rpc_rank, shards in remote_shards.items():
|
||||||
self.assertEqual(2, len(shards))
|
self.assertEqual(2, len(shards))
|
||||||
for remote_shard in shards:
|
for remote_shard in shards:
|
||||||
|
|
@ -2418,10 +2416,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
|
||||||
placement=f"rank:{self.rank}/cuda:{self.rank}",
|
placement=f"rank:{self.rank}/cuda:{self.rank}",
|
||||||
)
|
)
|
||||||
with self.assertRaisesRegex(ValueError, "Shard tensor size does not match"):
|
with self.assertRaisesRegex(ValueError, "Shard tensor size does not match"):
|
||||||
local_shard_from_wrong_meta = sharded_tensor.Shard(
|
sharded_tensor.Shard(local_tensor, metadata=wrong_local_shard_metadata)
|
||||||
local_tensor,
|
|
||||||
metadata=wrong_local_shard_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
@with_comms
|
@with_comms
|
||||||
@skip_if_lt_x_gpu(4)
|
@skip_if_lt_x_gpu(4)
|
||||||
|
|
@ -2696,7 +2691,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
|
||||||
|
|
||||||
empty_local_shards = []
|
empty_local_shards = []
|
||||||
with self.assertRaisesRegex(ValueError, "have no local shards on all ranks"):
|
with self.assertRaisesRegex(ValueError, "have no local shards on all ranks"):
|
||||||
st = sharded_tensor.init_from_local_shards(
|
sharded_tensor.init_from_local_shards(
|
||||||
empty_local_shards, [10, 10], init_rrefs=True
|
empty_local_shards, [10, 10], init_rrefs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -2706,7 +2701,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, "Only torch.strided layout is currently supported"
|
ValueError, "Only torch.strided layout is currently supported"
|
||||||
):
|
):
|
||||||
st = sharded_tensor.init_from_local_shards(
|
sharded_tensor.init_from_local_shards(
|
||||||
wrong_layout_shards, [10, 10], init_rrefs=True
|
wrong_layout_shards, [10, 10], init_rrefs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -2719,23 +2714,19 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
|
||||||
ValueError,
|
ValueError,
|
||||||
"Only torch.contiguous_format memory_format is currently supported",
|
"Only torch.contiguous_format memory_format is currently supported",
|
||||||
):
|
):
|
||||||
st = sharded_tensor.init_from_local_shards(
|
sharded_tensor.init_from_local_shards(
|
||||||
wrong_memory_format_shards, [10, 10], init_rrefs=True
|
wrong_memory_format_shards, [10, 10], init_rrefs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.assertRaisesRegex(ValueError, "Shard tensor size does not match"):
|
with self.assertRaisesRegex(ValueError, "Shard tensor size does not match"):
|
||||||
wrong_size_shards = [
|
sharded_tensor.Shard(
|
||||||
sharded_tensor.Shard(
|
torch.randn(2, 3, device=f"cuda:{self.rank}"), local_shard_metadata
|
||||||
torch.randn(2, 3, device=f"cuda:{self.rank}"), local_shard_metadata
|
)
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, "Local shard tensor device does not match"
|
ValueError, "Local shard tensor device does not match"
|
||||||
):
|
):
|
||||||
wrong_device_shards = [
|
sharded_tensor.Shard(torch.randn(5, 5), local_shard_metadata)
|
||||||
sharded_tensor.Shard(torch.randn(5, 5), local_shard_metadata)
|
|
||||||
]
|
|
||||||
|
|
||||||
@with_comms
|
@with_comms
|
||||||
@skip_if_lt_x_gpu(4)
|
@skip_if_lt_x_gpu(4)
|
||||||
|
|
@ -2756,7 +2747,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
|
||||||
ValueError,
|
ValueError,
|
||||||
"ShardedTensor global_size property does not match from different ranks!",
|
"ShardedTensor global_size property does not match from different ranks!",
|
||||||
):
|
):
|
||||||
st = sharded_tensor.init_from_local_shards(
|
sharded_tensor.init_from_local_shards(
|
||||||
wrong_dtype_shards, tensor_overall_size, init_rrefs=True
|
wrong_dtype_shards, tensor_overall_size, init_rrefs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -2771,7 +2762,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
|
||||||
ValueError,
|
ValueError,
|
||||||
"ShardedTensor dtype property does not match from different ranks!",
|
"ShardedTensor dtype property does not match from different ranks!",
|
||||||
):
|
):
|
||||||
st = sharded_tensor.init_from_local_shards(
|
sharded_tensor.init_from_local_shards(
|
||||||
wrong_dtype_shards, [10, 10], init_rrefs=True
|
wrong_dtype_shards, [10, 10], init_rrefs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -2788,7 +2779,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
|
||||||
ValueError,
|
ValueError,
|
||||||
"ShardedTensor requires_grad property does not match from different ranks!",
|
"ShardedTensor requires_grad property does not match from different ranks!",
|
||||||
):
|
):
|
||||||
st = sharded_tensor.init_from_local_shards(
|
sharded_tensor.init_from_local_shards(
|
||||||
wrong_requires_grad_shards, [10, 10], init_rrefs=True
|
wrong_requires_grad_shards, [10, 10], init_rrefs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -2818,7 +2809,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, "Local shards' tensor pin_memory property need to be the same"
|
ValueError, "Local shards' tensor pin_memory property need to be the same"
|
||||||
):
|
):
|
||||||
st = sharded_tensor.init_from_local_shards(
|
sharded_tensor.init_from_local_shards(
|
||||||
wrong_pin_memory_local_shards, [10, 10], init_rrefs=True
|
wrong_pin_memory_local_shards, [10, 10], init_rrefs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -2832,7 +2823,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
|
||||||
ValueError,
|
ValueError,
|
||||||
"ShardedTensor pin_memory property does not match from different ranks!",
|
"ShardedTensor pin_memory property does not match from different ranks!",
|
||||||
):
|
):
|
||||||
st = sharded_tensor.init_from_local_shards(
|
sharded_tensor.init_from_local_shards(
|
||||||
wrong_pin_memory_shards_cross_ranks, [10, 10], init_rrefs=True
|
wrong_pin_memory_shards_cross_ranks, [10, 10], init_rrefs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -2945,19 +2936,15 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, "Shard tensor size does not match with metadata.shard_lengths"
|
ValueError, "Shard tensor size does not match with metadata.shard_lengths"
|
||||||
):
|
):
|
||||||
wrong_size_shards = [
|
sharded_tensor.Shard(
|
||||||
sharded_tensor.Shard(
|
torch.randn(2, 3, device=f"cuda:{self.rank}"), local_shard_metadata
|
||||||
torch.randn(2, 3, device=f"cuda:{self.rank}"), local_shard_metadata
|
)
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError,
|
ValueError,
|
||||||
"Local shard tensor device does not match with local Shard's placement",
|
"Local shard tensor device does not match with local Shard's placement",
|
||||||
):
|
):
|
||||||
wrong_device_shards = [
|
sharded_tensor.Shard(torch.randn(5, 5), local_shard_metadata)
|
||||||
sharded_tensor.Shard(torch.randn(5, 5), local_shard_metadata)
|
|
||||||
]
|
|
||||||
|
|
||||||
wrong_dtype_shards = [
|
wrong_dtype_shards = [
|
||||||
sharded_tensor.Shard(
|
sharded_tensor.Shard(
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ class ChunkAllShardingPlanner(ShardingPlanner):
|
||||||
def build_plan(self, module: nn.Module) -> ShardingPlan:
|
def build_plan(self, module: nn.Module) -> ShardingPlan:
|
||||||
named_params = module.named_parameters()
|
named_params = module.named_parameters()
|
||||||
plan = {}
|
plan = {}
|
||||||
for name, param in named_params:
|
for name, _ in named_params:
|
||||||
plan[name] = ChunkShardingSpec(self.dim, placements=self.devices)
|
plan[name] = ChunkShardingSpec(self.dim, placements=self.devices)
|
||||||
|
|
||||||
return ShardingPlan(plan=plan)
|
return ShardingPlan(plan=plan)
|
||||||
|
|
|
||||||
|
|
@ -92,7 +92,6 @@ class TestCommMode(TestCase):
|
||||||
self.assertEqual(comm_counts[c10d_functional.reduce_scatter_tensor], 1)
|
self.assertEqual(comm_counts[c10d_functional.reduce_scatter_tensor], 1)
|
||||||
|
|
||||||
def test_comm_mode_with_dtensor(self):
|
def test_comm_mode_with_dtensor(self):
|
||||||
world_pg = self.world_pg
|
|
||||||
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||||
|
|
||||||
def f(x, y):
|
def f(x, y):
|
||||||
|
|
@ -118,8 +117,6 @@ class TestCommMode(TestCase):
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
return
|
return
|
||||||
|
|
||||||
world_pg = self.world_pg
|
|
||||||
|
|
||||||
inp = torch.rand(2, 8, 16).cuda()
|
inp = torch.rand(2, 8, 16).cuda()
|
||||||
all_gather_out = inp.new_empty(self.world_size * 2, 8, 16)
|
all_gather_out = inp.new_empty(self.world_size * 2, 8, 16)
|
||||||
|
|
||||||
|
|
@ -202,7 +199,7 @@ class TestCommMode(TestCase):
|
||||||
self.checksAssert(comm_mode, c10d_ops.reduce_scatter_, 1, 1)
|
self.checksAssert(comm_mode, c10d_ops.reduce_scatter_, 1, 1)
|
||||||
|
|
||||||
# tests c10d reduce_scatter_tensor_coalesced
|
# tests c10d reduce_scatter_tensor_coalesced
|
||||||
with comm_mode as A, dist._coalescing_manager() as B:
|
with comm_mode, dist._coalescing_manager():
|
||||||
dist.reduce_scatter_tensor(all_gather_out, inp)
|
dist.reduce_scatter_tensor(all_gather_out, inp)
|
||||||
|
|
||||||
self.checksAssert(comm_mode, c10d_ops.reduce_scatter_tensor_coalesced_, 1, 1)
|
self.checksAssert(comm_mode, c10d_ops.reduce_scatter_tensor_coalesced_, 1, 1)
|
||||||
|
|
|
||||||
|
|
@ -251,7 +251,7 @@ class TestCommModeFeatures(DTensorTestBase):
|
||||||
comm_mode.comm_module_counts,
|
comm_mode.comm_module_counts,
|
||||||
{"Global": {"forward": {}, "backward": {}}},
|
{"Global": {"forward": {}, "backward": {}}},
|
||||||
)
|
)
|
||||||
output_tp = model(inp)
|
model(inp)
|
||||||
|
|
||||||
model_args = ModelArgs(dropout_p=0.0)
|
model_args = ModelArgs(dropout_p=0.0)
|
||||||
model2 = Transformer(model_args).to(device=self.device_type)
|
model2 = Transformer(model_args).to(device=self.device_type)
|
||||||
|
|
@ -264,7 +264,7 @@ class TestCommModeFeatures(DTensorTestBase):
|
||||||
|
|
||||||
comm_mode = CommDebugMode()
|
comm_mode = CommDebugMode()
|
||||||
with comm_mode:
|
with comm_mode:
|
||||||
output = model2(inp)
|
model2(inp)
|
||||||
|
|
||||||
# checks to see if all collectives were correctly traced at the module-level
|
# checks to see if all collectives were correctly traced at the module-level
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
|
|
||||||
|
|
@ -155,14 +155,12 @@ class DTensorTest(DTensorTestBase):
|
||||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||||
shard0_spec = [Shard(0)]
|
shard0_spec = [Shard(0)]
|
||||||
local_tensor = torch.randn(4, 8)
|
local_tensor = torch.randn(4, 8)
|
||||||
global_shape = torch.Size([self.world_size * 4, 8])
|
|
||||||
dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard0_spec)
|
dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard0_spec)
|
||||||
# won't affect stride
|
# won't affect stride
|
||||||
self.assertEqual(dist_tensor.stride(), (8, 1))
|
self.assertEqual(dist_tensor.stride(), (8, 1))
|
||||||
|
|
||||||
shard1_spec = [Shard(1)]
|
shard1_spec = [Shard(1)]
|
||||||
local_tensor = torch.randn(8, 4)
|
local_tensor = torch.randn(8, 4)
|
||||||
global_shape = torch.Size([8, self.world_size * 4])
|
|
||||||
dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard1_spec)
|
dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard1_spec)
|
||||||
# will affect stride after DT initialized
|
# will affect stride after DT initialized
|
||||||
self.assertEqual(dist_tensor.stride(), (4 * self.world_size, 1))
|
self.assertEqual(dist_tensor.stride(), (4 * self.world_size, 1))
|
||||||
|
|
@ -170,7 +168,6 @@ class DTensorTest(DTensorTestBase):
|
||||||
# if initialized from a transposed mat
|
# if initialized from a transposed mat
|
||||||
local_tensor = torch.randn(8, 4, 8)
|
local_tensor = torch.randn(8, 4, 8)
|
||||||
local_tensor_t = local_tensor.permute(1, 2, 0)
|
local_tensor_t = local_tensor.permute(1, 2, 0)
|
||||||
global_shape = torch.Size([4, self.world_size * 8, 8])
|
|
||||||
self.assertEqual(local_tensor_t.stride(), (8, 1, 32))
|
self.assertEqual(local_tensor_t.stride(), (8, 1, 32))
|
||||||
dist_tensor = DTensor.from_local(local_tensor_t, device_mesh, shard1_spec)
|
dist_tensor = DTensor.from_local(local_tensor_t, device_mesh, shard1_spec)
|
||||||
global_stride = (8 * self.world_size, 1, 32 * self.world_size)
|
global_stride = (8 * self.world_size, 1, 32 * self.world_size)
|
||||||
|
|
@ -257,7 +254,7 @@ class DTensorTest(DTensorTestBase):
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError, "Please pass both shape and stride at the same time."
|
RuntimeError, "Please pass both shape and stride at the same time."
|
||||||
):
|
):
|
||||||
dtensor = DTensor.from_local(
|
DTensor.from_local(
|
||||||
tensor_list[self.rank],
|
tensor_list[self.rank],
|
||||||
device_mesh,
|
device_mesh,
|
||||||
(Shard(0),),
|
(Shard(0),),
|
||||||
|
|
@ -267,7 +264,7 @@ class DTensorTest(DTensorTestBase):
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError, "Please pass both shape and stride at the same time."
|
RuntimeError, "Please pass both shape and stride at the same time."
|
||||||
):
|
):
|
||||||
dtensor = DTensor.from_local(
|
DTensor.from_local(
|
||||||
tensor_list[self.rank],
|
tensor_list[self.rank],
|
||||||
device_mesh,
|
device_mesh,
|
||||||
(Shard(0),),
|
(Shard(0),),
|
||||||
|
|
@ -1043,7 +1040,7 @@ class DTensorLogTest(LoggingTestCase):
|
||||||
env["MASTER_PORT"] = "12345"
|
env["MASTER_PORT"] = "12345"
|
||||||
env["MASTER_ADDR"] = "localhost"
|
env["MASTER_ADDR"] = "localhost"
|
||||||
|
|
||||||
stdout, stderr = self.run_process_no_exception(
|
_, stderr = self.run_process_no_exception(
|
||||||
"""\
|
"""\
|
||||||
import logging
|
import logging
|
||||||
import torch
|
import torch
|
||||||
|
|
|
||||||
|
|
@ -234,8 +234,8 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
|
||||||
requires_grad=x.requires_grad,
|
requires_grad=x.requires_grad,
|
||||||
)
|
)
|
||||||
|
|
||||||
out = fn(x)
|
fn(x)
|
||||||
out2 = torch.compile(fn, backend="eager")(x)
|
torch.compile(fn, backend="eager")(x)
|
||||||
|
|
||||||
def test_dtensor_constructor_w_dynamo_disable(self):
|
def test_dtensor_constructor_w_dynamo_disable(self):
|
||||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||||
|
|
@ -599,7 +599,7 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
|
||||||
|
|
||||||
@torch.compile(backend=cnt)
|
@torch.compile(backend=cnt)
|
||||||
def fn(x):
|
def fn(x):
|
||||||
dt = DTensor.from_local(x, mesh, [placement], run_check=False)
|
DTensor.from_local(x, mesh, [placement], run_check=False)
|
||||||
|
|
||||||
x = torch.ones(4, 4, requires_grad=True)
|
x = torch.ones(4, 4, requires_grad=True)
|
||||||
|
|
||||||
|
|
@ -659,7 +659,7 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
|
||||||
x2 = x_dt.redistribute(mesh, [Replicate()], async_op=True)
|
x2 = x_dt.redistribute(mesh, [Replicate()], async_op=True)
|
||||||
x2 = x2.to_local()
|
x2 = x2.to_local()
|
||||||
self.assertTrue(isinstance(x2, AsyncCollectiveTensor))
|
self.assertTrue(isinstance(x2, AsyncCollectiveTensor))
|
||||||
out = opt_fn(x2)
|
opt_fn(x2)
|
||||||
# The important part: we get a wait_tensor() in the graph.
|
# The important part: we get a wait_tensor() in the graph.
|
||||||
# At runtime, the input to the graph is an AsyncCollectiveTensor,
|
# At runtime, the input to the graph is an AsyncCollectiveTensor,
|
||||||
# and inside the graph we need to issue a wait() to synchronize.
|
# and inside the graph we need to issue a wait() to synchronize.
|
||||||
|
|
@ -880,8 +880,6 @@ class TestDTensorCompileE2E(DTensorTestBase):
|
||||||
mesh_dim_names=["dp", "tp"],
|
mesh_dim_names=["dp", "tp"],
|
||||||
)
|
)
|
||||||
|
|
||||||
fsdp_pg = twod_mesh.get_group(mesh_dim=0)
|
|
||||||
|
|
||||||
inp = torch.rand(20, 10, device=self.device_type)
|
inp = torch.rand(20, 10, device=self.device_type)
|
||||||
parallelize_plan = {
|
parallelize_plan = {
|
||||||
"mlp_0.net1": ColwiseParallel(),
|
"mlp_0.net1": ColwiseParallel(),
|
||||||
|
|
|
||||||
|
|
@ -249,7 +249,7 @@ class DistTensorRandomOpTest(DTensorTestBase):
|
||||||
|
|
||||||
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||||
# seed synchronization happens after the first `distribute_tensor` call
|
# seed synchronization happens after the first `distribute_tensor` call
|
||||||
dtensor = distribute_tensor(
|
distribute_tensor(
|
||||||
torch.empty([self.world_size], device="cuda"), device_mesh, [Shard(0)]
|
torch.empty([self.world_size], device="cuda"), device_mesh, [Shard(0)]
|
||||||
)
|
)
|
||||||
self.assertEqual(seed_from_rank_0, random._rng_tracker.get_seed("parallel-rng"))
|
self.assertEqual(seed_from_rank_0, random._rng_tracker.get_seed("parallel-rng"))
|
||||||
|
|
|
||||||
|
|
@ -309,7 +309,7 @@ class RedistributeTest(DTensorTestBase):
|
||||||
shard_tensor = distribute_tensor(local_tensor, device_mesh, shard_spec)
|
shard_tensor = distribute_tensor(local_tensor, device_mesh, shard_spec)
|
||||||
self.assertEqual(shard_tensor.placements[0].dim, 1)
|
self.assertEqual(shard_tensor.placements[0].dim, 1)
|
||||||
reshard_tensor = shard_tensor.redistribute(device_mesh, shard_minus_spec)
|
reshard_tensor = shard_tensor.redistribute(device_mesh, shard_minus_spec)
|
||||||
self.assertEqual(shard_tensor.placements[0].dim, 1)
|
self.assertEqual(reshard_tensor.placements[0].dim, 1)
|
||||||
|
|
||||||
@with_comms
|
@with_comms
|
||||||
def test_redistribute_uneven_sharding(self):
|
def test_redistribute_uneven_sharding(self):
|
||||||
|
|
|
||||||
|
|
@ -622,7 +622,7 @@ class DistTensorOpsTest(DTensorTestBase):
|
||||||
self.assertEqual(misses, 2)
|
self.assertEqual(misses, 2)
|
||||||
|
|
||||||
# convert to fp32 again and see if there's cache hit
|
# convert to fp32 again and see if there's cache hit
|
||||||
fp32_sharded_dtensor1 = bf16_sharded_dtensor1.float()
|
bf16_sharded_dtensor1.float()
|
||||||
hits, misses, _, _ = _get_sharding_prop_cache_info()
|
hits, misses, _, _ = _get_sharding_prop_cache_info()
|
||||||
# by now we should have cache hit
|
# by now we should have cache hit
|
||||||
self.assertEqual(hits, 1)
|
self.assertEqual(hits, 1)
|
||||||
|
|
|
||||||
|
|
@ -133,7 +133,6 @@ class UtilTest(DTensorTestBase):
|
||||||
global_tensor_shape, global_mesh, placements
|
global_tensor_shape, global_mesh, placements
|
||||||
)
|
)
|
||||||
assert global_mesh.get_coordinate is not None
|
assert global_mesh.get_coordinate is not None
|
||||||
dp_replic_rank = global_mesh.get_local_rank("dp_replic")
|
|
||||||
dp_shard_rank = global_mesh.get_local_rank("dp_shard")
|
dp_shard_rank = global_mesh.get_local_rank("dp_shard")
|
||||||
tp_rank = global_mesh.get_local_rank("tp")
|
tp_rank = global_mesh.get_local_rank("tp")
|
||||||
shard_idx_on_dim_0 = tp_rank * dp_shard_size + dp_shard_rank
|
shard_idx_on_dim_0 = tp_rank * dp_shard_size + dp_shard_rank
|
||||||
|
|
|
||||||
|
|
@ -150,7 +150,7 @@ class DTensorXLAIntegrationTest(TestCase):
|
||||||
shard_spec = [Shard(0)]
|
shard_spec = [Shard(0)]
|
||||||
# annoate fc1 and fc2
|
# annoate fc1 and fc2
|
||||||
if isinstance(mod, nn.Linear):
|
if isinstance(mod, nn.Linear):
|
||||||
for name, param in mod.named_parameters():
|
for _, param in mod.named_parameters():
|
||||||
# annotate the parameter tensors directly
|
# annotate the parameter tensors directly
|
||||||
distribute_tensor(param, mesh, shard_spec)
|
distribute_tensor(param, mesh, shard_spec)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
# Owner(s): ["oncall: distributed"]
|
# Owner(s): ["oncall: distributed"]
|
||||||
|
# ruff: noqa: F841
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
|
||||||
|
|
@ -277,7 +277,7 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
|
||||||
self.assertEqual(loss, dist_loss)
|
self.assertEqual(loss, dist_loss)
|
||||||
|
|
||||||
dist_msd, dist_osd = get_state_dict(dist_model, optimizers=dist_optim)
|
dist_msd, dist_osd = get_state_dict(dist_model, optimizers=dist_optim)
|
||||||
model_sd, optim_sd = get_state_dict(model, optimizers=optim)
|
model_sd, _ = get_state_dict(model, optimizers=optim)
|
||||||
|
|
||||||
self._verify_msd(model_sd, dist_msd)
|
self._verify_msd(model_sd, dist_msd)
|
||||||
self._verify_osd_by_load(model, optim, self._optim(model), dist_osd)
|
self._verify_osd_by_load(model, optim, self._optim(model), dist_osd)
|
||||||
|
|
|
||||||
|
|
@ -96,7 +96,7 @@ class TestFineTuning(DTensorTestBase):
|
||||||
optim = torch.optim.Adam(model.parameters(), lr=1e-3)
|
optim = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
for i in range(3):
|
for _ in range(3):
|
||||||
batch = torch.rand(32, DIM, device="cuda")
|
batch = torch.rand(32, DIM, device="cuda")
|
||||||
loss = model(batch).sum()
|
loss = model(batch).sum()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
@ -161,7 +161,7 @@ class TestFineTuning(DTensorTestBase):
|
||||||
self.assertEqual(i, 0)
|
self.assertEqual(i, 0)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
for j in range(3):
|
for _ in range(3):
|
||||||
batch = torch.rand(32, DIM, device="cuda")
|
batch = torch.rand(32, DIM, device="cuda")
|
||||||
loss = model(batch).sum()
|
loss = model(batch).sum()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
|
||||||
|
|
@ -85,11 +85,9 @@ class TestDistributedCheckpointing(ShardedTensorTestBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
st = sharded_tensor.zeros(spec, 4, 4, dtype=torch.float64)
|
st = sharded_tensor.zeros(spec, 4, 4, dtype=torch.float64)
|
||||||
mapping = {}
|
|
||||||
|
|
||||||
md = _create_default_local_metadata({"st": st})
|
md = _create_default_local_metadata({"st": st})
|
||||||
|
|
||||||
st_md = md.state_dict_metadata["st"]
|
st_md = md.state_dict_metadata["st"]
|
||||||
|
|
||||||
self.assertEqual(1, len(st_md.chunks))
|
self.assertEqual(1, len(st_md.chunks))
|
||||||
|
|
||||||
@with_comms(init_rpc=False)
|
@with_comms(init_rpc=False)
|
||||||
|
|
|
||||||
|
|
@ -86,7 +86,6 @@ class TestFsdpTpCheckpointConversion(DTensorTestBase):
|
||||||
tp_model.load_state_dict(tp_state_dict)
|
tp_model.load_state_dict(tp_state_dict)
|
||||||
|
|
||||||
# Check parameters are equal after loading.
|
# Check parameters are equal after loading.
|
||||||
tp_state_dict_after_load = tp_model.state_dict()
|
|
||||||
for fsdp_item, tp_item in zip(fsdp_state_dict.items(), tp_state_dict.items()):
|
for fsdp_item, tp_item in zip(fsdp_state_dict.items(), tp_state_dict.items()):
|
||||||
fsdp_k, fsdp_v = fsdp_item
|
fsdp_k, fsdp_v = fsdp_item
|
||||||
tp_k, tp_v = tp_item
|
tp_k, tp_v = tp_item
|
||||||
|
|
|
||||||
|
|
@ -120,7 +120,6 @@ class TestHSDPCheckpoint(DTensorTestBase):
|
||||||
)
|
)
|
||||||
model.load_state_dict(state_dict_to_save["model"])
|
model.load_state_dict(state_dict_to_save["model"])
|
||||||
|
|
||||||
state_dict_after_load = model.state_dict()
|
|
||||||
# After loading, the current model state dict should be the same as state_dict_to_save.
|
# After loading, the current model state dict should be the same as state_dict_to_save.
|
||||||
for (k1, v1), (k2, v2) in zip(
|
for (k1, v1), (k2, v2) in zip(
|
||||||
state_dict_to_save["model"].items(), model.state_dict().items()
|
state_dict_to_save["model"].items(), model.state_dict().items()
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ class TestFlattening(TestCase):
|
||||||
"k3": ["x", 99, [{"k3": "y"}]],
|
"k3": ["x", 99, [{"k3": "y"}]],
|
||||||
}
|
}
|
||||||
|
|
||||||
flatten_dict, mapping = flatten_state_dict(state_dict)
|
_, mapping = flatten_state_dict(state_dict)
|
||||||
"""
|
"""
|
||||||
flatten_dict:
|
flatten_dict:
|
||||||
{'k0': [1], 'k2.0': tensor([1]), 'k2.1': 99, 'k2.2.0.k3': tensor(1), 'k3': ['x', 99, [{'k3': 'y'}]]}
|
{'k0': [1], 'k2.0': tensor([1]), 'k2.1': 99, 'k2.2.0.k3': tensor(1), 'k3': ['x', 99, [{'k3': 'y'}]]}
|
||||||
|
|
|
||||||
|
|
@ -40,21 +40,19 @@ class TestSaveAndLoadAPI(DTensorTestBase):
|
||||||
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
|
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||||
model = FSDP(model, device_mesh=device_mesh)
|
model = FSDP(model, device_mesh=device_mesh)
|
||||||
dcp.save(model.state_dict(), checkpoint_id=os.path.join(self.temp_dir, "first"))
|
dcp.save(model.state_dict(), checkpoint_id=os.path.join(self.temp_dir, "first"))
|
||||||
sd = dcp.load(
|
dcp.load(model.state_dict(), checkpoint_id=os.path.join(self.temp_dir, "first"))
|
||||||
model.state_dict(), checkpoint_id=os.path.join(self.temp_dir, "first")
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(
|
with patch.object(
|
||||||
dcp.FileSystemReader, "validate_checkpoint_id", return_value=False
|
dcp.FileSystemReader, "validate_checkpoint_id", return_value=False
|
||||||
) as m1:
|
):
|
||||||
with patch.object(
|
with patch.object(
|
||||||
dcp.FileSystemWriter, "validate_checkpoint_id", return_value=False
|
dcp.FileSystemWriter, "validate_checkpoint_id", return_value=False
|
||||||
) as m2:
|
):
|
||||||
dcp.save(
|
dcp.save(
|
||||||
model.state_dict(),
|
model.state_dict(),
|
||||||
checkpoint_id=os.path.join(self.temp_dir, "second"),
|
checkpoint_id=os.path.join(self.temp_dir, "second"),
|
||||||
)
|
)
|
||||||
sd = dcp.load(
|
dcp.load(
|
||||||
model.state_dict(),
|
model.state_dict(),
|
||||||
checkpoint_id=os.path.join(self.temp_dir, "second"),
|
checkpoint_id=os.path.join(self.temp_dir, "second"),
|
||||||
)
|
)
|
||||||
|
|
@ -62,7 +60,7 @@ class TestSaveAndLoadAPI(DTensorTestBase):
|
||||||
with self.assertRaisesRegex(RuntimeError, "Cannot detect"):
|
with self.assertRaisesRegex(RuntimeError, "Cannot detect"):
|
||||||
dcp.save(model.state_dict(), checkpoint_id="abc://abc.abc")
|
dcp.save(model.state_dict(), checkpoint_id="abc://abc.abc")
|
||||||
with self.assertRaisesRegex(RuntimeError, "Cannot detect"):
|
with self.assertRaisesRegex(RuntimeError, "Cannot detect"):
|
||||||
sd = dcp.load(model.state_dict(), checkpoint_id="abc://abc.abc")
|
dcp.load(model.state_dict(), checkpoint_id="abc://abc.abc")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -81,7 +81,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||||
|
|
||||||
# Train 10 steps.
|
# Train 10 steps.
|
||||||
_dist_optim = [dist_optim] if not isinstance(dist_optim, list) else dist_optim
|
_dist_optim = [dist_optim] if not isinstance(dist_optim, list) else dist_optim
|
||||||
for i in range(10):
|
for _ in range(10):
|
||||||
optim.zero_grad()
|
optim.zero_grad()
|
||||||
for d_optim in _dist_optim:
|
for d_optim in _dist_optim:
|
||||||
d_optim.zero_grad()
|
d_optim.zero_grad()
|
||||||
|
|
|
||||||
|
|
@ -104,7 +104,7 @@ class TestStateDictUtils(DTensorTestBase):
|
||||||
return tensor, dist_tensor
|
return tensor, dist_tensor
|
||||||
|
|
||||||
ltensor, ldtensor = [], []
|
ltensor, ldtensor = [], []
|
||||||
for i in range(10):
|
for _ in range(10):
|
||||||
tensor, dtensor = create_dtensor()
|
tensor, dtensor = create_dtensor()
|
||||||
ltensor.append(tensor)
|
ltensor.append(tensor)
|
||||||
ltensor.append(torch.ones(10, device=torch.device("cuda")))
|
ltensor.append(torch.ones(10, device=torch.device("cuda")))
|
||||||
|
|
|
||||||
|
|
@ -259,7 +259,7 @@ class _StartProcessesTest(TestCase):
|
||||||
) -> None:
|
) -> None:
|
||||||
mp_queue = mp.get_context("spawn").Queue()
|
mp_queue = mp.get_context("spawn").Queue()
|
||||||
child_nproc = 2
|
child_nproc = 2
|
||||||
ctx = mp.spawn(
|
mp.spawn(
|
||||||
start_processes_zombie_test,
|
start_processes_zombie_test,
|
||||||
nprocs=1,
|
nprocs=1,
|
||||||
args=(entrypoint, mp_queue, self.log_dir(), child_nproc),
|
args=(entrypoint, mp_queue, self.log_dir(), child_nproc),
|
||||||
|
|
|
||||||
|
|
@ -165,7 +165,7 @@ class CreateBackendTest(TestCase):
|
||||||
def test_create_backend_returns_backend_if_is_host_is_not_specified_and_store_already_exists(
|
def test_create_backend_returns_backend_if_is_host_is_not_specified_and_store_already_exists(
|
||||||
self,
|
self,
|
||||||
) -> None:
|
) -> None:
|
||||||
store = TCPStore( # type: ignore[call-arg] # noqa: F841
|
TCPStore( # type: ignore[call-arg] # noqa: F841
|
||||||
self._expected_endpoint_host, self._expected_endpoint_port, is_master=True
|
self._expected_endpoint_host, self._expected_endpoint_port, is_master=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -99,7 +99,7 @@ class RendezvousTimeoutTest(TestCase):
|
||||||
ValueError,
|
ValueError,
|
||||||
rf"^The join timeout \({join_timeout}\) must be positive.$",
|
rf"^The join timeout \({join_timeout}\) must be positive.$",
|
||||||
):
|
):
|
||||||
timeout = RendezvousTimeout(join_timeout)
|
RendezvousTimeout(join_timeout)
|
||||||
|
|
||||||
|
|
||||||
class NodeDescTest(TestCase):
|
class NodeDescTest(TestCase):
|
||||||
|
|
@ -1637,7 +1637,7 @@ class CreateHandlerTest(TestCase):
|
||||||
def _ignore_exception(exception_type: Exception, fn: Callable):
|
def _ignore_exception(exception_type: Exception, fn: Callable):
|
||||||
try:
|
try:
|
||||||
fn()
|
fn()
|
||||||
except exception_type as e:
|
except exception_type:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -70,7 +70,7 @@ class RendezvousBackendTestMixin(ABC):
|
||||||
self.assertTrue(has_set)
|
self.assertTrue(has_set)
|
||||||
|
|
||||||
def test_set_state_sets_backend_state_if_token_is_current(self) -> None:
|
def test_set_state_sets_backend_state_if_token_is_current(self) -> None:
|
||||||
state1, token1, has_set1 = self._set_state(b"x")
|
_, token1, has_set1 = self._set_state(b"x")
|
||||||
|
|
||||||
state2, token2, has_set2 = self._set_state(b"y", token1)
|
state2, token2, has_set2 = self._set_state(b"y", token1)
|
||||||
|
|
||||||
|
|
@ -80,7 +80,7 @@ class RendezvousBackendTestMixin(ABC):
|
||||||
self.assertTrue(has_set2)
|
self.assertTrue(has_set2)
|
||||||
|
|
||||||
def test_set_state_returns_current_backend_state_if_token_is_old(self) -> None:
|
def test_set_state_returns_current_backend_state_if_token_is_old(self) -> None:
|
||||||
state1, token1, _ = self._set_state(b"x")
|
_, token1, _ = self._set_state(b"x")
|
||||||
|
|
||||||
state2, token2, _ = self._set_state(b"y", token1)
|
state2, token2, _ = self._set_state(b"y", token1)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -113,7 +113,7 @@ if not (IS_WINDOWS or IS_MACOS):
|
||||||
num_clients = 10
|
num_clients = 10
|
||||||
num_requests_per_client = 10
|
num_requests_per_client = 10
|
||||||
processes = []
|
processes = []
|
||||||
for i in range(num_clients):
|
for _ in range(num_clients):
|
||||||
p = mp.Process(
|
p = mp.Process(
|
||||||
target=func, args=(num_requests_per_client, self.file_path)
|
target=func, args=(num_requests_per_client, self.file_path)
|
||||||
)
|
)
|
||||||
|
|
@ -190,7 +190,7 @@ if not (IS_WINDOWS or IS_MACOS):
|
||||||
"""
|
"""
|
||||||
client = timer.FileTimerClient(file_path)
|
client = timer.FileTimerClient(file_path)
|
||||||
sem.release()
|
sem.release()
|
||||||
for i in range(0, n):
|
for _ in range(0, n):
|
||||||
client.acquire("test_scope", 0)
|
client.acquire("test_scope", 0)
|
||||||
time.sleep(interval)
|
time.sleep(interval)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -159,7 +159,7 @@ class CheckpointWrapperTest(TestCase):
|
||||||
if use_reentrant
|
if use_reentrant
|
||||||
else CheckpointImpl.NO_REENTRANT,
|
else CheckpointImpl.NO_REENTRANT,
|
||||||
)
|
)
|
||||||
for i in range(self.n):
|
for _ in range(self.n):
|
||||||
l = nn.Sequential(
|
l = nn.Sequential(
|
||||||
nn.Linear(256, 256), nn.Linear(256, 256), nn.Linear(256, 256)
|
nn.Linear(256, 256), nn.Linear(256, 256), nn.Linear(256, 256)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -303,13 +303,13 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
|
||||||
RuntimeError, "DeviceMesh is not compatible with LOCAL_STATE_DICT."
|
RuntimeError, "DeviceMesh is not compatible with LOCAL_STATE_DICT."
|
||||||
):
|
):
|
||||||
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
|
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
|
||||||
state_dict = model.state_dict()
|
model.state_dict()
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError, "DeviceMesh is not compatible with LOCAL_STATE_DICT."
|
RuntimeError, "DeviceMesh is not compatible with LOCAL_STATE_DICT."
|
||||||
):
|
):
|
||||||
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
|
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
|
||||||
optim_state_dict = FSDP.optim_state_dict(model, optim)
|
FSDP.optim_state_dict(model, optim)
|
||||||
|
|
||||||
|
|
||||||
instantiate_parametrized_tests(TestFSDPWithDeviceMeshAndDTensor)
|
instantiate_parametrized_tests(TestFSDPWithDeviceMeshAndDTensor)
|
||||||
|
|
|
||||||
|
|
@ -364,9 +364,8 @@ class TestFSDPFineTune(FSDPTest):
|
||||||
)
|
)
|
||||||
torch.manual_seed(self.rank + 1)
|
torch.manual_seed(self.rank + 1)
|
||||||
losses = []
|
losses = []
|
||||||
for idx in range(6):
|
for _ in range(6):
|
||||||
frozen_input = torch.randn((4, 4), device="cuda", requires_grad=False)
|
frozen_input = torch.randn((4, 4), device="cuda", requires_grad=False)
|
||||||
learnable_input = torch.randn((4, 4), device="cuda", requires_grad=True)
|
|
||||||
for _model, _optim in ((model, model_optim), (ref_model, ref_model_optim)):
|
for _model, _optim in ((model, model_optim), (ref_model, ref_model_optim)):
|
||||||
loss = _model(frozen_input, frozen_input).sum()
|
loss = _model(frozen_input, frozen_input).sum()
|
||||||
losses.append(loss)
|
losses.append(loss)
|
||||||
|
|
|
||||||
|
|
@ -182,7 +182,7 @@ class TestFreezingWeights(FSDPTest):
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
|
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
|
||||||
|
|
||||||
for iteration in range(3):
|
for _ in range(3):
|
||||||
out = model(batch)
|
out = model(batch)
|
||||||
fake_loss = criterion(out, target)
|
fake_loss = criterion(out, target)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
|
||||||
|
|
@ -108,8 +108,6 @@ class TestFSDPMemory(FSDPTest):
|
||||||
|
|
||||||
def _dist_train(self, with_checkpoint, expected, model_hidden_dim, iterations):
|
def _dist_train(self, with_checkpoint, expected, model_hidden_dim, iterations):
|
||||||
gpu_id = self.rank
|
gpu_id = self.rank
|
||||||
world_size = self.world_size
|
|
||||||
|
|
||||||
batch = torch.randn(size=(2, 3, 224, 224)).cuda()
|
batch = torch.randn(size=(2, 3, 224, 224)).cuda()
|
||||||
|
|
||||||
model = create_model(
|
model = create_model(
|
||||||
|
|
|
||||||
|
|
@ -278,9 +278,9 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
||||||
)
|
)
|
||||||
x = torch.randn(10, 10, device="cuda")
|
x = torch.randn(10, 10, device="cuda")
|
||||||
y = torch.randn(10, 10, device="cuda")
|
y = torch.randn(10, 10, device="cuda")
|
||||||
for i in range(4):
|
for _ in range(4):
|
||||||
if use_second_layer:
|
if use_second_layer:
|
||||||
a, b = fsdp(x, y)
|
a, _ = fsdp(x, y)
|
||||||
else:
|
else:
|
||||||
a = fsdp(x, y)
|
a = fsdp(x, y)
|
||||||
loss = a.sum()
|
loss = a.sum()
|
||||||
|
|
@ -509,7 +509,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
||||||
def test_fsdp_cpu_training(self):
|
def test_fsdp_cpu_training(self):
|
||||||
"""Tests FSDP training on CPU."""
|
"""Tests FSDP training on CPU."""
|
||||||
gloo_pg = dist.new_group(backend="gloo")
|
gloo_pg = dist.new_group(backend="gloo")
|
||||||
for ss in [
|
for ss in [ # noqa: F841
|
||||||
ShardingStrategy.NO_SHARD,
|
ShardingStrategy.NO_SHARD,
|
||||||
ShardingStrategy.FULL_SHARD,
|
ShardingStrategy.FULL_SHARD,
|
||||||
ShardingStrategy.SHARD_GRAD_OP,
|
ShardingStrategy.SHARD_GRAD_OP,
|
||||||
|
|
@ -857,13 +857,13 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
||||||
torch.cuda.set_device(self.rank)
|
torch.cuda.set_device(self.rank)
|
||||||
# Test CPU
|
# Test CPU
|
||||||
no_params = nn.ReLU()
|
no_params = nn.ReLU()
|
||||||
module = FSDP(no_params)
|
FSDP(no_params)
|
||||||
# Test CUDA
|
# Test CUDA
|
||||||
no_params = nn.ReLU().cuda()
|
no_params = nn.ReLU().cuda()
|
||||||
module = FSDP(no_params)
|
FSDP(no_params)
|
||||||
# Test CPU + device_id
|
# Test CPU + device_id
|
||||||
no_params = nn.ReLU()
|
no_params = nn.ReLU()
|
||||||
module = FSDP(no_params, device_id=torch.cuda.current_device())
|
FSDP(no_params, device_id=torch.cuda.current_device())
|
||||||
# For modules with no params, wrong device_id will raise error about
|
# For modules with no params, wrong device_id will raise error about
|
||||||
# inconsistency between compute_device and device_id, since compute_device
|
# inconsistency between compute_device and device_id, since compute_device
|
||||||
# is computed as torch.cuda.current_device when there are no params.
|
# is computed as torch.cuda.current_device when there are no params.
|
||||||
|
|
|
||||||
|
|
@ -1139,7 +1139,6 @@ class TestFSDPDifferentSubmodulePrecision(FSDPTest):
|
||||||
model = SaveForwardInputsModel(
|
model = SaveForwardInputsModel(
|
||||||
forward_inputs=forward_inputs, cast_forward_inputs=False
|
forward_inputs=forward_inputs, cast_forward_inputs=False
|
||||||
).cuda()
|
).cuda()
|
||||||
c1, c2 = model.c1, model.c2
|
|
||||||
x = torch.zeros(2, 100, device="cuda")
|
x = torch.zeros(2, 100, device="cuda")
|
||||||
|
|
||||||
# float16 on one submodule and float32 on everything else
|
# float16 on one submodule and float32 on everything else
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ class TestMultipleWrapping(FSDPTest):
|
||||||
model = FSDP(inner_model).cuda()
|
model = FSDP(inner_model).cuda()
|
||||||
optim = SGD(model.parameters(), lr=0.1)
|
optim = SGD(model.parameters(), lr=0.1)
|
||||||
|
|
||||||
for i in range(3):
|
for _ in range(3):
|
||||||
input = torch.rand((1, 5), dtype=torch.float).cuda()
|
input = torch.rand((1, 5), dtype=torch.float).cuda()
|
||||||
input.requires_grad = True
|
input.requires_grad = True
|
||||||
output = model(input)
|
output = model(input)
|
||||||
|
|
|
||||||
|
|
@ -1510,7 +1510,7 @@ class TestFSDPOptimState(FSDPTest):
|
||||||
) = self._init_nested_model(wrap=False, use_multiple_param_groups=False)
|
) = self._init_nested_model(wrap=False, use_multiple_param_groups=False)
|
||||||
if should_check_method_fn("rekey_optim_state_dict"):
|
if should_check_method_fn("rekey_optim_state_dict"):
|
||||||
with context_fn():
|
with context_fn():
|
||||||
rekeyed_osd = FSDP.rekey_optim_state_dict(
|
FSDP.rekey_optim_state_dict(
|
||||||
fsdp_osd, # from `full_optim_state_dict()`
|
fsdp_osd, # from `full_optim_state_dict()`
|
||||||
OptimStateKeyType.PARAM_ID,
|
OptimStateKeyType.PARAM_ID,
|
||||||
nonwrapped_model,
|
nonwrapped_model,
|
||||||
|
|
@ -1650,7 +1650,7 @@ class TestFSDPOptimState(FSDPTest):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make optim1 has a different state.
|
# Make optim1 has a different state.
|
||||||
for i in range(5):
|
for _ in range(5):
|
||||||
batch = torch.rand(5, 8).cuda()
|
batch = torch.rand(5, 8).cuda()
|
||||||
loss = models[1](batch).sum()
|
loss = models[1](batch).sum()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
@ -1765,7 +1765,7 @@ class TestFSDPOptimState(FSDPTest):
|
||||||
initializer = self._model_class[model_class]
|
initializer = self._model_class[model_class]
|
||||||
|
|
||||||
# First, run a wrapped model with full world size for a few iterations
|
# First, run a wrapped model with full world size for a few iterations
|
||||||
model1, optim1, optim_input1 = initializer(
|
model1, optim1, _ = initializer(
|
||||||
wrap=True,
|
wrap=True,
|
||||||
use_multiple_param_groups=use_multiple_param_groups,
|
use_multiple_param_groups=use_multiple_param_groups,
|
||||||
)
|
)
|
||||||
|
|
@ -1788,7 +1788,7 @@ class TestFSDPOptimState(FSDPTest):
|
||||||
new_group = dist.distributed_c10d._get_default_group()
|
new_group = dist.distributed_c10d._get_default_group()
|
||||||
# Second, run a wrapped model with (possibly) halved world size and
|
# Second, run a wrapped model with (possibly) halved world size and
|
||||||
# (possibly) differing `optim_input` across ranks
|
# (possibly) differing `optim_input` across ranks
|
||||||
model2, optim2, optim_input2 = initializer(
|
model2, optim2, _ = initializer(
|
||||||
wrap=True,
|
wrap=True,
|
||||||
group=new_group,
|
group=new_group,
|
||||||
use_multiple_param_groups=use_multiple_param_groups,
|
use_multiple_param_groups=use_multiple_param_groups,
|
||||||
|
|
@ -1861,7 +1861,8 @@ class TestFSDPOptimState(FSDPTest):
|
||||||
FSDP.optim_state_dict(model, optim), osd, check_same_param_keys=True
|
FSDP.optim_state_dict(model, optim), osd, check_same_param_keys=True
|
||||||
)
|
)
|
||||||
step()
|
step()
|
||||||
osd_to_load = FSDP.optim_state_dict_to_load(
|
|
||||||
|
osd_to_load = FSDP.optim_state_dict_to_load( # noqa: F841
|
||||||
model, optim, osd, load_directly=True
|
model, optim, osd, load_directly=True
|
||||||
)
|
)
|
||||||
self._check_same_state(
|
self._check_same_state(
|
||||||
|
|
@ -1994,7 +1995,7 @@ class TestFSDPOptimState(FSDPTest):
|
||||||
loss.backward()
|
loss.backward()
|
||||||
fsdp_optim.step()
|
fsdp_optim.step()
|
||||||
orig_state_dict = deepcopy(fsdp_optim.state_dict())
|
orig_state_dict = deepcopy(fsdp_optim.state_dict())
|
||||||
optim_state_dict = FSDP.optim_state_dict(fsdp_model, fsdp_optim)
|
FSDP.optim_state_dict(fsdp_model, fsdp_optim)
|
||||||
FSDP.optim_state_dict_to_load(
|
FSDP.optim_state_dict_to_load(
|
||||||
fsdp_model,
|
fsdp_model,
|
||||||
fsdp_optim,
|
fsdp_optim,
|
||||||
|
|
|
||||||
|
|
@ -966,7 +966,7 @@ class TestFSDPStateDict(FSDPTest):
|
||||||
setattr(module, LINEAR_SKIP, linear_skip)
|
setattr(module, LINEAR_SKIP, linear_skip)
|
||||||
return fsdp, linear_skip_tensor_names
|
return fsdp, linear_skip_tensor_names
|
||||||
|
|
||||||
fsdp, linear_skip_tensor_names = _create_module()
|
fsdp, _ = _create_module()
|
||||||
# Run a forward pass
|
# Run a forward pass
|
||||||
inp = torch.randn((1, 10), device=torch.cuda.current_device())
|
inp = torch.randn((1, 10), device=torch.cuda.current_device())
|
||||||
loss = fsdp(inp)
|
loss = fsdp(inp)
|
||||||
|
|
|
||||||
|
|
@ -634,7 +634,7 @@ class TestUnshardParams(TestUnshardParamsBase):
|
||||||
model = FSDP(model, auto_wrap_policy=ModuleWrapPolicy((nn.Sequential,)))
|
model = FSDP(model, auto_wrap_policy=ModuleWrapPolicy((nn.Sequential,)))
|
||||||
with FSDP.summon_full_params(model[0]):
|
with FSDP.summon_full_params(model[0]):
|
||||||
# Check that the summoned module does not have its flat parameter
|
# Check that the summoned module does not have its flat parameter
|
||||||
for param_name, param in model[0].named_parameters():
|
for param_name, _ in model[0].named_parameters():
|
||||||
self.assertFalse(FLAT_PARAM in param_name)
|
self.assertFalse(FLAT_PARAM in param_name)
|
||||||
self.assertGreater(len(list(model[0].parameters())), 1)
|
self.assertGreater(len(list(model[0].parameters())), 1)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -260,7 +260,7 @@ class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest):
|
||||||
model = FSDP(copy.deepcopy(base_model), self.process_group, **fsdp_kwargs)
|
model = FSDP(copy.deepcopy(base_model), self.process_group, **fsdp_kwargs)
|
||||||
model = torch.compile(model)
|
model = torch.compile(model)
|
||||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||||
for i in range(10):
|
for _ in range(10):
|
||||||
losses = []
|
losses = []
|
||||||
inp = ref_model.get_input(torch.device("cuda"))
|
inp = ref_model.get_input(torch.device("cuda"))
|
||||||
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
||||||
|
|
|
||||||
|
|
@ -118,7 +118,7 @@ class TestUtils(TestCase):
|
||||||
x.fill_(0)
|
x.fill_(0)
|
||||||
|
|
||||||
x = nn.utils.rnn.pack_padded_sequence(x, seq_length)
|
x = nn.utils.rnn.pack_padded_sequence(x, seq_length)
|
||||||
x, h = rnn(x)
|
x, _ = rnn(x)
|
||||||
x = _apply_to_tensors(fill_fn, x)
|
x = _apply_to_tensors(fill_fn, x)
|
||||||
x, _ = nn.utils.rnn.pad_packed_sequence(x)
|
x, _ = nn.utils.rnn.pad_packed_sequence(x)
|
||||||
self.assertEqual(torch.sum(x), 0)
|
self.assertEqual(torch.sum(x), 0)
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,6 @@ class LaunchTest(unittest.TestCase):
|
||||||
def test_launch_without_env(self):
|
def test_launch_without_env(self):
|
||||||
nnodes = 1
|
nnodes = 1
|
||||||
nproc_per_node = 4
|
nproc_per_node = 4
|
||||||
world_size = nnodes * nproc_per_node
|
|
||||||
sock = get_socket_with_port()
|
sock = get_socket_with_port()
|
||||||
with closing(sock):
|
with closing(sock):
|
||||||
master_port = sock.getsockname()[1]
|
master_port = sock.getsockname()[1]
|
||||||
|
|
|
||||||
|
|
@ -114,7 +114,7 @@ class CustomLinearDx(Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
input_val, weight, bias = ctx.saved_tensors
|
input_val, weight, _ = ctx.saved_tensors
|
||||||
grad_input = grad_output.mm(weight)
|
grad_input = grad_output.mm(weight)
|
||||||
ctx.module.cached_context[ctx.layer_idx].append(grad_output.clone())
|
ctx.module.cached_context[ctx.layer_idx].append(grad_output.clone())
|
||||||
ctx.module.cached_context[str(ctx.layer_idx) + "_input"].append(
|
ctx.module.cached_context[str(ctx.layer_idx) + "_input"].append(
|
||||||
|
|
@ -131,7 +131,7 @@ class CustomLinearDxDw(Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
input_val, weight, bias = ctx.saved_tensors
|
input_val, weight, _ = ctx.saved_tensors
|
||||||
grad_input = grad_output.mm(weight)
|
grad_input = grad_output.mm(weight)
|
||||||
grad_weight = grad_output.t().mm(input_val)
|
grad_weight = grad_output.t().mm(input_val)
|
||||||
grad_bias = grad_output.sum(0)
|
grad_bias = grad_output.sum(0)
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,7 @@ class StageBackwardTests(TestCase):
|
||||||
# Forward, then backward of loss with respect to inputs
|
# Forward, then backward of loss with respect to inputs
|
||||||
out = mod(x)
|
out = mod(x)
|
||||||
loss = loss_fn(out, target)
|
loss = loss_fn(out, target)
|
||||||
dinputs, param_groups = stage_backward_input(
|
dinputs, _param_groups = stage_backward_input(
|
||||||
stage_outputs_or_loss=(loss,),
|
stage_outputs_or_loss=(loss,),
|
||||||
output_grads=None,
|
output_grads=None,
|
||||||
input_values=[x],
|
input_values=[x],
|
||||||
|
|
@ -88,7 +88,7 @@ class StageBackwardTests(TestCase):
|
||||||
|
|
||||||
torch.testing.assert_close(x.grad, ref_x.grad)
|
torch.testing.assert_close(x.grad, ref_x.grad)
|
||||||
torch.testing.assert_close(dinputs[0], ref_x.grad)
|
torch.testing.assert_close(dinputs[0], ref_x.grad)
|
||||||
for name, p in mod.named_parameters():
|
for _, p in mod.named_parameters():
|
||||||
# Check that the weight gradients were not updated
|
# Check that the weight gradients were not updated
|
||||||
self.assertEqual(p.grad, None)
|
self.assertEqual(p.grad, None)
|
||||||
|
|
||||||
|
|
@ -109,7 +109,7 @@ class StageBackwardTests(TestCase):
|
||||||
# Forward, then backward of loss with respect to inputs
|
# Forward, then backward of loss with respect to inputs
|
||||||
out = mod(x)
|
out = mod(x)
|
||||||
loss = loss_fn(out, target)
|
loss = loss_fn(out, target)
|
||||||
dinputs, param_groups = stage_backward_input(
|
_dinputs, param_groups = stage_backward_input(
|
||||||
stage_outputs_or_loss=(loss,),
|
stage_outputs_or_loss=(loss,),
|
||||||
output_grads=None,
|
output_grads=None,
|
||||||
input_values=[x],
|
input_values=[x],
|
||||||
|
|
@ -157,7 +157,7 @@ class StageBackwardTests(TestCase):
|
||||||
for x in inputs:
|
for x in inputs:
|
||||||
out = mod(x)
|
out = mod(x)
|
||||||
loss = loss_fn(out, target)
|
loss = loss_fn(out, target)
|
||||||
dinputs, param_groups = stage_backward_input(
|
_dinputs, param_groups = stage_backward_input(
|
||||||
stage_outputs_or_loss=(loss,),
|
stage_outputs_or_loss=(loss,),
|
||||||
output_grads=None,
|
output_grads=None,
|
||||||
input_values=[x],
|
input_values=[x],
|
||||||
|
|
|
||||||
|
|
@ -264,7 +264,7 @@ class TestSchedulePlan(TestCase):
|
||||||
]
|
]
|
||||||
|
|
||||||
schedule = ScheduleClass(stages, num_microbatches)
|
schedule = ScheduleClass(stages, num_microbatches)
|
||||||
formatted_pipeline_order = _format_pipeline_order(
|
_formatted_pipeline_order = _format_pipeline_order(
|
||||||
schedule.pipeline_order
|
schedule.pipeline_order
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -305,10 +305,7 @@ class TestSchedulePlan(TestCase):
|
||||||
for i in range(num_local_stages)
|
for i in range(num_local_stages)
|
||||||
]
|
]
|
||||||
schedule = ScheduleClass(stages, num_microbatches)
|
schedule = ScheduleClass(stages, num_microbatches)
|
||||||
formatted_pipeline_order = _format_pipeline_order(
|
_format_pipeline_order(schedule.pipeline_order)
|
||||||
schedule.pipeline_order
|
|
||||||
)
|
|
||||||
# print(formatted_pipeline_order)
|
|
||||||
|
|
||||||
def stage_to_rank(stage):
|
def stage_to_rank(stage):
|
||||||
return stage % group_size
|
return stage % group_size
|
||||||
|
|
|
||||||
|
|
@ -151,7 +151,7 @@ class ScheduleTest(MultiProcContinousTest):
|
||||||
schedule.step(x)
|
schedule.step(x)
|
||||||
elif self.rank == self.world_size - 1:
|
elif self.rank == self.world_size - 1:
|
||||||
losses = []
|
losses = []
|
||||||
out = schedule.step(target=target, losses=losses)
|
schedule.step(target=target, losses=losses)
|
||||||
else:
|
else:
|
||||||
schedule.step()
|
schedule.step()
|
||||||
|
|
||||||
|
|
@ -412,7 +412,6 @@ class ScheduleTest(MultiProcContinousTest):
|
||||||
if hasattr(ScheduleClass, "num_microbatches")
|
if hasattr(ScheduleClass, "num_microbatches")
|
||||||
else 8
|
else 8
|
||||||
)
|
)
|
||||||
input_args = x.chunk(num_microbatches)[0]
|
|
||||||
stages = [
|
stages = [
|
||||||
PipelineStage(
|
PipelineStage(
|
||||||
stage_module,
|
stage_module,
|
||||||
|
|
@ -548,7 +547,6 @@ class ScheduleTest(MultiProcContinousTest):
|
||||||
loss_fn = torch.nn.MSELoss(reduction="sum")
|
loss_fn = torch.nn.MSELoss(reduction="sum")
|
||||||
|
|
||||||
# Create a pipeline stage to wrap that submodule
|
# Create a pipeline stage to wrap that submodule
|
||||||
input_args = x.chunk(num_microbatches)[0]
|
|
||||||
stage_indices = rank_stages[self.rank]
|
stage_indices = rank_stages[self.rank]
|
||||||
print(f"Rank {self.rank} stages: {stage_indices}")
|
print(f"Rank {self.rank} stages: {stage_indices}")
|
||||||
submod_names = [f"layers.{i}" for i in stage_indices]
|
submod_names = [f"layers.{i}" for i in stage_indices]
|
||||||
|
|
@ -582,7 +580,7 @@ class ScheduleTest(MultiProcContinousTest):
|
||||||
schedule.step(x)
|
schedule.step(x)
|
||||||
elif self.rank == self.world_size - 1:
|
elif self.rank == self.world_size - 1:
|
||||||
losses = []
|
losses = []
|
||||||
out = schedule.step(target=target, losses=losses)
|
schedule.step(target=target, losses=losses)
|
||||||
else:
|
else:
|
||||||
schedule.step()
|
schedule.step()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
|
@ -887,7 +885,6 @@ class ScheduleTest(MultiProcContinousTest):
|
||||||
|
|
||||||
# Create a pipeline stage to wrap that submodule
|
# Create a pipeline stage to wrap that submodule
|
||||||
chunks = 2
|
chunks = 2
|
||||||
input_args = x.chunk(chunks)[0]
|
|
||||||
stages = [
|
stages = [
|
||||||
PipelineStage(
|
PipelineStage(
|
||||||
stage_module,
|
stage_module,
|
||||||
|
|
|
||||||
|
|
@ -310,9 +310,6 @@ class StageTest(MultiProcContinousTest):
|
||||||
full_mod.to(self.device)
|
full_mod.to(self.device)
|
||||||
stage_mod = full_mod.get_submodule(f"layers.{self.rank}")
|
stage_mod = full_mod.get_submodule(f"layers.{self.rank}")
|
||||||
|
|
||||||
x = torch.randn(batch_size, d_hid, device=self.device)
|
|
||||||
target = torch.randn(batch_size, d_hid, device=self.device)
|
|
||||||
|
|
||||||
stage_with_dw_builder = PipelineStage(
|
stage_with_dw_builder = PipelineStage(
|
||||||
stage_mod,
|
stage_mod,
|
||||||
self.rank,
|
self.rank,
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,7 @@ class UnflattenTests(TestCase):
|
||||||
# Check qualnames
|
# Check qualnames
|
||||||
for stage_idx in range(pipe.num_stages):
|
for stage_idx in range(pipe.num_stages):
|
||||||
stage_mod = pipe.get_stage_module(stage_idx)
|
stage_mod = pipe.get_stage_module(stage_idx)
|
||||||
for param_name, param in stage_mod.named_parameters():
|
for param_name, _ in stage_mod.named_parameters():
|
||||||
assert (
|
assert (
|
||||||
param_name in orig_state_dict
|
param_name in orig_state_dict
|
||||||
), f"{param_name} not in original state dict"
|
), f"{param_name} not in original state dict"
|
||||||
|
|
|
||||||
|
|
@ -87,7 +87,9 @@ class MicroPipelineTPTest(TestCase):
|
||||||
a = all_gather_tensor(inp, gather_dim=0, group=group.group_name)
|
a = all_gather_tensor(inp, gather_dim=0, group=group.group_name)
|
||||||
b = all_gather_tensor(inp, gather_dim=1, group=group.group_name)
|
b = all_gather_tensor(inp, gather_dim=1, group=group.group_name)
|
||||||
c = _fp8_all_gather(inp, gather_dim=0, group_name=group.group_name)
|
c = _fp8_all_gather(inp, gather_dim=0, group_name=group.group_name)
|
||||||
d = _fp8_all_gather(inp, gather_dim=1, group_name=group.group_name)
|
d = _fp8_all_gather( # noqa: F841
|
||||||
|
inp, gather_dim=1, group_name=group.group_name
|
||||||
|
)
|
||||||
return a, b, c
|
return a, b, c
|
||||||
|
|
||||||
inp = torch.rand(64, 32, device="cuda")
|
inp = torch.rand(64, 32, device="cuda")
|
||||||
|
|
|
||||||
|
|
@ -311,7 +311,7 @@ class DistTensorParallelExampleTest(DTensorTestBase):
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
steps = 10 if type(model) is torch.float64 else 1
|
steps = 10 if type(model) is torch.float64 else 1
|
||||||
for iter in range(steps):
|
for _ in range(steps):
|
||||||
inp = torch.randint(
|
inp = torch.randint(
|
||||||
model_args.vocab_size, inp_size, device=self.device_type
|
model_args.vocab_size, inp_size, device=self.device_type
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -223,7 +223,7 @@ class TensorParallelStyleTest(DTensorTestBase):
|
||||||
AssertionError,
|
AssertionError,
|
||||||
"input_layouts and desired_input_layouts should have same length!",
|
"input_layouts and desired_input_layouts should have same length!",
|
||||||
):
|
):
|
||||||
prepare_inps_dimension_mismatch = PrepareModuleInput(
|
PrepareModuleInput(
|
||||||
input_layouts=Shard(0), desired_input_layouts=(Replicate(), None)
|
input_layouts=Shard(0), desired_input_layouts=(Replicate(), None)
|
||||||
)
|
)
|
||||||
# Raise assertion error if module inputs and input_layouts do not have same length.
|
# Raise assertion error if module inputs and input_layouts do not have same length.
|
||||||
|
|
|
||||||
|
|
@ -182,7 +182,7 @@ class TimeoutTest(TestCase):
|
||||||
threads.append(t)
|
threads.append(t)
|
||||||
t.start()
|
t.start()
|
||||||
|
|
||||||
for i, thread in enumerate(threads):
|
for _, thread in enumerate(threads):
|
||||||
thread.join()
|
thread.join()
|
||||||
|
|
||||||
# we expect the world_size-1 threads to have failed
|
# we expect the world_size-1 threads to have failed
|
||||||
|
|
@ -583,14 +583,14 @@ class CommonDistributedDataParallelTest:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
with err_ctx:
|
with err_ctx:
|
||||||
model = self._test_ddp_checkpointing(
|
self._test_ddp_checkpointing(
|
||||||
self.CheckpointOnceModule(use_reentrant=use_reentrant),
|
self.CheckpointOnceModule(use_reentrant=use_reentrant),
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
use_bucket_view=use_bucket_view,
|
use_bucket_view=use_bucket_view,
|
||||||
find_unused_parameters=True,
|
find_unused_parameters=True,
|
||||||
)
|
)
|
||||||
# test passes when static_graph is true
|
# test passes when static_graph is true
|
||||||
model = self._test_ddp_checkpointing(
|
self._test_ddp_checkpointing(
|
||||||
self.CheckpointOnceModule(use_reentrant=use_reentrant),
|
self.CheckpointOnceModule(use_reentrant=use_reentrant),
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
use_bucket_view=use_bucket_view,
|
use_bucket_view=use_bucket_view,
|
||||||
|
|
@ -615,7 +615,7 @@ class CommonDistributedDataParallelTest:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
with err_ctx:
|
with err_ctx:
|
||||||
model = self._test_ddp_checkpointing(
|
self._test_ddp_checkpointing(
|
||||||
self.CheckpointTwiceModule(use_reentrant=use_reentrant),
|
self.CheckpointTwiceModule(use_reentrant=use_reentrant),
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
use_bucket_view=use_bucket_view,
|
use_bucket_view=use_bucket_view,
|
||||||
|
|
@ -623,7 +623,7 @@ class CommonDistributedDataParallelTest:
|
||||||
)
|
)
|
||||||
|
|
||||||
with err_ctx:
|
with err_ctx:
|
||||||
model = self._test_ddp_checkpointing(
|
self._test_ddp_checkpointing(
|
||||||
self.CheckpointTwiceModule(use_reentrant=use_reentrant),
|
self.CheckpointTwiceModule(use_reentrant=use_reentrant),
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
use_bucket_view=use_bucket_view,
|
use_bucket_view=use_bucket_view,
|
||||||
|
|
@ -641,7 +641,7 @@ class CommonDistributedDataParallelTest:
|
||||||
process_group = self._get_process_group()
|
process_group = self._get_process_group()
|
||||||
for use_bucket_view in (True, False):
|
for use_bucket_view in (True, False):
|
||||||
# Test passes when static_graph=True.
|
# Test passes when static_graph=True.
|
||||||
model = self._test_ddp_checkpointing(
|
self._test_ddp_checkpointing(
|
||||||
self.CheckpointTwiceModule(use_reentrant=use_reentrant),
|
self.CheckpointTwiceModule(use_reentrant=use_reentrant),
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
use_bucket_view=use_bucket_view,
|
use_bucket_view=use_bucket_view,
|
||||||
|
|
@ -656,7 +656,7 @@ class CommonDistributedDataParallelTest:
|
||||||
"""
|
"""
|
||||||
process_group = self._get_process_group()
|
process_group = self._get_process_group()
|
||||||
for use_bucket_view in (True, False):
|
for use_bucket_view in (True, False):
|
||||||
model = self._test_ddp_checkpointing(
|
self._test_ddp_checkpointing(
|
||||||
self.DynamicCheckpointTwiceModule(use_reentrant=False),
|
self.DynamicCheckpointTwiceModule(use_reentrant=False),
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
use_bucket_view=use_bucket_view,
|
use_bucket_view=use_bucket_view,
|
||||||
|
|
@ -675,7 +675,7 @@ class CommonDistributedDataParallelTest:
|
||||||
"""
|
"""
|
||||||
process_group = self._get_process_group()
|
process_group = self._get_process_group()
|
||||||
for use_bucket_view in (True, False):
|
for use_bucket_view in (True, False):
|
||||||
model = self._test_ddp_checkpointing(
|
self._test_ddp_checkpointing(
|
||||||
self.DynamicCheckpointTwiceModuleWeightSharing(use_reentrant=False),
|
self.DynamicCheckpointTwiceModuleWeightSharing(use_reentrant=False),
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
use_bucket_view=use_bucket_view,
|
use_bucket_view=use_bucket_view,
|
||||||
|
|
@ -719,7 +719,7 @@ class CommonDistributedDataParallelTest:
|
||||||
process_group = self._get_process_group()
|
process_group = self._get_process_group()
|
||||||
torch.cuda.set_device(self.rank)
|
torch.cuda.set_device(self.rank)
|
||||||
for use_bucket_view in (True, False):
|
for use_bucket_view in (True, False):
|
||||||
model = self._test_ddp_checkpointing(
|
self._test_ddp_checkpointing(
|
||||||
self.CheckpointTwiceModuleWeightSharing(),
|
self.CheckpointTwiceModuleWeightSharing(),
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
use_bucket_view=use_bucket_view,
|
use_bucket_view=use_bucket_view,
|
||||||
|
|
@ -737,7 +737,7 @@ class CommonDistributedDataParallelTest:
|
||||||
"Expect `start_powerSGD_iter` > 1 if `use_error_feedback` or `warm_start` is enabled, "
|
"Expect `start_powerSGD_iter` > 1 if `use_error_feedback` or `warm_start` is enabled, "
|
||||||
"because PowerSGD can only be applied after the first two iterations in DDP.",
|
"because PowerSGD can only be applied after the first two iterations in DDP.",
|
||||||
):
|
):
|
||||||
state = powerSGD.PowerSGDState(
|
powerSGD.PowerSGDState(
|
||||||
process_group=None,
|
process_group=None,
|
||||||
matrix_approximation_rank=1,
|
matrix_approximation_rank=1,
|
||||||
start_powerSGD_iter=start_powerSGD_iter,
|
start_powerSGD_iter=start_powerSGD_iter,
|
||||||
|
|
|
||||||
|
|
@ -429,7 +429,7 @@ class TestWithNCCL(MultiProcessTestCase):
|
||||||
|
|
||||||
input = torch.full((10, 10), float(self.rank), device=self.device)
|
input = torch.full((10, 10), float(self.rank), device=self.device)
|
||||||
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
|
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
|
||||||
output = torch.ops._c10d_functional.all_reduce(
|
torch.ops._c10d_functional.all_reduce(
|
||||||
input,
|
input,
|
||||||
"avg",
|
"avg",
|
||||||
"default",
|
"default",
|
||||||
|
|
@ -550,7 +550,7 @@ class CompileTest(TestCase):
|
||||||
assert "= torch.ops._c10d_functional.wait_tensor.default" not in code
|
assert "= torch.ops._c10d_functional.wait_tensor.default" not in code
|
||||||
|
|
||||||
# Test aoti
|
# Test aoti
|
||||||
out = AOTIRunnerUtil.run("cuda", func, (arg,))
|
AOTIRunnerUtil.run("cuda", func, (arg,))
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
|
|
@ -596,7 +596,7 @@ class CompileTest(TestCase):
|
||||||
assert "= torch.ops._c10d_functional.wait_tensor.default" not in code
|
assert "= torch.ops._c10d_functional.wait_tensor.default" not in code
|
||||||
|
|
||||||
# Test aoti
|
# Test aoti
|
||||||
out = AOTIRunnerUtil.run("cuda", func, (args,))
|
out = AOTIRunnerUtil.run("cuda", func, (args,)) # noqa: F841
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
|
|
@ -708,7 +708,7 @@ class CompileTest(TestCase):
|
||||||
assert "= torch.ops._c10d_functional.wait_tensor.default" not in code
|
assert "= torch.ops._c10d_functional.wait_tensor.default" not in code
|
||||||
|
|
||||||
# Test aoti
|
# Test aoti
|
||||||
out = AOTIRunnerUtil.run("cuda", func, (arg,))
|
AOTIRunnerUtil.run("cuda", func, (arg,))
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
|
|
@ -742,7 +742,7 @@ class CompileTest(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test aoti
|
# Test aoti
|
||||||
out = AOTIRunnerUtil.run("cuda", func, (args,))
|
out = AOTIRunnerUtil.run("cuda", func, (args,)) # noqa: F841
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "This is a GPU test!")
|
@unittest.skipIf(not HAS_GPU, "This is a GPU test!")
|
||||||
|
|
@ -764,7 +764,7 @@ class CompileTest(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test aoti
|
# Test aoti
|
||||||
out = AOTIRunnerUtil.run("cuda", func, (arg,))
|
AOTIRunnerUtil.run("cuda", func, (arg,))
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
|
|
@ -790,7 +790,7 @@ class CompileTest(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test aoti
|
# Test aoti
|
||||||
out = AOTIRunnerUtil.run("cuda", func, (arg,))
|
AOTIRunnerUtil.run("cuda", func, (arg,))
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
|
|
@ -910,7 +910,7 @@ class CompileTest(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test aoti
|
# Test aoti
|
||||||
out = AOTIRunnerUtil.run("cuda", func, (arg,))
|
AOTIRunnerUtil.run("cuda", func, (arg,))
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
|
|
|
||||||
|
|
@ -1920,7 +1920,7 @@ class DistributedDataParallelTest(
|
||||||
ddp_state_dict = torch.load(checkpoint_path, map_location=map_location)
|
ddp_state_dict = torch.load(checkpoint_path, map_location=map_location)
|
||||||
|
|
||||||
for model in [ddp_withload, model_withload]:
|
for model in [ddp_withload, model_withload]:
|
||||||
for p in ddp_withload.parameters():
|
for p in model.parameters():
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
p.zero_()
|
p.zero_()
|
||||||
ddp_withload.load_state_dict(ddp_state_dict)
|
ddp_withload.load_state_dict(ddp_state_dict)
|
||||||
|
|
@ -1973,7 +1973,8 @@ class DistributedDataParallelTest(
|
||||||
This unit test verifies whether the Future object is passed properly.
|
This unit test verifies whether the Future object is passed properly.
|
||||||
The callback function creates a Future object and sets a value to it.
|
The callback function creates a Future object and sets a value to it.
|
||||||
"""
|
"""
|
||||||
store = c10d.FileStore(self.file_name, self.world_size)
|
store = c10d.FileStore(self.file_name, self.world_size) # noqa: F841
|
||||||
|
|
||||||
process_group = self._get_process_group()
|
process_group = self._get_process_group()
|
||||||
|
|
||||||
# Test on CPU
|
# Test on CPU
|
||||||
|
|
|
||||||
|
|
@ -366,7 +366,7 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
# We would get stuck here due to d2h if we didn't abort.
|
# We would get stuck here due to d2h if we didn't abort.
|
||||||
t_cpu = t.cpu()
|
t.cpu()
|
||||||
|
|
||||||
thread.join()
|
thread.join()
|
||||||
|
|
||||||
|
|
@ -741,7 +741,7 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||||
# First allreduce to initialize default PG's communicator.
|
# First allreduce to initialize default PG's communicator.
|
||||||
pg.allreduce(t).wait()
|
pg.allreduce(t).wait()
|
||||||
# PG1 is an PG without comms initialized, since we don't call collective on it
|
# PG1 is an PG without comms initialized, since we don't call collective on it
|
||||||
new_pg1 = c10d.new_group([0, 1])
|
new_pg1 = c10d.new_group([0, 1]) # noqa: F841
|
||||||
new_pg2 = c10d.new_group([0, 1])
|
new_pg2 = c10d.new_group([0, 1])
|
||||||
t2 = torch.rand(10, 10, device=device)
|
t2 = torch.rand(10, 10, device=device)
|
||||||
|
|
||||||
|
|
@ -807,7 +807,7 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||||
# 'timeout' kwarg (or its kwdefault) taking precedence
|
# 'timeout' kwarg (or its kwdefault) taking precedence
|
||||||
opts = dist.ProcessGroupNCCL.Options()
|
opts = dist.ProcessGroupNCCL.Options()
|
||||||
opts._timeout = timedelta(seconds=123)
|
opts._timeout = timedelta(seconds=123)
|
||||||
with warnings.catch_warnings(record=True) as w:
|
with warnings.catch_warnings(record=True):
|
||||||
dist.init_process_group(**base_opts, pg_options=opts)
|
dist.init_process_group(**base_opts, pg_options=opts)
|
||||||
# TODO(whc) i verified that we are indeed emitting this warning, and i can't figure out why i can't catch it.
|
# TODO(whc) i verified that we are indeed emitting this warning, and i can't figure out why i can't catch it.
|
||||||
# self.assertEqual(len(w), 1)
|
# self.assertEqual(len(w), 1)
|
||||||
|
|
@ -1266,30 +1266,26 @@ class DistributedDataParallelTest(
|
||||||
"DistributedDataParallel device_ids and output_device arguments only work with "
|
"DistributedDataParallel device_ids and output_device arguments only work with "
|
||||||
"single-device/multiple-device GPU modules or CPU modules",
|
"single-device/multiple-device GPU modules or CPU modules",
|
||||||
):
|
):
|
||||||
ddp_model = DistributedDataParallel(
|
DistributedDataParallel(
|
||||||
model, output_device=gpus[1], process_group=process_group
|
model, output_device=gpus[1], process_group=process_group
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, "device_ids can only be None or contain a single element."
|
ValueError, "device_ids can only be None or contain a single element."
|
||||||
):
|
):
|
||||||
ddp_model = DistributedDataParallel(
|
DistributedDataParallel(model, device_ids=gpus, process_group=process_group)
|
||||||
model, device_ids=gpus, process_group=process_group
|
|
||||||
)
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, "input module must be on the same type of devices"
|
ValueError, "input module must be on the same type of devices"
|
||||||
):
|
):
|
||||||
model.fc1 = model.fc1.cpu()
|
model.fc1 = model.fc1.cpu()
|
||||||
ddp_model = DistributedDataParallel(model, process_group=process_group)
|
DistributedDataParallel(model, process_group=process_group)
|
||||||
|
|
||||||
model = model.cpu()
|
model = model.cpu()
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, "device_ids can only be None or contain a single element."
|
ValueError, "device_ids can only be None or contain a single element."
|
||||||
):
|
):
|
||||||
ddp_model = DistributedDataParallel(
|
DistributedDataParallel(model, device_ids=gpus, process_group=process_group)
|
||||||
model, device_ids=gpus, process_group=process_group
|
|
||||||
)
|
|
||||||
|
|
||||||
def _test_fp16(self, gradient_as_bucket_view=False):
|
def _test_fp16(self, gradient_as_bucket_view=False):
|
||||||
process_group = self._get_process_group()
|
process_group = self._get_process_group()
|
||||||
|
|
@ -1940,11 +1936,9 @@ class DistributedDataParallelTest(
|
||||||
),
|
),
|
||||||
named_msg,
|
named_msg,
|
||||||
)
|
)
|
||||||
for j, ((param_name, p), p_ddp) in enumerate(
|
for (param_name, p), p_ddp in zip(
|
||||||
zip(
|
m_child.named_parameters(),
|
||||||
m_child.named_parameters(),
|
m_ddp_child.parameters(),
|
||||||
m_ddp_child.parameters(),
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
named_msg = (
|
named_msg = (
|
||||||
layer_name + "." + param_name + " " + iter_msg
|
layer_name + "." + param_name + " " + iter_msg
|
||||||
|
|
@ -2010,15 +2004,13 @@ class DistributedDataParallelTest(
|
||||||
|
|
||||||
m = ConvNet(layer_devs, layer_formats, layer_dtypes)
|
m = ConvNet(layer_devs, layer_formats, layer_dtypes)
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
m_ddp = DistributedDataParallel(
|
DistributedDataParallel(m, device_ids=[dev0], process_group=process_group)
|
||||||
m, device_ids=[dev0], process_group=process_group
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError,
|
RuntimeError,
|
||||||
".* appears not to match strides of the same param in process 0",
|
".* appears not to match strides of the same param in process 0",
|
||||||
):
|
):
|
||||||
m_ddp = DistributedDataParallel(
|
DistributedDataParallel(
|
||||||
m, device_ids=[dev0], process_group=process_group
|
m, device_ids=[dev0], process_group=process_group
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -2356,7 +2348,7 @@ class DistributedDataParallelTest(
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
for i in range(3):
|
for _ in range(3):
|
||||||
m.zero_grad(set_to_none=try_set_to_none)
|
m.zero_grad(set_to_none=try_set_to_none)
|
||||||
m(1).sum().backward()
|
m(1).sum().backward()
|
||||||
|
|
||||||
|
|
@ -2701,7 +2693,7 @@ class WorkHookTest(MultiProcessTestCase):
|
||||||
pg._register_on_completion_hook(hook)
|
pg._register_on_completion_hook(hook)
|
||||||
tensor = torch.ones([2, 3]).cuda(self.rank) * self.rank
|
tensor = torch.ones([2, 3]).cuda(self.rank) * self.rank
|
||||||
work_count = 3
|
work_count = 3
|
||||||
for i in range(work_count):
|
for _ in range(work_count):
|
||||||
work += 1
|
work += 1
|
||||||
pg.broadcast([tensor]).wait()
|
pg.broadcast([tensor]).wait()
|
||||||
|
|
||||||
|
|
@ -2806,7 +2798,7 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||||
# Run some GPU operations to make sure cuda has not gotten stuck.
|
# Run some GPU operations to make sure cuda has not gotten stuck.
|
||||||
# It was observed cuda could get stuck if NCCL communicators were
|
# It was observed cuda could get stuck if NCCL communicators were
|
||||||
# not properly aborted before throwing RuntimeError.
|
# not properly aborted before throwing RuntimeError.
|
||||||
a = torch.rand(10).cuda(self.rank)
|
torch.rand(10).cuda(self.rank)
|
||||||
elif self.rank == 1:
|
elif self.rank == 1:
|
||||||
# Clean up structures (ex: files for FileStore before going down)
|
# Clean up structures (ex: files for FileStore before going down)
|
||||||
del process_group
|
del process_group
|
||||||
|
|
@ -2947,7 +2939,7 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||||
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = val
|
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = val
|
||||||
store = c10d.FileStore(self.file_name, self.world_size)
|
store = c10d.FileStore(self.file_name, self.world_size)
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
|
c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
|
||||||
|
|
||||||
@requires_nccl()
|
@requires_nccl()
|
||||||
@skip_if_lt_x_gpu(3)
|
@skip_if_lt_x_gpu(3)
|
||||||
|
|
@ -4223,7 +4215,7 @@ class NCCLTraceTestBase(MultiProcessTestCase):
|
||||||
def _join_processes(self, fn):
|
def _join_processes(self, fn):
|
||||||
# We need to patch sys.exit() as skip_if will use sys.exit() and
|
# We need to patch sys.exit() as skip_if will use sys.exit() and
|
||||||
# the exit code from the this process will not be catched.
|
# the exit code from the this process will not be catched.
|
||||||
with mock.patch("sys.exit") as exit_mock:
|
with mock.patch("sys.exit"):
|
||||||
fn()
|
fn()
|
||||||
super()._join_processes(fn)
|
super()._join_processes(fn)
|
||||||
|
|
||||||
|
|
@ -4231,7 +4223,7 @@ class NCCLTraceTestBase(MultiProcessTestCase):
|
||||||
proc = torch.multiprocessing.get_context("spawn").Process
|
proc = torch.multiprocessing.get_context("spawn").Process
|
||||||
self.children_pipes = []
|
self.children_pipes = []
|
||||||
parent_pipes = []
|
parent_pipes = []
|
||||||
for i in range(self.world_size):
|
for _ in range(self.world_size):
|
||||||
parent_conn, child_conn = torch.multiprocessing.Pipe()
|
parent_conn, child_conn = torch.multiprocessing.Pipe()
|
||||||
self.children_pipes.append(child_conn)
|
self.children_pipes.append(child_conn)
|
||||||
parent_pipes.append(parent_conn)
|
parent_pipes.append(parent_conn)
|
||||||
|
|
@ -4346,7 +4338,7 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||||
pg._enable_collectives_timing()
|
pg._enable_collectives_timing()
|
||||||
device = self.local_device
|
device = self.local_device
|
||||||
a = torch.full((3, 4), float(self.rank), device=device)
|
a = torch.full((3, 4), float(self.rank), device=device)
|
||||||
for i in range(2):
|
for _ in range(2):
|
||||||
f = pg.allreduce(a)
|
f = pg.allreduce(a)
|
||||||
f.wait()
|
f.wait()
|
||||||
torch.cuda.synchronize(device=device)
|
torch.cuda.synchronize(device=device)
|
||||||
|
|
@ -4372,7 +4364,7 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||||
pg._enable_collectives_timing()
|
pg._enable_collectives_timing()
|
||||||
device = self.local_device
|
device = self.local_device
|
||||||
a = torch.full((3, 4), float(self.rank), device=device)
|
a = torch.full((3, 4), float(self.rank), device=device)
|
||||||
for i in range(2):
|
for _ in range(2):
|
||||||
f = pg.allreduce(a)
|
f = pg.allreduce(a)
|
||||||
f.wait()
|
f.wait()
|
||||||
torch.cuda.synchronize(device=device)
|
torch.cuda.synchronize(device=device)
|
||||||
|
|
@ -4420,7 +4412,7 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||||
pg = self._create_process_group_nccl()
|
pg = self._create_process_group_nccl()
|
||||||
device = self.local_device
|
device = self.local_device
|
||||||
a = torch.full((3, 4), float(self.rank), device=device)
|
a = torch.full((3, 4), float(self.rank), device=device)
|
||||||
for i in range(2):
|
for _ in range(2):
|
||||||
f = pg.allreduce(a)
|
f = pg.allreduce(a)
|
||||||
f.wait()
|
f.wait()
|
||||||
torch.cuda.synchronize(device=device)
|
torch.cuda.synchronize(device=device)
|
||||||
|
|
@ -4436,7 +4428,7 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||||
pg = self._create_process_group_nccl()
|
pg = self._create_process_group_nccl()
|
||||||
device = self.local_device
|
device = self.local_device
|
||||||
a = torch.full((3, 4), float(self.rank), device=device)
|
a = torch.full((3, 4), float(self.rank), device=device)
|
||||||
for i in range(2):
|
for _ in range(2):
|
||||||
# test some other primitives to make sure
|
# test some other primitives to make sure
|
||||||
# their strings are valid
|
# their strings are valid
|
||||||
xs = [torch.ones(3, 4, device=device)]
|
xs = [torch.ones(3, 4, device=device)]
|
||||||
|
|
@ -4496,7 +4488,7 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||||
pg = self._create_process_group_nccl()
|
pg = self._create_process_group_nccl()
|
||||||
device = self.local_device
|
device = self.local_device
|
||||||
# send more works than the buffer size to overwrite the previous entry
|
# send more works than the buffer size to overwrite the previous entry
|
||||||
for i in range(12):
|
for _ in range(12):
|
||||||
a = [torch.ones(3, 4, device=device)]
|
a = [torch.ones(3, 4, device=device)]
|
||||||
pg.broadcast(a).wait()
|
pg.broadcast(a).wait()
|
||||||
torch.cuda.synchronize(device=device)
|
torch.cuda.synchronize(device=device)
|
||||||
|
|
@ -4611,7 +4603,7 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||||
th.start()
|
th.start()
|
||||||
# fill the cuda buffer, at around 1024 events
|
# fill the cuda buffer, at around 1024 events
|
||||||
# this will stall
|
# this will stall
|
||||||
for i in range(2000):
|
for _ in range(2000):
|
||||||
a = a + a
|
a = a + a
|
||||||
th.join()
|
th.join()
|
||||||
else:
|
else:
|
||||||
|
|
@ -4646,7 +4638,7 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||||
|
|
||||||
num_coalesced_ops = 20
|
num_coalesced_ops = 20
|
||||||
ops_per_coalesce = len(op_sizes_per_coalesce)
|
ops_per_coalesce = len(op_sizes_per_coalesce)
|
||||||
for i in range(num_coalesced_ops):
|
for _ in range(num_coalesced_ops):
|
||||||
ops = []
|
ops = []
|
||||||
for input_sizes in op_sizes_per_coalesce:
|
for input_sizes in op_sizes_per_coalesce:
|
||||||
tensor = torch.zeros(input_sizes).to(self.local_device)
|
tensor = torch.zeros(input_sizes).to(self.local_device)
|
||||||
|
|
@ -4745,7 +4737,7 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||||
pg._enable_collectives_timing()
|
pg._enable_collectives_timing()
|
||||||
num_repeats = 10
|
num_repeats = 10
|
||||||
ops_per_repeat = len(op_sizes)
|
ops_per_repeat = len(op_sizes)
|
||||||
for i in range(num_repeats):
|
for _ in range(num_repeats):
|
||||||
for input_sizes in op_sizes:
|
for input_sizes in op_sizes:
|
||||||
tensor = torch.zeros(input_sizes).to(self.local_device)
|
tensor = torch.zeros(input_sizes).to(self.local_device)
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
|
|
@ -5047,7 +5039,7 @@ class NcclErrorDumpTest(NCCLTraceTestBase):
|
||||||
# Block the current stream on the NCCL stream
|
# Block the current stream on the NCCL stream
|
||||||
work.wait()
|
work.wait()
|
||||||
# Run some GPU operations
|
# Run some GPU operations
|
||||||
a = torch.rand(10).cuda(self.rank)
|
torch.rand(10).cuda(self.rank)
|
||||||
elif self.rank == 1:
|
elif self.rank == 1:
|
||||||
# Clean up structures (ex: files for FileStore before going down)
|
# Clean up structures (ex: files for FileStore before going down)
|
||||||
del process_group
|
del process_group
|
||||||
|
|
@ -5108,7 +5100,6 @@ class ProcessGroupNCCLLargerScaleTest(MultiProcessTestCase):
|
||||||
|
|
||||||
tensor = torch.full((1,), self.rank).cuda(device)
|
tensor = torch.full((1,), self.rank).cuda(device)
|
||||||
ng1 = c10d.split_group(pg, [[0, 1], [2, 3, 4, 5, 6, 7]])
|
ng1 = c10d.split_group(pg, [[0, 1], [2, 3, 4, 5, 6, 7]])
|
||||||
backend1 = ng1._get_backend(torch.device(device))
|
|
||||||
|
|
||||||
# comm split happens eagerly since device_id is passed to init_process_group.
|
# comm split happens eagerly since device_id is passed to init_process_group.
|
||||||
self.assertEqual(backend.comm_split_count(), 1)
|
self.assertEqual(backend.comm_split_count(), 1)
|
||||||
|
|
|
||||||
|
|
@ -162,7 +162,6 @@ class ProcessGroupNCCLOpTest(MultiProcContinousTest):
|
||||||
@requires_nccl()
|
@requires_nccl()
|
||||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||||
def test_allreduce_ops(self):
|
def test_allreduce_ops(self):
|
||||||
device_count = torch.cuda.device_count()
|
|
||||||
pg = self.pg
|
pg = self.pg
|
||||||
local_device_id = self.rank_to_GPU[self.rank][0]
|
local_device_id = self.rank_to_GPU[self.rank][0]
|
||||||
|
|
||||||
|
|
@ -303,9 +302,8 @@ class ProcessGroupNCCLOpTest(MultiProcContinousTest):
|
||||||
pg = self.pg
|
pg = self.pg
|
||||||
rank = self.rank_to_GPU[self.rank][0]
|
rank = self.rank_to_GPU[self.rank][0]
|
||||||
with torch.cuda.device(rank):
|
with torch.cuda.device(rank):
|
||||||
for i in range(10):
|
for _ in range(10):
|
||||||
xs = [torch.FloatTensor([1]).cuda(rank)]
|
xs = [torch.FloatTensor([1]).cuda(rank)]
|
||||||
ys = [torch.FloatTensor([4]).cuda(rank)]
|
|
||||||
for _ in range(30):
|
for _ in range(30):
|
||||||
pg.allreduce(xs[0]).wait()
|
pg.allreduce(xs[0]).wait()
|
||||||
|
|
||||||
|
|
@ -410,7 +408,7 @@ class ProcessGroupNCCLOpTest(MultiProcContinousTest):
|
||||||
output_tensors.append([t.cuda(device=gpu) for t in output_per_gpu])
|
output_tensors.append([t.cuda(device=gpu) for t in output_per_gpu])
|
||||||
expected_output.append([t.cuda(device=gpu) for t in expected_per_gpu])
|
expected_output.append([t.cuda(device=gpu) for t in expected_per_gpu])
|
||||||
|
|
||||||
result = allgather(output_tensors, tensors)
|
allgather(output_tensors, tensors)
|
||||||
|
|
||||||
# Verification
|
# Verification
|
||||||
self.assertEqual(output_tensors, expected_output)
|
self.assertEqual(output_tensors, expected_output)
|
||||||
|
|
@ -558,7 +556,7 @@ class ProcessGroupNCCLOpTest(MultiProcContinousTest):
|
||||||
|
|
||||||
# init output
|
# init output
|
||||||
output_ts = []
|
output_ts = []
|
||||||
for rank in range(self.world_size):
|
for _ in range(self.world_size):
|
||||||
output_ts.append(torch.tensor([-1]).cuda(device_id))
|
output_ts.append(torch.tensor([-1]).cuda(device_id))
|
||||||
|
|
||||||
with self.assertRaisesRegex(ValueError, "invalid root rank"):
|
with self.assertRaisesRegex(ValueError, "invalid root rank"):
|
||||||
|
|
@ -914,7 +912,6 @@ class ProcessGroupNCCLOpTest(MultiProcContinousTest):
|
||||||
@requires_nccl()
|
@requires_nccl()
|
||||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||||
def test_send_recv(self):
|
def test_send_recv(self):
|
||||||
pg = self.pg
|
|
||||||
device = self.rank_to_GPU[self.rank][0]
|
device = self.rank_to_GPU[self.rank][0]
|
||||||
|
|
||||||
# Generate the same random tensor
|
# Generate the same random tensor
|
||||||
|
|
@ -930,7 +927,6 @@ class ProcessGroupNCCLOpTest(MultiProcContinousTest):
|
||||||
@requires_nccl()
|
@requires_nccl()
|
||||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||||
def test_send_recv_complex(self):
|
def test_send_recv_complex(self):
|
||||||
pg = self.pg
|
|
||||||
device = self.rank_to_GPU[self.rank][0]
|
device = self.rank_to_GPU[self.rank][0]
|
||||||
|
|
||||||
# Generate the same random tensor
|
# Generate the same random tensor
|
||||||
|
|
|
||||||
|
|
@ -755,7 +755,7 @@ class DistributedDataParallelTest(
|
||||||
ddp_state_dict = torch.load(checkpoint_path, map_location=map_location)
|
ddp_state_dict = torch.load(checkpoint_path, map_location=map_location)
|
||||||
|
|
||||||
for model in [ddp_withload, model_withload]:
|
for model in [ddp_withload, model_withload]:
|
||||||
for p in ddp_withload.parameters():
|
for p in model.parameters():
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
p.zero_()
|
p.zero_()
|
||||||
ddp_withload.load_state_dict(ddp_state_dict)
|
ddp_withload.load_state_dict(ddp_state_dict)
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ class TestCollectiveUtils(MultiProcessTestCase):
|
||||||
Ensure broadcast has no dependency on torch.distributed when run in single process.
|
Ensure broadcast has no dependency on torch.distributed when run in single process.
|
||||||
"""
|
"""
|
||||||
func = mock.MagicMock()
|
func = mock.MagicMock()
|
||||||
res = broadcast(data_or_fn=func, rank=0)
|
broadcast(data_or_fn=func, rank=0)
|
||||||
func.assert_called_once()
|
func.assert_called_once()
|
||||||
|
|
||||||
def test_broadcast_result_raises_exceptions_from_func(
|
def test_broadcast_result_raises_exceptions_from_func(
|
||||||
|
|
@ -98,7 +98,7 @@ class TestCollectiveUtils(MultiProcessTestCase):
|
||||||
Ensure all_gather has no dependency on torch.distributed when run in single process.
|
Ensure all_gather has no dependency on torch.distributed when run in single process.
|
||||||
"""
|
"""
|
||||||
func = mock.MagicMock()
|
func = mock.MagicMock()
|
||||||
res = all_gather(data_or_fn=func)
|
all_gather(data_or_fn=func)
|
||||||
func.assert_called_once()
|
func.assert_called_once()
|
||||||
|
|
||||||
def test_all_gather_result_raises_exceptions_from_func(
|
def test_all_gather_result_raises_exceptions_from_func(
|
||||||
|
|
|
||||||
|
|
@ -791,8 +791,8 @@ class TestDataParallel(TestCase):
|
||||||
),
|
),
|
||||||
named_msg,
|
named_msg,
|
||||||
)
|
)
|
||||||
for j, ((param_name, p), p_dp) in enumerate(
|
for (param_name, p), p_dp in zip(
|
||||||
zip(m_child.named_parameters(), m_dp_child.parameters())
|
m_child.named_parameters(), m_dp_child.parameters()
|
||||||
):
|
):
|
||||||
named_msg = (
|
named_msg = (
|
||||||
layer_name + "." + param_name + " " + iter_msg
|
layer_name + "." + param_name + " " + iter_msg
|
||||||
|
|
|
||||||
|
|
@ -88,7 +88,7 @@ class DeviceMeshTest(DTensorTestBase):
|
||||||
def test_assert_invalid_mesh_tensor(self):
|
def test_assert_invalid_mesh_tensor(self):
|
||||||
mesh = torch.arange(self.world_size).to(self.rank)
|
mesh = torch.arange(self.world_size).to(self.rank)
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
device_mesh = DeviceMesh(self.device_type, mesh)
|
DeviceMesh(self.device_type, mesh)
|
||||||
|
|
||||||
@with_comms()
|
@with_comms()
|
||||||
def test_2d_mesh_non_eager_init_subgroup(self):
|
def test_2d_mesh_non_eager_init_subgroup(self):
|
||||||
|
|
@ -144,7 +144,7 @@ class DeviceMeshTest(DTensorTestBase):
|
||||||
RuntimeError,
|
RuntimeError,
|
||||||
"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
|
"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
|
||||||
):
|
):
|
||||||
local_rank = mesh_2d.get_local_rank()
|
mesh_2d.get_local_rank()
|
||||||
|
|
||||||
@with_comms
|
@with_comms
|
||||||
def test_get_local_rank(self):
|
def test_get_local_rank(self):
|
||||||
|
|
@ -258,7 +258,7 @@ class DeviceMeshTest(DTensorTestBase):
|
||||||
):
|
):
|
||||||
# test init_device_mesh with an invalid device type that contains a GPU index
|
# test init_device_mesh with an invalid device type that contains a GPU index
|
||||||
mesh_shape = (2, self.world_size // 2)
|
mesh_shape = (2, self.world_size // 2)
|
||||||
mesh_2d = init_device_mesh(
|
init_device_mesh(
|
||||||
"cuda:0", mesh_shape=mesh_shape, mesh_dim_names=("dp", "tp")
|
"cuda:0", mesh_shape=mesh_shape, mesh_dim_names=("dp", "tp")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -453,7 +453,7 @@ class InitDeviceMeshTest(DTensorTestBase):
|
||||||
RuntimeError,
|
RuntimeError,
|
||||||
"Each mesh_dim_name must be unique.",
|
"Each mesh_dim_name must be unique.",
|
||||||
):
|
):
|
||||||
mesh = init_device_mesh(
|
init_device_mesh(
|
||||||
self.device_type,
|
self.device_type,
|
||||||
(2, 4),
|
(2, 4),
|
||||||
mesh_dim_names=["dp", "dp"],
|
mesh_dim_names=["dp", "dp"],
|
||||||
|
|
@ -465,7 +465,7 @@ class InitDeviceMeshTest(DTensorTestBase):
|
||||||
RuntimeError,
|
RuntimeError,
|
||||||
"mesh_shape and mesh_dim_names should have same length!",
|
"mesh_shape and mesh_dim_names should have same length!",
|
||||||
):
|
):
|
||||||
mesh = init_device_mesh(
|
init_device_mesh(
|
||||||
self.device_type,
|
self.device_type,
|
||||||
(8,),
|
(8,),
|
||||||
mesh_dim_names=["dp", "tp"],
|
mesh_dim_names=["dp", "tp"],
|
||||||
|
|
@ -483,7 +483,7 @@ class TestDeviceMeshGetItem(DTensorTestBase):
|
||||||
RuntimeError, "Cannot slice a DeviceMesh without mesh_dim_names!"
|
RuntimeError, "Cannot slice a DeviceMesh without mesh_dim_names!"
|
||||||
):
|
):
|
||||||
mesh = init_device_mesh(self.device_type, (2, 4))
|
mesh = init_device_mesh(self.device_type, (2, 4))
|
||||||
child_mesh = mesh["DP"]
|
mesh["DP"]
|
||||||
|
|
||||||
@with_comms
|
@with_comms
|
||||||
def test_raises_invalid_mesh_dim_name(self):
|
def test_raises_invalid_mesh_dim_name(self):
|
||||||
|
|
@ -493,7 +493,7 @@ class TestDeviceMeshGetItem(DTensorTestBase):
|
||||||
mesh = init_device_mesh(
|
mesh = init_device_mesh(
|
||||||
self.device_type, (2, 4), mesh_dim_names=mesh_dim_names
|
self.device_type, (2, 4), mesh_dim_names=mesh_dim_names
|
||||||
)
|
)
|
||||||
child_mesh = mesh[child_mesh_dim_name]
|
mesh[child_mesh_dim_name]
|
||||||
|
|
||||||
@with_comms
|
@with_comms
|
||||||
def test_get_item_2d(self):
|
def test_get_item_2d(self):
|
||||||
|
|
@ -514,7 +514,6 @@ class TestDeviceMeshGetItem(DTensorTestBase):
|
||||||
tp_group_idx = self.rank // 4
|
tp_group_idx = self.rank // 4
|
||||||
self.assertEqual(tp_mesh.mesh, pg_ranks_by_dim_name["TP"][tp_group_idx])
|
self.assertEqual(tp_mesh.mesh, pg_ranks_by_dim_name["TP"][tp_group_idx])
|
||||||
|
|
||||||
dp_mesh = mesh_2d["DP"]
|
|
||||||
dp_group_idx = self.rank % 4
|
dp_group_idx = self.rank % 4
|
||||||
self.assertEqual(mesh_2d["DP"].mesh, pg_ranks_by_dim_name["DP"][dp_group_idx])
|
self.assertEqual(mesh_2d["DP"].mesh, pg_ranks_by_dim_name["DP"][dp_group_idx])
|
||||||
|
|
||||||
|
|
@ -564,17 +563,15 @@ class TestDeviceMeshGetItem(DTensorTestBase):
|
||||||
def test_cache_and_reuse_submesh_slice_result(self):
|
def test_cache_and_reuse_submesh_slice_result(self):
|
||||||
mesh = init_device_mesh(self.device_type, (2, 4), mesh_dim_names=("dp", "tp"))
|
mesh = init_device_mesh(self.device_type, (2, 4), mesh_dim_names=("dp", "tp"))
|
||||||
|
|
||||||
dp_mesh = mesh["dp"]
|
|
||||||
ref_pg_count = _world.group_count
|
ref_pg_count = _world.group_count
|
||||||
|
|
||||||
# When we call the "dp" slice second time, it should not create any new pg.
|
# When we call the "dp" slice second time, it should not create any new pg.
|
||||||
# As we are just using the cached result so the pg count should be the same.
|
# As we are just using the cached result so the pg count should be the same.
|
||||||
dp_mesh_2 = mesh["dp"]
|
|
||||||
self.assertEqual(ref_pg_count, _world.group_count)
|
self.assertEqual(ref_pg_count, _world.group_count)
|
||||||
|
|
||||||
# When we call the "tp" slice, it should not create a new pg, as the "tp" slice would
|
# When we call the "tp" slice, it should not create a new pg, as the "tp" slice would
|
||||||
# just reuse the parent mesh pg.
|
# just reuse the parent mesh pg.
|
||||||
tp_mesh = mesh["tp"]
|
mesh["tp"]
|
||||||
self.assertEqual(_world.group_count, ref_pg_count)
|
self.assertEqual(_world.group_count, ref_pg_count)
|
||||||
|
|
||||||
@with_comms
|
@with_comms
|
||||||
|
|
@ -603,7 +600,7 @@ class TestDeviceMeshGetItem(DTensorTestBase):
|
||||||
KeyError,
|
KeyError,
|
||||||
"Invalid mesh_dim_names",
|
"Invalid mesh_dim_names",
|
||||||
):
|
):
|
||||||
cp_dp_mesh = mesh_3d["cp", "dp"]
|
mesh_3d["cp", "dp"]
|
||||||
|
|
||||||
@with_comms
|
@with_comms
|
||||||
def test_flatten_mesh_3d(self):
|
def test_flatten_mesh_3d(self):
|
||||||
|
|
@ -767,9 +764,9 @@ class TestMeshEnv(DTensorTestBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
with FakeTensorMode():
|
with FakeTensorMode():
|
||||||
dp_mesh = mesh_2d["DP"]
|
mesh_2d["DP"]
|
||||||
tp_mesh = mesh_2d["TP"]
|
mesh_2d["TP"]
|
||||||
dp_tp_mesh = mesh_2d["DP", "TP"]
|
mesh_2d["DP", "TP"]
|
||||||
|
|
||||||
|
|
||||||
class DeviceMeshCollectiveTest(DTensorTestBase):
|
class DeviceMeshCollectiveTest(DTensorTestBase):
|
||||||
|
|
|
||||||
|
|
@ -421,7 +421,7 @@ class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase):
|
||||||
self.weight2 = nn.Parameter(torch.randn(512, 512))
|
self.weight2 = nn.Parameter(torch.randn(512, 512))
|
||||||
|
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
u0, u1 = y.tolist()
|
u0, _ = y.tolist()
|
||||||
x = torch.cat([x, x])
|
x = torch.cat([x, x])
|
||||||
y = x @ self.weight1
|
y = x @ self.weight1
|
||||||
z = (x + y @ self.weight2) * u0
|
z = (x + y @ self.weight2) * u0
|
||||||
|
|
@ -442,7 +442,7 @@ class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase):
|
||||||
self.weight2 = nn.Parameter(torch.randn(512, 512))
|
self.weight2 = nn.Parameter(torch.randn(512, 512))
|
||||||
|
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
u0, u1 = y.tolist()
|
u0, _ = y.tolist()
|
||||||
a = torch.ones(u0)
|
a = torch.ones(u0)
|
||||||
x = torch.cat([x, x])
|
x = torch.cat([x, x])
|
||||||
y = x @ self.weight1
|
y = x @ self.weight1
|
||||||
|
|
@ -466,7 +466,7 @@ class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase):
|
||||||
|
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
# partition one (contains the u0 def)
|
# partition one (contains the u0 def)
|
||||||
u0, u1 = y.tolist()
|
u0, _ = y.tolist()
|
||||||
x = torch.cat([x, x])
|
x = torch.cat([x, x])
|
||||||
y1 = x @ self.weight1
|
y1 = x @ self.weight1
|
||||||
# partition two (contains the variable)
|
# partition two (contains the variable)
|
||||||
|
|
@ -511,7 +511,7 @@ class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
layers = []
|
layers = []
|
||||||
for l in range(2):
|
for _ in range(2):
|
||||||
layer = nn.ModuleList(
|
layer = nn.ModuleList(
|
||||||
[
|
[
|
||||||
nn.LayerNorm(96),
|
nn.LayerNorm(96),
|
||||||
|
|
@ -529,7 +529,7 @@ class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase):
|
||||||
for m in self.layers:
|
for m in self.layers:
|
||||||
x = x.reshape(B * F, T, H)
|
x = x.reshape(B * F, T, H)
|
||||||
x = m[0](x)
|
x = m[0](x)
|
||||||
x, attn = m[1].forward(x, x, x)
|
x, _ = m[1].forward(x, x, x)
|
||||||
x = x.reshape(B, F, T, H)
|
x = x.reshape(B, F, T, H)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
@ -937,8 +937,8 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||||
|
|
||||||
@torch.compile()
|
@torch.compile()
|
||||||
def f(x, y):
|
def f(x, y):
|
||||||
zx = x.shape
|
zx = x.shape # noqa: F841
|
||||||
zy = y.shape
|
zy = y.shape # noqa: F841
|
||||||
return x.sum() + y.sum()
|
return x.sum() + y.sum()
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
|
|
@ -967,10 +967,10 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||||
|
|
||||||
@torch.compile()
|
@torch.compile()
|
||||||
def f(x, y):
|
def f(x, y):
|
||||||
z = y
|
z = y # noqa: F841
|
||||||
print("woof")
|
print("woof")
|
||||||
zx = x.shape
|
zx = x.shape # noqa: F841
|
||||||
zy = y.shape
|
zy = y.shape # noqa: F841
|
||||||
return x.sum() + y.sum()
|
return x.sum() + y.sum()
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
|
|
@ -999,8 +999,8 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||||
|
|
||||||
@torch.compile()
|
@torch.compile()
|
||||||
def f(x, y):
|
def f(x, y):
|
||||||
zx = x.shape
|
zx = x.shape # noqa: F841
|
||||||
zy = y.shape
|
zy = y.shape # noqa: F841
|
||||||
return x.sum() + y.sum()
|
return x.sum() + y.sum()
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
|
|
@ -1405,7 +1405,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
|
||||||
model = DDP(model, device_ids=self.device_ids)
|
model = DDP(model, device_ids=self.device_ids)
|
||||||
|
|
||||||
hidden_states = torch.randn(B, S, H * D).to(device)
|
hidden_states = torch.randn(B, S, H * D).to(device)
|
||||||
attention_scores = model(hidden_states)
|
model(hidden_states)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
@patch.object(config, "optimize_ddp", True)
|
@patch.object(config, "optimize_ddp", True)
|
||||||
|
|
@ -1461,7 +1461,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
|
||||||
model = DDP(model, device_ids=self.device_ids)
|
model = DDP(model, device_ids=self.device_ids)
|
||||||
|
|
||||||
hidden_states = torch.randn(B, S, H * D).to(device)
|
hidden_states = torch.randn(B, S, H * D).to(device)
|
||||||
attention_scores = model(hidden_states)
|
model(hidden_states)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
@patch.object(config, "optimize_ddp", True)
|
@patch.object(config, "optimize_ddp", True)
|
||||||
|
|
@ -1723,7 +1723,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
|
||||||
|
|
||||||
def test_fsdp_orig_params_assert(self):
|
def test_fsdp_orig_params_assert(self):
|
||||||
# Test with basic FSDP wrapping (outer wrap around whole model)
|
# Test with basic FSDP wrapping (outer wrap around whole model)
|
||||||
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
|
m, inputs, _ = get_model(f"cuda:{self.rank}")
|
||||||
fsdp_m = FSDP(m, use_orig_params=False)
|
fsdp_m = FSDP(m, use_orig_params=False)
|
||||||
fsdp_m = torch.compile(fsdp_m)
|
fsdp_m = torch.compile(fsdp_m)
|
||||||
self.assertRaisesRegex(
|
self.assertRaisesRegex(
|
||||||
|
|
|
||||||
|
|
@ -130,7 +130,7 @@ class TestExpand(MultiThreadedTestCase):
|
||||||
tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD, "bla")
|
tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD, "bla")
|
||||||
self.assertEqual("bla", tag)
|
self.assertEqual("bla", tag)
|
||||||
|
|
||||||
my_pg, others = new_subgroups(group_size=2)
|
my_pg, _ = new_subgroups(group_size=2)
|
||||||
tag, rankset, group_size = ft_c._expand_group(my_pg)
|
tag, rankset, group_size = ft_c._expand_group(my_pg)
|
||||||
self.assertEqual(c10d._get_group_tag(my_pg), tag)
|
self.assertEqual(c10d._get_group_tag(my_pg), tag)
|
||||||
self.assertEqual(dist.get_process_group_ranks(my_pg), rankset)
|
self.assertEqual(dist.get_process_group_ranks(my_pg), rankset)
|
||||||
|
|
@ -588,7 +588,7 @@ class TestCollectivesWithDistributedBackend(DistributedTestBase):
|
||||||
def allreduce(t, pg):
|
def allreduce(t, pg):
|
||||||
return ft_c.all_reduce(t, "sum", pg)
|
return ft_c.all_reduce(t, "sum", pg)
|
||||||
|
|
||||||
compiled_allreduce = torch.compile(allreduce, fullgraph=True)
|
compiled_allreduce = torch.compile(allreduce, fullgraph=True) # noqa: F841
|
||||||
dist.init_process_group(
|
dist.init_process_group(
|
||||||
backend="fake",
|
backend="fake",
|
||||||
rank=0,
|
rank=0,
|
||||||
|
|
@ -615,9 +615,7 @@ class TestCollectivesWithDistributedBackend(DistributedTestBase):
|
||||||
return batch * 5
|
return batch * 5
|
||||||
|
|
||||||
compiled_func = torch.compile(func)
|
compiled_func = torch.compile(func)
|
||||||
ret = compiled_func(
|
compiled_func(torch.ones((100,), device=device), self.process_group, self.rank)
|
||||||
torch.ones((100,), device=device), self.process_group, self.rank
|
|
||||||
)
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -715,7 +713,7 @@ class TestFunctionalAutograd(MultiThreadedTestCase):
|
||||||
out = compiled(t, self.world_size)
|
out = compiled(t, self.world_size)
|
||||||
out.backward()
|
out.backward()
|
||||||
|
|
||||||
res, codes = run_and_get_code(run_with_backward)
|
_, codes = run_and_get_code(run_with_backward)
|
||||||
for code in codes:
|
for code in codes:
|
||||||
FileCheck().check_count(
|
FileCheck().check_count(
|
||||||
"_c10d_functional.all_to_all_single.default", 1, exactly=True
|
"_c10d_functional.all_to_all_single.default", 1, exactly=True
|
||||||
|
|
|
||||||
|
|
@ -411,7 +411,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||||
y = self.emb(x)
|
y = self.emb(x)
|
||||||
last_dim = y.dim() - 1
|
last_dim = y.dim() - 1
|
||||||
y = y.transpose_(0, last_dim).contiguous()
|
y = y.transpose_(0, last_dim).contiguous()
|
||||||
res = _functional_collectives.all_gather_tensor(y, 0, ranks, tag)
|
_functional_collectives.all_gather_tensor(y, 0, ranks, tag)
|
||||||
out = y.transpose_(0, last_dim).contiguous()
|
out = y.transpose_(0, last_dim).contiguous()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,6 @@ class TestDistributedLaunch(TestCase):
|
||||||
def test_launch_user_script(self):
|
def test_launch_user_script(self):
|
||||||
nnodes = 1
|
nnodes = 1
|
||||||
nproc_per_node = 4
|
nproc_per_node = 4
|
||||||
world_size = nnodes * nproc_per_node
|
|
||||||
sock = get_socket_with_port()
|
sock = get_socket_with_port()
|
||||||
with closing(sock):
|
with closing(sock):
|
||||||
master_port = sock.getsockname()[1]
|
master_port = sock.getsockname()[1]
|
||||||
|
|
|
||||||
|
|
@ -553,7 +553,7 @@ class LibUvTCPStoreTest(TCPStoreTest):
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.assertRaisesRegex(NotImplementedError, err_msg_reg):
|
with self.assertRaisesRegex(NotImplementedError, err_msg_reg):
|
||||||
store = dist.TCPStore(
|
dist.TCPStore(
|
||||||
addr,
|
addr,
|
||||||
port,
|
port,
|
||||||
1,
|
1,
|
||||||
|
|
@ -748,7 +748,7 @@ class RendezvousTCPTest(TestCase):
|
||||||
url = self.create_tcp_url()
|
url = self.create_tcp_url()
|
||||||
test_store_timeout = timedelta(seconds=0.1)
|
test_store_timeout = timedelta(seconds=0.1)
|
||||||
gen0 = dist.rendezvous(url + "&rank=0", timeout=timedelta(seconds=10))
|
gen0 = dist.rendezvous(url + "&rank=0", timeout=timedelta(seconds=10))
|
||||||
store0, rank0, size0 = next(gen0)
|
store0, _, _ = next(gen0)
|
||||||
store0.set_timeout(test_store_timeout)
|
store0.set_timeout(test_store_timeout)
|
||||||
# this should time out in 0.1s. If the timeout passed into rendezvous was
|
# this should time out in 0.1s. If the timeout passed into rendezvous was
|
||||||
# not respected, it will take much longer to timeout.
|
# not respected, it will take much longer to timeout.
|
||||||
|
|
@ -766,7 +766,7 @@ class RendezvousTCPTest(TestCase):
|
||||||
url = self.create_tcp_url()
|
url = self.create_tcp_url()
|
||||||
test_store_timeout = timedelta(seconds=0.1)
|
test_store_timeout = timedelta(seconds=0.1)
|
||||||
gen0 = dist.rendezvous(url + "&rank=0", timeout=timedelta(seconds=10))
|
gen0 = dist.rendezvous(url + "&rank=0", timeout=timedelta(seconds=10))
|
||||||
store0, rank0, size0 = next(gen0)
|
store0, _, _ = next(gen0)
|
||||||
store0.set_timeout(test_store_timeout)
|
store0.set_timeout(test_store_timeout)
|
||||||
# this should time out in 10s. If the timeout passed into rendezvous was
|
# this should time out in 10s. If the timeout passed into rendezvous was
|
||||||
# not respected, it will take much longer to timeout.
|
# not respected, it will take much longer to timeout.
|
||||||
|
|
@ -787,7 +787,7 @@ class RendezvousTCPTest(TestCase):
|
||||||
def test_tcp_store_url_with_libuv(self):
|
def test_tcp_store_url_with_libuv(self):
|
||||||
url = self.create_tcp_url()
|
url = self.create_tcp_url()
|
||||||
gen0 = dist.rendezvous(url + "&rank=0&use_libuv=1")
|
gen0 = dist.rendezvous(url + "&rank=0&use_libuv=1")
|
||||||
store0, rank0, size0 = next(gen0)
|
store0, _, _ = next(gen0)
|
||||||
self.assertTrue(store0.libuvBackend)
|
self.assertTrue(store0.libuvBackend)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1078,7 +1078,7 @@ class TestClientProtocol(TestCase):
|
||||||
thread = threading.Thread(target=listen)
|
thread = threading.Thread(target=listen)
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
store = dist.TCPStore(
|
dist.TCPStore(
|
||||||
host_name="localhost",
|
host_name="localhost",
|
||||||
port=port,
|
port=port,
|
||||||
world_size=2,
|
world_size=2,
|
||||||
|
|
|
||||||
|
|
@ -332,7 +332,6 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||||
K = 32
|
K = 32
|
||||||
group = dist.group.WORLD
|
group = dist.group.WORLD
|
||||||
rank = self.rank
|
rank = self.rank
|
||||||
world_size = self.world_size
|
|
||||||
|
|
||||||
torch.manual_seed(42 + rank)
|
torch.manual_seed(42 + rank)
|
||||||
A_shard = torch.rand(BATCH, M // self.world_size, K, device="cuda")
|
A_shard = torch.rand(BATCH, M // self.world_size, K, device="cuda")
|
||||||
|
|
@ -428,7 +427,6 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||||
K = 32
|
K = 32
|
||||||
group = dist.group.WORLD
|
group = dist.group.WORLD
|
||||||
rank = self.rank
|
rank = self.rank
|
||||||
world_size = self.world_size
|
|
||||||
|
|
||||||
if gather_dim == 0:
|
if gather_dim == 0:
|
||||||
leading_dims = (BATCH // self.world_size, M)
|
leading_dims = (BATCH // self.world_size, M)
|
||||||
|
|
@ -513,7 +511,6 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||||
K = 32
|
K = 32
|
||||||
group = dist.group.WORLD
|
group = dist.group.WORLD
|
||||||
rank = self.rank
|
rank = self.rank
|
||||||
world_size = self.world_size
|
|
||||||
|
|
||||||
torch.manual_seed(42 + rank)
|
torch.manual_seed(42 + rank)
|
||||||
A = torch.rand(BATCH, M, K, device="cuda")
|
A = torch.rand(BATCH, M, K, device="cuda")
|
||||||
|
|
@ -546,7 +543,6 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||||
K = 32
|
K = 32
|
||||||
group = dist.group.WORLD
|
group = dist.group.WORLD
|
||||||
rank = self.rank
|
rank = self.rank
|
||||||
world_size = self.world_size
|
|
||||||
|
|
||||||
torch.manual_seed(42 + rank)
|
torch.manual_seed(42 + rank)
|
||||||
A = torch.rand(BATCH, M, K, device="cuda").to(torch.float8_e4m3fn)
|
A = torch.rand(BATCH, M, K, device="cuda").to(torch.float8_e4m3fn)
|
||||||
|
|
|
||||||
|
|
@ -1314,7 +1314,7 @@ class TestDistributions(DistributionsTestCase):
|
||||||
if not msk.all():
|
if not msk.all():
|
||||||
counts = np.concatenate([counts[msk], np.sum(counts[~msk], keepdims=True)])
|
counts = np.concatenate([counts[msk], np.sum(counts[~msk], keepdims=True)])
|
||||||
pmf = np.concatenate([pmf[msk], np.sum(pmf[~msk], keepdims=True)])
|
pmf = np.concatenate([pmf[msk], np.sum(pmf[~msk], keepdims=True)])
|
||||||
chisq, p = scipy.stats.chisquare(counts, pmf * num_samples)
|
_, p = scipy.stats.chisquare(counts, pmf * num_samples)
|
||||||
self.assertGreater(p, failure_rate, message)
|
self.assertGreater(p, failure_rate, message)
|
||||||
|
|
||||||
def _check_enumerate_support(self, dist, examples):
|
def _check_enumerate_support(self, dist, examples):
|
||||||
|
|
@ -1912,9 +1912,7 @@ class TestDistributions(DistributionsTestCase):
|
||||||
@set_default_dtype(torch.double)
|
@set_default_dtype(torch.double)
|
||||||
def test_one_hot_categorical_2d(self):
|
def test_one_hot_categorical_2d(self):
|
||||||
probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
|
probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
|
||||||
probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
|
|
||||||
p = torch.tensor(probabilities, requires_grad=True)
|
p = torch.tensor(probabilities, requires_grad=True)
|
||||||
s = torch.tensor(probabilities_1, requires_grad=True)
|
|
||||||
self.assertEqual(OneHotCategorical(p).sample().size(), (2, 3))
|
self.assertEqual(OneHotCategorical(p).sample().size(), (2, 3))
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
OneHotCategorical(p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3)
|
OneHotCategorical(p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3)
|
||||||
|
|
@ -2074,13 +2072,11 @@ class TestDistributions(DistributionsTestCase):
|
||||||
@set_default_dtype(torch.double)
|
@set_default_dtype(torch.double)
|
||||||
def test_relaxed_one_hot_categorical_2d(self):
|
def test_relaxed_one_hot_categorical_2d(self):
|
||||||
probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
|
probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
|
||||||
probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
|
|
||||||
temp = torch.tensor([3.0], requires_grad=True)
|
temp = torch.tensor([3.0], requires_grad=True)
|
||||||
# The lower the temperature, the more unstable the log_prob gradcheck is
|
# The lower the temperature, the more unstable the log_prob gradcheck is
|
||||||
# w.r.t. the sample. Values below 0.25 empirically fail the default tol.
|
# w.r.t. the sample. Values below 0.25 empirically fail the default tol.
|
||||||
temp_2 = torch.tensor([0.25], requires_grad=True)
|
temp_2 = torch.tensor([0.25], requires_grad=True)
|
||||||
p = torch.tensor(probabilities, requires_grad=True)
|
p = torch.tensor(probabilities, requires_grad=True)
|
||||||
s = torch.tensor(probabilities_1, requires_grad=True)
|
|
||||||
self.assertEqual(RelaxedOneHotCategorical(temp, p).sample().size(), (2, 3))
|
self.assertEqual(RelaxedOneHotCategorical(temp, p).sample().size(), (2, 3))
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
RelaxedOneHotCategorical(temp, p).sample(sample_shape=(3, 4)).size(),
|
RelaxedOneHotCategorical(temp, p).sample(sample_shape=(3, 4)).size(),
|
||||||
|
|
@ -3939,7 +3935,7 @@ class TestDistributions(DistributionsTestCase):
|
||||||
for dim in range(2, 5):
|
for dim in range(2, 5):
|
||||||
log_probs = []
|
log_probs = []
|
||||||
lkj = LKJCholesky(dim, concentration=1.0, validate_args=True)
|
lkj = LKJCholesky(dim, concentration=1.0, validate_args=True)
|
||||||
for i in range(2):
|
for _ in range(2):
|
||||||
sample = lkj.sample()
|
sample = lkj.sample()
|
||||||
sample_tril = tril_matrix_to_vec(sample, diag=-1)
|
sample_tril = tril_matrix_to_vec(sample, diag=-1)
|
||||||
log_prob = lkj.log_prob(sample)
|
log_prob = lkj.log_prob(sample)
|
||||||
|
|
@ -6241,7 +6237,7 @@ class TestLazyLogitsInitialization(DistributionsTestCase):
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
pass
|
pass
|
||||||
self.assertNotIn("probs", dist.__dict__, msg=message)
|
self.assertNotIn("probs", dist.__dict__, msg=message)
|
||||||
batch_shape, event_shape = dist.batch_shape, dist.event_shape
|
dist.batch_shape, dist.event_shape
|
||||||
self.assertNotIn("probs", dist.__dict__, msg=message)
|
self.assertNotIn("probs", dist.__dict__, msg=message)
|
||||||
|
|
||||||
def test_lazy_probs_initialization(self):
|
def test_lazy_probs_initialization(self):
|
||||||
|
|
@ -6258,7 +6254,7 @@ class TestLazyLogitsInitialization(DistributionsTestCase):
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
pass
|
pass
|
||||||
self.assertNotIn("logits", dist.__dict__, msg=message)
|
self.assertNotIn("logits", dist.__dict__, msg=message)
|
||||||
batch_shape, event_shape = dist.batch_shape, dist.event_shape
|
dist.batch_shape, dist.event_shape
|
||||||
self.assertNotIn("logits", dist.__dict__, msg=message)
|
self.assertNotIn("logits", dist.__dict__, msg=message)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -6565,6 +6561,7 @@ class TestFunctors(DistributionsTestCase):
|
||||||
expected_jac = sum(
|
expected_jac = sum(
|
||||||
[t1.log_abs_det_jacobian(x1, y1), t2.log_abs_det_jacobian(x2, y2)]
|
[t1.log_abs_det_jacobian(x1, y1), t2.log_abs_det_jacobian(x2, y2)]
|
||||||
)
|
)
|
||||||
|
self.assertEqual(actual_jac, expected_jac)
|
||||||
|
|
||||||
def test_stack_transform(self):
|
def test_stack_transform(self):
|
||||||
x1 = -1 * torch.arange(1, 101, dtype=torch.float)
|
x1 = -1 * torch.arange(1, 101, dtype=torch.float)
|
||||||
|
|
@ -6628,18 +6625,18 @@ class TestValidation(DistributionsTestCase):
|
||||||
for v in torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]):
|
for v in torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]):
|
||||||
# samples with incorrect shape must throw ValueError only
|
# samples with incorrect shape must throw ValueError only
|
||||||
try:
|
try:
|
||||||
log_prob = d_val.log_prob(v)
|
d_val.log_prob(v)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
# get sample of correct shape
|
# get sample of correct shape
|
||||||
val = torch.full(d_val.batch_shape + d_val.event_shape, v)
|
val = torch.full(d_val.batch_shape + d_val.event_shape, v)
|
||||||
# check samples with incorrect support
|
# check samples with incorrect support
|
||||||
try:
|
try:
|
||||||
log_prob = d_val.log_prob(val)
|
d_val.log_prob(val)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
if e.args and "must be within the support" in e.args[0]:
|
if e.args and "must be within the support" in e.args[0]:
|
||||||
try:
|
try:
|
||||||
log_prob = d_nonval.log_prob(val)
|
d_nonval.log_prob(val)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1260,7 +1260,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, x, ys):
|
def forward(self, x, ys):
|
||||||
a = torch.sin(x)
|
a = torch.sin(x) # noqa: F841
|
||||||
b = torch.cos(ys[0])
|
b = torch.cos(ys[0])
|
||||||
c = torch.cos(ys[1])
|
c = torch.cos(ys[1])
|
||||||
return (x, [b, c])
|
return (x, [b, c])
|
||||||
|
|
|
||||||
|
|
@ -453,7 +453,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
|
||||||
a = torch.randn(3, 3, requires_grad=True)
|
a = torch.randn(3, 3, requires_grad=True)
|
||||||
b = torch.randn(3, 3, requires_grad=True)
|
b = torch.randn(3, 3, requires_grad=True)
|
||||||
a1, a2 = a.clone(), a.clone()
|
a1, a2 = a.clone(), a.clone()
|
||||||
b1, b2 = b.clone(), b.clone()
|
_, b2 = b.clone(), b.clone()
|
||||||
|
|
||||||
failure_reason = None
|
failure_reason = None
|
||||||
|
|
||||||
|
|
@ -481,7 +481,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
|
||||||
c = torch.randn(3, 3, requires_grad=True)
|
c = torch.randn(3, 3, requires_grad=True)
|
||||||
d = torch.randn(3, 3, requires_grad=True)
|
d = torch.randn(3, 3, requires_grad=True)
|
||||||
c3, c4 = c.clone(), c.clone()
|
c3, c4 = c.clone(), c.clone()
|
||||||
d3, d4 = d.clone(), d.clone()
|
_, d4 = d.clone(), d.clone()
|
||||||
|
|
||||||
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
|
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
|
||||||
f(c3, c3, 3, 3)
|
f(c3, c3, 3, 3)
|
||||||
|
|
@ -507,7 +507,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
|
||||||
b = torch.randn(3, 3, requires_grad=True)
|
b = torch.randn(3, 3, requires_grad=True)
|
||||||
z = a
|
z = a
|
||||||
a1, a2 = a.clone(), a.clone()
|
a1, a2 = a.clone(), a.clone()
|
||||||
b1, b2 = b.clone(), b.clone()
|
_, b2 = b.clone(), b.clone()
|
||||||
|
|
||||||
failure_reason = None
|
failure_reason = None
|
||||||
|
|
||||||
|
|
@ -543,7 +543,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
|
||||||
a = torch.randn(3, 3, requires_grad=True)
|
a = torch.randn(3, 3, requires_grad=True)
|
||||||
b = torch.randn(3, 3, requires_grad=True)
|
b = torch.randn(3, 3, requires_grad=True)
|
||||||
a1, a2 = a.clone(), a.clone()
|
a1, a2 = a.clone(), a.clone()
|
||||||
b1, b2 = b.clone(), b.clone()
|
_, b2 = b.clone(), b.clone()
|
||||||
|
|
||||||
failure_reason = None
|
failure_reason = None
|
||||||
|
|
||||||
|
|
@ -571,7 +571,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
|
||||||
c = torch.randn(3, 3, requires_grad=True)
|
c = torch.randn(3, 3, requires_grad=True)
|
||||||
d = torch.randn(3, 3, requires_grad=True)
|
d = torch.randn(3, 3, requires_grad=True)
|
||||||
c3, c4 = c.clone(), c.clone()
|
c3, c4 = c.clone(), c.clone()
|
||||||
d3, d4 = d.clone(), d.clone()
|
_, d4 = d.clone(), d.clone()
|
||||||
|
|
||||||
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
|
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
|
||||||
f([3, 2, 1], [4, 5, 6], c3, c3)
|
f([3, 2, 1], [4, 5, 6], c3, c3)
|
||||||
|
|
@ -593,7 +593,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
|
||||||
a = torch.randn(3, 3, requires_grad=True)
|
a = torch.randn(3, 3, requires_grad=True)
|
||||||
b = torch.randn(3, 3, requires_grad=True)
|
b = torch.randn(3, 3, requires_grad=True)
|
||||||
a1, a2 = a.clone(), a.clone()
|
a1, a2 = a.clone(), a.clone()
|
||||||
b1, b2 = b.clone(), b.clone()
|
_, b2 = b.clone(), b.clone()
|
||||||
|
|
||||||
failure_reason = None
|
failure_reason = None
|
||||||
|
|
||||||
|
|
@ -621,7 +621,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
|
||||||
c = torch.randn(3, 3, requires_grad=True)
|
c = torch.randn(3, 3, requires_grad=True)
|
||||||
d = torch.randn(3, 3, requires_grad=True)
|
d = torch.randn(3, 3, requires_grad=True)
|
||||||
c3, c4 = c.clone(), c.clone()
|
c3, c4 = c.clone(), c.clone()
|
||||||
d3, d4 = d.clone(), d.clone()
|
_, d4 = d.clone(), d.clone()
|
||||||
|
|
||||||
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
|
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
|
||||||
f(c3, c3)
|
f(c3, c3)
|
||||||
|
|
@ -642,7 +642,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
|
||||||
a = torch.randn(3, 3, requires_grad=True)
|
a = torch.randn(3, 3, requires_grad=True)
|
||||||
b = torch.randn(3, 3, requires_grad=True)
|
b = torch.randn(3, 3, requires_grad=True)
|
||||||
a1, a2, a3, a4 = a.clone(), a.clone(), a.clone(), a.clone()
|
a1, a2, a3, a4 = a.clone(), a.clone(), a.clone(), a.clone()
|
||||||
b1, b2, b3, b4 = b.clone(), b.clone(), b.clone(), b.clone()
|
_, b2, b3, b4 = b.clone(), b.clone(), b.clone(), b.clone()
|
||||||
|
|
||||||
failure_reason = None
|
failure_reason = None
|
||||||
|
|
||||||
|
|
@ -670,7 +670,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
|
||||||
c = torch.randn(3, 3, requires_grad=True)
|
c = torch.randn(3, 3, requires_grad=True)
|
||||||
d = torch.randn(3, 3, requires_grad=True)
|
d = torch.randn(3, 3, requires_grad=True)
|
||||||
c3, c4 = c.clone(), c.clone()
|
c3, c4 = c.clone(), c.clone()
|
||||||
d3, d4 = d.clone(), d.clone()
|
_, d4 = d.clone(), d.clone()
|
||||||
|
|
||||||
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
|
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
|
||||||
f(a3, b3, c3, c3)
|
f(a3, b3, c3, c3)
|
||||||
|
|
@ -1017,7 +1017,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
|
||||||
activities=[torch.profiler.ProfilerActivity.CPU],
|
activities=[torch.profiler.ProfilerActivity.CPU],
|
||||||
record_shapes=True,
|
record_shapes=True,
|
||||||
) as kineto_prof:
|
) as kineto_prof:
|
||||||
res = model_instance(*args)
|
model_instance(*args)
|
||||||
bwd_set = set()
|
bwd_set = set()
|
||||||
prof_str = "SeqNr|Thread|FwdThread|Name\n"
|
prof_str = "SeqNr|Thread|FwdThread|Name\n"
|
||||||
for event in kineto_prof.events():
|
for event in kineto_prof.events():
|
||||||
|
|
@ -1191,7 +1191,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
|
||||||
|
|
||||||
x = torch.randn(3, requires_grad=True)
|
x = torch.randn(3, requires_grad=True)
|
||||||
with self.assertRaisesRegex(RuntimeError, "Cannot access data pointer"):
|
with self.assertRaisesRegex(RuntimeError, "Cannot access data pointer"):
|
||||||
y = torch.compile(f, backend="aot_eager", fullgraph=True)(x)
|
torch.compile(f, backend="aot_eager", fullgraph=True)(x)
|
||||||
self.assertTrue(backward_called)
|
self.assertTrue(backward_called)
|
||||||
|
|
||||||
# We don't know how to catch multiple mutations to the same memory location
|
# We don't know how to catch multiple mutations to the same memory location
|
||||||
|
|
|
||||||
|
|
@ -157,7 +157,7 @@ class AOTAutogradCacheTests(InductorTestCase):
|
||||||
|
|
||||||
with torch.autograd._force_original_view_tracking(True):
|
with torch.autograd._force_original_view_tracking(True):
|
||||||
compiled_fn = torch.compile(fn)
|
compiled_fn = torch.compile(fn)
|
||||||
out = compiled_fn(torch.rand(2, 3))
|
compiled_fn(torch.rand(2, 3))
|
||||||
|
|
||||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
|
||||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
|
||||||
|
|
@ -654,7 +654,7 @@ class AOTAutogradCachePicklerTests(torch._dynamo.test_case.TestCase):
|
||||||
def fn(x):
|
def fn(x):
|
||||||
return x.sin().cos()
|
return x.sin().cos()
|
||||||
|
|
||||||
def fn2(x):
|
def fn2(x): # noqa: F841
|
||||||
y = x.sin()
|
y = x.sin()
|
||||||
z = y.cos()
|
z = y.cos()
|
||||||
return z
|
return z
|
||||||
|
|
|
||||||
|
|
@ -760,7 +760,7 @@ class GraphModule(torch.nn.Module):
|
||||||
def backward(ctx, gO):
|
def backward(ctx, gO):
|
||||||
return torch.tensor(float("nan")).expand(10, 10)
|
return torch.tensor(float("nan")).expand(10, 10)
|
||||||
|
|
||||||
def run_fn(a):
|
def run_fn(a): # noqa: F841
|
||||||
out = MyFunc2.apply(a)
|
out = MyFunc2.apply(a)
|
||||||
return out.sum()
|
return out.sum()
|
||||||
|
|
||||||
|
|
@ -837,11 +837,11 @@ class GraphModule(torch.nn.Module):
|
||||||
|
|
||||||
x = torch.randn(5, 5, requires_grad=True)
|
x = torch.randn(5, 5, requires_grad=True)
|
||||||
y = torch.randn(5, 5, requires_grad=True)
|
y = torch.randn(5, 5, requires_grad=True)
|
||||||
q, p = Identity.apply(x, y)
|
Identity.apply(x, y)
|
||||||
|
|
||||||
a = torch.rand(1, 2)
|
a = torch.rand(1, 2)
|
||||||
b = torch.rand(1, requires_grad=True)
|
b = torch.rand(1, requires_grad=True)
|
||||||
view_a = MyFn.apply(a)
|
MyFn.apply(a)
|
||||||
|
|
||||||
a = torch.ones(2, requires_grad=True)
|
a = torch.ones(2, requires_grad=True)
|
||||||
b = torch.ones(2, requires_grad=True)
|
b = torch.ones(2, requires_grad=True)
|
||||||
|
|
@ -860,7 +860,7 @@ class GraphModule(torch.nn.Module):
|
||||||
MyFn2.apply(c, d)
|
MyFn2.apply(c, d)
|
||||||
|
|
||||||
base = torch.rand(10, requires_grad=True)
|
base = torch.rand(10, requires_grad=True)
|
||||||
foo = MyFn3.apply(base, False)
|
MyFn3.apply(base, False)
|
||||||
|
|
||||||
test()
|
test()
|
||||||
opt_test = torch.compile(test, backend="eager")
|
opt_test = torch.compile(test, backend="eager")
|
||||||
|
|
|
||||||
|
|
@ -267,9 +267,8 @@ class TestCustomBackendAPI(torch._dynamo.test_case.TestCase):
|
||||||
self.assertTrue(backend_run)
|
self.assertTrue(backend_run)
|
||||||
|
|
||||||
def test_lookup_backend(self):
|
def test_lookup_backend(self):
|
||||||
from torch._dynamo import list_backends, lookup_backend
|
from torch._dynamo import lookup_backend
|
||||||
|
|
||||||
backends = list_backends()
|
|
||||||
backend_run = False
|
backend_run = False
|
||||||
|
|
||||||
def my_compiler(gm, example_inputs):
|
def my_compiler(gm, example_inputs):
|
||||||
|
|
|
||||||
|
|
@ -247,8 +247,6 @@ class GraphModule(torch.nn.Module):
|
||||||
with compiled_autograd._enable(compiler_fn):
|
with compiled_autograd._enable(compiler_fn):
|
||||||
out.backward(grad_out)
|
out.backward(grad_out)
|
||||||
|
|
||||||
graph = None
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch._dynamo.test_case import run_tests
|
from torch._dynamo.test_case import run_tests
|
||||||
|
|
|
||||||
|
|
@ -518,7 +518,7 @@ def fn():
|
||||||
insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False)
|
insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False)
|
||||||
self.assertEqual(insts[-1].opname, "NOP")
|
self.assertEqual(insts[-1].opname, "NOP")
|
||||||
insts_i = 0
|
insts_i = 0
|
||||||
for i, inst in enumerate(dis_insts):
|
for inst in dis_insts:
|
||||||
if inst.opname == "RETURN_CONST":
|
if inst.opname == "RETURN_CONST":
|
||||||
self.assertEqual(insts[insts_i].opname, "LOAD_CONST")
|
self.assertEqual(insts[insts_i].opname, "LOAD_CONST")
|
||||||
insts_i += 1
|
insts_i += 1
|
||||||
|
|
@ -538,7 +538,7 @@ def fn():
|
||||||
x = x + 1
|
x = x + 1
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
x = x + 1
|
x = x + 1
|
||||||
except Exception as e:
|
except Exception:
|
||||||
x = x + 1
|
x = x + 1
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ class TestCompilerBisector(TestCase):
|
||||||
return lib
|
return lib
|
||||||
|
|
||||||
def test_bad_decomp(self):
|
def test_bad_decomp(self):
|
||||||
mod = import_module("torch._inductor.compile_fx")
|
import_module("torch._inductor.compile_fx")
|
||||||
|
|
||||||
def bad_exp_decomp(self, rate=1, generator=None):
|
def bad_exp_decomp(self, rate=1, generator=None):
|
||||||
assert generator is None
|
assert generator is None
|
||||||
|
|
@ -86,7 +86,7 @@ class TestCompilerBisector(TestCase):
|
||||||
vq_compiled = torch.compile(vq)
|
vq_compiled = torch.compile(vq)
|
||||||
x = torch.randn(4, 400, 256).cuda()
|
x = torch.randn(4, 400, 256).cuda()
|
||||||
with torch._dynamo.utils.preserve_rng_state():
|
with torch._dynamo.utils.preserve_rng_state():
|
||||||
out = vq(x)
|
vq(x)
|
||||||
out_compiled = vq_compiled(x)
|
out_compiled = vq_compiled(x)
|
||||||
|
|
||||||
return not out_compiled.isnan().any()
|
return not out_compiled.isnan().any()
|
||||||
|
|
@ -150,7 +150,6 @@ class TestCompilerBisector(TestCase):
|
||||||
self.assertTrue("inductor_fallback_random" in out.debug_info)
|
self.assertTrue("inductor_fallback_random" in out.debug_info)
|
||||||
|
|
||||||
def test_crossref(self):
|
def test_crossref(self):
|
||||||
test_ns = "bisect_ops"
|
|
||||||
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
||||||
lib.define("foo(Tensor x) -> Tensor")
|
lib.define("foo(Tensor x) -> Tensor")
|
||||||
op = self.get_op("foo")
|
op = self.get_op("foo")
|
||||||
|
|
|
||||||
|
|
@ -117,7 +117,7 @@ def forward(self, L_x_ : torch.Tensor):
|
||||||
|
|
||||||
return y + 3
|
return y + 3
|
||||||
|
|
||||||
def munge_disas(s):
|
def munge_disas(s): # noqa: F841
|
||||||
re.sub(
|
re.sub(
|
||||||
r"^(?: +\d+)?(?: +(-->)) \+\d+ ([A-Za-z0-9_]+)",
|
r"^(?: +\d+)?(?: +(-->)) \+\d+ ([A-Za-z0-9_]+)",
|
||||||
"\1 \3",
|
"\1 \3",
|
||||||
|
|
@ -271,7 +271,7 @@ y = FakeTensor(..., size=(2,))
|
||||||
y = g(y)
|
y = g(y)
|
||||||
return y + 3
|
return y + 3
|
||||||
|
|
||||||
def munge_filenames(s):
|
def munge_filenames(s): # noqa: F841
|
||||||
return re.sub(r'File "[^"]+", line \d+', 'File "X", line X', s)
|
return re.sub(r'File "[^"]+", line \d+', 'File "X", line X', s)
|
||||||
|
|
||||||
f(torch.randn(2))
|
f(torch.randn(2))
|
||||||
|
|
@ -389,7 +389,7 @@ y = FakeTensor(..., size=(2,))
|
||||||
@torch.compile(backend=cnt)
|
@torch.compile(backend=cnt)
|
||||||
def f(x):
|
def f(x):
|
||||||
y = x * 2
|
y = x * 2
|
||||||
lit = 2
|
lit = 2 # noqa: F841
|
||||||
|
|
||||||
@comptime
|
@comptime
|
||||||
def _(ctx):
|
def _(ctx):
|
||||||
|
|
|
||||||
|
|
@ -268,15 +268,13 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
||||||
cur_stream.wait_stream(new_stream)
|
cur_stream.wait_stream(new_stream)
|
||||||
|
|
||||||
x = torch.add(x, 4)
|
x = torch.add(x, 4)
|
||||||
is_idle = cur_stream.query()
|
cur_stream.query()
|
||||||
cur_stream.synchronize()
|
cur_stream.synchronize()
|
||||||
|
|
||||||
with torch.cuda.stream(new_stream):
|
with torch.cuda.stream(new_stream):
|
||||||
x = torch.add(x, 5)
|
x = torch.add(x, 5)
|
||||||
new_stream.synchronize()
|
new_stream.synchronize()
|
||||||
|
|
||||||
is_equal = cur_stream == new_stream
|
|
||||||
|
|
||||||
x = torch.relu(x)
|
x = torch.relu(x)
|
||||||
x = torch.cos(x)
|
x = torch.cos(x)
|
||||||
return x
|
return x
|
||||||
|
|
@ -439,7 +437,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
||||||
x = torch.add(x, 3)
|
x = torch.add(x, 3)
|
||||||
|
|
||||||
event = cur_stream.record_event()
|
event = cur_stream.record_event()
|
||||||
is_idle = event.query()
|
event.query()
|
||||||
|
|
||||||
new_stream.wait_event(event)
|
new_stream.wait_event(event)
|
||||||
with torch.cuda.stream(new_stream):
|
with torch.cuda.stream(new_stream):
|
||||||
|
|
@ -481,7 +479,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
||||||
x = torch.add(x, 3)
|
x = torch.add(x, 3)
|
||||||
|
|
||||||
event = cur_stream.record_event()
|
event = cur_stream.record_event()
|
||||||
is_idle = event.query()
|
event.query()
|
||||||
|
|
||||||
new_stream.wait_event(event)
|
new_stream.wait_event(event)
|
||||||
with torch.cuda.stream(new_stream):
|
with torch.cuda.stream(new_stream):
|
||||||
|
|
@ -567,7 +565,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
||||||
real_device = real.device
|
real_device = real.device
|
||||||
real_dtype = real.dtype
|
real_dtype = real.dtype
|
||||||
|
|
||||||
graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
|
graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
|
||||||
exported = graph(torch.tensor([0.5]))
|
exported = graph(torch.tensor([0.5]))
|
||||||
self.assertEqual(exported.device, real_device)
|
self.assertEqual(exported.device, real_device)
|
||||||
self.assertEqual(exported.dtype, real_dtype)
|
self.assertEqual(exported.dtype, real_dtype)
|
||||||
|
|
@ -676,7 +674,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
||||||
real_device = real.device
|
real_device = real.device
|
||||||
real_dtype = real.dtype
|
real_dtype = real.dtype
|
||||||
|
|
||||||
graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
|
graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
|
||||||
exported = graph(torch.tensor([0.5]))
|
exported = graph(torch.tensor([0.5]))
|
||||||
self.assertEqual(exported.device, real_device)
|
self.assertEqual(exported.device, real_device)
|
||||||
self.assertEqual(exported.dtype, real_dtype)
|
self.assertEqual(exported.dtype, real_dtype)
|
||||||
|
|
@ -850,7 +848,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
||||||
real_device = real.device
|
real_device = real.device
|
||||||
real_dtype = real.dtype
|
real_dtype = real.dtype
|
||||||
|
|
||||||
graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
|
graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
|
||||||
exported = graph(torch.tensor([0.5]))
|
exported = graph(torch.tensor([0.5]))
|
||||||
self.assertEqual(exported.device, real_device)
|
self.assertEqual(exported.device, real_device)
|
||||||
self.assertEqual(exported.dtype, real_dtype)
|
self.assertEqual(exported.dtype, real_dtype)
|
||||||
|
|
@ -876,7 +874,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
||||||
real_device = real.device
|
real_device = real.device
|
||||||
real_dtype = real.dtype
|
real_dtype = real.dtype
|
||||||
|
|
||||||
graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
|
graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
|
||||||
exported = graph(torch.tensor([0.5]))
|
exported = graph(torch.tensor([0.5]))
|
||||||
self.assertEqual(exported.device, real_device)
|
self.assertEqual(exported.device, real_device)
|
||||||
self.assertEqual(exported.dtype, real_dtype)
|
self.assertEqual(exported.dtype, real_dtype)
|
||||||
|
|
@ -1297,7 +1295,7 @@ class GraphModule(torch.nn.Module):
|
||||||
eager = EagerAndRecordGraphs()
|
eager = EagerAndRecordGraphs()
|
||||||
torch.compile(fn, backend=eager, fullgraph=False)(torch.randn(()))
|
torch.compile(fn, backend=eager, fullgraph=False)(torch.randn(()))
|
||||||
|
|
||||||
def check_graph(actual, expected):
|
def check_graph(actual, expected): # noqa: F841
|
||||||
self.assertExpectedInline(actual, expected)
|
self.assertExpectedInline(actual, expected)
|
||||||
|
|
||||||
graph = eager.graphs[0]
|
graph = eager.graphs[0]
|
||||||
|
|
@ -1342,7 +1340,7 @@ class GraphModule(torch.nn.Module):
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
|
|
||||||
ctx_wrapper, mode = ctx_wrappers[i]
|
ctx_wrapper, _ = ctx_wrappers[i]
|
||||||
ctx_wrapper_inverse, mode_inverse = ctx_wrappers[(i + 1) % 2]
|
ctx_wrapper_inverse, mode_inverse = ctx_wrappers[(i + 1) % 2]
|
||||||
|
|
||||||
def fn(x):
|
def fn(x):
|
||||||
|
|
@ -1373,7 +1371,7 @@ class GraphModule(torch.nn.Module):
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
|
|
||||||
ctx_wrapper, mode = ctx_wrappers[i]
|
ctx_wrapper, _ = ctx_wrappers[i]
|
||||||
ctx_wrapper_inverse, mode_inverse = ctx_wrappers[(i + 1) % 2]
|
ctx_wrapper_inverse, mode_inverse = ctx_wrappers[(i + 1) % 2]
|
||||||
|
|
||||||
def fn(x):
|
def fn(x):
|
||||||
|
|
|
||||||
|
|
@ -63,7 +63,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
||||||
|
|
||||||
@torch.compile(backend="cudagraphs")
|
@torch.compile(backend="cudagraphs")
|
||||||
def fn(x, y):
|
def fn(x, y):
|
||||||
for i in range(N_ITERS):
|
for _ in range(N_ITERS):
|
||||||
loss = model(x, y).sum()
|
loss = model(x, y).sum()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
|
@ -80,7 +80,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
||||||
|
|
||||||
@torch.compile(backend="cudagraphs")
|
@torch.compile(backend="cudagraphs")
|
||||||
def fn(x, y):
|
def fn(x, y):
|
||||||
for i in range(N_ITERS):
|
for _ in range(N_ITERS):
|
||||||
loss = model(x, y).sum()
|
loss = model(x, y).sum()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
|
@ -96,7 +96,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
||||||
|
|
||||||
@torch.compile(backend="cudagraphs")
|
@torch.compile(backend="cudagraphs")
|
||||||
def fn(x, y):
|
def fn(x, y):
|
||||||
for i in range(N_ITERS):
|
for _ in range(N_ITERS):
|
||||||
loss = model(x, y).sum()
|
loss = model(x, y).sum()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ def forward(self, x_1):
|
||||||
""", # NOQA: B950
|
""", # NOQA: B950
|
||||||
)
|
)
|
||||||
|
|
||||||
fp64_model, fp64_examples = debug_utils.cast_to_fp64(fx, (x,))
|
_, fp64_examples = debug_utils.cast_to_fp64(fx, (x,))
|
||||||
self.assertEqual(fp64_examples, (x.to(torch.float64),))
|
self.assertEqual(fp64_examples, (x.to(torch.float64),))
|
||||||
|
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
|
|
@ -79,7 +79,7 @@ def forward(self, x_1):
|
||||||
_tensor_constant0
|
_tensor_constant0
|
||||||
)
|
)
|
||||||
_tensor_constant0 = None
|
_tensor_constant0 = None
|
||||||
index: "f32[6144, 4190]" = torch.ops.aten.index.Tensor(
|
index: "f32[6144, 4190]" = torch.ops.aten.index.Tensor( # noqa: F841
|
||||||
primals_48, [None, lift_fresh_copy]
|
primals_48, [None, lift_fresh_copy]
|
||||||
)
|
)
|
||||||
lift_fresh_copy = None
|
lift_fresh_copy = None
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||||
|
|
||||||
# This behavior is not ideal, but supporting it would add overhead
|
# This behavior is not ideal, but supporting it would add overhead
|
||||||
# to callsites of eval_frame.innermost_fn. A warning would also be very noisy.
|
# to callsites of eval_frame.innermost_fn. A warning would also be very noisy.
|
||||||
w = torch._dynamo.disable(fn=wrapper, recursive=True)
|
torch._dynamo.disable(fn=wrapper, recursive=True)
|
||||||
|
|
||||||
def test_disable_nn_modules_forward_hook(self):
|
def test_disable_nn_modules_forward_hook(self):
|
||||||
class SimpleLinear(torch.nn.Module):
|
class SimpleLinear(torch.nn.Module):
|
||||||
|
|
@ -543,7 +543,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||||
return v1, v2, v3, v4, v5, v6, v7, v8, v9
|
return v1, v2, v3, v4, v5, v6, v7, v8, v9
|
||||||
|
|
||||||
a, b, c = A(), B(), C()
|
a, b, c = A(), B(), C()
|
||||||
v1, v2, v3, v4, v5, v6, v7, v8, v9 = fn(a, b, c)
|
v1, v2, v3, v4, v5, _, v7, v8, v9 = fn(a, b, c)
|
||||||
|
|
||||||
self.assertEqual(v1, (A, 1))
|
self.assertEqual(v1, (A, 1))
|
||||||
self.assertEqual(v2, (A, 2))
|
self.assertEqual(v2, (A, 2))
|
||||||
|
|
|
||||||
|
|
@ -92,7 +92,7 @@ from user code:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
# Ensure graph break is not possible
|
# Ensure graph break is not possible
|
||||||
for i in range(3):
|
for _ in range(3):
|
||||||
comptime(f)
|
comptime(f)
|
||||||
|
|
||||||
torch.compile(fn001, backend="eager")(torch.randn(1))
|
torch.compile(fn001, backend="eager")(torch.randn(1))
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user