Fix unused Python variables in test/[a-d]* (#134665)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134665
Approved by: https://github.com/albanD
This commit is contained in:
Tom Ritchford 2024-12-13 18:35:20 +00:00 committed by PyTorch MergeBot
parent e19f493f02
commit d25e6e623f
120 changed files with 410 additions and 522 deletions

View File

@ -147,7 +147,6 @@ def _sparse_layer_test_helper(
W_zp = 0 W_zp = 0
X_fp32 = torch.randn(batch_size, input_channels, dtype=torch.float32) X_fp32 = torch.randn(batch_size, input_channels, dtype=torch.float32)
float_bias = torch.randn(output_channels, dtype=torch.float32)
# generate a weight which we'll insert into the model # generate a weight which we'll insert into the model
W_fp32 = torch.randn(output_channels, input_channels, dtype=torch.float32) W_fp32 = torch.randn(output_channels, input_channels, dtype=torch.float32)

View File

@ -30,7 +30,6 @@ class TestQlinearPackedParams(TestCase):
row_block_size = 1 row_block_size = 1
col_block_size = 4 col_block_size = 4
out_features = weight_fp32.shape[0] out_features = weight_fp32.shape[0]
in_features = weight_fp32.shape[1]
scales = [2.0, 6.0, 12.0] scales = [2.0, 6.0, 12.0]
zero_points = [ zero_points = [
@ -201,14 +200,11 @@ class TestQlinearPackedParams(TestCase):
row_block_size = 1 row_block_size = 1
col_block_size = 4 col_block_size = 4
out_features = weight_fp32.shape[0] out_features = weight_fp32.shape[0]
in_features = weight_fp32.shape[1]
scales = [2.0, 3.0, 7.0] scales = [2.0, 3.0, 7.0]
zero_points = [0 for _ in range(out_features)] zero_points = [0 for _ in range(out_features)]
dtype = torch.qint8 dtype = torch.qint8
x = torch.rand(size=(1, weight_fp32.shape[1]))
def make_lin_get_state_weight_bias_and_save(): def make_lin_get_state_weight_bias_and_save():
weight = torch.quantize_per_tensor( weight = torch.quantize_per_tensor(
weight_fp32, weight_fp32,

View File

@ -86,7 +86,7 @@ class TestBaseSparsifier(TestCase):
sparsifier0.prepare(model0, [{"tensor_fqn": "linear1.weight"}]) sparsifier0.prepare(model0, [{"tensor_fqn": "linear1.weight"}])
mask = model0.linear1.parametrizations["weight"][0].mask mask = model0.linear1.parametrizations["weight"][0].mask
mask.data = torch.arange(mask.shape[0] * mask.shape[1]).reshape(mask.shape) mask.data = torch.arange(mask.shape[0] * mask.shape[1]).reshape(mask.shape)
for step in range(step_count): for _ in range(step_count):
sparsifier0.step() sparsifier0.step()
state_dict = sparsifier0.state_dict() state_dict = sparsifier0.state_dict()

View File

@ -124,7 +124,7 @@ class TestSparsityUtilFunctions(TestCase):
list_of_modules = [m for _, m in model.named_modules()] + [model] list_of_modules = [m for _, m in model.named_modules()] + [model]
for module in list_of_modules: for module in list_of_modules:
module_fqn = module_to_fqn(model, module) module_fqn = module_to_fqn(model, module)
for tensor_name, tensor in module.named_parameters(recurse=False): for tensor_name, _ in module.named_parameters(recurse=False):
tensor_fqn = ( tensor_fqn = (
module_fqn + ("." if module_fqn != "" else "") + tensor_name module_fqn + ("." if module_fqn != "" else "") + tensor_name
) )

View File

@ -269,7 +269,6 @@ class TestBaseStructuredSparsifier(TestCase):
def _test_step_linear_on_device(self, model, device): def _test_step_linear_on_device(self, model, device):
model = model.to(device) model = model.to(device)
x = torch.ones(7, 7, device=device)
pruner = SimplePruner(None) pruner = SimplePruner(None)
pruner.prepare(model, None) pruner.prepare(model, None)
pruner.enable_mask_update = True pruner.enable_mask_update = True
@ -808,7 +807,7 @@ class TestBaseStructuredSparsifier(TestCase):
pruned_model = fx_pruner.prune() pruned_model = fx_pruner.prune()
pruned_model.eval() pruned_model.eval()
out_pruned, lstm_out_pruned = pruned_model(lstm_input) out_pruned, lstm_out_pruned = pruned_model(lstm_input)
r, c = lstm_out_expected.size() _, c = lstm_out_expected.size()
# We cannot check that y_expected == y_pruned as usual because # We cannot check that y_expected == y_pruned as usual because
# zeros vs. missing elements yield different numerical results. # zeros vs. missing elements yield different numerical results.
@ -891,7 +890,7 @@ class TestBaseStructuredSparsifier(TestCase):
pruned_model = fx_pruner.prune() pruned_model = fx_pruner.prune()
pruned_model.eval() pruned_model.eval()
out_pruned, lstm_out_pruned = pruned_model(lstm_input) out_pruned, lstm_out_pruned = pruned_model(lstm_input)
r, c = lstm_out_expected.size() _, c = lstm_out_expected.size()
# We cannot check that y_expected == y_pruned as usual because # We cannot check that y_expected == y_pruned as usual because
# zeros vs. missing elements yield different numerical results. # zeros vs. missing elements yield different numerical results.

View File

@ -670,7 +670,7 @@ class TestAutogradFunctional(TestCase):
x = ctors.randn(3) x = ctors.randn(3)
with warnings.catch_warnings(record=True) as wa: with warnings.catch_warnings(record=True) as wa:
result = api(foo, x, vectorize=True) api(foo, x, vectorize=True)
self.assertEqual(len(wa), 0) self.assertEqual(len(wa), 0)
@base_and_logging_tensor @base_and_logging_tensor
@ -762,7 +762,7 @@ class TestAutogradFunctional(TestCase):
inp = ctors.rand(4) inp = ctors.rand(4)
with self.assertRaisesRegex(RuntimeError, "not supported together"): with self.assertRaisesRegex(RuntimeError, "not supported together"):
res = autogradF.jacobian(foo, inp, strict=True, vectorize=True) autogradF.jacobian(foo, inp, strict=True, vectorize=True)
@base_and_logging_tensor @base_and_logging_tensor
def test_jacobian_no_grad(self, ctors): def test_jacobian_no_grad(self, ctors):
@ -1122,7 +1122,7 @@ class TestAutogradFunctional(TestCase):
inp = ctors.rand(4) inp = ctors.rand(4)
with self.assertRaisesRegex(RuntimeError, "not supported together"): with self.assertRaisesRegex(RuntimeError, "not supported together"):
res = autogradF.hessian(foo, inp, strict=True, vectorize=True) autogradF.hessian(foo, inp, strict=True, vectorize=True)
@base_and_logging_tensor @base_and_logging_tensor
def test_hessian_no_grad(self, ctors): def test_hessian_no_grad(self, ctors):

View File

@ -18,7 +18,7 @@ def main():
data = torch.randn(10, 50).cuda() data = torch.randn(10, 50).cuda()
model = Model().cuda() model = Model().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001) optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
for i in range(10): for _ in range(10):
optimizer.zero_grad() optimizer.zero_grad()
loss = model(data) loss = model(data)
loss.backward() loss.backward()

View File

@ -78,9 +78,9 @@ def forward(self, arg0_1):
x = torch.randn(3, device="meta") x = torch.randn(3, device="meta")
self.assertNotIn("my_custom_ops2", sys.modules.keys()) self.assertNotIn("my_custom_ops2", sys.modules.keys())
with self.assertRaisesRegex(NotImplementedError, r"'my_custom_ops2'"): with self.assertRaisesRegex(NotImplementedError, r"'my_custom_ops2'"):
y = torch.ops.custom.sin.default(x) torch.ops.custom.sin.default(x)
torch.ops.import_module("my_custom_ops2") torch.ops.import_module("my_custom_ops2")
y = torch.ops.custom.sin.default(x) torch.ops.custom.sin.default(x)
def test_calling_custom_op_string(self): def test_calling_custom_op_string(self):
output = ops.custom.op2("abc", "def") output = ops.custom.op2("abc", "def")

View File

@ -35,7 +35,7 @@ class _TestClipGradNormBase(FSDPTest):
vector_norm_fn = functools.partial(torch.linalg.vector_norm, ord=norm_type) vector_norm_fn = functools.partial(torch.linalg.vector_norm, ord=norm_type)
dp_mesh = dp_mesh or init_device_mesh("cuda", (self.world_size,)) dp_mesh = dp_mesh or init_device_mesh("cuda", (self.world_size,))
torch.manual_seed(42 + dp_mesh.get_local_rank() + 1) torch.manual_seed(42 + dp_mesh.get_local_rank() + 1)
for iter_idx in range(10): for _ in range(10):
ref_optim.zero_grad() ref_optim.zero_grad()
ref_model(inp).sum().backward() ref_model(inp).sum().backward()
optim.zero_grad() optim.zero_grad()

View File

@ -250,8 +250,8 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
self.assertEqual(group.size(), self.world_size) self.assertEqual(group.size(), self.world_size)
all_reduce_stream = torch.cuda.Stream() all_reduce_stream = torch.cuda.Stream()
( (
reduce_scatter_input, _,
reduce_scatter_event, _,
post_reduce_event, post_reduce_event,
_, _,
_, _,
@ -406,7 +406,7 @@ class TestFullyShardCommunication(FSDPTest):
torch.manual_seed(42 + self.rank) torch.manual_seed(42 + self.rank)
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda") inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
for iter_idx in range(10): for _ in range(10):
ref_loss = ref_model(inp).sum() ref_loss = ref_model(inp).sum()
ref_loss.backward() ref_loss.backward()
for param in ref_model.parameters(): for param in ref_model.parameters():
@ -501,7 +501,7 @@ class TestFullyShardPrefetch(FSDPTest):
self, reshard_after_forward: Union[bool, int], checkpoint_impl: Optional[str] self, reshard_after_forward: Union[bool, int], checkpoint_impl: Optional[str]
): ):
n_layers = 3 n_layers = 3
model, optim, inp = self._init_transformer( model, _, inp = self._init_transformer(
n_layers, reshard_after_forward, checkpoint_impl n_layers, reshard_after_forward, checkpoint_impl
) )
events: List[EventType] = [] events: List[EventType] = []
@ -843,7 +843,7 @@ class TestFullyShardPrefetch(FSDPTest):
with patch_unshard(unshard_with_record), patch_post_backward( with patch_unshard(unshard_with_record), patch_post_backward(
post_backward_with_record post_backward_with_record
): ):
for iter_idx in range(3): for _ in range(3):
loss = model(inp) loss = model(inp)
expected_events = [ expected_events = [
( (
@ -922,7 +922,7 @@ class TestFullyShardPrefetch(FSDPTest):
with patch_unshard(unshard_with_record), patch_post_backward( with patch_unshard(unshard_with_record), patch_post_backward(
post_backward_with_record post_backward_with_record
): ):
for iter_idx in range(3): for _ in range(3):
loss = model(inp) loss = model(inp)
expected_events = [ expected_events = [
("unshard", "", TrainingState.FORWARD), ("unshard", "", TrainingState.FORWARD),

View File

@ -662,7 +662,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
def __init__(self, n_layers): def __init__(self, n_layers):
super().__init__() super().__init__()
self.layers = torch.nn.ModuleList() self.layers = torch.nn.ModuleList()
for layer_id in range(n_layers): for _ in range(n_layers):
self.layers.append(TestSubmodule(hidden_dim)) self.layers.append(TestSubmodule(hidden_dim))
def forward(self, x): def forward(self, x):
@ -684,7 +684,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
fsdp_config = {} fsdp_config = {}
mesh = init_device_mesh("cuda", (self.world_size,)) mesh = init_device_mesh("cuda", (self.world_size,))
model = TestModule(n_layers=3) model = TestModule(n_layers=3)
for layer_id, mod in enumerate(model.layers): for mod in model.layers:
fully_shard(mod, mesh=mesh, reshard_after_forward=True, **fsdp_config) fully_shard(mod, mesh=mesh, reshard_after_forward=True, **fsdp_config)
model = fully_shard( model = fully_shard(
model, mesh=mesh, reshard_after_forward=True, **fsdp_config model, mesh=mesh, reshard_after_forward=True, **fsdp_config
@ -871,7 +871,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
else: else:
v.requires_grad_(False) v.requires_grad_(False)
assert requires_grad_param_count == n_layers * len(requires_grad_params) assert requires_grad_param_count == n_layers * len(requires_grad_params)
for layer_id, mod in enumerate(model.layers): for _, mod in enumerate(model.layers):
fully_shard(mod, mesh=mesh, reshard_after_forward=True, **fsdp_config) fully_shard(mod, mesh=mesh, reshard_after_forward=True, **fsdp_config)
model = fully_shard( model = fully_shard(
model, mesh=mesh, reshard_after_forward=True, **fsdp_config model, mesh=mesh, reshard_after_forward=True, **fsdp_config
@ -1087,7 +1087,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
setattr(m.encoder, name, new_child) setattr(m.encoder, name, new_child)
m = FSDP(m, sharding_strategy=ShardingStrategy.FULL_SHARD, use_orig_params=True) m = FSDP(m, sharding_strategy=ShardingStrategy.FULL_SHARD, use_orig_params=True)
inp = torch.randn(32, 784, device="cuda") inp = torch.randn(32, 784, device="cuda")
out = m(inp) m(inp)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -241,7 +241,7 @@ class TestFullyShardAllGatherExtensionsMultiProcess(
losses.append(_model(inp).sum()) losses.append(_model(inp).sum())
losses[-1].backward() losses[-1].backward()
if _model is ref_model: if _model is ref_model:
for param_name, param in _model.named_parameters(): for _, param in _model.named_parameters():
dist.all_reduce(param.grad) dist.all_reduce(param.grad)
param.grad.detach().div_(self.world_size) param.grad.detach().div_(self.world_size)
self.assertEqual(losses[0], losses[1]) self.assertEqual(losses[0], losses[1])

View File

@ -904,7 +904,7 @@ class TestFullyShardProcessGroupInit(FSDPTestMultiThread):
) )
self.assertEqual(mesh.mesh, ref_mesh.mesh) self.assertEqual(mesh.mesh, ref_mesh.mesh)
self.assertEqual(mesh._coordinate_on_dim, ref_mesh._coordinate_on_dim) self.assertEqual(mesh._coordinate_on_dim, ref_mesh._coordinate_on_dim)
for (tag, ranks, group_name), (ref_tag, ref_ranks, ref_group_name) in zip( for (_, ranks, _), (_, ref_ranks, _) in zip(
mesh._dim_group_infos, ref_mesh._dim_group_infos mesh._dim_group_infos, ref_mesh._dim_group_infos
): ):
# Since we manually constructed new subgroups, the test and ref # Since we manually constructed new subgroups, the test and ref

View File

@ -26,7 +26,7 @@ class LoggingTests(LoggingTestCase):
env["WORLD_SIZE"] = "1" env["WORLD_SIZE"] = "1"
env["MASTER_PORT"] = "34715" env["MASTER_PORT"] = "34715"
env["MASTER_ADDR"] = "localhost" env["MASTER_ADDR"] = "localhost"
stdout, stderr = self.run_process_no_exception( _, stderr = self.run_process_no_exception(
"""\ """\
import logging import logging
import torch import torch

View File

@ -590,7 +590,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
torch.manual_seed(42 + self.rank) torch.manual_seed(42 + self.rank)
inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda") inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda")
for iter_idx in range(10): for _ in range(10):
losses: List[torch.Tensor] = [] losses: List[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)): for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad() _optim.zero_grad()
@ -624,12 +624,12 @@ class TestFullyShard1DTrainingCore(FSDPTest):
# sync point after each iteration # sync point after each iteration
ref_losses: List[torch.Tensor] = [] ref_losses: List[torch.Tensor] = []
losses: List[torch.Tensor] = [] losses: List[torch.Tensor] = []
for iter_idx in range(10): for _ in range(10):
ref_optim.zero_grad() ref_optim.zero_grad()
ref_losses.append(ref_model(inp).sum()) ref_losses.append(ref_model(inp).sum())
ref_losses[-1].backward() ref_losses[-1].backward()
ref_optim.step() ref_optim.step()
for iter_idx in range(10): for _ in range(10):
optim.zero_grad() optim.zero_grad()
losses.append(model(inp).sum()) losses.append(model(inp).sum())
losses[-1].backward() losses[-1].backward()
@ -1185,7 +1185,7 @@ class TestFullyShardNDTraining(FSDPTest):
foreach: bool, foreach: bool,
): ):
global_mesh = self.init_global_mesh() global_mesh = self.init_global_mesh()
pp_mesh, dp_mesh, tp_mesh = ( _, dp_mesh, tp_mesh = (
global_mesh["pp"], global_mesh["pp"],
global_mesh["dp"], global_mesh["dp"],
global_mesh["tp"], global_mesh["tp"],
@ -1217,7 +1217,7 @@ class TestFullyShardNDTraining(FSDPTest):
_optim.step() _optim.step()
self.assertEqual(losses[0], losses[1]) self.assertEqual(losses[0], losses[1])
for n, p in model.named_parameters(): for _, p in model.named_parameters():
self.assertIsInstance(p, DTensor) self.assertIsInstance(p, DTensor)
self.assertEqual(p.device_mesh.ndim, 2) self.assertEqual(p.device_mesh.ndim, 2)
self.assertEqual(len(p.placements), 2) self.assertEqual(len(p.placements), 2)
@ -1288,7 +1288,7 @@ class TestFullyShardHSDP3DTraining(FSDPTest):
_optim.step() _optim.step()
self.assertEqual(losses[0], losses[1]) self.assertEqual(losses[0], losses[1])
for n, p in model.named_parameters(): for _, p in model.named_parameters():
self.assertIsInstance(p, DTensor) self.assertIsInstance(p, DTensor)
self.assertEqual(p.device_mesh.ndim, 3) self.assertEqual(p.device_mesh.ndim, 3)
self.assertEqual(len(p.placements), 3) self.assertEqual(len(p.placements), 3)

View File

@ -119,7 +119,6 @@ class TestCheckpoint(TestCase):
# no checkpoint # no checkpoint
with MemoryDelta(x.device) as mem1: with MemoryDelta(x.device) as mem1:
loss1 = net1(x1).sum() loss1 = net1(x1).sum()
graph_size1 = self._get_graph_size(loss1)
loss1.backward() loss1.backward()
# with checkpoint # with checkpoint

View File

@ -244,7 +244,6 @@ class TestFullyShard2DTraining(FSDPTest):
ref_model.parameters(), model.named_parameters() ref_model.parameters(), model.named_parameters()
): ):
full_grad = param.grad.full_tensor() full_grad = param.grad.full_tensor()
ref_grad = ref_param.grad
self.assertEqual(ref_param.grad, full_grad) self.assertEqual(ref_param.grad, full_grad)
ref_optim.step() ref_optim.step()
@ -285,7 +284,7 @@ class TestFullyShard2DTraining(FSDPTest):
# called, but they will just be no-ops without issuing any kernels. # called, but they will just be no-ops without issuing any kernels.
# We prefer to keep the no-op check at the c10d level, not in FSDP. # We prefer to keep the no-op check at the c10d level, not in FSDP.
inp = torch.randn((4, mlp_dim), device="cuda") # same on all ranks inp = torch.randn((4, mlp_dim), device="cuda") # same on all ranks
for iter_idx in range(10): for _ in range(10):
ref_optim.zero_grad() ref_optim.zero_grad()
optim.zero_grad() optim.zero_grad()
@ -583,9 +582,7 @@ class TestNew2dParallelTraining(DTensorTestBase):
"net1": ColwiseParallel(), "net1": ColwiseParallel(),
"net2": RowwiseParallel(), "net2": RowwiseParallel(),
} }
model_2d = parallelize_module( parallelize_module(SimpleModel().cuda(), mesh_2d["tp"], parallelize_plan)
SimpleModel().cuda(), mesh_2d["tp"], parallelize_plan
)
@with_comms @with_comms
@skip_if_lt_x_gpu(4) @skip_if_lt_x_gpu(4)
@ -833,7 +830,6 @@ class TestNew2dParallelStateDict(DTensorTestBase):
# Create a model without wrapper # Create a model without wrapper
torch.manual_seed(0) torch.manual_seed(0)
no_wrap_model = simple_model().cuda(self.rank) no_wrap_model = simple_model().cuda(self.rank)
no_wrap_state_dict = no_wrap_model.state_dict()
no_wrap_optim = torch.optim.Adam(no_wrap_model.parameters(), lr=0.01) no_wrap_optim = torch.optim.Adam(no_wrap_model.parameters(), lr=0.01)
no_wrap_model(no_wrap_model.get_input().cuda(self.rank)).sum().backward() no_wrap_model(no_wrap_model.get_input().cuda(self.rank)).sum().backward()
no_wrap_optim.step() no_wrap_optim.step()
@ -890,8 +886,6 @@ class TestNew2dParallelStateDict(DTensorTestBase):
set_optimizer_state_dict( set_optimizer_state_dict(
model_2d, optimizers=optim_2d, optim_state_dict=ref_optim_2d_osd model_2d, optimizers=optim_2d, optim_state_dict=ref_optim_2d_osd
) )
new_optim_2d_osd = get_optimizer_state_dict(model_2d, optimizers=optim_2d)
ref_optim_2d_osd_states = ref_optim_2d_osd["state"] ref_optim_2d_osd_states = ref_optim_2d_osd["state"]
new_optim_2d_osd_states = optim_2d_osd["state"] new_optim_2d_osd_states = optim_2d_osd["state"]

View File

@ -119,7 +119,7 @@ class ComposabilityTest(MultiProcessTestCase):
) )
@parametrize("use_new_runtime", [False, True]) @parametrize("use_new_runtime", [False, True])
def test_manual_with_data_parallel(self, dp_type, ScheduleClass, use_new_runtime): def test_manual_with_data_parallel(self, dp_type, ScheduleClass, use_new_runtime):
device = torch.device("cuda", self.device) _device_raii = torch.device("cuda", self.device)
torch.cuda.set_device(self.device) torch.cuda.set_device(self.device)
store = torch.distributed.FileStore(self.file_name, self.world_size) store = torch.distributed.FileStore(self.file_name, self.world_size)
torch.distributed.init_process_group( torch.distributed.init_process_group(
@ -398,7 +398,7 @@ class ComposabilityTest(MultiProcessTestCase):
], ],
) )
def test_3d_with_tp_dp_pp(self, ScheduleClass, MixedPrecisionParam): def test_3d_with_tp_dp_pp(self, ScheduleClass, MixedPrecisionParam):
device = torch.device("cuda", self.device) _device_raii = torch.device("cuda", self.device)
torch.cuda.set_device(self.device) torch.cuda.set_device(self.device)
store = torch.distributed.FileStore(self.file_name, self.world_size) store = torch.distributed.FileStore(self.file_name, self.world_size)
torch.distributed.init_process_group( torch.distributed.init_process_group(

View File

@ -329,11 +329,11 @@ class ReplicateTest(MultiProcessInductorTestCase):
code = self._test_bucketing() code = self._test_bucketing()
self.assertEqual(counters["inductor"]["ddp_buckets"], 3) self.assertEqual(counters["inductor"]["ddp_buckets"], 3)
fc = FileCheck() fc = FileCheck()
for i in range(3): for _ in range(3):
fc.check("cpp_fused_").check( fc.check("cpp_fused_").check(
"torch.ops._c10d_functional.all_reduce_coalesced_.default(" "torch.ops._c10d_functional.all_reduce_coalesced_.default("
) )
for i in range(3): for _ in range(3):
fc.check("torch.ops._c10d_functional.wait_tensor.default") fc.check("torch.ops._c10d_functional.wait_tensor.default")
fc.run(code) fc.run(code)
@ -342,11 +342,11 @@ class ReplicateTest(MultiProcessInductorTestCase):
code = self._test_bucketing(init_process_group=False, loop=2) code = self._test_bucketing(init_process_group=False, loop=2)
self.assertEqual(counters["inductor"]["ddp_buckets"], 3) self.assertEqual(counters["inductor"]["ddp_buckets"], 3)
fc = FileCheck() fc = FileCheck()
for i in range(3): for _ in range(3):
fc.check("cpp_fused_").check( fc.check("cpp_fused_").check(
"torch.ops._c10d_functional.all_reduce_coalesced_.default(" "torch.ops._c10d_functional.all_reduce_coalesced_.default("
) )
for i in range(3): for _ in range(3):
fc.check("torch.ops._c10d_functional.wait_tensor.default") fc.check("torch.ops._c10d_functional.wait_tensor.default")
fc.run(code) fc.run(code)
@ -371,11 +371,11 @@ class ReplicateTest(MultiProcessInductorTestCase):
code = self._test_bucketing() code = self._test_bucketing()
self.assertEqual(counters["inductor"]["ddp_buckets"], 3) self.assertEqual(counters["inductor"]["ddp_buckets"], 3)
fc = FileCheck() fc = FileCheck()
for i in range(3): for _ in range(3):
fc.check("aten.flatten.using_ints(").check("cpp_fused_").check( fc.check("aten.flatten.using_ints(").check("cpp_fused_").check(
"torch.ops._c10d_functional.all_reduce_.default(" "torch.ops._c10d_functional.all_reduce_.default("
) )
for i in range(3): for _ in range(3):
fc.check("torch.ops._c10d_functional.wait_tensor.default") fc.check("torch.ops._c10d_functional.wait_tensor.default")
fc.run(code) fc.run(code)
@ -383,11 +383,11 @@ class ReplicateTest(MultiProcessInductorTestCase):
code = self._test_bucketing(init_process_group=False, loop=2) code = self._test_bucketing(init_process_group=False, loop=2)
self.assertEqual(counters["inductor"]["ddp_buckets"], 3) self.assertEqual(counters["inductor"]["ddp_buckets"], 3)
fc = FileCheck() fc = FileCheck()
for i in range(3): for _ in range(3):
fc.check("aten.flatten.using_ints(").check("cpp_fused_").check( fc.check("aten.flatten.using_ints(").check("cpp_fused_").check(
"torch.ops._c10d_functional.all_reduce_.default(" "torch.ops._c10d_functional.all_reduce_.default("
) )
for i in range(3): for _ in range(3):
fc.check("torch.ops._c10d_functional.wait_tensor.default") fc.check("torch.ops._c10d_functional.wait_tensor.default")
fc.run(code) fc.run(code)

View File

@ -129,7 +129,7 @@ class TestShardedTensorBinaryOps(ShardedTensorTestBase):
def test_torch_equal(self): def test_torch_equal(self):
"""Test torch.equal(ShardedTensor, ShardedTensor)""" """Test torch.equal(ShardedTensor, ShardedTensor)"""
spec, alt_spec = self.get_gpu_specs() spec, _ = self.get_gpu_specs()
st1, st2 = self.get_random_tensors(spec, spec, 10, 10) st1, st2 = self.get_random_tensors(spec, spec, 10, 10)
self.assertTrue(torch.equal(st1, st2)) self.assertTrue(torch.equal(st1, st2))
@ -145,7 +145,7 @@ class TestShardedTensorBinaryOps(ShardedTensorTestBase):
def test_torch_allclose(self): def test_torch_allclose(self):
"""Test torch.allclose(ShardedTensor, ShardedTensor)""" """Test torch.allclose(ShardedTensor, ShardedTensor)"""
spec, alt_spec = self.get_gpu_specs() spec, _ = self.get_gpu_specs()
st1, st2 = self.get_random_tensors(spec, spec, 10, 10) st1, st2 = self.get_random_tensors(spec, spec, 10, 10)
self.assertTrue(torch.allclose(st1, st2)) self.assertTrue(torch.allclose(st1, st2))

View File

@ -40,8 +40,6 @@ class TestShardedTensorNNInit(ShardedTensorTestBase):
], ],
) )
h, w = 8, 2 h, w = 8, 2
expected_h = 2
expected_device = torch.device(f"cuda:{self.rank}")
a, b = 10, 20 a, b = 10, 20
seed = 1234 seed = 1234
@ -75,8 +73,6 @@ class TestShardedTensorNNInit(ShardedTensorTestBase):
], ],
) )
h, w = 8, 2 h, w = 8, 2
expected_h = 2
expected_device = torch.device(f"cuda:{self.rank}")
mean, std = 10, 5 mean, std = 10, 5
seed = 1234 seed = 1234
@ -110,8 +106,6 @@ class TestShardedTensorNNInit(ShardedTensorTestBase):
], ],
) )
h, w = 8, 2 h, w = 8, 2
expected_h = 2
expected_device = torch.device(f"cuda:{self.rank}")
a, mode, nonlinearity = 0, "fan_in", "leaky_relu" a, mode, nonlinearity = 0, "fan_in", "leaky_relu"
seed = 1234 seed = 1234

