From 3fe437b24ba70e0ac7af8b6664c66875cf91731b Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Wed, 3 Jan 2024 06:04:44 +0000 Subject: [PATCH] [BE]: Update flake8 to v6.1.0 and fix lints (#116591) Updates flake8 to v6.1.0 and fixes a few lints using sed and some ruff tooling. - Replace `assert(0)` with `raise AssertionError()` - Remove extraneous parenthesis i.e. - `assert(a == b)` -> `assert a == b` - `if(x > y or y < z):`->`if x > y or y < z:` - And `return('...')` -> `return '...'` Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/116591 Approved by: https://github.com/albanD, https://github.com/malfet --- .flake8 | 2 - .lintrunner.toml | 6 +-- .../ao/sparsity/test_structured_sparsifier.py | 2 +- test/distributed/test_c10d_spawn.py | 2 +- test/functorch/common_utils.py | 8 ++-- test/fx/test_z3_gradual_types.py | 2 +- test/jit/test_autodiff_subgraph_slicing.py | 6 +-- test/jit/test_scriptmod_ann.py | 14 +++--- test/mobile/test_bytecode.py | 18 +++---- test/mobile/test_lite_script_module.py | 2 +- test/nn/test_dropout.py | 4 +- test/optim/test_optim.py | 2 +- test/quantization/core/test_quantized_op.py | 12 ++--- test/quantization/core/test_workflow_ops.py | 2 +- .../eager/test_quantize_eager_qat.py | 2 +- test/test_autograd.py | 4 +- test/test_bundled_images.py | 2 +- test/test_cuda.py | 2 +- test/test_fx.py | 24 +++++----- test/test_jit.py | 22 ++++----- test/test_mps.py | 6 +-- test/test_nn.py | 4 +- test/test_ops.py | 2 +- test/test_overrides.py | 4 +- test/test_python_dispatch.py | 2 +- test/test_reductions.py | 2 +- test/test_sparse.py | 6 +-- test/test_torch.py | 2 +- torch/__init__.py | 2 +- torch/_functorch/benchmark_utils.py | 4 +- torch/_higher_order_ops/out_dtype.py | 2 +- .../quantized/modules/functional_modules.py | 2 +- torch/ao/nn/sparse/quantized/utils.py | 2 +- .../ao/quantization/backend_config/onednn.py | 2 +- .../ao/quantization/experimental/observer.py | 4 +- torch/ao/quantization/fake_quantize.py | 4 +- .../ao/quantization/fuser_method_mappings.py | 12 ++--- torch/ao/quantization/fx/_equalize.py | 44 ++++++++--------- .../fx/_lower_to_native_backend.py | 48 +++++++++---------- torch/ao/quantization/fx/convert.py | 2 +- torch/ao/quantization/fx/prepare.py | 8 ++-- .../_shard/sharded_tensor/_ops/binary_cmp.py | 2 +- .../constraint_transformation.py | 2 +- .../migrate_gradual_types/transform_to_z3.py | 2 +- torch/fx/experimental/optimization.py | 18 +++---- torch/fx/experimental/symbolic_shapes.py | 2 +- torch/fx/graph.py | 2 +- torch/fx/node.py | 2 +- torch/fx/passes/net_min_base.py | 2 +- torch/fx/passes/split_module.py | 2 +- torch/fx/proxy.py | 2 +- torch/jit/frontend.py | 2 +- torch/testing/_internal/common_nn.py | 2 +- torch/testing/_internal/common_quantized.py | 2 +- torch/testing/_internal/common_utils.py | 8 ++-- torch/testing/_internal/hypothesis_utils.py | 6 +-- .../_internal/jit_metaprogramming_utils.py | 2 +- torch/utils/_python_dispatch.py | 2 +- torch/utils/bundled_inputs.py | 4 +- torch/utils/collect_env.py | 2 +- torch/utils/flop_counter.py | 4 +- torch/utils/mkldnn.py | 6 +-- 62 files changed, 188 insertions(+), 190 deletions(-) diff --git a/.flake8 b/.flake8 index 57a3a211360..c59af78be7b 100644 --- a/.flake8 +++ b/.flake8 @@ -8,8 +8,6 @@ max-line-length = 120 # E501 is not flexible enough, we're using B950 instead ignore = E203,E305,E402,E501,E721,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303, - # fix these lints in the future - E275, # shebang has extra meaning in fbcode lints, so I think it's not worth trying # to line this up with executable bit EXE001, diff --git a/.lintrunner.toml b/.lintrunner.toml index 41c56ee529d..5e3cd708544 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -38,7 +38,7 @@ init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', - 'flake8==6.0.0', + 'flake8==6.1.0', 'flake8-bugbear==23.3.23', 'flake8-comprehensions==3.12.0', 'flake8-executable==2.1.3', @@ -46,8 +46,8 @@ init_command = [ 'flake8-pyi==23.3.1', 'flake8-simplify==0.19.3', 'mccabe==0.7.0', - 'pycodestyle==2.10.0', - 'pyflakes==3.0.1', + 'pycodestyle==2.11.1', + 'pyflakes==3.1.0', 'torchfix==0.2.0', ] diff --git a/test/ao/sparsity/test_structured_sparsifier.py b/test/ao/sparsity/test_structured_sparsifier.py index c11d0179a67..7587f5ca5c3 100644 --- a/test/ao/sparsity/test_structured_sparsifier.py +++ b/test/ao/sparsity/test_structured_sparsifier.py @@ -981,7 +981,7 @@ class TestFPGMPruner(TestCase): pruner.prepare(model, config) pruner.enable_mask_update = True pruner.step() - assert pruner.groups[0]["module"].parametrizations.weight[0].mask[-1].item() is not False,\ + assert pruner.groups[0]["module"].parametrizations.weight[0].mask[-1].item() is not False, \ "do not prune the least-norm filter" # fusion step diff --git a/test/distributed/test_c10d_spawn.py b/test/distributed/test_c10d_spawn.py index 7c09a8af7ab..a6fb33d4ead 100644 --- a/test/distributed/test_c10d_spawn.py +++ b/test/distributed/test_c10d_spawn.py @@ -9,7 +9,7 @@ import torch.distributed as c10d import torch.multiprocessing as mp from torch.testing._internal.common_distributed import \ MultiProcessTestCase -from torch.testing._internal.common_utils import load_tests,\ +from torch.testing._internal.common_utils import load_tests, \ NO_MULTIPROCESSING_SPAWN # Torch distributed.nn is not available in windows diff --git a/test/functorch/common_utils.py b/test/functorch/common_utils.py index 77a8114cc72..46555f5d51e 100644 --- a/test/functorch/common_utils.py +++ b/test/functorch/common_utils.py @@ -27,7 +27,7 @@ def loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values): for idx in range(batch_size): flat_args, args_spec = pytree.tree_flatten(batched_args) flat_dims, dims_spec = pytree.tree_flatten(in_dims) - assert(args_spec == dims_spec) + assert args_spec == dims_spec new_args = [a.select(in_dim, idx) if in_dim is not None else a for a, in_dim in zip(flat_args, flat_dims)] out = op(*pytree.tree_unflatten(new_args, args_spec), **kwarg_values) flat_out, out_spec = pytree.tree_flatten(out) @@ -45,9 +45,9 @@ def loop2(op, in_dims1, in_dims2, out_dim1, out_dim2, batch_size1, batch_size2, flat_args, args_spec = pytree.tree_flatten(batched_args) flat_dims1, dims_spec1 = pytree.tree_flatten(in_dims1) flat_dims2, dims_spec2 = pytree.tree_flatten(in_dims2) - assert(args_spec == dims_spec1) - assert(args_spec == dims_spec2) - assert(len(flat_dims1) == len(flat_dims2)) + assert args_spec == dims_spec1 + assert args_spec == dims_spec2 + assert len(flat_dims1) == len(flat_dims2) for idx1 in range(batch_size1): out_split = [] arg_split = [a.select(in_dim1, idx1) if in_dim1 is not None else a for a, in_dim1 in zip(flat_args, flat_dims1)] diff --git a/test/fx/test_z3_gradual_types.py b/test/fx/test_z3_gradual_types.py index 12d57b57ae1..4e5eb07767b 100644 --- a/test/fx/test_z3_gradual_types.py +++ b/test/fx/test_z3_gradual_types.py @@ -7,7 +7,7 @@ from torch.fx.experimental.migrate_gradual_types.constraint import BinConstraint from torch.fx.experimental.migrate_gradual_types.constraint_generator import ConstraintGenerator from torch.fx.experimental.migrate_gradual_types.constraint_transformation import transform_constraint from torch.fx.experimental.migrate_gradual_types.operation import op_precision, op_matching, op_consistency -from torch.fx.experimental.migrate_gradual_types.transform_to_z3 import transform_all_constraints,\ +from torch.fx.experimental.migrate_gradual_types.transform_to_z3 import transform_all_constraints, \ evaluate_conditional_with_constraints from torch.fx.experimental.migrate_gradual_types.z3_types import tensor_type, D, z3_dyn from torch.fx.experimental.rewriter import RewritingTracer diff --git a/test/jit/test_autodiff_subgraph_slicing.py b/test/jit/test_autodiff_subgraph_slicing.py index fbdcc190914..1fb40b62297 100644 --- a/test/jit/test_autodiff_subgraph_slicing.py +++ b/test/jit/test_autodiff_subgraph_slicing.py @@ -154,9 +154,9 @@ class TestAutodiffSubgraphSlicing(JitTestCase): output_ref = func(input0, input1) for i in range(2): output = jit_f(input0, input1) - assert(output_ref[0].requires_grad == output[0].requires_grad) - assert(output_ref[1][0].requires_grad == output[1][0].requires_grad) - assert(output_ref[1][1].requires_grad == output[1][1].requires_grad) + assert output_ref[0].requires_grad == output[0].requires_grad + assert output_ref[1][0].requires_grad == output[1][0].requires_grad + assert output_ref[1][1].requires_grad == output[1][1].requires_grad @unittest.skip("disable until we property handle tensor lists with undefined gradients") def test_differentiable_graph_ops_requires_grad(self): diff --git a/test/jit/test_scriptmod_ann.py b/test/jit/test_scriptmod_ann.py index 47e010e6122..3a9f2fd4d2c 100644 --- a/test/jit/test_scriptmod_ann.py +++ b/test/jit/test_scriptmod_ann.py @@ -35,7 +35,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase): with warnings.catch_warnings(record=True) as w: self.checkModule(M(), (1,)) - assert(len(w) == 0) + assert len(w) == 0 def test_annotated_nonempty_container(self): class M(torch.nn.Module): @@ -49,7 +49,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase): with warnings.catch_warnings(record=True) as w: self.checkModule(M(), ([1, 2, 3],)) - assert(len(w) == 0) + assert len(w) == 0 def test_annotated_empty_tensor(self): class M(torch.nn.Module): @@ -63,7 +63,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase): with warnings.catch_warnings(record=True) as w: self.checkModule(M(), (torch.rand(2, 3),)) - assert(len(w) == 0) + assert len(w) == 0 def test_annotated_with_jit_attribute(self): class M(torch.nn.Module): @@ -77,7 +77,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase): with warnings.catch_warnings(record=True) as w: self.checkModule(M(), ([1, 2, 3],)) - assert(len(w) == 0) + assert len(w) == 0 def test_annotated_class_level_annotation_only(self): class M(torch.nn.Module): @@ -94,7 +94,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase): with warnings.catch_warnings(record=True) as w: self.checkModule(M(), ([1, 2, 3],)) - assert(len(w) == 0) + assert len(w) == 0 def test_annotated_class_level_annotation_and_init_annotation(self): @@ -112,7 +112,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase): with warnings.catch_warnings(record=True) as w: self.checkModule(M(), ([1, 2, 3],)) - assert(len(w) == 0) + assert len(w) == 0 def test_annotated_class_level_jit_annotation(self): class M(torch.nn.Module): @@ -129,7 +129,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase): with warnings.catch_warnings(record=True) as w: self.checkModule(M(), ([1, 2, 3],)) - assert(len(w) == 0) + assert len(w) == 0 def test_annotated_empty_list(self): class M(torch.nn.Module): diff --git a/test/mobile/test_bytecode.py b/test/mobile/test_bytecode.py index b5a493e1103..a2f2a93eb00 100644 --- a/test/mobile/test_bytecode.py +++ b/test/mobile/test_bytecode.py @@ -147,7 +147,7 @@ class testVariousModelVersions(TestCase): def test_get_model_bytecode_version(self): def check_model_version(model_path, expect_version): actual_version = _get_model_bytecode_version(model_path) - assert(actual_version == expect_version) + assert actual_version == expect_version for version, model_info in SCRIPT_MODULE_BYTECODE_PKL.items(): model_path = pytorch_test_dir / "cpp" / "jit" / model_info["model_name"] check_model_version(model_path, version) @@ -174,7 +174,7 @@ class testVariousModelVersions(TestCase): current_to_version = current_from_version - 1 backport_success = _backport_for_mobile(input_model_path, tmp_output_model_path_backport, current_to_version) - assert(backport_success) + assert backport_success expect_bytecode_pkl = SCRIPT_MODULE_BYTECODE_PKL[current_to_version]["bytecode_pkl"] @@ -187,7 +187,7 @@ class testVariousModelVersions(TestCase): acutal_result_clean = "".join(output.split()) expect_result_clean = "".join(expect_bytecode_pkl.split()) isMatch = fnmatch.fnmatch(acutal_result_clean, expect_result_clean) - assert(isMatch) + assert isMatch current_from_version -= 1 shutil.rmtree(tmpdirname) @@ -254,7 +254,7 @@ class testVariousModelVersions(TestCase): script_module_v5_path, tmp_backport_model_path, maximum_checked_in_model_version - 1) - assert(success) + assert success buf = io.StringIO() torch.utils.show_pickle.main( @@ -266,7 +266,7 @@ class testVariousModelVersions(TestCase): acutal_result_clean = "".join(output.split()) expect_result_clean = "".join(expected_result.split()) isMatch = fnmatch.fnmatch(acutal_result_clean, expect_result_clean) - assert(isMatch) + assert isMatch # Load model v4 and run forward method mobile_module = _load_for_lite_interpreter(str(tmp_backport_model_path)) @@ -291,7 +291,7 @@ class testVariousModelVersions(TestCase): # Check version of the model v4 from backport bytesio = io.BytesIO(script_module_v4_buffer) backport_version = _get_model_bytecode_version(bytesio) - assert(backport_version == maximum_checked_in_model_version - 1) + assert backport_version == maximum_checked_in_model_version - 1 # Load model v4 from backport and run forward method bytesio = io.BytesIO(script_module_v4_buffer) @@ -306,8 +306,8 @@ class testVariousModelVersions(TestCase): # TODO update this to be more in the style of the above tests after a backport from 6 -> 5 exists script_module_v6 = pytorch_test_dir / "cpp" / "jit" / "script_module_v6.ptl" ops_v6 = _get_model_ops_and_info(script_module_v6) - assert(ops_v6["aten::add.int"].num_schema_args == 2) - assert(ops_v6["aten::add.Scalar"].num_schema_args == 2) + assert ops_v6["aten::add.int"].num_schema_args == 2 + assert ops_v6["aten::add.Scalar"].num_schema_args == 2 def test_get_mobile_model_contained_types(self): class MyTestModule(torch.nn.Module): @@ -322,7 +322,7 @@ class testVariousModelVersions(TestCase): buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter()) buffer.seek(0) type_list = _get_mobile_model_contained_types(buffer) - assert(len(type_list) >= 0) + assert len(type_list) >= 0 if __name__ == '__main__': run_tests() diff --git a/test/mobile/test_lite_script_module.py b/test/mobile/test_lite_script_module.py index 377a0237f82..7a742c55ab7 100644 --- a/test/mobile/test_lite_script_module.py +++ b/test/mobile/test_lite_script_module.py @@ -82,7 +82,7 @@ class TestLiteScriptModule(TestCase): buffer = io.BytesIO(exported_module) buffer.seek(0) - assert(b"callstack_debug_map.pkl" in exported_module) + assert b"callstack_debug_map.pkl" in exported_module mobile_module = _load_for_lite_interpreter(buffer) with self.assertRaisesRegex(RuntimeError, r"Module hierarchy:top\(B\)::.A0\(A\)::forward.aten::mul"): diff --git a/test/nn/test_dropout.py b/test/nn/test_dropout.py index 2c9e463318f..5337b0eb76f 100644 --- a/test/nn/test_dropout.py +++ b/test/nn/test_dropout.py @@ -65,8 +65,8 @@ class TestDropoutNN(NNTestCase): o_ref = torch.dropout(x_ref, p, train) o.sum().backward() o_ref.sum().backward() - assert(o.equal(o_ref)) - assert(x.grad.equal(x_ref.grad)) + assert o.equal(o_ref) + assert x.grad.equal(x_ref.grad) def test_invalid_dropout_p(self): v = torch.ones(1) diff --git a/test/optim/test_optim.py b/test/optim/test_optim.py index aacd837166d..fcc0b9b21e6 100644 --- a/test/optim/test_optim.py +++ b/test/optim/test_optim.py @@ -166,7 +166,7 @@ class TestOptim(TestCase): pass elif constructor_accepts_maximize: - def four_arg_constructor(weight, bias, maximize, foreach): + def four_arg_constructor(weight, bias, maximize, foreach): # noqa: F811 self.assertFalse(foreach) return constructor(weight, bias, maximize) diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index c458b5764e2..47564b1ccac 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -2127,7 +2127,7 @@ class TestQuantizedOps(TestCase): quantized_out = torch.topk(qX, k, dim=dim, largest=larg, sorted=sort) - assert(len(unquantized_out) == len(quantized_out)) + assert len(unquantized_out) == len(quantized_out) torch.testing.assert_close(quantized_out[0].dequantize(), unquantized_out[0]) torch.testing.assert_close(quantized_out[1], unquantized_out[1]) @@ -3813,7 +3813,7 @@ class TestQuantizedLinear(TestCase): post_op = 'none' cases = itertools.product(batch_size_list, input_channels_list, output_channels_list, use_bias_list, use_multi_dim_input_list, use_channelwise_list) - for batch_size, input_channels, output_channels, use_bias,\ + for batch_size, input_channels, output_channels, use_bias, \ use_multi_dim_input, use_channelwise in cases: self._test_qlinear_impl(batch_size, input_channels, output_channels, use_bias, post_op, use_multi_dim_input, use_channelwise) @@ -3830,7 +3830,7 @@ class TestQuantizedLinear(TestCase): post_op = 'relu' cases = itertools.product(batch_size_list, input_channels_list, output_channels_list, use_bias_list, use_multi_dim_input_list, use_channelwise_list) - for batch_size, input_channels, output_channels, use_bias,\ + for batch_size, input_channels, output_channels, use_bias, \ use_multi_dim_input, use_channelwise in cases: self._test_qlinear_impl(batch_size, input_channels, output_channels, use_bias, post_op, use_multi_dim_input, use_channelwise) @@ -4101,7 +4101,7 @@ class TestQuantizedLinear(TestCase): qparams=hu.qparams(dtypes=torch.qint8))) @override_qengines def test_qlinear_qnnpack_free_memory_and_unpack(self, W): - assert(qengine_is_qnnpack) + assert qengine_is_qnnpack W, (W_scale, W_zp, torch_type) = W qlinear_prepack = torch.ops.quantized.linear_prepack qlinear_unpack = torch.ops.quantized.linear_unpack @@ -4140,7 +4140,7 @@ class TestQuantizedLinear(TestCase): cases = itertools.product(batch_size_list, input_channels_list, output_channels_list, use_bias_list, use_multi_dim_input_list, use_channelwise_list, negative_slopes_list) - for batch_size, input_channels, output_channels, use_bias,\ + for batch_size, input_channels, output_channels, use_bias, \ use_multi_dim_input, use_channelwise, neg_slope in cases: self._test_qlinear_impl(batch_size, input_channels, output_channels, use_bias, post_op, use_multi_dim_input, @@ -4159,7 +4159,7 @@ class TestQuantizedLinear(TestCase): cases = itertools.product(batch_size_list, input_channels_list, output_channels_list, use_bias_list, use_multi_dim_input_list, use_channelwise_list) - for batch_size, input_channels, output_channels, use_bias,\ + for batch_size, input_channels, output_channels, use_bias, \ use_multi_dim_input, use_channelwise in cases: self._test_qlinear_impl(batch_size, input_channels, output_channels, use_bias, post_op, use_multi_dim_input, diff --git a/test/quantization/core/test_workflow_ops.py b/test/quantization/core/test_workflow_ops.py index 2e22d49af4c..0000ef9bbc2 100644 --- a/test/quantization/core/test_workflow_ops.py +++ b/test/quantization/core/test_workflow_ops.py @@ -879,7 +879,7 @@ class TestFakeQuantizeOps(TestCase): Y_prime.backward(dout) np.testing.assert_allclose( dX.cpu().detach().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance) - assert(X.grad.dtype == float_type) + assert X.grad.dtype == float_type def test_backward_per_channel_cachemask_cpu(self): diff --git a/test/quantization/eager/test_quantize_eager_qat.py b/test/quantization/eager/test_quantize_eager_qat.py index d51fcbb9997..52f169b1d5b 100644 --- a/test/quantization/eager/test_quantize_eager_qat.py +++ b/test/quantization/eager/test_quantize_eager_qat.py @@ -839,7 +839,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase): return reduce(lambda f, g: lambda x: f(g(x)), functions[::-1], lambda x: x) if not use_relu: - def relu_op(x): + def relu_op(x): # noqa: F811 return x if freeze_bn: diff --git a/test/test_autograd.py b/test/test_autograd.py index 872a8016cc4..7836b59388e 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -4319,7 +4319,7 @@ Done""") ] for thread, ranges in threads: for range in ranges: - assert(len(range) == 3) + assert len(range) == 3 events.append( FunctionEvent( id=range[2], @@ -4340,7 +4340,7 @@ Done""") def get_children_ids(event): return [child.id for child in event.cpu_children] - assert([get_children_ids(event) for event in events] == res) + assert [get_children_ids(event) for event in events] == res def test_profiler_aggregation_table(self): """ diff --git a/test/test_bundled_images.py b/test/test_bundled_images.py index 0dcba86bd12..118e276a30b 100644 --- a/test/test_bundled_images.py +++ b/test/test_bundled_images.py @@ -30,7 +30,7 @@ def save_and_load(sm): def bundle_jpeg_image(img_tensor, quality): # turn NCHW to HWC if img_tensor.dim() == 4: - assert(img_tensor.size(0) == 1) + assert img_tensor.size(0) == 1 img_tensor = img_tensor[0].permute(1, 2, 0) pixels = img_tensor.numpy() encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality] diff --git a/test/test_cuda.py b/test/test_cuda.py index d818916082e..5dfb24cd495 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1388,7 +1388,7 @@ torch.cuda.synchronize() scaler.scale(l).backward() scaler.step(optimizer) scaler.update() - assert(scaler._scale != float('inf') and scaler._scale != float('nan')) + assert scaler._scale != float('inf') and scaler._scale != float('nan') def test_grad_scaling_clipping(self): def run(data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api): diff --git a/test/test_fx.py b/test/test_fx.py index 13c748beeef..652fe6cd34a 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -1676,8 +1676,8 @@ class TestFX(JitTestCase): x = torch.randn(5, 5, 224, 224) shape_prop.ShapeProp(traced).propagate(x) - assert(all(node.meta['tensor_meta'].memory_format is torch.contiguous_format - for node in traced.graph.nodes)) + assert all(node.meta['tensor_meta'].memory_format is torch.contiguous_format + for node in traced.graph.nodes) x_channels_last = x.contiguous(memory_format=torch.channels_last) traced.to(memory_format=torch.channels_last) @@ -1733,8 +1733,8 @@ class TestFX(JitTestCase): traced_3d = symbolic_trace(test_mod_3d) x_3d = torch.randn(5, 5, 224, 224, 15) shape_prop.ShapeProp(traced_3d).propagate(x_3d) - assert(all(node.meta['tensor_meta'].memory_format is torch.contiguous_format - for node in traced_3d.graph.nodes)) + assert all(node.meta['tensor_meta'].memory_format is torch.contiguous_format + for node in traced_3d.graph.nodes) x_channels_last_3d = x_3d.contiguous(memory_format=torch.channels_last_3d) traced_3d.to(memory_format=torch.channels_last_3d) @@ -2548,7 +2548,7 @@ class TestFX(JitTestCase): return self.a.b, self.a.b.t(), self.a.b.view(12) traced = torch.fx.symbolic_trace(Foo()) - assert(all('constant' not in node.target for node in traced.graph.nodes)) + assert all('constant' not in node.target for node in traced.graph.nodes) def test_single_default_arg(self): class M(torch.nn.Module): @@ -2824,7 +2824,7 @@ class TestFX(JitTestCase): mod_false = symbolic_trace(mod, concrete_args={'y': False}) self.assertEqual(mod_true(3, True), 6) print(mod_true.code) - assert(any(i.target == torch._assert for i in mod_true.graph.nodes)) + assert any(i.target == torch._assert for i in mod_true.graph.nodes) with self.assertRaises(AssertionError): mod_true(3, False) self.assertEqual(mod_false(3, False), 3) @@ -3607,17 +3607,17 @@ class TestFX(JitTestCase): self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(val))), orig_out) assert num_flat_args == 0 or "tree_flatten_spec" in nf.code - assert(sum([i.op == 'placeholder' for i in nf.graph.nodes]) == num_flat_args) + assert sum([i.op == 'placeholder' for i in nf.graph.nodes]) == num_flat_args nf = symbolic_trace(nf) self.assertEqual(nf(val), orig_out) assert "tree_flatten_spec" not in nf.code - assert(sum([i.op == 'placeholder' for i in nf.graph.nodes]) == 1) + assert sum([i.op == 'placeholder' for i in nf.graph.nodes]) == 1 nf = symbolic_trace(nf, concrete_args={'x': inp}) self.assertEqual(nf(val), orig_out) assert num_flat_args == 0 or "tree_flatten_spec" in nf.code - assert(sum([i.op == 'placeholder' for i in nf.graph.nodes]) == num_flat_args) + assert sum([i.op == 'placeholder' for i in nf.graph.nodes]) == num_flat_args pickled = pickle.dumps(nf) nf = pickle.loads(pickled) @@ -3708,7 +3708,7 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: return [('List', typing.List)] def process_inputs(self, *inputs): - assert(len(inputs) == 1) + assert len(inputs) == 1 return inputs[0] def f(a, b): @@ -3743,7 +3743,7 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: return [('List', typing.List)] def process_inputs(self, *inputs): - assert(len(inputs) == 1) + assert len(inputs) == 1 return inputs[0] def f(a, b): @@ -3772,7 +3772,7 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: return [('List', typing.List)] def process_inputs(self, *inputs): - assert(len(inputs) == 1) + assert len(inputs) == 1 return inputs[0] def generate_output(self, output_args): diff --git a/test/test_jit.py b/test/test_jit.py index c7149524506..64cd80c10d9 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1875,8 +1875,8 @@ graph(%Ra, %Rb): o_ref = t(x_ref, p, train) o.sum().backward() o_ref.sum().backward() - assert(o.equal(o_ref)) - assert(x.grad.equal(x_ref.grad)) + assert o.equal(o_ref) + assert x.grad.equal(x_ref.grad) @slowTest @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, 'Testing differentiable graph') @@ -5746,7 +5746,7 @@ a") x = torch.zeros(3, 4, dtype=torch.long) graph = _propagate_shapes(fn.graph, (x,), False) default = torch.get_default_dtype() - if(default == torch.float): + if default == torch.float: FileCheck().check('Float(*, *, requires_grad=0, device=cpu) = aten::add').run(graph) else: FileCheck().check('Double(*, *, requires_grad=0, device=cpu) = aten::add').run(graph) @@ -6169,7 +6169,7 @@ a") with self.assertRaisesRegex(RuntimeError, r"Could not cast value of type Tuple\[Tensor, Tensor\]"): # noqa: W605 @torch.jit.script def test_mult(x, y): - return not(x, y) + return not (x, y) def test_cast_int(x): # type: (int) -> int @@ -8557,15 +8557,15 @@ dedent """ fb.run("22 1 22") def _dtype_to_jit_name(self, dtype): - if(dtype == torch.float32): + if dtype == torch.float32: return "Float" - if(dtype == torch.float64): + if dtype == torch.float64: return "Double" - if(dtype == torch.int64): + if dtype == torch.int64: return "Long" - if(dtype == torch.int32): + if dtype == torch.int32: return "Int" - if(dtype == torch.bool): + if dtype == torch.bool: return "Bool" raise RuntimeError('dtype not handled') @@ -8595,7 +8595,7 @@ dedent """ for dtype in (dtypes + [None]): for tensor_type in dtypes: # a couple of ops aren't implemented for non-floating types - if(not tensor_type.is_floating_point or (dtype is not None and not dtype.is_floating_point)): + if not tensor_type.is_floating_point or (dtype is not None and not dtype.is_floating_point): if op in ['mean', 'softmax', 'log_softmax']: continue return_line = f"torch.tensor({tensor_data}, dtype={tensor_type}).{op}({str_args}dtype={dtype})" @@ -12837,7 +12837,7 @@ dedent """ for i, offset in enumerate(parsed_serialized_offsets): data = reader.get_record(str(offset)) - assert(data == buffers[i]) + assert data == buffers[i] def test_file_reader_no_memory_leak(self): num_iters = 10000 diff --git a/test/test_mps.py b/test/test_mps.py index 639746379fa..fed7a3d91e0 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -1032,7 +1032,7 @@ class MpsMemoryLeakCheck: if driver_mem_allocated > self.driver_before: driver_discrepancy = True - if not(caching_allocator_discrepancy or driver_discrepancy): + if not (caching_allocator_discrepancy or driver_discrepancy): # Leak was false positive, exit loop discrepancy_detected = False break @@ -4122,7 +4122,7 @@ class TestMPS(TestCaseMPS): def helper(dtype): # Test with axis -1 cpu_x = None - if(dtype == torch.float32): + if dtype == torch.float32: cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32) else: cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32) @@ -4157,7 +4157,7 @@ class TestMPS(TestCaseMPS): def helper(dtype): # Test with axis -1 cpu_x = None - if(dtype == torch.float32): + if dtype == torch.float32: cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32) else: cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32) diff --git a/test/test_nn.py b/test/test_nn.py index 81fd93bbcd1..12d3fa91221 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -31,7 +31,7 @@ from torch.nn.parallel._functions import Broadcast from torch.testing._internal.common_dtype import integral_types, get_all_math_dtypes, floating_types from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \ TEST_NUMPY, TEST_SCIPY, TEST_WITH_CROSSREF, TEST_WITH_ROCM, \ - download_file, get_function_arglist, load_tests, skipIfMps,\ + download_file, get_function_arglist, load_tests, skipIfMps, \ IS_PPC, \ parametrize as parametrize_test, subtest, instantiate_parametrized_tests, \ skipIfTorchDynamo, IS_WINDOWS, gcIfJetson, set_default_dtype @@ -6865,7 +6865,7 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""") elif weight_layout == torch.sparse_coo: module.weight = nn.Parameter(module.weight.to_sparse_coo()) else: - assert(0) + raise AssertionError() inp = torch.randn(4, requires_grad=True, device=device) res = module(inp) diff --git a/test/test_ops.py b/test/test_ops.py index 80422c87a62..fc286d43059 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -108,7 +108,7 @@ _ref_test_ops = tuple( ) def reduction_dtype_filter(op): - if(not isinstance(op, ReductionPythonRefInfo) or not op.supports_out + if (not isinstance(op, ReductionPythonRefInfo) or not op.supports_out or torch.int16 not in op.dtypes): return False diff --git a/test/test_overrides.py b/test/test_overrides.py index f9988197a08..1a065d99f68 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -217,7 +217,7 @@ class SubTensor(torch.Tensor): """ @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): - if(kwargs is None): + if kwargs is None: kwargs = {} if func not in HANDLED_FUNCTIONS_SUB: @@ -368,7 +368,7 @@ class TensorLike: """ @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): - if(kwargs is None): + if kwargs is None: kwargs = {} if func not in HANDLED_FUNCTIONS_TENSOR_LIKE: diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index 6e358034b87..7b62b387b2b 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -465,7 +465,7 @@ class TestPythonRegistration(TestCase): # check rest of functional_result is the mutated args mutated_args = [maybe_mutated_arg for maybe_mutated_arg, arg in zip(cloned_args, args) - if not(maybe_mutated_arg is not None and arg is not None and torch.allclose(maybe_mutated_arg, arg))] + if not (maybe_mutated_arg is not None and arg is not None and torch.allclose(maybe_mutated_arg, arg))] self.assertEqual(flat_functional_result[len(flat_mutable_result):], mutated_args) # check that functionalization kernel was indeed registered diff --git a/test/test_reductions.py b/test/test_reductions.py index 272769808d6..45bff791594 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -1353,7 +1353,7 @@ class TestReductions(TestCase): def _test_memory_format_transformations(self, device, input_generator_fn, transformation_fn, memory_format, compare_data=True, default_is_preserve=False): - assert(memory_format == torch.channels_last or memory_format == torch.channels_last_3d) + assert memory_format == torch.channels_last or memory_format == torch.channels_last_3d # xc is a channels last tensor xc = input_generator_fn(device) diff --git a/test/test_sparse.py b/test/test_sparse.py index fc7c2e54b6e..737811abae2 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -3988,11 +3988,11 @@ class TestSparse(TestSparseBase): yield (make_diags((1, 3)), make_offsets([0]), (3, 2, 3)), "Output shape must be 2d" yield (make_diags((2, 3)), make_offsets([[1, 2], [0, 3]]), (3, 3)), "Offsets must be scalar or vector" yield (make_diags((3, 2, 3)), make_offsets([0, 1, 2]), (4, 4)), "Diagonals must be vector or matrix" - yield (make_diags((3, 3)), make_offsets([-1, 0]), (3, 3)),\ + yield (make_diags((3, 3)), make_offsets([-1, 0]), (3, 3)), \ r"Number of diagonals \(\d\) does not match the number of offsets \(\d\)" - yield (make_diags((5,)), make_offsets([0, 1, 2, 3, 4]), (3, 3)),\ + yield (make_diags((5,)), make_offsets([0, 1, 2, 3, 4]), (3, 3)), \ r"Number of diagonals \(\d\) does not match the number of offsets \(\d\)" - yield (make_diags((2, 2)), make_offsets([-1, 0]), (2, 3), torch.strided),\ + yield (make_diags((2, 2)), make_offsets([-1, 0]), (2, 3), torch.strided), \ r"Only output layouts \(\w+, \w+, \w+\) are supported, got \w+" yield (make_diags((2, 5)), make_offsets([0, 0]), (5, 5)), "Offset tensor contains duplicate values" yield (make_diags((1, 5)), make_offsets([0]).to(torch.int32), (5, 5)), r"Offset Tensor must have dtype Long but got \w+" diff --git a/test/test_torch.py b/test/test_torch.py index eb34da4663b..529e32c5aeb 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -5210,7 +5210,7 @@ else: def _test_memory_format_transformations(self, device, input_generator_fn, transformation_fn, memory_format, compare_data=True, default_is_preserve=False): - assert(memory_format == torch.channels_last or memory_format == torch.channels_last_3d) + assert memory_format == torch.channels_last or memory_format == torch.channels_last_3d # xc is a channels last tensor xc = input_generator_fn(device) diff --git a/torch/__init__.py b/torch/__init__.py index 281dc31204a..98c9a43511c 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -191,7 +191,7 @@ def _load_global_deps() -> None: 'nccl': 'libnccl.so.*[0-9]', 'nvtx': 'libnvToolsExt.so.*[0-9]', } - is_cuda_lib_err = [lib for lib in cuda_libs.values() if(lib.split('.')[0] in err.args[0])] + is_cuda_lib_err = [lib for lib in cuda_libs.values() if lib.split('.')[0] in err.args[0]] if not is_cuda_lib_err: raise err for lib_folder, lib_name in cuda_libs.items(): diff --git a/torch/_functorch/benchmark_utils.py b/torch/_functorch/benchmark_utils.py index a0837010662..78111275008 100644 --- a/torch/_functorch/benchmark_utils.py +++ b/torch/_functorch/benchmark_utils.py @@ -75,7 +75,7 @@ def is_gpu_compute_event(event): def get_sorted_gpu_events(events): sorted_gpu_events = [] for event in events: - if(not is_gpu_compute_event(event)): + if not is_gpu_compute_event(event): continue sorted_gpu_events.append(event) return sorted(sorted_gpu_events, key=lambda x: x["ts"]) @@ -102,7 +102,7 @@ def get_sorted_gpu_mm_conv_events(events): gpu_events = get_sorted_gpu_events(events) sorted_events = [] for event in gpu_events: - if(not is_mm_conv_event(event)): + if not is_mm_conv_event(event): continue sorted_events.append(event) return sorted_events diff --git a/torch/_higher_order_ops/out_dtype.py b/torch/_higher_order_ops/out_dtype.py index 3c53f1b1375..bfce30bb06d 100644 --- a/torch/_higher_order_ops/out_dtype.py +++ b/torch/_higher_order_ops/out_dtype.py @@ -55,7 +55,7 @@ class OutDtypeOperator(HigherOrderOperator): raise ValueError("out_dtype's first argument must be an OpOverload") if op._schema.is_mutable: raise ValueError("out_dtype's first argument needs to be a functional operator") - if not( + if not ( len(op._schema.returns) == 1 and isinstance(op._schema.returns[0].type, torch.TensorType) ): diff --git a/torch/ao/nn/quantized/modules/functional_modules.py b/torch/ao/nn/quantized/modules/functional_modules.py index 56778c59217..96408457a44 100644 --- a/torch/ao/nn/quantized/modules/functional_modules.py +++ b/torch/ao/nn/quantized/modules/functional_modules.py @@ -240,7 +240,7 @@ class QFunctional(torch.nn.Module): @classmethod def from_float(cls, mod): - assert type(mod) == FloatFunctional,\ + assert type(mod) == FloatFunctional, \ "QFunctional.from_float expects an instance of FloatFunctional" scale, zero_point = mod.activation_post_process.calculate_qparams() # type: ignore[operator] new_mod = QFunctional() diff --git a/torch/ao/nn/sparse/quantized/utils.py b/torch/ao/nn/sparse/quantized/utils.py index b7afcf186de..3d934f57857 100644 --- a/torch/ao/nn/sparse/quantized/utils.py +++ b/torch/ao/nn/sparse/quantized/utils.py @@ -22,7 +22,7 @@ class LinearBlockSparsePattern: prev_col_block_size = 4 def __init__(self, row_block_size=1, col_block_size=4): - assert(_is_valid_linear_block_sparse_pattern(row_block_size, col_block_size)) + assert _is_valid_linear_block_sparse_pattern(row_block_size, col_block_size) LinearBlockSparsePattern.rlock.acquire() LinearBlockSparsePattern.prev_row_block_size = LinearBlockSparsePattern.row_block_size LinearBlockSparsePattern.prev_col_block_size = LinearBlockSparsePattern.col_block_size diff --git a/torch/ao/quantization/backend_config/onednn.py b/torch/ao/quantization/backend_config/onednn.py index 8c14637ae3d..6eab945f7d7 100644 --- a/torch/ao/quantization/backend_config/onednn.py +++ b/torch/ao/quantization/backend_config/onednn.py @@ -85,7 +85,7 @@ def _fuse_linear_bn_leaky_relu(is_qat, linear, bn, leaky_relu): >>> lr = nn.LeakyReLU(0.01) >>> m2 = _fuse_linear_bn_leaky_relu(m1, b1, lr) """ - assert(linear.training == bn.training and bn.training == leaky_relu.training),\ + assert linear.training == bn.training and bn.training == leaky_relu.training, \ "Linear, BN and LeakyReLU all must be in the same mode (train or eval)." if is_qat: diff --git a/torch/ao/quantization/experimental/observer.py b/torch/ao/quantization/experimental/observer.py index 85b29c84007..76a63815bdc 100644 --- a/torch/ao/quantization/experimental/observer.py +++ b/torch/ao/quantization/experimental/observer.py @@ -58,8 +58,8 @@ class APoTObserver(ObserverBase): alpha = torch.max(-self.min_val, self.max_val) # check for valid inputs of b, k - assert(self.k and self.k != 0) - assert(self.b % self.k == 0) + assert self.k and self.k != 0 + assert self.b % self.k == 0 # compute n and store as member variable self.n = self.b // self.k diff --git a/torch/ao/quantization/fake_quantize.py b/torch/ao/quantization/fake_quantize.py index d616afa557d..e0a169956fd 100644 --- a/torch/ao/quantization/fake_quantize.py +++ b/torch/ao/quantization/fake_quantize.py @@ -274,7 +274,7 @@ class FixedQParamsFakeQuantize(FakeQuantize): # TODO: rename observer to observer_ctr def __init__(self, observer): super().__init__(observer=observer) - assert type(self.activation_post_process) == FixedQParamsObserver,\ + assert type(self.activation_post_process) == FixedQParamsObserver, \ f"{self.__class__.__name__}'s observer must be a {FixedQParamsObserver.__name__}" self._observer_ctr = observer self.scale = self.activation_post_process.scale @@ -322,7 +322,7 @@ class FusedMovingAvgObsFakeQuantize(FakeQuantize): **observer_kwargs: Any ) -> None: super().__init__(observer, quant_min, quant_max, **observer_kwargs) - assert isinstance(self.activation_post_process, (MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver)),\ + assert isinstance(self.activation_post_process, (MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver)), \ "Fused observer+fake_quant module only works with MovingAverageMinMaxObserver" self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long)) self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long)) diff --git a/torch/ao/quantization/fuser_method_mappings.py b/torch/ao/quantization/fuser_method_mappings.py index 7381c057141..a160346d8a8 100644 --- a/torch/ao/quantization/fuser_method_mappings.py +++ b/torch/ao/quantization/fuser_method_mappings.py @@ -31,7 +31,7 @@ def fuse_conv_bn(is_qat, conv, bn): >>> # xdoctest: +SKIP >>> m2 = fuse_conv_bn(m1, b1) """ - assert(conv.training == bn.training),\ + assert conv.training == bn.training, \ "Conv and BN both must be in the same mode (train or eval)." fused_module_class_map = { @@ -71,7 +71,7 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu): >>> # xdoctest: +SKIP >>> m2 = fuse_conv_bn_relu(m1, b1, r1) """ - assert(conv.training == bn.training == relu.training),\ + assert conv.training == bn.training == relu.training, \ "Conv and BN both must be in the same mode (train or eval)." fused_module : Optional[Type[nn.Sequential]] = None if is_qat: @@ -118,14 +118,14 @@ def fuse_linear_bn(is_qat, linear, bn): >>> # xdoctest: +SKIP >>> m2 = fuse_linear_bn(m1, b1) """ - assert(linear.training == bn.training),\ + assert linear.training == bn.training, \ "Linear and BN both must be in the same mode (train or eval)." if is_qat: - assert bn.num_features == linear.out_features,\ + assert bn.num_features == linear.out_features, \ "Output features of Linear must match num_features of BatchNorm1d" assert bn.affine, "Only support fusing BatchNorm1d with affine set to True" - assert bn.track_running_stats,\ + assert bn.track_running_stats, \ "Only support fusing BatchNorm1d with tracking_running_stats set to True" return nni.LinearBn1d(linear, bn) else: @@ -147,7 +147,7 @@ def fuse_convtranspose_bn(is_qat, convt, bn): >>> # xdoctest: +SKIP >>> m2 = fuse_convtranspose_bn(m1, b1) """ - assert(convt.training == bn.training),\ + assert convt.training == bn.training, \ "ConvTranspose and BN both must be in the same mode (train or eval)." if is_qat: diff --git a/torch/ao/quantization/fx/_equalize.py b/torch/ao/quantization/fx/_equalize.py index 883ddf19682..55bcb52576b 100644 --- a/torch/ao/quantization/fx/_equalize.py +++ b/torch/ao/quantization/fx/_equalize.py @@ -291,24 +291,24 @@ def get_op_node_and_weight_eq_obs( op_node = user break - assert(op_node is not None) + assert op_node is not None if op_node.op == 'call_module': # If the op_node is a nn.Linear layer, then it must have a # WeightEqualizationObserver configuration maybe_equalization_node_name_to_config = _get_observed_graph_module_attr(model, "equalization_node_name_to_qconfig") assert maybe_equalization_node_name_to_config is not None equalization_node_name_to_qconfig: Dict[str, Any] = maybe_equalization_node_name_to_config # type: ignore[assignment] - assert(equalization_node_name_to_qconfig.get(op_node.name, None) is not None) + assert equalization_node_name_to_qconfig.get(op_node.name, None) is not None weight_eq_obs = equalization_node_name_to_qconfig.get(op_node.name, None).weight() - assert(isinstance(weight_eq_obs, _WeightEqualizationObserver)) + assert isinstance(weight_eq_obs, _WeightEqualizationObserver) return op_node, weight_eq_obs elif op_node.op == 'call_function': weight_node = maybe_get_weight_eq_obs_node(op_node, modules) if weight_node is not None: weight_eq_obs = modules[str(weight_node.target)] - assert(isinstance(weight_eq_obs, _WeightEqualizationObserver)) + assert isinstance(weight_eq_obs, _WeightEqualizationObserver) return op_node, weight_eq_obs return None, None @@ -316,10 +316,10 @@ def get_op_node_and_weight_eq_obs( def maybe_get_weight_eq_obs_node(op_node: Node, modules: Dict[str, nn.Module]) -> Optional[Node]: """ Gets the weight equalization observer node if it exists. """ - assert(op_node.op == 'call_function') + assert op_node.op == 'call_function' for node_arg in op_node.args: if node_arg_is_weight(op_node, node_arg): - assert(isinstance(node_arg, Node) and node_arg.op == 'call_module' and + assert (isinstance(node_arg, Node) and node_arg.op == 'call_module' and isinstance(modules[str(node_arg.target)], _WeightEqualizationObserver)) return node_arg return None @@ -342,7 +342,7 @@ def maybe_get_next_input_eq_obs(node: Node, modules: Dict[str, nn.Module]) -> Op the following equalization observer for linear2. """ - assert(node_supports_equalization(node, modules)) + assert node_supports_equalization(node, modules) # Locate the following nn.ReLU or F.relu node if it exists maybe_relu_node = maybe_get_next_module(node, modules, nn.ReLU) @@ -364,7 +364,7 @@ def maybe_get_next_input_eq_obs(node: Node, modules: Dict[str, nn.Module]) -> Op return None maybe_eq_obs = modules[str(maybe_eq_obs_node)] - assert(isinstance(maybe_eq_obs, _InputEqualizationObserver)) + assert isinstance(maybe_eq_obs, _InputEqualizationObserver) return maybe_eq_obs def maybe_get_next_equalization_scale(node: Node, modules: Dict[str, nn.Module]) -> Optional[torch.Tensor]: @@ -389,10 +389,10 @@ def scale_input_observer(node: Node, modules: Dict[str, nn.Module]) -> None: equalization observer """ input_eq_obs = modules[str(node.target)] - assert(isinstance(input_eq_obs, _InputEqualizationObserver)) + assert isinstance(input_eq_obs, _InputEqualizationObserver) input_quant_obs_node = node.args[0] - assert(isinstance(input_quant_obs_node, Node)) + assert isinstance(input_quant_obs_node, Node) input_quant_obs = modules[str(input_quant_obs_node.target)] if not isinstance(input_quant_obs, ObserverBase): @@ -426,12 +426,12 @@ def scale_weight_node( op_module = modules[str(node.target)][0] # type: ignore[index] else: op_module = modules[str(node.target)] - assert(nn_module_supports_equalization(op_module) or custom_module_supports_equalization(op_module)) + assert nn_module_supports_equalization(op_module) or custom_module_supports_equalization(op_module) # Scale the weights for input-weight equalization # If the following layer needs to be equalized then we will multiply its scale weight = op_module.weight - assert(isinstance(weight, torch.Tensor)) + assert isinstance(weight, torch.Tensor) # Scale the weights by the reciprocal of the equalization scale # Reshape the equalization scale so that we can multiply it to the weight along axis=1 @@ -453,7 +453,7 @@ def scale_weight_node( bias = op_module.bias if bias is None: return - assert(isinstance(bias, torch.Tensor)) + assert isinstance(bias, torch.Tensor) # Reshape the equalization scale so that we can multiply it element-wise to the bias next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias) @@ -487,14 +487,14 @@ def scale_weight_functional( weight_quant_obs_node = weight_eq_obs_node.args[0] if weight_quant_obs_node is None: return - assert(isinstance(weight_quant_obs_node, Node) and + assert (isinstance(weight_quant_obs_node, Node) and isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase)) # Get the get_attr(weight) node weight_node = weight_quant_obs_node.args[0] if weight_node is None: return - assert(isinstance(weight_node, Node) and weight_node.op == 'get_attr') + assert isinstance(weight_node, Node) and weight_node.op == 'get_attr' weight_parent_name, weight_name = _parent_name(weight_node.target) weight = getattr(modules[weight_parent_name], weight_name) @@ -515,7 +515,7 @@ def scale_weight_functional( scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped) setattr(modules[weight_parent_name], weight_name, scaled_weight) - assert(torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight)) + assert torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight) # Multiply the bias element wise by the next equalization scale bias_node = None @@ -546,10 +546,10 @@ def clear_weight_quant_obs_node(op_node: Node, modules: Dict[str, nn.Module]) -> weight_quant_obs_node = weight_eq_obs_node.args[0] if weight_quant_obs_node is None: return - assert(isinstance(weight_quant_obs_node, Node)) + assert isinstance(weight_quant_obs_node, Node) weight_quant_obs = modules[str(weight_quant_obs_node.target)] - assert(isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase)) + assert isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase) weight_quant_obs.reset_min_max_vals() # type: ignore[operator] def remove_node(model: GraphModule, node: Node, prev_node: Node): @@ -578,7 +578,7 @@ def update_obs_for_equalization(model: GraphModule, modules: Dict[str, nn.Module for node in model.graph.nodes: if node.op == 'call_module' and isinstance(modules[node.target], _InputEqualizationObserver): input_eq_obs = modules[node.target] - assert(isinstance(input_eq_obs, _InputEqualizationObserver)) + assert isinstance(input_eq_obs, _InputEqualizationObserver) op_node, weight_eq_obs = get_op_node_and_weight_eq_obs(node, model, modules) if op_node is None or weight_eq_obs is None: @@ -589,7 +589,7 @@ def update_obs_for_equalization(model: GraphModule, modules: Dict[str, nn.Module # been created if fused_module_supports_equalization(modules[str(op_node.target)]): module = modules[str(op_node.target)][0] # type: ignore[index] - assert(nn_module_supports_equalization(module)) + assert nn_module_supports_equalization(module) weight_eq_obs(module.weight) else: weight_eq_obs(modules[str(op_node.target)].weight) @@ -696,7 +696,7 @@ def convert_eq_obs( elif weight_eq_obs_dict.get(node.name, None) is not None: weight_eq_obs = weight_eq_obs_dict.get(node.name) - assert(isinstance(weight_eq_obs, _WeightEqualizationObserver)) + assert isinstance(weight_eq_obs, _WeightEqualizationObserver) equalization_scale = weight_eq_obs.equalization_scale if equalization_scale.nelement() == 1 and equalization_scale == torch.tensor(1): @@ -712,7 +712,7 @@ def convert_eq_obs( weight_eq_obs_node = maybe_get_weight_eq_obs_node(node, modules) if weight_eq_obs_node is None: return - assert(isinstance(modules[str(weight_eq_obs_node.target)], _WeightEqualizationObserver)) + assert isinstance(modules[str(weight_eq_obs_node.target)], _WeightEqualizationObserver) # Clear the quantization observer's min/max values so that they # can get updated later based on the new scale values diff --git a/torch/ao/quantization/fx/_lower_to_native_backend.py b/torch/ao/quantization/fx/_lower_to_native_backend.py index b562a9e8398..728506037b5 100644 --- a/torch/ao/quantization/fx/_lower_to_native_backend.py +++ b/torch/ao/quantization/fx/_lower_to_native_backend.py @@ -487,14 +487,14 @@ def _match_static_pattern( return SKIP_LOWERING_VALUE q_node = node ref_node = q_node.args[0] - assert(isinstance(ref_node, Node)) + assert isinstance(ref_node, Node) # Handle cases where the node is wrapped in a ReLU if (ref_node.op == "call_function" and ref_node.target in (F.relu, torch.relu)) or\ (ref_node.op == "call_module" and type(_get_module(ref_node, modules)) == nn.ReLU): relu_node = ref_node ref_node = relu_node.args[0] - assert(isinstance(ref_node, Node)) + assert isinstance(ref_node, Node) else: relu_node = None if should_skip_lowering(ref_node, qconfig_map): @@ -515,7 +515,7 @@ def _match_static_pattern( # (2) There must be at least one dequantize node matched_dequantize = False for i in dequantize_node_arg_indices: - assert i < len(ref_node.args),\ + assert i < len(ref_node.args), \ f"Dequantize index {i} exceeded reference node's arg length {len(ref_node.args)}" arg = ref_node.args[i] if is_dequantize_node(arg): @@ -557,7 +557,7 @@ def _match_static_pattern_with_two_inputs( return SKIP_LOWERING_VALUE q_node = node ref_node = q_node.args[0] - assert(isinstance(ref_node, Node)) + assert isinstance(ref_node, Node) if should_skip_lowering(ref_node, qconfig_map): return SKIP_LOWERING_VALUE @@ -599,13 +599,13 @@ def _lower_static_weighted_ref_module( n, modules, qconfig_map, matching_modules, dequantize_node_arg_indices=[0]) # type: ignore[arg-type] if q_node is None: continue - assert(ref_node is not None) + assert ref_node is not None (_, scale_node, zero_point_node, _) = q_node.args ref_module = _get_module(ref_node, modules) ref_class = type(ref_module) - assert(isinstance(scale_node, Node)) - assert(isinstance(zero_point_node, Node)) - assert(issubclass(ref_class, nn.Module)) + assert isinstance(scale_node, Node) + assert isinstance(zero_point_node, Node) + assert issubclass(ref_class, nn.Module) # Step 1: Change this pattern to use the corresponding quantized module # For fused modules, we also check whether the inner module is a reference module @@ -624,9 +624,9 @@ def _lower_static_weighted_ref_module( setattr(modules[parent_name], module_name, q_module) # Step 2: Reroute around dq_node, and remove q_node and its args - assert(len(ref_node.args) == 1) + assert len(ref_node.args) == 1 dq_node = ref_node.args[0] - assert(isinstance(dq_node, Node)) + assert isinstance(dq_node, Node) ref_node.replace_input_with(dq_node, dq_node.args[0]) q_node.replace_all_uses_with(ref_node) model.graph.erase_node(q_node) @@ -655,13 +655,13 @@ def _lower_static_weighted_ref_module_with_two_inputs( n, modules, qconfig_map, matching_modules) # type: ignore[arg-type] if q_node is None: continue - assert(ref_node is not None) + assert ref_node is not None (_, scale_node, zero_point_node, _) = q_node.args ref_module = _get_module(ref_node, modules) ref_class = type(ref_module) - assert(isinstance(scale_node, Node)) - assert(isinstance(zero_point_node, Node)) - assert(issubclass(ref_class, nn.Module)) + assert isinstance(scale_node, Node) + assert isinstance(zero_point_node, Node) + assert issubclass(ref_class, nn.Module) # Step 1: Change this pattern to use the corresponding quantized module # For fused modules, we also check whether the inner module is a reference module @@ -680,12 +680,12 @@ def _lower_static_weighted_ref_module_with_two_inputs( setattr(modules[parent_name], module_name, q_module) # Step 2: Reroute around dq_node, and remove q_node and its args - assert(len(ref_node.args) == 2) + assert len(ref_node.args) == 2 for arg in ref_node.args: if not is_dequantize_node(arg): continue dq_node = arg - assert(isinstance(dq_node, Node)) + assert isinstance(dq_node, Node) ref_node.replace_input_with(dq_node, dq_node.args[0]) q_node.replace_all_uses_with(ref_node) @@ -778,14 +778,14 @@ def _lower_static_weighted_ref_functional( n, modules, qconfig_map, matching_ops, dequantize_node_arg_indices=[0, 1]) if q_node is None: continue - assert(func_node is not None) + assert func_node is not None (_, output_scale_node, output_zp_node, _) = q_node.args (input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args - assert(isinstance(output_zp_node, Node)) - assert(isinstance(input_dq_node, Node)) - assert(isinstance(weight_dq_node, Node)) + assert isinstance(output_zp_node, Node) + assert isinstance(input_dq_node, Node) + assert isinstance(weight_dq_node, Node) quantized_weight = weight_dq_node.args[0] - assert(isinstance(quantized_weight, Node)) + assert isinstance(quantized_weight, Node) if quantized_weight.op != "call_function" or\ quantized_weight.target not in (torch.quantize_per_tensor, torch.quantize_per_channel): continue @@ -966,7 +966,7 @@ def _lower_quantized_binary_op( n, modules, qconfig_map, binary_ops_to_lower, dequantize_node_arg_indices=[0, 1]) if q_node is None: continue - assert(bop_node is not None) + assert bop_node is not None (_, scale_node, zero_point_node, _) = q_node.args # Step 1: Remove dequant nodes @@ -975,11 +975,11 @@ def _lower_quantized_binary_op( if not is_dequantize_node(arg): continue dq_node = arg - assert(isinstance(dq_node, Node)) + assert isinstance(dq_node, Node) dn_input = dq_node.args[0] bop_node.replace_input_with(dq_node, dn_input) num_dq_nodes += 1 - assert(num_dq_nodes > 0) + assert num_dq_nodes > 0 # Step 2: Swap binary op to quantized binary op assert bop_node.target in QBIN_OP_MAPPING diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py index 73779fb6467..11fb2c9a892 100644 --- a/torch/ao/quantization/fx/convert.py +++ b/torch/ao/quantization/fx/convert.py @@ -956,7 +956,7 @@ def convert( "in a future version. Please pass in a QConfigMapping instead.") qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None qconfig_mapping = copy.deepcopy(qconfig_mapping) - assert(qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping)) + assert qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping) if isinstance(backend_config, Dict): warnings.warn( diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py index 59c9bec8c57..a8db114b2b4 100644 --- a/torch/ao/quantization/fx/prepare.py +++ b/torch/ao/quantization/fx/prepare.py @@ -1585,7 +1585,7 @@ def insert_observers_for_model( # should resolve this inconsistency by inserting DeQuantStubs for all custom # modules, not just for LSTM. _insert_dequant_stubs_for_custom_module_lstm_output(node, model, named_modules, model.graph) - if(node.target not in custom_module_names_already_swapped): + if node.target not in custom_module_names_already_swapped: custom_module_names_already_swapped.add(node.target) _swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config) else: @@ -1627,7 +1627,7 @@ def insert_observers_for_model( _remove_output_observer(node, model, named_modules) if qhandler is not None and qhandler.is_custom_module(): - if(node.target not in custom_module_names_already_swapped): + if node.target not in custom_module_names_already_swapped: custom_module_names_already_swapped.add(node.target) _swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config) @@ -1771,8 +1771,8 @@ def prepare( "in a future version. Please pass in a BackendConfig instead.") backend_config = BackendConfig.from_dict(backend_config) - assert(isinstance(qconfig_mapping, QConfigMapping)) - assert(isinstance(_equalization_config, QConfigMapping)) + assert isinstance(qconfig_mapping, QConfigMapping) + assert isinstance(_equalization_config, QConfigMapping) qconfig_mapping = copy.deepcopy(qconfig_mapping) _equalization_config = copy.deepcopy(_equalization_config) diff --git a/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py b/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py index fa1eded53b5..0a7999a4c26 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py @@ -26,7 +26,7 @@ def binary_cmp(cmp_fun, types, args, kwargs=None, process_group=None): result = True st1 = args[0] st2 = args[1] - if not(isinstance(st1, ShardedTensor) and isinstance(st2, ShardedTensor)): + if not (isinstance(st1, ShardedTensor) and isinstance(st2, ShardedTensor)): raise TypeError(f'Both arguments to torch.{cmp_fun.__name__} need to be of type ShardedTensor') # Verify same PG diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py b/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py index d866117ba14..439e3d6195e 100644 --- a/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py +++ b/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py @@ -403,7 +403,7 @@ def generate_calc_product(constraint, counter): for p in all_possibilities: p = list(p) # this tells us there is a dynamic variable - contains_dyn = not(all(constraint.op == op_neq for constraint in p)) + contains_dyn = not all(constraint.op == op_neq for constraint in p) if contains_dyn: mid_var = [Dyn] total_constraints = lhs + mid_var + rhs diff --git a/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py b/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py index c5c6436c2bc..15af0241ec5 100644 --- a/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py +++ b/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py @@ -311,7 +311,7 @@ try: negation_transformed_condition_constraint = z3.Not(transformed_condition_constraint) - return z3.And([transformed, transformed_condition_constraint]),\ + return z3.And([transformed, transformed_condition_constraint]), \ z3.And([transformed, negation_transformed_condition_constraint]) diff --git a/torch/fx/experimental/optimization.py b/torch/fx/experimental/optimization.py index fc7cbbb22fa..e62b2836fff 100644 --- a/torch/fx/experimental/optimization.py +++ b/torch/fx/experimental/optimization.py @@ -42,7 +42,7 @@ def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module): - assert(isinstance(node.target, str)) + assert isinstance(node.target, str) parent_name, name = _parent_name(node.target) modules[node.target] = new_module setattr(modules[parent_name], name, new_module) @@ -133,11 +133,11 @@ def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]): old_modules: Dict[nn.Module, nn.Module] = {} for node in nodes: if node.op == 'call_module': - assert(isinstance(node.target, str)) + assert isinstance(node.target, str) cur_module = modules[node.target] if type(cur_module) in mkldnn_map: new_module = mkldnn_map[type(cur_module)](cur_module, torch.float) - assert(isinstance(new_module, nn.Module)) + assert isinstance(new_module, nn.Module) old_modules[new_module] = copy.deepcopy(cur_module) replace_node_module(node, modules, new_module) return old_modules @@ -149,7 +149,7 @@ def reset_modules(nodes: List[fx.Node], modules: Dict[str, nn.Module], old_modul """ for node in nodes: if node.op == 'call_module': - assert(isinstance(node.target, str)) + assert (isinstance(node.target, str)) cur_module = modules[node.target] if cur_module in old_modules: replace_node_module(node, modules, old_modules[cur_module]) @@ -220,7 +220,7 @@ class UnionFind: par = self.parent[v] if v == par: return v - assert(par is not None) + assert par is not None self.parent[v] = self.find(par) return cast(int, self.parent[v]) @@ -294,8 +294,8 @@ def optimize_for_inference( supports_mkldnn = MklSupport.YES sample_parameter = next(cur_module.parameters(), None) if sample_parameter is not None: - assert(sample_parameter.dtype == torch.float), "this pass is only for torch.float modules" - assert(sample_parameter.device == torch.device('cpu')), "this pass is only for CPU modules" + assert sample_parameter.dtype == torch.float, "this pass is only for torch.float modules" + assert sample_parameter.device == torch.device('cpu'), "this pass is only for CPU modules" elif node.op == 'call_function': if node.target in mkldnn_supported: supports_mkldnn = MklSupport.YES @@ -360,14 +360,14 @@ def optimize_for_inference( node.start_color = cur_idx uf.make_set(cur_idx) elif node.op == 'call_method' and node.target == 'to_dense': - assert(get_color(node.args[0]) is not None) + assert get_color(node.args[0]) is not None node.end_color = get_color(node.args[0]) else: cur_colors = [get_color(i) for i in node.all_input_nodes if isinstance(i, fx.Node) if get_color(i) is not None] if len(cur_colors) == 0: continue - assert(not any(i is None for i in cur_colors)) + assert not any(i is None for i in cur_colors) cur_colors = sorted(cur_colors) node.color = cur_colors[0] for other_color in cur_colors[1:]: diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 8f40b277b5d..5c02aa7cfa2 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1359,7 +1359,7 @@ class DimConstraints: self.raise_inconsistencies() # as long as there are symbols with equalities, solve for them # NOTE(avik): this is guaranteed to terminate (#iterations <= #symbols) - while(self._symbols_with_equalities): + while self._symbols_with_equalities: s = self._symbols_with_equalities.pop() exprs = self._univariate_inequalities.pop(s) solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 89e61f1444d..ffc731d62d6 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -670,7 +670,7 @@ class _PyTreeCodeGen(CodeGen): return out if not isinstance(out, (list, tuple)): out = [out] - assert(self.pytree_info.out_spec is not None) + assert self.pytree_info.out_spec is not None return pytree.tree_unflatten(out, self.pytree_info.out_spec) def gen_fn_def(self, free_vars, maybe_return_annotation): diff --git a/torch/fx/node.py b/torch/fx/node.py index 89969c1e0d5..0c7b8fe9865 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -82,7 +82,7 @@ def _type_repr(obj): return obj.__qualname__ return f'{obj.__module__}.{obj.__qualname__}' if obj is ...: - return('...') + return '...' if isinstance(obj, types.FunctionType): return obj.__name__ return repr(obj) diff --git a/torch/fx/passes/net_min_base.py b/torch/fx/passes/net_min_base.py index 3fcd096b80e..3790acd3432 100644 --- a/torch/fx/passes/net_min_base.py +++ b/torch/fx/passes/net_min_base.py @@ -698,7 +698,7 @@ class _MinimizerBase: if self.settings.traverse_method == "accumulate": return self._accumulate_traverse(nodes) - if(self.settings.traverse_method == "skip"): + if self.settings.traverse_method == "skip": if (skip_nodes is None): raise RuntimeError("'skip_nodes' can't be None when 'traverse_method' is 'skip'.") return self._skip_traverse(nodes, skip_nodes) diff --git a/torch/fx/passes/split_module.py b/torch/fx/passes/split_module.py index d175d6d1de8..2a798426ed1 100644 --- a/torch/fx/passes/split_module.py +++ b/torch/fx/passes/split_module.py @@ -281,7 +281,7 @@ def split_module( if assert_monotonically_increasing: pid = split_callback(node) - assert highest_partition <= pid,\ + assert highest_partition <= pid, \ ("autocast or set_grad_enabled require monotonically increasing partitions:" f"highest: {highest_partition}, this node's: {pid}") highest_partition = pid diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index 66b785b8b03..1054b195e82 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -510,7 +510,7 @@ class ParameterProxy(Proxy): """ def __init__(self, tracer: TracerBase, node: Node, name, param): super().__init__(node, tracer) - assert(isinstance(param, torch.nn.Parameter)) + assert isinstance(param, torch.nn.Parameter) self.param = param self.name = name diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index 3062e9a9a00..400d2b1bef6 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -817,7 +817,7 @@ class StmtBuilder(Builder): if is_torch_jit_ignore_context_manager(stmt): if not _IS_ASTUNPARSE_INSTALLED: raise RuntimeError( - "torch.jit._IgnoreContextManager requires installing Python library `astunparse`,\ + "torch.jit._IgnoreContextManager requires installing Python library `astunparse`, \ please install it in your Python environment" ) assign_ast = build_ignore_context_manager(ctx, stmt) diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index a45b4fa627a..c8c05660f15 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -2571,7 +2571,7 @@ new_module_tests = [ .dim_feedforward(8) .dropout(0.0) .activation(torch::kReLU)''', - input_fn=lambda:(torch.rand(3, 3, 4), torch.rand(2, 3, 4), torch.rand(3, 3)), + input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4), torch.rand(3, 3)), check_gradgrad=False, desc='multilayer_coder', with_tf32=True, diff --git a/torch/testing/_internal/common_quantized.py b/torch/testing/_internal/common_quantized.py index 0db312f2c20..bd02761d3a2 100644 --- a/torch/testing/_internal/common_quantized.py +++ b/torch/testing/_internal/common_quantized.py @@ -125,7 +125,7 @@ def _snr(x, x_hat): signal, noise, SNR(in dB): Either floats or a nested list of floats """ if isinstance(x, (list, tuple)): - assert(len(x) == len(x_hat)) + assert len(x) == len(x_hat) res = [] for idx in range(len(x)): res.append(_snr(x[idx], x_hat[idx])) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 229142bfa51..ee67c343487 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1352,7 +1352,7 @@ def skipIfTorchDynamo(msg="test doesn't currently work with dynamo"): fn(*args, **kwargs) return wrapper - assert(isinstance(fn, type)) + assert isinstance(fn, type) if TEST_WITH_TORCHDYNAMO: # noqa: F821 fn.__unittest_skip__ = True fn.__unittest_skip_why__ = msg @@ -1374,7 +1374,7 @@ def skipIfTorchInductor(msg="test doesn't currently work with torchinductor", fn(*args, **kwargs) return wrapper - assert(isinstance(fn, type)) + assert isinstance(fn, type) if condition: fn.__unittest_skip__ = True fn.__unittest_skip_why__ = msg @@ -1440,7 +1440,7 @@ def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT exe fn(*args, **kwargs) return wrapper - assert(isinstance(fn, type)) + assert isinstance(fn, type) if GRAPH_EXECUTOR == ProfilingMode.LEGACY: fn.__unittest_skip__ = True fn.__unittest_skip_why__ = msg @@ -2080,7 +2080,7 @@ class CudaMemoryLeakCheck: if driver_mem_allocated > self.driver_befores[i]: driver_discrepancy = True - if not(caching_allocator_discrepancy or driver_discrepancy): + if not (caching_allocator_discrepancy or driver_discrepancy): # Leak was false positive, exit loop discrepancy_detected = False break diff --git a/torch/testing/_internal/hypothesis_utils.py b/torch/testing/_internal/hypothesis_utils.py index 0654a64b96b..f243e43007d 100644 --- a/torch/testing/_internal/hypothesis_utils.py +++ b/torch/testing/_internal/hypothesis_utils.py @@ -159,10 +159,10 @@ Example: @st.composite def array_shapes(draw, min_dims=1, max_dims=None, min_side=1, max_side=None, max_numel=None): """Return a strategy for array shapes (tuples of int >= 1).""" - assert(min_dims < 32) + assert min_dims < 32 if max_dims is None: max_dims = min(min_dims + 2, 32) - assert(max_dims < 32) + assert max_dims < 32 if max_side is None: max_side = min_side + 5 candidate = st.lists(st.integers(min_side, max_side), min_size=min_dims, max_size=max_dims) @@ -334,7 +334,7 @@ def tensor_conv( # Resolve the tensors if qparams is not None: if isinstance(qparams, (list, tuple)): - assert(len(qparams) == 3), "Need 3 qparams for X, w, b" + assert len(qparams) == 3, "Need 3 qparams for X, w, b" else: qparams = [qparams] * 3 diff --git a/torch/testing/_internal/jit_metaprogramming_utils.py b/torch/testing/_internal/jit_metaprogramming_utils.py index 82c1f7846ce..fe7927edf66 100644 --- a/torch/testing/_internal/jit_metaprogramming_utils.py +++ b/torch/testing/_internal/jit_metaprogramming_utils.py @@ -331,7 +331,7 @@ def value_to_literal(value): def get_call(method_name, func_type, args, kwargs): kwargs_str = ', '.join([k + '=' + value_to_literal(v) for k, v in kwargs.items()]) self_arg = args[0] - if(func_type == 'method'): + if func_type == 'method': args = args[1:] argument_str = ', '.join(args) diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index 0680d8a8597..434a88ab292 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -5,7 +5,7 @@ import warnings from dataclasses import dataclass import torch import torchgen -from torch._C import _len_torch_dispatch_stack, _get_dispatch_stack_at,\ +from torch._C import _len_torch_dispatch_stack, _get_dispatch_stack_at, \ _pop_torch_dispatch_stack, _push_on_torch_dispatch_stack, DispatchKey diff --git a/torch/utils/bundled_inputs.py b/torch/utils/bundled_inputs.py index d67e4481109..2a95c782884 100644 --- a/torch/utils/bundled_inputs.py +++ b/torch/utils/bundled_inputs.py @@ -117,10 +117,10 @@ def bundle_inputs( # Fortunately theres a function in _recursive that does exactly that conversion. cloned_module = wrap_cpp_module(clone) if isinstance(inputs, dict): - assert(isinstance(info, dict) or info is None) + assert isinstance(info, dict) or info is None augment_many_model_functions_with_bundled_inputs(cloned_module, inputs, _receive_inflate_expr, info) else: - assert(isinstance(info, list) or info is None) + assert isinstance(info, list) or info is None augment_model_with_bundled_inputs(cloned_module, inputs, _receive_inflate_expr, info) return cloned_module diff --git a/torch/utils/collect_env.py b/torch/utils/collect_env.py index 058e92edb6f..6cbf598156b 100644 --- a/torch/utils/collect_env.py +++ b/torch/utils/collect_env.py @@ -304,7 +304,7 @@ def get_cpu_info(run_lambda): if get_platform() == 'linux': rc, out, err = run_lambda('lscpu') elif get_platform() == 'win32': - rc, out, err = run_lambda('wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID,\ + rc, out, err = run_lambda('wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID, \ CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE') elif get_platform() == 'darwin': rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string") diff --git a/torch/utils/flop_counter.py b/torch/utils/flop_counter.py index 42005b5ad33..76c7fb2b0d9 100644 --- a/torch/utils/flop_counter.py +++ b/torch/utils/flop_counter.py @@ -329,7 +329,7 @@ class FlopCounterMode(TorchDispatchMode): class PushState(torch.autograd.Function): @staticmethod def forward(ctx, *args): - assert(self.parents[-1] == name) + assert self.parents[-1] == name self.parents.pop() args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args) return args @@ -351,7 +351,7 @@ class FlopCounterMode(TorchDispatchMode): @staticmethod def backward(ctx, *grad_outs): - assert(self.parents[-1] == name) + assert self.parents[-1] == name self.parents.pop() return grad_outs diff --git a/torch/utils/mkldnn.py b/torch/utils/mkldnn.py index 35541ebb5f2..2d1d8cd89ff 100644 --- a/torch/utils/mkldnn.py +++ b/torch/utils/mkldnn.py @@ -136,9 +136,9 @@ class MkldnnBatchNorm(torch.jit.ScriptModule): def __init__(self, dense_module): super().__init__() - assert(not dense_module.training) - assert(dense_module.track_running_stats) - assert(dense_module.affine) + assert not dense_module.training + assert dense_module.track_running_stats + assert dense_module.affine if dense_module.momentum is None: self.exponential_average_factor = 0.0