Support both train / eval modes for ModuleInfo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78735

Approved by: https://github.com/albanD
This commit is contained in:
Joel Benjamin Schlosser 2022-06-09 12:37:29 -04:00 committed by PyTorch MergeBot
parent 79f18c1aee
commit 70d6446a3d
3 changed files with 159 additions and 64 deletions

View File

@ -9,7 +9,7 @@ from operator import methodcaller
import torch import torch
from torch.testing._internal.common_device_type import ( from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, onlyCUDA, toleranceOverride, tol, skipMeta) instantiate_device_type_tests, onlyCUDA, toleranceOverride, tol, skipMeta)
from torch.testing._internal.common_modules import module_db, modules from torch.testing._internal.common_modules import module_db, modules, TrainEvalMode
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck, gradgradcheck, skipIfMps) TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck, gradgradcheck, skipIfMps)
from unittest.mock import patch, call from unittest.mock import patch, call
@ -42,10 +42,10 @@ class TestModule(TestCase):
@skipIfMps # the test doesn't work on MPS as double types are not supported @skipIfMps # the test doesn't work on MPS as double types are not supported
@modules(module_db) @modules(module_db)
def test_forward(self, device, dtype, module_info): def test_forward(self, device, dtype, module_info, training):
module_cls = module_info.module_cls module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=False) requires_grad=False, training=training)
dtype_to_method_caller = { dtype_to_method_caller = {
torch.float32: methodcaller("float"), torch.float32: methodcaller("float"),
torch.float64: methodcaller("double"), torch.float64: methodcaller("double"),
@ -59,6 +59,7 @@ class TestModule(TestCase):
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
m = module_cls(*args, **kwargs) m = module_cls(*args, **kwargs)
m.to(device).to(dtype) m.to(device).to(dtype)
m.train(training)
# === Do forward pass. === # === Do forward pass. ===
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
@ -80,10 +81,10 @@ class TestModule(TestCase):
# Tests passing factory kwargs (e.g. device / dtype) during module instantiation. # Tests passing factory kwargs (e.g. device / dtype) during module instantiation.
# They should be applied to any created parameters and buffers. # They should be applied to any created parameters and buffers.
@modules(module_db) @modules(module_db)
def test_factory_kwargs(self, device, dtype, module_info): def test_factory_kwargs(self, device, dtype, module_info, training):
module_cls = module_info.module_cls module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=False) requires_grad=False, training=training)
for module_input in module_inputs: for module_input in module_inputs:
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
@ -96,6 +97,7 @@ class TestModule(TestCase):
register_buffer = mock_wrapper(torch.nn.Module.register_buffer) register_buffer = mock_wrapper(torch.nn.Module.register_buffer)
with patch.object(torch.nn.Module, 'register_buffer', register_buffer): with patch.object(torch.nn.Module, 'register_buffer', register_buffer):
m = module_cls(*args, **kwargs) m = module_cls(*args, **kwargs)
m.train(training)
# Check if a parameter or buffer was created with a tensor not passed to the constructor. # Check if a parameter or buffer was created with a tensor not passed to the constructor.
constructor_tensors = get_tensors_from(args, kwargs) constructor_tensors = get_tensors_from(args, kwargs)
@ -122,6 +124,7 @@ class TestModule(TestCase):
uninit_buffer_new = mock_wrapper(torch.nn.UninitializedBuffer.__new__) uninit_buffer_new = mock_wrapper(torch.nn.UninitializedBuffer.__new__)
with patch.object(torch.nn.UninitializedBuffer, '__new__', uninit_buffer_new): with patch.object(torch.nn.UninitializedBuffer, '__new__', uninit_buffer_new):
m = module_cls(*args, **kwargs) m = module_cls(*args, **kwargs)
m.train(training)
uninit_param_new.mock.assert_has_calls( uninit_param_new.mock.assert_has_calls(
[call(device=device, dtype=dtype) for _ in uninit_param_new.mock.mock_calls]) [call(device=device, dtype=dtype) for _ in uninit_param_new.mock.mock_calls])
uninit_buffer_new.mock.assert_has_calls( uninit_buffer_new.mock.assert_has_calls(
@ -130,16 +133,17 @@ class TestModule(TestCase):
# Check device placement and dtype for created parameters and buffers. # Check device placement and dtype for created parameters and buffers.
# Only verify floating point dtypes since that's what the kwarg applies to. # Only verify floating point dtypes since that's what the kwarg applies to.
m = module_cls(*args, **kwargs) m = module_cls(*args, **kwargs)
m.train(training)
self._assert_module_parameters_and_buffer_are(m, device, dtype) self._assert_module_parameters_and_buffer_are(m, device, dtype)
@onlyCUDA @onlyCUDA
@modules(module_db) @modules(module_db)
def test_multiple_device_transfer(self, device, dtype, module_info): def test_multiple_device_transfer(self, device, dtype, module_info, training):
module_cls = module_info.module_cls module_cls = module_info.module_cls
module_inputs_device = module_info.module_inputs_func(module_info, device=device, dtype=dtype, module_inputs_device = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=False) requires_grad=False, training=training)
module_inputs_cpu = module_info.module_inputs_func(module_info, device="cpu", dtype=dtype, module_inputs_cpu = module_info.module_inputs_func(module_info, device="cpu", dtype=dtype,
requires_grad=False) requires_grad=False, training=training)
for module_input_device, module_input_cpu in zip(module_inputs_device, module_inputs_cpu): for module_input_device, module_input_cpu in zip(module_inputs_device, module_inputs_cpu):
if module_input_device.forward_input is None: if module_input_device.forward_input is None:
continue continue
@ -149,6 +153,7 @@ class TestModule(TestCase):
args, kwargs = module_input_device.constructor_input.args, module_input_device.constructor_input.kwargs args, kwargs = module_input_device.constructor_input.args, module_input_device.constructor_input.kwargs
m = module_cls(*args, **kwargs) m = module_cls(*args, **kwargs)
m.to(device).to(dtype) m.to(device).to(dtype)
m.train(training)
# === Do forward pass on GPU === # === Do forward pass on GPU ===
input_device_args = module_input_device.forward_input.args input_device_args = module_input_device.forward_input.args
@ -189,14 +194,16 @@ class TestModule(TestCase):
@modules(module_db) @modules(module_db)
def test_repr(self, device, dtype, module_info): def test_repr(self, device, dtype, module_info, training):
# Test module can be represented with repr and str without errors. # Test module can be represented with repr and str without errors.
module_cls = module_info.module_cls module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=False) requires_grad=False, training=training)
for module_input in module_inputs: for module_input in module_inputs:
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
m = module_cls(*args, **kwargs) m = module_cls(*args, **kwargs)
m.to(device).to(dtype)
m.train(training)
# Check that these methods do not raise errors # Check that these methods do not raise errors
m.__repr__() m.__repr__()
@ -204,11 +211,11 @@ class TestModule(TestCase):
@skipIfMps @skipIfMps
@modules(module_db) @modules(module_db)
def test_pickle(self, device, dtype, module_info): def test_pickle(self, device, dtype, module_info, training):
# Test that module can be pickled and unpickled. # Test that module can be pickled and unpickled.
module_cls = module_info.module_cls module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=False) requires_grad=False, training=training)
for module_input in module_inputs: for module_input in module_inputs:
if module_input.forward_input is None: if module_input.forward_input is None:
continue continue
@ -220,6 +227,7 @@ class TestModule(TestCase):
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
m = module_cls(*args, **kwargs) m = module_cls(*args, **kwargs)
m.to(device).to(dtype) m.to(device).to(dtype)
m.train(training)
# === Do forward pass. === # === Do forward pass. ===
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
@ -233,15 +241,15 @@ class TestModule(TestCase):
output_from_copy = m_copy(*args, **kwargs) output_from_copy = m_copy(*args, **kwargs)
self.assertEqual(output, output_from_copy) self.assertEqual(output, output_from_copy)
@skipMeta
@modules([module_info for module_info in module_db @modules([module_info for module_info in module_db
if 'inplace' in signature(module_info.module_cls).parameters]) if 'inplace' in signature(module_info.module_cls).parameters])
@skipMeta def test_check_inplace(self, device, dtype, module_info, training):
def test_check_inplace(self, device, dtype, module_info):
# Check if the inplace variant of the module gives the same result as the out of place # Check if the inplace variant of the module gives the same result as the out of place
# variant. # variant.
module_cls = module_info.module_cls module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=True) requires_grad=True, training=training)
for module_input in module_inputs: for module_input in module_inputs:
if module_input.forward_input is None: if module_input.forward_input is None:
continue continue
@ -250,8 +258,10 @@ class TestModule(TestCase):
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
m_op = module_cls(*args, **kwargs, inplace=False) m_op = module_cls(*args, **kwargs, inplace=False)
m_op.to(device).to(dtype) m_op.to(device).to(dtype)
m_op.train(training)
m_inplace = module_cls(*args, **kwargs, inplace=True) m_inplace = module_cls(*args, **kwargs, inplace=True)
m_inplace.to(device).to(dtype) m_inplace.to(device).to(dtype)
m_inplace.train(training)
# === Inplace modules only supports inplace operations on the first argument === # === Inplace modules only supports inplace operations on the first argument ===
input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
@ -315,12 +325,12 @@ class TestModule(TestCase):
@skipIfMps @skipIfMps
@modules(module_db) @modules(module_db)
def test_non_contiguous_tensors(self, device, dtype, module_info): def test_non_contiguous_tensors(self, device, dtype, module_info, training):
# Check modules work with non-contiguous tensors # Check modules work with non-contiguous tensors
module_cls = module_info.module_cls module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=True) requires_grad=True, training=training)
def _make_non_contiguous(obj): def _make_non_contiguous(obj):
def inner_make_non_contiguous(obj): def inner_make_non_contiguous(obj):
@ -357,6 +367,7 @@ class TestModule(TestCase):
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
m = module_cls(*args, **kwargs) m = module_cls(*args, **kwargs)
m.to(device).to(dtype) m.to(device).to(dtype)
m.train(training)
self._retain_grad((input_args, input_kwargs)) self._retain_grad((input_args, input_kwargs))
@ -409,11 +420,11 @@ class TestModule(TestCase):
self.assertEqual(param_grad, default_param_grad) self.assertEqual(param_grad, default_param_grad)
def _test_gradients_helper(self, device, dtype, module_info, check): def _test_gradients_helper(self, device, dtype, module_info, training, check):
# Check gradients # Check gradients
module_cls = module_info.module_cls module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=True) requires_grad=True, training=training)
# === Set nondet tol for gradcheck to user-defined value if on CUDA and cudNN is enabled # === Set nondet tol for gradcheck to user-defined value if on CUDA and cudNN is enabled
gradcheck_nondet_tol = 0.0 gradcheck_nondet_tol = 0.0
if (torch.device(device).type == 'cuda' and torch.backends.cudnn.enabled): if (torch.device(device).type == 'cuda' and torch.backends.cudnn.enabled):
@ -427,6 +438,7 @@ class TestModule(TestCase):
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
m = module_cls(*args, **kwargs) m = module_cls(*args, **kwargs)
m.to(device).to(dtype) m.to(device).to(dtype)
m.train(training)
params = tuple(m.parameters()) params = tuple(m.parameters())
@ -464,23 +476,33 @@ class TestModule(TestCase):
@modules(module_db, allowed_dtypes=[torch.double]) @modules(module_db, allowed_dtypes=[torch.double])
def test_grad(self, device, dtype, module_info): def test_grad(self, device, dtype, module_info, training):
self._test_gradients_helper(device, dtype, module_info, gradcheck) self._test_gradients_helper(device, dtype, module_info, training, gradcheck)
@modules([m for m in module_db if m.supports_gradgrad], @modules([m for m in module_db if m.supports_gradgrad],
allowed_dtypes=[torch.double]) allowed_dtypes=[torch.double])
def test_gradgrad(self, device, dtype, module_info): def test_gradgrad(self, device, dtype, module_info, training):
self._test_gradients_helper(device, dtype, module_info, gradgradcheck) self._test_gradients_helper(device, dtype, module_info, training, gradgradcheck)
@onlyCUDA @onlyCUDA
@toleranceOverride({torch.float32: tol(5e-2, 0), @toleranceOverride({torch.float32: tol(5e-2, 0),
torch.float64: tol(4e-4, 0)}) torch.float64: tol(4e-4, 0)})
@modules(module_db) @modules(module_db)
def test_cpu_gpu_parity(self, device, dtype, module_info): def test_cpu_gpu_parity(self, device, dtype, module_info, training):
# TODO: RNN / GRU / LSTM don't support backwards on eval mode for cuDNN; skip this in a
# nicer way for eval mode only.
# See https://github.com/pytorch/pytorch/issues/79161
rnn_modules = set([torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM])
if (module_info.module_cls in rnn_modules
and not training
and 'cuda' in device
and torch.backends.cudnn.enabled):
return
# Test cpu and gpu results are the same # Test cpu and gpu results are the same
module_cls = module_info.module_cls module_cls = module_info.module_cls
module_inputs_cpu = module_info.module_inputs_func(module_info, device="cpu", dtype=dtype, module_inputs_cpu = module_info.module_inputs_func(module_info, device="cpu", dtype=dtype,
requires_grad=True) requires_grad=True, training=training)
def _to_device(obj): def _to_device(obj):
if isinstance(obj, torch.Tensor): if isinstance(obj, torch.Tensor):
@ -495,7 +517,6 @@ class TestModule(TestCase):
return deepcopy(obj) return deepcopy(obj)
for module_input in module_inputs_cpu: for module_input in module_inputs_cpu:
# === Move input from cpu to device === # === Move input from cpu to device ===
cpu_forward_args = module_input.forward_input.args cpu_forward_args = module_input.forward_input.args
cpu_forward_kwargs = module_input.forward_input.kwargs cpu_forward_kwargs = module_input.forward_input.kwargs
@ -508,7 +529,9 @@ class TestModule(TestCase):
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
cpu_module = module_cls(*args, **kwargs).to(dtype).to("cpu") cpu_module = module_cls(*args, **kwargs).to(dtype).to("cpu")
cpu_module.train(training)
gpu_module = module_cls(*args, **kwargs).to(dtype).to(device) gpu_module = module_cls(*args, **kwargs).to(dtype).to(device)
gpu_module.train(training)
for cpu_p, gpu_p in zip(cpu_module.parameters(), gpu_module.parameters()): for cpu_p, gpu_p in zip(cpu_module.parameters(), gpu_module.parameters()):
gpu_p.data.copy_(cpu_p) gpu_p.data.copy_(cpu_p)
@ -549,10 +572,10 @@ class TestModule(TestCase):
@skipIfMps @skipIfMps
@modules(module_db) @modules(module_db)
def test_memory_format(self, device, dtype, module_info): def test_memory_format(self, device, dtype, module_info, training):
module_cls = module_info.module_cls module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=False) requires_grad=False, training=training)
module_memformat_affects_out = module_info.module_memformat_affects_out module_memformat_affects_out = module_info.module_memformat_affects_out
def _get_mem_formats(channels_last=False, channels_last_3d=False): def _get_mem_formats(channels_last=False, channels_last_3d=False):
@ -613,6 +636,7 @@ class TestModule(TestCase):
m = module_cls(*args, **kwargs) m = module_cls(*args, **kwargs)
m.to(device).to(dtype) m.to(device).to(dtype)
m.train(training)
# === Get output in (contiguous, contiguous) configuration. === # === Get output in (contiguous, contiguous) configuration. ===
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
@ -640,6 +664,42 @@ class TestModule(TestCase):
# === Check mem format of output. === # === Check mem format of output. ===
_check_out_mem_format(outputs, input_mem_format, module_mem_format) _check_out_mem_format(outputs, input_mem_format, module_mem_format)
# Test whether train and eval modes differ for each module. Use to verify
# that the ModuleInfo entry flag is correct.
@skipIfMps # the test doesn't work on MPS as double types are not supported
@modules(module_db, train_eval_mode=TrainEvalMode.train_only)
def test_if_train_and_eval_modes_differ(self, device, dtype, module_info, training):
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=False, training=training)
# Run forward inputs through to see if the training flag is accessed during forward.
for module_input in module_inputs:
if module_input.forward_input is None:
continue
# === Instantiate the module. ===
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
m = module_cls(*args, **kwargs)
m.to(device).to(dtype)
m.train(training)
# Remove training attribute and see if forward still works.
delattr(m, 'training')
# === Do forward pass. ===
try:
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
m(*args, **kwargs)
except AttributeError as e:
if "'training'" in str(e):
self.assertTrue(module_info.train_and_eval_differ,
f"The ModuleInfo entry for {module_info.name} has "
"train_and_eval_differ=False, but the training mode was found to "
"affect the forward pass. Consider setting train_and_eval_differ=True "
"for this ModuleInfo entry.")
else:
raise e
instantiate_device_type_tests(TestModule, globals()) instantiate_device_type_tests(TestModule, globals())