View File

@ -456,7 +456,7 @@ class TestLocalTensor(ShardedTensorTestBase):
with self.assertRaisesRegex( with self.assertRaisesRegex(
NotImplementedError, "Only single local shard is supported." NotImplementedError, "Only single local shard is supported."
): ):
local_shard = st.local_tensor() st.local_tensor()
class TestShardedTensorChunked(ShardedTensorTestBase): class TestShardedTensorChunked(ShardedTensorTestBase):
@ -981,7 +981,6 @@ class TestShardedTensorChunked(ShardedTensorTestBase):
# Validate remote shards. # Validate remote shards.
remote_shards = st.remote_shards() remote_shards = st.remote_shards()
self.assertEqual(3, len(remote_shards)) self.assertEqual(3, len(remote_shards))
owners = {}
for rpc_rank, shards in remote_shards.items(): for rpc_rank, shards in remote_shards.items():
self.assertEqual(2, len(shards)) self.assertEqual(2, len(shards))
for remote_shard in shards: for remote_shard in shards:
@ -1364,14 +1363,14 @@ class TestShardedTensorChunked(ShardedTensorTestBase):
with self.assertRaisesRegex(RuntimeError, "Local rank at save time was"): with self.assertRaisesRegex(RuntimeError, "Local rank at save time was"):
with load_with_process_group(pg): with load_with_process_group(pg):
# ShardedTensor weights_only is already tested in TestFSDPStateDict.test_torch_save_load # ShardedTensor weights_only is already tested in TestFSDPStateDict.test_torch_save_load
state_dict_deser = torch.load(buffer, weights_only=False) torch.load(buffer, weights_only=False)
else: else:
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, "Local world size at save time was" RuntimeError, "Local world size at save time was"
): ):
with load_with_process_group(pg): with load_with_process_group(pg):
# ShardedTensor weights_only is already tested in TestFSDPStateDict.test_torch_save_load # ShardedTensor weights_only is already tested in TestFSDPStateDict.test_torch_save_load
state_dict_deser = torch.load(buffer, weights_only=False) torch.load(buffer, weights_only=False)
dist.destroy_process_group() dist.destroy_process_group()
buffer.seek(0) buffer.seek(0)
@ -1379,7 +1378,7 @@ class TestShardedTensorChunked(ShardedTensorTestBase):
RuntimeError, "Need to initialize default process group" RuntimeError, "Need to initialize default process group"
): ):
# ShardedTensor weights_only is already tested in TestFSDPStateDict.test_torch_save_load # ShardedTensor weights_only is already tested in TestFSDPStateDict.test_torch_save_load
state_dict_deser = torch.load(buffer, weights_only=False) torch.load(buffer, weights_only=False)
rpc.shutdown() rpc.shutdown()
@with_comms @with_comms
@ -1396,8 +1395,8 @@ class TestShardedTensorChunked(ShardedTensorTestBase):
"rank:3/cuda:3", "rank:3/cuda:3",
], ],
) )
st1 = sharded_tensor.empty(spec, 10, 20, init_rrefs=True) sharded_tensor.empty(spec, 10, 20, init_rrefs=True)
st2 = sharded_tensor.empty(spec, 10, 20) sharded_tensor.empty(spec, 10, 20)
create_tensors() create_tensors()
self.assertEqual(0, len(sharded_tensor.api._sharded_tensor_map)) self.assertEqual(0, len(sharded_tensor.api._sharded_tensor_map))
@ -2204,7 +2203,6 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase):
else: else:
self.assertEqual(2, len(remote_shards)) self.assertEqual(2, len(remote_shards))
owners = {}
for rpc_rank, shards in remote_shards.items(): for rpc_rank, shards in remote_shards.items():
self.assertEqual(2, len(shards)) self.assertEqual(2, len(shards))
for remote_shard in shards: for remote_shard in shards:
@ -2418,10 +2416,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
placement=f"rank:{self.rank}/cuda:{self.rank}", placement=f"rank:{self.rank}/cuda:{self.rank}",
) )
with self.assertRaisesRegex(ValueError, "Shard tensor size does not match"): with self.assertRaisesRegex(ValueError, "Shard tensor size does not match"):
local_shard_from_wrong_meta = sharded_tensor.Shard( sharded_tensor.Shard(local_tensor, metadata=wrong_local_shard_metadata)
local_tensor,
metadata=wrong_local_shard_metadata,
)
@with_comms @with_comms
@skip_if_lt_x_gpu(4) @skip_if_lt_x_gpu(4)
@ -2696,7 +2691,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
empty_local_shards = [] empty_local_shards = []
with self.assertRaisesRegex(ValueError, "have no local shards on all ranks"): with self.assertRaisesRegex(ValueError, "have no local shards on all ranks"):
st = sharded_tensor.init_from_local_shards( sharded_tensor.init_from_local_shards(
empty_local_shards, [10, 10], init_rrefs=True empty_local_shards, [10, 10], init_rrefs=True
) )
@ -2706,7 +2701,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, "Only torch.strided layout is currently supported" ValueError, "Only torch.strided layout is currently supported"
): ):
st = sharded_tensor.init_from_local_shards( sharded_tensor.init_from_local_shards(
wrong_layout_shards, [10, 10], init_rrefs=True wrong_layout_shards, [10, 10], init_rrefs=True
) )
@ -2719,23 +2714,19 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
ValueError, ValueError,
"Only torch.contiguous_format memory_format is currently supported", "Only torch.contiguous_format memory_format is currently supported",
): ):
st = sharded_tensor.init_from_local_shards( sharded_tensor.init_from_local_shards(
wrong_memory_format_shards, [10, 10], init_rrefs=True wrong_memory_format_shards, [10, 10], init_rrefs=True
) )
with self.assertRaisesRegex(ValueError, "Shard tensor size does not match"): with self.assertRaisesRegex(ValueError, "Shard tensor size does not match"):
wrong_size_shards = [ sharded_tensor.Shard(
sharded_tensor.Shard( torch.randn(2, 3, device=f"cuda:{self.rank}"), local_shard_metadata
torch.randn(2, 3, device=f"cuda:{self.rank}"), local_shard_metadata )
)
]
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, "Local shard tensor device does not match" ValueError, "Local shard tensor device does not match"
): ):
wrong_device_shards = [ sharded_tensor.Shard(torch.randn(5, 5), local_shard_metadata)
sharded_tensor.Shard(torch.randn(5, 5), local_shard_metadata)
]
@with_comms @with_comms
@skip_if_lt_x_gpu(4) @skip_if_lt_x_gpu(4)
@ -2756,7 +2747,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
ValueError, ValueError,
"ShardedTensor global_size property does not match from different ranks!", "ShardedTensor global_size property does not match from different ranks!",
): ):
st = sharded_tensor.init_from_local_shards( sharded_tensor.init_from_local_shards(
wrong_dtype_shards, tensor_overall_size, init_rrefs=True wrong_dtype_shards, tensor_overall_size, init_rrefs=True
) )
@ -2771,7 +2762,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
ValueError, ValueError,
"ShardedTensor dtype property does not match from different ranks!", "ShardedTensor dtype property does not match from different ranks!",
): ):
st = sharded_tensor.init_from_local_shards( sharded_tensor.init_from_local_shards(
wrong_dtype_shards, [10, 10], init_rrefs=True wrong_dtype_shards, [10, 10], init_rrefs=True
) )
@ -2788,7 +2779,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
ValueError, ValueError,
"ShardedTensor requires_grad property does not match from different ranks!", "ShardedTensor requires_grad property does not match from different ranks!",
): ):
st = sharded_tensor.init_from_local_shards( sharded_tensor.init_from_local_shards(
wrong_requires_grad_shards, [10, 10], init_rrefs=True wrong_requires_grad_shards, [10, 10], init_rrefs=True
) )
@ -2818,7 +2809,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, "Local shards' tensor pin_memory property need to be the same" ValueError, "Local shards' tensor pin_memory property need to be the same"
): ):
st = sharded_tensor.init_from_local_shards( sharded_tensor.init_from_local_shards(
wrong_pin_memory_local_shards, [10, 10], init_rrefs=True wrong_pin_memory_local_shards, [10, 10], init_rrefs=True
) )
@ -2832,7 +2823,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
ValueError, ValueError,
"ShardedTensor pin_memory property does not match from different ranks!", "ShardedTensor pin_memory property does not match from different ranks!",
): ):
st = sharded_tensor.init_from_local_shards( sharded_tensor.init_from_local_shards(
wrong_pin_memory_shards_cross_ranks, [10, 10], init_rrefs=True wrong_pin_memory_shards_cross_ranks, [10, 10], init_rrefs=True
) )
@ -2945,19 +2936,15 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, "Shard tensor size does not match with metadata.shard_lengths" ValueError, "Shard tensor size does not match with metadata.shard_lengths"
): ):
wrong_size_shards = [ sharded_tensor.Shard(
sharded_tensor.Shard( torch.randn(2, 3, device=f"cuda:{self.rank}"), local_shard_metadata
torch.randn(2, 3, device=f"cuda:{self.rank}"), local_shard_metadata )
)
]
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, ValueError,
"Local shard tensor device does not match with local Shard's placement", "Local shard tensor device does not match with local Shard's placement",
): ):
wrong_device_shards = [ sharded_tensor.Shard(torch.randn(5, 5), local_shard_metadata)
sharded_tensor.Shard(torch.randn(5, 5), local_shard_metadata)
]
wrong_dtype_shards = [ wrong_dtype_shards = [
sharded_tensor.Shard( sharded_tensor.Shard(

View File

@ -42,7 +42,7 @@ class ChunkAllShardingPlanner(ShardingPlanner):
def build_plan(self, module: nn.Module) -> ShardingPlan: def build_plan(self, module: nn.Module) -> ShardingPlan:
named_params = module.named_parameters() named_params = module.named_parameters()
plan = {} plan = {}
for name, param in named_params: for name, _ in named_params:
plan[name] = ChunkShardingSpec(self.dim, placements=self.devices) plan[name] = ChunkShardingSpec(self.dim, placements=self.devices)
return ShardingPlan(plan=plan) return ShardingPlan(plan=plan)

View File

@ -92,7 +92,6 @@ class TestCommMode(TestCase):
self.assertEqual(comm_counts[c10d_functional.reduce_scatter_tensor], 1) self.assertEqual(comm_counts[c10d_functional.reduce_scatter_tensor], 1)
def test_comm_mode_with_dtensor(self): def test_comm_mode_with_dtensor(self):
world_pg = self.world_pg
mesh = DeviceMesh(self.device_type, list(range(self.world_size))) mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
def f(x, y): def f(x, y):
@ -118,8 +117,6 @@ class TestCommMode(TestCase):
if not torch.cuda.is_available(): if not torch.cuda.is_available():
return return
world_pg = self.world_pg
inp = torch.rand(2, 8, 16).cuda() inp = torch.rand(2, 8, 16).cuda()
all_gather_out = inp.new_empty(self.world_size * 2, 8, 16) all_gather_out = inp.new_empty(self.world_size * 2, 8, 16)
@ -202,7 +199,7 @@ class TestCommMode(TestCase):
self.checksAssert(comm_mode, c10d_ops.reduce_scatter_, 1, 1) self.checksAssert(comm_mode, c10d_ops.reduce_scatter_, 1, 1)
# tests c10d reduce_scatter_tensor_coalesced # tests c10d reduce_scatter_tensor_coalesced
with comm_mode as A, dist._coalescing_manager() as B: with comm_mode, dist._coalescing_manager():
dist.reduce_scatter_tensor(all_gather_out, inp) dist.reduce_scatter_tensor(all_gather_out, inp)
self.checksAssert(comm_mode, c10d_ops.reduce_scatter_tensor_coalesced_, 1, 1) self.checksAssert(comm_mode, c10d_ops.reduce_scatter_tensor_coalesced_, 1, 1)

View File

@ -251,7 +251,7 @@ class TestCommModeFeatures(DTensorTestBase):
comm_mode.comm_module_counts, comm_mode.comm_module_counts,
{"Global": {"forward": {}, "backward": {}}}, {"Global": {"forward": {}, "backward": {}}},
) )
output_tp = model(inp) model(inp)
model_args = ModelArgs(dropout_p=0.0) model_args = ModelArgs(dropout_p=0.0)
model2 = Transformer(model_args).to(device=self.device_type) model2 = Transformer(model_args).to(device=self.device_type)
@ -264,7 +264,7 @@ class TestCommModeFeatures(DTensorTestBase):
comm_mode = CommDebugMode() comm_mode = CommDebugMode()
with comm_mode: with comm_mode:
output = model2(inp) model2(inp)
# checks to see if all collectives were correctly traced at the module-level # checks to see if all collectives were correctly traced at the module-level
self.assertEqual( self.assertEqual(

View File

@ -155,14 +155,12 @@ class DTensorTest(DTensorTestBase):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard0_spec = [Shard(0)] shard0_spec = [Shard(0)]
local_tensor = torch.randn(4, 8) local_tensor = torch.randn(4, 8)
global_shape = torch.Size([self.world_size * 4, 8])
dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard0_spec) dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard0_spec)
# won't affect stride # won't affect stride
self.assertEqual(dist_tensor.stride(), (8, 1)) self.assertEqual(dist_tensor.stride(), (8, 1))
shard1_spec = [Shard(1)] shard1_spec = [Shard(1)]
local_tensor = torch.randn(8, 4) local_tensor = torch.randn(8, 4)
global_shape = torch.Size([8, self.world_size * 4])
dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard1_spec) dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard1_spec)
# will affect stride after DT initialized # will affect stride after DT initialized
self.assertEqual(dist_tensor.stride(), (4 * self.world_size, 1)) self.assertEqual(dist_tensor.stride(), (4 * self.world_size, 1))
@ -170,7 +168,6 @@ class DTensorTest(DTensorTestBase):
# if initialized from a transposed mat # if initialized from a transposed mat
local_tensor = torch.randn(8, 4, 8) local_tensor = torch.randn(8, 4, 8)
local_tensor_t = local_tensor.permute(1, 2, 0) local_tensor_t = local_tensor.permute(1, 2, 0)
global_shape = torch.Size([4, self.world_size * 8, 8])
self.assertEqual(local_tensor_t.stride(), (8, 1, 32)) self.assertEqual(local_tensor_t.stride(), (8, 1, 32))
dist_tensor = DTensor.from_local(local_tensor_t, device_mesh, shard1_spec) dist_tensor = DTensor.from_local(local_tensor_t, device_mesh, shard1_spec)
global_stride = (8 * self.world_size, 1, 32 * self.world_size) global_stride = (8 * self.world_size, 1, 32 * self.world_size)
@ -257,7 +254,7 @@ class DTensorTest(DTensorTestBase):
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, "Please pass both shape and stride at the same time." RuntimeError, "Please pass both shape and stride at the same time."
): ):
dtensor = DTensor.from_local( DTensor.from_local(
tensor_list[self.rank], tensor_list[self.rank],
device_mesh, device_mesh,
(Shard(0),), (Shard(0),),
@ -267,7 +264,7 @@ class DTensorTest(DTensorTestBase):
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, "Please pass both shape and stride at the same time." RuntimeError, "Please pass both shape and stride at the same time."
): ):
dtensor = DTensor.from_local( DTensor.from_local(
tensor_list[self.rank], tensor_list[self.rank],
device_mesh, device_mesh,
(Shard(0),), (Shard(0),),
@ -1043,7 +1040,7 @@ class DTensorLogTest(LoggingTestCase):
env["MASTER_PORT"] = "12345" env["MASTER_PORT"] = "12345"
env["MASTER_ADDR"] = "localhost" env["MASTER_ADDR"] = "localhost"
stdout, stderr = self.run_process_no_exception( _, stderr = self.run_process_no_exception(
"""\ """\
import logging import logging
import torch import torch

View File

@ -234,8 +234,8 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
requires_grad=x.requires_grad, requires_grad=x.requires_grad,
) )
out = fn(x) fn(x)
out2 = torch.compile(fn, backend="eager")(x) torch.compile(fn, backend="eager")(x)
def test_dtensor_constructor_w_dynamo_disable(self): def test_dtensor_constructor_w_dynamo_disable(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
@ -599,7 +599,7 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
@torch.compile(backend=cnt) @torch.compile(backend=cnt)
def fn(x): def fn(x):
dt = DTensor.from_local(x, mesh, [placement], run_check=False) DTensor.from_local(x, mesh, [placement], run_check=False)
x = torch.ones(4, 4, requires_grad=True) x = torch.ones(4, 4, requires_grad=True)
@ -659,7 +659,7 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
x2 = x_dt.redistribute(mesh, [Replicate()], async_op=True) x2 = x_dt.redistribute(mesh, [Replicate()], async_op=True)
x2 = x2.to_local() x2 = x2.to_local()
self.assertTrue(isinstance(x2, AsyncCollectiveTensor)) self.assertTrue(isinstance(x2, AsyncCollectiveTensor))
out = opt_fn(x2) opt_fn(x2)
# The important part: we get a wait_tensor() in the graph. # The important part: we get a wait_tensor() in the graph.
# At runtime, the input to the graph is an AsyncCollectiveTensor, # At runtime, the input to the graph is an AsyncCollectiveTensor,
# and inside the graph we need to issue a wait() to synchronize. # and inside the graph we need to issue a wait() to synchronize.
@ -880,8 +880,6 @@ class TestDTensorCompileE2E(DTensorTestBase):
mesh_dim_names=["dp", "tp"], mesh_dim_names=["dp", "tp"],
) )
fsdp_pg = twod_mesh.get_group(mesh_dim=0)
inp = torch.rand(20, 10, device=self.device_type) inp = torch.rand(20, 10, device=self.device_type)
parallelize_plan = { parallelize_plan = {
"mlp_0.net1": ColwiseParallel(), "mlp_0.net1": ColwiseParallel(),

View File

@ -249,7 +249,7 @@ class DistTensorRandomOpTest(DTensorTestBase):
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
# seed synchronization happens after the first `distribute_tensor` call # seed synchronization happens after the first `distribute_tensor` call
dtensor = distribute_tensor( distribute_tensor(
torch.empty([self.world_size], device="cuda"), device_mesh, [Shard(0)] torch.empty([self.world_size], device="cuda"), device_mesh, [Shard(0)]
) )
self.assertEqual(seed_from_rank_0, random._rng_tracker.get_seed("parallel-rng")) self.assertEqual(seed_from_rank_0, random._rng_tracker.get_seed("parallel-rng"))

View File

@ -309,7 +309,7 @@ class RedistributeTest(DTensorTestBase):
shard_tensor = distribute_tensor(local_tensor, device_mesh, shard_spec) shard_tensor = distribute_tensor(local_tensor, device_mesh, shard_spec)
self.assertEqual(shard_tensor.placements[0].dim, 1) self.assertEqual(shard_tensor.placements[0].dim, 1)
reshard_tensor = shard_tensor.redistribute(device_mesh, shard_minus_spec) reshard_tensor = shard_tensor.redistribute(device_mesh, shard_minus_spec)
self.assertEqual(shard_tensor.placements[0].dim, 1) self.assertEqual(reshard_tensor.placements[0].dim, 1)
@with_comms @with_comms
def test_redistribute_uneven_sharding(self): def test_redistribute_uneven_sharding(self):

View File

@ -622,7 +622,7 @@ class DistTensorOpsTest(DTensorTestBase):
self.assertEqual(misses, 2) self.assertEqual(misses, 2)
# convert to fp32 again and see if there's cache hit # convert to fp32 again and see if there's cache hit
fp32_sharded_dtensor1 = bf16_sharded_dtensor1.float() bf16_sharded_dtensor1.float()
hits, misses, _, _ = _get_sharding_prop_cache_info() hits, misses, _, _ = _get_sharding_prop_cache_info()
# by now we should have cache hit # by now we should have cache hit
self.assertEqual(hits, 1) self.assertEqual(hits, 1)

View File

@ -133,7 +133,6 @@ class UtilTest(DTensorTestBase):
global_tensor_shape, global_mesh, placements global_tensor_shape, global_mesh, placements
) )
assert global_mesh.get_coordinate is not None assert global_mesh.get_coordinate is not None
dp_replic_rank = global_mesh.get_local_rank("dp_replic")
dp_shard_rank = global_mesh.get_local_rank("dp_shard") dp_shard_rank = global_mesh.get_local_rank("dp_shard")
tp_rank = global_mesh.get_local_rank("tp") tp_rank = global_mesh.get_local_rank("tp")
shard_idx_on_dim_0 = tp_rank * dp_shard_size + dp_shard_rank shard_idx_on_dim_0 = tp_rank * dp_shard_size + dp_shard_rank

View File

@ -150,7 +150,7 @@ class DTensorXLAIntegrationTest(TestCase):
shard_spec = [Shard(0)] shard_spec = [Shard(0)]
# annoate fc1 and fc2 # annoate fc1 and fc2
if isinstance(mod, nn.Linear): if isinstance(mod, nn.Linear):
for name, param in mod.named_parameters(): for _, param in mod.named_parameters():
# annotate the parameter tensors directly # annotate the parameter tensors directly
distribute_tensor(param, mesh, shard_spec) distribute_tensor(param, mesh, shard_spec)

View File

@ -1,4 +1,5 @@
# Owner(s): ["oncall: distributed"] # Owner(s): ["oncall: distributed"]
# ruff: noqa: F841
import os import os
import sys import sys

View File

@ -277,7 +277,7 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
self.assertEqual(loss, dist_loss) self.assertEqual(loss, dist_loss)
dist_msd, dist_osd = get_state_dict(dist_model, optimizers=dist_optim) dist_msd, dist_osd = get_state_dict(dist_model, optimizers=dist_optim)
model_sd, optim_sd = get_state_dict(model, optimizers=optim) model_sd, _ = get_state_dict(model, optimizers=optim)
self._verify_msd(model_sd, dist_msd) self._verify_msd(model_sd, dist_msd)
self._verify_osd_by_load(model, optim, self._optim(model), dist_osd) self._verify_osd_by_load(model, optim, self._optim(model), dist_osd)

View File

@ -96,7 +96,7 @@ class TestFineTuning(DTensorTestBase):
optim = torch.optim.Adam(model.parameters(), lr=1e-3) optim = torch.optim.Adam(model.parameters(), lr=1e-3)
# Training # Training
for i in range(3): for _ in range(3):
batch = torch.rand(32, DIM, device="cuda") batch = torch.rand(32, DIM, device="cuda")
loss = model(batch).sum() loss = model(batch).sum()
loss.backward() loss.backward()
@ -161,7 +161,7 @@ class TestFineTuning(DTensorTestBase):
self.assertEqual(i, 0) self.assertEqual(i, 0)
# Training # Training
for j in range(3): for _ in range(3):
batch = torch.rand(32, DIM, device="cuda") batch = torch.rand(32, DIM, device="cuda")
loss = model(batch).sum() loss = model(batch).sum()
loss.backward() loss.backward()

View File

@ -85,11 +85,9 @@ class TestDistributedCheckpointing(ShardedTensorTestBase):
) )
st = sharded_tensor.zeros(spec, 4, 4, dtype=torch.float64) st = sharded_tensor.zeros(spec, 4, 4, dtype=torch.float64)
mapping = {}
md = _create_default_local_metadata({"st": st}) md = _create_default_local_metadata({"st": st})
st_md = md.state_dict_metadata["st"] st_md = md.state_dict_metadata["st"]
self.assertEqual(1, len(st_md.chunks)) self.assertEqual(1, len(st_md.chunks))
@with_comms(init_rpc=False) @with_comms(init_rpc=False)

