[BE]: Update flake8 to v6.1.0 and fix lints (#116591)

Updates flake8 to v6.1.0 and fixes a few lints using sed and some ruff tooling.
- Replace `assert(0)` with `raise AssertionError()`
- Remove extraneous parenthesis i.e.
  - `assert(a == b)` -> `assert a == b`
  - `if(x > y or y < z):`->`if x > y or y < z:`
  - And `return('...')` -> `return '...'`

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116591
Approved by: https://github.com/albanD, https://github.com/malfet
This commit is contained in:
Aaron Gokaslan 2024-01-03 06:04:44 +00:00 committed by PyTorch MergeBot
parent 09ee96b69d
commit 3fe437b24b
62 changed files with 188 additions and 190 deletions

View File

@ -8,8 +8,6 @@ max-line-length = 120
# E501 is not flexible enough, we're using B950 instead
ignore =
E203,E305,E402,E501,E721,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303,
# fix these lints in the future
E275,
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
# to line this up with executable bit
EXE001,

View File

@ -38,7 +38,7 @@ init_command = [
'python3',
'tools/linter/adapters/pip_init.py',
'--dry-run={{DRYRUN}}',
'flake8==6.0.0',
'flake8==6.1.0',
'flake8-bugbear==23.3.23',
'flake8-comprehensions==3.12.0',
'flake8-executable==2.1.3',
@ -46,8 +46,8 @@ init_command = [
'flake8-pyi==23.3.1',
'flake8-simplify==0.19.3',
'mccabe==0.7.0',
'pycodestyle==2.10.0',
'pyflakes==3.0.1',
'pycodestyle==2.11.1',
'pyflakes==3.1.0',
'torchfix==0.2.0',
]

View File

@ -27,7 +27,7 @@ def loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values):
for idx in range(batch_size):
flat_args, args_spec = pytree.tree_flatten(batched_args)
flat_dims, dims_spec = pytree.tree_flatten(in_dims)
assert(args_spec == dims_spec)
assert args_spec == dims_spec
new_args = [a.select(in_dim, idx) if in_dim is not None else a for a, in_dim in zip(flat_args, flat_dims)]
out = op(*pytree.tree_unflatten(new_args, args_spec), **kwarg_values)
flat_out, out_spec = pytree.tree_flatten(out)
@ -45,9 +45,9 @@ def loop2(op, in_dims1, in_dims2, out_dim1, out_dim2, batch_size1, batch_size2,
flat_args, args_spec = pytree.tree_flatten(batched_args)
flat_dims1, dims_spec1 = pytree.tree_flatten(in_dims1)
flat_dims2, dims_spec2 = pytree.tree_flatten(in_dims2)
assert(args_spec == dims_spec1)
assert(args_spec == dims_spec2)
assert(len(flat_dims1) == len(flat_dims2))
assert args_spec == dims_spec1
assert args_spec == dims_spec2
assert len(flat_dims1) == len(flat_dims2)
for idx1 in range(batch_size1):
out_split = []
arg_split = [a.select(in_dim1, idx1) if in_dim1 is not None else a for a, in_dim1 in zip(flat_args, flat_dims1)]

View File

@ -154,9 +154,9 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
output_ref = func(input0, input1)
for i in range(2):
output = jit_f(input0, input1)
assert(output_ref[0].requires_grad == output[0].requires_grad)
assert(output_ref[1][0].requires_grad == output[1][0].requires_grad)
assert(output_ref[1][1].requires_grad == output[1][1].requires_grad)
assert output_ref[0].requires_grad == output[0].requires_grad
assert output_ref[1][0].requires_grad == output[1][0].requires_grad
assert output_ref[1][1].requires_grad == output[1][1].requires_grad
@unittest.skip("disable until we property handle tensor lists with undefined gradients")
def test_differentiable_graph_ops_requires_grad(self):

View File

@ -35,7 +35,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
with warnings.catch_warnings(record=True) as w:
self.checkModule(M(), (1,))
assert(len(w) == 0)
assert len(w) == 0
def test_annotated_nonempty_container(self):
class M(torch.nn.Module):
@ -49,7 +49,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
with warnings.catch_warnings(record=True) as w:
self.checkModule(M(), ([1, 2, 3],))
assert(len(w) == 0)
assert len(w) == 0
def test_annotated_empty_tensor(self):
class M(torch.nn.Module):
@ -63,7 +63,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
with warnings.catch_warnings(record=True) as w:
self.checkModule(M(), (torch.rand(2, 3),))
assert(len(w) == 0)
assert len(w) == 0
def test_annotated_with_jit_attribute(self):
class M(torch.nn.Module):
@ -77,7 +77,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
with warnings.catch_warnings(record=True) as w:
self.checkModule(M(), ([1, 2, 3],))
assert(len(w) == 0)
assert len(w) == 0
def test_annotated_class_level_annotation_only(self):
class M(torch.nn.Module):
@ -94,7 +94,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
with warnings.catch_warnings(record=True) as w:
self.checkModule(M(), ([1, 2, 3],))
assert(len(w) == 0)
assert len(w) == 0
def test_annotated_class_level_annotation_and_init_annotation(self):
@ -112,7 +112,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
with warnings.catch_warnings(record=True) as w:
self.checkModule(M(), ([1, 2, 3],))
assert(len(w) == 0)
assert len(w) == 0
def test_annotated_class_level_jit_annotation(self):
class M(torch.nn.Module):
@ -129,7 +129,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
with warnings.catch_warnings(record=True) as w:
self.checkModule(M(), ([1, 2, 3],))
assert(len(w) == 0)
assert len(w) == 0
def test_annotated_empty_list(self):
class M(torch.nn.Module):

View File

@ -147,7 +147,7 @@ class testVariousModelVersions(TestCase):
def test_get_model_bytecode_version(self):
def check_model_version(model_path, expect_version):
actual_version = _get_model_bytecode_version(model_path)
assert(actual_version == expect_version)
assert actual_version == expect_version
for version, model_info in SCRIPT_MODULE_BYTECODE_PKL.items():
model_path = pytorch_test_dir / "cpp" / "jit" / model_info["model_name"]
check_model_version(model_path, version)
@ -174,7 +174,7 @@ class testVariousModelVersions(TestCase):
current_to_version = current_from_version - 1
backport_success = _backport_for_mobile(input_model_path, tmp_output_model_path_backport, current_to_version)
assert(backport_success)
assert backport_success
expect_bytecode_pkl = SCRIPT_MODULE_BYTECODE_PKL[current_to_version]["bytecode_pkl"]
@ -187,7 +187,7 @@ class testVariousModelVersions(TestCase):
acutal_result_clean = "".join(output.split())
expect_result_clean = "".join(expect_bytecode_pkl.split())
isMatch = fnmatch.fnmatch(acutal_result_clean, expect_result_clean)
assert(isMatch)
assert isMatch
current_from_version -= 1
shutil.rmtree(tmpdirname)
@ -254,7 +254,7 @@ class testVariousModelVersions(TestCase):
script_module_v5_path,
tmp_backport_model_path,
maximum_checked_in_model_version - 1)
assert(success)
assert success
buf = io.StringIO()
torch.utils.show_pickle.main(
@ -266,7 +266,7 @@ class testVariousModelVersions(TestCase):
acutal_result_clean = "".join(output.split())
expect_result_clean = "".join(expected_result.split())
isMatch = fnmatch.fnmatch(acutal_result_clean, expect_result_clean)
assert(isMatch)
assert isMatch
# Load model v4 and run forward method
mobile_module = _load_for_lite_interpreter(str(tmp_backport_model_path))
@ -291,7 +291,7 @@ class testVariousModelVersions(TestCase):
# Check version of the model v4 from backport
bytesio = io.BytesIO(script_module_v4_buffer)
backport_version = _get_model_bytecode_version(bytesio)
assert(backport_version == maximum_checked_in_model_version - 1)
assert backport_version == maximum_checked_in_model_version - 1
# Load model v4 from backport and run forward method
bytesio = io.BytesIO(script_module_v4_buffer)
@ -306,8 +306,8 @@ class testVariousModelVersions(TestCase):
# TODO update this to be more in the style of the above tests after a backport from 6 -> 5 exists
script_module_v6 = pytorch_test_dir / "cpp" / "jit" / "script_module_v6.ptl"
ops_v6 = _get_model_ops_and_info(script_module_v6)
assert(ops_v6["aten::add.int"].num_schema_args == 2)
assert(ops_v6["aten::add.Scalar"].num_schema_args == 2)
assert ops_v6["aten::add.int"].num_schema_args == 2
assert ops_v6["aten::add.Scalar"].num_schema_args == 2
def test_get_mobile_model_contained_types(self):
class MyTestModule(torch.nn.Module):
@ -322,7 +322,7 @@ class testVariousModelVersions(TestCase):
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
type_list = _get_mobile_model_contained_types(buffer)
assert(len(type_list) >= 0)
assert len(type_list) >= 0
if __name__ == '__main__':
run_tests()

View File

@ -82,7 +82,7 @@ class TestLiteScriptModule(TestCase):
buffer = io.BytesIO(exported_module)
buffer.seek(0)
assert(b"callstack_debug_map.pkl" in exported_module)
assert b"callstack_debug_map.pkl" in exported_module
mobile_module = _load_for_lite_interpreter(buffer)
with self.assertRaisesRegex(RuntimeError, r"Module hierarchy:top\(B\)::<unknown>.A0\(A\)::forward.aten::mul"):

View File

@ -65,8 +65,8 @@ class TestDropoutNN(NNTestCase):
o_ref = torch.dropout(x_ref, p, train)
o.sum().backward()
o_ref.sum().backward()
assert(o.equal(o_ref))
assert(x.grad.equal(x_ref.grad))
assert o.equal(o_ref)
assert x.grad.equal(x_ref.grad)
def test_invalid_dropout_p(self):
v = torch.ones(1)

View File

@ -166,7 +166,7 @@ class TestOptim(TestCase):
pass
elif constructor_accepts_maximize:
def four_arg_constructor(weight, bias, maximize, foreach):
def four_arg_constructor(weight, bias, maximize, foreach): # noqa: F811
self.assertFalse(foreach)
return constructor(weight, bias, maximize)

View File

@ -2127,7 +2127,7 @@ class TestQuantizedOps(TestCase):
quantized_out = torch.topk(qX, k, dim=dim, largest=larg, sorted=sort)
assert(len(unquantized_out) == len(quantized_out))
assert len(unquantized_out) == len(quantized_out)
torch.testing.assert_close(quantized_out[0].dequantize(), unquantized_out[0])
torch.testing.assert_close(quantized_out[1], unquantized_out[1])
@ -4101,7 +4101,7 @@ class TestQuantizedLinear(TestCase):
qparams=hu.qparams(dtypes=torch.qint8)))
@override_qengines
def test_qlinear_qnnpack_free_memory_and_unpack(self, W):
assert(qengine_is_qnnpack)
assert qengine_is_qnnpack
W, (W_scale, W_zp, torch_type) = W
qlinear_prepack = torch.ops.quantized.linear_prepack
qlinear_unpack = torch.ops.quantized.linear_unpack

