pytorch/test/test_module_init.py
Philip Meier bc73affdad prepare removal of deprecated functionality in torch.testing (#87969)
_Redo of #86586 with all BC breaking changes granularly placed into separate commits._

---

Per title. Deprecation happened on Feb 25, 2022 in c6f1bbc0ac, which made it into the 1.12 release. Since it is now 245 days later and the next release will be 1.14, the removals later in the stack comply with the [BC policy](https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#minimizing-the-disruption-of-bc-breaking-changes).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87969
Approved by: https://github.com/mruberry
2022-11-02 14:04:48 +00:00

535 lines
24 KiB
Python

# Owner(s): ["module: nn"]
import inspect
import torch
from unittest import mock
from unittest.mock import MagicMock, patch
from torch.testing._internal.common_dtype import floating_types
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
from torch.testing._internal.common_utils import TestCase, run_tests
# Returns a database of args & kwargs that can be used to construct each module.
# Each entry is in class -> (args, kwargs) format.
# Example: torch.nn.Linear -> ([10, 5], {})
# TODO: Merge this in with the initial ModuleInfo implementation.
def build_constructor_arg_db():
return {
torch.nn.AdaptiveAvgPool1d: ((5,), {}),
torch.nn.AdaptiveAvgPool2d: ((5,), {}),
torch.nn.AdaptiveAvgPool3d: ((5,), {}),
torch.nn.AdaptiveLogSoftmaxWithLoss: ((100, 20, [5, 10, 15]), {}),
torch.nn.AdaptiveMaxPool1d: ((5,), {}),
torch.nn.AdaptiveMaxPool2d: ((5,), {}),
torch.nn.AdaptiveMaxPool3d: ((5,), {}),
torch.nn.AlphaDropout: ((), {}),
torch.nn.AvgPool1d: ((3,), {}),
torch.nn.AvgPool2d: ((3,), {}),
torch.nn.AvgPool3d: ((3,), {}),
torch.nn.BCELoss: ((), {}),
torch.nn.BCEWithLogitsLoss: ((), {}),
torch.nn.BatchNorm1d: ((5,), {}),
torch.nn.BatchNorm2d: ((5,), {}),
torch.nn.BatchNorm3d: ((5,), {}),
torch.nn.Bilinear: ((2, 3, 4), {}),
torch.nn.CELU: ((), {}),
torch.nn.CTCLoss: ((), {}),
torch.nn.ChannelShuffle: ((4,), {}),
torch.nn.ConstantPad1d: ((2, 3.5), {}),
torch.nn.ConstantPad2d: ((2, 3.5), {}),
torch.nn.ConstantPad3d: ((2, 3.5), {}),
torch.nn.Conv1d: ((3, 3, 3), {}),
torch.nn.Conv2d: ((3, 3, 3), {}),
torch.nn.Conv3d: ((3, 3, 3), {}),
torch.nn.ConvTranspose1d: ((3, 3, 3), {}),
torch.nn.ConvTranspose2d: ((3, 3, 3), {}),
torch.nn.ConvTranspose3d: ((3, 3, 3), {}),
torch.nn.CosineEmbeddingLoss: ((), {}),
torch.nn.CosineSimilarity: ((), {}),
torch.nn.CrossEntropyLoss: ((), {}),
torch.nn.CrossMapLRN2d: ((5,), {}),
torch.nn.Dropout1d: ((), {}),
torch.nn.Dropout2d: ((), {}),
torch.nn.Dropout3d: ((), {}),
torch.nn.Dropout: ((), {}),
torch.nn.ELU: ((), {}),
torch.nn.Embedding: ((10, 5), {}),
torch.nn.EmbeddingBag: ((10, 5), {}),
torch.nn.FeatureAlphaDropout: ((), {}),
torch.nn.Flatten: ((), {}),
torch.nn.Fold: ((5, 2), {}),
torch.nn.FractionalMaxPool2d: ((5, 2), {}),
torch.nn.FractionalMaxPool3d: ((5, 2), {}),
torch.nn.GELU: ((), {}),
torch.nn.GLU: ((), {}),
torch.nn.GRU: ((5, 10), {}),
torch.nn.GRUCell: ((5, 10), {}),
torch.nn.GaussianNLLLoss: ((), {}),
torch.nn.GroupNorm: ((3, 6, 1e-5, True), {}),
torch.nn.Hardshrink: ((), {}),
torch.nn.Hardsigmoid: ((), {}),
torch.nn.Hardswish: ((), {}),
torch.nn.Hardtanh: ((), {}),
torch.nn.HingeEmbeddingLoss: ((), {}),
torch.nn.HuberLoss: ((), {}),
torch.nn.Identity: ((), {}),
torch.nn.InstanceNorm1d: ((5, 1e-5, 0.1, True), {}),
torch.nn.InstanceNorm2d: ((5, 1e-5, 0.1, True), {}),
torch.nn.InstanceNorm3d: ((5, 1e-5, 0.1, True), {}),
torch.nn.KLDivLoss: ((), {}),
torch.nn.L1Loss: ((), {}),
torch.nn.LPPool1d: ((2, 3), {}),
torch.nn.LPPool2d: ((2, 3), {}),
torch.nn.LSTM: ((5, 10), {}),
torch.nn.LSTMCell: ((5, 10), {}),
torch.nn.LayerNorm: ((2,), {}),
torch.nn.LazyBatchNorm1d: ((), {}),
torch.nn.LazyBatchNorm2d: ((), {}),
torch.nn.LazyBatchNorm3d: ((), {}),
torch.nn.LazyConv1d: ((5, 2), {}),
torch.nn.LazyConv2d: ((5, 2), {}),
torch.nn.LazyConv3d: ((5, 2), {}),
torch.nn.LazyConvTranspose1d: ((5, 2), {}),
torch.nn.LazyConvTranspose2d: ((5, 2), {}),
torch.nn.LazyConvTranspose3d: ((5, 2), {}),
torch.nn.LazyInstanceNorm1d: ((), {}),
torch.nn.LazyInstanceNorm2d: ((), {}),
torch.nn.LazyInstanceNorm3d: ((), {}),
torch.nn.LazyLinear: ((5,), {}),
torch.nn.LeakyReLU: ((), {}),
torch.nn.Linear: ((10, 5), {}),
torch.nn.LocalResponseNorm: ((2,), {}),
torch.nn.LogSigmoid: ((), {}),
torch.nn.LogSoftmax: ((), {}),
torch.nn.MSELoss: ((), {}),
torch.nn.MarginRankingLoss: ((), {}),
torch.nn.MaxPool1d: ((3,), {}),
torch.nn.MaxPool2d: ((3,), {}),
torch.nn.MaxPool3d: ((3,), {}),
torch.nn.MaxUnpool1d: ((5,), {}),
torch.nn.MaxUnpool2d: ((5,), {}),
torch.nn.MaxUnpool3d: ((5,), {}),
torch.nn.Mish: ((), {}),
torch.nn.ModuleDict: ((), {}),
torch.nn.ModuleList: ((), {}),
torch.nn.MultiLabelMarginLoss: ((), {}),
torch.nn.MultiLabelSoftMarginLoss: ((), {}),
torch.nn.MultiMarginLoss: ((), {}),
torch.nn.MultiheadAttention: ((100, 2), {}),
torch.nn.NLLLoss2d: ((), {}),
torch.nn.NLLLoss: ((), {}),
torch.nn.PReLU: ((), {}),
torch.nn.PairwiseDistance: ((), {}),
torch.nn.ParameterDict: ((), {}),
torch.nn.ParameterList: ((), {}),
torch.nn.PixelShuffle: ((2,), {}),
torch.nn.PixelUnshuffle: ((2,), {}),
torch.nn.PoissonNLLLoss: ((), {}),
torch.nn.RNN: ((5, 10), {}),
torch.nn.RNNBase: (('LSTM', 5, 10), {}),
torch.nn.RNNCell: ((5, 10), {}),
torch.nn.RNNCellBase: ((5, 10, True, 2), {}),
torch.nn.RReLU: ((), {}),
torch.nn.ReLU6: ((), {}),
torch.nn.ReLU: ((), {}),
torch.nn.ReflectionPad1d: ((2,), {}),
torch.nn.ReflectionPad2d: ((2,), {}),
torch.nn.ReflectionPad3d: ((2,), {}),
torch.nn.ReplicationPad1d: ((2,), {}),
torch.nn.ReplicationPad2d: ((2,), {}),
torch.nn.ReplicationPad3d: ((2,), {}),
torch.nn.SELU: ((), {}),
torch.nn.Sequential: ((), {}),
torch.nn.SiLU: ((), {}),
torch.nn.Sigmoid: ((), {}),
torch.nn.SmoothL1Loss: ((), {}),
torch.nn.SoftMarginLoss: ((), {}),
torch.nn.Softmax2d: ((), {}),
torch.nn.Softmax: ((), {}),
torch.nn.Softmin: ((), {}),
torch.nn.Softplus: ((), {}),
torch.nn.Softshrink: ((), {}),
torch.nn.Softsign: ((), {}),
torch.nn.SyncBatchNorm: ((5,), {}),
torch.nn.Tanh: ((), {}),
torch.nn.Tanhshrink: ((), {}),
torch.nn.Threshold: ((0.1, 20), {}),
torch.nn.Transformer: ((), {}),
torch.nn.TransformerDecoder: ((torch.nn.TransformerDecoderLayer, 3), {}),
torch.nn.TransformerDecoderLayer: ((10, 2), {}),
torch.nn.TransformerEncoder: ((torch.nn.TransformerEncoderLayer, 3), {}),
torch.nn.TransformerEncoderLayer: ((10, 2), {}),
torch.nn.TripletMarginLoss: ((), {}),
torch.nn.TripletMarginWithDistanceLoss: ((), {}),
torch.nn.Unflatten: ((1, (2, 5, 5)), {}),
torch.nn.Unfold: ((3,), {}),
torch.nn.Upsample: ((), {}),
torch.nn.UpsamplingBilinear2d: ((), {}),
torch.nn.UpsamplingNearest2d: ((), {}),
torch.nn.ZeroPad2d: ((0,), {}),
torch.ao.nn.qat.Conv1d: ((3, 3, 3), {
'qconfig': torch.ao.quantization.default_qconfig,
}),
torch.ao.nn.qat.Conv2d: ((3, 3, 3), {
'qconfig': torch.ao.quantization.default_qconfig,
}),
torch.ao.nn.qat.Conv3d: ((3, 3, 3), {
'qconfig': torch.ao.quantization.default_qconfig,
}),
torch.ao.nn.qat.Linear: ((5, 2), {
'qconfig': torch.ao.quantization.default_qconfig,
}),
torch.ao.nn.qat.Embedding: ((10, 12), {
'qconfig': torch.ao.quantization.float_qparams_weight_only_qconfig,
}),
torch.ao.nn.qat.EmbeddingBag: ((10, 12), {
'qconfig': torch.ao.quantization.float_qparams_weight_only_qconfig,
}),
torch.nn.quantizable.LSTM: ((5, 6), {}),
torch.nn.quantizable.LSTMCell: ((5, 6), {}),
torch.nn.quantizable.MultiheadAttention: ((10, 2), {}),
torch.ao.nn.quantized.BatchNorm2d: ((2,), {}),
torch.ao.nn.quantized.BatchNorm3d: ((2,), {}),
torch.ao.nn.quantized.Dropout: ((), {}),
torch.ao.nn.quantized.Conv1d: ((3, 3, 3), {}),
torch.ao.nn.quantized.Conv2d: ((3, 3, 3), {}),
torch.ao.nn.quantized.Conv3d: ((3, 3, 3), {}),
torch.ao.nn.quantized.ConvTranspose1d: ((3, 3, 3), {}),
torch.ao.nn.quantized.ConvTranspose2d: ((3, 3, 3), {}),
torch.ao.nn.quantized.ConvTranspose3d: ((16, 33, (3, 3, 5)), {
'stride': (2, 1, 1),
'padding': (4, 2, 2),
'output_padding': (2, 2, 2),
'dilation': (1, 1, 1),
}),
torch.ao.nn.quantized.DeQuantize: ((), {}),
torch.ao.nn.quantized.ELU: ((0.01, 0), {}),
torch.ao.nn.quantized.Embedding: ((10, 3), {
'factory_kwargs': {},
}),
torch.ao.nn.quantized.EmbeddingBag: ((10, 3), {
'factory_kwargs': {},
}),
torch.ao.nn.quantized.GroupNorm: ((2, 4, torch.nn.Parameter(torch.tensor(2.)),
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
torch.ao.nn.quantized.Hardswish: ((0.1, 0,), {}),
torch.ao.nn.quantized.InstanceNorm1d: ((2, torch.nn.Parameter(torch.tensor(2.)),
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
torch.ao.nn.quantized.InstanceNorm2d: ((2, torch.nn.Parameter(torch.tensor(2.)),
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
torch.ao.nn.quantized.InstanceNorm3d: ((2, torch.nn.Parameter(torch.tensor(2.)),
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
torch.ao.nn.quantized.LayerNorm: ((2, torch.nn.Parameter(torch.tensor(2.)),
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
torch.ao.nn.quantized.LeakyReLU: ((0.01, 0), {}),
torch.ao.nn.quantized.Linear: ((5, 2), {
'factory_kwargs': {},
}),
torch.ao.nn.quantized.MaxPool2d: ((3,), {}),
torch.ao.nn.quantized.Quantize: ((0.1, 0), {
'dtype': torch.int16,
'factory_kwargs': {},
}),
torch.ao.nn.quantized.ReLU6: ((), {}),
torch.ao.nn.quantized.Sigmoid: ((0.1, 0), {}),
torch.ao.nn.quantized.Softmax: ((), {}),
torch.ao.nn.quantized.FloatFunctional: ((), {}),
torch.ao.nn.quantized.FXFloatFunctional: ((), {}),
torch.ao.nn.quantized.QFunctional: ((), {}),
# Remove torch.nn.quantized after the migration completes:
torch.nn.qat.Conv1d: ((3, 3, 3), {
'qconfig': torch.ao.quantization.default_qconfig,
}),
torch.nn.qat.Conv2d: ((3, 3, 3), {
'qconfig': torch.ao.quantization.default_qconfig,
}),
torch.nn.qat.Conv3d: ((3, 3, 3), {
'qconfig': torch.ao.quantization.default_qconfig,
}),
torch.nn.qat.Linear: ((5, 2), {
'qconfig': torch.ao.quantization.default_qconfig,
}),
torch.nn.qat.Embedding: ((10, 12), {
'qconfig': torch.ao.quantization.float_qparams_weight_only_qconfig,
}),
torch.nn.qat.EmbeddingBag: ((10, 12), {
'qconfig': torch.ao.quantization.float_qparams_weight_only_qconfig,
}),
torch.nn.quantized.BatchNorm2d: ((2,), {}),
torch.nn.quantized.BatchNorm3d: ((2,), {}),
torch.nn.quantized.Dropout: ((), {}),
torch.nn.quantized.Conv1d: ((3, 3, 3), {}),
torch.nn.quantized.Conv2d: ((3, 3, 3), {}),
torch.nn.quantized.Conv3d: ((3, 3, 3), {}),
torch.nn.quantized.ConvTranspose1d: ((3, 3, 3), {}),
torch.nn.quantized.ConvTranspose2d: ((3, 3, 3), {}),
torch.nn.quantized.ConvTranspose3d: ((16, 33, (3, 3, 5)), {
'stride': (2, 1, 1),
'padding': (4, 2, 2),
'output_padding': (2, 2, 2),
'dilation': (1, 1, 1),
}),
torch.nn.quantized.DeQuantize: ((), {}),
torch.nn.quantized.ELU: ((0.01, 0), {}),
torch.nn.quantized.Embedding: ((10, 3), {
'factory_kwargs': {},
}),
torch.nn.quantized.EmbeddingBag: ((10, 3), {
'factory_kwargs': {},
}),
torch.nn.quantized.GroupNorm: ((2, 4, torch.nn.Parameter(torch.tensor(2.)),
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
torch.nn.quantized.Hardswish: ((0.1, 0,), {}),
torch.nn.quantized.InstanceNorm1d: ((2, torch.nn.Parameter(torch.tensor(2.)),
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
torch.nn.quantized.InstanceNorm2d: ((2, torch.nn.Parameter(torch.tensor(2.)),
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
torch.nn.quantized.InstanceNorm3d: ((2, torch.nn.Parameter(torch.tensor(2.)),
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
torch.nn.quantized.LayerNorm: ((2, torch.nn.Parameter(torch.tensor(2.)),
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
torch.nn.quantized.LeakyReLU: ((0.01, 0), {}),
torch.nn.quantized.Linear: ((5, 2), {
'factory_kwargs': {},
}),
torch.nn.quantized.MaxPool2d: ((3,), {}),
torch.nn.quantized.PReLU: ((0.01, 0), {}),
torch.nn.quantized.Quantize: ((0.1, 0), {
'dtype': torch.int16,
'factory_kwargs': {},
}),
torch.nn.quantized.ReLU6: ((), {}),
torch.nn.quantized.Sigmoid: ((0.1, 0), {}),
torch.nn.quantized.Softmax: ((), {}),
torch.nn.quantized.FloatFunctional: ((), {}),
torch.nn.quantized.FXFloatFunctional: ((), {}),
torch.nn.quantized.QFunctional: ((), {}),
}
# Instantiates the given class with the given args, kwargs, optionally on a given device.
def instantiate_class(cls, args, kwargs, extra_kwargs):
return cls(*args, **kwargs) if extra_kwargs is None else cls(*args, **kwargs, **extra_kwargs)
# Returns a function that calls the real implementation of a method
# in addition to passing args to a mock object.
def mock_wrapper(method):
mock = MagicMock()
def wrapper(self, *args, **kwargs):
mock(*args, **kwargs)
return method(self, *args, **kwargs)
wrapper.mock = mock
return wrapper
# Returns a set of args / kwargs that can be used to construct the module.
def get_example_args(module_cls, constructor_arg_db, extra_kwargs=None):
assert module_cls in constructor_arg_db, \
f"No entry for {module_cls} in the constructor arg DB. Please add it to pass these tests."
args, kwargs = constructor_arg_db[module_cls]
extra_kwargs = {} if extra_kwargs is None else extra_kwargs
# Recursively instantiate args / kwargs that are class objects.
args = [instantiate_class(arg, *get_example_args(arg, constructor_arg_db), extra_kwargs=extra_kwargs)
if inspect.isclass(arg) else torch.nn.Parameter(arg.to(**extra_kwargs))
if isinstance(arg, torch.nn.Parameter) else arg for arg in args]
kwargs = {k: instantiate_class(v, *get_example_args(v, constructor_arg_db), extra_kwargs=extra_kwargs)
if inspect.isclass(v) else torch.nn.Parameter(v.to(*extra_kwargs))
if isinstance(v, torch.nn.Parameter) else v for k, v in kwargs.items()}
kwargs.update(extra_kwargs)
return args, kwargs
def generate_test_func(test_cls, module_cls, constructor_arg_db,
verify_kwargs=True, module_is_lazy=False, check_nonexistent_arg=True):
# Generate a function for testing the given module.
@dtypes(*floating_types())
def run_test(test_cls, device, dtype, module_cls=module_cls):
# Check if this module creates parameters or registers buffers.
# The mock magic here passes through to the real Parameter / register_buffer
# logic and is only used to check for calls.
args, kwargs = get_example_args(module_cls, constructor_arg_db)
# Some modules need to pass factory_kwargs so as not to conflict with existing args such as dtype.
module_needs_factory_kwargs = 'factory_kwargs' in kwargs
if module_needs_factory_kwargs:
del kwargs['factory_kwargs']
extra_kwargs = {
'factory_kwargs': {
'device': device,
'dtype': dtype,
}
}
else:
extra_kwargs = {
'device': device,
'dtype': dtype,
}
parameter_new = mock_wrapper(torch.nn.Parameter.__new__)
with patch.object(torch.nn.Parameter, '__new__', parameter_new):
register_buffer = mock_wrapper(torch.nn.Module.register_buffer)
with patch.object(torch.nn.Module, 'register_buffer', register_buffer):
m = module_cls(*args, **kwargs)
module_creates_params_or_buffers = parameter_new.mock.called or register_buffer.mock.called
# == Verify factory kwargs are supported. ==
if verify_kwargs and module_creates_params_or_buffers:
args, kwargs = get_example_args(module_cls, constructor_arg_db,
extra_kwargs=extra_kwargs)
if module_is_lazy:
# Ensure device and dtype are passed to all UninitializedParameters and UninitializedBuffers.
uninit_param_new = mock_wrapper(torch.nn.UninitializedParameter.__new__)
with patch.object(torch.nn.UninitializedParameter, '__new__', uninit_param_new):
uninit_buffer_new = mock_wrapper(torch.nn.UninitializedBuffer.__new__)
with patch.object(torch.nn.UninitializedBuffer, '__new__', uninit_buffer_new):
m = module_cls(*args, **kwargs)
uninit_param_new.mock.assert_has_calls(
[mock.call(device=device, dtype=dtype) for _ in uninit_param_new.mock.mock_calls])
uninit_buffer_new.mock.assert_has_calls(
[mock.call(device=device, dtype=dtype) for _ in uninit_buffer_new.mock.mock_calls])
else:
# Check device placement and dtype for parameters and buffers.
# Only verify floating point dtypes since that's what the kwarg applies to.
# Note that dtype verification is also skipped if the module requires factory_kwargs.
m = module_cls(*args, **kwargs)
for name, param in m.named_parameters():
test_cls.assertEqual(
str(param.device), device,
f'Parameter {name} is on {param.device.type} instead of the expected device {device}')
if param.dtype.is_floating_point and not module_needs_factory_kwargs:
test_cls.assertEqual(
param.dtype, dtype,
f'Parameter {name} is of dtype {param.dtype} instead of the expected dtype {dtype}')
for name, buffer in m.named_buffers():
test_cls.assertEqual(
str(buffer.device), device,
f'Buffer {name} is on {buffer.device.type} instead of the expected device {device}')
if buffer.dtype.is_floating_point and not module_needs_factory_kwargs:
test_cls.assertEqual(
buffer.dtype, dtype,
f'Buffer {name} is of dtype {buffer.dtype} instead of the expected dtype {dtype}')
# == Verify passing a nonexistent arg errors out. ==
if check_nonexistent_arg:
with test_cls.assertRaises(TypeError):
m = module_cls(*args, **kwargs, nonexistent_arg='foo')
return run_test
def generate_tests(test_cls, constructor_arg_db):
# test all modules underneath these namespaces...
NAMESPACES = [
torch.nn,
torch.ao.nn.qat,
torch.ao.nn.quantized,
torch.nn.qat,
torch.nn.quantizable,
torch.nn.quantized,
]
# ...except these
MODULES_TO_SKIP = {
torch.nn.Module,
torch.nn.Container, # deprecated
torch.nn.NLLLoss2d, # deprecated
# TODO: Remove these 4 from this list once the ASan issue is fixed.
# See https://github.com/pytorch/pytorch/issues/55396
torch.ao.nn.quantized.Embedding,
torch.ao.nn.quantized.EmbeddingBag,
torch.nn.quantized.Embedding,
torch.nn.quantized.EmbeddingBag,
torch.nn.quantized.LSTM,
torch.nn.quantized.MultiheadAttention,
}
# no need to support kwargs for these modules even though
# they have parameters / buffers because they are passed in
# already instantiated s
MODULES_WITHOUT_KWARGS_SUPPORT = {
torch.nn.BCELoss,
torch.nn.BCEWithLogitsLoss,
torch.nn.CrossEntropyLoss,
torch.nn.FractionalMaxPool2d,
torch.nn.FractionalMaxPool3d,
torch.nn.MultiLabelSoftMarginLoss,
torch.nn.MultiMarginLoss,
torch.nn.NLLLoss,
torch.nn.TransformerDecoder,
torch.nn.TransformerEncoder,
}
# modules that supported kwargs before
MODULES_WITH_PREVIOUS_KWARGS = {
torch.nn.Identity,
}
# lazy modules don't instantiate parameters right away
LAZY_MODULES = {
torch.nn.LazyBatchNorm1d,
torch.nn.LazyBatchNorm2d,
torch.nn.LazyBatchNorm3d,
torch.nn.LazyConv1d,
torch.nn.LazyConv2d,
torch.nn.LazyConv3d,
torch.nn.LazyConvTranspose1d,
torch.nn.LazyConvTranspose2d,
torch.nn.LazyConvTranspose3d,
torch.nn.LazyConvTranspose3d,
torch.nn.LazyInstanceNorm1d,
torch.nn.LazyInstanceNorm2d,
torch.nn.LazyInstanceNorm3d,
torch.nn.LazyLinear,
}
# these modules requires FBGEMM backend to instantiate
MODULES_THAT_REQUIRE_FBGEMM = {
torch.ao.nn.quantized.Conv1d,
torch.ao.nn.quantized.Conv2d,
torch.ao.nn.quantized.Conv3d,
torch.ao.nn.quantized.ConvTranspose1d,
torch.ao.nn.quantized.ConvTranspose2d,
torch.ao.nn.quantized.ConvTranspose3d,
torch.ao.nn.quantized.Linear,
# Remove the lines below after AO migration is complete
torch.nn.quantized.Conv1d,
torch.nn.quantized.Conv2d,
torch.nn.quantized.Conv3d,
torch.nn.quantized.ConvTranspose1d,
torch.nn.quantized.ConvTranspose2d,
torch.nn.quantized.ConvTranspose3d,
torch.nn.quantized.Linear,
}
for namespace in NAMESPACES:
# the "nn" in "torch.nn"
namespace_basename = namespace.__name__.split('.')[-1]
for module_name in namespace.modules.__all__:
# class object for this module (e.g. torch.nn.Linear)
module_cls = getattr(namespace.modules, module_name)
if module_cls in MODULES_TO_SKIP:
continue
verify_kwargs = module_cls not in MODULES_WITHOUT_KWARGS_SUPPORT
module_is_lazy = module_cls in LAZY_MODULES
check_nonexistent_arg = module_cls not in MODULES_WITH_PREVIOUS_KWARGS
# Generate a function for testing this module and setattr it onto the test class.
run_test = generate_test_func(test_cls, module_cls, constructor_arg_db,
verify_kwargs=verify_kwargs,
module_is_lazy=module_is_lazy,
check_nonexistent_arg=check_nonexistent_arg)
test_name = f'test_{namespace_basename}_{module_name}'
if module_cls in MODULES_THAT_REQUIRE_FBGEMM:
run_test = skipIfNoFBGEMM(run_test)
setattr(TestModuleInit, test_name, run_test)
class TestModuleInit(TestCase):
_ignore_not_implemented_error = False
generate_tests(TestModuleInit, build_constructor_arg_db())
instantiate_device_type_tests(TestModuleInit, globals())
if __name__ == '__main__':
run_tests()