[BE][PYFMT] migrate PYFMT for test/[a-h]*/ to ruff format (#144555)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144555
Approved by: https://github.com/ezyang
ghstack dependencies: #144551, #144554
This commit is contained in:
Xuehai Pan 2025-06-24 00:18:25 +00:00 committed by PyTorch MergeBot
parent e600e044a7
commit 6d5c789ad5
84 changed files with 585 additions and 520 deletions

View File

@ -494,9 +494,7 @@ class TestBaseDataSparsifier(_BaseDataSparsiferTestCase):
(
emb1,
emb2,
) = nn.Embedding(
10, 3
), nn.Embedding(20, 3)
) = nn.Embedding(10, 3), nn.Embedding(20, 3)
emb1_bag, emb2_bag = nn.EmbeddingBag(10, 3), nn.EmbeddingBag(20, 3)
emb3, emb3_bag = nn.Embedding(15, 3), nn.EmbeddingBag(20, 3)
@ -627,9 +625,7 @@ class TestNormDataSparsifiers(_NormDataSparsifierTestCase):
(
emb1,
emb2,
) = nn.Embedding(
10, 3
), nn.Embedding(20, 3)
) = nn.Embedding(10, 3), nn.Embedding(20, 3)
emb1_bag, emb2_bag = nn.EmbeddingBag(10, 3), nn.EmbeddingBag(20, 3)
emb3, emb3_bag = nn.Embedding(15, 3), nn.EmbeddingBag(20, 3)

View File

@ -218,12 +218,12 @@ def _sparse_layer_test_helper(
qmodule_to_check = fqn_to_module(qmodel, fqn_to_check)
# check that the modules were converted as expected
assert isinstance(
sqmodule_to_check, sqmodule_expected_converted_class
), "Convert failed"
assert isinstance(
qmodule_to_check, qmodule_expected_converted_class
), "Mapping failed"
assert isinstance(sqmodule_to_check, sqmodule_expected_converted_class), (
"Convert failed"
)
assert isinstance(qmodule_to_check, qmodule_expected_converted_class), (
"Mapping failed"
)
row_block_size, col_block_size = sqmodel.linear._packed_params._weight_bias()[
2:

View File

@ -1055,9 +1055,9 @@ class TestFPGMPruner(TestCase):
mask1 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-1]
mask2 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-2]
# Check if either of the least-norm filters is not pruned
assert (
mask1.item() is not False or mask2.item() is not False
), "Do not prune all least-norm filters"
assert mask1.item() is not False or mask2.item() is not False, (
"Do not prune all least-norm filters"
)
# fusion step
pruned_model = pruner.prune()

View File

