mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Support both train / eval modes for ModuleInfo
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78735 Approved by: https://github.com/albanD
This commit is contained in:
parent
79f18c1aee
commit
70d6446a3d
|
|
@ -9,7 +9,7 @@ from operator import methodcaller
|
|||
import torch
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests, onlyCUDA, toleranceOverride, tol, skipMeta)
|
||||
from torch.testing._internal.common_modules import module_db, modules
|
||||
from torch.testing._internal.common_modules import module_db, modules, TrainEvalMode
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck, gradgradcheck, skipIfMps)
|
||||
from unittest.mock import patch, call
|
||||
|
|
@ -42,10 +42,10 @@ class TestModule(TestCase):
|
|||
|
||||
@skipIfMps # the test doesn't work on MPS as double types are not supported
|
||||
@modules(module_db)
|
||||
def test_forward(self, device, dtype, module_info):
|
||||
def test_forward(self, device, dtype, module_info, training):
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
requires_grad=False)
|
||||
requires_grad=False, training=training)
|
||||
dtype_to_method_caller = {
|
||||
torch.float32: methodcaller("float"),
|
||||
torch.float64: methodcaller("double"),
|
||||
|
|
@ -59,6 +59,7 @@ class TestModule(TestCase):
|
|||
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
m = module_cls(*args, **kwargs)
|
||||
m.to(device).to(dtype)
|
||||
m.train(training)
|
||||
|
||||
# === Do forward pass. ===
|
||||
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
||||
|
|
@ -80,10 +81,10 @@ class TestModule(TestCase):
|
|||
# Tests passing factory kwargs (e.g. device / dtype) during module instantiation.
|
||||
# They should be applied to any created parameters and buffers.
|
||||
@modules(module_db)
|
||||
def test_factory_kwargs(self, device, dtype, module_info):
|
||||
def test_factory_kwargs(self, device, dtype, module_info, training):
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
requires_grad=False)
|
||||
requires_grad=False, training=training)
|
||||
for module_input in module_inputs:
|
||||
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
|
||||
|
|
@ -96,6 +97,7 @@ class TestModule(TestCase):
|
|||
register_buffer = mock_wrapper(torch.nn.Module.register_buffer)
|
||||
with patch.object(torch.nn.Module, 'register_buffer', register_buffer):
|
||||
m = module_cls(*args, **kwargs)
|
||||
m.train(training)
|
||||
|
||||
# Check if a parameter or buffer was created with a tensor not passed to the constructor.
|
||||
constructor_tensors = get_tensors_from(args, kwargs)
|
||||
|
|
@ -122,6 +124,7 @@ class TestModule(TestCase):
|
|||
uninit_buffer_new = mock_wrapper(torch.nn.UninitializedBuffer.__new__)
|
||||
with patch.object(torch.nn.UninitializedBuffer, '__new__', uninit_buffer_new):
|
||||
m = module_cls(*args, **kwargs)
|
||||
m.train(training)
|
||||
uninit_param_new.mock.assert_has_calls(
|
||||
[call(device=device, dtype=dtype) for _ in uninit_param_new.mock.mock_calls])
|
||||
uninit_buffer_new.mock.assert_has_calls(
|
||||
|
|
@ -130,16 +133,17 @@ class TestModule(TestCase):
|
|||
# Check device placement and dtype for created parameters and buffers.
|
||||
# Only verify floating point dtypes since that's what the kwarg applies to.
|
||||
m = module_cls(*args, **kwargs)
|
||||
m.train(training)
|
||||
self._assert_module_parameters_and_buffer_are(m, device, dtype)
|
||||
|
||||
@onlyCUDA
|
||||
@modules(module_db)
|
||||
def test_multiple_device_transfer(self, device, dtype, module_info):
|
||||
def test_multiple_device_transfer(self, device, dtype, module_info, training):
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs_device = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
requires_grad=False)
|
||||
requires_grad=False, training=training)
|
||||
module_inputs_cpu = module_info.module_inputs_func(module_info, device="cpu", dtype=dtype,
|
||||
requires_grad=False)
|
||||
requires_grad=False, training=training)
|
||||
for module_input_device, module_input_cpu in zip(module_inputs_device, module_inputs_cpu):
|
||||
if module_input_device.forward_input is None:
|
||||
continue
|
||||
|
|
@ -149,6 +153,7 @@ class TestModule(TestCase):
|
|||
args, kwargs = module_input_device.constructor_input.args, module_input_device.constructor_input.kwargs
|
||||
m = module_cls(*args, **kwargs)
|
||||
m.to(device).to(dtype)
|
||||
m.train(training)
|
||||
|
||||
# === Do forward pass on GPU ===
|
||||
input_device_args = module_input_device.forward_input.args
|
||||
|
|
@ -189,14 +194,16 @@ class TestModule(TestCase):
|
|||
|
||||
|
||||
@modules(module_db)
|
||||
def test_repr(self, device, dtype, module_info):
|
||||
def test_repr(self, device, dtype, module_info, training):
|
||||
# Test module can be represented with repr and str without errors.
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
requires_grad=False)
|
||||
requires_grad=False, training=training)
|
||||
for module_input in module_inputs:
|
||||
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
m = module_cls(*args, **kwargs)
|
||||
m.to(device).to(dtype)
|
||||
m.train(training)
|
||||
|
||||
# Check that these methods do not raise errors
|
||||
m.__repr__()
|
||||
|
|
@ -204,11 +211,11 @@ class TestModule(TestCase):
|
|||
|
||||
@skipIfMps
|
||||
@modules(module_db)
|
||||
def test_pickle(self, device, dtype, module_info):
|
||||
def test_pickle(self, device, dtype, module_info, training):
|
||||
# Test that module can be pickled and unpickled.
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
requires_grad=False)
|
||||
requires_grad=False, training=training)
|
||||
for module_input in module_inputs:
|
||||
if module_input.forward_input is None:
|
||||
continue
|
||||
|
|
@ -220,6 +227,7 @@ class TestModule(TestCase):
|
|||
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
m = module_cls(*args, **kwargs)
|
||||
m.to(device).to(dtype)
|
||||
m.train(training)
|
||||
|
||||
# === Do forward pass. ===
|
||||
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
||||
|
|
@ -233,15 +241,15 @@ class TestModule(TestCase):
|
|||
output_from_copy = m_copy(*args, **kwargs)
|
||||
self.assertEqual(output, output_from_copy)
|
||||
|
||||
@skipMeta
|
||||
@modules([module_info for module_info in module_db
|
||||
if 'inplace' in signature(module_info.module_cls).parameters])
|
||||
@skipMeta
|
||||
def test_check_inplace(self, device, dtype, module_info):
|
||||
def test_check_inplace(self, device, dtype, module_info, training):
|
||||
# Check if the inplace variant of the module gives the same result as the out of place
|
||||
# variant.
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
requires_grad=True)
|
||||
requires_grad=True, training=training)
|
||||
for module_input in module_inputs:
|
||||
if module_input.forward_input is None:
|
||||
continue
|
||||
|
|
@ -250,8 +258,10 @@ class TestModule(TestCase):
|
|||
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
m_op = module_cls(*args, **kwargs, inplace=False)
|
||||
m_op.to(device).to(dtype)
|
||||
m_op.train(training)
|
||||
m_inplace = module_cls(*args, **kwargs, inplace=True)
|
||||
m_inplace.to(device).to(dtype)
|
||||
m_inplace.train(training)
|
||||
|
||||
# === Inplace modules only supports inplace operations on the first argument ===
|
||||
input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
||||
|
|
@ -315,12 +325,12 @@ class TestModule(TestCase):
|
|||
|
||||
@skipIfMps
|
||||
@modules(module_db)
|
||||
def test_non_contiguous_tensors(self, device, dtype, module_info):
|
||||
def test_non_contiguous_tensors(self, device, dtype, module_info, training):
|
||||
# Check modules work with non-contiguous tensors
|
||||
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
requires_grad=True)
|
||||
requires_grad=True, training=training)
|
||||
|
||||
def _make_non_contiguous(obj):
|
||||
def inner_make_non_contiguous(obj):
|
||||
|
|
@ -357,6 +367,7 @@ class TestModule(TestCase):
|
|||
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
m = module_cls(*args, **kwargs)
|
||||
m.to(device).to(dtype)
|
||||
m.train(training)
|
||||
|
||||
self._retain_grad((input_args, input_kwargs))
|
||||
|
||||
|
|
@ -409,11 +420,11 @@ class TestModule(TestCase):
|
|||
self.assertEqual(param_grad, default_param_grad)
|
||||
|
||||
|
||||
def _test_gradients_helper(self, device, dtype, module_info, check):
|
||||
def _test_gradients_helper(self, device, dtype, module_info, training, check):
|
||||
# Check gradients
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
requires_grad=True)
|
||||
requires_grad=True, training=training)
|
||||
# === Set nondet tol for gradcheck to user-defined value if on CUDA and cudNN is enabled
|
||||
gradcheck_nondet_tol = 0.0
|
||||
if (torch.device(device).type == 'cuda' and torch.backends.cudnn.enabled):
|
||||
|
|
@ -427,6 +438,7 @@ class TestModule(TestCase):
|
|||
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
m = module_cls(*args, **kwargs)
|
||||
m.to(device).to(dtype)
|
||||
m.train(training)
|
||||
|
||||
params = tuple(m.parameters())
|
||||
|
||||
|
|
@ -464,23 +476,33 @@ class TestModule(TestCase):
|
|||
|
||||
|
||||
@modules(module_db, allowed_dtypes=[torch.double])
|
||||
def test_grad(self, device, dtype, module_info):
|
||||
self._test_gradients_helper(device, dtype, module_info, gradcheck)
|
||||
def test_grad(self, device, dtype, module_info, training):
|
||||
self._test_gradients_helper(device, dtype, module_info, training, gradcheck)
|
||||
|
||||
@modules([m for m in module_db if m.supports_gradgrad],
|
||||
allowed_dtypes=[torch.double])
|
||||
def test_gradgrad(self, device, dtype, module_info):
|
||||
self._test_gradients_helper(device, dtype, module_info, gradgradcheck)
|
||||
def test_gradgrad(self, device, dtype, module_info, training):
|
||||
self._test_gradients_helper(device, dtype, module_info, training, gradgradcheck)
|
||||
|
||||
@onlyCUDA
|
||||
@toleranceOverride({torch.float32: tol(5e-2, 0),
|
||||
torch.float64: tol(4e-4, 0)})
|
||||
@modules(module_db)
|
||||
def test_cpu_gpu_parity(self, device, dtype, module_info):
|
||||
def test_cpu_gpu_parity(self, device, dtype, module_info, training):
|
||||
# TODO: RNN / GRU / LSTM don't support backwards on eval mode for cuDNN; skip this in a
|
||||
# nicer way for eval mode only.
|
||||
# See https://github.com/pytorch/pytorch/issues/79161
|
||||
rnn_modules = set([torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM])
|
||||
if (module_info.module_cls in rnn_modules
|
||||
and not training
|
||||
and 'cuda' in device
|
||||
and torch.backends.cudnn.enabled):
|
||||
return
|
||||
|
||||
# Test cpu and gpu results are the same
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs_cpu = module_info.module_inputs_func(module_info, device="cpu", dtype=dtype,
|
||||
requires_grad=True)
|
||||
requires_grad=True, training=training)
|
||||
|
||||
def _to_device(obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
|
|
@ -495,7 +517,6 @@ class TestModule(TestCase):
|
|||
return deepcopy(obj)
|
||||
|
||||
for module_input in module_inputs_cpu:
|
||||
|
||||
# === Move input from cpu to device ===
|
||||
cpu_forward_args = module_input.forward_input.args
|
||||
cpu_forward_kwargs = module_input.forward_input.kwargs
|
||||
|
|
@ -508,7 +529,9 @@ class TestModule(TestCase):
|
|||
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
|
||||
cpu_module = module_cls(*args, **kwargs).to(dtype).to("cpu")
|
||||
cpu_module.train(training)
|
||||
gpu_module = module_cls(*args, **kwargs).to(dtype).to(device)
|
||||
gpu_module.train(training)
|
||||
|
||||
for cpu_p, gpu_p in zip(cpu_module.parameters(), gpu_module.parameters()):
|
||||
gpu_p.data.copy_(cpu_p)
|
||||
|
|
@ -549,10 +572,10 @@ class TestModule(TestCase):
|
|||
|
||||
@skipIfMps
|
||||
@modules(module_db)
|
||||
def test_memory_format(self, device, dtype, module_info):
|
||||
def test_memory_format(self, device, dtype, module_info, training):
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
requires_grad=False)
|
||||
requires_grad=False, training=training)
|
||||
module_memformat_affects_out = module_info.module_memformat_affects_out
|
||||
|
||||
def _get_mem_formats(channels_last=False, channels_last_3d=False):
|
||||
|
|
@ -613,6 +636,7 @@ class TestModule(TestCase):
|
|||
|
||||
m = module_cls(*args, **kwargs)
|
||||
m.to(device).to(dtype)
|
||||
m.train(training)
|
||||
|
||||
# === Get output in (contiguous, contiguous) configuration. ===
|
||||
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
||||
|
|
@ -640,6 +664,42 @@ class TestModule(TestCase):
|
|||
# === Check mem format of output. ===
|
||||
_check_out_mem_format(outputs, input_mem_format, module_mem_format)
|
||||
|
||||
# Test whether train and eval modes differ for each module. Use to verify
|
||||
# that the ModuleInfo entry flag is correct.
|
||||
@skipIfMps # the test doesn't work on MPS as double types are not supported
|
||||
@modules(module_db, train_eval_mode=TrainEvalMode.train_only)
|
||||
def test_if_train_and_eval_modes_differ(self, device, dtype, module_info, training):
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
requires_grad=False, training=training)
|
||||
|
||||
# Run forward inputs through to see if the training flag is accessed during forward.
|
||||
for module_input in module_inputs:
|
||||
if module_input.forward_input is None:
|
||||
continue
|
||||
|
||||
# === Instantiate the module. ===
|
||||
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
m = module_cls(*args, **kwargs)
|
||||
m.to(device).to(dtype)
|
||||
m.train(training)
|
||||
|
||||
# Remove training attribute and see if forward still works.
|
||||
delattr(m, 'training')
|
||||
|
||||
# === Do forward pass. ===
|
||||
try:
|
||||
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
||||
m(*args, **kwargs)
|
||||
except AttributeError as e:
|
||||
if "'training'" in str(e):
|
||||
self.assertTrue(module_info.train_and_eval_differ,
|
||||
f"The ModuleInfo entry for {module_info.name} has "
|
||||
"train_and_eval_differ=False, but the training mode was found to "
|
||||
"affect the forward pass. Consider setting train_and_eval_differ=True "
|
||||
"for this ModuleInfo entry.")
|
||||
else:
|
||||
raise e
|
||||
|
||||
instantiate_device_type_tests(TestModule, globals())
|
||||
|
||||
|
|
|
|||
|
|
@ -273,7 +273,7 @@ def _update_param_kwargs(param_kwargs, name, value):
|
|||
if isinstance(value, list) or isinstance(value, tuple):
|
||||
# Make name plural (e.g. devices / dtypes) if the value is composite.
|
||||
param_kwargs['{}s'.format(name)] = value
|
||||
elif value:
|
||||
elif value is not None:
|
||||
param_kwargs[name] = value
|
||||
|
||||
# Leave param_kwargs as-is when value is None.
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from functools import wraps, partial
|
||||
from itertools import chain, product
|
||||
import itertools
|
||||
|
|
@ -50,12 +51,33 @@ for namespace in MODULE_NAMESPACES:
|
|||
MODULE_CLASS_NAMES[module_cls] = f'{namespace_name}.{module_name}'
|
||||
|
||||
|
||||
# Specifies the modes (i.e. train, eval) to test over.
|
||||
TrainEvalMode = Enum('TrainEvalMode', ('train_only', 'eval_only', 'train_and_eval'))
|
||||
|
||||
|
||||
class modules(_TestParametrizer):
|
||||
""" PROTOTYPE: Decorator for specifying a list of modules over which to run a test. """
|
||||
|
||||
def __init__(self, module_info_list, allowed_dtypes=None):
|
||||
def __init__(self, module_info_list, allowed_dtypes=None, train_eval_mode=TrainEvalMode.train_and_eval):
|
||||
self.module_info_list = module_info_list
|
||||
self.allowed_dtypes = set(allowed_dtypes) if allowed_dtypes is not None else None
|
||||
self.train_eval_mode = train_eval_mode
|
||||
|
||||
def _get_training_flags(self, module_info):
|
||||
training_flags = []
|
||||
if (self.train_eval_mode == TrainEvalMode.train_only or
|
||||
self.train_eval_mode == TrainEvalMode.train_and_eval):
|
||||
training_flags.append(True)
|
||||
|
||||
if (self.train_eval_mode == TrainEvalMode.eval_only or
|
||||
self.train_eval_mode == TrainEvalMode.train_and_eval):
|
||||
training_flags.append(False)
|
||||
|
||||
# If train and eval modes don't differ for the module, don't bother using more than one.
|
||||
if not module_info.train_and_eval_differ:
|
||||
training_flags = training_flags[:1]
|
||||
|
||||
return training_flags
|
||||
|
||||
def _parametrize_test(self, test, generic_cls, device_cls):
|
||||
if device_cls is None:
|
||||
|
|
@ -64,18 +86,22 @@ class modules(_TestParametrizer):
|
|||
'instantiate_parametrized_tests()')
|
||||
|
||||
for module_info in self.module_info_list:
|
||||
# Construct the test name; device / dtype parts are handled outside.
|
||||
# See [Note: device and dtype suffix placement]
|
||||
test_name = module_info.name.replace('.', '_')
|
||||
|
||||
dtypes = set(module_info.dtypes)
|
||||
if self.allowed_dtypes is not None:
|
||||
dtypes = dtypes.intersection(self.allowed_dtypes)
|
||||
|
||||
for dtype in dtypes:
|
||||
training_flags = self._get_training_flags(module_info)
|
||||
for (training, dtype) in product(training_flags, dtypes):
|
||||
# Construct the test name; device / dtype parts are handled outside.
|
||||
# See [Note: device and dtype suffix placement]
|
||||
test_name = module_info.formatted_name
|
||||
if len(training_flags) > 1:
|
||||
test_name += f"_{'train_mode' if training else 'eval_mode'}"
|
||||
|
||||
# Construct parameter kwargs to pass to the test.
|
||||
param_kwargs = {'module_info': module_info}
|
||||
_update_param_kwargs(param_kwargs, 'dtype', dtype)
|
||||
_update_param_kwargs(param_kwargs, 'training', training)
|
||||
|
||||
try:
|
||||
active_decorators = [set_single_threaded_if_parallel_tbb]
|
||||
|
|
@ -106,9 +132,9 @@ class modules(_TestParametrizer):
|
|||
raise ex
|
||||
|
||||
|
||||
def formatted_module_name(module_cls):
|
||||
def get_module_fully_qualified_name(module_cls):
|
||||
""" Returns the common name of the module class formatted for use in test names. """
|
||||
return MODULE_CLASS_NAMES[module_cls].replace('.', '_')
|
||||
return MODULE_CLASS_NAMES[module_cls]
|
||||
|
||||
|
||||
class FunctionInput(object):
|
||||
|
|
@ -157,6 +183,7 @@ class ModuleInfo(object):
|
|||
gradcheck_nondet_tol=0.0, # tolerance for nondeterminism while performing gradcheck
|
||||
module_memformat_affects_out=False, # whether converting module to channels last will generate
|
||||
# channels last output
|
||||
train_and_eval_differ=False, # whether the module has differing behavior between train and eval
|
||||
):
|
||||
self.module_cls = module_cls
|
||||
self.module_inputs_func = module_inputs_func
|
||||
|
|
@ -166,20 +193,21 @@ class ModuleInfo(object):
|
|||
self.supports_gradgrad = supports_gradgrad
|
||||
self.gradcheck_nondet_tol = gradcheck_nondet_tol
|
||||
self.module_memformat_affects_out = module_memformat_affects_out
|
||||
self.train_and_eval_differ = train_and_eval_differ
|
||||
|
||||
def should_skip(self, cls_name, test_name, device_type, dtype):
|
||||
return any(si.is_active(cls_name, test_name, device_type, dtype) for si in self.skips)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return formatted_module_name(self.module_cls)
|
||||
return get_module_fully_qualified_name(self.module_cls)
|
||||
|
||||
@property
|
||||
def formatted_name(self):
|
||||
return self.name.replace('.', '_')
|
||||
|
||||
|
||||
def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, **kwargs):
|
||||
def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
module_inputs = [
|
||||
|
|
@ -199,7 +227,7 @@ def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, **k
|
|||
return module_inputs
|
||||
|
||||
|
||||
def module_inputs_torch_nn_Bilinear(module_info, device, dtype, requires_grad, **kwargs):
|
||||
def module_inputs_torch_nn_Bilinear(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
def bilinear_reference_fn(m, p, x1, x2, bias=True):
|
||||
|
|
@ -228,7 +256,7 @@ def module_inputs_torch_nn_Bilinear(module_info, device, dtype, requires_grad, *
|
|||
return module_inputs
|
||||
|
||||
|
||||
def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, **kwargs):
|
||||
def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
|
||||
|
||||
|
|
@ -263,7 +291,7 @@ def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, **
|
|||
return module_inputs
|
||||
|
||||
|
||||
def module_inputs_torch_nn_GaussianNLLLoss(module_info, device, dtype, requires_grad, **kwargs):
|
||||
def module_inputs_torch_nn_GaussianNLLLoss(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
|
||||
|
||||
|
|
@ -424,7 +452,7 @@ def generate_regression_criterion_inputs(make_input):
|
|||
) for reduction in ['none', 'mean', 'sum']]
|
||||
|
||||
|
||||
def module_inputs_torch_nn_AvgPool1d(module_info, device, dtype, requires_grad, **kwargs):
|
||||
def module_inputs_torch_nn_AvgPool1d(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
return [
|
||||
|
|
@ -434,7 +462,7 @@ def module_inputs_torch_nn_AvgPool1d(module_info, device, dtype, requires_grad,
|
|||
reference_fn=no_batch_dim_reference_fn)]
|
||||
|
||||
|
||||
def module_inputs_torch_nn_AdaptiveAvgPool2d(module_info, device, dtype, requires_grad, **kwargs):
|
||||
def module_inputs_torch_nn_AdaptiveAvgPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
return [
|
||||
|
|
@ -443,7 +471,7 @@ def module_inputs_torch_nn_AdaptiveAvgPool2d(module_info, device, dtype, require
|
|||
desc='single')]
|
||||
|
||||
|
||||
def module_inputs_torch_nn_BatchNorm2d(module_info, device, dtype, requires_grad, **kwargs):
|
||||
def module_inputs_torch_nn_BatchNorm2d(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
return [
|
||||
|
|
@ -451,7 +479,7 @@ def module_inputs_torch_nn_BatchNorm2d(module_info, device, dtype, requires_grad
|
|||
forward_input=FunctionInput(make_input((2, 3, 6, 6))))]
|
||||
|
||||
|
||||
def module_inputs_torch_nn_BatchNorm3d(module_info, device, dtype, requires_grad, **kwargs):
|
||||
def module_inputs_torch_nn_BatchNorm3d(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
return [
|
||||
|
|
@ -459,7 +487,7 @@ def module_inputs_torch_nn_BatchNorm3d(module_info, device, dtype, requires_grad
|
|||
forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))))]
|
||||
|
||||
|
||||
def module_inputs_torch_nn_ConvNd(module_info, device, dtype, requires_grad, **kwargs):
|
||||
def module_inputs_torch_nn_ConvNd(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
N = kwargs['N']
|
||||
lazy = kwargs.get('lazy', False)
|
||||
transposed = kwargs.get('transposed', False)
|
||||
|
|
@ -479,7 +507,7 @@ def module_inputs_torch_nn_ConvNd(module_info, device, dtype, requires_grad, **k
|
|||
]
|
||||
|
||||
|
||||
def module_inputs_torch_nn_ELU(module_info, device, dtype, requires_grad, **kwargs):
|
||||
def module_inputs_torch_nn_ELU(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
return [
|
||||
|
|
@ -498,7 +526,7 @@ def module_inputs_torch_nn_ELU(module_info, device, dtype, requires_grad, **kwar
|
|||
desc='4d_input')]
|
||||
|
||||
|
||||
def module_inputs_torch_nn_CELU(module_info, device, dtype, requires_grad, **kwargs):
|
||||
def module_inputs_torch_nn_CELU(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
return [
|
||||
|
|
@ -515,7 +543,7 @@ def module_inputs_torch_nn_CELU(module_info, device, dtype, requires_grad, **kwa
|
|||
reference_fn=no_batch_dim_reference_fn)]
|
||||
|
||||
|
||||
def module_inputs_torch_nn_ReLU(module_info, device, dtype, requires_grad):
|
||||
def module_inputs_torch_nn_ReLU(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
return [
|
||||
|
|
@ -530,7 +558,7 @@ def module_inputs_torch_nn_ReLU(module_info, device, dtype, requires_grad):
|
|||
desc='channels_last_3d_mem_format')]
|
||||
|
||||
|
||||
def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, **kwargs):
|
||||
def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
return [
|
||||
|
|
@ -545,7 +573,7 @@ def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, **k
|
|||
desc='scalar')] + generate_regression_criterion_inputs(make_input)
|
||||
|
||||
|
||||
def module_inputs_torch_nn_CrossEntropyLoss(module_info, device, dtype, requires_grad, **kwargs):
|
||||
def module_inputs_torch_nn_CrossEntropyLoss(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
|
||||
make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
|
||||
|
|
@ -579,7 +607,7 @@ def module_inputs_torch_nn_CrossEntropyLoss(module_info, device, dtype, requires
|
|||
return samples
|
||||
|
||||
|
||||
def module_inputs_torch_nn_Hardswish(module_info, device, dtype, requires_grad, **kwargs):
|
||||
def module_inputs_torch_nn_Hardswish(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
return [
|
||||
|
|
@ -596,7 +624,7 @@ def module_inputs_torch_nn_Hardswish(module_info, device, dtype, requires_grad,
|
|||
]
|
||||
|
||||
|
||||
def module_inputs_torch_nn_MaxPool2d(module_info, device, dtype, requires_grad, **kwargs):
|
||||
def module_inputs_torch_nn_MaxPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
return [
|
||||
|
|
@ -615,7 +643,7 @@ def module_inputs_torch_nn_MaxPool2d(module_info, device, dtype, requires_grad,
|
|||
]
|
||||
|
||||
|
||||
def module_inputs_torch_nn_Sigmoid(module_info, device, dtype, requires_grad, **kwargs):
|
||||
def module_inputs_torch_nn_Sigmoid(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
return [
|
||||
|
|
@ -632,7 +660,7 @@ def module_inputs_torch_nn_Sigmoid(module_info, device, dtype, requires_grad, **
|
|||
]
|
||||
|
||||
|
||||
def module_inputs_torch_nn_TransformerEncoderLayer(module_info, device, dtype, requires_grad, **kwargs):
|
||||
def module_inputs_torch_nn_TransformerEncoderLayer(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
samples = [
|
||||
|
|
@ -681,7 +709,7 @@ def module_inputs_torch_nn_TransformerEncoderLayer(module_info, device, dtype, r
|
|||
return samples
|
||||
|
||||
|
||||
def module_inputs_torch_nn_TransformerDecoderLayer(module_info, device, dtype, requires_grad, **kwargs):
|
||||
def module_inputs_torch_nn_TransformerDecoderLayer(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
samples = [
|
||||
|
|
@ -737,7 +765,7 @@ def module_inputs_torch_nn_TransformerDecoderLayer(module_info, device, dtype, r
|
|||
return samples
|
||||
|
||||
|
||||
def module_inputs_torch_nn_Transformer(module_info, device, dtype, requires_grad, **kwargs):
|
||||
def module_inputs_torch_nn_Transformer(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
samples = []
|
||||
# Samples below are for validating the no-batch-dim support.
|
||||
|
|
@ -780,7 +808,7 @@ def module_inputs_torch_nn_Transformer(module_info, device, dtype, requires_grad
|
|||
return samples
|
||||
|
||||
|
||||
def module_inputs_torch_nn_Embedding(module_info, device, dtype, requires_grad, **kwargs):
|
||||
def module_inputs_torch_nn_Embedding(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_empty = partial(torch.empty, device=device, dtype=torch.long, requires_grad=False)
|
||||
return [
|
||||
ModuleInput(
|
||||
|
|
@ -795,7 +823,7 @@ def module_inputs_torch_nn_Embedding(module_info, device, dtype, requires_grad,
|
|||
]
|
||||
|
||||
|
||||
def module_inputs_torch_nn_MultiheadAttention(module_info, device, dtype, requires_grad, **kwargs):
|
||||
def module_inputs_torch_nn_MultiheadAttention(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
# Currently all samples below are for validating the no-batch-dim support.
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
samples = []
|
||||
|
|
@ -826,7 +854,7 @@ def module_inputs_torch_nn_MultiheadAttention(module_info, device, dtype, requir
|
|||
return samples
|
||||
|
||||
|
||||
def module_inputs_torch_nn_RNN_GRU_Cell(module_info, device, dtype, requires_grad, **kwargs):
|
||||
def module_inputs_torch_nn_RNN_GRU_Cell(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
# Currently all samples below are for validating the no-batch-dim support.
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
samples = [
|
||||
|
|
@ -857,7 +885,7 @@ def module_inputs_torch_nn_RNN_GRU_Cell(module_info, device, dtype, requires_gra
|
|||
return samples
|
||||
|
||||
|
||||
def module_inputs_torch_nn_LSTMCell(module_info, device, dtype, requires_grad, **kwargs):
|
||||
def module_inputs_torch_nn_LSTMCell(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
# Currently all samples below are for validating the no-batch-dim support.
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
samples = (
|
||||
|
|
@ -876,7 +904,7 @@ def module_inputs_torch_nn_LSTMCell(module_info, device, dtype, requires_grad, *
|
|||
return samples
|
||||
|
||||
|
||||
def module_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, **kwargs):
|
||||
def module_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
# Currently all samples below are for validating the no-batch-dim support.
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
is_rnn = kwargs['is_rnn']
|
||||
|
|
@ -923,7 +951,7 @@ def module_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, **
|
|||
return samples
|
||||
|
||||
|
||||
def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, **kwargs):
|
||||
def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
# Currently all samples below are for validating the no-batch-dim support.
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
bias = (False, True)
|
||||
|
|
@ -1005,6 +1033,7 @@ module_db: List[ModuleInfo] = [
|
|||
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
|
||||
),
|
||||
ModuleInfo(torch.nn.BatchNorm2d,
|
||||
train_and_eval_differ=True,
|
||||
module_inputs_func=module_inputs_torch_nn_BatchNorm2d,
|
||||
skips=(
|
||||
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),),
|
||||
|
|
@ -1013,6 +1042,7 @@ module_db: List[ModuleInfo] = [
|
|||
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),)
|
||||
),
|
||||
ModuleInfo(torch.nn.BatchNorm3d,
|
||||
train_and_eval_differ=True,
|
||||
module_inputs_func=module_inputs_torch_nn_BatchNorm3d,
|
||||
skips=(
|
||||
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),),
|
||||
|
|
@ -1273,6 +1303,7 @@ module_db: List[ModuleInfo] = [
|
|||
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),),
|
||||
supports_gradgrad=False),
|
||||
ModuleInfo(torch.nn.TransformerEncoderLayer,
|
||||
train_and_eval_differ=True,
|
||||
module_inputs_func=module_inputs_torch_nn_TransformerEncoderLayer,
|
||||
skips=(
|
||||
# No channels_last support for TransformerEncoderLayer currently.
|
||||
|
|
@ -1294,6 +1325,7 @@ module_db: List[ModuleInfo] = [
|
|||
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
|
||||
),
|
||||
ModuleInfo(torch.nn.MultiheadAttention,
|
||||
train_and_eval_differ=True,
|
||||
module_inputs_func=module_inputs_torch_nn_MultiheadAttention,
|
||||
skips=(
|
||||
# No channels_last support for MultiheadAttention currently.
|
||||
|
|
@ -1332,17 +1364,20 @@ module_db: List[ModuleInfo] = [
|
|||
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
|
||||
),
|
||||
ModuleInfo(torch.nn.RNN,
|
||||
train_and_eval_differ=True,
|
||||
module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=True),
|
||||
skips=(
|
||||
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),),
|
||||
decorators=rnn_gru_lstm_module_info_decorators
|
||||
),
|
||||
ModuleInfo(torch.nn.GRU,
|
||||
train_and_eval_differ=True,
|
||||
module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=False),
|
||||
skips=(
|
||||
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),),
|
||||
decorators=rnn_gru_lstm_module_info_decorators),
|
||||
ModuleInfo(torch.nn.LSTM,
|
||||
train_and_eval_differ=True,
|
||||
module_inputs_func=module_inputs_torch_nn_LSTM,
|
||||
skips=(
|
||||
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user