View File

@ -86,7 +86,6 @@ class TestFsdpTpCheckpointConversion(DTensorTestBase):
tp_model.load_state_dict(tp_state_dict) tp_model.load_state_dict(tp_state_dict)
# Check parameters are equal after loading. # Check parameters are equal after loading.
tp_state_dict_after_load = tp_model.state_dict()
for fsdp_item, tp_item in zip(fsdp_state_dict.items(), tp_state_dict.items()): for fsdp_item, tp_item in zip(fsdp_state_dict.items(), tp_state_dict.items()):
fsdp_k, fsdp_v = fsdp_item fsdp_k, fsdp_v = fsdp_item
tp_k, tp_v = tp_item tp_k, tp_v = tp_item

View File

@ -120,7 +120,6 @@ class TestHSDPCheckpoint(DTensorTestBase):
) )
model.load_state_dict(state_dict_to_save["model"]) model.load_state_dict(state_dict_to_save["model"])
state_dict_after_load = model.state_dict()
# After loading, the current model state dict should be the same as state_dict_to_save. # After loading, the current model state dict should be the same as state_dict_to_save.
for (k1, v1), (k2, v2) in zip( for (k1, v1), (k2, v2) in zip(
state_dict_to_save["model"].items(), model.state_dict().items() state_dict_to_save["model"].items(), model.state_dict().items()

View File

@ -43,7 +43,7 @@ class TestFlattening(TestCase):
"k3": ["x", 99, [{"k3": "y"}]], "k3": ["x", 99, [{"k3": "y"}]],
} }
flatten_dict, mapping = flatten_state_dict(state_dict) _, mapping = flatten_state_dict(state_dict)
""" """
flatten_dict: flatten_dict:
{'k0': [1], 'k2.0': tensor([1]), 'k2.1': 99, 'k2.2.0.k3': tensor(1), 'k3': ['x', 99, [{'k3': 'y'}]]} {'k0': [1], 'k2.0': tensor([1]), 'k2.1': 99, 'k2.2.0.k3': tensor(1), 'k3': ['x', 99, [{'k3': 'y'}]]}

View File

@ -40,21 +40,19 @@ class TestSaveAndLoadAPI(DTensorTestBase):
device_mesh = init_device_mesh(self.device_type, (self.world_size,)) device_mesh = init_device_mesh(self.device_type, (self.world_size,))
model = FSDP(model, device_mesh=device_mesh) model = FSDP(model, device_mesh=device_mesh)
dcp.save(model.state_dict(), checkpoint_id=os.path.join(self.temp_dir, "first")) dcp.save(model.state_dict(), checkpoint_id=os.path.join(self.temp_dir, "first"))
sd = dcp.load( dcp.load(model.state_dict(), checkpoint_id=os.path.join(self.temp_dir, "first"))
model.state_dict(), checkpoint_id=os.path.join(self.temp_dir, "first")
)
with patch.object( with patch.object(
dcp.FileSystemReader, "validate_checkpoint_id", return_value=False dcp.FileSystemReader, "validate_checkpoint_id", return_value=False
) as m1: ):
with patch.object( with patch.object(
dcp.FileSystemWriter, "validate_checkpoint_id", return_value=False dcp.FileSystemWriter, "validate_checkpoint_id", return_value=False
) as m2: ):
dcp.save( dcp.save(
model.state_dict(), model.state_dict(),
checkpoint_id=os.path.join(self.temp_dir, "second"), checkpoint_id=os.path.join(self.temp_dir, "second"),
) )
sd = dcp.load( dcp.load(
model.state_dict(), model.state_dict(),
checkpoint_id=os.path.join(self.temp_dir, "second"), checkpoint_id=os.path.join(self.temp_dir, "second"),
) )
@ -62,7 +60,7 @@ class TestSaveAndLoadAPI(DTensorTestBase):
with self.assertRaisesRegex(RuntimeError, "Cannot detect"): with self.assertRaisesRegex(RuntimeError, "Cannot detect"):
dcp.save(model.state_dict(), checkpoint_id="abc://abc.abc") dcp.save(model.state_dict(), checkpoint_id="abc://abc.abc")
with self.assertRaisesRegex(RuntimeError, "Cannot detect"): with self.assertRaisesRegex(RuntimeError, "Cannot detect"):
sd = dcp.load(model.state_dict(), checkpoint_id="abc://abc.abc") dcp.load(model.state_dict(), checkpoint_id="abc://abc.abc")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -81,7 +81,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
# Train 10 steps. # Train 10 steps.
_dist_optim = [dist_optim] if not isinstance(dist_optim, list) else dist_optim _dist_optim = [dist_optim] if not isinstance(dist_optim, list) else dist_optim
for i in range(10): for _ in range(10):
optim.zero_grad() optim.zero_grad()
for d_optim in _dist_optim: for d_optim in _dist_optim:
d_optim.zero_grad() d_optim.zero_grad()

View File

@ -104,7 +104,7 @@ class TestStateDictUtils(DTensorTestBase):
return tensor, dist_tensor return tensor, dist_tensor
ltensor, ldtensor = [], [] ltensor, ldtensor = [], []
for i in range(10): for _ in range(10):
tensor, dtensor = create_dtensor() tensor, dtensor = create_dtensor()
ltensor.append(tensor) ltensor.append(tensor)
ltensor.append(torch.ones(10, device=torch.device("cuda"))) ltensor.append(torch.ones(10, device=torch.device("cuda")))

View File

