No actual change, just remove variable contain Tensors from global scope (#143225)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143225
Approved by: https://github.com/ezyang
This commit is contained in:
albanD 2024-12-16 19:20:42 -05:00 committed by PyTorch MergeBot
parent afa313e669
commit 792f1c47e9
12 changed files with 2157 additions and 2138 deletions

View File

@ -20,8 +20,8 @@ from torch.testing._internal.common_utils import (
from torch.testing._internal.jit_metaprogramming_utils import (
get_all_nn_module_tests,
get_nn_functional_compiled_fn_and_inputs,
get_nn_functional_tests,
get_nn_mod_test_name,
nn_functional_tests,
try_get_nn_module_compiled_mod_and_inputs,
)
from torch.testing._internal.jit_utils import enable_profiling_mode, JitTestCase
@ -70,7 +70,7 @@ class TestComplexity(JitTestCase):
def test_generated_functional_tests(self):
with enable_profiling_mode():
stats = [("Name", "Ifs/Loops", "non-tensor ops")]
for test in nn_functional_tests:
for test in get_nn_functional_tests():
test_name = test[0]
fn, inputs = get_nn_functional_compiled_fn_and_inputs(*test)

View File

@ -42,7 +42,7 @@ if not common.IS_ARM64:
(sample_module.module_tests, common_nn.NewModuleTest),
(sample_functional.functional_tests, common_nn.NewModuleTest),
(common_nn.module_tests, common_nn.NewModuleTest),
(common_nn.new_module_tests, common_nn.NewModuleTest),
(common_nn.get_new_module_tests(), common_nn.NewModuleTest),
(common_nn.criterion_tests, common_nn.CriterionTest),
]:
for test_params_dict in test_params_dicts:

View File

@ -25,7 +25,11 @@ from torch.testing._internal.common_device_type import (
)
from torch.testing._internal.common_methods_invocations import op_db, SampleInput
from torch.testing._internal.common_modules import module_db, modules
from torch.testing._internal.common_nn import module_tests, new_module_tests, TestBase
from torch.testing._internal.common_nn import (
get_new_module_tests,
module_tests,
TestBase,
)
from torch.testing._internal.common_utils import (
freeze_rng_state,
make_tensor,
@ -1011,7 +1015,7 @@ def filter_supported_tests(t):
# TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests
# These currently use the legacy nn tests
supported_tests = [
t for t in module_tests + new_module_tests if filter_supported_tests(t)
t for t in module_tests + get_new_module_tests() if filter_supported_tests(t)
]
for test_param in supported_tests:
if "constructor" not in test_param:

View File

@ -50,7 +50,7 @@ from torch.testing._internal.common_device_type import (
ops,
)
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_nn import module_tests, new_module_tests
from torch.testing._internal.common_nn import module_tests, get_new_module_tests
from torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase
from torch.testing._internal.jit_utils import JitTestCase
import torch.utils._pytree as pytree
@ -1006,7 +1006,7 @@ terrible spacing
Exhaustively test `Node.normalized_arguments` on all standard
torch.nn Module classes
"""
for test_params in module_tests + new_module_tests:
for test_params in module_tests + get_new_module_tests():
if "constructor" not in test_params:
constructor = getattr(torch.nn, test_params["module_name"])
else:

View File

@ -107,10 +107,10 @@ from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, dis
from torch.testing._internal.jit_metaprogramming_utils import (
get_script_args,
create_input, unpack_variables,
additional_module_tests, EXCLUDE_SCRIPT_MODULES,
get_all_nn_module_tests, EXCLUDE_SCRIPT_MODULES,
get_nn_module_name_from_kwargs, get_nn_mod_test_name, script_method_template)
from torch.testing._internal.common_nn import module_tests, new_module_tests, criterion_tests
from torch.testing._internal.common_nn import criterion_tests
# For testing truediv in python 2
from torch.testing._internal.test_module.future_div import div_int_future, div_float_future
@ -16247,7 +16247,7 @@ class TestProducerVersion(TestCase):
# issue gh-32561
self.assertTrue(torch.__version__.startswith(torch.onnx.producer_version))
for test in module_tests + new_module_tests + additional_module_tests:
for test in get_all_nn_module_tests():
add_nn_module_test(**test)
for test in criterion_tests:

View File

@ -38,7 +38,7 @@ from torch.testing._internal.common_utils import freeze_rng_state, run_tests, Te
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, PLATFORM_SUPPORTS_FLASH_ATTENTION
from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \
module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \
ctcloss_reference, new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input
ctcloss_reference, get_new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input
from torch.testing._internal.common_device_type import dtypesIfMPS, instantiate_device_type_tests, dtypes, \
dtypesIfCUDA, precisionOverride, skipCUDAIfCudnnVersionLessThan, onlyCUDA, onlyCPU, \
skipCUDAIfRocm, skipCUDAIf, skipCUDAIfNotRocm, \
@ -7332,7 +7332,7 @@ def add_test(test, decorator=None):
else:
add(cuda_test_name, with_tf32_off)
for test_params in module_tests + new_module_tests:
for test_params in module_tests + get_new_module_tests():
# TODO: CUDA is not implemented yet
if 'constructor' not in test_params:
name = test_params.pop('module_name')

View File

@ -19,8 +19,6 @@ from torch.fx import Graph, GraphModule, Node
from torch.fx.subgraph_rewriter import replace_pattern_with_filters, ReplacedPatterns
from .utils import (
_conv1d_bn_example_inputs,
_conv2d_bn_example_inputs,
_get_aten_graph_module_for_pattern,
_is_bn_node,
_is_conv_or_conv_transpose_node,
@ -35,27 +33,6 @@ if TYPE_CHECKING:
__all__ = [] # type: ignore[var-annotated]
# Example inputs for quantized and folded conv-bn1d patterns used in convert
_quantized_conv1d_bn_example_inputs = (
torch.randn(1, 1, 3), # x
torch.randn(1, 1, 1), # conv_weight
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
# Example inputs for quantized and folded conv-bn2d patterns used in convert
_quantized_conv2d_bn_example_inputs = (
torch.randn(1, 1, 3, 3), # x
torch.randn(1, 1, 1, 1), # conv_weight
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
def _get_quantized_conv_bn_example_inputs_kwargs(
is_per_channel: bool,
has_bias: bool,
@ -631,6 +608,28 @@ def _update_special_qspecs_after_replacement(
def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
# Example inputs for conv-bn1d patterns
_conv1d_bn_example_inputs = (
torch.randn(1, 1, 3), # x
torch.randn(1, 1, 1), # conv_weight
torch.randn(1), # conv_bias
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
# Example inputs for conv-bn2d patterns
_conv2d_bn_example_inputs = (
torch.randn(1, 1, 3, 3), # x
torch.randn(1, 1, 1, 1), # conv_weight
torch.randn(1), # conv_bias
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
if not has_bn:
return m
@ -859,6 +858,26 @@ def _copy_over_q_dq_args(original_node: Node, replacement_node: Node):
def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
# Example inputs for quantized and folded conv-bn1d patterns used in convert
_quantized_conv1d_bn_example_inputs = (
torch.randn(1, 1, 3), # x
torch.randn(1, 1, 1), # conv_weight
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
# Example inputs for quantized and folded conv-bn2d patterns used in convert
_quantized_conv2d_bn_example_inputs = (
torch.randn(1, 1, 3, 3), # x
torch.randn(1, 1, 1, 1), # conv_weight
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
if not has_bn:
return m

View File

@ -22,25 +22,6 @@ __all__ = [
]
_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (2, 5), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-127], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randn(1, dtype=torch.float),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
def _qdq_quantized_linear(
x_i8,
x_scale,
@ -129,20 +110,6 @@ def _reference_quantized_linear(
return out_i8
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
torch.randn((2, 5), dtype=torch.float),
-128,
127,
torch.finfo(torch.float32).eps,
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-127], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randn(1, dtype=torch.float),
)
def _qdq_dynamic_quantized_linear(
x_fp32,
x_quant_min,
@ -223,25 +190,6 @@ def _reference_dynamic_quantized_linear(
return out_fp32
_QUANTIZED_CONV2d_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-127], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randn(1, dtype=torch.float),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
def _qdq_quantized_conv2d(
x_i8,
x_scale,
@ -375,20 +323,6 @@ def _reference_quantized_conv2d(
return out_i8
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
def _qdq_quantized_add_relu(
x_i8,
x_scale,
@ -518,19 +452,6 @@ def _reference_quantized_add(
return out_i8
_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
def _qdq_quantized_max_pool2d(
x_i8,
x_scale,
@ -587,15 +508,6 @@ def _reference_quantized_max_pool2d(
return out_i8
_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
torch.randn(1, 3, 3, 3, dtype=torch.float),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
def _quantize_per_tensor_int8(x_fp32, scale, zero_point, quant_min, quant_max):
x = torch.ops.quantized_decomposed.quantize_per_tensor(
x_fp32, scale, zero_point, quant_min, quant_max, torch.int8
@ -619,15 +531,6 @@ def _reference_quantize_per_tensor_int8(
return x
_DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
def _dequantize_per_tensor_int8(x_i8, scale, zero_point, quant_min, quant_max):
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
x_i8, scale, zero_point, quant_min, quant_max, torch.int8
@ -648,16 +551,6 @@ def _reference_dequantize_per_tensor_int8(
return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
_QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
torch.randn(1, 3, 3, 3, dtype=torch.float),
torch.randn(3, dtype=torch.float),
torch.zeros(3, dtype=torch.int),
1,
-128,
127,
)
def _quantize_per_channel_int8(
x_fp32, scales, zero_points, ch_axis, quant_min, quant_max
):
@ -678,16 +571,6 @@ def _reference_quantize_per_channel_int8(
return out_i32.to(torch.int8)
_DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(3, dtype=torch.float),
torch.zeros(3, dtype=torch.int),
1,
-128,
127,
)
def _dequantize_per_channel_int8(
x_i8, scales, zero_points, ch_axis, quant_min, quant_max
):
@ -733,7 +616,116 @@ class _RewriteInfo:
replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
_REWRITE_INFO_LIST = [
def reference_representation_rewrite(model: GraphModule) -> GraphModule:
_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (2, 5), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-127], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randn(1, dtype=torch.float),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
torch.randn((2, 5), dtype=torch.float),
-128,
127,
torch.finfo(torch.float32).eps,
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-127], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randn(1, dtype=torch.float),
)
_QUANTIZED_CONV2d_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-127], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randn(1, dtype=torch.float),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
torch.randn(1, 3, 3, 3, dtype=torch.float),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
_DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
_QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
torch.randn(1, 3, 3, 3, dtype=torch.float),
torch.randn(3, dtype=torch.float),
torch.zeros(3, dtype=torch.int),
1,
-128,
127,
)
_DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(3, dtype=torch.float),
torch.zeros(3, dtype=torch.int),
1,
-128,
127,
)
_REWRITE_INFO_LIST = [
_RewriteInfo(
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
_WrapperModule(_qdq_dynamic_quantized_linear),
@ -802,10 +794,8 @@ _REWRITE_INFO_LIST = [
_replace_ph_qdq_per_channel_replacement,
_replace_ph_qdq_per_channel_replacement,
),
]
]
def reference_representation_rewrite(model: GraphModule) -> GraphModule:
remove_tensor_overload_for_qdq_ops(model)
from torch._export import gm_using_training_ir

View File

@ -33,28 +33,6 @@ _DEQUANTIZE_OPS = [
torch.ops.quantized_decomposed.dequantize_per_channel.default,
]
# Example inputs for conv-bn1d patterns
_conv1d_bn_example_inputs = (
torch.randn(1, 1, 3), # x
torch.randn(1, 1, 1), # conv_weight
torch.randn(1), # conv_bias
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
# Example inputs for conv-bn2d patterns
_conv2d_bn_example_inputs = (
torch.randn(1, 1, 3, 3), # x
torch.randn(1, 1, 1, 1), # conv_weight
torch.randn(1), # conv_bias
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
def _is_connected(source: torch.fx.Node, dest: torch.fx.Node) -> bool:
"""

View File

@ -10,8 +10,6 @@ from torch._subclasses import FakeTensor
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
from torch.ao.quantization.pt2e.export_utils import _WrapperModule
from torch.ao.quantization.pt2e.utils import (
_conv1d_bn_example_inputs,
_conv2d_bn_example_inputs,
_get_aten_graph_module_for_pattern,
_is_conv_node,
_is_conv_transpose_node,
@ -487,6 +485,28 @@ def _do_annotate_conv_bn(
for the following names: "input", "conv", "weight", "bias", and "output".
"""
# Example inputs for conv-bn1d patterns
_conv1d_bn_example_inputs = (
torch.randn(1, 1, 3), # x
torch.randn(1, 1, 1), # conv_weight
torch.randn(1), # conv_bias
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
# Example inputs for conv-bn2d patterns
_conv2d_bn_example_inputs = (
torch.randn(1, 1, 3, 3), # x
torch.randn(1, 1, 1, 1), # conv_weight
torch.randn(1), # conv_bias
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
def get_pattern(conv_fn: Callable, relu_is_inplace: bool):
def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv):
conv = conv_fn(x, conv_weight, conv_bias)

View File

@ -1076,7 +1076,8 @@ def single_batch_reference_fn(input, parameters, module):
return module(*single_batch_input).squeeze(0)
new_module_tests = [
def get_new_module_tests():
new_module_tests = [
poissonnllloss_no_reduce_test(),
bceloss_no_reduce_test(),
bceloss_weights_no_reduce_test(),
@ -1788,7 +1789,8 @@ new_module_tests = [
dict(
fullname='EmbeddingBag_sparse',
constructor=lambda: nn.EmbeddingBag(4, 3, sparse=True, dtype=torch.double),
cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3).sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))',
cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
.sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))''',
input_fn=lambda: torch.randperm(2).repeat(1, 2),
check_gradgrad=False,
has_sparse_gradients=True,
@ -2622,10 +2624,10 @@ new_module_tests = [
check_half=True,
desc='3d_no_affine_large_feature',
),
]
]
# add conv padding mode tests:
for padding_mode, cpp_padding_mode in zip(
# add conv padding mode tests:
for padding_mode, cpp_padding_mode in zip(
['reflect', 'circular', 'replicate', 'zeros'],
['torch::kReflect', 'torch::kCircular', 'torch::kReplicate', 'torch::kZeros']):
# conv signature:
@ -2662,14 +2664,14 @@ for padding_mode, cpp_padding_mode in zip(
),
)
# Check that non linear activations work with no batch dimensions
non_linear_activations_no_batch = [
# Check that non linear activations work with no batch dimensions
non_linear_activations_no_batch = [
'ELU', 'Hardshrink', 'Hardsigmoid', 'Hardtanh', 'Hardswish', 'LeakyReLU',
'LogSigmoid', 'PReLU', 'ReLU', 'ReLU6', 'RReLU', 'SELU', 'CELU', 'GELU', 'GLU',
'Sigmoid', 'SiLU', 'Mish', 'Softplus', 'Softshrink', 'Softsign', 'Tanh',
'Tanhshrink', 'Threshold'
]
non_linear_activations_extra_info: Dict[str, dict] = {
]
non_linear_activations_extra_info: Dict[str, dict] = {
'CELU': {'constructor_args': (2.,), 'default_dtype': torch.double},
'Threshold': {'constructor_args': (2., 1.)},
'Hardsigmoid': {'check_gradgrad': False, 'check_jit': False, 'default_dtype': torch.double},
@ -2696,8 +2698,8 @@ non_linear_activations_extra_info: Dict[str, dict] = {
'Softsign': {'default_dtype': torch.double},
'Tanh': {'default_dtype': torch.double},
'Tanhshrink': {'default_dtype': torch.double},
}
for non_linear_activation in non_linear_activations_no_batch:
}
for non_linear_activation in non_linear_activations_no_batch:
activation_test_info = dict(
module_name=non_linear_activation,
input_size=(4,),
@ -2710,6 +2712,9 @@ for non_linear_activation in non_linear_activations_no_batch:
new_module_tests.append(activation_test_info)
return new_module_tests
def kldivloss_reference(input, target, reduction='mean', log_target=False):
if log_target:
result = torch.exp(target) * (target - input)

View File

@ -8,7 +8,7 @@ import torch.cuda
import torch.jit
import torch.jit._logging
import torch.jit.frontend
from torch.testing._internal.common_nn import module_tests, new_module_tests
from torch.testing._internal.common_nn import module_tests, get_new_module_tests
from torch.testing._internal.common_utils import is_iterable_of_tensors, noncontiguous_like
import collections
@ -95,7 +95,8 @@ def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwarg
# fn mapping output to part that should be gradcheck'ed, // optional
# kwargs for function, // optional
# )
nn_functional_tests = [
def get_nn_functional_tests():
nn_functional_tests = [
('conv1d', (S, S, S), ((S, S, S),)),
('conv2d', (S, S, S, S), ((S, S, S, S),)),
('conv3d', (S, S, S, S, S), ((S, S, S, S, S),)),
@ -314,7 +315,8 @@ nn_functional_tests = [
'trilinear_5d_with_scale_not_recompute_scale_factor'),
('interpolate', torch.randn(S, M, M, M, M), (4, None, 'trilinear', None, False),
'trilinear_5d_with_size_not_recompute_scale_factor'),
]
]
return nn_functional_tests
script_template = '''
def the_method({}):
@ -523,45 +525,6 @@ def get_nn_functional_compiled_fn_and_inputs(name, self_size, args, variant_name
return script_fn, inputs
# additional modules test
# TODO: delete this list once we make all nn_tests work
additional_module_tests = [
{
'module_name': 'Bilinear',
'constructor_args': (S, S, M),
'input_size': (S, S),
'extra_args': ((S, S),)
},
{
'module_name': 'RNNCell',
'constructor_args': (S, S),
'input_size': (S, S),
},
{
'module_name': 'LSTMCell',
'constructor_args': (S, S),
'input_size': (S, S),
},
{
'module_name': 'GRUCell',
'constructor_args': (S, S),
'input_size': (S, S),
},
{
'module_name': 'MultiheadAttention',
'constructor_args': (128, 8),
'input_size': (10, 8, 128),
'extra_args': (torch.randn(10, 8, 128), torch.randn(10, 8, 128)),
'slowTest': True
},
{
'module_name': 'Transformer',
'constructor_args': (1, 1, 1, 1, 2),
'input_size': (3, 1, 1),
'extra_args': (torch.randn(1, 1, 1),),
'slowTest': True
}
]
EXCLUDE_SCRIPT_MODULES = {
'test_nn_AdaptiveAvgPool2d_tuple_none',
@ -719,4 +682,44 @@ def try_get_nn_module_compiled_mod_and_inputs(*args, **kwargs):
def get_all_nn_module_tests():
return module_tests + new_module_tests + additional_module_tests
# additional modules test
# TODO: delete this list once we make all nn_tests work
additional_module_tests = [
{
'module_name': 'Bilinear',
'constructor_args': (S, S, M),
'input_size': (S, S),
'extra_args': ((S, S),)
},
{
'module_name': 'RNNCell',
'constructor_args': (S, S),
'input_size': (S, S),
},
{
'module_name': 'LSTMCell',
'constructor_args': (S, S),
'input_size': (S, S),
},
{
'module_name': 'GRUCell',
'constructor_args': (S, S),
'input_size': (S, S),
},
{
'module_name': 'MultiheadAttention',
'constructor_args': (128, 8),
'input_size': (10, 8, 128),
'extra_args': (torch.randn(10, 8, 128), torch.randn(10, 8, 128)),
'slowTest': True
},
{
'module_name': 'Transformer',
'constructor_args': (1, 1, 1, 1, 2),
'input_size': (3, 1, 1),
'extra_args': (torch.randn(1, 1, 1),),
'slowTest': True
}
]
return module_tests + get_new_module_tests() + additional_module_tests