View File

@ -273,7 +273,7 @@ def _update_param_kwargs(param_kwargs, name, value):
if isinstance(value, list) or isinstance(value, tuple): if isinstance(value, list) or isinstance(value, tuple):
# Make name plural (e.g. devices / dtypes) if the value is composite. # Make name plural (e.g. devices / dtypes) if the value is composite.
param_kwargs['{}s'.format(name)] = value param_kwargs['{}s'.format(name)] = value
elif value: elif value is not None:
param_kwargs[name] = value param_kwargs[name] = value
# Leave param_kwargs as-is when value is None. # Leave param_kwargs as-is when value is None.

View File

@ -1,6 +1,7 @@
import torch import torch
import unittest import unittest
from copy import deepcopy from copy import deepcopy
from enum import Enum
from functools import wraps, partial from functools import wraps, partial
from itertools import chain, product from itertools import chain, product
import itertools import itertools
@ -50,12 +51,33 @@ for namespace in MODULE_NAMESPACES:
MODULE_CLASS_NAMES[module_cls] = f'{namespace_name}.{module_name}' MODULE_CLASS_NAMES[module_cls] = f'{namespace_name}.{module_name}'
# Specifies the modes (i.e. train, eval) to test over.
TrainEvalMode = Enum('TrainEvalMode', ('train_only', 'eval_only', 'train_and_eval'))
class modules(_TestParametrizer): class modules(_TestParametrizer):
""" PROTOTYPE: Decorator for specifying a list of modules over which to run a test. """ """ PROTOTYPE: Decorator for specifying a list of modules over which to run a test. """
def __init__(self, module_info_list, allowed_dtypes=None): def __init__(self, module_info_list, allowed_dtypes=None, train_eval_mode=TrainEvalMode.train_and_eval):
self.module_info_list = module_info_list self.module_info_list = module_info_list
self.allowed_dtypes = set(allowed_dtypes) if allowed_dtypes is not None else None self.allowed_dtypes = set(allowed_dtypes) if allowed_dtypes is not None else None
self.train_eval_mode = train_eval_mode
def _get_training_flags(self, module_info):
training_flags = []
if (self.train_eval_mode == TrainEvalMode.train_only or
self.train_eval_mode == TrainEvalMode.train_and_eval):
training_flags.append(True)
if (self.train_eval_mode == TrainEvalMode.eval_only or
self.train_eval_mode == TrainEvalMode.train_and_eval):
training_flags.append(False)
# If train and eval modes don't differ for the module, don't bother using more than one.
if not module_info.train_and_eval_differ:
training_flags = training_flags[:1]
return training_flags
def _parametrize_test(self, test, generic_cls, device_cls): def _parametrize_test(self, test, generic_cls, device_cls):
if device_cls is None: if device_cls is None:
@ -64,18 +86,22 @@ class modules(_TestParametrizer):
'instantiate_parametrized_tests()') 'instantiate_parametrized_tests()')
for module_info in self.module_info_list: for module_info in self.module_info_list:
# Construct the test name; device / dtype parts are handled outside.
# See [Note: device and dtype suffix placement]
test_name = module_info.name.replace('.', '_')
dtypes = set(module_info.dtypes) dtypes = set(module_info.dtypes)
if self.allowed_dtypes is not None: if self.allowed_dtypes is not None:
dtypes = dtypes.intersection(self.allowed_dtypes) dtypes = dtypes.intersection(self.allowed_dtypes)
for dtype in dtypes: training_flags = self._get_training_flags(module_info)
for (training, dtype) in product(training_flags, dtypes):
# Construct the test name; device / dtype parts are handled outside.
# See [Note: device and dtype suffix placement]
test_name = module_info.formatted_name
if len(training_flags) > 1:
test_name += f"_{'train_mode' if training else 'eval_mode'}"
# Construct parameter kwargs to pass to the test. # Construct parameter kwargs to pass to the test.
param_kwargs = {'module_info': module_info} param_kwargs = {'module_info': module_info}
_update_param_kwargs(param_kwargs, 'dtype', dtype) _update_param_kwargs(param_kwargs, 'dtype', dtype)
_update_param_kwargs(param_kwargs, 'training', training)
try: try:
active_decorators = [set_single_threaded_if_parallel_tbb] active_decorators = [set_single_threaded_if_parallel_tbb]
@ -106,9 +132,9 @@ class modules(_TestParametrizer):
raise ex raise ex
def formatted_module_name(module_cls): def get_module_fully_qualified_name(module_cls):
""" Returns the common name of the module class formatted for use in test names. """ """ Returns the common name of the module class formatted for use in test names. """
return MODULE_CLASS_NAMES[module_cls].replace('.', '_') return MODULE_CLASS_NAMES[module_cls]
class FunctionInput(object): class FunctionInput(object):
@ -157,6 +183,7 @@ class ModuleInfo(object):
gradcheck_nondet_tol=0.0, # tolerance for nondeterminism while performing gradcheck gradcheck_nondet_tol=0.0, # tolerance for nondeterminism while performing gradcheck
module_memformat_affects_out=False, # whether converting module to channels last will generate module_memformat_affects_out=False, # whether converting module to channels last will generate
# channels last output # channels last output
train_and_eval_differ=False, # whether the module has differing behavior between train and eval
): ):
self.module_cls = module_cls self.module_cls = module_cls
self.module_inputs_func = module_inputs_func self.module_inputs_func = module_inputs_func
@ -166,20 +193,21 @@ class ModuleInfo(object):
self.supports_gradgrad = supports_gradgrad self.supports_gradgrad = supports_gradgrad
self.gradcheck_nondet_tol = gradcheck_nondet_tol self.gradcheck_nondet_tol = gradcheck_nondet_tol
self.module_memformat_affects_out = module_memformat_affects_out self.module_memformat_affects_out = module_memformat_affects_out
self.train_and_eval_differ = train_and_eval_differ
def should_skip(self, cls_name, test_name, device_type, dtype): def should_skip(self, cls_name, test_name, device_type, dtype):
return any(si.is_active(cls_name, test_name, device_type, dtype) for si in self.skips) return any(si.is_active(cls_name, test_name, device_type, dtype) for si in self.skips)
@property @property
def name(self): def name(self):
return formatted_module_name(self.module_cls) return get_module_fully_qualified_name(self.module_cls)
@property @property
def formatted_name(self): def formatted_name(self):
return self.name.replace('.', '_') return self.name.replace('.', '_')
def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, **kwargs): def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
module_inputs = [ module_inputs = [
@ -199,7 +227,7 @@ def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, **k
return module_inputs return module_inputs
def module_inputs_torch_nn_Bilinear(module_info, device, dtype, requires_grad, **kwargs): def module_inputs_torch_nn_Bilinear(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
def bilinear_reference_fn(m, p, x1, x2, bias=True): def bilinear_reference_fn(m, p, x1, x2, bias=True):
@ -228,7 +256,7 @@ def module_inputs_torch_nn_Bilinear(module_info, device, dtype, requires_grad, *
return module_inputs return module_inputs
def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, **kwargs): def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
@ -263,7 +291,7 @@ def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, **
return module_inputs return module_inputs
def module_inputs_torch_nn_GaussianNLLLoss(module_info, device, dtype, requires_grad, **kwargs): def module_inputs_torch_nn_GaussianNLLLoss(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
@ -424,7 +452,7 @@ def generate_regression_criterion_inputs(make_input):
) for reduction in ['none', 'mean', 'sum']] ) for reduction in ['none', 'mean', 'sum']]
def module_inputs_torch_nn_AvgPool1d(module_info, device, dtype, requires_grad, **kwargs): def module_inputs_torch_nn_AvgPool1d(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
return [ return [
@ -434,7 +462,7 @@ def module_inputs_torch_nn_AvgPool1d(module_info, device, dtype, requires_grad,
reference_fn=no_batch_dim_reference_fn)] reference_fn=no_batch_dim_reference_fn)]
def module_inputs_torch_nn_AdaptiveAvgPool2d(module_info, device, dtype, requires_grad, **kwargs): def module_inputs_torch_nn_AdaptiveAvgPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
return [ return [
@ -443,7 +471,7 @@ def module_inputs_torch_nn_AdaptiveAvgPool2d(module_info, device, dtype, require
desc='single')] desc='single')]
def module_inputs_torch_nn_BatchNorm2d(module_info, device, dtype, requires_grad, **kwargs): def module_inputs_torch_nn_BatchNorm2d(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
return [ return [
@ -451,7 +479,7 @@ def module_inputs_torch_nn_BatchNorm2d(module_info, device, dtype, requires_grad
forward_input=FunctionInput(make_input((2, 3, 6, 6))))] forward_input=FunctionInput(make_input((2, 3, 6, 6))))]
def module_inputs_torch_nn_BatchNorm3d(module_info, device, dtype, requires_grad, **kwargs): def module_inputs_torch_nn_BatchNorm3d(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
return [ return [
@ -459,7 +487,7 @@ def module_inputs_torch_nn_BatchNorm3d(module_info, device, dtype, requires_grad
forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))))] forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))))]
def module_inputs_torch_nn_ConvNd(module_info, device, dtype, requires_grad, **kwargs): def module_inputs_torch_nn_ConvNd(module_info, device, dtype, requires_grad, training, **kwargs):
N = kwargs['N'] N = kwargs['N']
lazy = kwargs.get('lazy', False) lazy = kwargs.get('lazy', False)
transposed = kwargs.get('transposed', False) transposed = kwargs.get('transposed', False)
@ -479,7 +507,7 @@ def module_inputs_torch_nn_ConvNd(module_info, device, dtype, requires_grad, **k
] ]
def module_inputs_torch_nn_ELU(module_info, device, dtype, requires_grad, **kwargs): def module_inputs_torch_nn_ELU(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
return [ return [
@ -498,7 +526,7 @@ def module_inputs_torch_nn_ELU(module_info, device, dtype, requires_grad, **kwar
desc='4d_input')] desc='4d_input')]
def module_inputs_torch_nn_CELU(module_info, device, dtype, requires_grad, **kwargs): def module_inputs_torch_nn_CELU(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
return [ return [
@ -515,7 +543,7 @@ def module_inputs_torch_nn_CELU(module_info, device, dtype, requires_grad, **kwa
reference_fn=no_batch_dim_reference_fn)] reference_fn=no_batch_dim_reference_fn)]
def module_inputs_torch_nn_ReLU(module_info, device, dtype, requires_grad): def module_inputs_torch_nn_ReLU(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
return [ return [
@ -530,7 +558,7 @@ def module_inputs_torch_nn_ReLU(module_info, device, dtype, requires_grad):
desc='channels_last_3d_mem_format')] desc='channels_last_3d_mem_format')]
def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, **kwargs): def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
return [ return [
@ -545,7 +573,7 @@ def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, **k
desc='scalar')] + generate_regression_criterion_inputs(make_input) desc='scalar')] + generate_regression_criterion_inputs(make_input)
def module_inputs_torch_nn_CrossEntropyLoss(module_info, device, dtype, requires_grad, **kwargs): def module_inputs_torch_nn_CrossEntropyLoss(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False) make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
@ -579,7 +607,7 @@ def module_inputs_torch_nn_CrossEntropyLoss(module_info, device, dtype, requires
return samples return samples
def module_inputs_torch_nn_Hardswish(module_info, device, dtype, requires_grad, **kwargs): def module_inputs_torch_nn_Hardswish(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
return [ return [
@ -596,7 +624,7 @@ def module_inputs_torch_nn_Hardswish(module_info, device, dtype, requires_grad,
] ]
def module_inputs_torch_nn_MaxPool2d(module_info, device, dtype, requires_grad, **kwargs): def module_inputs_torch_nn_MaxPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
return [ return [
@ -615,7 +643,7 @@ def module_inputs_torch_nn_MaxPool2d(module_info, device, dtype, requires_grad,
] ]
def module_inputs_torch_nn_Sigmoid(module_info, device, dtype, requires_grad, **kwargs): def module_inputs_torch_nn_Sigmoid(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
return [ return [
@ -632,7 +660,7 @@ def module_inputs_torch_nn_Sigmoid(module_info, device, dtype, requires_grad, **
] ]
def module_inputs_torch_nn_TransformerEncoderLayer(module_info, device, dtype, requires_grad, **kwargs): def module_inputs_torch_nn_TransformerEncoderLayer(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
samples = [ samples = [
@ -681,7 +709,7 @@ def module_inputs_torch_nn_TransformerEncoderLayer(module_info, device, dtype, r
return samples return samples
def module_inputs_torch_nn_TransformerDecoderLayer(module_info, device, dtype, requires_grad, **kwargs): def module_inputs_torch_nn_TransformerDecoderLayer(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
samples = [ samples = [
@ -737,7 +765,7 @@ def module_inputs_torch_nn_TransformerDecoderLayer(module_info, device, dtype, r
return samples return samples
def module_inputs_torch_nn_Transformer(module_info, device, dtype, requires_grad, **kwargs): def module_inputs_torch_nn_Transformer(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
samples = [] samples = []
# Samples below are for validating the no-batch-dim support. # Samples below are for validating the no-batch-dim support.
@ -780,7 +808,7 @@ def module_inputs_torch_nn_Transformer(module_info, device, dtype, requires_grad
return samples return samples
def module_inputs_torch_nn_Embedding(module_info, device, dtype, requires_grad, **kwargs): def module_inputs_torch_nn_Embedding(module_info, device, dtype, requires_grad, training, **kwargs):
make_empty = partial(torch.empty, device=device, dtype=torch.long, requires_grad=False) make_empty = partial(torch.empty, device=device, dtype=torch.long, requires_grad=False)
return [ return [
ModuleInput( ModuleInput(
@ -795,7 +823,7 @@ def module_inputs_torch_nn_Embedding(module_info, device, dtype, requires_grad,
] ]
def module_inputs_torch_nn_MultiheadAttention(module_info, device, dtype, requires_grad, **kwargs): def module_inputs_torch_nn_MultiheadAttention(module_info, device, dtype, requires_grad, training, **kwargs):
# Currently all samples below are for validating the no-batch-dim support. # Currently all samples below are for validating the no-batch-dim support.
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
samples = [] samples = []
@ -826,7 +854,7 @@ def module_inputs_torch_nn_MultiheadAttention(module_info, device, dtype, requir
return samples return samples
def module_inputs_torch_nn_RNN_GRU_Cell(module_info, device, dtype, requires_grad, **kwargs): def module_inputs_torch_nn_RNN_GRU_Cell(module_info, device, dtype, requires_grad, training, **kwargs):
# Currently all samples below are for validating the no-batch-dim support. # Currently all samples below are for validating the no-batch-dim support.
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
samples = [ samples = [
@ -857,7 +885,7 @@ def module_inputs_torch_nn_RNN_GRU_Cell(module_info, device, dtype, requires_gra
return samples return samples
def module_inputs_torch_nn_LSTMCell(module_info, device, dtype, requires_grad, **kwargs): def module_inputs_torch_nn_LSTMCell(module_info, device, dtype, requires_grad, training, **kwargs):
# Currently all samples below are for validating the no-batch-dim support. # Currently all samples below are for validating the no-batch-dim support.
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
samples = ( samples = (
@ -876,7 +904,7 @@ def module_inputs_torch_nn_LSTMCell(module_info, device, dtype, requires_grad, *
return samples return samples
def module_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, **kwargs): def module_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, training, **kwargs):
# Currently all samples below are for validating the no-batch-dim support. # Currently all samples below are for validating the no-batch-dim support.
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
is_rnn = kwargs['is_rnn'] is_rnn = kwargs['is_rnn']
@ -923,7 +951,7 @@ def module_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, **
return samples return samples
def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, **kwargs): def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, training, **kwargs):
# Currently all samples below are for validating the no-batch-dim support. # Currently all samples below are for validating the no-batch-dim support.
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
bias = (False, True) bias = (False, True)
@ -1005,6 +1033,7 @@ module_db: List[ModuleInfo] = [
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),) DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
), ),
ModuleInfo(torch.nn.BatchNorm2d, ModuleInfo(torch.nn.BatchNorm2d,
train_and_eval_differ=True,
module_inputs_func=module_inputs_torch_nn_BatchNorm2d, module_inputs_func=module_inputs_torch_nn_BatchNorm2d,
skips=( skips=(
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),), DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),),
@ -1013,6 +1042,7 @@ module_db: List[ModuleInfo] = [
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),) DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),)
), ),
ModuleInfo(torch.nn.BatchNorm3d, ModuleInfo(torch.nn.BatchNorm3d,
train_and_eval_differ=True,
module_inputs_func=module_inputs_torch_nn_BatchNorm3d, module_inputs_func=module_inputs_torch_nn_BatchNorm3d,
skips=( skips=(
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),), DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),),
@ -1273,6 +1303,7 @@ module_db: List[ModuleInfo] = [
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),), DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),),
supports_gradgrad=False), supports_gradgrad=False),
ModuleInfo(torch.nn.TransformerEncoderLayer, ModuleInfo(torch.nn.TransformerEncoderLayer,
train_and_eval_differ=True,
module_inputs_func=module_inputs_torch_nn_TransformerEncoderLayer, module_inputs_func=module_inputs_torch_nn_TransformerEncoderLayer,
skips=( skips=(
# No channels_last support for TransformerEncoderLayer currently. # No channels_last support for TransformerEncoderLayer currently.
@ -1294,6 +1325,7 @@ module_db: List[ModuleInfo] = [
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),) DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
), ),
ModuleInfo(torch.nn.MultiheadAttention, ModuleInfo(torch.nn.MultiheadAttention,
train_and_eval_differ=True,
module_inputs_func=module_inputs_torch_nn_MultiheadAttention, module_inputs_func=module_inputs_torch_nn_MultiheadAttention,
skips=( skips=(
# No channels_last support for MultiheadAttention currently. # No channels_last support for MultiheadAttention currently.
@ -1332,17 +1364,20 @@ module_db: List[ModuleInfo] = [
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),) DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
), ),
ModuleInfo(torch.nn.RNN, ModuleInfo(torch.nn.RNN,
train_and_eval_differ=True,
module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=True), module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=True),
skips=( skips=(
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),), DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),),
decorators=rnn_gru_lstm_module_info_decorators decorators=rnn_gru_lstm_module_info_decorators
), ),
ModuleInfo(torch.nn.GRU, ModuleInfo(torch.nn.GRU,
train_and_eval_differ=True,
module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=False), module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=False),
skips=( skips=(
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),), DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),),
decorators=rnn_gru_lstm_module_info_decorators), decorators=rnn_gru_lstm_module_info_decorators),
ModuleInfo(torch.nn.LSTM, ModuleInfo(torch.nn.LSTM,
train_and_eval_differ=True,
module_inputs_func=module_inputs_torch_nn_LSTM, module_inputs_func=module_inputs_torch_nn_LSTM,
skips=( skips=(
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),), DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),),