@ -259,7 +259,7 @@ class _StartProcessesTest(TestCase):
) -> None: ) -> None:
mp_queue = mp.get_context("spawn").Queue() mp_queue = mp.get_context("spawn").Queue()
child_nproc = 2 child_nproc = 2
ctx = mp.spawn( mp.spawn(
start_processes_zombie_test, start_processes_zombie_test,
nprocs=1, nprocs=1,
args=(entrypoint, mp_queue, self.log_dir(), child_nproc), args=(entrypoint, mp_queue, self.log_dir(), child_nproc),

View File

@ -165,7 +165,7 @@ class CreateBackendTest(TestCase):
def test_create_backend_returns_backend_if_is_host_is_not_specified_and_store_already_exists( def test_create_backend_returns_backend_if_is_host_is_not_specified_and_store_already_exists(
self, self,
) -> None: ) -> None:
store = TCPStore( # type: ignore[call-arg] # noqa: F841 TCPStore( # type: ignore[call-arg] # noqa: F841
self._expected_endpoint_host, self._expected_endpoint_port, is_master=True self._expected_endpoint_host, self._expected_endpoint_port, is_master=True
) )

View File

@ -99,7 +99,7 @@ class RendezvousTimeoutTest(TestCase):
ValueError, ValueError,
rf"^The join timeout \({join_timeout}\) must be positive.$", rf"^The join timeout \({join_timeout}\) must be positive.$",
): ):
timeout = RendezvousTimeout(join_timeout) RendezvousTimeout(join_timeout)
class NodeDescTest(TestCase): class NodeDescTest(TestCase):
@ -1637,7 +1637,7 @@ class CreateHandlerTest(TestCase):
def _ignore_exception(exception_type: Exception, fn: Callable): def _ignore_exception(exception_type: Exception, fn: Callable):
try: try:
fn() fn()
except exception_type as e: except exception_type:
pass pass

View File

@ -70,7 +70,7 @@ class RendezvousBackendTestMixin(ABC):
self.assertTrue(has_set) self.assertTrue(has_set)
def test_set_state_sets_backend_state_if_token_is_current(self) -> None: def test_set_state_sets_backend_state_if_token_is_current(self) -> None:
state1, token1, has_set1 = self._set_state(b"x") _, token1, has_set1 = self._set_state(b"x")
state2, token2, has_set2 = self._set_state(b"y", token1) state2, token2, has_set2 = self._set_state(b"y", token1)
@ -80,7 +80,7 @@ class RendezvousBackendTestMixin(ABC):
self.assertTrue(has_set2) self.assertTrue(has_set2)
def test_set_state_returns_current_backend_state_if_token_is_old(self) -> None: def test_set_state_returns_current_backend_state_if_token_is_old(self) -> None:
state1, token1, _ = self._set_state(b"x") _, token1, _ = self._set_state(b"x")
state2, token2, _ = self._set_state(b"y", token1) state2, token2, _ = self._set_state(b"y", token1)

View File

@ -113,7 +113,7 @@ if not (IS_WINDOWS or IS_MACOS):
num_clients = 10 num_clients = 10
num_requests_per_client = 10 num_requests_per_client = 10
processes = [] processes = []
for i in range(num_clients): for _ in range(num_clients):
p = mp.Process( p = mp.Process(
target=func, args=(num_requests_per_client, self.file_path) target=func, args=(num_requests_per_client, self.file_path)
) )
@ -190,7 +190,7 @@ if not (IS_WINDOWS or IS_MACOS):
""" """
client = timer.FileTimerClient(file_path) client = timer.FileTimerClient(file_path)
sem.release() sem.release()
for i in range(0, n): for _ in range(0, n):
client.acquire("test_scope", 0) client.acquire("test_scope", 0)
time.sleep(interval) time.sleep(interval)

View File

@ -159,7 +159,7 @@ class CheckpointWrapperTest(TestCase):
if use_reentrant if use_reentrant
else CheckpointImpl.NO_REENTRANT, else CheckpointImpl.NO_REENTRANT,
) )
for i in range(self.n): for _ in range(self.n):
l = nn.Sequential( l = nn.Sequential(
nn.Linear(256, 256), nn.Linear(256, 256), nn.Linear(256, 256) nn.Linear(256, 256), nn.Linear(256, 256), nn.Linear(256, 256)
) )

View File

@ -303,13 +303,13 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
RuntimeError, "DeviceMesh is not compatible with LOCAL_STATE_DICT." RuntimeError, "DeviceMesh is not compatible with LOCAL_STATE_DICT."
): ):
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
state_dict = model.state_dict() model.state_dict()
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, "DeviceMesh is not compatible with LOCAL_STATE_DICT." RuntimeError, "DeviceMesh is not compatible with LOCAL_STATE_DICT."
): ):
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
optim_state_dict = FSDP.optim_state_dict(model, optim) FSDP.optim_state_dict(model, optim)
instantiate_parametrized_tests(TestFSDPWithDeviceMeshAndDTensor) instantiate_parametrized_tests(TestFSDPWithDeviceMeshAndDTensor)

View File

@ -364,9 +364,8 @@ class TestFSDPFineTune(FSDPTest):
) )
torch.manual_seed(self.rank + 1) torch.manual_seed(self.rank + 1)
losses = [] losses = []
for idx in range(6): for _ in range(6):
frozen_input = torch.randn((4, 4), device="cuda", requires_grad=False) frozen_input = torch.randn((4, 4), device="cuda", requires_grad=False)
learnable_input = torch.randn((4, 4), device="cuda", requires_grad=True)
for _model, _optim in ((model, model_optim), (ref_model, ref_model_optim)): for _model, _optim in ((model, model_optim), (ref_model, ref_model_optim)):
loss = _model(frozen_input, frozen_input).sum() loss = _model(frozen_input, frozen_input).sum()
losses.append(loss) losses.append(loss)

View File

@ -182,7 +182,7 @@ class TestFreezingWeights(FSDPTest):
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
for iteration in range(3): for _ in range(3):
out = model(batch) out = model(batch)
fake_loss = criterion(out, target) fake_loss = criterion(out, target)
optimizer.zero_grad() optimizer.zero_grad()

View File