View File

@ -879,7 +879,7 @@ class TestFakeQuantizeOps(TestCase):
Y_prime.backward(dout)
np.testing.assert_allclose(
dX.cpu().detach().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
assert(X.grad.dtype == float_type)
assert X.grad.dtype == float_type
def test_backward_per_channel_cachemask_cpu(self):

View File

@ -839,7 +839,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
return reduce(lambda f, g: lambda x: f(g(x)), functions[::-1], lambda x: x)
if not use_relu:
def relu_op(x):
def relu_op(x): # noqa: F811
return x
if freeze_bn:

View File

@ -4319,7 +4319,7 @@ Done""")
]
for thread, ranges in threads:
for range in ranges:
assert(len(range) == 3)
assert len(range) == 3
events.append(
FunctionEvent(
id=range[2],
@ -4340,7 +4340,7 @@ Done""")
def get_children_ids(event):
return [child.id for child in event.cpu_children]
assert([get_children_ids(event) for event in events] == res)
assert [get_children_ids(event) for event in events] == res
def test_profiler_aggregation_table(self):
"""

View File

@ -30,7 +30,7 @@ def save_and_load(sm):
def bundle_jpeg_image(img_tensor, quality):
# turn NCHW to HWC
if img_tensor.dim() == 4:
assert(img_tensor.size(0) == 1)
assert img_tensor.size(0) == 1
img_tensor = img_tensor[0].permute(1, 2, 0)
pixels = img_tensor.numpy()
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]

View File

@ -1388,7 +1388,7 @@ torch.cuda.synchronize()
scaler.scale(l).backward()
scaler.step(optimizer)
scaler.update()
assert(scaler._scale != float('inf') and scaler._scale != float('nan'))
assert scaler._scale != float('inf') and scaler._scale != float('nan')
def test_grad_scaling_clipping(self):
def run(data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api):

View File

@ -1676,8 +1676,8 @@ class TestFX(JitTestCase):
x = torch.randn(5, 5, 224, 224)
shape_prop.ShapeProp(traced).propagate(x)
assert(all(node.meta['tensor_meta'].memory_format is torch.contiguous_format
for node in traced.graph.nodes))
assert all(node.meta['tensor_meta'].memory_format is torch.contiguous_format
for node in traced.graph.nodes)
x_channels_last = x.contiguous(memory_format=torch.channels_last)
traced.to(memory_format=torch.channels_last)
@ -1733,8 +1733,8 @@ class TestFX(JitTestCase):
traced_3d = symbolic_trace(test_mod_3d)
x_3d = torch.randn(5, 5, 224, 224, 15)
shape_prop.ShapeProp(traced_3d).propagate(x_3d)
assert(all(node.meta['tensor_meta'].memory_format is torch.contiguous_format
for node in traced_3d.graph.nodes))
assert all(node.meta['tensor_meta'].memory_format is torch.contiguous_format
for node in traced_3d.graph.nodes)
x_channels_last_3d = x_3d.contiguous(memory_format=torch.channels_last_3d)
traced_3d.to(memory_format=torch.channels_last_3d)
@ -2548,7 +2548,7 @@ class TestFX(JitTestCase):
return self.a.b, self.a.b.t(), self.a.b.view(12)
traced = torch.fx.symbolic_trace(Foo())
assert(all('constant' not in node.target for node in traced.graph.nodes))
assert all('constant' not in node.target for node in traced.graph.nodes)
def test_single_default_arg(self):
class M(torch.nn.Module):
@ -2824,7 +2824,7 @@ class TestFX(JitTestCase):
mod_false = symbolic_trace(mod, concrete_args={'y': False})
self.assertEqual(mod_true(3, True), 6)
print(mod_true.code)
assert(any(i.target == torch._assert for i in mod_true.graph.nodes))
assert any(i.target == torch._assert for i in mod_true.graph.nodes)
with self.assertRaises(AssertionError):
mod_true(3, False)
self.assertEqual(mod_false(3, False), 3)
@ -3607,17 +3607,17 @@ class TestFX(JitTestCase):
self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(val))), orig_out)
assert num_flat_args == 0 or "tree_flatten_spec" in nf.code
assert(sum([i.op == 'placeholder' for i in nf.graph.nodes]) == num_flat_args)
assert sum([i.op == 'placeholder' for i in nf.graph.nodes]) == num_flat_args
nf = symbolic_trace(nf)
self.assertEqual(nf(val), orig_out)
assert "tree_flatten_spec" not in nf.code
assert(sum([i.op == 'placeholder' for i in nf.graph.nodes]) == 1)
assert sum([i.op == 'placeholder' for i in nf.graph.nodes]) == 1
nf = symbolic_trace(nf, concrete_args={'x': inp})
self.assertEqual(nf(val), orig_out)
assert num_flat_args == 0 or "tree_flatten_spec" in nf.code
assert(sum([i.op == 'placeholder' for i in nf.graph.nodes]) == num_flat_args)
assert sum([i.op == 'placeholder' for i in nf.graph.nodes]) == num_flat_args
pickled = pickle.dumps(nf)
nf = pickle.loads(pickled)
@ -3708,7 +3708,7 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
return [('List', typing.List)]
def process_inputs(self, *inputs):
assert(len(inputs) == 1)
assert len(inputs) == 1
return inputs[0]
def f(a, b):
@ -3743,7 +3743,7 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
return [('List', typing.List)]
def process_inputs(self, *inputs):
assert(len(inputs) == 1)
assert len(inputs) == 1
return inputs[0]
def f(a, b):
@ -3772,7 +3772,7 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
return [('List', typing.List)]
def process_inputs(self, *inputs):
assert(len(inputs) == 1)
assert len(inputs) == 1
return inputs[0]
def generate_output(self, output_args):

View File

@ -1875,8 +1875,8 @@ graph(%Ra, %Rb):
o_ref = t(x_ref, p, train)
o.sum().backward()
o_ref.sum().backward()
assert(o.equal(o_ref))
assert(x.grad.equal(x_ref.grad))
assert o.equal(o_ref)
assert x.grad.equal(x_ref.grad)
@slowTest
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, 'Testing differentiable graph')
@ -5746,7 +5746,7 @@ a")
x = torch.zeros(3, 4, dtype=torch.long)
graph = _propagate_shapes(fn.graph, (x,), False)
default = torch.get_default_dtype()
if(default == torch.float):
if default == torch.float:
FileCheck().check('Float(*, *, requires_grad=0, device=cpu) = aten::add').run(graph)
else:
FileCheck().check('Double(*, *, requires_grad=0, device=cpu) = aten::add').run(graph)
@ -8557,15 +8557,15 @@ dedent """
fb.run("22 1 22")
def _dtype_to_jit_name(self, dtype):
if(dtype == torch.float32):
if dtype == torch.float32:
return "Float"
if(dtype == torch.float64):
if dtype == torch.float64:
return "Double"
if(dtype == torch.int64):
if dtype == torch.int64:
return "Long"
if(dtype == torch.int32):
if dtype == torch.int32:
return "Int"
if(dtype == torch.bool):
if dtype == torch.bool:
return "Bool"
raise RuntimeError('dtype not handled')
@ -8595,7 +8595,7 @@ dedent """
for dtype in (dtypes + [None]):
for tensor_type in dtypes:
# a couple of ops aren't implemented for non-floating types
if(not tensor_type.is_floating_point or (dtype is not None and not dtype.is_floating_point)):
if not tensor_type.is_floating_point or (dtype is not None and not dtype.is_floating_point):
if op in ['mean', 'softmax', 'log_softmax']:
continue
return_line = f"torch.tensor({tensor_data}, dtype={tensor_type}).{op}({str_args}dtype={dtype})"
@ -12837,7 +12837,7 @@ dedent """
for i, offset in enumerate(parsed_serialized_offsets):
data = reader.get_record(str(offset))
assert(data == buffers[i])
assert data == buffers[i]
def test_file_reader_no_memory_leak(self):
num_iters = 10000

