mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Convert assert -> cast. (#57458)
Summary: Fixes https://github.com/pytorch/pytorch/issues/55868. Pull Request resolved: https://github.com/pytorch/pytorch/pull/57458 Reviewed By: mruberry Differential Revision: D28365745 Pulled By: walterddr fbshipit-source-id: 35cc3fa85f87b0ef98cf970f620ab909d240c7be
This commit is contained in:
parent
614437751f
commit
46e4b2dbda
|
|
@ -978,5 +978,14 @@ class TestIterator(TestCase):
|
|||
self.assertIs(type(next(it)), SubTensor2)
|
||||
self.assertIs(type(next(it)), SubTensor2)
|
||||
|
||||
|
||||
class TestRNN(TestCase):
|
||||
# Regression test for gh-55868
|
||||
def test_rnn(self):
|
||||
model = torch.nn.RNN(10, 20, 2)
|
||||
input = Wrapper(torch.randn(1, 5, 10))
|
||||
model(input)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ except ImportError:
|
|||
)
|
||||
|
||||
try:
|
||||
import rich
|
||||
import rich # type: ignore[import]
|
||||
except ImportError:
|
||||
print("rich not found, for color output use 'pip install rich'")
|
||||
|
||||
|
|
|
|||
|
|
@ -151,8 +151,6 @@ class DeferredBatchNorm(_BatchNorm):
|
|||
if module.affine:
|
||||
module_output.register_parameter("weight", module.weight)
|
||||
module_output.register_parameter("bias", module.bias)
|
||||
assert isinstance(module.running_mean, Tensor)
|
||||
assert isinstance(module.running_var, Tensor)
|
||||
module_output.register_buffer("running_mean", module.running_mean)
|
||||
module_output.register_buffer("running_var", module.running_var)
|
||||
module_output.register_buffer("num_batches_tracked", module.num_batches_tracked)
|
||||
|
|
|
|||
|
|
@ -45,9 +45,9 @@ def _load_for_lite_interpreter(f, map_location=None):
|
|||
map_location = validate_map_location(map_location)
|
||||
|
||||
if isinstance(f, str) or isinstance(f, pathlib.Path):
|
||||
cpp_module = torch._C._load_for_lite_interpreter(f, map_location)
|
||||
cpp_module = torch._C._load_for_lite_interpreter(f, map_location) # type: ignore[attr-defined]
|
||||
else:
|
||||
cpp_module = torch._C._load_for_lite_interpreter_from_buffer(f.read(), map_location)
|
||||
cpp_module = torch._C._load_for_lite_interpreter_from_buffer(f.read(), map_location) # type: ignore[attr-defined]
|
||||
|
||||
return LiteScriptModule(cpp_module)
|
||||
|
||||
|
|
@ -102,9 +102,9 @@ def _get_model_bytecode_version(f_input) -> int:
|
|||
raise ValueError(f"The provided filename {f_input} is a directory")
|
||||
|
||||
if (isinstance(f_input, str) or isinstance(f_input, pathlib.Path)):
|
||||
return torch._C._get_model_bytecode_version(str(f_input))
|
||||
return torch._C._get_model_bytecode_version(str(f_input)) # type: ignore[attr-defined]
|
||||
else:
|
||||
return torch._C._get_model_bytecode_version_from_buffer(f_input.read())
|
||||
return torch._C._get_model_bytecode_version_from_buffer(f_input.read()) # type: ignore[attr-defined]
|
||||
|
||||
def _backport_for_mobile(f_input, f_output, to_version):
|
||||
r"""
|
||||
|
|
@ -124,9 +124,9 @@ def _backport_for_mobile(f_input, f_output, to_version):
|
|||
|
||||
if ((isinstance(f_input, str) or isinstance(f_input, pathlib.Path)) and (
|
||||
isinstance(f_output, str) or isinstance(f_output, pathlib.Path))):
|
||||
return torch._C._backport_for_mobile(str(f_input), str(f_output), to_version)
|
||||
return torch._C._backport_for_mobile(str(f_input), str(f_output), to_version) # type: ignore[attr-defined]
|
||||
else:
|
||||
return torch._C._backport_for_mobile_from_buffer(f_input.read(), str(f_output), to_version)
|
||||
return torch._C._backport_for_mobile_from_buffer(f_input.read(), str(f_output), to_version) # type: ignore[attr-defined]
|
||||
|
||||
def _backport_for_mobile_to_buffer(f_input, to_version):
|
||||
r"""
|
||||
|
|
@ -142,6 +142,6 @@ def _backport_for_mobile_to_buffer(f_input, to_version):
|
|||
raise ValueError(f"The provided filename {f_input} is a directory")
|
||||
|
||||
if (isinstance(f_input, str) or isinstance(f_input, pathlib.Path)):
|
||||
return torch._C._backport_for_mobile_to_buffer(str(f_input), to_version)
|
||||
return torch._C._backport_for_mobile_to_buffer(str(f_input), to_version) # type: ignore[attr-defined]
|
||||
else:
|
||||
return torch._C._backport_for_mobile_from_buffer_to_buffer(f_input.read(), to_version)
|
||||
return torch._C._backport_for_mobile_from_buffer_to_buffer(f_input.read(), to_version) # type: ignore[attr-defined]
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule):
|
|||
return self
|
||||
|
||||
def _forward(self, input):
|
||||
assert isinstance(self.bn.running_var, torch.Tensor)
|
||||
assert self.bn.running_var is not None
|
||||
running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
|
||||
scale_factor = self.bn.weight / running_std
|
||||
weight_shape = [1] * len(self.weight.shape)
|
||||
|
|
|
|||
|
|
@ -50,6 +50,8 @@ class _NormBase(Module):
|
|||
if self.track_running_stats:
|
||||
self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs))
|
||||
self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs))
|
||||
self.running_mean: Optional[Tensor]
|
||||
self.running_var: Optional[Tensor]
|
||||
self.register_buffer('num_batches_tracked',
|
||||
torch.tensor(0, dtype=torch.long,
|
||||
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
|
||||
|
|
@ -63,9 +65,9 @@ class _NormBase(Module):
|
|||
if self.track_running_stats:
|
||||
# running_mean/running_var/num_batches... are registered at runtime depending
|
||||
# if self.track_running_stats is on
|
||||
self.running_mean.zero_() # type: ignore[operator]
|
||||
self.running_var.fill_(1) # type: ignore[operator]
|
||||
self.num_batches_tracked.zero_() # type: ignore[operator]
|
||||
self.running_mean.zero_() # type: ignore[union-attr]
|
||||
self.running_var.fill_(1) # type: ignore[union-attr]
|
||||
self.num_batches_tracked.zero_() # type: ignore[union-attr,operator]
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
self.reset_running_stats()
|
||||
|
|
@ -162,8 +164,6 @@ class _BatchNorm(_NormBase):
|
|||
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
|
||||
used for normalization (i.e. in eval mode when buffers are not None).
|
||||
"""
|
||||
assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor)
|
||||
assert self.running_var is None or isinstance(self.running_var, torch.Tensor)
|
||||
return F.batch_norm(
|
||||
input,
|
||||
# If buffers are not to be tracked, ensure that they won't be updated
|
||||
|
|
@ -221,8 +221,8 @@ class _LazyBatchNorm(LazyModuleMixin, _BatchNorm):
|
|||
self.weight.materialize((self.num_features,))
|
||||
self.bias.materialize((self.num_features,))
|
||||
if self.track_running_stats:
|
||||
self.running_mean.materialize((self.num_features,))
|
||||
self.running_var.materialize((self.num_features,))
|
||||
self.running_mean.materialize((self.num_features,)) # type:ignore[union-attr]
|
||||
self.running_var.materialize((self.num_features,)) # type:ignore[union-attr]
|
||||
self.reset_parameters()
|
||||
|
||||
|
||||
|
|
@ -715,8 +715,6 @@ class SyncBatchNorm(_BatchNorm):
|
|||
used for normalization (i.e. in eval mode when buffers are not None).
|
||||
"""
|
||||
# If buffers are not to be tracked, ensure that they won't be updated
|
||||
assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor)
|
||||
assert self.running_var is None or isinstance(self.running_var, torch.Tensor)
|
||||
running_mean = (
|
||||
self.running_mean if not self.training or self.track_running_stats else None
|
||||
)
|
||||
|
|
|
|||
|
|
@ -54,9 +54,6 @@ class _InstanceNorm(_NormBase):
|
|||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
self._check_input_dim(input)
|
||||
|
||||
assert self.running_mean is None or isinstance(self.running_mean, Tensor)
|
||||
assert self.running_var is None or isinstance(self.running_var, Tensor)
|
||||
return F.instance_norm(
|
||||
input, self.running_mean, self.running_var, self.weight, self.bias,
|
||||
self.training or not self.track_running_stats, self.momentum, self.eps)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ class _Loss(Module):
|
|||
def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None:
|
||||
super(_Loss, self).__init__()
|
||||
if size_average is not None or reduce is not None:
|
||||
self.reduction = _Reduction.legacy_get_string(size_average, reduce)
|
||||
self.reduction: str = _Reduction.legacy_get_string(size_average, reduce)
|
||||
else:
|
||||
self.reduction = reduction
|
||||
|
||||
|
|
@ -24,6 +24,7 @@ class _WeightedLoss(_Loss):
|
|||
def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean') -> None:
|
||||
super(_WeightedLoss, self).__init__(size_average, reduce, reduction)
|
||||
self.register_buffer('weight', weight)
|
||||
self.weight: Optional[Tensor]
|
||||
|
||||
|
||||
class L1Loss(_Loss):
|
||||
|
|
@ -212,7 +213,6 @@ class NLLLoss(_WeightedLoss):
|
|||
self.ignore_index = ignore_index
|
||||
|
||||
def forward(self, input: Tensor, target: Tensor) -> Tensor:
|
||||
assert self.weight is None or isinstance(self.weight, Tensor)
|
||||
return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction)
|
||||
|
||||
|
||||
|
|
@ -609,7 +609,6 @@ class BCELoss(_WeightedLoss):
|
|||
super(BCELoss, self).__init__(weight, size_average, reduce, reduction)
|
||||
|
||||
def forward(self, input: Tensor, target: Tensor) -> Tensor:
|
||||
assert self.weight is None or isinstance(self.weight, Tensor)
|
||||
return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
|
||||
|
||||
|
||||
|
|
@ -707,10 +706,10 @@ class BCEWithLogitsLoss(_Loss):
|
|||
super(BCEWithLogitsLoss, self).__init__(size_average, reduce, reduction)
|
||||
self.register_buffer('weight', weight)
|
||||
self.register_buffer('pos_weight', pos_weight)
|
||||
self.weight: Optional[Tensor]
|
||||
self.pos_weight: Optional[Tensor]
|
||||
|
||||
def forward(self, input: Tensor, target: Tensor) -> Tensor:
|
||||
assert self.weight is None or isinstance(self.weight, Tensor)
|
||||
assert self.pos_weight is None or isinstance(self.pos_weight, Tensor)
|
||||
return F.binary_cross_entropy_with_logits(input, target,
|
||||
self.weight,
|
||||
pos_weight=self.pos_weight,
|
||||
|
|
@ -1118,7 +1117,6 @@ class CrossEntropyLoss(_WeightedLoss):
|
|||
self.ignore_index = ignore_index
|
||||
|
||||
def forward(self, input: Tensor, target: Tensor) -> Tensor:
|
||||
assert self.weight is None or isinstance(self.weight, Tensor)
|
||||
return F.cross_entropy(input, target, weight=self.weight,
|
||||
ignore_index=self.ignore_index, reduction=self.reduction)
|
||||
|
||||
|
|
@ -1167,7 +1165,6 @@ class MultiLabelSoftMarginLoss(_WeightedLoss):
|
|||
super(MultiLabelSoftMarginLoss, self).__init__(weight, size_average, reduce, reduction)
|
||||
|
||||
def forward(self, input: Tensor, target: Tensor) -> Tensor:
|
||||
assert self.weight is None or isinstance(self.weight, Tensor)
|
||||
return F.multilabel_soft_margin_loss(input, target, weight=self.weight, reduction=self.reduction)
|
||||
|
||||
|
||||
|
|
@ -1335,7 +1332,6 @@ class MultiMarginLoss(_WeightedLoss):
|
|||
self.margin = margin
|
||||
|
||||
def forward(self, input: Tensor, target: Tensor) -> Tensor:
|
||||
assert self.weight is None or isinstance(self.weight, Tensor)
|
||||
return F.multi_margin_loss(input, target, p=self.p, margin=self.margin,
|
||||
weight=self.weight, reduction=self.reduction)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import math
|
||||
import warnings
|
||||
import numbers
|
||||
from typing import List, Tuple, Optional, overload, Union
|
||||
from typing import List, Tuple, Optional, overload, Union, cast
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
@ -244,14 +244,13 @@ class RNNBase(Module):
|
|||
input, batch_sizes, sorted_indices, unsorted_indices = input
|
||||
max_batch_size = int(batch_sizes[0])
|
||||
else:
|
||||
assert isinstance(input, Tensor)
|
||||
input = cast(Tensor, input)
|
||||
batch_sizes = None
|
||||
max_batch_size = input.size(0) if self.batch_first else input.size(1)
|
||||
sorted_indices = None
|
||||
unsorted_indices = None
|
||||
|
||||
assert isinstance(input, Tensor)
|
||||
if hx is None:
|
||||
input = cast(Tensor, input)
|
||||
num_directions = 2 if self.bidirectional else 1
|
||||
hx = torch.zeros(self.num_layers * num_directions,
|
||||
max_batch_size, self.hidden_size,
|
||||
|
|
@ -262,6 +261,7 @@ class RNNBase(Module):
|
|||
hx = self.permute_hidden(hx, sorted_indices)
|
||||
|
||||
assert hx is not None
|
||||
input = cast(Tensor, input)
|
||||
self.check_forward_args(input, hx, batch_sizes)
|
||||
_impl = _rnn_impls[self.mode]
|
||||
if batch_sizes is None:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user