Revert D27855386: [pytorch][PR] Support factory kwargs in torch.nn modules

Test Plan: revert-hammer

Differential Revision:
D27855386 (40483acc51)

Original commit changeset: dabd505d2a04

fbshipit-source-id: f5bf3120d87861b30a8e1bf11977ad7d27cd8500
This commit is contained in:
Natalia Gimelshein 2021-04-19 20:06:17 -07:00 committed by Facebook GitHub Bot
parent b1282bc109
commit 92d24e3060
24 changed files with 235 additions and 879 deletions

View File

@ -62,7 +62,6 @@ TESTS = [
'test_linalg',
'test_logging',
'test_mkldnn',
'test_module_init',
'test_multiprocessing',
'test_multiprocessing_spawn',
'distributed/test_nccl',

View File

@ -1,428 +0,0 @@
import inspect
import torch
from unittest import mock
from unittest.mock import MagicMock, patch
from torch.testing import floating_types
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
from torch.testing._internal.common_utils import TestCase, run_tests
# Returns a database of args & kwargs that can be used to construct each module.
# Each entry is in class -> (args, kwargs) format.
# Example: torch.nn.Linear -> ([10, 5], {})
# TODO: Merge this in with the initial ModuleInfo implementation.
def build_constructor_arg_db():
return {
torch.nn.AdaptiveAvgPool1d: ((5,), {}),
torch.nn.AdaptiveAvgPool2d: ((5,), {}),
torch.nn.AdaptiveAvgPool3d: ((5,), {}),
torch.nn.AdaptiveLogSoftmaxWithLoss: ((100, 20, [5, 10, 15]), {}),
torch.nn.AdaptiveMaxPool1d: ((5,), {}),
torch.nn.AdaptiveMaxPool2d: ((5,), {}),
torch.nn.AdaptiveMaxPool3d: ((5,), {}),
torch.nn.AlphaDropout: ((), {}),
torch.nn.AvgPool1d: ((3,), {}),
torch.nn.AvgPool2d: ((3,), {}),
torch.nn.AvgPool3d: ((3,), {}),
torch.nn.BCELoss: ((), {}),
torch.nn.BCEWithLogitsLoss: ((), {}),
torch.nn.BatchNorm1d: ((5,), {}),
torch.nn.BatchNorm2d: ((5,), {}),
torch.nn.BatchNorm3d: ((5,), {}),
torch.nn.Bilinear: ((2, 3, 4), {}),
torch.nn.CELU: ((), {}),
torch.nn.CTCLoss: ((), {}),
torch.nn.ChannelShuffle: ((4,), {}),
torch.nn.ConstantPad1d: ((2, 3.5), {}),
torch.nn.ConstantPad2d: ((2, 3.5), {}),
torch.nn.ConstantPad3d: ((2, 3.5), {}),
torch.nn.Conv1d: ((3, 3, 3), {}),
torch.nn.Conv2d: ((3, 3, 3), {}),
torch.nn.Conv3d: ((3, 3, 3), {}),
torch.nn.ConvTranspose1d: ((3, 3, 3), {}),
torch.nn.ConvTranspose2d: ((3, 3, 3), {}),
torch.nn.ConvTranspose3d: ((3, 3, 3), {}),
torch.nn.CosineEmbeddingLoss: ((), {}),
torch.nn.CosineSimilarity: ((), {}),
torch.nn.CrossEntropyLoss: ((), {}),
torch.nn.CrossMapLRN2d: ((5,), {}),
torch.nn.Dropout2d: ((), {}),
torch.nn.Dropout3d: ((), {}),
torch.nn.Dropout: ((), {}),
torch.nn.ELU: ((), {}),
torch.nn.Embedding: ((10, 5), {}),
torch.nn.EmbeddingBag: ((10, 5), {}),
torch.nn.FeatureAlphaDropout: ((), {}),
torch.nn.Flatten: ((), {}),
torch.nn.Fold: ((5, 2), {}),
torch.nn.FractionalMaxPool2d: ((5, 2), {}),
torch.nn.FractionalMaxPool3d: ((5, 2), {}),
torch.nn.GELU: ((), {}),
torch.nn.GLU: ((), {}),
torch.nn.GRU: ((5, 10), {}),
torch.nn.GRUCell: ((5, 10), {}),
torch.nn.GaussianNLLLoss: ((), {}),
torch.nn.GroupNorm: ((3, 6, 1e-5, True), {}),
torch.nn.Hardshrink: ((), {}),
torch.nn.Hardsigmoid: ((), {}),
torch.nn.Hardswish: ((), {}),
torch.nn.Hardtanh: ((), {}),
torch.nn.HingeEmbeddingLoss: ((), {}),
torch.nn.HuberLoss: ((), {}),
torch.nn.Identity: ((), {}),
torch.nn.InstanceNorm1d: ((5, 1e-5, 0.1, True), {}),
torch.nn.InstanceNorm2d: ((5, 1e-5, 0.1, True), {}),
torch.nn.InstanceNorm3d: ((5, 1e-5, 0.1, True), {}),
torch.nn.KLDivLoss: ((), {}),
torch.nn.L1Loss: ((), {}),
torch.nn.LPPool1d: ((2, 3), {}),
torch.nn.LPPool2d: ((2, 3), {}),
torch.nn.LSTM: ((5, 10), {}),
torch.nn.LSTMCell: ((5, 10), {}),
torch.nn.LayerNorm: ((2,), {}),
torch.nn.LazyBatchNorm1d: ((), {}),
torch.nn.LazyBatchNorm2d: ((), {}),
torch.nn.LazyBatchNorm3d: ((), {}),
torch.nn.LazyConv1d: ((5, 2), {}),
torch.nn.LazyConv2d: ((5, 2), {}),
torch.nn.LazyConv3d: ((5, 2), {}),
torch.nn.LazyConvTranspose1d: ((5, 2), {}),
torch.nn.LazyConvTranspose2d: ((5, 2), {}),
torch.nn.LazyConvTranspose3d: ((5, 2), {}),
torch.nn.LazyLinear: ((5,), {}),
torch.nn.LeakyReLU: ((), {}),
torch.nn.Linear: ((10, 5), {}),
torch.nn.LocalResponseNorm: ((2,), {}),
torch.nn.LogSigmoid: ((), {}),
torch.nn.LogSoftmax: ((), {}),
torch.nn.MSELoss: ((), {}),
torch.nn.MarginRankingLoss: ((), {}),
torch.nn.MaxPool1d: ((3,), {}),
torch.nn.MaxPool2d: ((3,), {}),
torch.nn.MaxPool3d: ((3,), {}),
torch.nn.MaxUnpool1d: ((5,), {}),
torch.nn.MaxUnpool2d: ((5,), {}),
torch.nn.MaxUnpool3d: ((5,), {}),
torch.nn.ModuleDict: ((), {}),
torch.nn.ModuleList: ((), {}),
torch.nn.MultiLabelMarginLoss: ((), {}),
torch.nn.MultiLabelSoftMarginLoss: ((), {}),
torch.nn.MultiMarginLoss: ((), {}),
torch.nn.MultiheadAttention: ((100, 2), {}),
torch.nn.NLLLoss2d: ((), {}),
torch.nn.NLLLoss: ((), {}),
torch.nn.PReLU: ((), {}),
torch.nn.PairwiseDistance: ((), {}),
torch.nn.ParameterDict: ((), {}),
torch.nn.ParameterList: ((), {}),
torch.nn.PixelShuffle: ((2,), {}),
torch.nn.PixelUnshuffle: ((2,), {}),
torch.nn.PoissonNLLLoss: ((), {}),
torch.nn.RNN: ((5, 10), {}),
torch.nn.RNNBase: (('LSTM', 5, 10), {}),
torch.nn.RNNCell: ((5, 10), {}),
torch.nn.RNNCellBase: ((5, 10, True, 2), {}),
torch.nn.RReLU: ((), {}),
torch.nn.ReLU6: ((), {}),
torch.nn.ReLU: ((), {}),
torch.nn.ReflectionPad1d: ((2,), {}),
torch.nn.ReflectionPad2d: ((2,), {}),
torch.nn.ReplicationPad1d: ((2,), {}),
torch.nn.ReplicationPad2d: ((2,), {}),
torch.nn.ReplicationPad3d: ((2,), {}),
torch.nn.SELU: ((), {}),
torch.nn.Sequential: ((), {}),
torch.nn.SiLU: ((), {}),
torch.nn.Sigmoid: ((), {}),
torch.nn.SmoothL1Loss: ((), {}),
torch.nn.SoftMarginLoss: ((), {}),
torch.nn.Softmax2d: ((), {}),
torch.nn.Softmax: ((), {}),
torch.nn.Softmin: ((), {}),
torch.nn.Softplus: ((), {}),
torch.nn.Softshrink: ((), {}),
torch.nn.Softsign: ((), {}),
torch.nn.SyncBatchNorm: ((5,), {}),
torch.nn.Tanh: ((), {}),
torch.nn.Tanhshrink: ((), {}),
torch.nn.Threshold: ((0.1, 20), {}),
torch.nn.Transformer: ((), {}),
torch.nn.TransformerDecoder: ((torch.nn.TransformerDecoderLayer, 3), {}),
torch.nn.TransformerDecoderLayer: ((10, 2), {}),
torch.nn.TransformerEncoder: ((torch.nn.TransformerEncoderLayer, 3), {}),
torch.nn.TransformerEncoderLayer: ((10, 2), {}),
torch.nn.TripletMarginLoss: ((), {}),
torch.nn.TripletMarginWithDistanceLoss: ((), {}),
torch.nn.Unflatten: ((1, (2, 5, 5)), {}),
torch.nn.Unfold: ((3,), {}),
torch.nn.Upsample: ((), {}),
torch.nn.UpsamplingBilinear2d: ((), {}),
torch.nn.UpsamplingNearest2d: ((), {}),
torch.nn.ZeroPad2d: ((0,), {}),
torch.nn.qat.Conv2d: ((3, 3, 3), {
'qconfig': torch.quantization.default_qconfig,
}),
torch.nn.qat.Conv3d: ((3, 3, 3), {
'qconfig': torch.quantization.default_qconfig,
}),
torch.nn.qat.Linear: ((5, 2), {
'qconfig': torch.quantization.default_qconfig,
}),
torch.nn.quantizable.LSTM: ((5, 6), {}),
torch.nn.quantizable.LSTMCell: ((5, 6), {}),
torch.nn.quantizable.MultiheadAttention: ((10, 2), {}),
torch.nn.quantized.BatchNorm2d: ((2,), {}),
torch.nn.quantized.BatchNorm3d: ((2,), {}),
torch.nn.quantized.Conv1d: ((3, 3, 3), {}),
torch.nn.quantized.Conv2d: ((3, 3, 3), {}),
torch.nn.quantized.Conv3d: ((3, 3, 3), {}),
torch.nn.quantized.ConvTranspose1d: ((3, 3, 3), {}),
torch.nn.quantized.ConvTranspose2d: ((3, 3, 3), {}),
torch.nn.quantized.ConvTranspose3d: ((16, 33, (3, 3, 5)), {
'stride': (2, 1, 1),
'padding': (4, 2, 2),
'output_padding': (2, 2, 2),
'dilation': (1, 1, 1),
}),
torch.nn.quantized.DeQuantize: ((), {}),
torch.nn.quantized.ELU: ((0.01, 0), {}),
torch.nn.quantized.Embedding: ((10, 3), {
'factory_kwargs': {},
}),
torch.nn.quantized.EmbeddingBag: ((10, 3), {
'factory_kwargs': {},
}),
torch.nn.quantized.GroupNorm: ((2, 3, torch.nn.Parameter(torch.tensor(2.)),
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
torch.nn.quantized.Hardswish: ((0.1, 0,), {}),
torch.nn.quantized.InstanceNorm1d: ((2, torch.nn.Parameter(torch.tensor(2.)),
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
torch.nn.quantized.InstanceNorm2d: ((2, torch.nn.Parameter(torch.tensor(2.)),
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
torch.nn.quantized.InstanceNorm3d: ((2, torch.nn.Parameter(torch.tensor(2.)),
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
torch.nn.quantized.LayerNorm: ((2, torch.nn.Parameter(torch.tensor(2.)),
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
torch.nn.quantized.LeakyReLU: ((0.01, 0), {}),
torch.nn.quantized.Linear: ((5, 2), {
'factory_kwargs': {},
}),
torch.nn.quantized.MaxPool2d: ((3,), {}),
torch.nn.quantized.Quantize: ((0.1, 0), {
'dtype': torch.int16,
'factory_kwargs': {},
}),
torch.nn.quantized.ReLU6: ((), {}),
torch.nn.quantized.Sigmoid: ((0.1, 0), {}),
torch.nn.quantized.FloatFunctional: ((), {}),
torch.nn.quantized.FXFloatFunctional: ((), {}),
torch.nn.quantized.QFunctional: ((), {}),
}
# Instantiates the given class with the given args, kwargs, optionally on a given device.
def instantiate_class(cls, args, kwargs, extra_kwargs):
return cls(*args, **kwargs) if extra_kwargs is None else cls(*args, **kwargs, **extra_kwargs)
# Returns a function that calls the real implementation of a method
# in addition to passing args to a mock object.
def mock_wrapper(method):
mock = MagicMock()
def wrapper(self, *args, **kwargs):
mock(*args, **kwargs)
return method(self, *args, **kwargs)
wrapper.mock = mock
return wrapper
# Returns a set of args / kwargs that can be used to construct the module.
def get_example_args(module_cls, constructor_arg_db, extra_kwargs=None):
assert module_cls in constructor_arg_db, \
f"No entry for {module_cls} in the constructor arg DB. Please add it to pass these tests."
args, kwargs = constructor_arg_db[module_cls]
extra_kwargs = {} if extra_kwargs is None else extra_kwargs
# Recursively instantiate args / kwargs that are class objects.
args = [instantiate_class(arg, *get_example_args(arg, constructor_arg_db), extra_kwargs=extra_kwargs)
if inspect.isclass(arg) else torch.nn.Parameter(arg.to(**extra_kwargs))
if isinstance(arg, torch.nn.Parameter) else arg for arg in args]
kwargs = {k: instantiate_class(v, *get_example_args(v, constructor_arg_db), extra_kwargs=extra_kwargs)
if inspect.isclass(v) else torch.nn.Parameter(v.to(*extra_kwargs))
if isinstance(v, torch.nn.Parameter) else v for k, v in kwargs.items()}
kwargs.update(extra_kwargs)
return args, kwargs
def generate_test_func(test_cls, module_cls, constructor_arg_db,
verify_kwargs=True, module_is_lazy=False, check_nonexistent_arg=True):
# Generate a function for testing the given module.
@dtypes(*floating_types())
def run_test(test_cls, device, dtype, module_cls=module_cls):
# Check if this module creates parameters or registers buffers.
# The mock magic here passes through to the real Parameter / register_buffer
# logic and is only used to check for calls.
args, kwargs = get_example_args(module_cls, constructor_arg_db)
# Some modules need to pass factory_kwargs so as not to conflict with existing args such as dtype.
module_needs_factory_kwargs = 'factory_kwargs' in kwargs
if module_needs_factory_kwargs:
del kwargs['factory_kwargs']
extra_kwargs = {
'factory_kwargs': {
'device': device,
'dtype': dtype,
}
}
else:
extra_kwargs = {
'device': device,
'dtype': dtype,
}
parameter_new = mock_wrapper(torch.nn.Parameter.__new__)
with patch.object(torch.nn.Parameter, '__new__', parameter_new):
register_buffer = mock_wrapper(torch.nn.Module.register_buffer)
with patch.object(torch.nn.Module, 'register_buffer', register_buffer):
m = module_cls(*args, **kwargs)
module_creates_params_or_buffers = parameter_new.mock.called or register_buffer.mock.called
# == Verify factory kwargs are supported. ==
if verify_kwargs and module_creates_params_or_buffers:
args, kwargs = get_example_args(module_cls, constructor_arg_db,
extra_kwargs=extra_kwargs)
if module_is_lazy:
# Ensure device and dtype are passed to all UninitializedParameters and UninitializedBuffers.
uninit_param_new = mock_wrapper(torch.nn.UninitializedParameter.__new__)
with patch.object(torch.nn.UninitializedParameter, '__new__', uninit_param_new):
uninit_buffer_new = mock_wrapper(torch.nn.UninitializedBuffer.__new__)
with patch.object(torch.nn.UninitializedBuffer, '__new__', uninit_buffer_new):
m = module_cls(*args, **kwargs)
uninit_param_new.mock.assert_has_calls(
[mock.call(device=device, dtype=dtype) for _ in uninit_param_new.mock.mock_calls])
uninit_buffer_new.mock.assert_has_calls(
[mock.call(device=device, dtype=dtype) for _ in uninit_buffer_new.mock.mock_calls])
else:
# Check device placement and dtype for parameters and buffers.
# Only verify floating point dtypes since that's what the kwarg applies to.
# Note that dtype verification is also skipped if the module requires factory_kwargs.
m = module_cls(*args, **kwargs)
for name, param in m.named_parameters():
test_cls.assertEqual(
str(param.device), device,
f'Parameter {name} is on {param.device.type} instead of the expected device {device}')
if param.dtype.is_floating_point and not module_needs_factory_kwargs:
test_cls.assertEqual(
param.dtype, dtype,
f'Parameter {name} is of dtype {param.dtype} instead of the expected dtype {dtype}')
for name, buffer in m.named_buffers():
test_cls.assertEqual(
str(buffer.device), device,
f'Buffer {name} is on {buffer.device.type} instead of the expected device {device}')
if buffer.dtype.is_floating_point and not module_needs_factory_kwargs:
test_cls.assertEqual(
buffer.dtype, dtype,
f'Buffer {name} is of dtype {buffer.dtype} instead of the expected dtype {dtype}')
# == Verify passing a nonexistent arg errors out. ==
if check_nonexistent_arg:
with test_cls.assertRaises(TypeError):
m = module_cls(*args, **kwargs, nonexistent_arg='foo')
return run_test
def generate_tests(test_cls, constructor_arg_db):
# test all modules underneath these namespaces...
NAMESPACES = [
torch.nn,
torch.nn.qat,
torch.nn.quantizable,
torch.nn.quantized,
]
# ...except these
MODULES_TO_SKIP = {
torch.nn.Module,
torch.nn.Container, # deprecated
torch.nn.NLLLoss2d, # deprecated
torch.nn.quantized._ConvNd # base class in __all__ for some reason
}
# no need to support kwargs for these modules even though
# they have parameters / buffers because they are passed in
# already instantiated
MODULES_WITHOUT_KWARGS_SUPPORT = {
torch.nn.BCELoss,
torch.nn.BCEWithLogitsLoss,
torch.nn.CrossEntropyLoss,
torch.nn.FractionalMaxPool2d,
torch.nn.FractionalMaxPool3d,
torch.nn.MultiLabelSoftMarginLoss,
torch.nn.MultiMarginLoss,
torch.nn.NLLLoss,
torch.nn.TransformerDecoder,
torch.nn.TransformerEncoder,
}
# modules that supported kwargs before
MODULES_WITH_PREVIOUS_KWARGS = {
torch.nn.Identity,
}
# lazy modules don't instantiate parameters right away
LAZY_MODULES = {
torch.nn.LazyBatchNorm1d,
torch.nn.LazyBatchNorm2d,
torch.nn.LazyBatchNorm3d,
torch.nn.LazyConv1d,
torch.nn.LazyConv2d,
torch.nn.LazyConv3d,
torch.nn.LazyConvTranspose1d,
torch.nn.LazyConvTranspose2d,
torch.nn.LazyConvTranspose3d,
torch.nn.LazyConvTranspose3d,
torch.nn.LazyLinear,
}
# these modules requires FBGEMM backend to instantiate
MODULES_THAT_REQUIRE_FBGEMM = {
torch.nn.quantized.Conv1d,
torch.nn.quantized.Conv2d,
torch.nn.quantized.Conv3d,
torch.nn.quantized.ConvTranspose1d,
torch.nn.quantized.ConvTranspose2d,
torch.nn.quantized.ConvTranspose3d,
torch.nn.quantized.Linear,
}
for namespace in NAMESPACES:
# the "nn" in "torch.nn"
namespace_basename = namespace.__name__.split('.')[-1]
for module_name in namespace.modules.__all__:
# class object for this module (e.g. torch.nn.Linear)
module_cls = getattr(namespace.modules, module_name)
if module_cls in MODULES_TO_SKIP:
continue
verify_kwargs = module_cls not in MODULES_WITHOUT_KWARGS_SUPPORT
module_is_lazy = module_cls in LAZY_MODULES
check_nonexistent_arg = module_cls not in MODULES_WITH_PREVIOUS_KWARGS
# Generate a function for testing this module and setattr it onto the test class.
run_test = generate_test_func(test_cls, module_cls, constructor_arg_db,
verify_kwargs=verify_kwargs,
module_is_lazy=module_is_lazy,
check_nonexistent_arg=check_nonexistent_arg)
test_name = f'test_{namespace_basename}_{module_name}'
if module_cls in MODULES_THAT_REQUIRE_FBGEMM:
run_test = skipIfNoFBGEMM(run_test)
setattr(TestModuleInit, test_name, run_test)
class TestModuleInit(TestCase):
_ignore_not_implemented_error = False
generate_tests(TestModuleInit, build_constructor_arg_db())
instantiate_device_type_tests(TestModuleInit, globals())
if __name__ == '__main__':
run_tests()

View File

@ -3,44 +3,3 @@ from .parameter import Parameter, UninitializedParameter, UninitializedBuffer
from .parallel import DataParallel
from . import init
from . import utils
def factory_kwargs(kwargs):
r"""
Given kwargs, returns a canonicalized dict of factory kwargs that can be directly passed
to factory functions like torch.empty, or errors if unrecognized kwargs are present.
This function makes it simple to write code like this::
class MyModule(nn.Module):
def __init__(self, **kwargs):
factory_kwargs = torch.nn.factory_kwargs(kwargs)
self.weight = Parameter(torch.empty(10, **factory_kwargs))
Why should you use this function instead of just passing `kwargs` along directly?
1. This function does error validation, so if there are unexpected kwargs we will
immediately report an error, instead of deferring it to the factory call
2. This function supports a special `factory_kwargs` argument, which can be used to
explicitly specify a kwarg to be used for factory functions, in the event one of the
factory kwargs conflicts with an already existing argument in the signature (e.g.
in the signature ``def f(dtype, **kwargs)``, you can specify ``dtype`` for factory
functions, as distinct from the dtype argument, by saying
``f(dtype1, factory_kwargs={"dtype": dtype2})``)
"""
if kwargs is None:
return {}
simple_keys = {"device", "dtype", "memory_format"}
expected_keys = simple_keys | {"factory_kwargs"}
if not kwargs.keys() <= expected_keys:
raise TypeError(f"unexpected kwargs {kwargs.keys() - expected_keys}")
# guarantee no input kwargs is untouched
r = dict(kwargs.get("factory_kwargs", {}))
for k in simple_keys:
if k in kwargs:
if k in r:
raise TypeError(f"{k} specified twice, in **kwargs and in factory_kwargs")
r[k] = kwargs[k]
return r

View File

@ -872,8 +872,7 @@ class MultiheadAttention(Module):
bias_v: Optional[torch.Tensor]
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,
kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
kdim=None, vdim=None, batch_first=False):
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
@ -887,25 +886,25 @@ class MultiheadAttention(Module):
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
if self._qkv_same_embed_dim is False:
self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))
self.q_proj_weight = Parameter(torch.empty(embed_dim, embed_dim))
self.k_proj_weight = Parameter(torch.empty(embed_dim, self.kdim))
self.v_proj_weight = Parameter(torch.empty(embed_dim, self.vdim))
self.register_parameter('in_proj_weight', None)
else:
self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
self.register_parameter('q_proj_weight', None)
self.register_parameter('k_proj_weight', None)
self.register_parameter('v_proj_weight', None)
if bias:
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
else:
self.register_parameter('in_proj_bias', None)
self.out_proj = Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
self.out_proj = Linear(embed_dim, embed_dim, bias=bias)
if add_bias_kv:
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
else:
self.bias_k = self.bias_v = None
@ -1058,12 +1057,10 @@ class PReLU(Module):
__constants__ = ['num_parameters']
num_parameters: int
def __init__(self, num_parameters: int = 1, init: float = 0.25,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
def __init__(self, num_parameters: int = 1, init: float = 0.25) -> None:
self.num_parameters = num_parameters
super(PReLU, self).__init__()
self.weight = Parameter(torch.empty(num_parameters, **factory_kwargs).fill_(init))
self.weight = Parameter(torch.empty(num_parameters).fill_(init))
def forward(self, input: Tensor) -> Tensor:
return F.prelu(input, self.weight)

View File

@ -115,11 +115,8 @@ class AdaptiveLogSoftmaxWithLoss(Module):
n_classes: int,
cutoffs: Sequence[int],
div_value: float = 4.,
head_bias: bool = False,
device=None,
dtype=None
head_bias: bool = False
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(AdaptiveLogSoftmaxWithLoss, self).__init__()
cutoffs = list(cutoffs)
@ -144,8 +141,7 @@ class AdaptiveLogSoftmaxWithLoss(Module):
self.n_clusters = len(self.cutoffs) - 1
self.head_size = self.shortlist_size + self.n_clusters
self.head = Linear(self.in_features, self.head_size, bias=self.head_bias,
**factory_kwargs)
self.head = Linear(self.in_features, self.head_size, bias=self.head_bias)
self.tail = ModuleList()
for i in range(self.n_clusters):
@ -154,8 +150,8 @@ class AdaptiveLogSoftmaxWithLoss(Module):
osz = self.cutoffs[i + 1] - self.cutoffs[i]
projection = Sequential(
Linear(self.in_features, hsz, bias=False, **factory_kwargs),
Linear(hsz, osz, bias=False, **factory_kwargs),
Linear(self.in_features, hsz, bias=False),
Linear(hsz, osz, bias=False)
)
self.tail.append(projection)

View File

@ -31,10 +31,7 @@ class _NormBase(Module):
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(_NormBase, self).__init__()
self.num_features = num_features
self.eps = eps
@ -42,17 +39,17 @@ class _NormBase(Module):
self.affine = affine
self.track_running_stats = track_running_stats
if self.affine:
self.weight = Parameter(torch.empty(num_features, **factory_kwargs))
self.bias = Parameter(torch.empty(num_features, **factory_kwargs))
self.weight = Parameter(torch.empty(num_features))
self.bias = Parameter(torch.empty(num_features))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs))
self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs))
self.register_buffer('num_batches_tracked',
torch.tensor(0, dtype=torch.long,
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features))
self.register_buffer(
"num_batches_tracked", torch.tensor(0, dtype=torch.long)
)
else:
self.register_buffer("running_mean", None)
self.register_buffer("running_var", None)
@ -120,12 +117,9 @@ class _BatchNorm(_NormBase):
momentum=0.1,
affine=True,
track_running_stats=True,
device=None,
dtype=None
):
factory_kwargs = {'device': device, 'dtype': dtype}
super(_BatchNorm, self).__init__(
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
num_features, eps, momentum, affine, track_running_stats
)
def forward(self, input: Tensor) -> Tensor:
@ -184,9 +178,7 @@ class _LazyBatchNorm(LazyModuleMixin, _BatchNorm):
weight: UninitializedParameter # type: ignore[assignment]
bias: UninitializedParameter # type: ignore[assignment]
def __init__(self, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
def __init__(self, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
super(_LazyBatchNorm, self).__init__(
# affine and track_running_stats are hardcoded to False to
# avoid creating tensors that will soon be overwritten.
@ -195,18 +187,16 @@ class _LazyBatchNorm(LazyModuleMixin, _BatchNorm):
momentum,
False,
False,
**factory_kwargs,
)
self.affine = affine
self.track_running_stats = track_running_stats
if self.affine:
self.weight = UninitializedParameter(**factory_kwargs)
self.bias = UninitializedParameter(**factory_kwargs)
self.weight = UninitializedParameter()
self.bias = UninitializedParameter()
if self.track_running_stats:
self.running_mean = UninitializedBuffer(**factory_kwargs)
self.running_var = UninitializedBuffer(**factory_kwargs)
self.num_batches_tracked = torch.tensor(
0, dtype=torch.long, **{k: v for k, v in factory_kwargs.items() if k != 'dtype'})
self.running_mean = UninitializedBuffer()
self.running_var = UninitializedBuffer()
self.num_batches_tracked = torch.tensor(0, dtype=torch.long)
def reset_parameters(self) -> None:
if not self.has_uninitialized_params() and self.num_features != 0:
@ -650,12 +640,9 @@ class SyncBatchNorm(_BatchNorm):
affine: bool = True,
track_running_stats: bool = True,
process_group: Optional[Any] = None,
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(SyncBatchNorm, self).__init__(
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
num_features, eps, momentum, affine, track_running_stats
)
self.process_group = process_group
# gpu_size is set through DistributedDataParallel initialization. This is to ensure that SyncBatchNorm is used

View File

@ -75,10 +75,7 @@ class _ConvNd(Module):
output_padding: Tuple[int, ...],
groups: int,
bias: bool,
padding_mode: str,
device=None,
dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
padding_mode: str) -> None:
super(_ConvNd, self).__init__()
if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups')
@ -126,15 +123,14 @@ class _ConvNd(Module):
if transposed:
self.weight = Parameter(torch.empty(
(in_channels, out_channels // groups, *kernel_size), **factory_kwargs))
in_channels, out_channels // groups, *kernel_size))
else:
self.weight = Parameter(torch.empty(
(out_channels, in_channels // groups, *kernel_size), **factory_kwargs))
out_channels, in_channels // groups, *kernel_size))
if bias:
self.bias = Parameter(torch.empty(out_channels, **factory_kwargs))
self.bias = Parameter(torch.empty(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self) -> None:
@ -271,11 +267,8 @@ class Conv1d(_ConvNd):
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros', # TODO: refine this type
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
padding_mode: str = 'zeros' # TODO: refine this type
):
# we create new variables below to make mypy happy since kernel_size has
# type Union[int, Tuple[int]] and kernel_size_ has type Tuple[int]
kernel_size_ = _single(kernel_size)
@ -284,7 +277,7 @@ class Conv1d(_ConvNd):
dilation_ = _single(dilation)
super(Conv1d, self).__init__(
in_channels, out_channels, kernel_size_, stride_, padding_, dilation_,
False, _single(0), groups, bias, padding_mode, **factory_kwargs)
False, _single(0), groups, bias, padding_mode)
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
if self.padding_mode != 'zeros':
@ -418,18 +411,15 @@ class Conv2d(_ConvNd):
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros', # TODO: refine this type
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
padding_mode: str = 'zeros' # TODO: refine this type
):
kernel_size_ = _pair(kernel_size)
stride_ = _pair(stride)
padding_ = padding if isinstance(padding, str) else _pair(padding)
dilation_ = _pair(dilation)
super(Conv2d, self).__init__(
in_channels, out_channels, kernel_size_, stride_, padding_, dilation_,
False, _pair(0), groups, bias, padding_mode, **factory_kwargs)
False, _pair(0), groups, bias, padding_mode)
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
if self.padding_mode != 'zeros':
@ -553,18 +543,15 @@ class Conv3d(_ConvNd):
dilation: _size_3_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
padding_mode: str = 'zeros'
):
kernel_size_ = _triple(kernel_size)
stride_ = _triple(stride)
padding_ = padding if isinstance(padding, str) else _triple(padding)
dilation_ = _triple(dilation)
super(Conv3d, self).__init__(
in_channels, out_channels, kernel_size_, stride_, padding_, dilation_,
False, _triple(0), groups, bias, padding_mode, **factory_kwargs)
False, _triple(0), groups, bias, padding_mode)
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
if self.padding_mode != "zeros":
@ -591,15 +578,14 @@ class Conv3d(_ConvNd):
class _ConvTransposeNd(_ConvNd):
def __init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding,
groups, bias, padding_mode, device=None, dtype=None) -> None:
groups, bias, padding_mode):
if padding_mode != 'zeros':
raise ValueError('Only "zeros" padding mode is supported for {}'.format(self.__class__.__name__))
factory_kwargs = {'device': device, 'dtype': dtype}
super(_ConvTransposeNd, self).__init__(
in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding,
groups, bias, padding_mode, **factory_kwargs)
groups, bias, padding_mode)
# dilation being an optional parameter is for backwards
# compatibility
@ -741,11 +727,8 @@ class ConvTranspose1d(_ConvTransposeNd):
groups: int = 1,
bias: bool = True,
dilation: _size_1_t = 1,
padding_mode: str = 'zeros',
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
padding_mode: str = 'zeros'
):
kernel_size = _single(kernel_size)
stride = _single(stride)
padding = _single(padding)
@ -753,7 +736,7 @@ class ConvTranspose1d(_ConvTransposeNd):
output_padding = _single(output_padding)
super(ConvTranspose1d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
True, output_padding, groups, bias, padding_mode, **factory_kwargs)
True, output_padding, groups, bias, padding_mode)
def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
if self.padding_mode != 'zeros':
@ -889,11 +872,8 @@ class ConvTranspose2d(_ConvTransposeNd):
groups: int = 1,
bias: bool = True,
dilation: int = 1,
padding_mode: str = 'zeros',
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
padding_mode: str = 'zeros'
):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
@ -901,7 +881,7 @@ class ConvTranspose2d(_ConvTransposeNd):
output_padding = _pair(output_padding)
super(ConvTranspose2d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
True, output_padding, groups, bias, padding_mode, **factory_kwargs)
True, output_padding, groups, bias, padding_mode)
def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
if self.padding_mode != 'zeros':
@ -1034,11 +1014,8 @@ class ConvTranspose3d(_ConvTransposeNd):
groups: int = 1,
bias: bool = True,
dilation: _size_3_t = 1,
padding_mode: str = 'zeros',
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
padding_mode: str = 'zeros'
):
kernel_size = _triple(kernel_size)
stride = _triple(stride)
padding = _triple(padding)
@ -1046,7 +1023,7 @@ class ConvTranspose3d(_ConvTransposeNd):
output_padding = _triple(output_padding)
super(ConvTranspose3d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
True, output_padding, groups, bias, padding_mode, **factory_kwargs)
True, output_padding, groups, bias, padding_mode)
def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
if self.padding_mode != 'zeros':
@ -1169,11 +1146,8 @@ class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc]
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
device=None,
dtype=None
padding_mode: str = 'zeros'
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(
0,
0,
@ -1185,13 +1159,12 @@ class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc]
# bias is hardcoded to False to avoid creating tensor
# that will soon be overwritten.
False,
padding_mode,
**factory_kwargs
padding_mode
)
self.weight = UninitializedParameter(**factory_kwargs)
self.weight = UninitializedParameter()
self.out_channels = out_channels
if bias:
self.bias = UninitializedParameter(**factory_kwargs)
self.bias = UninitializedParameter()
# LazyConv2d defines weight as a Tensor but derived class defines it as UnitializeParameter
@ -1235,11 +1208,8 @@ class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc]
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros', # TODO: refine this type
device=None,
dtype=None
padding_mode: str = 'zeros' # TODO: refine this type
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(
0,
0,
@ -1251,13 +1221,12 @@ class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc]
# bias is hardcoded to False to avoid creating tensor
# that will soon be overwritten.
False,
padding_mode,
**factory_kwargs
padding_mode
)
self.weight = UninitializedParameter(**factory_kwargs)
self.weight = UninitializedParameter()
self.out_channels = out_channels
if bias:
self.bias = UninitializedParameter(**factory_kwargs)
self.bias = UninitializedParameter()
# LazyConv3d defines weight as a Tensor but derived class defines it as UnitializeParameter
@ -1301,11 +1270,8 @@ class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc]
dilation: _size_3_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
device=None,
dtype=None
padding_mode: str = 'zeros'
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(
0,
0,
@ -1317,13 +1283,12 @@ class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc]
# bias is hardcoded to False to avoid creating tensor
# that will soon be overwritten.
False,
padding_mode,
**factory_kwargs
padding_mode
)
self.weight = UninitializedParameter(**factory_kwargs)
self.weight = UninitializedParameter()
self.out_channels = out_channels
if bias:
self.bias = UninitializedParameter(**factory_kwargs)
self.bias = UninitializedParameter()
# LazyConvTranspose1d defines weight as a Tensor but derived class defines it as UnitializeParameter
@ -1365,11 +1330,8 @@ class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[mi
groups: int = 1,
bias: bool = True,
dilation: _size_1_t = 1,
padding_mode: str = 'zeros',
device=None,
dtype=None
padding_mode: str = 'zeros'
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(
0,
0,
@ -1382,13 +1344,12 @@ class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[mi
# that will soon be overwritten.
False,
dilation,
padding_mode,
**factory_kwargs
padding_mode
)
self.weight = UninitializedParameter(**factory_kwargs)
self.weight = UninitializedParameter()
self.out_channels = out_channels
if bias:
self.bias = UninitializedParameter(**factory_kwargs)
self.bias = UninitializedParameter()
# LazyConvTranspose2d defines weight as a Tensor but derived class defines it as UnitializeParameter
@ -1430,11 +1391,8 @@ class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[mi
groups: int = 1,
bias: bool = True,
dilation: int = 1,
padding_mode: str = 'zeros',
device=None,
dtype=None
padding_mode: str = 'zeros'
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(
0,
0,
@ -1447,13 +1405,12 @@ class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[mi
# that will soon be overwritten.
False,
dilation,
padding_mode,
**factory_kwargs
padding_mode
)
self.weight = UninitializedParameter(**factory_kwargs)
self.weight = UninitializedParameter()
self.out_channels = out_channels
if bias:
self.bias = UninitializedParameter(**factory_kwargs)
self.bias = UninitializedParameter()
# LazyConvTranspose3d defines weight as a Tensor but derived class defines it as UnitializeParameter
@ -1495,11 +1452,8 @@ class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[mi
groups: int = 1,
bias: bool = True,
dilation: _size_3_t = 1,
padding_mode: str = 'zeros',
device=None,
dtype=None
padding_mode: str = 'zeros'
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(
0,
0,
@ -1512,10 +1466,9 @@ class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[mi
# that will soon be overwritten.
False,
dilation,
padding_mode,
**factory_kwargs
padding_mode
)
self.weight = UninitializedParameter(**factory_kwargs)
self.weight = UninitializedParameter()
self.out_channels = out_channels
if bias:
self.bias = UninitializedParameter(**factory_kwargs)
self.bias = UninitializedParameter()

View File

@ -11,13 +11,10 @@ class _InstanceNorm(_NormBase):
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = False,
track_running_stats: bool = False,
device=None,
dtype=None
track_running_stats: bool = False
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(_InstanceNorm, self).__init__(
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs)
num_features, eps, momentum, affine, track_running_stats)
def _check_input_dim(self, input):
raise NotImplementedError

View File

@ -72,15 +72,13 @@ class Linear(Module):
out_features: int
weight: Tensor
def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
super(Linear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))
self.weight = Parameter(torch.empty(out_features, in_features))
if bias:
self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
self.bias = Parameter(torch.empty(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
@ -154,17 +152,15 @@ class Bilinear(Module):
out_features: int
weight: Tensor
def __init__(self, in1_features: int, in2_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
def __init__(self, in1_features: int, in2_features: int, out_features: int, bias: bool = True) -> None:
super(Bilinear, self).__init__()
self.in1_features = in1_features
self.in2_features = in2_features
self.out_features = out_features
self.weight = Parameter(torch.empty((out_features, in1_features, in2_features), **factory_kwargs))
self.weight = Parameter(torch.empty(out_features, in1_features, in2_features))
if bias:
self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
self.bias = Parameter(torch.empty(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
@ -217,16 +213,14 @@ class LazyLinear(LazyModuleMixin, Linear):
weight: UninitializedParameter
bias: UninitializedParameter # type: ignore[assignment]
def __init__(self, out_features: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
def __init__(self, out_features: int, bias: bool = True) -> None:
# bias is hardcoded to False to avoid creating tensor
# that will soon be overwritten.
super().__init__(0, 0, False)
self.weight = UninitializedParameter(**factory_kwargs)
self.weight = UninitializedParameter()
self.out_features = out_features
if bias:
self.bias = UninitializedParameter(**factory_kwargs)
self.bias = UninitializedParameter()
def reset_parameters(self) -> None:
if not self.has_uninitialized_params() and self.in_features != 0:

View File

@ -145,9 +145,7 @@ class LayerNorm(Module):
eps: float
elementwise_affine: bool
def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True) -> None:
super(LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
@ -156,12 +154,11 @@ class LayerNorm(Module):
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
self.bias = Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
self.weight = Parameter(torch.empty(self.normalized_shape))
self.bias = Parameter(torch.empty(self.normalized_shape))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self) -> None:
@ -226,21 +223,18 @@ class GroupNorm(Module):
eps: float
affine: bool
def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True) -> None:
super(GroupNorm, self).__init__()
self.num_groups = num_groups
self.num_channels = num_channels
self.eps = eps
self.affine = affine
if self.affine:
self.weight = Parameter(torch.empty(num_channels, **factory_kwargs))
self.bias = Parameter(torch.empty(num_channels, **factory_kwargs))
self.weight = Parameter(torch.empty(num_channels))
self.bias = Parameter(torch.empty(num_channels))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self) -> None:

View File

@ -38,9 +38,7 @@ class RNNBase(Module):
def __init__(self, mode: str, input_size: int, hidden_size: int,
num_layers: int = 1, bias: bool = True, batch_first: bool = False,
dropout: float = 0., bidirectional: bool = False, proj_size: int = 0,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
dropout: float = 0., bidirectional: bool = False, proj_size: int = 0) -> None:
super(RNNBase, self).__init__()
self.mode = mode
self.input_size = input_size
@ -86,12 +84,12 @@ class RNNBase(Module):
real_hidden_size = proj_size if proj_size > 0 else hidden_size
layer_input_size = input_size if layer == 0 else real_hidden_size * num_directions
w_ih = Parameter(torch.empty((gate_size, layer_input_size), **factory_kwargs))
w_hh = Parameter(torch.empty((gate_size, real_hidden_size), **factory_kwargs))
b_ih = Parameter(torch.empty(gate_size, **factory_kwargs))
w_ih = Parameter(torch.empty(gate_size, layer_input_size))
w_hh = Parameter(torch.empty(gate_size, real_hidden_size))
b_ih = Parameter(torch.empty(gate_size))
# Second bias vector included for CuDNN compatibility. Only one
# bias vector is needed in standard definition.
b_hh = Parameter(torch.empty(gate_size, **factory_kwargs))
b_hh = Parameter(torch.empty(gate_size))
layer_params: Tuple[Tensor, ...] = ()
if self.proj_size == 0:
if bias:
@ -99,7 +97,7 @@ class RNNBase(Module):
else:
layer_params = (w_ih, w_hh)
else:
w_hr = Parameter(torch.empty((proj_size, hidden_size), **factory_kwargs))
w_hr = Parameter(torch.empty(proj_size, hidden_size))
if bias:
layer_params = (w_ih, w_hh, b_ih, b_hh, w_hr)
else:
@ -120,7 +118,6 @@ class RNNBase(Module):
self._flat_weights = [(lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn) for wn in self._flat_weights_names]
self.flatten_parameters()
self.reset_parameters()
def __setattr__(self, attr, value):
@ -848,22 +845,19 @@ class RNNCellBase(Module):
# WARNING: bias_ih and bias_hh purposely not defined here.
# See https://github.com/pytorch/pytorch/issues/39670
def __init__(self, input_size: int, hidden_size: int, bias: bool, num_chunks: int,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
def __init__(self, input_size: int, hidden_size: int, bias: bool, num_chunks: int) -> None:
super(RNNCellBase, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.weight_ih = Parameter(torch.empty((num_chunks * hidden_size, input_size), **factory_kwargs))
self.weight_hh = Parameter(torch.empty((num_chunks * hidden_size, hidden_size), **factory_kwargs))
self.weight_ih = Parameter(torch.empty(num_chunks * hidden_size, input_size))
self.weight_hh = Parameter(torch.empty(num_chunks * hidden_size, hidden_size))
if bias:
self.bias_ih = Parameter(torch.empty(num_chunks * hidden_size, **factory_kwargs))
self.bias_hh = Parameter(torch.empty(num_chunks * hidden_size, **factory_kwargs))
self.bias_ih = Parameter(torch.empty(num_chunks * hidden_size))
self.bias_hh = Parameter(torch.empty(num_chunks * hidden_size))
else:
self.register_parameter('bias_ih', None)
self.register_parameter('bias_hh', None)
self.reset_parameters()
def extra_repr(self) -> str:
@ -940,10 +934,8 @@ class RNNCell(RNNCellBase):
__constants__ = ['input_size', 'hidden_size', 'bias', 'nonlinearity']
nonlinearity: str
def __init__(self, input_size: int, hidden_size: int, bias: bool = True, nonlinearity: str = "tanh",
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(RNNCell, self).__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs)
def __init__(self, input_size: int, hidden_size: int, bias: bool = True, nonlinearity: str = "tanh") -> None:
super(RNNCell, self).__init__(input_size, hidden_size, bias, num_chunks=1)
self.nonlinearity = nonlinearity
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
@ -1030,10 +1022,8 @@ class LSTMCell(RNNCellBase):
>>> output = torch.stack(output, dim=0)
"""
def __init__(self, input_size: int, hidden_size: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(LSTMCell, self).__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs)
def __init__(self, input_size: int, hidden_size: int, bias: bool = True) -> None:
super(LSTMCell, self).__init__(input_size, hidden_size, bias, num_chunks=4)
def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
if hx is None:
@ -1108,10 +1098,8 @@ class GRUCell(RNNCellBase):
output.append(hx)
"""
def __init__(self, input_size: int, hidden_size: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(GRUCell, self).__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs)
def __init__(self, input_size: int, hidden_size: int, bias: bool = True) -> None:
super(GRUCell, self).__init__(input_size, hidden_size, bias, num_chunks=3)
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
if hx is None:

View File

@ -119,9 +119,7 @@ class Embedding(Module):
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
sparse: bool = False, _weight: Optional[Tensor] = None,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
sparse: bool = False, _weight: Optional[Tensor] = None) -> None:
super(Embedding, self).__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
@ -136,7 +134,7 @@ class Embedding(Module):
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
if _weight is None:
self.weight = Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs))
self.weight = Parameter(torch.empty(num_embeddings, embedding_dim))
self.reset_parameters()
else:
assert list(_weight.shape) == [num_embeddings, embedding_dim], \
@ -309,9 +307,7 @@ class EmbeddingBag(Module):
def __init__(self, num_embeddings: int, embedding_dim: int,
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
mode: str = 'mean', sparse: bool = False, _weight: Optional[Tensor] = None,
include_last_offset: bool = False, padding_idx: Optional[int] = None,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
include_last_offset: bool = False, padding_idx: Optional[int] = None) -> None:
super(EmbeddingBag, self).__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
@ -326,7 +322,7 @@ class EmbeddingBag(Module):
padding_idx = self.num_embeddings + padding_idx
self.padding_idx = padding_idx
if _weight is None:
self.weight = Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs))
self.weight = Parameter(torch.empty(num_embeddings, embedding_dim))
self.reset_parameters()
else:
assert list(_weight.shape) == [num_embeddings, embedding_dim], \

View File

@ -48,27 +48,23 @@ class Transformer(Module):
def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,
num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
activation: str = "relu", custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None,
layer_norm_eps: float = 1e-5, batch_first: bool = False,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
layer_norm_eps: float = 1e-5, batch_first: bool = False) -> None:
super(Transformer, self).__init__()
if custom_encoder is not None:
self.encoder = custom_encoder
else:
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
activation, layer_norm_eps, batch_first,
**factory_kwargs)
encoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
activation, layer_norm_eps, batch_first)
encoder_norm = LayerNorm(d_model, eps=layer_norm_eps)
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
if custom_decoder is not None:
self.decoder = custom_decoder
else:
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
activation, layer_norm_eps, batch_first,
**factory_kwargs)
decoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
activation, layer_norm_eps, batch_first)
decoder_norm = LayerNorm(d_model, eps=layer_norm_eps)
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
self._reset_parameters()
@ -283,19 +279,16 @@ class TransformerEncoderLayer(Module):
__constants__ = ['batch_first']
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
layer_norm_eps=1e-5, batch_first=False,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
layer_norm_eps=1e-5, batch_first=False):
super(TransformerEncoderLayer, self).__init__()
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
**factory_kwargs)
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
# Implementation of Feedforward model
self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
self.linear1 = Linear(d_model, dim_feedforward)
self.dropout = Dropout(dropout)
self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
self.linear2 = Linear(dim_feedforward, d_model)
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps)
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
@ -360,21 +353,18 @@ class TransformerDecoderLayer(Module):
__constants__ = ['batch_first']
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
layer_norm_eps=1e-5, batch_first=False, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
layer_norm_eps=1e-5, batch_first=False):
super(TransformerDecoderLayer, self).__init__()
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
**factory_kwargs)
self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
**factory_kwargs)
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
# Implementation of Feedforward model
self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
self.linear1 = Linear(d_model, dim_feedforward)
self.dropout = Dropout(dropout)
self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
self.linear2 = Linear(dim_feedforward, d_model)
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps)
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps)
self.norm3 = LayerNorm(d_model, eps=layer_norm_eps)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
self.dropout3 = Dropout(dropout)

