mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Forgot to mirror the `nn/ __init__.py` semantics in the new `nn` type stub. Pull Request resolved: https://github.com/pytorch/pytorch/pull/22411 Differential Revision: D16149798 Pulled By: ezyang fbshipit-source-id: 0ffa256fbdc5e5383a7b9c9c3ae61acd11de1dba
34 lines
1.0 KiB
Python
34 lines
1.0 KiB
Python
from typing import Any, Optional, TypeVar
|
|
from ... import Tensor
|
|
from ..modules import Module
|
|
|
|
|
|
class SpectralNorm:
|
|
name: str = ...
|
|
dim: int = ...
|
|
n_power_iterations: int = ...
|
|
eps: float = ...
|
|
|
|
def __init__(self, name: str = ..., n_power_iterations: int = ..., dim: int = ..., eps: float = ...) -> None: ...
|
|
|
|
def reshape_weight_to_matrix(self, weight: Tensor) -> Tensor: ...
|
|
|
|
def compute_weight(self, module: Module, do_power_iteration: bool) -> Tensor: ...
|
|
|
|
def remove(self, module: Module) -> None: ...
|
|
|
|
def __call__(self, module: Module, inputs: Any) -> None: ...
|
|
|
|
@staticmethod
|
|
def apply(module: Module, name: str, n_power_iterations: int, dim: int, eps: float) -> 'SpectralNorm': ...
|
|
|
|
|
|
T_module = TypeVar('T_module', bound=Module)
|
|
|
|
|
|
def spectral_norm(module: T_module, name: str = ..., n_power_iterations: int = ..., eps: float = ...,
|
|
dim: Optional[int] = ...) -> T_module: ...
|
|
|
|
|
|
def remove_spectral_norm(module: T_module, name: str = ...) -> T_module: ...
|