mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
No actual change, just remove variable contain Tensors from global scope (#143225)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143225 Approved by: https://github.com/ezyang
This commit is contained in:
parent
afa313e669
commit
792f1c47e9
|
|
@ -20,8 +20,8 @@ from torch.testing._internal.common_utils import (
|
|||
from torch.testing._internal.jit_metaprogramming_utils import (
|
||||
get_all_nn_module_tests,
|
||||
get_nn_functional_compiled_fn_and_inputs,
|
||||
get_nn_functional_tests,
|
||||
get_nn_mod_test_name,
|
||||
nn_functional_tests,
|
||||
try_get_nn_module_compiled_mod_and_inputs,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import enable_profiling_mode, JitTestCase
|
||||
|
|
@ -70,7 +70,7 @@ class TestComplexity(JitTestCase):
|
|||
def test_generated_functional_tests(self):
|
||||
with enable_profiling_mode():
|
||||
stats = [("Name", "Ifs/Loops", "non-tensor ops")]
|
||||
for test in nn_functional_tests:
|
||||
for test in get_nn_functional_tests():
|
||||
test_name = test[0]
|
||||
|
||||
fn, inputs = get_nn_functional_compiled_fn_and_inputs(*test)
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ if not common.IS_ARM64:
|
|||
(sample_module.module_tests, common_nn.NewModuleTest),
|
||||
(sample_functional.functional_tests, common_nn.NewModuleTest),
|
||||
(common_nn.module_tests, common_nn.NewModuleTest),
|
||||
(common_nn.new_module_tests, common_nn.NewModuleTest),
|
||||
(common_nn.get_new_module_tests(), common_nn.NewModuleTest),
|
||||
(common_nn.criterion_tests, common_nn.CriterionTest),
|
||||
]:
|
||||
for test_params_dict in test_params_dicts:
|
||||
|
|
|
|||
|
|
@ -25,7 +25,11 @@ from torch.testing._internal.common_device_type import (
|
|||
)
|
||||
from torch.testing._internal.common_methods_invocations import op_db, SampleInput
|
||||
from torch.testing._internal.common_modules import module_db, modules
|
||||
from torch.testing._internal.common_nn import module_tests, new_module_tests, TestBase
|
||||
from torch.testing._internal.common_nn import (
|
||||
get_new_module_tests,
|
||||
module_tests,
|
||||
TestBase,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
freeze_rng_state,
|
||||
make_tensor,
|
||||
|
|
@ -1011,7 +1015,7 @@ def filter_supported_tests(t):
|
|||
# TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests
|
||||
# These currently use the legacy nn tests
|
||||
supported_tests = [
|
||||
t for t in module_tests + new_module_tests if filter_supported_tests(t)
|
||||
t for t in module_tests + get_new_module_tests() if filter_supported_tests(t)
|
||||
]
|
||||
for test_param in supported_tests:
|
||||
if "constructor" not in test_param:
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ from torch.testing._internal.common_device_type import (
|
|||
ops,
|
||||
)
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
from torch.testing._internal.common_nn import module_tests, new_module_tests
|
||||
from torch.testing._internal.common_nn import module_tests, get_new_module_tests
|
||||
from torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
import torch.utils._pytree as pytree
|
||||
|
|
@ -1006,7 +1006,7 @@ terrible spacing
|
|||
Exhaustively test `Node.normalized_arguments` on all standard
|
||||
torch.nn Module classes
|
||||
"""
|
||||
for test_params in module_tests + new_module_tests:
|
||||
for test_params in module_tests + get_new_module_tests():
|
||||
if "constructor" not in test_params:
|
||||
constructor = getattr(torch.nn, test_params["module_name"])
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -107,10 +107,10 @@ from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, dis
|
|||
from torch.testing._internal.jit_metaprogramming_utils import (
|
||||
get_script_args,
|
||||
create_input, unpack_variables,
|
||||
additional_module_tests, EXCLUDE_SCRIPT_MODULES,
|
||||
get_all_nn_module_tests, EXCLUDE_SCRIPT_MODULES,
|
||||
get_nn_module_name_from_kwargs, get_nn_mod_test_name, script_method_template)
|
||||
|
||||
from torch.testing._internal.common_nn import module_tests, new_module_tests, criterion_tests
|
||||
from torch.testing._internal.common_nn import criterion_tests
|
||||
|
||||
# For testing truediv in python 2
|
||||
from torch.testing._internal.test_module.future_div import div_int_future, div_float_future
|
||||
|
|
@ -16247,7 +16247,7 @@ class TestProducerVersion(TestCase):
|
|||
# issue gh-32561
|
||||
self.assertTrue(torch.__version__.startswith(torch.onnx.producer_version))
|
||||
|
||||
for test in module_tests + new_module_tests + additional_module_tests:
|
||||
for test in get_all_nn_module_tests():
|
||||
add_nn_module_test(**test)
|
||||
|
||||
for test in criterion_tests:
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ from torch.testing._internal.common_utils import freeze_rng_state, run_tests, Te
|
|||
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, PLATFORM_SUPPORTS_FLASH_ATTENTION
|
||||
from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \
|
||||
module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \
|
||||
ctcloss_reference, new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input
|
||||
ctcloss_reference, get_new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input
|
||||
from torch.testing._internal.common_device_type import dtypesIfMPS, instantiate_device_type_tests, dtypes, \
|
||||
dtypesIfCUDA, precisionOverride, skipCUDAIfCudnnVersionLessThan, onlyCUDA, onlyCPU, \
|
||||
skipCUDAIfRocm, skipCUDAIf, skipCUDAIfNotRocm, \
|
||||
|
|
@ -7332,7 +7332,7 @@ def add_test(test, decorator=None):
|
|||
else:
|
||||
add(cuda_test_name, with_tf32_off)
|
||||
|
||||
for test_params in module_tests + new_module_tests:
|
||||
for test_params in module_tests + get_new_module_tests():
|
||||
# TODO: CUDA is not implemented yet
|
||||
if 'constructor' not in test_params:
|
||||
name = test_params.pop('module_name')
|
||||
|
|
|
|||
|
|
@ -19,8 +19,6 @@ from torch.fx import Graph, GraphModule, Node
|
|||
from torch.fx.subgraph_rewriter import replace_pattern_with_filters, ReplacedPatterns
|
||||
|
||||
from .utils import (
|
||||
_conv1d_bn_example_inputs,
|
||||
_conv2d_bn_example_inputs,
|
||||
_get_aten_graph_module_for_pattern,
|
||||
_is_bn_node,
|
||||
_is_conv_or_conv_transpose_node,
|
||||
|
|
@ -35,27 +33,6 @@ if TYPE_CHECKING:
|
|||
__all__ = [] # type: ignore[var-annotated]
|
||||
|
||||
|
||||
# Example inputs for quantized and folded conv-bn1d patterns used in convert
|
||||
_quantized_conv1d_bn_example_inputs = (
|
||||
torch.randn(1, 1, 3), # x
|
||||
torch.randn(1, 1, 1), # conv_weight
|
||||
torch.randn(1), # bn_weight
|
||||
torch.randn(1), # bn_bias
|
||||
torch.randn(1), # bn_running_mean
|
||||
torch.randn(1), # bn_running_var
|
||||
)
|
||||
|
||||
# Example inputs for quantized and folded conv-bn2d patterns used in convert
|
||||
_quantized_conv2d_bn_example_inputs = (
|
||||
torch.randn(1, 1, 3, 3), # x
|
||||
torch.randn(1, 1, 1, 1), # conv_weight
|
||||
torch.randn(1), # bn_weight
|
||||
torch.randn(1), # bn_bias
|
||||
torch.randn(1), # bn_running_mean
|
||||
torch.randn(1), # bn_running_var
|
||||
)
|
||||
|
||||
|
||||
def _get_quantized_conv_bn_example_inputs_kwargs(
|
||||
is_per_channel: bool,
|
||||
has_bias: bool,
|
||||
|
|
@ -631,6 +608,28 @@ def _update_special_qspecs_after_replacement(
|
|||
|
||||
|
||||
def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
|
||||
# Example inputs for conv-bn1d patterns
|
||||
_conv1d_bn_example_inputs = (
|
||||
torch.randn(1, 1, 3), # x
|
||||
torch.randn(1, 1, 1), # conv_weight
|
||||
torch.randn(1), # conv_bias
|
||||
torch.randn(1), # bn_weight
|
||||
torch.randn(1), # bn_bias
|
||||
torch.randn(1), # bn_running_mean
|
||||
torch.randn(1), # bn_running_var
|
||||
)
|
||||
|
||||
# Example inputs for conv-bn2d patterns
|
||||
_conv2d_bn_example_inputs = (
|
||||
torch.randn(1, 1, 3, 3), # x
|
||||
torch.randn(1, 1, 1, 1), # conv_weight
|
||||
torch.randn(1), # conv_bias
|
||||
torch.randn(1), # bn_weight
|
||||
torch.randn(1), # bn_bias
|
||||
torch.randn(1), # bn_running_mean
|
||||
torch.randn(1), # bn_running_var
|
||||
)
|
||||
|
||||
has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
|
||||
if not has_bn:
|
||||
return m
|
||||
|
|
@ -859,6 +858,26 @@ def _copy_over_q_dq_args(original_node: Node, replacement_node: Node):
|
|||
|
||||
|
||||
def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
|
||||
# Example inputs for quantized and folded conv-bn1d patterns used in convert
|
||||
_quantized_conv1d_bn_example_inputs = (
|
||||
torch.randn(1, 1, 3), # x
|
||||
torch.randn(1, 1, 1), # conv_weight
|
||||
torch.randn(1), # bn_weight
|
||||
torch.randn(1), # bn_bias
|
||||
torch.randn(1), # bn_running_mean
|
||||
torch.randn(1), # bn_running_var
|
||||
)
|
||||
|
||||
# Example inputs for quantized and folded conv-bn2d patterns used in convert
|
||||
_quantized_conv2d_bn_example_inputs = (
|
||||
torch.randn(1, 1, 3, 3), # x
|
||||
torch.randn(1, 1, 1, 1), # conv_weight
|
||||
torch.randn(1), # bn_weight
|
||||
torch.randn(1), # bn_bias
|
||||
torch.randn(1), # bn_running_mean
|
||||
torch.randn(1), # bn_running_var
|
||||
)
|
||||
|
||||
has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
|
||||
if not has_bn:
|
||||
return m
|
||||
|
|
|
|||
|
|
@ -22,25 +22,6 @@ __all__ = [
|
|||
]
|
||||
|
||||
|
||||
_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (2, 5), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-127], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
|
||||
def _qdq_quantized_linear(
|
||||
x_i8,
|
||||
x_scale,
|
||||
|
|
@ -129,20 +110,6 @@ def _reference_quantized_linear(
|
|||
return out_i8
|
||||
|
||||
|
||||
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
|
||||
torch.randn((2, 5), dtype=torch.float),
|
||||
-128,
|
||||
127,
|
||||
torch.finfo(torch.float32).eps,
|
||||
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-127], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
)
|
||||
|
||||
|
||||
def _qdq_dynamic_quantized_linear(
|
||||
x_fp32,
|
||||
x_quant_min,
|
||||
|
|
@ -223,25 +190,6 @@ def _reference_dynamic_quantized_linear(
|
|||
return out_fp32
|
||||
|
||||
|
||||
_QUANTIZED_CONV2d_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-127], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
|
||||
def _qdq_quantized_conv2d(
|
||||
x_i8,
|
||||
x_scale,
|
||||
|
|
@ -375,20 +323,6 @@ def _reference_quantized_conv2d(
|
|||
return out_i8
|
||||
|
||||
|
||||
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
|
||||
def _qdq_quantized_add_relu(
|
||||
x_i8,
|
||||
x_scale,
|
||||
|
|
@ -518,19 +452,6 @@ def _reference_quantized_add(
|
|||
return out_i8
|
||||
|
||||
|
||||
_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
|
||||
def _qdq_quantized_max_pool2d(
|
||||
x_i8,
|
||||
x_scale,
|
||||
|
|
@ -587,15 +508,6 @@ def _reference_quantized_max_pool2d(
|
|||
return out_i8
|
||||
|
||||
|
||||
_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
|
||||
torch.randn(1, 3, 3, 3, dtype=torch.float),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
|
||||
def _quantize_per_tensor_int8(x_fp32, scale, zero_point, quant_min, quant_max):
|
||||
x = torch.ops.quantized_decomposed.quantize_per_tensor(
|
||||
x_fp32, scale, zero_point, quant_min, quant_max, torch.int8
|
||||
|
|
@ -619,15 +531,6 @@ def _reference_quantize_per_tensor_int8(
|
|||
return x
|
||||
|
||||
|
||||
_DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
|
||||
def _dequantize_per_tensor_int8(x_i8, scale, zero_point, quant_min, quant_max):
|
||||
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
||||
x_i8, scale, zero_point, quant_min, quant_max, torch.int8
|
||||
|
|
@ -648,16 +551,6 @@ def _reference_dequantize_per_tensor_int8(
|
|||
return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
|
||||
|
||||
|
||||
_QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
|
||||
torch.randn(1, 3, 3, 3, dtype=torch.float),
|
||||
torch.randn(3, dtype=torch.float),
|
||||
torch.zeros(3, dtype=torch.int),
|
||||
1,
|
||||
-128,
|
||||
127,
|
||||
)
|
||||
|
||||
|
||||
def _quantize_per_channel_int8(
|
||||
x_fp32, scales, zero_points, ch_axis, quant_min, quant_max
|
||||
):
|
||||
|
|
@ -678,16 +571,6 @@ def _reference_quantize_per_channel_int8(
|
|||
return out_i32.to(torch.int8)
|
||||
|
||||
|
||||
_DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(3, dtype=torch.float),
|
||||
torch.zeros(3, dtype=torch.int),
|
||||
1,
|
||||
-128,
|
||||
127,
|
||||
)
|
||||
|
||||
|
||||
def _dequantize_per_channel_int8(
|
||||
x_i8, scales, zero_points, ch_axis, quant_min, quant_max
|
||||
):
|
||||
|
|
@ -733,79 +616,186 @@ class _RewriteInfo:
|
|||
replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
|
||||
|
||||
|
||||
_REWRITE_INFO_LIST = [
|
||||
_RewriteInfo(
|
||||
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_qdq_dynamic_quantized_linear),
|
||||
_WrapperModule(_reference_dynamic_quantized_linear),
|
||||
partial(
|
||||
_replace_literals_with_existing_placeholders,
|
||||
literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3},
|
||||
),
|
||||
partial(
|
||||
_replace_literals_with_existing_placeholders,
|
||||
literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3},
|
||||
),
|
||||
),
|
||||
_RewriteInfo(
|
||||
_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_qdq_quantized_linear),
|
||||
_WrapperModule(_reference_quantized_linear),
|
||||
_replace_literals_with_new_placeholders,
|
||||
_replace_literals_with_new_placeholders,
|
||||
),
|
||||
_RewriteInfo(
|
||||
_QUANTIZED_CONV2d_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_qdq_quantized_conv2d),
|
||||
_WrapperModule(_reference_quantized_conv2d),
|
||||
partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]),
|
||||
partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]),
|
||||
),
|
||||
_RewriteInfo(
|
||||
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_qdq_quantized_add_relu),
|
||||
_WrapperModule(_reference_quantized_add_relu),
|
||||
),
|
||||
_RewriteInfo(
|
||||
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_qdq_quantized_add),
|
||||
_WrapperModule(_reference_quantized_add),
|
||||
),
|
||||
_RewriteInfo(
|
||||
_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_qdq_quantized_max_pool2d),
|
||||
_WrapperModule(_reference_quantized_max_pool2d),
|
||||
_replace_literals_with_new_placeholders,
|
||||
_replace_literals_with_new_placeholders,
|
||||
),
|
||||
_RewriteInfo(
|
||||
_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_quantize_per_tensor_int8),
|
||||
_WrapperModule(_reference_quantize_per_tensor_int8),
|
||||
),
|
||||
_RewriteInfo(
|
||||
_DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_dequantize_per_tensor_int8),
|
||||
_WrapperModule(_reference_dequantize_per_tensor_int8),
|
||||
),
|
||||
_RewriteInfo(
|
||||
_QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_quantize_per_channel_int8),
|
||||
_WrapperModule(_reference_quantize_per_channel_int8),
|
||||
_replace_ph_qdq_per_channel_replacement,
|
||||
_replace_ph_qdq_per_channel_replacement,
|
||||
),
|
||||
_RewriteInfo(
|
||||
_DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_dequantize_per_channel_int8),
|
||||
_WrapperModule(_reference_dequantize_per_channel_int8),
|
||||
_replace_ph_qdq_per_channel_replacement,
|
||||
_replace_ph_qdq_per_channel_replacement,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def reference_representation_rewrite(model: GraphModule) -> GraphModule:
|
||||
_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (2, 5), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-127], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
|
||||
torch.randn((2, 5), dtype=torch.float),
|
||||
-128,
|
||||
127,
|
||||
torch.finfo(torch.float32).eps,
|
||||
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-127], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
)
|
||||
|
||||
_QUANTIZED_CONV2d_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-127], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
|
||||
torch.randn(1, 3, 3, 3, dtype=torch.float),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
_DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
_QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
|
||||
torch.randn(1, 3, 3, 3, dtype=torch.float),
|
||||
torch.randn(3, dtype=torch.float),
|
||||
torch.zeros(3, dtype=torch.int),
|
||||
1,
|
||||
-128,
|
||||
127,
|
||||
)
|
||||
|
||||
_DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(3, dtype=torch.float),
|
||||
torch.zeros(3, dtype=torch.int),
|
||||
1,
|
||||
-128,
|
||||
127,
|
||||
)
|
||||
|
||||
_REWRITE_INFO_LIST = [
|
||||
_RewriteInfo(
|
||||
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_qdq_dynamic_quantized_linear),
|
||||
_WrapperModule(_reference_dynamic_quantized_linear),
|
||||
partial(
|
||||
_replace_literals_with_existing_placeholders,
|
||||
literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3},
|
||||
),
|
||||
partial(
|
||||
_replace_literals_with_existing_placeholders,
|
||||
literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3},
|
||||
),
|
||||
),
|
||||
_RewriteInfo(
|
||||
_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_qdq_quantized_linear),
|
||||
_WrapperModule(_reference_quantized_linear),
|
||||
_replace_literals_with_new_placeholders,
|
||||
_replace_literals_with_new_placeholders,
|
||||
),
|
||||
_RewriteInfo(
|
||||
_QUANTIZED_CONV2d_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_qdq_quantized_conv2d),
|
||||
_WrapperModule(_reference_quantized_conv2d),
|
||||
partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]),
|
||||
partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]),
|
||||
),
|
||||
_RewriteInfo(
|
||||
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_qdq_quantized_add_relu),
|
||||
_WrapperModule(_reference_quantized_add_relu),
|
||||
),
|
||||
_RewriteInfo(
|
||||
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_qdq_quantized_add),
|
||||
_WrapperModule(_reference_quantized_add),
|
||||
),
|
||||
_RewriteInfo(
|
||||
_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_qdq_quantized_max_pool2d),
|
||||
_WrapperModule(_reference_quantized_max_pool2d),
|
||||
_replace_literals_with_new_placeholders,
|
||||
_replace_literals_with_new_placeholders,
|
||||
),
|
||||
_RewriteInfo(
|
||||
_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_quantize_per_tensor_int8),
|
||||
_WrapperModule(_reference_quantize_per_tensor_int8),
|
||||
),
|
||||
_RewriteInfo(
|
||||
_DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_dequantize_per_tensor_int8),
|
||||
_WrapperModule(_reference_dequantize_per_tensor_int8),
|
||||
),
|
||||
_RewriteInfo(
|
||||
_QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_quantize_per_channel_int8),
|
||||
_WrapperModule(_reference_quantize_per_channel_int8),
|
||||
_replace_ph_qdq_per_channel_replacement,
|
||||
_replace_ph_qdq_per_channel_replacement,
|
||||
),
|
||||
_RewriteInfo(
|
||||
_DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_dequantize_per_channel_int8),
|
||||
_WrapperModule(_reference_dequantize_per_channel_int8),
|
||||
_replace_ph_qdq_per_channel_replacement,
|
||||
_replace_ph_qdq_per_channel_replacement,
|
||||
),
|
||||
]
|
||||
|
||||
remove_tensor_overload_for_qdq_ops(model)
|
||||
from torch._export import gm_using_training_ir
|
||||
|
||||
|
|
|
|||
|
|
@ -33,28 +33,6 @@ _DEQUANTIZE_OPS = [
|
|||
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
||||
]
|
||||
|
||||
# Example inputs for conv-bn1d patterns
|
||||
_conv1d_bn_example_inputs = (
|
||||
torch.randn(1, 1, 3), # x
|
||||
torch.randn(1, 1, 1), # conv_weight
|
||||
torch.randn(1), # conv_bias
|
||||
torch.randn(1), # bn_weight
|
||||
torch.randn(1), # bn_bias
|
||||
torch.randn(1), # bn_running_mean
|
||||
torch.randn(1), # bn_running_var
|
||||
)
|
||||
|
||||
# Example inputs for conv-bn2d patterns
|
||||
_conv2d_bn_example_inputs = (
|
||||
torch.randn(1, 1, 3, 3), # x
|
||||
torch.randn(1, 1, 1, 1), # conv_weight
|
||||
torch.randn(1), # conv_bias
|
||||
torch.randn(1), # bn_weight
|
||||
torch.randn(1), # bn_bias
|
||||
torch.randn(1), # bn_running_mean
|
||||
torch.randn(1), # bn_running_var
|
||||
)
|
||||
|
||||
|
||||
def _is_connected(source: torch.fx.Node, dest: torch.fx.Node) -> bool:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -10,8 +10,6 @@ from torch._subclasses import FakeTensor
|
|||
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
|
||||
from torch.ao.quantization.pt2e.export_utils import _WrapperModule
|
||||
from torch.ao.quantization.pt2e.utils import (
|
||||
_conv1d_bn_example_inputs,
|
||||
_conv2d_bn_example_inputs,
|
||||
_get_aten_graph_module_for_pattern,
|
||||
_is_conv_node,
|
||||
_is_conv_transpose_node,
|
||||
|
|
@ -487,6 +485,28 @@ def _do_annotate_conv_bn(
|
|||
for the following names: "input", "conv", "weight", "bias", and "output".
|
||||
"""
|
||||
|
||||
# Example inputs for conv-bn1d patterns
|
||||
_conv1d_bn_example_inputs = (
|
||||
torch.randn(1, 1, 3), # x
|
||||
torch.randn(1, 1, 1), # conv_weight
|
||||
torch.randn(1), # conv_bias
|
||||
torch.randn(1), # bn_weight
|
||||
torch.randn(1), # bn_bias
|
||||
torch.randn(1), # bn_running_mean
|
||||
torch.randn(1), # bn_running_var
|
||||
)
|
||||
|
||||
# Example inputs for conv-bn2d patterns
|
||||
_conv2d_bn_example_inputs = (
|
||||
torch.randn(1, 1, 3, 3), # x
|
||||
torch.randn(1, 1, 1, 1), # conv_weight
|
||||
torch.randn(1), # conv_bias
|
||||
torch.randn(1), # bn_weight
|
||||
torch.randn(1), # bn_bias
|
||||
torch.randn(1), # bn_running_mean
|
||||
torch.randn(1), # bn_running_var
|
||||
)
|
||||
|
||||
def get_pattern(conv_fn: Callable, relu_is_inplace: bool):
|
||||
def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv):
|
||||
conv = conv_fn(x, conv_weight, conv_bias)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -8,7 +8,7 @@ import torch.cuda
|
|||
import torch.jit
|
||||
import torch.jit._logging
|
||||
import torch.jit.frontend
|
||||
from torch.testing._internal.common_nn import module_tests, new_module_tests
|
||||
from torch.testing._internal.common_nn import module_tests, get_new_module_tests
|
||||
from torch.testing._internal.common_utils import is_iterable_of_tensors, noncontiguous_like
|
||||
|
||||
import collections
|
||||
|
|
@ -95,226 +95,228 @@ def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwarg
|
|||
# fn mapping output to part that should be gradcheck'ed, // optional
|
||||
# kwargs for function, // optional
|
||||
# )
|
||||
nn_functional_tests = [
|
||||
('conv1d', (S, S, S), ((S, S, S),)),
|
||||
('conv2d', (S, S, S, S), ((S, S, S, S),)),
|
||||
('conv3d', (S, S, S, S, S), ((S, S, S, S, S),)),
|
||||
('conv_transpose1d', (S, S, S), ((S, S, S),)),
|
||||
('conv_transpose2d', (S, S, S, S), ((S, S, S, S),)),
|
||||
('conv_transpose3d', (S, S, S, S, S), ((S, S, S, S, S),)),
|
||||
('conv_tbc', (S, S, S), ((S, S, S), (S,), 2)),
|
||||
('avg_pool1d', (S, S, S), (3,)),
|
||||
('avg_pool2d', (S, S, S, S), (3,), '', (True,)),
|
||||
('avg_pool3d', (S, S, S, S, S), (3,)),
|
||||
('fractional_max_pool2d', (S, S, S, S), (3, [2, 3],)),
|
||||
('max_pool1d', (S, S, S), (2, 1)),
|
||||
('max_pool1d', (S, S, S), (2, 1, 1, 1, False, True), 'with_indices'),
|
||||
('max_pool2d', (S, S, S, S), (2, 1), '', (True, 'aten::max_pool2d_with_indices')),
|
||||
('max_pool2d', (S, S, S, S), (2, 1, 1, 1, False, True), 'with_indices', (True, 'aten::max_pool2d_with_indices')),
|
||||
('max_pool3d', (S, S, S, S, S), (2, 1)),
|
||||
('max_unpool1d', torch.tensor([[[2., 4]]]), (torch.tensor([[[1, 3]]]), 2, 2, 0)),
|
||||
('max_unpool2d', torch.tensor([[[[2., 4]]]]), (torch.tensor([[[[1, 3]]]]), 2, 2, 0)),
|
||||
('max_unpool3d', torch.tensor([[[[[2., 4]]]]]), (torch.tensor([[[[[1, 3]]]]]), 2, 2, 0)),
|
||||
('lp_pool1d', (S, S, S), (2., 3, 2,)),
|
||||
('lp_pool2d', (S, S, S, S), (2., 3, 2,)),
|
||||
('lp_pool3d', (S, S, S, S, S), (2., 3, 2,)),
|
||||
('adaptive_max_pool1d', (S, S, S), (5,)),
|
||||
('adaptive_max_pool2d', (S, S, S, S), ([5, 7],)),
|
||||
('adaptive_max_pool3d', (S, S, S, S, S), ([3, 2, 2],)),
|
||||
('adaptive_avg_pool1d', (S, S, S), (5,), '', (True,)),
|
||||
('adaptive_avg_pool2d', (S, S, S, S), ([5, 7],), '', (True,)),
|
||||
('adaptive_avg_pool3d', (S, S, S, S, S), ([3, 2, 2],), '', (True,)),
|
||||
('dropout', (S, S, S), (0.5,), '', (True, 'aten::native_dropout')),
|
||||
('alpha_dropout', (S, S, S), (0.5,)),
|
||||
('dropout2d', (S, S, S), (0.5,)),
|
||||
('dropout2d', (S, S, S, S), (0.5,), 'batched'),
|
||||
('dropout3d', (S, S, S, S), (0.5,)),
|
||||
('dropout3d', (S, S, S, S, S), (0.5,), 'batched'),
|
||||
('feature_alpha_dropout', (S, S, S), (0.5,)),
|
||||
('threshold', (S, S, S), (0.1, 2.), '', (True,)),
|
||||
('threshold', (S, S, S), (0.1, 2., True), 'inplace'),
|
||||
('relu', (S, S, S), (), '', (True,)),
|
||||
('relu', (S, S, S), (), 'inplace'),
|
||||
('glu', (S - 1, S - 1, S - 1), (),),
|
||||
('hardtanh', (S, S, S), (-0.5, 0.5), '', (True,)),
|
||||
('hardtanh', (S, S, S), (-0.5, 0.5, True), 'inplace'),
|
||||
('relu6', (S, S, S), (), '', (True,)),
|
||||
('relu6', (S, S, S), (True), 'inplace'),
|
||||
('elu', (S, S, S), (0.9,),),
|
||||
('elu', (S, S, S), (0.9, True), 'inplace'),
|
||||
('selu', (S, S, S), (),),
|
||||
('selu', (S, S, S), (True), 'inplace'),
|
||||
('celu', (S, S, S), (0.9,),),
|
||||
('celu', (S, S, S), (0.9, True), 'inplace'),
|
||||
('leaky_relu', (S, S, S), (0.02,), '', (True,)),
|
||||
('leaky_relu', (S, S, S), (0.02,), 'inplace'),
|
||||
('rrelu', (S, S), (0.1, 0.3, False),),
|
||||
('rrelu', (S, S), (0.1, 0.3, False, True), 'inplace'),
|
||||
('hardshrink', (S, S, S), (0.4,), '', (True,)),
|
||||
('tanhshrink', (S, S, S), (),),
|
||||
('softsign', (S, S, S), (),),
|
||||
('softplus', (S, S, S), (), '', (True,)),
|
||||
('softmin', (S, S, S), (0,),),
|
||||
('softmax', (S, S, S), (0,), '', (True,)),
|
||||
('softmax', (S, S, S), (0, 3, torch.double), 'with_all_args', (True,)),
|
||||
('tanh', (S, S, S), (), '', (True,)),
|
||||
('sigmoid', (S, S, S), (), '', (True,)),
|
||||
('silu', (S, S, S), (), '', (True,)),
|
||||
('log_softmax', (S, S, S), (0,), '', (True,)),
|
||||
('linear', (S, S), ((M, S),), '', (True, ['aten::linear'])),
|
||||
('linear', (S, S), ((M, S), (M,)), 'addmm', (True, ['aten::linear'])),
|
||||
('bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M),),),
|
||||
('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ), '', (True,)),
|
||||
('embedding_bag', torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3), torch.tensor([0, 4]),),),
|
||||
('batch_norm', (S, S),
|
||||
(non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), None, None, True, ),
|
||||
'training', (True, 'aten::_batch_norm_impl_index')),
|
||||
('batch_norm', (0, S, S, S),
|
||||
(non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
||||
non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
|
||||
'size_zero', (True, 'aten::_batch_norm_impl_index')),
|
||||
('batch_norm', (0, S, S, S),
|
||||
(non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
||||
non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
|
||||
'size_zero_inference', (True, 'aten::_batch_norm_impl_index')),
|
||||
('batch_norm', (S, S),
|
||||
(non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
||||
non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
|
||||
'with_weight_and_bias_training', (True, 'aten::_batch_norm_impl_index')),
|
||||
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
||||
None, non_differentiable(torch.ones(S)), True, ),
|
||||
'with_only_bias_training', (True, 'aten::_batch_norm_impl_index')),
|
||||
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
||||
non_differentiable(torch.randn(S)), None, True, ),
|
||||
'with_only_weight_training', (True, 'aten::_batch_norm_impl_index')),
|
||||
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
||||
None, None, False, ),
|
||||
'inference', (True, 'aten::_batch_norm_impl_index')),
|
||||
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
||||
non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), False, ),
|
||||
'with_weight_and_bias_inference', (True, 'aten::_batch_norm_impl_index')),
|
||||
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
||||
None, non_differentiable(torch.ones(S)), False, ),
|
||||
'with_only_bias_inference', (True, 'aten::_batch_norm_impl_index')),
|
||||
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
||||
non_differentiable(torch.randn(S)), None, False, ),
|
||||
'with_only_weight_inference', (True, 'aten::_batch_norm_impl_index')),
|
||||
('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),),
|
||||
('layer_norm', (S, S, S, S), ([5],), '',
|
||||
(False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
|
||||
('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),), 'with_only_weight',
|
||||
(False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
|
||||
('layer_norm', (S, S, S, S), ([5], None, non_differentiable(torch.rand(S)),), 'with_only_bias',
|
||||
(False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
|
||||
('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),
|
||||
non_differentiable(torch.rand(S))), 'with_weight_and_bias',
|
||||
(False, ['aten::contiguous', 'aten::_batch_norm_impl_index', 'aten::addcmul'])),
|
||||
('group_norm', (S, S, S), (1, torch.rand(5),),),
|
||||
('local_response_norm', (S, S, S), (2, ),),
|
||||
('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),), '',),
|
||||
('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2),),),
|
||||
('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2), True, True), 'full'),
|
||||
('kl_div', F.log_softmax(torch.randn(S, 10), 1), (F.softmax(torch.randn(S, 10), 1),),),
|
||||
('cross_entropy', (3, S), (torch.randint(S, (3,), dtype=torch.int64),),),
|
||||
('binary_cross_entropy_with_logits', (3,), (torch.empty(3).random_(2), ),),
|
||||
('smooth_l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
||||
('huber_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
||||
('l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
||||
('mse_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
||||
('smooth_l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
|
||||
('huber_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
|
||||
('l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
|
||||
('mse_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
|
||||
('margin_ranking_loss', (S,), ((S,), (S,)),),
|
||||
('hinge_embedding_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
||||
('soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
||||
('multilabel_soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
||||
('cosine_embedding_loss', (S, S), ((S, S), non_differentiable(torch.rand(S,))),),
|
||||
('pixel_shuffle', (1, 9, 4, 4), (3,),),
|
||||
('pixel_unshuffle', (1, 1, 12, 12), (3,),),
|
||||
('affine_grid', (S, 2, 3), (torch.Size([S, 1, 7, 7]),),),
|
||||
('pad', (3, 3, 4, 2), ([1, 1],),),
|
||||
('pairwise_distance', (S, S), ((S, S),),),
|
||||
('pdist', (S, S), (),),
|
||||
('cosine_similarity', (S, S), ((S, S),),),
|
||||
('triplet_margin_loss', (S, S), ((S, S), (S, S)),),
|
||||
('normalize', (S, S, S), (),),
|
||||
('unfold', (S, S, S, S), ([2, 3]),),
|
||||
('fold', (1, 3 * 2 * 2, 12), ([4, 5], [2, 2]),),
|
||||
('grid_sample', (S, S, S, S), (non_differentiable(torch.rand(S, S, S, 2)),),),
|
||||
('gumbel_softmax', (S, S), (2.,), '', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])),
|
||||
('gumbel_softmax', (S, S), (2., True,), 'hard', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])),
|
||||
('multilabel_margin_loss', torch.tensor([[0.2, -0.2, 0.07]]), (torch.tensor([[0, 0, 1]]),),),
|
||||
('multi_margin_loss', (S, S), (non_differentiable(torch.randint(S, (S, ), dtype=torch.int64)),
|
||||
1, 1., non_differentiable(torch.randn(S))),),
|
||||
('binary_cross_entropy', torch.randn(3, 2).sigmoid(), (non_differentiable(torch.rand(3, 2)),
|
||||
non_differentiable(torch.randn(3, 2))),),
|
||||
('binary_cross_entropy', torch.randn(3, 2).sigmoid(),
|
||||
(non_differentiable(torch.rand(3, 2)),
|
||||
non_differentiable(torch.randn(3, 2)), None, None, 'mean'), 'size_average'),
|
||||
('ctc_loss', torch.rand(S, S, S).log_softmax(2).detach().requires_grad_(),
|
||||
(torch.randint(1, S, (S, S), dtype=torch.long), torch.full((S,), S, dtype=torch.long),
|
||||
torch.randint(1, S, (S,), dtype=torch.long))),
|
||||
('upsample', torch.randn(S, S, M, M), (None, 2.), 'with_scale'),
|
||||
('upsample', torch.randn(S, S, M, M), (4,), 'with_size'),
|
||||
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'nearest_4d'),
|
||||
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'nearest_4d_with_scale'),
|
||||
('interpolate', torch.randn(S, S, M, M), (4,), 'nearest_4d_with_size'),
|
||||
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'area_4d'),
|
||||
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'area_4d_with_scale'),
|
||||
('interpolate', torch.randn(S, S, M, M), (4,), 'area_4d_with_size'),
|
||||
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bilinear_4d'),
|
||||
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bilinear_4d_with_scale'),
|
||||
('interpolate', torch.randn(S, S, M, M), (4,), 'bilinear_4d_with_size'),
|
||||
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bicubic_4d'),
|
||||
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bicubic_4d_with_scale'),
|
||||
('interpolate', torch.randn(S, S, M, M), (4,), 'bicubic_4d_with_size'),
|
||||
('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'nearest_3d'),
|
||||
('interpolate', torch.randn(S, M, M), (None, 2.), 'nearest_3d_with_scale'),
|
||||
('interpolate', torch.randn(S, M, M), (4,), 'nearest_3d_with_size'),
|
||||
('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'area_3d'),
|
||||
('interpolate', torch.randn(S, M, M), (None, 2.), 'area_3d_with_scale'),
|
||||
('interpolate', torch.randn(S, M, M), (4,), 'area_3d_with_size'),
|
||||
('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'linear_3d'),
|
||||
('interpolate', torch.randn(S, M, M), (None, 2.), 'linear_3d_with_scale'),
|
||||
('interpolate', torch.randn(S, M, M), (4,), 'linear_3d_with_size'),
|
||||
('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'nearest_5d_with_scale'),
|
||||
('interpolate', torch.randn(S, M, M, M, M), (4,), 'nearest_5d_with_size'),
|
||||
('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'area_5d'),
|
||||
('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'area_5d_with_scale'),
|
||||
('interpolate', torch.randn(S, M, M, M, M), (4,), 'area_5d_with_size'),
|
||||
('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'trilinear_5d'),
|
||||
('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'trilinear_5d_with_scale'),
|
||||
('interpolate', torch.randn(S, M, M, M, M), (4,), 'trilinear_5d_with_size'),
|
||||
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2, None, 'nearest', None, False),
|
||||
'nearest_4d_not_recompute_scale_factor'),
|
||||
('interpolate', torch.randn(S, S, M, M), (4, None, 'nearest', None, False),
|
||||
'nearest_4d_with_size_not_recompute_scale_factor'),
|
||||
('interpolate', torch.randn(S, S, M, M), (None, 2., 'bilinear', None, False),
|
||||
'bilinear_4d_with_scale_not_recompute_scale_factor'),
|
||||
('interpolate', torch.randn(S, S, M, M), (4, None, 'bilinear', None, False),
|
||||
'bilinear_4d_with_size_not_recompute_scale_factor'),
|
||||
('interpolate', torch.randn(S, S, M, M), (None, 2., 'bicubic', None, False),
|
||||
'bicubic_4d_with_scale_not_recompute_scale_factor'),
|
||||
('interpolate', torch.randn(S, S, M, M), (4, None, 'bicubic', None, False),
|
||||
'bicubic_4d_with_size_not_recompute_scale_factor'),
|
||||
('interpolate', torch.randn(S, M, M), (None, 2., 'nearest', None, False),
|
||||
'nearest_3d_with_scale_not_recompute_scale_factor'),
|
||||
('interpolate', torch.randn(S, M, M), (4, None, 'nearest', None, False),
|
||||
'nearest_3d_with_size_not_recompute_scale_factor'),
|
||||
('interpolate', torch.randn(S, M, M), (None, 2., 'linear', None, False),
|
||||
'linear_3d_with_scale_not_recompute_scale_factor'),
|
||||
('interpolate', torch.randn(S, M, M), (4, None, 'linear', None, False),
|
||||
'linear_3d_with_size_not_recompute_scale_factor'),
|
||||
('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'nearest', None, False),
|
||||
'nearest_5d_with_scale_not_recompute_scale_factor'),
|
||||
('interpolate', torch.randn(S, M, M, M, M), (4, None, 'nearest', None, False),
|
||||
'nearest_5d_with_size_not_recompute_scale_factor'),
|
||||
('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'trilinear', None, False),
|
||||
'trilinear_5d_with_scale_not_recompute_scale_factor'),
|
||||
('interpolate', torch.randn(S, M, M, M, M), (4, None, 'trilinear', None, False),
|
||||
'trilinear_5d_with_size_not_recompute_scale_factor'),
|
||||
]
|
||||
def get_nn_functional_tests():
|
||||
nn_functional_tests = [
|
||||
('conv1d', (S, S, S), ((S, S, S),)),
|
||||
('conv2d', (S, S, S, S), ((S, S, S, S),)),
|
||||
('conv3d', (S, S, S, S, S), ((S, S, S, S, S),)),
|
||||
('conv_transpose1d', (S, S, S), ((S, S, S),)),
|
||||
('conv_transpose2d', (S, S, S, S), ((S, S, S, S),)),
|
||||
('conv_transpose3d', (S, S, S, S, S), ((S, S, S, S, S),)),
|
||||
('conv_tbc', (S, S, S), ((S, S, S), (S,), 2)),
|
||||
('avg_pool1d', (S, S, S), (3,)),
|
||||
('avg_pool2d', (S, S, S, S), (3,), '', (True,)),
|
||||
('avg_pool3d', (S, S, S, S, S), (3,)),
|
||||
('fractional_max_pool2d', (S, S, S, S), (3, [2, 3],)),
|
||||
('max_pool1d', (S, S, S), (2, 1)),
|
||||
('max_pool1d', (S, S, S), (2, 1, 1, 1, False, True), 'with_indices'),
|
||||
('max_pool2d', (S, S, S, S), (2, 1), '', (True, 'aten::max_pool2d_with_indices')),
|
||||
('max_pool2d', (S, S, S, S), (2, 1, 1, 1, False, True), 'with_indices', (True, 'aten::max_pool2d_with_indices')),
|
||||
('max_pool3d', (S, S, S, S, S), (2, 1)),
|
||||
('max_unpool1d', torch.tensor([[[2., 4]]]), (torch.tensor([[[1, 3]]]), 2, 2, 0)),
|
||||
('max_unpool2d', torch.tensor([[[[2., 4]]]]), (torch.tensor([[[[1, 3]]]]), 2, 2, 0)),
|
||||
('max_unpool3d', torch.tensor([[[[[2., 4]]]]]), (torch.tensor([[[[[1, 3]]]]]), 2, 2, 0)),
|
||||
('lp_pool1d', (S, S, S), (2., 3, 2,)),
|
||||
('lp_pool2d', (S, S, S, S), (2., 3, 2,)),
|
||||
('lp_pool3d', (S, S, S, S, S), (2., 3, 2,)),
|
||||
('adaptive_max_pool1d', (S, S, S), (5,)),
|
||||
('adaptive_max_pool2d', (S, S, S, S), ([5, 7],)),
|
||||
('adaptive_max_pool3d', (S, S, S, S, S), ([3, 2, 2],)),
|
||||
('adaptive_avg_pool1d', (S, S, S), (5,), '', (True,)),
|
||||
('adaptive_avg_pool2d', (S, S, S, S), ([5, 7],), '', (True,)),
|
||||
('adaptive_avg_pool3d', (S, S, S, S, S), ([3, 2, 2],), '', (True,)),
|
||||
('dropout', (S, S, S), (0.5,), '', (True, 'aten::native_dropout')),
|
||||
('alpha_dropout', (S, S, S), (0.5,)),
|
||||
('dropout2d', (S, S, S), (0.5,)),
|
||||
('dropout2d', (S, S, S, S), (0.5,), 'batched'),
|
||||
('dropout3d', (S, S, S, S), (0.5,)),
|
||||
('dropout3d', (S, S, S, S, S), (0.5,), 'batched'),
|
||||
('feature_alpha_dropout', (S, S, S), (0.5,)),
|
||||
('threshold', (S, S, S), (0.1, 2.), '', (True,)),
|
||||
('threshold', (S, S, S), (0.1, 2., True), 'inplace'),
|
||||
('relu', (S, S, S), (), '', (True,)),
|
||||
('relu', (S, S, S), (), 'inplace'),
|
||||
('glu', (S - 1, S - 1, S - 1), (),),
|
||||
('hardtanh', (S, S, S), (-0.5, 0.5), '', (True,)),
|
||||
('hardtanh', (S, S, S), (-0.5, 0.5, True), 'inplace'),
|
||||
('relu6', (S, S, S), (), '', (True,)),
|
||||
('relu6', (S, S, S), (True), 'inplace'),
|
||||
('elu', (S, S, S), (0.9,),),
|
||||
('elu', (S, S, S), (0.9, True), 'inplace'),
|
||||
('selu', (S, S, S), (),),
|
||||
('selu', (S, S, S), (True), 'inplace'),
|
||||
('celu', (S, S, S), (0.9,),),
|
||||
('celu', (S, S, S), (0.9, True), 'inplace'),
|
||||
('leaky_relu', (S, S, S), (0.02,), '', (True,)),
|
||||
('leaky_relu', (S, S, S), (0.02,), 'inplace'),
|
||||
('rrelu', (S, S), (0.1, 0.3, False),),
|
||||
('rrelu', (S, S), (0.1, 0.3, False, True), 'inplace'),
|
||||
('hardshrink', (S, S, S), (0.4,), '', (True,)),
|
||||
('tanhshrink', (S, S, S), (),),
|
||||
('softsign', (S, S, S), (),),
|
||||
('softplus', (S, S, S), (), '', (True,)),
|
||||
('softmin', (S, S, S), (0,),),
|
||||
('softmax', (S, S, S), (0,), '', (True,)),
|
||||
('softmax', (S, S, S), (0, 3, torch.double), 'with_all_args', (True,)),
|
||||
('tanh', (S, S, S), (), '', (True,)),
|
||||
('sigmoid', (S, S, S), (), '', (True,)),
|
||||
('silu', (S, S, S), (), '', (True,)),
|
||||
('log_softmax', (S, S, S), (0,), '', (True,)),
|
||||
('linear', (S, S), ((M, S),), '', (True, ['aten::linear'])),
|
||||
('linear', (S, S), ((M, S), (M,)), 'addmm', (True, ['aten::linear'])),
|
||||
('bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M),),),
|
||||
('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ), '', (True,)),
|
||||
('embedding_bag', torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3), torch.tensor([0, 4]),),),
|
||||
('batch_norm', (S, S),
|
||||
(non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), None, None, True, ),
|
||||
'training', (True, 'aten::_batch_norm_impl_index')),
|
||||
('batch_norm', (0, S, S, S),
|
||||
(non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
||||
non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
|
||||
'size_zero', (True, 'aten::_batch_norm_impl_index')),
|
||||
('batch_norm', (0, S, S, S),
|
||||
(non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
||||
non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
|
||||
'size_zero_inference', (True, 'aten::_batch_norm_impl_index')),
|
||||
('batch_norm', (S, S),
|
||||
(non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
||||
non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
|
||||
'with_weight_and_bias_training', (True, 'aten::_batch_norm_impl_index')),
|
||||
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
||||
None, non_differentiable(torch.ones(S)), True, ),
|
||||
'with_only_bias_training', (True, 'aten::_batch_norm_impl_index')),
|
||||
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
||||
non_differentiable(torch.randn(S)), None, True, ),
|
||||
'with_only_weight_training', (True, 'aten::_batch_norm_impl_index')),
|
||||
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
||||
None, None, False, ),
|
||||
'inference', (True, 'aten::_batch_norm_impl_index')),
|
||||
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
||||
non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), False, ),
|
||||
'with_weight_and_bias_inference', (True, 'aten::_batch_norm_impl_index')),
|
||||
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
||||
None, non_differentiable(torch.ones(S)), False, ),
|
||||
'with_only_bias_inference', (True, 'aten::_batch_norm_impl_index')),
|
||||
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
||||
non_differentiable(torch.randn(S)), None, False, ),
|
||||
'with_only_weight_inference', (True, 'aten::_batch_norm_impl_index')),
|
||||
('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),),
|
||||
('layer_norm', (S, S, S, S), ([5],), '',
|
||||
(False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
|
||||
('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),), 'with_only_weight',
|
||||
(False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
|
||||
('layer_norm', (S, S, S, S), ([5], None, non_differentiable(torch.rand(S)),), 'with_only_bias',
|
||||
(False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
|
||||
('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),
|
||||
non_differentiable(torch.rand(S))), 'with_weight_and_bias',
|
||||
(False, ['aten::contiguous', 'aten::_batch_norm_impl_index', 'aten::addcmul'])),
|
||||
('group_norm', (S, S, S), (1, torch.rand(5),),),
|
||||
('local_response_norm', (S, S, S), (2, ),),
|
||||
('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),), '',),
|
||||
('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2),),),
|
||||
('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2), True, True), 'full'),
|
||||
('kl_div', F.log_softmax(torch.randn(S, 10), 1), (F.softmax(torch.randn(S, 10), 1),),),
|
||||
('cross_entropy', (3, S), (torch.randint(S, (3,), dtype=torch.int64),),),
|
||||
('binary_cross_entropy_with_logits', (3,), (torch.empty(3).random_(2), ),),
|
||||
('smooth_l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
||||
('huber_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
||||
('l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
||||
('mse_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
||||
('smooth_l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
|
||||
('huber_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
|
||||
('l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
|
||||
('mse_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
|
||||
('margin_ranking_loss', (S,), ((S,), (S,)),),
|
||||
('hinge_embedding_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
||||
('soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
||||
('multilabel_soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
||||
('cosine_embedding_loss', (S, S), ((S, S), non_differentiable(torch.rand(S,))),),
|
||||
('pixel_shuffle', (1, 9, 4, 4), (3,),),
|
||||
('pixel_unshuffle', (1, 1, 12, 12), (3,),),
|
||||
('affine_grid', (S, 2, 3), (torch.Size([S, 1, 7, 7]),),),
|
||||
('pad', (3, 3, 4, 2), ([1, 1],),),
|
||||
('pairwise_distance', (S, S), ((S, S),),),
|
||||
('pdist', (S, S), (),),
|
||||
('cosine_similarity', (S, S), ((S, S),),),
|
||||
('triplet_margin_loss', (S, S), ((S, S), (S, S)),),
|
||||
('normalize', (S, S, S), (),),
|
||||
('unfold', (S, S, S, S), ([2, 3]),),
|
||||
('fold', (1, 3 * 2 * 2, 12), ([4, 5], [2, 2]),),
|
||||
('grid_sample', (S, S, S, S), (non_differentiable(torch.rand(S, S, S, 2)),),),
|
||||
('gumbel_softmax', (S, S), (2.,), '', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])),
|
||||
('gumbel_softmax', (S, S), (2., True,), 'hard', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])),
|
||||
('multilabel_margin_loss', torch.tensor([[0.2, -0.2, 0.07]]), (torch.tensor([[0, 0, 1]]),),),
|
||||
('multi_margin_loss', (S, S), (non_differentiable(torch.randint(S, (S, ), dtype=torch.int64)),
|
||||
1, 1., non_differentiable(torch.randn(S))),),
|
||||
('binary_cross_entropy', torch.randn(3, 2).sigmoid(), (non_differentiable(torch.rand(3, 2)),
|
||||
non_differentiable(torch.randn(3, 2))),),
|
||||
('binary_cross_entropy', torch.randn(3, 2).sigmoid(),
|
||||
(non_differentiable(torch.rand(3, 2)),
|
||||
non_differentiable(torch.randn(3, 2)), None, None, 'mean'), 'size_average'),
|
||||
('ctc_loss', torch.rand(S, S, S).log_softmax(2).detach().requires_grad_(),
|
||||
(torch.randint(1, S, (S, S), dtype=torch.long), torch.full((S,), S, dtype=torch.long),
|
||||
torch.randint(1, S, (S,), dtype=torch.long))),
|
||||
('upsample', torch.randn(S, S, M, M), (None, 2.), 'with_scale'),
|
||||
('upsample', torch.randn(S, S, M, M), (4,), 'with_size'),
|
||||
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'nearest_4d'),
|
||||
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'nearest_4d_with_scale'),
|
||||
('interpolate', torch.randn(S, S, M, M), (4,), 'nearest_4d_with_size'),
|
||||
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'area_4d'),
|
||||
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'area_4d_with_scale'),
|
||||
('interpolate', torch.randn(S, S, M, M), (4,), 'area_4d_with_size'),
|
||||
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bilinear_4d'),
|
||||
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bilinear_4d_with_scale'),
|
||||
('interpolate', torch.randn(S, S, M, M), (4,), 'bilinear_4d_with_size'),
|
||||
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bicubic_4d'),
|
||||
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bicubic_4d_with_scale'),
|
||||
('interpolate', torch.randn(S, S, M, M), (4,), 'bicubic_4d_with_size'),
|
||||
('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'nearest_3d'),
|
||||
('interpolate', torch.randn(S, M, M), (None, 2.), 'nearest_3d_with_scale'),
|
||||
('interpolate', torch.randn(S, M, M), (4,), 'nearest_3d_with_size'),
|
||||
('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'area_3d'),
|
||||
('interpolate', torch.randn(S, M, M), (None, 2.), 'area_3d_with_scale'),
|
||||
('interpolate', torch.randn(S, M, M), (4,), 'area_3d_with_size'),
|
||||
('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'linear_3d'),
|
||||
('interpolate', torch.randn(S, M, M), (None, 2.), 'linear_3d_with_scale'),
|
||||
('interpolate', torch.randn(S, M, M), (4,), 'linear_3d_with_size'),
|
||||
('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'nearest_5d_with_scale'),
|
||||
('interpolate', torch.randn(S, M, M, M, M), (4,), 'nearest_5d_with_size'),
|
||||
('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'area_5d'),
|
||||
('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'area_5d_with_scale'),
|
||||
('interpolate', torch.randn(S, M, M, M, M), (4,), 'area_5d_with_size'),
|
||||
('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'trilinear_5d'),
|
||||
('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'trilinear_5d_with_scale'),
|
||||
('interpolate', torch.randn(S, M, M, M, M), (4,), 'trilinear_5d_with_size'),
|
||||
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2, None, 'nearest', None, False),
|
||||
'nearest_4d_not_recompute_scale_factor'),
|
||||
('interpolate', torch.randn(S, S, M, M), (4, None, 'nearest', None, False),
|
||||
'nearest_4d_with_size_not_recompute_scale_factor'),
|
||||
('interpolate', torch.randn(S, S, M, M), (None, 2., 'bilinear', None, False),
|
||||
'bilinear_4d_with_scale_not_recompute_scale_factor'),
|
||||
('interpolate', torch.randn(S, S, M, M), (4, None, 'bilinear', None, False),
|
||||
'bilinear_4d_with_size_not_recompute_scale_factor'),
|
||||
('interpolate', torch.randn(S, S, M, M), (None, 2., 'bicubic', None, False),
|
||||
'bicubic_4d_with_scale_not_recompute_scale_factor'),
|
||||
('interpolate', torch.randn(S, S, M, M), (4, None, 'bicubic', None, False),
|
||||
'bicubic_4d_with_size_not_recompute_scale_factor'),
|
||||
('interpolate', torch.randn(S, M, M), (None, 2., 'nearest', None, False),
|
||||
'nearest_3d_with_scale_not_recompute_scale_factor'),
|
||||
('interpolate', torch.randn(S, M, M), (4, None, 'nearest', None, False),
|
||||
'nearest_3d_with_size_not_recompute_scale_factor'),
|
||||
('interpolate', torch.randn(S, M, M), (None, 2., 'linear', None, False),
|
||||
'linear_3d_with_scale_not_recompute_scale_factor'),
|
||||
('interpolate', torch.randn(S, M, M), (4, None, 'linear', None, False),
|
||||
'linear_3d_with_size_not_recompute_scale_factor'),
|
||||
('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'nearest', None, False),
|
||||
'nearest_5d_with_scale_not_recompute_scale_factor'),
|
||||
('interpolate', torch.randn(S, M, M, M, M), (4, None, 'nearest', None, False),
|
||||
'nearest_5d_with_size_not_recompute_scale_factor'),
|
||||
('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'trilinear', None, False),
|
||||
'trilinear_5d_with_scale_not_recompute_scale_factor'),
|
||||
('interpolate', torch.randn(S, M, M, M, M), (4, None, 'trilinear', None, False),
|
||||
'trilinear_5d_with_size_not_recompute_scale_factor'),
|
||||
]
|
||||
return nn_functional_tests
|
||||
|
||||
script_template = '''
|
||||
def the_method({}):
|
||||
|
|
@ -523,45 +525,6 @@ def get_nn_functional_compiled_fn_and_inputs(name, self_size, args, variant_name
|
|||
return script_fn, inputs
|
||||
|
||||
|
||||
# additional modules test
|
||||
# TODO: delete this list once we make all nn_tests work
|
||||
additional_module_tests = [
|
||||
{
|
||||
'module_name': 'Bilinear',
|
||||
'constructor_args': (S, S, M),
|
||||
'input_size': (S, S),
|
||||
'extra_args': ((S, S),)
|
||||
},
|
||||
{
|
||||
'module_name': 'RNNCell',
|
||||
'constructor_args': (S, S),
|
||||
'input_size': (S, S),
|
||||
},
|
||||
{
|
||||
'module_name': 'LSTMCell',
|
||||
'constructor_args': (S, S),
|
||||
'input_size': (S, S),
|
||||
},
|
||||
{
|
||||
'module_name': 'GRUCell',
|
||||
'constructor_args': (S, S),
|
||||
'input_size': (S, S),
|
||||
},
|
||||
{
|
||||
'module_name': 'MultiheadAttention',
|
||||
'constructor_args': (128, 8),
|
||||
'input_size': (10, 8, 128),
|
||||
'extra_args': (torch.randn(10, 8, 128), torch.randn(10, 8, 128)),
|
||||
'slowTest': True
|
||||
},
|
||||
{
|
||||
'module_name': 'Transformer',
|
||||
'constructor_args': (1, 1, 1, 1, 2),
|
||||
'input_size': (3, 1, 1),
|
||||
'extra_args': (torch.randn(1, 1, 1),),
|
||||
'slowTest': True
|
||||
}
|
||||
]
|
||||
|
||||
EXCLUDE_SCRIPT_MODULES = {
|
||||
'test_nn_AdaptiveAvgPool2d_tuple_none',
|
||||
|
|
@ -719,4 +682,44 @@ def try_get_nn_module_compiled_mod_and_inputs(*args, **kwargs):
|
|||
|
||||
|
||||
def get_all_nn_module_tests():
|
||||
return module_tests + new_module_tests + additional_module_tests
|
||||
# additional modules test
|
||||
# TODO: delete this list once we make all nn_tests work
|
||||
additional_module_tests = [
|
||||
{
|
||||
'module_name': 'Bilinear',
|
||||
'constructor_args': (S, S, M),
|
||||
'input_size': (S, S),
|
||||
'extra_args': ((S, S),)
|
||||
},
|
||||
{
|
||||
'module_name': 'RNNCell',
|
||||
'constructor_args': (S, S),
|
||||
'input_size': (S, S),
|
||||
},
|
||||
{
|
||||
'module_name': 'LSTMCell',
|
||||
'constructor_args': (S, S),
|
||||
'input_size': (S, S),
|
||||
},
|
||||
{
|
||||
'module_name': 'GRUCell',
|
||||
'constructor_args': (S, S),
|
||||
'input_size': (S, S),
|
||||
},
|
||||
{
|
||||
'module_name': 'MultiheadAttention',
|
||||
'constructor_args': (128, 8),
|
||||
'input_size': (10, 8, 128),
|
||||
'extra_args': (torch.randn(10, 8, 128), torch.randn(10, 8, 128)),
|
||||
'slowTest': True
|
||||
},
|
||||
{
|
||||
'module_name': 'Transformer',
|
||||
'constructor_args': (1, 1, 1, 1, 2),
|
||||
'input_size': (3, 1, 1),
|
||||
'extra_args': (torch.randn(1, 1, 1),),
|
||||
'slowTest': True
|
||||
}
|
||||
]
|
||||
|
||||
return module_tests + get_new_module_tests() + additional_module_tests
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user