@ -66,9 +66,10 @@ def generate_callgrind_artifacts() -> None:
json.dump(artifacts, f, indent=4)
def load_callgrind_artifacts() -> (
tuple[benchmark_utils.CallgrindStats, benchmark_utils.CallgrindStats]
):
def load_callgrind_artifacts() -> tuple[
benchmark_utils.CallgrindStats,
benchmark_utils.CallgrindStats,
]:
"""Hermetic artifact to unit test Callgrind wrapper.
In addition to collecting counts, this wrapper provides some facilities for

View File

@ -158,7 +158,8 @@ def compute_functional_name(test_params_dict):
return test_params_dict["cpp_function_call"].split("(")[0].replace("F::", "")
else:
raise RuntimeError(
f"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{pprint.pformat(test_params_dict)}" # noqa: B950
"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n"
f"{pprint.pformat(test_params_dict)}"
)
@ -179,7 +180,8 @@ def compute_cpp_function_call(test_params_dict, arg_dict, functional_name):
)
else:
raise RuntimeError(
f"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{pprint.pformat(test_params_dict)}" # noqa: B950
"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n"
f"{pprint.pformat(test_params_dict)}"
)
@ -217,7 +219,8 @@ def write_test_to_test_class(
or "cpp_function_call" in test_params_dict
), (
"To enable C++ API parity test, "
f"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{pprint.pformat(test_params_dict)}. \n" # noqa: B950
"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n"
f"{pprint.pformat(test_params_dict)}. \n"
"If you are interested in adding the C++ API parity test, please see:\n"
"NOTE [How to check NN module / functional API parity between Python and C++ frontends]. \n"
"If not, please add `test_cpp_api_parity=False` to the test params dict and file an issue about this."
@ -233,14 +236,16 @@ def write_test_to_test_class(
functional_name = compute_functional_name(test_params_dict)
assert hasattr(
torch.nn.functional, functional_name
), f"`torch.nn.functional` doesn't have function `{functional_name}`. (Discovered while processing\n{pprint.pformat(test_params_dict)}.)" # noqa: B950
assert hasattr(torch.nn.functional, functional_name), (
f"`torch.nn.functional` doesn't have function `{functional_name}`. "
f"(Discovered while processing\n{pprint.pformat(test_params_dict)}.)"
)
functional_full_name = "F::" + functional_name
assert functional_full_name in parity_table["torch::nn::functional"], (
f"Please add `{functional_full_name}` entry to `torch::nn::functional` section of `test/cpp_api_parity/parity-tracker.md`. "
f"Please add `{functional_full_name}` entry to `torch::nn::functional` "
"section of `test/cpp_api_parity/parity-tracker.md`. "
f"(Discovered while processing\n{pprint.pformat(test_params_dict)}.)"
)

View File

@ -78,9 +78,9 @@ def _kernel_fallback(op, *args, **kwargs):
elif op._schema.is_mutable or op is torch.ops.aten._copy_from.default:
# Only handle inplace ops returning their first arg
assert len(args) >= 1, f"Inplace {op} needs at least one arg"
assert (
len(op._schema.returns) == 1
), f"NYI Inplace {op} with more than one return"
assert len(op._schema.returns) == 1, (
f"NYI Inplace {op} with more than one return"
)
op_name = op.overloadpacket._qualified_op_name
real_res = args[0]
elif any(r.alias_info is not None for r in op._schema.returns):

View File

@ -67,7 +67,7 @@ def prepare_for_sending(args, kwargs):
def convert(obj):
if type(obj) not in VALID_QUEUE_TYPES_IN:
raise RuntimeError(
f"Cannot send object of type {type(obj)} " "over openreg device pipe."
f"Cannot send object of type {type(obj)} over openreg device pipe."
)
if isinstance(obj, torch.Tensor):
@ -82,8 +82,7 @@ def receive_after_sending(allocator, args, kwargs):
def convert(obj):
if type(obj) not in VALID_QUEUE_TYPES_OUT:
raise RuntimeError(
f"Received invalid object of type {type(obj)} "
"over openreg device pipe."
f"Received invalid object of type {type(obj)} over openreg device pipe."
)
if isinstance(obj, OpenRegTensorMeta):

View File

@ -561,8 +561,9 @@ class TestFullyShardPrefetch(FSDPTest):
FSDPParamGroup.post_backward, events
)
# Check the order for normal 1 forward, 1 backward, 1 optimizer step
with patch_unshard(unshard_with_record), patch_post_backward(
post_backward_with_record
with (
patch_unshard(unshard_with_record),
patch_post_backward(post_backward_with_record),
):
for iter_idx in range(3):
loss = model(inp)
@ -617,8 +618,9 @@ class TestFullyShardPrefetch(FSDPTest):
FSDPParamGroup.post_backward, events
)
# Check the order for multiple forwards before 1 backward
with patch_unshard(unshard_with_record), patch_post_backward(
post_backward_with_record
with (
patch_unshard(unshard_with_record),
patch_post_backward(post_backward_with_record),
):
loss1 = model(inp)
loss2 = model(inp)
@ -703,8 +705,9 @@ class TestFullyShardPrefetch(FSDPTest):
post_backward_with_record = self._get_post_backward_with_record(
FSDPParamGroup.post_backward, events
)
with patch_unshard(unshard_with_record), patch_post_backward(
post_backward_with_record
with (
patch_unshard(unshard_with_record),
patch_post_backward(post_backward_with_record),
):
loss1, loss2 = model(inp)
expected_events = [
@ -794,9 +797,11 @@ class TestFullyShardPrefetch(FSDPTest):
("reshard", "", TrainingState.POST_BACKWARD),
("post_backward", "", TrainingState.POST_BACKWARD),
]
with patch_unshard(unshard_with_record), patch_reshard(
reshard_with_record
), patch_post_backward(post_backward_with_record):
with (
patch_unshard(unshard_with_record),
patch_reshard(reshard_with_record),
patch_post_backward(post_backward_with_record),
):
set_forward_prefetch(model, num_to_prefetch=1)
loss = model(inp)
expected_forward_events = [
@ -882,9 +887,11 @@ class TestFullyShardPrefetch(FSDPTest):
("reshard", "layers.3", TrainingState.FORWARD),
("reshard", "", TrainingState.FORWARD),
]
with patch_unshard(unshard_with_record), patch_reshard(
reshard_with_record
), patch_post_backward(post_backward_with_record):
with (
patch_unshard(unshard_with_record),
patch_reshard(reshard_with_record),
patch_post_backward(post_backward_with_record),
):
set_backward_prefetch(model, num_to_prefetch=1)
loss = model(inp)
self.assertEqual(events, expected_forward_events)
@ -967,8 +974,9 @@ class TestFullyShardPrefetch(FSDPTest):
(2, model_args.max_seq_len),
device=device_type.type,
)
with patch_unshard(unshard_with_record), patch_post_backward(
post_backward_with_record
with (
patch_unshard(unshard_with_record),
patch_post_backward(post_backward_with_record),
):
for _ in range(3):
loss = model(inp)
@ -1046,8 +1054,9 @@ class TestFullyShardPrefetch(FSDPTest):
FSDPParamGroup.post_backward, events
)
inp = torch.randn((2, 16), device=device_type.type)
with patch_unshard(unshard_with_record), patch_post_backward(
post_backward_with_record
with (
patch_unshard(unshard_with_record),
patch_post_backward(post_backward_with_record),
):
for _ in range(3):
loss = model(inp)

View File

@ -222,9 +222,7 @@ class TestFullyShardCompile(FSDPTest):
):
unsharded_param_graph_inputs.add(node.args[0])
assert len(unsharded_param_graph_inputs) > 0
assert len(unsharded_param_graph_inputs) == len(
list(model.parameters())
), """\
assert len(unsharded_param_graph_inputs) == len(list(model.parameters())), """\
Expected all model parameters to be wrapped by FSDP2 and
have their unsharded version as graph input, but it's not true!
"""
@ -237,7 +235,7 @@ have their unsharded version as graph input, but it's not true!
no_aliased_unsharded_params_in_graph_inputs = False
err_msg += f"""\n
Found aliased unsharded param in graph inputs: {aliased_graph_inputs},
val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
val.shape: {[node.meta["val"].shape for node in aliased_graph_inputs]},
"""
self.assertTrue(no_aliased_unsharded_params_in_graph_inputs, err_msg)
@ -466,10 +464,9 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_compiled_autograd_ctx(self):
self.skipTestForOldSm()
with torch._dynamo.config.patch(
skip_fsdp_hooks=False,
), torch._functorch.config.patch(
recompute_views=True,
with (
torch._dynamo.config.patch(skip_fsdp_hooks=False),
torch._functorch.config.patch(recompute_views=True),
):
inputs = torch.randn(8, 8)
model = torch.nn.Linear(8, 8)
@ -567,17 +564,20 @@ Unsupported Tensor.backward() call
torch._dynamo.reset()
torch._dynamo.compiled_autograd.reset()
with torch._dynamo.config.patch(
with (
torch._dynamo.config.patch(
compiled_autograd=True,
compiled_autograd_kwargs_override={
"fullgraph": True,
},
inline_inbuilt_nn_modules=True,
skip_fsdp_hooks=False,
), torch._functorch.config.patch(
),
torch._functorch.config.patch(
enable_autograd_cache=False,
recompute_views=True,
), torch._inductor.config.patch(
),
torch._inductor.config.patch(
force_disable_caches=True,
reorder_for_compute_comm_overlap=True,
reorder_for_compute_comm_overlap_passes=[
@ -585,6 +585,7 @@ Unsupported Tensor.backward() call
"raise_comms",
"reorder_compute_for_overlap",
],
),
):
losses_compiled = test_compiled()
losses_eager = test_eager()
@ -741,9 +742,9 @@ Unsupported Tensor.backward() call
def _test_nested_fully_shard_backend_inductor_fullgraph_True(self):
self.skipTestForOldSm()
for fwd_fullgraph in [True]:
with self._reinplace_all_gather_with_optional_checks(
fwd_fullgraph
), torch._inductor.config.patch(
with (
self._reinplace_all_gather_with_optional_checks(fwd_fullgraph),
torch._inductor.config.patch(
post_grad_custom_post_pass=(
functools.partial(
self._check_fsdp_copy_and_resize_ops_count_in_graph,
@ -755,6 +756,7 @@ Unsupported Tensor.backward() call
if fwd_fullgraph
else None
)
),
):
_, triton_codes = run_and_get_code(
lambda: self._test_traceable_fsdp(
@ -943,9 +945,10 @@ Unsupported Tensor.backward() call
for fwd_fullgraph, all_requires_grad in itertools.product(
[True], [True, False]
):
with self._maybe_add_graph_break_to_sdpa(
fwd_fullgraph
), self._reinplace_all_gather_with_optional_checks(fwd_fullgraph):
with (
self._maybe_add_graph_break_to_sdpa(fwd_fullgraph),
self._reinplace_all_gather_with_optional_checks(fwd_fullgraph),
):
self._test_traceable_fsdp(
*self._create_transformer_factory_fns(
all_requires_grad=all_requires_grad
@ -982,9 +985,9 @@ Unsupported Tensor.backward() call
log.warning(
f"fwd_fullgraph={fwd_fullgraph}, all_requires_grad={all_requires_grad}, activation_checkpoint={activation_checkpoint}" # noqa: G004, G001, B950
)
with self._reinplace_all_gather_with_optional_checks(
fwd_fullgraph
), torch._inductor.config.patch(
with (
self._reinplace_all_gather_with_optional_checks(fwd_fullgraph),
torch._inductor.config.patch(
post_grad_custom_post_pass=(
functools.partial(
self._check_fsdp_copy_and_resize_ops_count_in_graph,
@ -999,6 +1002,7 @@ Unsupported Tensor.backward() call
if fwd_fullgraph
else None
)
),
):
_, triton_codes = run_and_get_code(
lambda: self._test_traceable_fsdp(

View File

@ -385,9 +385,9 @@ class TestFullyShardAllGatherExtensionsMultiThread(
only some ranks may require padding, in which case only those ranks
will error out and the all-gather will timeout.
"""
assert (
self.world_size >= 2
), f"Assumes world size of at least 2 but got {self.world_size=}"
assert self.world_size >= 2, (
f"Assumes world size of at least 2 but got {self.world_size=}"
)
model = MLP(dim=3, dim_multiplier=3)
for module in model.modules():
for param_name, param in module.named_parameters(recurse=False):

View File

@ -115,9 +115,10 @@ class TestFullyShardFrozen(FSDPTest):
torch.manual_seed(42 + self.rank + 1)
device = device_type
with patch_reduce_scatter(
reduce_scatter
), patch_register_post_backward_hook_backward(backward_with_count):
with (
patch_reduce_scatter(reduce_scatter),
patch_register_post_backward_hook_backward(backward_with_count),
):
for iter_idx in range(10):
inp = torch.randn((8, lin_dim), device=device)
losses: list[torch.Tensor] = []

View File

@ -910,9 +910,9 @@ class TestFullyShardProcessGroupInit(FSDPTestMultiThread):
@skip_if_lt_x_gpu(1)
def test_2d_process_group_init(self):
shard_mesh_dim_size = 2
assert (
self.world_size % shard_mesh_dim_size == 0
), f"Expects {self.world_size} to be divisible by {shard_mesh_dim_size}"
assert self.world_size % shard_mesh_dim_size == 0, (
f"Expects {self.world_size} to be divisible by {shard_mesh_dim_size}"
)
replicate_mesh_dim_size = self.world_size // shard_mesh_dim_size
mesh_dim_names = ("replicate", "shard")
ref_mesh = init_device_mesh(

View File

@ -54,9 +54,9 @@ class TestFullyShardMemory(FSDPTest):
)
):
return # skip since not a common use case
assert (
self.world_size == 2
), f"Requires world size of 2 since some values are hard coded: {self.world_size}"
assert self.world_size == 2, (
f"Requires world size of 2 since some values are hard coded: {self.world_size}"
)
torch.manual_seed(42)
# Pre-run a linear forward (gemm and bias) and backward (gemm) to
# allocate the cuBLAS workspaces before measuring the memory usage

View File

@ -284,9 +284,7 @@ class TestFullyShardMixedPrecisionTraining(FSDPTest):
) # bf16 reduction
param.grad = funcol.all_gather_tensor(
sharded_grad, gather_dim=0, group=group
).to(
param.dtype
) # upcast to fp32
).to(param.dtype) # upcast to fp32
ref_optim.step() # fp32 optimizer step
self.assertEqual(fsdp_loss, ref_loss)

View File

@ -139,8 +139,9 @@ class TestFullyShardOverlap(FSDPTest):
dist.reduce_scatter_tensor(dummy_rs_output, dummy_rs_input)
def fwd_bwd():
with patch_all_gather(delayed_all_gather), patch_reduce_scatter(
delayed_reduce_scatter
with (
patch_all_gather(delayed_all_gather),
patch_reduce_scatter(delayed_reduce_scatter),
):
loss = model(inp).sum()
loss.backward()

View File

@ -74,12 +74,12 @@ class TestFullyShardForwardInputs(FSDPTestMultiThread):
# Check that FSDP moved the inputs to GPU, including recursing
# into the tuple data structure
assert x.device == device, f"Expects {device} but got {x.device}"
assert (
ys[0].device == device
), f"Expects {device} but got {ys[0].device}"
assert (
ys[1].device == device
), f"Expects {device} but got {ys[1].device}"
assert ys[0].device == device, (
f"Expects {device} but got {ys[0].device}"
)
assert ys[1].device == device, (
f"Expects {device} but got {ys[1].device}"
)
y = ys[0] + ys[1]
return x + y + 1

View File

@ -234,8 +234,8 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
if __name__ == "__main__":
assert (
not torch.cuda._initialized
), "test_distributed must not have initialized CUDA context on main process"
assert not torch.cuda._initialized, (
"test_distributed must not have initialized CUDA context on main process"
)
run_tests()

View File

@ -250,9 +250,11 @@ class TestJoin(MultiProcessTestCase):
else "Detected at least one rank that exhausted inputs. "
"Throwing across all ranks."
)
with self.assertRaisesRegex(
RuntimeError, expected_msg
) if throw_on_early_termination else contextlib.nullcontext():
with (
self.assertRaisesRegex(RuntimeError, expected_msg)
if throw_on_early_termination
else contextlib.nullcontext()
):
with Join(
allreducers,
enable=enable,

View File

@ -677,9 +677,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
fully_shard(layer)
fully_shard(model)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
torch.optim.lr_scheduler.LambdaLR(
optim, lr_lambda=[lambda epoch: 0.95**epoch]
)
torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=[lambda epoch: 0.95**epoch])
opt_state_dict = ptd_state_dict.get_optimizer_state_dict(
model,
optim,

View File

@ -228,9 +228,9 @@ class TestStateDictStager(TestCase):
# Validate tensor count and bytes
expected_storage_cnt = 2
assert (
num_storages == expected_storage_cnt
), f"Expected {expected_storage_cnt} storages, got {num_storages}"
assert num_storages == expected_storage_cnt, (
f"Expected {expected_storage_cnt} storages, got {num_storages}"
)
# Calculate expected bytes
# Note: Only unique storages are counted in the byte count
@ -239,9 +239,9 @@ class TestStateDictStager(TestCase):
+ tensor3.numel() # tensor1 and tensor2 share storage
* tensor3.element_size() # tensor3 and its narrow view share storage
)
assert (
num_bytes == expected_bytes
), f"Expected {expected_bytes} bytes, got {num_bytes}"
assert num_bytes == expected_bytes, (
f"Expected {expected_bytes} bytes, got {num_bytes}"
)
# Verify that the CPU state dict is equivalent to the original CUDA state dict
result, error = compare_state_dicts(state_dict, cpu_state_dict)
assert result, f"State dicts are not equivalent: {error}"
@ -301,9 +301,9 @@ class TestStateDictStager(TestCase):
# Verify the first result is correct
result, error = compare_state_dicts(state_dict, cpu_state_dict1)
assert (
result
), f"First state dict is not equivalent to original: {error}"
assert result, (
f"First state dict is not equivalent to original: {error}"
)
# Modify the original tensors
tensor1.fill_(0)
@ -317,14 +317,14 @@ class TestStateDictStager(TestCase):
# Verify that the second CPU state dict is equivalent to the modified original state dict
result, error = compare_state_dicts(state_dict, cpu_state_dict2)
assert (
result
), f"Second state dict is not equivalent to modified original: {error}"
assert result, (
f"Second state dict is not equivalent to modified original: {error}"
)
# Verify that the number of cached storages hasn't changed
assert (
num_storages1 == num_storages2
), f"Storage count changed: {num_storages1} vs {num_storages2}"
assert num_storages1 == num_storages2, (
f"Storage count changed: {num_storages1} vs {num_storages2}"
)
# Verify that the tensors in the second state dict have the same storage pointers as the first
assert (
@ -347,12 +347,12 @@ class TestStateDictStager(TestCase):
cpu_state_dict3 = stager.stage(state_dict)
# Verify that the third CPU state dict reflects the updated values
assert torch.all(
cpu_state_dict3["tensor1"] == 42.0
), "Updated values should be reflected in the cached state dict"
assert torch.all(
cpu_state_dict3["tensor2"] == 42.0
), "Updated values should be reflected in the cached state dict"
assert torch.all(cpu_state_dict3["tensor1"] == 42.0), (
"Updated values should be reflected in the cached state dict"
)
assert torch.all(cpu_state_dict3["tensor2"] == 42.0), (
"Updated values should be reflected in the cached state dict"
)
@requires_cuda
def test_tensor_attrs(self):
@ -381,24 +381,24 @@ class TestStateDictStager(TestCase):
cpu_state_dict = stager.stage(state_dict)
# Verify that tensor attributes are preserved
assert hasattr(
cpu_state_dict["tensor1"], "a"
), "Tensor attribute 'a' was not preserved"
assert (
cpu_state_dict["tensor1"].a == 42
), "Tensor attribute 'a' has incorrect value"
assert hasattr(
cpu_state_dict["tensor1"], "b"
), "Tensor attribute 'b' was not preserved"
assert (
cpu_state_dict["tensor1"].b == 43
), "Tensor attribute 'b' has incorrect value"
assert hasattr(
cpu_state_dict["recursive"]["tensor3"], "c"
), "Tensor attribute 'c' was not preserved"
assert (
cpu_state_dict["recursive"]["tensor3"].c == 44
), "Tensor attribute 'c' has incorrect value"
assert hasattr(cpu_state_dict["tensor1"], "a"), (
"Tensor attribute 'a' was not preserved"
)
assert cpu_state_dict["tensor1"].a == 42, (
"Tensor attribute 'a' has incorrect value"
)
assert hasattr(cpu_state_dict["tensor1"], "b"), (
"Tensor attribute 'b' was not preserved"
)
assert cpu_state_dict["tensor1"].b == 43, (
"Tensor attribute 'b' has incorrect value"
)
assert hasattr(cpu_state_dict["recursive"]["tensor3"], "c"), (
"Tensor attribute 'c' was not preserved"
)
assert cpu_state_dict["recursive"]["tensor3"].c == 44, (
"Tensor attribute 'c' has incorrect value"
)
@requires_cuda
def test_different_dtypes(self):

View File

@ -500,11 +500,13 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
)
with mock.patch.object(
mpc, "_is_done", return_value=True
), mock.patch.object(mpc, "_pc"), mock.patch.object(
with (
mock.patch.object(mpc, "_is_done", return_value=True),
mock.patch.object(mpc, "_pc"),
mock.patch.object(
mpc._pc, "join", side_effect=[True, False, False, True]
) as mock_join:
) as mock_join,
):
mpc._poll()
self.assertEqual(4, mock_join.call_count)

View File

@ -56,32 +56,36 @@ class TestDistributedCheckpoint(FSDPTest):
torch.manual_seed(200)
new_model = wrap(SkipModel(double_nest=True))
with FullyShardedDataParallel.summon_full_params(
model
), FullyShardedDataParallel.summon_full_params(new_model):
with (
FullyShardedDataParallel.summon_full_params(model),
FullyShardedDataParallel.summon_full_params(new_model),
):
params = list(model.parameters())
new_params = list(new_model.parameters())
self.assertNotEqual(params, new_params)
writer = FileSystemWriter(self.temp_dir)
reader = FileSystemReader(self.temp_dir)
with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type(
new_model, state_dict_type
with (
FSDP.state_dict_type(model, state_dict_type),
FSDP.state_dict_type(new_model, state_dict_type),
):
state_dict = model.state_dict()
save(state_dict, writer)
with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type(
new_model, state_dict_type
with (
FSDP.state_dict_type(model, state_dict_type),
FSDP.state_dict_type(new_model, state_dict_type),
):
state_dict = new_model.state_dict()
load(state_dict, reader)
new_model.load_state_dict(state_dict)
with FullyShardedDataParallel.summon_full_params(
model
), FullyShardedDataParallel.summon_full_params(new_model):
with (
FullyShardedDataParallel.summon_full_params(model),
FullyShardedDataParallel.summon_full_params(new_model),
):
params = list(model.parameters())
new_params = list(new_model.parameters())
self.assertEqual(params, new_params)

View File

@ -242,11 +242,10 @@ class TestCommunication(FSDPTest):
# and if `use_no_sync=False`, we only run `num_iters` iterations
# outside `no_sync()`
num_iters = 3
with patch(
"torch.distributed.all_gather_into_tensor"
) as mock_all_gather, patch(
"torch.distributed.reduce_scatter_tensor"
) as mock_reduce_scatter:
with (
patch("torch.distributed.all_gather_into_tensor") as mock_all_gather,
patch("torch.distributed.reduce_scatter_tensor") as mock_reduce_scatter,
):
def reset_mocks():
mock_all_gather.reset_mock()

View File

@ -379,12 +379,15 @@ class TestHooks(FSDPTest):
register_pre_backward_hooks_call_count += 1
return orig_register_pre_backward_hooks(*args, **kwargs)
with mock.patch(
with (
mock.patch(
"torch.distributed.fsdp._runtime_utils._register_pre_backward_hooks",
_register_pre_backward_hooks_with_count,
), mock.patch(
),
mock.patch(
"torch.distributed.fsdp._runtime_utils._register_post_backward_hook"
) as register_post_bwd_mock:
) as register_post_bwd_mock,
):
self.assertEqual(register_pre_backward_hooks_call_count, 0)
self.assertFalse(register_post_bwd_mock.called)
fsdp_model(*input)

View File

@ -152,9 +152,9 @@ class TestGradAcc(FSDPTest):
batches.append(tuple(permute_tensor(t) for t in batch))
for batch1, batch2 in itertools.combinations(batches, r=2):
for t1, t2 in zip(batch1, batch2):
assert not torch.all(
t1 == t2
), "Check the test to make sure that batches are distinct"
assert not torch.all(t1 == t2), (
"Check the test to make sure that batches are distinct"
)
# Concatenate the batches along the given batch dimension
concat_batch: tuple[torch.Tensor, ...] = tuple(

View File

@ -121,8 +121,9 @@ class TestFSDPHybridShard(FSDPTest):
def test_hsdp_save_load_state_dict(self):
model = MyModel().cuda()
num_node_devices = torch.cuda.device_count()
shard_rank_lists = list(range(0, num_node_devices // 2)), list(
range(num_node_devices // 2, num_node_devices)
shard_rank_lists = (
list(range(0, num_node_devices // 2)),
list(range(num_node_devices // 2, num_node_devices)),
)
shard_groups = (
dist.new_group(shard_rank_lists[0]),
@ -171,8 +172,9 @@ class TestFSDPHybridShard(FSDPTest):
def test_hsdp_sync_module_state(self):
model = MyModel().cuda()
num_node_devices = torch.cuda.device_count()
shard_rank_lists = list(range(0, num_node_devices // 2)), list(
range(num_node_devices // 2, num_node_devices)
shard_rank_lists = (
list(range(0, num_node_devices // 2)),
list(range(num_node_devices // 2, num_node_devices)),
)
shard_groups = (
dist.new_group(shard_rank_lists[0]),
@ -310,8 +312,9 @@ class TestFSDPHybridShard(FSDPTest):
cntr = Counter()
patched_allreduce = partial(patched_collective, orig_ar, cntr)
patched_reduce_scatter = partial(patched_collective, orig_rs, cntr)
with patch_allreduce(patched_allreduce), patch_reduce_scatter(
patched_reduce_scatter
with (
patch_allreduce(patched_allreduce),
patch_reduce_scatter(patched_reduce_scatter),
):
inp = hsdp_model.get_input(device=torch.cuda.current_device())
out = hsdp_model(inp[0], inp[1])
@ -355,9 +358,9 @@ class TestFSDPHybridShard(FSDPTest):
use_orig_params,
hsdp_process_groups=hsdp_pgs,
)
assert (
hsdp_model._inter_node_pg.size() > 1
), "HSDP model initialized without replication"
assert hsdp_model._inter_node_pg.size() > 1, (
"HSDP model initialized without replication"
)
fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2)
hsdp_optim = torch.optim.Adam(hsdp_model.parameters(), lr=1e-2)
torch.manual_seed(global_pg.rank() + 1)

View File

@ -766,9 +766,9 @@ class TestFSDPMixedPrecisionSharded(TestFSDPMixedPrecision):
if expect_use_full_prec_in_eval:
assert x.dtype == torch.float32, f"Expected fp32, got {x.dtype}"
else:
assert (
x.dtype == low_prec_dtype
), f"Expected {low_prec_dtype}, got {x.dtype}"
assert x.dtype == low_prec_dtype, (
f"Expected {low_prec_dtype}, got {x.dtype}"
)
return self.a(x)
mp_config = MixedPrecision(

View File

@ -91,9 +91,9 @@ class TestTPFSDPIntegration(FSDPTest):
tensor_parallel_size: int,
) -> tuple[dict[str, int], dict[str, tuple[torch.Size, int]]]:
""" """
assert (
type(model) is SimpleModel
), "Expects a `SimpleModel` since the sharding cases on the model definition"
assert type(model) is SimpleModel, (
"Expects a `SimpleModel` since the sharding cases on the model definition"
)
param_name_to_numel = OrderedDict()
param_name_to_sharding_info = OrderedDict()
for param_name, param in model.named_parameters():

View File

@ -654,9 +654,12 @@ class TestFSDPUseOrigParamsUnshardReshard(FSDPTest):
losses1 = []
losses2 = []
losses = []
for _model, _optim in (fsdp_model, optim), (
for _model, _optim in (
(fsdp_model, optim),
(
fsdp_model_orig_params,
optim_orig_params,
),
):
_optim.zero_grad()
loss1 = _model(*inp1)
@ -1166,9 +1169,9 @@ class TestFSDPUseOrigParamsFQNs(FSDPTest):
clean_tensor_name(tup[0]) for tup in self.named_parameters()
]
params = [tup[1] for tup in self.named_parameters()]
assert (
param_shapes[0] is not None and param_shapes[1] is not None
), "`param_sizes` should be set"
assert param_shapes[0] is not None and param_shapes[1] is not None, (
"`param_sizes` should be set"
)
assert_equal_fn(
param_names,
[

View File

@ -19,6 +19,7 @@ The script itself is not a test case hence no assertions are made in this script
see: - test/distributed/launcher/run_test.py#test_is_torchelastic_launched()
- test/distributed/launcher/run_test.py#test_is_not_torchelastic_launched()
"""
import argparse
import torch.distributed as dist

View File

@ -292,9 +292,9 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
betas=BETAS,
eps=EPS,
)
assert (
len(o.param_groups) == 2
), f"Expected 2 ZeRO param groups, but got {len(o.param_groups)}"
assert len(o.param_groups) == 2, (
f"Expected 2 ZeRO param groups, but got {len(o.param_groups)}"
)
assert len(o.optim.param_groups) == 2, (
"Expected 2 local optimizer param groups, but got "
f"{len(o.optim.param_groups)}"
@ -713,9 +713,9 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
LR = 1e-3
MOMENTUM = 0.99
REFERENCE_RANK = 0
assert (
REFERENCE_RANK in subgroup_ranks
), "Reference rank must be in the new process group"
assert REFERENCE_RANK in subgroup_ranks, (
"Reference rank must be in the new process group"
)
loss_fn = torch.nn.L1Loss().to(device)
def check(optimizer):
@ -1165,22 +1165,28 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
# Increased tolerances are needed to pass when using TF32
# See: https://github.com/pytorch/pytorch/issues/67764
(
torch.testing.assert_close(
local_loss.cpu(),
ddp_loss.cpu(),
rtol=1e-03,
atol=1e-08,
), "Losses differ between local optimizer and ZeRO"
),
"Losses differ between local optimizer and ZeRO",
)
for local_p, ddp_p in zip(
local_model.parameters(), ddp_model.parameters()
):
(
torch.testing.assert_close(
local_p.cpu(),
ddp_p.cpu(),
rtol=1e-03,
atol=1e-04,
), "Models differ after a step"
),
"Models differ after a step",
)
@skipIfHpu
@skip_if_lt_x_gpu(4)

View File

@ -89,9 +89,9 @@ class PipeTests(TestCase):
mb_args=(x, y),
)
assert (
pipe.num_stages == EXPECTED_N_STAGES[ModelClass]
), f"nstages = {pipe.num_stages}, expect {EXPECTED_N_STAGES[ModelClass]}"
assert pipe.num_stages == EXPECTED_N_STAGES[ModelClass], (
f"nstages = {pipe.num_stages}, expect {EXPECTED_N_STAGES[ModelClass]}"
)
ref_out = mod(x, y)
out = pipe(x, y)[0]
@ -109,9 +109,7 @@ class PipeTests(TestCase):
new_names.update(stage_fqns)
if CHECK_FQN_SET_EQUALITY:
assert (
old_names == new_names
), f"""
assert old_names == new_names, f"""
old names {old_names}
new names {new_names}
"""

View File

@ -60,9 +60,9 @@ class UnflattenTests(TestCase):
for stage_idx in range(pipe.num_stages):
stage_mod = pipe.get_stage_module(stage_idx)
for param_name, _ in stage_mod.named_parameters():
assert (
param_name in orig_state_dict
), f"{param_name} not in original state dict"
assert param_name in orig_state_dict, (
f"{param_name} not in original state dict"
)
print("Param qualname test passed")
# Check equivalence

View File

@ -45,9 +45,9 @@ class ShareMemoryRPCPickler(_InternalRPCPickler):
for t in torch._tensor_classes:
self._dispatch_table[t] = TorchMpReductions.reduce_tensor
self._dispatch_table[torch.Tensor] = TorchMpReductions.reduce_tensor
self._dispatch_table[
torch.nn.parameter.Parameter
] = TorchMpReductions.reduce_tensor
self._dispatch_table[torch.nn.parameter.Parameter] = (
TorchMpReductions.reduce_tensor
)
def worker_loop(a):

View File

@ -872,9 +872,7 @@ def forward(self, primals_1):
"buf0 = torch.ops._c10d_functional.all_gather_into_tensor.default(primal"
).check("torch.ops._c10d_functional.wait_tensor.default(buf0").check(
"extern_kernels.mm(buf0,"
).run(
code
)
).run(code)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(1)

View File

@ -441,9 +441,11 @@ class DistMathOpsTest(DTensorTestBase):
out_req_grad: bool
subtest_fails = {}
valid_filter = lambda cfg: not ( # noqa: E731
cfg.ln_req_grad and not cfg.elementwise_affine
) and any(cfg[2:])
valid_filter = ( # noqa: E731
lambda cfg: (
not (cfg.ln_req_grad and not cfg.elementwise_affine) and any(cfg[2:])
)
)
subtest_cfgs = list(
filter(
valid_filter,
@ -566,9 +568,9 @@ class DistMathOpsTest(DTensorTestBase):
except Exception as e:
subtest_fails[subtest_cfg] = e
# if any subtest fails, provide the failed subtests and report the overall failure
assert (
not subtest_fails
), f"{len(subtest_fails)}/{len(subtest_cfgs)} subtests failed: {pformat(subtest_fails)}"
assert not subtest_fails, (
f"{len(subtest_fails)}/{len(subtest_cfgs)} subtests failed: {pformat(subtest_fails)}"
)
@with_comms
def test_topk(self):

View File

@ -26,7 +26,9 @@ def with_xla(func: Callable) -> Callable:
@wraps(func) # pyre-ignore[6]
def wrapper(
self, *args: tuple[object], **kwargs: dict[str, Any] # type: ignore[misc]
self,
*args: tuple[object],
**kwargs: dict[str, Any], # type: ignore[misc]
) -> None:
# TODO(yeounoh) replace this with xr.use_spmd() when we deprecate the flag.
os.environ["XLA_USE_SPMD"] = "1"

View File

@ -2234,8 +2234,8 @@ class LocalRankTest(MultiProcessTestCase):
if __name__ == "__main__":
assert (
not torch.cuda._initialized
), "test_distributed must not have initialized CUDA context on main process"
assert not torch.cuda._initialized, (
"test_distributed must not have initialized CUDA context on main process"
)
run_tests()

View File

@ -796,13 +796,11 @@ class CompileTest(TestCase):
.check("buf6 = empty")
# Expect in-place with inductor allocated buf
.check(
"torch.ops._c10d_functional.all_reduce_coalesced_"
".default([buf0, buf1]"
"torch.ops._c10d_functional.all_reduce_coalesced_.default([buf0, buf1]"
)
# Expect no in-place with graph input (buf5, buf6 are clones)
.check(
"torch.ops._c10d_functional.all_reduce_coalesced_"
".default([buf5, buf6]"
"torch.ops._c10d_functional.all_reduce_coalesced_.default([buf5, buf6]"
)
.check("torch.ops._c10d_functional.wait_tensor.default(buf0")
.check("torch.ops._c10d_functional.wait_tensor.default(buf1")

View File

@ -2705,8 +2705,8 @@ class LargeCommTest(test_c10d_common.AbstractLargeCommTest, MultiProcessTestCase
if __name__ == "__main__":
assert (
not torch.cuda._initialized
), "test_distributed must not have initialized CUDA context on main process"
assert not torch.cuda._initialized, (
"test_distributed must not have initialized CUDA context on main process"
)
run_tests()

View File

@ -2869,9 +2869,9 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
self.assertTrue(t.is_alive())
if prev_nccl_async_error_handling is not None:
os.environ[
"TORCH_NCCL_ASYNC_ERROR_HANDLING"
] = prev_nccl_async_error_handling
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = (
prev_nccl_async_error_handling
)
@requires_nccl()
@skip_if_lt_x_gpu(3)
@ -2931,9 +2931,9 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
self._test_barrier_error()
if prev_nccl_async_error_handling is not None:
os.environ[
"TORCH_NCCL_ASYNC_ERROR_HANDLING"
] = prev_nccl_async_error_handling
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = (
prev_nccl_async_error_handling
)
@requires_nccl()
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
@ -2984,9 +2984,9 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
process_group.abort()
if prev_nccl_async_error_handling is not None:
os.environ[
"TORCH_NCCL_ASYNC_ERROR_HANDLING"
] = prev_nccl_async_error_handling
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = (
prev_nccl_async_error_handling
)
@requires_nccl()
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
@ -3065,9 +3065,9 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
os.remove(new_file_name)
if prev_nccl_async_error_handling is not None:
os.environ[
"TORCH_NCCL_ASYNC_ERROR_HANDLING"
] = prev_nccl_async_error_handling
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = (
prev_nccl_async_error_handling
)
def _run_invalid_nccl_blocking_wait_env(self, val):
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = val
@ -3360,9 +3360,7 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
self.assertEqual(_get_intra_node_comm_usage_counter(), 3)
# Verify that IntraNodeComm is not used beyond 10MB
t = torch.full(
(10 * 1024**2 // 2 + 1,), self.rank, dtype=torch.bfloat16
).cuda()
t = torch.full((10 * 1024**2 // 2 + 1,), self.rank, dtype=torch.bfloat16).cuda()
c10d.all_reduce(t, c10d.ReduceOp.SUM)
self.assertTrue(t.eq(expect).all())
self.assertEqual(_get_intra_node_comm_usage_counter(), 3)
@ -4249,9 +4247,9 @@ class SparseCollective(MultiProcessTestCase):
class NCCLTraceTestBase(MultiProcessTestCase):
def setUp(self):
super().setUp()
os.environ[
"TORCH_NCCL_ENABLE_TIMING"
] = "0" # see 'timing_enabled' parametrized tests
os.environ["TORCH_NCCL_ENABLE_TIMING"] = (
"0" # see 'timing_enabled' parametrized tests
)
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "1000"
os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "1"
self.tempdir = tempfile.TemporaryDirectory()
@ -5331,8 +5329,8 @@ class ProcessGroupNCCLLargerScaleTest(MultiProcessTestCase):
if __name__ == "__main__":
assert (
not torch.cuda._initialized
), "test_distributed must not have initialized CUDA context on main process"
assert not torch.cuda._initialized, (
"test_distributed must not have initialized CUDA context on main process"
)
run_tests()

View File

@ -1090,8 +1090,8 @@ class UccProcessGroupWithDispatchedCollectivesTests(
if __name__ == "__main__":
assert (
not torch.cuda._initialized
), "test_distributed must not have initialized CUDA context on main process"
assert not torch.cuda._initialized, (
"test_distributed must not have initialized CUDA context on main process"
)
run_tests()

View File

@ -90,9 +90,9 @@ class TestCollectiveUtils(MultiProcessTestCase):
res = all_gather(data_or_fn=func, pg=pg)
func.assert_called_once()
assert res == list(
range(self.world_size)
), f"Expect res to be list of 0 through {self.world_size} (got {res})"
assert res == list(range(self.world_size)), (
f"Expect res to be list of 0 through {self.world_size} (got {res})"
)
def test_all_gather_result_no_pg(self) -> None:
"""

View File

@ -207,8 +207,8 @@ class TestCollectives(TestCase):
if __name__ == "__main__":
assert (
not torch.cuda._initialized
), "test_distributed must not have initialized CUDA context on main process"
assert not torch.cuda._initialized, (
"test_distributed must not have initialized CUDA context on main process"
)
run_tests()

View File

@ -490,8 +490,9 @@ class DeviceMeshTestNDim(DTensorTestBase):
# Create shard groups (e.g. (0, 1, 2, 3), (4, 5, 6, 7))
# and assign the correct shard group to each rank
shard_rank_lists = list(range(0, self.world_size // 2)), list(
range(self.world_size // 2, self.world_size)
shard_rank_lists = (
list(range(0, self.world_size // 2)),
list(range(self.world_size // 2, self.world_size)),
)
shard_groups = (
new_group(shard_rank_lists[0]),

View File

@ -1830,9 +1830,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
f"""{expected_guard_source} "L['self']._modules['net']" TYPE_MATCH"""
).check(
f"""{expected_guard_source} "L['self']._modules['net']._modules['0']" TYPE_MATCH"""
).run(
GUARDS_FILE.getvalue()
)
).run(GUARDS_FILE.getvalue())
self.assertTrue(same(correct_outputs, outputs))

View File

@ -519,12 +519,13 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
out = a2a / a2a.sum(dim=0)
return out
with _dynamo_dist_per_rank_init(
self.rank, self.world_size
), torch._dynamo.config.patch(
with (
_dynamo_dist_per_rank_init(self.rank, self.world_size),
torch._dynamo.config.patch(
dynamic_shapes=True,
capture_dynamic_output_shape_ops=True,
capture_scalar_outputs=True,
),
):
row = self.world_size * (self.rank + 1) * (self.world_size + 1) / 2
input_split_sizes_tensor = torch.tensor(
@ -680,15 +681,15 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
return torch.ops.custom_ns.foo(a2a)
with _dynamo_dist_per_rank_init(
self.rank, self.world_size
), torch._dynamo.config.patch(
with (
_dynamo_dist_per_rank_init(self.rank, self.world_size),
torch._dynamo.config.patch(
dynamic_shapes=True,
capture_dynamic_output_shape_ops=True,
capture_scalar_outputs=True,
), torch.library._scoped_library(
"custom_ns", "FRAGMENT"
) as lib:
),
torch.library._scoped_library("custom_ns", "FRAGMENT") as lib,
):
lib.define(
"alltoall_autograd(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor" # noqa: B950
)

View File

@ -464,8 +464,8 @@ class ProcessGroupGlooWrapperTest(AbstractProcessGroupWrapperTest):
if __name__ == "__main__":
assert (
not torch.cuda._initialized
), "test_pg_wrapper must not have initialized CUDA context on main process"
assert not torch.cuda._initialized, (
"test_pg_wrapper must not have initialized CUDA context on main process"
)
run_tests()

View File

@ -372,12 +372,8 @@ class TCPStoreTest(TestCase, StoreTestBase):
# Use noqa to silence flake8.
# Need to store in an unused variable here to ensure the first
# object is not destroyed before the second object is created.
store1 = dist.TCPStore(
addr, port, 1, True, use_libuv=self._use_libuv
) # noqa: F841
store2 = dist.TCPStore(
addr, port, 1, True, use_libuv=self._use_libuv
) # noqa: F841
store1 = dist.TCPStore(addr, port, 1, True, use_libuv=self._use_libuv) # noqa: F841
store2 = dist.TCPStore(addr, port, 1, True, use_libuv=self._use_libuv) # noqa: F841
self.assertEqual(store1.libuvBackend, self._use_libuv)
self.assertEqual(store2.libuvBackend, self._use_libuv)
@ -767,7 +763,7 @@ class RendezvousFileTest(TestCase):
def test_nominal(self):
with tempfile.NamedTemporaryFile(delete=False) as file:
url = f'file:///{file.name.replace(os.path.sep, "/")}?world_size=2'
url = f"file:///{file.name.replace(os.path.sep, '/')}?world_size=2"
gen0 = dist.rendezvous(url + "&rank=0")
store0, rank0, size0 = next(gen0)
self.assertEqual(0, rank0)
@ -1178,8 +1174,8 @@ class TestClientProtocol(TestCase):
if __name__ == "__main__":
assert (
not torch.cuda._initialized
), "test_distributed must not have initialized CUDA context on main process"
assert not torch.cuda._initialized, (
"test_distributed must not have initialized CUDA context on main process"
)
run_tests()

View File

@ -81,9 +81,9 @@ def count_ops(
for node in gm.graph.nodes:
if match_rng_op(node, op) or node.target == op:
actual_count += 1
assert (
actual_count >= freq_ge
), f"In graph {gm}, expected {op} to have occurred at least {freq_ge} times in the graph, but got {actual_count}."
assert actual_count >= freq_ge, (
f"In graph {gm}, expected {op} to have occurred at least {freq_ge} times in the graph, but got {actual_count}."
)
return gm

View File

@ -455,9 +455,9 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
# Modify gradient using .data (Dangerous: Breaks autograd tracking!)
modified_grad = grad_output.clone()
modified_grad.data[
input_tensor.data < 0
] = 0 # Zero-out gradients for negative inputs
modified_grad.data[input_tensor.data < 0] = (
0 # Zero-out gradients for negative inputs
)
return modified_grad * 3

View File

@ -195,10 +195,13 @@ class GraphModule(torch.nn.Module):
def f(x, y):
return invoke_quant_test(inner, [x, y], scheme="nf4")
with mock.patch(
with (
mock.patch(
"torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation",
True,
), torch.no_grad():
),
torch.no_grad(),
):
torch.compile(f, backend=bk, fullgraph=True)(x.clone(), y)
self.assertEqual(len(bk.graphs), 1)
@ -319,10 +322,13 @@ class GraphModule(torch.nn.Module):
x = torch.randn(3, 3, requires_grad=False)
x_clone = x.clone()
y = torch.randn(3, 3, requires_grad=True)
with mock.patch(
with (
mock.patch(
"torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation",
True,
), torch.no_grad():
),
torch.no_grad(),
):
compiled_out = torch.compile(f, backend=backend, fullgraph=True)(x, y)
self.assertEqual(x, x_clone + 1)
self.assertEqual(compiled_out, x_clone + y + 1)

View File

@ -53,8 +53,8 @@ class BytecodeTests(torch._dynamo.test_case.TestCase):
fn_str = f"""\
def fn():
foo.bar(1, 2, 3)
{str(chr(10)).join(' ' * 4 + 'x' + str(i) + ' = 1' for i in range(1 << 9))}
l = [{' '.join('x' + str(i) + ',' for i in range(1 << 9))}]
{str(chr(10)).join(" " * 4 + "x" + str(i) + " = 1" for i in range(1 << 9))}
l = [{" ".join("x" + str(i) + "," for i in range(1 << 9))}]
"""
locals = {}
exec(fn_str, {}, locals)

View File

@ -26,9 +26,10 @@ class CallbackTests(TestCase):
def test_callbacks_with_duplicate_prevention(self) -> None:
trigger = CallbackTrigger.DYNAMO
compile_id = CompileId(0, 0)
with callback_handler.install_callbacks(
trigger, compile_id
), callback_handler.install_callbacks(trigger, compile_id):
with (
callback_handler.install_callbacks(trigger, compile_id),
callback_handler.install_callbacks(trigger, compile_id),
):
self._on_compile_start.assert_called_once()
self._on_compile_end.assert_called_once()

View File

@ -1252,10 +1252,13 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
def f(x, y):
return x + y
x, y = torch.ones(
x, y = (
torch.ones(
1,
), torch.zeros(
),
torch.zeros(
1,
),
)
return f(x, y)
@ -1289,10 +1292,13 @@ class GraphModule(torch.nn.Module):
def f(x, y):
return x + y
x, y = torch.ones(
x, y = (
torch.ones(
1,
), torch.zeros(
),
torch.zeros(
1,
),
)
return f(x, y)
@ -1335,10 +1341,13 @@ class GraphModule(torch.nn.Module):
return inner_fn(x, y) + x
x, y = torch.ones(
x, y = (
torch.ones(
1,
), torch.zeros(
),
torch.zeros(
1,
),
)
return f(x, y)

View File

@ -19,7 +19,7 @@ from torch.testing._internal.common_utils import (
class CustomException(Exception):
...
pass
class CustomExceptionMeta(type):
@ -28,7 +28,7 @@ class CustomExceptionMeta(type):
class CustomExceptionWithInstanceCheck(Exception, metaclass=CustomExceptionMeta):
...
pass
class CustomExceptionWithArgs(Exception):
@ -358,7 +358,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
def test_raise_custom_exception(self):
class Exc(Exception):
...
pass
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
@ -375,7 +375,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
def test_raise_custom_exception_with_args(self):
class Exc(Exception):
...
pass
@torch.compile(backend="eager", fullgraph=True)
def fn(t):

View File

@ -324,9 +324,13 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
distributed_state=None,
package=None,
)
with compile_context(CompileContext(CompileId(0, 0))), tracing(
tracer.output.tracing_context
), tracer.set_current_tx(), get_metrics_context(), dynamo_timed(""):
with (
compile_context(CompileContext(CompileId(0, 0))),
tracing(tracer.output.tracing_context),
tracer.set_current_tx(),
get_metrics_context(),
dynamo_timed(""),
):
tracer.run()
check_fn_manager = CheckFunctionManager(

View File

@ -1092,9 +1092,7 @@ not ___dict_contains('bbbbbbbb', G['sys'].modules)
not ___dict_contains('cccccccc', G['sys'].modules)
str(L['x'].device) == 'cpu'
str(L['x'].dtype) == 'torch.float32'
utils_device.CURRENT_DEVICE == None""".split(
"\n"
):
utils_device.CURRENT_DEVICE == None""".split("\n"):
self.assertIn(
line,
guard_code_str,
@ -2806,7 +2804,7 @@ utils_device.CURRENT_DEVICE == None""".split(
"int",
np.intp,
np.int32,
np.uint8
np.uint8,
# np.dtype('int') # XXX: as above
]
@ -5527,9 +5525,9 @@ utils_device.CURRENT_DEVICE == None""".split(
def forward(self, idx, targets=None):
b, t = idx.size()
assert (
t <= self.block_size
), "Cannot forward, model block size is exhausted."
assert t <= self.block_size, (
"Cannot forward, model block size is exhausted."
)
# forward the GPT model
token_embeddings = self.tok_emb(
@ -6075,15 +6073,17 @@ utils_device.CURRENT_DEVICE == None""".split(
def count_graph_break_msgs(msgs):
return sum("Graph break in user code" in msg for msg in msgs)
with self.assertLogs(
logger="torch._dynamo", level=logging.DEBUG
) as log, torch._dynamo.config.patch(verbose=True):
with (
self.assertLogs(logger="torch._dynamo", level=logging.DEBUG) as log,
torch._dynamo.config.patch(verbose=True),
):
f1(torch.randn(10), torch.randn(10))
self.assertGreater(count_graph_break_msgs(log.output), 1)
with self.assertLogs(
logger="torch._dynamo", level=logging.DEBUG
) as log, torch._dynamo.config.patch(verbose=False):
with (
self.assertLogs(logger="torch._dynamo", level=logging.DEBUG) as log,
torch._dynamo.config.patch(verbose=False),
):
g1(torch.randn(10), torch.randn(10))
self.assertEqual(count_graph_break_msgs(log.output), 1)
@ -8235,8 +8235,9 @@ utils_device.CURRENT_DEVICE == None""".split(
def f(a):
return h(a)
with warnings.catch_warnings(record=True) as w, self.assertRaises(
torch._dynamo.exc.BackendCompilerFailed
with (
warnings.catch_warnings(record=True) as w,
self.assertRaises(torch._dynamo.exc.BackendCompilerFailed),
):
f(torch.randn(2, 2, requires_grad=True))
@ -8429,8 +8430,7 @@ utils_device.CURRENT_DEVICE == None""".split(
def test_torch_compile_ctx_on_forward_and_training_step(self):
class MyModel(torch.nn.Module):
def forward(self):
...
def forward(self): ...
def training_step(self):
self()

View File

@ -2094,11 +2094,12 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
mod = MockModule()
# Each submod is compiled separately and has a different nn module
# guard. Ensure that recompilation logic is handle correctly.
with unittest.mock.patch(
"torch._dynamo.config.error_on_recompile", True
), unittest.mock.patch(
with (
unittest.mock.patch("torch._dynamo.config.error_on_recompile", True),
unittest.mock.patch(
"torch._dynamo.config.recompile_limit",
recompile_limit,
),
):
x = torch.randn(*size, requires_grad=True)
mod(x)
@ -2160,11 +2161,12 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
mod = MockModule()
# Each submod is compiled separately and has a different nn module
# guard. Ensure that recompilation logic is handle correctly.
with unittest.mock.patch(
"torch._dynamo.config.error_on_recompile", True
), unittest.mock.patch(
with (
unittest.mock.patch("torch._dynamo.config.error_on_recompile", True),
unittest.mock.patch(
"torch._dynamo.config.recompile_limit",
recompile_limit,
),
):
x = torch.randn(*size, requires_grad=True)
mod(x)

View File

@ -3,6 +3,7 @@
PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes
with test_adam in OptimizerTests)
"""
import functools
import torch

View File

@ -238,9 +238,7 @@ class RecompileUxTests(torch._dynamo.test_case.TestCase):
tensor 'x' size mismatch at index 0. expected 11, actual 12
tensor 'x' size mismatch at index 0. expected 10, actual 12
tensor 'x' size mismatch at index 0. expected 9, actual 12
tensor 'x' size mismatch at index 0. expected 8, actual 12""".split(
"\n"
):
tensor 'x' size mismatch at index 0. expected 8, actual 12""".split("\n"):
self.assertIn(
line,
failure_str,
@ -276,9 +274,7 @@ tensor 'x' size mismatch at index 0. expected 8, actual 12""".split(
opt_f([7, 8])
for line in """\
len(x) == 3""".split(
"\n"
):
len(x) == 3""".split("\n"):
self.assertIn(line, filter_reasons())
failure_reasons.clear()
@ -286,9 +282,7 @@ len(x) == 3""".split(
for line in """\
len(x) == 2
len(x) == 3""".split(
"\n"
):
len(x) == 3""".split("\n"):
self.assertIn(line, filter_reasons())
@torch._dynamo.config.patch(recompile_limit=1)

View File

@ -179,9 +179,9 @@ def shapes_to_tensor(x, device=None):
if torch.jit.is_scripting():
return torch.as_tensor(x, device=device)
if torch.jit.is_tracing():
assert all(
isinstance(t, torch.Tensor) for t in x
), "Shape should be tensor during tracing!"
assert all(isinstance(t, torch.Tensor) for t in x), (
"Shape should be tensor during tracing!"
)
# as_tensor should not be used in tracing because it records a constant
ret = torch.stack(x)
if ret.device != device: # avoid recording a hard-coded device if not necessary
@ -480,9 +480,9 @@ class PartialT5(torch.nn.Module):
real_seq_length = seq_length
if past_key_value is not None:
assert (
len(past_key_value) == 2
), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states"
assert len(past_key_value) == 2, (
f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states"
)
real_seq_length += (
past_key_value[0].shape[2] if query_length is None else query_length
)
@ -4877,9 +4877,9 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
with warnings.catch_warnings(record=True):
data_len = len(value)
if len(self._fields):
assert (
len(self) == data_len
), f"Adding a field of length {data_len} to a Instances of length {len(self)}"
assert len(self) == data_len, (
f"Adding a field of length {data_len} to a Instances of length {len(self)}"
)
self._fields[name] = value
def get(self, name: str) -> Any:

View File

@ -108,9 +108,10 @@ def get_view_test_cases():
for requires_grad_1, requires_grad_2 in itertools.product(
[True, False], repeat=2
):
yield partial(
mk_leaf, base_is_nt, requires_grad_1, requires_grad_2
), f"{prefix}_leaf_{requires_grad_1}_{requires_grad_2}"
yield (
partial(mk_leaf, base_is_nt, requires_grad_1, requires_grad_2),
f"{prefix}_leaf_{requires_grad_1}_{requires_grad_2}",
)
# (3) obscure case:
# view is not a leaf (implies requires_grad True)
@ -118,9 +119,10 @@ def get_view_test_cases():
yield partial(mk_obscure, base_is_nt), f"{prefix}_obscure"
# Subclass -> Dense
yield lambda: get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[
0
].clone(), "subclass_dense"
yield (
lambda: get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone(),
"subclass_dense",
)
# Dense -> Subclass -> Dense -> Subclass
def mk_dense_subclass_dense_subclass():

View File

@ -151,9 +151,9 @@ def gen_allowed_objs_and_ids(record=False, c_binding_only=True) -> AllowedObject
types.WrapperDescriptorType,
),
) or is_special_functions(obj):
torch_name_rule_map[
f"{module.__name__}.{name}"
] = TorchInGraphFunctionVariable
torch_name_rule_map[f"{module.__name__}.{name}"] = (
TorchInGraphFunctionVariable
)
if c_binding_only:
if not hasattr(obj, "__code__"):
c_binding_in_graph_functions.add(obj)
@ -398,12 +398,15 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase):
)
self.assertTrue("torch._dynamo" not in torch._dynamo.trace_rules.MOD_INLINELIST)
with unittest.mock.patch(
with (
unittest.mock.patch(
"torch._dynamo.trace_rules.torch_name_rule_map",
_torch_name_rule_map,
), unittest.mock.patch(
),
unittest.mock.patch(
"torch._dynamo.trace_rules.get_torch_obj_rule_map",
torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, # bypass functools.lru_cache
),
):
x = torch.rand(3)
opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
@ -419,9 +422,9 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase):
_manual_torch_name_rule_map = manual_torch_name_rule_map.copy()
# Force inline `mod.func` by setting trace rule.
_manual_torch_name_rule_map[
f"{mod.__name__}.{func.__name__}"
] = UserFunctionVariable
_manual_torch_name_rule_map[f"{mod.__name__}.{func.__name__}"] = (
UserFunctionVariable
)
_torch_name_rule_map = [
_manual_torch_name_rule_map,
@ -429,12 +432,15 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase):
torch_non_c_binding_in_graph_functions,
]
with unittest.mock.patch(
with (
unittest.mock.patch(
"torch._dynamo.trace_rules.torch_name_rule_map",
_torch_name_rule_map,
), unittest.mock.patch(
),
unittest.mock.patch(
"torch._dynamo.trace_rules.get_torch_obj_rule_map",
torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__,
),
):
# First adding the module to SKIP_DIRS so that it will be skipped by default.
torch._dynamo.trace_rules.add(mod.__name__)

View File

@ -593,9 +593,10 @@ class TestDynamoTimed(TestCase):
)
compilation_events = []
with dynamo_config.patch({"automatic_dynamic_shapes": False}), mock.patch(
"torch._dynamo.utils.log_compilation_event"
) as log_event:
with (
dynamo_config.patch({"automatic_dynamic_shapes": False}),
mock.patch("torch._dynamo.utils.log_compilation_event") as log_event,
):
@torch.compile()
def f(x):

View File

@ -181,9 +181,12 @@ class TestDraftExport(TestCase):
self.assertEqual(len(report.op_profiles), 1)
self.assertEqual(len(report.op_profiles["mylib.foo8.default"]), 1)
with torch._library.fake_profile.unsafe_generate_fake_kernels(
with (
torch._library.fake_profile.unsafe_generate_fake_kernels(
report.op_profiles
), FakeTensorMode(allow_non_fake_inputs=True, shape_env=ShapeEnv()):
),
FakeTensorMode(allow_non_fake_inputs=True, shape_env=ShapeEnv()),
):
torch.ops.mylib.foo8(*new_inp)
# Existing registration has been updated to match the new

View File

@ -12834,12 +12834,14 @@ def forward(self, x, y):
"y": [Dim("dy")], # y & z incorrect, export is supposed to fail.
"z": [Dim("dz")], # suggested fix should be to match these up.
}
with self.assertRaisesRegex( # if disable=True, suggested fixes should not specialize.
with (
self.assertRaisesRegex( # if disable=True, suggested fixes should not specialize.
torch._dynamo.exc.UserError,
r".*Constraints violated(.*\n)*"
r"Suggested fixes:(.*\n)*"
r".*dz = dy(.*\n)*",
) as msg:
) as msg
):
export(
Foo(),
inputs,
@ -13675,8 +13677,7 @@ def forward(self, x):
"""Make sure the metadata is kept after exported program run_decompositions."""
@torch.library.custom_op("mylib::add", mutates_args=())
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
...
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ...
@torch.library.register_fake("mylib::add")
def _(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

View File

@ -947,10 +947,12 @@ class TestDeserialize(TestCase):
ep = torch.export.export(M(), (torch.ones(3, 3), None, torch.ones(3, 3)))
serialized_program = ExportedProgramSerializer(None, 2).serialize(ep)
serialized_program.exported_program.graph_module.signature.input_specs[
1
] = schema.InputSpec.create(
user_input=schema.UserInputSpec(arg=schema.Argument.create(as_none=True))
serialized_program.exported_program.graph_module.signature.input_specs[1] = (
schema.InputSpec.create(
user_input=schema.UserInputSpec(
arg=schema.Argument.create(as_none=True)
)
)
)
ep = ExportedProgramDeserializer(None).deserialize(
serialized_program.exported_program, {}, {}, {}

View File

@ -22,7 +22,9 @@ from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase
{"strict": False},
{"strict": True},
],
class_name_func=lambda cls, _, params: f"{cls.__name__}_{'strict' if params['strict'] else 'nonstrict'}",
class_name_func=lambda cls,
_,
params: f"{cls.__name__}_{'strict' if params['strict'] else 'nonstrict'}",
)
class TestSwap(TestCase):
def test_unflatten_preserve_signature(self):

View File

@ -348,8 +348,7 @@ def check_fc(existing_schemas):
"\n\t".join(str(s) for s in matching_new_schemas),
)
log.warning(
"Refer to following reasons for failure "
"to find FC schema:\n[\n%s\n]",
"Refer to following reasons for failure to find FC schema:\n[\n%s\n]",
"\n\t".join(str(r) for r in possible_failure_reasons),
)
broken_ops.append(str(existing_schema))

View File

@ -523,15 +523,15 @@ def decorateForModules(decorator, module_classes, device_type=None, dtypes=None)
dtypes=dtypes,
):
name_parts = fn.__qualname__.split(".")
assert (
len(name_parts) == 2
), "Decorator only applies to a test function of a test class"
assert len(name_parts) == 2, (
"Decorator only applies to a test function of a test class"
)
test_case_name, base_test_name = name_parts
for module_cls in module_classes:
matching_module_infos = [m for m in module_db if m.module_cls == module_cls]
assert (
len(matching_module_infos) == 1
), f"Couldn't find single ModuleInfo for {module_cls}"
assert len(matching_module_infos) == 1, (
f"Couldn't find single ModuleInfo for {module_cls}"
)
module_info = matching_module_infos[0]
decorators = list(module_info.decorators)
new_decorator = DecorateInfo(

View File

@ -124,9 +124,7 @@ class TestGraphInfoProvider(TestCase):
)
def test_recomputable_node_only_graph_with_larger_graph_context(self):
recomputable_node_only_graph_with_larger_graph_context = (
self.graph_info_provider.recomputable_node_only_graph_with_larger_graph_context
)
recomputable_node_only_graph_with_larger_graph_context = self.graph_info_provider.recomputable_node_only_graph_with_larger_graph_context # noqa: B950
expected_nodes = self.all_recomputable_banned_nodes
# node1 does not have an indirect path to node5 because of node2
# node2 has an indirect path to node5

View File

@ -2568,8 +2568,9 @@ def forward(self, primals_1, primals_2):
def fn(x: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
return torch.ops._test._clone_create_graph(x, x1)
inp_x, inp_x1 = torch.randn(3, requires_grad=True), torch.randn(
3, requires_grad=True
inp_x, inp_x1 = (
torch.randn(3, requires_grad=True),
torch.randn(3, requires_grad=True),
)
ref_x, ref_x1 = inp_x.clone(), inp_x1.clone()
@ -5283,11 +5284,12 @@ def forward(self, arg0_1):
mod = TestMod(fn)
inp = torch.randn(2)
with patch(
"functorch.compile.config.functionalize_rng_ops", True
), self.assertRaisesRegex(
with (
patch("functorch.compile.config.functionalize_rng_ops", True),
self.assertRaisesRegex(
RuntimeError,
"Functionalized RNG is not currently supported in the aot_export",
),
):
aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False)
aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True)

View File

@ -4324,9 +4324,7 @@ class TestExamplesCorrectness(TestCase):
def lennard_jones_force(r):
"""Get magnitude of LJ force"""
return -epsilon * (
(-12 * sigma**12 / r**13) + (6 * sigma**6 / r**7)
)
return -epsilon * ((-12 * sigma**12 / r**13) + (6 * sigma**6 / r**7))
r = torch.linspace(0.5, 2 * sigma, steps=100, requires_grad=True, device=device)
drs = torch.outer(r, torch.tensor([1.0, 0, 0], device=device))
@ -4495,8 +4493,9 @@ class TestExamplesCorrectness(TestCase):
# This example mimics what a user might do when trying to find the optimal learning rate. They would
# want to run a bunch of models with the same behavior (including the same dropout!) and have them
# each run with different learning rates. Specifically, this is an example of using same randomness with vmap
points, labels = torch.randn(100, 2, 2, 2, 2, device=device), torch.randint(
0, 2, (100,), device=device
points, labels = (
torch.randn(100, 2, 2, 2, 2, device=device),
torch.randint(0, 2, (100,), device=device),
)
class MLPClassifier(nn.Module):

View File

@ -208,33 +208,33 @@ def check(f, t, delta, check_val=True, graph_input=False):
old_num_nodes = len(fx_g.graph.nodes)
new_num_nodes = len(new_graph.nodes)
if delta == -1:
assert (
old_num_nodes >= new_num_nodes
), f"number of nodes increased {old_num_nodes}, {new_num_nodes}"
assert old_num_nodes >= new_num_nodes, (
f"number of nodes increased {old_num_nodes}, {new_num_nodes}"
)
else:
assert (
old_num_nodes == new_num_nodes + delta
), f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}"
assert old_num_nodes == new_num_nodes + delta, (
f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}"
)
# a second pass should not reduce more nodes
pass_2_graph = fx_graph_cse(new_graph)
pass_2_num_nodes = len(pass_2_graph.nodes)
assert (
pass_2_num_nodes == new_num_nodes
), f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}"
assert pass_2_num_nodes == new_num_nodes, (
f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}"
)
# check correctness
if check_val:
true_result = fx_g(t)
our_result = new_g(t)
if true_result is None: # both return None
assert (
our_result is None
), f"true result is None, CSE result is {our_result}"
assert our_result is None, (
f"true result is None, CSE result is {our_result}"
)
else: # results returned are the same
assert torch.all(
true_result == our_result
), f"results are different {true_result}, {our_result}" # check results are the same
assert torch.all(true_result == our_result), (
f"results are different {true_result}, {our_result}"
) # check results are the same
class NoChangeTestCase(TestCase):

View File

@ -2154,9 +2154,9 @@ class TestOperators(TestCase):
else:
weight = torch.randn(weight_shape, device=device)
target = torch.randint(0, C, target_shape, device=device)
target[
0
] = 1 # since we're ignoring index 0, at least one element must be non-zero
target[0] = (
1 # since we're ignoring index 0, at least one element must be non-zero
)
fn = functools.partial(
torch.nn.functional.nll_loss, target=target, weight=weight, **kwargs

View File

@ -24,6 +24,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
from typing import Any
from unittest import mock
@ -107,7 +108,7 @@ class TestParsedExpression(TestCase):
ParsedExpression("(a) ((b c) (d ...))")
# invalid identifiers
ParsedExpression("camelCase under_scored cApiTaLs \u00DF ...")
ParsedExpression("camelCase under_scored cApiTaLs \u00df ...")
with self.assertRaises(ValueError):
ParsedExpression("1a")
with self.assertRaises(ValueError):

View File

@ -25,7 +25,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import numpy as np
import torch

View File

@ -1532,7 +1532,7 @@ class TestVmapOperators(Namespace.TestVmapBase):
self._test_unary(op, getter, "cpu")
# test in-place
method = getattr(Tensor, f'{op.__name__ + "_"}')
method = getattr(Tensor, f"{op.__name__ + '_'}")
self._test_unary(method, getter, "cpu", check_propagates_grad=False)
def test_clone(self):

View File

@ -2,6 +2,7 @@ r"""
**This file is EXPERIMENTAL and is mostly used for testing purposes! Do not
rely on it for anything!**
"""
import operator
import sys

View File

@ -46,9 +46,9 @@ def check(self, f, t, delta, check_val=True, graph_input=False, P=None):
old_num_nodes = len(fx_g.graph.nodes)
new_num_nodes = len(new_graph.nodes)
assert (
new_num_nodes < old_num_nodes
) == modified, "modified should be True if the number of nodes decrease"
assert (new_num_nodes < old_num_nodes) == modified, (
"modified should be True if the number of nodes decrease"
)
if delta == -1:
self.assertTrue(

View File

@ -1783,8 +1783,9 @@ class TestSingleOperation(unittest.TestCase):
self.assertEqual(s.check(), z3.sat)
add_result = z3.Const(3, tensor_type)
broadcast_res1, broadcast_res2 = z3.Const(4, tensor_type), z3.Const(
5, tensor_type
broadcast_res1, broadcast_res2 = (
z3.Const(4, tensor_type),
z3.Const(5, tensor_type),
)
# print(s.model())

View File

@ -251,10 +251,13 @@ class TestInvokeSubgraphCompile(TestCase):
x_clone = x.detach().clone().requires_grad_(True)
y_clone = y.detach().clone().requires_grad_(True)
backend = EagerAndRecordGraphs()
with mock.patch(
with (
mock.patch(
"torch._dynamo.variables.higher_order_ops.InvokeSubgraphHigherOrderVariable.supports_input_mutation",
True,
), torch.no_grad():
),
torch.no_grad(),
):
res = torch.compile(fn, backend=backend, fullgraph=True)(
mod, x_clone, y_clone
)
@ -2399,7 +2402,9 @@ class GraphModule(torch.nn.Module):
{"strict": False},
{"strict": True},
],
class_name_func=lambda cls, _, params: f"{cls.__name__}{'Strict' if params['strict'] else 'Nonstrict'}",
class_name_func=lambda cls,
_,
params: f"{cls.__name__}{'Strict' if params['strict'] else 'Nonstrict'}",
)
class TestInvokeSubgraphExport(TestCase):
def test_simple_func(self):

View File

@ -38,7 +38,6 @@ USE_BLACK_FILELIST = re.compile(
# torchgen/**
# test/**
# test/[a-h]*/**
"test/[a-h]*/**",
# test/[i-j]*/**
"test/[i-j]*/**",
# test/[k-m]*/**