mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
No actual change, just remove variable contain Tensors from global scope (#143225)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143225 Approved by: https://github.com/ezyang
This commit is contained in:
parent
afa313e669
commit
792f1c47e9
|
|
@ -20,8 +20,8 @@ from torch.testing._internal.common_utils import (
|
||||||
from torch.testing._internal.jit_metaprogramming_utils import (
|
from torch.testing._internal.jit_metaprogramming_utils import (
|
||||||
get_all_nn_module_tests,
|
get_all_nn_module_tests,
|
||||||
get_nn_functional_compiled_fn_and_inputs,
|
get_nn_functional_compiled_fn_and_inputs,
|
||||||
|
get_nn_functional_tests,
|
||||||
get_nn_mod_test_name,
|
get_nn_mod_test_name,
|
||||||
nn_functional_tests,
|
|
||||||
try_get_nn_module_compiled_mod_and_inputs,
|
try_get_nn_module_compiled_mod_and_inputs,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.jit_utils import enable_profiling_mode, JitTestCase
|
from torch.testing._internal.jit_utils import enable_profiling_mode, JitTestCase
|
||||||
|
|
@ -70,7 +70,7 @@ class TestComplexity(JitTestCase):
|
||||||
def test_generated_functional_tests(self):
|
def test_generated_functional_tests(self):
|
||||||
with enable_profiling_mode():
|
with enable_profiling_mode():
|
||||||
stats = [("Name", "Ifs/Loops", "non-tensor ops")]
|
stats = [("Name", "Ifs/Loops", "non-tensor ops")]
|
||||||
for test in nn_functional_tests:
|
for test in get_nn_functional_tests():
|
||||||
test_name = test[0]
|
test_name = test[0]
|
||||||
|
|
||||||
fn, inputs = get_nn_functional_compiled_fn_and_inputs(*test)
|
fn, inputs = get_nn_functional_compiled_fn_and_inputs(*test)
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ if not common.IS_ARM64:
|
||||||
(sample_module.module_tests, common_nn.NewModuleTest),
|
(sample_module.module_tests, common_nn.NewModuleTest),
|
||||||
(sample_functional.functional_tests, common_nn.NewModuleTest),
|
(sample_functional.functional_tests, common_nn.NewModuleTest),
|
||||||
(common_nn.module_tests, common_nn.NewModuleTest),
|
(common_nn.module_tests, common_nn.NewModuleTest),
|
||||||
(common_nn.new_module_tests, common_nn.NewModuleTest),
|
(common_nn.get_new_module_tests(), common_nn.NewModuleTest),
|
||||||
(common_nn.criterion_tests, common_nn.CriterionTest),
|
(common_nn.criterion_tests, common_nn.CriterionTest),
|
||||||
]:
|
]:
|
||||||
for test_params_dict in test_params_dicts:
|
for test_params_dict in test_params_dicts:
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,11 @@ from torch.testing._internal.common_device_type import (
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_methods_invocations import op_db, SampleInput
|
from torch.testing._internal.common_methods_invocations import op_db, SampleInput
|
||||||
from torch.testing._internal.common_modules import module_db, modules
|
from torch.testing._internal.common_modules import module_db, modules
|
||||||
from torch.testing._internal.common_nn import module_tests, new_module_tests, TestBase
|
from torch.testing._internal.common_nn import (
|
||||||
|
get_new_module_tests,
|
||||||
|
module_tests,
|
||||||
|
TestBase,
|
||||||
|
)
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
freeze_rng_state,
|
freeze_rng_state,
|
||||||
make_tensor,
|
make_tensor,
|
||||||
|
|
@ -1011,7 +1015,7 @@ def filter_supported_tests(t):
|
||||||
# TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests
|
# TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests
|
||||||
# These currently use the legacy nn tests
|
# These currently use the legacy nn tests
|
||||||
supported_tests = [
|
supported_tests = [
|
||||||
t for t in module_tests + new_module_tests if filter_supported_tests(t)
|
t for t in module_tests + get_new_module_tests() if filter_supported_tests(t)
|
||||||
]
|
]
|
||||||
for test_param in supported_tests:
|
for test_param in supported_tests:
|
||||||
if "constructor" not in test_param:
|
if "constructor" not in test_param:
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ from torch.testing._internal.common_device_type import (
|
||||||
ops,
|
ops,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_methods_invocations import op_db
|
from torch.testing._internal.common_methods_invocations import op_db
|
||||||
from torch.testing._internal.common_nn import module_tests, new_module_tests
|
from torch.testing._internal.common_nn import module_tests, get_new_module_tests
|
||||||
from torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase
|
from torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase
|
||||||
from torch.testing._internal.jit_utils import JitTestCase
|
from torch.testing._internal.jit_utils import JitTestCase
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
|
|
@ -1006,7 +1006,7 @@ terrible spacing
|
||||||
Exhaustively test `Node.normalized_arguments` on all standard
|
Exhaustively test `Node.normalized_arguments` on all standard
|
||||||
torch.nn Module classes
|
torch.nn Module classes
|
||||||
"""
|
"""
|
||||||
for test_params in module_tests + new_module_tests:
|
for test_params in module_tests + get_new_module_tests():
|
||||||
if "constructor" not in test_params:
|
if "constructor" not in test_params:
|
||||||
constructor = getattr(torch.nn, test_params["module_name"])
|
constructor = getattr(torch.nn, test_params["module_name"])
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -107,10 +107,10 @@ from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, dis
|
||||||
from torch.testing._internal.jit_metaprogramming_utils import (
|
from torch.testing._internal.jit_metaprogramming_utils import (
|
||||||
get_script_args,
|
get_script_args,
|
||||||
create_input, unpack_variables,
|
create_input, unpack_variables,
|
||||||
additional_module_tests, EXCLUDE_SCRIPT_MODULES,
|
get_all_nn_module_tests, EXCLUDE_SCRIPT_MODULES,
|
||||||
get_nn_module_name_from_kwargs, get_nn_mod_test_name, script_method_template)
|
get_nn_module_name_from_kwargs, get_nn_mod_test_name, script_method_template)
|
||||||
|
|
||||||
from torch.testing._internal.common_nn import module_tests, new_module_tests, criterion_tests
|
from torch.testing._internal.common_nn import criterion_tests
|
||||||
|
|
||||||
# For testing truediv in python 2
|
# For testing truediv in python 2
|
||||||
from torch.testing._internal.test_module.future_div import div_int_future, div_float_future
|
from torch.testing._internal.test_module.future_div import div_int_future, div_float_future
|
||||||
|
|
@ -16247,7 +16247,7 @@ class TestProducerVersion(TestCase):
|
||||||
# issue gh-32561
|
# issue gh-32561
|
||||||
self.assertTrue(torch.__version__.startswith(torch.onnx.producer_version))
|
self.assertTrue(torch.__version__.startswith(torch.onnx.producer_version))
|
||||||
|
|
||||||
for test in module_tests + new_module_tests + additional_module_tests:
|
for test in get_all_nn_module_tests():
|
||||||
add_nn_module_test(**test)
|
add_nn_module_test(**test)
|
||||||
|
|
||||||
for test in criterion_tests:
|
for test in criterion_tests:
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ from torch.testing._internal.common_utils import freeze_rng_state, run_tests, Te
|
||||||
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, PLATFORM_SUPPORTS_FLASH_ATTENTION
|
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, PLATFORM_SUPPORTS_FLASH_ATTENTION
|
||||||
from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \
|
from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \
|
||||||
module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \
|
module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \
|
||||||
ctcloss_reference, new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input
|
ctcloss_reference, get_new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input
|
||||||
from torch.testing._internal.common_device_type import dtypesIfMPS, instantiate_device_type_tests, dtypes, \
|
from torch.testing._internal.common_device_type import dtypesIfMPS, instantiate_device_type_tests, dtypes, \
|
||||||
dtypesIfCUDA, precisionOverride, skipCUDAIfCudnnVersionLessThan, onlyCUDA, onlyCPU, \
|
dtypesIfCUDA, precisionOverride, skipCUDAIfCudnnVersionLessThan, onlyCUDA, onlyCPU, \
|
||||||
skipCUDAIfRocm, skipCUDAIf, skipCUDAIfNotRocm, \
|
skipCUDAIfRocm, skipCUDAIf, skipCUDAIfNotRocm, \
|
||||||
|
|
@ -7332,7 +7332,7 @@ def add_test(test, decorator=None):
|
||||||
else:
|
else:
|
||||||
add(cuda_test_name, with_tf32_off)
|
add(cuda_test_name, with_tf32_off)
|
||||||
|
|
||||||
for test_params in module_tests + new_module_tests:
|
for test_params in module_tests + get_new_module_tests():
|
||||||
# TODO: CUDA is not implemented yet
|
# TODO: CUDA is not implemented yet
|
||||||
if 'constructor' not in test_params:
|
if 'constructor' not in test_params:
|
||||||
name = test_params.pop('module_name')
|
name = test_params.pop('module_name')
|
||||||
|
|
|
||||||
|
|
@ -19,8 +19,6 @@ from torch.fx import Graph, GraphModule, Node
|
||||||
from torch.fx.subgraph_rewriter import replace_pattern_with_filters, ReplacedPatterns
|
from torch.fx.subgraph_rewriter import replace_pattern_with_filters, ReplacedPatterns
|
||||||
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
_conv1d_bn_example_inputs,
|
|
||||||
_conv2d_bn_example_inputs,
|
|
||||||
_get_aten_graph_module_for_pattern,
|
_get_aten_graph_module_for_pattern,
|
||||||
_is_bn_node,
|
_is_bn_node,
|
||||||
_is_conv_or_conv_transpose_node,
|
_is_conv_or_conv_transpose_node,
|
||||||
|
|
@ -35,27 +33,6 @@ if TYPE_CHECKING:
|
||||||
__all__ = [] # type: ignore[var-annotated]
|
__all__ = [] # type: ignore[var-annotated]
|
||||||
|
|
||||||
|
|
||||||
# Example inputs for quantized and folded conv-bn1d patterns used in convert
|
|
||||||
_quantized_conv1d_bn_example_inputs = (
|
|
||||||
torch.randn(1, 1, 3), # x
|
|
||||||
torch.randn(1, 1, 1), # conv_weight
|
|
||||||
torch.randn(1), # bn_weight
|
|
||||||
torch.randn(1), # bn_bias
|
|
||||||
torch.randn(1), # bn_running_mean
|
|
||||||
torch.randn(1), # bn_running_var
|
|
||||||
)
|
|
||||||
|
|
||||||
# Example inputs for quantized and folded conv-bn2d patterns used in convert
|
|
||||||
_quantized_conv2d_bn_example_inputs = (
|
|
||||||
torch.randn(1, 1, 3, 3), # x
|
|
||||||
torch.randn(1, 1, 1, 1), # conv_weight
|
|
||||||
torch.randn(1), # bn_weight
|
|
||||||
torch.randn(1), # bn_bias
|
|
||||||
torch.randn(1), # bn_running_mean
|
|
||||||
torch.randn(1), # bn_running_var
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_quantized_conv_bn_example_inputs_kwargs(
|
def _get_quantized_conv_bn_example_inputs_kwargs(
|
||||||
is_per_channel: bool,
|
is_per_channel: bool,
|
||||||
has_bias: bool,
|
has_bias: bool,
|
||||||
|
|
@ -631,6 +608,28 @@ def _update_special_qspecs_after_replacement(
|
||||||
|
|
||||||
|
|
||||||
def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
|
def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
|
||||||
|
# Example inputs for conv-bn1d patterns
|
||||||
|
_conv1d_bn_example_inputs = (
|
||||||
|
torch.randn(1, 1, 3), # x
|
||||||
|
torch.randn(1, 1, 1), # conv_weight
|
||||||
|
torch.randn(1), # conv_bias
|
||||||
|
torch.randn(1), # bn_weight
|
||||||
|
torch.randn(1), # bn_bias
|
||||||
|
torch.randn(1), # bn_running_mean
|
||||||
|
torch.randn(1), # bn_running_var
|
||||||
|
)
|
||||||
|
|
||||||
|
# Example inputs for conv-bn2d patterns
|
||||||
|
_conv2d_bn_example_inputs = (
|
||||||
|
torch.randn(1, 1, 3, 3), # x
|
||||||
|
torch.randn(1, 1, 1, 1), # conv_weight
|
||||||
|
torch.randn(1), # conv_bias
|
||||||
|
torch.randn(1), # bn_weight
|
||||||
|
torch.randn(1), # bn_bias
|
||||||
|
torch.randn(1), # bn_running_mean
|
||||||
|
torch.randn(1), # bn_running_var
|
||||||
|
)
|
||||||
|
|
||||||
has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
|
has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
|
||||||
if not has_bn:
|
if not has_bn:
|
||||||
return m
|
return m
|
||||||
|
|
@ -859,6 +858,26 @@ def _copy_over_q_dq_args(original_node: Node, replacement_node: Node):
|
||||||
|
|
||||||
|
|
||||||
def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
|
def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
|
||||||
|
# Example inputs for quantized and folded conv-bn1d patterns used in convert
|
||||||
|
_quantized_conv1d_bn_example_inputs = (
|
||||||
|
torch.randn(1, 1, 3), # x
|
||||||
|
torch.randn(1, 1, 1), # conv_weight
|
||||||
|
torch.randn(1), # bn_weight
|
||||||
|
torch.randn(1), # bn_bias
|
||||||
|
torch.randn(1), # bn_running_mean
|
||||||
|
torch.randn(1), # bn_running_var
|
||||||
|
)
|
||||||
|
|
||||||
|
# Example inputs for quantized and folded conv-bn2d patterns used in convert
|
||||||
|
_quantized_conv2d_bn_example_inputs = (
|
||||||
|
torch.randn(1, 1, 3, 3), # x
|
||||||
|
torch.randn(1, 1, 1, 1), # conv_weight
|
||||||
|
torch.randn(1), # bn_weight
|
||||||
|
torch.randn(1), # bn_bias
|
||||||
|
torch.randn(1), # bn_running_mean
|
||||||
|
torch.randn(1), # bn_running_var
|
||||||
|
)
|
||||||
|
|
||||||
has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
|
has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
|
||||||
if not has_bn:
|
if not has_bn:
|
||||||
return m
|
return m
|
||||||
|
|
|
||||||
|
|
@ -22,25 +22,6 @@ __all__ = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
|
|
||||||
torch.randint(-128, 127, (2, 5), dtype=torch.int8),
|
|
||||||
torch.randn(1, dtype=torch.float),
|
|
||||||
torch.zeros(1, dtype=torch.int),
|
|
||||||
torch.tensor([-128], dtype=torch.int),
|
|
||||||
torch.tensor([127], dtype=torch.int),
|
|
||||||
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
|
|
||||||
torch.randn(1, dtype=torch.float),
|
|
||||||
torch.zeros(1, dtype=torch.int),
|
|
||||||
torch.tensor([-127], dtype=torch.int),
|
|
||||||
torch.tensor([127], dtype=torch.int),
|
|
||||||
torch.randn(1, dtype=torch.float),
|
|
||||||
torch.randn(1, dtype=torch.float),
|
|
||||||
torch.zeros(1, dtype=torch.int),
|
|
||||||
torch.tensor([-128], dtype=torch.int),
|
|
||||||
torch.tensor([127], dtype=torch.int),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _qdq_quantized_linear(
|
def _qdq_quantized_linear(
|
||||||
x_i8,
|
x_i8,
|
||||||
x_scale,
|
x_scale,
|
||||||
|
|
@ -129,20 +110,6 @@ def _reference_quantized_linear(
|
||||||
return out_i8
|
return out_i8
|
||||||
|
|
||||||
|
|
||||||
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
|
|
||||||
torch.randn((2, 5), dtype=torch.float),
|
|
||||||
-128,
|
|
||||||
127,
|
|
||||||
torch.finfo(torch.float32).eps,
|
|
||||||
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
|
|
||||||
torch.randn(1, dtype=torch.float),
|
|
||||||
torch.zeros(1, dtype=torch.int),
|
|
||||||
torch.tensor([-127], dtype=torch.int),
|
|
||||||
torch.tensor([127], dtype=torch.int),
|
|
||||||
torch.randn(1, dtype=torch.float),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _qdq_dynamic_quantized_linear(
|
def _qdq_dynamic_quantized_linear(
|
||||||
x_fp32,
|
x_fp32,
|
||||||
x_quant_min,
|
x_quant_min,
|
||||||
|
|
@ -223,25 +190,6 @@ def _reference_dynamic_quantized_linear(
|
||||||
return out_fp32
|
return out_fp32
|
||||||
|
|
||||||
|
|
||||||
_QUANTIZED_CONV2d_EXAMPLE_INPUTS = (
|
|
||||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
|
||||||
torch.randn(1, dtype=torch.float),
|
|
||||||
torch.zeros(1, dtype=torch.int),
|
|
||||||
torch.tensor([-128], dtype=torch.int),
|
|
||||||
torch.tensor([127], dtype=torch.int),
|
|
||||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
|
||||||
torch.randn(1, dtype=torch.float),
|
|
||||||
torch.zeros(1, dtype=torch.int),
|
|
||||||
torch.tensor([-127], dtype=torch.int),
|
|
||||||
torch.tensor([127], dtype=torch.int),
|
|
||||||
torch.randn(1, dtype=torch.float),
|
|
||||||
torch.randn(1, dtype=torch.float),
|
|
||||||
torch.zeros(1, dtype=torch.int),
|
|
||||||
torch.tensor([-128], dtype=torch.int),
|
|
||||||
torch.tensor([127], dtype=torch.int),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _qdq_quantized_conv2d(
|
def _qdq_quantized_conv2d(
|
||||||
x_i8,
|
x_i8,
|
||||||
x_scale,
|
x_scale,
|
||||||
|
|
@ -375,20 +323,6 @@ def _reference_quantized_conv2d(
|
||||||
return out_i8
|
return out_i8
|
||||||
|
|
||||||
|
|
||||||
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS = (
|
|
||||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
|
||||||
torch.randn(1, dtype=torch.float),
|
|
||||||
torch.zeros(1, dtype=torch.int),
|
|
||||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
|
||||||
torch.randn(1, dtype=torch.float),
|
|
||||||
torch.zeros(1, dtype=torch.int),
|
|
||||||
torch.randn(1, dtype=torch.float),
|
|
||||||
torch.zeros(1, dtype=torch.int),
|
|
||||||
torch.tensor([-128], dtype=torch.int),
|
|
||||||
torch.tensor([127], dtype=torch.int),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _qdq_quantized_add_relu(
|
def _qdq_quantized_add_relu(
|
||||||
x_i8,
|
x_i8,
|
||||||
x_scale,
|
x_scale,
|
||||||
|
|
@ -518,19 +452,6 @@ def _reference_quantized_add(
|
||||||
return out_i8
|
return out_i8
|
||||||
|
|
||||||
|
|
||||||
_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS = (
|
|
||||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
|
||||||
torch.randn(1, dtype=torch.float),
|
|
||||||
torch.zeros(1, dtype=torch.int),
|
|
||||||
torch.tensor([-128], dtype=torch.int),
|
|
||||||
torch.tensor([127], dtype=torch.int),
|
|
||||||
torch.randn(1, dtype=torch.float),
|
|
||||||
torch.zeros(1, dtype=torch.int),
|
|
||||||
torch.tensor([-128], dtype=torch.int),
|
|
||||||
torch.tensor([127], dtype=torch.int),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _qdq_quantized_max_pool2d(
|
def _qdq_quantized_max_pool2d(
|
||||||
x_i8,
|
x_i8,
|
||||||
x_scale,
|
x_scale,
|
||||||
|
|
@ -587,15 +508,6 @@ def _reference_quantized_max_pool2d(
|
||||||
return out_i8
|
return out_i8
|
||||||
|
|
||||||
|
|
||||||
_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
|
|
||||||
torch.randn(1, 3, 3, 3, dtype=torch.float),
|
|
||||||
torch.randn(1, dtype=torch.float),
|
|
||||||
torch.zeros(1, dtype=torch.int),
|
|
||||||
torch.tensor([-128], dtype=torch.int),
|
|
||||||
torch.tensor([127], dtype=torch.int),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _quantize_per_tensor_int8(x_fp32, scale, zero_point, quant_min, quant_max):
|
def _quantize_per_tensor_int8(x_fp32, scale, zero_point, quant_min, quant_max):
|
||||||
x = torch.ops.quantized_decomposed.quantize_per_tensor(
|
x = torch.ops.quantized_decomposed.quantize_per_tensor(
|
||||||
x_fp32, scale, zero_point, quant_min, quant_max, torch.int8
|
x_fp32, scale, zero_point, quant_min, quant_max, torch.int8
|
||||||
|
|
@ -619,15 +531,6 @@ def _reference_quantize_per_tensor_int8(
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
_DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
|
|
||||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
|
||||||
torch.randn(1, dtype=torch.float),
|
|
||||||
torch.zeros(1, dtype=torch.int),
|
|
||||||
torch.tensor([-128], dtype=torch.int),
|
|
||||||
torch.tensor([127], dtype=torch.int),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _dequantize_per_tensor_int8(x_i8, scale, zero_point, quant_min, quant_max):
|
def _dequantize_per_tensor_int8(x_i8, scale, zero_point, quant_min, quant_max):
|
||||||
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
||||||
x_i8, scale, zero_point, quant_min, quant_max, torch.int8
|
x_i8, scale, zero_point, quant_min, quant_max, torch.int8
|
||||||
|
|
@ -648,16 +551,6 @@ def _reference_dequantize_per_tensor_int8(
|
||||||
return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
|
return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
_QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
|
|
||||||
torch.randn(1, 3, 3, 3, dtype=torch.float),
|
|
||||||
torch.randn(3, dtype=torch.float),
|
|
||||||
torch.zeros(3, dtype=torch.int),
|
|
||||||
1,
|
|
||||||
-128,
|
|
||||||
127,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _quantize_per_channel_int8(
|
def _quantize_per_channel_int8(
|
||||||
x_fp32, scales, zero_points, ch_axis, quant_min, quant_max
|
x_fp32, scales, zero_points, ch_axis, quant_min, quant_max
|
||||||
):
|
):
|
||||||
|
|
@ -678,16 +571,6 @@ def _reference_quantize_per_channel_int8(
|
||||||
return out_i32.to(torch.int8)
|
return out_i32.to(torch.int8)
|
||||||
|
|
||||||
|
|
||||||
_DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
|
|
||||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
|
||||||
torch.randn(3, dtype=torch.float),
|
|
||||||
torch.zeros(3, dtype=torch.int),
|
|
||||||
1,
|
|
||||||
-128,
|
|
||||||
127,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _dequantize_per_channel_int8(
|
def _dequantize_per_channel_int8(
|
||||||
x_i8, scales, zero_points, ch_axis, quant_min, quant_max
|
x_i8, scales, zero_points, ch_axis, quant_min, quant_max
|
||||||
):
|
):
|
||||||
|
|
@ -733,79 +616,186 @@ class _RewriteInfo:
|
||||||
replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
|
replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
|
||||||
|
|
||||||
|
|
||||||
_REWRITE_INFO_LIST = [
|
|
||||||
_RewriteInfo(
|
|
||||||
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
|
|
||||||
_WrapperModule(_qdq_dynamic_quantized_linear),
|
|
||||||
_WrapperModule(_reference_dynamic_quantized_linear),
|
|
||||||
partial(
|
|
||||||
_replace_literals_with_existing_placeholders,
|
|
||||||
literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3},
|
|
||||||
),
|
|
||||||
partial(
|
|
||||||
_replace_literals_with_existing_placeholders,
|
|
||||||
literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3},
|
|
||||||
),
|
|
||||||
),
|
|
||||||
_RewriteInfo(
|
|
||||||
_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
|
|
||||||
_WrapperModule(_qdq_quantized_linear),
|
|
||||||
_WrapperModule(_reference_quantized_linear),
|
|
||||||
_replace_literals_with_new_placeholders,
|
|
||||||
_replace_literals_with_new_placeholders,
|
|
||||||
),
|
|
||||||
_RewriteInfo(
|
|
||||||
_QUANTIZED_CONV2d_EXAMPLE_INPUTS,
|
|
||||||
_WrapperModule(_qdq_quantized_conv2d),
|
|
||||||
_WrapperModule(_reference_quantized_conv2d),
|
|
||||||
partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]),
|
|
||||||
partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]),
|
|
||||||
),
|
|
||||||
_RewriteInfo(
|
|
||||||
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS,
|
|
||||||
_WrapperModule(_qdq_quantized_add_relu),
|
|
||||||
_WrapperModule(_reference_quantized_add_relu),
|
|
||||||
),
|
|
||||||
_RewriteInfo(
|
|
||||||
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS,
|
|
||||||
_WrapperModule(_qdq_quantized_add),
|
|
||||||
_WrapperModule(_reference_quantized_add),
|
|
||||||
),
|
|
||||||
_RewriteInfo(
|
|
||||||
_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS,
|
|
||||||
_WrapperModule(_qdq_quantized_max_pool2d),
|
|
||||||
_WrapperModule(_reference_quantized_max_pool2d),
|
|
||||||
_replace_literals_with_new_placeholders,
|
|
||||||
_replace_literals_with_new_placeholders,
|
|
||||||
),
|
|
||||||
_RewriteInfo(
|
|
||||||
_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS,
|
|
||||||
_WrapperModule(_quantize_per_tensor_int8),
|
|
||||||
_WrapperModule(_reference_quantize_per_tensor_int8),
|
|
||||||
),
|
|
||||||
_RewriteInfo(
|
|
||||||
_DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS,
|
|
||||||
_WrapperModule(_dequantize_per_tensor_int8),
|
|
||||||
_WrapperModule(_reference_dequantize_per_tensor_int8),
|
|
||||||
),
|
|
||||||
_RewriteInfo(
|
|
||||||
_QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS,
|
|
||||||
_WrapperModule(_quantize_per_channel_int8),
|
|
||||||
_WrapperModule(_reference_quantize_per_channel_int8),
|
|
||||||
_replace_ph_qdq_per_channel_replacement,
|
|
||||||
_replace_ph_qdq_per_channel_replacement,
|
|
||||||
),
|
|
||||||
_RewriteInfo(
|
|
||||||
_DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS,
|
|
||||||
_WrapperModule(_dequantize_per_channel_int8),
|
|
||||||
_WrapperModule(_reference_dequantize_per_channel_int8),
|
|
||||||
_replace_ph_qdq_per_channel_replacement,
|
|
||||||
_replace_ph_qdq_per_channel_replacement,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def reference_representation_rewrite(model: GraphModule) -> GraphModule:
|
def reference_representation_rewrite(model: GraphModule) -> GraphModule:
|
||||||
|
_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
|
||||||
|
torch.randint(-128, 127, (2, 5), dtype=torch.int8),
|
||||||
|
torch.randn(1, dtype=torch.float),
|
||||||
|
torch.zeros(1, dtype=torch.int),
|
||||||
|
torch.tensor([-128], dtype=torch.int),
|
||||||
|
torch.tensor([127], dtype=torch.int),
|
||||||
|
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
|
||||||
|
torch.randn(1, dtype=torch.float),
|
||||||
|
torch.zeros(1, dtype=torch.int),
|
||||||
|
torch.tensor([-127], dtype=torch.int),
|
||||||
|
torch.tensor([127], dtype=torch.int),
|
||||||
|
torch.randn(1, dtype=torch.float),
|
||||||
|
torch.randn(1, dtype=torch.float),
|
||||||
|
torch.zeros(1, dtype=torch.int),
|
||||||
|
torch.tensor([-128], dtype=torch.int),
|
||||||
|
torch.tensor([127], dtype=torch.int),
|
||||||
|
)
|
||||||
|
|
||||||
|
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
|
||||||
|
torch.randn((2, 5), dtype=torch.float),
|
||||||
|
-128,
|
||||||
|
127,
|
||||||
|
torch.finfo(torch.float32).eps,
|
||||||
|
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
|
||||||
|
torch.randn(1, dtype=torch.float),
|
||||||
|
torch.zeros(1, dtype=torch.int),
|
||||||
|
torch.tensor([-127], dtype=torch.int),
|
||||||
|
torch.tensor([127], dtype=torch.int),
|
||||||
|
torch.randn(1, dtype=torch.float),
|
||||||
|
)
|
||||||
|
|
||||||
|
_QUANTIZED_CONV2d_EXAMPLE_INPUTS = (
|
||||||
|
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||||
|
torch.randn(1, dtype=torch.float),
|
||||||
|
torch.zeros(1, dtype=torch.int),
|
||||||
|
torch.tensor([-128], dtype=torch.int),
|
||||||
|
torch.tensor([127], dtype=torch.int),
|
||||||
|
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||||
|
torch.randn(1, dtype=torch.float),
|
||||||
|
torch.zeros(1, dtype=torch.int),
|
||||||
|
torch.tensor([-127], dtype=torch.int),
|
||||||
|
torch.tensor([127], dtype=torch.int),
|
||||||
|
torch.randn(1, dtype=torch.float),
|
||||||
|
torch.randn(1, dtype=torch.float),
|
||||||
|
torch.zeros(1, dtype=torch.int),
|
||||||
|
torch.tensor([-128], dtype=torch.int),
|
||||||
|
torch.tensor([127], dtype=torch.int),
|
||||||
|
)
|
||||||
|
|
||||||
|
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS = (
|
||||||
|
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||||
|
torch.randn(1, dtype=torch.float),
|
||||||
|
torch.zeros(1, dtype=torch.int),
|
||||||
|
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||||
|
torch.randn(1, dtype=torch.float),
|
||||||
|
torch.zeros(1, dtype=torch.int),
|
||||||
|
torch.randn(1, dtype=torch.float),
|
||||||
|
torch.zeros(1, dtype=torch.int),
|
||||||
|
torch.tensor([-128], dtype=torch.int),
|
||||||
|
torch.tensor([127], dtype=torch.int),
|
||||||
|
)
|
||||||
|
|
||||||
|
_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS = (
|
||||||
|
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||||
|
torch.randn(1, dtype=torch.float),
|
||||||
|
torch.zeros(1, dtype=torch.int),
|
||||||
|
torch.tensor([-128], dtype=torch.int),
|
||||||
|
torch.tensor([127], dtype=torch.int),
|
||||||
|
torch.randn(1, dtype=torch.float),
|
||||||
|
torch.zeros(1, dtype=torch.int),
|
||||||
|
torch.tensor([-128], dtype=torch.int),
|
||||||
|
torch.tensor([127], dtype=torch.int),
|
||||||
|
)
|
||||||
|
|
||||||
|
_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
|
||||||
|
torch.randn(1, 3, 3, 3, dtype=torch.float),
|
||||||
|
torch.randn(1, dtype=torch.float),
|
||||||
|
torch.zeros(1, dtype=torch.int),
|
||||||
|
torch.tensor([-128], dtype=torch.int),
|
||||||
|
torch.tensor([127], dtype=torch.int),
|
||||||
|
)
|
||||||
|
|
||||||
|
_DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
|
||||||
|
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||||
|
torch.randn(1, dtype=torch.float),
|
||||||
|
torch.zeros(1, dtype=torch.int),
|
||||||
|
torch.tensor([-128], dtype=torch.int),
|
||||||
|
torch.tensor([127], dtype=torch.int),
|
||||||
|
)
|
||||||
|
|
||||||
|
_QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
|
||||||
|
torch.randn(1, 3, 3, 3, dtype=torch.float),
|
||||||
|
torch.randn(3, dtype=torch.float),
|
||||||
|
torch.zeros(3, dtype=torch.int),
|
||||||
|
1,
|
||||||
|
-128,
|
||||||
|
127,
|
||||||
|
)
|
||||||
|
|
||||||
|
_DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
|
||||||
|
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||||
|
torch.randn(3, dtype=torch.float),
|
||||||
|
torch.zeros(3, dtype=torch.int),
|
||||||
|
1,
|
||||||
|
-128,
|
||||||
|
127,
|
||||||
|
)
|
||||||
|
|
||||||
|
_REWRITE_INFO_LIST = [
|
||||||
|
_RewriteInfo(
|
||||||
|
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
|
||||||
|
_WrapperModule(_qdq_dynamic_quantized_linear),
|
||||||
|
_WrapperModule(_reference_dynamic_quantized_linear),
|
||||||
|
partial(
|
||||||
|
_replace_literals_with_existing_placeholders,
|
||||||
|
literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3},
|
||||||
|
),
|
||||||
|
partial(
|
||||||
|
_replace_literals_with_existing_placeholders,
|
||||||
|
literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
_RewriteInfo(
|
||||||
|
_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
|
||||||
|
_WrapperModule(_qdq_quantized_linear),
|
||||||
|
_WrapperModule(_reference_quantized_linear),
|
||||||
|
_replace_literals_with_new_placeholders,
|
||||||
|
_replace_literals_with_new_placeholders,
|
||||||
|
),
|
||||||
|
_RewriteInfo(
|
||||||
|
_QUANTIZED_CONV2d_EXAMPLE_INPUTS,
|
||||||
|
_WrapperModule(_qdq_quantized_conv2d),
|
||||||
|
_WrapperModule(_reference_quantized_conv2d),
|
||||||
|
partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]),
|
||||||
|
partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]),
|
||||||
|
),
|
||||||
|
_RewriteInfo(
|
||||||
|
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS,
|
||||||
|
_WrapperModule(_qdq_quantized_add_relu),
|
||||||
|
_WrapperModule(_reference_quantized_add_relu),
|
||||||
|
),
|
||||||
|
_RewriteInfo(
|
||||||
|
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS,
|
||||||
|
_WrapperModule(_qdq_quantized_add),
|
||||||
|
_WrapperModule(_reference_quantized_add),
|
||||||
|
),
|
||||||
|
_RewriteInfo(
|
||||||
|
_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS,
|
||||||
|
_WrapperModule(_qdq_quantized_max_pool2d),
|
||||||
|
_WrapperModule(_reference_quantized_max_pool2d),
|
||||||
|
_replace_literals_with_new_placeholders,
|
||||||
|
_replace_literals_with_new_placeholders,
|
||||||
|
),
|
||||||
|
_RewriteInfo(
|
||||||
|
_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS,
|
||||||
|
_WrapperModule(_quantize_per_tensor_int8),
|
||||||
|
_WrapperModule(_reference_quantize_per_tensor_int8),
|
||||||
|
),
|
||||||
|
_RewriteInfo(
|
||||||
|
_DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS,
|
||||||
|
_WrapperModule(_dequantize_per_tensor_int8),
|
||||||
|
_WrapperModule(_reference_dequantize_per_tensor_int8),
|
||||||
|
),
|
||||||
|
_RewriteInfo(
|
||||||
|
_QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS,
|
||||||
|
_WrapperModule(_quantize_per_channel_int8),
|
||||||
|
_WrapperModule(_reference_quantize_per_channel_int8),
|
||||||
|
_replace_ph_qdq_per_channel_replacement,
|
||||||
|
_replace_ph_qdq_per_channel_replacement,
|
||||||
|
),
|
||||||
|
_RewriteInfo(
|
||||||
|
_DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS,
|
||||||
|
_WrapperModule(_dequantize_per_channel_int8),
|
||||||
|
_WrapperModule(_reference_dequantize_per_channel_int8),
|
||||||
|
_replace_ph_qdq_per_channel_replacement,
|
||||||
|
_replace_ph_qdq_per_channel_replacement,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
remove_tensor_overload_for_qdq_ops(model)
|
remove_tensor_overload_for_qdq_ops(model)
|
||||||
from torch._export import gm_using_training_ir
|
from torch._export import gm_using_training_ir
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,28 +33,6 @@ _DEQUANTIZE_OPS = [
|
||||||
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
||||||
]
|
]
|
||||||
|
|
||||||
# Example inputs for conv-bn1d patterns
|
|
||||||
_conv1d_bn_example_inputs = (
|
|
||||||
torch.randn(1, 1, 3), # x
|
|
||||||
torch.randn(1, 1, 1), # conv_weight
|
|
||||||
torch.randn(1), # conv_bias
|
|
||||||
torch.randn(1), # bn_weight
|
|
||||||
torch.randn(1), # bn_bias
|
|
||||||
torch.randn(1), # bn_running_mean
|
|
||||||
torch.randn(1), # bn_running_var
|
|
||||||
)
|
|
||||||
|
|
||||||
# Example inputs for conv-bn2d patterns
|
|
||||||
_conv2d_bn_example_inputs = (
|
|
||||||
torch.randn(1, 1, 3, 3), # x
|
|
||||||
torch.randn(1, 1, 1, 1), # conv_weight
|
|
||||||
torch.randn(1), # conv_bias
|
|
||||||
torch.randn(1), # bn_weight
|
|
||||||
torch.randn(1), # bn_bias
|
|
||||||
torch.randn(1), # bn_running_mean
|
|
||||||
torch.randn(1), # bn_running_var
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_connected(source: torch.fx.Node, dest: torch.fx.Node) -> bool:
|
def _is_connected(source: torch.fx.Node, dest: torch.fx.Node) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -10,8 +10,6 @@ from torch._subclasses import FakeTensor
|
||||||
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
|
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
|
||||||
from torch.ao.quantization.pt2e.export_utils import _WrapperModule
|
from torch.ao.quantization.pt2e.export_utils import _WrapperModule
|
||||||
from torch.ao.quantization.pt2e.utils import (
|
from torch.ao.quantization.pt2e.utils import (
|
||||||
_conv1d_bn_example_inputs,
|
|
||||||
_conv2d_bn_example_inputs,
|
|
||||||
_get_aten_graph_module_for_pattern,
|
_get_aten_graph_module_for_pattern,
|
||||||
_is_conv_node,
|
_is_conv_node,
|
||||||
_is_conv_transpose_node,
|
_is_conv_transpose_node,
|
||||||
|
|
@ -487,6 +485,28 @@ def _do_annotate_conv_bn(
|
||||||
for the following names: "input", "conv", "weight", "bias", and "output".
|
for the following names: "input", "conv", "weight", "bias", and "output".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Example inputs for conv-bn1d patterns
|
||||||
|
_conv1d_bn_example_inputs = (
|
||||||
|
torch.randn(1, 1, 3), # x
|
||||||
|
torch.randn(1, 1, 1), # conv_weight
|
||||||
|
torch.randn(1), # conv_bias
|
||||||
|
torch.randn(1), # bn_weight
|
||||||
|
torch.randn(1), # bn_bias
|
||||||
|
torch.randn(1), # bn_running_mean
|
||||||
|
torch.randn(1), # bn_running_var
|
||||||
|
)
|
||||||
|
|
||||||
|
# Example inputs for conv-bn2d patterns
|
||||||
|
_conv2d_bn_example_inputs = (
|
||||||
|
torch.randn(1, 1, 3, 3), # x
|
||||||
|
torch.randn(1, 1, 1, 1), # conv_weight
|
||||||
|
torch.randn(1), # conv_bias
|
||||||
|
torch.randn(1), # bn_weight
|
||||||
|
torch.randn(1), # bn_bias
|
||||||
|
torch.randn(1), # bn_running_mean
|
||||||
|
torch.randn(1), # bn_running_var
|
||||||
|
)
|
||||||
|
|
||||||
def get_pattern(conv_fn: Callable, relu_is_inplace: bool):
|
def get_pattern(conv_fn: Callable, relu_is_inplace: bool):
|
||||||
def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv):
|
def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv):
|
||||||
conv = conv_fn(x, conv_weight, conv_bias)
|
conv = conv_fn(x, conv_weight, conv_bias)
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -8,7 +8,7 @@ import torch.cuda
|
||||||
import torch.jit
|
import torch.jit
|
||||||
import torch.jit._logging
|
import torch.jit._logging
|
||||||
import torch.jit.frontend
|
import torch.jit.frontend
|
||||||
from torch.testing._internal.common_nn import module_tests, new_module_tests
|
from torch.testing._internal.common_nn import module_tests, get_new_module_tests
|
||||||
from torch.testing._internal.common_utils import is_iterable_of_tensors, noncontiguous_like
|
from torch.testing._internal.common_utils import is_iterable_of_tensors, noncontiguous_like
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
|
@ -95,226 +95,228 @@ def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwarg
|
||||||
# fn mapping output to part that should be gradcheck'ed, // optional
|
# fn mapping output to part that should be gradcheck'ed, // optional
|
||||||
# kwargs for function, // optional
|
# kwargs for function, // optional
|
||||||
# )
|
# )
|
||||||
nn_functional_tests = [
|
def get_nn_functional_tests():
|
||||||
('conv1d', (S, S, S), ((S, S, S),)),
|
nn_functional_tests = [
|
||||||
('conv2d', (S, S, S, S), ((S, S, S, S),)),
|
('conv1d', (S, S, S), ((S, S, S),)),
|
||||||
('conv3d', (S, S, S, S, S), ((S, S, S, S, S),)),
|
('conv2d', (S, S, S, S), ((S, S, S, S),)),
|
||||||
('conv_transpose1d', (S, S, S), ((S, S, S),)),
|
('conv3d', (S, S, S, S, S), ((S, S, S, S, S),)),
|
||||||
('conv_transpose2d', (S, S, S, S), ((S, S, S, S),)),
|
('conv_transpose1d', (S, S, S), ((S, S, S),)),
|
||||||
('conv_transpose3d', (S, S, S, S, S), ((S, S, S, S, S),)),
|
('conv_transpose2d', (S, S, S, S), ((S, S, S, S),)),
|
||||||
('conv_tbc', (S, S, S), ((S, S, S), (S,), 2)),
|
('conv_transpose3d', (S, S, S, S, S), ((S, S, S, S, S),)),
|
||||||
('avg_pool1d', (S, S, S), (3,)),
|
('conv_tbc', (S, S, S), ((S, S, S), (S,), 2)),
|
||||||
('avg_pool2d', (S, S, S, S), (3,), '', (True,)),
|
('avg_pool1d', (S, S, S), (3,)),
|
||||||
('avg_pool3d', (S, S, S, S, S), (3,)),
|
('avg_pool2d', (S, S, S, S), (3,), '', (True,)),
|
||||||
('fractional_max_pool2d', (S, S, S, S), (3, [2, 3],)),
|
('avg_pool3d', (S, S, S, S, S), (3,)),
|
||||||
('max_pool1d', (S, S, S), (2, 1)),
|
('fractional_max_pool2d', (S, S, S, S), (3, [2, 3],)),
|
||||||
('max_pool1d', (S, S, S), (2, 1, 1, 1, False, True), 'with_indices'),
|
('max_pool1d', (S, S, S), (2, 1)),
|
||||||
('max_pool2d', (S, S, S, S), (2, 1), '', (True, 'aten::max_pool2d_with_indices')),
|
('max_pool1d', (S, S, S), (2, 1, 1, 1, False, True), 'with_indices'),
|
||||||
('max_pool2d', (S, S, S, S), (2, 1, 1, 1, False, True), 'with_indices', (True, 'aten::max_pool2d_with_indices')),
|
('max_pool2d', (S, S, S, S), (2, 1), '', (True, 'aten::max_pool2d_with_indices')),
|
||||||
('max_pool3d', (S, S, S, S, S), (2, 1)),
|
('max_pool2d', (S, S, S, S), (2, 1, 1, 1, False, True), 'with_indices', (True, 'aten::max_pool2d_with_indices')),
|
||||||
('max_unpool1d', torch.tensor([[[2., 4]]]), (torch.tensor([[[1, 3]]]), 2, 2, 0)),
|
('max_pool3d', (S, S, S, S, S), (2, 1)),
|
||||||
('max_unpool2d', torch.tensor([[[[2., 4]]]]), (torch.tensor([[[[1, 3]]]]), 2, 2, 0)),
|
('max_unpool1d', torch.tensor([[[2., 4]]]), (torch.tensor([[[1, 3]]]), 2, 2, 0)),
|
||||||
('max_unpool3d', torch.tensor([[[[[2., 4]]]]]), (torch.tensor([[[[[1, 3]]]]]), 2, 2, 0)),
|
('max_unpool2d', torch.tensor([[[[2., 4]]]]), (torch.tensor([[[[1, 3]]]]), 2, 2, 0)),
|
||||||
('lp_pool1d', (S, S, S), (2., 3, 2,)),
|
('max_unpool3d', torch.tensor([[[[[2., 4]]]]]), (torch.tensor([[[[[1, 3]]]]]), 2, 2, 0)),
|
||||||
('lp_pool2d', (S, S, S, S), (2., 3, 2,)),
|
('lp_pool1d', (S, S, S), (2., 3, 2,)),
|
||||||
('lp_pool3d', (S, S, S, S, S), (2., 3, 2,)),
|
('lp_pool2d', (S, S, S, S), (2., 3, 2,)),
|
||||||
('adaptive_max_pool1d', (S, S, S), (5,)),
|
('lp_pool3d', (S, S, S, S, S), (2., 3, 2,)),
|
||||||
('adaptive_max_pool2d', (S, S, S, S), ([5, 7],)),
|
('adaptive_max_pool1d', (S, S, S), (5,)),
|
||||||
('adaptive_max_pool3d', (S, S, S, S, S), ([3, 2, 2],)),
|
('adaptive_max_pool2d', (S, S, S, S), ([5, 7],)),
|
||||||
('adaptive_avg_pool1d', (S, S, S), (5,), '', (True,)),
|
('adaptive_max_pool3d', (S, S, S, S, S), ([3, 2, 2],)),
|
||||||
('adaptive_avg_pool2d', (S, S, S, S), ([5, 7],), '', (True,)),
|
('adaptive_avg_pool1d', (S, S, S), (5,), '', (True,)),
|
||||||
('adaptive_avg_pool3d', (S, S, S, S, S), ([3, 2, 2],), '', (True,)),
|
('adaptive_avg_pool2d', (S, S, S, S), ([5, 7],), '', (True,)),
|
||||||
('dropout', (S, S, S), (0.5,), '', (True, 'aten::native_dropout')),
|
('adaptive_avg_pool3d', (S, S, S, S, S), ([3, 2, 2],), '', (True,)),
|
||||||
('alpha_dropout', (S, S, S), (0.5,)),
|
('dropout', (S, S, S), (0.5,), '', (True, 'aten::native_dropout')),
|
||||||
('dropout2d', (S, S, S), (0.5,)),
|
('alpha_dropout', (S, S, S), (0.5,)),
|
||||||
('dropout2d', (S, S, S, S), (0.5,), 'batched'),
|
('dropout2d', (S, S, S), (0.5,)),
|
||||||
('dropout3d', (S, S, S, S), (0.5,)),
|
('dropout2d', (S, S, S, S), (0.5,), 'batched'),
|
||||||
('dropout3d', (S, S, S, S, S), (0.5,), 'batched'),
|
('dropout3d', (S, S, S, S), (0.5,)),
|
||||||
('feature_alpha_dropout', (S, S, S), (0.5,)),
|
('dropout3d', (S, S, S, S, S), (0.5,), 'batched'),
|
||||||
('threshold', (S, S, S), (0.1, 2.), '', (True,)),
|
('feature_alpha_dropout', (S, S, S), (0.5,)),
|
||||||
('threshold', (S, S, S), (0.1, 2., True), 'inplace'),
|
('threshold', (S, S, S), (0.1, 2.), '', (True,)),
|
||||||
('relu', (S, S, S), (), '', (True,)),
|
('threshold', (S, S, S), (0.1, 2., True), 'inplace'),
|
||||||
('relu', (S, S, S), (), 'inplace'),
|
('relu', (S, S, S), (), '', (True,)),
|
||||||
('glu', (S - 1, S - 1, S - 1), (),),
|
('relu', (S, S, S), (), 'inplace'),
|
||||||
('hardtanh', (S, S, S), (-0.5, 0.5), '', (True,)),
|
('glu', (S - 1, S - 1, S - 1), (),),
|
||||||
('hardtanh', (S, S, S), (-0.5, 0.5, True), 'inplace'),
|
('hardtanh', (S, S, S), (-0.5, 0.5), '', (True,)),
|
||||||
('relu6', (S, S, S), (), '', (True,)),
|
('hardtanh', (S, S, S), (-0.5, 0.5, True), 'inplace'),
|
||||||
('relu6', (S, S, S), (True), 'inplace'),
|
('relu6', (S, S, S), (), '', (True,)),
|
||||||
('elu', (S, S, S), (0.9,),),
|
('relu6', (S, S, S), (True), 'inplace'),
|
||||||
('elu', (S, S, S), (0.9, True), 'inplace'),
|
('elu', (S, S, S), (0.9,),),
|
||||||
('selu', (S, S, S), (),),
|
('elu', (S, S, S), (0.9, True), 'inplace'),
|
||||||
('selu', (S, S, S), (True), 'inplace'),
|
('selu', (S, S, S), (),),
|
||||||
('celu', (S, S, S), (0.9,),),
|
('selu', (S, S, S), (True), 'inplace'),
|
||||||
('celu', (S, S, S), (0.9, True), 'inplace'),
|
('celu', (S, S, S), (0.9,),),
|
||||||
('leaky_relu', (S, S, S), (0.02,), '', (True,)),
|
('celu', (S, S, S), (0.9, True), 'inplace'),
|
||||||
('leaky_relu', (S, S, S), (0.02,), 'inplace'),
|
('leaky_relu', (S, S, S), (0.02,), '', (True,)),
|
||||||
('rrelu', (S, S), (0.1, 0.3, False),),
|
('leaky_relu', (S, S, S), (0.02,), 'inplace'),
|
||||||
('rrelu', (S, S), (0.1, 0.3, False, True), 'inplace'),
|
('rrelu', (S, S), (0.1, 0.3, False),),
|
||||||
('hardshrink', (S, S, S), (0.4,), '', (True,)),
|
('rrelu', (S, S), (0.1, 0.3, False, True), 'inplace'),
|
||||||
('tanhshrink', (S, S, S), (),),
|
('hardshrink', (S, S, S), (0.4,), '', (True,)),
|
||||||
('softsign', (S, S, S), (),),
|
('tanhshrink', (S, S, S), (),),
|
||||||
('softplus', (S, S, S), (), '', (True,)),
|
('softsign', (S, S, S), (),),
|
||||||
('softmin', (S, S, S), (0,),),
|
('softplus', (S, S, S), (), '', (True,)),
|
||||||
('softmax', (S, S, S), (0,), '', (True,)),
|
('softmin', (S, S, S), (0,),),
|
||||||
('softmax', (S, S, S), (0, 3, torch.double), 'with_all_args', (True,)),
|
('softmax', (S, S, S), (0,), '', (True,)),
|
||||||
('tanh', (S, S, S), (), '', (True,)),
|
('softmax', (S, S, S), (0, 3, torch.double), 'with_all_args', (True,)),
|
||||||
('sigmoid', (S, S, S), (), '', (True,)),
|
('tanh', (S, S, S), (), '', (True,)),
|
||||||
('silu', (S, S, S), (), '', (True,)),
|
('sigmoid', (S, S, S), (), '', (True,)),
|
||||||
('log_softmax', (S, S, S), (0,), '', (True,)),
|
('silu', (S, S, S), (), '', (True,)),
|
||||||
('linear', (S, S), ((M, S),), '', (True, ['aten::linear'])),
|
('log_softmax', (S, S, S), (0,), '', (True,)),
|
||||||
('linear', (S, S), ((M, S), (M,)), 'addmm', (True, ['aten::linear'])),
|
('linear', (S, S), ((M, S),), '', (True, ['aten::linear'])),
|
||||||
('bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M),),),
|
('linear', (S, S), ((M, S), (M,)), 'addmm', (True, ['aten::linear'])),
|
||||||
('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ), '', (True,)),
|
('bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M),),),
|
||||||
('embedding_bag', torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3), torch.tensor([0, 4]),),),
|
('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ), '', (True,)),
|
||||||
('batch_norm', (S, S),
|
('embedding_bag', torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3), torch.tensor([0, 4]),),),
|
||||||
(non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), None, None, True, ),
|
('batch_norm', (S, S),
|
||||||
'training', (True, 'aten::_batch_norm_impl_index')),
|
(non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), None, None, True, ),
|
||||||
('batch_norm', (0, S, S, S),
|
'training', (True, 'aten::_batch_norm_impl_index')),
|
||||||
(non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
('batch_norm', (0, S, S, S),
|
||||||
non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
|
(non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
||||||
'size_zero', (True, 'aten::_batch_norm_impl_index')),
|
non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
|
||||||
('batch_norm', (0, S, S, S),
|
'size_zero', (True, 'aten::_batch_norm_impl_index')),
|
||||||
(non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
('batch_norm', (0, S, S, S),
|
||||||
non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
|
(non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
||||||
'size_zero_inference', (True, 'aten::_batch_norm_impl_index')),
|
non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
|
||||||
('batch_norm', (S, S),
|
'size_zero_inference', (True, 'aten::_batch_norm_impl_index')),
|
||||||
(non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
('batch_norm', (S, S),
|
||||||
non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
|
(non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
||||||
'with_weight_and_bias_training', (True, 'aten::_batch_norm_impl_index')),
|
non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
|
||||||
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
'with_weight_and_bias_training', (True, 'aten::_batch_norm_impl_index')),
|
||||||
None, non_differentiable(torch.ones(S)), True, ),
|
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
||||||
'with_only_bias_training', (True, 'aten::_batch_norm_impl_index')),
|
None, non_differentiable(torch.ones(S)), True, ),
|
||||||
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
'with_only_bias_training', (True, 'aten::_batch_norm_impl_index')),
|
||||||
non_differentiable(torch.randn(S)), None, True, ),
|
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
||||||
'with_only_weight_training', (True, 'aten::_batch_norm_impl_index')),
|
non_differentiable(torch.randn(S)), None, True, ),
|
||||||
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
'with_only_weight_training', (True, 'aten::_batch_norm_impl_index')),
|
||||||
None, None, False, ),
|
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
||||||
'inference', (True, 'aten::_batch_norm_impl_index')),
|
None, None, False, ),
|
||||||
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
'inference', (True, 'aten::_batch_norm_impl_index')),
|
||||||
non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), False, ),
|
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
||||||
'with_weight_and_bias_inference', (True, 'aten::_batch_norm_impl_index')),
|
non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), False, ),
|
||||||
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
'with_weight_and_bias_inference', (True, 'aten::_batch_norm_impl_index')),
|
||||||
None, non_differentiable(torch.ones(S)), False, ),
|
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
||||||
'with_only_bias_inference', (True, 'aten::_batch_norm_impl_index')),
|
None, non_differentiable(torch.ones(S)), False, ),
|
||||||
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
'with_only_bias_inference', (True, 'aten::_batch_norm_impl_index')),
|
||||||
non_differentiable(torch.randn(S)), None, False, ),
|
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
|
||||||
'with_only_weight_inference', (True, 'aten::_batch_norm_impl_index')),
|
non_differentiable(torch.randn(S)), None, False, ),
|
||||||
('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),),
|
'with_only_weight_inference', (True, 'aten::_batch_norm_impl_index')),
|
||||||
('layer_norm', (S, S, S, S), ([5],), '',
|
('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),),
|
||||||
(False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
|
('layer_norm', (S, S, S, S), ([5],), '',
|
||||||
('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),), 'with_only_weight',
|
(False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
|
||||||
(False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
|
('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),), 'with_only_weight',
|
||||||
('layer_norm', (S, S, S, S), ([5], None, non_differentiable(torch.rand(S)),), 'with_only_bias',
|
(False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
|
||||||
(False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
|
('layer_norm', (S, S, S, S), ([5], None, non_differentiable(torch.rand(S)),), 'with_only_bias',
|
||||||
('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),
|
(False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
|
||||||
non_differentiable(torch.rand(S))), 'with_weight_and_bias',
|
('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),
|
||||||
(False, ['aten::contiguous', 'aten::_batch_norm_impl_index', 'aten::addcmul'])),
|
non_differentiable(torch.rand(S))), 'with_weight_and_bias',
|
||||||
('group_norm', (S, S, S), (1, torch.rand(5),),),
|
(False, ['aten::contiguous', 'aten::_batch_norm_impl_index', 'aten::addcmul'])),
|
||||||
('local_response_norm', (S, S, S), (2, ),),
|
('group_norm', (S, S, S), (1, torch.rand(5),),),
|
||||||
('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),), '',),
|
('local_response_norm', (S, S, S), (2, ),),
|
||||||
('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2),),),
|
('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),), '',),
|
||||||
('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2), True, True), 'full'),
|
('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2),),),
|
||||||
('kl_div', F.log_softmax(torch.randn(S, 10), 1), (F.softmax(torch.randn(S, 10), 1),),),
|
('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2), True, True), 'full'),
|
||||||
('cross_entropy', (3, S), (torch.randint(S, (3,), dtype=torch.int64),),),
|
('kl_div', F.log_softmax(torch.randn(S, 10), 1), (F.softmax(torch.randn(S, 10), 1),),),
|
||||||
('binary_cross_entropy_with_logits', (3,), (torch.empty(3).random_(2), ),),
|
('cross_entropy', (3, S), (torch.randint(S, (3,), dtype=torch.int64),),),
|
||||||
('smooth_l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
('binary_cross_entropy_with_logits', (3,), (torch.empty(3).random_(2), ),),
|
||||||
('huber_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
('smooth_l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
||||||
('l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
('huber_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
||||||
('mse_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
('l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
||||||
('smooth_l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
|
('mse_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
||||||
('huber_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
|
('smooth_l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
|
||||||
('l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
|
('huber_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
|
||||||
('mse_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
|
('l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
|
||||||
('margin_ranking_loss', (S,), ((S,), (S,)),),
|
('mse_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
|
||||||
('hinge_embedding_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
('margin_ranking_loss', (S,), ((S,), (S,)),),
|
||||||
('soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
('hinge_embedding_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
||||||
('multilabel_soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
('soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
||||||
('cosine_embedding_loss', (S, S), ((S, S), non_differentiable(torch.rand(S,))),),
|
('multilabel_soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
||||||
('pixel_shuffle', (1, 9, 4, 4), (3,),),
|
('cosine_embedding_loss', (S, S), ((S, S), non_differentiable(torch.rand(S,))),),
|
||||||
('pixel_unshuffle', (1, 1, 12, 12), (3,),),
|
('pixel_shuffle', (1, 9, 4, 4), (3,),),
|
||||||
('affine_grid', (S, 2, 3), (torch.Size([S, 1, 7, 7]),),),
|
('pixel_unshuffle', (1, 1, 12, 12), (3,),),
|
||||||
('pad', (3, 3, 4, 2), ([1, 1],),),
|
('affine_grid', (S, 2, 3), (torch.Size([S, 1, 7, 7]),),),
|
||||||
('pairwise_distance', (S, S), ((S, S),),),
|
('pad', (3, 3, 4, 2), ([1, 1],),),
|
||||||
('pdist', (S, S), (),),
|
('pairwise_distance', (S, S), ((S, S),),),
|
||||||
('cosine_similarity', (S, S), ((S, S),),),
|
('pdist', (S, S), (),),
|
||||||
('triplet_margin_loss', (S, S), ((S, S), (S, S)),),
|
('cosine_similarity', (S, S), ((S, S),),),
|
||||||
('normalize', (S, S, S), (),),
|
('triplet_margin_loss', (S, S), ((S, S), (S, S)),),
|
||||||
('unfold', (S, S, S, S), ([2, 3]),),
|
('normalize', (S, S, S), (),),
|
||||||
('fold', (1, 3 * 2 * 2, 12), ([4, 5], [2, 2]),),
|
('unfold', (S, S, S, S), ([2, 3]),),
|
||||||
('grid_sample', (S, S, S, S), (non_differentiable(torch.rand(S, S, S, 2)),),),
|
('fold', (1, 3 * 2 * 2, 12), ([4, 5], [2, 2]),),
|
||||||
('gumbel_softmax', (S, S), (2.,), '', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])),
|
('grid_sample', (S, S, S, S), (non_differentiable(torch.rand(S, S, S, 2)),),),
|
||||||
('gumbel_softmax', (S, S), (2., True,), 'hard', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])),
|
('gumbel_softmax', (S, S), (2.,), '', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])),
|
||||||
('multilabel_margin_loss', torch.tensor([[0.2, -0.2, 0.07]]), (torch.tensor([[0, 0, 1]]),),),
|
('gumbel_softmax', (S, S), (2., True,), 'hard', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])),
|
||||||
('multi_margin_loss', (S, S), (non_differentiable(torch.randint(S, (S, ), dtype=torch.int64)),
|
('multilabel_margin_loss', torch.tensor([[0.2, -0.2, 0.07]]), (torch.tensor([[0, 0, 1]]),),),
|
||||||
1, 1., non_differentiable(torch.randn(S))),),
|
('multi_margin_loss', (S, S), (non_differentiable(torch.randint(S, (S, ), dtype=torch.int64)),
|
||||||
('binary_cross_entropy', torch.randn(3, 2).sigmoid(), (non_differentiable(torch.rand(3, 2)),
|
1, 1., non_differentiable(torch.randn(S))),),
|
||||||
non_differentiable(torch.randn(3, 2))),),
|
('binary_cross_entropy', torch.randn(3, 2).sigmoid(), (non_differentiable(torch.rand(3, 2)),
|
||||||
('binary_cross_entropy', torch.randn(3, 2).sigmoid(),
|
non_differentiable(torch.randn(3, 2))),),
|
||||||
(non_differentiable(torch.rand(3, 2)),
|
('binary_cross_entropy', torch.randn(3, 2).sigmoid(),
|
||||||
non_differentiable(torch.randn(3, 2)), None, None, 'mean'), 'size_average'),
|
(non_differentiable(torch.rand(3, 2)),
|
||||||
('ctc_loss', torch.rand(S, S, S).log_softmax(2).detach().requires_grad_(),
|
non_differentiable(torch.randn(3, 2)), None, None, 'mean'), 'size_average'),
|
||||||
(torch.randint(1, S, (S, S), dtype=torch.long), torch.full((S,), S, dtype=torch.long),
|
('ctc_loss', torch.rand(S, S, S).log_softmax(2).detach().requires_grad_(),
|
||||||
torch.randint(1, S, (S,), dtype=torch.long))),
|
(torch.randint(1, S, (S, S), dtype=torch.long), torch.full((S,), S, dtype=torch.long),
|
||||||
('upsample', torch.randn(S, S, M, M), (None, 2.), 'with_scale'),
|
torch.randint(1, S, (S,), dtype=torch.long))),
|
||||||
('upsample', torch.randn(S, S, M, M), (4,), 'with_size'),
|
('upsample', torch.randn(S, S, M, M), (None, 2.), 'with_scale'),
|
||||||
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'nearest_4d'),
|
('upsample', torch.randn(S, S, M, M), (4,), 'with_size'),
|
||||||
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'nearest_4d_with_scale'),
|
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'nearest_4d'),
|
||||||
('interpolate', torch.randn(S, S, M, M), (4,), 'nearest_4d_with_size'),
|
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'nearest_4d_with_scale'),
|
||||||
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'area_4d'),
|
('interpolate', torch.randn(S, S, M, M), (4,), 'nearest_4d_with_size'),
|
||||||
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'area_4d_with_scale'),
|
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'area_4d'),
|
||||||
('interpolate', torch.randn(S, S, M, M), (4,), 'area_4d_with_size'),
|
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'area_4d_with_scale'),
|
||||||
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bilinear_4d'),
|
('interpolate', torch.randn(S, S, M, M), (4,), 'area_4d_with_size'),
|
||||||
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bilinear_4d_with_scale'),
|
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bilinear_4d'),
|
||||||
('interpolate', torch.randn(S, S, M, M), (4,), 'bilinear_4d_with_size'),
|
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bilinear_4d_with_scale'),
|
||||||
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bicubic_4d'),
|
('interpolate', torch.randn(S, S, M, M), (4,), 'bilinear_4d_with_size'),
|
||||||
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bicubic_4d_with_scale'),
|
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bicubic_4d'),
|
||||||
('interpolate', torch.randn(S, S, M, M), (4,), 'bicubic_4d_with_size'),
|
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bicubic_4d_with_scale'),
|
||||||
('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'nearest_3d'),
|
('interpolate', torch.randn(S, S, M, M), (4,), 'bicubic_4d_with_size'),
|
||||||
('interpolate', torch.randn(S, M, M), (None, 2.), 'nearest_3d_with_scale'),
|
('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'nearest_3d'),
|
||||||
('interpolate', torch.randn(S, M, M), (4,), 'nearest_3d_with_size'),
|
('interpolate', torch.randn(S, M, M), (None, 2.), 'nearest_3d_with_scale'),
|
||||||
('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'area_3d'),
|
('interpolate', torch.randn(S, M, M), (4,), 'nearest_3d_with_size'),
|
||||||
('interpolate', torch.randn(S, M, M), (None, 2.), 'area_3d_with_scale'),
|
('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'area_3d'),
|
||||||
('interpolate', torch.randn(S, M, M), (4,), 'area_3d_with_size'),
|
('interpolate', torch.randn(S, M, M), (None, 2.), 'area_3d_with_scale'),
|
||||||
('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'linear_3d'),
|
('interpolate', torch.randn(S, M, M), (4,), 'area_3d_with_size'),
|
||||||
('interpolate', torch.randn(S, M, M), (None, 2.), 'linear_3d_with_scale'),
|
('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'linear_3d'),
|
||||||
('interpolate', torch.randn(S, M, M), (4,), 'linear_3d_with_size'),
|
('interpolate', torch.randn(S, M, M), (None, 2.), 'linear_3d_with_scale'),
|
||||||
('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'nearest_5d_with_scale'),
|
('interpolate', torch.randn(S, M, M), (4,), 'linear_3d_with_size'),
|
||||||
('interpolate', torch.randn(S, M, M, M, M), (4,), 'nearest_5d_with_size'),
|
('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'nearest_5d_with_scale'),
|
||||||
('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'area_5d'),
|
('interpolate', torch.randn(S, M, M, M, M), (4,), 'nearest_5d_with_size'),
|
||||||
('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'area_5d_with_scale'),
|
('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'area_5d'),
|
||||||
('interpolate', torch.randn(S, M, M, M, M), (4,), 'area_5d_with_size'),
|
('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'area_5d_with_scale'),
|
||||||
('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'trilinear_5d'),
|
('interpolate', torch.randn(S, M, M, M, M), (4,), 'area_5d_with_size'),
|
||||||
('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'trilinear_5d_with_scale'),
|
('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'trilinear_5d'),
|
||||||
('interpolate', torch.randn(S, M, M, M, M), (4,), 'trilinear_5d_with_size'),
|
('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'trilinear_5d_with_scale'),
|
||||||
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2, None, 'nearest', None, False),
|
('interpolate', torch.randn(S, M, M, M, M), (4,), 'trilinear_5d_with_size'),
|
||||||
'nearest_4d_not_recompute_scale_factor'),
|
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2, None, 'nearest', None, False),
|
||||||
('interpolate', torch.randn(S, S, M, M), (4, None, 'nearest', None, False),
|
'nearest_4d_not_recompute_scale_factor'),
|
||||||
'nearest_4d_with_size_not_recompute_scale_factor'),
|
('interpolate', torch.randn(S, S, M, M), (4, None, 'nearest', None, False),
|
||||||
('interpolate', torch.randn(S, S, M, M), (None, 2., 'bilinear', None, False),
|
'nearest_4d_with_size_not_recompute_scale_factor'),
|
||||||
'bilinear_4d_with_scale_not_recompute_scale_factor'),
|
('interpolate', torch.randn(S, S, M, M), (None, 2., 'bilinear', None, False),
|
||||||
('interpolate', torch.randn(S, S, M, M), (4, None, 'bilinear', None, False),
|
'bilinear_4d_with_scale_not_recompute_scale_factor'),
|
||||||
'bilinear_4d_with_size_not_recompute_scale_factor'),
|
('interpolate', torch.randn(S, S, M, M), (4, None, 'bilinear', None, False),
|
||||||
('interpolate', torch.randn(S, S, M, M), (None, 2., 'bicubic', None, False),
|
'bilinear_4d_with_size_not_recompute_scale_factor'),
|
||||||
'bicubic_4d_with_scale_not_recompute_scale_factor'),
|
('interpolate', torch.randn(S, S, M, M), (None, 2., 'bicubic', None, False),
|
||||||
('interpolate', torch.randn(S, S, M, M), (4, None, 'bicubic', None, False),
|
'bicubic_4d_with_scale_not_recompute_scale_factor'),
|
||||||
'bicubic_4d_with_size_not_recompute_scale_factor'),
|
('interpolate', torch.randn(S, S, M, M), (4, None, 'bicubic', None, False),
|
||||||
('interpolate', torch.randn(S, M, M), (None, 2., 'nearest', None, False),
|
'bicubic_4d_with_size_not_recompute_scale_factor'),
|
||||||
'nearest_3d_with_scale_not_recompute_scale_factor'),
|
('interpolate', torch.randn(S, M, M), (None, 2., 'nearest', None, False),
|
||||||
('interpolate', torch.randn(S, M, M), (4, None, 'nearest', None, False),
|
'nearest_3d_with_scale_not_recompute_scale_factor'),
|
||||||
'nearest_3d_with_size_not_recompute_scale_factor'),
|
('interpolate', torch.randn(S, M, M), (4, None, 'nearest', None, False),
|
||||||
('interpolate', torch.randn(S, M, M), (None, 2., 'linear', None, False),
|
'nearest_3d_with_size_not_recompute_scale_factor'),
|
||||||
'linear_3d_with_scale_not_recompute_scale_factor'),
|
('interpolate', torch.randn(S, M, M), (None, 2., 'linear', None, False),
|
||||||
('interpolate', torch.randn(S, M, M), (4, None, 'linear', None, False),
|
'linear_3d_with_scale_not_recompute_scale_factor'),
|
||||||
'linear_3d_with_size_not_recompute_scale_factor'),
|
('interpolate', torch.randn(S, M, M), (4, None, 'linear', None, False),
|
||||||
('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'nearest', None, False),
|
'linear_3d_with_size_not_recompute_scale_factor'),
|
||||||
'nearest_5d_with_scale_not_recompute_scale_factor'),
|
('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'nearest', None, False),
|
||||||
('interpolate', torch.randn(S, M, M, M, M), (4, None, 'nearest', None, False),
|
'nearest_5d_with_scale_not_recompute_scale_factor'),
|
||||||
'nearest_5d_with_size_not_recompute_scale_factor'),
|
('interpolate', torch.randn(S, M, M, M, M), (4, None, 'nearest', None, False),
|
||||||
('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'trilinear', None, False),
|
'nearest_5d_with_size_not_recompute_scale_factor'),
|
||||||
'trilinear_5d_with_scale_not_recompute_scale_factor'),
|
('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'trilinear', None, False),
|
||||||
('interpolate', torch.randn(S, M, M, M, M), (4, None, 'trilinear', None, False),
|
'trilinear_5d_with_scale_not_recompute_scale_factor'),
|
||||||
'trilinear_5d_with_size_not_recompute_scale_factor'),
|
('interpolate', torch.randn(S, M, M, M, M), (4, None, 'trilinear', None, False),
|
||||||
]
|
'trilinear_5d_with_size_not_recompute_scale_factor'),
|
||||||
|
]
|
||||||
|
return nn_functional_tests
|
||||||
|
|
||||||
script_template = '''
|
script_template = '''
|
||||||
def the_method({}):
|
def the_method({}):
|
||||||
|
|
@ -523,45 +525,6 @@ def get_nn_functional_compiled_fn_and_inputs(name, self_size, args, variant_name
|
||||||
return script_fn, inputs
|
return script_fn, inputs
|
||||||
|
|
||||||
|
|
||||||
# additional modules test
|
|
||||||
# TODO: delete this list once we make all nn_tests work
|
|
||||||
additional_module_tests = [
|
|
||||||
{
|
|
||||||
'module_name': 'Bilinear',
|
|
||||||
'constructor_args': (S, S, M),
|
|
||||||
'input_size': (S, S),
|
|
||||||
'extra_args': ((S, S),)
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'module_name': 'RNNCell',
|
|
||||||
'constructor_args': (S, S),
|
|
||||||
'input_size': (S, S),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'module_name': 'LSTMCell',
|
|
||||||
'constructor_args': (S, S),
|
|
||||||
'input_size': (S, S),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'module_name': 'GRUCell',
|
|
||||||
'constructor_args': (S, S),
|
|
||||||
'input_size': (S, S),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'module_name': 'MultiheadAttention',
|
|
||||||
'constructor_args': (128, 8),
|
|
||||||
'input_size': (10, 8, 128),
|
|
||||||
'extra_args': (torch.randn(10, 8, 128), torch.randn(10, 8, 128)),
|
|
||||||
'slowTest': True
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'module_name': 'Transformer',
|
|
||||||
'constructor_args': (1, 1, 1, 1, 2),
|
|
||||||
'input_size': (3, 1, 1),
|
|
||||||
'extra_args': (torch.randn(1, 1, 1),),
|
|
||||||
'slowTest': True
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
EXCLUDE_SCRIPT_MODULES = {
|
EXCLUDE_SCRIPT_MODULES = {
|
||||||
'test_nn_AdaptiveAvgPool2d_tuple_none',
|
'test_nn_AdaptiveAvgPool2d_tuple_none',
|
||||||
|
|
@ -719,4 +682,44 @@ def try_get_nn_module_compiled_mod_and_inputs(*args, **kwargs):
|
||||||
|
|
||||||
|
|
||||||
def get_all_nn_module_tests():
|
def get_all_nn_module_tests():
|
||||||
return module_tests + new_module_tests + additional_module_tests
|
# additional modules test
|
||||||
|
# TODO: delete this list once we make all nn_tests work
|
||||||
|
additional_module_tests = [
|
||||||
|
{
|
||||||
|
'module_name': 'Bilinear',
|
||||||
|
'constructor_args': (S, S, M),
|
||||||
|
'input_size': (S, S),
|
||||||
|
'extra_args': ((S, S),)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'module_name': 'RNNCell',
|
||||||
|
'constructor_args': (S, S),
|
||||||
|
'input_size': (S, S),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'module_name': 'LSTMCell',
|
||||||
|
'constructor_args': (S, S),
|
||||||
|
'input_size': (S, S),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'module_name': 'GRUCell',
|
||||||
|
'constructor_args': (S, S),
|
||||||
|
'input_size': (S, S),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'module_name': 'MultiheadAttention',
|
||||||
|
'constructor_args': (128, 8),
|
||||||
|
'input_size': (10, 8, 128),
|
||||||
|
'extra_args': (torch.randn(10, 8, 128), torch.randn(10, 8, 128)),
|
||||||
|
'slowTest': True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'module_name': 'Transformer',
|
||||||
|
'constructor_args': (1, 1, 1, 1, 2),
|
||||||
|
'input_size': (3, 1, 1),
|
||||||
|
'extra_args': (torch.randn(1, 1, 1),),
|
||||||
|
'slowTest': True
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
return module_tests + get_new_module_tests() + additional_module_tests
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user