View File

@ -4122,7 +4122,7 @@ class TestMPS(TestCaseMPS):
def helper(dtype):
# Test with axis -1
cpu_x = None
if(dtype == torch.float32):
if dtype == torch.float32:
cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32)
else:
cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32)
@ -4157,7 +4157,7 @@ class TestMPS(TestCaseMPS):
def helper(dtype):
# Test with axis -1
cpu_x = None
if(dtype == torch.float32):
if dtype == torch.float32:
cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32)
else:
cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32)

View File

@ -6865,7 +6865,7 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
elif weight_layout == torch.sparse_coo:
module.weight = nn.Parameter(module.weight.to_sparse_coo())
else:
assert(0)
raise AssertionError()
inp = torch.randn(4, requires_grad=True, device=device)
res = module(inp)

View File

@ -217,7 +217,7 @@ class SubTensor(torch.Tensor):
"""
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if(kwargs is None):
if kwargs is None:
kwargs = {}
if func not in HANDLED_FUNCTIONS_SUB:
@ -368,7 +368,7 @@ class TensorLike:
"""
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if(kwargs is None):
if kwargs is None:
kwargs = {}
if func not in HANDLED_FUNCTIONS_TENSOR_LIKE:

View File

@ -1353,7 +1353,7 @@ class TestReductions(TestCase):
def _test_memory_format_transformations(self, device, input_generator_fn, transformation_fn,
memory_format, compare_data=True, default_is_preserve=False):
assert(memory_format == torch.channels_last or memory_format == torch.channels_last_3d)
assert memory_format == torch.channels_last or memory_format == torch.channels_last_3d
# xc is a channels last tensor
xc = input_generator_fn(device)

