add type annotations to common_nn.py (#48190)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/48189

Pull Request resolved: https://github.com/pytorch/pytorch/pull/48190

Reviewed By: walterddr, zhangguanheng66

Differential Revision: D25245261

Pulled By: malfet

fbshipit-source-id: 0eabaed54996be83ead0fd7668f4d2be20adfc17
This commit is contained in:
Guilherme Leobas 2020-12-02 14:43:03 -08:00 committed by Facebook GitHub Bot
parent a49e2c5ce6
commit a4e13fcf3f
3 changed files with 61 additions and 34 deletions

View File

@ -56,9 +56,6 @@ ignore_errors = True
[mypy-torch.testing._internal.hypothesis_utils.*]
ignore_errors = True
[mypy-torch.testing._internal.common_nn.*]
ignore_errors = True
[mypy-torch.testing._internal.common_quantization.*]
ignore_errors = True

View File

@ -13,7 +13,7 @@ from ._utils import _import_dotted_name
from ._six import string_classes as _string_classes
from torch._utils_internal import get_source_lines_and_file
from torch.types import Storage
from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union
from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union, IO
import copyreg
import pickle
import pathlib
@ -330,7 +330,7 @@ def _check_dill_version(pickle_module) -> None:
pickle_module.__version__
))
def save(obj, f: Union[str, os.PathLike, BinaryIO],
def save(obj, f: Union[str, os.PathLike, BinaryIO, IO[bytes]],
pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) -> None:
"""Saves an object to a disk file.

View File

@ -1,3 +1,4 @@
from abc import abstractmethod
import math
import tempfile
import unittest
@ -13,7 +14,7 @@ import torch
import torch.cuda
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import _Reduction
from torch.nn import _reduction as _Reduction
from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \
TEST_WITH_ROCM, _assertGradAndGradgradChecks
from torch.testing._internal.common_cuda import TEST_CUDA
@ -24,6 +25,8 @@ from torch.autograd import Variable
from torch.types import _TensorOrTensors
import torch.backends.cudnn
from typing import Dict, Callable, Tuple, List, Sequence, Union, Any
TemporaryFile = tempfile.TemporaryFile
PRECISION = 1e-5
@ -640,7 +643,7 @@ def nllloss_no_reduce_test():
return dict(
fullname='NLLLoss_no_reduce',
constructor=wrap_functional(
lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)),
lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])),
cpp_function_call='''F::nll_loss(
i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''',
input_fn=lambda: torch.rand(15, 10).log(),
@ -652,11 +655,12 @@ def nllloss_no_reduce_test():
def nllloss_no_reduce_ignore_index_test():
t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
kwargs = {'ignore_index': 2, 'reduction': 'none'}
kwargs: Dict[str, Union[int, str]] = {'ignore_index': 2, 'reduction': 'none'}
return dict(
fullname='NLLLoss_no_reduce_ignore_index',
constructor=wrap_functional(
lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)),
lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']),
reduction=str(kwargs['reduction']))),
cpp_function_call='''F::nll_loss(
i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(2).reduction(torch::kNone))''',
input_fn=lambda: torch.rand(15, 10).log(),
@ -737,7 +741,7 @@ def nllloss2d_no_reduce_test():
return dict(
fullname='NLLLoss2d_no_reduce',
constructor=wrap_functional(
lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)),
lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])),
cpp_function_call='''F::nll_loss(
i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''',
input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
@ -749,11 +753,12 @@ def nllloss2d_no_reduce_test():
def nllloss2d_no_reduce_ignore_index_test():
t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
kwargs = {'ignore_index': 1, 'reduction': 'none'}
kwargs: Dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'}
return dict(
fullname='NLLLoss2d_no_reduce_ignore_index',
constructor=wrap_functional(
lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)),
lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']),
reduction=str(kwargs['reduction']))),
cpp_function_call='''F::nll_loss(
i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))''',
input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
@ -790,7 +795,7 @@ def nlllossNd_no_reduce_test():
return dict(
fullname='NLLLossNd_no_reduce',
constructor=wrap_functional(
lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)),
lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])),
cpp_function_call='''F::nll_loss(
i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''',
input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
@ -802,11 +807,12 @@ def nlllossNd_no_reduce_test():
def nlllossNd_no_reduce_ignore_index_test():
t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
kwargs = {'ignore_index': 1, 'reduction': 'none'}
kwargs: Dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'}
return dict(
fullname='NLLLossNd_no_reduce_ignore_index',
constructor=wrap_functional(
lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)),
lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']),
reduction=str(kwargs['reduction']))),
cpp_function_call='''F::nll_loss(
i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))''',
input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
@ -3691,7 +3697,7 @@ def kldivloss_reference(input, target, reduction='mean'):
return result.mean()
elif reduction == 'sum':
return result.sum()
elif reduction == 'batchmean' and results.dim() != 0:
elif reduction == 'batchmean' and result.dim() != 0:
return result.sum() / result.size(0)
return result
@ -3701,7 +3707,7 @@ def kldivloss_log_target_reference(input, target, reduction='mean'):
return result.mean()
elif reduction == 'sum':
return result.sum()
elif reduction == 'batchmean' and results.dim() != 0:
elif reduction == 'batchmean' and result.dim() != 0:
return result.sum() / result.size(0)
return result
@ -4017,7 +4023,7 @@ def padding3d_circular(input, pad):
return torch.cat([input[:, :, :, :, -pad[0]:], input, input[:, :, :, :, 0:pad[1]]], dim=4)
loss_reference_fns = {
loss_reference_fns: Dict['str', Callable] = {
'KLDivLoss': kldivloss_reference,
'KLDivLoss_log_target': kldivloss_log_target_reference,
'NLLLoss': nllloss_reference,
@ -4631,6 +4637,26 @@ criterion_tests = [
class NNTestCase(TestCase):
# _forward is defined in classes inheriting from NNTestCase
@abstractmethod
def _forward(self, *args, **kwargs):
raise NotImplementedError
@abstractmethod
def _get_parameters(self, module: nn.Module) -> Tuple[List[nn.Parameter], List[nn.Parameter]]:
raise NotImplementedError
@abstractmethod
def _zero_grad_parameters(self, module: nn.Module) -> None:
raise NotImplementedError
@abstractmethod
def _backward(self, module: nn.Module,
input: _TensorOrTensors, output: torch.Tensor,
grad_output: Union[torch.Tensor, Sequence[torch.Tensor]],
create_graph: bool = False):
raise NotImplementedError
def _jacobian(self, input, num_out):
if isinstance(input, tuple):
return tuple(self._jacobian(elem, num_out) for elem in input)
@ -4691,7 +4717,7 @@ class NNTestCase(TestCase):
if jacobian_parameters:
jacobian_param[:, i] = torch.cat(self._flatten_tensors(d_param), 0)
res = tuple()
res: Tuple[torch.Tensor, ...] = tuple()
if jacobian_input:
res += jacobian_inp,
if jacobian_parameters:
@ -4703,7 +4729,7 @@ class NNTestCase(TestCase):
def fw(input):
return self._forward(module, input).detach()
res = tuple()
res: Tuple[torch.Tensor, ...] = tuple()
if jacobian_input:
res += get_numerical_jacobian(fw, input, eps=1e-6),
if jacobian_parameters:
@ -4724,7 +4750,7 @@ class NNTestCase(TestCase):
differences.append(a.add(n, alpha=-1).abs().max())
# TODO: compare structure (ensure analytic jacobian has correct shape)
if len(differences) > 0:
self.assertLessEqual(max(differences), PRECISION)
self.assertLessEqual(max(differences), PRECISION) # type: ignore[type-var]
class TestBase(object):
@ -4807,6 +4833,10 @@ class TestBase(object):
class ModuleTest(TestBase):
@abstractmethod
def _do_test(self, test_case: Any, module: nn.Module, input: Any) -> Any:
raise NotImplementedError
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.jacobian_input = kwargs.get('jacobian_input', True)
@ -4903,7 +4933,7 @@ class ModuleTest(TestBase):
raise unittest.SkipTest('Excluded from CUDA tests')
cpu_input = self._get_input()
type_map = {'torch.DoubleTensor': torch.cuda.FloatTensor}
type_map = {'torch.DoubleTensor': torch.cuda.FloatTensor} # type: ignore[attr-defined]
cpu_input_tuple = cpu_input if isinstance(cpu_input, tuple) else (cpu_input,)
gpu_input_tuple = to_gpu(cpu_input_tuple, type_map=type_map)
@ -4980,7 +5010,7 @@ class ModuleTest(TestBase):
class InputVariableMixin(object):
def _get_input(self):
input = TestBase._get_input(self, False)
input = TestBase._get_input(self, False) # type: ignore[arg-type]
def map_variables(i):
if isinstance(i, torch.Tensor):
@ -4993,7 +5023,7 @@ class InputVariableMixin(object):
return map_variables(input)
class NewModuleTest(InputVariableMixin, ModuleTest):
class NewModuleTest(InputVariableMixin, ModuleTest): # type: ignore[misc]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cudnn = kwargs.get('cudnn', False)
@ -5059,14 +5089,14 @@ class NewModuleTest(InputVariableMixin, ModuleTest):
input_tuple = tuple(t.cuda() for t in input_tuple)
module.float().cuda()
module(*input_tuple)
assert_module_parameters_are(torch.cuda.FloatTensor, 0)
assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined]
if torch.cuda.device_count() > 1:
input_tuple = tuple(t.cuda(1) for t in input_tuple)
module.cuda(1)
with torch.cuda.device(1):
module(*input_tuple)
assert_module_parameters_are(torch.cuda.FloatTensor, 1)
assert_module_parameters_are(torch.cuda.FloatTensor, 1) # type: ignore[attr-defined]
else:
# check that float()/double() casters work correctly
@ -5091,7 +5121,7 @@ class NewModuleTest(InputVariableMixin, ModuleTest):
t.float().cuda() if not isinstance(t, torch.LongTensor) else t.cuda() for t in input_tuple)
module.float().cuda()
module(*input_tuple)
assert_module_parameters_are(torch.cuda.FloatTensor, 0)
assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined]
# to CPU
input_tuple = tuple(t.cpu() for t in input_tuple)
@ -5103,13 +5133,13 @@ class NewModuleTest(InputVariableMixin, ModuleTest):
input_tuple = tuple(t.cuda() for t in input_tuple)
module.cuda()
module(*input_tuple)
assert_module_parameters_are(torch.cuda.FloatTensor, 0)
assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined]
# test that forwards of module runs correctly without cuDNN
if self.cudnn:
with torch.backends.cudnn.flags(enabled=False):
module(*input_tuple)
assert_module_parameters_are(torch.cuda.FloatTensor, 0)
assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined]
if torch.cuda.device_count() >= 2:
# test cross-GPU transfer works
@ -5118,7 +5148,7 @@ class NewModuleTest(InputVariableMixin, ModuleTest):
module.cuda(1)
with torch.cuda.device(1):
module(*input_tuple)
assert_module_parameters_are(torch.cuda.FloatTensor, 1)
assert_module_parameters_are(torch.cuda.FloatTensor, 1) # type: ignore[attr-defined]
if not self.skip_double:
# test double()
@ -5126,14 +5156,14 @@ class NewModuleTest(InputVariableMixin, ModuleTest):
t.double().cuda() if not isinstance(t, torch.LongTensor) else t.cuda() for t in input_tuple)
module.double().cuda()
module(*input_tuple)
assert_module_parameters_are(torch.cuda.DoubleTensor, 0)
assert_module_parameters_are(torch.cuda.DoubleTensor, 0) # type: ignore[attr-defined]
# test half()
input_tuple = tuple(
t.half().cuda() if not isinstance(t, torch.LongTensor) else t.cuda() for t in input_tuple)
module.half().cuda()
module(*input_tuple)
assert_module_parameters_are(torch.cuda.HalfTensor, 0)
assert_module_parameters_are(torch.cuda.HalfTensor, 0) # type: ignore[attr-defined]
torch.set_num_threads(num_threads)
def _get_target(self):
@ -5144,7 +5174,7 @@ class NewModuleTest(InputVariableMixin, ModuleTest):
return self._get_arg('constructor_args', False)
class CriterionTest(InputVariableMixin, TestBase):
class CriterionTest(InputVariableMixin, TestBase): # type: ignore[misc]
# TODO: check that criterions don't ignore grad_output
_required_arg_names = TestBase._required_arg_names.union({'target'})
@ -5188,7 +5218,7 @@ class CriterionTest(InputVariableMixin, TestBase):
else:
inputs = input + params + (target,)
def apply_fn(input1, input2, target, *params):
def apply_fn(input1, input2, target, *params): # type: ignore[misc]
return module(input1, input2, target)
gradcheck(apply_fn, inputs)