Convert assert -> cast. (#57458)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/55868.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/57458

Reviewed By: mruberry

Differential Revision: D28365745

Pulled By: walterddr

fbshipit-source-id: 35cc3fa85f87b0ef98cf970f620ab909d240c7be
This commit is contained in:
Hameer Abbasi 2021-05-12 13:53:03 -07:00 committed by Facebook GitHub Bot
parent 614437751f
commit 46e4b2dbda
9 changed files with 34 additions and 36 deletions

View File

@ -978,5 +978,14 @@ class TestIterator(TestCase):
self.assertIs(type(next(it)), SubTensor2)
self.assertIs(type(next(it)), SubTensor2)
class TestRNN(TestCase):
# Regression test for gh-55868
def test_rnn(self):
model = torch.nn.RNN(10, 20, 2)
input = Wrapper(torch.randn(1, 5, 10))
model(input)
if __name__ == '__main__':
run_tests()

View File

@ -12,7 +12,7 @@ except ImportError:
)
try:
import rich
import rich # type: ignore[import]
except ImportError:
print("rich not found, for color output use 'pip install rich'")

View File

@ -151,8 +151,6 @@ class DeferredBatchNorm(_BatchNorm):
if module.affine:
module_output.register_parameter("weight", module.weight)
module_output.register_parameter("bias", module.bias)
assert isinstance(module.running_mean, Tensor)
assert isinstance(module.running_var, Tensor)
module_output.register_buffer("running_mean", module.running_mean)
module_output.register_buffer("running_var", module.running_var)
module_output.register_buffer("num_batches_tracked", module.num_batches_tracked)

View File

@ -45,9 +45,9 @@ def _load_for_lite_interpreter(f, map_location=None):
map_location = validate_map_location(map_location)
if isinstance(f, str) or isinstance(f, pathlib.Path):
cpp_module = torch._C._load_for_lite_interpreter(f, map_location)
cpp_module = torch._C._load_for_lite_interpreter(f, map_location) # type: ignore[attr-defined]
else:
cpp_module = torch._C._load_for_lite_interpreter_from_buffer(f.read(), map_location)
cpp_module = torch._C._load_for_lite_interpreter_from_buffer(f.read(), map_location) # type: ignore[attr-defined]
return LiteScriptModule(cpp_module)
@ -102,9 +102,9 @@ def _get_model_bytecode_version(f_input) -> int:
raise ValueError(f"The provided filename {f_input} is a directory")
if (isinstance(f_input, str) or isinstance(f_input, pathlib.Path)):
return torch._C._get_model_bytecode_version(str(f_input))
return torch._C._get_model_bytecode_version(str(f_input)) # type: ignore[attr-defined]
else:
return torch._C._get_model_bytecode_version_from_buffer(f_input.read())
return torch._C._get_model_bytecode_version_from_buffer(f_input.read()) # type: ignore[attr-defined]
def _backport_for_mobile(f_input, f_output, to_version):
r"""
@ -124,9 +124,9 @@ def _backport_for_mobile(f_input, f_output, to_version):
if ((isinstance(f_input, str) or isinstance(f_input, pathlib.Path)) and (
isinstance(f_output, str) or isinstance(f_output, pathlib.Path))):
return torch._C._backport_for_mobile(str(f_input), str(f_output), to_version)
return torch._C._backport_for_mobile(str(f_input), str(f_output), to_version) # type: ignore[attr-defined]
else:
return torch._C._backport_for_mobile_from_buffer(f_input.read(), str(f_output), to_version)
return torch._C._backport_for_mobile_from_buffer(f_input.read(), str(f_output), to_version) # type: ignore[attr-defined]
def _backport_for_mobile_to_buffer(f_input, to_version):
r"""
@ -142,6 +142,6 @@ def _backport_for_mobile_to_buffer(f_input, to_version):
raise ValueError(f"The provided filename {f_input} is a directory")
if (isinstance(f_input, str) or isinstance(f_input, pathlib.Path)):
return torch._C._backport_for_mobile_to_buffer(str(f_input), to_version)
return torch._C._backport_for_mobile_to_buffer(str(f_input), to_version) # type: ignore[attr-defined]
else:
return torch._C._backport_for_mobile_from_buffer_to_buffer(f_input.read(), to_version)
return torch._C._backport_for_mobile_from_buffer_to_buffer(f_input.read(), to_version) # type: ignore[attr-defined]

View File

@ -91,7 +91,7 @@ class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule):
return self
def _forward(self, input):
assert isinstance(self.bn.running_var, torch.Tensor)
assert self.bn.running_var is not None
running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
scale_factor = self.bn.weight / running_std
weight_shape = [1] * len(self.weight.shape)

View File