View File

@ -141,16 +141,12 @@ class UninitializedParameter(UninitializedTensorMixin, Parameter):
will throw a runtime error. The only operations that can be performed on a uninitialized
parameter are changing its datatype, moving it to a different device and
converting it to a regular :class:`torch.nn.Parameter`.
The default device or dtype to use when the parameter is materialized can be set
during construction using e.g. ``device='cuda'``.
"""
cls_to_become = Parameter
def __new__(cls, requires_grad=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
data = torch.tensor([], **factory_kwargs)
def __new__(cls, requires_grad=True):
data = torch.tensor([])
return torch.Tensor._make_subclass(cls, data, requires_grad)
@ -165,14 +161,10 @@ class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor):
will throw a runtime error. The only operations that can be performed on a uninitialized
parameter are changing its datatype, moving it to a different device and
converting it to a regular :class:`torch.Tensor`.
The default device or dtype to use when the buffer is materialized can be set
during construction using e.g. ``device='cuda'``.
"""
cls_to_become = torch.Tensor
def __new__(cls, requires_grad=False, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
data = torch.tensor([], **factory_kwargs)
def __new__(cls, requires_grad=False):
data = torch.tensor([])
return torch.Tensor._make_subclass(cls, data, requires_grad)

View File

@ -21,16 +21,13 @@ class Conv2d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1,
bias=True, padding_mode='zeros', qconfig=None,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
bias=True, padding_mode='zeros', qconfig=None):
super().__init__(in_channels, out_channels, kernel_size,
stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=bias, padding_mode=padding_mode,
**factory_kwargs)
groups=groups, bias=bias, padding_mode=padding_mode)
assert qconfig, 'qconfig must be provided for QAT module'
self.qconfig = qconfig
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
self.weight_fake_quant = qconfig.weight()
def forward(self, input):
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
@ -87,10 +84,7 @@ class Conv3d(nn.Conv3d):
bias=True,
padding_mode="zeros",
qconfig=None,
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
):
super().__init__(
in_channels,
out_channels,
@ -101,11 +95,10 @@ class Conv3d(nn.Conv3d):
groups=groups,
bias=bias,
padding_mode=padding_mode,
**factory_kwargs
)
assert qconfig, "qconfig must be provided for QAT module"
self.qconfig = qconfig
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
self.weight_fake_quant = qconfig.weight()
def forward(self, input):
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)

