mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
TST Adds more modules into common module tests (#62999)
Summary: This PR moves some modules into `common_modules` to see what it looks like. While migrating some no batch modules into `common_modules`, I noticed that `desc` is not used for the name. This means we can not use `-k` to filter tests. This PR moves the sample generation into `_parametrize_test`, and passes in the already generated `module_input` into users of `modules(modules_db)`. I can see this is a little different from opsinfo and would be happy to revert to the original implementation of `modules`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/62999 Reviewed By: heitorschueroff Differential Revision: D30522737 Pulled By: jbschlosser fbshipit-source-id: 7ed1aeb3753fc97a4ad6f1a3c789727c78e1bc73
This commit is contained in:
parent
544af391b5
commit
ba126df614
|
|
@ -5,8 +5,8 @@ from itertools import chain
|
|||
from torch.testing import floating_types
|
||||
from torch.testing._internal.common_device_type import (
|
||||
_TestParametrizer, _dtype_test_suffix, _update_param_kwargs, skipIf)
|
||||
from torch.testing._internal.common_nn import nllloss_reference
|
||||
from torch.testing._internal.common_utils import make_tensor
|
||||
from torch.testing._internal.common_nn import nllloss_reference, get_reduction
|
||||
from torch.testing._internal.common_utils import make_tensor, freeze_rng_state
|
||||
from types import ModuleType
|
||||
from typing import List, Tuple, Type, Set, Dict
|
||||
|
||||
|
|
@ -46,6 +46,7 @@ for namespace in MODULE_NAMESPACES:
|
|||
|
||||
class modules(_TestParametrizer):
|
||||
""" PROTOTYPE: Decorator for specifying a list of modules over which to run a test. """
|
||||
|
||||
def __init__(self, module_info_list):
|
||||
self.module_info_list = module_info_list
|
||||
|
||||
|
|
@ -199,8 +200,103 @@ def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, **
|
|||
return module_inputs
|
||||
|
||||
|
||||
def no_batch_dim_reference_fn(m, p, *args, **kwargs):
|
||||
"""Reference function for modules supporting no batch dimensions.
|
||||
|
||||
The module is passed the input and target in batched form with a single item.
|
||||
The output is squeezed to compare with the no-batch input.
|
||||
"""
|
||||
single_batch_input_args = [input.unsqueeze(0) for input in args]
|
||||
with freeze_rng_state():
|
||||
return m(*single_batch_input_args).squeeze(0)
|
||||
|
||||
|
||||
def no_batch_dim_reference_criterion_fn(m, *args, **kwargs):
|
||||
"""Reference function for criterion supporting no batch dimensions."""
|
||||
output = no_batch_dim_reference_fn(m, *args, **kwargs)
|
||||
reduction = get_reduction(m)
|
||||
if reduction == 'none':
|
||||
return output.squeeze(0)
|
||||
# reduction is 'sum' or 'mean' which results in a 0D tensor
|
||||
return output
|
||||
|
||||
|
||||
def generate_regression_criterion_inputs(make_input):
|
||||
return [
|
||||
ModuleInput(
|
||||
constructor_input=FunctionInput(reduction=reduction),
|
||||
forward_input=FunctionInput(make_input(size=(4, )), make_input(size=4,)),
|
||||
reference_fn=no_batch_dim_reference_criterion_fn,
|
||||
desc='no_batch_dim_{}'.format(reduction)
|
||||
) for reduction in ['none', 'mean', 'sum']]
|
||||
|
||||
|
||||
def module_inputs_torch_nn_AvgPool1d(module_info, device, dtype, requires_grad, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
return [
|
||||
ModuleInput(constructor_input=FunctionInput(kernel_size=2),
|
||||
forward_input=FunctionInput(make_input(size=(3, 6))),
|
||||
desc='no_batch_dim',
|
||||
reference_fn=no_batch_dim_reference_fn)]
|
||||
|
||||
|
||||
def module_inputs_torch_nn_ELU(module_info, device, dtype, requires_grad, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
return [
|
||||
ModuleInput(constructor_input=FunctionInput(alpha=2.),
|
||||
forward_input=FunctionInput(make_input(size=(3, 2, 5))),
|
||||
reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2 * (i.exp() - 1))),
|
||||
ModuleInput(constructor_input=FunctionInput(alpha=2.),
|
||||
forward_input=FunctionInput(make_input(size=())),
|
||||
desc='scalar'),
|
||||
ModuleInput(constructor_input=FunctionInput(),
|
||||
forward_input=FunctionInput(make_input(size=(3,))),
|
||||
desc='no_batch_dim',
|
||||
reference_fn=no_batch_dim_reference_fn)]
|
||||
|
||||
|
||||
def module_inputs_torch_nn_CELU(module_info, device, dtype, requires_grad, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
return [
|
||||
ModuleInput(constructor_input=FunctionInput(alpha=2.),
|
||||
forward_input=FunctionInput(make_input(size=(3, 2, 5))),
|
||||
reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2. * ((.5 * i).exp() - 1))),
|
||||
ModuleInput(constructor_input=FunctionInput(alpha=2.),
|
||||
forward_input=FunctionInput(make_input(size=())),
|
||||
reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2 * (i.exp() - 1)),
|
||||
desc='scalar'),
|
||||
ModuleInput(constructor_input=FunctionInput(alpha=2.),
|
||||
forward_input=FunctionInput(make_input(size=(3,))),
|
||||
desc='no_batch_dim',
|
||||
reference_fn=no_batch_dim_reference_fn)]
|
||||
|
||||
|
||||
def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
return [
|
||||
ModuleInput(constructor_input=FunctionInput(),
|
||||
forward_input=FunctionInput(make_input(size=(2, 3, 4)),
|
||||
make_input(size=(2, 3, 4))),
|
||||
reference_fn=lambda m, p, i, t: 1. / i.numel() * sum((a - b).abs().sum()
|
||||
for a, b in zip(i, t))),
|
||||
ModuleInput(constructor_input=FunctionInput(),
|
||||
forward_input=FunctionInput(make_input(size=()), make_input(size=())),
|
||||
reference_fn=lambda m, p, i, t: 1. / i.numel() * (i - t).abs().sum(),
|
||||
desc='scalar')] + generate_regression_criterion_inputs(make_input)
|
||||
|
||||
|
||||
# Database of ModuleInfo entries in alphabetical order.
|
||||
module_db: List[ModuleInfo] = [
|
||||
ModuleInfo(torch.nn.AvgPool1d,
|
||||
module_inputs_func=module_inputs_torch_nn_AvgPool1d),
|
||||
ModuleInfo(torch.nn.ELU,
|
||||
module_inputs_func=module_inputs_torch_nn_ELU),
|
||||
ModuleInfo(torch.nn.L1Loss,
|
||||
module_inputs_func=module_inputs_torch_nn_L1Loss),
|
||||
ModuleInfo(torch.nn.Linear,
|
||||
module_inputs_func=module_inputs_torch_nn_Linear),
|
||||
ModuleInfo(torch.nn.NLLLoss,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user