No actual change, just remove variable contain Tensors from global scope (#143225)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143225
Approved by: https://github.com/ezyang
This commit is contained in:
albanD 2024-12-16 19:20:42 -05:00 committed by PyTorch MergeBot
parent afa313e669
commit 792f1c47e9
12 changed files with 2157 additions and 2138 deletions

View File

@ -20,8 +20,8 @@ from torch.testing._internal.common_utils import (
from torch.testing._internal.jit_metaprogramming_utils import ( from torch.testing._internal.jit_metaprogramming_utils import (
get_all_nn_module_tests, get_all_nn_module_tests,
get_nn_functional_compiled_fn_and_inputs, get_nn_functional_compiled_fn_and_inputs,
get_nn_functional_tests,
get_nn_mod_test_name, get_nn_mod_test_name,
nn_functional_tests,
try_get_nn_module_compiled_mod_and_inputs, try_get_nn_module_compiled_mod_and_inputs,
) )
from torch.testing._internal.jit_utils import enable_profiling_mode, JitTestCase from torch.testing._internal.jit_utils import enable_profiling_mode, JitTestCase
@ -70,7 +70,7 @@ class TestComplexity(JitTestCase):
def test_generated_functional_tests(self): def test_generated_functional_tests(self):
with enable_profiling_mode(): with enable_profiling_mode():
stats = [("Name", "Ifs/Loops", "non-tensor ops")] stats = [("Name", "Ifs/Loops", "non-tensor ops")]
for test in nn_functional_tests: for test in get_nn_functional_tests():
test_name = test[0] test_name = test[0]
fn, inputs = get_nn_functional_compiled_fn_and_inputs(*test) fn, inputs = get_nn_functional_compiled_fn_and_inputs(*test)

View File

@ -42,7 +42,7 @@ if not common.IS_ARM64:
(sample_module.module_tests, common_nn.NewModuleTest), (sample_module.module_tests, common_nn.NewModuleTest),
(sample_functional.functional_tests, common_nn.NewModuleTest), (sample_functional.functional_tests, common_nn.NewModuleTest),
(common_nn.module_tests, common_nn.NewModuleTest), (common_nn.module_tests, common_nn.NewModuleTest),
(common_nn.new_module_tests, common_nn.NewModuleTest), (common_nn.get_new_module_tests(), common_nn.NewModuleTest),
(common_nn.criterion_tests, common_nn.CriterionTest), (common_nn.criterion_tests, common_nn.CriterionTest),
]: ]:
for test_params_dict in test_params_dicts: for test_params_dict in test_params_dicts:

View File

@ -25,7 +25,11 @@ from torch.testing._internal.common_device_type import (
) )
from torch.testing._internal.common_methods_invocations import op_db, SampleInput from torch.testing._internal.common_methods_invocations import op_db, SampleInput
from torch.testing._internal.common_modules import module_db, modules from torch.testing._internal.common_modules import module_db, modules
from torch.testing._internal.common_nn import module_tests, new_module_tests, TestBase from torch.testing._internal.common_nn import (
get_new_module_tests,
module_tests,
TestBase,
)
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
freeze_rng_state, freeze_rng_state,
make_tensor, make_tensor,
@ -1011,7 +1015,7 @@ def filter_supported_tests(t):
# TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests # TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests
# These currently use the legacy nn tests # These currently use the legacy nn tests
supported_tests = [ supported_tests = [
t for t in module_tests + new_module_tests if filter_supported_tests(t) t for t in module_tests + get_new_module_tests() if filter_supported_tests(t)
] ]
for test_param in supported_tests: for test_param in supported_tests:
if "constructor" not in test_param: if "constructor" not in test_param:

View File

@ -50,7 +50,7 @@ from torch.testing._internal.common_device_type import (
ops, ops,
) )
from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_nn import module_tests, new_module_tests from torch.testing._internal.common_nn import module_tests, get_new_module_tests
from torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase from torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
@ -1006,7 +1006,7 @@ terrible spacing
Exhaustively test `Node.normalized_arguments` on all standard Exhaustively test `Node.normalized_arguments` on all standard
torch.nn Module classes torch.nn Module classes
""" """
for test_params in module_tests + new_module_tests: for test_params in module_tests + get_new_module_tests():
if "constructor" not in test_params: if "constructor" not in test_params:
constructor = getattr(torch.nn, test_params["module_name"]) constructor = getattr(torch.nn, test_params["module_name"])
else: else:

View File

@ -107,10 +107,10 @@ from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, dis
from torch.testing._internal.jit_metaprogramming_utils import ( from torch.testing._internal.jit_metaprogramming_utils import (
get_script_args, get_script_args,
create_input, unpack_variables, create_input, unpack_variables,
additional_module_tests, EXCLUDE_SCRIPT_MODULES, get_all_nn_module_tests, EXCLUDE_SCRIPT_MODULES,
get_nn_module_name_from_kwargs, get_nn_mod_test_name, script_method_template) get_nn_module_name_from_kwargs, get_nn_mod_test_name, script_method_template)
from torch.testing._internal.common_nn import module_tests, new_module_tests, criterion_tests from torch.testing._internal.common_nn import criterion_tests
# For testing truediv in python 2 # For testing truediv in python 2
from torch.testing._internal.test_module.future_div import div_int_future, div_float_future from torch.testing._internal.test_module.future_div import div_int_future, div_float_future
@ -16247,7 +16247,7 @@ class TestProducerVersion(TestCase):
# issue gh-32561 # issue gh-32561
self.assertTrue(torch.__version__.startswith(torch.onnx.producer_version)) self.assertTrue(torch.__version__.startswith(torch.onnx.producer_version))
for test in module_tests + new_module_tests + additional_module_tests: for test in get_all_nn_module_tests():
add_nn_module_test(**test) add_nn_module_test(**test)
for test in criterion_tests: for test in criterion_tests:

View File

@ -38,7 +38,7 @@ from torch.testing._internal.common_utils import freeze_rng_state, run_tests, Te
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, PLATFORM_SUPPORTS_FLASH_ATTENTION from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, PLATFORM_SUPPORTS_FLASH_ATTENTION
from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \ from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \
module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \ module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \
ctcloss_reference, new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input ctcloss_reference, get_new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input
from torch.testing._internal.common_device_type import dtypesIfMPS, instantiate_device_type_tests, dtypes, \ from torch.testing._internal.common_device_type import dtypesIfMPS, instantiate_device_type_tests, dtypes, \
dtypesIfCUDA, precisionOverride, skipCUDAIfCudnnVersionLessThan, onlyCUDA, onlyCPU, \ dtypesIfCUDA, precisionOverride, skipCUDAIfCudnnVersionLessThan, onlyCUDA, onlyCPU, \
skipCUDAIfRocm, skipCUDAIf, skipCUDAIfNotRocm, \ skipCUDAIfRocm, skipCUDAIf, skipCUDAIfNotRocm, \
@ -7332,7 +7332,7 @@ def add_test(test, decorator=None):
else: else:
add(cuda_test_name, with_tf32_off) add(cuda_test_name, with_tf32_off)
for test_params in module_tests + new_module_tests: for test_params in module_tests + get_new_module_tests():
# TODO: CUDA is not implemented yet # TODO: CUDA is not implemented yet
if 'constructor' not in test_params: if 'constructor' not in test_params:
name = test_params.pop('module_name') name = test_params.pop('module_name')