View File

@ -20,12 +20,11 @@ class Linear(nn.Linear):
_FLOAT_MODULE = nn.Linear
def __init__(self, in_features, out_features, bias=True,
qconfig=None, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(in_features, out_features, bias, **factory_kwargs)
qconfig=None):
super().__init__(in_features, out_features, bias)
assert qconfig, 'qconfig must be provided for QAT module'
self.qconfig = qconfig
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
self.weight_fake_quant = qconfig.weight()
def forward(self, input):
return F.linear(input, self.weight_fake_quant(self.weight), self.bias)

View File

@ -59,21 +59,18 @@ class MultiheadAttention(nn.MultiheadAttention):
def __init__(self, embed_dim: int, num_heads: int,
dropout: float = 0., bias: bool = True,
add_bias_kv: bool = False, add_zero_attn: bool = False,
kdim: int = None, vdim: int = None, batch_first: bool = False,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
kdim: int = None, vdim: int = None, batch_first: bool = False):
super(MultiheadAttention, self).__init__(embed_dim, num_heads, dropout,
bias, add_bias_kv,
add_zero_attn, kdim, vdim, batch_first,
**factory_kwargs)
self.linear_Q = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs)
self.linear_K = nn.Linear(self.kdim, self.embed_dim, bias=bias, **factory_kwargs)
self.linear_V = nn.Linear(self.vdim, self.embed_dim, bias=bias, **factory_kwargs)
add_zero_attn, kdim, vdim, batch_first)
self.linear_Q = nn.Linear(self.embed_dim, self.embed_dim, bias=bias)
self.linear_K = nn.Linear(self.kdim, self.embed_dim, bias=bias)
self.linear_V = nn.Linear(self.vdim, self.embed_dim, bias=bias)
# TODO: The use of the `_LinearWithBias` increases the quantization noise
# The `out_proj` in the parent is ``_LinearWithBias`, so need to ignore
# the type for mypy not to complain.
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs) # type: ignore
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias) # type: ignore
# Functionals
self.q_scaling_product = nnq.FloatFunctional()