@ -108,8 +108,6 @@ class TestFSDPMemory(FSDPTest):
def _dist_train(self, with_checkpoint, expected, model_hidden_dim, iterations): def _dist_train(self, with_checkpoint, expected, model_hidden_dim, iterations):
gpu_id = self.rank gpu_id = self.rank
world_size = self.world_size
batch = torch.randn(size=(2, 3, 224, 224)).cuda() batch = torch.randn(size=(2, 3, 224, 224)).cuda()
model = create_model( model = create_model(

View File

@ -278,9 +278,9 @@ class TestFSDPMiscMultiProcess(FSDPTest):
) )
x = torch.randn(10, 10, device="cuda") x = torch.randn(10, 10, device="cuda")
y = torch.randn(10, 10, device="cuda") y = torch.randn(10, 10, device="cuda")
for i in range(4): for _ in range(4):
if use_second_layer: if use_second_layer:
a, b = fsdp(x, y) a, _ = fsdp(x, y)
else: else:
a = fsdp(x, y) a = fsdp(x, y)
loss = a.sum() loss = a.sum()
@ -509,7 +509,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
def test_fsdp_cpu_training(self): def test_fsdp_cpu_training(self):
"""Tests FSDP training on CPU.""" """Tests FSDP training on CPU."""
gloo_pg = dist.new_group(backend="gloo") gloo_pg = dist.new_group(backend="gloo")
for ss in [ for ss in [ # noqa: F841
ShardingStrategy.NO_SHARD, ShardingStrategy.NO_SHARD,
ShardingStrategy.FULL_SHARD, ShardingStrategy.FULL_SHARD,
ShardingStrategy.SHARD_GRAD_OP, ShardingStrategy.SHARD_GRAD_OP,
@ -857,13 +857,13 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
torch.cuda.set_device(self.rank) torch.cuda.set_device(self.rank)
# Test CPU # Test CPU
no_params = nn.ReLU() no_params = nn.ReLU()
module = FSDP(no_params) FSDP(no_params)
# Test CUDA # Test CUDA
no_params = nn.ReLU().cuda() no_params = nn.ReLU().cuda()
module = FSDP(no_params) FSDP(no_params)
# Test CPU + device_id # Test CPU + device_id
no_params = nn.ReLU() no_params = nn.ReLU()
module = FSDP(no_params, device_id=torch.cuda.current_device()) FSDP(no_params, device_id=torch.cuda.current_device())
# For modules with no params, wrong device_id will raise error about # For modules with no params, wrong device_id will raise error about
# inconsistency between compute_device and device_id, since compute_device # inconsistency between compute_device and device_id, since compute_device
# is computed as torch.cuda.current_device when there are no params. # is computed as torch.cuda.current_device when there are no params.

View File

@ -1139,7 +1139,6 @@ class TestFSDPDifferentSubmodulePrecision(FSDPTest):
model = SaveForwardInputsModel( model = SaveForwardInputsModel(
forward_inputs=forward_inputs, cast_forward_inputs=False forward_inputs=forward_inputs, cast_forward_inputs=False
).cuda() ).cuda()
c1, c2 = model.c1, model.c2
x = torch.zeros(2, 100, device="cuda") x = torch.zeros(2, 100, device="cuda")
# float16 on one submodule and float32 on everything else # float16 on one submodule and float32 on everything else

View File

@ -45,7 +45,7 @@ class TestMultipleWrapping(FSDPTest):
model = FSDP(inner_model).cuda() model = FSDP(inner_model).cuda()
optim = SGD(model.parameters(), lr=0.1) optim = SGD(model.parameters(), lr=0.1)
for i in range(3): for _ in range(3):
input = torch.rand((1, 5), dtype=torch.float).cuda() input = torch.rand((1, 5), dtype=torch.float).cuda()
input.requires_grad = True input.requires_grad = True
output = model(input) output = model(input)

View File

@ -1510,7 +1510,7 @@ class TestFSDPOptimState(FSDPTest):
) = self._init_nested_model(wrap=False, use_multiple_param_groups=False) ) = self._init_nested_model(wrap=False, use_multiple_param_groups=False)
if should_check_method_fn("rekey_optim_state_dict"): if should_check_method_fn("rekey_optim_state_dict"):
with context_fn(): with context_fn():
rekeyed_osd = FSDP.rekey_optim_state_dict( FSDP.rekey_optim_state_dict(
fsdp_osd, # from `full_optim_state_dict()` fsdp_osd, # from `full_optim_state_dict()`
OptimStateKeyType.PARAM_ID, OptimStateKeyType.PARAM_ID,
nonwrapped_model, nonwrapped_model,
@ -1650,7 +1650,7 @@ class TestFSDPOptimState(FSDPTest):
) )
# Make optim1 has a different state. # Make optim1 has a different state.
for i in range(5): for _ in range(5):
batch = torch.rand(5, 8).cuda() batch = torch.rand(5, 8).cuda()
loss = models[1](batch).sum() loss = models[1](batch).sum()
loss.backward() loss.backward()
@ -1765,7 +1765,7 @@ class TestFSDPOptimState(FSDPTest):
initializer = self._model_class[model_class] initializer = self._model_class[model_class]
# First, run a wrapped model with full world size for a few iterations # First, run a wrapped model with full world size for a few iterations
model1, optim1, optim_input1 = initializer( model1, optim1, _ = initializer(
wrap=True, wrap=True,
use_multiple_param_groups=use_multiple_param_groups, use_multiple_param_groups=use_multiple_param_groups,
) )
@ -1788,7 +1788,7 @@ class TestFSDPOptimState(FSDPTest):
new_group = dist.distributed_c10d._get_default_group() new_group = dist.distributed_c10d._get_default_group()
# Second, run a wrapped model with (possibly) halved world size and # Second, run a wrapped model with (possibly) halved world size and
# (possibly) differing `optim_input` across ranks # (possibly) differing `optim_input` across ranks
model2, optim2, optim_input2 = initializer( model2, optim2, _ = initializer(
wrap=True, wrap=True,
group=new_group, group=new_group,
use_multiple_param_groups=use_multiple_param_groups, use_multiple_param_groups=use_multiple_param_groups,
@ -1861,7 +1861,8 @@ class TestFSDPOptimState(FSDPTest):
FSDP.optim_state_dict(model, optim), osd, check_same_param_keys=True FSDP.optim_state_dict(model, optim), osd, check_same_param_keys=True
) )
step() step()
osd_to_load = FSDP.optim_state_dict_to_load(
osd_to_load = FSDP.optim_state_dict_to_load( # noqa: F841
model, optim, osd, load_directly=True model, optim, osd, load_directly=True
) )
self._check_same_state( self._check_same_state(
@ -1994,7 +1995,7 @@ class TestFSDPOptimState(FSDPTest):
loss.backward() loss.backward()
fsdp_optim.step() fsdp_optim.step()
orig_state_dict = deepcopy(fsdp_optim.state_dict()) orig_state_dict = deepcopy(fsdp_optim.state_dict())
optim_state_dict = FSDP.optim_state_dict(fsdp_model, fsdp_optim) FSDP.optim_state_dict(fsdp_model, fsdp_optim)
FSDP.optim_state_dict_to_load( FSDP.optim_state_dict_to_load(
fsdp_model, fsdp_model,
fsdp_optim, fsdp_optim,

View File

@ -966,7 +966,7 @@ class TestFSDPStateDict(FSDPTest):
setattr(module, LINEAR_SKIP, linear_skip) setattr(module, LINEAR_SKIP, linear_skip)
return fsdp, linear_skip_tensor_names return fsdp, linear_skip_tensor_names
fsdp, linear_skip_tensor_names = _create_module() fsdp, _ = _create_module()
# Run a forward pass # Run a forward pass
inp = torch.randn((1, 10), device=torch.cuda.current_device()) inp = torch.randn((1, 10), device=torch.cuda.current_device())
loss = fsdp(inp) loss = fsdp(inp)

View File

@ -634,7 +634,7 @@ class TestUnshardParams(TestUnshardParamsBase):
model = FSDP(model, auto_wrap_policy=ModuleWrapPolicy((nn.Sequential,))) model = FSDP(model, auto_wrap_policy=ModuleWrapPolicy((nn.Sequential,)))
with FSDP.summon_full_params(model[0]): with FSDP.summon_full_params(model[0]):
# Check that the summoned module does not have its flat parameter # Check that the summoned module does not have its flat parameter
for param_name, param in model[0].named_parameters(): for param_name, _ in model[0].named_parameters():
self.assertFalse(FLAT_PARAM in param_name) self.assertFalse(FLAT_PARAM in param_name)
self.assertGreater(len(list(model[0].parameters())), 1) self.assertGreater(len(list(model[0].parameters())), 1)

View File

@ -260,7 +260,7 @@ class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest):
model = FSDP(copy.deepcopy(base_model), self.process_group, **fsdp_kwargs) model = FSDP(copy.deepcopy(base_model), self.process_group, **fsdp_kwargs)
model = torch.compile(model) model = torch.compile(model)
optim = torch.optim.Adam(model.parameters(), lr=1e-2) optim = torch.optim.Adam(model.parameters(), lr=1e-2)
for i in range(10): for _ in range(10):
losses = [] losses = []
inp = ref_model.get_input(torch.device("cuda")) inp = ref_model.get_input(torch.device("cuda"))
for _model, _optim in ((ref_model, ref_optim), (model, optim)): for _model, _optim in ((ref_model, ref_optim), (model, optim)):

View File

@ -118,7 +118,7 @@ class TestUtils(TestCase):
x.fill_(0) x.fill_(0)
x = nn.utils.rnn.pack_padded_sequence(x, seq_length) x = nn.utils.rnn.pack_padded_sequence(x, seq_length)
x, h = rnn(x) x, _ = rnn(x)
x = _apply_to_tensors(fill_fn, x) x = _apply_to_tensors(fill_fn, x)
x, _ = nn.utils.rnn.pad_packed_sequence(x) x, _ = nn.utils.rnn.pad_packed_sequence(x)
self.assertEqual(torch.sum(x), 0) self.assertEqual(torch.sum(x), 0)

View File

@ -41,7 +41,6 @@ class LaunchTest(unittest.TestCase):
def test_launch_without_env(self): def test_launch_without_env(self):
nnodes = 1 nnodes = 1
nproc_per_node = 4 nproc_per_node = 4
world_size = nnodes * nproc_per_node
sock = get_socket_with_port() sock = get_socket_with_port()
with closing(sock): with closing(sock):
master_port = sock.getsockname()[1] master_port = sock.getsockname()[1]

View File

@ -114,7 +114,7 @@ class CustomLinearDx(Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
input_val, weight, bias = ctx.saved_tensors input_val, weight, _ = ctx.saved_tensors
grad_input = grad_output.mm(weight) grad_input = grad_output.mm(weight)
ctx.module.cached_context[ctx.layer_idx].append(grad_output.clone()) ctx.module.cached_context[ctx.layer_idx].append(grad_output.clone())
ctx.module.cached_context[str(ctx.layer_idx) + "_input"].append( ctx.module.cached_context[str(ctx.layer_idx) + "_input"].append(
@ -131,7 +131,7 @@ class CustomLinearDxDw(Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
input_val, weight, bias = ctx.saved_tensors input_val, weight, _ = ctx.saved_tensors
grad_input = grad_output.mm(weight) grad_input = grad_output.mm(weight)
grad_weight = grad_output.t().mm(input_val) grad_weight = grad_output.t().mm(input_val)
grad_bias = grad_output.sum(0) grad_bias = grad_output.sum(0)

View File

@ -74,7 +74,7 @@ class StageBackwardTests(TestCase):
# Forward, then backward of loss with respect to inputs # Forward, then backward of loss with respect to inputs
out = mod(x) out = mod(x)
loss = loss_fn(out, target) loss = loss_fn(out, target)
dinputs, param_groups = stage_backward_input( dinputs, _param_groups = stage_backward_input(
stage_outputs_or_loss=(loss,), stage_outputs_or_loss=(loss,),
output_grads=None, output_grads=None,
input_values=[x], input_values=[x],
@ -88,7 +88,7 @@ class StageBackwardTests(TestCase):
torch.testing.assert_close(x.grad, ref_x.grad) torch.testing.assert_close(x.grad, ref_x.grad)
torch.testing.assert_close(dinputs[0], ref_x.grad) torch.testing.assert_close(dinputs[0], ref_x.grad)
for name, p in mod.named_parameters(): for _, p in mod.named_parameters():
# Check that the weight gradients were not updated # Check that the weight gradients were not updated
self.assertEqual(p.grad, None) self.assertEqual(p.grad, None)
@ -109,7 +109,7 @@ class StageBackwardTests(TestCase):
# Forward, then backward of loss with respect to inputs # Forward, then backward of loss with respect to inputs
out = mod(x) out = mod(x)
loss = loss_fn(out, target) loss = loss_fn(out, target)
dinputs, param_groups = stage_backward_input( _dinputs, param_groups = stage_backward_input(
stage_outputs_or_loss=(loss,), stage_outputs_or_loss=(loss,),
output_grads=None, output_grads=None,
input_values=[x], input_values=[x],
@ -157,7 +157,7 @@ class StageBackwardTests(TestCase):
for x in inputs: for x in inputs:
out = mod(x) out = mod(x)
loss = loss_fn(out, target) loss = loss_fn(out, target)
dinputs, param_groups = stage_backward_input( _dinputs, param_groups = stage_backward_input(
stage_outputs_or_loss=(loss,), stage_outputs_or_loss=(loss,),
output_grads=None, output_grads=None,
input_values=[x], input_values=[x],

View File

@ -264,7 +264,7 @@ class TestSchedulePlan(TestCase):
] ]
schedule = ScheduleClass(stages, num_microbatches) schedule = ScheduleClass(stages, num_microbatches)
formatted_pipeline_order = _format_pipeline_order( _formatted_pipeline_order = _format_pipeline_order(
schedule.pipeline_order schedule.pipeline_order
) )
@ -305,10 +305,7 @@ class TestSchedulePlan(TestCase):
for i in range(num_local_stages) for i in range(num_local_stages)
] ]
schedule = ScheduleClass(stages, num_microbatches) schedule = ScheduleClass(stages, num_microbatches)
formatted_pipeline_order = _format_pipeline_order( _format_pipeline_order(schedule.pipeline_order)
schedule.pipeline_order
)
# print(formatted_pipeline_order)
def stage_to_rank(stage): def stage_to_rank(stage):
return stage % group_size return stage % group_size

View File

@ -151,7 +151,7 @@ class ScheduleTest(MultiProcContinousTest):
schedule.step(x) schedule.step(x)
elif self.rank == self.world_size - 1: elif self.rank == self.world_size - 1:
losses = [] losses = []
out = schedule.step(target=target, losses=losses) schedule.step(target=target, losses=losses)
else: else:
schedule.step() schedule.step()
@ -412,7 +412,6 @@ class ScheduleTest(MultiProcContinousTest):
if hasattr(ScheduleClass, "num_microbatches") if hasattr(ScheduleClass, "num_microbatches")
else 8 else 8
) )
input_args = x.chunk(num_microbatches)[0]
stages = [ stages = [
PipelineStage( PipelineStage(
stage_module, stage_module,
@ -548,7 +547,6 @@ class ScheduleTest(MultiProcContinousTest):
loss_fn = torch.nn.MSELoss(reduction="sum") loss_fn = torch.nn.MSELoss(reduction="sum")
# Create a pipeline stage to wrap that submodule # Create a pipeline stage to wrap that submodule
input_args = x.chunk(num_microbatches)[0]
stage_indices = rank_stages[self.rank] stage_indices = rank_stages[self.rank]
print(f"Rank {self.rank} stages: {stage_indices}") print(f"Rank {self.rank} stages: {stage_indices}")
submod_names = [f"layers.{i}" for i in stage_indices] submod_names = [f"layers.{i}" for i in stage_indices]
@ -582,7 +580,7 @@ class ScheduleTest(MultiProcContinousTest):
schedule.step(x) schedule.step(x)
elif self.rank == self.world_size - 1: elif self.rank == self.world_size - 1:
losses = [] losses = []
out = schedule.step(target=target, losses=losses) schedule.step(target=target, losses=losses)
else: else:
schedule.step() schedule.step()
self.assertEqual( self.assertEqual(
@ -887,7 +885,6 @@ class ScheduleTest(MultiProcContinousTest):
# Create a pipeline stage to wrap that submodule # Create a pipeline stage to wrap that submodule
chunks = 2 chunks = 2
input_args = x.chunk(chunks)[0]
stages = [ stages = [
PipelineStage( PipelineStage(
stage_module, stage_module,

View File

@ -310,9 +310,6 @@ class StageTest(MultiProcContinousTest):
full_mod.to(self.device) full_mod.to(self.device)
stage_mod = full_mod.get_submodule(f"layers.{self.rank}") stage_mod = full_mod.get_submodule(f"layers.{self.rank}")
x = torch.randn(batch_size, d_hid, device=self.device)
target = torch.randn(batch_size, d_hid, device=self.device)
stage_with_dw_builder = PipelineStage( stage_with_dw_builder = PipelineStage(
stage_mod, stage_mod,
self.rank, self.rank,

View File

@ -58,7 +58,7 @@ class UnflattenTests(TestCase):
# Check qualnames # Check qualnames
for stage_idx in range(pipe.num_stages): for stage_idx in range(pipe.num_stages):
stage_mod = pipe.get_stage_module(stage_idx) stage_mod = pipe.get_stage_module(stage_idx)
for param_name, param in stage_mod.named_parameters(): for param_name, _ in stage_mod.named_parameters():
assert ( assert (
param_name in orig_state_dict param_name in orig_state_dict
), f"{param_name} not in original state dict" ), f"{param_name} not in original state dict"

View File

@ -87,7 +87,9 @@ class MicroPipelineTPTest(TestCase):
a = all_gather_tensor(inp, gather_dim=0, group=group.group_name) a = all_gather_tensor(inp, gather_dim=0, group=group.group_name)
b = all_gather_tensor(inp, gather_dim=1, group=group.group_name) b = all_gather_tensor(inp, gather_dim=1, group=group.group_name)
c = _fp8_all_gather(inp, gather_dim=0, group_name=group.group_name) c = _fp8_all_gather(inp, gather_dim=0, group_name=group.group_name)
d = _fp8_all_gather(inp, gather_dim=1, group_name=group.group_name) d = _fp8_all_gather( # noqa: F841
inp, gather_dim=1, group_name=group.group_name
)
return a, b, c return a, b, c
inp = torch.rand(64, 32, device="cuda") inp = torch.rand(64, 32, device="cuda")

View File

@ -311,7 +311,7 @@ class DistTensorParallelExampleTest(DTensorTestBase):
torch.manual_seed(0) torch.manual_seed(0)
steps = 10 if type(model) is torch.float64 else 1 steps = 10 if type(model) is torch.float64 else 1
for iter in range(steps): for _ in range(steps):
inp = torch.randint( inp = torch.randint(
model_args.vocab_size, inp_size, device=self.device_type model_args.vocab_size, inp_size, device=self.device_type
) )

View File

@ -223,7 +223,7 @@ class TensorParallelStyleTest(DTensorTestBase):
AssertionError, AssertionError,
"input_layouts and desired_input_layouts should have same length!", "input_layouts and desired_input_layouts should have same length!",
): ):
prepare_inps_dimension_mismatch = PrepareModuleInput( PrepareModuleInput(
input_layouts=Shard(0), desired_input_layouts=(Replicate(), None) input_layouts=Shard(0), desired_input_layouts=(Replicate(), None)
) )
# Raise assertion error if module inputs and input_layouts do not have same length. # Raise assertion error if module inputs and input_layouts do not have same length.

View File

@ -182,7 +182,7 @@ class TimeoutTest(TestCase):
threads.append(t) threads.append(t)
t.start() t.start()
for i, thread in enumerate(threads): for _, thread in enumerate(threads):
thread.join() thread.join()
# we expect the world_size-1 threads to have failed # we expect the world_size-1 threads to have failed
@ -583,14 +583,14 @@ class CommonDistributedDataParallelTest:
) )
) )
with err_ctx: with err_ctx:
model = self._test_ddp_checkpointing( self._test_ddp_checkpointing(
self.CheckpointOnceModule(use_reentrant=use_reentrant), self.CheckpointOnceModule(use_reentrant=use_reentrant),
process_group=process_group, process_group=process_group,
use_bucket_view=use_bucket_view, use_bucket_view=use_bucket_view,
find_unused_parameters=True, find_unused_parameters=True,
) )
# test passes when static_graph is true # test passes when static_graph is true
model = self._test_ddp_checkpointing( self._test_ddp_checkpointing(
self.CheckpointOnceModule(use_reentrant=use_reentrant), self.CheckpointOnceModule(use_reentrant=use_reentrant),
process_group=process_group, process_group=process_group,
use_bucket_view=use_bucket_view, use_bucket_view=use_bucket_view,
@ -615,7 +615,7 @@ class CommonDistributedDataParallelTest:
) )
) )
with err_ctx: with err_ctx:
model = self._test_ddp_checkpointing( self._test_ddp_checkpointing(
self.CheckpointTwiceModule(use_reentrant=use_reentrant), self.CheckpointTwiceModule(use_reentrant=use_reentrant),
process_group=process_group, process_group=process_group,
use_bucket_view=use_bucket_view, use_bucket_view=use_bucket_view,
@ -623,7 +623,7 @@ class CommonDistributedDataParallelTest:
) )
with err_ctx: with err_ctx:
model = self._test_ddp_checkpointing( self._test_ddp_checkpointing(
self.CheckpointTwiceModule(use_reentrant=use_reentrant), self.CheckpointTwiceModule(use_reentrant=use_reentrant),
process_group=process_group, process_group=process_group,
use_bucket_view=use_bucket_view, use_bucket_view=use_bucket_view,
@ -641,7 +641,7 @@ class CommonDistributedDataParallelTest:
process_group = self._get_process_group() process_group = self._get_process_group()
for use_bucket_view in (True, False): for use_bucket_view in (True, False):
# Test passes when static_graph=True. # Test passes when static_graph=True.
model = self._test_ddp_checkpointing( self._test_ddp_checkpointing(
self.CheckpointTwiceModule(use_reentrant=use_reentrant), self.CheckpointTwiceModule(use_reentrant=use_reentrant),
process_group=process_group, process_group=process_group,
use_bucket_view=use_bucket_view, use_bucket_view=use_bucket_view,
@ -656,7 +656,7 @@ class CommonDistributedDataParallelTest:
""" """
process_group = self._get_process_group() process_group = self._get_process_group()
for use_bucket_view in (True, False): for use_bucket_view in (True, False):
model = self._test_ddp_checkpointing( self._test_ddp_checkpointing(
self.DynamicCheckpointTwiceModule(use_reentrant=False), self.DynamicCheckpointTwiceModule(use_reentrant=False),
process_group=process_group, process_group=process_group,
use_bucket_view=use_bucket_view, use_bucket_view=use_bucket_view,
@ -675,7 +675,7 @@ class CommonDistributedDataParallelTest:
""" """
process_group = self._get_process_group() process_group = self._get_process_group()
for use_bucket_view in (True, False): for use_bucket_view in (True, False):
model = self._test_ddp_checkpointing( self._test_ddp_checkpointing(
self.DynamicCheckpointTwiceModuleWeightSharing(use_reentrant=False), self.DynamicCheckpointTwiceModuleWeightSharing(use_reentrant=False),
process_group=process_group, process_group=process_group,
use_bucket_view=use_bucket_view, use_bucket_view=use_bucket_view,
@ -719,7 +719,7 @@ class CommonDistributedDataParallelTest:
process_group = self._get_process_group() process_group = self._get_process_group()
torch.cuda.set_device(self.rank) torch.cuda.set_device(self.rank)
for use_bucket_view in (True, False): for use_bucket_view in (True, False):
model = self._test_ddp_checkpointing( self._test_ddp_checkpointing(
self.CheckpointTwiceModuleWeightSharing(), self.CheckpointTwiceModuleWeightSharing(),
process_group=process_group, process_group=process_group,
use_bucket_view=use_bucket_view, use_bucket_view=use_bucket_view,
@ -737,7 +737,7 @@ class CommonDistributedDataParallelTest:
"Expect `start_powerSGD_iter` > 1 if `use_error_feedback` or `warm_start` is enabled, " "Expect `start_powerSGD_iter` > 1 if `use_error_feedback` or `warm_start` is enabled, "
"because PowerSGD can only be applied after the first two iterations in DDP.", "because PowerSGD can only be applied after the first two iterations in DDP.",
): ):
state = powerSGD.PowerSGDState( powerSGD.PowerSGDState(
process_group=None, process_group=None,
matrix_approximation_rank=1, matrix_approximation_rank=1,
start_powerSGD_iter=start_powerSGD_iter, start_powerSGD_iter=start_powerSGD_iter,

View File

@ -429,7 +429,7 @@ class TestWithNCCL(MultiProcessTestCase):
input = torch.full((10, 10), float(self.rank), device=self.device) input = torch.full((10, 10), float(self.rank), device=self.device)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
output = torch.ops._c10d_functional.all_reduce( torch.ops._c10d_functional.all_reduce(
input, input,
"avg", "avg",
"default", "default",
@ -550,7 +550,7 @@ class CompileTest(TestCase):
assert "= torch.ops._c10d_functional.wait_tensor.default" not in code assert "= torch.ops._c10d_functional.wait_tensor.default" not in code
# Test aoti # Test aoti
out = AOTIRunnerUtil.run("cuda", func, (arg,)) AOTIRunnerUtil.run("cuda", func, (arg,))
torch.cuda.synchronize() torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@ -596,7 +596,7 @@ class CompileTest(TestCase):
assert "= torch.ops._c10d_functional.wait_tensor.default" not in code assert "= torch.ops._c10d_functional.wait_tensor.default" not in code
# Test aoti # Test aoti
out = AOTIRunnerUtil.run("cuda", func, (args,)) out = AOTIRunnerUtil.run("cuda", func, (args,)) # noqa: F841
torch.cuda.synchronize() torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@ -708,7 +708,7 @@ class CompileTest(TestCase):
assert "= torch.ops._c10d_functional.wait_tensor.default" not in code assert "= torch.ops._c10d_functional.wait_tensor.default" not in code
# Test aoti # Test aoti
out = AOTIRunnerUtil.run("cuda", func, (arg,)) AOTIRunnerUtil.run("cuda", func, (arg,))
torch.cuda.synchronize() torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@ -742,7 +742,7 @@ class CompileTest(TestCase):
) )
# Test aoti # Test aoti
out = AOTIRunnerUtil.run("cuda", func, (args,)) out = AOTIRunnerUtil.run("cuda", func, (args,)) # noqa: F841
torch.cuda.synchronize() torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "This is a GPU test!") @unittest.skipIf(not HAS_GPU, "This is a GPU test!")
@ -764,7 +764,7 @@ class CompileTest(TestCase):
) )
# Test aoti # Test aoti
out = AOTIRunnerUtil.run("cuda", func, (arg,)) AOTIRunnerUtil.run("cuda", func, (arg,))
torch.cuda.synchronize() torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@ -790,7 +790,7 @@ class CompileTest(TestCase):
) )
# Test aoti # Test aoti
out = AOTIRunnerUtil.run("cuda", func, (arg,)) AOTIRunnerUtil.run("cuda", func, (arg,))
torch.cuda.synchronize() torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@ -910,7 +910,7 @@ class CompileTest(TestCase):
) )
# Test aoti # Test aoti
out = AOTIRunnerUtil.run("cuda", func, (arg,)) AOTIRunnerUtil.run("cuda", func, (arg,))
torch.cuda.synchronize() torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")

View File

@ -1920,7 +1920,7 @@ class DistributedDataParallelTest(
ddp_state_dict = torch.load(checkpoint_path, map_location=map_location) ddp_state_dict = torch.load(checkpoint_path, map_location=map_location)
for model in [ddp_withload, model_withload]: for model in [ddp_withload, model_withload]:
for p in ddp_withload.parameters(): for p in model.parameters():
with torch.no_grad(): with torch.no_grad():
p.zero_() p.zero_()
ddp_withload.load_state_dict(ddp_state_dict) ddp_withload.load_state_dict(ddp_state_dict)
@ -1973,7 +1973,8 @@ class DistributedDataParallelTest(
This unit test verifies whether the Future object is passed properly. This unit test verifies whether the Future object is passed properly.
The callback function creates a Future object and sets a value to it. The callback function creates a Future object and sets a value to it.
""" """
store = c10d.FileStore(self.file_name, self.world_size) store = c10d.FileStore(self.file_name, self.world_size) # noqa: F841
process_group = self._get_process_group() process_group = self._get_process_group()
# Test on CPU # Test on CPU

View File

@ -366,7 +366,7 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
thread.start() thread.start()
# We would get stuck here due to d2h if we didn't abort. # We would get stuck here due to d2h if we didn't abort.
t_cpu = t.cpu() t.cpu()
thread.join() thread.join()
@ -741,7 +741,7 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
# First allreduce to initialize default PG's communicator. # First allreduce to initialize default PG's communicator.
pg.allreduce(t).wait() pg.allreduce(t).wait()
# PG1 is an PG without comms initialized, since we don't call collective on it # PG1 is an PG without comms initialized, since we don't call collective on it
new_pg1 = c10d.new_group([0, 1]) new_pg1 = c10d.new_group([0, 1]) # noqa: F841
new_pg2 = c10d.new_group([0, 1]) new_pg2 = c10d.new_group([0, 1])
t2 = torch.rand(10, 10, device=device) t2 = torch.rand(10, 10, device=device)
@ -807,7 +807,7 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
# 'timeout' kwarg (or its kwdefault) taking precedence # 'timeout' kwarg (or its kwdefault) taking precedence
opts = dist.ProcessGroupNCCL.Options() opts = dist.ProcessGroupNCCL.Options()
opts._timeout = timedelta(seconds=123) opts._timeout = timedelta(seconds=123)
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True):
dist.init_process_group(**base_opts, pg_options=opts) dist.init_process_group(**base_opts, pg_options=opts)
# TODO(whc) i verified that we are indeed emitting this warning, and i can't figure out why i can't catch it. # TODO(whc) i verified that we are indeed emitting this warning, and i can't figure out why i can't catch it.
# self.assertEqual(len(w), 1) # self.assertEqual(len(w), 1)
@ -1266,30 +1266,26 @@ class DistributedDataParallelTest(
"DistributedDataParallel device_ids and output_device arguments only work with " "DistributedDataParallel device_ids and output_device arguments only work with "
"single-device/multiple-device GPU modules or CPU modules", "single-device/multiple-device GPU modules or CPU modules",
): ):
ddp_model = DistributedDataParallel( DistributedDataParallel(
model, output_device=gpus[1], process_group=process_group model, output_device=gpus[1], process_group=process_group
) )
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, "device_ids can only be None or contain a single element." ValueError, "device_ids can only be None or contain a single element."
): ):
ddp_model = DistributedDataParallel( DistributedDataParallel(model, device_ids=gpus, process_group=process_group)
model, device_ids=gpus, process_group=process_group
)
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, "input module must be on the same type of devices" ValueError, "input module must be on the same type of devices"
): ):
model.fc1 = model.fc1.cpu() model.fc1 = model.fc1.cpu()
ddp_model = DistributedDataParallel(model, process_group=process_group) DistributedDataParallel(model, process_group=process_group)
model = model.cpu() model = model.cpu()
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, "device_ids can only be None or contain a single element." ValueError, "device_ids can only be None or contain a single element."
): ):
ddp_model = DistributedDataParallel( DistributedDataParallel(model, device_ids=gpus, process_group=process_group)
model, device_ids=gpus, process_group=process_group
)
def _test_fp16(self, gradient_as_bucket_view=False): def _test_fp16(self, gradient_as_bucket_view=False):
process_group = self._get_process_group() process_group = self._get_process_group()
@ -1940,11 +1936,9 @@ class DistributedDataParallelTest(
), ),
named_msg, named_msg,
) )
for j, ((param_name, p), p_ddp) in enumerate( for (param_name, p), p_ddp in zip(
zip( m_child.named_parameters(),
m_child.named_parameters(), m_ddp_child.parameters(),
m_ddp_child.parameters(),
)
): ):
named_msg = ( named_msg = (
layer_name + "." + param_name + " " + iter_msg layer_name + "." + param_name + " " + iter_msg
@ -2010,15 +2004,13 @@ class DistributedDataParallelTest(
m = ConvNet(layer_devs, layer_formats, layer_dtypes) m = ConvNet(layer_devs, layer_formats, layer_dtypes)
if self.rank == 0: if self.rank == 0:
m_ddp = DistributedDataParallel( DistributedDataParallel(m, device_ids=[dev0], process_group=process_group)
m, device_ids=[dev0], process_group=process_group
)
else: else:
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
".* appears not to match strides of the same param in process 0", ".* appears not to match strides of the same param in process 0",
): ):
m_ddp = DistributedDataParallel( DistributedDataParallel(
m, device_ids=[dev0], process_group=process_group m, device_ids=[dev0], process_group=process_group
) )
@ -2356,7 +2348,7 @@ class DistributedDataParallelTest(
process_group=process_group, process_group=process_group,
) )
for i in range(3): for _ in range(3):
m.zero_grad(set_to_none=try_set_to_none) m.zero_grad(set_to_none=try_set_to_none)
m(1).sum().backward() m(1).sum().backward()
@ -2701,7 +2693,7 @@ class WorkHookTest(MultiProcessTestCase):
pg._register_on_completion_hook(hook) pg._register_on_completion_hook(hook)
tensor = torch.ones([2, 3]).cuda(self.rank) * self.rank tensor = torch.ones([2, 3]).cuda(self.rank) * self.rank
work_count = 3 work_count = 3
for i in range(work_count): for _ in range(work_count):
work += 1 work += 1
pg.broadcast([tensor]).wait() pg.broadcast([tensor]).wait()
@ -2806,7 +2798,7 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
# Run some GPU operations to make sure cuda has not gotten stuck. # Run some GPU operations to make sure cuda has not gotten stuck.
# It was observed cuda could get stuck if NCCL communicators were # It was observed cuda could get stuck if NCCL communicators were
# not properly aborted before throwing RuntimeError. # not properly aborted before throwing RuntimeError.
a = torch.rand(10).cuda(self.rank) torch.rand(10).cuda(self.rank)
elif self.rank == 1: elif self.rank == 1:
# Clean up structures (ex: files for FileStore before going down) # Clean up structures (ex: files for FileStore before going down)
del process_group del process_group
@ -2947,7 +2939,7 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = val os.environ["TORCH_NCCL_BLOCKING_WAIT"] = val
store = c10d.FileStore(self.file_name, self.world_size) store = c10d.FileStore(self.file_name, self.world_size)
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
@requires_nccl() @requires_nccl()
@skip_if_lt_x_gpu(3) @skip_if_lt_x_gpu(3)
@ -4223,7 +4215,7 @@ class NCCLTraceTestBase(MultiProcessTestCase):
def _join_processes(self, fn): def _join_processes(self, fn):
# We need to patch sys.exit() as skip_if will use sys.exit() and # We need to patch sys.exit() as skip_if will use sys.exit() and
# the exit code from the this process will not be catched. # the exit code from the this process will not be catched.
with mock.patch("sys.exit") as exit_mock: with mock.patch("sys.exit"):
fn() fn()
super()._join_processes(fn) super()._join_processes(fn)
@ -4231,7 +4223,7 @@ class NCCLTraceTestBase(MultiProcessTestCase):
proc = torch.multiprocessing.get_context("spawn").Process proc = torch.multiprocessing.get_context("spawn").Process
self.children_pipes = [] self.children_pipes = []
parent_pipes = [] parent_pipes = []
for i in range(self.world_size): for _ in range(self.world_size):
parent_conn, child_conn = torch.multiprocessing.Pipe() parent_conn, child_conn = torch.multiprocessing.Pipe()
self.children_pipes.append(child_conn) self.children_pipes.append(child_conn)
parent_pipes.append(parent_conn) parent_pipes.append(parent_conn)
@ -4346,7 +4338,7 @@ class NCCLTraceTest(NCCLTraceTestBase):
pg._enable_collectives_timing() pg._enable_collectives_timing()
device = self.local_device device = self.local_device
a = torch.full((3, 4), float(self.rank), device=device) a = torch.full((3, 4), float(self.rank), device=device)
for i in range(2): for _ in range(2):
f = pg.allreduce(a) f = pg.allreduce(a)
f.wait() f.wait()
torch.cuda.synchronize(device=device) torch.cuda.synchronize(device=device)
@ -4372,7 +4364,7 @@ class NCCLTraceTest(NCCLTraceTestBase):
pg._enable_collectives_timing() pg._enable_collectives_timing()
device = self.local_device device = self.local_device
a = torch.full((3, 4), float(self.rank), device=device) a = torch.full((3, 4), float(self.rank), device=device)
for i in range(2): for _ in range(2):
f = pg.allreduce(a) f = pg.allreduce(a)
f.wait() f.wait()
torch.cuda.synchronize(device=device) torch.cuda.synchronize(device=device)
@ -4420,7 +4412,7 @@ class NCCLTraceTest(NCCLTraceTestBase):
pg = self._create_process_group_nccl() pg = self._create_process_group_nccl()
device = self.local_device device = self.local_device
a = torch.full((3, 4), float(self.rank), device=device) a = torch.full((3, 4), float(self.rank), device=device)
for i in range(2): for _ in range(2):
f = pg.allreduce(a) f = pg.allreduce(a)
f.wait() f.wait()
torch.cuda.synchronize(device=device) torch.cuda.synchronize(device=device)
@ -4436,7 +4428,7 @@ class NCCLTraceTest(NCCLTraceTestBase):
pg = self._create_process_group_nccl() pg = self._create_process_group_nccl()
device = self.local_device device = self.local_device
a = torch.full((3, 4), float(self.rank), device=device) a = torch.full((3, 4), float(self.rank), device=device)
for i in range(2): for _ in range(2):
# test some other primitives to make sure # test some other primitives to make sure
# their strings are valid # their strings are valid
xs = [torch.ones(3, 4, device=device)] xs = [torch.ones(3, 4, device=device)]
@ -4496,7 +4488,7 @@ class NCCLTraceTest(NCCLTraceTestBase):
pg = self._create_process_group_nccl() pg = self._create_process_group_nccl()
device = self.local_device device = self.local_device
# send more works than the buffer size to overwrite the previous entry # send more works than the buffer size to overwrite the previous entry
for i in range(12): for _ in range(12):
a = [torch.ones(3, 4, device=device)] a = [torch.ones(3, 4, device=device)]
pg.broadcast(a).wait() pg.broadcast(a).wait()
torch.cuda.synchronize(device=device) torch.cuda.synchronize(device=device)
@ -4611,7 +4603,7 @@ class NCCLTraceTest(NCCLTraceTestBase):
th.start() th.start()
# fill the cuda buffer, at around 1024 events # fill the cuda buffer, at around 1024 events
# this will stall # this will stall
for i in range(2000): for _ in range(2000):
a = a + a a = a + a
th.join() th.join()
else: else:
@ -4646,7 +4638,7 @@ class NCCLTraceTest(NCCLTraceTestBase):
num_coalesced_ops = 20 num_coalesced_ops = 20
ops_per_coalesce = len(op_sizes_per_coalesce) ops_per_coalesce = len(op_sizes_per_coalesce)
for i in range(num_coalesced_ops): for _ in range(num_coalesced_ops):
ops = [] ops = []
for input_sizes in op_sizes_per_coalesce: for input_sizes in op_sizes_per_coalesce:
tensor = torch.zeros(input_sizes).to(self.local_device) tensor = torch.zeros(input_sizes).to(self.local_device)
@ -4745,7 +4737,7 @@ class NCCLTraceTest(NCCLTraceTestBase):
pg._enable_collectives_timing() pg._enable_collectives_timing()
num_repeats = 10 num_repeats = 10
ops_per_repeat = len(op_sizes) ops_per_repeat = len(op_sizes)
for i in range(num_repeats): for _ in range(num_repeats):
for input_sizes in op_sizes: for input_sizes in op_sizes:
tensor = torch.zeros(input_sizes).to(self.local_device) tensor = torch.zeros(input_sizes).to(self.local_device)
if self.rank == 0: if self.rank == 0:
@ -5047,7 +5039,7 @@ class NcclErrorDumpTest(NCCLTraceTestBase):
# Block the current stream on the NCCL stream # Block the current stream on the NCCL stream
work.wait() work.wait()
# Run some GPU operations # Run some GPU operations
a = torch.rand(10).cuda(self.rank) torch.rand(10).cuda(self.rank)
elif self.rank == 1: elif self.rank == 1:
# Clean up structures (ex: files for FileStore before going down) # Clean up structures (ex: files for FileStore before going down)
del process_group del process_group
@ -5108,7 +5100,6 @@ class ProcessGroupNCCLLargerScaleTest(MultiProcessTestCase):
tensor = torch.full((1,), self.rank).cuda(device) tensor = torch.full((1,), self.rank).cuda(device)
ng1 = c10d.split_group(pg, [[0, 1], [2, 3, 4, 5, 6, 7]]) ng1 = c10d.split_group(pg, [[0, 1], [2, 3, 4, 5, 6, 7]])
backend1 = ng1._get_backend(torch.device(device))
# comm split happens eagerly since device_id is passed to init_process_group. # comm split happens eagerly since device_id is passed to init_process_group.
self.assertEqual(backend.comm_split_count(), 1) self.assertEqual(backend.comm_split_count(), 1)

View File

@ -162,7 +162,6 @@ class ProcessGroupNCCLOpTest(MultiProcContinousTest):
@requires_nccl() @requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_allreduce_ops(self): def test_allreduce_ops(self):
device_count = torch.cuda.device_count()
pg = self.pg pg = self.pg
local_device_id = self.rank_to_GPU[self.rank][0] local_device_id = self.rank_to_GPU[self.rank][0]
@ -303,9 +302,8 @@ class ProcessGroupNCCLOpTest(MultiProcContinousTest):
pg = self.pg pg = self.pg
rank = self.rank_to_GPU[self.rank][0] rank = self.rank_to_GPU[self.rank][0]
with torch.cuda.device(rank): with torch.cuda.device(rank):
for i in range(10): for _ in range(10):
xs = [torch.FloatTensor([1]).cuda(rank)] xs = [torch.FloatTensor([1]).cuda(rank)]
ys = [torch.FloatTensor([4]).cuda(rank)]
for _ in range(30): for _ in range(30):
pg.allreduce(xs[0]).wait() pg.allreduce(xs[0]).wait()
@ -410,7 +408,7 @@ class ProcessGroupNCCLOpTest(MultiProcContinousTest):
output_tensors.append([t.cuda(device=gpu) for t in output_per_gpu]) output_tensors.append([t.cuda(device=gpu) for t in output_per_gpu])
expected_output.append([t.cuda(device=gpu) for t in expected_per_gpu]) expected_output.append([t.cuda(device=gpu) for t in expected_per_gpu])
result = allgather(output_tensors, tensors) allgather(output_tensors, tensors)
# Verification # Verification
self.assertEqual(output_tensors, expected_output) self.assertEqual(output_tensors, expected_output)
@ -558,7 +556,7 @@ class ProcessGroupNCCLOpTest(MultiProcContinousTest):
# init output # init output
output_ts = [] output_ts = []
for rank in range(self.world_size): for _ in range(self.world_size):
output_ts.append(torch.tensor([-1]).cuda(device_id)) output_ts.append(torch.tensor([-1]).cuda(device_id))
with self.assertRaisesRegex(ValueError, "invalid root rank"): with self.assertRaisesRegex(ValueError, "invalid root rank"):
@ -914,7 +912,6 @@ class ProcessGroupNCCLOpTest(MultiProcContinousTest):
@requires_nccl() @requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_send_recv(self): def test_send_recv(self):
pg = self.pg
device = self.rank_to_GPU[self.rank][0] device = self.rank_to_GPU[self.rank][0]
# Generate the same random tensor # Generate the same random tensor
@ -930,7 +927,6 @@ class ProcessGroupNCCLOpTest(MultiProcContinousTest):
@requires_nccl() @requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_send_recv_complex(self): def test_send_recv_complex(self):
pg = self.pg
device = self.rank_to_GPU[self.rank][0] device = self.rank_to_GPU[self.rank][0]
# Generate the same random tensor # Generate the same random tensor

View File

@ -755,7 +755,7 @@ class DistributedDataParallelTest(
ddp_state_dict = torch.load(checkpoint_path, map_location=map_location) ddp_state_dict = torch.load(checkpoint_path, map_location=map_location)
for model in [ddp_withload, model_withload]: for model in [ddp_withload, model_withload]:
for p in ddp_withload.parameters(): for p in model.parameters():
with torch.no_grad(): with torch.no_grad():
p.zero_() p.zero_()
ddp_withload.load_state_dict(ddp_state_dict) ddp_withload.load_state_dict(ddp_state_dict)

View File

@ -57,7 +57,7 @@ class TestCollectiveUtils(MultiProcessTestCase):
Ensure broadcast has no dependency on torch.distributed when run in single process. Ensure broadcast has no dependency on torch.distributed when run in single process.
""" """
func = mock.MagicMock() func = mock.MagicMock()
res = broadcast(data_or_fn=func, rank=0) broadcast(data_or_fn=func, rank=0)
func.assert_called_once() func.assert_called_once()
def test_broadcast_result_raises_exceptions_from_func( def test_broadcast_result_raises_exceptions_from_func(
@ -98,7 +98,7 @@ class TestCollectiveUtils(MultiProcessTestCase):
Ensure all_gather has no dependency on torch.distributed when run in single process. Ensure all_gather has no dependency on torch.distributed when run in single process.
""" """
func = mock.MagicMock() func = mock.MagicMock()
res = all_gather(data_or_fn=func) all_gather(data_or_fn=func)
func.assert_called_once() func.assert_called_once()
def test_all_gather_result_raises_exceptions_from_func( def test_all_gather_result_raises_exceptions_from_func(

View File

@ -791,8 +791,8 @@ class TestDataParallel(TestCase):
), ),
named_msg, named_msg,
) )
for j, ((param_name, p), p_dp) in enumerate( for (param_name, p), p_dp in zip(
zip(m_child.named_parameters(), m_dp_child.parameters()) m_child.named_parameters(), m_dp_child.parameters()
): ):
named_msg = ( named_msg = (
layer_name + "." + param_name + " " + iter_msg layer_name + "." + param_name + " " + iter_msg

View File

@ -88,7 +88,7 @@ class DeviceMeshTest(DTensorTestBase):
def test_assert_invalid_mesh_tensor(self): def test_assert_invalid_mesh_tensor(self):
mesh = torch.arange(self.world_size).to(self.rank) mesh = torch.arange(self.world_size).to(self.rank)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
device_mesh = DeviceMesh(self.device_type, mesh) DeviceMesh(self.device_type, mesh)
@with_comms() @with_comms()
def test_2d_mesh_non_eager_init_subgroup(self): def test_2d_mesh_non_eager_init_subgroup(self):
@ -144,7 +144,7 @@ class DeviceMeshTest(DTensorTestBase):
RuntimeError, RuntimeError,
"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
): ):
local_rank = mesh_2d.get_local_rank() mesh_2d.get_local_rank()
@with_comms @with_comms
def test_get_local_rank(self): def test_get_local_rank(self):
@ -258,7 +258,7 @@ class DeviceMeshTest(DTensorTestBase):
): ):
# test init_device_mesh with an invalid device type that contains a GPU index # test init_device_mesh with an invalid device type that contains a GPU index
mesh_shape = (2, self.world_size // 2) mesh_shape = (2, self.world_size // 2)
mesh_2d = init_device_mesh( init_device_mesh(
"cuda:0", mesh_shape=mesh_shape, mesh_dim_names=("dp", "tp") "cuda:0", mesh_shape=mesh_shape, mesh_dim_names=("dp", "tp")
) )
@ -453,7 +453,7 @@ class InitDeviceMeshTest(DTensorTestBase):
RuntimeError, RuntimeError,
"Each mesh_dim_name must be unique.", "Each mesh_dim_name must be unique.",
): ):
mesh = init_device_mesh( init_device_mesh(
self.device_type, self.device_type,
(2, 4), (2, 4),
mesh_dim_names=["dp", "dp"], mesh_dim_names=["dp", "dp"],
@ -465,7 +465,7 @@ class InitDeviceMeshTest(DTensorTestBase):
RuntimeError, RuntimeError,
"mesh_shape and mesh_dim_names should have same length!", "mesh_shape and mesh_dim_names should have same length!",
): ):
mesh = init_device_mesh( init_device_mesh(
self.device_type, self.device_type,
(8,), (8,),
mesh_dim_names=["dp", "tp"], mesh_dim_names=["dp", "tp"],
@ -483,7 +483,7 @@ class TestDeviceMeshGetItem(DTensorTestBase):
RuntimeError, "Cannot slice a DeviceMesh without mesh_dim_names!" RuntimeError, "Cannot slice a DeviceMesh without mesh_dim_names!"
): ):
mesh = init_device_mesh(self.device_type, (2, 4)) mesh = init_device_mesh(self.device_type, (2, 4))
child_mesh = mesh["DP"] mesh["DP"]
@with_comms @with_comms
def test_raises_invalid_mesh_dim_name(self): def test_raises_invalid_mesh_dim_name(self):
@ -493,7 +493,7 @@ class TestDeviceMeshGetItem(DTensorTestBase):
mesh = init_device_mesh( mesh = init_device_mesh(
self.device_type, (2, 4), mesh_dim_names=mesh_dim_names self.device_type, (2, 4), mesh_dim_names=mesh_dim_names
) )
child_mesh = mesh[child_mesh_dim_name] mesh[child_mesh_dim_name]
@with_comms @with_comms
def test_get_item_2d(self): def test_get_item_2d(self):
@ -514,7 +514,6 @@ class TestDeviceMeshGetItem(DTensorTestBase):
tp_group_idx = self.rank // 4 tp_group_idx = self.rank // 4
self.assertEqual(tp_mesh.mesh, pg_ranks_by_dim_name["TP"][tp_group_idx]) self.assertEqual(tp_mesh.mesh, pg_ranks_by_dim_name["TP"][tp_group_idx])
dp_mesh = mesh_2d["DP"]
dp_group_idx = self.rank % 4 dp_group_idx = self.rank % 4
self.assertEqual(mesh_2d["DP"].mesh, pg_ranks_by_dim_name["DP"][dp_group_idx]) self.assertEqual(mesh_2d["DP"].mesh, pg_ranks_by_dim_name["DP"][dp_group_idx])
@ -564,17 +563,15 @@ class TestDeviceMeshGetItem(DTensorTestBase):
def test_cache_and_reuse_submesh_slice_result(self): def test_cache_and_reuse_submesh_slice_result(self):
mesh = init_device_mesh(self.device_type, (2, 4), mesh_dim_names=("dp", "tp")) mesh = init_device_mesh(self.device_type, (2, 4), mesh_dim_names=("dp", "tp"))
dp_mesh = mesh["dp"]
ref_pg_count = _world.group_count ref_pg_count = _world.group_count
# When we call the "dp" slice second time, it should not create any new pg. # When we call the "dp" slice second time, it should not create any new pg.
# As we are just using the cached result so the pg count should be the same. # As we are just using the cached result so the pg count should be the same.
dp_mesh_2 = mesh["dp"]
self.assertEqual(ref_pg_count, _world.group_count) self.assertEqual(ref_pg_count, _world.group_count)
# When we call the "tp" slice, it should not create a new pg, as the "tp" slice would # When we call the "tp" slice, it should not create a new pg, as the "tp" slice would
# just reuse the parent mesh pg. # just reuse the parent mesh pg.
tp_mesh = mesh["tp"] mesh["tp"]
self.assertEqual(_world.group_count, ref_pg_count) self.assertEqual(_world.group_count, ref_pg_count)
@with_comms @with_comms
@ -603,7 +600,7 @@ class TestDeviceMeshGetItem(DTensorTestBase):
KeyError, KeyError,
"Invalid mesh_dim_names", "Invalid mesh_dim_names",
): ):
cp_dp_mesh = mesh_3d["cp", "dp"] mesh_3d["cp", "dp"]
@with_comms @with_comms
def test_flatten_mesh_3d(self): def test_flatten_mesh_3d(self):
@ -767,9 +764,9 @@ class TestMeshEnv(DTensorTestBase):
) )
with FakeTensorMode(): with FakeTensorMode():
dp_mesh = mesh_2d["DP"] mesh_2d["DP"]
tp_mesh = mesh_2d["TP"] mesh_2d["TP"]
dp_tp_mesh = mesh_2d["DP", "TP"] mesh_2d["DP", "TP"]
class DeviceMeshCollectiveTest(DTensorTestBase): class DeviceMeshCollectiveTest(DTensorTestBase):