View File

@ -19,8 +19,6 @@ from torch.fx import Graph, GraphModule, Node
from torch.fx.subgraph_rewriter import replace_pattern_with_filters, ReplacedPatterns from torch.fx.subgraph_rewriter import replace_pattern_with_filters, ReplacedPatterns
from .utils import ( from .utils import (
_conv1d_bn_example_inputs,
_conv2d_bn_example_inputs,
_get_aten_graph_module_for_pattern, _get_aten_graph_module_for_pattern,
_is_bn_node, _is_bn_node,
_is_conv_or_conv_transpose_node, _is_conv_or_conv_transpose_node,
@ -35,27 +33,6 @@ if TYPE_CHECKING:
__all__ = [] # type: ignore[var-annotated] __all__ = [] # type: ignore[var-annotated]
# Example inputs for quantized and folded conv-bn1d patterns used in convert
_quantized_conv1d_bn_example_inputs = (
torch.randn(1, 1, 3), # x
torch.randn(1, 1, 1), # conv_weight
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
# Example inputs for quantized and folded conv-bn2d patterns used in convert
_quantized_conv2d_bn_example_inputs = (
torch.randn(1, 1, 3, 3), # x
torch.randn(1, 1, 1, 1), # conv_weight
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
def _get_quantized_conv_bn_example_inputs_kwargs( def _get_quantized_conv_bn_example_inputs_kwargs(
is_per_channel: bool, is_per_channel: bool,
has_bias: bool, has_bias: bool,
@ -631,6 +608,28 @@ def _update_special_qspecs_after_replacement(
def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule: def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
# Example inputs for conv-bn1d patterns
_conv1d_bn_example_inputs = (
torch.randn(1, 1, 3), # x
torch.randn(1, 1, 1), # conv_weight
torch.randn(1), # conv_bias
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
# Example inputs for conv-bn2d patterns
_conv2d_bn_example_inputs = (
torch.randn(1, 1, 3, 3), # x
torch.randn(1, 1, 1, 1), # conv_weight
torch.randn(1), # conv_bias
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
has_bn = any(_is_bn_node(n) for n in m.graph.nodes) has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
if not has_bn: if not has_bn:
return m return m
@ -859,6 +858,26 @@ def _copy_over_q_dq_args(original_node: Node, replacement_node: Node):
def _fold_conv_bn_qat(m: GraphModule) -> GraphModule: def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
# Example inputs for quantized and folded conv-bn1d patterns used in convert
_quantized_conv1d_bn_example_inputs = (
torch.randn(1, 1, 3), # x
torch.randn(1, 1, 1), # conv_weight
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
# Example inputs for quantized and folded conv-bn2d patterns used in convert
_quantized_conv2d_bn_example_inputs = (
torch.randn(1, 1, 3, 3), # x
torch.randn(1, 1, 1, 1), # conv_weight
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
has_bn = any(_is_bn_node(n) for n in m.graph.nodes) has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
if not has_bn: if not has_bn:
return m return m

View File

@ -22,25 +22,6 @@ __all__ = [
] ]
_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (2, 5), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-127], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randn(1, dtype=torch.float),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
def _qdq_quantized_linear( def _qdq_quantized_linear(
x_i8, x_i8,
x_scale, x_scale,
@ -129,20 +110,6 @@ def _reference_quantized_linear(
return out_i8 return out_i8
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
torch.randn((2, 5), dtype=torch.float),
-128,
127,
torch.finfo(torch.float32).eps,
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-127], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randn(1, dtype=torch.float),
)
def _qdq_dynamic_quantized_linear( def _qdq_dynamic_quantized_linear(
x_fp32, x_fp32,
x_quant_min, x_quant_min,
@ -223,25 +190,6 @@ def _reference_dynamic_quantized_linear(
return out_fp32 return out_fp32
_QUANTIZED_CONV2d_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-127], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randn(1, dtype=torch.float),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
def _qdq_quantized_conv2d( def _qdq_quantized_conv2d(
x_i8, x_i8,
x_scale, x_scale,
@ -375,20 +323,6 @@ def _reference_quantized_conv2d(
return out_i8 return out_i8
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
def _qdq_quantized_add_relu( def _qdq_quantized_add_relu(
x_i8, x_i8,
x_scale, x_scale,
@ -518,19 +452,6 @@ def _reference_quantized_add(
return out_i8 return out_i8
_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
def _qdq_quantized_max_pool2d( def _qdq_quantized_max_pool2d(
x_i8, x_i8,
x_scale, x_scale,
@ -587,15 +508,6 @@ def _reference_quantized_max_pool2d(
return out_i8 return out_i8
_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
torch.randn(1, 3, 3, 3, dtype=torch.float),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
def _quantize_per_tensor_int8(x_fp32, scale, zero_point, quant_min, quant_max): def _quantize_per_tensor_int8(x_fp32, scale, zero_point, quant_min, quant_max):
x = torch.ops.quantized_decomposed.quantize_per_tensor( x = torch.ops.quantized_decomposed.quantize_per_tensor(
x_fp32, scale, zero_point, quant_min, quant_max, torch.int8 x_fp32, scale, zero_point, quant_min, quant_max, torch.int8
@ -619,15 +531,6 @@ def _reference_quantize_per_tensor_int8(
return x return x
_DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
def _dequantize_per_tensor_int8(x_i8, scale, zero_point, quant_min, quant_max): def _dequantize_per_tensor_int8(x_i8, scale, zero_point, quant_min, quant_max):
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
x_i8, scale, zero_point, quant_min, quant_max, torch.int8 x_i8, scale, zero_point, quant_min, quant_max, torch.int8
@ -648,16 +551,6 @@ def _reference_dequantize_per_tensor_int8(
return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32) return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
_QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
torch.randn(1, 3, 3, 3, dtype=torch.float),
torch.randn(3, dtype=torch.float),
torch.zeros(3, dtype=torch.int),
1,
-128,
127,
)
def _quantize_per_channel_int8( def _quantize_per_channel_int8(
x_fp32, scales, zero_points, ch_axis, quant_min, quant_max x_fp32, scales, zero_points, ch_axis, quant_min, quant_max
): ):
@ -678,16 +571,6 @@ def _reference_quantize_per_channel_int8(
return out_i32.to(torch.int8) return out_i32.to(torch.int8)
_DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(3, dtype=torch.float),
torch.zeros(3, dtype=torch.int),
1,
-128,
127,
)
def _dequantize_per_channel_int8( def _dequantize_per_channel_int8(
x_i8, scales, zero_points, ch_axis, quant_min, quant_max x_i8, scales, zero_points, ch_axis, quant_min, quant_max
): ):
@ -733,79 +616,186 @@ class _RewriteInfo:
replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
_REWRITE_INFO_LIST = [
_RewriteInfo(
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
_WrapperModule(_qdq_dynamic_quantized_linear),
_WrapperModule(_reference_dynamic_quantized_linear),
partial(
_replace_literals_with_existing_placeholders,
literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3},
),
partial(
_replace_literals_with_existing_placeholders,
literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3},
),
),
_RewriteInfo(
_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
_WrapperModule(_qdq_quantized_linear),
_WrapperModule(_reference_quantized_linear),
_replace_literals_with_new_placeholders,
_replace_literals_with_new_placeholders,
),
_RewriteInfo(
_QUANTIZED_CONV2d_EXAMPLE_INPUTS,
_WrapperModule(_qdq_quantized_conv2d),
_WrapperModule(_reference_quantized_conv2d),
partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]),
partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]),
),
_RewriteInfo(
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS,
_WrapperModule(_qdq_quantized_add_relu),
_WrapperModule(_reference_quantized_add_relu),
),
_RewriteInfo(
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS,
_WrapperModule(_qdq_quantized_add),
_WrapperModule(_reference_quantized_add),
),
_RewriteInfo(
_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS,
_WrapperModule(_qdq_quantized_max_pool2d),
_WrapperModule(_reference_quantized_max_pool2d),
_replace_literals_with_new_placeholders,
_replace_literals_with_new_placeholders,
),
_RewriteInfo(
_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS,
_WrapperModule(_quantize_per_tensor_int8),
_WrapperModule(_reference_quantize_per_tensor_int8),
),
_RewriteInfo(
_DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS,
_WrapperModule(_dequantize_per_tensor_int8),
_WrapperModule(_reference_dequantize_per_tensor_int8),
),
_RewriteInfo(
_QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS,
_WrapperModule(_quantize_per_channel_int8),
_WrapperModule(_reference_quantize_per_channel_int8),
_replace_ph_qdq_per_channel_replacement,
_replace_ph_qdq_per_channel_replacement,
),
_RewriteInfo(
_DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS,
_WrapperModule(_dequantize_per_channel_int8),
_WrapperModule(_reference_dequantize_per_channel_int8),
_replace_ph_qdq_per_channel_replacement,
_replace_ph_qdq_per_channel_replacement,
),
]
def reference_representation_rewrite(model: GraphModule) -> GraphModule: def reference_representation_rewrite(model: GraphModule) -> GraphModule:
_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (2, 5), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-127], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randn(1, dtype=torch.float),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
torch.randn((2, 5), dtype=torch.float),
-128,
127,
torch.finfo(torch.float32).eps,
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-127], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randn(1, dtype=torch.float),
)
_QUANTIZED_CONV2d_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-127], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randn(1, dtype=torch.float),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
torch.randn(1, 3, 3, 3, dtype=torch.float),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
_DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
_QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
torch.randn(1, 3, 3, 3, dtype=torch.float),
torch.randn(3, dtype=torch.float),
torch.zeros(3, dtype=torch.int),
1,
-128,
127,
)
_DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(3, dtype=torch.float),
torch.zeros(3, dtype=torch.int),
1,
-128,
127,
)
_REWRITE_INFO_LIST = [
_RewriteInfo(
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
_WrapperModule(_qdq_dynamic_quantized_linear),
_WrapperModule(_reference_dynamic_quantized_linear),
partial(
_replace_literals_with_existing_placeholders,
literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3},
),
partial(
_replace_literals_with_existing_placeholders,
literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3},
),
),
_RewriteInfo(
_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
_WrapperModule(_qdq_quantized_linear),
_WrapperModule(_reference_quantized_linear),
_replace_literals_with_new_placeholders,
_replace_literals_with_new_placeholders,
),
_RewriteInfo(
_QUANTIZED_CONV2d_EXAMPLE_INPUTS,
_WrapperModule(_qdq_quantized_conv2d),
_WrapperModule(_reference_quantized_conv2d),
partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]),
partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]),
),
_RewriteInfo(
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS,
_WrapperModule(_qdq_quantized_add_relu),
_WrapperModule(_reference_quantized_add_relu),
),
_RewriteInfo(
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS,
_WrapperModule(_qdq_quantized_add),
_WrapperModule(_reference_quantized_add),
),
_RewriteInfo(
_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS,
_WrapperModule(_qdq_quantized_max_pool2d),
_WrapperModule(_reference_quantized_max_pool2d),
_replace_literals_with_new_placeholders,
_replace_literals_with_new_placeholders,
),
_RewriteInfo(
_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS,
_WrapperModule(_quantize_per_tensor_int8),
_WrapperModule(_reference_quantize_per_tensor_int8),
),
_RewriteInfo(
_DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS,
_WrapperModule(_dequantize_per_tensor_int8),
_WrapperModule(_reference_dequantize_per_tensor_int8),
),
_RewriteInfo(
_QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS,
_WrapperModule(_quantize_per_channel_int8),
_WrapperModule(_reference_quantize_per_channel_int8),
_replace_ph_qdq_per_channel_replacement,
_replace_ph_qdq_per_channel_replacement,
),
_RewriteInfo(
_DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS,
_WrapperModule(_dequantize_per_channel_int8),
_WrapperModule(_reference_dequantize_per_channel_int8),
_replace_ph_qdq_per_channel_replacement,
_replace_ph_qdq_per_channel_replacement,
),
]
remove_tensor_overload_for_qdq_ops(model) remove_tensor_overload_for_qdq_ops(model)
from torch._export import gm_using_training_ir from torch._export import gm_using_training_ir

View File

@ -33,28 +33,6 @@ _DEQUANTIZE_OPS = [
torch.ops.quantized_decomposed.dequantize_per_channel.default, torch.ops.quantized_decomposed.dequantize_per_channel.default,
] ]
# Example inputs for conv-bn1d patterns
_conv1d_bn_example_inputs = (
torch.randn(1, 1, 3), # x
torch.randn(1, 1, 1), # conv_weight
torch.randn(1), # conv_bias
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
# Example inputs for conv-bn2d patterns
_conv2d_bn_example_inputs = (
torch.randn(1, 1, 3, 3), # x
torch.randn(1, 1, 1, 1), # conv_weight
torch.randn(1), # conv_bias
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
def _is_connected(source: torch.fx.Node, dest: torch.fx.Node) -> bool: def _is_connected(source: torch.fx.Node, dest: torch.fx.Node) -> bool:
""" """

View File

@ -10,8 +10,6 @@ from torch._subclasses import FakeTensor
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
from torch.ao.quantization.pt2e.export_utils import _WrapperModule from torch.ao.quantization.pt2e.export_utils import _WrapperModule
from torch.ao.quantization.pt2e.utils import ( from torch.ao.quantization.pt2e.utils import (
_conv1d_bn_example_inputs,
_conv2d_bn_example_inputs,
_get_aten_graph_module_for_pattern, _get_aten_graph_module_for_pattern,
_is_conv_node, _is_conv_node,
_is_conv_transpose_node, _is_conv_transpose_node,
@ -487,6 +485,28 @@ def _do_annotate_conv_bn(
for the following names: "input", "conv", "weight", "bias", and "output". for the following names: "input", "conv", "weight", "bias", and "output".
""" """
# Example inputs for conv-bn1d patterns
_conv1d_bn_example_inputs = (
torch.randn(1, 1, 3), # x
torch.randn(1, 1, 1), # conv_weight
torch.randn(1), # conv_bias
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
# Example inputs for conv-bn2d patterns
_conv2d_bn_example_inputs = (
torch.randn(1, 1, 3, 3), # x
torch.randn(1, 1, 1, 1), # conv_weight
torch.randn(1), # conv_bias
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
def get_pattern(conv_fn: Callable, relu_is_inplace: bool): def get_pattern(conv_fn: Callable, relu_is_inplace: bool):
def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv): def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv):
conv = conv_fn(x, conv_weight, conv_bias) conv = conv_fn(x, conv_weight, conv_bias)

File diff suppressed because it is too large Load Diff

View File

@ -8,7 +8,7 @@ import torch.cuda
import torch.jit import torch.jit
import torch.jit._logging import torch.jit._logging
import torch.jit.frontend import torch.jit.frontend
from torch.testing._internal.common_nn import module_tests, new_module_tests from torch.testing._internal.common_nn import module_tests, get_new_module_tests
from torch.testing._internal.common_utils import is_iterable_of_tensors, noncontiguous_like from torch.testing._internal.common_utils import is_iterable_of_tensors, noncontiguous_like
import collections import collections
@ -95,226 +95,228 @@ def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwarg
# fn mapping output to part that should be gradcheck'ed, // optional # fn mapping output to part that should be gradcheck'ed, // optional
# kwargs for function, // optional # kwargs for function, // optional
# ) # )
nn_functional_tests = [ def get_nn_functional_tests():
('conv1d', (S, S, S), ((S, S, S),)), nn_functional_tests = [
('conv2d', (S, S, S, S), ((S, S, S, S),)), ('conv1d', (S, S, S), ((S, S, S),)),
('conv3d', (S, S, S, S, S), ((S, S, S, S, S),)), ('conv2d', (S, S, S, S), ((S, S, S, S),)),
('conv_transpose1d', (S, S, S), ((S, S, S),)), ('conv3d', (S, S, S, S, S), ((S, S, S, S, S),)),
('conv_transpose2d', (S, S, S, S), ((S, S, S, S),)), ('conv_transpose1d', (S, S, S), ((S, S, S),)),
('conv_transpose3d', (S, S, S, S, S), ((S, S, S, S, S),)), ('conv_transpose2d', (S, S, S, S), ((S, S, S, S),)),
('conv_tbc', (S, S, S), ((S, S, S), (S,), 2)), ('conv_transpose3d', (S, S, S, S, S), ((S, S, S, S, S),)),
('avg_pool1d', (S, S, S), (3,)), ('conv_tbc', (S, S, S), ((S, S, S), (S,), 2)),
('avg_pool2d', (S, S, S, S), (3,), '', (True,)), ('avg_pool1d', (S, S, S), (3,)),
('avg_pool3d', (S, S, S, S, S), (3,)), ('avg_pool2d', (S, S, S, S), (3,), '', (True,)),
('fractional_max_pool2d', (S, S, S, S), (3, [2, 3],)), ('avg_pool3d', (S, S, S, S, S), (3,)),
('max_pool1d', (S, S, S), (2, 1)), ('fractional_max_pool2d', (S, S, S, S), (3, [2, 3],)),
('max_pool1d', (S, S, S), (2, 1, 1, 1, False, True), 'with_indices'), ('max_pool1d', (S, S, S), (2, 1)),
('max_pool2d', (S, S, S, S), (2, 1), '', (True, 'aten::max_pool2d_with_indices')), ('max_pool1d', (S, S, S), (2, 1, 1, 1, False, True), 'with_indices'),
('max_pool2d', (S, S, S, S), (2, 1, 1, 1, False, True), 'with_indices', (True, 'aten::max_pool2d_with_indices')), ('max_pool2d', (S, S, S, S), (2, 1), '', (True, 'aten::max_pool2d_with_indices')),
('max_pool3d', (S, S, S, S, S), (2, 1)), ('max_pool2d', (S, S, S, S), (2, 1, 1, 1, False, True), 'with_indices', (True, 'aten::max_pool2d_with_indices')),
('max_unpool1d', torch.tensor([[[2., 4]]]), (torch.tensor([[[1, 3]]]), 2, 2, 0)), ('max_pool3d', (S, S, S, S, S), (2, 1)),
('max_unpool2d', torch.tensor([[[[2., 4]]]]), (torch.tensor([[[[1, 3]]]]), 2, 2, 0)), ('max_unpool1d', torch.tensor([[[2., 4]]]), (torch.tensor([[[1, 3]]]), 2, 2, 0)),
('max_unpool3d', torch.tensor([[[[[2., 4]]]]]), (torch.tensor([[[[[1, 3]]]]]), 2, 2, 0)), ('max_unpool2d', torch.tensor([[[[2., 4]]]]), (torch.tensor([[[[1, 3]]]]), 2, 2, 0)),
('lp_pool1d', (S, S, S), (2., 3, 2,)), ('max_unpool3d', torch.tensor([[[[[2., 4]]]]]), (torch.tensor([[[[[1, 3]]]]]), 2, 2, 0)),
('lp_pool2d', (S, S, S, S), (2., 3, 2,)), ('lp_pool1d', (S, S, S), (2., 3, 2,)),
('lp_pool3d', (S, S, S, S, S), (2., 3, 2,)), ('lp_pool2d', (S, S, S, S), (2., 3, 2,)),
('adaptive_max_pool1d', (S, S, S), (5,)), ('lp_pool3d', (S, S, S, S, S), (2., 3, 2,)),
('adaptive_max_pool2d', (S, S, S, S), ([5, 7],)), ('adaptive_max_pool1d', (S, S, S), (5,)),
('adaptive_max_pool3d', (S, S, S, S, S), ([3, 2, 2],)), ('adaptive_max_pool2d', (S, S, S, S), ([5, 7],)),
('adaptive_avg_pool1d', (S, S, S), (5,), '', (True,)), ('adaptive_max_pool3d', (S, S, S, S, S), ([3, 2, 2],)),
('adaptive_avg_pool2d', (S, S, S, S), ([5, 7],), '', (True,)), ('adaptive_avg_pool1d', (S, S, S), (5,), '', (True,)),
('adaptive_avg_pool3d', (S, S, S, S, S), ([3, 2, 2],), '', (True,)), ('adaptive_avg_pool2d', (S, S, S, S), ([5, 7],), '', (True,)),
('dropout', (S, S, S), (0.5,), '', (True, 'aten::native_dropout')), ('adaptive_avg_pool3d', (S, S, S, S, S), ([3, 2, 2],), '', (True,)),
('alpha_dropout', (S, S, S), (0.5,)), ('dropout', (S, S, S), (0.5,), '', (True, 'aten::native_dropout')),
('dropout2d', (S, S, S), (0.5,)), ('alpha_dropout', (S, S, S), (0.5,)),
('dropout2d', (S, S, S, S), (0.5,), 'batched'), ('dropout2d', (S, S, S), (0.5,)),
('dropout3d', (S, S, S, S), (0.5,)), ('dropout2d', (S, S, S, S), (0.5,), 'batched'),
('dropout3d', (S, S, S, S, S), (0.5,), 'batched'), ('dropout3d', (S, S, S, S), (0.5,)),
('feature_alpha_dropout', (S, S, S), (0.5,)), ('dropout3d', (S, S, S, S, S), (0.5,), 'batched'),
('threshold', (S, S, S), (0.1, 2.), '', (True,)), ('feature_alpha_dropout', (S, S, S), (0.5,)),
('threshold', (S, S, S), (0.1, 2., True), 'inplace'), ('threshold', (S, S, S), (0.1, 2.), '', (True,)),
('relu', (S, S, S), (), '', (True,)), ('threshold', (S, S, S), (0.1, 2., True), 'inplace'),
('relu', (S, S, S), (), 'inplace'), ('relu', (S, S, S), (), '', (True,)),
('glu', (S - 1, S - 1, S - 1), (),), ('relu', (S, S, S), (), 'inplace'),
('hardtanh', (S, S, S), (-0.5, 0.5), '', (True,)), ('glu', (S - 1, S - 1, S - 1), (),),
('hardtanh', (S, S, S), (-0.5, 0.5, True), 'inplace'), ('hardtanh', (S, S, S), (-0.5, 0.5), '', (True,)),
('relu6', (S, S, S), (), '', (True,)), ('hardtanh', (S, S, S), (-0.5, 0.5, True), 'inplace'),
('relu6', (S, S, S), (True), 'inplace'), ('relu6', (S, S, S), (), '', (True,)),
('elu', (S, S, S), (0.9,),), ('relu6', (S, S, S), (True), 'inplace'),
('elu', (S, S, S), (0.9, True), 'inplace'), ('elu', (S, S, S), (0.9,),),
('selu', (S, S, S), (),), ('elu', (S, S, S), (0.9, True), 'inplace'),
('selu', (S, S, S), (True), 'inplace'), ('selu', (S, S, S), (),),
('celu', (S, S, S), (0.9,),), ('selu', (S, S, S), (True), 'inplace'),
('celu', (S, S, S), (0.9, True), 'inplace'), ('celu', (S, S, S), (0.9,),),
('leaky_relu', (S, S, S), (0.02,), '', (True,)), ('celu', (S, S, S), (0.9, True), 'inplace'),
('leaky_relu', (S, S, S), (0.02,), 'inplace'), ('leaky_relu', (S, S, S), (0.02,), '', (True,)),
('rrelu', (S, S), (0.1, 0.3, False),), ('leaky_relu', (S, S, S), (0.02,), 'inplace'),
('rrelu', (S, S), (0.1, 0.3, False, True), 'inplace'), ('rrelu', (S, S), (0.1, 0.3, False),),
('hardshrink', (S, S, S), (0.4,), '', (True,)), ('rrelu', (S, S), (0.1, 0.3, False, True), 'inplace'),
('tanhshrink', (S, S, S), (),), ('hardshrink', (S, S, S), (0.4,), '', (True,)),
('softsign', (S, S, S), (),), ('tanhshrink', (S, S, S), (),),
('softplus', (S, S, S), (), '', (True,)), ('softsign', (S, S, S), (),),
('softmin', (S, S, S), (0,),), ('softplus', (S, S, S), (), '', (True,)),
('softmax', (S, S, S), (0,), '', (True,)), ('softmin', (S, S, S), (0,),),
('softmax', (S, S, S), (0, 3, torch.double), 'with_all_args', (True,)), ('softmax', (S, S, S), (0,), '', (True,)),
('tanh', (S, S, S), (), '', (True,)), ('softmax', (S, S, S), (0, 3, torch.double), 'with_all_args', (True,)),
('sigmoid', (S, S, S), (), '', (True,)), ('tanh', (S, S, S), (), '', (True,)),
('silu', (S, S, S), (), '', (True,)), ('sigmoid', (S, S, S), (), '', (True,)),
('log_softmax', (S, S, S), (0,), '', (True,)), ('silu', (S, S, S), (), '', (True,)),
('linear', (S, S), ((M, S),), '', (True, ['aten::linear'])), ('log_softmax', (S, S, S), (0,), '', (True,)),
('linear', (S, S), ((M, S), (M,)), 'addmm', (True, ['aten::linear'])), ('linear', (S, S), ((M, S),), '', (True, ['aten::linear'])),
('bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M),),), ('linear', (S, S), ((M, S), (M,)), 'addmm', (True, ['aten::linear'])),
('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ), '', (True,)), ('bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M),),),
('embedding_bag', torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3), torch.tensor([0, 4]),),), ('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ), '', (True,)),
('batch_norm', (S, S), ('embedding_bag', torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3), torch.tensor([0, 4]),),),
(non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), None, None, True, ), ('batch_norm', (S, S),
'training', (True, 'aten::_batch_norm_impl_index')), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), None, None, True, ),
('batch_norm', (0, S, S, S), 'training', (True, 'aten::_batch_norm_impl_index')),
(non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), ('batch_norm', (0, S, S, S),
non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
'size_zero', (True, 'aten::_batch_norm_impl_index')), non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
('batch_norm', (0, S, S, S), 'size_zero', (True, 'aten::_batch_norm_impl_index')),
(non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), ('batch_norm', (0, S, S, S),
non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
'size_zero_inference', (True, 'aten::_batch_norm_impl_index')), non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
('batch_norm', (S, S), 'size_zero_inference', (True, 'aten::_batch_norm_impl_index')),
(non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), ('batch_norm', (S, S),
non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
'with_weight_and_bias_training', (True, 'aten::_batch_norm_impl_index')), non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), 'with_weight_and_bias_training', (True, 'aten::_batch_norm_impl_index')),
None, non_differentiable(torch.ones(S)), True, ), ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
'with_only_bias_training', (True, 'aten::_batch_norm_impl_index')), None, non_differentiable(torch.ones(S)), True, ),
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), 'with_only_bias_training', (True, 'aten::_batch_norm_impl_index')),
non_differentiable(torch.randn(S)), None, True, ), ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
'with_only_weight_training', (True, 'aten::_batch_norm_impl_index')), non_differentiable(torch.randn(S)), None, True, ),
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), 'with_only_weight_training', (True, 'aten::_batch_norm_impl_index')),
None, None, False, ), ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
'inference', (True, 'aten::_batch_norm_impl_index')), None, None, False, ),
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), 'inference', (True, 'aten::_batch_norm_impl_index')),
non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), False, ), ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
'with_weight_and_bias_inference', (True, 'aten::_batch_norm_impl_index')), non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), False, ),
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), 'with_weight_and_bias_inference', (True, 'aten::_batch_norm_impl_index')),
None, non_differentiable(torch.ones(S)), False, ), ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
'with_only_bias_inference', (True, 'aten::_batch_norm_impl_index')), None, non_differentiable(torch.ones(S)), False, ),
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), 'with_only_bias_inference', (True, 'aten::_batch_norm_impl_index')),
non_differentiable(torch.randn(S)), None, False, ), ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
'with_only_weight_inference', (True, 'aten::_batch_norm_impl_index')), non_differentiable(torch.randn(S)), None, False, ),
('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),), 'with_only_weight_inference', (True, 'aten::_batch_norm_impl_index')),
('layer_norm', (S, S, S, S), ([5],), '', ('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),),
(False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])), ('layer_norm', (S, S, S, S), ([5],), '',
('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),), 'with_only_weight', (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
(False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])), ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),), 'with_only_weight',
('layer_norm', (S, S, S, S), ([5], None, non_differentiable(torch.rand(S)),), 'with_only_bias', (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
(False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])), ('layer_norm', (S, S, S, S), ([5], None, non_differentiable(torch.rand(S)),), 'with_only_bias',
('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)), (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
non_differentiable(torch.rand(S))), 'with_weight_and_bias', ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),
(False, ['aten::contiguous', 'aten::_batch_norm_impl_index', 'aten::addcmul'])), non_differentiable(torch.rand(S))), 'with_weight_and_bias',
('group_norm', (S, S, S), (1, torch.rand(5),),), (False, ['aten::contiguous', 'aten::_batch_norm_impl_index', 'aten::addcmul'])),
('local_response_norm', (S, S, S), (2, ),), ('group_norm', (S, S, S), (1, torch.rand(5),),),
('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),), '',), ('local_response_norm', (S, S, S), (2, ),),
('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2),),), ('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),), '',),
('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2), True, True), 'full'), ('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2),),),
('kl_div', F.log_softmax(torch.randn(S, 10), 1), (F.softmax(torch.randn(S, 10), 1),),), ('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2), True, True), 'full'),
('cross_entropy', (3, S), (torch.randint(S, (3,), dtype=torch.int64),),), ('kl_div', F.log_softmax(torch.randn(S, 10), 1), (F.softmax(torch.randn(S, 10), 1),),),
('binary_cross_entropy_with_logits', (3,), (torch.empty(3).random_(2), ),), ('cross_entropy', (3, S), (torch.randint(S, (3,), dtype=torch.int64),),),
('smooth_l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), ('binary_cross_entropy_with_logits', (3,), (torch.empty(3).random_(2), ),),
('huber_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), ('smooth_l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
('l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), ('huber_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
('mse_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), ('l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
('smooth_l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'), ('mse_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
('huber_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'), ('smooth_l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
('l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'), ('huber_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
('mse_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'), ('l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
('margin_ranking_loss', (S,), ((S,), (S,)),), ('mse_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
('hinge_embedding_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), ('margin_ranking_loss', (S,), ((S,), (S,)),),
('soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), ('hinge_embedding_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
('multilabel_soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), ('soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
('cosine_embedding_loss', (S, S), ((S, S), non_differentiable(torch.rand(S,))),), ('multilabel_soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
('pixel_shuffle', (1, 9, 4, 4), (3,),), ('cosine_embedding_loss', (S, S), ((S, S), non_differentiable(torch.rand(S,))),),
('pixel_unshuffle', (1, 1, 12, 12), (3,),), ('pixel_shuffle', (1, 9, 4, 4), (3,),),
('affine_grid', (S, 2, 3), (torch.Size([S, 1, 7, 7]),),), ('pixel_unshuffle', (1, 1, 12, 12), (3,),),
('pad', (3, 3, 4, 2), ([1, 1],),), ('affine_grid', (S, 2, 3), (torch.Size([S, 1, 7, 7]),),),
('pairwise_distance', (S, S), ((S, S),),), ('pad', (3, 3, 4, 2), ([1, 1],),),
('pdist', (S, S), (),), ('pairwise_distance', (S, S), ((S, S),),),
('cosine_similarity', (S, S), ((S, S),),), ('pdist', (S, S), (),),
('triplet_margin_loss', (S, S), ((S, S), (S, S)),), ('cosine_similarity', (S, S), ((S, S),),),
('normalize', (S, S, S), (),), ('triplet_margin_loss', (S, S), ((S, S), (S, S)),),
('unfold', (S, S, S, S), ([2, 3]),), ('normalize', (S, S, S), (),),
('fold', (1, 3 * 2 * 2, 12), ([4, 5], [2, 2]),), ('unfold', (S, S, S, S), ([2, 3]),),
('grid_sample', (S, S, S, S), (non_differentiable(torch.rand(S, S, S, 2)),),), ('fold', (1, 3 * 2 * 2, 12), ([4, 5], [2, 2]),),
('gumbel_softmax', (S, S), (2.,), '', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])), ('grid_sample', (S, S, S, S), (non_differentiable(torch.rand(S, S, S, 2)),),),
('gumbel_softmax', (S, S), (2., True,), 'hard', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])), ('gumbel_softmax', (S, S), (2.,), '', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])),
('multilabel_margin_loss', torch.tensor([[0.2, -0.2, 0.07]]), (torch.tensor([[0, 0, 1]]),),), ('gumbel_softmax', (S, S), (2., True,), 'hard', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])),
('multi_margin_loss', (S, S), (non_differentiable(torch.randint(S, (S, ), dtype=torch.int64)), ('multilabel_margin_loss', torch.tensor([[0.2, -0.2, 0.07]]), (torch.tensor([[0, 0, 1]]),),),
1, 1., non_differentiable(torch.randn(S))),), ('multi_margin_loss', (S, S), (non_differentiable(torch.randint(S, (S, ), dtype=torch.int64)),
('binary_cross_entropy', torch.randn(3, 2).sigmoid(), (non_differentiable(torch.rand(3, 2)), 1, 1., non_differentiable(torch.randn(S))),),
non_differentiable(torch.randn(3, 2))),), ('binary_cross_entropy', torch.randn(3, 2).sigmoid(), (non_differentiable(torch.rand(3, 2)),
('binary_cross_entropy', torch.randn(3, 2).sigmoid(), non_differentiable(torch.randn(3, 2))),),
(non_differentiable(torch.rand(3, 2)), ('binary_cross_entropy', torch.randn(3, 2).sigmoid(),
non_differentiable(torch.randn(3, 2)), None, None, 'mean'), 'size_average'), (non_differentiable(torch.rand(3, 2)),
('ctc_loss', torch.rand(S, S, S).log_softmax(2).detach().requires_grad_(), non_differentiable(torch.randn(3, 2)), None, None, 'mean'), 'size_average'),
(torch.randint(1, S, (S, S), dtype=torch.long), torch.full((S,), S, dtype=torch.long), ('ctc_loss', torch.rand(S, S, S).log_softmax(2).detach().requires_grad_(),
torch.randint(1, S, (S,), dtype=torch.long))), (torch.randint(1, S, (S, S), dtype=torch.long), torch.full((S,), S, dtype=torch.long),
('upsample', torch.randn(S, S, M, M), (None, 2.), 'with_scale'), torch.randint(1, S, (S,), dtype=torch.long))),
('upsample', torch.randn(S, S, M, M), (4,), 'with_size'), ('upsample', torch.randn(S, S, M, M), (None, 2.), 'with_scale'),
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'nearest_4d'), ('upsample', torch.randn(S, S, M, M), (4,), 'with_size'),
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'nearest_4d_with_scale'), ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'nearest_4d'),
('interpolate', torch.randn(S, S, M, M), (4,), 'nearest_4d_with_size'), ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'nearest_4d_with_scale'),
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'area_4d'), ('interpolate', torch.randn(S, S, M, M), (4,), 'nearest_4d_with_size'),
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'area_4d_with_scale'), ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'area_4d'),
('interpolate', torch.randn(S, S, M, M), (4,), 'area_4d_with_size'), ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'area_4d_with_scale'),
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bilinear_4d'), ('interpolate', torch.randn(S, S, M, M), (4,), 'area_4d_with_size'),
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bilinear_4d_with_scale'), ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bilinear_4d'),
('interpolate', torch.randn(S, S, M, M), (4,), 'bilinear_4d_with_size'), ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bilinear_4d_with_scale'),
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bicubic_4d'), ('interpolate', torch.randn(S, S, M, M), (4,), 'bilinear_4d_with_size'),
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bicubic_4d_with_scale'), ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bicubic_4d'),
('interpolate', torch.randn(S, S, M, M), (4,), 'bicubic_4d_with_size'), ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bicubic_4d_with_scale'),
('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'nearest_3d'), ('interpolate', torch.randn(S, S, M, M), (4,), 'bicubic_4d_with_size'),
('interpolate', torch.randn(S, M, M), (None, 2.), 'nearest_3d_with_scale'), ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'nearest_3d'),
('interpolate', torch.randn(S, M, M), (4,), 'nearest_3d_with_size'), ('interpolate', torch.randn(S, M, M), (None, 2.), 'nearest_3d_with_scale'),
('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'area_3d'), ('interpolate', torch.randn(S, M, M), (4,), 'nearest_3d_with_size'),
('interpolate', torch.randn(S, M, M), (None, 2.), 'area_3d_with_scale'), ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'area_3d'),
('interpolate', torch.randn(S, M, M), (4,), 'area_3d_with_size'), ('interpolate', torch.randn(S, M, M), (None, 2.), 'area_3d_with_scale'),
('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'linear_3d'), ('interpolate', torch.randn(S, M, M), (4,), 'area_3d_with_size'),
('interpolate', torch.randn(S, M, M), (None, 2.), 'linear_3d_with_scale'), ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'linear_3d'),
('interpolate', torch.randn(S, M, M), (4,), 'linear_3d_with_size'), ('interpolate', torch.randn(S, M, M), (None, 2.), 'linear_3d_with_scale'),
('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'nearest_5d_with_scale'), ('interpolate', torch.randn(S, M, M), (4,), 'linear_3d_with_size'),
('interpolate', torch.randn(S, M, M, M, M), (4,), 'nearest_5d_with_size'), ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'nearest_5d_with_scale'),
('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'area_5d'), ('interpolate', torch.randn(S, M, M, M, M), (4,), 'nearest_5d_with_size'),
('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'area_5d_with_scale'), ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'area_5d'),
('interpolate', torch.randn(S, M, M, M, M), (4,), 'area_5d_with_size'), ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'area_5d_with_scale'),
('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'trilinear_5d'), ('interpolate', torch.randn(S, M, M, M, M), (4,), 'area_5d_with_size'),
('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'trilinear_5d_with_scale'), ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'trilinear_5d'),
('interpolate', torch.randn(S, M, M, M, M), (4,), 'trilinear_5d_with_size'), ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'trilinear_5d_with_scale'),
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2, None, 'nearest', None, False), ('interpolate', torch.randn(S, M, M, M, M), (4,), 'trilinear_5d_with_size'),
'nearest_4d_not_recompute_scale_factor'), ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2, None, 'nearest', None, False),
('interpolate', torch.randn(S, S, M, M), (4, None, 'nearest', None, False), 'nearest_4d_not_recompute_scale_factor'),
'nearest_4d_with_size_not_recompute_scale_factor'), ('interpolate', torch.randn(S, S, M, M), (4, None, 'nearest', None, False),
('interpolate', torch.randn(S, S, M, M), (None, 2., 'bilinear', None, False), 'nearest_4d_with_size_not_recompute_scale_factor'),
'bilinear_4d_with_scale_not_recompute_scale_factor'), ('interpolate', torch.randn(S, S, M, M), (None, 2., 'bilinear', None, False),
('interpolate', torch.randn(S, S, M, M), (4, None, 'bilinear', None, False), 'bilinear_4d_with_scale_not_recompute_scale_factor'),
'bilinear_4d_with_size_not_recompute_scale_factor'), ('interpolate', torch.randn(S, S, M, M), (4, None, 'bilinear', None, False),
('interpolate', torch.randn(S, S, M, M), (None, 2., 'bicubic', None, False), 'bilinear_4d_with_size_not_recompute_scale_factor'),
'bicubic_4d_with_scale_not_recompute_scale_factor'), ('interpolate', torch.randn(S, S, M, M), (None, 2., 'bicubic', None, False),
('interpolate', torch.randn(S, S, M, M), (4, None, 'bicubic', None, False), 'bicubic_4d_with_scale_not_recompute_scale_factor'),
'bicubic_4d_with_size_not_recompute_scale_factor'), ('interpolate', torch.randn(S, S, M, M), (4, None, 'bicubic', None, False),
('interpolate', torch.randn(S, M, M), (None, 2., 'nearest', None, False), 'bicubic_4d_with_size_not_recompute_scale_factor'),
'nearest_3d_with_scale_not_recompute_scale_factor'), ('interpolate', torch.randn(S, M, M), (None, 2., 'nearest', None, False),
('interpolate', torch.randn(S, M, M), (4, None, 'nearest', None, False), 'nearest_3d_with_scale_not_recompute_scale_factor'),
'nearest_3d_with_size_not_recompute_scale_factor'), ('interpolate', torch.randn(S, M, M), (4, None, 'nearest', None, False),
('interpolate', torch.randn(S, M, M), (None, 2., 'linear', None, False), 'nearest_3d_with_size_not_recompute_scale_factor'),
'linear_3d_with_scale_not_recompute_scale_factor'), ('interpolate', torch.randn(S, M, M), (None, 2., 'linear', None, False),
('interpolate', torch.randn(S, M, M), (4, None, 'linear', None, False), 'linear_3d_with_scale_not_recompute_scale_factor'),
'linear_3d_with_size_not_recompute_scale_factor'), ('interpolate', torch.randn(S, M, M), (4, None, 'linear', None, False),
('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'nearest', None, False), 'linear_3d_with_size_not_recompute_scale_factor'),
'nearest_5d_with_scale_not_recompute_scale_factor'), ('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'nearest', None, False),
('interpolate', torch.randn(S, M, M, M, M), (4, None, 'nearest', None, False), 'nearest_5d_with_scale_not_recompute_scale_factor'),
'nearest_5d_with_size_not_recompute_scale_factor'), ('interpolate', torch.randn(S, M, M, M, M), (4, None, 'nearest', None, False),
('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'trilinear', None, False), 'nearest_5d_with_size_not_recompute_scale_factor'),
'trilinear_5d_with_scale_not_recompute_scale_factor'), ('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'trilinear', None, False),
('interpolate', torch.randn(S, M, M, M, M), (4, None, 'trilinear', None, False), 'trilinear_5d_with_scale_not_recompute_scale_factor'),
'trilinear_5d_with_size_not_recompute_scale_factor'), ('interpolate', torch.randn(S, M, M, M, M), (4, None, 'trilinear', None, False),
] 'trilinear_5d_with_size_not_recompute_scale_factor'),
]
return nn_functional_tests
script_template = ''' script_template = '''
def the_method({}): def the_method({}):
@ -523,45 +525,6 @@ def get_nn_functional_compiled_fn_and_inputs(name, self_size, args, variant_name
return script_fn, inputs return script_fn, inputs
# additional modules test
# TODO: delete this list once we make all nn_tests work
additional_module_tests = [
{
'module_name': 'Bilinear',
'constructor_args': (S, S, M),
'input_size': (S, S),
'extra_args': ((S, S),)
},
{
'module_name': 'RNNCell',
'constructor_args': (S, S),
'input_size': (S, S),
},
{
'module_name': 'LSTMCell',
'constructor_args': (S, S),
'input_size': (S, S),
},
{
'module_name': 'GRUCell',
'constructor_args': (S, S),
'input_size': (S, S),
},
{
'module_name': 'MultiheadAttention',
'constructor_args': (128, 8),
'input_size': (10, 8, 128),
'extra_args': (torch.randn(10, 8, 128), torch.randn(10, 8, 128)),
'slowTest': True
},
{
'module_name': 'Transformer',
'constructor_args': (1, 1, 1, 1, 2),
'input_size': (3, 1, 1),
'extra_args': (torch.randn(1, 1, 1),),
'slowTest': True
}
]
EXCLUDE_SCRIPT_MODULES = { EXCLUDE_SCRIPT_MODULES = {
'test_nn_AdaptiveAvgPool2d_tuple_none', 'test_nn_AdaptiveAvgPool2d_tuple_none',
@ -719,4 +682,44 @@ def try_get_nn_module_compiled_mod_and_inputs(*args, **kwargs):
def get_all_nn_module_tests(): def get_all_nn_module_tests():
return module_tests + new_module_tests + additional_module_tests # additional modules test
# TODO: delete this list once we make all nn_tests work
additional_module_tests = [
{
'module_name': 'Bilinear',
'constructor_args': (S, S, M),
'input_size': (S, S),
'extra_args': ((S, S),)
},
{
'module_name': 'RNNCell',
'constructor_args': (S, S),
'input_size': (S, S),
},
{
'module_name': 'LSTMCell',
'constructor_args': (S, S),
'input_size': (S, S),
},
{
'module_name': 'GRUCell',
'constructor_args': (S, S),
'input_size': (S, S),
},
{
'module_name': 'MultiheadAttention',
'constructor_args': (128, 8),
'input_size': (10, 8, 128),
'extra_args': (torch.randn(10, 8, 128), torch.randn(10, 8, 128)),
'slowTest': True
},
{
'module_name': 'Transformer',
'constructor_args': (1, 1, 1, 1, 2),
'input_size': (3, 1, 1),
'extra_args': (torch.randn(1, 1, 1),),
'slowTest': True
}
]
return module_tests + get_new_module_tests() + additional_module_tests