View File

@ -29,16 +29,14 @@ class LSTMCell(torch.nn.Module):
"""
_FLOAT_MODULE = torch.nn.LSTMCell
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True):
super().__init__()
self.input_size = input_dim
self.hidden_size = hidden_dim
self.bias = bias
self.igates = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs)
self.hgates = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs)
self.igates = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=bias)
self.hgates = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=bias)
self.gates = torch.nn.quantized.FloatFunctional()
self.fgate_cx = torch.nn.quantized.FloatFunctional()
@ -121,11 +119,9 @@ class _LSTMSingleLayer(torch.nn.Module):
The difference between a layer and a cell is that the layer can process a
sequence, while the cell only expects an instantaneous value.
"""
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True):
super().__init__()
self.cell = LSTMCell(input_dim, hidden_dim, bias=bias, **factory_kwargs)
self.cell = LSTMCell(input_dim, hidden_dim, bias=bias)
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
result = []
@ -146,15 +142,13 @@ class _LSTMSingleLayer(torch.nn.Module):
class _LSTMLayer(torch.nn.Module):
r"""A single bi-directional LSTM layer."""
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
batch_first: bool = False, bidirectional: bool = False,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
batch_first: bool = False, bidirectional: bool = False):
super().__init__()
self.batch_first = batch_first
self.bidirectional = bidirectional
self.layer_fw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias, **factory_kwargs)
self.layer_fw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias)
if self.bidirectional:
self.layer_bw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias, **factory_kwargs)
self.layer_bw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias)
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
if self.batch_first:
@ -299,9 +293,7 @@ class LSTM(torch.nn.Module):
def __init__(self, input_size: int, hidden_size: int,
num_layers: int = 1, bias: bool = True,
batch_first: bool = False, dropout: float = 0.,
bidirectional: bool = False,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
bidirectional: bool = False):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
@ -330,12 +322,11 @@ class LSTM(torch.nn.Module):
layers = [_LSTMLayer(self.input_size, self.hidden_size,
self.bias, batch_first=False,
bidirectional=self.bidirectional, **factory_kwargs)]
bidirectional=self.bidirectional)]
for layer in range(1, num_layers):
layers.append(_LSTMLayer(self.hidden_size, self.hidden_size,
self.bias, batch_first=False,
bidirectional=self.bidirectional,
**factory_kwargs))
bidirectional=self.bidirectional))
self.layers = torch.nn.ModuleList(layers)
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):

