mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
No actual change, just remove variable contain Tensors from global scope (#143225)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143225 Approved by: https://github.com/ezyang
This commit is contained in:
parent
afa313e669
commit
792f1c47e9
|
|
@ -20,8 +20,8 @@ from torch.testing._internal.common_utils import (
|
|||
from torch.testing._internal.jit_metaprogramming_utils import (
|
||||
get_all_nn_module_tests,
|
||||
get_nn_functional_compiled_fn_and_inputs,
|
||||
get_nn_functional_tests,
|
||||
get_nn_mod_test_name,
|
||||
nn_functional_tests,
|
||||
try_get_nn_module_compiled_mod_and_inputs,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import enable_profiling_mode, JitTestCase
|
||||
|
|
@ -70,7 +70,7 @@ class TestComplexity(JitTestCase):
|
|||
def test_generated_functional_tests(self):
|
||||
with enable_profiling_mode():
|
||||
stats = [("Name", "Ifs/Loops", "non-tensor ops")]
|
||||
for test in nn_functional_tests:
|
||||
for test in get_nn_functional_tests():
|
||||
test_name = test[0]
|
||||
|
||||
fn, inputs = get_nn_functional_compiled_fn_and_inputs(*test)
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ if not common.IS_ARM64:
|
|||
(sample_module.module_tests, common_nn.NewModuleTest),
|
||||
(sample_functional.functional_tests, common_nn.NewModuleTest),
|
||||
(common_nn.module_tests, common_nn.NewModuleTest),
|
||||
(common_nn.new_module_tests, common_nn.NewModuleTest),
|
||||
(common_nn.get_new_module_tests(), common_nn.NewModuleTest),
|
||||
(common_nn.criterion_tests, common_nn.CriterionTest),
|
||||
]:
|
||||
for test_params_dict in test_params_dicts:
|
||||
|
|
|
|||
|
|
@ -25,7 +25,11 @@ from torch.testing._internal.common_device_type import (
|
|||
)
|
||||
from torch.testing._internal.common_methods_invocations import op_db, SampleInput
|
||||
from torch.testing._internal.common_modules import module_db, modules
|
||||
from torch.testing._internal.common_nn import module_tests, new_module_tests, TestBase
|
||||
from torch.testing._internal.common_nn import (
|
||||
get_new_module_tests,
|
||||
module_tests,
|
||||
TestBase,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
freeze_rng_state,
|
||||
make_tensor,
|
||||
|
|
@ -1011,7 +1015,7 @@ def filter_supported_tests(t):
|
|||
# TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests
|
||||
# These currently use the legacy nn tests
|
||||
supported_tests = [
|
||||
t for t in module_tests + new_module_tests if filter_supported_tests(t)
|
||||
t for t in module_tests + get_new_module_tests() if filter_supported_tests(t)
|
||||
]
|
||||
for test_param in supported_tests:
|
||||
if "constructor" not in test_param:
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ from torch.testing._internal.common_device_type import (
|
|||
ops,
|
||||
)
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
from torch.testing._internal.common_nn import module_tests, new_module_tests
|
||||
from torch.testing._internal.common_nn import module_tests, get_new_module_tests
|
||||
from torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
import torch.utils._pytree as pytree
|
||||
|
|
@ -1006,7 +1006,7 @@ terrible spacing
|
|||
Exhaustively test `Node.normalized_arguments` on all standard
|
||||
torch.nn Module classes
|
||||
"""
|
||||
for test_params in module_tests + new_module_tests:
|
||||
for test_params in module_tests + get_new_module_tests():
|
||||
if "constructor" not in test_params:
|
||||
constructor = getattr(torch.nn, test_params["module_name"])
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -107,10 +107,10 @@ from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, dis
|
|||
from torch.testing._internal.jit_metaprogramming_utils import (
|
||||
get_script_args,
|
||||
create_input, unpack_variables,
|
||||
additional_module_tests, EXCLUDE_SCRIPT_MODULES,
|
||||
get_all_nn_module_tests, EXCLUDE_SCRIPT_MODULES,
|
||||
get_nn_module_name_from_kwargs, get_nn_mod_test_name, script_method_template)
|
||||
|
||||
from torch.testing._internal.common_nn import module_tests, new_module_tests, criterion_tests
|
||||
from torch.testing._internal.common_nn import criterion_tests
|
||||
|
||||
# For testing truediv in python 2
|
||||
from torch.testing._internal.test_module.future_div import div_int_future, div_float_future
|
||||
|
|
@ -16247,7 +16247,7 @@ class TestProducerVersion(TestCase):
|
|||
# issue gh-32561
|
||||
self.assertTrue(torch.__version__.startswith(torch.onnx.producer_version))
|
||||
|
||||
for test in module_tests + new_module_tests + additional_module_tests:
|
||||
for test in get_all_nn_module_tests():
|
||||
add_nn_module_test(**test)
|
||||
|
||||
for test in criterion_tests:
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ from torch.testing._internal.common_utils import freeze_rng_state, run_tests, Te
|
|||
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, PLATFORM_SUPPORTS_FLASH_ATTENTION
|
||||
from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \
|
||||
module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \
|
||||
ctcloss_reference, new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input
|
||||
ctcloss_reference, get_new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input
|
||||
from torch.testing._internal.common_device_type import dtypesIfMPS, instantiate_device_type_tests, dtypes, \
|
||||
dtypesIfCUDA, precisionOverride, skipCUDAIfCudnnVersionLessThan, onlyCUDA, onlyCPU, \
|
||||
skipCUDAIfRocm, skipCUDAIf, skipCUDAIfNotRocm, \
|
||||
|
|
@ -7332,7 +7332,7 @@ def add_test(test, decorator=None):
|
|||
else:
|
||||
add(cuda_test_name, with_tf32_off)
|
||||
|
||||
for test_params in module_tests + new_module_tests:
|
||||
for test_params in module_tests + get_new_module_tests():
|
||||
# TODO: CUDA is not implemented yet
|
||||
if 'constructor' not in test_params:
|
||||
name = test_params.pop('module_name')
|
||||
|
|
|
|||
|
|
@ -19,8 +19,6 @@ from torch.fx import Graph, GraphModule, Node
|
|||
from torch.fx.subgraph_rewriter import replace_pattern_with_filters, ReplacedPatterns
|
||||
|
||||
from .utils import (
|
||||
_conv1d_bn_example_inputs,
|
||||
_conv2d_bn_example_inputs,
|
||||
_get_aten_graph_module_for_pattern,
|
||||
_is_bn_node,
|
||||
_is_conv_or_conv_transpose_node,
|
||||
|
|
@ -35,27 +33,6 @@ if TYPE_CHECKING:
|
|||
__all__ = [] # type: ignore[var-annotated]
|
||||
|
||||
|
||||
# Example inputs for quantized and folded conv-bn1d patterns used in convert
|
||||
_quantized_conv1d_bn_example_inputs = (
|
||||
torch.randn(1, 1, 3), # x
|
||||
torch.randn(1, 1, 1), # conv_weight
|
||||
torch.randn(1), # bn_weight
|
||||
torch.randn(1), # bn_bias
|
||||
torch.randn(1), # bn_running_mean
|
||||
torch.randn(1), # bn_running_var
|
||||
)
|
||||
|
||||
# Example inputs for quantized and folded conv-bn2d patterns used in convert
|
||||
_quantized_conv2d_bn_example_inputs = (
|
||||
torch.randn(1, 1, 3, 3), # x
|
||||
torch.randn(1, 1, 1, 1), # conv_weight
|
||||
torch.randn(1), # bn_weight
|
||||
torch.randn(1), # bn_bias
|
||||
torch.randn(1), # bn_running_mean
|
||||
torch.randn(1), # bn_running_var
|
||||
)
|
||||
|
||||
|
||||
def _get_quantized_conv_bn_example_inputs_kwargs(
|
||||
is_per_channel: bool,
|
||||
has_bias: bool,
|
||||
|
|
@ -631,6 +608,28 @@ def _update_special_qspecs_after_replacement(
|
|||
|
||||
|
||||
def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
|
||||
# Example inputs for conv-bn1d patterns
|
||||
_conv1d_bn_example_inputs = (
|
||||
torch.randn(1, 1, 3), # x
|
||||
torch.randn(1, 1, 1), # conv_weight
|
||||
torch.randn(1), # conv_bias
|
||||
torch.randn(1), # bn_weight
|
||||
torch.randn(1), # bn_bias
|
||||
torch.randn(1), # bn_running_mean
|
||||
torch.randn(1), # bn_running_var
|
||||
)
|
||||
|
||||
# Example inputs for conv-bn2d patterns
|
||||
_conv2d_bn_example_inputs = (
|
||||
torch.randn(1, 1, 3, 3), # x
|
||||
torch.randn(1, 1, 1, 1), # conv_weight
|
||||
torch.randn(1), # conv_bias
|
||||
torch.randn(1), # bn_weight
|
||||
torch.randn(1), # bn_bias
|
||||
torch.randn(1), # bn_running_mean
|
||||
torch.randn(1), # bn_running_var
|
||||
)
|
||||
|
||||
has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
|
||||
if not has_bn:
|
||||
return m
|
||||
|
|
@ -859,6 +858,26 @@ def _copy_over_q_dq_args(original_node: Node, replacement_node: Node):
|
|||
|
||||
|
||||
def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
|
||||
# Example inputs for quantized and folded conv-bn1d patterns used in convert
|
||||
_quantized_conv1d_bn_example_inputs = (
|
||||
torch.randn(1, 1, 3), # x
|
||||
torch.randn(1, 1, 1), # conv_weight
|
||||
torch.randn(1), # bn_weight
|
||||
torch.randn(1), # bn_bias
|
||||
torch.randn(1), # bn_running_mean
|
||||
torch.randn(1), # bn_running_var
|
||||
)
|
||||
|
||||
# Example inputs for quantized and folded conv-bn2d patterns used in convert
|
||||
_quantized_conv2d_bn_example_inputs = (
|
||||
torch.randn(1, 1, 3, 3), # x
|
||||
torch.randn(1, 1, 1, 1), # conv_weight
|
||||
torch.randn(1), # bn_weight
|
||||
torch.randn(1), # bn_bias
|
||||
torch.randn(1), # bn_running_mean
|
||||
torch.randn(1), # bn_running_var
|
||||
)
|
||||
|
||||
has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
|
||||
if not has_bn:
|
||||
return m
|
||||
|
|
|
|||
|
|
@ -22,25 +22,6 @@ __all__ = [
|
|||
]
|
||||
|
||||
|
||||
_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (2, 5), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-127], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
|
||||
def _qdq_quantized_linear(
|
||||
x_i8,
|
||||
x_scale,
|
||||
|
|
@ -129,20 +110,6 @@ def _reference_quantized_linear(
|
|||
return out_i8
|
||||
|
||||
|
||||
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
|
||||
torch.randn((2, 5), dtype=torch.float),
|
||||
-128,
|
||||
127,
|
||||
torch.finfo(torch.float32).eps,
|
||||
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-127], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
)
|
||||
|
||||
|
||||
def _qdq_dynamic_quantized_linear(
|
||||
x_fp32,
|
||||
x_quant_min,
|
||||
|
|
@ -223,25 +190,6 @@ def _reference_dynamic_quantized_linear(
|
|||
return out_fp32
|
||||
|
||||
|
||||
_QUANTIZED_CONV2d_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-127], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
|
||||
def _qdq_quantized_conv2d(
|
||||
x_i8,
|
||||
x_scale,
|
||||
|
|
@ -375,20 +323,6 @@ def _reference_quantized_conv2d(
|
|||
return out_i8
|
||||
|
||||
|
||||
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
|
||||
def _qdq_quantized_add_relu(
|
||||
x_i8,
|
||||
x_scale,
|
||||
|
|
@ -518,19 +452,6 @@ def _reference_quantized_add(
|
|||
return out_i8
|
||||
|
||||
|
||||
_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
|
||||
def _qdq_quantized_max_pool2d(
|
||||
x_i8,
|
||||
x_scale,
|
||||
|
|
@ -587,15 +508,6 @@ def _reference_quantized_max_pool2d(
|
|||
return out_i8
|
||||
|
||||
|
||||
_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
|
||||
torch.randn(1, 3, 3, 3, dtype=torch.float),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
|
||||
def _quantize_per_tensor_int8(x_fp32, scale, zero_point, quant_min, quant_max):
|
||||
x = torch.ops.quantized_decomposed.quantize_per_tensor(
|
||||
x_fp32, scale, zero_point, quant_min, quant_max, torch.int8
|
||||
|
|
@ -619,15 +531,6 @@ def _reference_quantize_per_tensor_int8(
|
|||
return x
|
||||
|
||||
|
||||
_DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
|
||||
def _dequantize_per_tensor_int8(x_i8, scale, zero_point, quant_min, quant_max):
|
||||
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
||||
x_i8, scale, zero_point, quant_min, quant_max, torch.int8
|
||||
|
|
@ -648,16 +551,6 @@ def _reference_dequantize_per_tensor_int8(
|
|||
return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
|
||||
|
||||
|
||||
_QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
|
||||
torch.randn(1, 3, 3, 3, dtype=torch.float),
|
||||
torch.randn(3, dtype=torch.float),
|
||||
torch.zeros(3, dtype=torch.int),
|
||||
1,
|
||||
-128,
|
||||
127,
|
||||
)
|
||||
|
||||
|
||||
def _quantize_per_channel_int8(
|
||||
x_fp32, scales, zero_points, ch_axis, quant_min, quant_max
|
||||
):
|
||||
|
|
@ -678,16 +571,6 @@ def _reference_quantize_per_channel_int8(
|
|||
return out_i32.to(torch.int8)
|
||||
|
||||
|
||||
_DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(3, dtype=torch.float),
|
||||
torch.zeros(3, dtype=torch.int),
|
||||
1,
|
||||
-128,
|
||||
127,
|
||||
)
|
||||
|
||||
|
||||
def _dequantize_per_channel_int8(
|
||||
x_i8, scales, zero_points, ch_axis, quant_min, quant_max
|
||||
):
|
||||
|
|
@ -733,6 +616,115 @@ class _RewriteInfo:
|
|||
replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
|
||||
|
||||
|
||||
def reference_representation_rewrite(model: GraphModule) -> GraphModule:
|
||||
_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (2, 5), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-127], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
|
||||
torch.randn((2, 5), dtype=torch.float),
|
||||
-128,
|
||||
127,
|
||||
torch.finfo(torch.float32).eps,
|
||||
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-127], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
)
|
||||
|
||||
_QUANTIZED_CONV2d_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-127], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
|
||||
torch.randn(1, 3, 3, 3, dtype=torch.float),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
_DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
_QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
|
||||
torch.randn(1, 3, 3, 3, dtype=torch.float),
|
||||
torch.randn(3, dtype=torch.float),
|
||||
torch.zeros(3, dtype=torch.int),
|
||||
1,
|
||||
-128,
|
||||
127,
|
||||
)
|
||||
|
||||
_DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(3, dtype=torch.float),
|
||||
torch.zeros(3, dtype=torch.int),
|
||||
1,
|
||||
-128,
|
||||
127,
|
||||
)
|
||||
|
||||
_REWRITE_INFO_LIST = [
|
||||
_RewriteInfo(
|
||||
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
|
||||
|
|
@ -804,8 +796,6 @@ _REWRITE_INFO_LIST = [
|
|||
),
|
||||
]
|
||||
|
||||
|
||||
def reference_representation_rewrite(model: GraphModule) -> GraphModule:
|
||||
remove_tensor_overload_for_qdq_ops(model)
|
||||
from torch._export import gm_using_training_ir
|
||||
|
||||
|
|
|
|||
|
|
@ -33,28 +33,6 @@ _DEQUANTIZE_OPS = [
|
|||
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
||||
]
|
||||
|
||||
# Example inputs for conv-bn1d patterns
|
||||
_conv1d_bn_example_inputs = (
|
||||
torch.randn(1, 1, 3), # x
|
||||
torch.randn(1, 1, 1), # conv_weight
|
||||
torch.randn(1), # conv_bias
|
||||
torch.randn(1), # bn_weight
|
||||
torch.randn(1), # bn_bias
|
||||
torch.randn(1), # bn_running_mean
|
||||
torch.randn(1), # bn_running_var
|
||||
)
|
||||
|
||||
# Example inputs for conv-bn2d patterns
|
||||
_conv2d_bn_example_inputs = (
|
||||
torch.randn(1, 1, 3, 3), # x
|
||||
torch.randn(1, 1, 1, 1), # conv_weight
|
||||
torch.randn(1), # conv_bias
|
||||
torch.randn(1), # bn_weight
|
||||
torch.randn(1), # bn_bias
|
||||
torch.randn(1), # bn_running_mean
|
||||
torch.randn(1), # bn_running_var
|
||||
)
|
||||
|
||||
|
||||
def _is_connected(source: torch.fx.Node, dest: torch.fx.Node) -> bool:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -10,8 +10,6 @@ from torch._subclasses import FakeTensor
|
|||
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
|
||||
from torch.ao.quantization.pt2e.export_utils import _WrapperModule
|
||||
from torch.ao.quantization.pt2e.utils import (
|
||||
_conv1d_bn_example_inputs,
|
||||
_conv2d_bn_example_inputs,
|
||||
_get_aten_graph_module_for_pattern,
|
||||
_is_conv_node,
|
||||
_is_conv_transpose_node,
|
||||
|
|
@ -487,6 +485,28 @@ def _do_annotate_conv_bn(
|
|||
for the following names: "input", "conv", "weight", "bias", and "output".
|
||||
"""
|
||||
|
||||
# Example inputs for conv-bn1d patterns
|
||||
_conv1d_bn_example_inputs = (
|
||||
torch.randn(1, 1, 3), # x
|
||||
torch.randn(1, 1, 1), # conv_weight
|
||||
torch.randn(1), # conv_bias
|
||||
torch.randn(1), # bn_weight
|
||||
torch.randn(1), # bn_bias
|
||||
torch.randn(1), # bn_running_mean
|
||||
torch.randn(1), # bn_running_var
|
||||
)
|
||||
|
||||
# Example inputs for conv-bn2d patterns
|
||||
_conv2d_bn_example_inputs = (
|
||||
torch.randn(1, 1, 3, 3), # x
|
||||
torch.randn(1, 1, 1, 1), # conv_weight
|
||||
torch.randn(1), # conv_bias
|
||||
torch.randn(1), # bn_weight
|
||||
torch.randn(1), # bn_bias
|
||||
torch.randn(1), # bn_running_mean
|
||||
torch.randn(1), # bn_running_var
|
||||
)
|
||||
|
||||
def get_pattern(conv_fn: Callable, relu_is_inplace: bool):
|
||||
def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv):
|
||||
conv = conv_fn(x, conv_weight, conv_bias)
|
||||
|
|
|
|||
|
|
@ -1076,6 +1076,7 @@ def single_batch_reference_fn(input, parameters, module):
|
|||
return module(*single_batch_input).squeeze(0)
|
||||
|
||||
|
||||
def get_new_module_tests():
|
||||
new_module_tests = [
|
||||
poissonnllloss_no_reduce_test(),
|
||||
bceloss_no_reduce_test(),
|
||||
|
|
@ -1788,7 +1789,8 @@ new_module_tests = [
|
|||
dict(
|
||||
fullname='EmbeddingBag_sparse',
|
||||
constructor=lambda: nn.EmbeddingBag(4, 3, sparse=True, dtype=torch.double),
|
||||
cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3).sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))',
|
||||
cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
|
||||
.sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))''',
|
||||
input_fn=lambda: torch.randperm(2).repeat(1, 2),
|
||||
check_gradgrad=False,
|
||||
has_sparse_gradients=True,
|
||||
|
|
@ -2710,6 +2712,9 @@ for non_linear_activation in non_linear_activations_no_batch:
|
|||
new_module_tests.append(activation_test_info)
|
||||
|
||||
|
||||
return new_module_tests
|
||||
|
||||
|
||||
def kldivloss_reference(input, target, reduction='mean', log_target=False):
|
||||
if log_target:
|
||||
result = torch.exp(target) * (target - input)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import torch.cuda
|
|||
import torch.jit
|
||||
import torch.jit._logging
|
||||
import torch.jit.frontend
|
||||
from torch.testing._internal.common_nn import module_tests, new_module_tests
|
||||
from torch.testing._internal.common_nn import module_tests, get_new_module_tests
|
||||
from torch.testing._internal.common_utils import is_iterable_of_tensors, noncontiguous_like
|
||||
|
||||
import collections
|
||||
|
|
@ -95,6 +95,7 @@ def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwarg
|
|||
# fn mapping output to part that should be gradcheck'ed, // optional
|
||||
# kwargs for function, // optional
|
||||
# )
|
||||
def get_nn_functional_tests():
|
||||
nn_functional_tests = [
|
||||
('conv1d', (S, S, S), ((S, S, S),)),
|
||||
('conv2d', (S, S, S, S), ((S, S, S, S),)),
|
||||
|
|
@ -315,6 +316,7 @@ nn_functional_tests = [
|
|||
('interpolate', torch.randn(S, M, M, M, M), (4, None, 'trilinear', None, False),
|
||||
'trilinear_5d_with_size_not_recompute_scale_factor'),
|
||||
]
|
||||
return nn_functional_tests
|
||||
|
||||
script_template = '''
|
||||
def the_method({}):
|
||||
|
|
@ -523,45 +525,6 @@ def get_nn_functional_compiled_fn_and_inputs(name, self_size, args, variant_name
|
|||
return script_fn, inputs
|
||||
|
||||
|
||||
# additional modules test
|
||||
# TODO: delete this list once we make all nn_tests work
|
||||
additional_module_tests = [
|
||||
{
|
||||
'module_name': 'Bilinear',
|
||||
'constructor_args': (S, S, M),
|
||||
'input_size': (S, S),
|
||||
'extra_args': ((S, S),)
|
||||
},
|
||||
{
|
||||
'module_name': 'RNNCell',
|
||||
'constructor_args': (S, S),
|
||||
'input_size': (S, S),
|
||||
},
|
||||
{
|
||||
'module_name': 'LSTMCell',
|
||||
'constructor_args': (S, S),
|
||||
'input_size': (S, S),
|
||||
},
|
||||
{
|
||||
'module_name': 'GRUCell',
|
||||
'constructor_args': (S, S),
|
||||
'input_size': (S, S),
|
||||
},
|
||||
{
|
||||
'module_name': 'MultiheadAttention',
|
||||
'constructor_args': (128, 8),
|
||||
'input_size': (10, 8, 128),
|
||||
'extra_args': (torch.randn(10, 8, 128), torch.randn(10, 8, 128)),
|
||||
'slowTest': True
|
||||
},
|
||||
{
|
||||
'module_name': 'Transformer',
|
||||
'constructor_args': (1, 1, 1, 1, 2),
|
||||
'input_size': (3, 1, 1),
|
||||
'extra_args': (torch.randn(1, 1, 1),),
|
||||
'slowTest': True
|
||||
}
|
||||
]
|
||||
|
||||
EXCLUDE_SCRIPT_MODULES = {
|
||||
'test_nn_AdaptiveAvgPool2d_tuple_none',
|
||||
|
|
@ -719,4 +682,44 @@ def try_get_nn_module_compiled_mod_and_inputs(*args, **kwargs):
|
|||
|
||||
|
||||
def get_all_nn_module_tests():
|
||||
return module_tests + new_module_tests + additional_module_tests
|
||||
# additional modules test
|
||||
# TODO: delete this list once we make all nn_tests work
|
||||
additional_module_tests = [
|
||||
{
|
||||
'module_name': 'Bilinear',
|
||||
'constructor_args': (S, S, M),
|
||||
'input_size': (S, S),
|
||||
'extra_args': ((S, S),)
|
||||
},
|
||||
{
|
||||
'module_name': 'RNNCell',
|
||||
'constructor_args': (S, S),
|
||||
'input_size': (S, S),
|
||||
},
|
||||
{
|
||||
'module_name': 'LSTMCell',
|
||||
'constructor_args': (S, S),
|
||||
'input_size': (S, S),
|
||||
},
|
||||
{
|
||||
'module_name': 'GRUCell',
|
||||
'constructor_args': (S, S),
|
||||
'input_size': (S, S),
|
||||
},
|
||||
{
|
||||
'module_name': 'MultiheadAttention',
|
||||
'constructor_args': (128, 8),
|
||||
'input_size': (10, 8, 128),
|
||||
'extra_args': (torch.randn(10, 8, 128), torch.randn(10, 8, 128)),
|
||||
'slowTest': True
|
||||
},
|
||||
{
|
||||
'module_name': 'Transformer',
|
||||
'constructor_args': (1, 1, 1, 1, 2),
|
||||
'input_size': (3, 1, 1),
|
||||
'extra_args': (torch.randn(1, 1, 1),),
|
||||
'slowTest': True
|
||||
}
|
||||
]
|
||||
|
||||
return module_tests + get_new_module_tests() + additional_module_tests
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user