mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Add type annotations for torch.nn.utils.* (#43080)
Summary: Fixes https://github.com/pytorch/pytorch/issues/43013 Redo of gh-42954 Pull Request resolved: https://github.com/pytorch/pytorch/pull/43080 Reviewed By: albanD Differential Revision: D23681334 Pulled By: malfet fbshipit-source-id: 20ec78aa3bfecb7acffc12eb89d3ad833024394c
This commit is contained in:
parent
551494b01d
commit
e107ef5ca2
3
mypy.ini
3
mypy.ini
|
|
@ -255,9 +255,6 @@ ignore_errors = True
|
|||
[mypy-torch.nn.utils.prune]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.nn.utils.memory_format]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.nn.cpp]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
|
|||
|
|
@ -560,7 +560,7 @@ def gen_pyi(declarations_path, out):
|
|||
'def type(self, dtype: Union[str, _dtype], non_blocking: _bool=False) -> Tensor: ...',
|
||||
],
|
||||
'get_device': ['def get_device(self) -> _int: ...'],
|
||||
'contiguous': ['def contiguous(self) -> Tensor: ...'],
|
||||
'contiguous': ['def contiguous(self, memory_format=torch.contiguous_format) -> Tensor: ...'],
|
||||
'is_contiguous': ['def is_contiguous(self, memory_format=torch.contiguous_format) -> _bool: ...'],
|
||||
'is_cuda': ['is_cuda: _bool'],
|
||||
'is_leaf': ['is_leaf: _bool'],
|
||||
|
|
|
|||
|
|
@ -4,14 +4,9 @@ Pruning methods
|
|||
from abc import abstractmethod
|
||||
import numbers
|
||||
import torch
|
||||
# For Python 2 and 3 support
|
||||
try:
|
||||
from abc import ABC
|
||||
from collections.abc import Iterable
|
||||
except ImportError:
|
||||
from abc import ABCMeta
|
||||
ABC = ABCMeta('ABC', (), {})
|
||||
from collections import Iterable
|
||||
from abc import ABC
|
||||
from collections.abc import Iterable
|
||||
from typing import Tuple
|
||||
|
||||
class BasePruningMethod(ABC):
|
||||
r"""Abstract base class for creation of new pruning techniques.
|
||||
|
|
@ -19,6 +14,8 @@ class BasePruningMethod(ABC):
|
|||
Provides a skeleton for customization requiring the overriding of methods
|
||||
such as :meth:`compute_mask` and :meth:`apply`.
|
||||
"""
|
||||
_tensor_name: str
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
|
@ -252,7 +249,7 @@ class PruningContainer(BasePruningMethod):
|
|||
"""
|
||||
|
||||
def __init__(self, *args):
|
||||
self._pruning_methods = tuple()
|
||||
self._pruning_methods: Tuple['BasePruningMethod', ...] = tuple()
|
||||
if not isinstance(args, Iterable): # only 1 item
|
||||
self._tensor_name = args._tensor_name
|
||||
self.add_pruning_method(args)
|
||||
|
|
@ -275,7 +272,7 @@ class PruningContainer(BasePruningMethod):
|
|||
raise TypeError(
|
||||
"{} is not a BasePruningMethod subclass".format(type(method))
|
||||
)
|
||||
elif self._tensor_name != method._tensor_name:
|
||||
elif method is not None and self._tensor_name != method._tensor_name:
|
||||
raise ValueError(
|
||||
"Can only add pruning methods acting on "
|
||||
"the parameter named '{}' to PruningContainer {}.".format(
|
||||
|
|
@ -1182,7 +1179,7 @@ def _validate_pruning_amount_init(amount):
|
|||
|
||||
if (isinstance(amount, numbers.Integral) and amount < 0) or (
|
||||
not isinstance(amount, numbers.Integral) # so it's a float
|
||||
and (amount > 1.0 or amount < 0.0)
|
||||
and (float(amount) > 1.0 or float(amount) < 0.0)
|
||||
):
|
||||
raise ValueError(
|
||||
"amount={} should either be a float in the "
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user