From 792f1c47e93c54c8433acb8e2435baaf9a8bd2e2 Mon Sep 17 00:00:00 2001 From: albanD Date: Mon, 16 Dec 2024 19:20:42 -0500 Subject: [PATCH] No actual change, just remove variable contain Tensors from global scope (#143225) Pull Request resolved: https://github.com/pytorch/pytorch/pull/143225 Approved by: https://github.com/ezyang --- test/jit/test_complexity.py | 4 +- test/test_cpp_api_parity.py | 2 +- test/test_expanded_weights.py | 8 +- test/test_fx_experimental.py | 4 +- test/test_jit.py | 6 +- test/test_nn.py | 4 +- torch/ao/quantization/pt2e/qat_utils.py | 65 +- .../pt2e/representation/rewrite.py | 368 +- torch/ao/quantization/pt2e/utils.py | 22 - .../quantizer/xnnpack_quantizer_utils.py | 24 +- torch/testing/_internal/common_nn.py | 3263 +++++++++-------- .../_internal/jit_metaprogramming_utils.py | 525 +-- 12 files changed, 2157 insertions(+), 2138 deletions(-) diff --git a/test/jit/test_complexity.py b/test/jit/test_complexity.py index 2812ce13c1b..cd022eb5224 100644 --- a/test/jit/test_complexity.py +++ b/test/jit/test_complexity.py @@ -20,8 +20,8 @@ from torch.testing._internal.common_utils import ( from torch.testing._internal.jit_metaprogramming_utils import ( get_all_nn_module_tests, get_nn_functional_compiled_fn_and_inputs, + get_nn_functional_tests, get_nn_mod_test_name, - nn_functional_tests, try_get_nn_module_compiled_mod_and_inputs, ) from torch.testing._internal.jit_utils import enable_profiling_mode, JitTestCase @@ -70,7 +70,7 @@ class TestComplexity(JitTestCase): def test_generated_functional_tests(self): with enable_profiling_mode(): stats = [("Name", "Ifs/Loops", "non-tensor ops")] - for test in nn_functional_tests: + for test in get_nn_functional_tests(): test_name = test[0] fn, inputs = get_nn_functional_compiled_fn_and_inputs(*test) diff --git a/test/test_cpp_api_parity.py b/test/test_cpp_api_parity.py index df510792a2a..0c27051e8e5 100644 --- a/test/test_cpp_api_parity.py +++ b/test/test_cpp_api_parity.py @@ -42,7 +42,7 @@ if not common.IS_ARM64: (sample_module.module_tests, common_nn.NewModuleTest), (sample_functional.functional_tests, common_nn.NewModuleTest), (common_nn.module_tests, common_nn.NewModuleTest), - (common_nn.new_module_tests, common_nn.NewModuleTest), + (common_nn.get_new_module_tests(), common_nn.NewModuleTest), (common_nn.criterion_tests, common_nn.CriterionTest), ]: for test_params_dict in test_params_dicts: diff --git a/test/test_expanded_weights.py b/test/test_expanded_weights.py index fbeb8f77cb8..7f210bf79a2 100644 --- a/test/test_expanded_weights.py +++ b/test/test_expanded_weights.py @@ -25,7 +25,11 @@ from torch.testing._internal.common_device_type import ( ) from torch.testing._internal.common_methods_invocations import op_db, SampleInput from torch.testing._internal.common_modules import module_db, modules -from torch.testing._internal.common_nn import module_tests, new_module_tests, TestBase +from torch.testing._internal.common_nn import ( + get_new_module_tests, + module_tests, + TestBase, +) from torch.testing._internal.common_utils import ( freeze_rng_state, make_tensor, @@ -1011,7 +1015,7 @@ def filter_supported_tests(t): # TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests # These currently use the legacy nn tests supported_tests = [ - t for t in module_tests + new_module_tests if filter_supported_tests(t) + t for t in module_tests + get_new_module_tests() if filter_supported_tests(t) ] for test_param in supported_tests: if "constructor" not in test_param: diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 40cc6f1ad11..fac9365e60a 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -50,7 +50,7 @@ from torch.testing._internal.common_device_type import ( ops, ) from torch.testing._internal.common_methods_invocations import op_db -from torch.testing._internal.common_nn import module_tests, new_module_tests +from torch.testing._internal.common_nn import module_tests, get_new_module_tests from torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase from torch.testing._internal.jit_utils import JitTestCase import torch.utils._pytree as pytree @@ -1006,7 +1006,7 @@ terrible spacing Exhaustively test `Node.normalized_arguments` on all standard torch.nn Module classes """ - for test_params in module_tests + new_module_tests: + for test_params in module_tests + get_new_module_tests(): if "constructor" not in test_params: constructor = getattr(torch.nn, test_params["module_name"]) else: diff --git a/test/test_jit.py b/test/test_jit.py index af5f194b7f3..fd76a922536 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -107,10 +107,10 @@ from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, dis from torch.testing._internal.jit_metaprogramming_utils import ( get_script_args, create_input, unpack_variables, - additional_module_tests, EXCLUDE_SCRIPT_MODULES, + get_all_nn_module_tests, EXCLUDE_SCRIPT_MODULES, get_nn_module_name_from_kwargs, get_nn_mod_test_name, script_method_template) -from torch.testing._internal.common_nn import module_tests, new_module_tests, criterion_tests +from torch.testing._internal.common_nn import criterion_tests # For testing truediv in python 2 from torch.testing._internal.test_module.future_div import div_int_future, div_float_future @@ -16247,7 +16247,7 @@ class TestProducerVersion(TestCase): # issue gh-32561 self.assertTrue(torch.__version__.startswith(torch.onnx.producer_version)) -for test in module_tests + new_module_tests + additional_module_tests: +for test in get_all_nn_module_tests(): add_nn_module_test(**test) for test in criterion_tests: diff --git a/test/test_nn.py b/test/test_nn.py index 0af76d427e2..af30d2cf23a 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -38,7 +38,7 @@ from torch.testing._internal.common_utils import freeze_rng_state, run_tests, Te from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, PLATFORM_SUPPORTS_FLASH_ATTENTION from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \ module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \ - ctcloss_reference, new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input + ctcloss_reference, get_new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input from torch.testing._internal.common_device_type import dtypesIfMPS, instantiate_device_type_tests, dtypes, \ dtypesIfCUDA, precisionOverride, skipCUDAIfCudnnVersionLessThan, onlyCUDA, onlyCPU, \ skipCUDAIfRocm, skipCUDAIf, skipCUDAIfNotRocm, \ @@ -7332,7 +7332,7 @@ def add_test(test, decorator=None): else: add(cuda_test_name, with_tf32_off) -for test_params in module_tests + new_module_tests: +for test_params in module_tests + get_new_module_tests(): # TODO: CUDA is not implemented yet if 'constructor' not in test_params: name = test_params.pop('module_name') diff --git a/torch/ao/quantization/pt2e/qat_utils.py b/torch/ao/quantization/pt2e/qat_utils.py index 746ef788cd5..7c479550a3e 100644 --- a/torch/ao/quantization/pt2e/qat_utils.py +++ b/torch/ao/quantization/pt2e/qat_utils.py @@ -19,8 +19,6 @@ from torch.fx import Graph, GraphModule, Node from torch.fx.subgraph_rewriter import replace_pattern_with_filters, ReplacedPatterns from .utils import ( - _conv1d_bn_example_inputs, - _conv2d_bn_example_inputs, _get_aten_graph_module_for_pattern, _is_bn_node, _is_conv_or_conv_transpose_node, @@ -35,27 +33,6 @@ if TYPE_CHECKING: __all__ = [] # type: ignore[var-annotated] -# Example inputs for quantized and folded conv-bn1d patterns used in convert -_quantized_conv1d_bn_example_inputs = ( - torch.randn(1, 1, 3), # x - torch.randn(1, 1, 1), # conv_weight - torch.randn(1), # bn_weight - torch.randn(1), # bn_bias - torch.randn(1), # bn_running_mean - torch.randn(1), # bn_running_var -) - -# Example inputs for quantized and folded conv-bn2d patterns used in convert -_quantized_conv2d_bn_example_inputs = ( - torch.randn(1, 1, 3, 3), # x - torch.randn(1, 1, 1, 1), # conv_weight - torch.randn(1), # bn_weight - torch.randn(1), # bn_bias - torch.randn(1), # bn_running_mean - torch.randn(1), # bn_running_var -) - - def _get_quantized_conv_bn_example_inputs_kwargs( is_per_channel: bool, has_bias: bool, @@ -631,6 +608,28 @@ def _update_special_qspecs_after_replacement( def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule: + # Example inputs for conv-bn1d patterns + _conv1d_bn_example_inputs = ( + torch.randn(1, 1, 3), # x + torch.randn(1, 1, 1), # conv_weight + torch.randn(1), # conv_bias + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + # Example inputs for conv-bn2d patterns + _conv2d_bn_example_inputs = ( + torch.randn(1, 1, 3, 3), # x + torch.randn(1, 1, 1, 1), # conv_weight + torch.randn(1), # conv_bias + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + has_bn = any(_is_bn_node(n) for n in m.graph.nodes) if not has_bn: return m @@ -859,6 +858,26 @@ def _copy_over_q_dq_args(original_node: Node, replacement_node: Node): def _fold_conv_bn_qat(m: GraphModule) -> GraphModule: + # Example inputs for quantized and folded conv-bn1d patterns used in convert + _quantized_conv1d_bn_example_inputs = ( + torch.randn(1, 1, 3), # x + torch.randn(1, 1, 1), # conv_weight + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + # Example inputs for quantized and folded conv-bn2d patterns used in convert + _quantized_conv2d_bn_example_inputs = ( + torch.randn(1, 1, 3, 3), # x + torch.randn(1, 1, 1, 1), # conv_weight + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + has_bn = any(_is_bn_node(n) for n in m.graph.nodes) if not has_bn: return m diff --git a/torch/ao/quantization/pt2e/representation/rewrite.py b/torch/ao/quantization/pt2e/representation/rewrite.py index 6609032e928..bd35798d239 100644 --- a/torch/ao/quantization/pt2e/representation/rewrite.py +++ b/torch/ao/quantization/pt2e/representation/rewrite.py @@ -22,25 +22,6 @@ __all__ = [ ] -_QUANTIZED_LINEAR_EXAMPLE_INPUTS = ( - torch.randint(-128, 127, (2, 5), dtype=torch.int8), - torch.randn(1, dtype=torch.float), - torch.zeros(1, dtype=torch.int), - torch.tensor([-128], dtype=torch.int), - torch.tensor([127], dtype=torch.int), - torch.randint(-128, 127, (5, 5), dtype=torch.int8), - torch.randn(1, dtype=torch.float), - torch.zeros(1, dtype=torch.int), - torch.tensor([-127], dtype=torch.int), - torch.tensor([127], dtype=torch.int), - torch.randn(1, dtype=torch.float), - torch.randn(1, dtype=torch.float), - torch.zeros(1, dtype=torch.int), - torch.tensor([-128], dtype=torch.int), - torch.tensor([127], dtype=torch.int), -) - - def _qdq_quantized_linear( x_i8, x_scale, @@ -129,20 +110,6 @@ def _reference_quantized_linear( return out_i8 -_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS = ( - torch.randn((2, 5), dtype=torch.float), - -128, - 127, - torch.finfo(torch.float32).eps, - torch.randint(-128, 127, (5, 5), dtype=torch.int8), - torch.randn(1, dtype=torch.float), - torch.zeros(1, dtype=torch.int), - torch.tensor([-127], dtype=torch.int), - torch.tensor([127], dtype=torch.int), - torch.randn(1, dtype=torch.float), -) - - def _qdq_dynamic_quantized_linear( x_fp32, x_quant_min, @@ -223,25 +190,6 @@ def _reference_dynamic_quantized_linear( return out_fp32 -_QUANTIZED_CONV2d_EXAMPLE_INPUTS = ( - torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), - torch.randn(1, dtype=torch.float), - torch.zeros(1, dtype=torch.int), - torch.tensor([-128], dtype=torch.int), - torch.tensor([127], dtype=torch.int), - torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), - torch.randn(1, dtype=torch.float), - torch.zeros(1, dtype=torch.int), - torch.tensor([-127], dtype=torch.int), - torch.tensor([127], dtype=torch.int), - torch.randn(1, dtype=torch.float), - torch.randn(1, dtype=torch.float), - torch.zeros(1, dtype=torch.int), - torch.tensor([-128], dtype=torch.int), - torch.tensor([127], dtype=torch.int), -) - - def _qdq_quantized_conv2d( x_i8, x_scale, @@ -375,20 +323,6 @@ def _reference_quantized_conv2d( return out_i8 -_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS = ( - torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), - torch.randn(1, dtype=torch.float), - torch.zeros(1, dtype=torch.int), - torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), - torch.randn(1, dtype=torch.float), - torch.zeros(1, dtype=torch.int), - torch.randn(1, dtype=torch.float), - torch.zeros(1, dtype=torch.int), - torch.tensor([-128], dtype=torch.int), - torch.tensor([127], dtype=torch.int), -) - - def _qdq_quantized_add_relu( x_i8, x_scale, @@ -518,19 +452,6 @@ def _reference_quantized_add( return out_i8 -_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS = ( - torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), - torch.randn(1, dtype=torch.float), - torch.zeros(1, dtype=torch.int), - torch.tensor([-128], dtype=torch.int), - torch.tensor([127], dtype=torch.int), - torch.randn(1, dtype=torch.float), - torch.zeros(1, dtype=torch.int), - torch.tensor([-128], dtype=torch.int), - torch.tensor([127], dtype=torch.int), -) - - def _qdq_quantized_max_pool2d( x_i8, x_scale, @@ -587,15 +508,6 @@ def _reference_quantized_max_pool2d( return out_i8 -_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = ( - torch.randn(1, 3, 3, 3, dtype=torch.float), - torch.randn(1, dtype=torch.float), - torch.zeros(1, dtype=torch.int), - torch.tensor([-128], dtype=torch.int), - torch.tensor([127], dtype=torch.int), -) - - def _quantize_per_tensor_int8(x_fp32, scale, zero_point, quant_min, quant_max): x = torch.ops.quantized_decomposed.quantize_per_tensor( x_fp32, scale, zero_point, quant_min, quant_max, torch.int8 @@ -619,15 +531,6 @@ def _reference_quantize_per_tensor_int8( return x -_DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = ( - torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), - torch.randn(1, dtype=torch.float), - torch.zeros(1, dtype=torch.int), - torch.tensor([-128], dtype=torch.int), - torch.tensor([127], dtype=torch.int), -) - - def _dequantize_per_tensor_int8(x_i8, scale, zero_point, quant_min, quant_max): x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( x_i8, scale, zero_point, quant_min, quant_max, torch.int8 @@ -648,16 +551,6 @@ def _reference_dequantize_per_tensor_int8( return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32) -_QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = ( - torch.randn(1, 3, 3, 3, dtype=torch.float), - torch.randn(3, dtype=torch.float), - torch.zeros(3, dtype=torch.int), - 1, - -128, - 127, -) - - def _quantize_per_channel_int8( x_fp32, scales, zero_points, ch_axis, quant_min, quant_max ): @@ -678,16 +571,6 @@ def _reference_quantize_per_channel_int8( return out_i32.to(torch.int8) -_DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = ( - torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), - torch.randn(3, dtype=torch.float), - torch.zeros(3, dtype=torch.int), - 1, - -128, - 127, -) - - def _dequantize_per_channel_int8( x_i8, scales, zero_points, ch_axis, quant_min, quant_max ): @@ -733,79 +616,186 @@ class _RewriteInfo: replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None -_REWRITE_INFO_LIST = [ - _RewriteInfo( - _DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS, - _WrapperModule(_qdq_dynamic_quantized_linear), - _WrapperModule(_reference_dynamic_quantized_linear), - partial( - _replace_literals_with_existing_placeholders, - literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3}, - ), - partial( - _replace_literals_with_existing_placeholders, - literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3}, - ), - ), - _RewriteInfo( - _QUANTIZED_LINEAR_EXAMPLE_INPUTS, - _WrapperModule(_qdq_quantized_linear), - _WrapperModule(_reference_quantized_linear), - _replace_literals_with_new_placeholders, - _replace_literals_with_new_placeholders, - ), - _RewriteInfo( - _QUANTIZED_CONV2d_EXAMPLE_INPUTS, - _WrapperModule(_qdq_quantized_conv2d), - _WrapperModule(_reference_quantized_conv2d), - partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]), - partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]), - ), - _RewriteInfo( - _QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS, - _WrapperModule(_qdq_quantized_add_relu), - _WrapperModule(_reference_quantized_add_relu), - ), - _RewriteInfo( - _QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS, - _WrapperModule(_qdq_quantized_add), - _WrapperModule(_reference_quantized_add), - ), - _RewriteInfo( - _QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS, - _WrapperModule(_qdq_quantized_max_pool2d), - _WrapperModule(_reference_quantized_max_pool2d), - _replace_literals_with_new_placeholders, - _replace_literals_with_new_placeholders, - ), - _RewriteInfo( - _QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS, - _WrapperModule(_quantize_per_tensor_int8), - _WrapperModule(_reference_quantize_per_tensor_int8), - ), - _RewriteInfo( - _DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS, - _WrapperModule(_dequantize_per_tensor_int8), - _WrapperModule(_reference_dequantize_per_tensor_int8), - ), - _RewriteInfo( - _QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS, - _WrapperModule(_quantize_per_channel_int8), - _WrapperModule(_reference_quantize_per_channel_int8), - _replace_ph_qdq_per_channel_replacement, - _replace_ph_qdq_per_channel_replacement, - ), - _RewriteInfo( - _DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS, - _WrapperModule(_dequantize_per_channel_int8), - _WrapperModule(_reference_dequantize_per_channel_int8), - _replace_ph_qdq_per_channel_replacement, - _replace_ph_qdq_per_channel_replacement, - ), -] - - def reference_representation_rewrite(model: GraphModule) -> GraphModule: + _QUANTIZED_LINEAR_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (2, 5), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randint(-128, 127, (5, 5), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-127], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randn(1, dtype=torch.float), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + ) + + _DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS = ( + torch.randn((2, 5), dtype=torch.float), + -128, + 127, + torch.finfo(torch.float32).eps, + torch.randint(-128, 127, (5, 5), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-127], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randn(1, dtype=torch.float), + ) + + _QUANTIZED_CONV2d_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-127], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randn(1, dtype=torch.float), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + ) + + _QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + ) + + _QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + ) + + _QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = ( + torch.randn(1, 3, 3, 3, dtype=torch.float), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + ) + + _DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + ) + + _QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = ( + torch.randn(1, 3, 3, 3, dtype=torch.float), + torch.randn(3, dtype=torch.float), + torch.zeros(3, dtype=torch.int), + 1, + -128, + 127, + ) + + _DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(3, dtype=torch.float), + torch.zeros(3, dtype=torch.int), + 1, + -128, + 127, + ) + + _REWRITE_INFO_LIST = [ + _RewriteInfo( + _DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS, + _WrapperModule(_qdq_dynamic_quantized_linear), + _WrapperModule(_reference_dynamic_quantized_linear), + partial( + _replace_literals_with_existing_placeholders, + literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3}, + ), + partial( + _replace_literals_with_existing_placeholders, + literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3}, + ), + ), + _RewriteInfo( + _QUANTIZED_LINEAR_EXAMPLE_INPUTS, + _WrapperModule(_qdq_quantized_linear), + _WrapperModule(_reference_quantized_linear), + _replace_literals_with_new_placeholders, + _replace_literals_with_new_placeholders, + ), + _RewriteInfo( + _QUANTIZED_CONV2d_EXAMPLE_INPUTS, + _WrapperModule(_qdq_quantized_conv2d), + _WrapperModule(_reference_quantized_conv2d), + partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]), + partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]), + ), + _RewriteInfo( + _QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS, + _WrapperModule(_qdq_quantized_add_relu), + _WrapperModule(_reference_quantized_add_relu), + ), + _RewriteInfo( + _QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS, + _WrapperModule(_qdq_quantized_add), + _WrapperModule(_reference_quantized_add), + ), + _RewriteInfo( + _QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS, + _WrapperModule(_qdq_quantized_max_pool2d), + _WrapperModule(_reference_quantized_max_pool2d), + _replace_literals_with_new_placeholders, + _replace_literals_with_new_placeholders, + ), + _RewriteInfo( + _QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS, + _WrapperModule(_quantize_per_tensor_int8), + _WrapperModule(_reference_quantize_per_tensor_int8), + ), + _RewriteInfo( + _DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS, + _WrapperModule(_dequantize_per_tensor_int8), + _WrapperModule(_reference_dequantize_per_tensor_int8), + ), + _RewriteInfo( + _QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS, + _WrapperModule(_quantize_per_channel_int8), + _WrapperModule(_reference_quantize_per_channel_int8), + _replace_ph_qdq_per_channel_replacement, + _replace_ph_qdq_per_channel_replacement, + ), + _RewriteInfo( + _DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS, + _WrapperModule(_dequantize_per_channel_int8), + _WrapperModule(_reference_dequantize_per_channel_int8), + _replace_ph_qdq_per_channel_replacement, + _replace_ph_qdq_per_channel_replacement, + ), + ] + remove_tensor_overload_for_qdq_ops(model) from torch._export import gm_using_training_ir diff --git a/torch/ao/quantization/pt2e/utils.py b/torch/ao/quantization/pt2e/utils.py index 7b22bacbe57..d0801142bd1 100644 --- a/torch/ao/quantization/pt2e/utils.py +++ b/torch/ao/quantization/pt2e/utils.py @@ -33,28 +33,6 @@ _DEQUANTIZE_OPS = [ torch.ops.quantized_decomposed.dequantize_per_channel.default, ] -# Example inputs for conv-bn1d patterns -_conv1d_bn_example_inputs = ( - torch.randn(1, 1, 3), # x - torch.randn(1, 1, 1), # conv_weight - torch.randn(1), # conv_bias - torch.randn(1), # bn_weight - torch.randn(1), # bn_bias - torch.randn(1), # bn_running_mean - torch.randn(1), # bn_running_var -) - -# Example inputs for conv-bn2d patterns -_conv2d_bn_example_inputs = ( - torch.randn(1, 1, 3, 3), # x - torch.randn(1, 1, 1, 1), # conv_weight - torch.randn(1), # conv_bias - torch.randn(1), # bn_weight - torch.randn(1), # bn_bias - torch.randn(1), # bn_running_mean - torch.randn(1), # bn_running_var -) - def _is_connected(source: torch.fx.Node, dest: torch.fx.Node) -> bool: """ diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py index 2bbf4fffee7..fa86c6f4c30 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py @@ -10,8 +10,6 @@ from torch._subclasses import FakeTensor from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix from torch.ao.quantization.pt2e.export_utils import _WrapperModule from torch.ao.quantization.pt2e.utils import ( - _conv1d_bn_example_inputs, - _conv2d_bn_example_inputs, _get_aten_graph_module_for_pattern, _is_conv_node, _is_conv_transpose_node, @@ -487,6 +485,28 @@ def _do_annotate_conv_bn( for the following names: "input", "conv", "weight", "bias", and "output". """ + # Example inputs for conv-bn1d patterns + _conv1d_bn_example_inputs = ( + torch.randn(1, 1, 3), # x + torch.randn(1, 1, 1), # conv_weight + torch.randn(1), # conv_bias + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + # Example inputs for conv-bn2d patterns + _conv2d_bn_example_inputs = ( + torch.randn(1, 1, 3, 3), # x + torch.randn(1, 1, 1, 1), # conv_weight + torch.randn(1), # conv_bias + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + def get_pattern(conv_fn: Callable, relu_is_inplace: bool): def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv): conv = conv_fn(x, conv_weight, conv_bias) diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index aab3c42076d..eea6e5788ce 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -1076,1638 +1076,1643 @@ def single_batch_reference_fn(input, parameters, module): return module(*single_batch_input).squeeze(0) -new_module_tests = [ - poissonnllloss_no_reduce_test(), - bceloss_no_reduce_test(), - bceloss_weights_no_reduce_test(), - bce_with_logistic_legacy_enum_test(), - bce_with_logistic_no_reduce_test(), - bceloss_no_reduce_scalar_test(), - bceloss_weights_no_reduce_scalar_test(), - bce_with_logistic_no_reduce_scalar_test(), - kldivloss_with_target_no_reduce_test(), - kldivloss_no_reduce_test(), - kldivloss_no_reduce_scalar_test(), - kldivloss_with_log_target_no_reduce_test(), - kldivloss_no_reduce_log_target_test(), - kldivloss_no_reduce_scalar_log_target_test(), - l1loss_no_reduce_test(), - l1loss_no_reduce_complex_test(), - l1loss_no_reduce_scalar_test(), - mseloss_no_reduce_test(), - mseloss_no_reduce_scalar_test(), - nllloss_no_reduce_test(), - nllloss_no_reduce_ignore_index_test(), - nllloss_no_reduce_weights_test(), - nllloss_no_reduce_weights_ignore_index_test(), - nllloss_no_reduce_weights_ignore_index_neg_test(), - nllloss2d_no_reduce_test(), - nllloss2d_no_reduce_weights_test(), - nllloss2d_no_reduce_ignore_index_test(), - nlllossNd_no_reduce_test(), - nlllossNd_no_reduce_weights_test(), - nlllossNd_no_reduce_ignore_index_test(), - smoothl1loss_no_reduce_test(), - smoothl1loss_no_reduce_scalar_test(), - smoothl1loss_beta_test(), - smoothl1loss_zero_beta_test(), - huberloss_delta_test(), - multilabelmarginloss_0d_no_reduce_test(), - multilabelmarginloss_1d_no_reduce_test(), - multilabelmarginloss_index_neg_test(), - multilabelmarginloss_no_reduce_test(), - hingeembeddingloss_no_reduce_test(), - hingeembeddingloss_margin_no_reduce_test(), - softmarginloss_no_reduce_test(), - multilabelsoftmarginloss_no_reduce_test(), - multilabelsoftmarginloss_weights_no_reduce_test(), - multimarginloss_no_reduce_test(), - multimarginloss_1d_no_reduce_test(), - multimarginloss_1d_input_0d_target_no_reduce_test(), - multimarginloss_p_no_reduce_test(), - multimarginloss_margin_no_reduce_test(), - multimarginloss_weights_no_reduce_test(), - dict( - module_name='Conv1d', - constructor_args=(4, 5, 3), - cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3)', - input_size=(2, 4, 10), - cudnn=True, - with_tf32=True, - tf32_precision=0.005, - default_dtype=torch.double, - ), - dict( - module_name='Conv1d', - constructor_args=(4, 5, 3, 2), - cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).stride(2)', - input_size=(2, 4, 10), - cudnn=True, - desc='stride', - with_tf32=True, - tf32_precision=0.005, - default_dtype=torch.double, - ), - dict( - module_name='Conv1d', - constructor_args=(4, 5, 3, 1, 1), - cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).stride(1).padding(1)', - input_size=(2, 4, 10), - cudnn=True, - desc='pad1', - with_tf32=True, - tf32_precision=0.01, - default_dtype=torch.double, - ), - dict( - module_name='Conv1d', - constructor_args=(4, 5, 5, 1, 2), - cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 5).stride(1).padding(2)', - input_size=(2, 4, 10), - cudnn=True, - desc='pad2', - with_tf32=True, - tf32_precision=0.005, - default_dtype=torch.double, - ), - dict( - module_name='Conv1d', - constructor_args=(4, 4, 3, 1, 1), - cpp_constructor_args='torch::nn::Conv1dOptions(4, 4, 3).stride(1).padding(1)', - input_size=(1, 4, 1), - cudnn=True, - desc='pad1size1', - with_tf32=True, - tf32_precision=0.005, - default_dtype=torch.double, - ), - dict( - module_name='Conv1d', - constructor_args=(4, 4, 5, 1, 2), - cpp_constructor_args='torch::nn::Conv1dOptions(4, 4, 5).stride(1).padding(2)', - input_size=(1, 4, 1), - cudnn=True, - desc='pad2size1', - with_tf32=True, - tf32_precision=0.005, - default_dtype=torch.double, - ), - dict( - module_name='Conv1d', - constructor_args=(4, 5, 3), - cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3)', - input_size=(0, 4, 10), - cudnn=True, - desc='zero_batch', - with_tf32=True, - tf32_precision=0.005, - ), - dict( - fullname='Conv1d_dilated', - constructor=lambda: nn.Conv1d(4, 5, kernel_size=3, dilation=2), - cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).dilation(2)', - input_size=(2, 4, 10), - with_tf32=True, - tf32_precision=0.005, - default_dtype=torch.double, - ), - dict( - fullname='Conv1d_groups', - constructor=lambda: nn.Conv1d(4, 6, kernel_size=3, groups=2), - cpp_constructor_args='torch::nn::Conv1dOptions(4, 6, 3).groups(2)', - input_size=(2, 4, 6), - cudnn=True, - with_tf32=True, - tf32_precision=0.005, - default_dtype=torch.double, - ), - dict( - fullname='Conv1d_pad_valid', - constructor=lambda: nn.Conv1d(4, 5, 3, padding="valid"), - cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kValid)', - input_size=(2, 4, 10), - cudnn=True, - with_tf32=True, - tf32_precision=0.005, - default_dtype=torch.double, - ), - dict( - fullname='Conv1d_pad_same', - constructor=lambda: nn.Conv1d(4, 5, 3, padding="same"), - cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kSame)', - input_size=(2, 4, 10), - cudnn=True, - with_tf32=True, - tf32_precision=0.005, - default_dtype=torch.double, - ), - dict( - fullname='Conv1d_pad_same2', - constructor=lambda: nn.Conv1d(4, 5, 4, padding="same"), - cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 4).padding(torch::kSame)', - input_size=(2, 4, 10), - cudnn=True, - with_tf32=True, - tf32_precision=0.005, - default_dtype=torch.double, - ), - dict( - fullname='Conv1d_pad_same_dilated', - constructor=lambda: nn.Conv1d(4, 5, 4, padding="same", dilation=2), - cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kSame).dilation(2)', - input_size=(2, 4, 10), - cudnn=True, - with_tf32=True, - tf32_precision=0.005, - default_dtype=torch.double, - ), - dict( - fullname='ConvTranspose1d', - constructor=lambda: nn.ConvTranspose1d(3, 4, kernel_size=3, stride=(3,), padding=1, output_padding=(1,)), - cpp_constructor_args='torch::nn::ConvTranspose1dOptions(3, 4, 3).stride(3).padding(1).output_padding(1)', - cudnn=True, - input_size=(1, 3, 7), - with_tf32=True, - tf32_precision=0.005, - default_dtype=torch.double, - ), - dict( - module_name='ConvTranspose1d', - constructor_args=(3, 4, 3, 2, 1, 1, 1, False), - cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(3, 4, 3) - .stride(2).padding(1).output_padding(1).groups(1).bias(false)''', - input_size=(1, 3, 6), - cudnn=True, - desc='no_bias', - with_tf32=True, - tf32_precision=0.005, - default_dtype=torch.double, - ), - dict( - module_name='ConvTranspose1d', - constructor_args=(3, 4, 3, 2, 1, 1, 1, True, 2), - cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(3, 4, 3) - .stride(2).padding(1).output_padding(1).groups(1).bias(true).dilation(2)''', - input_size=(1, 3, 6), - cudnn=True, - desc='dilated', - with_tf32=True, - tf32_precision=0.005, - default_dtype=torch.double, - ), - dict( - fullname='ConvTranspose1d_groups', - constructor=lambda: nn.ConvTranspose1d(4, 6, 3, stride=(3,), padding=1, output_padding=(1,), groups=2), - cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(4, 6, 3) - .stride(3).padding(1).output_padding(1).groups(2)''', - cudnn=True, - input_size=(2, 4, 7), - with_tf32=True, - tf32_precision=0.005, - default_dtype=torch.double, - ), - dict( - module_name='Conv2d', - constructor_args=(3, 4, (3, 2)), - cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 2})', - input_size=(2, 3, 7, 5), - cudnn=True, - check_with_long_tensor=True, - with_tf32=True, - tf32_precision=0.005, - default_dtype=torch.double, - ), - dict( - module_name='Conv2d', - constructor_args=(3, 4, (3, 3), (2, 2)), - cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2})', - input_size=(2, 3, 6, 6), - cudnn=True, - desc='strided', - check_with_long_tensor=True, - with_tf32=True, - tf32_precision=0.005, - default_dtype=torch.double, - ), - dict( - module_name='Conv2d', - constructor_args=(3, 4, (3, 3), (2, 2), (1, 1)), - cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2}).padding({1, 1})', - input_size=(2, 3, 6, 6), - cudnn=True, - desc='padding', - check_with_long_tensor=True, - with_tf32=True, - tf32_precision=0.005, - default_dtype=torch.double, - ), - dict( - module_name='Conv2d', - constructor_args=(3, 2, (3, 3), (2, 2), (1, 1), (2, 2)), - cpp_constructor_args='torch::nn::Conv2dOptions(3, 2, {3, 3}).stride({2, 2}).padding({1, 1}).dilation({2, 2})', - input_size=(2, 3, 8, 8), - cudnn=True, - desc='dilated', - check_with_long_tensor=True, - with_tf32=True, - tf32_precision=0.005, - default_dtype=torch.double, - ), - dict( - module_name='Conv2d', - constructor_args=(3, 4, (3, 2), 1, 0, 1, 1, False), - cpp_constructor_args='''torch::nn::Conv2dOptions(3, 4, {3, 2}) - .stride(1).padding(0).dilation(1).groups(1).bias(false)''', - input_size=(2, 3, 6, 5), - cudnn=True, - desc='no_bias', - check_with_long_tensor=True, - with_tf32=True, - tf32_precision=0.015, - default_dtype=torch.double, - ), - dict( - module_name='Conv2d', - constructor_args=(3, 4, (3, 2)), - cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 2})', - input_size=(0, 3, 7, 5), - cudnn=True, - desc='zero_batch', - check_with_long_tensor=True, - with_tf32=True, - ), - dict( - fullname='Conv2d_groups', - constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2), - cpp_constructor_args='torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)', - input_size=(2, 4, 6, 5), - cudnn=True, - check_with_long_tensor=True, - with_tf32=True, - tf32_precision=0.015, - default_dtype=torch.double, - ), - dict( - fullname='Conv2d_groups_thnn', - constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2), - cpp_constructor_args='torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)', - input_size=(2, 4, 6, 5), - check_with_long_tensor=True, - with_tf32=True, - tf32_precision=0.015, - default_dtype=torch.double, - ), - dict( - fullname='Conv2d_pad_valid', - constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="valid"), - cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kValid)', - input_size=(2, 2, 6, 5), - cudnn=True, - with_tf32=True, - tf32_precision=0.005, - default_dtype=torch.double, - ), - dict( - fullname='Conv2d_pad_same', - constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="same"), - cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kSame)', - input_size=(2, 2, 6, 5), - cudnn=True, - with_tf32=True, - tf32_precision=0.01, - default_dtype=torch.double, - ), - dict( - fullname='Conv2d_pad_same_dilated', - constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="same", dilation=2), - cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kSame).dilation(2)', - input_size=(2, 2, 6, 5), - cudnn=True, - with_tf32=True, - tf32_precision=0.01, - default_dtype=torch.double, - ), - dict( - module_name='ConvTranspose2d', - constructor_args=(3, 4, 3, (3, 2), 1, (1, 1)), - cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3) - .stride({3, 2}).padding(1).output_padding({1, 1})''', - cudnn=True, - input_size=(1, 3, 7, 6), - check_with_long_tensor=True, - with_tf32=True, - tf32_precision=0.01, - default_dtype=torch.double, - ), - dict( - module_name='ConvTranspose2d', - constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False, (2, 2)), - cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3) - .stride({2, 3}) - .padding(1) - .output_padding({1, 1}) - .groups(1) - .bias(false) - .dilation({2, 2})''', - input_size=(1, 3, 6, 7), - cudnn=True, - desc='dilated', - check_with_long_tensor=True, - with_tf32=True, - tf32_precision=0.01, - default_dtype=torch.double, - ), - dict( - module_name='ConvTranspose2d', - constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False), - cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3) - .stride({2, 3}).padding(1).output_padding({1, 1}).groups(1).bias(false)''', - input_size=(1, 3, 6, 7), - cudnn=True, - desc='no_bias', - check_with_long_tensor=True, - with_tf32=True, - tf32_precision=0.01, - default_dtype=torch.double, - ), - dict( - fullname='ConvTranspose2d_groups', - constructor=lambda: nn.ConvTranspose2d(2, 4, (2, 3), groups=2), - cpp_constructor_args='torch::nn::ConvTranspose2dOptions(2, 4, {2, 3}).groups(2)', - input_size=(1, 2, 4, 5), - cudnn=True, - check_with_long_tensor=True, - with_tf32=True, - tf32_precision=0.01, - default_dtype=torch.double, - ), - dict( - fullname='Conv2d_depthwise', - constructor=lambda: nn.Conv2d(4, 4, (3, 3), groups=4), - cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).groups(4)', - input_size=(2, 4, 6, 6), - with_tf32=True, - tf32_precision=0.005, - default_dtype=torch.double, - ), - dict( - fullname='Conv2d_depthwise_with_multiplier', - constructor=lambda: nn.Conv2d(4, 8, (3, 3), groups=4), - cpp_constructor_args='torch::nn::Conv2dOptions(4, 8, {3, 3}).groups(4)', - input_size=(2, 4, 6, 6), - with_tf32=True, - tf32_precision=0.005, - default_dtype=torch.double, - ), - dict( - fullname='Conv2d_depthwise_strided', - constructor=lambda: nn.Conv2d(4, 4, (3, 3), stride=(2, 2), groups=4), - cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).stride({2, 2}).groups(4)', - input_size=(2, 4, 6, 6), - with_tf32=True, - tf32_precision=0.005, - default_dtype=torch.double, - ), - dict( - fullname='Conv2d_depthwise_padded', - constructor=lambda: nn.Conv2d(4, 4, (3, 3), padding=(1, 1), groups=4), - cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).padding({1, 1}).groups(4)', - input_size=(2, 4, 6, 6), - with_tf32=True, - tf32_precision=0.005, - default_dtype=torch.double, - ), - dict( - fullname='Conv2d_depthwise_dilated', - constructor=lambda: nn.Conv2d(4, 4, (2, 2), dilation=(2, 2), groups=4), - cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {2, 2}).dilation({2, 2}).groups(4)', - input_size=(2, 4, 5, 5), - with_tf32=True, - tf32_precision=0.005, - default_dtype=torch.double, - ), - dict( - module_name='Conv3d', - constructor_args=(2, 3, (2, 3, 2)), - cpp_constructor_args='torch::nn::Conv3dOptions(2, 3, {2, 3, 2})', - input_size=(1, 2, 4, 5, 4), - cudnn=True, - check_with_long_tensor=True, - with_tf32=True, - tf32_precision=0.05, - default_dtype=torch.double, - ), - dict( - module_name='Conv3d', - constructor_args=(2, 3, (2, 3, 4), 1, 0, 1, 1, False), - cpp_constructor_args='''torch::nn::Conv3dOptions(2, 3, {2, 3, 4}) - .stride(1).padding(0).dilation(1).groups(1).bias(false)''', - input_size=(1, 2, 3, 4, 5), - cudnn=True, - desc='no_bias', - check_with_long_tensor=True, - with_tf32=True, - tf32_precision=0.05, - default_dtype=torch.double, - ), - dict( - module_name='Conv3d', - constructor_args=(2, 3, (1, 1, 1), 1, 0, 1, 1, False), - cpp_constructor_args='''torch::nn::Conv3dOptions(2, 3, {2, 3, 4}) - .stride(1).padding(0).dilation(1).groups(1).bias(false)''', - input_size=(1, 2, 3, 4, 5), - cudnn=True, - desc='1x1x1_no_bias', - check_with_long_tensor=False, - with_tf32=True, - tf32_precision=0.05, - default_dtype=torch.double, - ), - dict( - module_name='Conv3d', - constructor_args=(3, 4, 2, 2), - cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).stride(2)', - input_size=(2, 3, 5, 5, 5), - cudnn=True, - desc='stride', - check_with_long_tensor=True, - with_tf32=True, - tf32_precision=0.05, - default_dtype=torch.double, - ), - dict( - module_name='Conv3d', - constructor_args=(3, 4, 2, 2, 1), - cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).stride(2).padding(1)', - input_size=(2, 3, 5, 5, 5), - cudnn=True, - desc='stride_padding', - check_with_long_tensor=True, - with_tf32=True, - tf32_precision=0.05, - default_dtype=torch.double, - ), - dict( - module_name='Conv3d', - constructor_args=(3, 4, (2, 3, 4)), - cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4})', - input_size=(0, 3, 3, 4, 5), - cudnn=True, - check_with_long_tensor=True, - desc='zero_batch', - with_tf32=True, - ), - dict( - fullname='Conv3d_groups', - constructor=lambda: nn.Conv3d(2, 4, kernel_size=3, groups=2), - cpp_constructor_args='torch::nn::Conv3dOptions(2, 4, 3).groups(2)', - input_size=(1, 2, 4, 5, 4), - cudnn=True, - check_with_long_tensor=True, - with_tf32=True, - tf32_precision=0.005, - default_dtype=torch.double, - ), - dict( - fullname='Conv3d_dilated', - constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2), - cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2)', - input_size=(2, 3, 5, 5, 5), - with_tf32=True, - tf32_precision=0.05, - default_dtype=torch.double, - ), - dict( - fullname='Conv3d_dilated_strided', - constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2, stride=2), - cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2).stride(2)', - input_size=(2, 3, 5, 5, 5), - with_tf32=True, - tf32_precision=0.05, - default_dtype=torch.double, - ), - dict( - fullname='Conv3d_pad_valid', - constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="valid"), - cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kValid)', - input_size=(2, 3, 6, 5, 4), - cudnn=True, - with_tf32=True, - tf32_precision=0.05, - default_dtype=torch.double, - ), - dict( - fullname='Conv3d_pad_same', - constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="same"), - cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kSame)', - input_size=(2, 3, 6, 5, 4), - cudnn=True, - with_tf32=True, - tf32_precision=0.05, - default_dtype=torch.double, - ), - dict( - fullname='Conv3d_pad_same_dilated', - constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="same", dilation=2), - cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kSame).dilation(2)', - input_size=(2, 3, 6, 5, 4), - cudnn=True, - with_tf32=True, - tf32_precision=0.05, - default_dtype=torch.double, - ), - dict( - module_name='ConvTranspose3d', - constructor_args=(2, 3, (2, 3, 2)), - cpp_constructor_args='torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2})', - cudnn=True, - input_size=(1, 2, 4, 5, 4), - with_tf32=True, - tf32_precision=0.05, - default_dtype=torch.double, - ), - dict( - module_name='ConvTranspose3d', - constructor_args=(2, 3, (2, 3, 2), 1, 0, 0, 1, True, (2, 2, 2)), - cpp_constructor_args='''torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2}) - .stride(1).padding(0).output_padding(0).groups(1).bias(true).dilation({2, 2, 2})''', - cudnn=True, - input_size=(1, 2, 4, 5, 4), - desc='dilated', - with_tf32=True, - tf32_precision=0.05, - default_dtype=torch.double, - ), - dict( - module_name='ReplicationPad3d', - constructor_args=((1, 2, 3, 3, 2, 1),), - cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})', - input_size=(2, 3, 2, 2, 2), - default_dtype=torch.double, - ), - dict( - module_name='ReplicationPad3d', - constructor_args=((1, 2, 3, 3, 2, 1),), - cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})', - input_size=(3, 2, 2, 2), - reference_fn=single_batch_reference_fn, - desc='no_batch_dim', - default_dtype=torch.double, - ), - dict( - module_name='ReplicationPad3d', - constructor_args=((1, 2, 3, 3, 2, 1),), - cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})', - input_fn=lambda: torch.rand(2, 3, 2, 2, 2, dtype=torch.complex128, requires_grad=True), - skip_half=True, - desc='complex' - ), - dict( - module_name='Embedding', - constructor_args=(4, 3), - cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)', - input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), - check_gradgrad=False, - default_dtype=torch.double, - decorator=skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/117971") - ), - dict( - module_name='Embedding', - constructor_args=(4, 3), - cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)', - input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512), - check_gradgrad=False, - desc='discontiguous', - default_dtype=torch.double, - decorator=skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/117971") - ), - dict( - module_name='EmbeddingBag', - constructor_args=(4, 3), - cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)', - input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), - check_gradgrad=False, - desc='mean', - default_dtype=torch.double, - ), - dict( - module_name='EmbeddingBag', - constructor_args=(4, 3), - cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)', - input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512), - check_gradgrad=False, - desc='discontiguous', - default_dtype=torch.double, - ), - dict( - module_name='EmbeddingBag', - constructor_args=(4, 3, None, 2., False, 'sum'), - cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3) - .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum)''', - input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), - check_gradgrad=False, - desc='sum', - default_dtype=torch.double, - ), - dict( - module_name='EmbeddingBag', - constructor_args=(4, 3, None, 2., False, 'max'), - cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3) - .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax)''', - input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), - check_gradgrad=False, - desc='max', - default_dtype=torch.double, - ), - dict( - fullname='EmbeddingBag_mean_padding_idx', - constructor=lambda: nn.EmbeddingBag(4, 3, padding_idx=1), - cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3).padding_idx(1)', - input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]), - check_gradgrad=False, - default_dtype=torch.double, - ), - dict( - fullname='EmbeddingBag_sum_padding_idx', - constructor=lambda: nn.EmbeddingBag(4, 3, None, 2., False, 'sum', padding_idx=1), - cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3) - .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum).padding_idx(1)''', - input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]), - check_gradgrad=False, - default_dtype=torch.double, - ), - dict( - fullname='EmbeddingBag_max_padding_idx', - constructor=lambda: nn.EmbeddingBag(4, 3, None, 2., False, 'max', padding_idx=1), - cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3) - .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax).padding_idx(1)''', - input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]), - check_gradgrad=False, - default_dtype=torch.double, - ), - dict( - fullname='EmbeddingBag_sparse', - constructor=lambda: nn.EmbeddingBag(4, 3, sparse=True, dtype=torch.double), - cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3).sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))', - input_fn=lambda: torch.randperm(2).repeat(1, 2), - check_gradgrad=False, - has_sparse_gradients=True, - ), - dict( - constructor=lambda: nn.Embedding(4, 3, dtype=torch.double, sparse=True), - cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3).sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))', - input_fn=lambda: torch.randperm(2).repeat(1, 2), - fullname='Embedding_sparse', - check_gradgrad=False, - has_sparse_gradients=True, - ), - dict( - module_name='PixelShuffle', - constructor_args=(3,), - cpp_constructor_args='torch::nn::PixelShuffleOptions(3)', - input_size=(1, 9, 4, 4), - default_dtype=torch.double, - ), - dict( - module_name='PixelUnshuffle', - constructor_args=(3,), - cpp_constructor_args='torch::nn::PixelUnshuffleOptions(3)', - input_size=(1, 1, 12, 12), - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::vector({12})).scale_factor(std::nullopt).mode(torch::kNearest)''', - input_size=(1, 2, 4), - fullname='interpolate_nearest_1d', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::vector({12})).scale_factor(std::nullopt).mode(torch::kNearest)''', - input_size=(0, 2, 4), - fullname='interpolate_nearest_1d_zero_dim', - pickle=False, - ), - dict( - constructor=wrap_functional(F.interpolate, size=(12, ), scale_factor=None, mode='nearest'), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::vector({12})).scale_factor(std::nullopt).mode(torch::kNearest)''', - input_size=(1, 2, 3), - fullname='interpolate_nearest_tuple_1d', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::nullopt).scale_factor(std::vector({4.})).mode(torch::kNearest)''', - input_size=(1, 2, 4), - fullname='interpolate_nearest_scale_1d', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=False), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::vector({12})) - .scale_factor(std::nullopt) - .mode(torch::kLinear) - .align_corners(false)''', - input_size=(1, 2, 4), - fullname='interpolate_linear_1d', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=(4, ), scale_factor=None, mode='linear', align_corners=False), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::vector({4})) - .scale_factor(std::nullopt) - .mode(torch::kLinear) - .align_corners(false)''', - input_size=(1, 2, 3), - fullname='interpolate_linear_tuple_1d', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='linear', align_corners=False), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::nullopt) - .scale_factor(std::vector({4.})) - .mode(torch::kLinear) - .align_corners(false)''', - input_size=(1, 2, 4), - fullname='interpolate_linear_scale_1d', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=False), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::vector({12})) - .scale_factor(std::nullopt) - .mode(torch::kLinear) - .align_corners(false)''', - input_size=(0, 2, 4), - fullname='interpolate_linear_1d_zero_dim', - pickle=False, - ), - dict( - constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=True), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::vector({12})) - .scale_factor(std::nullopt) - .mode(torch::kLinear) - .align_corners(true)''', - input_size=(1, 2, 4), - fullname='interpolate_linear_1d_align_corners', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='linear', align_corners=True), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::nullopt) - .scale_factor(std::vector({4.})) - .mode(torch::kLinear) - .align_corners(true)''', - input_size=(1, 2, 4), - fullname='interpolate_linear_scale_1d_align_corners', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=2, scale_factor=None, mode='nearest'), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::vector({2, 2})) - .scale_factor(std::nullopt) - .mode(torch::kNearest)''', - input_size=(1, 128, 1, 1), - fullname='interpolate_nearest_2d_launch_configs', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::vector({12, 12})) - .scale_factor(std::nullopt) - .mode(torch::kNearest)''', - input_size=(1, 2, 4, 4), - fullname='interpolate_nearest_2d', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=(12, 16), scale_factor=None, mode='nearest'), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::vector({12, 16})) - .scale_factor(std::nullopt) - .mode(torch::kNearest)''', - input_size=(1, 2, 3, 4), - fullname='interpolate_nearest_tuple_2d', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::nullopt) - .scale_factor(std::vector({4., 4.})) - .mode(torch::kNearest)''', - input_size=(1, 2, 4, 4), - fullname='interpolate_nearest_scale_2d', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::vector({12, 12})) - .scale_factor(std::nullopt) - .mode(torch::kNearest)''', - input_size=(0, 2, 4, 4), - fullname='interpolate_nearest_2d_zero_dim', - pickle=False, - ), - dict( - constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bilinear', align_corners=False), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::vector({12, 12})) - .scale_factor(std::nullopt) - .mode(torch::kBilinear) - .align_corners(false)''', - input_size=(1, 2, 4, 4), - fullname='interpolate_bilinear_2d', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bilinear', align_corners=False), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::vector({12, 12})) - .scale_factor(std::nullopt) - .mode(torch::kBilinear) - .align_corners(false)''', - input_size=(0, 2, 4, 4), - fullname='interpolate_bilinear_2d_zero_dim', - pickle=False, - ), - dict( - constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, - mode='bilinear', align_corners=False), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::vector({4, 6})) - .scale_factor(std::nullopt) - .mode(torch::kBilinear) - .align_corners(false)''', - input_size=(1, 2, 2, 3), - fullname='interpolate_bilinear_tuple_2d', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., - mode='bilinear', align_corners=False), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::nullopt) - .scale_factor(std::vector({4., 4.})) - .mode(torch::kBilinear) - .align_corners(false)''', - input_size=(1, 2, 4, 4), - fullname='interpolate_bilinear_scale_2d', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 2.), - mode='bilinear', align_corners=False), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::nullopt) - .scale_factor(std::vector({2., 2.})) - .mode(torch::kBilinear) - .align_corners(false)''', - input_size=(1, 2, 4, 4), - fullname='interpolate_bilinear_scale_tuple_shared_2d', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.), - mode='bilinear', align_corners=False), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::nullopt) - .scale_factor(std::vector({2., 1.})) - .mode(torch::kBilinear) - .align_corners(false)''', - input_size=(1, 2, 4, 4), - fullname='interpolate_bilinear_scale_tuple_skewed_2d', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, mode='bilinear', align_corners=True), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::vector({4, 6})) - .scale_factor(std::nullopt) - .mode(torch::kBilinear) - .align_corners(true)''', - input_size=(1, 2, 4, 4), - fullname='interpolate_bilinear_tuple_2d_align_corners', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.), - mode='bilinear', align_corners=True), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::nullopt) - .scale_factor(std::vector({2., 1.})) - .mode(torch::kBilinear) - .align_corners(true)''', - input_size=(1, 2, 4, 4), - fullname='interpolate_bilinear_scale_tuple_skewed_2d_align_corners', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bicubic', align_corners=False), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::vector({12, 12})) - .scale_factor(std::nullopt) - .mode(torch::kBicubic) - .align_corners(false)''', - input_size=(1, 2, 4, 4), - fullname='interpolate_bicubic_2d', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bicubic', align_corners=False), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::vector({12, 12})) - .scale_factor(std::nullopt) - .mode(torch::kBicubic) - .align_corners(false)''', - input_size=(0, 2, 4, 4), - fullname='interpolate_bicubic_2d_zero_dim', - pickle=False, - ), - dict( - constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, - mode='bicubic', align_corners=False), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::vector({4, 6})) - .scale_factor(std::nullopt) - .mode(torch::kBicubic) - .align_corners(false)''', - input_size=(1, 2, 2, 3), - fullname='interpolate_bicubic_tuple_2d', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='bicubic', align_corners=False), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::nullopt) - .scale_factor(std::vector({4., 4.})) - .mode(torch::kBicubic) - .align_corners(false)''', - input_size=(1, 2, 4, 4), - fullname='interpolate_bicubic_scale_2d', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 2.), - mode='bicubic', align_corners=False), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::nullopt) - .scale_factor(std::vector({2., 2.})) - .mode(torch::kBicubic) - .align_corners(false)''', - input_size=(1, 2, 4, 4), - fullname='interpolate_bicubic_scale_tuple_shared_2d', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.), - mode='bicubic', align_corners=False), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::nullopt) - .scale_factor(std::vector({2., 1.})) - .mode(torch::kBicubic) - .align_corners(false)''', - input_size=(1, 2, 4, 4), - fullname='interpolate_bicubic_scale_tuple_skewed_2d', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, mode='bicubic', align_corners=True), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::vector({4, 6})) - .scale_factor(std::nullopt) - .mode(torch::kBicubic) - .align_corners(true)''', - input_size=(1, 2, 4, 4), - fullname='interpolate_bicubic_tuple_2d_align_corners', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.), - mode='bicubic', align_corners=True), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::nullopt) - .scale_factor(std::vector({2., 1.})) - .mode(torch::kBicubic) - .align_corners(true)''', - input_size=(1, 2, 4, 4), - fullname='interpolate_bicubic_scale_tuple_skewed_2d_align_corners', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::vector({12, 12, 12})) - .scale_factor(std::nullopt) - .mode(torch::kNearest)''', - input_size=(1, 2, 4, 4, 4), - fullname='interpolate_nearest_3d', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::vector({12, 12, 12})) - .scale_factor(std::nullopt) - .mode(torch::kNearest)''', - input_size=(0, 2, 4, 4, 4), - fullname='interpolate_nearest_3d_zero_dim', - pickle=False, - ), - dict( - constructor=wrap_functional(F.interpolate, size=(12, 16, 16), scale_factor=None, mode='nearest'), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::vector({12, 16, 16})) - .scale_factor(std::nullopt) - .mode(torch::kNearest)''', - input_size=(1, 2, 3, 4, 4), - fullname='interpolate_nearest_tuple_3d', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::nullopt) - .scale_factor(std::vector({4., 4., 4.})) - .mode(torch::kNearest)''', - input_size=(1, 2, 4, 4, 4), - fullname='interpolate_nearest_scale_3d', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='trilinear', align_corners=False), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::vector({12, 12, 12})) - .scale_factor(std::nullopt) - .mode(torch::kTrilinear) - .align_corners(false)''', - input_size=(1, 2, 4, 4, 4), - fullname='interpolate_trilinear_3d', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='trilinear', align_corners=False), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::vector({12, 12, 12})) - .scale_factor(std::nullopt) - .mode(torch::kTrilinear) - .align_corners(false)''', - input_size=(0, 2, 4, 4, 4), - fullname='interpolate_trilinear_3d_zero_dim', - pickle=False, - ), - dict( - constructor=wrap_functional(F.interpolate, size=(4, 6, 6), - scale_factor=None, mode='trilinear', align_corners=False), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::vector({4, 6, 6})) - .scale_factor(std::nullopt) - .mode(torch::kTrilinear) - .align_corners(false)''', - input_size=(1, 2, 2, 3, 3), - fullname='interpolate_trilinear_tuple_3d', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=None, scale_factor=3., mode='trilinear', align_corners=False), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::nullopt) - .scale_factor(std::vector({3., 3., 3.})) - .mode(torch::kTrilinear) - .align_corners(false)''', - input_size=(1, 2, 3, 4, 5), - fullname='interpolate_trilinear_scale_3d', - # See https://github.com/pytorch/pytorch/issues/5006 - precision=3e-4, - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.interpolate, size=(4, 6, 6), scale_factor=None, - mode='trilinear', align_corners=True), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::vector({4, 6, 6})) - .scale_factor(std::nullopt) - .mode(torch::kTrilinear) - .align_corners(true)''', - input_size=(1, 2, 2, 3, 3), - fullname='interpolate_trilinear_tuple_3d_align_corners', - pickle=False, - default_dtype=torch.double - ), - dict( - constructor=wrap_functional(F.interpolate, size=None, scale_factor=3., mode='trilinear', align_corners=True), - cpp_options_args='''F::InterpolateFuncOptions() - .size(std::nullopt) - .scale_factor(std::vector({3., 3., 3.})) - .mode(torch::kTrilinear) - .align_corners(true)''', - input_size=(1, 2, 3, 4, 4), - fullname='interpolate_trilinear_scale_3d_align_corners', - # See https://github.com/pytorch/pytorch/issues/5006 - precision=3e-4, - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.softmax, dim=-1), - cpp_options_args='F::SoftmaxFuncOptions(-1)', - input_size=(2, 128), # trigger the last-dim algo in CUDA - fullname='softmax_lastdim', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64), - cpp_options_args='F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)', - input_size=(2, 128), - fullname='softmax_lastdim_dtype', - pickle=False, - test_cuda=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.softmax, dim=1), - cpp_options_args='F::SoftmaxFuncOptions(1)', - input_size=(2, 128, 2, 2), # trigger special case of spatial CUDA algo - fullname='softmax_spatial_special', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.softmax, dim=1), - cpp_options_args='F::SoftmaxFuncOptions(1)', - input_size=(2, 2, 4, 4), # regular spatial algorithm - fullname='softmax_spatial', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64), - cpp_options_args='F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)', - input_size=(2, 2, 4, 4), # regular spatial algorithm - fullname='softmax_spatial_dtype', - pickle=False, - test_cuda=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.softmax, dim=0), - cpp_options_args='F::SoftmaxFuncOptions(0)', - input_size=(2, 3, 4, 5), - fullname='softmax_functional_dim0', - test_cuda=False, - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.softmax, dim=3), - cpp_options_args='F::SoftmaxFuncOptions(3)', - input_size=(2, 3, 4, 5), - fullname='softmax_functional_dim3', - test_cuda=False, - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.softmax, dim=-1), - cpp_options_args='F::SoftmaxFuncOptions(-1)', - input_size=(), - fullname='softmax_functional_scalar', - test_cuda=False, - pickle=False, - ), - dict( - constructor=wrap_functional(F.log_softmax, dim=-1), - cpp_options_args='F::LogSoftmaxFuncOptions(-1)', - input_size=(2, 128), # trigger the last-dim algo in CUDA - fullname='log_softmax_lastdim', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.log_softmax, dim=1), - cpp_options_args='F::LogSoftmaxFuncOptions(1)', - input_size=(2, 128, 2, 2), # trigger special case of spatial CUDA algo - fullname='log_softmax_spatial_special', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.log_softmax, dim=1), - cpp_options_args='F::LogSoftmaxFuncOptions(1)', - input_size=(2, 2, 4, 4), # regular spatial algorithm - fullname='log_softmax_spatial', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.log_softmax, dim=0), - cpp_options_args='F::LogSoftmaxFuncOptions(0)', - input_size=(2, 3, 4, 5), - fullname='log_softmax_dim0', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.log_softmax, dim=3), - cpp_options_args='F::LogSoftmaxFuncOptions(3)', - input_size=(2, 3, 4, 5), - fullname='log_softmax_dim3', - pickle=False, - default_dtype=torch.double, - ), - dict( - constructor=wrap_functional(F.log_softmax, dim=0), - cpp_options_args='F::LogSoftmaxFuncOptions(0)', - input_size=(), - fullname='log_softmax_scalar', - pickle=False, - ), - dict( - fullname='Unfold', - constructor=lambda: nn.Unfold((2, 2), (1, 1), (0, 0), (1, 1)), - cpp_constructor_args='torch::nn::UnfoldOptions({2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})', - input_size=(2, 4, 3, 3), - check_gradgrad=False, - test_cuda=True, - default_dtype=torch.double, - ), - dict( - fullname='Fold', - constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)), - cpp_constructor_args='torch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})', - input_size=(2, 16, 4), - check_gradgrad=False, - test_cuda=True, - default_dtype=torch.double, - ), - dict( - fullname='Fold_no_batch_dim_input', - constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)), - cpp_constructor_args='torch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})', - input_size=(16, 4), - check_gradgrad=False, - ref=single_batch_reference_fn, - test_cuda=True, - default_dtype=torch.double, - ), - dict( - fullname='Unfold_int_input', - constructor=lambda: nn.Unfold(2, 1, 0, 1), - cpp_constructor_args='torch::nn::UnfoldOptions(2).dilation(1).padding(0).stride(1)', - input_size=(2, 4, 3, 3), - check_gradgrad=False, - test_cuda=True, - default_dtype=torch.double, - ), - dict( - fullname='Fold_int_input', - constructor=lambda: nn.Fold(3, 2, 1, 0, 1), - cpp_constructor_args='torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1)', - input_size=(2, 16, 4), - check_gradgrad=False, - test_cuda=True, - default_dtype=torch.double, - ), - dict( - fullname='Fold_no_batch_dim_int_input', - constructor=lambda: nn.Fold(3, 2, 1, 0, 1), - cpp_constructor_args='torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1)', - input_size=(16, 4), - ref=single_batch_reference_fn, - check_gradgrad=False, - test_cuda=True, - default_dtype=torch.double, - ), - dict( - module_name='RReLU', - constructor_args=(0.1, 0.9), - cpp_constructor_args='torch::nn::RReLUOptions().lower(0.1).upper(0.9)', - input_size=(), - desc='with_up_down_scalar', - test_cuda=False, - default_dtype=torch.double, - ), - dict( - module_name='PairwiseDistance', - input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)), - default_dtype=torch.double, - ), - dict( - module_name='PairwiseDistance', - input_fn=lambda: (torch.randn(10, 1), torch.randn(10, 8)), - desc='broadcast_lhs', - default_dtype=torch.double, - ), - dict( - module_name='PairwiseDistance', - input_fn=lambda: (torch.randn(10, 8), torch.randn(1, 8)), - desc='broadcast_rhs', - default_dtype=torch.double, - ), - dict( - module_name='PairwiseDistance', - constructor_args=(1.5, 1e-05, True), - cpp_constructor_args='torch::nn::PairwiseDistanceOptions().p(1.5).eps(1e-05).keepdim(true)', - input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)), - desc='with_non_default_args', - default_dtype=torch.double, - ), - dict( - module_name='PairwiseDistance', - input_fn=lambda: (torch.randn(8), torch.randn(8)), - reference_fn=single_batch_reference_fn, - desc='no_batch_dim', - default_dtype=torch.double, - ), - dict( - module_name='TransformerEncoderLayer', - constructor_args=(4, 2, 16, 0.0), - cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2) - .dim_feedforward(16) - .dropout(0.0)''', - input_size=(2, 3, 4), - desc='relu_activation', - with_tf32=True, - tf32_precision=0.1, - # TODO(#50743): figure out the error - # RuntimeError: The size of tensor a (6) must match the size of tensor b (4) - # at non-singleton dimension 2 - check_batched_grad=False, - check_gradgrad=False, - default_dtype=torch.double, - ), - dict( - module_name='TransformerEncoderLayer', - constructor_args=(4, 2, 8, 0.0, F.gelu), - cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2) - .dim_feedforward(8) - .dropout(0.0) - .activation(torch::kGELU)''', - input_size=(2, 3, 4), - check_gradgrad=False, - desc='gelu_activation', - with_tf32=True, - tf32_precision=0.08 if SM90OrLater else 0.05, - default_dtype=torch.double, - ), - dict( - module_name='TransformerDecoderLayer', - constructor_args=(4, 2, 8, 0.0), - cpp_constructor_args='''torch::nn::TransformerDecoderLayerOptions(4, 2) - .dim_feedforward(8) - .dropout(0.0)''', - input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4)), - check_gradgrad=False, - desc='relu_activation', - with_tf32=True, - tf32_precision=0.05, - default_dtype=torch.double, - ), - dict( - module_name='TransformerDecoderLayer', - constructor_args=(4, 2, 8, 0.0, F.gelu), - cpp_constructor_args='''torch::nn::TransformerDecoderLayerOptions(4, 2) - .dim_feedforward(8) - .dropout(0.0) - .activation(torch::kGELU)''', - input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4)), - check_gradgrad=False, - desc='gelu_activation', - with_tf32=True, - tf32_precision=0.05, - default_dtype=torch.double, - ), - dict( - module_name='Transformer', - constructor_args=(4, 2, 2, 2, 8, 0.0, F.relu), - cpp_constructor_args='''torch::nn::TransformerOptions() - .d_model(4) - .nhead(2) - .num_encoder_layers(2) - .num_decoder_layers(2) - .dim_feedforward(8) - .dropout(0.0) - .activation(torch::kReLU)''', - input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4), torch.rand(3, 3)), - check_gradgrad=False, - desc='multilayer_coder', - with_tf32=True, - tf32_precision=0.05 if SM90OrLater else 0.03, - default_dtype=torch.double, - ), - dict( - module_name='Linear', - constructor_args=(3, 5), - cpp_constructor_args='torch::nn::LinearOptions(3, 5)', - input_fn=lambda: torch.rand(3), - reference_fn=lambda i, p, _: torch.mm(i.view(1, -1), p[0].t()).view(-1) + p[1], - desc="no_batch_dim", - with_tf32=True, - tf32_precision=0.005, - default_dtype=torch.double, - ), - dict( - module_name='Flatten', - cpp_constructor_args='torch::nn::FlattenOptions().start_dim(-3).end_dim(-1)', - constructor_args=(-3, -1), - input_size=(3, 4, 5), - reference_fn=single_batch_reference_fn, - desc="no_batch_dim", - default_dtype=torch.double, - ), - dict( - module_name='Unflatten', - cpp_constructor_args='torch::nn::UnflattenOptions(-2, {2, 2})', - constructor_args=(-2, torch.Size([2, 2])), - input_size=(3, 4, 5), - reference_fn=single_batch_reference_fn, - desc="no_batch_dim", - default_dtype=torch.double, - ), - dict( - module_name='LayerNorm', - constructor_args=([56, 56, 56], 1e-5, False), - cpp_constructor_args='torch::nn::LayerNormOptions({56, 56, 56}).eps(1e-5).elementwise_affine(false)', - input_size=(4, 56, 56, 56), - cudnn=True, - check_eval=True, - gradcheck_fast_mode=True, - check_half=True, - desc='3d_no_affine_large_feature', - ), -] +def get_new_module_tests(): + new_module_tests = [ + poissonnllloss_no_reduce_test(), + bceloss_no_reduce_test(), + bceloss_weights_no_reduce_test(), + bce_with_logistic_legacy_enum_test(), + bce_with_logistic_no_reduce_test(), + bceloss_no_reduce_scalar_test(), + bceloss_weights_no_reduce_scalar_test(), + bce_with_logistic_no_reduce_scalar_test(), + kldivloss_with_target_no_reduce_test(), + kldivloss_no_reduce_test(), + kldivloss_no_reduce_scalar_test(), + kldivloss_with_log_target_no_reduce_test(), + kldivloss_no_reduce_log_target_test(), + kldivloss_no_reduce_scalar_log_target_test(), + l1loss_no_reduce_test(), + l1loss_no_reduce_complex_test(), + l1loss_no_reduce_scalar_test(), + mseloss_no_reduce_test(), + mseloss_no_reduce_scalar_test(), + nllloss_no_reduce_test(), + nllloss_no_reduce_ignore_index_test(), + nllloss_no_reduce_weights_test(), + nllloss_no_reduce_weights_ignore_index_test(), + nllloss_no_reduce_weights_ignore_index_neg_test(), + nllloss2d_no_reduce_test(), + nllloss2d_no_reduce_weights_test(), + nllloss2d_no_reduce_ignore_index_test(), + nlllossNd_no_reduce_test(), + nlllossNd_no_reduce_weights_test(), + nlllossNd_no_reduce_ignore_index_test(), + smoothl1loss_no_reduce_test(), + smoothl1loss_no_reduce_scalar_test(), + smoothl1loss_beta_test(), + smoothl1loss_zero_beta_test(), + huberloss_delta_test(), + multilabelmarginloss_0d_no_reduce_test(), + multilabelmarginloss_1d_no_reduce_test(), + multilabelmarginloss_index_neg_test(), + multilabelmarginloss_no_reduce_test(), + hingeembeddingloss_no_reduce_test(), + hingeembeddingloss_margin_no_reduce_test(), + softmarginloss_no_reduce_test(), + multilabelsoftmarginloss_no_reduce_test(), + multilabelsoftmarginloss_weights_no_reduce_test(), + multimarginloss_no_reduce_test(), + multimarginloss_1d_no_reduce_test(), + multimarginloss_1d_input_0d_target_no_reduce_test(), + multimarginloss_p_no_reduce_test(), + multimarginloss_margin_no_reduce_test(), + multimarginloss_weights_no_reduce_test(), + dict( + module_name='Conv1d', + constructor_args=(4, 5, 3), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3)', + input_size=(2, 4, 10), + cudnn=True, + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='Conv1d', + constructor_args=(4, 5, 3, 2), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).stride(2)', + input_size=(2, 4, 10), + cudnn=True, + desc='stride', + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='Conv1d', + constructor_args=(4, 5, 3, 1, 1), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).stride(1).padding(1)', + input_size=(2, 4, 10), + cudnn=True, + desc='pad1', + with_tf32=True, + tf32_precision=0.01, + default_dtype=torch.double, + ), + dict( + module_name='Conv1d', + constructor_args=(4, 5, 5, 1, 2), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 5).stride(1).padding(2)', + input_size=(2, 4, 10), + cudnn=True, + desc='pad2', + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='Conv1d', + constructor_args=(4, 4, 3, 1, 1), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 4, 3).stride(1).padding(1)', + input_size=(1, 4, 1), + cudnn=True, + desc='pad1size1', + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='Conv1d', + constructor_args=(4, 4, 5, 1, 2), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 4, 5).stride(1).padding(2)', + input_size=(1, 4, 1), + cudnn=True, + desc='pad2size1', + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='Conv1d', + constructor_args=(4, 5, 3), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3)', + input_size=(0, 4, 10), + cudnn=True, + desc='zero_batch', + with_tf32=True, + tf32_precision=0.005, + ), + dict( + fullname='Conv1d_dilated', + constructor=lambda: nn.Conv1d(4, 5, kernel_size=3, dilation=2), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).dilation(2)', + input_size=(2, 4, 10), + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + fullname='Conv1d_groups', + constructor=lambda: nn.Conv1d(4, 6, kernel_size=3, groups=2), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 6, 3).groups(2)', + input_size=(2, 4, 6), + cudnn=True, + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + fullname='Conv1d_pad_valid', + constructor=lambda: nn.Conv1d(4, 5, 3, padding="valid"), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kValid)', + input_size=(2, 4, 10), + cudnn=True, + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + fullname='Conv1d_pad_same', + constructor=lambda: nn.Conv1d(4, 5, 3, padding="same"), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kSame)', + input_size=(2, 4, 10), + cudnn=True, + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + fullname='Conv1d_pad_same2', + constructor=lambda: nn.Conv1d(4, 5, 4, padding="same"), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 4).padding(torch::kSame)', + input_size=(2, 4, 10), + cudnn=True, + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + fullname='Conv1d_pad_same_dilated', + constructor=lambda: nn.Conv1d(4, 5, 4, padding="same", dilation=2), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kSame).dilation(2)', + input_size=(2, 4, 10), + cudnn=True, + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + fullname='ConvTranspose1d', + constructor=lambda: nn.ConvTranspose1d(3, 4, kernel_size=3, stride=(3,), padding=1, output_padding=(1,)), + cpp_constructor_args='torch::nn::ConvTranspose1dOptions(3, 4, 3).stride(3).padding(1).output_padding(1)', + cudnn=True, + input_size=(1, 3, 7), + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='ConvTranspose1d', + constructor_args=(3, 4, 3, 2, 1, 1, 1, False), + cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(3, 4, 3) + .stride(2).padding(1).output_padding(1).groups(1).bias(false)''', + input_size=(1, 3, 6), + cudnn=True, + desc='no_bias', + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='ConvTranspose1d', + constructor_args=(3, 4, 3, 2, 1, 1, 1, True, 2), + cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(3, 4, 3) + .stride(2).padding(1).output_padding(1).groups(1).bias(true).dilation(2)''', + input_size=(1, 3, 6), + cudnn=True, + desc='dilated', + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + fullname='ConvTranspose1d_groups', + constructor=lambda: nn.ConvTranspose1d(4, 6, 3, stride=(3,), padding=1, output_padding=(1,), groups=2), + cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(4, 6, 3) + .stride(3).padding(1).output_padding(1).groups(2)''', + cudnn=True, + input_size=(2, 4, 7), + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='Conv2d', + constructor_args=(3, 4, (3, 2)), + cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 2})', + input_size=(2, 3, 7, 5), + cudnn=True, + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='Conv2d', + constructor_args=(3, 4, (3, 3), (2, 2)), + cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2})', + input_size=(2, 3, 6, 6), + cudnn=True, + desc='strided', + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='Conv2d', + constructor_args=(3, 4, (3, 3), (2, 2), (1, 1)), + cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2}).padding({1, 1})', + input_size=(2, 3, 6, 6), + cudnn=True, + desc='padding', + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='Conv2d', + constructor_args=(3, 2, (3, 3), (2, 2), (1, 1), (2, 2)), + cpp_constructor_args='torch::nn::Conv2dOptions(3, 2, {3, 3}).stride({2, 2}).padding({1, 1}).dilation({2, 2})', + input_size=(2, 3, 8, 8), + cudnn=True, + desc='dilated', + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='Conv2d', + constructor_args=(3, 4, (3, 2), 1, 0, 1, 1, False), + cpp_constructor_args='''torch::nn::Conv2dOptions(3, 4, {3, 2}) + .stride(1).padding(0).dilation(1).groups(1).bias(false)''', + input_size=(2, 3, 6, 5), + cudnn=True, + desc='no_bias', + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.015, + default_dtype=torch.double, + ), + dict( + module_name='Conv2d', + constructor_args=(3, 4, (3, 2)), + cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 2})', + input_size=(0, 3, 7, 5), + cudnn=True, + desc='zero_batch', + check_with_long_tensor=True, + with_tf32=True, + ), + dict( + fullname='Conv2d_groups', + constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2), + cpp_constructor_args='torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)', + input_size=(2, 4, 6, 5), + cudnn=True, + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.015, + default_dtype=torch.double, + ), + dict( + fullname='Conv2d_groups_thnn', + constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2), + cpp_constructor_args='torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)', + input_size=(2, 4, 6, 5), + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.015, + default_dtype=torch.double, + ), + dict( + fullname='Conv2d_pad_valid', + constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="valid"), + cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kValid)', + input_size=(2, 2, 6, 5), + cudnn=True, + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + fullname='Conv2d_pad_same', + constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="same"), + cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kSame)', + input_size=(2, 2, 6, 5), + cudnn=True, + with_tf32=True, + tf32_precision=0.01, + default_dtype=torch.double, + ), + dict( + fullname='Conv2d_pad_same_dilated', + constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="same", dilation=2), + cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kSame).dilation(2)', + input_size=(2, 2, 6, 5), + cudnn=True, + with_tf32=True, + tf32_precision=0.01, + default_dtype=torch.double, + ), + dict( + module_name='ConvTranspose2d', + constructor_args=(3, 4, 3, (3, 2), 1, (1, 1)), + cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3) + .stride({3, 2}).padding(1).output_padding({1, 1})''', + cudnn=True, + input_size=(1, 3, 7, 6), + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.01, + default_dtype=torch.double, + ), + dict( + module_name='ConvTranspose2d', + constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False, (2, 2)), + cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3) + .stride({2, 3}) + .padding(1) + .output_padding({1, 1}) + .groups(1) + .bias(false) + .dilation({2, 2})''', + input_size=(1, 3, 6, 7), + cudnn=True, + desc='dilated', + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.01, + default_dtype=torch.double, + ), + dict( + module_name='ConvTranspose2d', + constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False), + cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3) + .stride({2, 3}).padding(1).output_padding({1, 1}).groups(1).bias(false)''', + input_size=(1, 3, 6, 7), + cudnn=True, + desc='no_bias', + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.01, + default_dtype=torch.double, + ), + dict( + fullname='ConvTranspose2d_groups', + constructor=lambda: nn.ConvTranspose2d(2, 4, (2, 3), groups=2), + cpp_constructor_args='torch::nn::ConvTranspose2dOptions(2, 4, {2, 3}).groups(2)', + input_size=(1, 2, 4, 5), + cudnn=True, + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.01, + default_dtype=torch.double, + ), + dict( + fullname='Conv2d_depthwise', + constructor=lambda: nn.Conv2d(4, 4, (3, 3), groups=4), + cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).groups(4)', + input_size=(2, 4, 6, 6), + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + fullname='Conv2d_depthwise_with_multiplier', + constructor=lambda: nn.Conv2d(4, 8, (3, 3), groups=4), + cpp_constructor_args='torch::nn::Conv2dOptions(4, 8, {3, 3}).groups(4)', + input_size=(2, 4, 6, 6), + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + fullname='Conv2d_depthwise_strided', + constructor=lambda: nn.Conv2d(4, 4, (3, 3), stride=(2, 2), groups=4), + cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).stride({2, 2}).groups(4)', + input_size=(2, 4, 6, 6), + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + fullname='Conv2d_depthwise_padded', + constructor=lambda: nn.Conv2d(4, 4, (3, 3), padding=(1, 1), groups=4), + cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).padding({1, 1}).groups(4)', + input_size=(2, 4, 6, 6), + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + fullname='Conv2d_depthwise_dilated', + constructor=lambda: nn.Conv2d(4, 4, (2, 2), dilation=(2, 2), groups=4), + cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {2, 2}).dilation({2, 2}).groups(4)', + input_size=(2, 4, 5, 5), + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='Conv3d', + constructor_args=(2, 3, (2, 3, 2)), + cpp_constructor_args='torch::nn::Conv3dOptions(2, 3, {2, 3, 2})', + input_size=(1, 2, 4, 5, 4), + cudnn=True, + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + dict( + module_name='Conv3d', + constructor_args=(2, 3, (2, 3, 4), 1, 0, 1, 1, False), + cpp_constructor_args='''torch::nn::Conv3dOptions(2, 3, {2, 3, 4}) + .stride(1).padding(0).dilation(1).groups(1).bias(false)''', + input_size=(1, 2, 3, 4, 5), + cudnn=True, + desc='no_bias', + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + dict( + module_name='Conv3d', + constructor_args=(2, 3, (1, 1, 1), 1, 0, 1, 1, False), + cpp_constructor_args='''torch::nn::Conv3dOptions(2, 3, {2, 3, 4}) + .stride(1).padding(0).dilation(1).groups(1).bias(false)''', + input_size=(1, 2, 3, 4, 5), + cudnn=True, + desc='1x1x1_no_bias', + check_with_long_tensor=False, + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + dict( + module_name='Conv3d', + constructor_args=(3, 4, 2, 2), + cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).stride(2)', + input_size=(2, 3, 5, 5, 5), + cudnn=True, + desc='stride', + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + dict( + module_name='Conv3d', + constructor_args=(3, 4, 2, 2, 1), + cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).stride(2).padding(1)', + input_size=(2, 3, 5, 5, 5), + cudnn=True, + desc='stride_padding', + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + dict( + module_name='Conv3d', + constructor_args=(3, 4, (2, 3, 4)), + cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4})', + input_size=(0, 3, 3, 4, 5), + cudnn=True, + check_with_long_tensor=True, + desc='zero_batch', + with_tf32=True, + ), + dict( + fullname='Conv3d_groups', + constructor=lambda: nn.Conv3d(2, 4, kernel_size=3, groups=2), + cpp_constructor_args='torch::nn::Conv3dOptions(2, 4, 3).groups(2)', + input_size=(1, 2, 4, 5, 4), + cudnn=True, + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + fullname='Conv3d_dilated', + constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2), + cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2)', + input_size=(2, 3, 5, 5, 5), + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + dict( + fullname='Conv3d_dilated_strided', + constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2, stride=2), + cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2).stride(2)', + input_size=(2, 3, 5, 5, 5), + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + dict( + fullname='Conv3d_pad_valid', + constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="valid"), + cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kValid)', + input_size=(2, 3, 6, 5, 4), + cudnn=True, + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + dict( + fullname='Conv3d_pad_same', + constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="same"), + cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kSame)', + input_size=(2, 3, 6, 5, 4), + cudnn=True, + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + dict( + fullname='Conv3d_pad_same_dilated', + constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="same", dilation=2), + cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kSame).dilation(2)', + input_size=(2, 3, 6, 5, 4), + cudnn=True, + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + dict( + module_name='ConvTranspose3d', + constructor_args=(2, 3, (2, 3, 2)), + cpp_constructor_args='torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2})', + cudnn=True, + input_size=(1, 2, 4, 5, 4), + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + dict( + module_name='ConvTranspose3d', + constructor_args=(2, 3, (2, 3, 2), 1, 0, 0, 1, True, (2, 2, 2)), + cpp_constructor_args='''torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2}) + .stride(1).padding(0).output_padding(0).groups(1).bias(true).dilation({2, 2, 2})''', + cudnn=True, + input_size=(1, 2, 4, 5, 4), + desc='dilated', + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + dict( + module_name='ReplicationPad3d', + constructor_args=((1, 2, 3, 3, 2, 1),), + cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})', + input_size=(2, 3, 2, 2, 2), + default_dtype=torch.double, + ), + dict( + module_name='ReplicationPad3d', + constructor_args=((1, 2, 3, 3, 2, 1),), + cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})', + input_size=(3, 2, 2, 2), + reference_fn=single_batch_reference_fn, + desc='no_batch_dim', + default_dtype=torch.double, + ), + dict( + module_name='ReplicationPad3d', + constructor_args=((1, 2, 3, 3, 2, 1),), + cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})', + input_fn=lambda: torch.rand(2, 3, 2, 2, 2, dtype=torch.complex128, requires_grad=True), + skip_half=True, + desc='complex' + ), + dict( + module_name='Embedding', + constructor_args=(4, 3), + cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)', + input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), + check_gradgrad=False, + default_dtype=torch.double, + decorator=skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/117971") + ), + dict( + module_name='Embedding', + constructor_args=(4, 3), + cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)', + input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512), + check_gradgrad=False, + desc='discontiguous', + default_dtype=torch.double, + decorator=skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/117971") + ), + dict( + module_name='EmbeddingBag', + constructor_args=(4, 3), + cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)', + input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), + check_gradgrad=False, + desc='mean', + default_dtype=torch.double, + ), + dict( + module_name='EmbeddingBag', + constructor_args=(4, 3), + cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)', + input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512), + check_gradgrad=False, + desc='discontiguous', + default_dtype=torch.double, + ), + dict( + module_name='EmbeddingBag', + constructor_args=(4, 3, None, 2., False, 'sum'), + cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3) + .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum)''', + input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), + check_gradgrad=False, + desc='sum', + default_dtype=torch.double, + ), + dict( + module_name='EmbeddingBag', + constructor_args=(4, 3, None, 2., False, 'max'), + cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3) + .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax)''', + input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), + check_gradgrad=False, + desc='max', + default_dtype=torch.double, + ), + dict( + fullname='EmbeddingBag_mean_padding_idx', + constructor=lambda: nn.EmbeddingBag(4, 3, padding_idx=1), + cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3).padding_idx(1)', + input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]), + check_gradgrad=False, + default_dtype=torch.double, + ), + dict( + fullname='EmbeddingBag_sum_padding_idx', + constructor=lambda: nn.EmbeddingBag(4, 3, None, 2., False, 'sum', padding_idx=1), + cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3) + .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum).padding_idx(1)''', + input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]), + check_gradgrad=False, + default_dtype=torch.double, + ), + dict( + fullname='EmbeddingBag_max_padding_idx', + constructor=lambda: nn.EmbeddingBag(4, 3, None, 2., False, 'max', padding_idx=1), + cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3) + .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax).padding_idx(1)''', + input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]), + check_gradgrad=False, + default_dtype=torch.double, + ), + dict( + fullname='EmbeddingBag_sparse', + constructor=lambda: nn.EmbeddingBag(4, 3, sparse=True, dtype=torch.double), + cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3) + .sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))''', + input_fn=lambda: torch.randperm(2).repeat(1, 2), + check_gradgrad=False, + has_sparse_gradients=True, + ), + dict( + constructor=lambda: nn.Embedding(4, 3, dtype=torch.double, sparse=True), + cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3).sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))', + input_fn=lambda: torch.randperm(2).repeat(1, 2), + fullname='Embedding_sparse', + check_gradgrad=False, + has_sparse_gradients=True, + ), + dict( + module_name='PixelShuffle', + constructor_args=(3,), + cpp_constructor_args='torch::nn::PixelShuffleOptions(3)', + input_size=(1, 9, 4, 4), + default_dtype=torch.double, + ), + dict( + module_name='PixelUnshuffle', + constructor_args=(3,), + cpp_constructor_args='torch::nn::PixelUnshuffleOptions(3)', + input_size=(1, 1, 12, 12), + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12})).scale_factor(std::nullopt).mode(torch::kNearest)''', + input_size=(1, 2, 4), + fullname='interpolate_nearest_1d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12})).scale_factor(std::nullopt).mode(torch::kNearest)''', + input_size=(0, 2, 4), + fullname='interpolate_nearest_1d_zero_dim', + pickle=False, + ), + dict( + constructor=wrap_functional(F.interpolate, size=(12, ), scale_factor=None, mode='nearest'), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12})).scale_factor(std::nullopt).mode(torch::kNearest)''', + input_size=(1, 2, 3), + fullname='interpolate_nearest_tuple_1d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt).scale_factor(std::vector({4.})).mode(torch::kNearest)''', + input_size=(1, 2, 4), + fullname='interpolate_nearest_scale_1d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12})) + .scale_factor(std::nullopt) + .mode(torch::kLinear) + .align_corners(false)''', + input_size=(1, 2, 4), + fullname='interpolate_linear_1d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=(4, ), scale_factor=None, mode='linear', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({4})) + .scale_factor(std::nullopt) + .mode(torch::kLinear) + .align_corners(false)''', + input_size=(1, 2, 3), + fullname='interpolate_linear_tuple_1d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='linear', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt) + .scale_factor(std::vector({4.})) + .mode(torch::kLinear) + .align_corners(false)''', + input_size=(1, 2, 4), + fullname='interpolate_linear_scale_1d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12})) + .scale_factor(std::nullopt) + .mode(torch::kLinear) + .align_corners(false)''', + input_size=(0, 2, 4), + fullname='interpolate_linear_1d_zero_dim', + pickle=False, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=True), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12})) + .scale_factor(std::nullopt) + .mode(torch::kLinear) + .align_corners(true)''', + input_size=(1, 2, 4), + fullname='interpolate_linear_1d_align_corners', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='linear', align_corners=True), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt) + .scale_factor(std::vector({4.})) + .mode(torch::kLinear) + .align_corners(true)''', + input_size=(1, 2, 4), + fullname='interpolate_linear_scale_1d_align_corners', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=2, scale_factor=None, mode='nearest'), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({2, 2})) + .scale_factor(std::nullopt) + .mode(torch::kNearest)''', + input_size=(1, 128, 1, 1), + fullname='interpolate_nearest_2d_launch_configs', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12, 12})) + .scale_factor(std::nullopt) + .mode(torch::kNearest)''', + input_size=(1, 2, 4, 4), + fullname='interpolate_nearest_2d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=(12, 16), scale_factor=None, mode='nearest'), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12, 16})) + .scale_factor(std::nullopt) + .mode(torch::kNearest)''', + input_size=(1, 2, 3, 4), + fullname='interpolate_nearest_tuple_2d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt) + .scale_factor(std::vector({4., 4.})) + .mode(torch::kNearest)''', + input_size=(1, 2, 4, 4), + fullname='interpolate_nearest_scale_2d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12, 12})) + .scale_factor(std::nullopt) + .mode(torch::kNearest)''', + input_size=(0, 2, 4, 4), + fullname='interpolate_nearest_2d_zero_dim', + pickle=False, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bilinear', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12, 12})) + .scale_factor(std::nullopt) + .mode(torch::kBilinear) + .align_corners(false)''', + input_size=(1, 2, 4, 4), + fullname='interpolate_bilinear_2d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bilinear', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12, 12})) + .scale_factor(std::nullopt) + .mode(torch::kBilinear) + .align_corners(false)''', + input_size=(0, 2, 4, 4), + fullname='interpolate_bilinear_2d_zero_dim', + pickle=False, + ), + dict( + constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, + mode='bilinear', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({4, 6})) + .scale_factor(std::nullopt) + .mode(torch::kBilinear) + .align_corners(false)''', + input_size=(1, 2, 2, 3), + fullname='interpolate_bilinear_tuple_2d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., + mode='bilinear', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt) + .scale_factor(std::vector({4., 4.})) + .mode(torch::kBilinear) + .align_corners(false)''', + input_size=(1, 2, 4, 4), + fullname='interpolate_bilinear_scale_2d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 2.), + mode='bilinear', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt) + .scale_factor(std::vector({2., 2.})) + .mode(torch::kBilinear) + .align_corners(false)''', + input_size=(1, 2, 4, 4), + fullname='interpolate_bilinear_scale_tuple_shared_2d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.), + mode='bilinear', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt) + .scale_factor(std::vector({2., 1.})) + .mode(torch::kBilinear) + .align_corners(false)''', + input_size=(1, 2, 4, 4), + fullname='interpolate_bilinear_scale_tuple_skewed_2d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, mode='bilinear', align_corners=True), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({4, 6})) + .scale_factor(std::nullopt) + .mode(torch::kBilinear) + .align_corners(true)''', + input_size=(1, 2, 4, 4), + fullname='interpolate_bilinear_tuple_2d_align_corners', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.), + mode='bilinear', align_corners=True), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt) + .scale_factor(std::vector({2., 1.})) + .mode(torch::kBilinear) + .align_corners(true)''', + input_size=(1, 2, 4, 4), + fullname='interpolate_bilinear_scale_tuple_skewed_2d_align_corners', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bicubic', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12, 12})) + .scale_factor(std::nullopt) + .mode(torch::kBicubic) + .align_corners(false)''', + input_size=(1, 2, 4, 4), + fullname='interpolate_bicubic_2d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bicubic', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12, 12})) + .scale_factor(std::nullopt) + .mode(torch::kBicubic) + .align_corners(false)''', + input_size=(0, 2, 4, 4), + fullname='interpolate_bicubic_2d_zero_dim', + pickle=False, + ), + dict( + constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, + mode='bicubic', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({4, 6})) + .scale_factor(std::nullopt) + .mode(torch::kBicubic) + .align_corners(false)''', + input_size=(1, 2, 2, 3), + fullname='interpolate_bicubic_tuple_2d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='bicubic', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt) + .scale_factor(std::vector({4., 4.})) + .mode(torch::kBicubic) + .align_corners(false)''', + input_size=(1, 2, 4, 4), + fullname='interpolate_bicubic_scale_2d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 2.), + mode='bicubic', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt) + .scale_factor(std::vector({2., 2.})) + .mode(torch::kBicubic) + .align_corners(false)''', + input_size=(1, 2, 4, 4), + fullname='interpolate_bicubic_scale_tuple_shared_2d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.), + mode='bicubic', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt) + .scale_factor(std::vector({2., 1.})) + .mode(torch::kBicubic) + .align_corners(false)''', + input_size=(1, 2, 4, 4), + fullname='interpolate_bicubic_scale_tuple_skewed_2d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, mode='bicubic', align_corners=True), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({4, 6})) + .scale_factor(std::nullopt) + .mode(torch::kBicubic) + .align_corners(true)''', + input_size=(1, 2, 4, 4), + fullname='interpolate_bicubic_tuple_2d_align_corners', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.), + mode='bicubic', align_corners=True), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt) + .scale_factor(std::vector({2., 1.})) + .mode(torch::kBicubic) + .align_corners(true)''', + input_size=(1, 2, 4, 4), + fullname='interpolate_bicubic_scale_tuple_skewed_2d_align_corners', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12, 12, 12})) + .scale_factor(std::nullopt) + .mode(torch::kNearest)''', + input_size=(1, 2, 4, 4, 4), + fullname='interpolate_nearest_3d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12, 12, 12})) + .scale_factor(std::nullopt) + .mode(torch::kNearest)''', + input_size=(0, 2, 4, 4, 4), + fullname='interpolate_nearest_3d_zero_dim', + pickle=False, + ), + dict( + constructor=wrap_functional(F.interpolate, size=(12, 16, 16), scale_factor=None, mode='nearest'), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12, 16, 16})) + .scale_factor(std::nullopt) + .mode(torch::kNearest)''', + input_size=(1, 2, 3, 4, 4), + fullname='interpolate_nearest_tuple_3d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt) + .scale_factor(std::vector({4., 4., 4.})) + .mode(torch::kNearest)''', + input_size=(1, 2, 4, 4, 4), + fullname='interpolate_nearest_scale_3d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='trilinear', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12, 12, 12})) + .scale_factor(std::nullopt) + .mode(torch::kTrilinear) + .align_corners(false)''', + input_size=(1, 2, 4, 4, 4), + fullname='interpolate_trilinear_3d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='trilinear', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12, 12, 12})) + .scale_factor(std::nullopt) + .mode(torch::kTrilinear) + .align_corners(false)''', + input_size=(0, 2, 4, 4, 4), + fullname='interpolate_trilinear_3d_zero_dim', + pickle=False, + ), + dict( + constructor=wrap_functional(F.interpolate, size=(4, 6, 6), + scale_factor=None, mode='trilinear', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({4, 6, 6})) + .scale_factor(std::nullopt) + .mode(torch::kTrilinear) + .align_corners(false)''', + input_size=(1, 2, 2, 3, 3), + fullname='interpolate_trilinear_tuple_3d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=3., mode='trilinear', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt) + .scale_factor(std::vector({3., 3., 3.})) + .mode(torch::kTrilinear) + .align_corners(false)''', + input_size=(1, 2, 3, 4, 5), + fullname='interpolate_trilinear_scale_3d', + # See https://github.com/pytorch/pytorch/issues/5006 + precision=3e-4, + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=(4, 6, 6), scale_factor=None, + mode='trilinear', align_corners=True), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({4, 6, 6})) + .scale_factor(std::nullopt) + .mode(torch::kTrilinear) + .align_corners(true)''', + input_size=(1, 2, 2, 3, 3), + fullname='interpolate_trilinear_tuple_3d_align_corners', + pickle=False, + default_dtype=torch.double + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=3., mode='trilinear', align_corners=True), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt) + .scale_factor(std::vector({3., 3., 3.})) + .mode(torch::kTrilinear) + .align_corners(true)''', + input_size=(1, 2, 3, 4, 4), + fullname='interpolate_trilinear_scale_3d_align_corners', + # See https://github.com/pytorch/pytorch/issues/5006 + precision=3e-4, + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.softmax, dim=-1), + cpp_options_args='F::SoftmaxFuncOptions(-1)', + input_size=(2, 128), # trigger the last-dim algo in CUDA + fullname='softmax_lastdim', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64), + cpp_options_args='F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)', + input_size=(2, 128), + fullname='softmax_lastdim_dtype', + pickle=False, + test_cuda=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.softmax, dim=1), + cpp_options_args='F::SoftmaxFuncOptions(1)', + input_size=(2, 128, 2, 2), # trigger special case of spatial CUDA algo + fullname='softmax_spatial_special', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.softmax, dim=1), + cpp_options_args='F::SoftmaxFuncOptions(1)', + input_size=(2, 2, 4, 4), # regular spatial algorithm + fullname='softmax_spatial', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64), + cpp_options_args='F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)', + input_size=(2, 2, 4, 4), # regular spatial algorithm + fullname='softmax_spatial_dtype', + pickle=False, + test_cuda=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.softmax, dim=0), + cpp_options_args='F::SoftmaxFuncOptions(0)', + input_size=(2, 3, 4, 5), + fullname='softmax_functional_dim0', + test_cuda=False, + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.softmax, dim=3), + cpp_options_args='F::SoftmaxFuncOptions(3)', + input_size=(2, 3, 4, 5), + fullname='softmax_functional_dim3', + test_cuda=False, + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.softmax, dim=-1), + cpp_options_args='F::SoftmaxFuncOptions(-1)', + input_size=(), + fullname='softmax_functional_scalar', + test_cuda=False, + pickle=False, + ), + dict( + constructor=wrap_functional(F.log_softmax, dim=-1), + cpp_options_args='F::LogSoftmaxFuncOptions(-1)', + input_size=(2, 128), # trigger the last-dim algo in CUDA + fullname='log_softmax_lastdim', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.log_softmax, dim=1), + cpp_options_args='F::LogSoftmaxFuncOptions(1)', + input_size=(2, 128, 2, 2), # trigger special case of spatial CUDA algo + fullname='log_softmax_spatial_special', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.log_softmax, dim=1), + cpp_options_args='F::LogSoftmaxFuncOptions(1)', + input_size=(2, 2, 4, 4), # regular spatial algorithm + fullname='log_softmax_spatial', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.log_softmax, dim=0), + cpp_options_args='F::LogSoftmaxFuncOptions(0)', + input_size=(2, 3, 4, 5), + fullname='log_softmax_dim0', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.log_softmax, dim=3), + cpp_options_args='F::LogSoftmaxFuncOptions(3)', + input_size=(2, 3, 4, 5), + fullname='log_softmax_dim3', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.log_softmax, dim=0), + cpp_options_args='F::LogSoftmaxFuncOptions(0)', + input_size=(), + fullname='log_softmax_scalar', + pickle=False, + ), + dict( + fullname='Unfold', + constructor=lambda: nn.Unfold((2, 2), (1, 1), (0, 0), (1, 1)), + cpp_constructor_args='torch::nn::UnfoldOptions({2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})', + input_size=(2, 4, 3, 3), + check_gradgrad=False, + test_cuda=True, + default_dtype=torch.double, + ), + dict( + fullname='Fold', + constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)), + cpp_constructor_args='torch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})', + input_size=(2, 16, 4), + check_gradgrad=False, + test_cuda=True, + default_dtype=torch.double, + ), + dict( + fullname='Fold_no_batch_dim_input', + constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)), + cpp_constructor_args='torch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})', + input_size=(16, 4), + check_gradgrad=False, + ref=single_batch_reference_fn, + test_cuda=True, + default_dtype=torch.double, + ), + dict( + fullname='Unfold_int_input', + constructor=lambda: nn.Unfold(2, 1, 0, 1), + cpp_constructor_args='torch::nn::UnfoldOptions(2).dilation(1).padding(0).stride(1)', + input_size=(2, 4, 3, 3), + check_gradgrad=False, + test_cuda=True, + default_dtype=torch.double, + ), + dict( + fullname='Fold_int_input', + constructor=lambda: nn.Fold(3, 2, 1, 0, 1), + cpp_constructor_args='torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1)', + input_size=(2, 16, 4), + check_gradgrad=False, + test_cuda=True, + default_dtype=torch.double, + ), + dict( + fullname='Fold_no_batch_dim_int_input', + constructor=lambda: nn.Fold(3, 2, 1, 0, 1), + cpp_constructor_args='torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1)', + input_size=(16, 4), + ref=single_batch_reference_fn, + check_gradgrad=False, + test_cuda=True, + default_dtype=torch.double, + ), + dict( + module_name='RReLU', + constructor_args=(0.1, 0.9), + cpp_constructor_args='torch::nn::RReLUOptions().lower(0.1).upper(0.9)', + input_size=(), + desc='with_up_down_scalar', + test_cuda=False, + default_dtype=torch.double, + ), + dict( + module_name='PairwiseDistance', + input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)), + default_dtype=torch.double, + ), + dict( + module_name='PairwiseDistance', + input_fn=lambda: (torch.randn(10, 1), torch.randn(10, 8)), + desc='broadcast_lhs', + default_dtype=torch.double, + ), + dict( + module_name='PairwiseDistance', + input_fn=lambda: (torch.randn(10, 8), torch.randn(1, 8)), + desc='broadcast_rhs', + default_dtype=torch.double, + ), + dict( + module_name='PairwiseDistance', + constructor_args=(1.5, 1e-05, True), + cpp_constructor_args='torch::nn::PairwiseDistanceOptions().p(1.5).eps(1e-05).keepdim(true)', + input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)), + desc='with_non_default_args', + default_dtype=torch.double, + ), + dict( + module_name='PairwiseDistance', + input_fn=lambda: (torch.randn(8), torch.randn(8)), + reference_fn=single_batch_reference_fn, + desc='no_batch_dim', + default_dtype=torch.double, + ), + dict( + module_name='TransformerEncoderLayer', + constructor_args=(4, 2, 16, 0.0), + cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2) + .dim_feedforward(16) + .dropout(0.0)''', + input_size=(2, 3, 4), + desc='relu_activation', + with_tf32=True, + tf32_precision=0.1, + # TODO(#50743): figure out the error + # RuntimeError: The size of tensor a (6) must match the size of tensor b (4) + # at non-singleton dimension 2 + check_batched_grad=False, + check_gradgrad=False, + default_dtype=torch.double, + ), + dict( + module_name='TransformerEncoderLayer', + constructor_args=(4, 2, 8, 0.0, F.gelu), + cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2) + .dim_feedforward(8) + .dropout(0.0) + .activation(torch::kGELU)''', + input_size=(2, 3, 4), + check_gradgrad=False, + desc='gelu_activation', + with_tf32=True, + tf32_precision=0.08 if SM90OrLater else 0.05, + default_dtype=torch.double, + ), + dict( + module_name='TransformerDecoderLayer', + constructor_args=(4, 2, 8, 0.0), + cpp_constructor_args='''torch::nn::TransformerDecoderLayerOptions(4, 2) + .dim_feedforward(8) + .dropout(0.0)''', + input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4)), + check_gradgrad=False, + desc='relu_activation', + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + dict( + module_name='TransformerDecoderLayer', + constructor_args=(4, 2, 8, 0.0, F.gelu), + cpp_constructor_args='''torch::nn::TransformerDecoderLayerOptions(4, 2) + .dim_feedforward(8) + .dropout(0.0) + .activation(torch::kGELU)''', + input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4)), + check_gradgrad=False, + desc='gelu_activation', + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + dict( + module_name='Transformer', + constructor_args=(4, 2, 2, 2, 8, 0.0, F.relu), + cpp_constructor_args='''torch::nn::TransformerOptions() + .d_model(4) + .nhead(2) + .num_encoder_layers(2) + .num_decoder_layers(2) + .dim_feedforward(8) + .dropout(0.0) + .activation(torch::kReLU)''', + input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4), torch.rand(3, 3)), + check_gradgrad=False, + desc='multilayer_coder', + with_tf32=True, + tf32_precision=0.05 if SM90OrLater else 0.03, + default_dtype=torch.double, + ), + dict( + module_name='Linear', + constructor_args=(3, 5), + cpp_constructor_args='torch::nn::LinearOptions(3, 5)', + input_fn=lambda: torch.rand(3), + reference_fn=lambda i, p, _: torch.mm(i.view(1, -1), p[0].t()).view(-1) + p[1], + desc="no_batch_dim", + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='Flatten', + cpp_constructor_args='torch::nn::FlattenOptions().start_dim(-3).end_dim(-1)', + constructor_args=(-3, -1), + input_size=(3, 4, 5), + reference_fn=single_batch_reference_fn, + desc="no_batch_dim", + default_dtype=torch.double, + ), + dict( + module_name='Unflatten', + cpp_constructor_args='torch::nn::UnflattenOptions(-2, {2, 2})', + constructor_args=(-2, torch.Size([2, 2])), + input_size=(3, 4, 5), + reference_fn=single_batch_reference_fn, + desc="no_batch_dim", + default_dtype=torch.double, + ), + dict( + module_name='LayerNorm', + constructor_args=([56, 56, 56], 1e-5, False), + cpp_constructor_args='torch::nn::LayerNormOptions({56, 56, 56}).eps(1e-5).elementwise_affine(false)', + input_size=(4, 56, 56, 56), + cudnn=True, + check_eval=True, + gradcheck_fast_mode=True, + check_half=True, + desc='3d_no_affine_large_feature', + ), + ] -# add conv padding mode tests: -for padding_mode, cpp_padding_mode in zip( - ['reflect', 'circular', 'replicate', 'zeros'], - ['torch::kReflect', 'torch::kCircular', 'torch::kReplicate', 'torch::kZeros']): - # conv signature: - # in_channels, out_channels, kernel_size, stride=1, - # padding=0, dilation=1, groups=1, - # bias=True, padding_mode='zeros' - for d in (1, 2, 3): - if d == 3 and padding_mode == 'reflect': - # FIXME: remove after implementing reflection pad 3d - # https://github.com/pytorch/pytorch/issues/27655 - continue - padding = tuple(range(1, d + 1)) - cpp_padding = '{' + ', '.join(map(str, padding)) + '}' - input_size = (2, 2) + (4,) * d - output_size = (2, 3) + tuple(p + 1 for p in padding) # simplified from `(4 + 2 * p - 3) // 2 + 1` - new_module_tests.append( - dict( - module_name=f'Conv{d}d', - constructor_args=(2, 3, 3, 2, padding, 1, 1, True, padding_mode), - cpp_constructor_args=f'''torch::nn::Conv{d}dOptions(2, 3, 3) - .stride(2) - .padding({cpp_padding}) - .dilation(1) - .groups(1) - .bias(true) - .padding_mode({cpp_padding_mode})''', - input_size=input_size, - output_size=output_size, - cudnn=True, - desc=f'{padding_mode}_stride2_pad2', - with_tf32=True, - tf32_precision=0.05, - default_dtype=torch.double, - ), + # add conv padding mode tests: + for padding_mode, cpp_padding_mode in zip( + ['reflect', 'circular', 'replicate', 'zeros'], + ['torch::kReflect', 'torch::kCircular', 'torch::kReplicate', 'torch::kZeros']): + # conv signature: + # in_channels, out_channels, kernel_size, stride=1, + # padding=0, dilation=1, groups=1, + # bias=True, padding_mode='zeros' + for d in (1, 2, 3): + if d == 3 and padding_mode == 'reflect': + # FIXME: remove after implementing reflection pad 3d + # https://github.com/pytorch/pytorch/issues/27655 + continue + padding = tuple(range(1, d + 1)) + cpp_padding = '{' + ', '.join(map(str, padding)) + '}' + input_size = (2, 2) + (4,) * d + output_size = (2, 3) + tuple(p + 1 for p in padding) # simplified from `(4 + 2 * p - 3) // 2 + 1` + new_module_tests.append( + dict( + module_name=f'Conv{d}d', + constructor_args=(2, 3, 3, 2, padding, 1, 1, True, padding_mode), + cpp_constructor_args=f'''torch::nn::Conv{d}dOptions(2, 3, 3) + .stride(2) + .padding({cpp_padding}) + .dilation(1) + .groups(1) + .bias(true) + .padding_mode({cpp_padding_mode})''', + input_size=input_size, + output_size=output_size, + cudnn=True, + desc=f'{padding_mode}_stride2_pad2', + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + ) + + # Check that non linear activations work with no batch dimensions + non_linear_activations_no_batch = [ + 'ELU', 'Hardshrink', 'Hardsigmoid', 'Hardtanh', 'Hardswish', 'LeakyReLU', + 'LogSigmoid', 'PReLU', 'ReLU', 'ReLU6', 'RReLU', 'SELU', 'CELU', 'GELU', 'GLU', + 'Sigmoid', 'SiLU', 'Mish', 'Softplus', 'Softshrink', 'Softsign', 'Tanh', + 'Tanhshrink', 'Threshold' + ] + non_linear_activations_extra_info: Dict[str, dict] = { + 'CELU': {'constructor_args': (2.,), 'default_dtype': torch.double}, + 'Threshold': {'constructor_args': (2., 1.)}, + 'Hardsigmoid': {'check_gradgrad': False, 'check_jit': False, 'default_dtype': torch.double}, + 'Hardswish': {'check_gradgrad': False, 'check_jit': False, 'default_dtype': torch.double}, + # For RRelu, test that compare CPU and GPU results fail because RNG + # is different between CPU and GPU + 'RReLU': {'test_cuda': False, 'default_dtype': torch.double}, + 'ELU': {'default_dtype': torch.double}, + 'GELU': {'default_dtype': torch.double}, + 'GLU': {'default_dtype': torch.double}, + 'Hardshrink': {'default_dtype': torch.double}, + 'Hardtanh': {'default_dtype': torch.double}, + 'LeakyReLU': {'default_dtype': torch.double}, + 'LogSigmoid': {'default_dtype': torch.double}, + 'Mish': {'default_dtype': torch.double}, + 'PReLU': {'default_dtype': torch.double}, + 'ReLU6': {'default_dtype': torch.double}, + 'ReLU': {'default_dtype': torch.double}, + 'SELU': {'default_dtype': torch.double}, + 'SiLU': {'default_dtype': torch.double}, + 'Sigmoid': {'default_dtype': torch.double}, + 'Softplus': {'default_dtype': torch.double}, + 'Softshrink': {'default_dtype': torch.double}, + 'Softsign': {'default_dtype': torch.double}, + 'Tanh': {'default_dtype': torch.double}, + 'Tanhshrink': {'default_dtype': torch.double}, + } + for non_linear_activation in non_linear_activations_no_batch: + activation_test_info = dict( + module_name=non_linear_activation, + input_size=(4,), + reference_fn=single_batch_reference_fn, + desc='no_batch_dim', + test_cpp_api_parity=False, ) + extra_info = non_linear_activations_extra_info.get(non_linear_activation, {}) + activation_test_info.update(extra_info) + new_module_tests.append(activation_test_info) -# Check that non linear activations work with no batch dimensions -non_linear_activations_no_batch = [ - 'ELU', 'Hardshrink', 'Hardsigmoid', 'Hardtanh', 'Hardswish', 'LeakyReLU', - 'LogSigmoid', 'PReLU', 'ReLU', 'ReLU6', 'RReLU', 'SELU', 'CELU', 'GELU', 'GLU', - 'Sigmoid', 'SiLU', 'Mish', 'Softplus', 'Softshrink', 'Softsign', 'Tanh', - 'Tanhshrink', 'Threshold' -] -non_linear_activations_extra_info: Dict[str, dict] = { - 'CELU': {'constructor_args': (2.,), 'default_dtype': torch.double}, - 'Threshold': {'constructor_args': (2., 1.)}, - 'Hardsigmoid': {'check_gradgrad': False, 'check_jit': False, 'default_dtype': torch.double}, - 'Hardswish': {'check_gradgrad': False, 'check_jit': False, 'default_dtype': torch.double}, - # For RRelu, test that compare CPU and GPU results fail because RNG - # is different between CPU and GPU - 'RReLU': {'test_cuda': False, 'default_dtype': torch.double}, - 'ELU': {'default_dtype': torch.double}, - 'GELU': {'default_dtype': torch.double}, - 'GLU': {'default_dtype': torch.double}, - 'Hardshrink': {'default_dtype': torch.double}, - 'Hardtanh': {'default_dtype': torch.double}, - 'LeakyReLU': {'default_dtype': torch.double}, - 'LogSigmoid': {'default_dtype': torch.double}, - 'Mish': {'default_dtype': torch.double}, - 'PReLU': {'default_dtype': torch.double}, - 'ReLU6': {'default_dtype': torch.double}, - 'ReLU': {'default_dtype': torch.double}, - 'SELU': {'default_dtype': torch.double}, - 'SiLU': {'default_dtype': torch.double}, - 'Sigmoid': {'default_dtype': torch.double}, - 'Softplus': {'default_dtype': torch.double}, - 'Softshrink': {'default_dtype': torch.double}, - 'Softsign': {'default_dtype': torch.double}, - 'Tanh': {'default_dtype': torch.double}, - 'Tanhshrink': {'default_dtype': torch.double}, -} -for non_linear_activation in non_linear_activations_no_batch: - activation_test_info = dict( - module_name=non_linear_activation, - input_size=(4,), - reference_fn=single_batch_reference_fn, - desc='no_batch_dim', - test_cpp_api_parity=False, - ) - extra_info = non_linear_activations_extra_info.get(non_linear_activation, {}) - activation_test_info.update(extra_info) - new_module_tests.append(activation_test_info) + + return new_module_tests def kldivloss_reference(input, target, reduction='mean', log_target=False): diff --git a/torch/testing/_internal/jit_metaprogramming_utils.py b/torch/testing/_internal/jit_metaprogramming_utils.py index 30a6b8f8e06..9c128493699 100644 --- a/torch/testing/_internal/jit_metaprogramming_utils.py +++ b/torch/testing/_internal/jit_metaprogramming_utils.py @@ -8,7 +8,7 @@ import torch.cuda import torch.jit import torch.jit._logging import torch.jit.frontend -from torch.testing._internal.common_nn import module_tests, new_module_tests +from torch.testing._internal.common_nn import module_tests, get_new_module_tests from torch.testing._internal.common_utils import is_iterable_of_tensors, noncontiguous_like import collections @@ -95,226 +95,228 @@ def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwarg # fn mapping output to part that should be gradcheck'ed, // optional # kwargs for function, // optional # ) -nn_functional_tests = [ - ('conv1d', (S, S, S), ((S, S, S),)), - ('conv2d', (S, S, S, S), ((S, S, S, S),)), - ('conv3d', (S, S, S, S, S), ((S, S, S, S, S),)), - ('conv_transpose1d', (S, S, S), ((S, S, S),)), - ('conv_transpose2d', (S, S, S, S), ((S, S, S, S),)), - ('conv_transpose3d', (S, S, S, S, S), ((S, S, S, S, S),)), - ('conv_tbc', (S, S, S), ((S, S, S), (S,), 2)), - ('avg_pool1d', (S, S, S), (3,)), - ('avg_pool2d', (S, S, S, S), (3,), '', (True,)), - ('avg_pool3d', (S, S, S, S, S), (3,)), - ('fractional_max_pool2d', (S, S, S, S), (3, [2, 3],)), - ('max_pool1d', (S, S, S), (2, 1)), - ('max_pool1d', (S, S, S), (2, 1, 1, 1, False, True), 'with_indices'), - ('max_pool2d', (S, S, S, S), (2, 1), '', (True, 'aten::max_pool2d_with_indices')), - ('max_pool2d', (S, S, S, S), (2, 1, 1, 1, False, True), 'with_indices', (True, 'aten::max_pool2d_with_indices')), - ('max_pool3d', (S, S, S, S, S), (2, 1)), - ('max_unpool1d', torch.tensor([[[2., 4]]]), (torch.tensor([[[1, 3]]]), 2, 2, 0)), - ('max_unpool2d', torch.tensor([[[[2., 4]]]]), (torch.tensor([[[[1, 3]]]]), 2, 2, 0)), - ('max_unpool3d', torch.tensor([[[[[2., 4]]]]]), (torch.tensor([[[[[1, 3]]]]]), 2, 2, 0)), - ('lp_pool1d', (S, S, S), (2., 3, 2,)), - ('lp_pool2d', (S, S, S, S), (2., 3, 2,)), - ('lp_pool3d', (S, S, S, S, S), (2., 3, 2,)), - ('adaptive_max_pool1d', (S, S, S), (5,)), - ('adaptive_max_pool2d', (S, S, S, S), ([5, 7],)), - ('adaptive_max_pool3d', (S, S, S, S, S), ([3, 2, 2],)), - ('adaptive_avg_pool1d', (S, S, S), (5,), '', (True,)), - ('adaptive_avg_pool2d', (S, S, S, S), ([5, 7],), '', (True,)), - ('adaptive_avg_pool3d', (S, S, S, S, S), ([3, 2, 2],), '', (True,)), - ('dropout', (S, S, S), (0.5,), '', (True, 'aten::native_dropout')), - ('alpha_dropout', (S, S, S), (0.5,)), - ('dropout2d', (S, S, S), (0.5,)), - ('dropout2d', (S, S, S, S), (0.5,), 'batched'), - ('dropout3d', (S, S, S, S), (0.5,)), - ('dropout3d', (S, S, S, S, S), (0.5,), 'batched'), - ('feature_alpha_dropout', (S, S, S), (0.5,)), - ('threshold', (S, S, S), (0.1, 2.), '', (True,)), - ('threshold', (S, S, S), (0.1, 2., True), 'inplace'), - ('relu', (S, S, S), (), '', (True,)), - ('relu', (S, S, S), (), 'inplace'), - ('glu', (S - 1, S - 1, S - 1), (),), - ('hardtanh', (S, S, S), (-0.5, 0.5), '', (True,)), - ('hardtanh', (S, S, S), (-0.5, 0.5, True), 'inplace'), - ('relu6', (S, S, S), (), '', (True,)), - ('relu6', (S, S, S), (True), 'inplace'), - ('elu', (S, S, S), (0.9,),), - ('elu', (S, S, S), (0.9, True), 'inplace'), - ('selu', (S, S, S), (),), - ('selu', (S, S, S), (True), 'inplace'), - ('celu', (S, S, S), (0.9,),), - ('celu', (S, S, S), (0.9, True), 'inplace'), - ('leaky_relu', (S, S, S), (0.02,), '', (True,)), - ('leaky_relu', (S, S, S), (0.02,), 'inplace'), - ('rrelu', (S, S), (0.1, 0.3, False),), - ('rrelu', (S, S), (0.1, 0.3, False, True), 'inplace'), - ('hardshrink', (S, S, S), (0.4,), '', (True,)), - ('tanhshrink', (S, S, S), (),), - ('softsign', (S, S, S), (),), - ('softplus', (S, S, S), (), '', (True,)), - ('softmin', (S, S, S), (0,),), - ('softmax', (S, S, S), (0,), '', (True,)), - ('softmax', (S, S, S), (0, 3, torch.double), 'with_all_args', (True,)), - ('tanh', (S, S, S), (), '', (True,)), - ('sigmoid', (S, S, S), (), '', (True,)), - ('silu', (S, S, S), (), '', (True,)), - ('log_softmax', (S, S, S), (0,), '', (True,)), - ('linear', (S, S), ((M, S),), '', (True, ['aten::linear'])), - ('linear', (S, S), ((M, S), (M,)), 'addmm', (True, ['aten::linear'])), - ('bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M),),), - ('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ), '', (True,)), - ('embedding_bag', torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3), torch.tensor([0, 4]),),), - ('batch_norm', (S, S), - (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), None, None, True, ), - 'training', (True, 'aten::_batch_norm_impl_index')), - ('batch_norm', (0, S, S, S), - (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), - non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ), - 'size_zero', (True, 'aten::_batch_norm_impl_index')), - ('batch_norm', (0, S, S, S), - (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), - non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ), - 'size_zero_inference', (True, 'aten::_batch_norm_impl_index')), - ('batch_norm', (S, S), - (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), - non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ), - 'with_weight_and_bias_training', (True, 'aten::_batch_norm_impl_index')), - ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), - None, non_differentiable(torch.ones(S)), True, ), - 'with_only_bias_training', (True, 'aten::_batch_norm_impl_index')), - ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), - non_differentiable(torch.randn(S)), None, True, ), - 'with_only_weight_training', (True, 'aten::_batch_norm_impl_index')), - ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), - None, None, False, ), - 'inference', (True, 'aten::_batch_norm_impl_index')), - ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), - non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), False, ), - 'with_weight_and_bias_inference', (True, 'aten::_batch_norm_impl_index')), - ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), - None, non_differentiable(torch.ones(S)), False, ), - 'with_only_bias_inference', (True, 'aten::_batch_norm_impl_index')), - ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), - non_differentiable(torch.randn(S)), None, False, ), - 'with_only_weight_inference', (True, 'aten::_batch_norm_impl_index')), - ('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),), - ('layer_norm', (S, S, S, S), ([5],), '', - (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])), - ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),), 'with_only_weight', - (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])), - ('layer_norm', (S, S, S, S), ([5], None, non_differentiable(torch.rand(S)),), 'with_only_bias', - (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])), - ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)), - non_differentiable(torch.rand(S))), 'with_weight_and_bias', - (False, ['aten::contiguous', 'aten::_batch_norm_impl_index', 'aten::addcmul'])), - ('group_norm', (S, S, S), (1, torch.rand(5),),), - ('local_response_norm', (S, S, S), (2, ),), - ('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),), '',), - ('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2),),), - ('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2), True, True), 'full'), - ('kl_div', F.log_softmax(torch.randn(S, 10), 1), (F.softmax(torch.randn(S, 10), 1),),), - ('cross_entropy', (3, S), (torch.randint(S, (3,), dtype=torch.int64),),), - ('binary_cross_entropy_with_logits', (3,), (torch.empty(3).random_(2), ),), - ('smooth_l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), - ('huber_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), - ('l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), - ('mse_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), - ('smooth_l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'), - ('huber_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'), - ('l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'), - ('mse_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'), - ('margin_ranking_loss', (S,), ((S,), (S,)),), - ('hinge_embedding_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), - ('soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), - ('multilabel_soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), - ('cosine_embedding_loss', (S, S), ((S, S), non_differentiable(torch.rand(S,))),), - ('pixel_shuffle', (1, 9, 4, 4), (3,),), - ('pixel_unshuffle', (1, 1, 12, 12), (3,),), - ('affine_grid', (S, 2, 3), (torch.Size([S, 1, 7, 7]),),), - ('pad', (3, 3, 4, 2), ([1, 1],),), - ('pairwise_distance', (S, S), ((S, S),),), - ('pdist', (S, S), (),), - ('cosine_similarity', (S, S), ((S, S),),), - ('triplet_margin_loss', (S, S), ((S, S), (S, S)),), - ('normalize', (S, S, S), (),), - ('unfold', (S, S, S, S), ([2, 3]),), - ('fold', (1, 3 * 2 * 2, 12), ([4, 5], [2, 2]),), - ('grid_sample', (S, S, S, S), (non_differentiable(torch.rand(S, S, S, 2)),),), - ('gumbel_softmax', (S, S), (2.,), '', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])), - ('gumbel_softmax', (S, S), (2., True,), 'hard', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])), - ('multilabel_margin_loss', torch.tensor([[0.2, -0.2, 0.07]]), (torch.tensor([[0, 0, 1]]),),), - ('multi_margin_loss', (S, S), (non_differentiable(torch.randint(S, (S, ), dtype=torch.int64)), - 1, 1., non_differentiable(torch.randn(S))),), - ('binary_cross_entropy', torch.randn(3, 2).sigmoid(), (non_differentiable(torch.rand(3, 2)), - non_differentiable(torch.randn(3, 2))),), - ('binary_cross_entropy', torch.randn(3, 2).sigmoid(), - (non_differentiable(torch.rand(3, 2)), - non_differentiable(torch.randn(3, 2)), None, None, 'mean'), 'size_average'), - ('ctc_loss', torch.rand(S, S, S).log_softmax(2).detach().requires_grad_(), - (torch.randint(1, S, (S, S), dtype=torch.long), torch.full((S,), S, dtype=torch.long), - torch.randint(1, S, (S,), dtype=torch.long))), - ('upsample', torch.randn(S, S, M, M), (None, 2.), 'with_scale'), - ('upsample', torch.randn(S, S, M, M), (4,), 'with_size'), - ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'nearest_4d'), - ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'nearest_4d_with_scale'), - ('interpolate', torch.randn(S, S, M, M), (4,), 'nearest_4d_with_size'), - ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'area_4d'), - ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'area_4d_with_scale'), - ('interpolate', torch.randn(S, S, M, M), (4,), 'area_4d_with_size'), - ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bilinear_4d'), - ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bilinear_4d_with_scale'), - ('interpolate', torch.randn(S, S, M, M), (4,), 'bilinear_4d_with_size'), - ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bicubic_4d'), - ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bicubic_4d_with_scale'), - ('interpolate', torch.randn(S, S, M, M), (4,), 'bicubic_4d_with_size'), - ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'nearest_3d'), - ('interpolate', torch.randn(S, M, M), (None, 2.), 'nearest_3d_with_scale'), - ('interpolate', torch.randn(S, M, M), (4,), 'nearest_3d_with_size'), - ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'area_3d'), - ('interpolate', torch.randn(S, M, M), (None, 2.), 'area_3d_with_scale'), - ('interpolate', torch.randn(S, M, M), (4,), 'area_3d_with_size'), - ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'linear_3d'), - ('interpolate', torch.randn(S, M, M), (None, 2.), 'linear_3d_with_scale'), - ('interpolate', torch.randn(S, M, M), (4,), 'linear_3d_with_size'), - ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'nearest_5d_with_scale'), - ('interpolate', torch.randn(S, M, M, M, M), (4,), 'nearest_5d_with_size'), - ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'area_5d'), - ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'area_5d_with_scale'), - ('interpolate', torch.randn(S, M, M, M, M), (4,), 'area_5d_with_size'), - ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'trilinear_5d'), - ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'trilinear_5d_with_scale'), - ('interpolate', torch.randn(S, M, M, M, M), (4,), 'trilinear_5d_with_size'), - ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2, None, 'nearest', None, False), - 'nearest_4d_not_recompute_scale_factor'), - ('interpolate', torch.randn(S, S, M, M), (4, None, 'nearest', None, False), - 'nearest_4d_with_size_not_recompute_scale_factor'), - ('interpolate', torch.randn(S, S, M, M), (None, 2., 'bilinear', None, False), - 'bilinear_4d_with_scale_not_recompute_scale_factor'), - ('interpolate', torch.randn(S, S, M, M), (4, None, 'bilinear', None, False), - 'bilinear_4d_with_size_not_recompute_scale_factor'), - ('interpolate', torch.randn(S, S, M, M), (None, 2., 'bicubic', None, False), - 'bicubic_4d_with_scale_not_recompute_scale_factor'), - ('interpolate', torch.randn(S, S, M, M), (4, None, 'bicubic', None, False), - 'bicubic_4d_with_size_not_recompute_scale_factor'), - ('interpolate', torch.randn(S, M, M), (None, 2., 'nearest', None, False), - 'nearest_3d_with_scale_not_recompute_scale_factor'), - ('interpolate', torch.randn(S, M, M), (4, None, 'nearest', None, False), - 'nearest_3d_with_size_not_recompute_scale_factor'), - ('interpolate', torch.randn(S, M, M), (None, 2., 'linear', None, False), - 'linear_3d_with_scale_not_recompute_scale_factor'), - ('interpolate', torch.randn(S, M, M), (4, None, 'linear', None, False), - 'linear_3d_with_size_not_recompute_scale_factor'), - ('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'nearest', None, False), - 'nearest_5d_with_scale_not_recompute_scale_factor'), - ('interpolate', torch.randn(S, M, M, M, M), (4, None, 'nearest', None, False), - 'nearest_5d_with_size_not_recompute_scale_factor'), - ('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'trilinear', None, False), - 'trilinear_5d_with_scale_not_recompute_scale_factor'), - ('interpolate', torch.randn(S, M, M, M, M), (4, None, 'trilinear', None, False), - 'trilinear_5d_with_size_not_recompute_scale_factor'), -] +def get_nn_functional_tests(): + nn_functional_tests = [ + ('conv1d', (S, S, S), ((S, S, S),)), + ('conv2d', (S, S, S, S), ((S, S, S, S),)), + ('conv3d', (S, S, S, S, S), ((S, S, S, S, S),)), + ('conv_transpose1d', (S, S, S), ((S, S, S),)), + ('conv_transpose2d', (S, S, S, S), ((S, S, S, S),)), + ('conv_transpose3d', (S, S, S, S, S), ((S, S, S, S, S),)), + ('conv_tbc', (S, S, S), ((S, S, S), (S,), 2)), + ('avg_pool1d', (S, S, S), (3,)), + ('avg_pool2d', (S, S, S, S), (3,), '', (True,)), + ('avg_pool3d', (S, S, S, S, S), (3,)), + ('fractional_max_pool2d', (S, S, S, S), (3, [2, 3],)), + ('max_pool1d', (S, S, S), (2, 1)), + ('max_pool1d', (S, S, S), (2, 1, 1, 1, False, True), 'with_indices'), + ('max_pool2d', (S, S, S, S), (2, 1), '', (True, 'aten::max_pool2d_with_indices')), + ('max_pool2d', (S, S, S, S), (2, 1, 1, 1, False, True), 'with_indices', (True, 'aten::max_pool2d_with_indices')), + ('max_pool3d', (S, S, S, S, S), (2, 1)), + ('max_unpool1d', torch.tensor([[[2., 4]]]), (torch.tensor([[[1, 3]]]), 2, 2, 0)), + ('max_unpool2d', torch.tensor([[[[2., 4]]]]), (torch.tensor([[[[1, 3]]]]), 2, 2, 0)), + ('max_unpool3d', torch.tensor([[[[[2., 4]]]]]), (torch.tensor([[[[[1, 3]]]]]), 2, 2, 0)), + ('lp_pool1d', (S, S, S), (2., 3, 2,)), + ('lp_pool2d', (S, S, S, S), (2., 3, 2,)), + ('lp_pool3d', (S, S, S, S, S), (2., 3, 2,)), + ('adaptive_max_pool1d', (S, S, S), (5,)), + ('adaptive_max_pool2d', (S, S, S, S), ([5, 7],)), + ('adaptive_max_pool3d', (S, S, S, S, S), ([3, 2, 2],)), + ('adaptive_avg_pool1d', (S, S, S), (5,), '', (True,)), + ('adaptive_avg_pool2d', (S, S, S, S), ([5, 7],), '', (True,)), + ('adaptive_avg_pool3d', (S, S, S, S, S), ([3, 2, 2],), '', (True,)), + ('dropout', (S, S, S), (0.5,), '', (True, 'aten::native_dropout')), + ('alpha_dropout', (S, S, S), (0.5,)), + ('dropout2d', (S, S, S), (0.5,)), + ('dropout2d', (S, S, S, S), (0.5,), 'batched'), + ('dropout3d', (S, S, S, S), (0.5,)), + ('dropout3d', (S, S, S, S, S), (0.5,), 'batched'), + ('feature_alpha_dropout', (S, S, S), (0.5,)), + ('threshold', (S, S, S), (0.1, 2.), '', (True,)), + ('threshold', (S, S, S), (0.1, 2., True), 'inplace'), + ('relu', (S, S, S), (), '', (True,)), + ('relu', (S, S, S), (), 'inplace'), + ('glu', (S - 1, S - 1, S - 1), (),), + ('hardtanh', (S, S, S), (-0.5, 0.5), '', (True,)), + ('hardtanh', (S, S, S), (-0.5, 0.5, True), 'inplace'), + ('relu6', (S, S, S), (), '', (True,)), + ('relu6', (S, S, S), (True), 'inplace'), + ('elu', (S, S, S), (0.9,),), + ('elu', (S, S, S), (0.9, True), 'inplace'), + ('selu', (S, S, S), (),), + ('selu', (S, S, S), (True), 'inplace'), + ('celu', (S, S, S), (0.9,),), + ('celu', (S, S, S), (0.9, True), 'inplace'), + ('leaky_relu', (S, S, S), (0.02,), '', (True,)), + ('leaky_relu', (S, S, S), (0.02,), 'inplace'), + ('rrelu', (S, S), (0.1, 0.3, False),), + ('rrelu', (S, S), (0.1, 0.3, False, True), 'inplace'), + ('hardshrink', (S, S, S), (0.4,), '', (True,)), + ('tanhshrink', (S, S, S), (),), + ('softsign', (S, S, S), (),), + ('softplus', (S, S, S), (), '', (True,)), + ('softmin', (S, S, S), (0,),), + ('softmax', (S, S, S), (0,), '', (True,)), + ('softmax', (S, S, S), (0, 3, torch.double), 'with_all_args', (True,)), + ('tanh', (S, S, S), (), '', (True,)), + ('sigmoid', (S, S, S), (), '', (True,)), + ('silu', (S, S, S), (), '', (True,)), + ('log_softmax', (S, S, S), (0,), '', (True,)), + ('linear', (S, S), ((M, S),), '', (True, ['aten::linear'])), + ('linear', (S, S), ((M, S), (M,)), 'addmm', (True, ['aten::linear'])), + ('bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M),),), + ('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ), '', (True,)), + ('embedding_bag', torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3), torch.tensor([0, 4]),),), + ('batch_norm', (S, S), + (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), None, None, True, ), + 'training', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (0, S, S, S), + (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ), + 'size_zero', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (0, S, S, S), + (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ), + 'size_zero_inference', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), + (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ), + 'with_weight_and_bias_training', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + None, non_differentiable(torch.ones(S)), True, ), + 'with_only_bias_training', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + non_differentiable(torch.randn(S)), None, True, ), + 'with_only_weight_training', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + None, None, False, ), + 'inference', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), False, ), + 'with_weight_and_bias_inference', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + None, non_differentiable(torch.ones(S)), False, ), + 'with_only_bias_inference', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + non_differentiable(torch.randn(S)), None, False, ), + 'with_only_weight_inference', (True, 'aten::_batch_norm_impl_index')), + ('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),), + ('layer_norm', (S, S, S, S), ([5],), '', + (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])), + ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),), 'with_only_weight', + (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])), + ('layer_norm', (S, S, S, S), ([5], None, non_differentiable(torch.rand(S)),), 'with_only_bias', + (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])), + ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)), + non_differentiable(torch.rand(S))), 'with_weight_and_bias', + (False, ['aten::contiguous', 'aten::_batch_norm_impl_index', 'aten::addcmul'])), + ('group_norm', (S, S, S), (1, torch.rand(5),),), + ('local_response_norm', (S, S, S), (2, ),), + ('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),), '',), + ('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2),),), + ('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2), True, True), 'full'), + ('kl_div', F.log_softmax(torch.randn(S, 10), 1), (F.softmax(torch.randn(S, 10), 1),),), + ('cross_entropy', (3, S), (torch.randint(S, (3,), dtype=torch.int64),),), + ('binary_cross_entropy_with_logits', (3,), (torch.empty(3).random_(2), ),), + ('smooth_l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), + ('huber_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), + ('l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), + ('mse_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), + ('smooth_l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'), + ('huber_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'), + ('l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'), + ('mse_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'), + ('margin_ranking_loss', (S,), ((S,), (S,)),), + ('hinge_embedding_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), + ('soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), + ('multilabel_soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), + ('cosine_embedding_loss', (S, S), ((S, S), non_differentiable(torch.rand(S,))),), + ('pixel_shuffle', (1, 9, 4, 4), (3,),), + ('pixel_unshuffle', (1, 1, 12, 12), (3,),), + ('affine_grid', (S, 2, 3), (torch.Size([S, 1, 7, 7]),),), + ('pad', (3, 3, 4, 2), ([1, 1],),), + ('pairwise_distance', (S, S), ((S, S),),), + ('pdist', (S, S), (),), + ('cosine_similarity', (S, S), ((S, S),),), + ('triplet_margin_loss', (S, S), ((S, S), (S, S)),), + ('normalize', (S, S, S), (),), + ('unfold', (S, S, S, S), ([2, 3]),), + ('fold', (1, 3 * 2 * 2, 12), ([4, 5], [2, 2]),), + ('grid_sample', (S, S, S, S), (non_differentiable(torch.rand(S, S, S, 2)),),), + ('gumbel_softmax', (S, S), (2.,), '', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])), + ('gumbel_softmax', (S, S), (2., True,), 'hard', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])), + ('multilabel_margin_loss', torch.tensor([[0.2, -0.2, 0.07]]), (torch.tensor([[0, 0, 1]]),),), + ('multi_margin_loss', (S, S), (non_differentiable(torch.randint(S, (S, ), dtype=torch.int64)), + 1, 1., non_differentiable(torch.randn(S))),), + ('binary_cross_entropy', torch.randn(3, 2).sigmoid(), (non_differentiable(torch.rand(3, 2)), + non_differentiable(torch.randn(3, 2))),), + ('binary_cross_entropy', torch.randn(3, 2).sigmoid(), + (non_differentiable(torch.rand(3, 2)), + non_differentiable(torch.randn(3, 2)), None, None, 'mean'), 'size_average'), + ('ctc_loss', torch.rand(S, S, S).log_softmax(2).detach().requires_grad_(), + (torch.randint(1, S, (S, S), dtype=torch.long), torch.full((S,), S, dtype=torch.long), + torch.randint(1, S, (S,), dtype=torch.long))), + ('upsample', torch.randn(S, S, M, M), (None, 2.), 'with_scale'), + ('upsample', torch.randn(S, S, M, M), (4,), 'with_size'), + ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'nearest_4d'), + ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'nearest_4d_with_scale'), + ('interpolate', torch.randn(S, S, M, M), (4,), 'nearest_4d_with_size'), + ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'area_4d'), + ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'area_4d_with_scale'), + ('interpolate', torch.randn(S, S, M, M), (4,), 'area_4d_with_size'), + ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bilinear_4d'), + ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bilinear_4d_with_scale'), + ('interpolate', torch.randn(S, S, M, M), (4,), 'bilinear_4d_with_size'), + ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bicubic_4d'), + ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bicubic_4d_with_scale'), + ('interpolate', torch.randn(S, S, M, M), (4,), 'bicubic_4d_with_size'), + ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'nearest_3d'), + ('interpolate', torch.randn(S, M, M), (None, 2.), 'nearest_3d_with_scale'), + ('interpolate', torch.randn(S, M, M), (4,), 'nearest_3d_with_size'), + ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'area_3d'), + ('interpolate', torch.randn(S, M, M), (None, 2.), 'area_3d_with_scale'), + ('interpolate', torch.randn(S, M, M), (4,), 'area_3d_with_size'), + ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'linear_3d'), + ('interpolate', torch.randn(S, M, M), (None, 2.), 'linear_3d_with_scale'), + ('interpolate', torch.randn(S, M, M), (4,), 'linear_3d_with_size'), + ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'nearest_5d_with_scale'), + ('interpolate', torch.randn(S, M, M, M, M), (4,), 'nearest_5d_with_size'), + ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'area_5d'), + ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'area_5d_with_scale'), + ('interpolate', torch.randn(S, M, M, M, M), (4,), 'area_5d_with_size'), + ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'trilinear_5d'), + ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'trilinear_5d_with_scale'), + ('interpolate', torch.randn(S, M, M, M, M), (4,), 'trilinear_5d_with_size'), + ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2, None, 'nearest', None, False), + 'nearest_4d_not_recompute_scale_factor'), + ('interpolate', torch.randn(S, S, M, M), (4, None, 'nearest', None, False), + 'nearest_4d_with_size_not_recompute_scale_factor'), + ('interpolate', torch.randn(S, S, M, M), (None, 2., 'bilinear', None, False), + 'bilinear_4d_with_scale_not_recompute_scale_factor'), + ('interpolate', torch.randn(S, S, M, M), (4, None, 'bilinear', None, False), + 'bilinear_4d_with_size_not_recompute_scale_factor'), + ('interpolate', torch.randn(S, S, M, M), (None, 2., 'bicubic', None, False), + 'bicubic_4d_with_scale_not_recompute_scale_factor'), + ('interpolate', torch.randn(S, S, M, M), (4, None, 'bicubic', None, False), + 'bicubic_4d_with_size_not_recompute_scale_factor'), + ('interpolate', torch.randn(S, M, M), (None, 2., 'nearest', None, False), + 'nearest_3d_with_scale_not_recompute_scale_factor'), + ('interpolate', torch.randn(S, M, M), (4, None, 'nearest', None, False), + 'nearest_3d_with_size_not_recompute_scale_factor'), + ('interpolate', torch.randn(S, M, M), (None, 2., 'linear', None, False), + 'linear_3d_with_scale_not_recompute_scale_factor'), + ('interpolate', torch.randn(S, M, M), (4, None, 'linear', None, False), + 'linear_3d_with_size_not_recompute_scale_factor'), + ('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'nearest', None, False), + 'nearest_5d_with_scale_not_recompute_scale_factor'), + ('interpolate', torch.randn(S, M, M, M, M), (4, None, 'nearest', None, False), + 'nearest_5d_with_size_not_recompute_scale_factor'), + ('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'trilinear', None, False), + 'trilinear_5d_with_scale_not_recompute_scale_factor'), + ('interpolate', torch.randn(S, M, M, M, M), (4, None, 'trilinear', None, False), + 'trilinear_5d_with_size_not_recompute_scale_factor'), + ] + return nn_functional_tests script_template = ''' def the_method({}): @@ -523,45 +525,6 @@ def get_nn_functional_compiled_fn_and_inputs(name, self_size, args, variant_name return script_fn, inputs -# additional modules test -# TODO: delete this list once we make all nn_tests work -additional_module_tests = [ - { - 'module_name': 'Bilinear', - 'constructor_args': (S, S, M), - 'input_size': (S, S), - 'extra_args': ((S, S),) - }, - { - 'module_name': 'RNNCell', - 'constructor_args': (S, S), - 'input_size': (S, S), - }, - { - 'module_name': 'LSTMCell', - 'constructor_args': (S, S), - 'input_size': (S, S), - }, - { - 'module_name': 'GRUCell', - 'constructor_args': (S, S), - 'input_size': (S, S), - }, - { - 'module_name': 'MultiheadAttention', - 'constructor_args': (128, 8), - 'input_size': (10, 8, 128), - 'extra_args': (torch.randn(10, 8, 128), torch.randn(10, 8, 128)), - 'slowTest': True - }, - { - 'module_name': 'Transformer', - 'constructor_args': (1, 1, 1, 1, 2), - 'input_size': (3, 1, 1), - 'extra_args': (torch.randn(1, 1, 1),), - 'slowTest': True - } -] EXCLUDE_SCRIPT_MODULES = { 'test_nn_AdaptiveAvgPool2d_tuple_none', @@ -719,4 +682,44 @@ def try_get_nn_module_compiled_mod_and_inputs(*args, **kwargs): def get_all_nn_module_tests(): - return module_tests + new_module_tests + additional_module_tests + # additional modules test + # TODO: delete this list once we make all nn_tests work + additional_module_tests = [ + { + 'module_name': 'Bilinear', + 'constructor_args': (S, S, M), + 'input_size': (S, S), + 'extra_args': ((S, S),) + }, + { + 'module_name': 'RNNCell', + 'constructor_args': (S, S), + 'input_size': (S, S), + }, + { + 'module_name': 'LSTMCell', + 'constructor_args': (S, S), + 'input_size': (S, S), + }, + { + 'module_name': 'GRUCell', + 'constructor_args': (S, S), + 'input_size': (S, S), + }, + { + 'module_name': 'MultiheadAttention', + 'constructor_args': (128, 8), + 'input_size': (10, 8, 128), + 'extra_args': (torch.randn(10, 8, 128), torch.randn(10, 8, 128)), + 'slowTest': True + }, + { + 'module_name': 'Transformer', + 'constructor_args': (1, 1, 1, 1, 2), + 'input_size': (3, 1, 1), + 'extra_args': (torch.randn(1, 1, 1),), + 'slowTest': True + } + ] + + return module_tests + get_new_module_tests() + additional_module_tests