@ -50,6 +50,8 @@ class _NormBase(Module):
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs))
self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs))
self.running_mean: Optional[Tensor]
self.running_var: Optional[Tensor]
self.register_buffer('num_batches_tracked',
torch.tensor(0, dtype=torch.long,
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
@ -63,9 +65,9 @@ class _NormBase(Module):
if self.track_running_stats:
# running_mean/running_var/num_batches... are registered at runtime depending
# if self.track_running_stats is on
self.running_mean.zero_() # type: ignore[operator]
self.running_var.fill_(1) # type: ignore[operator]
self.num_batches_tracked.zero_() # type: ignore[operator]
self.running_mean.zero_() # type: ignore[union-attr]
self.running_var.fill_(1) # type: ignore[union-attr]
self.num_batches_tracked.zero_() # type: ignore[union-attr,operator]
def reset_parameters(self) -> None:
self.reset_running_stats()
@ -162,8 +164,6 @@ class _BatchNorm(_NormBase):
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
used for normalization (i.e. in eval mode when buffers are not None).
"""
assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor)
assert self.running_var is None or isinstance(self.running_var, torch.Tensor)
return F.batch_norm(
input,
# If buffers are not to be tracked, ensure that they won't be updated
@ -221,8 +221,8 @@ class _LazyBatchNorm(LazyModuleMixin, _BatchNorm):
self.weight.materialize((self.num_features,))
self.bias.materialize((self.num_features,))
if self.track_running_stats:
self.running_mean.materialize((self.num_features,))
self.running_var.materialize((self.num_features,))
self.running_mean.materialize((self.num_features,)) # type:ignore[union-attr]
self.running_var.materialize((self.num_features,)) # type:ignore[union-attr]
self.reset_parameters()
@ -715,8 +715,6 @@ class SyncBatchNorm(_BatchNorm):
used for normalization (i.e. in eval mode when buffers are not None).
"""
# If buffers are not to be tracked, ensure that they won't be updated
assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor)
assert self.running_var is None or isinstance(self.running_var, torch.Tensor)
running_mean = (
self.running_mean if not self.training or self.track_running_stats else None
)

View File

@ -54,9 +54,6 @@ class _InstanceNorm(_NormBase):
def forward(self, input: Tensor) -> Tensor:
self._check_input_dim(input)
assert self.running_mean is None or isinstance(self.running_mean, Tensor)
assert self.running_var is None or isinstance(self.running_var, Tensor)
return F.instance_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
self.training or not self.track_running_stats, self.momentum, self.eps)

View File

@ -15,7 +15,7 @@ class _Loss(Module):
def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None:
super(_Loss, self).__init__()
if size_average is not None or reduce is not None:
self.reduction = _Reduction.legacy_get_string(size_average, reduce)
self.reduction: str = _Reduction.legacy_get_string(size_average, reduce)
else:
self.reduction = reduction
@ -24,6 +24,7 @@ class _WeightedLoss(_Loss):
def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean') -> None:
super(_WeightedLoss, self).__init__(size_average, reduce, reduction)
self.register_buffer('weight', weight)
self.weight: Optional[Tensor]
class L1Loss(_Loss):
@ -212,7 +213,6 @@ class NLLLoss(_WeightedLoss):
self.ignore_index = ignore_index
def forward(self, input: Tensor, target: Tensor) -> Tensor:
assert self.weight is None or isinstance(self.weight, Tensor)
return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction)
@ -609,7 +609,6 @@ class BCELoss(_WeightedLoss):
super(BCELoss, self).__init__(weight, size_average, reduce, reduction)
def forward(self, input: Tensor, target: Tensor) -> Tensor:
assert self.weight is None or isinstance(self.weight, Tensor)
return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
@ -707,10 +706,10 @@ class BCEWithLogitsLoss(_Loss):
super(BCEWithLogitsLoss, self).__init__(size_average, reduce, reduction)
self.register_buffer('weight', weight)
self.register_buffer('pos_weight', pos_weight)
self.weight: Optional[Tensor]
self.pos_weight: Optional[Tensor]
def forward(self, input: Tensor, target: Tensor) -> Tensor:
assert self.weight is None or isinstance(self.weight, Tensor)
assert self.pos_weight is None or isinstance(self.pos_weight, Tensor)
return F.binary_cross_entropy_with_logits(input, target,
self.weight,
pos_weight=self.pos_weight,
@ -1118,7 +1117,6 @@ class CrossEntropyLoss(_WeightedLoss):
self.ignore_index = ignore_index
def forward(self, input: Tensor, target: Tensor) -> Tensor:
assert self.weight is None or isinstance(self.weight, Tensor)
return F.cross_entropy(input, target, weight=self.weight,
ignore_index=self.ignore_index, reduction=self.reduction)
@ -1167,7 +1165,6 @@ class MultiLabelSoftMarginLoss(_WeightedLoss):
super(MultiLabelSoftMarginLoss, self).__init__(weight, size_average, reduce, reduction)
def forward(self, input: Tensor, target: Tensor) -> Tensor:
assert self.weight is None or isinstance(self.weight, Tensor)
return F.multilabel_soft_margin_loss(input, target, weight=self.weight, reduction=self.reduction)
@ -1335,7 +1332,6 @@ class MultiMarginLoss(_WeightedLoss):
self.margin = margin
def forward(self, input: Tensor, target: Tensor) -> Tensor:
assert self.weight is None or isinstance(self.weight, Tensor)
return F.multi_margin_loss(input, target, p=self.p, margin=self.margin,
weight=self.weight, reduction=self.reduction)

View File

@ -1,7 +1,7 @@
import math
import warnings
import numbers
from typing import List, Tuple, Optional, overload, Union
from typing import List, Tuple, Optional, overload, Union, cast
import torch
from torch import Tensor
@ -244,14 +244,13 @@ class RNNBase(Module):
input, batch_sizes, sorted_indices, unsorted_indices = input
max_batch_size = int(batch_sizes[0])
else:
assert isinstance(input, Tensor)
input = cast(Tensor, input)
batch_sizes = None
max_batch_size = input.size(0) if self.batch_first else input.size(1)
sorted_indices = None
unsorted_indices = None
assert isinstance(input, Tensor)
if hx is None:
input = cast(Tensor, input)
num_directions = 2 if self.bidirectional else 1
hx = torch.zeros(self.num_layers * num_directions,
max_batch_size, self.hidden_size,
@ -262,6 +261,7 @@ class RNNBase(Module):
hx = self.permute_hidden(hx, sorted_indices)
assert hx is not None
input = cast(Tensor, input)
self.check_forward_args(input, hx, batch_sizes)
_impl = _rnn_impls[self.mode]
if batch_sizes is None: