Support both train / eval modes for ModuleInfo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78735

Approved by: https://github.com/albanD
This commit is contained in:
Joel Benjamin Schlosser 2022-06-09 12:37:29 -04:00 committed by PyTorch MergeBot
parent 79f18c1aee
commit 70d6446a3d
3 changed files with 159 additions and 64 deletions

View File

@ -9,7 +9,7 @@ from operator import methodcaller
import torch
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, onlyCUDA, toleranceOverride, tol, skipMeta)
from torch.testing._internal.common_modules import module_db, modules
from torch.testing._internal.common_modules import module_db, modules, TrainEvalMode
from torch.testing._internal.common_utils import (
TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck, gradgradcheck, skipIfMps)
from unittest.mock import patch, call
@ -42,10 +42,10 @@ class TestModule(TestCase):
@skipIfMps # the test doesn't work on MPS as double types are not supported
@modules(module_db)
def test_forward(self, device, dtype, module_info):
def test_forward(self, device, dtype, module_info, training):
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=False)
requires_grad=False, training=training)
dtype_to_method_caller = {
torch.float32: methodcaller("float"),
torch.float64: methodcaller("double"),
@ -59,6 +59,7 @@ class TestModule(TestCase):
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
m = module_cls(*args, **kwargs)
m.to(device).to(dtype)
m.train(training)
# === Do forward pass. ===
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
@ -80,10 +81,10 @@ class TestModule(TestCase):
# Tests passing factory kwargs (e.g. device / dtype) during module instantiation.
# They should be applied to any created parameters and buffers.
@modules(module_db)
def test_factory_kwargs(self, device, dtype, module_info):
def test_factory_kwargs(self, device, dtype, module_info, training):
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=False)
requires_grad=False, training=training)
for module_input in module_inputs:
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
@ -96,6 +97,7 @@ class TestModule(TestCase):
register_buffer = mock_wrapper(torch.nn.Module.register_buffer)
with patch.object(torch.nn.Module, 'register_buffer', register_buffer):
m = module_cls(*args, **kwargs)
m.train(training)
# Check if a parameter or buffer was created with a tensor not passed to the constructor.
constructor_tensors = get_tensors_from(args, kwargs)
@ -122,6 +124,7 @@ class TestModule(TestCase):
uninit_buffer_new = mock_wrapper(torch.nn.UninitializedBuffer.__new__)
with patch.object(torch.nn.UninitializedBuffer, '__new__', uninit_buffer_new):
m = module_cls(*args, **kwargs)
m.train(training)
uninit_param_new.mock.assert_has_calls(
[call(device=device, dtype=dtype) for _ in uninit_param_new.mock.mock_calls])
uninit_buffer_new.mock.assert_has_calls(
@ -130,16 +133,17 @@ class TestModule(TestCase):
# Check device placement and dtype for created parameters and buffers.
# Only verify floating point dtypes since that's what the kwarg applies to.
m = module_cls(*args, **kwargs)
m.train(training)
self._assert_module_parameters_and_buffer_are(m, device, dtype)
@onlyCUDA
@modules(module_db)
def test_multiple_device_transfer(self, device, dtype, module_info):
def test_multiple_device_transfer(self, device, dtype, module_info, training):
module_cls = module_info.module_cls
module_inputs_device = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=False)
requires_grad=False, training=training)
module_inputs_cpu = module_info.module_inputs_func(module_info, device="cpu", dtype=dtype,
requires_grad=False)
requires_grad=False, training=training)
for module_input_device, module_input_cpu in zip(module_inputs_device, module_inputs_cpu):
if module_input_device.forward_input is None:
continue
@ -149,6 +153,7 @@ class TestModule(TestCase):
args, kwargs = module_input_device.constructor_input.args, module_input_device.constructor_input.kwargs
m = module_cls(*args, **kwargs)
m.to(device).to(dtype)
m.train(training)
# === Do forward pass on GPU ===
input_device_args = module_input_device.forward_input.args
@ -189,14 +194,16 @@ class TestModule(TestCase):
@modules(module_db)
def test_repr(self, device, dtype, module_info):
def test_repr(self, device, dtype, module_info, training):
# Test module can be represented with repr and str without errors.
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=False)
requires_grad=False, training=training)
for module_input in module_inputs:
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
m = module_cls(*args, **kwargs)
m.to(device).to(dtype)
m.train(training)
# Check that these methods do not raise errors
m.__repr__()
@ -204,11 +211,11 @@ class TestModule(TestCase):
@skipIfMps
@modules(module_db)
def test_pickle(self, device, dtype, module_info):
def test_pickle(self, device, dtype, module_info, training):
# Test that module can be pickled and unpickled.
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=False)
requires_grad=False, training=training)
for module_input in module_inputs:
if module_input.forward_input is None:
continue
@ -220,6 +227,7 @@ class TestModule(TestCase):
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
m = module_cls(*args, **kwargs)
m.to(device).to(dtype)
m.train(training)
# === Do forward pass. ===
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
@ -233,15 +241,15 @@ class TestModule(TestCase):
output_from_copy = m_copy(*args, **kwargs)
self.assertEqual(output, output_from_copy)
@skipMeta
@modules([module_info for module_info in module_db
if 'inplace' in signature(module_info.module_cls).parameters])
@skipMeta
def test_check_inplace(self, device, dtype, module_info):
def test_check_inplace(self, device, dtype, module_info, training):
# Check if the inplace variant of the module gives the same result as the out of place
# variant.
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=True)
requires_grad=True, training=training)
for module_input in module_inputs:
if module_input.forward_input is None:
continue
@ -250,8 +258,10 @@ class TestModule(TestCase):
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
m_op = module_cls(*args, **kwargs, inplace=False)
m_op.to(device).to(dtype)
m_op.train(training)
m_inplace = module_cls(*args, **kwargs, inplace=True)
m_inplace.to(device).to(dtype)
m_inplace.train(training)
# === Inplace modules only supports inplace operations on the first argument ===
input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
@ -315,12 +325,12 @@ class TestModule(TestCase):
@skipIfMps
@modules(module_db)
def test_non_contiguous_tensors(self, device, dtype, module_info):
def test_non_contiguous_tensors(self, device, dtype, module_info, training):
# Check modules work with non-contiguous tensors
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=True)
requires_grad=True, training=training)
def _make_non_contiguous(obj):
def inner_make_non_contiguous(obj):
@ -357,6 +367,7 @@ class TestModule(TestCase):
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
m = module_cls(*args, **kwargs)
m.to(device).to(dtype)
m.train(training)
self._retain_grad((input_args, input_kwargs))
@ -409,11 +420,11 @@ class TestModule(TestCase):
self.assertEqual(param_grad, default_param_grad)
def _test_gradients_helper(self, device, dtype, module_info, check):
def _test_gradients_helper(self, device, dtype, module_info, training, check):
# Check gradients
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=True)
requires_grad=True, training=training)
# === Set nondet tol for gradcheck to user-defined value if on CUDA and cudNN is enabled
gradcheck_nondet_tol = 0.0
if (torch.device(device).type == 'cuda' and torch.backends.cudnn.enabled):
@ -427,6 +438,7 @@ class TestModule(TestCase):
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
m = module_cls(*args, **kwargs)
m.to(device).to(dtype)
m.train(training)
params = tuple(m.parameters())
@ -464,23 +476,33 @@ class TestModule(TestCase):
@modules(module_db, allowed_dtypes=[torch.double])
def test_grad(self, device, dtype, module_info):
self._test_gradients_helper(device, dtype, module_info, gradcheck)
def test_grad(self, device, dtype, module_info, training):
self._test_gradients_helper(device, dtype, module_info, training, gradcheck)
@modules([m for m in module_db if m.supports_gradgrad],
allowed_dtypes=[torch.double])
def test_gradgrad(self, device, dtype, module_info):
self._test_gradients_helper(device, dtype, module_info, gradgradcheck)
def test_gradgrad(self, device, dtype, module_info, training):
self._test_gradients_helper(device, dtype, module_info, training, gradgradcheck)
@onlyCUDA
@toleranceOverride({torch.float32: tol(5e-2, 0),
torch.float64: tol(4e-4, 0)})
@modules(module_db)
def test_cpu_gpu_parity(self, device, dtype, module_info):
def test_cpu_gpu_parity(self, device, dtype, module_info, training):
# TODO: RNN / GRU / LSTM don't support backwards on eval mode for cuDNN; skip this in a
# nicer way for eval mode only.
# See https://github.com/pytorch/pytorch/issues/79161
rnn_modules = set([torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM])
if (module_info.module_cls in rnn_modules
and not training
and 'cuda' in device
and torch.backends.cudnn.enabled):
return
# Test cpu and gpu results are the same
module_cls = module_info.module_cls
module_inputs_cpu = module_info.module_inputs_func(module_info, device="cpu", dtype=dtype,
requires_grad=True)
requires_grad=True, training=training)
def _to_device(obj):
if isinstance(obj, torch.Tensor):
@ -495,7 +517,6 @@ class TestModule(TestCase):
return deepcopy(obj)
for module_input in module_inputs_cpu:
# === Move input from cpu to device ===
cpu_forward_args = module_input.forward_input.args
cpu_forward_kwargs = module_input.forward_input.kwargs
@ -508,7 +529,9 @@ class TestModule(TestCase):
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
cpu_module = module_cls(*args, **kwargs).to(dtype).to("cpu")
cpu_module.train(training)
gpu_module = module_cls(*args, **kwargs).to(dtype).to(device)
gpu_module.train(training)
for cpu_p, gpu_p in zip(cpu_module.parameters(), gpu_module.parameters()):
gpu_p.data.copy_(cpu_p)
@ -549,10 +572,10 @@ class TestModule(TestCase):
@skipIfMps
@modules(module_db)
def test_memory_format(self, device, dtype, module_info):
def test_memory_format(self, device, dtype, module_info, training):
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=False)
requires_grad=False, training=training)
module_memformat_affects_out = module_info.module_memformat_affects_out
def _get_mem_formats(channels_last=False, channels_last_3d=False):
@ -613,6 +636,7 @@ class TestModule(TestCase):
m = module_cls(*args, **kwargs)
m.to(device).to(dtype)
m.train(training)
# === Get output in (contiguous, contiguous) configuration. ===
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
@ -640,6 +664,42 @@ class TestModule(TestCase):
# === Check mem format of output. ===
_check_out_mem_format(outputs, input_mem_format, module_mem_format)
# Test whether train and eval modes differ for each module. Use to verify
# that the ModuleInfo entry flag is correct.
@skipIfMps # the test doesn't work on MPS as double types are not supported
@modules(module_db, train_eval_mode=TrainEvalMode.train_only)
def test_if_train_and_eval_modes_differ(self, device, dtype, module_info, training):
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=False, training=training)
# Run forward inputs through to see if the training flag is accessed during forward.
for module_input in module_inputs:
if module_input.forward_input is None:
continue
# === Instantiate the module. ===
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
m = module_cls(*args, **kwargs)
m.to(device).to(dtype)
m.train(training)
# Remove training attribute and see if forward still works.
delattr(m, 'training')
# === Do forward pass. ===
try:
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
m(*args, **kwargs)
except AttributeError as e:
if "'training'" in str(e):
self.assertTrue(module_info.train_and_eval_differ,
f"The ModuleInfo entry for {module_info.name} has "
"train_and_eval_differ=False, but the training mode was found to "
"affect the forward pass. Consider setting train_and_eval_differ=True "
"for this ModuleInfo entry.")
else:
raise e
instantiate_device_type_tests(TestModule, globals())