View File

@ -37,13 +37,10 @@ class Quantize(torch.nn.Module):
scale: torch.Tensor
zero_point: torch.Tensor
def __init__(self, scale, zero_point, dtype, factory_kwargs=None):
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
def __init__(self, scale, zero_point, dtype):
super(Quantize, self).__init__()
self.register_buffer('scale', torch.tensor([scale], **factory_kwargs))
self.register_buffer('zero_point',
torch.tensor([zero_point], dtype=torch.long,
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
self.register_buffer('scale', torch.tensor([scale]))
self.register_buffer('zero_point', torch.tensor([zero_point], dtype=torch.long))
self.dtype = dtype
def forward(self, X):

View File

@ -95,12 +95,10 @@ class LeakyReLU(torch.nn.LeakyReLU):
zero_point: quantization zero point of the output tensor
negative_slope: Controls the angle of the negative slope. Default: 1e-2
"""
def __init__(self, scale: float, zero_point: int, negative_slope: float = 1e-2,
inplace: bool = False, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
def __init__(self, scale: float, zero_point: int, negative_slope: float = 1e-2, inplace: bool = False):
super().__init__(negative_slope, inplace)
self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
self.register_buffer('scale', torch.tensor(scale))
self.register_buffer('zero_point', torch.tensor(zero_point))
def forward(self, input):
return torch.ops.quantized.leaky_relu(

View File

@ -6,9 +6,8 @@ class BatchNorm2d(torch.nn.BatchNorm2d):
r"""This is the quantized version of :class:`~torch.nn.BatchNorm2d`.
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(BatchNorm2d, self).__init__(num_features, **factory_kwargs)
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super(BatchNorm2d, self).__init__(num_features)
self.eps = eps
self.scale = 1.0
self.zero_point = 0
@ -40,9 +39,8 @@ class BatchNorm3d(torch.nn.BatchNorm3d):
r"""This is the quantized version of :class:`~torch.nn.BatchNorm3d`.
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super(BatchNorm3d, self).__init__(num_features, **factory_kwargs)
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super(BatchNorm3d, self).__init__(num_features)
self.eps = eps
self.scale = 1.0
self.zero_point = 0

View File

@ -33,7 +33,7 @@ def _reverse_repeat_padding(padding: List[int]) -> List[int]:
class _ConvNd(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
padding_mode='zeros', device=None, dtype=None):
padding_mode='zeros'):
# All subclasses have this signature - See PR #49702s
raise NotImplementedError
@ -41,10 +41,7 @@ class _ConvNd(nn.Module):
padding, dilation,
transposed, output_padding,
groups, bias,
padding_mode='zeros',
device=None,
dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
padding_mode='zeros'):
super(_ConvNd, self).__init__()
if in_channels % groups != 0:
@ -70,11 +67,9 @@ class _ConvNd(nn.Module):
weight_shape = [out_channels, in_channels // self.groups]
qweight = torch._empty_affine_quantized(
weight_shape + list(kernel_size),
scale=1, zero_point=0, dtype=torch.qint8,
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'})
scale=1, zero_point=0, dtype=torch.qint8)
bias_float = (
torch.zeros(out_channels, dtype=torch.float,
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'}) if bias else None)
torch.zeros(out_channels, dtype=torch.float) if bias else None)
self.set_weight_bias(qweight, bias_float)
self.scale = 1.0
@ -280,10 +275,7 @@ class Conv1d(_ConvNd):
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
device=None,
dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
padding_mode: str = 'zeros'):
kernel_size = _pair_from_first(kernel_size)
stride = _pair_from_first(stride)
padding = _pair_from_first(padding)
@ -293,7 +285,7 @@ class Conv1d(_ConvNd):
# discussion on PR #49702
super(Conv1d, self)._init(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _single(0), groups, bias, padding_mode, **factory_kwargs)
False, _single(0), groups, bias, padding_mode)
def _get_name(self):
return 'QuantizedConv1d'
@ -382,8 +374,7 @@ class Conv2d(_ConvNd):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
padding_mode='zeros', device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
padding_mode='zeros'):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
@ -392,7 +383,7 @@ class Conv2d(_ConvNd):
# discussion on PR #49702
super(Conv2d, self)._init(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), groups, bias, padding_mode, **factory_kwargs)
False, _pair(0), groups, bias, padding_mode)
def _get_name(self):
return 'QuantizedConv2d'
@ -479,9 +470,8 @@ class Conv3d(_ConvNd):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
padding_mode='zeros', device=None, dtype=None):
padding_mode='zeros'):
assert padding_mode != 'reflect', "Conv3d does not support reflection padding"
factory_kwargs = {'device': device, 'dtype': dtype}
kernel_size = _triple(kernel_size)
stride = _triple(stride)
padding = _triple(padding)
@ -490,7 +480,7 @@ class Conv3d(_ConvNd):
# discussion on PR #49702
super(Conv3d, self)._init(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _triple(0), groups, bias, padding_mode, **factory_kwargs)
False, _triple(0), groups, bias, padding_mode)
def _get_name(self):
return 'QuantizedConv3d'
@ -543,16 +533,15 @@ class _ConvTransposeNd(_ConvNd):
def __init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding,
groups, bias, padding_mode, device=None, dtype=None):
groups, bias, padding_mode):
if padding_mode != 'zeros':
raise ValueError('Only "zeros" padding mode is supported for {}'.format(self.__class__.__name__))
factory_kwargs = {'device': device, 'dtype': dtype}
# Subclasses of _ConvNd need to call _init rather than __init__. See
# discussion on PR #49702
super(_ConvTransposeNd, self)._init(
in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding,
groups, bias, padding_mode, **factory_kwargs)
groups, bias, padding_mode)
def _input_padding(self, kernel_size: List[int], dilation: List[int], padding: List[int]) -> List[int]:
res = torch.jit.annotate(List[int], [])
@ -636,8 +625,7 @@ class ConvTranspose1d(_ConvTransposeNd):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, output_padding=0, groups=1, bias=True,
dilation=1, padding_mode='zeros', device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
dilation=1, padding_mode='zeros'):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
@ -646,7 +634,7 @@ class ConvTranspose1d(_ConvTransposeNd):
super(ConvTranspose1d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
True, output_padding, groups, bias, padding_mode, **factory_kwargs)
True, output_padding, groups, bias, padding_mode)
def _get_name(self):
return 'QuantizedConvTranpose1d'
@ -720,8 +708,7 @@ class ConvTranspose2d(_ConvTransposeNd):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, output_padding=0, groups=1, bias=True,
dilation=1, padding_mode='zeros', device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
dilation=1, padding_mode='zeros'):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
@ -730,7 +717,7 @@ class ConvTranspose2d(_ConvTransposeNd):
super(ConvTranspose2d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
True, output_padding, groups, bias, padding_mode, **factory_kwargs)
True, output_padding, groups, bias, padding_mode)
def _get_name(self):
return 'QuantizedConvTranpose2d'
@ -805,8 +792,7 @@ class ConvTranspose3d(_ConvTransposeNd):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, output_padding=0, groups=1, bias=True,
dilation=1, padding_mode='zeros', device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
dilation=1, padding_mode='zeros'):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
@ -815,7 +801,7 @@ class ConvTranspose3d(_ConvTransposeNd):
super(ConvTranspose3d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
True, output_padding, groups, bias, padding_mode, **factory_kwargs)
True, output_padding, groups, bias, padding_mode)
def _get_name(self):
return 'QuantizedConvTranpose3d'

View File

@ -11,15 +11,13 @@ class LayerNorm(torch.nn.LayerNorm):
"""
def __init__(self, normalized_shape, weight, bias, scale, zero_point, eps=1e-5,
elementwise_affine=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
elementwise_affine=True):
super(LayerNorm, self).__init__(
normalized_shape, eps=eps, elementwise_affine=elementwise_affine,
**factory_kwargs)
normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
self.weight = weight
self.bias = bias
self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
self.register_buffer('scale', torch.tensor(scale))
self.register_buffer('zero_point', torch.tensor(zero_point))
def forward(self, input):
return torch.ops.quantized.layer_norm(
@ -47,15 +45,12 @@ class GroupNorm(torch.nn.GroupNorm):
"""
__constants__ = ['num_groups', 'num_channels', 'eps', 'affine']
def __init__(self, num_groups, num_channels, weight, bias, scale, zero_point, eps=1e-5,
affine=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(GroupNorm, self).__init__(num_groups, num_channels, eps, affine,
**factory_kwargs)
def __init__(self, num_groups, num_channels, weight, bias, scale, zero_point, eps=1e-5, affine=True):
super(GroupNorm, self).__init__(num_groups, num_channels, eps, affine)
self.weight = weight
self.bias = bias
self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
self.register_buffer('scale', torch.tensor(scale))
self.register_buffer('zero_point', torch.tensor(zero_point))
def forward(self, input):
return torch.ops.quantized.group_norm(
@ -83,14 +78,13 @@ class InstanceNorm1d(torch.nn.InstanceNorm1d):
"""
def __init__(self, num_features, weight, bias, scale, zero_point,
eps=1e-5, momentum=0.1, affine=False,
track_running_stats=False, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
track_running_stats=False):
super(InstanceNorm1d, self).__init__(
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs)
num_features, eps, momentum, affine, track_running_stats)
self.weight = weight
self.bias = bias
self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
self.register_buffer('scale', torch.tensor(scale))
self.register_buffer('zero_point', torch.tensor(zero_point))
def forward(self, input):
return torch.ops.quantized.instance_norm(
@ -118,14 +112,13 @@ class InstanceNorm2d(torch.nn.InstanceNorm2d):
"""
def __init__(self, num_features, weight, bias, scale, zero_point,
eps=1e-5, momentum=0.1, affine=False,
track_running_stats=False, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
track_running_stats=False):
super(InstanceNorm2d, self).__init__(
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs)
num_features, eps, momentum, affine, track_running_stats)
self.weight = weight
self.bias = bias
self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
self.register_buffer('scale', torch.tensor(scale))
self.register_buffer('zero_point', torch.tensor(zero_point))
def forward(self, input):
return torch.ops.quantized.instance_norm(
@ -153,14 +146,13 @@ class InstanceNorm3d(torch.nn.InstanceNorm3d):
"""
def __init__(self, num_features, weight, bias, scale, zero_point,
eps=1e-5, momentum=0.1, affine=False,
track_running_stats=False, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
track_running_stats=False):
super(InstanceNorm3d, self).__init__(
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs)
num_features, eps, momentum, affine, track_running_stats)
self.weight = weight
self.bias = bias
self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
self.register_buffer('scale', torch.tensor(scale))
self.register_buffer('zero_point', torch.tensor(zero_point))
def forward(self, input):
return torch.ops.quantized.instance_norm(

View File

@ -114,8 +114,7 @@ class _ObserverBase(ObserverBase):
eps: torch.Tensor
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
reduce_range=False, quant_min=None, quant_max=None, factory_kwargs=None) -> None:
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
reduce_range=False, quant_min=None, quant_max=None):
super(_ObserverBase, self).__init__(dtype=dtype)
self.qscheme = qscheme
if reduce_range:
@ -124,7 +123,7 @@ class _ObserverBase(ObserverBase):
reduce_range will be deprecated in a future release of PyTorch."
)
self.reduce_range = reduce_range
self.register_buffer('eps', torch.tensor([torch.finfo(torch.float32).eps], **factory_kwargs))
self.register_buffer('eps', torch.tensor([torch.finfo(torch.float32).eps]))
assert self.qscheme in (
torch.per_tensor_affine,
torch.per_tensor_symmetric,
@ -368,23 +367,21 @@ class MinMaxObserver(_ObserverBase):
max_val: torch.Tensor
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
reduce_range=False, quant_min=None, quant_max=None, factory_kwargs=None) -> None:
reduce_range=False, quant_min=None, quant_max=None):
# For x86 quantized kernels, we need to ensure that the vpmaddubsw
# instruction does not overflow. We allow for a reduce_range argument to
# observers that reduces the quantized range to (0,127) or (-64, 63).
# For more details see aten/src/ATen/native/quantized/cpu/qconv.cpp
# This is not an optimal choice for non x86 backends as it loses a bit
# of precision for activations.
super(MinMaxObserver, self).__init__(dtype=dtype,
qscheme=qscheme,
reduce_range=reduce_range,
quant_min=quant_min,
quant_max=quant_max,
factory_kwargs=factory_kwargs)
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
self.register_buffer('min_val', torch.tensor(float('inf'), **factory_kwargs))
self.register_buffer('max_val', torch.tensor(float('-inf'), **factory_kwargs))
quant_max=quant_max)
self.register_buffer('min_val', torch.tensor(float('inf')))
self.register_buffer('max_val', torch.tensor(float('-inf')))
if self.qscheme == torch.per_tensor_symmetric and \
self.reduce_range and \
self.dtype == torch.quint8:
@ -459,14 +456,13 @@ class MovingAverageMinMaxObserver(MinMaxObserver):
"""
def __init__(self, averaging_constant=0.01, dtype=torch.quint8,
qscheme=torch.per_tensor_affine, reduce_range=False,
quant_min=None, quant_max=None, **kwargs) -> None:
quant_min=None, quant_max=None):
self.averaging_constant = averaging_constant
super(MovingAverageMinMaxObserver, self).__init__(dtype=dtype,
qscheme=qscheme,
reduce_range=reduce_range,
quant_min=quant_min,
quant_max=quant_max,
**kwargs)
quant_max=quant_max)
def forward(self, x_orig):
if x_orig.numel() == 0:
@ -518,17 +514,15 @@ class PerChannelMinMaxObserver(_ObserverBase):
def __init__(self, ch_axis=0, dtype=torch.quint8,
qscheme=torch.per_channel_affine, reduce_range=False,
quant_min=None, quant_max=None, factory_kwargs=None) -> None:
quant_min=None, quant_max=None):
super(PerChannelMinMaxObserver, self).__init__(dtype=dtype,
qscheme=qscheme,
reduce_range=reduce_range,
quant_min=quant_min,
quant_max=quant_max,
factory_kwargs=factory_kwargs)
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
quant_max=quant_max)
self.ch_axis = ch_axis
self.register_buffer('min_vals', torch.tensor([], **factory_kwargs))
self.register_buffer('max_vals', torch.tensor([], **factory_kwargs))
self.register_buffer('min_vals', torch.tensor([]))
self.register_buffer('max_vals', torch.tensor([]))
if (
self.qscheme == torch.per_channel_symmetric
and self.reduce_range
@ -643,10 +637,10 @@ class MovingAveragePerChannelMinMaxObserver(PerChannelMinMaxObserver):
def __init__(self, averaging_constant=0.01, ch_axis=0, dtype=torch.quint8,
qscheme=torch.per_channel_affine, reduce_range=False,
quant_min=None, quant_max=None, **kwargs) -> None:
quant_min=None, quant_max=None):
super(MovingAveragePerChannelMinMaxObserver, self).__init__(
ch_axis=ch_axis, dtype=dtype, qscheme=qscheme,
reduce_range=reduce_range, quant_min=quant_min, quant_max=quant_max, **kwargs)
reduce_range=reduce_range, quant_min=quant_min, quant_max=quant_max)
self.averaging_constant = averaging_constant
def forward(self, x_orig):
@ -709,19 +703,16 @@ class HistogramObserver(_ObserverBase):
upsample_rate: int = 128,
dtype: torch.dtype = torch.quint8,
qscheme=torch.per_tensor_affine,
reduce_range=False,
factory_kwargs=None,
) -> None:
reduce_range=False
):
# bins: The number of bins used for histogram calculation.
super(HistogramObserver, self).__init__(dtype=dtype,
qscheme=qscheme,
reduce_range=reduce_range,
factory_kwargs=factory_kwargs)
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
reduce_range=reduce_range)
self.bins = bins
self.register_buffer('histogram', torch.zeros(self.bins, **factory_kwargs))
self.register_buffer('min_val', torch.tensor(float('inf'), **factory_kwargs))
self.register_buffer('max_val', torch.tensor(float('-inf'), **factory_kwargs))
self.register_buffer('histogram', torch.zeros(self.bins))
self.register_buffer('min_val', torch.tensor(float('inf')))
self.register_buffer('max_val', torch.tensor(float('-inf')))
self.dst_nbins = 2 ** torch.iinfo(self.dtype).bits
self.upsample_rate = upsample_rate
@ -1019,7 +1010,7 @@ class PlaceholderObserver(ObserverBase):
custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation
(Can be used in Graph Mode Passes for special case ops).
"""
def __init__(self, dtype=torch.float32, custom_op_name="", compute_dtype=None) -> None:
def __init__(self, dtype=torch.float32, custom_op_name="", compute_dtype=None):
super(PlaceholderObserver, self).__init__(dtype=dtype)
# dtype of input of the target operator, e.g. for dynamic quantization
# ops, the dtype will be float32
@ -1079,7 +1070,7 @@ class NoopObserver(ObserverBase):
custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation
(Can be used in Graph Mode Passes for special case ops).
"""
def __init__(self, dtype=torch.float16, custom_op_name="") -> None:
def __init__(self, dtype=torch.float16, custom_op_name=""):
super(NoopObserver, self).__init__(dtype=dtype)
self.dtype = dtype
self.custom_op = custom_op_name