Add type annotations for torch.nn.utils.* (#43080)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/43013

Redo of gh-42954

Pull Request resolved: https://github.com/pytorch/pytorch/pull/43080

Reviewed By: albanD

Differential Revision: D23681334

Pulled By: malfet

fbshipit-source-id: 20ec78aa3bfecb7acffc12eb89d3ad833024394c
This commit is contained in:
Guilherme Leobas 2020-09-14 17:50:05 -07:00 committed by Facebook GitHub Bot
parent 551494b01d
commit e107ef5ca2
3 changed files with 9 additions and 15 deletions

View File

@ -255,9 +255,6 @@ ignore_errors = True
[mypy-torch.nn.utils.prune]
ignore_errors = True
[mypy-torch.nn.utils.memory_format]
ignore_errors = True
[mypy-torch.nn.cpp]
ignore_errors = True

View File

@ -560,7 +560,7 @@ def gen_pyi(declarations_path, out):
'def type(self, dtype: Union[str, _dtype], non_blocking: _bool=False) -> Tensor: ...',
],
'get_device': ['def get_device(self) -> _int: ...'],
'contiguous': ['def contiguous(self) -> Tensor: ...'],
'contiguous': ['def contiguous(self, memory_format=torch.contiguous_format) -> Tensor: ...'],
'is_contiguous': ['def is_contiguous(self, memory_format=torch.contiguous_format) -> _bool: ...'],
'is_cuda': ['is_cuda: _bool'],
'is_leaf': ['is_leaf: _bool'],

View File

@ -4,14 +4,9 @@ Pruning methods
from abc import abstractmethod
import numbers
import torch
# For Python 2 and 3 support
try:
from abc import ABC
from collections.abc import Iterable
except ImportError:
from abc import ABCMeta
ABC = ABCMeta('ABC', (), {})
from collections import Iterable
from abc import ABC
from collections.abc import Iterable
from typing import Tuple
class BasePruningMethod(ABC):
r"""Abstract base class for creation of new pruning techniques.
@ -19,6 +14,8 @@ class BasePruningMethod(ABC):
Provides a skeleton for customization requiring the overriding of methods
such as :meth:`compute_mask` and :meth:`apply`.
"""
_tensor_name: str
def __init__(self):
pass
@ -252,7 +249,7 @@ class PruningContainer(BasePruningMethod):
"""
def __init__(self, *args):
self._pruning_methods = tuple()
self._pruning_methods: Tuple['BasePruningMethod', ...] = tuple()
if not isinstance(args, Iterable): # only 1 item
self._tensor_name = args._tensor_name
self.add_pruning_method(args)
@ -275,7 +272,7 @@ class PruningContainer(BasePruningMethod):
raise TypeError(
"{} is not a BasePruningMethod subclass".format(type(method))
)
elif self._tensor_name != method._tensor_name:
elif method is not None and self._tensor_name != method._tensor_name:
raise ValueError(
"Can only add pruning methods acting on "
"the parameter named '{}' to PruningContainer {}.".format(
@ -1182,7 +1179,7 @@ def _validate_pruning_amount_init(amount):
if (isinstance(amount, numbers.Integral) and amount < 0) or (
not isinstance(amount, numbers.Integral) # so it's a float
and (amount > 1.0 or amount < 0.0)
and (float(amount) > 1.0 or float(amount) < 0.0)
):
raise ValueError(
"amount={} should either be a float in the "