[BE]: Update flake8 to v6.1.0 and fix lints (#116591)

Updates flake8 to v6.1.0 and fixes a few lints using sed and some ruff tooling.
- Replace `assert(0)` with `raise AssertionError()`
- Remove extraneous parenthesis i.e.
  - `assert(a == b)` -> `assert a == b`
  - `if(x > y or y < z):`->`if x > y or y < z:`
  - And `return('...')` -> `return '...'`

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116591
Approved by: https://github.com/albanD, https://github.com/malfet
This commit is contained in:
Aaron Gokaslan 2024-01-03 06:04:44 +00:00 committed by PyTorch MergeBot
parent 09ee96b69d
commit 3fe437b24b
62 changed files with 188 additions and 190 deletions

View File

@ -8,8 +8,6 @@ max-line-length = 120
# E501 is not flexible enough, we're using B950 instead # E501 is not flexible enough, we're using B950 instead
ignore = ignore =
E203,E305,E402,E501,E721,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303, E203,E305,E402,E501,E721,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303,
# fix these lints in the future
E275,
# shebang has extra meaning in fbcode lints, so I think it's not worth trying # shebang has extra meaning in fbcode lints, so I think it's not worth trying
# to line this up with executable bit # to line this up with executable bit
EXE001, EXE001,

View File

@ -38,7 +38,7 @@ init_command = [
'python3', 'python3',
'tools/linter/adapters/pip_init.py', 'tools/linter/adapters/pip_init.py',
'--dry-run={{DRYRUN}}', '--dry-run={{DRYRUN}}',
'flake8==6.0.0', 'flake8==6.1.0',
'flake8-bugbear==23.3.23', 'flake8-bugbear==23.3.23',
'flake8-comprehensions==3.12.0', 'flake8-comprehensions==3.12.0',
'flake8-executable==2.1.3', 'flake8-executable==2.1.3',
@ -46,8 +46,8 @@ init_command = [
'flake8-pyi==23.3.1', 'flake8-pyi==23.3.1',
'flake8-simplify==0.19.3', 'flake8-simplify==0.19.3',
'mccabe==0.7.0', 'mccabe==0.7.0',
'pycodestyle==2.10.0', 'pycodestyle==2.11.1',
'pyflakes==3.0.1', 'pyflakes==3.1.0',
'torchfix==0.2.0', 'torchfix==0.2.0',
] ]

View File

@ -981,7 +981,7 @@ class TestFPGMPruner(TestCase):
pruner.prepare(model, config) pruner.prepare(model, config)
pruner.enable_mask_update = True pruner.enable_mask_update = True
pruner.step() pruner.step()
assert pruner.groups[0]["module"].parametrizations.weight[0].mask[-1].item() is not False,\ assert pruner.groups[0]["module"].parametrizations.weight[0].mask[-1].item() is not False, \
"do not prune the least-norm filter" "do not prune the least-norm filter"
# fusion step # fusion step

View File

@ -9,7 +9,7 @@ import torch.distributed as c10d
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.testing._internal.common_distributed import \ from torch.testing._internal.common_distributed import \
MultiProcessTestCase MultiProcessTestCase
from torch.testing._internal.common_utils import load_tests,\ from torch.testing._internal.common_utils import load_tests, \
NO_MULTIPROCESSING_SPAWN NO_MULTIPROCESSING_SPAWN
# Torch distributed.nn is not available in windows # Torch distributed.nn is not available in windows

View File

@ -27,7 +27,7 @@ def loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values):
for idx in range(batch_size): for idx in range(batch_size):
flat_args, args_spec = pytree.tree_flatten(batched_args) flat_args, args_spec = pytree.tree_flatten(batched_args)
flat_dims, dims_spec = pytree.tree_flatten(in_dims) flat_dims, dims_spec = pytree.tree_flatten(in_dims)
assert(args_spec == dims_spec) assert args_spec == dims_spec
new_args = [a.select(in_dim, idx) if in_dim is not None else a for a, in_dim in zip(flat_args, flat_dims)] new_args = [a.select(in_dim, idx) if in_dim is not None else a for a, in_dim in zip(flat_args, flat_dims)]
out = op(*pytree.tree_unflatten(new_args, args_spec), **kwarg_values) out = op(*pytree.tree_unflatten(new_args, args_spec), **kwarg_values)
flat_out, out_spec = pytree.tree_flatten(out) flat_out, out_spec = pytree.tree_flatten(out)
@ -45,9 +45,9 @@ def loop2(op, in_dims1, in_dims2, out_dim1, out_dim2, batch_size1, batch_size2,
flat_args, args_spec = pytree.tree_flatten(batched_args) flat_args, args_spec = pytree.tree_flatten(batched_args)
flat_dims1, dims_spec1 = pytree.tree_flatten(in_dims1) flat_dims1, dims_spec1 = pytree.tree_flatten(in_dims1)
flat_dims2, dims_spec2 = pytree.tree_flatten(in_dims2) flat_dims2, dims_spec2 = pytree.tree_flatten(in_dims2)
assert(args_spec == dims_spec1) assert args_spec == dims_spec1
assert(args_spec == dims_spec2) assert args_spec == dims_spec2
assert(len(flat_dims1) == len(flat_dims2)) assert len(flat_dims1) == len(flat_dims2)
for idx1 in range(batch_size1): for idx1 in range(batch_size1):
out_split = [] out_split = []
arg_split = [a.select(in_dim1, idx1) if in_dim1 is not None else a for a, in_dim1 in zip(flat_args, flat_dims1)] arg_split = [a.select(in_dim1, idx1) if in_dim1 is not None else a for a, in_dim1 in zip(flat_args, flat_dims1)]

View File

@ -7,7 +7,7 @@ from torch.fx.experimental.migrate_gradual_types.constraint import BinConstraint
from torch.fx.experimental.migrate_gradual_types.constraint_generator import ConstraintGenerator from torch.fx.experimental.migrate_gradual_types.constraint_generator import ConstraintGenerator
from torch.fx.experimental.migrate_gradual_types.constraint_transformation import transform_constraint from torch.fx.experimental.migrate_gradual_types.constraint_transformation import transform_constraint
from torch.fx.experimental.migrate_gradual_types.operation import op_precision, op_matching, op_consistency from torch.fx.experimental.migrate_gradual_types.operation import op_precision, op_matching, op_consistency
from torch.fx.experimental.migrate_gradual_types.transform_to_z3 import transform_all_constraints,\ from torch.fx.experimental.migrate_gradual_types.transform_to_z3 import transform_all_constraints, \
evaluate_conditional_with_constraints evaluate_conditional_with_constraints
from torch.fx.experimental.migrate_gradual_types.z3_types import tensor_type, D, z3_dyn from torch.fx.experimental.migrate_gradual_types.z3_types import tensor_type, D, z3_dyn
from torch.fx.experimental.rewriter import RewritingTracer from torch.fx.experimental.rewriter import RewritingTracer

View File

@ -154,9 +154,9 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
output_ref = func(input0, input1) output_ref = func(input0, input1)
for i in range(2): for i in range(2):
output = jit_f(input0, input1) output = jit_f(input0, input1)
assert(output_ref[0].requires_grad == output[0].requires_grad) assert output_ref[0].requires_grad == output[0].requires_grad
assert(output_ref[1][0].requires_grad == output[1][0].requires_grad) assert output_ref[1][0].requires_grad == output[1][0].requires_grad
assert(output_ref[1][1].requires_grad == output[1][1].requires_grad) assert output_ref[1][1].requires_grad == output[1][1].requires_grad
@unittest.skip("disable until we property handle tensor lists with undefined gradients") @unittest.skip("disable until we property handle tensor lists with undefined gradients")
def test_differentiable_graph_ops_requires_grad(self): def test_differentiable_graph_ops_requires_grad(self):

View File

@ -35,7 +35,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
self.checkModule(M(), (1,)) self.checkModule(M(), (1,))
assert(len(w) == 0) assert len(w) == 0
def test_annotated_nonempty_container(self): def test_annotated_nonempty_container(self):
class M(torch.nn.Module): class M(torch.nn.Module):
@ -49,7 +49,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
self.checkModule(M(), ([1, 2, 3],)) self.checkModule(M(), ([1, 2, 3],))
assert(len(w) == 0) assert len(w) == 0
def test_annotated_empty_tensor(self): def test_annotated_empty_tensor(self):
class M(torch.nn.Module): class M(torch.nn.Module):
@ -63,7 +63,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
self.checkModule(M(), (torch.rand(2, 3),)) self.checkModule(M(), (torch.rand(2, 3),))
assert(len(w) == 0) assert len(w) == 0
def test_annotated_with_jit_attribute(self): def test_annotated_with_jit_attribute(self):
class M(torch.nn.Module): class M(torch.nn.Module):
@ -77,7 +77,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
self.checkModule(M(), ([1, 2, 3],)) self.checkModule(M(), ([1, 2, 3],))
assert(len(w) == 0) assert len(w) == 0
def test_annotated_class_level_annotation_only(self): def test_annotated_class_level_annotation_only(self):
class M(torch.nn.Module): class M(torch.nn.Module):
@ -94,7 +94,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
self.checkModule(M(), ([1, 2, 3],)) self.checkModule(M(), ([1, 2, 3],))
assert(len(w) == 0) assert len(w) == 0
def test_annotated_class_level_annotation_and_init_annotation(self): def test_annotated_class_level_annotation_and_init_annotation(self):
@ -112,7 +112,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
self.checkModule(M(), ([1, 2, 3],)) self.checkModule(M(), ([1, 2, 3],))
assert(len(w) == 0) assert len(w) == 0
def test_annotated_class_level_jit_annotation(self): def test_annotated_class_level_jit_annotation(self):
class M(torch.nn.Module): class M(torch.nn.Module):
@ -129,7 +129,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
self.checkModule(M(), ([1, 2, 3],)) self.checkModule(M(), ([1, 2, 3],))
assert(len(w) == 0) assert len(w) == 0
def test_annotated_empty_list(self): def test_annotated_empty_list(self):
class M(torch.nn.Module): class M(torch.nn.Module):

View File

@ -147,7 +147,7 @@ class testVariousModelVersions(TestCase):
def test_get_model_bytecode_version(self): def test_get_model_bytecode_version(self):
def check_model_version(model_path, expect_version): def check_model_version(model_path, expect_version):
actual_version = _get_model_bytecode_version(model_path) actual_version = _get_model_bytecode_version(model_path)
assert(actual_version == expect_version) assert actual_version == expect_version
for version, model_info in SCRIPT_MODULE_BYTECODE_PKL.items(): for version, model_info in SCRIPT_MODULE_BYTECODE_PKL.items():
model_path = pytorch_test_dir / "cpp" / "jit" / model_info["model_name"] model_path = pytorch_test_dir / "cpp" / "jit" / model_info["model_name"]
check_model_version(model_path, version) check_model_version(model_path, version)
@ -174,7 +174,7 @@ class testVariousModelVersions(TestCase):
current_to_version = current_from_version - 1 current_to_version = current_from_version - 1
backport_success = _backport_for_mobile(input_model_path, tmp_output_model_path_backport, current_to_version) backport_success = _backport_for_mobile(input_model_path, tmp_output_model_path_backport, current_to_version)
assert(backport_success) assert backport_success
expect_bytecode_pkl = SCRIPT_MODULE_BYTECODE_PKL[current_to_version]["bytecode_pkl"] expect_bytecode_pkl = SCRIPT_MODULE_BYTECODE_PKL[current_to_version]["bytecode_pkl"]
@ -187,7 +187,7 @@ class testVariousModelVersions(TestCase):
acutal_result_clean = "".join(output.split()) acutal_result_clean = "".join(output.split())
expect_result_clean = "".join(expect_bytecode_pkl.split()) expect_result_clean = "".join(expect_bytecode_pkl.split())
isMatch = fnmatch.fnmatch(acutal_result_clean, expect_result_clean) isMatch = fnmatch.fnmatch(acutal_result_clean, expect_result_clean)
assert(isMatch) assert isMatch
current_from_version -= 1 current_from_version -= 1
shutil.rmtree(tmpdirname) shutil.rmtree(tmpdirname)
@ -254,7 +254,7 @@ class testVariousModelVersions(TestCase):
script_module_v5_path, script_module_v5_path,
tmp_backport_model_path, tmp_backport_model_path,
maximum_checked_in_model_version - 1) maximum_checked_in_model_version - 1)
assert(success) assert success
buf = io.StringIO() buf = io.StringIO()
torch.utils.show_pickle.main( torch.utils.show_pickle.main(
@ -266,7 +266,7 @@ class testVariousModelVersions(TestCase):
acutal_result_clean = "".join(output.split()) acutal_result_clean = "".join(output.split())
expect_result_clean = "".join(expected_result.split()) expect_result_clean = "".join(expected_result.split())
isMatch = fnmatch.fnmatch(acutal_result_clean, expect_result_clean) isMatch = fnmatch.fnmatch(acutal_result_clean, expect_result_clean)
assert(isMatch) assert isMatch
# Load model v4 and run forward method # Load model v4 and run forward method
mobile_module = _load_for_lite_interpreter(str(tmp_backport_model_path)) mobile_module = _load_for_lite_interpreter(str(tmp_backport_model_path))
@ -291,7 +291,7 @@ class testVariousModelVersions(TestCase):
# Check version of the model v4 from backport # Check version of the model v4 from backport
bytesio = io.BytesIO(script_module_v4_buffer) bytesio = io.BytesIO(script_module_v4_buffer)
backport_version = _get_model_bytecode_version(bytesio) backport_version = _get_model_bytecode_version(bytesio)
assert(backport_version == maximum_checked_in_model_version - 1) assert backport_version == maximum_checked_in_model_version - 1
# Load model v4 from backport and run forward method # Load model v4 from backport and run forward method
bytesio = io.BytesIO(script_module_v4_buffer) bytesio = io.BytesIO(script_module_v4_buffer)
@ -306,8 +306,8 @@ class testVariousModelVersions(TestCase):
# TODO update this to be more in the style of the above tests after a backport from 6 -> 5 exists # TODO update this to be more in the style of the above tests after a backport from 6 -> 5 exists
script_module_v6 = pytorch_test_dir / "cpp" / "jit" / "script_module_v6.ptl" script_module_v6 = pytorch_test_dir / "cpp" / "jit" / "script_module_v6.ptl"
ops_v6 = _get_model_ops_and_info(script_module_v6) ops_v6 = _get_model_ops_and_info(script_module_v6)
assert(ops_v6["aten::add.int"].num_schema_args == 2) assert ops_v6["aten::add.int"].num_schema_args == 2
assert(ops_v6["aten::add.Scalar"].num_schema_args == 2) assert ops_v6["aten::add.Scalar"].num_schema_args == 2
def test_get_mobile_model_contained_types(self): def test_get_mobile_model_contained_types(self):
class MyTestModule(torch.nn.Module): class MyTestModule(torch.nn.Module):
@ -322,7 +322,7 @@ class testVariousModelVersions(TestCase):
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter()) buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
buffer.seek(0) buffer.seek(0)
type_list = _get_mobile_model_contained_types(buffer) type_list = _get_mobile_model_contained_types(buffer)
assert(len(type_list) >= 0) assert len(type_list) >= 0
if __name__ == '__main__': if __name__ == '__main__':
run_tests() run_tests()

View File

@ -82,7 +82,7 @@ class TestLiteScriptModule(TestCase):
buffer = io.BytesIO(exported_module) buffer = io.BytesIO(exported_module)
buffer.seek(0) buffer.seek(0)
assert(b"callstack_debug_map.pkl" in exported_module) assert b"callstack_debug_map.pkl" in exported_module
mobile_module = _load_for_lite_interpreter(buffer) mobile_module = _load_for_lite_interpreter(buffer)
with self.assertRaisesRegex(RuntimeError, r"Module hierarchy:top\(B\)::<unknown>.A0\(A\)::forward.aten::mul"): with self.assertRaisesRegex(RuntimeError, r"Module hierarchy:top\(B\)::<unknown>.A0\(A\)::forward.aten::mul"):

View File

@ -65,8 +65,8 @@ class TestDropoutNN(NNTestCase):
o_ref = torch.dropout(x_ref, p, train) o_ref = torch.dropout(x_ref, p, train)
o.sum().backward() o.sum().backward()
o_ref.sum().backward() o_ref.sum().backward()
assert(o.equal(o_ref)) assert o.equal(o_ref)
assert(x.grad.equal(x_ref.grad)) assert x.grad.equal(x_ref.grad)
def test_invalid_dropout_p(self): def test_invalid_dropout_p(self):
v = torch.ones(1) v = torch.ones(1)

View File

@ -166,7 +166,7 @@ class TestOptim(TestCase):
pass pass
elif constructor_accepts_maximize: elif constructor_accepts_maximize:
def four_arg_constructor(weight, bias, maximize, foreach): def four_arg_constructor(weight, bias, maximize, foreach): # noqa: F811
self.assertFalse(foreach) self.assertFalse(foreach)
return constructor(weight, bias, maximize) return constructor(weight, bias, maximize)

View File

@ -2127,7 +2127,7 @@ class TestQuantizedOps(TestCase):
quantized_out = torch.topk(qX, k, dim=dim, largest=larg, sorted=sort) quantized_out = torch.topk(qX, k, dim=dim, largest=larg, sorted=sort)
assert(len(unquantized_out) == len(quantized_out)) assert len(unquantized_out) == len(quantized_out)
torch.testing.assert_close(quantized_out[0].dequantize(), unquantized_out[0]) torch.testing.assert_close(quantized_out[0].dequantize(), unquantized_out[0])
torch.testing.assert_close(quantized_out[1], unquantized_out[1]) torch.testing.assert_close(quantized_out[1], unquantized_out[1])
@ -3813,7 +3813,7 @@ class TestQuantizedLinear(TestCase):
post_op = 'none' post_op = 'none'
cases = itertools.product(batch_size_list, input_channels_list, output_channels_list, cases = itertools.product(batch_size_list, input_channels_list, output_channels_list,
use_bias_list, use_multi_dim_input_list, use_channelwise_list) use_bias_list, use_multi_dim_input_list, use_channelwise_list)
for batch_size, input_channels, output_channels, use_bias,\ for batch_size, input_channels, output_channels, use_bias, \
use_multi_dim_input, use_channelwise in cases: use_multi_dim_input, use_channelwise in cases:
self._test_qlinear_impl(batch_size, input_channels, output_channels, self._test_qlinear_impl(batch_size, input_channels, output_channels,
use_bias, post_op, use_multi_dim_input, use_channelwise) use_bias, post_op, use_multi_dim_input, use_channelwise)
@ -3830,7 +3830,7 @@ class TestQuantizedLinear(TestCase):
post_op = 'relu' post_op = 'relu'
cases = itertools.product(batch_size_list, input_channels_list, output_channels_list, cases = itertools.product(batch_size_list, input_channels_list, output_channels_list,
use_bias_list, use_multi_dim_input_list, use_channelwise_list) use_bias_list, use_multi_dim_input_list, use_channelwise_list)
for batch_size, input_channels, output_channels, use_bias,\ for batch_size, input_channels, output_channels, use_bias, \
use_multi_dim_input, use_channelwise in cases: use_multi_dim_input, use_channelwise in cases:
self._test_qlinear_impl(batch_size, input_channels, output_channels, self._test_qlinear_impl(batch_size, input_channels, output_channels,
use_bias, post_op, use_multi_dim_input, use_channelwise) use_bias, post_op, use_multi_dim_input, use_channelwise)
@ -4101,7 +4101,7 @@ class TestQuantizedLinear(TestCase):
qparams=hu.qparams(dtypes=torch.qint8))) qparams=hu.qparams(dtypes=torch.qint8)))
@override_qengines @override_qengines
def test_qlinear_qnnpack_free_memory_and_unpack(self, W): def test_qlinear_qnnpack_free_memory_and_unpack(self, W):
assert(qengine_is_qnnpack) assert qengine_is_qnnpack
W, (W_scale, W_zp, torch_type) = W W, (W_scale, W_zp, torch_type) = W
qlinear_prepack = torch.ops.quantized.linear_prepack qlinear_prepack = torch.ops.quantized.linear_prepack
qlinear_unpack = torch.ops.quantized.linear_unpack qlinear_unpack = torch.ops.quantized.linear_unpack
@ -4140,7 +4140,7 @@ class TestQuantizedLinear(TestCase):
cases = itertools.product(batch_size_list, input_channels_list, output_channels_list, cases = itertools.product(batch_size_list, input_channels_list, output_channels_list,
use_bias_list, use_multi_dim_input_list, use_bias_list, use_multi_dim_input_list,
use_channelwise_list, negative_slopes_list) use_channelwise_list, negative_slopes_list)
for batch_size, input_channels, output_channels, use_bias,\ for batch_size, input_channels, output_channels, use_bias, \
use_multi_dim_input, use_channelwise, neg_slope in cases: use_multi_dim_input, use_channelwise, neg_slope in cases:
self._test_qlinear_impl(batch_size, input_channels, output_channels, self._test_qlinear_impl(batch_size, input_channels, output_channels,
use_bias, post_op, use_multi_dim_input, use_bias, post_op, use_multi_dim_input,
@ -4159,7 +4159,7 @@ class TestQuantizedLinear(TestCase):
cases = itertools.product(batch_size_list, input_channels_list, cases = itertools.product(batch_size_list, input_channels_list,
output_channels_list, use_bias_list, output_channels_list, use_bias_list,
use_multi_dim_input_list, use_channelwise_list) use_multi_dim_input_list, use_channelwise_list)
for batch_size, input_channels, output_channels, use_bias,\ for batch_size, input_channels, output_channels, use_bias, \
use_multi_dim_input, use_channelwise in cases: use_multi_dim_input, use_channelwise in cases:
self._test_qlinear_impl(batch_size, input_channels, output_channels, self._test_qlinear_impl(batch_size, input_channels, output_channels,
use_bias, post_op, use_multi_dim_input, use_bias, post_op, use_multi_dim_input,

View File

@ -879,7 +879,7 @@ class TestFakeQuantizeOps(TestCase):
Y_prime.backward(dout) Y_prime.backward(dout)
np.testing.assert_allclose( np.testing.assert_allclose(
dX.cpu().detach().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance) dX.cpu().detach().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
assert(X.grad.dtype == float_type) assert X.grad.dtype == float_type
def test_backward_per_channel_cachemask_cpu(self): def test_backward_per_channel_cachemask_cpu(self):

View File

@ -839,7 +839,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
return reduce(lambda f, g: lambda x: f(g(x)), functions[::-1], lambda x: x) return reduce(lambda f, g: lambda x: f(g(x)), functions[::-1], lambda x: x)
if not use_relu: if not use_relu:
def relu_op(x): def relu_op(x): # noqa: F811
return x return x
if freeze_bn: if freeze_bn:

View File

@ -4319,7 +4319,7 @@ Done""")
] ]
for thread, ranges in threads: for thread, ranges in threads:
for range in ranges: for range in ranges:
assert(len(range) == 3) assert len(range) == 3
events.append( events.append(
FunctionEvent( FunctionEvent(
id=range[2], id=range[2],
@ -4340,7 +4340,7 @@ Done""")
def get_children_ids(event): def get_children_ids(event):
return [child.id for child in event.cpu_children] return [child.id for child in event.cpu_children]
assert([get_children_ids(event) for event in events] == res) assert [get_children_ids(event) for event in events] == res
def test_profiler_aggregation_table(self): def test_profiler_aggregation_table(self):
""" """

View File

@ -30,7 +30,7 @@ def save_and_load(sm):
def bundle_jpeg_image(img_tensor, quality): def bundle_jpeg_image(img_tensor, quality):
# turn NCHW to HWC # turn NCHW to HWC
if img_tensor.dim() == 4: if img_tensor.dim() == 4:
assert(img_tensor.size(0) == 1) assert img_tensor.size(0) == 1
img_tensor = img_tensor[0].permute(1, 2, 0) img_tensor = img_tensor[0].permute(1, 2, 0)
pixels = img_tensor.numpy() pixels = img_tensor.numpy()
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality] encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]

View File

@ -1388,7 +1388,7 @@ torch.cuda.synchronize()
scaler.scale(l).backward() scaler.scale(l).backward()
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
assert(scaler._scale != float('inf') and scaler._scale != float('nan')) assert scaler._scale != float('inf') and scaler._scale != float('nan')
def test_grad_scaling_clipping(self): def test_grad_scaling_clipping(self):
def run(data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api): def run(data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api):

View File

@ -1676,8 +1676,8 @@ class TestFX(JitTestCase):
x = torch.randn(5, 5, 224, 224) x = torch.randn(5, 5, 224, 224)
shape_prop.ShapeProp(traced).propagate(x) shape_prop.ShapeProp(traced).propagate(x)
assert(all(node.meta['tensor_meta'].memory_format is torch.contiguous_format assert all(node.meta['tensor_meta'].memory_format is torch.contiguous_format
for node in traced.graph.nodes)) for node in traced.graph.nodes)
x_channels_last = x.contiguous(memory_format=torch.channels_last) x_channels_last = x.contiguous(memory_format=torch.channels_last)
traced.to(memory_format=torch.channels_last) traced.to(memory_format=torch.channels_last)
@ -1733,8 +1733,8 @@ class TestFX(JitTestCase):
traced_3d = symbolic_trace(test_mod_3d) traced_3d = symbolic_trace(test_mod_3d)
x_3d = torch.randn(5, 5, 224, 224, 15) x_3d = torch.randn(5, 5, 224, 224, 15)
shape_prop.ShapeProp(traced_3d).propagate(x_3d) shape_prop.ShapeProp(traced_3d).propagate(x_3d)
assert(all(node.meta['tensor_meta'].memory_format is torch.contiguous_format assert all(node.meta['tensor_meta'].memory_format is torch.contiguous_format
for node in traced_3d.graph.nodes)) for node in traced_3d.graph.nodes)
x_channels_last_3d = x_3d.contiguous(memory_format=torch.channels_last_3d) x_channels_last_3d = x_3d.contiguous(memory_format=torch.channels_last_3d)
traced_3d.to(memory_format=torch.channels_last_3d) traced_3d.to(memory_format=torch.channels_last_3d)
@ -2548,7 +2548,7 @@ class TestFX(JitTestCase):
return self.a.b, self.a.b.t(), self.a.b.view(12) return self.a.b, self.a.b.t(), self.a.b.view(12)
traced = torch.fx.symbolic_trace(Foo()) traced = torch.fx.symbolic_trace(Foo())
assert(all('constant' not in node.target for node in traced.graph.nodes)) assert all('constant' not in node.target for node in traced.graph.nodes)
def test_single_default_arg(self): def test_single_default_arg(self):
class M(torch.nn.Module): class M(torch.nn.Module):
@ -2824,7 +2824,7 @@ class TestFX(JitTestCase):
mod_false = symbolic_trace(mod, concrete_args={'y': False}) mod_false = symbolic_trace(mod, concrete_args={'y': False})
self.assertEqual(mod_true(3, True), 6) self.assertEqual(mod_true(3, True), 6)
print(mod_true.code) print(mod_true.code)
assert(any(i.target == torch._assert for i in mod_true.graph.nodes)) assert any(i.target == torch._assert for i in mod_true.graph.nodes)
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
mod_true(3, False) mod_true(3, False)
self.assertEqual(mod_false(3, False), 3) self.assertEqual(mod_false(3, False), 3)
@ -3607,17 +3607,17 @@ class TestFX(JitTestCase):
self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(val))), orig_out) self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(val))), orig_out)
assert num_flat_args == 0 or "tree_flatten_spec" in nf.code assert num_flat_args == 0 or "tree_flatten_spec" in nf.code
assert(sum([i.op == 'placeholder' for i in nf.graph.nodes]) == num_flat_args) assert sum([i.op == 'placeholder' for i in nf.graph.nodes]) == num_flat_args
nf = symbolic_trace(nf) nf = symbolic_trace(nf)
self.assertEqual(nf(val), orig_out) self.assertEqual(nf(val), orig_out)
assert "tree_flatten_spec" not in nf.code assert "tree_flatten_spec" not in nf.code
assert(sum([i.op == 'placeholder' for i in nf.graph.nodes]) == 1) assert sum([i.op == 'placeholder' for i in nf.graph.nodes]) == 1
nf = symbolic_trace(nf, concrete_args={'x': inp}) nf = symbolic_trace(nf, concrete_args={'x': inp})
self.assertEqual(nf(val), orig_out) self.assertEqual(nf(val), orig_out)
assert num_flat_args == 0 or "tree_flatten_spec" in nf.code assert num_flat_args == 0 or "tree_flatten_spec" in nf.code
assert(sum([i.op == 'placeholder' for i in nf.graph.nodes]) == num_flat_args) assert sum([i.op == 'placeholder' for i in nf.graph.nodes]) == num_flat_args
pickled = pickle.dumps(nf) pickled = pickle.dumps(nf)
nf = pickle.loads(pickled) nf = pickle.loads(pickled)
@ -3708,7 +3708,7 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
return [('List', typing.List)] return [('List', typing.List)]
def process_inputs(self, *inputs): def process_inputs(self, *inputs):
assert(len(inputs) == 1) assert len(inputs) == 1
return inputs[0] return inputs[0]
def f(a, b): def f(a, b):
@ -3743,7 +3743,7 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
return [('List', typing.List)] return [('List', typing.List)]
def process_inputs(self, *inputs): def process_inputs(self, *inputs):
assert(len(inputs) == 1) assert len(inputs) == 1
return inputs[0] return inputs[0]
def f(a, b): def f(a, b):
@ -3772,7 +3772,7 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
return [('List', typing.List)] return [('List', typing.List)]
def process_inputs(self, *inputs): def process_inputs(self, *inputs):
assert(len(inputs) == 1) assert len(inputs) == 1
return inputs[0] return inputs[0]
def generate_output(self, output_args): def generate_output(self, output_args):

