Remove __torch_function__ handling for torch.nn.init functions

ghstack-source-id: 32bbef8d50
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72195
This commit is contained in:
Joel Benjamin Schlosser 2022-02-02 15:08:48 -05:00
parent e19f2e52ad
commit 28128f0c1a

View File

@ -4,10 +4,6 @@ import warnings
from torch import Tensor
import torch
from ..overrides import (
has_torch_function_variadic,
handle_torch_function)
# These no_grad_* functions are necessary as wrappers around the parts of these
# functions that use `with torch.no_grad()`. The JIT doesn't support context
# managers, so these need to be implemented as builtins. Using these wrappers
@ -135,8 +131,6 @@ def uniform_(tensor: Tensor, a: float = 0., b: float = 1.) -> Tensor:
>>> w = torch.empty(3, 5)
>>> nn.init.uniform_(w)
"""
if has_torch_function_variadic(tensor):
return handle_torch_function(uniform_, (tensor,), tensor=tensor, a=a, b=b)
return _no_grad_uniform_(tensor, a, b)
@ -153,8 +147,6 @@ def normal_(tensor: Tensor, mean: float = 0., std: float = 1.) -> Tensor:
>>> w = torch.empty(3, 5)
>>> nn.init.normal_(w)
"""
if has_torch_function_variadic(tensor):
return handle_torch_function(normal_, (tensor,), tensor=tensor, mean=mean, std=std)
return _no_grad_normal_(tensor, mean, std)
def trunc_normal_(tensor: Tensor, mean: float = 0., std: float = 1., a: float = -2., b: float = 2.) -> Tensor:
@ -391,9 +383,6 @@ def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
>>> w = torch.empty(3, 5)
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
"""
if has_torch_function_variadic(tensor):
return handle_torch_function(kaiming_uniform_, (tensor,), tensor=tensor, a=a, mode=mode, nonlinearity=nonlinearity)
if 0 in tensor.shape:
warnings.warn("Initializing zero-element tensors is a no-op")
return tensor