mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
add type annotations to common_nn.py (#48190)
Summary: Fixes https://github.com/pytorch/pytorch/issues/48189 Pull Request resolved: https://github.com/pytorch/pytorch/pull/48190 Reviewed By: walterddr, zhangguanheng66 Differential Revision: D25245261 Pulled By: malfet fbshipit-source-id: 0eabaed54996be83ead0fd7668f4d2be20adfc17
This commit is contained in:
parent
a49e2c5ce6
commit
a4e13fcf3f
3
mypy.ini
3
mypy.ini
|
|
@ -56,9 +56,6 @@ ignore_errors = True
|
|||
[mypy-torch.testing._internal.hypothesis_utils.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.testing._internal.common_nn.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.testing._internal.common_quantization.*]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from ._utils import _import_dotted_name
|
|||
from ._six import string_classes as _string_classes
|
||||
from torch._utils_internal import get_source_lines_and_file
|
||||
from torch.types import Storage
|
||||
from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union
|
||||
from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union, IO
|
||||
import copyreg
|
||||
import pickle
|
||||
import pathlib
|
||||
|
|
@ -330,7 +330,7 @@ def _check_dill_version(pickle_module) -> None:
|
|||
pickle_module.__version__
|
||||
))
|
||||
|
||||
def save(obj, f: Union[str, os.PathLike, BinaryIO],
|
||||
def save(obj, f: Union[str, os.PathLike, BinaryIO, IO[bytes]],
|
||||
pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) -> None:
|
||||
"""Saves an object to a disk file.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from abc import abstractmethod
|
||||
import math
|
||||
import tempfile
|
||||
import unittest
|
||||
|
|
@ -13,7 +14,7 @@ import torch
|
|||
import torch.cuda
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.functional import _Reduction
|
||||
from torch.nn import _reduction as _Reduction
|
||||
from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \
|
||||
TEST_WITH_ROCM, _assertGradAndGradgradChecks
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
|
|
@ -24,6 +25,8 @@ from torch.autograd import Variable
|
|||
from torch.types import _TensorOrTensors
|
||||
import torch.backends.cudnn
|
||||
|
||||
from typing import Dict, Callable, Tuple, List, Sequence, Union, Any
|
||||
|
||||
TemporaryFile = tempfile.TemporaryFile
|
||||
PRECISION = 1e-5
|
||||
|
||||
|
|
@ -640,7 +643,7 @@ def nllloss_no_reduce_test():
|
|||
return dict(
|
||||
fullname='NLLLoss_no_reduce',
|
||||
constructor=wrap_functional(
|
||||
lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)),
|
||||
lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])),
|
||||
cpp_function_call='''F::nll_loss(
|
||||
i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''',
|
||||
input_fn=lambda: torch.rand(15, 10).log(),
|
||||
|
|
@ -652,11 +655,12 @@ def nllloss_no_reduce_test():
|
|||
|
||||
def nllloss_no_reduce_ignore_index_test():
|
||||
t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
|
||||
kwargs = {'ignore_index': 2, 'reduction': 'none'}
|
||||
kwargs: Dict[str, Union[int, str]] = {'ignore_index': 2, 'reduction': 'none'}
|
||||
return dict(
|
||||
fullname='NLLLoss_no_reduce_ignore_index',
|
||||
constructor=wrap_functional(
|
||||
lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)),
|
||||
lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']),
|
||||
reduction=str(kwargs['reduction']))),
|
||||
cpp_function_call='''F::nll_loss(
|
||||
i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(2).reduction(torch::kNone))''',
|
||||
input_fn=lambda: torch.rand(15, 10).log(),
|
||||
|
|
@ -737,7 +741,7 @@ def nllloss2d_no_reduce_test():
|
|||
return dict(
|
||||
fullname='NLLLoss2d_no_reduce',
|
||||
constructor=wrap_functional(
|
||||
lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)),
|
||||
lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])),
|
||||
cpp_function_call='''F::nll_loss(
|
||||
i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''',
|
||||
input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
|
||||
|
|
@ -749,11 +753,12 @@ def nllloss2d_no_reduce_test():
|
|||
|
||||
def nllloss2d_no_reduce_ignore_index_test():
|
||||
t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
|
||||
kwargs = {'ignore_index': 1, 'reduction': 'none'}
|
||||
kwargs: Dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'}
|
||||
return dict(
|
||||
fullname='NLLLoss2d_no_reduce_ignore_index',
|
||||
constructor=wrap_functional(
|
||||
lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)),
|
||||
lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']),
|
||||
reduction=str(kwargs['reduction']))),
|
||||
cpp_function_call='''F::nll_loss(
|
||||
i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))''',
|
||||
input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
|
||||
|
|
@ -790,7 +795,7 @@ def nlllossNd_no_reduce_test():
|
|||
return dict(
|
||||
fullname='NLLLossNd_no_reduce',
|
||||
constructor=wrap_functional(
|
||||
lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)),
|
||||
lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])),
|
||||
cpp_function_call='''F::nll_loss(
|
||||
i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''',
|
||||
input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
|
||||
|
|
@ -802,11 +807,12 @@ def nlllossNd_no_reduce_test():
|
|||
|
||||
def nlllossNd_no_reduce_ignore_index_test():
|
||||
t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
|
||||
kwargs = {'ignore_index': 1, 'reduction': 'none'}
|
||||
kwargs: Dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'}
|
||||
return dict(
|
||||
fullname='NLLLossNd_no_reduce_ignore_index',
|
||||
constructor=wrap_functional(
|
||||
lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)),
|
||||
lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']),
|
||||
reduction=str(kwargs['reduction']))),
|
||||
cpp_function_call='''F::nll_loss(
|
||||
i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))''',
|
||||
input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
|
||||
|
|
@ -3691,7 +3697,7 @@ def kldivloss_reference(input, target, reduction='mean'):
|
|||
return result.mean()
|
||||
elif reduction == 'sum':
|
||||
return result.sum()
|
||||
elif reduction == 'batchmean' and results.dim() != 0:
|
||||
elif reduction == 'batchmean' and result.dim() != 0:
|
||||
return result.sum() / result.size(0)
|
||||
return result
|
||||
|
||||
|
|
@ -3701,7 +3707,7 @@ def kldivloss_log_target_reference(input, target, reduction='mean'):
|
|||
return result.mean()
|
||||
elif reduction == 'sum':
|
||||
return result.sum()
|
||||
elif reduction == 'batchmean' and results.dim() != 0:
|
||||
elif reduction == 'batchmean' and result.dim() != 0:
|
||||
return result.sum() / result.size(0)
|
||||
return result
|
||||
|
||||
|
|
@ -4017,7 +4023,7 @@ def padding3d_circular(input, pad):
|
|||
return torch.cat([input[:, :, :, :, -pad[0]:], input, input[:, :, :, :, 0:pad[1]]], dim=4)
|
||||
|
||||
|
||||
loss_reference_fns = {
|
||||
loss_reference_fns: Dict['str', Callable] = {
|
||||
'KLDivLoss': kldivloss_reference,
|
||||
'KLDivLoss_log_target': kldivloss_log_target_reference,
|
||||
'NLLLoss': nllloss_reference,
|
||||
|
|
@ -4631,6 +4637,26 @@ criterion_tests = [
|
|||
|
||||
class NNTestCase(TestCase):
|
||||
|
||||
# _forward is defined in classes inheriting from NNTestCase
|
||||
@abstractmethod
|
||||
def _forward(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_parameters(self, module: nn.Module) -> Tuple[List[nn.Parameter], List[nn.Parameter]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _zero_grad_parameters(self, module: nn.Module) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _backward(self, module: nn.Module,
|
||||
input: _TensorOrTensors, output: torch.Tensor,
|
||||
grad_output: Union[torch.Tensor, Sequence[torch.Tensor]],
|
||||
create_graph: bool = False):
|
||||
raise NotImplementedError
|
||||
|
||||
def _jacobian(self, input, num_out):
|
||||
if isinstance(input, tuple):
|
||||
return tuple(self._jacobian(elem, num_out) for elem in input)
|
||||
|
|
@ -4691,7 +4717,7 @@ class NNTestCase(TestCase):
|
|||
if jacobian_parameters:
|
||||
jacobian_param[:, i] = torch.cat(self._flatten_tensors(d_param), 0)
|
||||
|
||||
res = tuple()
|
||||
res: Tuple[torch.Tensor, ...] = tuple()
|
||||
if jacobian_input:
|
||||
res += jacobian_inp,
|
||||
if jacobian_parameters:
|
||||
|
|
@ -4703,7 +4729,7 @@ class NNTestCase(TestCase):
|
|||
def fw(input):
|
||||
return self._forward(module, input).detach()
|
||||
|
||||
res = tuple()
|
||||
res: Tuple[torch.Tensor, ...] = tuple()
|
||||
if jacobian_input:
|
||||
res += get_numerical_jacobian(fw, input, eps=1e-6),
|
||||
if jacobian_parameters:
|
||||
|
|
@ -4724,7 +4750,7 @@ class NNTestCase(TestCase):
|
|||
differences.append(a.add(n, alpha=-1).abs().max())
|
||||
# TODO: compare structure (ensure analytic jacobian has correct shape)
|
||||
if len(differences) > 0:
|
||||
self.assertLessEqual(max(differences), PRECISION)
|
||||
self.assertLessEqual(max(differences), PRECISION) # type: ignore[type-var]
|
||||
|
||||
|
||||
class TestBase(object):
|
||||
|
|
@ -4807,6 +4833,10 @@ class TestBase(object):
|
|||
|
||||
class ModuleTest(TestBase):
|
||||
|
||||
@abstractmethod
|
||||
def _do_test(self, test_case: Any, module: nn.Module, input: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.jacobian_input = kwargs.get('jacobian_input', True)
|
||||
|
|
@ -4903,7 +4933,7 @@ class ModuleTest(TestBase):
|
|||
raise unittest.SkipTest('Excluded from CUDA tests')
|
||||
|
||||
cpu_input = self._get_input()
|
||||
type_map = {'torch.DoubleTensor': torch.cuda.FloatTensor}
|
||||
type_map = {'torch.DoubleTensor': torch.cuda.FloatTensor} # type: ignore[attr-defined]
|
||||
cpu_input_tuple = cpu_input if isinstance(cpu_input, tuple) else (cpu_input,)
|
||||
gpu_input_tuple = to_gpu(cpu_input_tuple, type_map=type_map)
|
||||
|
||||
|
|
@ -4980,7 +5010,7 @@ class ModuleTest(TestBase):
|
|||
|
||||
class InputVariableMixin(object):
|
||||
def _get_input(self):
|
||||
input = TestBase._get_input(self, False)
|
||||
input = TestBase._get_input(self, False) # type: ignore[arg-type]
|
||||
|
||||
def map_variables(i):
|
||||
if isinstance(i, torch.Tensor):
|
||||
|
|
@ -4993,7 +5023,7 @@ class InputVariableMixin(object):
|
|||
return map_variables(input)
|
||||
|
||||
|
||||
class NewModuleTest(InputVariableMixin, ModuleTest):
|
||||
class NewModuleTest(InputVariableMixin, ModuleTest): # type: ignore[misc]
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.cudnn = kwargs.get('cudnn', False)
|
||||
|
|
@ -5059,14 +5089,14 @@ class NewModuleTest(InputVariableMixin, ModuleTest):
|
|||
input_tuple = tuple(t.cuda() for t in input_tuple)
|
||||
module.float().cuda()
|
||||
module(*input_tuple)
|
||||
assert_module_parameters_are(torch.cuda.FloatTensor, 0)
|
||||
assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined]
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
input_tuple = tuple(t.cuda(1) for t in input_tuple)
|
||||
module.cuda(1)
|
||||
with torch.cuda.device(1):
|
||||
module(*input_tuple)
|
||||
assert_module_parameters_are(torch.cuda.FloatTensor, 1)
|
||||
assert_module_parameters_are(torch.cuda.FloatTensor, 1) # type: ignore[attr-defined]
|
||||
else:
|
||||
# check that float()/double() casters work correctly
|
||||
|
||||
|
|
@ -5091,7 +5121,7 @@ class NewModuleTest(InputVariableMixin, ModuleTest):
|
|||
t.float().cuda() if not isinstance(t, torch.LongTensor) else t.cuda() for t in input_tuple)
|
||||
module.float().cuda()
|
||||
module(*input_tuple)
|
||||
assert_module_parameters_are(torch.cuda.FloatTensor, 0)
|
||||
assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined]
|
||||
|
||||
# to CPU
|
||||
input_tuple = tuple(t.cpu() for t in input_tuple)
|
||||
|
|
@ -5103,13 +5133,13 @@ class NewModuleTest(InputVariableMixin, ModuleTest):
|
|||
input_tuple = tuple(t.cuda() for t in input_tuple)
|
||||
module.cuda()
|
||||
module(*input_tuple)
|
||||
assert_module_parameters_are(torch.cuda.FloatTensor, 0)
|
||||
assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined]
|
||||
|
||||
# test that forwards of module runs correctly without cuDNN
|
||||
if self.cudnn:
|
||||
with torch.backends.cudnn.flags(enabled=False):
|
||||
module(*input_tuple)
|
||||
assert_module_parameters_are(torch.cuda.FloatTensor, 0)
|
||||
assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined]
|
||||
|
||||
if torch.cuda.device_count() >= 2:
|
||||
# test cross-GPU transfer works
|
||||
|
|
@ -5118,7 +5148,7 @@ class NewModuleTest(InputVariableMixin, ModuleTest):
|
|||
module.cuda(1)
|
||||
with torch.cuda.device(1):
|
||||
module(*input_tuple)
|
||||
assert_module_parameters_are(torch.cuda.FloatTensor, 1)
|
||||
assert_module_parameters_are(torch.cuda.FloatTensor, 1) # type: ignore[attr-defined]
|
||||
|
||||
if not self.skip_double:
|
||||
# test double()
|
||||
|
|
@ -5126,14 +5156,14 @@ class NewModuleTest(InputVariableMixin, ModuleTest):
|
|||
t.double().cuda() if not isinstance(t, torch.LongTensor) else t.cuda() for t in input_tuple)
|
||||
module.double().cuda()
|
||||
module(*input_tuple)
|
||||
assert_module_parameters_are(torch.cuda.DoubleTensor, 0)
|
||||
assert_module_parameters_are(torch.cuda.DoubleTensor, 0) # type: ignore[attr-defined]
|
||||
|
||||
# test half()
|
||||
input_tuple = tuple(
|
||||
t.half().cuda() if not isinstance(t, torch.LongTensor) else t.cuda() for t in input_tuple)
|
||||
module.half().cuda()
|
||||
module(*input_tuple)
|
||||
assert_module_parameters_are(torch.cuda.HalfTensor, 0)
|
||||
assert_module_parameters_are(torch.cuda.HalfTensor, 0) # type: ignore[attr-defined]
|
||||
torch.set_num_threads(num_threads)
|
||||
|
||||
def _get_target(self):
|
||||
|
|
@ -5144,7 +5174,7 @@ class NewModuleTest(InputVariableMixin, ModuleTest):
|
|||
return self._get_arg('constructor_args', False)
|
||||
|
||||
|
||||
class CriterionTest(InputVariableMixin, TestBase):
|
||||
class CriterionTest(InputVariableMixin, TestBase): # type: ignore[misc]
|
||||
# TODO: check that criterions don't ignore grad_output
|
||||
|
||||
_required_arg_names = TestBase._required_arg_names.union({'target'})
|
||||
|
|
@ -5188,7 +5218,7 @@ class CriterionTest(InputVariableMixin, TestBase):
|
|||
else:
|
||||
inputs = input + params + (target,)
|
||||
|
||||
def apply_fn(input1, input2, target, *params):
|
||||
def apply_fn(input1, input2, target, *params): # type: ignore[misc]
|
||||
return module(input1, input2, target)
|
||||
|
||||
gradcheck(apply_fn, inputs)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user