View File

@ -1875,8 +1875,8 @@ graph(%Ra, %Rb):
o_ref = t(x_ref, p, train) o_ref = t(x_ref, p, train)
o.sum().backward() o.sum().backward()
o_ref.sum().backward() o_ref.sum().backward()
assert(o.equal(o_ref)) assert o.equal(o_ref)
assert(x.grad.equal(x_ref.grad)) assert x.grad.equal(x_ref.grad)
@slowTest @slowTest
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, 'Testing differentiable graph') @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, 'Testing differentiable graph')
@ -5746,7 +5746,7 @@ a")
x = torch.zeros(3, 4, dtype=torch.long) x = torch.zeros(3, 4, dtype=torch.long)
graph = _propagate_shapes(fn.graph, (x,), False) graph = _propagate_shapes(fn.graph, (x,), False)
default = torch.get_default_dtype() default = torch.get_default_dtype()
if(default == torch.float): if default == torch.float:
FileCheck().check('Float(*, *, requires_grad=0, device=cpu) = aten::add').run(graph) FileCheck().check('Float(*, *, requires_grad=0, device=cpu) = aten::add').run(graph)
else: else:
FileCheck().check('Double(*, *, requires_grad=0, device=cpu) = aten::add').run(graph) FileCheck().check('Double(*, *, requires_grad=0, device=cpu) = aten::add').run(graph)
@ -6169,7 +6169,7 @@ a")
with self.assertRaisesRegex(RuntimeError, r"Could not cast value of type Tuple\[Tensor, Tensor\]"): # noqa: W605 with self.assertRaisesRegex(RuntimeError, r"Could not cast value of type Tuple\[Tensor, Tensor\]"): # noqa: W605
@torch.jit.script @torch.jit.script
def test_mult(x, y): def test_mult(x, y):
return not(x, y) return not (x, y)
def test_cast_int(x): def test_cast_int(x):
# type: (int) -> int # type: (int) -> int
@ -8557,15 +8557,15 @@ dedent """
fb.run("22 1 22") fb.run("22 1 22")
def _dtype_to_jit_name(self, dtype): def _dtype_to_jit_name(self, dtype):
if(dtype == torch.float32): if dtype == torch.float32:
return "Float" return "Float"
if(dtype == torch.float64): if dtype == torch.float64:
return "Double" return "Double"
if(dtype == torch.int64): if dtype == torch.int64:
return "Long" return "Long"
if(dtype == torch.int32): if dtype == torch.int32:
return "Int" return "Int"
if(dtype == torch.bool): if dtype == torch.bool:
return "Bool" return "Bool"
raise RuntimeError('dtype not handled') raise RuntimeError('dtype not handled')
@ -8595,7 +8595,7 @@ dedent """
for dtype in (dtypes + [None]): for dtype in (dtypes + [None]):
for tensor_type in dtypes: for tensor_type in dtypes:
# a couple of ops aren't implemented for non-floating types # a couple of ops aren't implemented for non-floating types
if(not tensor_type.is_floating_point or (dtype is not None and not dtype.is_floating_point)): if not tensor_type.is_floating_point or (dtype is not None and not dtype.is_floating_point):
if op in ['mean', 'softmax', 'log_softmax']: if op in ['mean', 'softmax', 'log_softmax']:
continue continue
return_line = f"torch.tensor({tensor_data}, dtype={tensor_type}).{op}({str_args}dtype={dtype})" return_line = f"torch.tensor({tensor_data}, dtype={tensor_type}).{op}({str_args}dtype={dtype})"
@ -12837,7 +12837,7 @@ dedent """
for i, offset in enumerate(parsed_serialized_offsets): for i, offset in enumerate(parsed_serialized_offsets):
data = reader.get_record(str(offset)) data = reader.get_record(str(offset))
assert(data == buffers[i]) assert data == buffers[i]
def test_file_reader_no_memory_leak(self): def test_file_reader_no_memory_leak(self):
num_iters = 10000 num_iters = 10000

View File

@ -1032,7 +1032,7 @@ class MpsMemoryLeakCheck:
if driver_mem_allocated > self.driver_before: if driver_mem_allocated > self.driver_before:
driver_discrepancy = True driver_discrepancy = True
if not(caching_allocator_discrepancy or driver_discrepancy): if not (caching_allocator_discrepancy or driver_discrepancy):
# Leak was false positive, exit loop # Leak was false positive, exit loop
discrepancy_detected = False discrepancy_detected = False
break break
@ -4122,7 +4122,7 @@ class TestMPS(TestCaseMPS):
def helper(dtype): def helper(dtype):
# Test with axis -1 # Test with axis -1
cpu_x = None cpu_x = None
if(dtype == torch.float32): if dtype == torch.float32:
cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32) cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32)
else: else:
cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32) cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32)
@ -4157,7 +4157,7 @@ class TestMPS(TestCaseMPS):
def helper(dtype): def helper(dtype):
# Test with axis -1 # Test with axis -1
cpu_x = None cpu_x = None
if(dtype == torch.float32): if dtype == torch.float32:
cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32) cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32)
else: else:
cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32) cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32)

View File

@ -31,7 +31,7 @@ from torch.nn.parallel._functions import Broadcast
from torch.testing._internal.common_dtype import integral_types, get_all_math_dtypes, floating_types from torch.testing._internal.common_dtype import integral_types, get_all_math_dtypes, floating_types
from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \ from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \
TEST_NUMPY, TEST_SCIPY, TEST_WITH_CROSSREF, TEST_WITH_ROCM, \ TEST_NUMPY, TEST_SCIPY, TEST_WITH_CROSSREF, TEST_WITH_ROCM, \
download_file, get_function_arglist, load_tests, skipIfMps,\ download_file, get_function_arglist, load_tests, skipIfMps, \
IS_PPC, \ IS_PPC, \
parametrize as parametrize_test, subtest, instantiate_parametrized_tests, \ parametrize as parametrize_test, subtest, instantiate_parametrized_tests, \
skipIfTorchDynamo, IS_WINDOWS, gcIfJetson, set_default_dtype skipIfTorchDynamo, IS_WINDOWS, gcIfJetson, set_default_dtype
@ -6865,7 +6865,7 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
elif weight_layout == torch.sparse_coo: elif weight_layout == torch.sparse_coo:
module.weight = nn.Parameter(module.weight.to_sparse_coo()) module.weight = nn.Parameter(module.weight.to_sparse_coo())
else: else:
assert(0) raise AssertionError()
inp = torch.randn(4, requires_grad=True, device=device) inp = torch.randn(4, requires_grad=True, device=device)
res = module(inp) res = module(inp)

View File

@ -108,7 +108,7 @@ _ref_test_ops = tuple(
) )
def reduction_dtype_filter(op): def reduction_dtype_filter(op):
if(not isinstance(op, ReductionPythonRefInfo) or not op.supports_out if (not isinstance(op, ReductionPythonRefInfo) or not op.supports_out
or torch.int16 not in op.dtypes): or torch.int16 not in op.dtypes):
return False return False

View File

@ -217,7 +217,7 @@ class SubTensor(torch.Tensor):
""" """
@classmethod @classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None): def __torch_function__(cls, func, types, args=(), kwargs=None):
if(kwargs is None): if kwargs is None:
kwargs = {} kwargs = {}
if func not in HANDLED_FUNCTIONS_SUB: if func not in HANDLED_FUNCTIONS_SUB:
@ -368,7 +368,7 @@ class TensorLike:
""" """
@classmethod @classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None): def __torch_function__(cls, func, types, args=(), kwargs=None):
if(kwargs is None): if kwargs is None:
kwargs = {} kwargs = {}
if func not in HANDLED_FUNCTIONS_TENSOR_LIKE: if func not in HANDLED_FUNCTIONS_TENSOR_LIKE:

View File

@ -465,7 +465,7 @@ class TestPythonRegistration(TestCase):
# check rest of functional_result is the mutated args # check rest of functional_result is the mutated args
mutated_args = [maybe_mutated_arg for maybe_mutated_arg, arg in zip(cloned_args, args) mutated_args = [maybe_mutated_arg for maybe_mutated_arg, arg in zip(cloned_args, args)
if not(maybe_mutated_arg is not None and arg is not None and torch.allclose(maybe_mutated_arg, arg))] if not (maybe_mutated_arg is not None and arg is not None and torch.allclose(maybe_mutated_arg, arg))]
self.assertEqual(flat_functional_result[len(flat_mutable_result):], mutated_args) self.assertEqual(flat_functional_result[len(flat_mutable_result):], mutated_args)
# check that functionalization kernel was indeed registered # check that functionalization kernel was indeed registered

View File

@ -1353,7 +1353,7 @@ class TestReductions(TestCase):
def _test_memory_format_transformations(self, device, input_generator_fn, transformation_fn, def _test_memory_format_transformations(self, device, input_generator_fn, transformation_fn,
memory_format, compare_data=True, default_is_preserve=False): memory_format, compare_data=True, default_is_preserve=False):
assert(memory_format == torch.channels_last or memory_format == torch.channels_last_3d) assert memory_format == torch.channels_last or memory_format == torch.channels_last_3d
# xc is a channels last tensor # xc is a channels last tensor
xc = input_generator_fn(device) xc = input_generator_fn(device)

View File

@ -3988,11 +3988,11 @@ class TestSparse(TestSparseBase):
yield (make_diags((1, 3)), make_offsets([0]), (3, 2, 3)), "Output shape must be 2d" yield (make_diags((1, 3)), make_offsets([0]), (3, 2, 3)), "Output shape must be 2d"
yield (make_diags((2, 3)), make_offsets([[1, 2], [0, 3]]), (3, 3)), "Offsets must be scalar or vector" yield (make_diags((2, 3)), make_offsets([[1, 2], [0, 3]]), (3, 3)), "Offsets must be scalar or vector"
yield (make_diags((3, 2, 3)), make_offsets([0, 1, 2]), (4, 4)), "Diagonals must be vector or matrix" yield (make_diags((3, 2, 3)), make_offsets([0, 1, 2]), (4, 4)), "Diagonals must be vector or matrix"
yield (make_diags((3, 3)), make_offsets([-1, 0]), (3, 3)),\ yield (make_diags((3, 3)), make_offsets([-1, 0]), (3, 3)), \
r"Number of diagonals \(\d\) does not match the number of offsets \(\d\)" r"Number of diagonals \(\d\) does not match the number of offsets \(\d\)"
yield (make_diags((5,)), make_offsets([0, 1, 2, 3, 4]), (3, 3)),\ yield (make_diags((5,)), make_offsets([0, 1, 2, 3, 4]), (3, 3)), \
r"Number of diagonals \(\d\) does not match the number of offsets \(\d\)" r"Number of diagonals \(\d\) does not match the number of offsets \(\d\)"
yield (make_diags((2, 2)), make_offsets([-1, 0]), (2, 3), torch.strided),\ yield (make_diags((2, 2)), make_offsets([-1, 0]), (2, 3), torch.strided), \
r"Only output layouts \(\w+, \w+, \w+\) are supported, got \w+" r"Only output layouts \(\w+, \w+, \w+\) are supported, got \w+"
yield (make_diags((2, 5)), make_offsets([0, 0]), (5, 5)), "Offset tensor contains duplicate values" yield (make_diags((2, 5)), make_offsets([0, 0]), (5, 5)), "Offset tensor contains duplicate values"
yield (make_diags((1, 5)), make_offsets([0]).to(torch.int32), (5, 5)), r"Offset Tensor must have dtype Long but got \w+" yield (make_diags((1, 5)), make_offsets([0]).to(torch.int32), (5, 5)), r"Offset Tensor must have dtype Long but got \w+"

View File

@ -5210,7 +5210,7 @@ else:
def _test_memory_format_transformations(self, device, input_generator_fn, transformation_fn, def _test_memory_format_transformations(self, device, input_generator_fn, transformation_fn,
memory_format, compare_data=True, default_is_preserve=False): memory_format, compare_data=True, default_is_preserve=False):
assert(memory_format == torch.channels_last or memory_format == torch.channels_last_3d) assert memory_format == torch.channels_last or memory_format == torch.channels_last_3d
# xc is a channels last tensor # xc is a channels last tensor
xc = input_generator_fn(device) xc = input_generator_fn(device)

View File

@ -191,7 +191,7 @@ def _load_global_deps() -> None:
'nccl': 'libnccl.so.*[0-9]', 'nccl': 'libnccl.so.*[0-9]',
'nvtx': 'libnvToolsExt.so.*[0-9]', 'nvtx': 'libnvToolsExt.so.*[0-9]',
} }
is_cuda_lib_err = [lib for lib in cuda_libs.values() if(lib.split('.')[0] in err.args[0])] is_cuda_lib_err = [lib for lib in cuda_libs.values() if lib.split('.')[0] in err.args[0]]
if not is_cuda_lib_err: if not is_cuda_lib_err:
raise err raise err
for lib_folder, lib_name in cuda_libs.items(): for lib_folder, lib_name in cuda_libs.items():

View File

@ -75,7 +75,7 @@ def is_gpu_compute_event(event):
def get_sorted_gpu_events(events): def get_sorted_gpu_events(events):
sorted_gpu_events = [] sorted_gpu_events = []
for event in events: for event in events:
if(not is_gpu_compute_event(event)): if not is_gpu_compute_event(event):
continue continue
sorted_gpu_events.append(event) sorted_gpu_events.append(event)
return sorted(sorted_gpu_events, key=lambda x: x["ts"]) return sorted(sorted_gpu_events, key=lambda x: x["ts"])
@ -102,7 +102,7 @@ def get_sorted_gpu_mm_conv_events(events):
gpu_events = get_sorted_gpu_events(events) gpu_events = get_sorted_gpu_events(events)
sorted_events = [] sorted_events = []
for event in gpu_events: for event in gpu_events:
if(not is_mm_conv_event(event)): if not is_mm_conv_event(event):
continue continue
sorted_events.append(event) sorted_events.append(event)
return sorted_events return sorted_events

View File

@ -55,7 +55,7 @@ class OutDtypeOperator(HigherOrderOperator):
raise ValueError("out_dtype's first argument must be an OpOverload") raise ValueError("out_dtype's first argument must be an OpOverload")
if op._schema.is_mutable: if op._schema.is_mutable:
raise ValueError("out_dtype's first argument needs to be a functional operator") raise ValueError("out_dtype's first argument needs to be a functional operator")
if not( if not (
len(op._schema.returns) == 1 and len(op._schema.returns) == 1 and
isinstance(op._schema.returns[0].type, torch.TensorType) isinstance(op._schema.returns[0].type, torch.TensorType)
): ):

View File

@ -240,7 +240,7 @@ class QFunctional(torch.nn.Module):
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod):
assert type(mod) == FloatFunctional,\ assert type(mod) == FloatFunctional, \
"QFunctional.from_float expects an instance of FloatFunctional" "QFunctional.from_float expects an instance of FloatFunctional"
scale, zero_point = mod.activation_post_process.calculate_qparams() # type: ignore[operator] scale, zero_point = mod.activation_post_process.calculate_qparams() # type: ignore[operator]
new_mod = QFunctional() new_mod = QFunctional()

View File

@ -22,7 +22,7 @@ class LinearBlockSparsePattern:
prev_col_block_size = 4 prev_col_block_size = 4
def __init__(self, row_block_size=1, col_block_size=4): def __init__(self, row_block_size=1, col_block_size=4):
assert(_is_valid_linear_block_sparse_pattern(row_block_size, col_block_size)) assert _is_valid_linear_block_sparse_pattern(row_block_size, col_block_size)
LinearBlockSparsePattern.rlock.acquire() LinearBlockSparsePattern.rlock.acquire()
LinearBlockSparsePattern.prev_row_block_size = LinearBlockSparsePattern.row_block_size LinearBlockSparsePattern.prev_row_block_size = LinearBlockSparsePattern.row_block_size
LinearBlockSparsePattern.prev_col_block_size = LinearBlockSparsePattern.col_block_size LinearBlockSparsePattern.prev_col_block_size = LinearBlockSparsePattern.col_block_size

View File

@ -85,7 +85,7 @@ def _fuse_linear_bn_leaky_relu(is_qat, linear, bn, leaky_relu):
>>> lr = nn.LeakyReLU(0.01) >>> lr = nn.LeakyReLU(0.01)
>>> m2 = _fuse_linear_bn_leaky_relu(m1, b1, lr) >>> m2 = _fuse_linear_bn_leaky_relu(m1, b1, lr)
""" """
assert(linear.training == bn.training and bn.training == leaky_relu.training),\ assert linear.training == bn.training and bn.training == leaky_relu.training, \
"Linear, BN and LeakyReLU all must be in the same mode (train or eval)." "Linear, BN and LeakyReLU all must be in the same mode (train or eval)."
if is_qat: if is_qat:

View File

@ -58,8 +58,8 @@ class APoTObserver(ObserverBase):
alpha = torch.max(-self.min_val, self.max_val) alpha = torch.max(-self.min_val, self.max_val)
# check for valid inputs of b, k # check for valid inputs of b, k
assert(self.k and self.k != 0) assert self.k and self.k != 0
assert(self.b % self.k == 0) assert self.b % self.k == 0
# compute n and store as member variable # compute n and store as member variable
self.n = self.b // self.k self.n = self.b // self.k

View File

@ -274,7 +274,7 @@ class FixedQParamsFakeQuantize(FakeQuantize):
# TODO: rename observer to observer_ctr # TODO: rename observer to observer_ctr
def __init__(self, observer): def __init__(self, observer):
super().__init__(observer=observer) super().__init__(observer=observer)
assert type(self.activation_post_process) == FixedQParamsObserver,\ assert type(self.activation_post_process) == FixedQParamsObserver, \
f"{self.__class__.__name__}'s observer must be a {FixedQParamsObserver.__name__}" f"{self.__class__.__name__}'s observer must be a {FixedQParamsObserver.__name__}"
self._observer_ctr = observer self._observer_ctr = observer
self.scale = self.activation_post_process.scale self.scale = self.activation_post_process.scale
@ -322,7 +322,7 @@ class FusedMovingAvgObsFakeQuantize(FakeQuantize):
**observer_kwargs: Any **observer_kwargs: Any
) -> None: ) -> None:
super().__init__(observer, quant_min, quant_max, **observer_kwargs) super().__init__(observer, quant_min, quant_max, **observer_kwargs)
assert isinstance(self.activation_post_process, (MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver)),\ assert isinstance(self.activation_post_process, (MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver)), \
"Fused observer+fake_quant module only works with MovingAverageMinMaxObserver" "Fused observer+fake_quant module only works with MovingAverageMinMaxObserver"
self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long)) self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long))
self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long)) self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long))

View File

@ -31,7 +31,7 @@ def fuse_conv_bn(is_qat, conv, bn):
>>> # xdoctest: +SKIP >>> # xdoctest: +SKIP
>>> m2 = fuse_conv_bn(m1, b1) >>> m2 = fuse_conv_bn(m1, b1)
""" """
assert(conv.training == bn.training),\ assert conv.training == bn.training, \
"Conv and BN both must be in the same mode (train or eval)." "Conv and BN both must be in the same mode (train or eval)."
fused_module_class_map = { fused_module_class_map = {
@ -71,7 +71,7 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu):
>>> # xdoctest: +SKIP >>> # xdoctest: +SKIP
>>> m2 = fuse_conv_bn_relu(m1, b1, r1) >>> m2 = fuse_conv_bn_relu(m1, b1, r1)
""" """
assert(conv.training == bn.training == relu.training),\ assert conv.training == bn.training == relu.training, \
"Conv and BN both must be in the same mode (train or eval)." "Conv and BN both must be in the same mode (train or eval)."
fused_module : Optional[Type[nn.Sequential]] = None fused_module : Optional[Type[nn.Sequential]] = None
if is_qat: if is_qat:
@ -118,14 +118,14 @@ def fuse_linear_bn(is_qat, linear, bn):
>>> # xdoctest: +SKIP >>> # xdoctest: +SKIP
>>> m2 = fuse_linear_bn(m1, b1) >>> m2 = fuse_linear_bn(m1, b1)
""" """
assert(linear.training == bn.training),\ assert linear.training == bn.training, \
"Linear and BN both must be in the same mode (train or eval)." "Linear and BN both must be in the same mode (train or eval)."
if is_qat: if is_qat:
assert bn.num_features == linear.out_features,\ assert bn.num_features == linear.out_features, \
"Output features of Linear must match num_features of BatchNorm1d" "Output features of Linear must match num_features of BatchNorm1d"
assert bn.affine, "Only support fusing BatchNorm1d with affine set to True" assert bn.affine, "Only support fusing BatchNorm1d with affine set to True"
assert bn.track_running_stats,\ assert bn.track_running_stats, \
"Only support fusing BatchNorm1d with tracking_running_stats set to True" "Only support fusing BatchNorm1d with tracking_running_stats set to True"
return nni.LinearBn1d(linear, bn) return nni.LinearBn1d(linear, bn)
else: else:
@ -147,7 +147,7 @@ def fuse_convtranspose_bn(is_qat, convt, bn):
>>> # xdoctest: +SKIP >>> # xdoctest: +SKIP
>>> m2 = fuse_convtranspose_bn(m1, b1) >>> m2 = fuse_convtranspose_bn(m1, b1)
""" """
assert(convt.training == bn.training),\ assert convt.training == bn.training, \
"ConvTranspose and BN both must be in the same mode (train or eval)." "ConvTranspose and BN both must be in the same mode (train or eval)."
if is_qat: if is_qat:

View File

@ -291,24 +291,24 @@ def get_op_node_and_weight_eq_obs(
op_node = user op_node = user
break break
assert(op_node is not None) assert op_node is not None
if op_node.op == 'call_module': if op_node.op == 'call_module':
# If the op_node is a nn.Linear layer, then it must have a # If the op_node is a nn.Linear layer, then it must have a
# WeightEqualizationObserver configuration # WeightEqualizationObserver configuration
maybe_equalization_node_name_to_config = _get_observed_graph_module_attr(model, "equalization_node_name_to_qconfig") maybe_equalization_node_name_to_config = _get_observed_graph_module_attr(model, "equalization_node_name_to_qconfig")
assert maybe_equalization_node_name_to_config is not None assert maybe_equalization_node_name_to_config is not None
equalization_node_name_to_qconfig: Dict[str, Any] = maybe_equalization_node_name_to_config # type: ignore[assignment] equalization_node_name_to_qconfig: Dict[str, Any] = maybe_equalization_node_name_to_config # type: ignore[assignment]
assert(equalization_node_name_to_qconfig.get(op_node.name, None) is not None) assert equalization_node_name_to_qconfig.get(op_node.name, None) is not None
weight_eq_obs = equalization_node_name_to_qconfig.get(op_node.name, None).weight() weight_eq_obs = equalization_node_name_to_qconfig.get(op_node.name, None).weight()
assert(isinstance(weight_eq_obs, _WeightEqualizationObserver)) assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
return op_node, weight_eq_obs return op_node, weight_eq_obs
elif op_node.op == 'call_function': elif op_node.op == 'call_function':
weight_node = maybe_get_weight_eq_obs_node(op_node, modules) weight_node = maybe_get_weight_eq_obs_node(op_node, modules)
if weight_node is not None: if weight_node is not None:
weight_eq_obs = modules[str(weight_node.target)] weight_eq_obs = modules[str(weight_node.target)]
assert(isinstance(weight_eq_obs, _WeightEqualizationObserver)) assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
return op_node, weight_eq_obs return op_node, weight_eq_obs
return None, None return None, None
@ -316,10 +316,10 @@ def get_op_node_and_weight_eq_obs(
def maybe_get_weight_eq_obs_node(op_node: Node, modules: Dict[str, nn.Module]) -> Optional[Node]: def maybe_get_weight_eq_obs_node(op_node: Node, modules: Dict[str, nn.Module]) -> Optional[Node]:
""" Gets the weight equalization observer node if it exists. """ Gets the weight equalization observer node if it exists.
""" """
assert(op_node.op == 'call_function') assert op_node.op == 'call_function'
for node_arg in op_node.args: for node_arg in op_node.args:
if node_arg_is_weight(op_node, node_arg): if node_arg_is_weight(op_node, node_arg):
assert(isinstance(node_arg, Node) and node_arg.op == 'call_module' and assert (isinstance(node_arg, Node) and node_arg.op == 'call_module' and
isinstance(modules[str(node_arg.target)], _WeightEqualizationObserver)) isinstance(modules[str(node_arg.target)], _WeightEqualizationObserver))
return node_arg return node_arg
return None return None
@ -342,7 +342,7 @@ def maybe_get_next_input_eq_obs(node: Node, modules: Dict[str, nn.Module]) -> Op
the following equalization observer for linear2. the following equalization observer for linear2.
""" """
assert(node_supports_equalization(node, modules)) assert node_supports_equalization(node, modules)
# Locate the following nn.ReLU or F.relu node if it exists # Locate the following nn.ReLU or F.relu node if it exists
maybe_relu_node = maybe_get_next_module(node, modules, nn.ReLU) maybe_relu_node = maybe_get_next_module(node, modules, nn.ReLU)
@ -364,7 +364,7 @@ def maybe_get_next_input_eq_obs(node: Node, modules: Dict[str, nn.Module]) -> Op
return None return None
maybe_eq_obs = modules[str(maybe_eq_obs_node)] maybe_eq_obs = modules[str(maybe_eq_obs_node)]
assert(isinstance(maybe_eq_obs, _InputEqualizationObserver)) assert isinstance(maybe_eq_obs, _InputEqualizationObserver)
return maybe_eq_obs return maybe_eq_obs
def maybe_get_next_equalization_scale(node: Node, modules: Dict[str, nn.Module]) -> Optional[torch.Tensor]: def maybe_get_next_equalization_scale(node: Node, modules: Dict[str, nn.Module]) -> Optional[torch.Tensor]:
@ -389,10 +389,10 @@ def scale_input_observer(node: Node, modules: Dict[str, nn.Module]) -> None:
equalization observer equalization observer
""" """
input_eq_obs = modules[str(node.target)] input_eq_obs = modules[str(node.target)]
assert(isinstance(input_eq_obs, _InputEqualizationObserver)) assert isinstance(input_eq_obs, _InputEqualizationObserver)
input_quant_obs_node = node.args[0] input_quant_obs_node = node.args[0]
assert(isinstance(input_quant_obs_node, Node)) assert isinstance(input_quant_obs_node, Node)
input_quant_obs = modules[str(input_quant_obs_node.target)] input_quant_obs = modules[str(input_quant_obs_node.target)]
if not isinstance(input_quant_obs, ObserverBase): if not isinstance(input_quant_obs, ObserverBase):
@ -426,12 +426,12 @@ def scale_weight_node(
op_module = modules[str(node.target)][0] # type: ignore[index] op_module = modules[str(node.target)][0] # type: ignore[index]
else: else:
op_module = modules[str(node.target)] op_module = modules[str(node.target)]
assert(nn_module_supports_equalization(op_module) or custom_module_supports_equalization(op_module)) assert nn_module_supports_equalization(op_module) or custom_module_supports_equalization(op_module)
# Scale the weights for input-weight equalization # Scale the weights for input-weight equalization
# If the following layer needs to be equalized then we will multiply its scale # If the following layer needs to be equalized then we will multiply its scale
weight = op_module.weight weight = op_module.weight
assert(isinstance(weight, torch.Tensor)) assert isinstance(weight, torch.Tensor)
# Scale the weights by the reciprocal of the equalization scale # Scale the weights by the reciprocal of the equalization scale
# Reshape the equalization scale so that we can multiply it to the weight along axis=1 # Reshape the equalization scale so that we can multiply it to the weight along axis=1
@ -453,7 +453,7 @@ def scale_weight_node(
bias = op_module.bias bias = op_module.bias
if bias is None: if bias is None:
return return
assert(isinstance(bias, torch.Tensor)) assert isinstance(bias, torch.Tensor)
# Reshape the equalization scale so that we can multiply it element-wise to the bias # Reshape the equalization scale so that we can multiply it element-wise to the bias
next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias) next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias)
@ -487,14 +487,14 @@ def scale_weight_functional(
weight_quant_obs_node = weight_eq_obs_node.args[0] weight_quant_obs_node = weight_eq_obs_node.args[0]
if weight_quant_obs_node is None: if weight_quant_obs_node is None:
return return
assert(isinstance(weight_quant_obs_node, Node) and assert (isinstance(weight_quant_obs_node, Node) and
isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase)) isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase))
# Get the get_attr(weight) node # Get the get_attr(weight) node
weight_node = weight_quant_obs_node.args[0] weight_node = weight_quant_obs_node.args[0]
if weight_node is None: if weight_node is None:
return return
assert(isinstance(weight_node, Node) and weight_node.op == 'get_attr') assert isinstance(weight_node, Node) and weight_node.op == 'get_attr'
weight_parent_name, weight_name = _parent_name(weight_node.target) weight_parent_name, weight_name = _parent_name(weight_node.target)
weight = getattr(modules[weight_parent_name], weight_name) weight = getattr(modules[weight_parent_name], weight_name)
@ -515,7 +515,7 @@ def scale_weight_functional(
scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped) scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped)
setattr(modules[weight_parent_name], weight_name, scaled_weight) setattr(modules[weight_parent_name], weight_name, scaled_weight)
assert(torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight)) assert torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight)
# Multiply the bias element wise by the next equalization scale # Multiply the bias element wise by the next equalization scale
bias_node = None bias_node = None
@ -546,10 +546,10 @@ def clear_weight_quant_obs_node(op_node: Node, modules: Dict[str, nn.Module]) ->
weight_quant_obs_node = weight_eq_obs_node.args[0] weight_quant_obs_node = weight_eq_obs_node.args[0]
if weight_quant_obs_node is None: if weight_quant_obs_node is None:
return return
assert(isinstance(weight_quant_obs_node, Node)) assert isinstance(weight_quant_obs_node, Node)
weight_quant_obs = modules[str(weight_quant_obs_node.target)] weight_quant_obs = modules[str(weight_quant_obs_node.target)]
assert(isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase)) assert isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase)
weight_quant_obs.reset_min_max_vals() # type: ignore[operator] weight_quant_obs.reset_min_max_vals() # type: ignore[operator]
def remove_node(model: GraphModule, node: Node, prev_node: Node): def remove_node(model: GraphModule, node: Node, prev_node: Node):
@ -578,7 +578,7 @@ def update_obs_for_equalization(model: GraphModule, modules: Dict[str, nn.Module
for node in model.graph.nodes: for node in model.graph.nodes:
if node.op == 'call_module' and isinstance(modules[node.target], _InputEqualizationObserver): if node.op == 'call_module' and isinstance(modules[node.target], _InputEqualizationObserver):
input_eq_obs = modules[node.target] input_eq_obs = modules[node.target]
assert(isinstance(input_eq_obs, _InputEqualizationObserver)) assert isinstance(input_eq_obs, _InputEqualizationObserver)
op_node, weight_eq_obs = get_op_node_and_weight_eq_obs(node, model, modules) op_node, weight_eq_obs = get_op_node_and_weight_eq_obs(node, model, modules)
if op_node is None or weight_eq_obs is None: if op_node is None or weight_eq_obs is None:
@ -589,7 +589,7 @@ def update_obs_for_equalization(model: GraphModule, modules: Dict[str, nn.Module
# been created # been created
if fused_module_supports_equalization(modules[str(op_node.target)]): if fused_module_supports_equalization(modules[str(op_node.target)]):
module = modules[str(op_node.target)][0] # type: ignore[index] module = modules[str(op_node.target)][0] # type: ignore[index]
assert(nn_module_supports_equalization(module)) assert nn_module_supports_equalization(module)
weight_eq_obs(module.weight) weight_eq_obs(module.weight)
else: else:
weight_eq_obs(modules[str(op_node.target)].weight) weight_eq_obs(modules[str(op_node.target)].weight)
@ -696,7 +696,7 @@ def convert_eq_obs(
elif weight_eq_obs_dict.get(node.name, None) is not None: elif weight_eq_obs_dict.get(node.name, None) is not None:
weight_eq_obs = weight_eq_obs_dict.get(node.name) weight_eq_obs = weight_eq_obs_dict.get(node.name)
assert(isinstance(weight_eq_obs, _WeightEqualizationObserver)) assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
equalization_scale = weight_eq_obs.equalization_scale equalization_scale = weight_eq_obs.equalization_scale
if equalization_scale.nelement() == 1 and equalization_scale == torch.tensor(1): if equalization_scale.nelement() == 1 and equalization_scale == torch.tensor(1):
@ -712,7 +712,7 @@ def convert_eq_obs(
weight_eq_obs_node = maybe_get_weight_eq_obs_node(node, modules) weight_eq_obs_node = maybe_get_weight_eq_obs_node(node, modules)
if weight_eq_obs_node is None: if weight_eq_obs_node is None:
return return
assert(isinstance(modules[str(weight_eq_obs_node.target)], _WeightEqualizationObserver)) assert isinstance(modules[str(weight_eq_obs_node.target)], _WeightEqualizationObserver)
# Clear the quantization observer's min/max values so that they # Clear the quantization observer's min/max values so that they
# can get updated later based on the new scale values # can get updated later based on the new scale values

View File

@ -487,14 +487,14 @@ def _match_static_pattern(
return SKIP_LOWERING_VALUE return SKIP_LOWERING_VALUE
q_node = node q_node = node
ref_node = q_node.args[0] ref_node = q_node.args[0]
assert(isinstance(ref_node, Node)) assert isinstance(ref_node, Node)
# Handle cases where the node is wrapped in a ReLU # Handle cases where the node is wrapped in a ReLU
if (ref_node.op == "call_function" and ref_node.target in (F.relu, torch.relu)) or\ if (ref_node.op == "call_function" and ref_node.target in (F.relu, torch.relu)) or\
(ref_node.op == "call_module" and type(_get_module(ref_node, modules)) == nn.ReLU): (ref_node.op == "call_module" and type(_get_module(ref_node, modules)) == nn.ReLU):
relu_node = ref_node relu_node = ref_node
ref_node = relu_node.args[0] ref_node = relu_node.args[0]
assert(isinstance(ref_node, Node)) assert isinstance(ref_node, Node)
else: else:
relu_node = None relu_node = None
if should_skip_lowering(ref_node, qconfig_map): if should_skip_lowering(ref_node, qconfig_map):
@ -515,7 +515,7 @@ def _match_static_pattern(
# (2) There must be at least one dequantize node # (2) There must be at least one dequantize node
matched_dequantize = False matched_dequantize = False
for i in dequantize_node_arg_indices: for i in dequantize_node_arg_indices:
assert i < len(ref_node.args),\ assert i < len(ref_node.args), \
f"Dequantize index {i} exceeded reference node's arg length {len(ref_node.args)}" f"Dequantize index {i} exceeded reference node's arg length {len(ref_node.args)}"
arg = ref_node.args[i] arg = ref_node.args[i]
if is_dequantize_node(arg): if is_dequantize_node(arg):
@ -557,7 +557,7 @@ def _match_static_pattern_with_two_inputs(
return SKIP_LOWERING_VALUE return SKIP_LOWERING_VALUE
q_node = node q_node = node
ref_node = q_node.args[0] ref_node = q_node.args[0]
assert(isinstance(ref_node, Node)) assert isinstance(ref_node, Node)
if should_skip_lowering(ref_node, qconfig_map): if should_skip_lowering(ref_node, qconfig_map):
return SKIP_LOWERING_VALUE return SKIP_LOWERING_VALUE
@ -599,13 +599,13 @@ def _lower_static_weighted_ref_module(
n, modules, qconfig_map, matching_modules, dequantize_node_arg_indices=[0]) # type: ignore[arg-type] n, modules, qconfig_map, matching_modules, dequantize_node_arg_indices=[0]) # type: ignore[arg-type]
if q_node is None: if q_node is None:
continue continue
assert(ref_node is not None) assert ref_node is not None
(_, scale_node, zero_point_node, _) = q_node.args (_, scale_node, zero_point_node, _) = q_node.args
ref_module = _get_module(ref_node, modules) ref_module = _get_module(ref_node, modules)
ref_class = type(ref_module) ref_class = type(ref_module)
assert(isinstance(scale_node, Node)) assert isinstance(scale_node, Node)
assert(isinstance(zero_point_node, Node)) assert isinstance(zero_point_node, Node)
assert(issubclass(ref_class, nn.Module)) assert issubclass(ref_class, nn.Module)
# Step 1: Change this pattern to use the corresponding quantized module # Step 1: Change this pattern to use the corresponding quantized module
# For fused modules, we also check whether the inner module is a reference module # For fused modules, we also check whether the inner module is a reference module
@ -624,9 +624,9 @@ def _lower_static_weighted_ref_module(
setattr(modules[parent_name], module_name, q_module) setattr(modules[parent_name], module_name, q_module)
# Step 2: Reroute around dq_node, and remove q_node and its args # Step 2: Reroute around dq_node, and remove q_node and its args
assert(len(ref_node.args) == 1) assert len(ref_node.args) == 1
dq_node = ref_node.args[0] dq_node = ref_node.args[0]
assert(isinstance(dq_node, Node)) assert isinstance(dq_node, Node)
ref_node.replace_input_with(dq_node, dq_node.args[0]) ref_node.replace_input_with(dq_node, dq_node.args[0])
q_node.replace_all_uses_with(ref_node) q_node.replace_all_uses_with(ref_node)
model.graph.erase_node(q_node) model.graph.erase_node(q_node)
@ -655,13 +655,13 @@ def _lower_static_weighted_ref_module_with_two_inputs(
n, modules, qconfig_map, matching_modules) # type: ignore[arg-type] n, modules, qconfig_map, matching_modules) # type: ignore[arg-type]
if q_node is None: if q_node is None:
continue continue
assert(ref_node is not None) assert ref_node is not None
(_, scale_node, zero_point_node, _) = q_node.args (_, scale_node, zero_point_node, _) = q_node.args
ref_module = _get_module(ref_node, modules) ref_module = _get_module(ref_node, modules)
ref_class = type(ref_module) ref_class = type(ref_module)
assert(isinstance(scale_node, Node)) assert isinstance(scale_node, Node)
assert(isinstance(zero_point_node, Node)) assert isinstance(zero_point_node, Node)
assert(issubclass(ref_class, nn.Module)) assert issubclass(ref_class, nn.Module)
# Step 1: Change this pattern to use the corresponding quantized module # Step 1: Change this pattern to use the corresponding quantized module
# For fused modules, we also check whether the inner module is a reference module # For fused modules, we also check whether the inner module is a reference module
@ -680,12 +680,12 @@ def _lower_static_weighted_ref_module_with_two_inputs(
setattr(modules[parent_name], module_name, q_module) setattr(modules[parent_name], module_name, q_module)
# Step 2: Reroute around dq_node, and remove q_node and its args # Step 2: Reroute around dq_node, and remove q_node and its args
assert(len(ref_node.args) == 2) assert len(ref_node.args) == 2
for arg in ref_node.args: for arg in ref_node.args:
if not is_dequantize_node(arg): if not is_dequantize_node(arg):
continue continue
dq_node = arg dq_node = arg
assert(isinstance(dq_node, Node)) assert isinstance(dq_node, Node)
ref_node.replace_input_with(dq_node, dq_node.args[0]) ref_node.replace_input_with(dq_node, dq_node.args[0])
q_node.replace_all_uses_with(ref_node) q_node.replace_all_uses_with(ref_node)
@ -778,14 +778,14 @@ def _lower_static_weighted_ref_functional(
n, modules, qconfig_map, matching_ops, dequantize_node_arg_indices=[0, 1]) n, modules, qconfig_map, matching_ops, dequantize_node_arg_indices=[0, 1])
if q_node is None: if q_node is None:
continue continue
assert(func_node is not None) assert func_node is not None
(_, output_scale_node, output_zp_node, _) = q_node.args (_, output_scale_node, output_zp_node, _) = q_node.args
(input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args (input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args
assert(isinstance(output_zp_node, Node)) assert isinstance(output_zp_node, Node)
assert(isinstance(input_dq_node, Node)) assert isinstance(input_dq_node, Node)
assert(isinstance(weight_dq_node, Node)) assert isinstance(weight_dq_node, Node)
quantized_weight = weight_dq_node.args[0] quantized_weight = weight_dq_node.args[0]
assert(isinstance(quantized_weight, Node)) assert isinstance(quantized_weight, Node)
if quantized_weight.op != "call_function" or\ if quantized_weight.op != "call_function" or\
quantized_weight.target not in (torch.quantize_per_tensor, torch.quantize_per_channel): quantized_weight.target not in (torch.quantize_per_tensor, torch.quantize_per_channel):
continue continue
@ -966,7 +966,7 @@ def _lower_quantized_binary_op(
n, modules, qconfig_map, binary_ops_to_lower, dequantize_node_arg_indices=[0, 1]) n, modules, qconfig_map, binary_ops_to_lower, dequantize_node_arg_indices=[0, 1])
if q_node is None: if q_node is None:
continue continue
assert(bop_node is not None) assert bop_node is not None
(_, scale_node, zero_point_node, _) = q_node.args (_, scale_node, zero_point_node, _) = q_node.args
# Step 1: Remove dequant nodes # Step 1: Remove dequant nodes
@ -975,11 +975,11 @@ def _lower_quantized_binary_op(
if not is_dequantize_node(arg): if not is_dequantize_node(arg):
continue continue
dq_node = arg dq_node = arg
assert(isinstance(dq_node, Node)) assert isinstance(dq_node, Node)
dn_input = dq_node.args[0] dn_input = dq_node.args[0]
bop_node.replace_input_with(dq_node, dn_input) bop_node.replace_input_with(dq_node, dn_input)
num_dq_nodes += 1 num_dq_nodes += 1
assert(num_dq_nodes > 0) assert num_dq_nodes > 0
# Step 2: Swap binary op to quantized binary op # Step 2: Swap binary op to quantized binary op
assert bop_node.target in QBIN_OP_MAPPING assert bop_node.target in QBIN_OP_MAPPING

View File

@ -956,7 +956,7 @@ def convert(
"in a future version. Please pass in a QConfigMapping instead.") "in a future version. Please pass in a QConfigMapping instead.")
qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None
qconfig_mapping = copy.deepcopy(qconfig_mapping) qconfig_mapping = copy.deepcopy(qconfig_mapping)
assert(qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping)) assert qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping)
if isinstance(backend_config, Dict): if isinstance(backend_config, Dict):
warnings.warn( warnings.warn(

View File

@ -1585,7 +1585,7 @@ def insert_observers_for_model(
# should resolve this inconsistency by inserting DeQuantStubs for all custom # should resolve this inconsistency by inserting DeQuantStubs for all custom
# modules, not just for LSTM. # modules, not just for LSTM.
_insert_dequant_stubs_for_custom_module_lstm_output(node, model, named_modules, model.graph) _insert_dequant_stubs_for_custom_module_lstm_output(node, model, named_modules, model.graph)
if(node.target not in custom_module_names_already_swapped): if node.target not in custom_module_names_already_swapped:
custom_module_names_already_swapped.add(node.target) custom_module_names_already_swapped.add(node.target)
_swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config) _swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config)
else: else:
@ -1627,7 +1627,7 @@ def insert_observers_for_model(
_remove_output_observer(node, model, named_modules) _remove_output_observer(node, model, named_modules)
if qhandler is not None and qhandler.is_custom_module(): if qhandler is not None and qhandler.is_custom_module():
if(node.target not in custom_module_names_already_swapped): if node.target not in custom_module_names_already_swapped:
custom_module_names_already_swapped.add(node.target) custom_module_names_already_swapped.add(node.target)
_swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config) _swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config)
@ -1771,8 +1771,8 @@ def prepare(
"in a future version. Please pass in a BackendConfig instead.") "in a future version. Please pass in a BackendConfig instead.")
backend_config = BackendConfig.from_dict(backend_config) backend_config = BackendConfig.from_dict(backend_config)
assert(isinstance(qconfig_mapping, QConfigMapping)) assert isinstance(qconfig_mapping, QConfigMapping)
assert(isinstance(_equalization_config, QConfigMapping)) assert isinstance(_equalization_config, QConfigMapping)
qconfig_mapping = copy.deepcopy(qconfig_mapping) qconfig_mapping = copy.deepcopy(qconfig_mapping)
_equalization_config = copy.deepcopy(_equalization_config) _equalization_config = copy.deepcopy(_equalization_config)

View File

@ -26,7 +26,7 @@ def binary_cmp(cmp_fun, types, args, kwargs=None, process_group=None):
result = True result = True
st1 = args[0] st1 = args[0]
st2 = args[1] st2 = args[1]
if not(isinstance(st1, ShardedTensor) and isinstance(st2, ShardedTensor)): if not (isinstance(st1, ShardedTensor) and isinstance(st2, ShardedTensor)):
raise TypeError(f'Both arguments to torch.{cmp_fun.__name__} need to be of type ShardedTensor') raise TypeError(f'Both arguments to torch.{cmp_fun.__name__} need to be of type ShardedTensor')
# Verify same PG # Verify same PG

View File

@ -403,7 +403,7 @@ def generate_calc_product(constraint, counter):
for p in all_possibilities: for p in all_possibilities:
p = list(p) p = list(p)
# this tells us there is a dynamic variable # this tells us there is a dynamic variable
contains_dyn = not(all(constraint.op == op_neq for constraint in p)) contains_dyn = not all(constraint.op == op_neq for constraint in p)
if contains_dyn: if contains_dyn:
mid_var = [Dyn] mid_var = [Dyn]
total_constraints = lhs + mid_var + rhs total_constraints = lhs + mid_var + rhs

View File

@ -311,7 +311,7 @@ try:
negation_transformed_condition_constraint = z3.Not(transformed_condition_constraint) negation_transformed_condition_constraint = z3.Not(transformed_condition_constraint)
return z3.And([transformed, transformed_condition_constraint]),\ return z3.And([transformed, transformed_condition_constraint]), \
z3.And([transformed, negation_transformed_condition_constraint]) z3.And([transformed, negation_transformed_condition_constraint])

View File

@ -42,7 +42,7 @@ def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict
def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module): def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):
assert(isinstance(node.target, str)) assert isinstance(node.target, str)
parent_name, name = _parent_name(node.target) parent_name, name = _parent_name(node.target)
modules[node.target] = new_module modules[node.target] = new_module
setattr(modules[parent_name], name, new_module) setattr(modules[parent_name], name, new_module)
@ -133,11 +133,11 @@ def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]):
old_modules: Dict[nn.Module, nn.Module] = {} old_modules: Dict[nn.Module, nn.Module] = {}
for node in nodes: for node in nodes:
if node.op == 'call_module': if node.op == 'call_module':
assert(isinstance(node.target, str)) assert isinstance(node.target, str)
cur_module = modules[node.target] cur_module = modules[node.target]
if type(cur_module) in mkldnn_map: if type(cur_module) in mkldnn_map:
new_module = mkldnn_map[type(cur_module)](cur_module, torch.float) new_module = mkldnn_map[type(cur_module)](cur_module, torch.float)
assert(isinstance(new_module, nn.Module)) assert isinstance(new_module, nn.Module)
old_modules[new_module] = copy.deepcopy(cur_module) old_modules[new_module] = copy.deepcopy(cur_module)
replace_node_module(node, modules, new_module) replace_node_module(node, modules, new_module)
return old_modules return old_modules
@ -149,7 +149,7 @@ def reset_modules(nodes: List[fx.Node], modules: Dict[str, nn.Module], old_modul
""" """
for node in nodes: for node in nodes:
if node.op == 'call_module': if node.op == 'call_module':
assert(isinstance(node.target, str)) assert (isinstance(node.target, str))
cur_module = modules[node.target] cur_module = modules[node.target]
if cur_module in old_modules: if cur_module in old_modules:
replace_node_module(node, modules, old_modules[cur_module]) replace_node_module(node, modules, old_modules[cur_module])
@ -220,7 +220,7 @@ class UnionFind:
par = self.parent[v] par = self.parent[v]
if v == par: if v == par:
return v return v
assert(par is not None) assert par is not None
self.parent[v] = self.find(par) self.parent[v] = self.find(par)
return cast(int, self.parent[v]) return cast(int, self.parent[v])
@ -294,8 +294,8 @@ def optimize_for_inference(
supports_mkldnn = MklSupport.YES supports_mkldnn = MklSupport.YES
sample_parameter = next(cur_module.parameters(), None) sample_parameter = next(cur_module.parameters(), None)
if sample_parameter is not None: if sample_parameter is not None:
assert(sample_parameter.dtype == torch.float), "this pass is only for torch.float modules" assert sample_parameter.dtype == torch.float, "this pass is only for torch.float modules"
assert(sample_parameter.device == torch.device('cpu')), "this pass is only for CPU modules" assert sample_parameter.device == torch.device('cpu'), "this pass is only for CPU modules"
elif node.op == 'call_function': elif node.op == 'call_function':
if node.target in mkldnn_supported: if node.target in mkldnn_supported:
supports_mkldnn = MklSupport.YES supports_mkldnn = MklSupport.YES
@ -360,14 +360,14 @@ def optimize_for_inference(
node.start_color = cur_idx node.start_color = cur_idx
uf.make_set(cur_idx) uf.make_set(cur_idx)
elif node.op == 'call_method' and node.target == 'to_dense': elif node.op == 'call_method' and node.target == 'to_dense':
assert(get_color(node.args[0]) is not None) assert get_color(node.args[0]) is not None
node.end_color = get_color(node.args[0]) node.end_color = get_color(node.args[0])
else: else:
cur_colors = [get_color(i) for i in node.all_input_nodes if isinstance(i, fx.Node) if get_color(i) is not None] cur_colors = [get_color(i) for i in node.all_input_nodes if isinstance(i, fx.Node) if get_color(i) is not None]
if len(cur_colors) == 0: if len(cur_colors) == 0:
continue continue
assert(not any(i is None for i in cur_colors)) assert not any(i is None for i in cur_colors)
cur_colors = sorted(cur_colors) cur_colors = sorted(cur_colors)
node.color = cur_colors[0] node.color = cur_colors[0]
for other_color in cur_colors[1:]: for other_color in cur_colors[1:]:

View File

@ -1359,7 +1359,7 @@ class DimConstraints:
self.raise_inconsistencies() self.raise_inconsistencies()
# as long as there are symbols with equalities, solve for them # as long as there are symbols with equalities, solve for them
# NOTE(avik): this is guaranteed to terminate (#iterations <= #symbols) # NOTE(avik): this is guaranteed to terminate (#iterations <= #symbols)
while(self._symbols_with_equalities): while self._symbols_with_equalities:
s = self._symbols_with_equalities.pop() s = self._symbols_with_equalities.pop()
exprs = self._univariate_inequalities.pop(s) exprs = self._univariate_inequalities.pop(s)
solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s) solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s)

View File

@ -670,7 +670,7 @@ class _PyTreeCodeGen(CodeGen):
return out return out
if not isinstance(out, (list, tuple)): if not isinstance(out, (list, tuple)):
out = [out] out = [out]
assert(self.pytree_info.out_spec is not None) assert self.pytree_info.out_spec is not None
return pytree.tree_unflatten(out, self.pytree_info.out_spec) return pytree.tree_unflatten(out, self.pytree_info.out_spec)
def gen_fn_def(self, free_vars, maybe_return_annotation): def gen_fn_def(self, free_vars, maybe_return_annotation):

View File

@ -82,7 +82,7 @@ def _type_repr(obj):
return obj.__qualname__ return obj.__qualname__
return f'{obj.__module__}.{obj.__qualname__}' return f'{obj.__module__}.{obj.__qualname__}'
if obj is ...: if obj is ...:
return('...') return '...'
if isinstance(obj, types.FunctionType): if isinstance(obj, types.FunctionType):
return obj.__name__ return obj.__name__
return repr(obj) return repr(obj)

View File

@ -698,7 +698,7 @@ class _MinimizerBase:
if self.settings.traverse_method == "accumulate": if self.settings.traverse_method == "accumulate":
return self._accumulate_traverse(nodes) return self._accumulate_traverse(nodes)
if(self.settings.traverse_method == "skip"): if self.settings.traverse_method == "skip":
if (skip_nodes is None): if (skip_nodes is None):
raise RuntimeError("'skip_nodes' can't be None when 'traverse_method' is 'skip'.") raise RuntimeError("'skip_nodes' can't be None when 'traverse_method' is 'skip'.")
return self._skip_traverse(nodes, skip_nodes) return self._skip_traverse(nodes, skip_nodes)

View File

@ -281,7 +281,7 @@ def split_module(
if assert_monotonically_increasing: if assert_monotonically_increasing:
pid = split_callback(node) pid = split_callback(node)
assert highest_partition <= pid,\ assert highest_partition <= pid, \
("autocast or set_grad_enabled require monotonically increasing partitions:" ("autocast or set_grad_enabled require monotonically increasing partitions:"
f"highest: {highest_partition}, this node's: {pid}") f"highest: {highest_partition}, this node's: {pid}")
highest_partition = pid highest_partition = pid

View File

@ -510,7 +510,7 @@ class ParameterProxy(Proxy):
""" """
def __init__(self, tracer: TracerBase, node: Node, name, param): def __init__(self, tracer: TracerBase, node: Node, name, param):
super().__init__(node, tracer) super().__init__(node, tracer)
assert(isinstance(param, torch.nn.Parameter)) assert isinstance(param, torch.nn.Parameter)
self.param = param self.param = param
self.name = name self.name = name

View File

@ -817,7 +817,7 @@ class StmtBuilder(Builder):
if is_torch_jit_ignore_context_manager(stmt): if is_torch_jit_ignore_context_manager(stmt):
if not _IS_ASTUNPARSE_INSTALLED: if not _IS_ASTUNPARSE_INSTALLED:
raise RuntimeError( raise RuntimeError(
"torch.jit._IgnoreContextManager requires installing Python library `astunparse`,\ "torch.jit._IgnoreContextManager requires installing Python library `astunparse`, \
please install it in your Python environment" please install it in your Python environment"
) )
assign_ast = build_ignore_context_manager(ctx, stmt) assign_ast = build_ignore_context_manager(ctx, stmt)

View File

@ -2571,7 +2571,7 @@ new_module_tests = [
.dim_feedforward(8) .dim_feedforward(8)
.dropout(0.0) .dropout(0.0)
.activation(torch::kReLU)''', .activation(torch::kReLU)''',
input_fn=lambda:(torch.rand(3, 3, 4), torch.rand(2, 3, 4), torch.rand(3, 3)), input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4), torch.rand(3, 3)),
check_gradgrad=False, check_gradgrad=False,
desc='multilayer_coder', desc='multilayer_coder',
with_tf32=True, with_tf32=True,

View File

@ -125,7 +125,7 @@ def _snr(x, x_hat):
signal, noise, SNR(in dB): Either floats or a nested list of floats signal, noise, SNR(in dB): Either floats or a nested list of floats
""" """
if isinstance(x, (list, tuple)): if isinstance(x, (list, tuple)):
assert(len(x) == len(x_hat)) assert len(x) == len(x_hat)
res = [] res = []
for idx in range(len(x)): for idx in range(len(x)):
res.append(_snr(x[idx], x_hat[idx])) res.append(_snr(x[idx], x_hat[idx]))

View File

@ -1352,7 +1352,7 @@ def skipIfTorchDynamo(msg="test doesn't currently work with dynamo"):
fn(*args, **kwargs) fn(*args, **kwargs)
return wrapper return wrapper
assert(isinstance(fn, type)) assert isinstance(fn, type)
if TEST_WITH_TORCHDYNAMO: # noqa: F821 if TEST_WITH_TORCHDYNAMO: # noqa: F821
fn.__unittest_skip__ = True fn.__unittest_skip__ = True
fn.__unittest_skip_why__ = msg fn.__unittest_skip_why__ = msg
@ -1374,7 +1374,7 @@ def skipIfTorchInductor(msg="test doesn't currently work with torchinductor",
fn(*args, **kwargs) fn(*args, **kwargs)
return wrapper return wrapper
assert(isinstance(fn, type)) assert isinstance(fn, type)
if condition: if condition:
fn.__unittest_skip__ = True fn.__unittest_skip__ = True
fn.__unittest_skip_why__ = msg fn.__unittest_skip_why__ = msg
@ -1440,7 +1440,7 @@ def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT exe
fn(*args, **kwargs) fn(*args, **kwargs)
return wrapper return wrapper
assert(isinstance(fn, type)) assert isinstance(fn, type)
if GRAPH_EXECUTOR == ProfilingMode.LEGACY: if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
fn.__unittest_skip__ = True fn.__unittest_skip__ = True
fn.__unittest_skip_why__ = msg fn.__unittest_skip_why__ = msg
@ -2080,7 +2080,7 @@ class CudaMemoryLeakCheck:
if driver_mem_allocated > self.driver_befores[i]: if driver_mem_allocated > self.driver_befores[i]:
driver_discrepancy = True driver_discrepancy = True
if not(caching_allocator_discrepancy or driver_discrepancy): if not (caching_allocator_discrepancy or driver_discrepancy):
# Leak was false positive, exit loop # Leak was false positive, exit loop
discrepancy_detected = False discrepancy_detected = False
break break

View File

@ -159,10 +159,10 @@ Example:
@st.composite @st.composite
def array_shapes(draw, min_dims=1, max_dims=None, min_side=1, max_side=None, max_numel=None): def array_shapes(draw, min_dims=1, max_dims=None, min_side=1, max_side=None, max_numel=None):
"""Return a strategy for array shapes (tuples of int >= 1).""" """Return a strategy for array shapes (tuples of int >= 1)."""
assert(min_dims < 32) assert min_dims < 32
if max_dims is None: if max_dims is None:
max_dims = min(min_dims + 2, 32) max_dims = min(min_dims + 2, 32)
assert(max_dims < 32) assert max_dims < 32
if max_side is None: if max_side is None:
max_side = min_side + 5 max_side = min_side + 5
candidate = st.lists(st.integers(min_side, max_side), min_size=min_dims, max_size=max_dims) candidate = st.lists(st.integers(min_side, max_side), min_size=min_dims, max_size=max_dims)
@ -334,7 +334,7 @@ def tensor_conv(
# Resolve the tensors # Resolve the tensors
if qparams is not None: if qparams is not None:
if isinstance(qparams, (list, tuple)): if isinstance(qparams, (list, tuple)):
assert(len(qparams) == 3), "Need 3 qparams for X, w, b" assert len(qparams) == 3, "Need 3 qparams for X, w, b"
else: else:
qparams = [qparams] * 3 qparams = [qparams] * 3

View File

@ -331,7 +331,7 @@ def value_to_literal(value):
def get_call(method_name, func_type, args, kwargs): def get_call(method_name, func_type, args, kwargs):
kwargs_str = ', '.join([k + '=' + value_to_literal(v) for k, v in kwargs.items()]) kwargs_str = ', '.join([k + '=' + value_to_literal(v) for k, v in kwargs.items()])
self_arg = args[0] self_arg = args[0]
if(func_type == 'method'): if func_type == 'method':
args = args[1:] args = args[1:]
argument_str = ', '.join(args) argument_str = ', '.join(args)

View File

@ -5,7 +5,7 @@ import warnings
from dataclasses import dataclass from dataclasses import dataclass
import torch import torch
import torchgen import torchgen
from torch._C import _len_torch_dispatch_stack, _get_dispatch_stack_at,\ from torch._C import _len_torch_dispatch_stack, _get_dispatch_stack_at, \
_pop_torch_dispatch_stack, _push_on_torch_dispatch_stack, DispatchKey _pop_torch_dispatch_stack, _push_on_torch_dispatch_stack, DispatchKey

View File

@ -117,10 +117,10 @@ def bundle_inputs(
# Fortunately theres a function in _recursive that does exactly that conversion. # Fortunately theres a function in _recursive that does exactly that conversion.
cloned_module = wrap_cpp_module(clone) cloned_module = wrap_cpp_module(clone)
if isinstance(inputs, dict): if isinstance(inputs, dict):
assert(isinstance(info, dict) or info is None) assert isinstance(info, dict) or info is None
augment_many_model_functions_with_bundled_inputs(cloned_module, inputs, _receive_inflate_expr, info) augment_many_model_functions_with_bundled_inputs(cloned_module, inputs, _receive_inflate_expr, info)
else: else:
assert(isinstance(info, list) or info is None) assert isinstance(info, list) or info is None
augment_model_with_bundled_inputs(cloned_module, inputs, _receive_inflate_expr, info) augment_model_with_bundled_inputs(cloned_module, inputs, _receive_inflate_expr, info)
return cloned_module return cloned_module

View File

@ -304,7 +304,7 @@ def get_cpu_info(run_lambda):
if get_platform() == 'linux': if get_platform() == 'linux':
rc, out, err = run_lambda('lscpu') rc, out, err = run_lambda('lscpu')
elif get_platform() == 'win32': elif get_platform() == 'win32':
rc, out, err = run_lambda('wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID,\ rc, out, err = run_lambda('wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID, \
CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE') CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE')
elif get_platform() == 'darwin': elif get_platform() == 'darwin':
rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string") rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string")

View File

@ -329,7 +329,7 @@ class FlopCounterMode(TorchDispatchMode):
class PushState(torch.autograd.Function): class PushState(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, *args): def forward(ctx, *args):
assert(self.parents[-1] == name) assert self.parents[-1] == name
self.parents.pop() self.parents.pop()
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args) args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
return args return args
@ -351,7 +351,7 @@ class FlopCounterMode(TorchDispatchMode):
@staticmethod @staticmethod
def backward(ctx, *grad_outs): def backward(ctx, *grad_outs):
assert(self.parents[-1] == name) assert self.parents[-1] == name
self.parents.pop() self.parents.pop()
return grad_outs return grad_outs

View File

@ -136,9 +136,9 @@ class MkldnnBatchNorm(torch.jit.ScriptModule):
def __init__(self, dense_module): def __init__(self, dense_module):
super().__init__() super().__init__()
assert(not dense_module.training) assert not dense_module.training
assert(dense_module.track_running_stats) assert dense_module.track_running_stats
assert(dense_module.affine) assert dense_module.affine
if dense_module.momentum is None: if dense_module.momentum is None:
self.exponential_average_factor = 0.0 self.exponential_average_factor = 0.0