View File

@ -421,7 +421,7 @@ class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase):
self.weight2 = nn.Parameter(torch.randn(512, 512)) self.weight2 = nn.Parameter(torch.randn(512, 512))
def forward(self, x, y): def forward(self, x, y):
u0, u1 = y.tolist() u0, _ = y.tolist()
x = torch.cat([x, x]) x = torch.cat([x, x])
y = x @ self.weight1 y = x @ self.weight1
z = (x + y @ self.weight2) * u0 z = (x + y @ self.weight2) * u0
@ -442,7 +442,7 @@ class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase):
self.weight2 = nn.Parameter(torch.randn(512, 512)) self.weight2 = nn.Parameter(torch.randn(512, 512))
def forward(self, x, y): def forward(self, x, y):
u0, u1 = y.tolist() u0, _ = y.tolist()
a = torch.ones(u0) a = torch.ones(u0)
x = torch.cat([x, x]) x = torch.cat([x, x])
y = x @ self.weight1 y = x @ self.weight1
@ -466,7 +466,7 @@ class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase):
def forward(self, x, y): def forward(self, x, y):
# partition one (contains the u0 def) # partition one (contains the u0 def)
u0, u1 = y.tolist() u0, _ = y.tolist()
x = torch.cat([x, x]) x = torch.cat([x, x])
y1 = x @ self.weight1 y1 = x @ self.weight1
# partition two (contains the variable) # partition two (contains the variable)
@ -511,7 +511,7 @@ class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase):
): ):
super().__init__() super().__init__()
layers = [] layers = []
for l in range(2): for _ in range(2):
layer = nn.ModuleList( layer = nn.ModuleList(
[ [
nn.LayerNorm(96), nn.LayerNorm(96),
@ -529,7 +529,7 @@ class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase):
for m in self.layers: for m in self.layers:
x = x.reshape(B * F, T, H) x = x.reshape(B * F, T, H)
x = m[0](x) x = m[0](x)
x, attn = m[1].forward(x, x, x) x, _ = m[1].forward(x, x, x)
x = x.reshape(B, F, T, H) x = x.reshape(B, F, T, H)
return x return x
@ -937,8 +937,8 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
@torch.compile() @torch.compile()
def f(x, y): def f(x, y):
zx = x.shape zx = x.shape # noqa: F841
zy = y.shape zy = y.shape # noqa: F841
return x.sum() + y.sum() return x.sum() + y.sum()
if self.rank == 0: if self.rank == 0:
@ -967,10 +967,10 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
@torch.compile() @torch.compile()
def f(x, y): def f(x, y):
z = y z = y # noqa: F841
print("woof") print("woof")
zx = x.shape zx = x.shape # noqa: F841
zy = y.shape zy = y.shape # noqa: F841
return x.sum() + y.sum() return x.sum() + y.sum()
if self.rank == 0: if self.rank == 0:
@ -999,8 +999,8 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
@torch.compile() @torch.compile()
def f(x, y): def f(x, y):
zx = x.shape zx = x.shape # noqa: F841
zy = y.shape zy = y.shape # noqa: F841
return x.sum() + y.sum() return x.sum() + y.sum()
if self.rank == 0: if self.rank == 0:
@ -1405,7 +1405,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
model = DDP(model, device_ids=self.device_ids) model = DDP(model, device_ids=self.device_ids)
hidden_states = torch.randn(B, S, H * D).to(device) hidden_states = torch.randn(B, S, H * D).to(device)
attention_scores = model(hidden_states) model(hidden_states)
torch.cuda.synchronize() torch.cuda.synchronize()
@patch.object(config, "optimize_ddp", True) @patch.object(config, "optimize_ddp", True)
@ -1461,7 +1461,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
model = DDP(model, device_ids=self.device_ids) model = DDP(model, device_ids=self.device_ids)
hidden_states = torch.randn(B, S, H * D).to(device) hidden_states = torch.randn(B, S, H * D).to(device)
attention_scores = model(hidden_states) model(hidden_states)
torch.cuda.synchronize() torch.cuda.synchronize()
@patch.object(config, "optimize_ddp", True) @patch.object(config, "optimize_ddp", True)
@ -1723,7 +1723,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
def test_fsdp_orig_params_assert(self): def test_fsdp_orig_params_assert(self):
# Test with basic FSDP wrapping (outer wrap around whole model) # Test with basic FSDP wrapping (outer wrap around whole model)
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") m, inputs, _ = get_model(f"cuda:{self.rank}")
fsdp_m = FSDP(m, use_orig_params=False) fsdp_m = FSDP(m, use_orig_params=False)
fsdp_m = torch.compile(fsdp_m) fsdp_m = torch.compile(fsdp_m)
self.assertRaisesRegex( self.assertRaisesRegex(

View File

@ -130,7 +130,7 @@ class TestExpand(MultiThreadedTestCase):
tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD, "bla") tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD, "bla")
self.assertEqual("bla", tag) self.assertEqual("bla", tag)
my_pg, others = new_subgroups(group_size=2) my_pg, _ = new_subgroups(group_size=2)
tag, rankset, group_size = ft_c._expand_group(my_pg) tag, rankset, group_size = ft_c._expand_group(my_pg)
self.assertEqual(c10d._get_group_tag(my_pg), tag) self.assertEqual(c10d._get_group_tag(my_pg), tag)
self.assertEqual(dist.get_process_group_ranks(my_pg), rankset) self.assertEqual(dist.get_process_group_ranks(my_pg), rankset)
@ -588,7 +588,7 @@ class TestCollectivesWithDistributedBackend(DistributedTestBase):
def allreduce(t, pg): def allreduce(t, pg):
return ft_c.all_reduce(t, "sum", pg) return ft_c.all_reduce(t, "sum", pg)
compiled_allreduce = torch.compile(allreduce, fullgraph=True) compiled_allreduce = torch.compile(allreduce, fullgraph=True) # noqa: F841
dist.init_process_group( dist.init_process_group(
backend="fake", backend="fake",
rank=0, rank=0,
@ -615,9 +615,7 @@ class TestCollectivesWithDistributedBackend(DistributedTestBase):
return batch * 5 return batch * 5
compiled_func = torch.compile(func) compiled_func = torch.compile(func)
ret = compiled_func( compiled_func(torch.ones((100,), device=device), self.process_group, self.rank)
torch.ones((100,), device=device), self.process_group, self.rank
)
dist.barrier() dist.barrier()
@ -715,7 +713,7 @@ class TestFunctionalAutograd(MultiThreadedTestCase):
out = compiled(t, self.world_size) out = compiled(t, self.world_size)
out.backward() out.backward()
res, codes = run_and_get_code(run_with_backward) _, codes = run_and_get_code(run_with_backward)
for code in codes: for code in codes:
FileCheck().check_count( FileCheck().check_count(
"_c10d_functional.all_to_all_single.default", 1, exactly=True "_c10d_functional.all_to_all_single.default", 1, exactly=True

View File

@ -411,7 +411,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
y = self.emb(x) y = self.emb(x)
last_dim = y.dim() - 1 last_dim = y.dim() - 1
y = y.transpose_(0, last_dim).contiguous() y = y.transpose_(0, last_dim).contiguous()
res = _functional_collectives.all_gather_tensor(y, 0, ranks, tag) _functional_collectives.all_gather_tensor(y, 0, ranks, tag)
out = y.transpose_(0, last_dim).contiguous() out = y.transpose_(0, last_dim).contiguous()
return out return out

View File

@ -35,7 +35,6 @@ class TestDistributedLaunch(TestCase):
def test_launch_user_script(self): def test_launch_user_script(self):
nnodes = 1 nnodes = 1
nproc_per_node = 4 nproc_per_node = 4
world_size = nnodes * nproc_per_node
sock = get_socket_with_port() sock = get_socket_with_port()
with closing(sock): with closing(sock):
master_port = sock.getsockname()[1] master_port = sock.getsockname()[1]

View File

@ -553,7 +553,7 @@ class LibUvTCPStoreTest(TCPStoreTest):
) )
with self.assertRaisesRegex(NotImplementedError, err_msg_reg): with self.assertRaisesRegex(NotImplementedError, err_msg_reg):
store = dist.TCPStore( dist.TCPStore(
addr, addr,
port, port,
1, 1,
@ -748,7 +748,7 @@ class RendezvousTCPTest(TestCase):
url = self.create_tcp_url() url = self.create_tcp_url()
test_store_timeout = timedelta(seconds=0.1) test_store_timeout = timedelta(seconds=0.1)
gen0 = dist.rendezvous(url + "&rank=0", timeout=timedelta(seconds=10)) gen0 = dist.rendezvous(url + "&rank=0", timeout=timedelta(seconds=10))
store0, rank0, size0 = next(gen0) store0, _, _ = next(gen0)
store0.set_timeout(test_store_timeout) store0.set_timeout(test_store_timeout)
# this should time out in 0.1s. If the timeout passed into rendezvous was # this should time out in 0.1s. If the timeout passed into rendezvous was
# not respected, it will take much longer to timeout. # not respected, it will take much longer to timeout.
@ -766,7 +766,7 @@ class RendezvousTCPTest(TestCase):
url = self.create_tcp_url() url = self.create_tcp_url()
test_store_timeout = timedelta(seconds=0.1) test_store_timeout = timedelta(seconds=0.1)
gen0 = dist.rendezvous(url + "&rank=0", timeout=timedelta(seconds=10)) gen0 = dist.rendezvous(url + "&rank=0", timeout=timedelta(seconds=10))
store0, rank0, size0 = next(gen0) store0, _, _ = next(gen0)
store0.set_timeout(test_store_timeout) store0.set_timeout(test_store_timeout)
# this should time out in 10s. If the timeout passed into rendezvous was # this should time out in 10s. If the timeout passed into rendezvous was
# not respected, it will take much longer to timeout. # not respected, it will take much longer to timeout.
@ -787,7 +787,7 @@ class RendezvousTCPTest(TestCase):
def test_tcp_store_url_with_libuv(self): def test_tcp_store_url_with_libuv(self):
url = self.create_tcp_url() url = self.create_tcp_url()
gen0 = dist.rendezvous(url + "&rank=0&use_libuv=1") gen0 = dist.rendezvous(url + "&rank=0&use_libuv=1")
store0, rank0, size0 = next(gen0) store0, _, _ = next(gen0)
self.assertTrue(store0.libuvBackend) self.assertTrue(store0.libuvBackend)
@ -1078,7 +1078,7 @@ class TestClientProtocol(TestCase):
thread = threading.Thread(target=listen) thread = threading.Thread(target=listen)
thread.start() thread.start()
store = dist.TCPStore( dist.TCPStore(
host_name="localhost", host_name="localhost",
port=port, port=port,
world_size=2, world_size=2,

View File

@ -332,7 +332,6 @@ class SymmetricMemoryTest(MultiProcessTestCase):
K = 32 K = 32
group = dist.group.WORLD group = dist.group.WORLD
rank = self.rank rank = self.rank
world_size = self.world_size
torch.manual_seed(42 + rank) torch.manual_seed(42 + rank)
A_shard = torch.rand(BATCH, M // self.world_size, K, device="cuda") A_shard = torch.rand(BATCH, M // self.world_size, K, device="cuda")
@ -428,7 +427,6 @@ class SymmetricMemoryTest(MultiProcessTestCase):
K = 32 K = 32
group = dist.group.WORLD group = dist.group.WORLD
rank = self.rank rank = self.rank
world_size = self.world_size
if gather_dim == 0: if gather_dim == 0:
leading_dims = (BATCH // self.world_size, M) leading_dims = (BATCH // self.world_size, M)
@ -513,7 +511,6 @@ class SymmetricMemoryTest(MultiProcessTestCase):
K = 32 K = 32
group = dist.group.WORLD group = dist.group.WORLD
rank = self.rank rank = self.rank
world_size = self.world_size
torch.manual_seed(42 + rank) torch.manual_seed(42 + rank)
A = torch.rand(BATCH, M, K, device="cuda") A = torch.rand(BATCH, M, K, device="cuda")
@ -546,7 +543,6 @@ class SymmetricMemoryTest(MultiProcessTestCase):
K = 32 K = 32
group = dist.group.WORLD group = dist.group.WORLD
rank = self.rank rank = self.rank
world_size = self.world_size
torch.manual_seed(42 + rank) torch.manual_seed(42 + rank)
A = torch.rand(BATCH, M, K, device="cuda").to(torch.float8_e4m3fn) A = torch.rand(BATCH, M, K, device="cuda").to(torch.float8_e4m3fn)

View File

@ -1314,7 +1314,7 @@ class TestDistributions(DistributionsTestCase):
if not msk.all(): if not msk.all():
counts = np.concatenate([counts[msk], np.sum(counts[~msk], keepdims=True)]) counts = np.concatenate([counts[msk], np.sum(counts[~msk], keepdims=True)])
pmf = np.concatenate([pmf[msk], np.sum(pmf[~msk], keepdims=True)]) pmf = np.concatenate([pmf[msk], np.sum(pmf[~msk], keepdims=True)])
chisq, p = scipy.stats.chisquare(counts, pmf * num_samples) _, p = scipy.stats.chisquare(counts, pmf * num_samples)
self.assertGreater(p, failure_rate, message) self.assertGreater(p, failure_rate, message)
def _check_enumerate_support(self, dist, examples): def _check_enumerate_support(self, dist, examples):
@ -1912,9 +1912,7 @@ class TestDistributions(DistributionsTestCase):
@set_default_dtype(torch.double) @set_default_dtype(torch.double)
def test_one_hot_categorical_2d(self): def test_one_hot_categorical_2d(self):
probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]] probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
p = torch.tensor(probabilities, requires_grad=True) p = torch.tensor(probabilities, requires_grad=True)
s = torch.tensor(probabilities_1, requires_grad=True)
self.assertEqual(OneHotCategorical(p).sample().size(), (2, 3)) self.assertEqual(OneHotCategorical(p).sample().size(), (2, 3))
self.assertEqual( self.assertEqual(
OneHotCategorical(p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3) OneHotCategorical(p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3)
@ -2074,13 +2072,11 @@ class TestDistributions(DistributionsTestCase):
@set_default_dtype(torch.double) @set_default_dtype(torch.double)
def test_relaxed_one_hot_categorical_2d(self): def test_relaxed_one_hot_categorical_2d(self):
probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]] probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
temp = torch.tensor([3.0], requires_grad=True) temp = torch.tensor([3.0], requires_grad=True)
# The lower the temperature, the more unstable the log_prob gradcheck is # The lower the temperature, the more unstable the log_prob gradcheck is
# w.r.t. the sample. Values below 0.25 empirically fail the default tol. # w.r.t. the sample. Values below 0.25 empirically fail the default tol.
temp_2 = torch.tensor([0.25], requires_grad=True) temp_2 = torch.tensor([0.25], requires_grad=True)
p = torch.tensor(probabilities, requires_grad=True) p = torch.tensor(probabilities, requires_grad=True)
s = torch.tensor(probabilities_1, requires_grad=True)
self.assertEqual(RelaxedOneHotCategorical(temp, p).sample().size(), (2, 3)) self.assertEqual(RelaxedOneHotCategorical(temp, p).sample().size(), (2, 3))
self.assertEqual( self.assertEqual(
RelaxedOneHotCategorical(temp, p).sample(sample_shape=(3, 4)).size(), RelaxedOneHotCategorical(temp, p).sample(sample_shape=(3, 4)).size(),
@ -3939,7 +3935,7 @@ class TestDistributions(DistributionsTestCase):
for dim in range(2, 5): for dim in range(2, 5):
log_probs = [] log_probs = []
lkj = LKJCholesky(dim, concentration=1.0, validate_args=True) lkj = LKJCholesky(dim, concentration=1.0, validate_args=True)
for i in range(2): for _ in range(2):
sample = lkj.sample() sample = lkj.sample()
sample_tril = tril_matrix_to_vec(sample, diag=-1) sample_tril = tril_matrix_to_vec(sample, diag=-1)
log_prob = lkj.log_prob(sample) log_prob = lkj.log_prob(sample)
@ -6241,7 +6237,7 @@ class TestLazyLogitsInitialization(DistributionsTestCase):
except NotImplementedError: except NotImplementedError:
pass pass
self.assertNotIn("probs", dist.__dict__, msg=message) self.assertNotIn("probs", dist.__dict__, msg=message)
batch_shape, event_shape = dist.batch_shape, dist.event_shape dist.batch_shape, dist.event_shape
self.assertNotIn("probs", dist.__dict__, msg=message) self.assertNotIn("probs", dist.__dict__, msg=message)
def test_lazy_probs_initialization(self): def test_lazy_probs_initialization(self):
@ -6258,7 +6254,7 @@ class TestLazyLogitsInitialization(DistributionsTestCase):
except NotImplementedError: except NotImplementedError:
pass pass
self.assertNotIn("logits", dist.__dict__, msg=message) self.assertNotIn("logits", dist.__dict__, msg=message)
batch_shape, event_shape = dist.batch_shape, dist.event_shape dist.batch_shape, dist.event_shape
self.assertNotIn("logits", dist.__dict__, msg=message) self.assertNotIn("logits", dist.__dict__, msg=message)
@ -6565,6 +6561,7 @@ class TestFunctors(DistributionsTestCase):
expected_jac = sum( expected_jac = sum(
[t1.log_abs_det_jacobian(x1, y1), t2.log_abs_det_jacobian(x2, y2)] [t1.log_abs_det_jacobian(x1, y1), t2.log_abs_det_jacobian(x2, y2)]
) )
self.assertEqual(actual_jac, expected_jac)
def test_stack_transform(self): def test_stack_transform(self):
x1 = -1 * torch.arange(1, 101, dtype=torch.float) x1 = -1 * torch.arange(1, 101, dtype=torch.float)
@ -6628,18 +6625,18 @@ class TestValidation(DistributionsTestCase):
for v in torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]): for v in torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]):
# samples with incorrect shape must throw ValueError only # samples with incorrect shape must throw ValueError only
try: try:
log_prob = d_val.log_prob(v) d_val.log_prob(v)
except ValueError: except ValueError:
pass pass
# get sample of correct shape # get sample of correct shape
val = torch.full(d_val.batch_shape + d_val.event_shape, v) val = torch.full(d_val.batch_shape + d_val.event_shape, v)
# check samples with incorrect support # check samples with incorrect support
try: try:
log_prob = d_val.log_prob(val) d_val.log_prob(val)
except ValueError as e: except ValueError as e:
if e.args and "must be within the support" in e.args[0]: if e.args and "must be within the support" in e.args[0]:
try: try:
log_prob = d_nonval.log_prob(val) d_nonval.log_prob(val)
except RuntimeError: except RuntimeError:
pass pass

View File

@ -1260,7 +1260,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
super().__init__() super().__init__()
def forward(self, x, ys): def forward(self, x, ys):
a = torch.sin(x) a = torch.sin(x) # noqa: F841
b = torch.cos(ys[0]) b = torch.cos(ys[0])
c = torch.cos(ys[1]) c = torch.cos(ys[1])
return (x, [b, c]) return (x, [b, c])

View File

@ -453,7 +453,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
a = torch.randn(3, 3, requires_grad=True) a = torch.randn(3, 3, requires_grad=True)
b = torch.randn(3, 3, requires_grad=True) b = torch.randn(3, 3, requires_grad=True)
a1, a2 = a.clone(), a.clone() a1, a2 = a.clone(), a.clone()
b1, b2 = b.clone(), b.clone() _, b2 = b.clone(), b.clone()
failure_reason = None failure_reason = None
@ -481,7 +481,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
c = torch.randn(3, 3, requires_grad=True) c = torch.randn(3, 3, requires_grad=True)
d = torch.randn(3, 3, requires_grad=True) d = torch.randn(3, 3, requires_grad=True)
c3, c4 = c.clone(), c.clone() c3, c4 = c.clone(), c.clone()
d3, d4 = d.clone(), d.clone() _, d4 = d.clone(), d.clone()
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F()) f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
f(c3, c3, 3, 3) f(c3, c3, 3, 3)
@ -507,7 +507,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
b = torch.randn(3, 3, requires_grad=True) b = torch.randn(3, 3, requires_grad=True)
z = a z = a
a1, a2 = a.clone(), a.clone() a1, a2 = a.clone(), a.clone()
b1, b2 = b.clone(), b.clone() _, b2 = b.clone(), b.clone()
failure_reason = None failure_reason = None
@ -543,7 +543,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
a = torch.randn(3, 3, requires_grad=True) a = torch.randn(3, 3, requires_grad=True)
b = torch.randn(3, 3, requires_grad=True) b = torch.randn(3, 3, requires_grad=True)
a1, a2 = a.clone(), a.clone() a1, a2 = a.clone(), a.clone()
b1, b2 = b.clone(), b.clone() _, b2 = b.clone(), b.clone()
failure_reason = None failure_reason = None
@ -571,7 +571,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
c = torch.randn(3, 3, requires_grad=True) c = torch.randn(3, 3, requires_grad=True)
d = torch.randn(3, 3, requires_grad=True) d = torch.randn(3, 3, requires_grad=True)
c3, c4 = c.clone(), c.clone() c3, c4 = c.clone(), c.clone()
d3, d4 = d.clone(), d.clone() _, d4 = d.clone(), d.clone()
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F()) f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
f([3, 2, 1], [4, 5, 6], c3, c3) f([3, 2, 1], [4, 5, 6], c3, c3)
@ -593,7 +593,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
a = torch.randn(3, 3, requires_grad=True) a = torch.randn(3, 3, requires_grad=True)
b = torch.randn(3, 3, requires_grad=True) b = torch.randn(3, 3, requires_grad=True)
a1, a2 = a.clone(), a.clone() a1, a2 = a.clone(), a.clone()
b1, b2 = b.clone(), b.clone() _, b2 = b.clone(), b.clone()
failure_reason = None failure_reason = None
@ -621,7 +621,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
c = torch.randn(3, 3, requires_grad=True) c = torch.randn(3, 3, requires_grad=True)
d = torch.randn(3, 3, requires_grad=True) d = torch.randn(3, 3, requires_grad=True)
c3, c4 = c.clone(), c.clone() c3, c4 = c.clone(), c.clone()
d3, d4 = d.clone(), d.clone() _, d4 = d.clone(), d.clone()
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F()) f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
f(c3, c3) f(c3, c3)
@ -642,7 +642,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
a = torch.randn(3, 3, requires_grad=True) a = torch.randn(3, 3, requires_grad=True)
b = torch.randn(3, 3, requires_grad=True) b = torch.randn(3, 3, requires_grad=True)
a1, a2, a3, a4 = a.clone(), a.clone(), a.clone(), a.clone() a1, a2, a3, a4 = a.clone(), a.clone(), a.clone(), a.clone()
b1, b2, b3, b4 = b.clone(), b.clone(), b.clone(), b.clone() _, b2, b3, b4 = b.clone(), b.clone(), b.clone(), b.clone()
failure_reason = None failure_reason = None
@ -670,7 +670,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
c = torch.randn(3, 3, requires_grad=True) c = torch.randn(3, 3, requires_grad=True)
d = torch.randn(3, 3, requires_grad=True) d = torch.randn(3, 3, requires_grad=True)
c3, c4 = c.clone(), c.clone() c3, c4 = c.clone(), c.clone()
d3, d4 = d.clone(), d.clone() _, d4 = d.clone(), d.clone()
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F()) f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
f(a3, b3, c3, c3) f(a3, b3, c3, c3)
@ -1017,7 +1017,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
activities=[torch.profiler.ProfilerActivity.CPU], activities=[torch.profiler.ProfilerActivity.CPU],
record_shapes=True, record_shapes=True,
) as kineto_prof: ) as kineto_prof:
res = model_instance(*args) model_instance(*args)
bwd_set = set() bwd_set = set()
prof_str = "SeqNr|Thread|FwdThread|Name\n" prof_str = "SeqNr|Thread|FwdThread|Name\n"
for event in kineto_prof.events(): for event in kineto_prof.events():
@ -1191,7 +1191,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
x = torch.randn(3, requires_grad=True) x = torch.randn(3, requires_grad=True)
with self.assertRaisesRegex(RuntimeError, "Cannot access data pointer"): with self.assertRaisesRegex(RuntimeError, "Cannot access data pointer"):
y = torch.compile(f, backend="aot_eager", fullgraph=True)(x) torch.compile(f, backend="aot_eager", fullgraph=True)(x)
self.assertTrue(backward_called) self.assertTrue(backward_called)
# We don't know how to catch multiple mutations to the same memory location # We don't know how to catch multiple mutations to the same memory location

View File

@ -157,7 +157,7 @@ class AOTAutogradCacheTests(InductorTestCase):
with torch.autograd._force_original_view_tracking(True): with torch.autograd._force_original_view_tracking(True):
compiled_fn = torch.compile(fn) compiled_fn = torch.compile(fn)
out = compiled_fn(torch.rand(2, 3)) compiled_fn(torch.rand(2, 3))
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1) self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
@ -654,7 +654,7 @@ class AOTAutogradCachePicklerTests(torch._dynamo.test_case.TestCase):
def fn(x): def fn(x):
return x.sin().cos() return x.sin().cos()
def fn2(x): def fn2(x): # noqa: F841
y = x.sin() y = x.sin()
z = y.cos() z = y.cos()
return z return z

