mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
annotate a few torch.nn.modules.* modules (#45772)
Summary: Fixes https://github.com/pytorch/pytorch/issues/45771 Pull Request resolved: https://github.com/pytorch/pytorch/pull/45772 Reviewed By: mruberry Differential Revision: D24682013 Pulled By: albanD fbshipit-source-id: e32bc4fe9c586c079f7070924a874c70f3d127fa
This commit is contained in:
parent
7178790381
commit
9b52654620
12
mypy.ini
12
mypy.ini
|
|
@ -77,9 +77,6 @@ ignore_errors = True
|
|||
[mypy-torch._tensor_str]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.nn.modules.batchnorm]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.nn.modules.container]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
@ -89,12 +86,6 @@ ignore_errors = True
|
|||
[mypy-torch.nn.modules.fold]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.nn.modules.instancenorm]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.nn.modules.linear]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.nn.modules.loss]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
@ -113,9 +104,6 @@ ignore_errors = True
|
|||
[mypy-torch.nn.modules.rnn]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.nn.modules.sparse]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.nn.parallel._functions]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
|
|||
|
|
@ -189,7 +189,8 @@ def embedding(input: Tensor, weight: Tensor, padding_idx: Optional[int] = ..., m
|
|||
|
||||
def embedding_bag(input: Tensor, weight: Tensor, offsets: Optional[Tensor] = ..., max_norm: Optional[float] = ...,
|
||||
norm_type: float = ..., scale_grad_by_freq: bool = ..., mode: str = ...,
|
||||
sparse: bool = ...) -> Tensor: ...
|
||||
sparse: bool = ..., per_sample_weights: Optional[Tensor] = ...,
|
||||
include_last_offset: bool = ...) -> Tensor: ...
|
||||
|
||||
def batch_norm(input: Tensor, running_mean: Optional[Tensor], running_var: Optional[Tensor],
|
||||
weight: Optional[Tensor] = ..., bias: Optional[Tensor] = ..., training: bool = ...,
|
||||
|
|
|
|||
|
|
@ -54,9 +54,11 @@ class _NormBase(Module):
|
|||
|
||||
def reset_running_stats(self) -> None:
|
||||
if self.track_running_stats:
|
||||
self.running_mean.zero_()
|
||||
self.running_var.fill_(1)
|
||||
self.num_batches_tracked.zero_()
|
||||
# running_mean/running_var/num_batches... are registerd at runtime depending
|
||||
# if self.track_running_stats is on
|
||||
self.running_mean.zero_() # type: ignore[operator]
|
||||
self.running_var.fill_(1) # type: ignore[operator]
|
||||
self.num_batches_tracked.zero_() # type: ignore[operator]
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
self.reset_running_stats()
|
||||
|
|
@ -107,8 +109,8 @@ class _BatchNorm(_NormBase):
|
|||
|
||||
if self.training and self.track_running_stats:
|
||||
# TODO: if statement only here to tell the jit to skip emitting this when it is None
|
||||
if self.num_batches_tracked is not None:
|
||||
self.num_batches_tracked = self.num_batches_tracked + 1
|
||||
if self.num_batches_tracked is not None: # type: ignore
|
||||
self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore
|
||||
if self.momentum is None: # use cumulative moving average
|
||||
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
|
||||
else: # use exponential moving average
|
||||
|
|
@ -128,6 +130,8 @@ class _BatchNorm(_NormBase):
|
|||
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
|
||||
used for normalization (i.e. in eval mode when buffers are not None).
|
||||
"""
|
||||
assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor)
|
||||
assert self.running_var is None or isinstance(self.running_var, torch.Tensor)
|
||||
return F.batch_norm(
|
||||
input,
|
||||
# If buffers are not to be tracked, ensure that they won't be updated
|
||||
|
|
@ -487,6 +491,7 @@ class SyncBatchNorm(_BatchNorm):
|
|||
exponential_average_factor = self.momentum
|
||||
|
||||
if self.training and self.track_running_stats:
|
||||
assert self.num_batches_tracked is not None
|
||||
self.num_batches_tracked = self.num_batches_tracked + 1
|
||||
if self.momentum is None: # use cumulative moving average
|
||||
exponential_average_factor = 1.0 / self.num_batches_tracked.item()
|
||||
|
|
@ -508,6 +513,8 @@ class SyncBatchNorm(_BatchNorm):
|
|||
used for normalization (i.e. in eval mode when buffers are not None).
|
||||
"""
|
||||
# If buffers are not to be tracked, ensure that they won't be updated
|
||||
assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor)
|
||||
assert self.running_var is None or isinstance(self.running_var, torch.Tensor)
|
||||
running_mean = self.running_mean if not self.training or self.track_running_stats else None
|
||||
running_var = self.running_var if not self.training or self.track_running_stats else None
|
||||
|
||||
|
|
|
|||
|
|
@ -52,6 +52,8 @@ class _InstanceNorm(_NormBase):
|
|||
def forward(self, input: Tensor) -> Tensor:
|
||||
self._check_input_dim(input)
|
||||
|
||||
assert self.running_mean is None or isinstance(self.running_mean, Tensor)
|
||||
assert self.running_var is None or isinstance(self.running_var, Tensor)
|
||||
return F.instance_norm(
|
||||
input, self.running_mean, self.running_var, self.weight, self.bias,
|
||||
self.training or not self.track_running_stats, self.momentum, self.eps)
|
||||
|
|
|
|||
|
|
@ -102,10 +102,10 @@ class Linear(Module):
|
|||
# This class exists solely for Transformer; it has an annotation stating
|
||||
# that bias is never None, which appeases TorchScript
|
||||
class _LinearWithBias(Linear):
|
||||
bias: Tensor
|
||||
bias: Tensor # type: ignore
|
||||
|
||||
def __init__(self, in_features: int, out_features: int) -> None:
|
||||
super().__init__(in_features, out_features, bias=True)
|
||||
super().__init__(in_features, out_features, bias=True) # type: ignore
|
||||
|
||||
|
||||
class Bilinear(Module):
|
||||
|
|
@ -208,7 +208,8 @@ class LazyLinear(LazyModuleMixin, Linear):
|
|||
|
||||
"""
|
||||
|
||||
cls_to_become = Linear
|
||||
cls_to_become = Linear # type: ignore[assignment]
|
||||
weight: UninitializedParameter
|
||||
|
||||
def __init__(self, out_features: int, bias: bool = True) -> None:
|
||||
super().__init__(0, out_features, bias)
|
||||
|
|
@ -218,7 +219,7 @@ class LazyLinear(LazyModuleMixin, Linear):
|
|||
if not self.has_uninitialized_params() and self.in_features != 0:
|
||||
super().reset_parameters()
|
||||
|
||||
def initialize_parameters(self, input) -> None:
|
||||
def initialize_parameters(self, input) -> None: # type: ignore
|
||||
if self.has_uninitialized_params():
|
||||
with torch.no_grad():
|
||||
self.in_features = input.shape[-1]
|
||||
|
|
|
|||
|
|
@ -99,8 +99,8 @@ class Embedding(Module):
|
|||
|
||||
num_embeddings: int
|
||||
embedding_dim: int
|
||||
padding_idx: int
|
||||
max_norm: float
|
||||
padding_idx: Optional[int]
|
||||
max_norm: Optional[float]
|
||||
norm_type: float
|
||||
scale_grad_by_freq: bool
|
||||
weight: Tensor
|
||||
|
|
@ -284,7 +284,7 @@ class EmbeddingBag(Module):
|
|||
|
||||
num_embeddings: int
|
||||
embedding_dim: int
|
||||
max_norm: float
|
||||
max_norm: Optional[float]
|
||||
norm_type: float
|
||||
scale_grad_by_freq: bool
|
||||
weight: Tensor
|
||||
|
|
|
|||
|
|
@ -11,5 +11,5 @@ class Parameter(Tensor):
|
|||
class UninitializedParameter(Tensor):
|
||||
def __init__(self, data: Tensor=..., requires_grad: builtins.bool=...): ...
|
||||
|
||||
def materialize(self, shape: Tuple[int], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): ...
|
||||
def materialize(self, shape: Tuple[int, ...], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): ...
|
||||
...
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user