mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE]: Update flake8 to v6.1.0 and fix lints (#116591)
Updates flake8 to v6.1.0 and fixes a few lints using sed and some ruff tooling.
- Replace `assert(0)` with `raise AssertionError()`
- Remove extraneous parenthesis i.e.
- `assert(a == b)` -> `assert a == b`
- `if(x > y or y < z):`->`if x > y or y < z:`
- And `return('...')` -> `return '...'`
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116591
Approved by: https://github.com/albanD, https://github.com/malfet
This commit is contained in:
parent
09ee96b69d
commit
3fe437b24b
2
.flake8
2
.flake8
|
|
@ -8,8 +8,6 @@ max-line-length = 120
|
||||||
# E501 is not flexible enough, we're using B950 instead
|
# E501 is not flexible enough, we're using B950 instead
|
||||||
ignore =
|
ignore =
|
||||||
E203,E305,E402,E501,E721,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303,
|
E203,E305,E402,E501,E721,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303,
|
||||||
# fix these lints in the future
|
|
||||||
E275,
|
|
||||||
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
|
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
|
||||||
# to line this up with executable bit
|
# to line this up with executable bit
|
||||||
EXE001,
|
EXE001,
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ init_command = [
|
||||||
'python3',
|
'python3',
|
||||||
'tools/linter/adapters/pip_init.py',
|
'tools/linter/adapters/pip_init.py',
|
||||||
'--dry-run={{DRYRUN}}',
|
'--dry-run={{DRYRUN}}',
|
||||||
'flake8==6.0.0',
|
'flake8==6.1.0',
|
||||||
'flake8-bugbear==23.3.23',
|
'flake8-bugbear==23.3.23',
|
||||||
'flake8-comprehensions==3.12.0',
|
'flake8-comprehensions==3.12.0',
|
||||||
'flake8-executable==2.1.3',
|
'flake8-executable==2.1.3',
|
||||||
|
|
@ -46,8 +46,8 @@ init_command = [
|
||||||
'flake8-pyi==23.3.1',
|
'flake8-pyi==23.3.1',
|
||||||
'flake8-simplify==0.19.3',
|
'flake8-simplify==0.19.3',
|
||||||
'mccabe==0.7.0',
|
'mccabe==0.7.0',
|
||||||
'pycodestyle==2.10.0',
|
'pycodestyle==2.11.1',
|
||||||
'pyflakes==3.0.1',
|
'pyflakes==3.1.0',
|
||||||
'torchfix==0.2.0',
|
'torchfix==0.2.0',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -981,7 +981,7 @@ class TestFPGMPruner(TestCase):
|
||||||
pruner.prepare(model, config)
|
pruner.prepare(model, config)
|
||||||
pruner.enable_mask_update = True
|
pruner.enable_mask_update = True
|
||||||
pruner.step()
|
pruner.step()
|
||||||
assert pruner.groups[0]["module"].parametrizations.weight[0].mask[-1].item() is not False,\
|
assert pruner.groups[0]["module"].parametrizations.weight[0].mask[-1].item() is not False, \
|
||||||
"do not prune the least-norm filter"
|
"do not prune the least-norm filter"
|
||||||
|
|
||||||
# fusion step
|
# fusion step
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ import torch.distributed as c10d
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from torch.testing._internal.common_distributed import \
|
from torch.testing._internal.common_distributed import \
|
||||||
MultiProcessTestCase
|
MultiProcessTestCase
|
||||||
from torch.testing._internal.common_utils import load_tests,\
|
from torch.testing._internal.common_utils import load_tests, \
|
||||||
NO_MULTIPROCESSING_SPAWN
|
NO_MULTIPROCESSING_SPAWN
|
||||||
|
|
||||||
# Torch distributed.nn is not available in windows
|
# Torch distributed.nn is not available in windows
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ def loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values):
|
||||||
for idx in range(batch_size):
|
for idx in range(batch_size):
|
||||||
flat_args, args_spec = pytree.tree_flatten(batched_args)
|
flat_args, args_spec = pytree.tree_flatten(batched_args)
|
||||||
flat_dims, dims_spec = pytree.tree_flatten(in_dims)
|
flat_dims, dims_spec = pytree.tree_flatten(in_dims)
|
||||||
assert(args_spec == dims_spec)
|
assert args_spec == dims_spec
|
||||||
new_args = [a.select(in_dim, idx) if in_dim is not None else a for a, in_dim in zip(flat_args, flat_dims)]
|
new_args = [a.select(in_dim, idx) if in_dim is not None else a for a, in_dim in zip(flat_args, flat_dims)]
|
||||||
out = op(*pytree.tree_unflatten(new_args, args_spec), **kwarg_values)
|
out = op(*pytree.tree_unflatten(new_args, args_spec), **kwarg_values)
|
||||||
flat_out, out_spec = pytree.tree_flatten(out)
|
flat_out, out_spec = pytree.tree_flatten(out)
|
||||||
|
|
@ -45,9 +45,9 @@ def loop2(op, in_dims1, in_dims2, out_dim1, out_dim2, batch_size1, batch_size2,
|
||||||
flat_args, args_spec = pytree.tree_flatten(batched_args)
|
flat_args, args_spec = pytree.tree_flatten(batched_args)
|
||||||
flat_dims1, dims_spec1 = pytree.tree_flatten(in_dims1)
|
flat_dims1, dims_spec1 = pytree.tree_flatten(in_dims1)
|
||||||
flat_dims2, dims_spec2 = pytree.tree_flatten(in_dims2)
|
flat_dims2, dims_spec2 = pytree.tree_flatten(in_dims2)
|
||||||
assert(args_spec == dims_spec1)
|
assert args_spec == dims_spec1
|
||||||
assert(args_spec == dims_spec2)
|
assert args_spec == dims_spec2
|
||||||
assert(len(flat_dims1) == len(flat_dims2))
|
assert len(flat_dims1) == len(flat_dims2)
|
||||||
for idx1 in range(batch_size1):
|
for idx1 in range(batch_size1):
|
||||||
out_split = []
|
out_split = []
|
||||||
arg_split = [a.select(in_dim1, idx1) if in_dim1 is not None else a for a, in_dim1 in zip(flat_args, flat_dims1)]
|
arg_split = [a.select(in_dim1, idx1) if in_dim1 is not None else a for a, in_dim1 in zip(flat_args, flat_dims1)]
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from torch.fx.experimental.migrate_gradual_types.constraint import BinConstraint
|
||||||
from torch.fx.experimental.migrate_gradual_types.constraint_generator import ConstraintGenerator
|
from torch.fx.experimental.migrate_gradual_types.constraint_generator import ConstraintGenerator
|
||||||
from torch.fx.experimental.migrate_gradual_types.constraint_transformation import transform_constraint
|
from torch.fx.experimental.migrate_gradual_types.constraint_transformation import transform_constraint
|
||||||
from torch.fx.experimental.migrate_gradual_types.operation import op_precision, op_matching, op_consistency
|
from torch.fx.experimental.migrate_gradual_types.operation import op_precision, op_matching, op_consistency
|
||||||
from torch.fx.experimental.migrate_gradual_types.transform_to_z3 import transform_all_constraints,\
|
from torch.fx.experimental.migrate_gradual_types.transform_to_z3 import transform_all_constraints, \
|
||||||
evaluate_conditional_with_constraints
|
evaluate_conditional_with_constraints
|
||||||
from torch.fx.experimental.migrate_gradual_types.z3_types import tensor_type, D, z3_dyn
|
from torch.fx.experimental.migrate_gradual_types.z3_types import tensor_type, D, z3_dyn
|
||||||
from torch.fx.experimental.rewriter import RewritingTracer
|
from torch.fx.experimental.rewriter import RewritingTracer
|
||||||
|
|
|
||||||
|
|
@ -154,9 +154,9 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||||
output_ref = func(input0, input1)
|
output_ref = func(input0, input1)
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
output = jit_f(input0, input1)
|
output = jit_f(input0, input1)
|
||||||
assert(output_ref[0].requires_grad == output[0].requires_grad)
|
assert output_ref[0].requires_grad == output[0].requires_grad
|
||||||
assert(output_ref[1][0].requires_grad == output[1][0].requires_grad)
|
assert output_ref[1][0].requires_grad == output[1][0].requires_grad
|
||||||
assert(output_ref[1][1].requires_grad == output[1][1].requires_grad)
|
assert output_ref[1][1].requires_grad == output[1][1].requires_grad
|
||||||
|
|
||||||
@unittest.skip("disable until we property handle tensor lists with undefined gradients")
|
@unittest.skip("disable until we property handle tensor lists with undefined gradients")
|
||||||
def test_differentiable_graph_ops_requires_grad(self):
|
def test_differentiable_graph_ops_requires_grad(self):
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||||
|
|
||||||
with warnings.catch_warnings(record=True) as w:
|
with warnings.catch_warnings(record=True) as w:
|
||||||
self.checkModule(M(), (1,))
|
self.checkModule(M(), (1,))
|
||||||
assert(len(w) == 0)
|
assert len(w) == 0
|
||||||
|
|
||||||
def test_annotated_nonempty_container(self):
|
def test_annotated_nonempty_container(self):
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
|
|
@ -49,7 +49,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||||
|
|
||||||
with warnings.catch_warnings(record=True) as w:
|
with warnings.catch_warnings(record=True) as w:
|
||||||
self.checkModule(M(), ([1, 2, 3],))
|
self.checkModule(M(), ([1, 2, 3],))
|
||||||
assert(len(w) == 0)
|
assert len(w) == 0
|
||||||
|
|
||||||
def test_annotated_empty_tensor(self):
|
def test_annotated_empty_tensor(self):
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
|
|
@ -63,7 +63,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||||
|
|
||||||
with warnings.catch_warnings(record=True) as w:
|
with warnings.catch_warnings(record=True) as w:
|
||||||
self.checkModule(M(), (torch.rand(2, 3),))
|
self.checkModule(M(), (torch.rand(2, 3),))
|
||||||
assert(len(w) == 0)
|
assert len(w) == 0
|
||||||
|
|
||||||
def test_annotated_with_jit_attribute(self):
|
def test_annotated_with_jit_attribute(self):
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
|
|
@ -77,7 +77,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||||
|
|
||||||
with warnings.catch_warnings(record=True) as w:
|
with warnings.catch_warnings(record=True) as w:
|
||||||
self.checkModule(M(), ([1, 2, 3],))
|
self.checkModule(M(), ([1, 2, 3],))
|
||||||
assert(len(w) == 0)
|
assert len(w) == 0
|
||||||
|
|
||||||
def test_annotated_class_level_annotation_only(self):
|
def test_annotated_class_level_annotation_only(self):
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
|
|
@ -94,7 +94,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||||
|
|
||||||
with warnings.catch_warnings(record=True) as w:
|
with warnings.catch_warnings(record=True) as w:
|
||||||
self.checkModule(M(), ([1, 2, 3],))
|
self.checkModule(M(), ([1, 2, 3],))
|
||||||
assert(len(w) == 0)
|
assert len(w) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_annotated_class_level_annotation_and_init_annotation(self):
|
def test_annotated_class_level_annotation_and_init_annotation(self):
|
||||||
|
|
@ -112,7 +112,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||||
|
|
||||||
with warnings.catch_warnings(record=True) as w:
|
with warnings.catch_warnings(record=True) as w:
|
||||||
self.checkModule(M(), ([1, 2, 3],))
|
self.checkModule(M(), ([1, 2, 3],))
|
||||||
assert(len(w) == 0)
|
assert len(w) == 0
|
||||||
|
|
||||||
def test_annotated_class_level_jit_annotation(self):
|
def test_annotated_class_level_jit_annotation(self):
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
|
|
@ -129,7 +129,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||||
|
|
||||||
with warnings.catch_warnings(record=True) as w:
|
with warnings.catch_warnings(record=True) as w:
|
||||||
self.checkModule(M(), ([1, 2, 3],))
|
self.checkModule(M(), ([1, 2, 3],))
|
||||||
assert(len(w) == 0)
|
assert len(w) == 0
|
||||||
|
|
||||||
def test_annotated_empty_list(self):
|
def test_annotated_empty_list(self):
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -147,7 +147,7 @@ class testVariousModelVersions(TestCase):
|
||||||
def test_get_model_bytecode_version(self):
|
def test_get_model_bytecode_version(self):
|
||||||
def check_model_version(model_path, expect_version):
|
def check_model_version(model_path, expect_version):
|
||||||
actual_version = _get_model_bytecode_version(model_path)
|
actual_version = _get_model_bytecode_version(model_path)
|
||||||
assert(actual_version == expect_version)
|
assert actual_version == expect_version
|
||||||
for version, model_info in SCRIPT_MODULE_BYTECODE_PKL.items():
|
for version, model_info in SCRIPT_MODULE_BYTECODE_PKL.items():
|
||||||
model_path = pytorch_test_dir / "cpp" / "jit" / model_info["model_name"]
|
model_path = pytorch_test_dir / "cpp" / "jit" / model_info["model_name"]
|
||||||
check_model_version(model_path, version)
|
check_model_version(model_path, version)
|
||||||
|
|
@ -174,7 +174,7 @@ class testVariousModelVersions(TestCase):
|
||||||
|
|
||||||
current_to_version = current_from_version - 1
|
current_to_version = current_from_version - 1
|
||||||
backport_success = _backport_for_mobile(input_model_path, tmp_output_model_path_backport, current_to_version)
|
backport_success = _backport_for_mobile(input_model_path, tmp_output_model_path_backport, current_to_version)
|
||||||
assert(backport_success)
|
assert backport_success
|
||||||
|
|
||||||
expect_bytecode_pkl = SCRIPT_MODULE_BYTECODE_PKL[current_to_version]["bytecode_pkl"]
|
expect_bytecode_pkl = SCRIPT_MODULE_BYTECODE_PKL[current_to_version]["bytecode_pkl"]
|
||||||
|
|
||||||
|
|
@ -187,7 +187,7 @@ class testVariousModelVersions(TestCase):
|
||||||
acutal_result_clean = "".join(output.split())
|
acutal_result_clean = "".join(output.split())
|
||||||
expect_result_clean = "".join(expect_bytecode_pkl.split())
|
expect_result_clean = "".join(expect_bytecode_pkl.split())
|
||||||
isMatch = fnmatch.fnmatch(acutal_result_clean, expect_result_clean)
|
isMatch = fnmatch.fnmatch(acutal_result_clean, expect_result_clean)
|
||||||
assert(isMatch)
|
assert isMatch
|
||||||
|
|
||||||
current_from_version -= 1
|
current_from_version -= 1
|
||||||
shutil.rmtree(tmpdirname)
|
shutil.rmtree(tmpdirname)
|
||||||
|
|
@ -254,7 +254,7 @@ class testVariousModelVersions(TestCase):
|
||||||
script_module_v5_path,
|
script_module_v5_path,
|
||||||
tmp_backport_model_path,
|
tmp_backport_model_path,
|
||||||
maximum_checked_in_model_version - 1)
|
maximum_checked_in_model_version - 1)
|
||||||
assert(success)
|
assert success
|
||||||
|
|
||||||
buf = io.StringIO()
|
buf = io.StringIO()
|
||||||
torch.utils.show_pickle.main(
|
torch.utils.show_pickle.main(
|
||||||
|
|
@ -266,7 +266,7 @@ class testVariousModelVersions(TestCase):
|
||||||
acutal_result_clean = "".join(output.split())
|
acutal_result_clean = "".join(output.split())
|
||||||
expect_result_clean = "".join(expected_result.split())
|
expect_result_clean = "".join(expected_result.split())
|
||||||
isMatch = fnmatch.fnmatch(acutal_result_clean, expect_result_clean)
|
isMatch = fnmatch.fnmatch(acutal_result_clean, expect_result_clean)
|
||||||
assert(isMatch)
|
assert isMatch
|
||||||
|
|
||||||
# Load model v4 and run forward method
|
# Load model v4 and run forward method
|
||||||
mobile_module = _load_for_lite_interpreter(str(tmp_backport_model_path))
|
mobile_module = _load_for_lite_interpreter(str(tmp_backport_model_path))
|
||||||
|
|
@ -291,7 +291,7 @@ class testVariousModelVersions(TestCase):
|
||||||
# Check version of the model v4 from backport
|
# Check version of the model v4 from backport
|
||||||
bytesio = io.BytesIO(script_module_v4_buffer)
|
bytesio = io.BytesIO(script_module_v4_buffer)
|
||||||
backport_version = _get_model_bytecode_version(bytesio)
|
backport_version = _get_model_bytecode_version(bytesio)
|
||||||
assert(backport_version == maximum_checked_in_model_version - 1)
|
assert backport_version == maximum_checked_in_model_version - 1
|
||||||
|
|
||||||
# Load model v4 from backport and run forward method
|
# Load model v4 from backport and run forward method
|
||||||
bytesio = io.BytesIO(script_module_v4_buffer)
|
bytesio = io.BytesIO(script_module_v4_buffer)
|
||||||
|
|
@ -306,8 +306,8 @@ class testVariousModelVersions(TestCase):
|
||||||
# TODO update this to be more in the style of the above tests after a backport from 6 -> 5 exists
|
# TODO update this to be more in the style of the above tests after a backport from 6 -> 5 exists
|
||||||
script_module_v6 = pytorch_test_dir / "cpp" / "jit" / "script_module_v6.ptl"
|
script_module_v6 = pytorch_test_dir / "cpp" / "jit" / "script_module_v6.ptl"
|
||||||
ops_v6 = _get_model_ops_and_info(script_module_v6)
|
ops_v6 = _get_model_ops_and_info(script_module_v6)
|
||||||
assert(ops_v6["aten::add.int"].num_schema_args == 2)
|
assert ops_v6["aten::add.int"].num_schema_args == 2
|
||||||
assert(ops_v6["aten::add.Scalar"].num_schema_args == 2)
|
assert ops_v6["aten::add.Scalar"].num_schema_args == 2
|
||||||
|
|
||||||
def test_get_mobile_model_contained_types(self):
|
def test_get_mobile_model_contained_types(self):
|
||||||
class MyTestModule(torch.nn.Module):
|
class MyTestModule(torch.nn.Module):
|
||||||
|
|
@ -322,7 +322,7 @@ class testVariousModelVersions(TestCase):
|
||||||
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
|
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
|
||||||
buffer.seek(0)
|
buffer.seek(0)
|
||||||
type_list = _get_mobile_model_contained_types(buffer)
|
type_list = _get_mobile_model_contained_types(buffer)
|
||||||
assert(len(type_list) >= 0)
|
assert len(type_list) >= 0
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -82,7 +82,7 @@ class TestLiteScriptModule(TestCase):
|
||||||
buffer = io.BytesIO(exported_module)
|
buffer = io.BytesIO(exported_module)
|
||||||
buffer.seek(0)
|
buffer.seek(0)
|
||||||
|
|
||||||
assert(b"callstack_debug_map.pkl" in exported_module)
|
assert b"callstack_debug_map.pkl" in exported_module
|
||||||
|
|
||||||
mobile_module = _load_for_lite_interpreter(buffer)
|
mobile_module = _load_for_lite_interpreter(buffer)
|
||||||
with self.assertRaisesRegex(RuntimeError, r"Module hierarchy:top\(B\)::<unknown>.A0\(A\)::forward.aten::mul"):
|
with self.assertRaisesRegex(RuntimeError, r"Module hierarchy:top\(B\)::<unknown>.A0\(A\)::forward.aten::mul"):
|
||||||
|
|
|
||||||
|
|
@ -65,8 +65,8 @@ class TestDropoutNN(NNTestCase):
|
||||||
o_ref = torch.dropout(x_ref, p, train)
|
o_ref = torch.dropout(x_ref, p, train)
|
||||||
o.sum().backward()
|
o.sum().backward()
|
||||||
o_ref.sum().backward()
|
o_ref.sum().backward()
|
||||||
assert(o.equal(o_ref))
|
assert o.equal(o_ref)
|
||||||
assert(x.grad.equal(x_ref.grad))
|
assert x.grad.equal(x_ref.grad)
|
||||||
|
|
||||||
def test_invalid_dropout_p(self):
|
def test_invalid_dropout_p(self):
|
||||||
v = torch.ones(1)
|
v = torch.ones(1)
|
||||||
|
|
|
||||||
|
|
@ -166,7 +166,7 @@ class TestOptim(TestCase):
|
||||||
pass
|
pass
|
||||||
elif constructor_accepts_maximize:
|
elif constructor_accepts_maximize:
|
||||||
|
|
||||||
def four_arg_constructor(weight, bias, maximize, foreach):
|
def four_arg_constructor(weight, bias, maximize, foreach): # noqa: F811
|
||||||
self.assertFalse(foreach)
|
self.assertFalse(foreach)
|
||||||
return constructor(weight, bias, maximize)
|
return constructor(weight, bias, maximize)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2127,7 +2127,7 @@ class TestQuantizedOps(TestCase):
|
||||||
|
|
||||||
quantized_out = torch.topk(qX, k, dim=dim, largest=larg, sorted=sort)
|
quantized_out = torch.topk(qX, k, dim=dim, largest=larg, sorted=sort)
|
||||||
|
|
||||||
assert(len(unquantized_out) == len(quantized_out))
|
assert len(unquantized_out) == len(quantized_out)
|
||||||
torch.testing.assert_close(quantized_out[0].dequantize(), unquantized_out[0])
|
torch.testing.assert_close(quantized_out[0].dequantize(), unquantized_out[0])
|
||||||
torch.testing.assert_close(quantized_out[1], unquantized_out[1])
|
torch.testing.assert_close(quantized_out[1], unquantized_out[1])
|
||||||
|
|
||||||
|
|
@ -3813,7 +3813,7 @@ class TestQuantizedLinear(TestCase):
|
||||||
post_op = 'none'
|
post_op = 'none'
|
||||||
cases = itertools.product(batch_size_list, input_channels_list, output_channels_list,
|
cases = itertools.product(batch_size_list, input_channels_list, output_channels_list,
|
||||||
use_bias_list, use_multi_dim_input_list, use_channelwise_list)
|
use_bias_list, use_multi_dim_input_list, use_channelwise_list)
|
||||||
for batch_size, input_channels, output_channels, use_bias,\
|
for batch_size, input_channels, output_channels, use_bias, \
|
||||||
use_multi_dim_input, use_channelwise in cases:
|
use_multi_dim_input, use_channelwise in cases:
|
||||||
self._test_qlinear_impl(batch_size, input_channels, output_channels,
|
self._test_qlinear_impl(batch_size, input_channels, output_channels,
|
||||||
use_bias, post_op, use_multi_dim_input, use_channelwise)
|
use_bias, post_op, use_multi_dim_input, use_channelwise)
|
||||||
|
|
@ -3830,7 +3830,7 @@ class TestQuantizedLinear(TestCase):
|
||||||
post_op = 'relu'
|
post_op = 'relu'
|
||||||
cases = itertools.product(batch_size_list, input_channels_list, output_channels_list,
|
cases = itertools.product(batch_size_list, input_channels_list, output_channels_list,
|
||||||
use_bias_list, use_multi_dim_input_list, use_channelwise_list)
|
use_bias_list, use_multi_dim_input_list, use_channelwise_list)
|
||||||
for batch_size, input_channels, output_channels, use_bias,\
|
for batch_size, input_channels, output_channels, use_bias, \
|
||||||
use_multi_dim_input, use_channelwise in cases:
|
use_multi_dim_input, use_channelwise in cases:
|
||||||
self._test_qlinear_impl(batch_size, input_channels, output_channels,
|
self._test_qlinear_impl(batch_size, input_channels, output_channels,
|
||||||
use_bias, post_op, use_multi_dim_input, use_channelwise)
|
use_bias, post_op, use_multi_dim_input, use_channelwise)
|
||||||
|
|
@ -4101,7 +4101,7 @@ class TestQuantizedLinear(TestCase):
|
||||||
qparams=hu.qparams(dtypes=torch.qint8)))
|
qparams=hu.qparams(dtypes=torch.qint8)))
|
||||||
@override_qengines
|
@override_qengines
|
||||||
def test_qlinear_qnnpack_free_memory_and_unpack(self, W):
|
def test_qlinear_qnnpack_free_memory_and_unpack(self, W):
|
||||||
assert(qengine_is_qnnpack)
|
assert qengine_is_qnnpack
|
||||||
W, (W_scale, W_zp, torch_type) = W
|
W, (W_scale, W_zp, torch_type) = W
|
||||||
qlinear_prepack = torch.ops.quantized.linear_prepack
|
qlinear_prepack = torch.ops.quantized.linear_prepack
|
||||||
qlinear_unpack = torch.ops.quantized.linear_unpack
|
qlinear_unpack = torch.ops.quantized.linear_unpack
|
||||||
|
|
@ -4140,7 +4140,7 @@ class TestQuantizedLinear(TestCase):
|
||||||
cases = itertools.product(batch_size_list, input_channels_list, output_channels_list,
|
cases = itertools.product(batch_size_list, input_channels_list, output_channels_list,
|
||||||
use_bias_list, use_multi_dim_input_list,
|
use_bias_list, use_multi_dim_input_list,
|
||||||
use_channelwise_list, negative_slopes_list)
|
use_channelwise_list, negative_slopes_list)
|
||||||
for batch_size, input_channels, output_channels, use_bias,\
|
for batch_size, input_channels, output_channels, use_bias, \
|
||||||
use_multi_dim_input, use_channelwise, neg_slope in cases:
|
use_multi_dim_input, use_channelwise, neg_slope in cases:
|
||||||
self._test_qlinear_impl(batch_size, input_channels, output_channels,
|
self._test_qlinear_impl(batch_size, input_channels, output_channels,
|
||||||
use_bias, post_op, use_multi_dim_input,
|
use_bias, post_op, use_multi_dim_input,
|
||||||
|
|
@ -4159,7 +4159,7 @@ class TestQuantizedLinear(TestCase):
|
||||||
cases = itertools.product(batch_size_list, input_channels_list,
|
cases = itertools.product(batch_size_list, input_channels_list,
|
||||||
output_channels_list, use_bias_list,
|
output_channels_list, use_bias_list,
|
||||||
use_multi_dim_input_list, use_channelwise_list)
|
use_multi_dim_input_list, use_channelwise_list)
|
||||||
for batch_size, input_channels, output_channels, use_bias,\
|
for batch_size, input_channels, output_channels, use_bias, \
|
||||||
use_multi_dim_input, use_channelwise in cases:
|
use_multi_dim_input, use_channelwise in cases:
|
||||||
self._test_qlinear_impl(batch_size, input_channels, output_channels,
|
self._test_qlinear_impl(batch_size, input_channels, output_channels,
|
||||||
use_bias, post_op, use_multi_dim_input,
|
use_bias, post_op, use_multi_dim_input,
|
||||||
|
|
|
||||||
|
|
@ -879,7 +879,7 @@ class TestFakeQuantizeOps(TestCase):
|
||||||
Y_prime.backward(dout)
|
Y_prime.backward(dout)
|
||||||
np.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
dX.cpu().detach().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
|
dX.cpu().detach().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
|
||||||
assert(X.grad.dtype == float_type)
|
assert X.grad.dtype == float_type
|
||||||
|
|
||||||
|
|
||||||
def test_backward_per_channel_cachemask_cpu(self):
|
def test_backward_per_channel_cachemask_cpu(self):
|
||||||
|
|
|
||||||
|
|
@ -839,7 +839,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||||
return reduce(lambda f, g: lambda x: f(g(x)), functions[::-1], lambda x: x)
|
return reduce(lambda f, g: lambda x: f(g(x)), functions[::-1], lambda x: x)
|
||||||
|
|
||||||
if not use_relu:
|
if not use_relu:
|
||||||
def relu_op(x):
|
def relu_op(x): # noqa: F811
|
||||||
return x
|
return x
|
||||||
|
|
||||||
if freeze_bn:
|
if freeze_bn:
|
||||||
|
|
|
||||||
|
|
@ -4319,7 +4319,7 @@ Done""")
|
||||||
]
|
]
|
||||||
for thread, ranges in threads:
|
for thread, ranges in threads:
|
||||||
for range in ranges:
|
for range in ranges:
|
||||||
assert(len(range) == 3)
|
assert len(range) == 3
|
||||||
events.append(
|
events.append(
|
||||||
FunctionEvent(
|
FunctionEvent(
|
||||||
id=range[2],
|
id=range[2],
|
||||||
|
|
@ -4340,7 +4340,7 @@ Done""")
|
||||||
def get_children_ids(event):
|
def get_children_ids(event):
|
||||||
return [child.id for child in event.cpu_children]
|
return [child.id for child in event.cpu_children]
|
||||||
|
|
||||||
assert([get_children_ids(event) for event in events] == res)
|
assert [get_children_ids(event) for event in events] == res
|
||||||
|
|
||||||
def test_profiler_aggregation_table(self):
|
def test_profiler_aggregation_table(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ def save_and_load(sm):
|
||||||
def bundle_jpeg_image(img_tensor, quality):
|
def bundle_jpeg_image(img_tensor, quality):
|
||||||
# turn NCHW to HWC
|
# turn NCHW to HWC
|
||||||
if img_tensor.dim() == 4:
|
if img_tensor.dim() == 4:
|
||||||
assert(img_tensor.size(0) == 1)
|
assert img_tensor.size(0) == 1
|
||||||
img_tensor = img_tensor[0].permute(1, 2, 0)
|
img_tensor = img_tensor[0].permute(1, 2, 0)
|
||||||
pixels = img_tensor.numpy()
|
pixels = img_tensor.numpy()
|
||||||
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
|
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
|
||||||
|
|
|
||||||
|
|
@ -1388,7 +1388,7 @@ torch.cuda.synchronize()
|
||||||
scaler.scale(l).backward()
|
scaler.scale(l).backward()
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
assert(scaler._scale != float('inf') and scaler._scale != float('nan'))
|
assert scaler._scale != float('inf') and scaler._scale != float('nan')
|
||||||
|
|
||||||
def test_grad_scaling_clipping(self):
|
def test_grad_scaling_clipping(self):
|
||||||
def run(data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api):
|
def run(data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api):
|
||||||
|
|
|
||||||
|
|
@ -1676,8 +1676,8 @@ class TestFX(JitTestCase):
|
||||||
x = torch.randn(5, 5, 224, 224)
|
x = torch.randn(5, 5, 224, 224)
|
||||||
shape_prop.ShapeProp(traced).propagate(x)
|
shape_prop.ShapeProp(traced).propagate(x)
|
||||||
|
|
||||||
assert(all(node.meta['tensor_meta'].memory_format is torch.contiguous_format
|
assert all(node.meta['tensor_meta'].memory_format is torch.contiguous_format
|
||||||
for node in traced.graph.nodes))
|
for node in traced.graph.nodes)
|
||||||
|
|
||||||
x_channels_last = x.contiguous(memory_format=torch.channels_last)
|
x_channels_last = x.contiguous(memory_format=torch.channels_last)
|
||||||
traced.to(memory_format=torch.channels_last)
|
traced.to(memory_format=torch.channels_last)
|
||||||
|
|
@ -1733,8 +1733,8 @@ class TestFX(JitTestCase):
|
||||||
traced_3d = symbolic_trace(test_mod_3d)
|
traced_3d = symbolic_trace(test_mod_3d)
|
||||||
x_3d = torch.randn(5, 5, 224, 224, 15)
|
x_3d = torch.randn(5, 5, 224, 224, 15)
|
||||||
shape_prop.ShapeProp(traced_3d).propagate(x_3d)
|
shape_prop.ShapeProp(traced_3d).propagate(x_3d)
|
||||||
assert(all(node.meta['tensor_meta'].memory_format is torch.contiguous_format
|
assert all(node.meta['tensor_meta'].memory_format is torch.contiguous_format
|
||||||
for node in traced_3d.graph.nodes))
|
for node in traced_3d.graph.nodes)
|
||||||
|
|
||||||
x_channels_last_3d = x_3d.contiguous(memory_format=torch.channels_last_3d)
|
x_channels_last_3d = x_3d.contiguous(memory_format=torch.channels_last_3d)
|
||||||
traced_3d.to(memory_format=torch.channels_last_3d)
|
traced_3d.to(memory_format=torch.channels_last_3d)
|
||||||
|
|
@ -2548,7 +2548,7 @@ class TestFX(JitTestCase):
|
||||||
return self.a.b, self.a.b.t(), self.a.b.view(12)
|
return self.a.b, self.a.b.t(), self.a.b.view(12)
|
||||||
|
|
||||||
traced = torch.fx.symbolic_trace(Foo())
|
traced = torch.fx.symbolic_trace(Foo())
|
||||||
assert(all('constant' not in node.target for node in traced.graph.nodes))
|
assert all('constant' not in node.target for node in traced.graph.nodes)
|
||||||
|
|
||||||
def test_single_default_arg(self):
|
def test_single_default_arg(self):
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
|
|
@ -2824,7 +2824,7 @@ class TestFX(JitTestCase):
|
||||||
mod_false = symbolic_trace(mod, concrete_args={'y': False})
|
mod_false = symbolic_trace(mod, concrete_args={'y': False})
|
||||||
self.assertEqual(mod_true(3, True), 6)
|
self.assertEqual(mod_true(3, True), 6)
|
||||||
print(mod_true.code)
|
print(mod_true.code)
|
||||||
assert(any(i.target == torch._assert for i in mod_true.graph.nodes))
|
assert any(i.target == torch._assert for i in mod_true.graph.nodes)
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
mod_true(3, False)
|
mod_true(3, False)
|
||||||
self.assertEqual(mod_false(3, False), 3)
|
self.assertEqual(mod_false(3, False), 3)
|
||||||
|
|
@ -3607,17 +3607,17 @@ class TestFX(JitTestCase):
|
||||||
self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(val))), orig_out)
|
self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(val))), orig_out)
|
||||||
|
|
||||||
assert num_flat_args == 0 or "tree_flatten_spec" in nf.code
|
assert num_flat_args == 0 or "tree_flatten_spec" in nf.code
|
||||||
assert(sum([i.op == 'placeholder' for i in nf.graph.nodes]) == num_flat_args)
|
assert sum([i.op == 'placeholder' for i in nf.graph.nodes]) == num_flat_args
|
||||||
|
|
||||||
nf = symbolic_trace(nf)
|
nf = symbolic_trace(nf)
|
||||||
self.assertEqual(nf(val), orig_out)
|
self.assertEqual(nf(val), orig_out)
|
||||||
assert "tree_flatten_spec" not in nf.code
|
assert "tree_flatten_spec" not in nf.code
|
||||||
assert(sum([i.op == 'placeholder' for i in nf.graph.nodes]) == 1)
|
assert sum([i.op == 'placeholder' for i in nf.graph.nodes]) == 1
|
||||||
|
|
||||||
nf = symbolic_trace(nf, concrete_args={'x': inp})
|
nf = symbolic_trace(nf, concrete_args={'x': inp})
|
||||||
self.assertEqual(nf(val), orig_out)
|
self.assertEqual(nf(val), orig_out)
|
||||||
assert num_flat_args == 0 or "tree_flatten_spec" in nf.code
|
assert num_flat_args == 0 or "tree_flatten_spec" in nf.code
|
||||||
assert(sum([i.op == 'placeholder' for i in nf.graph.nodes]) == num_flat_args)
|
assert sum([i.op == 'placeholder' for i in nf.graph.nodes]) == num_flat_args
|
||||||
|
|
||||||
pickled = pickle.dumps(nf)
|
pickled = pickle.dumps(nf)
|
||||||
nf = pickle.loads(pickled)
|
nf = pickle.loads(pickled)
|
||||||
|
|
@ -3708,7 +3708,7 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
|
||||||
return [('List', typing.List)]
|
return [('List', typing.List)]
|
||||||
|
|
||||||
def process_inputs(self, *inputs):
|
def process_inputs(self, *inputs):
|
||||||
assert(len(inputs) == 1)
|
assert len(inputs) == 1
|
||||||
return inputs[0]
|
return inputs[0]
|
||||||
|
|
||||||
def f(a, b):
|
def f(a, b):
|
||||||
|
|
@ -3743,7 +3743,7 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
|
||||||
return [('List', typing.List)]
|
return [('List', typing.List)]
|
||||||
|
|
||||||
def process_inputs(self, *inputs):
|
def process_inputs(self, *inputs):
|
||||||
assert(len(inputs) == 1)
|
assert len(inputs) == 1
|
||||||
return inputs[0]
|
return inputs[0]
|
||||||
|
|
||||||
def f(a, b):
|
def f(a, b):
|
||||||
|
|
@ -3772,7 +3772,7 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
|
||||||
return [('List', typing.List)]
|
return [('List', typing.List)]
|
||||||
|
|
||||||
def process_inputs(self, *inputs):
|
def process_inputs(self, *inputs):
|
||||||
assert(len(inputs) == 1)
|
assert len(inputs) == 1
|
||||||
return inputs[0]
|
return inputs[0]
|
||||||
|
|
||||||
def generate_output(self, output_args):
|
def generate_output(self, output_args):
|
||||||
|
|
|
||||||
|
|
@ -1875,8 +1875,8 @@ graph(%Ra, %Rb):
|
||||||
o_ref = t(x_ref, p, train)
|
o_ref = t(x_ref, p, train)
|
||||||
o.sum().backward()
|
o.sum().backward()
|
||||||
o_ref.sum().backward()
|
o_ref.sum().backward()
|
||||||
assert(o.equal(o_ref))
|
assert o.equal(o_ref)
|
||||||
assert(x.grad.equal(x_ref.grad))
|
assert x.grad.equal(x_ref.grad)
|
||||||
|
|
||||||
@slowTest
|
@slowTest
|
||||||
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, 'Testing differentiable graph')
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, 'Testing differentiable graph')
|
||||||
|
|
@ -5746,7 +5746,7 @@ a")
|
||||||
x = torch.zeros(3, 4, dtype=torch.long)
|
x = torch.zeros(3, 4, dtype=torch.long)
|
||||||
graph = _propagate_shapes(fn.graph, (x,), False)
|
graph = _propagate_shapes(fn.graph, (x,), False)
|
||||||
default = torch.get_default_dtype()
|
default = torch.get_default_dtype()
|
||||||
if(default == torch.float):
|
if default == torch.float:
|
||||||
FileCheck().check('Float(*, *, requires_grad=0, device=cpu) = aten::add').run(graph)
|
FileCheck().check('Float(*, *, requires_grad=0, device=cpu) = aten::add').run(graph)
|
||||||
else:
|
else:
|
||||||
FileCheck().check('Double(*, *, requires_grad=0, device=cpu) = aten::add').run(graph)
|
FileCheck().check('Double(*, *, requires_grad=0, device=cpu) = aten::add').run(graph)
|
||||||
|
|
@ -6169,7 +6169,7 @@ a")
|
||||||
with self.assertRaisesRegex(RuntimeError, r"Could not cast value of type Tuple\[Tensor, Tensor\]"): # noqa: W605
|
with self.assertRaisesRegex(RuntimeError, r"Could not cast value of type Tuple\[Tensor, Tensor\]"): # noqa: W605
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def test_mult(x, y):
|
def test_mult(x, y):
|
||||||
return not(x, y)
|
return not (x, y)
|
||||||
|
|
||||||
def test_cast_int(x):
|
def test_cast_int(x):
|
||||||
# type: (int) -> int
|
# type: (int) -> int
|
||||||
|
|
@ -8557,15 +8557,15 @@ dedent """
|
||||||
fb.run("22 1 22")
|
fb.run("22 1 22")
|
||||||
|
|
||||||
def _dtype_to_jit_name(self, dtype):
|
def _dtype_to_jit_name(self, dtype):
|
||||||
if(dtype == torch.float32):
|
if dtype == torch.float32:
|
||||||
return "Float"
|
return "Float"
|
||||||
if(dtype == torch.float64):
|
if dtype == torch.float64:
|
||||||
return "Double"
|
return "Double"
|
||||||
if(dtype == torch.int64):
|
if dtype == torch.int64:
|
||||||
return "Long"
|
return "Long"
|
||||||
if(dtype == torch.int32):
|
if dtype == torch.int32:
|
||||||
return "Int"
|
return "Int"
|
||||||
if(dtype == torch.bool):
|
if dtype == torch.bool:
|
||||||
return "Bool"
|
return "Bool"
|
||||||
raise RuntimeError('dtype not handled')
|
raise RuntimeError('dtype not handled')
|
||||||
|
|
||||||
|
|
@ -8595,7 +8595,7 @@ dedent """
|
||||||
for dtype in (dtypes + [None]):
|
for dtype in (dtypes + [None]):
|
||||||
for tensor_type in dtypes:
|
for tensor_type in dtypes:
|
||||||
# a couple of ops aren't implemented for non-floating types
|
# a couple of ops aren't implemented for non-floating types
|
||||||
if(not tensor_type.is_floating_point or (dtype is not None and not dtype.is_floating_point)):
|
if not tensor_type.is_floating_point or (dtype is not None and not dtype.is_floating_point):
|
||||||
if op in ['mean', 'softmax', 'log_softmax']:
|
if op in ['mean', 'softmax', 'log_softmax']:
|
||||||
continue
|
continue
|
||||||
return_line = f"torch.tensor({tensor_data}, dtype={tensor_type}).{op}({str_args}dtype={dtype})"
|
return_line = f"torch.tensor({tensor_data}, dtype={tensor_type}).{op}({str_args}dtype={dtype})"
|
||||||
|
|
@ -12837,7 +12837,7 @@ dedent """
|
||||||
|
|
||||||
for i, offset in enumerate(parsed_serialized_offsets):
|
for i, offset in enumerate(parsed_serialized_offsets):
|
||||||
data = reader.get_record(str(offset))
|
data = reader.get_record(str(offset))
|
||||||
assert(data == buffers[i])
|
assert data == buffers[i]
|
||||||
|
|
||||||
def test_file_reader_no_memory_leak(self):
|
def test_file_reader_no_memory_leak(self):
|
||||||
num_iters = 10000
|
num_iters = 10000
|
||||||
|
|
|
||||||
|
|
@ -1032,7 +1032,7 @@ class MpsMemoryLeakCheck:
|
||||||
if driver_mem_allocated > self.driver_before:
|
if driver_mem_allocated > self.driver_before:
|
||||||
driver_discrepancy = True
|
driver_discrepancy = True
|
||||||
|
|
||||||
if not(caching_allocator_discrepancy or driver_discrepancy):
|
if not (caching_allocator_discrepancy or driver_discrepancy):
|
||||||
# Leak was false positive, exit loop
|
# Leak was false positive, exit loop
|
||||||
discrepancy_detected = False
|
discrepancy_detected = False
|
||||||
break
|
break
|
||||||
|
|
@ -4122,7 +4122,7 @@ class TestMPS(TestCaseMPS):
|
||||||
def helper(dtype):
|
def helper(dtype):
|
||||||
# Test with axis -1
|
# Test with axis -1
|
||||||
cpu_x = None
|
cpu_x = None
|
||||||
if(dtype == torch.float32):
|
if dtype == torch.float32:
|
||||||
cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32)
|
cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32)
|
||||||
else:
|
else:
|
||||||
cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32)
|
cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32)
|
||||||
|
|
@ -4157,7 +4157,7 @@ class TestMPS(TestCaseMPS):
|
||||||
def helper(dtype):
|
def helper(dtype):
|
||||||
# Test with axis -1
|
# Test with axis -1
|
||||||
cpu_x = None
|
cpu_x = None
|
||||||
if(dtype == torch.float32):
|
if dtype == torch.float32:
|
||||||
cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32)
|
cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32)
|
||||||
else:
|
else:
|
||||||
cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32)
|
cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32)
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ from torch.nn.parallel._functions import Broadcast
|
||||||
from torch.testing._internal.common_dtype import integral_types, get_all_math_dtypes, floating_types
|
from torch.testing._internal.common_dtype import integral_types, get_all_math_dtypes, floating_types
|
||||||
from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \
|
from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \
|
||||||
TEST_NUMPY, TEST_SCIPY, TEST_WITH_CROSSREF, TEST_WITH_ROCM, \
|
TEST_NUMPY, TEST_SCIPY, TEST_WITH_CROSSREF, TEST_WITH_ROCM, \
|
||||||
download_file, get_function_arglist, load_tests, skipIfMps,\
|
download_file, get_function_arglist, load_tests, skipIfMps, \
|
||||||
IS_PPC, \
|
IS_PPC, \
|
||||||
parametrize as parametrize_test, subtest, instantiate_parametrized_tests, \
|
parametrize as parametrize_test, subtest, instantiate_parametrized_tests, \
|
||||||
skipIfTorchDynamo, IS_WINDOWS, gcIfJetson, set_default_dtype
|
skipIfTorchDynamo, IS_WINDOWS, gcIfJetson, set_default_dtype
|
||||||
|
|
@ -6865,7 +6865,7 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
||||||
elif weight_layout == torch.sparse_coo:
|
elif weight_layout == torch.sparse_coo:
|
||||||
module.weight = nn.Parameter(module.weight.to_sparse_coo())
|
module.weight = nn.Parameter(module.weight.to_sparse_coo())
|
||||||
else:
|
else:
|
||||||
assert(0)
|
raise AssertionError()
|
||||||
|
|
||||||
inp = torch.randn(4, requires_grad=True, device=device)
|
inp = torch.randn(4, requires_grad=True, device=device)
|
||||||
res = module(inp)
|
res = module(inp)
|
||||||
|
|
|
||||||
|
|
@ -108,7 +108,7 @@ _ref_test_ops = tuple(
|
||||||
)
|
)
|
||||||
|
|
||||||
def reduction_dtype_filter(op):
|
def reduction_dtype_filter(op):
|
||||||
if(not isinstance(op, ReductionPythonRefInfo) or not op.supports_out
|
if (not isinstance(op, ReductionPythonRefInfo) or not op.supports_out
|
||||||
or torch.int16 not in op.dtypes):
|
or torch.int16 not in op.dtypes):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -217,7 +217,7 @@ class SubTensor(torch.Tensor):
|
||||||
"""
|
"""
|
||||||
@classmethod
|
@classmethod
|
||||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||||
if(kwargs is None):
|
if kwargs is None:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
||||||
if func not in HANDLED_FUNCTIONS_SUB:
|
if func not in HANDLED_FUNCTIONS_SUB:
|
||||||
|
|
@ -368,7 +368,7 @@ class TensorLike:
|
||||||
"""
|
"""
|
||||||
@classmethod
|
@classmethod
|
||||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||||
if(kwargs is None):
|
if kwargs is None:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
||||||
if func not in HANDLED_FUNCTIONS_TENSOR_LIKE:
|
if func not in HANDLED_FUNCTIONS_TENSOR_LIKE:
|
||||||
|
|
|
||||||
|
|
@ -465,7 +465,7 @@ class TestPythonRegistration(TestCase):
|
||||||
|
|
||||||
# check rest of functional_result is the mutated args
|
# check rest of functional_result is the mutated args
|
||||||
mutated_args = [maybe_mutated_arg for maybe_mutated_arg, arg in zip(cloned_args, args)
|
mutated_args = [maybe_mutated_arg for maybe_mutated_arg, arg in zip(cloned_args, args)
|
||||||
if not(maybe_mutated_arg is not None and arg is not None and torch.allclose(maybe_mutated_arg, arg))]
|
if not (maybe_mutated_arg is not None and arg is not None and torch.allclose(maybe_mutated_arg, arg))]
|
||||||
self.assertEqual(flat_functional_result[len(flat_mutable_result):], mutated_args)
|
self.assertEqual(flat_functional_result[len(flat_mutable_result):], mutated_args)
|
||||||
|
|
||||||
# check that functionalization kernel was indeed registered
|
# check that functionalization kernel was indeed registered
|
||||||
|
|
|
||||||
|
|
@ -1353,7 +1353,7 @@ class TestReductions(TestCase):
|
||||||
def _test_memory_format_transformations(self, device, input_generator_fn, transformation_fn,
|
def _test_memory_format_transformations(self, device, input_generator_fn, transformation_fn,
|
||||||
memory_format, compare_data=True, default_is_preserve=False):
|
memory_format, compare_data=True, default_is_preserve=False):
|
||||||
|
|
||||||
assert(memory_format == torch.channels_last or memory_format == torch.channels_last_3d)
|
assert memory_format == torch.channels_last or memory_format == torch.channels_last_3d
|
||||||
|
|
||||||
# xc is a channels last tensor
|
# xc is a channels last tensor
|
||||||
xc = input_generator_fn(device)
|
xc = input_generator_fn(device)
|
||||||
|
|
|
||||||
|
|
@ -3988,11 +3988,11 @@ class TestSparse(TestSparseBase):
|
||||||
yield (make_diags((1, 3)), make_offsets([0]), (3, 2, 3)), "Output shape must be 2d"
|
yield (make_diags((1, 3)), make_offsets([0]), (3, 2, 3)), "Output shape must be 2d"
|
||||||
yield (make_diags((2, 3)), make_offsets([[1, 2], [0, 3]]), (3, 3)), "Offsets must be scalar or vector"
|
yield (make_diags((2, 3)), make_offsets([[1, 2], [0, 3]]), (3, 3)), "Offsets must be scalar or vector"
|
||||||
yield (make_diags((3, 2, 3)), make_offsets([0, 1, 2]), (4, 4)), "Diagonals must be vector or matrix"
|
yield (make_diags((3, 2, 3)), make_offsets([0, 1, 2]), (4, 4)), "Diagonals must be vector or matrix"
|
||||||
yield (make_diags((3, 3)), make_offsets([-1, 0]), (3, 3)),\
|
yield (make_diags((3, 3)), make_offsets([-1, 0]), (3, 3)), \
|
||||||
r"Number of diagonals \(\d\) does not match the number of offsets \(\d\)"
|
r"Number of diagonals \(\d\) does not match the number of offsets \(\d\)"
|
||||||
yield (make_diags((5,)), make_offsets([0, 1, 2, 3, 4]), (3, 3)),\
|
yield (make_diags((5,)), make_offsets([0, 1, 2, 3, 4]), (3, 3)), \
|
||||||
r"Number of diagonals \(\d\) does not match the number of offsets \(\d\)"
|
r"Number of diagonals \(\d\) does not match the number of offsets \(\d\)"
|
||||||
yield (make_diags((2, 2)), make_offsets([-1, 0]), (2, 3), torch.strided),\
|
yield (make_diags((2, 2)), make_offsets([-1, 0]), (2, 3), torch.strided), \
|
||||||
r"Only output layouts \(\w+, \w+, \w+\) are supported, got \w+"
|
r"Only output layouts \(\w+, \w+, \w+\) are supported, got \w+"
|
||||||
yield (make_diags((2, 5)), make_offsets([0, 0]), (5, 5)), "Offset tensor contains duplicate values"
|
yield (make_diags((2, 5)), make_offsets([0, 0]), (5, 5)), "Offset tensor contains duplicate values"
|
||||||
yield (make_diags((1, 5)), make_offsets([0]).to(torch.int32), (5, 5)), r"Offset Tensor must have dtype Long but got \w+"
|
yield (make_diags((1, 5)), make_offsets([0]).to(torch.int32), (5, 5)), r"Offset Tensor must have dtype Long but got \w+"
|
||||||
|
|
|
||||||
|
|
@ -5210,7 +5210,7 @@ else:
|
||||||
def _test_memory_format_transformations(self, device, input_generator_fn, transformation_fn,
|
def _test_memory_format_transformations(self, device, input_generator_fn, transformation_fn,
|
||||||
memory_format, compare_data=True, default_is_preserve=False):
|
memory_format, compare_data=True, default_is_preserve=False):
|
||||||
|
|
||||||
assert(memory_format == torch.channels_last or memory_format == torch.channels_last_3d)
|
assert memory_format == torch.channels_last or memory_format == torch.channels_last_3d
|
||||||
|
|
||||||
# xc is a channels last tensor
|
# xc is a channels last tensor
|
||||||
xc = input_generator_fn(device)
|
xc = input_generator_fn(device)
|
||||||
|
|
|
||||||
|
|
@ -191,7 +191,7 @@ def _load_global_deps() -> None:
|
||||||
'nccl': 'libnccl.so.*[0-9]',
|
'nccl': 'libnccl.so.*[0-9]',
|
||||||
'nvtx': 'libnvToolsExt.so.*[0-9]',
|
'nvtx': 'libnvToolsExt.so.*[0-9]',
|
||||||
}
|
}
|
||||||
is_cuda_lib_err = [lib for lib in cuda_libs.values() if(lib.split('.')[0] in err.args[0])]
|
is_cuda_lib_err = [lib for lib in cuda_libs.values() if lib.split('.')[0] in err.args[0]]
|
||||||
if not is_cuda_lib_err:
|
if not is_cuda_lib_err:
|
||||||
raise err
|
raise err
|
||||||
for lib_folder, lib_name in cuda_libs.items():
|
for lib_folder, lib_name in cuda_libs.items():
|
||||||
|
|
|
||||||
|
|
@ -75,7 +75,7 @@ def is_gpu_compute_event(event):
|
||||||
def get_sorted_gpu_events(events):
|
def get_sorted_gpu_events(events):
|
||||||
sorted_gpu_events = []
|
sorted_gpu_events = []
|
||||||
for event in events:
|
for event in events:
|
||||||
if(not is_gpu_compute_event(event)):
|
if not is_gpu_compute_event(event):
|
||||||
continue
|
continue
|
||||||
sorted_gpu_events.append(event)
|
sorted_gpu_events.append(event)
|
||||||
return sorted(sorted_gpu_events, key=lambda x: x["ts"])
|
return sorted(sorted_gpu_events, key=lambda x: x["ts"])
|
||||||
|
|
@ -102,7 +102,7 @@ def get_sorted_gpu_mm_conv_events(events):
|
||||||
gpu_events = get_sorted_gpu_events(events)
|
gpu_events = get_sorted_gpu_events(events)
|
||||||
sorted_events = []
|
sorted_events = []
|
||||||
for event in gpu_events:
|
for event in gpu_events:
|
||||||
if(not is_mm_conv_event(event)):
|
if not is_mm_conv_event(event):
|
||||||
continue
|
continue
|
||||||
sorted_events.append(event)
|
sorted_events.append(event)
|
||||||
return sorted_events
|
return sorted_events
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,7 @@ class OutDtypeOperator(HigherOrderOperator):
|
||||||
raise ValueError("out_dtype's first argument must be an OpOverload")
|
raise ValueError("out_dtype's first argument must be an OpOverload")
|
||||||
if op._schema.is_mutable:
|
if op._schema.is_mutable:
|
||||||
raise ValueError("out_dtype's first argument needs to be a functional operator")
|
raise ValueError("out_dtype's first argument needs to be a functional operator")
|
||||||
if not(
|
if not (
|
||||||
len(op._schema.returns) == 1 and
|
len(op._schema.returns) == 1 and
|
||||||
isinstance(op._schema.returns[0].type, torch.TensorType)
|
isinstance(op._schema.returns[0].type, torch.TensorType)
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -240,7 +240,7 @@ class QFunctional(torch.nn.Module):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod):
|
||||||
assert type(mod) == FloatFunctional,\
|
assert type(mod) == FloatFunctional, \
|
||||||
"QFunctional.from_float expects an instance of FloatFunctional"
|
"QFunctional.from_float expects an instance of FloatFunctional"
|
||||||
scale, zero_point = mod.activation_post_process.calculate_qparams() # type: ignore[operator]
|
scale, zero_point = mod.activation_post_process.calculate_qparams() # type: ignore[operator]
|
||||||
new_mod = QFunctional()
|
new_mod = QFunctional()
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ class LinearBlockSparsePattern:
|
||||||
prev_col_block_size = 4
|
prev_col_block_size = 4
|
||||||
|
|
||||||
def __init__(self, row_block_size=1, col_block_size=4):
|
def __init__(self, row_block_size=1, col_block_size=4):
|
||||||
assert(_is_valid_linear_block_sparse_pattern(row_block_size, col_block_size))
|
assert _is_valid_linear_block_sparse_pattern(row_block_size, col_block_size)
|
||||||
LinearBlockSparsePattern.rlock.acquire()
|
LinearBlockSparsePattern.rlock.acquire()
|
||||||
LinearBlockSparsePattern.prev_row_block_size = LinearBlockSparsePattern.row_block_size
|
LinearBlockSparsePattern.prev_row_block_size = LinearBlockSparsePattern.row_block_size
|
||||||
LinearBlockSparsePattern.prev_col_block_size = LinearBlockSparsePattern.col_block_size
|
LinearBlockSparsePattern.prev_col_block_size = LinearBlockSparsePattern.col_block_size
|
||||||
|
|
|
||||||
|
|
@ -85,7 +85,7 @@ def _fuse_linear_bn_leaky_relu(is_qat, linear, bn, leaky_relu):
|
||||||
>>> lr = nn.LeakyReLU(0.01)
|
>>> lr = nn.LeakyReLU(0.01)
|
||||||
>>> m2 = _fuse_linear_bn_leaky_relu(m1, b1, lr)
|
>>> m2 = _fuse_linear_bn_leaky_relu(m1, b1, lr)
|
||||||
"""
|
"""
|
||||||
assert(linear.training == bn.training and bn.training == leaky_relu.training),\
|
assert linear.training == bn.training and bn.training == leaky_relu.training, \
|
||||||
"Linear, BN and LeakyReLU all must be in the same mode (train or eval)."
|
"Linear, BN and LeakyReLU all must be in the same mode (train or eval)."
|
||||||
|
|
||||||
if is_qat:
|
if is_qat:
|
||||||
|
|
|
||||||
|
|
@ -58,8 +58,8 @@ class APoTObserver(ObserverBase):
|
||||||
alpha = torch.max(-self.min_val, self.max_val)
|
alpha = torch.max(-self.min_val, self.max_val)
|
||||||
|
|
||||||
# check for valid inputs of b, k
|
# check for valid inputs of b, k
|
||||||
assert(self.k and self.k != 0)
|
assert self.k and self.k != 0
|
||||||
assert(self.b % self.k == 0)
|
assert self.b % self.k == 0
|
||||||
|
|
||||||
# compute n and store as member variable
|
# compute n and store as member variable
|
||||||
self.n = self.b // self.k
|
self.n = self.b // self.k
|
||||||
|
|
|
||||||
|
|
@ -274,7 +274,7 @@ class FixedQParamsFakeQuantize(FakeQuantize):
|
||||||
# TODO: rename observer to observer_ctr
|
# TODO: rename observer to observer_ctr
|
||||||
def __init__(self, observer):
|
def __init__(self, observer):
|
||||||
super().__init__(observer=observer)
|
super().__init__(observer=observer)
|
||||||
assert type(self.activation_post_process) == FixedQParamsObserver,\
|
assert type(self.activation_post_process) == FixedQParamsObserver, \
|
||||||
f"{self.__class__.__name__}'s observer must be a {FixedQParamsObserver.__name__}"
|
f"{self.__class__.__name__}'s observer must be a {FixedQParamsObserver.__name__}"
|
||||||
self._observer_ctr = observer
|
self._observer_ctr = observer
|
||||||
self.scale = self.activation_post_process.scale
|
self.scale = self.activation_post_process.scale
|
||||||
|
|
@ -322,7 +322,7 @@ class FusedMovingAvgObsFakeQuantize(FakeQuantize):
|
||||||
**observer_kwargs: Any
|
**observer_kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(observer, quant_min, quant_max, **observer_kwargs)
|
super().__init__(observer, quant_min, quant_max, **observer_kwargs)
|
||||||
assert isinstance(self.activation_post_process, (MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver)),\
|
assert isinstance(self.activation_post_process, (MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver)), \
|
||||||
"Fused observer+fake_quant module only works with MovingAverageMinMaxObserver"
|
"Fused observer+fake_quant module only works with MovingAverageMinMaxObserver"
|
||||||
self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long))
|
self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long))
|
||||||
self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long))
|
self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long))
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ def fuse_conv_bn(is_qat, conv, bn):
|
||||||
>>> # xdoctest: +SKIP
|
>>> # xdoctest: +SKIP
|
||||||
>>> m2 = fuse_conv_bn(m1, b1)
|
>>> m2 = fuse_conv_bn(m1, b1)
|
||||||
"""
|
"""
|
||||||
assert(conv.training == bn.training),\
|
assert conv.training == bn.training, \
|
||||||
"Conv and BN both must be in the same mode (train or eval)."
|
"Conv and BN both must be in the same mode (train or eval)."
|
||||||
|
|
||||||
fused_module_class_map = {
|
fused_module_class_map = {
|
||||||
|
|
@ -71,7 +71,7 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu):
|
||||||
>>> # xdoctest: +SKIP
|
>>> # xdoctest: +SKIP
|
||||||
>>> m2 = fuse_conv_bn_relu(m1, b1, r1)
|
>>> m2 = fuse_conv_bn_relu(m1, b1, r1)
|
||||||
"""
|
"""
|
||||||
assert(conv.training == bn.training == relu.training),\
|
assert conv.training == bn.training == relu.training, \
|
||||||
"Conv and BN both must be in the same mode (train or eval)."
|
"Conv and BN both must be in the same mode (train or eval)."
|
||||||
fused_module : Optional[Type[nn.Sequential]] = None
|
fused_module : Optional[Type[nn.Sequential]] = None
|
||||||
if is_qat:
|
if is_qat:
|
||||||
|
|
@ -118,14 +118,14 @@ def fuse_linear_bn(is_qat, linear, bn):
|
||||||
>>> # xdoctest: +SKIP
|
>>> # xdoctest: +SKIP
|
||||||
>>> m2 = fuse_linear_bn(m1, b1)
|
>>> m2 = fuse_linear_bn(m1, b1)
|
||||||
"""
|
"""
|
||||||
assert(linear.training == bn.training),\
|
assert linear.training == bn.training, \
|
||||||
"Linear and BN both must be in the same mode (train or eval)."
|
"Linear and BN both must be in the same mode (train or eval)."
|
||||||
|
|
||||||
if is_qat:
|
if is_qat:
|
||||||
assert bn.num_features == linear.out_features,\
|
assert bn.num_features == linear.out_features, \
|
||||||
"Output features of Linear must match num_features of BatchNorm1d"
|
"Output features of Linear must match num_features of BatchNorm1d"
|
||||||
assert bn.affine, "Only support fusing BatchNorm1d with affine set to True"
|
assert bn.affine, "Only support fusing BatchNorm1d with affine set to True"
|
||||||
assert bn.track_running_stats,\
|
assert bn.track_running_stats, \
|
||||||
"Only support fusing BatchNorm1d with tracking_running_stats set to True"
|
"Only support fusing BatchNorm1d with tracking_running_stats set to True"
|
||||||
return nni.LinearBn1d(linear, bn)
|
return nni.LinearBn1d(linear, bn)
|
||||||
else:
|
else:
|
||||||
|
|
@ -147,7 +147,7 @@ def fuse_convtranspose_bn(is_qat, convt, bn):
|
||||||
>>> # xdoctest: +SKIP
|
>>> # xdoctest: +SKIP
|
||||||
>>> m2 = fuse_convtranspose_bn(m1, b1)
|
>>> m2 = fuse_convtranspose_bn(m1, b1)
|
||||||
"""
|
"""
|
||||||
assert(convt.training == bn.training),\
|
assert convt.training == bn.training, \
|
||||||
"ConvTranspose and BN both must be in the same mode (train or eval)."
|
"ConvTranspose and BN both must be in the same mode (train or eval)."
|
||||||
|
|
||||||
if is_qat:
|
if is_qat:
|
||||||
|
|
|
||||||
|
|
@ -291,24 +291,24 @@ def get_op_node_and_weight_eq_obs(
|
||||||
op_node = user
|
op_node = user
|
||||||
break
|
break
|
||||||
|
|
||||||
assert(op_node is not None)
|
assert op_node is not None
|
||||||
if op_node.op == 'call_module':
|
if op_node.op == 'call_module':
|
||||||
# If the op_node is a nn.Linear layer, then it must have a
|
# If the op_node is a nn.Linear layer, then it must have a
|
||||||
# WeightEqualizationObserver configuration
|
# WeightEqualizationObserver configuration
|
||||||
maybe_equalization_node_name_to_config = _get_observed_graph_module_attr(model, "equalization_node_name_to_qconfig")
|
maybe_equalization_node_name_to_config = _get_observed_graph_module_attr(model, "equalization_node_name_to_qconfig")
|
||||||
assert maybe_equalization_node_name_to_config is not None
|
assert maybe_equalization_node_name_to_config is not None
|
||||||
equalization_node_name_to_qconfig: Dict[str, Any] = maybe_equalization_node_name_to_config # type: ignore[assignment]
|
equalization_node_name_to_qconfig: Dict[str, Any] = maybe_equalization_node_name_to_config # type: ignore[assignment]
|
||||||
assert(equalization_node_name_to_qconfig.get(op_node.name, None) is not None)
|
assert equalization_node_name_to_qconfig.get(op_node.name, None) is not None
|
||||||
weight_eq_obs = equalization_node_name_to_qconfig.get(op_node.name, None).weight()
|
weight_eq_obs = equalization_node_name_to_qconfig.get(op_node.name, None).weight()
|
||||||
|
|
||||||
assert(isinstance(weight_eq_obs, _WeightEqualizationObserver))
|
assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
|
||||||
return op_node, weight_eq_obs
|
return op_node, weight_eq_obs
|
||||||
|
|
||||||
elif op_node.op == 'call_function':
|
elif op_node.op == 'call_function':
|
||||||
weight_node = maybe_get_weight_eq_obs_node(op_node, modules)
|
weight_node = maybe_get_weight_eq_obs_node(op_node, modules)
|
||||||
if weight_node is not None:
|
if weight_node is not None:
|
||||||
weight_eq_obs = modules[str(weight_node.target)]
|
weight_eq_obs = modules[str(weight_node.target)]
|
||||||
assert(isinstance(weight_eq_obs, _WeightEqualizationObserver))
|
assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
|
||||||
return op_node, weight_eq_obs
|
return op_node, weight_eq_obs
|
||||||
|
|
||||||
return None, None
|
return None, None
|
||||||
|
|
@ -316,10 +316,10 @@ def get_op_node_and_weight_eq_obs(
|
||||||
def maybe_get_weight_eq_obs_node(op_node: Node, modules: Dict[str, nn.Module]) -> Optional[Node]:
|
def maybe_get_weight_eq_obs_node(op_node: Node, modules: Dict[str, nn.Module]) -> Optional[Node]:
|
||||||
""" Gets the weight equalization observer node if it exists.
|
""" Gets the weight equalization observer node if it exists.
|
||||||
"""
|
"""
|
||||||
assert(op_node.op == 'call_function')
|
assert op_node.op == 'call_function'
|
||||||
for node_arg in op_node.args:
|
for node_arg in op_node.args:
|
||||||
if node_arg_is_weight(op_node, node_arg):
|
if node_arg_is_weight(op_node, node_arg):
|
||||||
assert(isinstance(node_arg, Node) and node_arg.op == 'call_module' and
|
assert (isinstance(node_arg, Node) and node_arg.op == 'call_module' and
|
||||||
isinstance(modules[str(node_arg.target)], _WeightEqualizationObserver))
|
isinstance(modules[str(node_arg.target)], _WeightEqualizationObserver))
|
||||||
return node_arg
|
return node_arg
|
||||||
return None
|
return None
|
||||||
|
|
@ -342,7 +342,7 @@ def maybe_get_next_input_eq_obs(node: Node, modules: Dict[str, nn.Module]) -> Op
|
||||||
the following equalization observer for linear2.
|
the following equalization observer for linear2.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert(node_supports_equalization(node, modules))
|
assert node_supports_equalization(node, modules)
|
||||||
|
|
||||||
# Locate the following nn.ReLU or F.relu node if it exists
|
# Locate the following nn.ReLU or F.relu node if it exists
|
||||||
maybe_relu_node = maybe_get_next_module(node, modules, nn.ReLU)
|
maybe_relu_node = maybe_get_next_module(node, modules, nn.ReLU)
|
||||||
|
|
@ -364,7 +364,7 @@ def maybe_get_next_input_eq_obs(node: Node, modules: Dict[str, nn.Module]) -> Op
|
||||||
return None
|
return None
|
||||||
|
|
||||||
maybe_eq_obs = modules[str(maybe_eq_obs_node)]
|
maybe_eq_obs = modules[str(maybe_eq_obs_node)]
|
||||||
assert(isinstance(maybe_eq_obs, _InputEqualizationObserver))
|
assert isinstance(maybe_eq_obs, _InputEqualizationObserver)
|
||||||
return maybe_eq_obs
|
return maybe_eq_obs
|
||||||
|
|
||||||
def maybe_get_next_equalization_scale(node: Node, modules: Dict[str, nn.Module]) -> Optional[torch.Tensor]:
|
def maybe_get_next_equalization_scale(node: Node, modules: Dict[str, nn.Module]) -> Optional[torch.Tensor]:
|
||||||
|
|
@ -389,10 +389,10 @@ def scale_input_observer(node: Node, modules: Dict[str, nn.Module]) -> None:
|
||||||
equalization observer
|
equalization observer
|
||||||
"""
|
"""
|
||||||
input_eq_obs = modules[str(node.target)]
|
input_eq_obs = modules[str(node.target)]
|
||||||
assert(isinstance(input_eq_obs, _InputEqualizationObserver))
|
assert isinstance(input_eq_obs, _InputEqualizationObserver)
|
||||||
|
|
||||||
input_quant_obs_node = node.args[0]
|
input_quant_obs_node = node.args[0]
|
||||||
assert(isinstance(input_quant_obs_node, Node))
|
assert isinstance(input_quant_obs_node, Node)
|
||||||
|
|
||||||
input_quant_obs = modules[str(input_quant_obs_node.target)]
|
input_quant_obs = modules[str(input_quant_obs_node.target)]
|
||||||
if not isinstance(input_quant_obs, ObserverBase):
|
if not isinstance(input_quant_obs, ObserverBase):
|
||||||
|
|
@ -426,12 +426,12 @@ def scale_weight_node(
|
||||||
op_module = modules[str(node.target)][0] # type: ignore[index]
|
op_module = modules[str(node.target)][0] # type: ignore[index]
|
||||||
else:
|
else:
|
||||||
op_module = modules[str(node.target)]
|
op_module = modules[str(node.target)]
|
||||||
assert(nn_module_supports_equalization(op_module) or custom_module_supports_equalization(op_module))
|
assert nn_module_supports_equalization(op_module) or custom_module_supports_equalization(op_module)
|
||||||
|
|
||||||
# Scale the weights for input-weight equalization
|
# Scale the weights for input-weight equalization
|
||||||
# If the following layer needs to be equalized then we will multiply its scale
|
# If the following layer needs to be equalized then we will multiply its scale
|
||||||
weight = op_module.weight
|
weight = op_module.weight
|
||||||
assert(isinstance(weight, torch.Tensor))
|
assert isinstance(weight, torch.Tensor)
|
||||||
|
|
||||||
# Scale the weights by the reciprocal of the equalization scale
|
# Scale the weights by the reciprocal of the equalization scale
|
||||||
# Reshape the equalization scale so that we can multiply it to the weight along axis=1
|
# Reshape the equalization scale so that we can multiply it to the weight along axis=1
|
||||||
|
|
@ -453,7 +453,7 @@ def scale_weight_node(
|
||||||
bias = op_module.bias
|
bias = op_module.bias
|
||||||
if bias is None:
|
if bias is None:
|
||||||
return
|
return
|
||||||
assert(isinstance(bias, torch.Tensor))
|
assert isinstance(bias, torch.Tensor)
|
||||||
|
|
||||||
# Reshape the equalization scale so that we can multiply it element-wise to the bias
|
# Reshape the equalization scale so that we can multiply it element-wise to the bias
|
||||||
next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias)
|
next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias)
|
||||||
|
|
@ -487,14 +487,14 @@ def scale_weight_functional(
|
||||||
weight_quant_obs_node = weight_eq_obs_node.args[0]
|
weight_quant_obs_node = weight_eq_obs_node.args[0]
|
||||||
if weight_quant_obs_node is None:
|
if weight_quant_obs_node is None:
|
||||||
return
|
return
|
||||||
assert(isinstance(weight_quant_obs_node, Node) and
|
assert (isinstance(weight_quant_obs_node, Node) and
|
||||||
isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase))
|
isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase))
|
||||||
|
|
||||||
# Get the get_attr(weight) node
|
# Get the get_attr(weight) node
|
||||||
weight_node = weight_quant_obs_node.args[0]
|
weight_node = weight_quant_obs_node.args[0]
|
||||||
if weight_node is None:
|
if weight_node is None:
|
||||||
return
|
return
|
||||||
assert(isinstance(weight_node, Node) and weight_node.op == 'get_attr')
|
assert isinstance(weight_node, Node) and weight_node.op == 'get_attr'
|
||||||
|
|
||||||
weight_parent_name, weight_name = _parent_name(weight_node.target)
|
weight_parent_name, weight_name = _parent_name(weight_node.target)
|
||||||
weight = getattr(modules[weight_parent_name], weight_name)
|
weight = getattr(modules[weight_parent_name], weight_name)
|
||||||
|
|
@ -515,7 +515,7 @@ def scale_weight_functional(
|
||||||
scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped)
|
scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped)
|
||||||
|
|
||||||
setattr(modules[weight_parent_name], weight_name, scaled_weight)
|
setattr(modules[weight_parent_name], weight_name, scaled_weight)
|
||||||
assert(torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight))
|
assert torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight)
|
||||||
|
|
||||||
# Multiply the bias element wise by the next equalization scale
|
# Multiply the bias element wise by the next equalization scale
|
||||||
bias_node = None
|
bias_node = None
|
||||||
|
|
@ -546,10 +546,10 @@ def clear_weight_quant_obs_node(op_node: Node, modules: Dict[str, nn.Module]) ->
|
||||||
weight_quant_obs_node = weight_eq_obs_node.args[0]
|
weight_quant_obs_node = weight_eq_obs_node.args[0]
|
||||||
if weight_quant_obs_node is None:
|
if weight_quant_obs_node is None:
|
||||||
return
|
return
|
||||||
assert(isinstance(weight_quant_obs_node, Node))
|
assert isinstance(weight_quant_obs_node, Node)
|
||||||
|
|
||||||
weight_quant_obs = modules[str(weight_quant_obs_node.target)]
|
weight_quant_obs = modules[str(weight_quant_obs_node.target)]
|
||||||
assert(isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase))
|
assert isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase)
|
||||||
weight_quant_obs.reset_min_max_vals() # type: ignore[operator]
|
weight_quant_obs.reset_min_max_vals() # type: ignore[operator]
|
||||||
|
|
||||||
def remove_node(model: GraphModule, node: Node, prev_node: Node):
|
def remove_node(model: GraphModule, node: Node, prev_node: Node):
|
||||||
|
|
@ -578,7 +578,7 @@ def update_obs_for_equalization(model: GraphModule, modules: Dict[str, nn.Module
|
||||||
for node in model.graph.nodes:
|
for node in model.graph.nodes:
|
||||||
if node.op == 'call_module' and isinstance(modules[node.target], _InputEqualizationObserver):
|
if node.op == 'call_module' and isinstance(modules[node.target], _InputEqualizationObserver):
|
||||||
input_eq_obs = modules[node.target]
|
input_eq_obs = modules[node.target]
|
||||||
assert(isinstance(input_eq_obs, _InputEqualizationObserver))
|
assert isinstance(input_eq_obs, _InputEqualizationObserver)
|
||||||
op_node, weight_eq_obs = get_op_node_and_weight_eq_obs(node, model, modules)
|
op_node, weight_eq_obs = get_op_node_and_weight_eq_obs(node, model, modules)
|
||||||
|
|
||||||
if op_node is None or weight_eq_obs is None:
|
if op_node is None or weight_eq_obs is None:
|
||||||
|
|
@ -589,7 +589,7 @@ def update_obs_for_equalization(model: GraphModule, modules: Dict[str, nn.Module
|
||||||
# been created
|
# been created
|
||||||
if fused_module_supports_equalization(modules[str(op_node.target)]):
|
if fused_module_supports_equalization(modules[str(op_node.target)]):
|
||||||
module = modules[str(op_node.target)][0] # type: ignore[index]
|
module = modules[str(op_node.target)][0] # type: ignore[index]
|
||||||
assert(nn_module_supports_equalization(module))
|
assert nn_module_supports_equalization(module)
|
||||||
weight_eq_obs(module.weight)
|
weight_eq_obs(module.weight)
|
||||||
else:
|
else:
|
||||||
weight_eq_obs(modules[str(op_node.target)].weight)
|
weight_eq_obs(modules[str(op_node.target)].weight)
|
||||||
|
|
@ -696,7 +696,7 @@ def convert_eq_obs(
|
||||||
|
|
||||||
elif weight_eq_obs_dict.get(node.name, None) is not None:
|
elif weight_eq_obs_dict.get(node.name, None) is not None:
|
||||||
weight_eq_obs = weight_eq_obs_dict.get(node.name)
|
weight_eq_obs = weight_eq_obs_dict.get(node.name)
|
||||||
assert(isinstance(weight_eq_obs, _WeightEqualizationObserver))
|
assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
|
||||||
equalization_scale = weight_eq_obs.equalization_scale
|
equalization_scale = weight_eq_obs.equalization_scale
|
||||||
|
|
||||||
if equalization_scale.nelement() == 1 and equalization_scale == torch.tensor(1):
|
if equalization_scale.nelement() == 1 and equalization_scale == torch.tensor(1):
|
||||||
|
|
@ -712,7 +712,7 @@ def convert_eq_obs(
|
||||||
weight_eq_obs_node = maybe_get_weight_eq_obs_node(node, modules)
|
weight_eq_obs_node = maybe_get_weight_eq_obs_node(node, modules)
|
||||||
if weight_eq_obs_node is None:
|
if weight_eq_obs_node is None:
|
||||||
return
|
return
|
||||||
assert(isinstance(modules[str(weight_eq_obs_node.target)], _WeightEqualizationObserver))
|
assert isinstance(modules[str(weight_eq_obs_node.target)], _WeightEqualizationObserver)
|
||||||
|
|
||||||
# Clear the quantization observer's min/max values so that they
|
# Clear the quantization observer's min/max values so that they
|
||||||
# can get updated later based on the new scale values
|
# can get updated later based on the new scale values
|
||||||
|
|
|
||||||
|
|
@ -487,14 +487,14 @@ def _match_static_pattern(
|
||||||
return SKIP_LOWERING_VALUE
|
return SKIP_LOWERING_VALUE
|
||||||
q_node = node
|
q_node = node
|
||||||
ref_node = q_node.args[0]
|
ref_node = q_node.args[0]
|
||||||
assert(isinstance(ref_node, Node))
|
assert isinstance(ref_node, Node)
|
||||||
|
|
||||||
# Handle cases where the node is wrapped in a ReLU
|
# Handle cases where the node is wrapped in a ReLU
|
||||||
if (ref_node.op == "call_function" and ref_node.target in (F.relu, torch.relu)) or\
|
if (ref_node.op == "call_function" and ref_node.target in (F.relu, torch.relu)) or\
|
||||||
(ref_node.op == "call_module" and type(_get_module(ref_node, modules)) == nn.ReLU):
|
(ref_node.op == "call_module" and type(_get_module(ref_node, modules)) == nn.ReLU):
|
||||||
relu_node = ref_node
|
relu_node = ref_node
|
||||||
ref_node = relu_node.args[0]
|
ref_node = relu_node.args[0]
|
||||||
assert(isinstance(ref_node, Node))
|
assert isinstance(ref_node, Node)
|
||||||
else:
|
else:
|
||||||
relu_node = None
|
relu_node = None
|
||||||
if should_skip_lowering(ref_node, qconfig_map):
|
if should_skip_lowering(ref_node, qconfig_map):
|
||||||
|
|
@ -515,7 +515,7 @@ def _match_static_pattern(
|
||||||
# (2) There must be at least one dequantize node
|
# (2) There must be at least one dequantize node
|
||||||
matched_dequantize = False
|
matched_dequantize = False
|
||||||
for i in dequantize_node_arg_indices:
|
for i in dequantize_node_arg_indices:
|
||||||
assert i < len(ref_node.args),\
|
assert i < len(ref_node.args), \
|
||||||
f"Dequantize index {i} exceeded reference node's arg length {len(ref_node.args)}"
|
f"Dequantize index {i} exceeded reference node's arg length {len(ref_node.args)}"
|
||||||
arg = ref_node.args[i]
|
arg = ref_node.args[i]
|
||||||
if is_dequantize_node(arg):
|
if is_dequantize_node(arg):
|
||||||
|
|
@ -557,7 +557,7 @@ def _match_static_pattern_with_two_inputs(
|
||||||
return SKIP_LOWERING_VALUE
|
return SKIP_LOWERING_VALUE
|
||||||
q_node = node
|
q_node = node
|
||||||
ref_node = q_node.args[0]
|
ref_node = q_node.args[0]
|
||||||
assert(isinstance(ref_node, Node))
|
assert isinstance(ref_node, Node)
|
||||||
|
|
||||||
if should_skip_lowering(ref_node, qconfig_map):
|
if should_skip_lowering(ref_node, qconfig_map):
|
||||||
return SKIP_LOWERING_VALUE
|
return SKIP_LOWERING_VALUE
|
||||||
|
|
@ -599,13 +599,13 @@ def _lower_static_weighted_ref_module(
|
||||||
n, modules, qconfig_map, matching_modules, dequantize_node_arg_indices=[0]) # type: ignore[arg-type]
|
n, modules, qconfig_map, matching_modules, dequantize_node_arg_indices=[0]) # type: ignore[arg-type]
|
||||||
if q_node is None:
|
if q_node is None:
|
||||||
continue
|
continue
|
||||||
assert(ref_node is not None)
|
assert ref_node is not None
|
||||||
(_, scale_node, zero_point_node, _) = q_node.args
|
(_, scale_node, zero_point_node, _) = q_node.args
|
||||||
ref_module = _get_module(ref_node, modules)
|
ref_module = _get_module(ref_node, modules)
|
||||||
ref_class = type(ref_module)
|
ref_class = type(ref_module)
|
||||||
assert(isinstance(scale_node, Node))
|
assert isinstance(scale_node, Node)
|
||||||
assert(isinstance(zero_point_node, Node))
|
assert isinstance(zero_point_node, Node)
|
||||||
assert(issubclass(ref_class, nn.Module))
|
assert issubclass(ref_class, nn.Module)
|
||||||
|
|
||||||
# Step 1: Change this pattern to use the corresponding quantized module
|
# Step 1: Change this pattern to use the corresponding quantized module
|
||||||
# For fused modules, we also check whether the inner module is a reference module
|
# For fused modules, we also check whether the inner module is a reference module
|
||||||
|
|
@ -624,9 +624,9 @@ def _lower_static_weighted_ref_module(
|
||||||
setattr(modules[parent_name], module_name, q_module)
|
setattr(modules[parent_name], module_name, q_module)
|
||||||
|
|
||||||
# Step 2: Reroute around dq_node, and remove q_node and its args
|
# Step 2: Reroute around dq_node, and remove q_node and its args
|
||||||
assert(len(ref_node.args) == 1)
|
assert len(ref_node.args) == 1
|
||||||
dq_node = ref_node.args[0]
|
dq_node = ref_node.args[0]
|
||||||
assert(isinstance(dq_node, Node))
|
assert isinstance(dq_node, Node)
|
||||||
ref_node.replace_input_with(dq_node, dq_node.args[0])
|
ref_node.replace_input_with(dq_node, dq_node.args[0])
|
||||||
q_node.replace_all_uses_with(ref_node)
|
q_node.replace_all_uses_with(ref_node)
|
||||||
model.graph.erase_node(q_node)
|
model.graph.erase_node(q_node)
|
||||||
|
|
@ -655,13 +655,13 @@ def _lower_static_weighted_ref_module_with_two_inputs(
|
||||||
n, modules, qconfig_map, matching_modules) # type: ignore[arg-type]
|
n, modules, qconfig_map, matching_modules) # type: ignore[arg-type]
|
||||||
if q_node is None:
|
if q_node is None:
|
||||||
continue
|
continue
|
||||||
assert(ref_node is not None)
|
assert ref_node is not None
|
||||||
(_, scale_node, zero_point_node, _) = q_node.args
|
(_, scale_node, zero_point_node, _) = q_node.args
|
||||||
ref_module = _get_module(ref_node, modules)
|
ref_module = _get_module(ref_node, modules)
|
||||||
ref_class = type(ref_module)
|
ref_class = type(ref_module)
|
||||||
assert(isinstance(scale_node, Node))
|
assert isinstance(scale_node, Node)
|
||||||
assert(isinstance(zero_point_node, Node))
|
assert isinstance(zero_point_node, Node)
|
||||||
assert(issubclass(ref_class, nn.Module))
|
assert issubclass(ref_class, nn.Module)
|
||||||
|
|
||||||
# Step 1: Change this pattern to use the corresponding quantized module
|
# Step 1: Change this pattern to use the corresponding quantized module
|
||||||
# For fused modules, we also check whether the inner module is a reference module
|
# For fused modules, we also check whether the inner module is a reference module
|
||||||
|
|
@ -680,12 +680,12 @@ def _lower_static_weighted_ref_module_with_two_inputs(
|
||||||
setattr(modules[parent_name], module_name, q_module)
|
setattr(modules[parent_name], module_name, q_module)
|
||||||
|
|
||||||
# Step 2: Reroute around dq_node, and remove q_node and its args
|
# Step 2: Reroute around dq_node, and remove q_node and its args
|
||||||
assert(len(ref_node.args) == 2)
|
assert len(ref_node.args) == 2
|
||||||
for arg in ref_node.args:
|
for arg in ref_node.args:
|
||||||
if not is_dequantize_node(arg):
|
if not is_dequantize_node(arg):
|
||||||
continue
|
continue
|
||||||
dq_node = arg
|
dq_node = arg
|
||||||
assert(isinstance(dq_node, Node))
|
assert isinstance(dq_node, Node)
|
||||||
ref_node.replace_input_with(dq_node, dq_node.args[0])
|
ref_node.replace_input_with(dq_node, dq_node.args[0])
|
||||||
|
|
||||||
q_node.replace_all_uses_with(ref_node)
|
q_node.replace_all_uses_with(ref_node)
|
||||||
|
|
@ -778,14 +778,14 @@ def _lower_static_weighted_ref_functional(
|
||||||
n, modules, qconfig_map, matching_ops, dequantize_node_arg_indices=[0, 1])
|
n, modules, qconfig_map, matching_ops, dequantize_node_arg_indices=[0, 1])
|
||||||
if q_node is None:
|
if q_node is None:
|
||||||
continue
|
continue
|
||||||
assert(func_node is not None)
|
assert func_node is not None
|
||||||
(_, output_scale_node, output_zp_node, _) = q_node.args
|
(_, output_scale_node, output_zp_node, _) = q_node.args
|
||||||
(input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args
|
(input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args
|
||||||
assert(isinstance(output_zp_node, Node))
|
assert isinstance(output_zp_node, Node)
|
||||||
assert(isinstance(input_dq_node, Node))
|
assert isinstance(input_dq_node, Node)
|
||||||
assert(isinstance(weight_dq_node, Node))
|
assert isinstance(weight_dq_node, Node)
|
||||||
quantized_weight = weight_dq_node.args[0]
|
quantized_weight = weight_dq_node.args[0]
|
||||||
assert(isinstance(quantized_weight, Node))
|
assert isinstance(quantized_weight, Node)
|
||||||
if quantized_weight.op != "call_function" or\
|
if quantized_weight.op != "call_function" or\
|
||||||
quantized_weight.target not in (torch.quantize_per_tensor, torch.quantize_per_channel):
|
quantized_weight.target not in (torch.quantize_per_tensor, torch.quantize_per_channel):
|
||||||
continue
|
continue
|
||||||
|
|
@ -966,7 +966,7 @@ def _lower_quantized_binary_op(
|
||||||
n, modules, qconfig_map, binary_ops_to_lower, dequantize_node_arg_indices=[0, 1])
|
n, modules, qconfig_map, binary_ops_to_lower, dequantize_node_arg_indices=[0, 1])
|
||||||
if q_node is None:
|
if q_node is None:
|
||||||
continue
|
continue
|
||||||
assert(bop_node is not None)
|
assert bop_node is not None
|
||||||
(_, scale_node, zero_point_node, _) = q_node.args
|
(_, scale_node, zero_point_node, _) = q_node.args
|
||||||
|
|
||||||
# Step 1: Remove dequant nodes
|
# Step 1: Remove dequant nodes
|
||||||
|
|
@ -975,11 +975,11 @@ def _lower_quantized_binary_op(
|
||||||
if not is_dequantize_node(arg):
|
if not is_dequantize_node(arg):
|
||||||
continue
|
continue
|
||||||
dq_node = arg
|
dq_node = arg
|
||||||
assert(isinstance(dq_node, Node))
|
assert isinstance(dq_node, Node)
|
||||||
dn_input = dq_node.args[0]
|
dn_input = dq_node.args[0]
|
||||||
bop_node.replace_input_with(dq_node, dn_input)
|
bop_node.replace_input_with(dq_node, dn_input)
|
||||||
num_dq_nodes += 1
|
num_dq_nodes += 1
|
||||||
assert(num_dq_nodes > 0)
|
assert num_dq_nodes > 0
|
||||||
|
|
||||||
# Step 2: Swap binary op to quantized binary op
|
# Step 2: Swap binary op to quantized binary op
|
||||||
assert bop_node.target in QBIN_OP_MAPPING
|
assert bop_node.target in QBIN_OP_MAPPING
|
||||||
|
|
|
||||||
|
|
@ -956,7 +956,7 @@ def convert(
|
||||||
"in a future version. Please pass in a QConfigMapping instead.")
|
"in a future version. Please pass in a QConfigMapping instead.")
|
||||||
qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None
|
qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None
|
||||||
qconfig_mapping = copy.deepcopy(qconfig_mapping)
|
qconfig_mapping = copy.deepcopy(qconfig_mapping)
|
||||||
assert(qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping))
|
assert qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping)
|
||||||
|
|
||||||
if isinstance(backend_config, Dict):
|
if isinstance(backend_config, Dict):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|
|
||||||
|
|
@ -1585,7 +1585,7 @@ def insert_observers_for_model(
|
||||||
# should resolve this inconsistency by inserting DeQuantStubs for all custom
|
# should resolve this inconsistency by inserting DeQuantStubs for all custom
|
||||||
# modules, not just for LSTM.
|
# modules, not just for LSTM.
|
||||||
_insert_dequant_stubs_for_custom_module_lstm_output(node, model, named_modules, model.graph)
|
_insert_dequant_stubs_for_custom_module_lstm_output(node, model, named_modules, model.graph)
|
||||||
if(node.target not in custom_module_names_already_swapped):
|
if node.target not in custom_module_names_already_swapped:
|
||||||
custom_module_names_already_swapped.add(node.target)
|
custom_module_names_already_swapped.add(node.target)
|
||||||
_swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config)
|
_swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config)
|
||||||
else:
|
else:
|
||||||
|
|
@ -1627,7 +1627,7 @@ def insert_observers_for_model(
|
||||||
_remove_output_observer(node, model, named_modules)
|
_remove_output_observer(node, model, named_modules)
|
||||||
|
|
||||||
if qhandler is not None and qhandler.is_custom_module():
|
if qhandler is not None and qhandler.is_custom_module():
|
||||||
if(node.target not in custom_module_names_already_swapped):
|
if node.target not in custom_module_names_already_swapped:
|
||||||
custom_module_names_already_swapped.add(node.target)
|
custom_module_names_already_swapped.add(node.target)
|
||||||
_swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config)
|
_swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config)
|
||||||
|
|
||||||
|
|
@ -1771,8 +1771,8 @@ def prepare(
|
||||||
"in a future version. Please pass in a BackendConfig instead.")
|
"in a future version. Please pass in a BackendConfig instead.")
|
||||||
backend_config = BackendConfig.from_dict(backend_config)
|
backend_config = BackendConfig.from_dict(backend_config)
|
||||||
|
|
||||||
assert(isinstance(qconfig_mapping, QConfigMapping))
|
assert isinstance(qconfig_mapping, QConfigMapping)
|
||||||
assert(isinstance(_equalization_config, QConfigMapping))
|
assert isinstance(_equalization_config, QConfigMapping)
|
||||||
qconfig_mapping = copy.deepcopy(qconfig_mapping)
|
qconfig_mapping = copy.deepcopy(qconfig_mapping)
|
||||||
_equalization_config = copy.deepcopy(_equalization_config)
|
_equalization_config = copy.deepcopy(_equalization_config)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ def binary_cmp(cmp_fun, types, args, kwargs=None, process_group=None):
|
||||||
result = True
|
result = True
|
||||||
st1 = args[0]
|
st1 = args[0]
|
||||||
st2 = args[1]
|
st2 = args[1]
|
||||||
if not(isinstance(st1, ShardedTensor) and isinstance(st2, ShardedTensor)):
|
if not (isinstance(st1, ShardedTensor) and isinstance(st2, ShardedTensor)):
|
||||||
raise TypeError(f'Both arguments to torch.{cmp_fun.__name__} need to be of type ShardedTensor')
|
raise TypeError(f'Both arguments to torch.{cmp_fun.__name__} need to be of type ShardedTensor')
|
||||||
|
|
||||||
# Verify same PG
|
# Verify same PG
|
||||||
|
|
|
||||||
|
|
@ -403,7 +403,7 @@ def generate_calc_product(constraint, counter):
|
||||||
for p in all_possibilities:
|
for p in all_possibilities:
|
||||||
p = list(p)
|
p = list(p)
|
||||||
# this tells us there is a dynamic variable
|
# this tells us there is a dynamic variable
|
||||||
contains_dyn = not(all(constraint.op == op_neq for constraint in p))
|
contains_dyn = not all(constraint.op == op_neq for constraint in p)
|
||||||
if contains_dyn:
|
if contains_dyn:
|
||||||
mid_var = [Dyn]
|
mid_var = [Dyn]
|
||||||
total_constraints = lhs + mid_var + rhs
|
total_constraints = lhs + mid_var + rhs
|
||||||
|
|
|
||||||
|
|
@ -311,7 +311,7 @@ try:
|
||||||
|
|
||||||
negation_transformed_condition_constraint = z3.Not(transformed_condition_constraint)
|
negation_transformed_condition_constraint = z3.Not(transformed_condition_constraint)
|
||||||
|
|
||||||
return z3.And([transformed, transformed_condition_constraint]),\
|
return z3.And([transformed, transformed_condition_constraint]), \
|
||||||
z3.And([transformed, negation_transformed_condition_constraint])
|
z3.And([transformed, negation_transformed_condition_constraint])
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict
|
||||||
|
|
||||||
|
|
||||||
def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):
|
def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):
|
||||||
assert(isinstance(node.target, str))
|
assert isinstance(node.target, str)
|
||||||
parent_name, name = _parent_name(node.target)
|
parent_name, name = _parent_name(node.target)
|
||||||
modules[node.target] = new_module
|
modules[node.target] = new_module
|
||||||
setattr(modules[parent_name], name, new_module)
|
setattr(modules[parent_name], name, new_module)
|
||||||
|
|
@ -133,11 +133,11 @@ def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]):
|
||||||
old_modules: Dict[nn.Module, nn.Module] = {}
|
old_modules: Dict[nn.Module, nn.Module] = {}
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
if node.op == 'call_module':
|
if node.op == 'call_module':
|
||||||
assert(isinstance(node.target, str))
|
assert isinstance(node.target, str)
|
||||||
cur_module = modules[node.target]
|
cur_module = modules[node.target]
|
||||||
if type(cur_module) in mkldnn_map:
|
if type(cur_module) in mkldnn_map:
|
||||||
new_module = mkldnn_map[type(cur_module)](cur_module, torch.float)
|
new_module = mkldnn_map[type(cur_module)](cur_module, torch.float)
|
||||||
assert(isinstance(new_module, nn.Module))
|
assert isinstance(new_module, nn.Module)
|
||||||
old_modules[new_module] = copy.deepcopy(cur_module)
|
old_modules[new_module] = copy.deepcopy(cur_module)
|
||||||
replace_node_module(node, modules, new_module)
|
replace_node_module(node, modules, new_module)
|
||||||
return old_modules
|
return old_modules
|
||||||
|
|
@ -149,7 +149,7 @@ def reset_modules(nodes: List[fx.Node], modules: Dict[str, nn.Module], old_modul
|
||||||
"""
|
"""
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
if node.op == 'call_module':
|
if node.op == 'call_module':
|
||||||
assert(isinstance(node.target, str))
|
assert (isinstance(node.target, str))
|
||||||
cur_module = modules[node.target]
|
cur_module = modules[node.target]
|
||||||
if cur_module in old_modules:
|
if cur_module in old_modules:
|
||||||
replace_node_module(node, modules, old_modules[cur_module])
|
replace_node_module(node, modules, old_modules[cur_module])
|
||||||
|
|
@ -220,7 +220,7 @@ class UnionFind:
|
||||||
par = self.parent[v]
|
par = self.parent[v]
|
||||||
if v == par:
|
if v == par:
|
||||||
return v
|
return v
|
||||||
assert(par is not None)
|
assert par is not None
|
||||||
self.parent[v] = self.find(par)
|
self.parent[v] = self.find(par)
|
||||||
return cast(int, self.parent[v])
|
return cast(int, self.parent[v])
|
||||||
|
|
||||||
|
|
@ -294,8 +294,8 @@ def optimize_for_inference(
|
||||||
supports_mkldnn = MklSupport.YES
|
supports_mkldnn = MklSupport.YES
|
||||||
sample_parameter = next(cur_module.parameters(), None)
|
sample_parameter = next(cur_module.parameters(), None)
|
||||||
if sample_parameter is not None:
|
if sample_parameter is not None:
|
||||||
assert(sample_parameter.dtype == torch.float), "this pass is only for torch.float modules"
|
assert sample_parameter.dtype == torch.float, "this pass is only for torch.float modules"
|
||||||
assert(sample_parameter.device == torch.device('cpu')), "this pass is only for CPU modules"
|
assert sample_parameter.device == torch.device('cpu'), "this pass is only for CPU modules"
|
||||||
elif node.op == 'call_function':
|
elif node.op == 'call_function':
|
||||||
if node.target in mkldnn_supported:
|
if node.target in mkldnn_supported:
|
||||||
supports_mkldnn = MklSupport.YES
|
supports_mkldnn = MklSupport.YES
|
||||||
|
|
@ -360,14 +360,14 @@ def optimize_for_inference(
|
||||||
node.start_color = cur_idx
|
node.start_color = cur_idx
|
||||||
uf.make_set(cur_idx)
|
uf.make_set(cur_idx)
|
||||||
elif node.op == 'call_method' and node.target == 'to_dense':
|
elif node.op == 'call_method' and node.target == 'to_dense':
|
||||||
assert(get_color(node.args[0]) is not None)
|
assert get_color(node.args[0]) is not None
|
||||||
node.end_color = get_color(node.args[0])
|
node.end_color = get_color(node.args[0])
|
||||||
else:
|
else:
|
||||||
cur_colors = [get_color(i) for i in node.all_input_nodes if isinstance(i, fx.Node) if get_color(i) is not None]
|
cur_colors = [get_color(i) for i in node.all_input_nodes if isinstance(i, fx.Node) if get_color(i) is not None]
|
||||||
|
|
||||||
if len(cur_colors) == 0:
|
if len(cur_colors) == 0:
|
||||||
continue
|
continue
|
||||||
assert(not any(i is None for i in cur_colors))
|
assert not any(i is None for i in cur_colors)
|
||||||
cur_colors = sorted(cur_colors)
|
cur_colors = sorted(cur_colors)
|
||||||
node.color = cur_colors[0]
|
node.color = cur_colors[0]
|
||||||
for other_color in cur_colors[1:]:
|
for other_color in cur_colors[1:]:
|
||||||
|
|
|
||||||
|
|
@ -1359,7 +1359,7 @@ class DimConstraints:
|
||||||
self.raise_inconsistencies()
|
self.raise_inconsistencies()
|
||||||
# as long as there are symbols with equalities, solve for them
|
# as long as there are symbols with equalities, solve for them
|
||||||
# NOTE(avik): this is guaranteed to terminate (#iterations <= #symbols)
|
# NOTE(avik): this is guaranteed to terminate (#iterations <= #symbols)
|
||||||
while(self._symbols_with_equalities):
|
while self._symbols_with_equalities:
|
||||||
s = self._symbols_with_equalities.pop()
|
s = self._symbols_with_equalities.pop()
|
||||||
exprs = self._univariate_inequalities.pop(s)
|
exprs = self._univariate_inequalities.pop(s)
|
||||||
solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s)
|
solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s)
|
||||||
|
|
|
||||||
|
|
@ -670,7 +670,7 @@ class _PyTreeCodeGen(CodeGen):
|
||||||
return out
|
return out
|
||||||
if not isinstance(out, (list, tuple)):
|
if not isinstance(out, (list, tuple)):
|
||||||
out = [out]
|
out = [out]
|
||||||
assert(self.pytree_info.out_spec is not None)
|
assert self.pytree_info.out_spec is not None
|
||||||
return pytree.tree_unflatten(out, self.pytree_info.out_spec)
|
return pytree.tree_unflatten(out, self.pytree_info.out_spec)
|
||||||
|
|
||||||
def gen_fn_def(self, free_vars, maybe_return_annotation):
|
def gen_fn_def(self, free_vars, maybe_return_annotation):
|
||||||
|
|
|
||||||
|
|
@ -82,7 +82,7 @@ def _type_repr(obj):
|
||||||
return obj.__qualname__
|
return obj.__qualname__
|
||||||
return f'{obj.__module__}.{obj.__qualname__}'
|
return f'{obj.__module__}.{obj.__qualname__}'
|
||||||
if obj is ...:
|
if obj is ...:
|
||||||
return('...')
|
return '...'
|
||||||
if isinstance(obj, types.FunctionType):
|
if isinstance(obj, types.FunctionType):
|
||||||
return obj.__name__
|
return obj.__name__
|
||||||
return repr(obj)
|
return repr(obj)
|
||||||
|
|
|
||||||
|
|
@ -698,7 +698,7 @@ class _MinimizerBase:
|
||||||
if self.settings.traverse_method == "accumulate":
|
if self.settings.traverse_method == "accumulate":
|
||||||
return self._accumulate_traverse(nodes)
|
return self._accumulate_traverse(nodes)
|
||||||
|
|
||||||
if(self.settings.traverse_method == "skip"):
|
if self.settings.traverse_method == "skip":
|
||||||
if (skip_nodes is None):
|
if (skip_nodes is None):
|
||||||
raise RuntimeError("'skip_nodes' can't be None when 'traverse_method' is 'skip'.")
|
raise RuntimeError("'skip_nodes' can't be None when 'traverse_method' is 'skip'.")
|
||||||
return self._skip_traverse(nodes, skip_nodes)
|
return self._skip_traverse(nodes, skip_nodes)
|
||||||
|
|
|
||||||
|
|
@ -281,7 +281,7 @@ def split_module(
|
||||||
|
|
||||||
if assert_monotonically_increasing:
|
if assert_monotonically_increasing:
|
||||||
pid = split_callback(node)
|
pid = split_callback(node)
|
||||||
assert highest_partition <= pid,\
|
assert highest_partition <= pid, \
|
||||||
("autocast or set_grad_enabled require monotonically increasing partitions:"
|
("autocast or set_grad_enabled require monotonically increasing partitions:"
|
||||||
f"highest: {highest_partition}, this node's: {pid}")
|
f"highest: {highest_partition}, this node's: {pid}")
|
||||||
highest_partition = pid
|
highest_partition = pid
|
||||||
|
|
|
||||||
|
|
@ -510,7 +510,7 @@ class ParameterProxy(Proxy):
|
||||||
"""
|
"""
|
||||||
def __init__(self, tracer: TracerBase, node: Node, name, param):
|
def __init__(self, tracer: TracerBase, node: Node, name, param):
|
||||||
super().__init__(node, tracer)
|
super().__init__(node, tracer)
|
||||||
assert(isinstance(param, torch.nn.Parameter))
|
assert isinstance(param, torch.nn.Parameter)
|
||||||
self.param = param
|
self.param = param
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -817,7 +817,7 @@ class StmtBuilder(Builder):
|
||||||
if is_torch_jit_ignore_context_manager(stmt):
|
if is_torch_jit_ignore_context_manager(stmt):
|
||||||
if not _IS_ASTUNPARSE_INSTALLED:
|
if not _IS_ASTUNPARSE_INSTALLED:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"torch.jit._IgnoreContextManager requires installing Python library `astunparse`,\
|
"torch.jit._IgnoreContextManager requires installing Python library `astunparse`, \
|
||||||
please install it in your Python environment"
|
please install it in your Python environment"
|
||||||
)
|
)
|
||||||
assign_ast = build_ignore_context_manager(ctx, stmt)
|
assign_ast = build_ignore_context_manager(ctx, stmt)
|
||||||
|
|
|
||||||
|
|
@ -2571,7 +2571,7 @@ new_module_tests = [
|
||||||
.dim_feedforward(8)
|
.dim_feedforward(8)
|
||||||
.dropout(0.0)
|
.dropout(0.0)
|
||||||
.activation(torch::kReLU)''',
|
.activation(torch::kReLU)''',
|
||||||
input_fn=lambda:(torch.rand(3, 3, 4), torch.rand(2, 3, 4), torch.rand(3, 3)),
|
input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4), torch.rand(3, 3)),
|
||||||
check_gradgrad=False,
|
check_gradgrad=False,
|
||||||
desc='multilayer_coder',
|
desc='multilayer_coder',
|
||||||
with_tf32=True,
|
with_tf32=True,
|
||||||
|
|
|
||||||
|
|
@ -125,7 +125,7 @@ def _snr(x, x_hat):
|
||||||
signal, noise, SNR(in dB): Either floats or a nested list of floats
|
signal, noise, SNR(in dB): Either floats or a nested list of floats
|
||||||
"""
|
"""
|
||||||
if isinstance(x, (list, tuple)):
|
if isinstance(x, (list, tuple)):
|
||||||
assert(len(x) == len(x_hat))
|
assert len(x) == len(x_hat)
|
||||||
res = []
|
res = []
|
||||||
for idx in range(len(x)):
|
for idx in range(len(x)):
|
||||||
res.append(_snr(x[idx], x_hat[idx]))
|
res.append(_snr(x[idx], x_hat[idx]))
|
||||||
|
|
|
||||||
|
|
@ -1352,7 +1352,7 @@ def skipIfTorchDynamo(msg="test doesn't currently work with dynamo"):
|
||||||
fn(*args, **kwargs)
|
fn(*args, **kwargs)
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
assert(isinstance(fn, type))
|
assert isinstance(fn, type)
|
||||||
if TEST_WITH_TORCHDYNAMO: # noqa: F821
|
if TEST_WITH_TORCHDYNAMO: # noqa: F821
|
||||||
fn.__unittest_skip__ = True
|
fn.__unittest_skip__ = True
|
||||||
fn.__unittest_skip_why__ = msg
|
fn.__unittest_skip_why__ = msg
|
||||||
|
|
@ -1374,7 +1374,7 @@ def skipIfTorchInductor(msg="test doesn't currently work with torchinductor",
|
||||||
fn(*args, **kwargs)
|
fn(*args, **kwargs)
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
assert(isinstance(fn, type))
|
assert isinstance(fn, type)
|
||||||
if condition:
|
if condition:
|
||||||
fn.__unittest_skip__ = True
|
fn.__unittest_skip__ = True
|
||||||
fn.__unittest_skip_why__ = msg
|
fn.__unittest_skip_why__ = msg
|
||||||
|
|
@ -1440,7 +1440,7 @@ def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT exe
|
||||||
fn(*args, **kwargs)
|
fn(*args, **kwargs)
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
assert(isinstance(fn, type))
|
assert isinstance(fn, type)
|
||||||
if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
|
if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
|
||||||
fn.__unittest_skip__ = True
|
fn.__unittest_skip__ = True
|
||||||
fn.__unittest_skip_why__ = msg
|
fn.__unittest_skip_why__ = msg
|
||||||
|
|
@ -2080,7 +2080,7 @@ class CudaMemoryLeakCheck:
|
||||||
if driver_mem_allocated > self.driver_befores[i]:
|
if driver_mem_allocated > self.driver_befores[i]:
|
||||||
driver_discrepancy = True
|
driver_discrepancy = True
|
||||||
|
|
||||||
if not(caching_allocator_discrepancy or driver_discrepancy):
|
if not (caching_allocator_discrepancy or driver_discrepancy):
|
||||||
# Leak was false positive, exit loop
|
# Leak was false positive, exit loop
|
||||||
discrepancy_detected = False
|
discrepancy_detected = False
|
||||||
break
|
break
|
||||||
|
|
|
||||||
|
|
@ -159,10 +159,10 @@ Example:
|
||||||
@st.composite
|
@st.composite
|
||||||
def array_shapes(draw, min_dims=1, max_dims=None, min_side=1, max_side=None, max_numel=None):
|
def array_shapes(draw, min_dims=1, max_dims=None, min_side=1, max_side=None, max_numel=None):
|
||||||
"""Return a strategy for array shapes (tuples of int >= 1)."""
|
"""Return a strategy for array shapes (tuples of int >= 1)."""
|
||||||
assert(min_dims < 32)
|
assert min_dims < 32
|
||||||
if max_dims is None:
|
if max_dims is None:
|
||||||
max_dims = min(min_dims + 2, 32)
|
max_dims = min(min_dims + 2, 32)
|
||||||
assert(max_dims < 32)
|
assert max_dims < 32
|
||||||
if max_side is None:
|
if max_side is None:
|
||||||
max_side = min_side + 5
|
max_side = min_side + 5
|
||||||
candidate = st.lists(st.integers(min_side, max_side), min_size=min_dims, max_size=max_dims)
|
candidate = st.lists(st.integers(min_side, max_side), min_size=min_dims, max_size=max_dims)
|
||||||
|
|
@ -334,7 +334,7 @@ def tensor_conv(
|
||||||
# Resolve the tensors
|
# Resolve the tensors
|
||||||
if qparams is not None:
|
if qparams is not None:
|
||||||
if isinstance(qparams, (list, tuple)):
|
if isinstance(qparams, (list, tuple)):
|
||||||
assert(len(qparams) == 3), "Need 3 qparams for X, w, b"
|
assert len(qparams) == 3, "Need 3 qparams for X, w, b"
|
||||||
else:
|
else:
|
||||||
qparams = [qparams] * 3
|
qparams = [qparams] * 3
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -331,7 +331,7 @@ def value_to_literal(value):
|
||||||
def get_call(method_name, func_type, args, kwargs):
|
def get_call(method_name, func_type, args, kwargs):
|
||||||
kwargs_str = ', '.join([k + '=' + value_to_literal(v) for k, v in kwargs.items()])
|
kwargs_str = ', '.join([k + '=' + value_to_literal(v) for k, v in kwargs.items()])
|
||||||
self_arg = args[0]
|
self_arg = args[0]
|
||||||
if(func_type == 'method'):
|
if func_type == 'method':
|
||||||
args = args[1:]
|
args = args[1:]
|
||||||
|
|
||||||
argument_str = ', '.join(args)
|
argument_str = ', '.join(args)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import torch
|
import torch
|
||||||
import torchgen
|
import torchgen
|
||||||
from torch._C import _len_torch_dispatch_stack, _get_dispatch_stack_at,\
|
from torch._C import _len_torch_dispatch_stack, _get_dispatch_stack_at, \
|
||||||
_pop_torch_dispatch_stack, _push_on_torch_dispatch_stack, DispatchKey
|
_pop_torch_dispatch_stack, _push_on_torch_dispatch_stack, DispatchKey
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -117,10 +117,10 @@ def bundle_inputs(
|
||||||
# Fortunately theres a function in _recursive that does exactly that conversion.
|
# Fortunately theres a function in _recursive that does exactly that conversion.
|
||||||
cloned_module = wrap_cpp_module(clone)
|
cloned_module = wrap_cpp_module(clone)
|
||||||
if isinstance(inputs, dict):
|
if isinstance(inputs, dict):
|
||||||
assert(isinstance(info, dict) or info is None)
|
assert isinstance(info, dict) or info is None
|
||||||
augment_many_model_functions_with_bundled_inputs(cloned_module, inputs, _receive_inflate_expr, info)
|
augment_many_model_functions_with_bundled_inputs(cloned_module, inputs, _receive_inflate_expr, info)
|
||||||
else:
|
else:
|
||||||
assert(isinstance(info, list) or info is None)
|
assert isinstance(info, list) or info is None
|
||||||
augment_model_with_bundled_inputs(cloned_module, inputs, _receive_inflate_expr, info)
|
augment_model_with_bundled_inputs(cloned_module, inputs, _receive_inflate_expr, info)
|
||||||
return cloned_module
|
return cloned_module
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -304,7 +304,7 @@ def get_cpu_info(run_lambda):
|
||||||
if get_platform() == 'linux':
|
if get_platform() == 'linux':
|
||||||
rc, out, err = run_lambda('lscpu')
|
rc, out, err = run_lambda('lscpu')
|
||||||
elif get_platform() == 'win32':
|
elif get_platform() == 'win32':
|
||||||
rc, out, err = run_lambda('wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID,\
|
rc, out, err = run_lambda('wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID, \
|
||||||
CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE')
|
CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE')
|
||||||
elif get_platform() == 'darwin':
|
elif get_platform() == 'darwin':
|
||||||
rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string")
|
rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string")
|
||||||
|
|
|
||||||
|
|
@ -329,7 +329,7 @@ class FlopCounterMode(TorchDispatchMode):
|
||||||
class PushState(torch.autograd.Function):
|
class PushState(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, *args):
|
def forward(ctx, *args):
|
||||||
assert(self.parents[-1] == name)
|
assert self.parents[-1] == name
|
||||||
self.parents.pop()
|
self.parents.pop()
|
||||||
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
|
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
|
||||||
return args
|
return args
|
||||||
|
|
@ -351,7 +351,7 @@ class FlopCounterMode(TorchDispatchMode):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, *grad_outs):
|
def backward(ctx, *grad_outs):
|
||||||
assert(self.parents[-1] == name)
|
assert self.parents[-1] == name
|
||||||
self.parents.pop()
|
self.parents.pop()
|
||||||
return grad_outs
|
return grad_outs
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -136,9 +136,9 @@ class MkldnnBatchNorm(torch.jit.ScriptModule):
|
||||||
def __init__(self, dense_module):
|
def __init__(self, dense_module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
assert(not dense_module.training)
|
assert not dense_module.training
|
||||||
assert(dense_module.track_running_stats)
|
assert dense_module.track_running_stats
|
||||||
assert(dense_module.affine)
|
assert dense_module.affine
|
||||||
|
|
||||||
if dense_module.momentum is None:
|
if dense_module.momentum is None:
|
||||||
self.exponential_average_factor = 0.0
|
self.exponential_average_factor = 0.0
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user