mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert D27855386: [pytorch][PR] Support factory kwargs in torch.nn modules
Test Plan: revert-hammer
Differential Revision:
D27855386 (40483acc51)
Original commit changeset: dabd505d2a04
fbshipit-source-id: f5bf3120d87861b30a8e1bf11977ad7d27cd8500
This commit is contained in:
parent
b1282bc109
commit
92d24e3060
|
|
@ -62,7 +62,6 @@ TESTS = [
|
|||
'test_linalg',
|
||||
'test_logging',
|
||||
'test_mkldnn',
|
||||
'test_module_init',
|
||||
'test_multiprocessing',
|
||||
'test_multiprocessing_spawn',
|
||||
'distributed/test_nccl',
|
||||
|
|
|
|||
|
|
@ -1,428 +0,0 @@
|
|||
import inspect
|
||||
import torch
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
from torch.testing import floating_types
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes
|
||||
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
|
||||
|
||||
# Returns a database of args & kwargs that can be used to construct each module.
|
||||
# Each entry is in class -> (args, kwargs) format.
|
||||
# Example: torch.nn.Linear -> ([10, 5], {})
|
||||
# TODO: Merge this in with the initial ModuleInfo implementation.
|
||||
def build_constructor_arg_db():
|
||||
return {
|
||||
torch.nn.AdaptiveAvgPool1d: ((5,), {}),
|
||||
torch.nn.AdaptiveAvgPool2d: ((5,), {}),
|
||||
torch.nn.AdaptiveAvgPool3d: ((5,), {}),
|
||||
torch.nn.AdaptiveLogSoftmaxWithLoss: ((100, 20, [5, 10, 15]), {}),
|
||||
torch.nn.AdaptiveMaxPool1d: ((5,), {}),
|
||||
torch.nn.AdaptiveMaxPool2d: ((5,), {}),
|
||||
torch.nn.AdaptiveMaxPool3d: ((5,), {}),
|
||||
torch.nn.AlphaDropout: ((), {}),
|
||||
torch.nn.AvgPool1d: ((3,), {}),
|
||||
torch.nn.AvgPool2d: ((3,), {}),
|
||||
torch.nn.AvgPool3d: ((3,), {}),
|
||||
torch.nn.BCELoss: ((), {}),
|
||||
torch.nn.BCEWithLogitsLoss: ((), {}),
|
||||
torch.nn.BatchNorm1d: ((5,), {}),
|
||||
torch.nn.BatchNorm2d: ((5,), {}),
|
||||
torch.nn.BatchNorm3d: ((5,), {}),
|
||||
torch.nn.Bilinear: ((2, 3, 4), {}),
|
||||
torch.nn.CELU: ((), {}),
|
||||
torch.nn.CTCLoss: ((), {}),
|
||||
torch.nn.ChannelShuffle: ((4,), {}),
|
||||
torch.nn.ConstantPad1d: ((2, 3.5), {}),
|
||||
torch.nn.ConstantPad2d: ((2, 3.5), {}),
|
||||
torch.nn.ConstantPad3d: ((2, 3.5), {}),
|
||||
torch.nn.Conv1d: ((3, 3, 3), {}),
|
||||
torch.nn.Conv2d: ((3, 3, 3), {}),
|
||||
torch.nn.Conv3d: ((3, 3, 3), {}),
|
||||
torch.nn.ConvTranspose1d: ((3, 3, 3), {}),
|
||||
torch.nn.ConvTranspose2d: ((3, 3, 3), {}),
|
||||
torch.nn.ConvTranspose3d: ((3, 3, 3), {}),
|
||||
torch.nn.CosineEmbeddingLoss: ((), {}),
|
||||
torch.nn.CosineSimilarity: ((), {}),
|
||||
torch.nn.CrossEntropyLoss: ((), {}),
|
||||
torch.nn.CrossMapLRN2d: ((5,), {}),
|
||||
torch.nn.Dropout2d: ((), {}),
|
||||
torch.nn.Dropout3d: ((), {}),
|
||||
torch.nn.Dropout: ((), {}),
|
||||
torch.nn.ELU: ((), {}),
|
||||
torch.nn.Embedding: ((10, 5), {}),
|
||||
torch.nn.EmbeddingBag: ((10, 5), {}),
|
||||
torch.nn.FeatureAlphaDropout: ((), {}),
|
||||
torch.nn.Flatten: ((), {}),
|
||||
torch.nn.Fold: ((5, 2), {}),
|
||||
torch.nn.FractionalMaxPool2d: ((5, 2), {}),
|
||||
torch.nn.FractionalMaxPool3d: ((5, 2), {}),
|
||||
torch.nn.GELU: ((), {}),
|
||||
torch.nn.GLU: ((), {}),
|
||||
torch.nn.GRU: ((5, 10), {}),
|
||||
torch.nn.GRUCell: ((5, 10), {}),
|
||||
torch.nn.GaussianNLLLoss: ((), {}),
|
||||
torch.nn.GroupNorm: ((3, 6, 1e-5, True), {}),
|
||||
torch.nn.Hardshrink: ((), {}),
|
||||
torch.nn.Hardsigmoid: ((), {}),
|
||||
torch.nn.Hardswish: ((), {}),
|
||||
torch.nn.Hardtanh: ((), {}),
|
||||
torch.nn.HingeEmbeddingLoss: ((), {}),
|
||||
torch.nn.HuberLoss: ((), {}),
|
||||
torch.nn.Identity: ((), {}),
|
||||
torch.nn.InstanceNorm1d: ((5, 1e-5, 0.1, True), {}),
|
||||
torch.nn.InstanceNorm2d: ((5, 1e-5, 0.1, True), {}),
|
||||
torch.nn.InstanceNorm3d: ((5, 1e-5, 0.1, True), {}),
|
||||
torch.nn.KLDivLoss: ((), {}),
|
||||
torch.nn.L1Loss: ((), {}),
|
||||
torch.nn.LPPool1d: ((2, 3), {}),
|
||||
torch.nn.LPPool2d: ((2, 3), {}),
|
||||
torch.nn.LSTM: ((5, 10), {}),
|
||||
torch.nn.LSTMCell: ((5, 10), {}),
|
||||
torch.nn.LayerNorm: ((2,), {}),
|
||||
torch.nn.LazyBatchNorm1d: ((), {}),
|
||||
torch.nn.LazyBatchNorm2d: ((), {}),
|
||||
torch.nn.LazyBatchNorm3d: ((), {}),
|
||||
torch.nn.LazyConv1d: ((5, 2), {}),
|
||||
torch.nn.LazyConv2d: ((5, 2), {}),
|
||||
torch.nn.LazyConv3d: ((5, 2), {}),
|
||||
torch.nn.LazyConvTranspose1d: ((5, 2), {}),
|
||||
torch.nn.LazyConvTranspose2d: ((5, 2), {}),
|
||||
torch.nn.LazyConvTranspose3d: ((5, 2), {}),
|
||||
torch.nn.LazyLinear: ((5,), {}),
|
||||
torch.nn.LeakyReLU: ((), {}),
|
||||
torch.nn.Linear: ((10, 5), {}),
|
||||
torch.nn.LocalResponseNorm: ((2,), {}),
|
||||
torch.nn.LogSigmoid: ((), {}),
|
||||
torch.nn.LogSoftmax: ((), {}),
|
||||
torch.nn.MSELoss: ((), {}),
|
||||
torch.nn.MarginRankingLoss: ((), {}),
|
||||
torch.nn.MaxPool1d: ((3,), {}),
|
||||
torch.nn.MaxPool2d: ((3,), {}),
|
||||
torch.nn.MaxPool3d: ((3,), {}),
|
||||
torch.nn.MaxUnpool1d: ((5,), {}),
|
||||
torch.nn.MaxUnpool2d: ((5,), {}),
|
||||
torch.nn.MaxUnpool3d: ((5,), {}),
|
||||
torch.nn.ModuleDict: ((), {}),
|
||||
torch.nn.ModuleList: ((), {}),
|
||||
torch.nn.MultiLabelMarginLoss: ((), {}),
|
||||
torch.nn.MultiLabelSoftMarginLoss: ((), {}),
|
||||
torch.nn.MultiMarginLoss: ((), {}),
|
||||
torch.nn.MultiheadAttention: ((100, 2), {}),
|
||||
torch.nn.NLLLoss2d: ((), {}),
|
||||
torch.nn.NLLLoss: ((), {}),
|
||||
torch.nn.PReLU: ((), {}),
|
||||
torch.nn.PairwiseDistance: ((), {}),
|
||||
torch.nn.ParameterDict: ((), {}),
|
||||
torch.nn.ParameterList: ((), {}),
|
||||
torch.nn.PixelShuffle: ((2,), {}),
|
||||
torch.nn.PixelUnshuffle: ((2,), {}),
|
||||
torch.nn.PoissonNLLLoss: ((), {}),
|
||||
torch.nn.RNN: ((5, 10), {}),
|
||||
torch.nn.RNNBase: (('LSTM', 5, 10), {}),
|
||||
torch.nn.RNNCell: ((5, 10), {}),
|
||||
torch.nn.RNNCellBase: ((5, 10, True, 2), {}),
|
||||
torch.nn.RReLU: ((), {}),
|
||||
torch.nn.ReLU6: ((), {}),
|
||||
torch.nn.ReLU: ((), {}),
|
||||
torch.nn.ReflectionPad1d: ((2,), {}),
|
||||
torch.nn.ReflectionPad2d: ((2,), {}),
|
||||
torch.nn.ReplicationPad1d: ((2,), {}),
|
||||
torch.nn.ReplicationPad2d: ((2,), {}),
|
||||
torch.nn.ReplicationPad3d: ((2,), {}),
|
||||
torch.nn.SELU: ((), {}),
|
||||
torch.nn.Sequential: ((), {}),
|
||||
torch.nn.SiLU: ((), {}),
|
||||
torch.nn.Sigmoid: ((), {}),
|
||||
torch.nn.SmoothL1Loss: ((), {}),
|
||||
torch.nn.SoftMarginLoss: ((), {}),
|
||||
torch.nn.Softmax2d: ((), {}),
|
||||
torch.nn.Softmax: ((), {}),
|
||||
torch.nn.Softmin: ((), {}),
|
||||
torch.nn.Softplus: ((), {}),
|
||||
torch.nn.Softshrink: ((), {}),
|
||||
torch.nn.Softsign: ((), {}),
|
||||
torch.nn.SyncBatchNorm: ((5,), {}),
|
||||
torch.nn.Tanh: ((), {}),
|
||||
torch.nn.Tanhshrink: ((), {}),
|
||||
torch.nn.Threshold: ((0.1, 20), {}),
|
||||
torch.nn.Transformer: ((), {}),
|
||||
torch.nn.TransformerDecoder: ((torch.nn.TransformerDecoderLayer, 3), {}),
|
||||
torch.nn.TransformerDecoderLayer: ((10, 2), {}),
|
||||
torch.nn.TransformerEncoder: ((torch.nn.TransformerEncoderLayer, 3), {}),
|
||||
torch.nn.TransformerEncoderLayer: ((10, 2), {}),
|
||||
torch.nn.TripletMarginLoss: ((), {}),
|
||||
torch.nn.TripletMarginWithDistanceLoss: ((), {}),
|
||||
torch.nn.Unflatten: ((1, (2, 5, 5)), {}),
|
||||
torch.nn.Unfold: ((3,), {}),
|
||||
torch.nn.Upsample: ((), {}),
|
||||
torch.nn.UpsamplingBilinear2d: ((), {}),
|
||||
torch.nn.UpsamplingNearest2d: ((), {}),
|
||||
torch.nn.ZeroPad2d: ((0,), {}),
|
||||
torch.nn.qat.Conv2d: ((3, 3, 3), {
|
||||
'qconfig': torch.quantization.default_qconfig,
|
||||
}),
|
||||
torch.nn.qat.Conv3d: ((3, 3, 3), {
|
||||
'qconfig': torch.quantization.default_qconfig,
|
||||
}),
|
||||
torch.nn.qat.Linear: ((5, 2), {
|
||||
'qconfig': torch.quantization.default_qconfig,
|
||||
}),
|
||||
torch.nn.quantizable.LSTM: ((5, 6), {}),
|
||||
torch.nn.quantizable.LSTMCell: ((5, 6), {}),
|
||||
torch.nn.quantizable.MultiheadAttention: ((10, 2), {}),
|
||||
torch.nn.quantized.BatchNorm2d: ((2,), {}),
|
||||
torch.nn.quantized.BatchNorm3d: ((2,), {}),
|
||||
torch.nn.quantized.Conv1d: ((3, 3, 3), {}),
|
||||
torch.nn.quantized.Conv2d: ((3, 3, 3), {}),
|
||||
torch.nn.quantized.Conv3d: ((3, 3, 3), {}),
|
||||
torch.nn.quantized.ConvTranspose1d: ((3, 3, 3), {}),
|
||||
torch.nn.quantized.ConvTranspose2d: ((3, 3, 3), {}),
|
||||
torch.nn.quantized.ConvTranspose3d: ((16, 33, (3, 3, 5)), {
|
||||
'stride': (2, 1, 1),
|
||||
'padding': (4, 2, 2),
|
||||
'output_padding': (2, 2, 2),
|
||||
'dilation': (1, 1, 1),
|
||||
}),
|
||||
torch.nn.quantized.DeQuantize: ((), {}),
|
||||
torch.nn.quantized.ELU: ((0.01, 0), {}),
|
||||
torch.nn.quantized.Embedding: ((10, 3), {
|
||||
'factory_kwargs': {},
|
||||
}),
|
||||
torch.nn.quantized.EmbeddingBag: ((10, 3), {
|
||||
'factory_kwargs': {},
|
||||
}),
|
||||
torch.nn.quantized.GroupNorm: ((2, 3, torch.nn.Parameter(torch.tensor(2.)),
|
||||
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
|
||||
torch.nn.quantized.Hardswish: ((0.1, 0,), {}),
|
||||
torch.nn.quantized.InstanceNorm1d: ((2, torch.nn.Parameter(torch.tensor(2.)),
|
||||
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
|
||||
torch.nn.quantized.InstanceNorm2d: ((2, torch.nn.Parameter(torch.tensor(2.)),
|
||||
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
|
||||
torch.nn.quantized.InstanceNorm3d: ((2, torch.nn.Parameter(torch.tensor(2.)),
|
||||
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
|
||||
torch.nn.quantized.LayerNorm: ((2, torch.nn.Parameter(torch.tensor(2.)),
|
||||
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
|
||||
torch.nn.quantized.LeakyReLU: ((0.01, 0), {}),
|
||||
torch.nn.quantized.Linear: ((5, 2), {
|
||||
'factory_kwargs': {},
|
||||
}),
|
||||
torch.nn.quantized.MaxPool2d: ((3,), {}),
|
||||
torch.nn.quantized.Quantize: ((0.1, 0), {
|
||||
'dtype': torch.int16,
|
||||
'factory_kwargs': {},
|
||||
}),
|
||||
torch.nn.quantized.ReLU6: ((), {}),
|
||||
torch.nn.quantized.Sigmoid: ((0.1, 0), {}),
|
||||
torch.nn.quantized.FloatFunctional: ((), {}),
|
||||
torch.nn.quantized.FXFloatFunctional: ((), {}),
|
||||
torch.nn.quantized.QFunctional: ((), {}),
|
||||
}
|
||||
|
||||
|
||||
# Instantiates the given class with the given args, kwargs, optionally on a given device.
|
||||
def instantiate_class(cls, args, kwargs, extra_kwargs):
|
||||
return cls(*args, **kwargs) if extra_kwargs is None else cls(*args, **kwargs, **extra_kwargs)
|
||||
|
||||
|
||||
# Returns a function that calls the real implementation of a method
|
||||
# in addition to passing args to a mock object.
|
||||
def mock_wrapper(method):
|
||||
mock = MagicMock()
|
||||
|
||||
def wrapper(self, *args, **kwargs):
|
||||
mock(*args, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
wrapper.mock = mock
|
||||
return wrapper
|
||||
|
||||
|
||||
# Returns a set of args / kwargs that can be used to construct the module.
|
||||
def get_example_args(module_cls, constructor_arg_db, extra_kwargs=None):
|
||||
assert module_cls in constructor_arg_db, \
|
||||
f"No entry for {module_cls} in the constructor arg DB. Please add it to pass these tests."
|
||||
args, kwargs = constructor_arg_db[module_cls]
|
||||
extra_kwargs = {} if extra_kwargs is None else extra_kwargs
|
||||
|
||||
# Recursively instantiate args / kwargs that are class objects.
|
||||
args = [instantiate_class(arg, *get_example_args(arg, constructor_arg_db), extra_kwargs=extra_kwargs)
|
||||
if inspect.isclass(arg) else torch.nn.Parameter(arg.to(**extra_kwargs))
|
||||
if isinstance(arg, torch.nn.Parameter) else arg for arg in args]
|
||||
kwargs = {k: instantiate_class(v, *get_example_args(v, constructor_arg_db), extra_kwargs=extra_kwargs)
|
||||
if inspect.isclass(v) else torch.nn.Parameter(v.to(*extra_kwargs))
|
||||
if isinstance(v, torch.nn.Parameter) else v for k, v in kwargs.items()}
|
||||
kwargs.update(extra_kwargs)
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def generate_test_func(test_cls, module_cls, constructor_arg_db,
|
||||
verify_kwargs=True, module_is_lazy=False, check_nonexistent_arg=True):
|
||||
# Generate a function for testing the given module.
|
||||
@dtypes(*floating_types())
|
||||
def run_test(test_cls, device, dtype, module_cls=module_cls):
|
||||
# Check if this module creates parameters or registers buffers.
|
||||
# The mock magic here passes through to the real Parameter / register_buffer
|
||||
# logic and is only used to check for calls.
|
||||
args, kwargs = get_example_args(module_cls, constructor_arg_db)
|
||||
|
||||
# Some modules need to pass factory_kwargs so as not to conflict with existing args such as dtype.
|
||||
module_needs_factory_kwargs = 'factory_kwargs' in kwargs
|
||||
if module_needs_factory_kwargs:
|
||||
del kwargs['factory_kwargs']
|
||||
extra_kwargs = {
|
||||
'factory_kwargs': {
|
||||
'device': device,
|
||||
'dtype': dtype,
|
||||
}
|
||||
}
|
||||
else:
|
||||
extra_kwargs = {
|
||||
'device': device,
|
||||
'dtype': dtype,
|
||||
}
|
||||
|
||||
parameter_new = mock_wrapper(torch.nn.Parameter.__new__)
|
||||
with patch.object(torch.nn.Parameter, '__new__', parameter_new):
|
||||
register_buffer = mock_wrapper(torch.nn.Module.register_buffer)
|
||||
with patch.object(torch.nn.Module, 'register_buffer', register_buffer):
|
||||
m = module_cls(*args, **kwargs)
|
||||
module_creates_params_or_buffers = parameter_new.mock.called or register_buffer.mock.called
|
||||
|
||||
# == Verify factory kwargs are supported. ==
|
||||
if verify_kwargs and module_creates_params_or_buffers:
|
||||
args, kwargs = get_example_args(module_cls, constructor_arg_db,
|
||||
extra_kwargs=extra_kwargs)
|
||||
|
||||
if module_is_lazy:
|
||||
# Ensure device and dtype are passed to all UninitializedParameters and UninitializedBuffers.
|
||||
uninit_param_new = mock_wrapper(torch.nn.UninitializedParameter.__new__)
|
||||
with patch.object(torch.nn.UninitializedParameter, '__new__', uninit_param_new):
|
||||
uninit_buffer_new = mock_wrapper(torch.nn.UninitializedBuffer.__new__)
|
||||
with patch.object(torch.nn.UninitializedBuffer, '__new__', uninit_buffer_new):
|
||||
m = module_cls(*args, **kwargs)
|
||||
uninit_param_new.mock.assert_has_calls(
|
||||
[mock.call(device=device, dtype=dtype) for _ in uninit_param_new.mock.mock_calls])
|
||||
uninit_buffer_new.mock.assert_has_calls(
|
||||
[mock.call(device=device, dtype=dtype) for _ in uninit_buffer_new.mock.mock_calls])
|
||||
else:
|
||||
# Check device placement and dtype for parameters and buffers.
|
||||
# Only verify floating point dtypes since that's what the kwarg applies to.
|
||||
# Note that dtype verification is also skipped if the module requires factory_kwargs.
|
||||
m = module_cls(*args, **kwargs)
|
||||
for name, param in m.named_parameters():
|
||||
test_cls.assertEqual(
|
||||
str(param.device), device,
|
||||
f'Parameter {name} is on {param.device.type} instead of the expected device {device}')
|
||||
if param.dtype.is_floating_point and not module_needs_factory_kwargs:
|
||||
test_cls.assertEqual(
|
||||
param.dtype, dtype,
|
||||
f'Parameter {name} is of dtype {param.dtype} instead of the expected dtype {dtype}')
|
||||
for name, buffer in m.named_buffers():
|
||||
test_cls.assertEqual(
|
||||
str(buffer.device), device,
|
||||
f'Buffer {name} is on {buffer.device.type} instead of the expected device {device}')
|
||||
if buffer.dtype.is_floating_point and not module_needs_factory_kwargs:
|
||||
test_cls.assertEqual(
|
||||
buffer.dtype, dtype,
|
||||
f'Buffer {name} is of dtype {buffer.dtype} instead of the expected dtype {dtype}')
|
||||
|
||||
# == Verify passing a nonexistent arg errors out. ==
|
||||
if check_nonexistent_arg:
|
||||
with test_cls.assertRaises(TypeError):
|
||||
m = module_cls(*args, **kwargs, nonexistent_arg='foo')
|
||||
|
||||
return run_test
|
||||
|
||||
|
||||
def generate_tests(test_cls, constructor_arg_db):
|
||||
# test all modules underneath these namespaces...
|
||||
NAMESPACES = [
|
||||
torch.nn,
|
||||
torch.nn.qat,
|
||||
torch.nn.quantizable,
|
||||
torch.nn.quantized,
|
||||
]
|
||||
# ...except these
|
||||
MODULES_TO_SKIP = {
|
||||
torch.nn.Module,
|
||||
torch.nn.Container, # deprecated
|
||||
torch.nn.NLLLoss2d, # deprecated
|
||||
torch.nn.quantized._ConvNd # base class in __all__ for some reason
|
||||
}
|
||||
# no need to support kwargs for these modules even though
|
||||
# they have parameters / buffers because they are passed in
|
||||
# already instantiated
|
||||
MODULES_WITHOUT_KWARGS_SUPPORT = {
|
||||
torch.nn.BCELoss,
|
||||
torch.nn.BCEWithLogitsLoss,
|
||||
torch.nn.CrossEntropyLoss,
|
||||
torch.nn.FractionalMaxPool2d,
|
||||
torch.nn.FractionalMaxPool3d,
|
||||
torch.nn.MultiLabelSoftMarginLoss,
|
||||
torch.nn.MultiMarginLoss,
|
||||
torch.nn.NLLLoss,
|
||||
torch.nn.TransformerDecoder,
|
||||
torch.nn.TransformerEncoder,
|
||||
}
|
||||
# modules that supported kwargs before
|
||||
MODULES_WITH_PREVIOUS_KWARGS = {
|
||||
torch.nn.Identity,
|
||||
}
|
||||
# lazy modules don't instantiate parameters right away
|
||||
LAZY_MODULES = {
|
||||
torch.nn.LazyBatchNorm1d,
|
||||
torch.nn.LazyBatchNorm2d,
|
||||
torch.nn.LazyBatchNorm3d,
|
||||
torch.nn.LazyConv1d,
|
||||
torch.nn.LazyConv2d,
|
||||
torch.nn.LazyConv3d,
|
||||
torch.nn.LazyConvTranspose1d,
|
||||
torch.nn.LazyConvTranspose2d,
|
||||
torch.nn.LazyConvTranspose3d,
|
||||
torch.nn.LazyConvTranspose3d,
|
||||
torch.nn.LazyLinear,
|
||||
}
|
||||
# these modules requires FBGEMM backend to instantiate
|
||||
MODULES_THAT_REQUIRE_FBGEMM = {
|
||||
torch.nn.quantized.Conv1d,
|
||||
torch.nn.quantized.Conv2d,
|
||||
torch.nn.quantized.Conv3d,
|
||||
torch.nn.quantized.ConvTranspose1d,
|
||||
torch.nn.quantized.ConvTranspose2d,
|
||||
torch.nn.quantized.ConvTranspose3d,
|
||||
torch.nn.quantized.Linear,
|
||||
}
|
||||
|
||||
for namespace in NAMESPACES:
|
||||
# the "nn" in "torch.nn"
|
||||
namespace_basename = namespace.__name__.split('.')[-1]
|
||||
for module_name in namespace.modules.__all__:
|
||||
# class object for this module (e.g. torch.nn.Linear)
|
||||
module_cls = getattr(namespace.modules, module_name)
|
||||
if module_cls in MODULES_TO_SKIP:
|
||||
continue
|
||||
verify_kwargs = module_cls not in MODULES_WITHOUT_KWARGS_SUPPORT
|
||||
module_is_lazy = module_cls in LAZY_MODULES
|
||||
check_nonexistent_arg = module_cls not in MODULES_WITH_PREVIOUS_KWARGS
|
||||
# Generate a function for testing this module and setattr it onto the test class.
|
||||
run_test = generate_test_func(test_cls, module_cls, constructor_arg_db,
|
||||
verify_kwargs=verify_kwargs,
|
||||
module_is_lazy=module_is_lazy,
|
||||
check_nonexistent_arg=check_nonexistent_arg)
|
||||
test_name = f'test_{namespace_basename}_{module_name}'
|
||||
if module_cls in MODULES_THAT_REQUIRE_FBGEMM:
|
||||
run_test = skipIfNoFBGEMM(run_test)
|
||||
setattr(TestModuleInit, test_name, run_test)
|
||||
|
||||
|
||||
class TestModuleInit(TestCase):
|
||||
_ignore_not_implemented_error = False
|
||||
|
||||
|
||||
generate_tests(TestModuleInit, build_constructor_arg_db())
|
||||
instantiate_device_type_tests(TestModuleInit, globals())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
@ -3,44 +3,3 @@ from .parameter import Parameter, UninitializedParameter, UninitializedBuffer
|
|||
from .parallel import DataParallel
|
||||
from . import init
|
||||
from . import utils
|
||||
|
||||
|
||||
def factory_kwargs(kwargs):
|
||||
r"""
|
||||
Given kwargs, returns a canonicalized dict of factory kwargs that can be directly passed
|
||||
to factory functions like torch.empty, or errors if unrecognized kwargs are present.
|
||||
|
||||
This function makes it simple to write code like this::
|
||||
|
||||
class MyModule(nn.Module):
|
||||
def __init__(self, **kwargs):
|
||||
factory_kwargs = torch.nn.factory_kwargs(kwargs)
|
||||
self.weight = Parameter(torch.empty(10, **factory_kwargs))
|
||||
|
||||
Why should you use this function instead of just passing `kwargs` along directly?
|
||||
|
||||
1. This function does error validation, so if there are unexpected kwargs we will
|
||||
immediately report an error, instead of deferring it to the factory call
|
||||
2. This function supports a special `factory_kwargs` argument, which can be used to
|
||||
explicitly specify a kwarg to be used for factory functions, in the event one of the
|
||||
factory kwargs conflicts with an already existing argument in the signature (e.g.
|
||||
in the signature ``def f(dtype, **kwargs)``, you can specify ``dtype`` for factory
|
||||
functions, as distinct from the dtype argument, by saying
|
||||
``f(dtype1, factory_kwargs={"dtype": dtype2})``)
|
||||
"""
|
||||
if kwargs is None:
|
||||
return {}
|
||||
simple_keys = {"device", "dtype", "memory_format"}
|
||||
expected_keys = simple_keys | {"factory_kwargs"}
|
||||
if not kwargs.keys() <= expected_keys:
|
||||
raise TypeError(f"unexpected kwargs {kwargs.keys() - expected_keys}")
|
||||
|
||||
# guarantee no input kwargs is untouched
|
||||
r = dict(kwargs.get("factory_kwargs", {}))
|
||||
for k in simple_keys:
|
||||
if k in kwargs:
|
||||
if k in r:
|
||||
raise TypeError(f"{k} specified twice, in **kwargs and in factory_kwargs")
|
||||
r[k] = kwargs[k]
|
||||
|
||||
return r
|
||||
|
|
|
|||
|
|
@ -872,8 +872,7 @@ class MultiheadAttention(Module):
|
|||
bias_v: Optional[torch.Tensor]
|
||||
|
||||
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,
|
||||
kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
kdim=None, vdim=None, batch_first=False):
|
||||
super(MultiheadAttention, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.kdim = kdim if kdim is not None else embed_dim
|
||||
|
|
@ -887,25 +886,25 @@ class MultiheadAttention(Module):
|
|||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
|
||||
if self._qkv_same_embed_dim is False:
|
||||
self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
|
||||
self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
|
||||
self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))
|
||||
self.q_proj_weight = Parameter(torch.empty(embed_dim, embed_dim))
|
||||
self.k_proj_weight = Parameter(torch.empty(embed_dim, self.kdim))
|
||||
self.v_proj_weight = Parameter(torch.empty(embed_dim, self.vdim))
|
||||
self.register_parameter('in_proj_weight', None)
|
||||
else:
|
||||
self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
|
||||
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
|
||||
self.register_parameter('q_proj_weight', None)
|
||||
self.register_parameter('k_proj_weight', None)
|
||||
self.register_parameter('v_proj_weight', None)
|
||||
|
||||
if bias:
|
||||
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
|
||||
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
|
||||
else:
|
||||
self.register_parameter('in_proj_bias', None)
|
||||
self.out_proj = Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
||||
self.out_proj = Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
if add_bias_kv:
|
||||
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||
self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||
self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
|
||||
self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
|
||||
else:
|
||||
self.bias_k = self.bias_v = None
|
||||
|
||||
|
|
@ -1058,12 +1057,10 @@ class PReLU(Module):
|
|||
__constants__ = ['num_parameters']
|
||||
num_parameters: int
|
||||
|
||||
def __init__(self, num_parameters: int = 1, init: float = 0.25,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
def __init__(self, num_parameters: int = 1, init: float = 0.25) -> None:
|
||||
self.num_parameters = num_parameters
|
||||
super(PReLU, self).__init__()
|
||||
self.weight = Parameter(torch.empty(num_parameters, **factory_kwargs).fill_(init))
|
||||
self.weight = Parameter(torch.empty(num_parameters).fill_(init))
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return F.prelu(input, self.weight)
|
||||
|
|
|
|||
|
|
@ -115,11 +115,8 @@ class AdaptiveLogSoftmaxWithLoss(Module):
|
|||
n_classes: int,
|
||||
cutoffs: Sequence[int],
|
||||
div_value: float = 4.,
|
||||
head_bias: bool = False,
|
||||
device=None,
|
||||
dtype=None
|
||||
head_bias: bool = False
|
||||
) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super(AdaptiveLogSoftmaxWithLoss, self).__init__()
|
||||
|
||||
cutoffs = list(cutoffs)
|
||||
|
|
@ -144,8 +141,7 @@ class AdaptiveLogSoftmaxWithLoss(Module):
|
|||
self.n_clusters = len(self.cutoffs) - 1
|
||||
self.head_size = self.shortlist_size + self.n_clusters
|
||||
|
||||
self.head = Linear(self.in_features, self.head_size, bias=self.head_bias,
|
||||
**factory_kwargs)
|
||||
self.head = Linear(self.in_features, self.head_size, bias=self.head_bias)
|
||||
self.tail = ModuleList()
|
||||
|
||||
for i in range(self.n_clusters):
|
||||
|
|
@ -154,8 +150,8 @@ class AdaptiveLogSoftmaxWithLoss(Module):
|
|||
osz = self.cutoffs[i + 1] - self.cutoffs[i]
|
||||
|
||||
projection = Sequential(
|
||||
Linear(self.in_features, hsz, bias=False, **factory_kwargs),
|
||||
Linear(hsz, osz, bias=False, **factory_kwargs),
|
||||
Linear(self.in_features, hsz, bias=False),
|
||||
Linear(hsz, osz, bias=False)
|
||||
)
|
||||
|
||||
self.tail.append(projection)
|
||||
|
|
|
|||
|
|
@ -31,10 +31,7 @@ class _NormBase(Module):
|
|||
momentum: float = 0.1,
|
||||
affine: bool = True,
|
||||
track_running_stats: bool = True,
|
||||
device=None,
|
||||
dtype=None
|
||||
) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super(_NormBase, self).__init__()
|
||||
self.num_features = num_features
|
||||
self.eps = eps
|
||||
|
|
@ -42,17 +39,17 @@ class _NormBase(Module):
|
|||
self.affine = affine
|
||||
self.track_running_stats = track_running_stats
|
||||
if self.affine:
|
||||
self.weight = Parameter(torch.empty(num_features, **factory_kwargs))
|
||||
self.bias = Parameter(torch.empty(num_features, **factory_kwargs))
|
||||
self.weight = Parameter(torch.empty(num_features))
|
||||
self.bias = Parameter(torch.empty(num_features))
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
self.register_parameter("bias", None)
|
||||
if self.track_running_stats:
|
||||
self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs))
|
||||
self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs))
|
||||
self.register_buffer('num_batches_tracked',
|
||||
torch.tensor(0, dtype=torch.long,
|
||||
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
|
||||
self.register_buffer("running_mean", torch.zeros(num_features))
|
||||
self.register_buffer("running_var", torch.ones(num_features))
|
||||
self.register_buffer(
|
||||
"num_batches_tracked", torch.tensor(0, dtype=torch.long)
|
||||
)
|
||||
else:
|
||||
self.register_buffer("running_mean", None)
|
||||
self.register_buffer("running_var", None)
|
||||
|
|
@ -120,12 +117,9 @@ class _BatchNorm(_NormBase):
|
|||
momentum=0.1,
|
||||
affine=True,
|
||||
track_running_stats=True,
|
||||
device=None,
|
||||
dtype=None
|
||||
):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super(_BatchNorm, self).__init__(
|
||||
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
|
||||
num_features, eps, momentum, affine, track_running_stats
|
||||
)
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
|
|
@ -184,9 +178,7 @@ class _LazyBatchNorm(LazyModuleMixin, _BatchNorm):
|
|||
weight: UninitializedParameter # type: ignore[assignment]
|
||||
bias: UninitializedParameter # type: ignore[assignment]
|
||||
|
||||
def __init__(self, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
def __init__(self, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
|
||||
super(_LazyBatchNorm, self).__init__(
|
||||
# affine and track_running_stats are hardcoded to False to
|
||||
# avoid creating tensors that will soon be overwritten.
|
||||
|
|
@ -195,18 +187,16 @@ class _LazyBatchNorm(LazyModuleMixin, _BatchNorm):
|
|||
momentum,
|
||||
False,
|
||||
False,
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.affine = affine
|
||||
self.track_running_stats = track_running_stats
|
||||
if self.affine:
|
||||
self.weight = UninitializedParameter(**factory_kwargs)
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
self.weight = UninitializedParameter()
|
||||
self.bias = UninitializedParameter()
|
||||
if self.track_running_stats:
|
||||
self.running_mean = UninitializedBuffer(**factory_kwargs)
|
||||
self.running_var = UninitializedBuffer(**factory_kwargs)
|
||||
self.num_batches_tracked = torch.tensor(
|
||||
0, dtype=torch.long, **{k: v for k, v in factory_kwargs.items() if k != 'dtype'})
|
||||
self.running_mean = UninitializedBuffer()
|
||||
self.running_var = UninitializedBuffer()
|
||||
self.num_batches_tracked = torch.tensor(0, dtype=torch.long)
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
if not self.has_uninitialized_params() and self.num_features != 0:
|
||||
|
|
@ -650,12 +640,9 @@ class SyncBatchNorm(_BatchNorm):
|
|||
affine: bool = True,
|
||||
track_running_stats: bool = True,
|
||||
process_group: Optional[Any] = None,
|
||||
device=None,
|
||||
dtype=None
|
||||
) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super(SyncBatchNorm, self).__init__(
|
||||
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
|
||||
num_features, eps, momentum, affine, track_running_stats
|
||||
)
|
||||
self.process_group = process_group
|
||||
# gpu_size is set through DistributedDataParallel initialization. This is to ensure that SyncBatchNorm is used
|
||||
|
|
|
|||
|
|
@ -75,10 +75,7 @@ class _ConvNd(Module):
|
|||
output_padding: Tuple[int, ...],
|
||||
groups: int,
|
||||
bias: bool,
|
||||
padding_mode: str,
|
||||
device=None,
|
||||
dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
padding_mode: str) -> None:
|
||||
super(_ConvNd, self).__init__()
|
||||
if in_channels % groups != 0:
|
||||
raise ValueError('in_channels must be divisible by groups')
|
||||
|
|
@ -126,15 +123,14 @@ class _ConvNd(Module):
|
|||
|
||||
if transposed:
|
||||
self.weight = Parameter(torch.empty(
|
||||
(in_channels, out_channels // groups, *kernel_size), **factory_kwargs))
|
||||
in_channels, out_channels // groups, *kernel_size))
|
||||
else:
|
||||
self.weight = Parameter(torch.empty(
|
||||
(out_channels, in_channels // groups, *kernel_size), **factory_kwargs))
|
||||
out_channels, in_channels // groups, *kernel_size))
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(out_channels, **factory_kwargs))
|
||||
self.bias = Parameter(torch.empty(out_channels))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
|
|
@ -271,11 +267,8 @@ class Conv1d(_ConvNd):
|
|||
dilation: _size_1_t = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = 'zeros', # TODO: refine this type
|
||||
device=None,
|
||||
dtype=None
|
||||
) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
padding_mode: str = 'zeros' # TODO: refine this type
|
||||
):
|
||||
# we create new variables below to make mypy happy since kernel_size has
|
||||
# type Union[int, Tuple[int]] and kernel_size_ has type Tuple[int]
|
||||
kernel_size_ = _single(kernel_size)
|
||||
|
|
@ -284,7 +277,7 @@ class Conv1d(_ConvNd):
|
|||
dilation_ = _single(dilation)
|
||||
super(Conv1d, self).__init__(
|
||||
in_channels, out_channels, kernel_size_, stride_, padding_, dilation_,
|
||||
False, _single(0), groups, bias, padding_mode, **factory_kwargs)
|
||||
False, _single(0), groups, bias, padding_mode)
|
||||
|
||||
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
|
||||
if self.padding_mode != 'zeros':
|
||||
|
|
@ -418,18 +411,15 @@ class Conv2d(_ConvNd):
|
|||
dilation: _size_2_t = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = 'zeros', # TODO: refine this type
|
||||
device=None,
|
||||
dtype=None
|
||||
) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
padding_mode: str = 'zeros' # TODO: refine this type
|
||||
):
|
||||
kernel_size_ = _pair(kernel_size)
|
||||
stride_ = _pair(stride)
|
||||
padding_ = padding if isinstance(padding, str) else _pair(padding)
|
||||
dilation_ = _pair(dilation)
|
||||
super(Conv2d, self).__init__(
|
||||
in_channels, out_channels, kernel_size_, stride_, padding_, dilation_,
|
||||
False, _pair(0), groups, bias, padding_mode, **factory_kwargs)
|
||||
False, _pair(0), groups, bias, padding_mode)
|
||||
|
||||
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
|
||||
if self.padding_mode != 'zeros':
|
||||
|
|
@ -553,18 +543,15 @@ class Conv3d(_ConvNd):
|
|||
dilation: _size_3_t = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = 'zeros',
|
||||
device=None,
|
||||
dtype=None
|
||||
) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
padding_mode: str = 'zeros'
|
||||
):
|
||||
kernel_size_ = _triple(kernel_size)
|
||||
stride_ = _triple(stride)
|
||||
padding_ = padding if isinstance(padding, str) else _triple(padding)
|
||||
dilation_ = _triple(dilation)
|
||||
super(Conv3d, self).__init__(
|
||||
in_channels, out_channels, kernel_size_, stride_, padding_, dilation_,
|
||||
False, _triple(0), groups, bias, padding_mode, **factory_kwargs)
|
||||
False, _triple(0), groups, bias, padding_mode)
|
||||
|
||||
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
|
||||
if self.padding_mode != "zeros":
|
||||
|
|
@ -591,15 +578,14 @@ class Conv3d(_ConvNd):
|
|||
class _ConvTransposeNd(_ConvNd):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride,
|
||||
padding, dilation, transposed, output_padding,
|
||||
groups, bias, padding_mode, device=None, dtype=None) -> None:
|
||||
groups, bias, padding_mode):
|
||||
if padding_mode != 'zeros':
|
||||
raise ValueError('Only "zeros" padding mode is supported for {}'.format(self.__class__.__name__))
|
||||
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super(_ConvTransposeNd, self).__init__(
|
||||
in_channels, out_channels, kernel_size, stride,
|
||||
padding, dilation, transposed, output_padding,
|
||||
groups, bias, padding_mode, **factory_kwargs)
|
||||
groups, bias, padding_mode)
|
||||
|
||||
# dilation being an optional parameter is for backwards
|
||||
# compatibility
|
||||
|
|
@ -741,11 +727,8 @@ class ConvTranspose1d(_ConvTransposeNd):
|
|||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
dilation: _size_1_t = 1,
|
||||
padding_mode: str = 'zeros',
|
||||
device=None,
|
||||
dtype=None
|
||||
) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
padding_mode: str = 'zeros'
|
||||
):
|
||||
kernel_size = _single(kernel_size)
|
||||
stride = _single(stride)
|
||||
padding = _single(padding)
|
||||
|
|
@ -753,7 +736,7 @@ class ConvTranspose1d(_ConvTransposeNd):
|
|||
output_padding = _single(output_padding)
|
||||
super(ConvTranspose1d, self).__init__(
|
||||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||
True, output_padding, groups, bias, padding_mode, **factory_kwargs)
|
||||
True, output_padding, groups, bias, padding_mode)
|
||||
|
||||
def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
|
||||
if self.padding_mode != 'zeros':
|
||||
|
|
@ -889,11 +872,8 @@ class ConvTranspose2d(_ConvTransposeNd):
|
|||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
dilation: int = 1,
|
||||
padding_mode: str = 'zeros',
|
||||
device=None,
|
||||
dtype=None
|
||||
) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
padding_mode: str = 'zeros'
|
||||
):
|
||||
kernel_size = _pair(kernel_size)
|
||||
stride = _pair(stride)
|
||||
padding = _pair(padding)
|
||||
|
|
@ -901,7 +881,7 @@ class ConvTranspose2d(_ConvTransposeNd):
|
|||
output_padding = _pair(output_padding)
|
||||
super(ConvTranspose2d, self).__init__(
|
||||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||
True, output_padding, groups, bias, padding_mode, **factory_kwargs)
|
||||
True, output_padding, groups, bias, padding_mode)
|
||||
|
||||
def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
|
||||
if self.padding_mode != 'zeros':
|
||||
|
|
@ -1034,11 +1014,8 @@ class ConvTranspose3d(_ConvTransposeNd):
|
|||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
dilation: _size_3_t = 1,
|
||||
padding_mode: str = 'zeros',
|
||||
device=None,
|
||||
dtype=None
|
||||
) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
padding_mode: str = 'zeros'
|
||||
):
|
||||
kernel_size = _triple(kernel_size)
|
||||
stride = _triple(stride)
|
||||
padding = _triple(padding)
|
||||
|
|
@ -1046,7 +1023,7 @@ class ConvTranspose3d(_ConvTransposeNd):
|
|||
output_padding = _triple(output_padding)
|
||||
super(ConvTranspose3d, self).__init__(
|
||||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||
True, output_padding, groups, bias, padding_mode, **factory_kwargs)
|
||||
True, output_padding, groups, bias, padding_mode)
|
||||
|
||||
def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
|
||||
if self.padding_mode != 'zeros':
|
||||
|
|
@ -1169,11 +1146,8 @@ class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc]
|
|||
dilation: _size_1_t = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = 'zeros',
|
||||
device=None,
|
||||
dtype=None
|
||||
padding_mode: str = 'zeros'
|
||||
) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__(
|
||||
0,
|
||||
0,
|
||||
|
|
@ -1185,13 +1159,12 @@ class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc]
|
|||
# bias is hardcoded to False to avoid creating tensor
|
||||
# that will soon be overwritten.
|
||||
False,
|
||||
padding_mode,
|
||||
**factory_kwargs
|
||||
padding_mode
|
||||
)
|
||||
self.weight = UninitializedParameter(**factory_kwargs)
|
||||
self.weight = UninitializedParameter()
|
||||
self.out_channels = out_channels
|
||||
if bias:
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
self.bias = UninitializedParameter()
|
||||
|
||||
|
||||
# LazyConv2d defines weight as a Tensor but derived class defines it as UnitializeParameter
|
||||
|
|
@ -1235,11 +1208,8 @@ class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc]
|
|||
dilation: _size_2_t = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = 'zeros', # TODO: refine this type
|
||||
device=None,
|
||||
dtype=None
|
||||
padding_mode: str = 'zeros' # TODO: refine this type
|
||||
) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__(
|
||||
0,
|
||||
0,
|
||||
|
|
@ -1251,13 +1221,12 @@ class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc]
|
|||
# bias is hardcoded to False to avoid creating tensor
|
||||
# that will soon be overwritten.
|
||||
False,
|
||||
padding_mode,
|
||||
**factory_kwargs
|
||||
padding_mode
|
||||
)
|
||||
self.weight = UninitializedParameter(**factory_kwargs)
|
||||
self.weight = UninitializedParameter()
|
||||
self.out_channels = out_channels
|
||||
if bias:
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
self.bias = UninitializedParameter()
|
||||
|
||||
|
||||
# LazyConv3d defines weight as a Tensor but derived class defines it as UnitializeParameter
|
||||
|
|
@ -1301,11 +1270,8 @@ class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc]
|
|||
dilation: _size_3_t = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = 'zeros',
|
||||
device=None,
|
||||
dtype=None
|
||||
padding_mode: str = 'zeros'
|
||||
) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__(
|
||||
0,
|
||||
0,
|
||||
|
|
@ -1317,13 +1283,12 @@ class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc]
|
|||
# bias is hardcoded to False to avoid creating tensor
|
||||
# that will soon be overwritten.
|
||||
False,
|
||||
padding_mode,
|
||||
**factory_kwargs
|
||||
padding_mode
|
||||
)
|
||||
self.weight = UninitializedParameter(**factory_kwargs)
|
||||
self.weight = UninitializedParameter()
|
||||
self.out_channels = out_channels
|
||||
if bias:
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
self.bias = UninitializedParameter()
|
||||
|
||||
|
||||
# LazyConvTranspose1d defines weight as a Tensor but derived class defines it as UnitializeParameter
|
||||
|
|
@ -1365,11 +1330,8 @@ class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[mi
|
|||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
dilation: _size_1_t = 1,
|
||||
padding_mode: str = 'zeros',
|
||||
device=None,
|
||||
dtype=None
|
||||
padding_mode: str = 'zeros'
|
||||
) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__(
|
||||
0,
|
||||
0,
|
||||
|
|
@ -1382,13 +1344,12 @@ class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[mi
|
|||
# that will soon be overwritten.
|
||||
False,
|
||||
dilation,
|
||||
padding_mode,
|
||||
**factory_kwargs
|
||||
padding_mode
|
||||
)
|
||||
self.weight = UninitializedParameter(**factory_kwargs)
|
||||
self.weight = UninitializedParameter()
|
||||
self.out_channels = out_channels
|
||||
if bias:
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
self.bias = UninitializedParameter()
|
||||
|
||||
|
||||
# LazyConvTranspose2d defines weight as a Tensor but derived class defines it as UnitializeParameter
|
||||
|
|
@ -1430,11 +1391,8 @@ class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[mi
|
|||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
dilation: int = 1,
|
||||
padding_mode: str = 'zeros',
|
||||
device=None,
|
||||
dtype=None
|
||||
padding_mode: str = 'zeros'
|
||||
) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__(
|
||||
0,
|
||||
0,
|
||||
|
|
@ -1447,13 +1405,12 @@ class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[mi
|
|||
# that will soon be overwritten.
|
||||
False,
|
||||
dilation,
|
||||
padding_mode,
|
||||
**factory_kwargs
|
||||
padding_mode
|
||||
)
|
||||
self.weight = UninitializedParameter(**factory_kwargs)
|
||||
self.weight = UninitializedParameter()
|
||||
self.out_channels = out_channels
|
||||
if bias:
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
self.bias = UninitializedParameter()
|
||||
|
||||
|
||||
# LazyConvTranspose3d defines weight as a Tensor but derived class defines it as UnitializeParameter
|
||||
|
|
@ -1495,11 +1452,8 @@ class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[mi
|
|||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
dilation: _size_3_t = 1,
|
||||
padding_mode: str = 'zeros',
|
||||
device=None,
|
||||
dtype=None
|
||||
padding_mode: str = 'zeros'
|
||||
) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__(
|
||||
0,
|
||||
0,
|
||||
|
|
@ -1512,10 +1466,9 @@ class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[mi
|
|||
# that will soon be overwritten.
|
||||
False,
|
||||
dilation,
|
||||
padding_mode,
|
||||
**factory_kwargs
|
||||
padding_mode
|
||||
)
|
||||
self.weight = UninitializedParameter(**factory_kwargs)
|
||||
self.weight = UninitializedParameter()
|
||||
self.out_channels = out_channels
|
||||
if bias:
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
self.bias = UninitializedParameter()
|
||||
|
|
|
|||
|
|
@ -11,13 +11,10 @@ class _InstanceNorm(_NormBase):
|
|||
eps: float = 1e-5,
|
||||
momentum: float = 0.1,
|
||||
affine: bool = False,
|
||||
track_running_stats: bool = False,
|
||||
device=None,
|
||||
dtype=None
|
||||
track_running_stats: bool = False
|
||||
) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super(_InstanceNorm, self).__init__(
|
||||
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs)
|
||||
num_features, eps, momentum, affine, track_running_stats)
|
||||
|
||||
def _check_input_dim(self, input):
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -72,15 +72,13 @@ class Linear(Module):
|
|||
out_features: int
|
||||
weight: Tensor
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
|
||||
super(Linear, self).__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))
|
||||
self.weight = Parameter(torch.empty(out_features, in_features))
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
|
||||
self.bias = Parameter(torch.empty(out_features))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.reset_parameters()
|
||||
|
|
@ -154,17 +152,15 @@ class Bilinear(Module):
|
|||
out_features: int
|
||||
weight: Tensor
|
||||
|
||||
def __init__(self, in1_features: int, in2_features: int, out_features: int, bias: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
def __init__(self, in1_features: int, in2_features: int, out_features: int, bias: bool = True) -> None:
|
||||
super(Bilinear, self).__init__()
|
||||
self.in1_features = in1_features
|
||||
self.in2_features = in2_features
|
||||
self.out_features = out_features
|
||||
self.weight = Parameter(torch.empty((out_features, in1_features, in2_features), **factory_kwargs))
|
||||
self.weight = Parameter(torch.empty(out_features, in1_features, in2_features))
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
|
||||
self.bias = Parameter(torch.empty(out_features))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.reset_parameters()
|
||||
|
|
@ -217,16 +213,14 @@ class LazyLinear(LazyModuleMixin, Linear):
|
|||
weight: UninitializedParameter
|
||||
bias: UninitializedParameter # type: ignore[assignment]
|
||||
|
||||
def __init__(self, out_features: int, bias: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
def __init__(self, out_features: int, bias: bool = True) -> None:
|
||||
# bias is hardcoded to False to avoid creating tensor
|
||||
# that will soon be overwritten.
|
||||
super().__init__(0, 0, False)
|
||||
self.weight = UninitializedParameter(**factory_kwargs)
|
||||
self.weight = UninitializedParameter()
|
||||
self.out_features = out_features
|
||||
if bias:
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
self.bias = UninitializedParameter()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
if not self.has_uninitialized_params() and self.in_features != 0:
|
||||
|
|
|
|||
|
|
@ -145,9 +145,7 @@ class LayerNorm(Module):
|
|||
eps: float
|
||||
elementwise_affine: bool
|
||||
|
||||
def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True) -> None:
|
||||
super(LayerNorm, self).__init__()
|
||||
if isinstance(normalized_shape, numbers.Integral):
|
||||
# mypy error: incompatible types in assignment
|
||||
|
|
@ -156,12 +154,11 @@ class LayerNorm(Module):
|
|||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
if self.elementwise_affine:
|
||||
self.weight = Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||
self.bias = Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||
self.weight = Parameter(torch.empty(self.normalized_shape))
|
||||
self.bias = Parameter(torch.empty(self.normalized_shape))
|
||||
else:
|
||||
self.register_parameter('weight', None)
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
|
|
@ -226,21 +223,18 @@ class GroupNorm(Module):
|
|||
eps: float
|
||||
affine: bool
|
||||
|
||||
def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True) -> None:
|
||||
super(GroupNorm, self).__init__()
|
||||
self.num_groups = num_groups
|
||||
self.num_channels = num_channels
|
||||
self.eps = eps
|
||||
self.affine = affine
|
||||
if self.affine:
|
||||
self.weight = Parameter(torch.empty(num_channels, **factory_kwargs))
|
||||
self.bias = Parameter(torch.empty(num_channels, **factory_kwargs))
|
||||
self.weight = Parameter(torch.empty(num_channels))
|
||||
self.bias = Parameter(torch.empty(num_channels))
|
||||
else:
|
||||
self.register_parameter('weight', None)
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
|
|
|
|||
|
|
@ -38,9 +38,7 @@ class RNNBase(Module):
|
|||
|
||||
def __init__(self, mode: str, input_size: int, hidden_size: int,
|
||||
num_layers: int = 1, bias: bool = True, batch_first: bool = False,
|
||||
dropout: float = 0., bidirectional: bool = False, proj_size: int = 0,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
dropout: float = 0., bidirectional: bool = False, proj_size: int = 0) -> None:
|
||||
super(RNNBase, self).__init__()
|
||||
self.mode = mode
|
||||
self.input_size = input_size
|
||||
|
|
@ -86,12 +84,12 @@ class RNNBase(Module):
|
|||
real_hidden_size = proj_size if proj_size > 0 else hidden_size
|
||||
layer_input_size = input_size if layer == 0 else real_hidden_size * num_directions
|
||||
|
||||
w_ih = Parameter(torch.empty((gate_size, layer_input_size), **factory_kwargs))
|
||||
w_hh = Parameter(torch.empty((gate_size, real_hidden_size), **factory_kwargs))
|
||||
b_ih = Parameter(torch.empty(gate_size, **factory_kwargs))
|
||||
w_ih = Parameter(torch.empty(gate_size, layer_input_size))
|
||||
w_hh = Parameter(torch.empty(gate_size, real_hidden_size))
|
||||
b_ih = Parameter(torch.empty(gate_size))
|
||||
# Second bias vector included for CuDNN compatibility. Only one
|
||||
# bias vector is needed in standard definition.
|
||||
b_hh = Parameter(torch.empty(gate_size, **factory_kwargs))
|
||||
b_hh = Parameter(torch.empty(gate_size))
|
||||
layer_params: Tuple[Tensor, ...] = ()
|
||||
if self.proj_size == 0:
|
||||
if bias:
|
||||
|
|
@ -99,7 +97,7 @@ class RNNBase(Module):
|
|||
else:
|
||||
layer_params = (w_ih, w_hh)
|
||||
else:
|
||||
w_hr = Parameter(torch.empty((proj_size, hidden_size), **factory_kwargs))
|
||||
w_hr = Parameter(torch.empty(proj_size, hidden_size))
|
||||
if bias:
|
||||
layer_params = (w_ih, w_hh, b_ih, b_hh, w_hr)
|
||||
else:
|
||||
|
|
@ -120,7 +118,6 @@ class RNNBase(Module):
|
|||
|
||||
self._flat_weights = [(lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn) for wn in self._flat_weights_names]
|
||||
self.flatten_parameters()
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def __setattr__(self, attr, value):
|
||||
|
|
@ -848,22 +845,19 @@ class RNNCellBase(Module):
|
|||
# WARNING: bias_ih and bias_hh purposely not defined here.
|
||||
# See https://github.com/pytorch/pytorch/issues/39670
|
||||
|
||||
def __init__(self, input_size: int, hidden_size: int, bias: bool, num_chunks: int,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
def __init__(self, input_size: int, hidden_size: int, bias: bool, num_chunks: int) -> None:
|
||||
super(RNNCellBase, self).__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.bias = bias
|
||||
self.weight_ih = Parameter(torch.empty((num_chunks * hidden_size, input_size), **factory_kwargs))
|
||||
self.weight_hh = Parameter(torch.empty((num_chunks * hidden_size, hidden_size), **factory_kwargs))
|
||||
self.weight_ih = Parameter(torch.empty(num_chunks * hidden_size, input_size))
|
||||
self.weight_hh = Parameter(torch.empty(num_chunks * hidden_size, hidden_size))
|
||||
if bias:
|
||||
self.bias_ih = Parameter(torch.empty(num_chunks * hidden_size, **factory_kwargs))
|
||||
self.bias_hh = Parameter(torch.empty(num_chunks * hidden_size, **factory_kwargs))
|
||||
self.bias_ih = Parameter(torch.empty(num_chunks * hidden_size))
|
||||
self.bias_hh = Parameter(torch.empty(num_chunks * hidden_size))
|
||||
else:
|
||||
self.register_parameter('bias_ih', None)
|
||||
self.register_parameter('bias_hh', None)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
|
|
@ -940,10 +934,8 @@ class RNNCell(RNNCellBase):
|
|||
__constants__ = ['input_size', 'hidden_size', 'bias', 'nonlinearity']
|
||||
nonlinearity: str
|
||||
|
||||
def __init__(self, input_size: int, hidden_size: int, bias: bool = True, nonlinearity: str = "tanh",
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super(RNNCell, self).__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs)
|
||||
def __init__(self, input_size: int, hidden_size: int, bias: bool = True, nonlinearity: str = "tanh") -> None:
|
||||
super(RNNCell, self).__init__(input_size, hidden_size, bias, num_chunks=1)
|
||||
self.nonlinearity = nonlinearity
|
||||
|
||||
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
|
||||
|
|
@ -1030,10 +1022,8 @@ class LSTMCell(RNNCellBase):
|
|||
>>> output = torch.stack(output, dim=0)
|
||||
"""
|
||||
|
||||
def __init__(self, input_size: int, hidden_size: int, bias: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super(LSTMCell, self).__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs)
|
||||
def __init__(self, input_size: int, hidden_size: int, bias: bool = True) -> None:
|
||||
super(LSTMCell, self).__init__(input_size, hidden_size, bias, num_chunks=4)
|
||||
|
||||
def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
|
||||
if hx is None:
|
||||
|
|
@ -1108,10 +1098,8 @@ class GRUCell(RNNCellBase):
|
|||
output.append(hx)
|
||||
"""
|
||||
|
||||
def __init__(self, input_size: int, hidden_size: int, bias: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super(GRUCell, self).__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs)
|
||||
def __init__(self, input_size: int, hidden_size: int, bias: bool = True) -> None:
|
||||
super(GRUCell, self).__init__(input_size, hidden_size, bias, num_chunks=3)
|
||||
|
||||
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
|
||||
if hx is None:
|
||||
|
|
|
|||
|
|
@ -119,9 +119,7 @@ class Embedding(Module):
|
|||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
|
||||
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False, _weight: Optional[Tensor] = None,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
sparse: bool = False, _weight: Optional[Tensor] = None) -> None:
|
||||
super(Embedding, self).__init__()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
|
|
@ -136,7 +134,7 @@ class Embedding(Module):
|
|||
self.norm_type = norm_type
|
||||
self.scale_grad_by_freq = scale_grad_by_freq
|
||||
if _weight is None:
|
||||
self.weight = Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs))
|
||||
self.weight = Parameter(torch.empty(num_embeddings, embedding_dim))
|
||||
self.reset_parameters()
|
||||
else:
|
||||
assert list(_weight.shape) == [num_embeddings, embedding_dim], \
|
||||
|
|
@ -309,9 +307,7 @@ class EmbeddingBag(Module):
|
|||
def __init__(self, num_embeddings: int, embedding_dim: int,
|
||||
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
|
||||
mode: str = 'mean', sparse: bool = False, _weight: Optional[Tensor] = None,
|
||||
include_last_offset: bool = False, padding_idx: Optional[int] = None,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
include_last_offset: bool = False, padding_idx: Optional[int] = None) -> None:
|
||||
super(EmbeddingBag, self).__init__()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
|
|
@ -326,7 +322,7 @@ class EmbeddingBag(Module):
|
|||
padding_idx = self.num_embeddings + padding_idx
|
||||
self.padding_idx = padding_idx
|
||||
if _weight is None:
|
||||
self.weight = Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs))
|
||||
self.weight = Parameter(torch.empty(num_embeddings, embedding_dim))
|
||||
self.reset_parameters()
|
||||
else:
|
||||
assert list(_weight.shape) == [num_embeddings, embedding_dim], \
|
||||
|
|
|
|||
|
|
@ -48,27 +48,23 @@ class Transformer(Module):
|
|||
def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,
|
||||
num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
|
||||
activation: str = "relu", custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None,
|
||||
layer_norm_eps: float = 1e-5, batch_first: bool = False,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
layer_norm_eps: float = 1e-5, batch_first: bool = False) -> None:
|
||||
super(Transformer, self).__init__()
|
||||
|
||||
if custom_encoder is not None:
|
||||
self.encoder = custom_encoder
|
||||
else:
|
||||
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
|
||||
activation, layer_norm_eps, batch_first,
|
||||
**factory_kwargs)
|
||||
encoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
||||
activation, layer_norm_eps, batch_first)
|
||||
encoder_norm = LayerNorm(d_model, eps=layer_norm_eps)
|
||||
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
||||
|
||||
if custom_decoder is not None:
|
||||
self.decoder = custom_decoder
|
||||
else:
|
||||
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
|
||||
activation, layer_norm_eps, batch_first,
|
||||
**factory_kwargs)
|
||||
decoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
||||
activation, layer_norm_eps, batch_first)
|
||||
decoder_norm = LayerNorm(d_model, eps=layer_norm_eps)
|
||||
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
|
||||
|
||||
self._reset_parameters()
|
||||
|
|
@ -283,19 +279,16 @@ class TransformerEncoderLayer(Module):
|
|||
__constants__ = ['batch_first']
|
||||
|
||||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
|
||||
layer_norm_eps=1e-5, batch_first=False,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
layer_norm_eps=1e-5, batch_first=False):
|
||||
super(TransformerEncoderLayer, self).__init__()
|
||||
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
|
||||
**factory_kwargs)
|
||||
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
|
||||
self.linear1 = Linear(d_model, dim_feedforward)
|
||||
self.dropout = Dropout(dropout)
|
||||
self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
|
||||
self.linear2 = Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
||||
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
||||
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps)
|
||||
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps)
|
||||
self.dropout1 = Dropout(dropout)
|
||||
self.dropout2 = Dropout(dropout)
|
||||
|
||||
|
|
@ -360,21 +353,18 @@ class TransformerDecoderLayer(Module):
|
|||
__constants__ = ['batch_first']
|
||||
|
||||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
|
||||
layer_norm_eps=1e-5, batch_first=False, device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
layer_norm_eps=1e-5, batch_first=False):
|
||||
super(TransformerDecoderLayer, self).__init__()
|
||||
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
|
||||
**factory_kwargs)
|
||||
self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
|
||||
**factory_kwargs)
|
||||
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
|
||||
self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
|
||||
self.linear1 = Linear(d_model, dim_feedforward)
|
||||
self.dropout = Dropout(dropout)
|
||||
self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
|
||||
self.linear2 = Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
||||
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
||||
self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
||||
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps)
|
||||
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps)
|
||||
self.norm3 = LayerNorm(d_model, eps=layer_norm_eps)
|
||||
self.dropout1 = Dropout(dropout)
|
||||
self.dropout2 = Dropout(dropout)
|
||||
self.dropout3 = Dropout(dropout)
|
||||
|
|
|
|||
|
|
@ -141,16 +141,12 @@ class UninitializedParameter(UninitializedTensorMixin, Parameter):
|
|||
will throw a runtime error. The only operations that can be performed on a uninitialized
|
||||
parameter are changing its datatype, moving it to a different device and
|
||||
converting it to a regular :class:`torch.nn.Parameter`.
|
||||
|
||||
The default device or dtype to use when the parameter is materialized can be set
|
||||
during construction using e.g. ``device='cuda'``.
|
||||
"""
|
||||
|
||||
cls_to_become = Parameter
|
||||
|
||||
def __new__(cls, requires_grad=True, device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
data = torch.tensor([], **factory_kwargs)
|
||||
def __new__(cls, requires_grad=True):
|
||||
data = torch.tensor([])
|
||||
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
|
||||
|
||||
|
|
@ -165,14 +161,10 @@ class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor):
|
|||
will throw a runtime error. The only operations that can be performed on a uninitialized
|
||||
parameter are changing its datatype, moving it to a different device and
|
||||
converting it to a regular :class:`torch.Tensor`.
|
||||
|
||||
The default device or dtype to use when the buffer is materialized can be set
|
||||
during construction using e.g. ``device='cuda'``.
|
||||
"""
|
||||
|
||||
cls_to_become = torch.Tensor
|
||||
|
||||
def __new__(cls, requires_grad=False, device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
data = torch.tensor([], **factory_kwargs)
|
||||
def __new__(cls, requires_grad=False):
|
||||
data = torch.tensor([])
|
||||
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
|
|
|
|||
|
|
@ -21,16 +21,13 @@ class Conv2d(nn.Conv2d):
|
|||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, dilation=1, groups=1,
|
||||
bias=True, padding_mode='zeros', qconfig=None,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
bias=True, padding_mode='zeros', qconfig=None):
|
||||
super().__init__(in_channels, out_channels, kernel_size,
|
||||
stride=stride, padding=padding, dilation=dilation,
|
||||
groups=groups, bias=bias, padding_mode=padding_mode,
|
||||
**factory_kwargs)
|
||||
groups=groups, bias=bias, padding_mode=padding_mode)
|
||||
assert qconfig, 'qconfig must be provided for QAT module'
|
||||
self.qconfig = qconfig
|
||||
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
|
||||
self.weight_fake_quant = qconfig.weight()
|
||||
|
||||
def forward(self, input):
|
||||
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
|
||||
|
|
@ -87,10 +84,7 @@ class Conv3d(nn.Conv3d):
|
|||
bias=True,
|
||||
padding_mode="zeros",
|
||||
qconfig=None,
|
||||
device=None,
|
||||
dtype=None
|
||||
) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
):
|
||||
super().__init__(
|
||||
in_channels,
|
||||
out_channels,
|
||||
|
|
@ -101,11 +95,10 @@ class Conv3d(nn.Conv3d):
|
|||
groups=groups,
|
||||
bias=bias,
|
||||
padding_mode=padding_mode,
|
||||
**factory_kwargs
|
||||
)
|
||||
assert qconfig, "qconfig must be provided for QAT module"
|
||||
self.qconfig = qconfig
|
||||
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
|
||||
self.weight_fake_quant = qconfig.weight()
|
||||
|
||||
def forward(self, input):
|
||||
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
|
||||
|
|
|
|||
|
|
@ -20,12 +20,11 @@ class Linear(nn.Linear):
|
|||
_FLOAT_MODULE = nn.Linear
|
||||
|
||||
def __init__(self, in_features, out_features, bias=True,
|
||||
qconfig=None, device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__(in_features, out_features, bias, **factory_kwargs)
|
||||
qconfig=None):
|
||||
super().__init__(in_features, out_features, bias)
|
||||
assert qconfig, 'qconfig must be provided for QAT module'
|
||||
self.qconfig = qconfig
|
||||
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
|
||||
self.weight_fake_quant = qconfig.weight()
|
||||
|
||||
def forward(self, input):
|
||||
return F.linear(input, self.weight_fake_quant(self.weight), self.bias)
|
||||
|
|
|
|||
|
|
@ -59,21 +59,18 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|||
def __init__(self, embed_dim: int, num_heads: int,
|
||||
dropout: float = 0., bias: bool = True,
|
||||
add_bias_kv: bool = False, add_zero_attn: bool = False,
|
||||
kdim: int = None, vdim: int = None, batch_first: bool = False,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
kdim: int = None, vdim: int = None, batch_first: bool = False):
|
||||
super(MultiheadAttention, self).__init__(embed_dim, num_heads, dropout,
|
||||
bias, add_bias_kv,
|
||||
add_zero_attn, kdim, vdim, batch_first,
|
||||
**factory_kwargs)
|
||||
self.linear_Q = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs)
|
||||
self.linear_K = nn.Linear(self.kdim, self.embed_dim, bias=bias, **factory_kwargs)
|
||||
self.linear_V = nn.Linear(self.vdim, self.embed_dim, bias=bias, **factory_kwargs)
|
||||
add_zero_attn, kdim, vdim, batch_first)
|
||||
self.linear_Q = nn.Linear(self.embed_dim, self.embed_dim, bias=bias)
|
||||
self.linear_K = nn.Linear(self.kdim, self.embed_dim, bias=bias)
|
||||
self.linear_V = nn.Linear(self.vdim, self.embed_dim, bias=bias)
|
||||
|
||||
# TODO: The use of the `_LinearWithBias` increases the quantization noise
|
||||
# The `out_proj` in the parent is ``_LinearWithBias`, so need to ignore
|
||||
# the type for mypy not to complain.
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs) # type: ignore
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias) # type: ignore
|
||||
|
||||
# Functionals
|
||||
self.q_scaling_product = nnq.FloatFunctional()
|
||||
|
|
|
|||
|
|
@ -29,16 +29,14 @@ class LSTMCell(torch.nn.Module):
|
|||
"""
|
||||
_FLOAT_MODULE = torch.nn.LSTMCell
|
||||
|
||||
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True):
|
||||
super().__init__()
|
||||
self.input_size = input_dim
|
||||
self.hidden_size = hidden_dim
|
||||
self.bias = bias
|
||||
|
||||
self.igates = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs)
|
||||
self.hgates = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs)
|
||||
self.igates = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=bias)
|
||||
self.hgates = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=bias)
|
||||
self.gates = torch.nn.quantized.FloatFunctional()
|
||||
|
||||
self.fgate_cx = torch.nn.quantized.FloatFunctional()
|
||||
|
|
@ -121,11 +119,9 @@ class _LSTMSingleLayer(torch.nn.Module):
|
|||
The difference between a layer and a cell is that the layer can process a
|
||||
sequence, while the cell only expects an instantaneous value.
|
||||
"""
|
||||
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True):
|
||||
super().__init__()
|
||||
self.cell = LSTMCell(input_dim, hidden_dim, bias=bias, **factory_kwargs)
|
||||
self.cell = LSTMCell(input_dim, hidden_dim, bias=bias)
|
||||
|
||||
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
|
||||
result = []
|
||||
|
|
@ -146,15 +142,13 @@ class _LSTMSingleLayer(torch.nn.Module):
|
|||
class _LSTMLayer(torch.nn.Module):
|
||||
r"""A single bi-directional LSTM layer."""
|
||||
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
|
||||
batch_first: bool = False, bidirectional: bool = False,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
batch_first: bool = False, bidirectional: bool = False):
|
||||
super().__init__()
|
||||
self.batch_first = batch_first
|
||||
self.bidirectional = bidirectional
|
||||
self.layer_fw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias, **factory_kwargs)
|
||||
self.layer_fw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias)
|
||||
if self.bidirectional:
|
||||
self.layer_bw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias, **factory_kwargs)
|
||||
self.layer_bw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias)
|
||||
|
||||
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
|
||||
if self.batch_first:
|
||||
|
|
@ -299,9 +293,7 @@ class LSTM(torch.nn.Module):
|
|||
def __init__(self, input_size: int, hidden_size: int,
|
||||
num_layers: int = 1, bias: bool = True,
|
||||
batch_first: bool = False, dropout: float = 0.,
|
||||
bidirectional: bool = False,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
bidirectional: bool = False):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
|
|
@ -330,12 +322,11 @@ class LSTM(torch.nn.Module):
|
|||
|
||||
layers = [_LSTMLayer(self.input_size, self.hidden_size,
|
||||
self.bias, batch_first=False,
|
||||
bidirectional=self.bidirectional, **factory_kwargs)]
|
||||
bidirectional=self.bidirectional)]
|
||||
for layer in range(1, num_layers):
|
||||
layers.append(_LSTMLayer(self.hidden_size, self.hidden_size,
|
||||
self.bias, batch_first=False,
|
||||
bidirectional=self.bidirectional,
|
||||
**factory_kwargs))
|
||||
bidirectional=self.bidirectional))
|
||||
self.layers = torch.nn.ModuleList(layers)
|
||||
|
||||
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
|
||||
|
|
|
|||
|
|
@ -37,13 +37,10 @@ class Quantize(torch.nn.Module):
|
|||
scale: torch.Tensor
|
||||
zero_point: torch.Tensor
|
||||
|
||||
def __init__(self, scale, zero_point, dtype, factory_kwargs=None):
|
||||
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
|
||||
def __init__(self, scale, zero_point, dtype):
|
||||
super(Quantize, self).__init__()
|
||||
self.register_buffer('scale', torch.tensor([scale], **factory_kwargs))
|
||||
self.register_buffer('zero_point',
|
||||
torch.tensor([zero_point], dtype=torch.long,
|
||||
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
|
||||
self.register_buffer('scale', torch.tensor([scale]))
|
||||
self.register_buffer('zero_point', torch.tensor([zero_point], dtype=torch.long))
|
||||
self.dtype = dtype
|
||||
|
||||
def forward(self, X):
|
||||
|
|
|
|||
|
|
@ -95,12 +95,10 @@ class LeakyReLU(torch.nn.LeakyReLU):
|
|||
zero_point: quantization zero point of the output tensor
|
||||
negative_slope: Controls the angle of the negative slope. Default: 1e-2
|
||||
"""
|
||||
def __init__(self, scale: float, zero_point: int, negative_slope: float = 1e-2,
|
||||
inplace: bool = False, device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
def __init__(self, scale: float, zero_point: int, negative_slope: float = 1e-2, inplace: bool = False):
|
||||
super().__init__(negative_slope, inplace)
|
||||
self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
|
||||
self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
|
||||
self.register_buffer('scale', torch.tensor(scale))
|
||||
self.register_buffer('zero_point', torch.tensor(zero_point))
|
||||
|
||||
def forward(self, input):
|
||||
return torch.ops.quantized.leaky_relu(
|
||||
|
|
|
|||
|
|
@ -6,9 +6,8 @@ class BatchNorm2d(torch.nn.BatchNorm2d):
|
|||
r"""This is the quantized version of :class:`~torch.nn.BatchNorm2d`.
|
||||
"""
|
||||
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super(BatchNorm2d, self).__init__(num_features, **factory_kwargs)
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1):
|
||||
super(BatchNorm2d, self).__init__(num_features)
|
||||
self.eps = eps
|
||||
self.scale = 1.0
|
||||
self.zero_point = 0
|
||||
|
|
@ -40,9 +39,8 @@ class BatchNorm3d(torch.nn.BatchNorm3d):
|
|||
r"""This is the quantized version of :class:`~torch.nn.BatchNorm3d`.
|
||||
"""
|
||||
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super(BatchNorm3d, self).__init__(num_features, **factory_kwargs)
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1):
|
||||
super(BatchNorm3d, self).__init__(num_features)
|
||||
self.eps = eps
|
||||
self.scale = 1.0
|
||||
self.zero_point = 0
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ def _reverse_repeat_padding(padding: List[int]) -> List[int]:
|
|||
class _ConvNd(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, dilation=1, groups=1, bias=True,
|
||||
padding_mode='zeros', device=None, dtype=None):
|
||||
padding_mode='zeros'):
|
||||
# All subclasses have this signature - See PR #49702s
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
@ -41,10 +41,7 @@ class _ConvNd(nn.Module):
|
|||
padding, dilation,
|
||||
transposed, output_padding,
|
||||
groups, bias,
|
||||
padding_mode='zeros',
|
||||
device=None,
|
||||
dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
padding_mode='zeros'):
|
||||
super(_ConvNd, self).__init__()
|
||||
|
||||
if in_channels % groups != 0:
|
||||
|
|
@ -70,11 +67,9 @@ class _ConvNd(nn.Module):
|
|||
weight_shape = [out_channels, in_channels // self.groups]
|
||||
qweight = torch._empty_affine_quantized(
|
||||
weight_shape + list(kernel_size),
|
||||
scale=1, zero_point=0, dtype=torch.qint8,
|
||||
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'})
|
||||
scale=1, zero_point=0, dtype=torch.qint8)
|
||||
bias_float = (
|
||||
torch.zeros(out_channels, dtype=torch.float,
|
||||
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'}) if bias else None)
|
||||
torch.zeros(out_channels, dtype=torch.float) if bias else None)
|
||||
|
||||
self.set_weight_bias(qweight, bias_float)
|
||||
self.scale = 1.0
|
||||
|
|
@ -280,10 +275,7 @@ class Conv1d(_ConvNd):
|
|||
dilation: _size_1_t = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = 'zeros',
|
||||
device=None,
|
||||
dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
padding_mode: str = 'zeros'):
|
||||
kernel_size = _pair_from_first(kernel_size)
|
||||
stride = _pair_from_first(stride)
|
||||
padding = _pair_from_first(padding)
|
||||
|
|
@ -293,7 +285,7 @@ class Conv1d(_ConvNd):
|
|||
# discussion on PR #49702
|
||||
super(Conv1d, self)._init(
|
||||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||
False, _single(0), groups, bias, padding_mode, **factory_kwargs)
|
||||
False, _single(0), groups, bias, padding_mode)
|
||||
|
||||
def _get_name(self):
|
||||
return 'QuantizedConv1d'
|
||||
|
|
@ -382,8 +374,7 @@ class Conv2d(_ConvNd):
|
|||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, dilation=1, groups=1, bias=True,
|
||||
padding_mode='zeros', device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
padding_mode='zeros'):
|
||||
kernel_size = _pair(kernel_size)
|
||||
stride = _pair(stride)
|
||||
padding = _pair(padding)
|
||||
|
|
@ -392,7 +383,7 @@ class Conv2d(_ConvNd):
|
|||
# discussion on PR #49702
|
||||
super(Conv2d, self)._init(
|
||||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||
False, _pair(0), groups, bias, padding_mode, **factory_kwargs)
|
||||
False, _pair(0), groups, bias, padding_mode)
|
||||
|
||||
def _get_name(self):
|
||||
return 'QuantizedConv2d'
|
||||
|
|
@ -479,9 +470,8 @@ class Conv3d(_ConvNd):
|
|||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, dilation=1, groups=1, bias=True,
|
||||
padding_mode='zeros', device=None, dtype=None):
|
||||
padding_mode='zeros'):
|
||||
assert padding_mode != 'reflect', "Conv3d does not support reflection padding"
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
kernel_size = _triple(kernel_size)
|
||||
stride = _triple(stride)
|
||||
padding = _triple(padding)
|
||||
|
|
@ -490,7 +480,7 @@ class Conv3d(_ConvNd):
|
|||
# discussion on PR #49702
|
||||
super(Conv3d, self)._init(
|
||||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||
False, _triple(0), groups, bias, padding_mode, **factory_kwargs)
|
||||
False, _triple(0), groups, bias, padding_mode)
|
||||
|
||||
def _get_name(self):
|
||||
return 'QuantizedConv3d'
|
||||
|
|
@ -543,16 +533,15 @@ class _ConvTransposeNd(_ConvNd):
|
|||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride,
|
||||
padding, dilation, transposed, output_padding,
|
||||
groups, bias, padding_mode, device=None, dtype=None):
|
||||
groups, bias, padding_mode):
|
||||
if padding_mode != 'zeros':
|
||||
raise ValueError('Only "zeros" padding mode is supported for {}'.format(self.__class__.__name__))
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
# Subclasses of _ConvNd need to call _init rather than __init__. See
|
||||
# discussion on PR #49702
|
||||
super(_ConvTransposeNd, self)._init(
|
||||
in_channels, out_channels, kernel_size, stride,
|
||||
padding, dilation, transposed, output_padding,
|
||||
groups, bias, padding_mode, **factory_kwargs)
|
||||
groups, bias, padding_mode)
|
||||
|
||||
def _input_padding(self, kernel_size: List[int], dilation: List[int], padding: List[int]) -> List[int]:
|
||||
res = torch.jit.annotate(List[int], [])
|
||||
|
|
@ -636,8 +625,7 @@ class ConvTranspose1d(_ConvTransposeNd):
|
|||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, output_padding=0, groups=1, bias=True,
|
||||
dilation=1, padding_mode='zeros', device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
dilation=1, padding_mode='zeros'):
|
||||
kernel_size = _pair(kernel_size)
|
||||
stride = _pair(stride)
|
||||
padding = _pair(padding)
|
||||
|
|
@ -646,7 +634,7 @@ class ConvTranspose1d(_ConvTransposeNd):
|
|||
|
||||
super(ConvTranspose1d, self).__init__(
|
||||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||
True, output_padding, groups, bias, padding_mode, **factory_kwargs)
|
||||
True, output_padding, groups, bias, padding_mode)
|
||||
|
||||
def _get_name(self):
|
||||
return 'QuantizedConvTranpose1d'
|
||||
|
|
@ -720,8 +708,7 @@ class ConvTranspose2d(_ConvTransposeNd):
|
|||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, output_padding=0, groups=1, bias=True,
|
||||
dilation=1, padding_mode='zeros', device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
dilation=1, padding_mode='zeros'):
|
||||
kernel_size = _pair(kernel_size)
|
||||
stride = _pair(stride)
|
||||
padding = _pair(padding)
|
||||
|
|
@ -730,7 +717,7 @@ class ConvTranspose2d(_ConvTransposeNd):
|
|||
|
||||
super(ConvTranspose2d, self).__init__(
|
||||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||
True, output_padding, groups, bias, padding_mode, **factory_kwargs)
|
||||
True, output_padding, groups, bias, padding_mode)
|
||||
|
||||
def _get_name(self):
|
||||
return 'QuantizedConvTranpose2d'
|
||||
|
|
@ -805,8 +792,7 @@ class ConvTranspose3d(_ConvTransposeNd):
|
|||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, output_padding=0, groups=1, bias=True,
|
||||
dilation=1, padding_mode='zeros', device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
dilation=1, padding_mode='zeros'):
|
||||
kernel_size = _pair(kernel_size)
|
||||
stride = _pair(stride)
|
||||
padding = _pair(padding)
|
||||
|
|
@ -815,7 +801,7 @@ class ConvTranspose3d(_ConvTransposeNd):
|
|||
|
||||
super(ConvTranspose3d, self).__init__(
|
||||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||
True, output_padding, groups, bias, padding_mode, **factory_kwargs)
|
||||
True, output_padding, groups, bias, padding_mode)
|
||||
|
||||
def _get_name(self):
|
||||
return 'QuantizedConvTranpose3d'
|
||||
|
|
|
|||
|
|
@ -11,15 +11,13 @@ class LayerNorm(torch.nn.LayerNorm):
|
|||
"""
|
||||
|
||||
def __init__(self, normalized_shape, weight, bias, scale, zero_point, eps=1e-5,
|
||||
elementwise_affine=True, device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
elementwise_affine=True):
|
||||
super(LayerNorm, self).__init__(
|
||||
normalized_shape, eps=eps, elementwise_affine=elementwise_affine,
|
||||
**factory_kwargs)
|
||||
normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
|
||||
self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
|
||||
self.register_buffer('scale', torch.tensor(scale))
|
||||
self.register_buffer('zero_point', torch.tensor(zero_point))
|
||||
|
||||
def forward(self, input):
|
||||
return torch.ops.quantized.layer_norm(
|
||||
|
|
@ -47,15 +45,12 @@ class GroupNorm(torch.nn.GroupNorm):
|
|||
"""
|
||||
__constants__ = ['num_groups', 'num_channels', 'eps', 'affine']
|
||||
|
||||
def __init__(self, num_groups, num_channels, weight, bias, scale, zero_point, eps=1e-5,
|
||||
affine=True, device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super(GroupNorm, self).__init__(num_groups, num_channels, eps, affine,
|
||||
**factory_kwargs)
|
||||
def __init__(self, num_groups, num_channels, weight, bias, scale, zero_point, eps=1e-5, affine=True):
|
||||
super(GroupNorm, self).__init__(num_groups, num_channels, eps, affine)
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
|
||||
self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
|
||||
self.register_buffer('scale', torch.tensor(scale))
|
||||
self.register_buffer('zero_point', torch.tensor(zero_point))
|
||||
|
||||
def forward(self, input):
|
||||
return torch.ops.quantized.group_norm(
|
||||
|
|
@ -83,14 +78,13 @@ class InstanceNorm1d(torch.nn.InstanceNorm1d):
|
|||
"""
|
||||
def __init__(self, num_features, weight, bias, scale, zero_point,
|
||||
eps=1e-5, momentum=0.1, affine=False,
|
||||
track_running_stats=False, device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
track_running_stats=False):
|
||||
super(InstanceNorm1d, self).__init__(
|
||||
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs)
|
||||
num_features, eps, momentum, affine, track_running_stats)
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
|
||||
self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
|
||||
self.register_buffer('scale', torch.tensor(scale))
|
||||
self.register_buffer('zero_point', torch.tensor(zero_point))
|
||||
|
||||
def forward(self, input):
|
||||
return torch.ops.quantized.instance_norm(
|
||||
|
|
@ -118,14 +112,13 @@ class InstanceNorm2d(torch.nn.InstanceNorm2d):
|
|||
"""
|
||||
def __init__(self, num_features, weight, bias, scale, zero_point,
|
||||
eps=1e-5, momentum=0.1, affine=False,
|
||||
track_running_stats=False, device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
track_running_stats=False):
|
||||
super(InstanceNorm2d, self).__init__(
|
||||
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs)
|
||||
num_features, eps, momentum, affine, track_running_stats)
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
|
||||
self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
|
||||
self.register_buffer('scale', torch.tensor(scale))
|
||||
self.register_buffer('zero_point', torch.tensor(zero_point))
|
||||
|
||||
def forward(self, input):
|
||||
return torch.ops.quantized.instance_norm(
|
||||
|
|
@ -153,14 +146,13 @@ class InstanceNorm3d(torch.nn.InstanceNorm3d):
|
|||
"""
|
||||
def __init__(self, num_features, weight, bias, scale, zero_point,
|
||||
eps=1e-5, momentum=0.1, affine=False,
|
||||
track_running_stats=False, device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
track_running_stats=False):
|
||||
super(InstanceNorm3d, self).__init__(
|
||||
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs)
|
||||
num_features, eps, momentum, affine, track_running_stats)
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
|
||||
self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
|
||||
self.register_buffer('scale', torch.tensor(scale))
|
||||
self.register_buffer('zero_point', torch.tensor(zero_point))
|
||||
|
||||
def forward(self, input):
|
||||
return torch.ops.quantized.instance_norm(
|
||||
|
|
|
|||
|
|
@ -114,8 +114,7 @@ class _ObserverBase(ObserverBase):
|
|||
eps: torch.Tensor
|
||||
|
||||
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
|
||||
reduce_range=False, quant_min=None, quant_max=None, factory_kwargs=None) -> None:
|
||||
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
|
||||
reduce_range=False, quant_min=None, quant_max=None):
|
||||
super(_ObserverBase, self).__init__(dtype=dtype)
|
||||
self.qscheme = qscheme
|
||||
if reduce_range:
|
||||
|
|
@ -124,7 +123,7 @@ class _ObserverBase(ObserverBase):
|
|||
reduce_range will be deprecated in a future release of PyTorch."
|
||||
)
|
||||
self.reduce_range = reduce_range
|
||||
self.register_buffer('eps', torch.tensor([torch.finfo(torch.float32).eps], **factory_kwargs))
|
||||
self.register_buffer('eps', torch.tensor([torch.finfo(torch.float32).eps]))
|
||||
assert self.qscheme in (
|
||||
torch.per_tensor_affine,
|
||||
torch.per_tensor_symmetric,
|
||||
|
|
@ -368,23 +367,21 @@ class MinMaxObserver(_ObserverBase):
|
|||
max_val: torch.Tensor
|
||||
|
||||
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
|
||||
reduce_range=False, quant_min=None, quant_max=None, factory_kwargs=None) -> None:
|
||||
|
||||
reduce_range=False, quant_min=None, quant_max=None):
|
||||
# For x86 quantized kernels, we need to ensure that the vpmaddubsw
|
||||
# instruction does not overflow. We allow for a reduce_range argument to
|
||||
# observers that reduces the quantized range to (0,127) or (-64, 63).
|
||||
# For more details see aten/src/ATen/native/quantized/cpu/qconv.cpp
|
||||
# This is not an optimal choice for non x86 backends as it loses a bit
|
||||
# of precision for activations.
|
||||
|
||||
super(MinMaxObserver, self).__init__(dtype=dtype,
|
||||
qscheme=qscheme,
|
||||
reduce_range=reduce_range,
|
||||
quant_min=quant_min,
|
||||
quant_max=quant_max,
|
||||
factory_kwargs=factory_kwargs)
|
||||
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
|
||||
self.register_buffer('min_val', torch.tensor(float('inf'), **factory_kwargs))
|
||||
self.register_buffer('max_val', torch.tensor(float('-inf'), **factory_kwargs))
|
||||
quant_max=quant_max)
|
||||
self.register_buffer('min_val', torch.tensor(float('inf')))
|
||||
self.register_buffer('max_val', torch.tensor(float('-inf')))
|
||||
if self.qscheme == torch.per_tensor_symmetric and \
|
||||
self.reduce_range and \
|
||||
self.dtype == torch.quint8:
|
||||
|
|
@ -459,14 +456,13 @@ class MovingAverageMinMaxObserver(MinMaxObserver):
|
|||
"""
|
||||
def __init__(self, averaging_constant=0.01, dtype=torch.quint8,
|
||||
qscheme=torch.per_tensor_affine, reduce_range=False,
|
||||
quant_min=None, quant_max=None, **kwargs) -> None:
|
||||
quant_min=None, quant_max=None):
|
||||
self.averaging_constant = averaging_constant
|
||||
super(MovingAverageMinMaxObserver, self).__init__(dtype=dtype,
|
||||
qscheme=qscheme,
|
||||
reduce_range=reduce_range,
|
||||
quant_min=quant_min,
|
||||
quant_max=quant_max,
|
||||
**kwargs)
|
||||
quant_max=quant_max)
|
||||
|
||||
def forward(self, x_orig):
|
||||
if x_orig.numel() == 0:
|
||||
|
|
@ -518,17 +514,15 @@ class PerChannelMinMaxObserver(_ObserverBase):
|
|||
|
||||
def __init__(self, ch_axis=0, dtype=torch.quint8,
|
||||
qscheme=torch.per_channel_affine, reduce_range=False,
|
||||
quant_min=None, quant_max=None, factory_kwargs=None) -> None:
|
||||
quant_min=None, quant_max=None):
|
||||
super(PerChannelMinMaxObserver, self).__init__(dtype=dtype,
|
||||
qscheme=qscheme,
|
||||
reduce_range=reduce_range,
|
||||
quant_min=quant_min,
|
||||
quant_max=quant_max,
|
||||
factory_kwargs=factory_kwargs)
|
||||
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
|
||||
quant_max=quant_max)
|
||||
self.ch_axis = ch_axis
|
||||
self.register_buffer('min_vals', torch.tensor([], **factory_kwargs))
|
||||
self.register_buffer('max_vals', torch.tensor([], **factory_kwargs))
|
||||
self.register_buffer('min_vals', torch.tensor([]))
|
||||
self.register_buffer('max_vals', torch.tensor([]))
|
||||
if (
|
||||
self.qscheme == torch.per_channel_symmetric
|
||||
and self.reduce_range
|
||||
|
|
@ -643,10 +637,10 @@ class MovingAveragePerChannelMinMaxObserver(PerChannelMinMaxObserver):
|
|||
|
||||
def __init__(self, averaging_constant=0.01, ch_axis=0, dtype=torch.quint8,
|
||||
qscheme=torch.per_channel_affine, reduce_range=False,
|
||||
quant_min=None, quant_max=None, **kwargs) -> None:
|
||||
quant_min=None, quant_max=None):
|
||||
super(MovingAveragePerChannelMinMaxObserver, self).__init__(
|
||||
ch_axis=ch_axis, dtype=dtype, qscheme=qscheme,
|
||||
reduce_range=reduce_range, quant_min=quant_min, quant_max=quant_max, **kwargs)
|
||||
reduce_range=reduce_range, quant_min=quant_min, quant_max=quant_max)
|
||||
self.averaging_constant = averaging_constant
|
||||
|
||||
def forward(self, x_orig):
|
||||
|
|
@ -709,19 +703,16 @@ class HistogramObserver(_ObserverBase):
|
|||
upsample_rate: int = 128,
|
||||
dtype: torch.dtype = torch.quint8,
|
||||
qscheme=torch.per_tensor_affine,
|
||||
reduce_range=False,
|
||||
factory_kwargs=None,
|
||||
) -> None:
|
||||
reduce_range=False
|
||||
):
|
||||
# bins: The number of bins used for histogram calculation.
|
||||
super(HistogramObserver, self).__init__(dtype=dtype,
|
||||
qscheme=qscheme,
|
||||
reduce_range=reduce_range,
|
||||
factory_kwargs=factory_kwargs)
|
||||
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
|
||||
reduce_range=reduce_range)
|
||||
self.bins = bins
|
||||
self.register_buffer('histogram', torch.zeros(self.bins, **factory_kwargs))
|
||||
self.register_buffer('min_val', torch.tensor(float('inf'), **factory_kwargs))
|
||||
self.register_buffer('max_val', torch.tensor(float('-inf'), **factory_kwargs))
|
||||
self.register_buffer('histogram', torch.zeros(self.bins))
|
||||
self.register_buffer('min_val', torch.tensor(float('inf')))
|
||||
self.register_buffer('max_val', torch.tensor(float('-inf')))
|
||||
self.dst_nbins = 2 ** torch.iinfo(self.dtype).bits
|
||||
self.upsample_rate = upsample_rate
|
||||
|
||||
|
|
@ -1019,7 +1010,7 @@ class PlaceholderObserver(ObserverBase):
|
|||
custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation
|
||||
(Can be used in Graph Mode Passes for special case ops).
|
||||
"""
|
||||
def __init__(self, dtype=torch.float32, custom_op_name="", compute_dtype=None) -> None:
|
||||
def __init__(self, dtype=torch.float32, custom_op_name="", compute_dtype=None):
|
||||
super(PlaceholderObserver, self).__init__(dtype=dtype)
|
||||
# dtype of input of the target operator, e.g. for dynamic quantization
|
||||
# ops, the dtype will be float32
|
||||
|
|
@ -1079,7 +1070,7 @@ class NoopObserver(ObserverBase):
|
|||
custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation
|
||||
(Can be used in Graph Mode Passes for special case ops).
|
||||
"""
|
||||
def __init__(self, dtype=torch.float16, custom_op_name="") -> None:
|
||||
def __init__(self, dtype=torch.float16, custom_op_name=""):
|
||||
super(NoopObserver, self).__init__(dtype=dtype)
|
||||
self.dtype = dtype
|
||||
self.custom_op = custom_op_name
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user