annotate a few torch.nn.modules.* modules (#45772)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/45771

Pull Request resolved: https://github.com/pytorch/pytorch/pull/45772

Reviewed By: mruberry

Differential Revision: D24682013

Pulled By: albanD

fbshipit-source-id: e32bc4fe9c586c079f7070924a874c70f3d127fa
This commit is contained in:
Guilherme Leobas 2020-11-02 13:03:02 -08:00 committed by Facebook GitHub Bot
parent 7178790381
commit 9b52654620
7 changed files with 25 additions and 26 deletions

View File

@ -77,9 +77,6 @@ ignore_errors = True
[mypy-torch._tensor_str]
ignore_errors = True
[mypy-torch.nn.modules.batchnorm]
ignore_errors = True
[mypy-torch.nn.modules.container]
ignore_errors = True
@ -89,12 +86,6 @@ ignore_errors = True
[mypy-torch.nn.modules.fold]
ignore_errors = True
[mypy-torch.nn.modules.instancenorm]
ignore_errors = True
[mypy-torch.nn.modules.linear]
ignore_errors = True
[mypy-torch.nn.modules.loss]
ignore_errors = True
@ -113,9 +104,6 @@ ignore_errors = True
[mypy-torch.nn.modules.rnn]
ignore_errors = True
[mypy-torch.nn.modules.sparse]
ignore_errors = True
[mypy-torch.nn.parallel._functions]
ignore_errors = True

View File

@ -189,7 +189,8 @@ def embedding(input: Tensor, weight: Tensor, padding_idx: Optional[int] = ..., m
def embedding_bag(input: Tensor, weight: Tensor, offsets: Optional[Tensor] = ..., max_norm: Optional[float] = ...,
norm_type: float = ..., scale_grad_by_freq: bool = ..., mode: str = ...,
sparse: bool = ...) -> Tensor: ...
sparse: bool = ..., per_sample_weights: Optional[Tensor] = ...,
include_last_offset: bool = ...) -> Tensor: ...
def batch_norm(input: Tensor, running_mean: Optional[Tensor], running_var: Optional[Tensor],
weight: Optional[Tensor] = ..., bias: Optional[Tensor] = ..., training: bool = ...,

View File

@ -54,9 +54,11 @@ class _NormBase(Module):
def reset_running_stats(self) -> None:
if self.track_running_stats:
self.running_mean.zero_()
self.running_var.fill_(1)
self.num_batches_tracked.zero_()
# running_mean/running_var/num_batches... are registerd at runtime depending
# if self.track_running_stats is on
self.running_mean.zero_() # type: ignore[operator]
self.running_var.fill_(1) # type: ignore[operator]
self.num_batches_tracked.zero_() # type: ignore[operator]
def reset_parameters(self) -> None:
self.reset_running_stats()
@ -107,8 +109,8 @@ class _BatchNorm(_NormBase):
if self.training and self.track_running_stats:
# TODO: if statement only here to tell the jit to skip emitting this when it is None
if self.num_batches_tracked is not None:
self.num_batches_tracked = self.num_batches_tracked + 1
if self.num_batches_tracked is not None: # type: ignore
self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
@ -128,6 +130,8 @@ class _BatchNorm(_NormBase):
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
used for normalization (i.e. in eval mode when buffers are not None).
"""
assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor)
assert self.running_var is None or isinstance(self.running_var, torch.Tensor)
return F.batch_norm(
input,
# If buffers are not to be tracked, ensure that they won't be updated
@ -487,6 +491,7 @@ class SyncBatchNorm(_BatchNorm):
exponential_average_factor = self.momentum
if self.training and self.track_running_stats:
assert self.num_batches_tracked is not None
self.num_batches_tracked = self.num_batches_tracked + 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / self.num_batches_tracked.item()
@ -508,6 +513,8 @@ class SyncBatchNorm(_BatchNorm):
used for normalization (i.e. in eval mode when buffers are not None).
"""
# If buffers are not to be tracked, ensure that they won't be updated
assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor)
assert self.running_var is None or isinstance(self.running_var, torch.Tensor)
running_mean = self.running_mean if not self.training or self.track_running_stats else None
running_var = self.running_var if not self.training or self.track_running_stats else None

View File

@ -52,6 +52,8 @@ class _InstanceNorm(_NormBase):
def forward(self, input: Tensor) -> Tensor:
self._check_input_dim(input)
assert self.running_mean is None or isinstance(self.running_mean, Tensor)
assert self.running_var is None or isinstance(self.running_var, Tensor)
return F.instance_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
self.training or not self.track_running_stats, self.momentum, self.eps)

View File

@ -102,10 +102,10 @@ class Linear(Module):
# This class exists solely for Transformer; it has an annotation stating
# that bias is never None, which appeases TorchScript
class _LinearWithBias(Linear):
bias: Tensor
bias: Tensor # type: ignore
def __init__(self, in_features: int, out_features: int) -> None:
super().__init__(in_features, out_features, bias=True)
super().__init__(in_features, out_features, bias=True) # type: ignore
class Bilinear(Module):
@ -208,7 +208,8 @@ class LazyLinear(LazyModuleMixin, Linear):
"""
cls_to_become = Linear
cls_to_become = Linear # type: ignore[assignment]
weight: UninitializedParameter
def __init__(self, out_features: int, bias: bool = True) -> None:
super().__init__(0, out_features, bias)
@ -218,7 +219,7 @@ class LazyLinear(LazyModuleMixin, Linear):
if not self.has_uninitialized_params() and self.in_features != 0:
super().reset_parameters()
def initialize_parameters(self, input) -> None:
def initialize_parameters(self, input) -> None: # type: ignore
if self.has_uninitialized_params():
with torch.no_grad():
self.in_features = input.shape[-1]

View File

@ -99,8 +99,8 @@ class Embedding(Module):
num_embeddings: int
embedding_dim: int
padding_idx: int
max_norm: float
padding_idx: Optional[int]
max_norm: Optional[float]
norm_type: float
scale_grad_by_freq: bool
weight: Tensor
@ -284,7 +284,7 @@ class EmbeddingBag(Module):
num_embeddings: int
embedding_dim: int
max_norm: float
max_norm: Optional[float]
norm_type: float
scale_grad_by_freq: bool
weight: Tensor

View File

@ -11,5 +11,5 @@ class Parameter(Tensor):
class UninitializedParameter(Tensor):
def __init__(self, data: Tensor=..., requires_grad: builtins.bool=...): ...
def materialize(self, shape: Tuple[int], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): ...
def materialize(self, shape: Tuple[int, ...], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): ...
...