View File

@ -5210,7 +5210,7 @@ else:
def _test_memory_format_transformations(self, device, input_generator_fn, transformation_fn,
memory_format, compare_data=True, default_is_preserve=False):
assert(memory_format == torch.channels_last or memory_format == torch.channels_last_3d)
assert memory_format == torch.channels_last or memory_format == torch.channels_last_3d
# xc is a channels last tensor
xc = input_generator_fn(device)

View File

@ -191,7 +191,7 @@ def _load_global_deps() -> None:
'nccl': 'libnccl.so.*[0-9]',
'nvtx': 'libnvToolsExt.so.*[0-9]',
}
is_cuda_lib_err = [lib for lib in cuda_libs.values() if(lib.split('.')[0] in err.args[0])]
is_cuda_lib_err = [lib for lib in cuda_libs.values() if lib.split('.')[0] in err.args[0]]
if not is_cuda_lib_err:
raise err
for lib_folder, lib_name in cuda_libs.items():

View File

@ -75,7 +75,7 @@ def is_gpu_compute_event(event):
def get_sorted_gpu_events(events):
sorted_gpu_events = []
for event in events:
if(not is_gpu_compute_event(event)):
if not is_gpu_compute_event(event):
continue
sorted_gpu_events.append(event)
return sorted(sorted_gpu_events, key=lambda x: x["ts"])
@ -102,7 +102,7 @@ def get_sorted_gpu_mm_conv_events(events):
gpu_events = get_sorted_gpu_events(events)
sorted_events = []
for event in gpu_events:
if(not is_mm_conv_event(event)):
if not is_mm_conv_event(event):
continue
sorted_events.append(event)
return sorted_events

View File

@ -22,7 +22,7 @@ class LinearBlockSparsePattern:
prev_col_block_size = 4
def __init__(self, row_block_size=1, col_block_size=4):
assert(_is_valid_linear_block_sparse_pattern(row_block_size, col_block_size))
assert _is_valid_linear_block_sparse_pattern(row_block_size, col_block_size)
LinearBlockSparsePattern.rlock.acquire()
LinearBlockSparsePattern.prev_row_block_size = LinearBlockSparsePattern.row_block_size
LinearBlockSparsePattern.prev_col_block_size = LinearBlockSparsePattern.col_block_size

View File