View File

@ -760,7 +760,7 @@ class GraphModule(torch.nn.Module):
def backward(ctx, gO): def backward(ctx, gO):
return torch.tensor(float("nan")).expand(10, 10) return torch.tensor(float("nan")).expand(10, 10)
def run_fn(a): def run_fn(a): # noqa: F841
out = MyFunc2.apply(a) out = MyFunc2.apply(a)
return out.sum() return out.sum()
@ -837,11 +837,11 @@ class GraphModule(torch.nn.Module):
x = torch.randn(5, 5, requires_grad=True) x = torch.randn(5, 5, requires_grad=True)
y = torch.randn(5, 5, requires_grad=True) y = torch.randn(5, 5, requires_grad=True)
q, p = Identity.apply(x, y) Identity.apply(x, y)
a = torch.rand(1, 2) a = torch.rand(1, 2)
b = torch.rand(1, requires_grad=True) b = torch.rand(1, requires_grad=True)
view_a = MyFn.apply(a) MyFn.apply(a)
a = torch.ones(2, requires_grad=True) a = torch.ones(2, requires_grad=True)
b = torch.ones(2, requires_grad=True) b = torch.ones(2, requires_grad=True)
@ -860,7 +860,7 @@ class GraphModule(torch.nn.Module):
MyFn2.apply(c, d) MyFn2.apply(c, d)
base = torch.rand(10, requires_grad=True) base = torch.rand(10, requires_grad=True)
foo = MyFn3.apply(base, False) MyFn3.apply(base, False)
test() test()
opt_test = torch.compile(test, backend="eager") opt_test = torch.compile(test, backend="eager")

