mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Remove __torch_function__ handling for torch.nn.init functions
ghstack-source-id: 32bbef8d50
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72195
This commit is contained in:
parent
e19f2e52ad
commit
28128f0c1a
|
|
@ -4,10 +4,6 @@ import warnings
|
|||
from torch import Tensor
|
||||
import torch
|
||||
|
||||
from ..overrides import (
|
||||
has_torch_function_variadic,
|
||||
handle_torch_function)
|
||||
|
||||
# These no_grad_* functions are necessary as wrappers around the parts of these
|
||||
# functions that use `with torch.no_grad()`. The JIT doesn't support context
|
||||
# managers, so these need to be implemented as builtins. Using these wrappers
|
||||
|
|
@ -135,8 +131,6 @@ def uniform_(tensor: Tensor, a: float = 0., b: float = 1.) -> Tensor:
|
|||
>>> w = torch.empty(3, 5)
|
||||
>>> nn.init.uniform_(w)
|
||||
"""
|
||||
if has_torch_function_variadic(tensor):
|
||||
return handle_torch_function(uniform_, (tensor,), tensor=tensor, a=a, b=b)
|
||||
return _no_grad_uniform_(tensor, a, b)
|
||||
|
||||
|
||||
|
|
@ -153,8 +147,6 @@ def normal_(tensor: Tensor, mean: float = 0., std: float = 1.) -> Tensor:
|
|||
>>> w = torch.empty(3, 5)
|
||||
>>> nn.init.normal_(w)
|
||||
"""
|
||||
if has_torch_function_variadic(tensor):
|
||||
return handle_torch_function(normal_, (tensor,), tensor=tensor, mean=mean, std=std)
|
||||
return _no_grad_normal_(tensor, mean, std)
|
||||
|
||||
def trunc_normal_(tensor: Tensor, mean: float = 0., std: float = 1., a: float = -2., b: float = 2.) -> Tensor:
|
||||
|
|
@ -391,9 +383,6 @@ def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
|||
>>> w = torch.empty(3, 5)
|
||||
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
|
||||
"""
|
||||
if has_torch_function_variadic(tensor):
|
||||
return handle_torch_function(kaiming_uniform_, (tensor,), tensor=tensor, a=a, mode=mode, nonlinearity=nonlinearity)
|
||||
|
||||
if 0 in tensor.shape:
|
||||
warnings.warn("Initializing zero-element tensors is a no-op")
|
||||
return tensor
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user