@ -85,7 +85,7 @@ def _fuse_linear_bn_leaky_relu(is_qat, linear, bn, leaky_relu):
>>> lr = nn.LeakyReLU(0.01)
>>> m2 = _fuse_linear_bn_leaky_relu(m1, b1, lr)
"""
assert(linear.training == bn.training and bn.training == leaky_relu.training),\
assert linear.training == bn.training and bn.training == leaky_relu.training, \
"Linear, BN and LeakyReLU all must be in the same mode (train or eval)."
if is_qat:

View File

@ -58,8 +58,8 @@ class APoTObserver(ObserverBase):
alpha = torch.max(-self.min_val, self.max_val)
# check for valid inputs of b, k
assert(self.k and self.k != 0)
assert(self.b % self.k == 0)
assert self.k and self.k != 0
assert self.b % self.k == 0
# compute n and store as member variable
self.n = self.b // self.k

View File

@ -31,7 +31,7 @@ def fuse_conv_bn(is_qat, conv, bn):
>>> # xdoctest: +SKIP
>>> m2 = fuse_conv_bn(m1, b1)
"""
assert(conv.training == bn.training),\
assert conv.training == bn.training, \
"Conv and BN both must be in the same mode (train or eval)."
fused_module_class_map = {
@ -71,7 +71,7 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu):
>>> # xdoctest: +SKIP
>>> m2 = fuse_conv_bn_relu(m1, b1, r1)
"""
assert(conv.training == bn.training == relu.training),\
assert conv.training == bn.training == relu.training, \
"Conv and BN both must be in the same mode (train or eval)."
fused_module : Optional[Type[nn.Sequential]] = None
if is_qat:
@ -118,7 +118,7 @@ def fuse_linear_bn(is_qat, linear, bn):
>>> # xdoctest: +SKIP
>>> m2 = fuse_linear_bn(m1, b1)
"""
assert(linear.training == bn.training),\
assert linear.training == bn.training, \
"Linear and BN both must be in the same mode (train or eval)."
if is_qat:
@ -147,7 +147,7 @@ def fuse_convtranspose_bn(is_qat, convt, bn):
>>> # xdoctest: +SKIP
>>> m2 = fuse_convtranspose_bn(m1, b1)
"""
assert(convt.training == bn.training),\
assert convt.training == bn.training, \
"ConvTranspose and BN both must be in the same mode (train or eval)."
if is_qat:

View File

@ -291,24 +291,24 @@ def get_op_node_and_weight_eq_obs(
op_node = user
break
assert(op_node is not None)
assert op_node is not None
if op_node.op == 'call_module':
# If the op_node is a nn.Linear layer, then it must have a
# WeightEqualizationObserver configuration
maybe_equalization_node_name_to_config = _get_observed_graph_module_attr(model, "equalization_node_name_to_qconfig")
assert maybe_equalization_node_name_to_config is not None
equalization_node_name_to_qconfig: Dict[str, Any] = maybe_equalization_node_name_to_config # type: ignore[assignment]
assert(equalization_node_name_to_qconfig.get(op_node.name, None) is not None)
assert equalization_node_name_to_qconfig.get(op_node.name, None) is not None
weight_eq_obs = equalization_node_name_to_qconfig.get(op_node.name, None).weight()
assert(isinstance(weight_eq_obs, _WeightEqualizationObserver))
assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
return op_node, weight_eq_obs
elif op_node.op == 'call_function':
weight_node = maybe_get_weight_eq_obs_node(op_node, modules)
if weight_node is not None:
weight_eq_obs = modules[str(weight_node.target)]
assert(isinstance(weight_eq_obs, _WeightEqualizationObserver))
assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
return op_node, weight_eq_obs
return None, None
@ -316,7 +316,7 @@ def get_op_node_and_weight_eq_obs(
def maybe_get_weight_eq_obs_node(op_node: Node, modules: Dict[str, nn.Module]) -> Optional[Node]:
""" Gets the weight equalization observer node if it exists.
"""
assert(op_node.op == 'call_function')
assert op_node.op == 'call_function'
for node_arg in op_node.args:
if node_arg_is_weight(op_node, node_arg):
assert (isinstance(node_arg, Node) and node_arg.op == 'call_module' and
@ -342,7 +342,7 @@ def maybe_get_next_input_eq_obs(node: Node, modules: Dict[str, nn.Module]) -> Op
the following equalization observer for linear2.
"""
assert(node_supports_equalization(node, modules))
assert node_supports_equalization(node, modules)
# Locate the following nn.ReLU or F.relu node if it exists
maybe_relu_node = maybe_get_next_module(node, modules, nn.ReLU)
@ -364,7 +364,7 @@ def maybe_get_next_input_eq_obs(node: Node, modules: Dict[str, nn.Module]) -> Op
return None
maybe_eq_obs = modules[str(maybe_eq_obs_node)]
assert(isinstance(maybe_eq_obs, _InputEqualizationObserver))
assert isinstance(maybe_eq_obs, _InputEqualizationObserver)
return maybe_eq_obs
def maybe_get_next_equalization_scale(node: Node, modules: Dict[str, nn.Module]) -> Optional[torch.Tensor]:
@ -389,10 +389,10 @@ def scale_input_observer(node: Node, modules: Dict[str, nn.Module]) -> None:
equalization observer
"""
input_eq_obs = modules[str(node.target)]
assert(isinstance(input_eq_obs, _InputEqualizationObserver))
assert isinstance(input_eq_obs, _InputEqualizationObserver)
input_quant_obs_node = node.args[0]
assert(isinstance(input_quant_obs_node, Node))
assert isinstance(input_quant_obs_node, Node)
input_quant_obs = modules[str(input_quant_obs_node.target)]
if not isinstance(input_quant_obs, ObserverBase):
@ -426,12 +426,12 @@ def scale_weight_node(
op_module = modules[str(node.target)][0] # type: ignore[index]
else:
op_module = modules[str(node.target)]
assert(nn_module_supports_equalization(op_module) or custom_module_supports_equalization(op_module))
assert nn_module_supports_equalization(op_module) or custom_module_supports_equalization(op_module)
# Scale the weights for input-weight equalization
# If the following layer needs to be equalized then we will multiply its scale
weight = op_module.weight
assert(isinstance(weight, torch.Tensor))
assert isinstance(weight, torch.Tensor)
# Scale the weights by the reciprocal of the equalization scale
# Reshape the equalization scale so that we can multiply it to the weight along axis=1
@ -453,7 +453,7 @@ def scale_weight_node(
bias = op_module.bias
if bias is None:
return
assert(isinstance(bias, torch.Tensor))
assert isinstance(bias, torch.Tensor)
# Reshape the equalization scale so that we can multiply it element-wise to the bias
next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias)
@ -494,7 +494,7 @@ def scale_weight_functional(
weight_node = weight_quant_obs_node.args[0]
if weight_node is None:
return
assert(isinstance(weight_node, Node) and weight_node.op == 'get_attr')
assert isinstance(weight_node, Node) and weight_node.op == 'get_attr'
weight_parent_name, weight_name = _parent_name(weight_node.target)
weight = getattr(modules[weight_parent_name], weight_name)
@ -515,7 +515,7 @@ def scale_weight_functional(
scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped)
setattr(modules[weight_parent_name], weight_name, scaled_weight)
assert(torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight))
assert torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight)
# Multiply the bias element wise by the next equalization scale
bias_node = None
@ -546,10 +546,10 @@ def clear_weight_quant_obs_node(op_node: Node, modules: Dict[str, nn.Module]) ->
weight_quant_obs_node = weight_eq_obs_node.args[0]
if weight_quant_obs_node is None:
return
assert(isinstance(weight_quant_obs_node, Node))
assert isinstance(weight_quant_obs_node, Node)
weight_quant_obs = modules[str(weight_quant_obs_node.target)]
assert(isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase))
assert isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase)
weight_quant_obs.reset_min_max_vals() # type: ignore[operator]
def remove_node(model: GraphModule, node: Node, prev_node: Node):
@ -578,7 +578,7 @@ def update_obs_for_equalization(model: GraphModule, modules: Dict[str, nn.Module
for node in model.graph.nodes:
if node.op == 'call_module' and isinstance(modules[node.target], _InputEqualizationObserver):
input_eq_obs = modules[node.target]
assert(isinstance(input_eq_obs, _InputEqualizationObserver))
assert isinstance(input_eq_obs, _InputEqualizationObserver)
op_node, weight_eq_obs = get_op_node_and_weight_eq_obs(node, model, modules)
if op_node is None or weight_eq_obs is None:
@ -589,7 +589,7 @@ def update_obs_for_equalization(model: GraphModule, modules: Dict[str, nn.Module
# been created
if fused_module_supports_equalization(modules[str(op_node.target)]):
module = modules[str(op_node.target)][0] # type: ignore[index]
assert(nn_module_supports_equalization(module))
assert nn_module_supports_equalization(module)
weight_eq_obs(module.weight)
else:
weight_eq_obs(modules[str(op_node.target)].weight)
@ -696,7 +696,7 @@ def convert_eq_obs(
elif weight_eq_obs_dict.get(node.name, None) is not None:
weight_eq_obs = weight_eq_obs_dict.get(node.name)
assert(isinstance(weight_eq_obs, _WeightEqualizationObserver))
assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
equalization_scale = weight_eq_obs.equalization_scale
if equalization_scale.nelement() == 1 and equalization_scale == torch.tensor(1):
@ -712,7 +712,7 @@ def convert_eq_obs(
weight_eq_obs_node = maybe_get_weight_eq_obs_node(node, modules)
if weight_eq_obs_node is None:
return
assert(isinstance(modules[str(weight_eq_obs_node.target)], _WeightEqualizationObserver))
assert isinstance(modules[str(weight_eq_obs_node.target)], _WeightEqualizationObserver)
# Clear the quantization observer's min/max values so that they
# can get updated later based on the new scale values

View File

@ -487,14 +487,14 @@ def _match_static_pattern(
return SKIP_LOWERING_VALUE
q_node = node
ref_node = q_node.args[0]
assert(isinstance(ref_node, Node))
assert isinstance(ref_node, Node)
# Handle cases where the node is wrapped in a ReLU
if (ref_node.op == "call_function" and ref_node.target in (F.relu, torch.relu)) or\
(ref_node.op == "call_module" and type(_get_module(ref_node, modules)) == nn.ReLU):
relu_node = ref_node
ref_node = relu_node.args[0]
assert(isinstance(ref_node, Node))
assert isinstance(ref_node, Node)
else:
relu_node = None
if should_skip_lowering(ref_node, qconfig_map):
@ -557,7 +557,7 @@ def _match_static_pattern_with_two_inputs(
return SKIP_LOWERING_VALUE
q_node = node
ref_node = q_node.args[0]
assert(isinstance(ref_node, Node))
assert isinstance(ref_node, Node)
if should_skip_lowering(ref_node, qconfig_map):
return SKIP_LOWERING_VALUE
@ -599,13 +599,13 @@ def _lower_static_weighted_ref_module(
n, modules, qconfig_map, matching_modules, dequantize_node_arg_indices=[0]) # type: ignore[arg-type]
if q_node is None:
continue
assert(ref_node is not None)
assert ref_node is not None
(_, scale_node, zero_point_node, _) = q_node.args
ref_module = _get_module(ref_node, modules)
ref_class = type(ref_module)
assert(isinstance(scale_node, Node))
assert(isinstance(zero_point_node, Node))
assert(issubclass(ref_class, nn.Module))
assert isinstance(scale_node, Node)
assert isinstance(zero_point_node, Node)
assert issubclass(ref_class, nn.Module)
# Step 1: Change this pattern to use the corresponding quantized module
# For fused modules, we also check whether the inner module is a reference module
@ -624,9 +624,9 @@ def _lower_static_weighted_ref_module(
setattr(modules[parent_name], module_name, q_module)
# Step 2: Reroute around dq_node, and remove q_node and its args
assert(len(ref_node.args) == 1)
assert len(ref_node.args) == 1
dq_node = ref_node.args[0]
assert(isinstance(dq_node, Node))
assert isinstance(dq_node, Node)
ref_node.replace_input_with(dq_node, dq_node.args[0])
q_node.replace_all_uses_with(ref_node)
model.graph.erase_node(q_node)
@ -655,13 +655,13 @@ def _lower_static_weighted_ref_module_with_two_inputs(
n, modules, qconfig_map, matching_modules) # type: ignore[arg-type]
if q_node is None:
continue
assert(ref_node is not None)
assert ref_node is not None
(_, scale_node, zero_point_node, _) = q_node.args
ref_module = _get_module(ref_node, modules)
ref_class = type(ref_module)
assert(isinstance(scale_node, Node))
assert(isinstance(zero_point_node, Node))
assert(issubclass(ref_class, nn.Module))
assert isinstance(scale_node, Node)
assert isinstance(zero_point_node, Node)
assert issubclass(ref_class, nn.Module)
# Step 1: Change this pattern to use the corresponding quantized module
# For fused modules, we also check whether the inner module is a reference module
@ -680,12 +680,12 @@ def _lower_static_weighted_ref_module_with_two_inputs(
setattr(modules[parent_name], module_name, q_module)
# Step 2: Reroute around dq_node, and remove q_node and its args
assert(len(ref_node.args) == 2)
assert len(ref_node.args) == 2
for arg in ref_node.args:
if not is_dequantize_node(arg):
continue
dq_node = arg
assert(isinstance(dq_node, Node))
assert isinstance(dq_node, Node)
ref_node.replace_input_with(dq_node, dq_node.args[0])
q_node.replace_all_uses_with(ref_node)
@ -778,14 +778,14 @@ def _lower_static_weighted_ref_functional(
n, modules, qconfig_map, matching_ops, dequantize_node_arg_indices=[0, 1])
if q_node is None:
continue
assert(func_node is not None)
assert func_node is not None
(_, output_scale_node, output_zp_node, _) = q_node.args
(input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args
assert(isinstance(output_zp_node, Node))
assert(isinstance(input_dq_node, Node))
assert(isinstance(weight_dq_node, Node))
assert isinstance(output_zp_node, Node)
assert isinstance(input_dq_node, Node)
assert isinstance(weight_dq_node, Node)
quantized_weight = weight_dq_node.args[0]
assert(isinstance(quantized_weight, Node))
assert isinstance(quantized_weight, Node)
if quantized_weight.op != "call_function" or\
quantized_weight.target not in (torch.quantize_per_tensor, torch.quantize_per_channel):
continue
@ -966,7 +966,7 @@ def _lower_quantized_binary_op(
n, modules, qconfig_map, binary_ops_to_lower, dequantize_node_arg_indices=[0, 1])
if q_node is None:
continue
assert(bop_node is not None)
assert bop_node is not None
(_, scale_node, zero_point_node, _) = q_node.args
# Step 1: Remove dequant nodes
@ -975,11 +975,11 @@ def _lower_quantized_binary_op(
if not is_dequantize_node(arg):
continue
dq_node = arg
assert(isinstance(dq_node, Node))
assert isinstance(dq_node, Node)
dn_input = dq_node.args[0]
bop_node.replace_input_with(dq_node, dn_input)
num_dq_nodes += 1
assert(num_dq_nodes > 0)
assert num_dq_nodes > 0
# Step 2: Swap binary op to quantized binary op
assert bop_node.target in QBIN_OP_MAPPING

View File

@ -956,7 +956,7 @@ def convert(
"in a future version. Please pass in a QConfigMapping instead.")
qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None
qconfig_mapping = copy.deepcopy(qconfig_mapping)
assert(qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping))
assert qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping)
if isinstance(backend_config, Dict):
warnings.warn(

View File

@ -1585,7 +1585,7 @@ def insert_observers_for_model(
# should resolve this inconsistency by inserting DeQuantStubs for all custom
# modules, not just for LSTM.
_insert_dequant_stubs_for_custom_module_lstm_output(node, model, named_modules, model.graph)
if(node.target not in custom_module_names_already_swapped):
if node.target not in custom_module_names_already_swapped:
custom_module_names_already_swapped.add(node.target)
_swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config)
else:
@ -1627,7 +1627,7 @@ def insert_observers_for_model(
_remove_output_observer(node, model, named_modules)
if qhandler is not None and qhandler.is_custom_module():
if(node.target not in custom_module_names_already_swapped):
if node.target not in custom_module_names_already_swapped:
custom_module_names_already_swapped.add(node.target)
_swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config)
@ -1771,8 +1771,8 @@ def prepare(
"in a future version. Please pass in a BackendConfig instead.")
backend_config = BackendConfig.from_dict(backend_config)
assert(isinstance(qconfig_mapping, QConfigMapping))
assert(isinstance(_equalization_config, QConfigMapping))
assert isinstance(qconfig_mapping, QConfigMapping)
assert isinstance(_equalization_config, QConfigMapping)
qconfig_mapping = copy.deepcopy(qconfig_mapping)
_equalization_config = copy.deepcopy(_equalization_config)

View File

@ -403,7 +403,7 @@ def generate_calc_product(constraint, counter):
for p in all_possibilities:
p = list(p)
# this tells us there is a dynamic variable
contains_dyn = not(all(constraint.op == op_neq for constraint in p))
contains_dyn = not all(constraint.op == op_neq for constraint in p)
if contains_dyn:
mid_var = [Dyn]
total_constraints = lhs + mid_var + rhs

View File

@ -42,7 +42,7 @@ def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict
def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):
assert(isinstance(node.target, str))
assert isinstance(node.target, str)
parent_name, name = _parent_name(node.target)
modules[node.target] = new_module
setattr(modules[parent_name], name, new_module)
@ -133,11 +133,11 @@ def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]):
old_modules: Dict[nn.Module, nn.Module] = {}
for node in nodes:
if node.op == 'call_module':
assert(isinstance(node.target, str))
assert isinstance(node.target, str)
cur_module = modules[node.target]
if type(cur_module) in mkldnn_map:
new_module = mkldnn_map[type(cur_module)](cur_module, torch.float)
assert(isinstance(new_module, nn.Module))
assert isinstance(new_module, nn.Module)
old_modules[new_module] = copy.deepcopy(cur_module)
replace_node_module(node, modules, new_module)
return old_modules
@ -220,7 +220,7 @@ class UnionFind:
par = self.parent[v]
if v == par:
return v
assert(par is not None)
assert par is not None
self.parent[v] = self.find(par)
return cast(int, self.parent[v])
@ -294,8 +294,8 @@ def optimize_for_inference(
supports_mkldnn = MklSupport.YES
sample_parameter = next(cur_module.parameters(), None)
if sample_parameter is not None:
assert(sample_parameter.dtype == torch.float), "this pass is only for torch.float modules"
assert(sample_parameter.device == torch.device('cpu')), "this pass is only for CPU modules"
assert sample_parameter.dtype == torch.float, "this pass is only for torch.float modules"
assert sample_parameter.device == torch.device('cpu'), "this pass is only for CPU modules"
elif node.op == 'call_function':
if node.target in mkldnn_supported:
supports_mkldnn = MklSupport.YES
@ -360,14 +360,14 @@ def optimize_for_inference(
node.start_color = cur_idx
uf.make_set(cur_idx)
elif node.op == 'call_method' and node.target == 'to_dense':
assert(get_color(node.args[0]) is not None)
assert get_color(node.args[0]) is not None
node.end_color = get_color(node.args[0])
else:
cur_colors = [get_color(i) for i in node.all_input_nodes if isinstance(i, fx.Node) if get_color(i) is not None]
if len(cur_colors) == 0:
continue
assert(not any(i is None for i in cur_colors))
assert not any(i is None for i in cur_colors)
cur_colors = sorted(cur_colors)
node.color = cur_colors[0]
for other_color in cur_colors[1:]:

View File

@ -1359,7 +1359,7 @@ class DimConstraints:
self.raise_inconsistencies()
# as long as there are symbols with equalities, solve for them
# NOTE(avik): this is guaranteed to terminate (#iterations <= #symbols)
while(self._symbols_with_equalities):
while self._symbols_with_equalities:
s = self._symbols_with_equalities.pop()
exprs = self._univariate_inequalities.pop(s)
solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s)

View File

@ -670,7 +670,7 @@ class _PyTreeCodeGen(CodeGen):
return out
if not isinstance(out, (list, tuple)):
out = [out]
assert(self.pytree_info.out_spec is not None)
assert self.pytree_info.out_spec is not None
return pytree.tree_unflatten(out, self.pytree_info.out_spec)
def gen_fn_def(self, free_vars, maybe_return_annotation):

View File

@ -82,7 +82,7 @@ def _type_repr(obj):
return obj.__qualname__
return f'{obj.__module__}.{obj.__qualname__}'
if obj is ...:
return('...')
return '...'
if isinstance(obj, types.FunctionType):
return obj.__name__
return repr(obj)

View File

@ -698,7 +698,7 @@ class _MinimizerBase:
if self.settings.traverse_method == "accumulate":
return self._accumulate_traverse(nodes)
if(self.settings.traverse_method == "skip"):
if self.settings.traverse_method == "skip":
if (skip_nodes is None):
raise RuntimeError("'skip_nodes' can't be None when 'traverse_method' is 'skip'.")
return self._skip_traverse(nodes, skip_nodes)

View File

@ -510,7 +510,7 @@ class ParameterProxy(Proxy):
"""
def __init__(self, tracer: TracerBase, node: Node, name, param):
super().__init__(node, tracer)
assert(isinstance(param, torch.nn.Parameter))
assert isinstance(param, torch.nn.Parameter)
self.param = param
self.name = name