View File

@ -273,7 +273,7 @@ def _update_param_kwargs(param_kwargs, name, value):
if isinstance(value, list) or isinstance(value, tuple):
# Make name plural (e.g. devices / dtypes) if the value is composite.
param_kwargs['{}s'.format(name)] = value
elif value:
elif value is not None:
param_kwargs[name] = value
# Leave param_kwargs as-is when value is None.

View File

@ -1,6 +1,7 @@
import torch
import unittest
from copy import deepcopy
from enum import Enum
from functools import wraps, partial
from itertools import chain, product
import itertools
@ -50,12 +51,33 @@ for namespace in MODULE_NAMESPACES:
MODULE_CLASS_NAMES[module_cls] = f'{namespace_name}.{module_name}'
# Specifies the modes (i.e. train, eval) to test over.
TrainEvalMode = Enum('TrainEvalMode', ('train_only', 'eval_only', 'train_and_eval'))
class modules(_TestParametrizer):
""" PROTOTYPE: Decorator for specifying a list of modules over which to run a test. """
def __init__(self, module_info_list, allowed_dtypes=None):
def __init__(self, module_info_list, allowed_dtypes=None, train_eval_mode=TrainEvalMode.train_and_eval):
self.module_info_list = module_info_list
self.allowed_dtypes = set(allowed_dtypes) if allowed_dtypes is not None else None
self.train_eval_mode = train_eval_mode
def _get_training_flags(self, module_info):
training_flags = []
if (self.train_eval_mode == TrainEvalMode.train_only or
self.train_eval_mode == TrainEvalMode.train_and_eval):
training_flags.append(True)
if (self.train_eval_mode == TrainEvalMode.eval_only or
self.train_eval_mode == TrainEvalMode.train_and_eval):
training_flags.append(False)
# If train and eval modes don't differ for the module, don't bother using more than one.
if not module_info.train_and_eval_differ:
training_flags = training_flags[:1]
return training_flags
def _parametrize_test(self, test, generic_cls, device_cls):
if device_cls is None:
@ -64,18 +86,22 @@ class modules(_TestParametrizer):
'instantiate_parametrized_tests()')
for module_info in self.module_info_list:
# Construct the test name; device / dtype parts are handled outside.
# See [Note: device and dtype suffix placement]
test_name = module_info.name.replace('.', '_')
dtypes = set(module_info.dtypes)
if self.allowed_dtypes is not None:
dtypes = dtypes.intersection(self.allowed_dtypes)
for dtype in dtypes:
training_flags = self._get_training_flags(module_info)
for (training, dtype) in product(training_flags, dtypes):
# Construct the test name; device / dtype parts are handled outside.
# See [Note: device and dtype suffix placement]
test_name = module_info.formatted_name
if len(training_flags) > 1:
test_name += f"_{'train_mode' if training else 'eval_mode'}"
# Construct parameter kwargs to pass to the test.
param_kwargs = {'module_info': module_info}
_update_param_kwargs(param_kwargs, 'dtype', dtype)
_update_param_kwargs(param_kwargs, 'training', training)
try:
active_decorators = [set_single_threaded_if_parallel_tbb]
@ -106,9 +132,9 @@ class modules(_TestParametrizer):
raise ex
def formatted_module_name(module_cls):
def get_module_fully_qualified_name(module_cls):
""" Returns the common name of the module class formatted for use in test names. """
return MODULE_CLASS_NAMES[module_cls].replace('.', '_')
return MODULE_CLASS_NAMES[module_cls]
class FunctionInput(object):
@ -157,6 +183,7 @@ class ModuleInfo(object):
gradcheck_nondet_tol=0.0, # tolerance for nondeterminism while performing gradcheck
module_memformat_affects_out=False, # whether converting module to channels last will generate
# channels last output
train_and_eval_differ=False, # whether the module has differing behavior between train and eval
):
self.module_cls = module_cls
self.module_inputs_func = module_inputs_func
@ -166,20 +193,21 @@ class ModuleInfo(object):
self.supports_gradgrad = supports_gradgrad
self.gradcheck_nondet_tol = gradcheck_nondet_tol
self.module_memformat_affects_out = module_memformat_affects_out
self.train_and_eval_differ = train_and_eval_differ
def should_skip(self, cls_name, test_name, device_type, dtype):
return any(si.is_active(cls_name, test_name, device_type, dtype) for si in self.skips)
@property
def name(self):
return formatted_module_name(self.module_cls)
return get_module_fully_qualified_name(self.module_cls)
@property
def formatted_name(self):
return self.name.replace('.', '_')
def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, **kwargs):
def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
module_inputs = [
@ -199,7 +227,7 @@ def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, **k
return module_inputs
def module_inputs_torch_nn_Bilinear(module_info, device, dtype, requires_grad, **kwargs):
def module_inputs_torch_nn_Bilinear(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
def bilinear_reference_fn(m, p, x1, x2, bias=True):
@ -228,7 +256,7 @@ def module_inputs_torch_nn_Bilinear(module_info, device, dtype, requires_grad, *
return module_inputs
def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, **kwargs):
def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
@ -263,7 +291,7 @@ def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, **
return module_inputs
def module_inputs_torch_nn_GaussianNLLLoss(module_info, device, dtype, requires_grad, **kwargs):
def module_inputs_torch_nn_GaussianNLLLoss(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
@ -424,7 +452,7 @@ def generate_regression_criterion_inputs(make_input):
) for reduction in ['none', 'mean', 'sum']]
def module_inputs_torch_nn_AvgPool1d(module_info, device, dtype, requires_grad, **kwargs):
def module_inputs_torch_nn_AvgPool1d(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
return [
@ -434,7 +462,7 @@ def module_inputs_torch_nn_AvgPool1d(module_info, device, dtype, requires_grad,
reference_fn=no_batch_dim_reference_fn)]
def module_inputs_torch_nn_AdaptiveAvgPool2d(module_info, device, dtype, requires_grad, **kwargs):
def module_inputs_torch_nn_AdaptiveAvgPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
return [
@ -443,7 +471,7 @@ def module_inputs_torch_nn_AdaptiveAvgPool2d(module_info, device, dtype, require
desc='single')]
def module_inputs_torch_nn_BatchNorm2d(module_info, device, dtype, requires_grad, **kwargs):
def module_inputs_torch_nn_BatchNorm2d(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
return [
@ -451,7 +479,7 @@ def module_inputs_torch_nn_BatchNorm2d(module_info, device, dtype, requires_grad
forward_input=FunctionInput(make_input((2, 3, 6, 6))))]
def module_inputs_torch_nn_BatchNorm3d(module_info, device, dtype, requires_grad, **kwargs):
def module_inputs_torch_nn_BatchNorm3d(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
return [
@ -459,7 +487,7 @@ def module_inputs_torch_nn_BatchNorm3d(module_info, device, dtype, requires_grad
forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))))]
def module_inputs_torch_nn_ConvNd(module_info, device, dtype, requires_grad, **kwargs):
def module_inputs_torch_nn_ConvNd(module_info, device, dtype, requires_grad, training, **kwargs):
N = kwargs['N']
lazy = kwargs.get('lazy', False)
transposed = kwargs.get('transposed', False)
@ -479,7 +507,7 @@ def module_inputs_torch_nn_ConvNd(module_info, device, dtype, requires_grad, **k
]
def module_inputs_torch_nn_ELU(module_info, device, dtype, requires_grad, **kwargs):
def module_inputs_torch_nn_ELU(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
return [
@ -498,7 +526,7 @@ def module_inputs_torch_nn_ELU(module_info, device, dtype, requires_grad, **kwar
desc='4d_input')]
def module_inputs_torch_nn_CELU(module_info, device, dtype, requires_grad, **kwargs):
def module_inputs_torch_nn_CELU(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
return [
@ -515,7 +543,7 @@ def module_inputs_torch_nn_CELU(module_info, device, dtype, requires_grad, **kwa
reference_fn=no_batch_dim_reference_fn)]
def module_inputs_torch_nn_ReLU(module_info, device, dtype, requires_grad):
def module_inputs_torch_nn_ReLU(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
return [
@ -530,7 +558,7 @@ def module_inputs_torch_nn_ReLU(module_info, device, dtype, requires_grad):
desc='channels_last_3d_mem_format')]
def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, **kwargs):
def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
return [
@ -545,7 +573,7 @@ def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, **k
desc='scalar')] + generate_regression_criterion_inputs(make_input)
def module_inputs_torch_nn_CrossEntropyLoss(module_info, device, dtype, requires_grad, **kwargs):
def module_inputs_torch_nn_CrossEntropyLoss(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
@ -579,7 +607,7 @@ def module_inputs_torch_nn_CrossEntropyLoss(module_info, device, dtype, requires
return samples
def module_inputs_torch_nn_Hardswish(module_info, device, dtype, requires_grad, **kwargs):
def module_inputs_torch_nn_Hardswish(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
return [
@ -596,7 +624,7 @@ def module_inputs_torch_nn_Hardswish(module_info, device, dtype, requires_grad,
]
def module_inputs_torch_nn_MaxPool2d(module_info, device, dtype, requires_grad, **kwargs):
def module_inputs_torch_nn_MaxPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
return [
@ -615,7 +643,7 @@ def module_inputs_torch_nn_MaxPool2d(module_info, device, dtype, requires_grad,
]
def module_inputs_torch_nn_Sigmoid(module_info, device, dtype, requires_grad, **kwargs):
def module_inputs_torch_nn_Sigmoid(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
return [
@ -632,7 +660,7 @@ def module_inputs_torch_nn_Sigmoid(module_info, device, dtype, requires_grad, **
]
def module_inputs_torch_nn_TransformerEncoderLayer(module_info, device, dtype, requires_grad, **kwargs):
def module_inputs_torch_nn_TransformerEncoderLayer(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
samples = [
@ -681,7 +709,7 @@ def module_inputs_torch_nn_TransformerEncoderLayer(module_info, device, dtype, r
return samples
def module_inputs_torch_nn_TransformerDecoderLayer(module_info, device, dtype, requires_grad, **kwargs):
def module_inputs_torch_nn_TransformerDecoderLayer(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
samples = [
@ -737,7 +765,7 @@ def module_inputs_torch_nn_TransformerDecoderLayer(module_info, device, dtype, r
return samples
def module_inputs_torch_nn_Transformer(module_info, device, dtype, requires_grad, **kwargs):
def module_inputs_torch_nn_Transformer(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
samples = []
# Samples below are for validating the no-batch-dim support.
@ -780,7 +808,7 @@ def module_inputs_torch_nn_Transformer(module_info, device, dtype, requires_grad
return samples
def module_inputs_torch_nn_Embedding(module_info, device, dtype, requires_grad, **kwargs):
def module_inputs_torch_nn_Embedding(module_info, device, dtype, requires_grad, training, **kwargs):
make_empty = partial(torch.empty, device=device, dtype=torch.long, requires_grad=False)
return [
ModuleInput(
@ -795,7 +823,7 @@ def module_inputs_torch_nn_Embedding(module_info, device, dtype, requires_grad,
]
def module_inputs_torch_nn_MultiheadAttention(module_info, device, dtype, requires_grad, **kwargs):
def module_inputs_torch_nn_MultiheadAttention(module_info, device, dtype, requires_grad, training, **kwargs):
# Currently all samples below are for validating the no-batch-dim support.
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
samples = []
@ -826,7 +854,7 @@ def module_inputs_torch_nn_MultiheadAttention(module_info, device, dtype, requir
return samples
def module_inputs_torch_nn_RNN_GRU_Cell(module_info, device, dtype, requires_grad, **kwargs):
def module_inputs_torch_nn_RNN_GRU_Cell(module_info, device, dtype, requires_grad, training, **kwargs):
# Currently all samples below are for validating the no-batch-dim support.
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
samples = [
@ -857,7 +885,7 @@ def module_inputs_torch_nn_RNN_GRU_Cell(module_info, device, dtype, requires_gra
return samples
def module_inputs_torch_nn_LSTMCell(module_info, device, dtype, requires_grad, **kwargs):
def module_inputs_torch_nn_LSTMCell(module_info, device, dtype, requires_grad, training, **kwargs):
# Currently all samples below are for validating the no-batch-dim support.
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
samples = (
@ -876,7 +904,7 @@ def module_inputs_torch_nn_LSTMCell(module_info, device, dtype, requires_grad, *
return samples
def module_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, **kwargs):
def module_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, training, **kwargs):
# Currently all samples below are for validating the no-batch-dim support.
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
is_rnn = kwargs['is_rnn']
@ -923,7 +951,7 @@ def module_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, **
return samples
def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, **kwargs):
def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, training, **kwargs):
# Currently all samples below are for validating the no-batch-dim support.
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
bias = (False, True)
@ -1005,6 +1033,7 @@ module_db: List[ModuleInfo] = [
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
),
ModuleInfo(torch.nn.BatchNorm2d,
train_and_eval_differ=True,
module_inputs_func=module_inputs_torch_nn_BatchNorm2d,
skips=(
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),),
@ -1013,6 +1042,7 @@ module_db: List[ModuleInfo] = [
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),)
),
ModuleInfo(torch.nn.BatchNorm3d,
train_and_eval_differ=True,
module_inputs_func=module_inputs_torch_nn_BatchNorm3d,
skips=(
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),),
@ -1273,6 +1303,7 @@ module_db: List[ModuleInfo] = [
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),),
supports_gradgrad=False),
ModuleInfo(torch.nn.TransformerEncoderLayer,
train_and_eval_differ=True,
module_inputs_func=module_inputs_torch_nn_TransformerEncoderLayer,
skips=(
# No channels_last support for TransformerEncoderLayer currently.
@ -1294,6 +1325,7 @@ module_db: List[ModuleInfo] = [
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
),
ModuleInfo(torch.nn.MultiheadAttention,
train_and_eval_differ=True,
module_inputs_func=module_inputs_torch_nn_MultiheadAttention,
skips=(
# No channels_last support for MultiheadAttention currently.
@ -1332,17 +1364,20 @@ module_db: List[ModuleInfo] = [
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
),
ModuleInfo(torch.nn.RNN,
train_and_eval_differ=True,
module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=True),
skips=(
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),),
decorators=rnn_gru_lstm_module_info_decorators
),
ModuleInfo(torch.nn.GRU,
train_and_eval_differ=True,
module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=False),
skips=(
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),),
decorators=rnn_gru_lstm_module_info_decorators),
ModuleInfo(torch.nn.LSTM,
train_and_eval_differ=True,
module_inputs_func=module_inputs_torch_nn_LSTM,
skips=(
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),),