diff --git a/test/ao/sparsity/test_data_sparsifier.py b/test/ao/sparsity/test_data_sparsifier.py index efd75b057aa..99943821574 100644 --- a/test/ao/sparsity/test_data_sparsifier.py +++ b/test/ao/sparsity/test_data_sparsifier.py @@ -494,9 +494,7 @@ class TestBaseDataSparsifier(_BaseDataSparsiferTestCase): ( emb1, emb2, - ) = nn.Embedding( - 10, 3 - ), nn.Embedding(20, 3) + ) = nn.Embedding(10, 3), nn.Embedding(20, 3) emb1_bag, emb2_bag = nn.EmbeddingBag(10, 3), nn.EmbeddingBag(20, 3) emb3, emb3_bag = nn.Embedding(15, 3), nn.EmbeddingBag(20, 3) @@ -627,9 +625,7 @@ class TestNormDataSparsifiers(_NormDataSparsifierTestCase): ( emb1, emb2, - ) = nn.Embedding( - 10, 3 - ), nn.Embedding(20, 3) + ) = nn.Embedding(10, 3), nn.Embedding(20, 3) emb1_bag, emb2_bag = nn.EmbeddingBag(10, 3), nn.EmbeddingBag(20, 3) emb3, emb3_bag = nn.Embedding(15, 3), nn.EmbeddingBag(20, 3) diff --git a/test/ao/sparsity/test_kernels.py b/test/ao/sparsity/test_kernels.py index be2f3b8cd98..1ffdca5fd34 100644 --- a/test/ao/sparsity/test_kernels.py +++ b/test/ao/sparsity/test_kernels.py @@ -218,12 +218,12 @@ def _sparse_layer_test_helper( qmodule_to_check = fqn_to_module(qmodel, fqn_to_check) # check that the modules were converted as expected - assert isinstance( - sqmodule_to_check, sqmodule_expected_converted_class - ), "Convert failed" - assert isinstance( - qmodule_to_check, qmodule_expected_converted_class - ), "Mapping failed" + assert isinstance(sqmodule_to_check, sqmodule_expected_converted_class), ( + "Convert failed" + ) + assert isinstance(qmodule_to_check, qmodule_expected_converted_class), ( + "Mapping failed" + ) row_block_size, col_block_size = sqmodel.linear._packed_params._weight_bias()[ 2: diff --git a/test/ao/sparsity/test_structured_sparsifier.py b/test/ao/sparsity/test_structured_sparsifier.py index 21a951ddec2..c62cc3d3053 100644 --- a/test/ao/sparsity/test_structured_sparsifier.py +++ b/test/ao/sparsity/test_structured_sparsifier.py @@ -1055,9 +1055,9 @@ class TestFPGMPruner(TestCase): mask1 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-1] mask2 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-2] # Check if either of the least-norm filters is not pruned - assert ( - mask1.item() is not False or mask2.item() is not False - ), "Do not prune all least-norm filters" + assert mask1.item() is not False or mask2.item() is not False, ( + "Do not prune all least-norm filters" + ) # fusion step pruned_model = pruner.prune() diff --git a/test/benchmark_utils/test_benchmark_utils.py b/test/benchmark_utils/test_benchmark_utils.py index 969a1584b68..1d8d8e7e359 100644 --- a/test/benchmark_utils/test_benchmark_utils.py +++ b/test/benchmark_utils/test_benchmark_utils.py @@ -66,9 +66,10 @@ def generate_callgrind_artifacts() -> None: json.dump(artifacts, f, indent=4) -def load_callgrind_artifacts() -> ( - tuple[benchmark_utils.CallgrindStats, benchmark_utils.CallgrindStats] -): +def load_callgrind_artifacts() -> tuple[ + benchmark_utils.CallgrindStats, + benchmark_utils.CallgrindStats, +]: """Hermetic artifact to unit test Callgrind wrapper. In addition to collecting counts, this wrapper provides some facilities for diff --git a/test/cpp_api_parity/functional_impl_check.py b/test/cpp_api_parity/functional_impl_check.py index b4272a2df1b..34b9ac15812 100644 --- a/test/cpp_api_parity/functional_impl_check.py +++ b/test/cpp_api_parity/functional_impl_check.py @@ -158,7 +158,8 @@ def compute_functional_name(test_params_dict): return test_params_dict["cpp_function_call"].split("(")[0].replace("F::", "") else: raise RuntimeError( - f"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{pprint.pformat(test_params_dict)}" # noqa: B950 + "`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n" + f"{pprint.pformat(test_params_dict)}" ) @@ -179,7 +180,8 @@ def compute_cpp_function_call(test_params_dict, arg_dict, functional_name): ) else: raise RuntimeError( - f"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{pprint.pformat(test_params_dict)}" # noqa: B950 + "`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n" + f"{pprint.pformat(test_params_dict)}" ) @@ -217,7 +219,8 @@ def write_test_to_test_class( or "cpp_function_call" in test_params_dict ), ( "To enable C++ API parity test, " - f"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{pprint.pformat(test_params_dict)}. \n" # noqa: B950 + "`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n" + f"{pprint.pformat(test_params_dict)}. \n" "If you are interested in adding the C++ API parity test, please see:\n" "NOTE [How to check NN module / functional API parity between Python and C++ frontends]. \n" "If not, please add `test_cpp_api_parity=False` to the test params dict and file an issue about this." @@ -233,14 +236,16 @@ def write_test_to_test_class( functional_name = compute_functional_name(test_params_dict) - assert hasattr( - torch.nn.functional, functional_name - ), f"`torch.nn.functional` doesn't have function `{functional_name}`. (Discovered while processing\n{pprint.pformat(test_params_dict)}.)" # noqa: B950 + assert hasattr(torch.nn.functional, functional_name), ( + f"`torch.nn.functional` doesn't have function `{functional_name}`. " + f"(Discovered while processing\n{pprint.pformat(test_params_dict)}.)" + ) functional_full_name = "F::" + functional_name assert functional_full_name in parity_table["torch::nn::functional"], ( - f"Please add `{functional_full_name}` entry to `torch::nn::functional` section of `test/cpp_api_parity/parity-tracker.md`. " + f"Please add `{functional_full_name}` entry to `torch::nn::functional` " + "section of `test/cpp_api_parity/parity-tracker.md`. " f"(Discovered while processing\n{pprint.pformat(test_params_dict)}.)" ) diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py index 8d98387cf5f..d4c49bd28d4 100644 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py @@ -78,9 +78,9 @@ def _kernel_fallback(op, *args, **kwargs): elif op._schema.is_mutable or op is torch.ops.aten._copy_from.default: # Only handle inplace ops returning their first arg assert len(args) >= 1, f"Inplace {op} needs at least one arg" - assert ( - len(op._schema.returns) == 1 - ), f"NYI Inplace {op} with more than one return" + assert len(op._schema.returns) == 1, ( + f"NYI Inplace {op} with more than one return" + ) op_name = op.overloadpacket._qualified_op_name real_res = args[0] elif any(r.alias_info is not None for r in op._schema.returns): diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py index 80194b38aae..0f54f2ec4df 100644 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py @@ -67,7 +67,7 @@ def prepare_for_sending(args, kwargs): def convert(obj): if type(obj) not in VALID_QUEUE_TYPES_IN: raise RuntimeError( - f"Cannot send object of type {type(obj)} " "over openreg device pipe." + f"Cannot send object of type {type(obj)} over openreg device pipe." ) if isinstance(obj, torch.Tensor): @@ -82,8 +82,7 @@ def receive_after_sending(allocator, args, kwargs): def convert(obj): if type(obj) not in VALID_QUEUE_TYPES_OUT: raise RuntimeError( - f"Received invalid object of type {type(obj)} " - "over openreg device pipe." + f"Received invalid object of type {type(obj)} over openreg device pipe." ) if isinstance(obj, OpenRegTensorMeta): diff --git a/test/distributed/_composable/fsdp/test_fully_shard_comm.py b/test/distributed/_composable/fsdp/test_fully_shard_comm.py index cd37107e5a9..dcc34b5489a 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_comm.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_comm.py @@ -561,8 +561,9 @@ class TestFullyShardPrefetch(FSDPTest): FSDPParamGroup.post_backward, events ) # Check the order for normal 1 forward, 1 backward, 1 optimizer step - with patch_unshard(unshard_with_record), patch_post_backward( - post_backward_with_record + with ( + patch_unshard(unshard_with_record), + patch_post_backward(post_backward_with_record), ): for iter_idx in range(3): loss = model(inp) @@ -617,8 +618,9 @@ class TestFullyShardPrefetch(FSDPTest): FSDPParamGroup.post_backward, events ) # Check the order for multiple forwards before 1 backward - with patch_unshard(unshard_with_record), patch_post_backward( - post_backward_with_record + with ( + patch_unshard(unshard_with_record), + patch_post_backward(post_backward_with_record), ): loss1 = model(inp) loss2 = model(inp) @@ -703,8 +705,9 @@ class TestFullyShardPrefetch(FSDPTest): post_backward_with_record = self._get_post_backward_with_record( FSDPParamGroup.post_backward, events ) - with patch_unshard(unshard_with_record), patch_post_backward( - post_backward_with_record + with ( + patch_unshard(unshard_with_record), + patch_post_backward(post_backward_with_record), ): loss1, loss2 = model(inp) expected_events = [ @@ -794,9 +797,11 @@ class TestFullyShardPrefetch(FSDPTest): ("reshard", "", TrainingState.POST_BACKWARD), ("post_backward", "", TrainingState.POST_BACKWARD), ] - with patch_unshard(unshard_with_record), patch_reshard( - reshard_with_record - ), patch_post_backward(post_backward_with_record): + with ( + patch_unshard(unshard_with_record), + patch_reshard(reshard_with_record), + patch_post_backward(post_backward_with_record), + ): set_forward_prefetch(model, num_to_prefetch=1) loss = model(inp) expected_forward_events = [ @@ -882,9 +887,11 @@ class TestFullyShardPrefetch(FSDPTest): ("reshard", "layers.3", TrainingState.FORWARD), ("reshard", "", TrainingState.FORWARD), ] - with patch_unshard(unshard_with_record), patch_reshard( - reshard_with_record - ), patch_post_backward(post_backward_with_record): + with ( + patch_unshard(unshard_with_record), + patch_reshard(reshard_with_record), + patch_post_backward(post_backward_with_record), + ): set_backward_prefetch(model, num_to_prefetch=1) loss = model(inp) self.assertEqual(events, expected_forward_events) @@ -967,8 +974,9 @@ class TestFullyShardPrefetch(FSDPTest): (2, model_args.max_seq_len), device=device_type.type, ) - with patch_unshard(unshard_with_record), patch_post_backward( - post_backward_with_record + with ( + patch_unshard(unshard_with_record), + patch_post_backward(post_backward_with_record), ): for _ in range(3): loss = model(inp) @@ -1046,8 +1054,9 @@ class TestFullyShardPrefetch(FSDPTest): FSDPParamGroup.post_backward, events ) inp = torch.randn((2, 16), device=device_type.type) - with patch_unshard(unshard_with_record), patch_post_backward( - post_backward_with_record + with ( + patch_unshard(unshard_with_record), + patch_post_backward(post_backward_with_record), ): for _ in range(3): loss = model(inp) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_compile.py b/test/distributed/_composable/fsdp/test_fully_shard_compile.py index 02a3377babf..c376aa0e1aa 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_compile.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_compile.py @@ -222,9 +222,7 @@ class TestFullyShardCompile(FSDPTest): ): unsharded_param_graph_inputs.add(node.args[0]) assert len(unsharded_param_graph_inputs) > 0 - assert len(unsharded_param_graph_inputs) == len( - list(model.parameters()) - ), """\ + assert len(unsharded_param_graph_inputs) == len(list(model.parameters())), """\ Expected all model parameters to be wrapped by FSDP2 and have their unsharded version as graph input, but it's not true! """ @@ -237,7 +235,7 @@ have their unsharded version as graph input, but it's not true! no_aliased_unsharded_params_in_graph_inputs = False err_msg += f"""\n Found aliased unsharded param in graph inputs: {aliased_graph_inputs}, -val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]}, +val.shape: {[node.meta["val"].shape for node in aliased_graph_inputs]}, """ self.assertTrue(no_aliased_unsharded_params_in_graph_inputs, err_msg) @@ -466,10 +464,9 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]}, @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_compiled_autograd_ctx(self): self.skipTestForOldSm() - with torch._dynamo.config.patch( - skip_fsdp_hooks=False, - ), torch._functorch.config.patch( - recompute_views=True, + with ( + torch._dynamo.config.patch(skip_fsdp_hooks=False), + torch._functorch.config.patch(recompute_views=True), ): inputs = torch.randn(8, 8) model = torch.nn.Linear(8, 8) @@ -567,24 +564,28 @@ Unsupported Tensor.backward() call torch._dynamo.reset() torch._dynamo.compiled_autograd.reset() - with torch._dynamo.config.patch( - compiled_autograd=True, - compiled_autograd_kwargs_override={ - "fullgraph": True, - }, - inline_inbuilt_nn_modules=True, - skip_fsdp_hooks=False, - ), torch._functorch.config.patch( - enable_autograd_cache=False, - recompute_views=True, - ), torch._inductor.config.patch( - force_disable_caches=True, - reorder_for_compute_comm_overlap=True, - reorder_for_compute_comm_overlap_passes=[ - "sink_waits", - "raise_comms", - "reorder_compute_for_overlap", - ], + with ( + torch._dynamo.config.patch( + compiled_autograd=True, + compiled_autograd_kwargs_override={ + "fullgraph": True, + }, + inline_inbuilt_nn_modules=True, + skip_fsdp_hooks=False, + ), + torch._functorch.config.patch( + enable_autograd_cache=False, + recompute_views=True, + ), + torch._inductor.config.patch( + force_disable_caches=True, + reorder_for_compute_comm_overlap=True, + reorder_for_compute_comm_overlap_passes=[ + "sink_waits", + "raise_comms", + "reorder_compute_for_overlap", + ], + ), ): losses_compiled = test_compiled() losses_eager = test_eager() @@ -741,20 +742,21 @@ Unsupported Tensor.backward() call def _test_nested_fully_shard_backend_inductor_fullgraph_True(self): self.skipTestForOldSm() for fwd_fullgraph in [True]: - with self._reinplace_all_gather_with_optional_checks( - fwd_fullgraph - ), torch._inductor.config.patch( - post_grad_custom_post_pass=( - functools.partial( - self._check_fsdp_copy_and_resize_ops_count_in_graph, - fwd_copy_count=0, - fwd_resize_count=0, - bwd_copy_count=0, - bwd_resize_count=0, + with ( + self._reinplace_all_gather_with_optional_checks(fwd_fullgraph), + torch._inductor.config.patch( + post_grad_custom_post_pass=( + functools.partial( + self._check_fsdp_copy_and_resize_ops_count_in_graph, + fwd_copy_count=0, + fwd_resize_count=0, + bwd_copy_count=0, + bwd_resize_count=0, + ) + if fwd_fullgraph + else None ) - if fwd_fullgraph - else None - ) + ), ): _, triton_codes = run_and_get_code( lambda: self._test_traceable_fsdp( @@ -943,9 +945,10 @@ Unsupported Tensor.backward() call for fwd_fullgraph, all_requires_grad in itertools.product( [True], [True, False] ): - with self._maybe_add_graph_break_to_sdpa( - fwd_fullgraph - ), self._reinplace_all_gather_with_optional_checks(fwd_fullgraph): + with ( + self._maybe_add_graph_break_to_sdpa(fwd_fullgraph), + self._reinplace_all_gather_with_optional_checks(fwd_fullgraph), + ): self._test_traceable_fsdp( *self._create_transformer_factory_fns( all_requires_grad=all_requires_grad @@ -982,23 +985,24 @@ Unsupported Tensor.backward() call log.warning( f"fwd_fullgraph={fwd_fullgraph}, all_requires_grad={all_requires_grad}, activation_checkpoint={activation_checkpoint}" # noqa: G004, G001, B950 ) - with self._reinplace_all_gather_with_optional_checks( - fwd_fullgraph - ), torch._inductor.config.patch( - post_grad_custom_post_pass=( - functools.partial( - self._check_fsdp_copy_and_resize_ops_count_in_graph, - # NOTE: For the root unsharded params, we don't reshard after forward since for training, - # the parameters would be freed and all-gathered immediately. Hence we still have - # their resize and copy ops in the graph. - fwd_copy_count=4, - fwd_resize_count=4, - bwd_copy_count=0, - bwd_resize_count=4, + with ( + self._reinplace_all_gather_with_optional_checks(fwd_fullgraph), + torch._inductor.config.patch( + post_grad_custom_post_pass=( + functools.partial( + self._check_fsdp_copy_and_resize_ops_count_in_graph, + # NOTE: For the root unsharded params, we don't reshard after forward since for training, + # the parameters would be freed and all-gathered immediately. Hence we still have + # their resize and copy ops in the graph. + fwd_copy_count=4, + fwd_resize_count=4, + bwd_copy_count=0, + bwd_resize_count=4, + ) + if fwd_fullgraph + else None ) - if fwd_fullgraph - else None - ) + ), ): _, triton_codes = run_and_get_code( lambda: self._test_traceable_fsdp( diff --git a/test/distributed/_composable/fsdp/test_fully_shard_extensions.py b/test/distributed/_composable/fsdp/test_fully_shard_extensions.py index f8888d12fc9..0b25e09b3de 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_extensions.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_extensions.py @@ -385,9 +385,9 @@ class TestFullyShardAllGatherExtensionsMultiThread( only some ranks may require padding, in which case only those ranks will error out and the all-gather will timeout. """ - assert ( - self.world_size >= 2 - ), f"Assumes world size of at least 2 but got {self.world_size=}" + assert self.world_size >= 2, ( + f"Assumes world size of at least 2 but got {self.world_size=}" + ) model = MLP(dim=3, dim_multiplier=3) for module in model.modules(): for param_name, param in module.named_parameters(recurse=False): diff --git a/test/distributed/_composable/fsdp/test_fully_shard_frozen.py b/test/distributed/_composable/fsdp/test_fully_shard_frozen.py index 467b63563b8..f56c5e76c12 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_frozen.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_frozen.py @@ -115,9 +115,10 @@ class TestFullyShardFrozen(FSDPTest): torch.manual_seed(42 + self.rank + 1) device = device_type - with patch_reduce_scatter( - reduce_scatter - ), patch_register_post_backward_hook_backward(backward_with_count): + with ( + patch_reduce_scatter(reduce_scatter), + patch_register_post_backward_hook_backward(backward_with_count), + ): for iter_idx in range(10): inp = torch.randn((8, lin_dim), device=device) losses: list[torch.Tensor] = [] diff --git a/test/distributed/_composable/fsdp/test_fully_shard_init.py b/test/distributed/_composable/fsdp/test_fully_shard_init.py index 6f5326dab5a..714145f8b97 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_init.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_init.py @@ -910,9 +910,9 @@ class TestFullyShardProcessGroupInit(FSDPTestMultiThread): @skip_if_lt_x_gpu(1) def test_2d_process_group_init(self): shard_mesh_dim_size = 2 - assert ( - self.world_size % shard_mesh_dim_size == 0 - ), f"Expects {self.world_size} to be divisible by {shard_mesh_dim_size}" + assert self.world_size % shard_mesh_dim_size == 0, ( + f"Expects {self.world_size} to be divisible by {shard_mesh_dim_size}" + ) replicate_mesh_dim_size = self.world_size // shard_mesh_dim_size mesh_dim_names = ("replicate", "shard") ref_mesh = init_device_mesh( diff --git a/test/distributed/_composable/fsdp/test_fully_shard_memory.py b/test/distributed/_composable/fsdp/test_fully_shard_memory.py index c3b8f04688e..44d05ade98f 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_memory.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_memory.py @@ -54,9 +54,9 @@ class TestFullyShardMemory(FSDPTest): ) ): return # skip since not a common use case - assert ( - self.world_size == 2 - ), f"Requires world size of 2 since some values are hard coded: {self.world_size}" + assert self.world_size == 2, ( + f"Requires world size of 2 since some values are hard coded: {self.world_size}" + ) torch.manual_seed(42) # Pre-run a linear forward (gemm and bias) and backward (gemm) to # allocate the cuBLAS workspaces before measuring the memory usage diff --git a/test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py b/test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py index 4e5bf9465b4..06881442b74 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py @@ -284,9 +284,7 @@ class TestFullyShardMixedPrecisionTraining(FSDPTest): ) # bf16 reduction param.grad = funcol.all_gather_tensor( sharded_grad, gather_dim=0, group=group - ).to( - param.dtype - ) # upcast to fp32 + ).to(param.dtype) # upcast to fp32 ref_optim.step() # fp32 optimizer step self.assertEqual(fsdp_loss, ref_loss) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_overlap.py b/test/distributed/_composable/fsdp/test_fully_shard_overlap.py index c9653d06ade..e8d52f70e0f 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_overlap.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_overlap.py @@ -139,8 +139,9 @@ class TestFullyShardOverlap(FSDPTest): dist.reduce_scatter_tensor(dummy_rs_output, dummy_rs_input) def fwd_bwd(): - with patch_all_gather(delayed_all_gather), patch_reduce_scatter( - delayed_reduce_scatter + with ( + patch_all_gather(delayed_all_gather), + patch_reduce_scatter(delayed_reduce_scatter), ): loss = model(inp).sum() loss.backward() diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index b3a32575fc0..96b2a8b4dfd 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -74,12 +74,12 @@ class TestFullyShardForwardInputs(FSDPTestMultiThread): # Check that FSDP moved the inputs to GPU, including recursing # into the tuple data structure assert x.device == device, f"Expects {device} but got {x.device}" - assert ( - ys[0].device == device - ), f"Expects {device} but got {ys[0].device}" - assert ( - ys[1].device == device - ), f"Expects {device} but got {ys[1].device}" + assert ys[0].device == device, ( + f"Expects {device} but got {ys[0].device}" + ) + assert ys[1].device == device, ( + f"Expects {device} but got {ys[1].device}" + ) y = ys[0] + ys[1] return x + y + 1 diff --git a/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py b/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py index ec85a668d74..89a893037c3 100644 --- a/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py +++ b/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py @@ -234,8 +234,8 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase): if __name__ == "__main__": - assert ( - not torch.cuda._initialized - ), "test_distributed must not have initialized CUDA context on main process" + assert not torch.cuda._initialized, ( + "test_distributed must not have initialized CUDA context on main process" + ) run_tests() diff --git a/test/distributed/algorithms/test_join.py b/test/distributed/algorithms/test_join.py index 60982d29cc6..8fd613a47d7 100644 --- a/test/distributed/algorithms/test_join.py +++ b/test/distributed/algorithms/test_join.py @@ -250,9 +250,11 @@ class TestJoin(MultiProcessTestCase): else "Detected at least one rank that exhausted inputs. " "Throwing across all ranks." ) - with self.assertRaisesRegex( - RuntimeError, expected_msg - ) if throw_on_early_termination else contextlib.nullcontext(): + with ( + self.assertRaisesRegex(RuntimeError, expected_msg) + if throw_on_early_termination + else contextlib.nullcontext() + ): with Join( allreducers, enable=enable, diff --git a/test/distributed/checkpoint/test_state_dict.py b/test/distributed/checkpoint/test_state_dict.py index b0a8ae3f58c..9c4f6fb005a 100644 --- a/test/distributed/checkpoint/test_state_dict.py +++ b/test/distributed/checkpoint/test_state_dict.py @@ -677,9 +677,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin): fully_shard(layer) fully_shard(model) optim = torch.optim.Adam(model.parameters(), lr=1e-2) - torch.optim.lr_scheduler.LambdaLR( - optim, lr_lambda=[lambda epoch: 0.95**epoch] - ) + torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=[lambda epoch: 0.95**epoch]) opt_state_dict = ptd_state_dict.get_optimizer_state_dict( model, optim, diff --git a/test/distributed/checkpoint/test_state_dict_stager.py b/test/distributed/checkpoint/test_state_dict_stager.py index 4b721bc4d19..57f3c014e88 100644 --- a/test/distributed/checkpoint/test_state_dict_stager.py +++ b/test/distributed/checkpoint/test_state_dict_stager.py @@ -228,9 +228,9 @@ class TestStateDictStager(TestCase): # Validate tensor count and bytes expected_storage_cnt = 2 - assert ( - num_storages == expected_storage_cnt - ), f"Expected {expected_storage_cnt} storages, got {num_storages}" + assert num_storages == expected_storage_cnt, ( + f"Expected {expected_storage_cnt} storages, got {num_storages}" + ) # Calculate expected bytes # Note: Only unique storages are counted in the byte count @@ -239,9 +239,9 @@ class TestStateDictStager(TestCase): + tensor3.numel() # tensor1 and tensor2 share storage * tensor3.element_size() # tensor3 and its narrow view share storage ) - assert ( - num_bytes == expected_bytes - ), f"Expected {expected_bytes} bytes, got {num_bytes}" + assert num_bytes == expected_bytes, ( + f"Expected {expected_bytes} bytes, got {num_bytes}" + ) # Verify that the CPU state dict is equivalent to the original CUDA state dict result, error = compare_state_dicts(state_dict, cpu_state_dict) assert result, f"State dicts are not equivalent: {error}" @@ -301,9 +301,9 @@ class TestStateDictStager(TestCase): # Verify the first result is correct result, error = compare_state_dicts(state_dict, cpu_state_dict1) - assert ( - result - ), f"First state dict is not equivalent to original: {error}" + assert result, ( + f"First state dict is not equivalent to original: {error}" + ) # Modify the original tensors tensor1.fill_(0) @@ -317,14 +317,14 @@ class TestStateDictStager(TestCase): # Verify that the second CPU state dict is equivalent to the modified original state dict result, error = compare_state_dicts(state_dict, cpu_state_dict2) - assert ( - result - ), f"Second state dict is not equivalent to modified original: {error}" + assert result, ( + f"Second state dict is not equivalent to modified original: {error}" + ) # Verify that the number of cached storages hasn't changed - assert ( - num_storages1 == num_storages2 - ), f"Storage count changed: {num_storages1} vs {num_storages2}" + assert num_storages1 == num_storages2, ( + f"Storage count changed: {num_storages1} vs {num_storages2}" + ) # Verify that the tensors in the second state dict have the same storage pointers as the first assert ( @@ -347,12 +347,12 @@ class TestStateDictStager(TestCase): cpu_state_dict3 = stager.stage(state_dict) # Verify that the third CPU state dict reflects the updated values - assert torch.all( - cpu_state_dict3["tensor1"] == 42.0 - ), "Updated values should be reflected in the cached state dict" - assert torch.all( - cpu_state_dict3["tensor2"] == 42.0 - ), "Updated values should be reflected in the cached state dict" + assert torch.all(cpu_state_dict3["tensor1"] == 42.0), ( + "Updated values should be reflected in the cached state dict" + ) + assert torch.all(cpu_state_dict3["tensor2"] == 42.0), ( + "Updated values should be reflected in the cached state dict" + ) @requires_cuda def test_tensor_attrs(self): @@ -381,24 +381,24 @@ class TestStateDictStager(TestCase): cpu_state_dict = stager.stage(state_dict) # Verify that tensor attributes are preserved - assert hasattr( - cpu_state_dict["tensor1"], "a" - ), "Tensor attribute 'a' was not preserved" - assert ( - cpu_state_dict["tensor1"].a == 42 - ), "Tensor attribute 'a' has incorrect value" - assert hasattr( - cpu_state_dict["tensor1"], "b" - ), "Tensor attribute 'b' was not preserved" - assert ( - cpu_state_dict["tensor1"].b == 43 - ), "Tensor attribute 'b' has incorrect value" - assert hasattr( - cpu_state_dict["recursive"]["tensor3"], "c" - ), "Tensor attribute 'c' was not preserved" - assert ( - cpu_state_dict["recursive"]["tensor3"].c == 44 - ), "Tensor attribute 'c' has incorrect value" + assert hasattr(cpu_state_dict["tensor1"], "a"), ( + "Tensor attribute 'a' was not preserved" + ) + assert cpu_state_dict["tensor1"].a == 42, ( + "Tensor attribute 'a' has incorrect value" + ) + assert hasattr(cpu_state_dict["tensor1"], "b"), ( + "Tensor attribute 'b' was not preserved" + ) + assert cpu_state_dict["tensor1"].b == 43, ( + "Tensor attribute 'b' has incorrect value" + ) + assert hasattr(cpu_state_dict["recursive"]["tensor3"], "c"), ( + "Tensor attribute 'c' was not preserved" + ) + assert cpu_state_dict["recursive"]["tensor3"].c == 44, ( + "Tensor attribute 'c' has incorrect value" + ) @requires_cuda def test_different_dtypes(self): diff --git a/test/distributed/elastic/multiprocessing/api_test.py b/test/distributed/elastic/multiprocessing/api_test.py index a6acc177ec8..6e0f273a7c8 100644 --- a/test/distributed/elastic/multiprocessing/api_test.py +++ b/test/distributed/elastic/multiprocessing/api_test.py @@ -500,11 +500,13 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS): logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()), ) - with mock.patch.object( - mpc, "_is_done", return_value=True - ), mock.patch.object(mpc, "_pc"), mock.patch.object( - mpc._pc, "join", side_effect=[True, False, False, True] - ) as mock_join: + with ( + mock.patch.object(mpc, "_is_done", return_value=True), + mock.patch.object(mpc, "_pc"), + mock.patch.object( + mpc._pc, "join", side_effect=[True, False, False, True] + ) as mock_join, + ): mpc._poll() self.assertEqual(4, mock_join.call_count) diff --git a/test/distributed/fsdp/test_distributed_checkpoint.py b/test/distributed/fsdp/test_distributed_checkpoint.py index 42111efc892..ac34246ee64 100644 --- a/test/distributed/fsdp/test_distributed_checkpoint.py +++ b/test/distributed/fsdp/test_distributed_checkpoint.py @@ -56,32 +56,36 @@ class TestDistributedCheckpoint(FSDPTest): torch.manual_seed(200) new_model = wrap(SkipModel(double_nest=True)) - with FullyShardedDataParallel.summon_full_params( - model - ), FullyShardedDataParallel.summon_full_params(new_model): + with ( + FullyShardedDataParallel.summon_full_params(model), + FullyShardedDataParallel.summon_full_params(new_model), + ): params = list(model.parameters()) new_params = list(new_model.parameters()) self.assertNotEqual(params, new_params) writer = FileSystemWriter(self.temp_dir) reader = FileSystemReader(self.temp_dir) - with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type( - new_model, state_dict_type + with ( + FSDP.state_dict_type(model, state_dict_type), + FSDP.state_dict_type(new_model, state_dict_type), ): state_dict = model.state_dict() save(state_dict, writer) - with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type( - new_model, state_dict_type + with ( + FSDP.state_dict_type(model, state_dict_type), + FSDP.state_dict_type(new_model, state_dict_type), ): state_dict = new_model.state_dict() load(state_dict, reader) new_model.load_state_dict(state_dict) - with FullyShardedDataParallel.summon_full_params( - model - ), FullyShardedDataParallel.summon_full_params(new_model): + with ( + FullyShardedDataParallel.summon_full_params(model), + FullyShardedDataParallel.summon_full_params(new_model), + ): params = list(model.parameters()) new_params = list(new_model.parameters()) self.assertEqual(params, new_params) diff --git a/test/distributed/fsdp/test_fsdp_comm.py b/test/distributed/fsdp/test_fsdp_comm.py index aedeb688977..42fa0316230 100644 --- a/test/distributed/fsdp/test_fsdp_comm.py +++ b/test/distributed/fsdp/test_fsdp_comm.py @@ -242,11 +242,10 @@ class TestCommunication(FSDPTest): # and if `use_no_sync=False`, we only run `num_iters` iterations # outside `no_sync()` num_iters = 3 - with patch( - "torch.distributed.all_gather_into_tensor" - ) as mock_all_gather, patch( - "torch.distributed.reduce_scatter_tensor" - ) as mock_reduce_scatter: + with ( + patch("torch.distributed.all_gather_into_tensor") as mock_all_gather, + patch("torch.distributed.reduce_scatter_tensor") as mock_reduce_scatter, + ): def reset_mocks(): mock_all_gather.reset_mock() diff --git a/test/distributed/fsdp/test_fsdp_core.py b/test/distributed/fsdp/test_fsdp_core.py index 5f8b88bb6e5..d6ee32c1f2e 100644 --- a/test/distributed/fsdp/test_fsdp_core.py +++ b/test/distributed/fsdp/test_fsdp_core.py @@ -379,12 +379,15 @@ class TestHooks(FSDPTest): register_pre_backward_hooks_call_count += 1 return orig_register_pre_backward_hooks(*args, **kwargs) - with mock.patch( - "torch.distributed.fsdp._runtime_utils._register_pre_backward_hooks", - _register_pre_backward_hooks_with_count, - ), mock.patch( - "torch.distributed.fsdp._runtime_utils._register_post_backward_hook" - ) as register_post_bwd_mock: + with ( + mock.patch( + "torch.distributed.fsdp._runtime_utils._register_pre_backward_hooks", + _register_pre_backward_hooks_with_count, + ), + mock.patch( + "torch.distributed.fsdp._runtime_utils._register_post_backward_hook" + ) as register_post_bwd_mock, + ): self.assertEqual(register_pre_backward_hooks_call_count, 0) self.assertFalse(register_post_bwd_mock.called) fsdp_model(*input) diff --git a/test/distributed/fsdp/test_fsdp_grad_acc.py b/test/distributed/fsdp/test_fsdp_grad_acc.py index 1e51938a033..b674b408462 100644 --- a/test/distributed/fsdp/test_fsdp_grad_acc.py +++ b/test/distributed/fsdp/test_fsdp_grad_acc.py @@ -152,9 +152,9 @@ class TestGradAcc(FSDPTest): batches.append(tuple(permute_tensor(t) for t in batch)) for batch1, batch2 in itertools.combinations(batches, r=2): for t1, t2 in zip(batch1, batch2): - assert not torch.all( - t1 == t2 - ), "Check the test to make sure that batches are distinct" + assert not torch.all(t1 == t2), ( + "Check the test to make sure that batches are distinct" + ) # Concatenate the batches along the given batch dimension concat_batch: tuple[torch.Tensor, ...] = tuple( diff --git a/test/distributed/fsdp/test_fsdp_hybrid_shard.py b/test/distributed/fsdp/test_fsdp_hybrid_shard.py index dc9b54be2dd..70c415ae1fe 100644 --- a/test/distributed/fsdp/test_fsdp_hybrid_shard.py +++ b/test/distributed/fsdp/test_fsdp_hybrid_shard.py @@ -121,8 +121,9 @@ class TestFSDPHybridShard(FSDPTest): def test_hsdp_save_load_state_dict(self): model = MyModel().cuda() num_node_devices = torch.cuda.device_count() - shard_rank_lists = list(range(0, num_node_devices // 2)), list( - range(num_node_devices // 2, num_node_devices) + shard_rank_lists = ( + list(range(0, num_node_devices // 2)), + list(range(num_node_devices // 2, num_node_devices)), ) shard_groups = ( dist.new_group(shard_rank_lists[0]), @@ -171,8 +172,9 @@ class TestFSDPHybridShard(FSDPTest): def test_hsdp_sync_module_state(self): model = MyModel().cuda() num_node_devices = torch.cuda.device_count() - shard_rank_lists = list(range(0, num_node_devices // 2)), list( - range(num_node_devices // 2, num_node_devices) + shard_rank_lists = ( + list(range(0, num_node_devices // 2)), + list(range(num_node_devices // 2, num_node_devices)), ) shard_groups = ( dist.new_group(shard_rank_lists[0]), @@ -310,8 +312,9 @@ class TestFSDPHybridShard(FSDPTest): cntr = Counter() patched_allreduce = partial(patched_collective, orig_ar, cntr) patched_reduce_scatter = partial(patched_collective, orig_rs, cntr) - with patch_allreduce(patched_allreduce), patch_reduce_scatter( - patched_reduce_scatter + with ( + patch_allreduce(patched_allreduce), + patch_reduce_scatter(patched_reduce_scatter), ): inp = hsdp_model.get_input(device=torch.cuda.current_device()) out = hsdp_model(inp[0], inp[1]) @@ -355,9 +358,9 @@ class TestFSDPHybridShard(FSDPTest): use_orig_params, hsdp_process_groups=hsdp_pgs, ) - assert ( - hsdp_model._inter_node_pg.size() > 1 - ), "HSDP model initialized without replication" + assert hsdp_model._inter_node_pg.size() > 1, ( + "HSDP model initialized without replication" + ) fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2) hsdp_optim = torch.optim.Adam(hsdp_model.parameters(), lr=1e-2) torch.manual_seed(global_pg.rank() + 1) diff --git a/test/distributed/fsdp/test_fsdp_mixed_precision.py b/test/distributed/fsdp/test_fsdp_mixed_precision.py index bb54f1c2d2c..dee38d04034 100644 --- a/test/distributed/fsdp/test_fsdp_mixed_precision.py +++ b/test/distributed/fsdp/test_fsdp_mixed_precision.py @@ -766,9 +766,9 @@ class TestFSDPMixedPrecisionSharded(TestFSDPMixedPrecision): if expect_use_full_prec_in_eval: assert x.dtype == torch.float32, f"Expected fp32, got {x.dtype}" else: - assert ( - x.dtype == low_prec_dtype - ), f"Expected {low_prec_dtype}, got {x.dtype}" + assert x.dtype == low_prec_dtype, ( + f"Expected {low_prec_dtype}, got {x.dtype}" + ) return self.a(x) mp_config = MixedPrecision( diff --git a/test/distributed/fsdp/test_fsdp_tp_integration.py b/test/distributed/fsdp/test_fsdp_tp_integration.py index 326157ec9e4..2cc3858e126 100644 --- a/test/distributed/fsdp/test_fsdp_tp_integration.py +++ b/test/distributed/fsdp/test_fsdp_tp_integration.py @@ -91,9 +91,9 @@ class TestTPFSDPIntegration(FSDPTest): tensor_parallel_size: int, ) -> tuple[dict[str, int], dict[str, tuple[torch.Size, int]]]: """ """ - assert ( - type(model) is SimpleModel - ), "Expects a `SimpleModel` since the sharding cases on the model definition" + assert type(model) is SimpleModel, ( + "Expects a `SimpleModel` since the sharding cases on the model definition" + ) param_name_to_numel = OrderedDict() param_name_to_sharding_info = OrderedDict() for param_name, param in model.named_parameters(): diff --git a/test/distributed/fsdp/test_fsdp_use_orig_params.py b/test/distributed/fsdp/test_fsdp_use_orig_params.py index a0e1d0a50cc..7efe6ec6661 100644 --- a/test/distributed/fsdp/test_fsdp_use_orig_params.py +++ b/test/distributed/fsdp/test_fsdp_use_orig_params.py @@ -654,9 +654,12 @@ class TestFSDPUseOrigParamsUnshardReshard(FSDPTest): losses1 = [] losses2 = [] losses = [] - for _model, _optim in (fsdp_model, optim), ( - fsdp_model_orig_params, - optim_orig_params, + for _model, _optim in ( + (fsdp_model, optim), + ( + fsdp_model_orig_params, + optim_orig_params, + ), ): _optim.zero_grad() loss1 = _model(*inp1) @@ -1166,9 +1169,9 @@ class TestFSDPUseOrigParamsFQNs(FSDPTest): clean_tensor_name(tup[0]) for tup in self.named_parameters() ] params = [tup[1] for tup in self.named_parameters()] - assert ( - param_shapes[0] is not None and param_shapes[1] is not None - ), "`param_sizes` should be set" + assert param_shapes[0] is not None and param_shapes[1] is not None, ( + "`param_sizes` should be set" + ) assert_equal_fn( param_names, [ diff --git a/test/distributed/launcher/bin/test_script_is_torchelastic_launched.py b/test/distributed/launcher/bin/test_script_is_torchelastic_launched.py index 691c43ddb54..f3ab4090e8d 100755 --- a/test/distributed/launcher/bin/test_script_is_torchelastic_launched.py +++ b/test/distributed/launcher/bin/test_script_is_torchelastic_launched.py @@ -19,6 +19,7 @@ The script itself is not a test case hence no assertions are made in this script see: - test/distributed/launcher/run_test.py#test_is_torchelastic_launched() - test/distributed/launcher/run_test.py#test_is_not_torchelastic_launched() """ + import argparse import torch.distributed as dist diff --git a/test/distributed/optim/test_zero_redundancy_optimizer.py b/test/distributed/optim/test_zero_redundancy_optimizer.py index 0f0ee84cee3..603f671546a 100644 --- a/test/distributed/optim/test_zero_redundancy_optimizer.py +++ b/test/distributed/optim/test_zero_redundancy_optimizer.py @@ -292,9 +292,9 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer): betas=BETAS, eps=EPS, ) - assert ( - len(o.param_groups) == 2 - ), f"Expected 2 ZeRO param groups, but got {len(o.param_groups)}" + assert len(o.param_groups) == 2, ( + f"Expected 2 ZeRO param groups, but got {len(o.param_groups)}" + ) assert len(o.optim.param_groups) == 2, ( "Expected 2 local optimizer param groups, but got " f"{len(o.optim.param_groups)}" @@ -713,9 +713,9 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer): LR = 1e-3 MOMENTUM = 0.99 REFERENCE_RANK = 0 - assert ( - REFERENCE_RANK in subgroup_ranks - ), "Reference rank must be in the new process group" + assert REFERENCE_RANK in subgroup_ranks, ( + "Reference rank must be in the new process group" + ) loss_fn = torch.nn.L1Loss().to(device) def check(optimizer): @@ -1165,22 +1165,28 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer): # Increased tolerances are needed to pass when using TF32 # See: https://github.com/pytorch/pytorch/issues/67764 - torch.testing.assert_close( - local_loss.cpu(), - ddp_loss.cpu(), - rtol=1e-03, - atol=1e-08, - ), "Losses differ between local optimizer and ZeRO" + ( + torch.testing.assert_close( + local_loss.cpu(), + ddp_loss.cpu(), + rtol=1e-03, + atol=1e-08, + ), + "Losses differ between local optimizer and ZeRO", + ) for local_p, ddp_p in zip( local_model.parameters(), ddp_model.parameters() ): - torch.testing.assert_close( - local_p.cpu(), - ddp_p.cpu(), - rtol=1e-03, - atol=1e-04, - ), "Models differ after a step" + ( + torch.testing.assert_close( + local_p.cpu(), + ddp_p.cpu(), + rtol=1e-03, + atol=1e-04, + ), + "Models differ after a step", + ) @skipIfHpu @skip_if_lt_x_gpu(4) diff --git a/test/distributed/pipelining/test_pipe.py b/test/distributed/pipelining/test_pipe.py index 3e02c4de3c9..8ddb5634811 100644 --- a/test/distributed/pipelining/test_pipe.py +++ b/test/distributed/pipelining/test_pipe.py @@ -89,9 +89,9 @@ class PipeTests(TestCase): mb_args=(x, y), ) - assert ( - pipe.num_stages == EXPECTED_N_STAGES[ModelClass] - ), f"nstages = {pipe.num_stages}, expect {EXPECTED_N_STAGES[ModelClass]}" + assert pipe.num_stages == EXPECTED_N_STAGES[ModelClass], ( + f"nstages = {pipe.num_stages}, expect {EXPECTED_N_STAGES[ModelClass]}" + ) ref_out = mod(x, y) out = pipe(x, y)[0] @@ -109,9 +109,7 @@ class PipeTests(TestCase): new_names.update(stage_fqns) if CHECK_FQN_SET_EQUALITY: - assert ( - old_names == new_names - ), f""" + assert old_names == new_names, f""" old names {old_names} new names {new_names} """ diff --git a/test/distributed/pipelining/test_unflatten.py b/test/distributed/pipelining/test_unflatten.py index 5fb30b5e1d1..ae1e684d7c2 100644 --- a/test/distributed/pipelining/test_unflatten.py +++ b/test/distributed/pipelining/test_unflatten.py @@ -60,9 +60,9 @@ class UnflattenTests(TestCase): for stage_idx in range(pipe.num_stages): stage_mod = pipe.get_stage_module(stage_idx) for param_name, _ in stage_mod.named_parameters(): - assert ( - param_name in orig_state_dict - ), f"{param_name} not in original state dict" + assert param_name in orig_state_dict, ( + f"{param_name} not in original state dict" + ) print("Param qualname test passed") # Check equivalence diff --git a/test/distributed/rpc/test_share_memory.py b/test/distributed/rpc/test_share_memory.py index bda98b1df94..97273981d08 100644 --- a/test/distributed/rpc/test_share_memory.py +++ b/test/distributed/rpc/test_share_memory.py @@ -45,9 +45,9 @@ class ShareMemoryRPCPickler(_InternalRPCPickler): for t in torch._tensor_classes: self._dispatch_table[t] = TorchMpReductions.reduce_tensor self._dispatch_table[torch.Tensor] = TorchMpReductions.reduce_tensor - self._dispatch_table[ - torch.nn.parameter.Parameter - ] = TorchMpReductions.reduce_tensor + self._dispatch_table[torch.nn.parameter.Parameter] = ( + TorchMpReductions.reduce_tensor + ) def worker_loop(a): diff --git a/test/distributed/tensor/test_dtensor_compile.py b/test/distributed/tensor/test_dtensor_compile.py index c0859b5925f..23114f87f46 100644 --- a/test/distributed/tensor/test_dtensor_compile.py +++ b/test/distributed/tensor/test_dtensor_compile.py @@ -872,9 +872,7 @@ def forward(self, primals_1): "buf0 = torch.ops._c10d_functional.all_gather_into_tensor.default(primal" ).check("torch.ops._c10d_functional.wait_tensor.default(buf0").check( "extern_kernels.mm(buf0," - ).run( - code - ) + ).run(code) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(1) diff --git a/test/distributed/tensor/test_math_ops.py b/test/distributed/tensor/test_math_ops.py index 0ce1206ae1b..48f92c4ecd7 100644 --- a/test/distributed/tensor/test_math_ops.py +++ b/test/distributed/tensor/test_math_ops.py @@ -441,9 +441,11 @@ class DistMathOpsTest(DTensorTestBase): out_req_grad: bool subtest_fails = {} - valid_filter = lambda cfg: not ( # noqa: E731 - cfg.ln_req_grad and not cfg.elementwise_affine - ) and any(cfg[2:]) + valid_filter = ( # noqa: E731 + lambda cfg: ( + not (cfg.ln_req_grad and not cfg.elementwise_affine) and any(cfg[2:]) + ) + ) subtest_cfgs = list( filter( valid_filter, @@ -566,9 +568,9 @@ class DistMathOpsTest(DTensorTestBase): except Exception as e: subtest_fails[subtest_cfg] = e # if any subtest fails, provide the failed subtests and report the overall failure - assert ( - not subtest_fails - ), f"{len(subtest_fails)}/{len(subtest_cfgs)} subtests failed: {pformat(subtest_fails)}" + assert not subtest_fails, ( + f"{len(subtest_fails)}/{len(subtest_cfgs)} subtests failed: {pformat(subtest_fails)}" + ) @with_comms def test_topk(self): diff --git a/test/distributed/tensor/test_xla_integration.py b/test/distributed/tensor/test_xla_integration.py index 179b5bc796c..3fbfcffbd76 100644 --- a/test/distributed/tensor/test_xla_integration.py +++ b/test/distributed/tensor/test_xla_integration.py @@ -26,7 +26,9 @@ def with_xla(func: Callable) -> Callable: @wraps(func) # pyre-ignore[6] def wrapper( - self, *args: tuple[object], **kwargs: dict[str, Any] # type: ignore[misc] + self, + *args: tuple[object], + **kwargs: dict[str, Any], # type: ignore[misc] ) -> None: # TODO(yeounoh) replace this with xr.use_spmd() when we deprecate the flag. os.environ["XLA_USE_SPMD"] = "1" diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index 8590f25a351..efac131e6c3 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -2234,8 +2234,8 @@ class LocalRankTest(MultiProcessTestCase): if __name__ == "__main__": - assert ( - not torch.cuda._initialized - ), "test_distributed must not have initialized CUDA context on main process" + assert not torch.cuda._initialized, ( + "test_distributed must not have initialized CUDA context on main process" + ) run_tests() diff --git a/test/distributed/test_c10d_functional_native.py b/test/distributed/test_c10d_functional_native.py index ebdf2c0dcdc..e49fb2b1036 100644 --- a/test/distributed/test_c10d_functional_native.py +++ b/test/distributed/test_c10d_functional_native.py @@ -796,13 +796,11 @@ class CompileTest(TestCase): .check("buf6 = empty") # Expect in-place with inductor allocated buf .check( - "torch.ops._c10d_functional.all_reduce_coalesced_" - ".default([buf0, buf1]" + "torch.ops._c10d_functional.all_reduce_coalesced_.default([buf0, buf1]" ) # Expect no in-place with graph input (buf5, buf6 are clones) .check( - "torch.ops._c10d_functional.all_reduce_coalesced_" - ".default([buf5, buf6]" + "torch.ops._c10d_functional.all_reduce_coalesced_.default([buf5, buf6]" ) .check("torch.ops._c10d_functional.wait_tensor.default(buf0") .check("torch.ops._c10d_functional.wait_tensor.default(buf1") diff --git a/test/distributed/test_c10d_gloo.py b/test/distributed/test_c10d_gloo.py index b770d52c01e..96ad01b95b1 100644 --- a/test/distributed/test_c10d_gloo.py +++ b/test/distributed/test_c10d_gloo.py @@ -2705,8 +2705,8 @@ class LargeCommTest(test_c10d_common.AbstractLargeCommTest, MultiProcessTestCase if __name__ == "__main__": - assert ( - not torch.cuda._initialized - ), "test_distributed must not have initialized CUDA context on main process" + assert not torch.cuda._initialized, ( + "test_distributed must not have initialized CUDA context on main process" + ) run_tests() diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 6240f131518..c02e968e23f 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -2869,9 +2869,9 @@ class NcclErrorHandlingTest(MultiProcessTestCase): self.assertTrue(t.is_alive()) if prev_nccl_async_error_handling is not None: - os.environ[ - "TORCH_NCCL_ASYNC_ERROR_HANDLING" - ] = prev_nccl_async_error_handling + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = ( + prev_nccl_async_error_handling + ) @requires_nccl() @skip_if_lt_x_gpu(3) @@ -2931,9 +2931,9 @@ class NcclErrorHandlingTest(MultiProcessTestCase): os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0" self._test_barrier_error() if prev_nccl_async_error_handling is not None: - os.environ[ - "TORCH_NCCL_ASYNC_ERROR_HANDLING" - ] = prev_nccl_async_error_handling + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = ( + prev_nccl_async_error_handling + ) @requires_nccl() @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking") @@ -2984,9 +2984,9 @@ class NcclErrorHandlingTest(MultiProcessTestCase): process_group.abort() if prev_nccl_async_error_handling is not None: - os.environ[ - "TORCH_NCCL_ASYNC_ERROR_HANDLING" - ] = prev_nccl_async_error_handling + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = ( + prev_nccl_async_error_handling + ) @requires_nccl() @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking") @@ -3065,9 +3065,9 @@ class NcclErrorHandlingTest(MultiProcessTestCase): os.remove(new_file_name) if prev_nccl_async_error_handling is not None: - os.environ[ - "TORCH_NCCL_ASYNC_ERROR_HANDLING" - ] = prev_nccl_async_error_handling + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = ( + prev_nccl_async_error_handling + ) def _run_invalid_nccl_blocking_wait_env(self, val): os.environ["TORCH_NCCL_BLOCKING_WAIT"] = val @@ -3360,9 +3360,7 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase): self.assertEqual(_get_intra_node_comm_usage_counter(), 3) # Verify that IntraNodeComm is not used beyond 10MB - t = torch.full( - (10 * 1024**2 // 2 + 1,), self.rank, dtype=torch.bfloat16 - ).cuda() + t = torch.full((10 * 1024**2 // 2 + 1,), self.rank, dtype=torch.bfloat16).cuda() c10d.all_reduce(t, c10d.ReduceOp.SUM) self.assertTrue(t.eq(expect).all()) self.assertEqual(_get_intra_node_comm_usage_counter(), 3) @@ -4249,9 +4247,9 @@ class SparseCollective(MultiProcessTestCase): class NCCLTraceTestBase(MultiProcessTestCase): def setUp(self): super().setUp() - os.environ[ - "TORCH_NCCL_ENABLE_TIMING" - ] = "0" # see 'timing_enabled' parametrized tests + os.environ["TORCH_NCCL_ENABLE_TIMING"] = ( + "0" # see 'timing_enabled' parametrized tests + ) os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "1000" os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "1" self.tempdir = tempfile.TemporaryDirectory() @@ -5331,8 +5329,8 @@ class ProcessGroupNCCLLargerScaleTest(MultiProcessTestCase): if __name__ == "__main__": - assert ( - not torch.cuda._initialized - ), "test_distributed must not have initialized CUDA context on main process" + assert not torch.cuda._initialized, ( + "test_distributed must not have initialized CUDA context on main process" + ) run_tests() diff --git a/test/distributed/test_c10d_ucc.py b/test/distributed/test_c10d_ucc.py index 5e7af710c8a..e3a4764d594 100644 --- a/test/distributed/test_c10d_ucc.py +++ b/test/distributed/test_c10d_ucc.py @@ -1090,8 +1090,8 @@ class UccProcessGroupWithDispatchedCollectivesTests( if __name__ == "__main__": - assert ( - not torch.cuda._initialized - ), "test_distributed must not have initialized CUDA context on main process" + assert not torch.cuda._initialized, ( + "test_distributed must not have initialized CUDA context on main process" + ) run_tests() diff --git a/test/distributed/test_collective_utils.py b/test/distributed/test_collective_utils.py index 2999f318d69..a150a55f77b 100644 --- a/test/distributed/test_collective_utils.py +++ b/test/distributed/test_collective_utils.py @@ -90,9 +90,9 @@ class TestCollectiveUtils(MultiProcessTestCase): res = all_gather(data_or_fn=func, pg=pg) func.assert_called_once() - assert res == list( - range(self.world_size) - ), f"Expect res to be list of 0 through {self.world_size} (got {res})" + assert res == list(range(self.world_size)), ( + f"Expect res to be list of 0 through {self.world_size} (got {res})" + ) def test_all_gather_result_no_pg(self) -> None: """ diff --git a/test/distributed/test_control_collectives.py b/test/distributed/test_control_collectives.py index 594c028ae9d..8e48735c777 100644 --- a/test/distributed/test_control_collectives.py +++ b/test/distributed/test_control_collectives.py @@ -207,8 +207,8 @@ class TestCollectives(TestCase): if __name__ == "__main__": - assert ( - not torch.cuda._initialized - ), "test_distributed must not have initialized CUDA context on main process" + assert not torch.cuda._initialized, ( + "test_distributed must not have initialized CUDA context on main process" + ) run_tests() diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index 7ad4c33de43..06502943934 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -490,8 +490,9 @@ class DeviceMeshTestNDim(DTensorTestBase): # Create shard groups (e.g. (0, 1, 2, 3), (4, 5, 6, 7)) # and assign the correct shard group to each rank - shard_rank_lists = list(range(0, self.world_size // 2)), list( - range(self.world_size // 2, self.world_size) + shard_rank_lists = ( + list(range(0, self.world_size // 2)), + list(range(self.world_size // 2, self.world_size)), ) shard_groups = ( new_group(shard_rank_lists[0]), diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index d43c0c3e3e0..8446282c84f 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -1830,9 +1830,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase): f"""{expected_guard_source} "L['self']._modules['net']" TYPE_MATCH""" ).check( f"""{expected_guard_source} "L['self']._modules['net']._modules['0']" TYPE_MATCH""" - ).run( - GUARDS_FILE.getvalue() - ) + ).run(GUARDS_FILE.getvalue()) self.assertTrue(same(correct_outputs, outputs)) diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index 6c33a6031d2..77dd871a552 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -519,12 +519,13 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase): out = a2a / a2a.sum(dim=0) return out - with _dynamo_dist_per_rank_init( - self.rank, self.world_size - ), torch._dynamo.config.patch( - dynamic_shapes=True, - capture_dynamic_output_shape_ops=True, - capture_scalar_outputs=True, + with ( + _dynamo_dist_per_rank_init(self.rank, self.world_size), + torch._dynamo.config.patch( + dynamic_shapes=True, + capture_dynamic_output_shape_ops=True, + capture_scalar_outputs=True, + ), ): row = self.world_size * (self.rank + 1) * (self.world_size + 1) / 2 input_split_sizes_tensor = torch.tensor( @@ -680,15 +681,15 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase): return torch.ops.custom_ns.foo(a2a) - with _dynamo_dist_per_rank_init( - self.rank, self.world_size - ), torch._dynamo.config.patch( - dynamic_shapes=True, - capture_dynamic_output_shape_ops=True, - capture_scalar_outputs=True, - ), torch.library._scoped_library( - "custom_ns", "FRAGMENT" - ) as lib: + with ( + _dynamo_dist_per_rank_init(self.rank, self.world_size), + torch._dynamo.config.patch( + dynamic_shapes=True, + capture_dynamic_output_shape_ops=True, + capture_scalar_outputs=True, + ), + torch.library._scoped_library("custom_ns", "FRAGMENT") as lib, + ): lib.define( "alltoall_autograd(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor" # noqa: B950 ) diff --git a/test/distributed/test_pg_wrapper.py b/test/distributed/test_pg_wrapper.py index d7e59f1c90a..4c96d4b564d 100644 --- a/test/distributed/test_pg_wrapper.py +++ b/test/distributed/test_pg_wrapper.py @@ -464,8 +464,8 @@ class ProcessGroupGlooWrapperTest(AbstractProcessGroupWrapperTest): if __name__ == "__main__": - assert ( - not torch.cuda._initialized - ), "test_pg_wrapper must not have initialized CUDA context on main process" + assert not torch.cuda._initialized, ( + "test_pg_wrapper must not have initialized CUDA context on main process" + ) run_tests() diff --git a/test/distributed/test_store.py b/test/distributed/test_store.py index 25a554942c8..e9abb1d9071 100644 --- a/test/distributed/test_store.py +++ b/test/distributed/test_store.py @@ -372,12 +372,8 @@ class TCPStoreTest(TestCase, StoreTestBase): # Use noqa to silence flake8. # Need to store in an unused variable here to ensure the first # object is not destroyed before the second object is created. - store1 = dist.TCPStore( - addr, port, 1, True, use_libuv=self._use_libuv - ) # noqa: F841 - store2 = dist.TCPStore( - addr, port, 1, True, use_libuv=self._use_libuv - ) # noqa: F841 + store1 = dist.TCPStore(addr, port, 1, True, use_libuv=self._use_libuv) # noqa: F841 + store2 = dist.TCPStore(addr, port, 1, True, use_libuv=self._use_libuv) # noqa: F841 self.assertEqual(store1.libuvBackend, self._use_libuv) self.assertEqual(store2.libuvBackend, self._use_libuv) @@ -767,7 +763,7 @@ class RendezvousFileTest(TestCase): def test_nominal(self): with tempfile.NamedTemporaryFile(delete=False) as file: - url = f'file:///{file.name.replace(os.path.sep, "/")}?world_size=2' + url = f"file:///{file.name.replace(os.path.sep, '/')}?world_size=2" gen0 = dist.rendezvous(url + "&rank=0") store0, rank0, size0 = next(gen0) self.assertEqual(0, rank0) @@ -1178,8 +1174,8 @@ class TestClientProtocol(TestCase): if __name__ == "__main__": - assert ( - not torch.cuda._initialized - ), "test_distributed must not have initialized CUDA context on main process" + assert not torch.cuda._initialized, ( + "test_distributed must not have initialized CUDA context on main process" + ) run_tests() diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index c1ab329a137..6699c973052 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -81,9 +81,9 @@ def count_ops( for node in gm.graph.nodes: if match_rng_op(node, op) or node.target == op: actual_count += 1 - assert ( - actual_count >= freq_ge - ), f"In graph {gm}, expected {op} to have occurred at least {freq_ge} times in the graph, but got {actual_count}." + assert actual_count >= freq_ge, ( + f"In graph {gm}, expected {op} to have occurred at least {freq_ge} times in the graph, but got {actual_count}." + ) return gm diff --git a/test/dynamo/test_autograd_function.py b/test/dynamo/test_autograd_function.py index 80aa5c1025f..978a23f8094 100644 --- a/test/dynamo/test_autograd_function.py +++ b/test/dynamo/test_autograd_function.py @@ -455,9 +455,9 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase): # Modify gradient using .data (Dangerous: Breaks autograd tracking!) modified_grad = grad_output.clone() - modified_grad.data[ - input_tensor.data < 0 - ] = 0 # Zero-out gradients for negative inputs + modified_grad.data[input_tensor.data < 0] = ( + 0 # Zero-out gradients for negative inputs + ) return modified_grad * 3 diff --git a/test/dynamo/test_base_hop.py b/test/dynamo/test_base_hop.py index b185a1a1333..18cdf78c61f 100644 --- a/test/dynamo/test_base_hop.py +++ b/test/dynamo/test_base_hop.py @@ -195,10 +195,13 @@ class GraphModule(torch.nn.Module): def f(x, y): return invoke_quant_test(inner, [x, y], scheme="nf4") - with mock.patch( - "torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation", - True, - ), torch.no_grad(): + with ( + mock.patch( + "torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation", + True, + ), + torch.no_grad(), + ): torch.compile(f, backend=bk, fullgraph=True)(x.clone(), y) self.assertEqual(len(bk.graphs), 1) @@ -319,10 +322,13 @@ class GraphModule(torch.nn.Module): x = torch.randn(3, 3, requires_grad=False) x_clone = x.clone() y = torch.randn(3, 3, requires_grad=True) - with mock.patch( - "torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation", - True, - ), torch.no_grad(): + with ( + mock.patch( + "torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation", + True, + ), + torch.no_grad(), + ): compiled_out = torch.compile(f, backend=backend, fullgraph=True)(x, y) self.assertEqual(x, x_clone + 1) self.assertEqual(compiled_out, x_clone + y + 1) diff --git a/test/dynamo/test_bytecode_utils.py b/test/dynamo/test_bytecode_utils.py index fa906a2ac16..b91b8156ec1 100644 --- a/test/dynamo/test_bytecode_utils.py +++ b/test/dynamo/test_bytecode_utils.py @@ -53,8 +53,8 @@ class BytecodeTests(torch._dynamo.test_case.TestCase): fn_str = f"""\ def fn(): foo.bar(1, 2, 3) -{str(chr(10)).join(' ' * 4 + 'x' + str(i) + ' = 1' for i in range(1 << 9))} - l = [{' '.join('x' + str(i) + ',' for i in range(1 << 9))}] +{str(chr(10)).join(" " * 4 + "x" + str(i) + " = 1" for i in range(1 << 9))} + l = [{" ".join("x" + str(i) + "," for i in range(1 << 9))}] """ locals = {} exec(fn_str, {}, locals) diff --git a/test/dynamo/test_callback.py b/test/dynamo/test_callback.py index 86bd692e8d5..1d221f63553 100644 --- a/test/dynamo/test_callback.py +++ b/test/dynamo/test_callback.py @@ -26,9 +26,10 @@ class CallbackTests(TestCase): def test_callbacks_with_duplicate_prevention(self) -> None: trigger = CallbackTrigger.DYNAMO compile_id = CompileId(0, 0) - with callback_handler.install_callbacks( - trigger, compile_id - ), callback_handler.install_callbacks(trigger, compile_id): + with ( + callback_handler.install_callbacks(trigger, compile_id), + callback_handler.install_callbacks(trigger, compile_id), + ): self._on_compile_start.assert_called_once() self._on_compile_end.assert_called_once() diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py index d0216ed5903..3f0edd939a5 100644 --- a/test/dynamo/test_ctx_manager.py +++ b/test/dynamo/test_ctx_manager.py @@ -1252,10 +1252,13 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase): def f(x, y): return x + y - x, y = torch.ones( - 1, - ), torch.zeros( - 1, + x, y = ( + torch.ones( + 1, + ), + torch.zeros( + 1, + ), ) return f(x, y) @@ -1289,10 +1292,13 @@ class GraphModule(torch.nn.Module): def f(x, y): return x + y - x, y = torch.ones( - 1, - ), torch.zeros( - 1, + x, y = ( + torch.ones( + 1, + ), + torch.zeros( + 1, + ), ) return f(x, y) @@ -1335,10 +1341,13 @@ class GraphModule(torch.nn.Module): return inner_fn(x, y) + x - x, y = torch.ones( - 1, - ), torch.zeros( - 1, + x, y = ( + torch.ones( + 1, + ), + torch.zeros( + 1, + ), ) return f(x, y) diff --git a/test/dynamo/test_exceptions.py b/test/dynamo/test_exceptions.py index 07bb5760326..f3a2a1d7c77 100644 --- a/test/dynamo/test_exceptions.py +++ b/test/dynamo/test_exceptions.py @@ -19,7 +19,7 @@ from torch.testing._internal.common_utils import ( class CustomException(Exception): - ... + pass class CustomExceptionMeta(type): @@ -28,7 +28,7 @@ class CustomExceptionMeta(type): class CustomExceptionWithInstanceCheck(Exception, metaclass=CustomExceptionMeta): - ... + pass class CustomExceptionWithArgs(Exception): @@ -358,7 +358,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase): def test_raise_custom_exception(self): class Exc(Exception): - ... + pass @torch.compile(backend="eager", fullgraph=True) def fn(t): @@ -375,7 +375,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase): def test_raise_custom_exception_with_args(self): class Exc(Exception): - ... + pass @torch.compile(backend="eager", fullgraph=True) def fn(t): diff --git a/test/dynamo/test_guard_serialization.py b/test/dynamo/test_guard_serialization.py index 6fa40064beb..9e93f3048ea 100644 --- a/test/dynamo/test_guard_serialization.py +++ b/test/dynamo/test_guard_serialization.py @@ -324,9 +324,13 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase): distributed_state=None, package=None, ) - with compile_context(CompileContext(CompileId(0, 0))), tracing( - tracer.output.tracing_context - ), tracer.set_current_tx(), get_metrics_context(), dynamo_timed(""): + with ( + compile_context(CompileContext(CompileId(0, 0))), + tracing(tracer.output.tracing_context), + tracer.set_current_tx(), + get_metrics_context(), + dynamo_timed(""), + ): tracer.run() check_fn_manager = CheckFunctionManager( diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index ced119969ac..50ede0b5465 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -1092,9 +1092,7 @@ not ___dict_contains('bbbbbbbb', G['sys'].modules) not ___dict_contains('cccccccc', G['sys'].modules) str(L['x'].device) == 'cpu' str(L['x'].dtype) == 'torch.float32' -utils_device.CURRENT_DEVICE == None""".split( - "\n" - ): +utils_device.CURRENT_DEVICE == None""".split("\n"): self.assertIn( line, guard_code_str, @@ -2806,7 +2804,7 @@ utils_device.CURRENT_DEVICE == None""".split( "int", np.intp, np.int32, - np.uint8 + np.uint8, # np.dtype('int') # XXX: as above ] @@ -5527,9 +5525,9 @@ utils_device.CURRENT_DEVICE == None""".split( def forward(self, idx, targets=None): b, t = idx.size() - assert ( - t <= self.block_size - ), "Cannot forward, model block size is exhausted." + assert t <= self.block_size, ( + "Cannot forward, model block size is exhausted." + ) # forward the GPT model token_embeddings = self.tok_emb( @@ -6075,15 +6073,17 @@ utils_device.CURRENT_DEVICE == None""".split( def count_graph_break_msgs(msgs): return sum("Graph break in user code" in msg for msg in msgs) - with self.assertLogs( - logger="torch._dynamo", level=logging.DEBUG - ) as log, torch._dynamo.config.patch(verbose=True): + with ( + self.assertLogs(logger="torch._dynamo", level=logging.DEBUG) as log, + torch._dynamo.config.patch(verbose=True), + ): f1(torch.randn(10), torch.randn(10)) self.assertGreater(count_graph_break_msgs(log.output), 1) - with self.assertLogs( - logger="torch._dynamo", level=logging.DEBUG - ) as log, torch._dynamo.config.patch(verbose=False): + with ( + self.assertLogs(logger="torch._dynamo", level=logging.DEBUG) as log, + torch._dynamo.config.patch(verbose=False), + ): g1(torch.randn(10), torch.randn(10)) self.assertEqual(count_graph_break_msgs(log.output), 1) @@ -8235,8 +8235,9 @@ utils_device.CURRENT_DEVICE == None""".split( def f(a): return h(a) - with warnings.catch_warnings(record=True) as w, self.assertRaises( - torch._dynamo.exc.BackendCompilerFailed + with ( + warnings.catch_warnings(record=True) as w, + self.assertRaises(torch._dynamo.exc.BackendCompilerFailed), ): f(torch.randn(2, 2, requires_grad=True)) @@ -8429,8 +8430,7 @@ utils_device.CURRENT_DEVICE == None""".split( def test_torch_compile_ctx_on_forward_and_training_step(self): class MyModel(torch.nn.Module): - def forward(self): - ... + def forward(self): ... def training_step(self): self() diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 9a8fe50bc8e..b6cb548647a 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -2094,11 +2094,12 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase): mod = MockModule() # Each submod is compiled separately and has a different nn module # guard. Ensure that recompilation logic is handle correctly. - with unittest.mock.patch( - "torch._dynamo.config.error_on_recompile", True - ), unittest.mock.patch( - "torch._dynamo.config.recompile_limit", - recompile_limit, + with ( + unittest.mock.patch("torch._dynamo.config.error_on_recompile", True), + unittest.mock.patch( + "torch._dynamo.config.recompile_limit", + recompile_limit, + ), ): x = torch.randn(*size, requires_grad=True) mod(x) @@ -2160,11 +2161,12 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase): mod = MockModule() # Each submod is compiled separately and has a different nn module # guard. Ensure that recompilation logic is handle correctly. - with unittest.mock.patch( - "torch._dynamo.config.error_on_recompile", True - ), unittest.mock.patch( - "torch._dynamo.config.recompile_limit", - recompile_limit, + with ( + unittest.mock.patch("torch._dynamo.config.error_on_recompile", True), + unittest.mock.patch( + "torch._dynamo.config.recompile_limit", + recompile_limit, + ), ): x = torch.randn(*size, requires_grad=True) mod(x) diff --git a/test/dynamo/test_optimizers.py b/test/dynamo/test_optimizers.py index 614baec1e3d..e74ebc22587 100644 --- a/test/dynamo/test_optimizers.py +++ b/test/dynamo/test_optimizers.py @@ -3,6 +3,7 @@ PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes with test_adam in OptimizerTests) """ + import functools import torch diff --git a/test/dynamo/test_recompile_ux.py b/test/dynamo/test_recompile_ux.py index 4507d339462..e69c23c9524 100644 --- a/test/dynamo/test_recompile_ux.py +++ b/test/dynamo/test_recompile_ux.py @@ -238,9 +238,7 @@ class RecompileUxTests(torch._dynamo.test_case.TestCase): tensor 'x' size mismatch at index 0. expected 11, actual 12 tensor 'x' size mismatch at index 0. expected 10, actual 12 tensor 'x' size mismatch at index 0. expected 9, actual 12 -tensor 'x' size mismatch at index 0. expected 8, actual 12""".split( - "\n" - ): +tensor 'x' size mismatch at index 0. expected 8, actual 12""".split("\n"): self.assertIn( line, failure_str, @@ -276,9 +274,7 @@ tensor 'x' size mismatch at index 0. expected 8, actual 12""".split( opt_f([7, 8]) for line in """\ -len(x) == 3""".split( - "\n" - ): +len(x) == 3""".split("\n"): self.assertIn(line, filter_reasons()) failure_reasons.clear() @@ -286,9 +282,7 @@ len(x) == 3""".split( for line in """\ len(x) == 2 -len(x) == 3""".split( - "\n" - ): +len(x) == 3""".split("\n"): self.assertIn(line, filter_reasons()) @torch._dynamo.config.patch(recompile_limit=1) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 49fcacd3342..c9aa02a44ad 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -179,9 +179,9 @@ def shapes_to_tensor(x, device=None): if torch.jit.is_scripting(): return torch.as_tensor(x, device=device) if torch.jit.is_tracing(): - assert all( - isinstance(t, torch.Tensor) for t in x - ), "Shape should be tensor during tracing!" + assert all(isinstance(t, torch.Tensor) for t in x), ( + "Shape should be tensor during tracing!" + ) # as_tensor should not be used in tracing because it records a constant ret = torch.stack(x) if ret.device != device: # avoid recording a hard-coded device if not necessary @@ -480,9 +480,9 @@ class PartialT5(torch.nn.Module): real_seq_length = seq_length if past_key_value is not None: - assert ( - len(past_key_value) == 2 - ), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" + assert len(past_key_value) == 2, ( + f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" + ) real_seq_length += ( past_key_value[0].shape[2] if query_length is None else query_length ) @@ -4877,9 +4877,9 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor): with warnings.catch_warnings(record=True): data_len = len(value) if len(self._fields): - assert ( - len(self) == data_len - ), f"Adding a field of length {data_len} to a Instances of length {len(self)}" + assert len(self) == data_len, ( + f"Adding a field of length {data_len} to a Instances of length {len(self)}" + ) self._fields[name] = value def get(self, name: str) -> Any: diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index b741c6b5b9c..c4e0fdadeed 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -108,9 +108,10 @@ def get_view_test_cases(): for requires_grad_1, requires_grad_2 in itertools.product( [True, False], repeat=2 ): - yield partial( - mk_leaf, base_is_nt, requires_grad_1, requires_grad_2 - ), f"{prefix}_leaf_{requires_grad_1}_{requires_grad_2}" + yield ( + partial(mk_leaf, base_is_nt, requires_grad_1, requires_grad_2), + f"{prefix}_leaf_{requires_grad_1}_{requires_grad_2}", + ) # (3) obscure case: # view is not a leaf (implies requires_grad True) @@ -118,9 +119,10 @@ def get_view_test_cases(): yield partial(mk_obscure, base_is_nt), f"{prefix}_obscure" # Subclass -> Dense - yield lambda: get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[ - 0 - ].clone(), "subclass_dense" + yield ( + lambda: get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone(), + "subclass_dense", + ) # Dense -> Subclass -> Dense -> Subclass def mk_dense_subclass_dense_subclass(): diff --git a/test/dynamo/test_trace_rules.py b/test/dynamo/test_trace_rules.py index 90aa18caee4..0125b06c64b 100644 --- a/test/dynamo/test_trace_rules.py +++ b/test/dynamo/test_trace_rules.py @@ -151,9 +151,9 @@ def gen_allowed_objs_and_ids(record=False, c_binding_only=True) -> AllowedObject types.WrapperDescriptorType, ), ) or is_special_functions(obj): - torch_name_rule_map[ - f"{module.__name__}.{name}" - ] = TorchInGraphFunctionVariable + torch_name_rule_map[f"{module.__name__}.{name}"] = ( + TorchInGraphFunctionVariable + ) if c_binding_only: if not hasattr(obj, "__code__"): c_binding_in_graph_functions.add(obj) @@ -398,12 +398,15 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase): ) self.assertTrue("torch._dynamo" not in torch._dynamo.trace_rules.MOD_INLINELIST) - with unittest.mock.patch( - "torch._dynamo.trace_rules.torch_name_rule_map", - _torch_name_rule_map, - ), unittest.mock.patch( - "torch._dynamo.trace_rules.get_torch_obj_rule_map", - torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, # bypass functools.lru_cache + with ( + unittest.mock.patch( + "torch._dynamo.trace_rules.torch_name_rule_map", + _torch_name_rule_map, + ), + unittest.mock.patch( + "torch._dynamo.trace_rules.get_torch_obj_rule_map", + torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, # bypass functools.lru_cache + ), ): x = torch.rand(3) opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) @@ -419,9 +422,9 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase): _manual_torch_name_rule_map = manual_torch_name_rule_map.copy() # Force inline `mod.func` by setting trace rule. - _manual_torch_name_rule_map[ - f"{mod.__name__}.{func.__name__}" - ] = UserFunctionVariable + _manual_torch_name_rule_map[f"{mod.__name__}.{func.__name__}"] = ( + UserFunctionVariable + ) _torch_name_rule_map = [ _manual_torch_name_rule_map, @@ -429,12 +432,15 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase): torch_non_c_binding_in_graph_functions, ] - with unittest.mock.patch( - "torch._dynamo.trace_rules.torch_name_rule_map", - _torch_name_rule_map, - ), unittest.mock.patch( - "torch._dynamo.trace_rules.get_torch_obj_rule_map", - torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, + with ( + unittest.mock.patch( + "torch._dynamo.trace_rules.torch_name_rule_map", + _torch_name_rule_map, + ), + unittest.mock.patch( + "torch._dynamo.trace_rules.get_torch_obj_rule_map", + torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, + ), ): # First adding the module to SKIP_DIRS so that it will be skipped by default. torch._dynamo.trace_rules.add(mod.__name__) diff --git a/test/dynamo/test_utils.py b/test/dynamo/test_utils.py index 007c56e6a26..c9ab3b78188 100644 --- a/test/dynamo/test_utils.py +++ b/test/dynamo/test_utils.py @@ -593,9 +593,10 @@ class TestDynamoTimed(TestCase): ) compilation_events = [] - with dynamo_config.patch({"automatic_dynamic_shapes": False}), mock.patch( - "torch._dynamo.utils.log_compilation_event" - ) as log_event: + with ( + dynamo_config.patch({"automatic_dynamic_shapes": False}), + mock.patch("torch._dynamo.utils.log_compilation_event") as log_event, + ): @torch.compile() def f(x): diff --git a/test/export/test_draft_export.py b/test/export/test_draft_export.py index 9f2dd833b3b..6cf819958fc 100644 --- a/test/export/test_draft_export.py +++ b/test/export/test_draft_export.py @@ -181,9 +181,12 @@ class TestDraftExport(TestCase): self.assertEqual(len(report.op_profiles), 1) self.assertEqual(len(report.op_profiles["mylib.foo8.default"]), 1) - with torch._library.fake_profile.unsafe_generate_fake_kernels( - report.op_profiles - ), FakeTensorMode(allow_non_fake_inputs=True, shape_env=ShapeEnv()): + with ( + torch._library.fake_profile.unsafe_generate_fake_kernels( + report.op_profiles + ), + FakeTensorMode(allow_non_fake_inputs=True, shape_env=ShapeEnv()), + ): torch.ops.mylib.foo8(*new_inp) # Existing registration has been updated to match the new diff --git a/test/export/test_export.py b/test/export/test_export.py index f6ba272bc91..dda4ddacd2e 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -12834,12 +12834,14 @@ def forward(self, x, y): "y": [Dim("dy")], # y & z incorrect, export is supposed to fail. "z": [Dim("dz")], # suggested fix should be to match these up. } - with self.assertRaisesRegex( # if disable=True, suggested fixes should not specialize. - torch._dynamo.exc.UserError, - r".*Constraints violated(.*\n)*" - r"Suggested fixes:(.*\n)*" - r".*dz = dy(.*\n)*", - ) as msg: + with ( + self.assertRaisesRegex( # if disable=True, suggested fixes should not specialize. + torch._dynamo.exc.UserError, + r".*Constraints violated(.*\n)*" + r"Suggested fixes:(.*\n)*" + r".*dz = dy(.*\n)*", + ) as msg + ): export( Foo(), inputs, @@ -13675,8 +13677,7 @@ def forward(self, x): """Make sure the metadata is kept after exported program run_decompositions.""" @torch.library.custom_op("mylib::add", mutates_args=()) - def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - ... + def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ... @torch.library.register_fake("mylib::add") def _(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index f28611f04ba..75a30ccf3da 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -947,10 +947,12 @@ class TestDeserialize(TestCase): ep = torch.export.export(M(), (torch.ones(3, 3), None, torch.ones(3, 3))) serialized_program = ExportedProgramSerializer(None, 2).serialize(ep) - serialized_program.exported_program.graph_module.signature.input_specs[ - 1 - ] = schema.InputSpec.create( - user_input=schema.UserInputSpec(arg=schema.Argument.create(as_none=True)) + serialized_program.exported_program.graph_module.signature.input_specs[1] = ( + schema.InputSpec.create( + user_input=schema.UserInputSpec( + arg=schema.Argument.create(as_none=True) + ) + ) ) ep = ExportedProgramDeserializer(None).deserialize( serialized_program.exported_program, {}, {}, {} diff --git a/test/export/test_swap.py b/test/export/test_swap.py index 8833c3c94ae..d9b2269dc32 100644 --- a/test/export/test_swap.py +++ b/test/export/test_swap.py @@ -22,7 +22,9 @@ from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase {"strict": False}, {"strict": True}, ], - class_name_func=lambda cls, _, params: f"{cls.__name__}_{'strict' if params['strict'] else 'nonstrict'}", + class_name_func=lambda cls, + _, + params: f"{cls.__name__}_{'strict' if params['strict'] else 'nonstrict'}", ) class TestSwap(TestCase): def test_unflatten_preserve_signature(self): diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 3fe6b66039c..d6cf2df4343 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -348,8 +348,7 @@ def check_fc(existing_schemas): "\n\t".join(str(s) for s in matching_new_schemas), ) log.warning( - "Refer to following reasons for failure " - "to find FC schema:\n[\n%s\n]", + "Refer to following reasons for failure to find FC schema:\n[\n%s\n]", "\n\t".join(str(r) for r in possible_failure_reasons), ) broken_ops.append(str(existing_schema)) diff --git a/test/functorch/common_utils.py b/test/functorch/common_utils.py index 72a41dad777..4fa17b89f19 100644 --- a/test/functorch/common_utils.py +++ b/test/functorch/common_utils.py @@ -523,15 +523,15 @@ def decorateForModules(decorator, module_classes, device_type=None, dtypes=None) dtypes=dtypes, ): name_parts = fn.__qualname__.split(".") - assert ( - len(name_parts) == 2 - ), "Decorator only applies to a test function of a test class" + assert len(name_parts) == 2, ( + "Decorator only applies to a test function of a test class" + ) test_case_name, base_test_name = name_parts for module_cls in module_classes: matching_module_infos = [m for m in module_db if m.module_cls == module_cls] - assert ( - len(matching_module_infos) == 1 - ), f"Couldn't find single ModuleInfo for {module_cls}" + assert len(matching_module_infos) == 1, ( + f"Couldn't find single ModuleInfo for {module_cls}" + ) module_info = matching_module_infos[0] decorators = list(module_info.decorators) new_decorator = DecorateInfo( diff --git a/test/functorch/test_ac_knapsack.py b/test/functorch/test_ac_knapsack.py index f0a3c3916e6..751a4c4d218 100644 --- a/test/functorch/test_ac_knapsack.py +++ b/test/functorch/test_ac_knapsack.py @@ -124,9 +124,7 @@ class TestGraphInfoProvider(TestCase): ) def test_recomputable_node_only_graph_with_larger_graph_context(self): - recomputable_node_only_graph_with_larger_graph_context = ( - self.graph_info_provider.recomputable_node_only_graph_with_larger_graph_context - ) + recomputable_node_only_graph_with_larger_graph_context = self.graph_info_provider.recomputable_node_only_graph_with_larger_graph_context # noqa: B950 expected_nodes = self.all_recomputable_banned_nodes # node1 does not have an indirect path to node5 because of node2 # node2 has an indirect path to node5 diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 1edd3845df9..9bd326304fa 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -2568,8 +2568,9 @@ def forward(self, primals_1, primals_2): def fn(x: torch.Tensor, x1: torch.Tensor) -> torch.Tensor: return torch.ops._test._clone_create_graph(x, x1) - inp_x, inp_x1 = torch.randn(3, requires_grad=True), torch.randn( - 3, requires_grad=True + inp_x, inp_x1 = ( + torch.randn(3, requires_grad=True), + torch.randn(3, requires_grad=True), ) ref_x, ref_x1 = inp_x.clone(), inp_x1.clone() @@ -5283,11 +5284,12 @@ def forward(self, arg0_1): mod = TestMod(fn) inp = torch.randn(2) - with patch( - "functorch.compile.config.functionalize_rng_ops", True - ), self.assertRaisesRegex( - RuntimeError, - "Functionalized RNG is not currently supported in the aot_export", + with ( + patch("functorch.compile.config.functionalize_rng_ops", True), + self.assertRaisesRegex( + RuntimeError, + "Functionalized RNG is not currently supported in the aot_export", + ), ): aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False) aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True) diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index 1c53063b7a7..bd8abbc3ea8 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -4324,9 +4324,7 @@ class TestExamplesCorrectness(TestCase): def lennard_jones_force(r): """Get magnitude of LJ force""" - return -epsilon * ( - (-12 * sigma**12 / r**13) + (6 * sigma**6 / r**7) - ) + return -epsilon * ((-12 * sigma**12 / r**13) + (6 * sigma**6 / r**7)) r = torch.linspace(0.5, 2 * sigma, steps=100, requires_grad=True, device=device) drs = torch.outer(r, torch.tensor([1.0, 0, 0], device=device)) @@ -4495,8 +4493,9 @@ class TestExamplesCorrectness(TestCase): # This example mimics what a user might do when trying to find the optimal learning rate. They would # want to run a bunch of models with the same behavior (including the same dropout!) and have them # each run with different learning rates. Specifically, this is an example of using same randomness with vmap - points, labels = torch.randn(100, 2, 2, 2, 2, device=device), torch.randint( - 0, 2, (100,), device=device + points, labels = ( + torch.randn(100, 2, 2, 2, 2, device=device), + torch.randint(0, 2, (100,), device=device), ) class MLPClassifier(nn.Module): diff --git a/test/functorch/test_memory_efficient_fusion.py b/test/functorch/test_memory_efficient_fusion.py index 7bf263431ad..4926781d7f6 100644 --- a/test/functorch/test_memory_efficient_fusion.py +++ b/test/functorch/test_memory_efficient_fusion.py @@ -208,33 +208,33 @@ def check(f, t, delta, check_val=True, graph_input=False): old_num_nodes = len(fx_g.graph.nodes) new_num_nodes = len(new_graph.nodes) if delta == -1: - assert ( - old_num_nodes >= new_num_nodes - ), f"number of nodes increased {old_num_nodes}, {new_num_nodes}" + assert old_num_nodes >= new_num_nodes, ( + f"number of nodes increased {old_num_nodes}, {new_num_nodes}" + ) else: - assert ( - old_num_nodes == new_num_nodes + delta - ), f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}" + assert old_num_nodes == new_num_nodes + delta, ( + f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}" + ) # a second pass should not reduce more nodes pass_2_graph = fx_graph_cse(new_graph) pass_2_num_nodes = len(pass_2_graph.nodes) - assert ( - pass_2_num_nodes == new_num_nodes - ), f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}" + assert pass_2_num_nodes == new_num_nodes, ( + f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}" + ) # check correctness if check_val: true_result = fx_g(t) our_result = new_g(t) if true_result is None: # both return None - assert ( - our_result is None - ), f"true result is None, CSE result is {our_result}" + assert our_result is None, ( + f"true result is None, CSE result is {our_result}" + ) else: # results returned are the same - assert torch.all( - true_result == our_result - ), f"results are different {true_result}, {our_result}" # check results are the same + assert torch.all(true_result == our_result), ( + f"results are different {true_result}, {our_result}" + ) # check results are the same class NoChangeTestCase(TestCase): diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 8a0bf6ad40f..cef00f83eb7 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -2154,9 +2154,9 @@ class TestOperators(TestCase): else: weight = torch.randn(weight_shape, device=device) target = torch.randint(0, C, target_shape, device=device) - target[ - 0 - ] = 1 # since we're ignoring index 0, at least one element must be non-zero + target[0] = ( + 1 # since we're ignoring index 0, at least one element must be non-zero + ) fn = functools.partial( torch.nn.functional.nll_loss, target=target, weight=weight, **kwargs diff --git a/test/functorch/test_parsing.py b/test/functorch/test_parsing.py index 46c9b340c59..8183755ebd4 100644 --- a/test/functorch/test_parsing.py +++ b/test/functorch/test_parsing.py @@ -24,6 +24,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from typing import Any from unittest import mock @@ -107,7 +108,7 @@ class TestParsedExpression(TestCase): ParsedExpression("(a) ((b c) (d ...))") # invalid identifiers - ParsedExpression("camelCase under_scored cApiTaLs \u00DF ...") + ParsedExpression("camelCase under_scored cApiTaLs \u00df ...") with self.assertRaises(ValueError): ParsedExpression("1a") with self.assertRaises(ValueError): diff --git a/test/functorch/test_rearrange.py b/test/functorch/test_rearrange.py index d5f55d7e7a3..b3c8f775368 100644 --- a/test/functorch/test_rearrange.py +++ b/test/functorch/test_rearrange.py @@ -25,7 +25,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ - import numpy as np import torch diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index dc4f239ca2d..1222e890597 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -1532,7 +1532,7 @@ class TestVmapOperators(Namespace.TestVmapBase): self._test_unary(op, getter, "cpu") # test in-place - method = getattr(Tensor, f'{op.__name__ + "_"}') + method = getattr(Tensor, f"{op.__name__ + '_'}") self._test_unary(method, getter, "cpu", check_propagates_grad=False) def test_clone(self): diff --git a/test/fx/quantization.py b/test/fx/quantization.py index 3daa4da479e..33550702ca6 100644 --- a/test/fx/quantization.py +++ b/test/fx/quantization.py @@ -2,6 +2,7 @@ r""" **This file is EXPERIMENTAL and is mostly used for testing purposes! Do not rely on it for anything!** """ + import operator import sys diff --git a/test/fx/test_cse_pass.py b/test/fx/test_cse_pass.py index 06ecd2e1428..74eb2ca3af4 100644 --- a/test/fx/test_cse_pass.py +++ b/test/fx/test_cse_pass.py @@ -46,9 +46,9 @@ def check(self, f, t, delta, check_val=True, graph_input=False, P=None): old_num_nodes = len(fx_g.graph.nodes) new_num_nodes = len(new_graph.nodes) - assert ( - new_num_nodes < old_num_nodes - ) == modified, "modified should be True if the number of nodes decrease" + assert (new_num_nodes < old_num_nodes) == modified, ( + "modified should be True if the number of nodes decrease" + ) if delta == -1: self.assertTrue( diff --git a/test/fx/test_z3_gradual_types.py b/test/fx/test_z3_gradual_types.py index 70430e03c3a..9b1a3878ed6 100644 --- a/test/fx/test_z3_gradual_types.py +++ b/test/fx/test_z3_gradual_types.py @@ -1783,8 +1783,9 @@ class TestSingleOperation(unittest.TestCase): self.assertEqual(s.check(), z3.sat) add_result = z3.Const(3, tensor_type) - broadcast_res1, broadcast_res2 = z3.Const(4, tensor_type), z3.Const( - 5, tensor_type + broadcast_res1, broadcast_res2 = ( + z3.Const(4, tensor_type), + z3.Const(5, tensor_type), ) # print(s.model()) diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index 0fee2cf0953..f7efc393697 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -251,10 +251,13 @@ class TestInvokeSubgraphCompile(TestCase): x_clone = x.detach().clone().requires_grad_(True) y_clone = y.detach().clone().requires_grad_(True) backend = EagerAndRecordGraphs() - with mock.patch( - "torch._dynamo.variables.higher_order_ops.InvokeSubgraphHigherOrderVariable.supports_input_mutation", - True, - ), torch.no_grad(): + with ( + mock.patch( + "torch._dynamo.variables.higher_order_ops.InvokeSubgraphHigherOrderVariable.supports_input_mutation", + True, + ), + torch.no_grad(), + ): res = torch.compile(fn, backend=backend, fullgraph=True)( mod, x_clone, y_clone ) @@ -2399,7 +2402,9 @@ class GraphModule(torch.nn.Module): {"strict": False}, {"strict": True}, ], - class_name_func=lambda cls, _, params: f"{cls.__name__}{'Strict' if params['strict'] else 'Nonstrict'}", + class_name_func=lambda cls, + _, + params: f"{cls.__name__}{'Strict' if params['strict'] else 'Nonstrict'}", ) class TestInvokeSubgraphExport(TestCase): def test_simple_func(self): diff --git a/tools/linter/adapters/pyfmt_linter.py b/tools/linter/adapters/pyfmt_linter.py index 88f04145f89..59044534158 100644 --- a/tools/linter/adapters/pyfmt_linter.py +++ b/tools/linter/adapters/pyfmt_linter.py @@ -38,7 +38,6 @@ USE_BLACK_FILELIST = re.compile( # torchgen/** # test/** # test/[a-h]*/** - "test/[a-h]*/**", # test/[i-j]*/** "test/[i-j]*/**", # test/[k-m]*/**