View File

@ -267,9 +267,8 @@ class TestCustomBackendAPI(torch._dynamo.test_case.TestCase):
self.assertTrue(backend_run) self.assertTrue(backend_run)
def test_lookup_backend(self): def test_lookup_backend(self):
from torch._dynamo import list_backends, lookup_backend from torch._dynamo import lookup_backend
backends = list_backends()
backend_run = False backend_run = False
def my_compiler(gm, example_inputs): def my_compiler(gm, example_inputs):

View File

@ -247,8 +247,6 @@ class GraphModule(torch.nn.Module):
with compiled_autograd._enable(compiler_fn): with compiled_autograd._enable(compiler_fn):
out.backward(grad_out) out.backward(grad_out)
graph = None
if __name__ == "__main__": if __name__ == "__main__":
from torch._dynamo.test_case import run_tests from torch._dynamo.test_case import run_tests

View File

@ -518,7 +518,7 @@ def fn():
insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False) insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False)
self.assertEqual(insts[-1].opname, "NOP") self.assertEqual(insts[-1].opname, "NOP")
insts_i = 0 insts_i = 0
for i, inst in enumerate(dis_insts): for inst in dis_insts:
if inst.opname == "RETURN_CONST": if inst.opname == "RETURN_CONST":
self.assertEqual(insts[insts_i].opname, "LOAD_CONST") self.assertEqual(insts[insts_i].opname, "LOAD_CONST")
insts_i += 1 insts_i += 1
@ -538,7 +538,7 @@ def fn():
x = x + 1 x = x + 1
except NotImplementedError: except NotImplementedError:
x = x + 1 x = x + 1
except Exception as e: except Exception:
x = x + 1 x = x + 1
return x return x

View File

@ -43,7 +43,7 @@ class TestCompilerBisector(TestCase):
return lib return lib
def test_bad_decomp(self): def test_bad_decomp(self):
mod = import_module("torch._inductor.compile_fx") import_module("torch._inductor.compile_fx")
def bad_exp_decomp(self, rate=1, generator=None): def bad_exp_decomp(self, rate=1, generator=None):
assert generator is None assert generator is None
@ -86,7 +86,7 @@ class TestCompilerBisector(TestCase):
vq_compiled = torch.compile(vq) vq_compiled = torch.compile(vq)
x = torch.randn(4, 400, 256).cuda() x = torch.randn(4, 400, 256).cuda()
with torch._dynamo.utils.preserve_rng_state(): with torch._dynamo.utils.preserve_rng_state():
out = vq(x) vq(x)
out_compiled = vq_compiled(x) out_compiled = vq_compiled(x)
return not out_compiled.isnan().any() return not out_compiled.isnan().any()
@ -150,7 +150,6 @@ class TestCompilerBisector(TestCase):
self.assertTrue("inductor_fallback_random" in out.debug_info) self.assertTrue("inductor_fallback_random" in out.debug_info)
def test_crossref(self): def test_crossref(self):
test_ns = "bisect_ops"
with _scoped_library(self.test_ns, "FRAGMENT") as lib: with _scoped_library(self.test_ns, "FRAGMENT") as lib:
lib.define("foo(Tensor x) -> Tensor") lib.define("foo(Tensor x) -> Tensor")
op = self.get_op("foo") op = self.get_op("foo")

View File

@ -117,7 +117,7 @@ def forward(self, L_x_ : torch.Tensor):
return y + 3 return y + 3
def munge_disas(s): def munge_disas(s): # noqa: F841
re.sub( re.sub(
r"^(?: +\d+)?(?: +(-->)) \+\d+ ([A-Za-z0-9_]+)", r"^(?: +\d+)?(?: +(-->)) \+\d+ ([A-Za-z0-9_]+)",
"\1 \3", "\1 \3",
@ -271,7 +271,7 @@ y = FakeTensor(..., size=(2,))
y = g(y) y = g(y)
return y + 3 return y + 3
def munge_filenames(s): def munge_filenames(s): # noqa: F841
return re.sub(r'File "[^"]+", line \d+', 'File "X", line X', s) return re.sub(r'File "[^"]+", line \d+', 'File "X", line X', s)
f(torch.randn(2)) f(torch.randn(2))
@ -389,7 +389,7 @@ y = FakeTensor(..., size=(2,))
@torch.compile(backend=cnt) @torch.compile(backend=cnt)
def f(x): def f(x):
y = x * 2 y = x * 2
lit = 2 lit = 2 # noqa: F841
@comptime @comptime
def _(ctx): def _(ctx):

View File

@ -268,15 +268,13 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
cur_stream.wait_stream(new_stream) cur_stream.wait_stream(new_stream)
x = torch.add(x, 4) x = torch.add(x, 4)
is_idle = cur_stream.query() cur_stream.query()
cur_stream.synchronize() cur_stream.synchronize()
with torch.cuda.stream(new_stream): with torch.cuda.stream(new_stream):
x = torch.add(x, 5) x = torch.add(x, 5)
new_stream.synchronize() new_stream.synchronize()
is_equal = cur_stream == new_stream
x = torch.relu(x) x = torch.relu(x)
x = torch.cos(x) x = torch.cos(x)
return x return x
@ -439,7 +437,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
x = torch.add(x, 3) x = torch.add(x, 3)
event = cur_stream.record_event() event = cur_stream.record_event()
is_idle = event.query() event.query()
new_stream.wait_event(event) new_stream.wait_event(event)
with torch.cuda.stream(new_stream): with torch.cuda.stream(new_stream):
@ -481,7 +479,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
x = torch.add(x, 3) x = torch.add(x, 3)
event = cur_stream.record_event() event = cur_stream.record_event()
is_idle = event.query() event.query()
new_stream.wait_event(event) new_stream.wait_event(event)
with torch.cuda.stream(new_stream): with torch.cuda.stream(new_stream):
@ -567,7 +565,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
real_device = real.device real_device = real.device
real_dtype = real.dtype real_dtype = real.dtype
graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
exported = graph(torch.tensor([0.5])) exported = graph(torch.tensor([0.5]))
self.assertEqual(exported.device, real_device) self.assertEqual(exported.device, real_device)
self.assertEqual(exported.dtype, real_dtype) self.assertEqual(exported.dtype, real_dtype)
@ -676,7 +674,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
real_device = real.device real_device = real.device
real_dtype = real.dtype real_dtype = real.dtype
graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
exported = graph(torch.tensor([0.5])) exported = graph(torch.tensor([0.5]))
self.assertEqual(exported.device, real_device) self.assertEqual(exported.device, real_device)
self.assertEqual(exported.dtype, real_dtype) self.assertEqual(exported.dtype, real_dtype)
@ -850,7 +848,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
real_device = real.device real_device = real.device
real_dtype = real.dtype real_dtype = real.dtype
graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
exported = graph(torch.tensor([0.5])) exported = graph(torch.tensor([0.5]))
self.assertEqual(exported.device, real_device) self.assertEqual(exported.device, real_device)
self.assertEqual(exported.dtype, real_dtype) self.assertEqual(exported.dtype, real_dtype)
@ -876,7 +874,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
real_device = real.device real_device = real.device
real_dtype = real.dtype real_dtype = real.dtype
graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
exported = graph(torch.tensor([0.5])) exported = graph(torch.tensor([0.5]))
self.assertEqual(exported.device, real_device) self.assertEqual(exported.device, real_device)
self.assertEqual(exported.dtype, real_dtype) self.assertEqual(exported.dtype, real_dtype)
@ -1297,7 +1295,7 @@ class GraphModule(torch.nn.Module):
eager = EagerAndRecordGraphs() eager = EagerAndRecordGraphs()
torch.compile(fn, backend=eager, fullgraph=False)(torch.randn(())) torch.compile(fn, backend=eager, fullgraph=False)(torch.randn(()))
def check_graph(actual, expected): def check_graph(actual, expected): # noqa: F841
self.assertExpectedInline(actual, expected) self.assertExpectedInline(actual, expected)
graph = eager.graphs[0] graph = eager.graphs[0]
@ -1342,7 +1340,7 @@ class GraphModule(torch.nn.Module):
for i in range(2): for i in range(2):
torch._dynamo.reset() torch._dynamo.reset()
ctx_wrapper, mode = ctx_wrappers[i] ctx_wrapper, _ = ctx_wrappers[i]
ctx_wrapper_inverse, mode_inverse = ctx_wrappers[(i + 1) % 2] ctx_wrapper_inverse, mode_inverse = ctx_wrappers[(i + 1) % 2]
def fn(x): def fn(x):
@ -1373,7 +1371,7 @@ class GraphModule(torch.nn.Module):
for i in range(2): for i in range(2):
torch._dynamo.reset() torch._dynamo.reset()
ctx_wrapper, mode = ctx_wrappers[i] ctx_wrapper, _ = ctx_wrappers[i]
ctx_wrapper_inverse, mode_inverse = ctx_wrappers[(i + 1) % 2] ctx_wrapper_inverse, mode_inverse = ctx_wrappers[(i + 1) % 2]
def fn(x): def fn(x):

View File

@ -63,7 +63,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
@torch.compile(backend="cudagraphs") @torch.compile(backend="cudagraphs")
def fn(x, y): def fn(x, y):
for i in range(N_ITERS): for _ in range(N_ITERS):
loss = model(x, y).sum() loss = model(x, y).sum()
loss.backward() loss.backward()
@ -80,7 +80,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
@torch.compile(backend="cudagraphs") @torch.compile(backend="cudagraphs")
def fn(x, y): def fn(x, y):
for i in range(N_ITERS): for _ in range(N_ITERS):
loss = model(x, y).sum() loss = model(x, y).sum()
loss.backward() loss.backward()
@ -96,7 +96,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
@torch.compile(backend="cudagraphs") @torch.compile(backend="cudagraphs")
def fn(x, y): def fn(x, y):
for i in range(N_ITERS): for _ in range(N_ITERS):
loss = model(x, y).sum() loss = model(x, y).sum()
loss.backward() loss.backward()

View File

@ -45,7 +45,7 @@ def forward(self, x_1):
""", # NOQA: B950 """, # NOQA: B950
) )
fp64_model, fp64_examples = debug_utils.cast_to_fp64(fx, (x,)) _, fp64_examples = debug_utils.cast_to_fp64(fx, (x,))
self.assertEqual(fp64_examples, (x.to(torch.float64),)) self.assertEqual(fp64_examples, (x.to(torch.float64),))
self.assertExpectedInline( self.assertExpectedInline(
@ -79,7 +79,7 @@ def forward(self, x_1):
_tensor_constant0 _tensor_constant0
) )
_tensor_constant0 = None _tensor_constant0 = None
index: "f32[6144, 4190]" = torch.ops.aten.index.Tensor( index: "f32[6144, 4190]" = torch.ops.aten.index.Tensor( # noqa: F841
primals_48, [None, lift_fresh_copy] primals_48, [None, lift_fresh_copy]
) )
lift_fresh_copy = None lift_fresh_copy = None

View File

@ -83,7 +83,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
# This behavior is not ideal, but supporting it would add overhead # This behavior is not ideal, but supporting it would add overhead
# to callsites of eval_frame.innermost_fn. A warning would also be very noisy. # to callsites of eval_frame.innermost_fn. A warning would also be very noisy.
w = torch._dynamo.disable(fn=wrapper, recursive=True) torch._dynamo.disable(fn=wrapper, recursive=True)
def test_disable_nn_modules_forward_hook(self): def test_disable_nn_modules_forward_hook(self):
class SimpleLinear(torch.nn.Module): class SimpleLinear(torch.nn.Module):
@ -543,7 +543,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
return v1, v2, v3, v4, v5, v6, v7, v8, v9 return v1, v2, v3, v4, v5, v6, v7, v8, v9
a, b, c = A(), B(), C() a, b, c = A(), B(), C()
v1, v2, v3, v4, v5, v6, v7, v8, v9 = fn(a, b, c) v1, v2, v3, v4, v5, _, v7, v8, v9 = fn(a, b, c)
self.assertEqual(v1, (A, 1)) self.assertEqual(v1, (A, 1))
self.assertEqual(v2, (A, 2)) self.assertEqual(v2, (A, 2))

View File

@ -92,7 +92,7 @@ from user code:
raise NotImplementedError raise NotImplementedError
# Ensure graph break is not possible # Ensure graph break is not possible
for i in range(3): for _ in range(3):
comptime(f) comptime(f)
torch.compile(fn001, backend="eager")(torch.randn(1)) torch.compile(fn001, backend="eager")(torch.randn(1))

Some files were not shown because too many files have changed in this diff Show More