View File

@ -125,7 +125,7 @@ def _snr(x, x_hat):
signal, noise, SNR(in dB): Either floats or a nested list of floats
"""
if isinstance(x, (list, tuple)):
assert(len(x) == len(x_hat))
assert len(x) == len(x_hat)
res = []
for idx in range(len(x)):
res.append(_snr(x[idx], x_hat[idx]))

View File

@ -1352,7 +1352,7 @@ def skipIfTorchDynamo(msg="test doesn't currently work with dynamo"):
fn(*args, **kwargs)
return wrapper
assert(isinstance(fn, type))
assert isinstance(fn, type)
if TEST_WITH_TORCHDYNAMO: # noqa: F821
fn.__unittest_skip__ = True
fn.__unittest_skip_why__ = msg
@ -1374,7 +1374,7 @@ def skipIfTorchInductor(msg="test doesn't currently work with torchinductor",
fn(*args, **kwargs)
return wrapper
assert(isinstance(fn, type))
assert isinstance(fn, type)
if condition:
fn.__unittest_skip__ = True
fn.__unittest_skip_why__ = msg
@ -1440,7 +1440,7 @@ def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT exe
fn(*args, **kwargs)
return wrapper
assert(isinstance(fn, type))
assert isinstance(fn, type)
if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
fn.__unittest_skip__ = True
fn.__unittest_skip_why__ = msg

View File

@ -159,10 +159,10 @@ Example:
@st.composite
def array_shapes(draw, min_dims=1, max_dims=None, min_side=1, max_side=None, max_numel=None):
"""Return a strategy for array shapes (tuples of int >= 1)."""
assert(min_dims < 32)
assert min_dims < 32
if max_dims is None:
max_dims = min(min_dims + 2, 32)
assert(max_dims < 32)
assert max_dims < 32
if max_side is None:
max_side = min_side + 5
candidate = st.lists(st.integers(min_side, max_side), min_size=min_dims, max_size=max_dims)
@ -334,7 +334,7 @@ def tensor_conv(
# Resolve the tensors
if qparams is not None:
if isinstance(qparams, (list, tuple)):
assert(len(qparams) == 3), "Need 3 qparams for X, w, b"
assert len(qparams) == 3, "Need 3 qparams for X, w, b"
else:
qparams = [qparams] * 3

View File

@ -331,7 +331,7 @@ def value_to_literal(value):
def get_call(method_name, func_type, args, kwargs):
kwargs_str = ', '.join([k + '=' + value_to_literal(v) for k, v in kwargs.items()])
self_arg = args[0]
if(func_type == 'method'):
if func_type == 'method':
args = args[1:]
argument_str = ', '.join(args)

View File

@ -117,10 +117,10 @@ def bundle_inputs(
# Fortunately theres a function in _recursive that does exactly that conversion.
cloned_module = wrap_cpp_module(clone)
if isinstance(inputs, dict):
assert(isinstance(info, dict) or info is None)
assert isinstance(info, dict) or info is None
augment_many_model_functions_with_bundled_inputs(cloned_module, inputs, _receive_inflate_expr, info)
else:
assert(isinstance(info, list) or info is None)
assert isinstance(info, list) or info is None
augment_model_with_bundled_inputs(cloned_module, inputs, _receive_inflate_expr, info)
return cloned_module

View File

@ -329,7 +329,7 @@ class FlopCounterMode(TorchDispatchMode):
class PushState(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
assert(self.parents[-1] == name)
assert self.parents[-1] == name
self.parents.pop()
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
return args
@ -351,7 +351,7 @@ class FlopCounterMode(TorchDispatchMode):
@staticmethod
def backward(ctx, *grad_outs):
assert(self.parents[-1] == name)
assert self.parents[-1] == name
self.parents.pop()
return grad_outs

View File

@ -136,9 +136,9 @@ class MkldnnBatchNorm(torch.jit.ScriptModule):
def __init__(self, dense_module):
super().__init__()
assert(not dense_module.training)
assert(dense_module.track_running_stats)
assert(dense_module.affine)
assert not dense_module.training
assert dense_module.track_running_stats
assert dense_module.affine
if dense_module.momentum is None:
self.exponential_average_factor = 0.0