mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
avoiding adding some functions to the public python API before 1.11 release (#72543)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72543 Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D34085724 Pulled By: bdhirsh fbshipit-source-id: 941d5a90a6fa5328268d623e0e2b01577e4132ca
This commit is contained in:
parent
963027c7f2
commit
6676a0c79a
|
|
@ -6061,7 +6061,7 @@
|
|||
- func: scatter_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor
|
||||
variants: function, method
|
||||
|
||||
- func: scatter_reduce.two(Tensor self, int dim, Tensor index, str reduce, *, int? output_size=None) -> Tensor
|
||||
- func: _scatter_reduce.two(Tensor self, int dim, Tensor index, str reduce, *, int? output_size=None) -> Tensor
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CPU: scatter_reduce_two_cpu
|
||||
|
|
|
|||
|
|
@ -106,6 +106,7 @@ ALLOW_LIST = [
|
|||
("aten::_scatter_reduce", datetime.date(2022, 1, 31)),
|
||||
("aten::native_multi_head_self_attention", datetime.date(9999, 1, 1)),
|
||||
("aten::_native_multi_head_self_attention", datetime.date(9999, 1, 1)),
|
||||
("aten::scatter_reduce.two", datetime.date(2022, 3, 15)),
|
||||
]
|
||||
|
||||
ALLOW_LIST_COMPILED = [
|
||||
|
|
|
|||
|
|
@ -5773,7 +5773,7 @@ class TestTorch(TestCase):
|
|||
|
||||
for reduce in reduces:
|
||||
for dim in range(len(shape)):
|
||||
output = input.scatter_reduce(dim, index, reduce, output_size=output_size)
|
||||
output = input._scatter_reduce(dim, index, reduce, output_size=output_size)
|
||||
|
||||
# Check that output is of the correct size
|
||||
output_shape = copy.copy(shape)
|
||||
|
|
@ -5807,16 +5807,16 @@ class TestTorch(TestCase):
|
|||
self.assertTrue(torch.allclose(output, expected))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected `dim` to be in range -3 to 2"):
|
||||
torch.scatter_reduce(input, 4, index, "sum")
|
||||
torch._scatter_reduce(input, 4, index, "sum")
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Shape mismatch"):
|
||||
index2 = torch.randint(0, output_size, (10, ), dtype=torch.long, device=device)
|
||||
torch.scatter_reduce(input, 0, index2, "sum")
|
||||
torch._scatter_reduce(input, 0, index2, "sum")
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected `index` values to be in range 0 to 2"):
|
||||
input2 = torch.randn(10, dtype=dtype, device=device)
|
||||
index2 = torch.tensor([0, 1, 0, 1, 2, 3, 3, 4, 4, 3])
|
||||
torch.scatter_reduce(input2, 0, index2, "sum", output_size=2)
|
||||
torch._scatter_reduce(input2, 0, index2, "sum", output_size=2)
|
||||
|
||||
def test_structseq_repr(self):
|
||||
a = torch.arange(250).reshape(5, 5, 10)
|
||||
|
|
|
|||
|
|
@ -2595,6 +2595,6 @@
|
|||
- name: _efficientzerotensor(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
output_differentiability: [False]
|
||||
|
||||
- name: scatter_reduce.two(Tensor self, int dim, Tensor index, str reduce, *, int? output_size=None) -> Tensor
|
||||
- name: _scatter_reduce.two(Tensor self, int dim, Tensor index, str reduce, *, int? output_size=None) -> Tensor
|
||||
self: scatter_reduce_backward(grad, self, dim, index, reduce, result)
|
||||
index: non_differentiable
|
||||
|
|
|
|||
|
|
@ -4,9 +4,6 @@ import warnings
|
|||
from torch import Tensor
|
||||
import torch
|
||||
|
||||
from ..overrides import (
|
||||
has_torch_function_variadic,
|
||||
handle_torch_function)
|
||||
|
||||
# These no_grad_* functions are necessary as wrappers around the parts of these
|
||||
# functions that use `with torch.no_grad()`. The JIT doesn't support context
|
||||
|
|
@ -135,8 +132,8 @@ def uniform_(tensor: Tensor, a: float = 0., b: float = 1.) -> Tensor:
|
|||
>>> w = torch.empty(3, 5)
|
||||
>>> nn.init.uniform_(w)
|
||||
"""
|
||||
if has_torch_function_variadic(tensor):
|
||||
return handle_torch_function(uniform_, (tensor,), tensor=tensor, a=a, b=b)
|
||||
if torch.overrides.has_torch_function_variadic(tensor):
|
||||
return torch.overrides.handle_torch_function(uniform_, (tensor,), tensor=tensor, a=a, b=b)
|
||||
return _no_grad_uniform_(tensor, a, b)
|
||||
|
||||
|
||||
|
|
@ -153,8 +150,8 @@ def normal_(tensor: Tensor, mean: float = 0., std: float = 1.) -> Tensor:
|
|||
>>> w = torch.empty(3, 5)
|
||||
>>> nn.init.normal_(w)
|
||||
"""
|
||||
if has_torch_function_variadic(tensor):
|
||||
return handle_torch_function(normal_, (tensor,), tensor=tensor, mean=mean, std=std)
|
||||
if torch.overrides.has_torch_function_variadic(tensor):
|
||||
return torch.overrides.handle_torch_function(normal_, (tensor,), tensor=tensor, mean=mean, std=std)
|
||||
return _no_grad_normal_(tensor, mean, std)
|
||||
|
||||
def trunc_normal_(tensor: Tensor, mean: float = 0., std: float = 1., a: float = -2., b: float = 2.) -> Tensor:
|
||||
|
|
@ -190,8 +187,8 @@ def constant_(tensor: Tensor, val: float) -> Tensor:
|
|||
>>> w = torch.empty(3, 5)
|
||||
>>> nn.init.constant_(w, 0.3)
|
||||
"""
|
||||
if has_torch_function_variadic(tensor):
|
||||
return handle_torch_function(constant_, (tensor,), tensor=tensor, val=val)
|
||||
if torch.overrides.has_torch_function_variadic(tensor):
|
||||
return torch.overrides.handle_torch_function(constant_, (tensor,), tensor=tensor, val=val)
|
||||
return _no_grad_fill_(tensor, val)
|
||||
|
||||
|
||||
|
|
@ -393,8 +390,14 @@ def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
|||
>>> w = torch.empty(3, 5)
|
||||
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
|
||||
"""
|
||||
if has_torch_function_variadic(tensor):
|
||||
return handle_torch_function(kaiming_uniform_, (tensor,), tensor=tensor, a=a, mode=mode, nonlinearity=nonlinearity)
|
||||
if torch.overrides.has_torch_function_variadic(tensor):
|
||||
return torch.overrides.handle_torch_function(
|
||||
kaiming_uniform_,
|
||||
(tensor,),
|
||||
tensor=tensor,
|
||||
a=a,
|
||||
mode=mode,
|
||||
nonlinearity=nonlinearity)
|
||||
|
||||
if 0 in tensor.shape:
|
||||
warnings.warn("Initializing zero-element tensors is a no-op")
|
||||
|
|
|
|||
|
|
@ -897,7 +897,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.saddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
|
||||
torch.scatter: lambda input, dim, index, src: -1,
|
||||
torch.scatter_add: lambda input, dim, index, src: -1,
|
||||
torch.scatter_reduce: lambda input, dim, index, reduce, output_size=None: -1,
|
||||
torch._scatter_reduce: lambda input, dim, index, reduce, output_size=None: -1,
|
||||
torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1,
|
||||
torch.segment_reduce: lambda data, reduce="max", lengths=None, indices=None, axis=0, unsafe=False: -1,
|
||||
torch.select: lambda input, dim, index: -1,
|
||||
|
|
|
|||
|
|
@ -15517,7 +15517,7 @@ op_db: List[OpInfo] = [
|
|||
supports_fwgrad_bwgrad=True,
|
||||
),
|
||||
OpInfo(
|
||||
'scatter_reduce',
|
||||
'_scatter_reduce',
|
||||
dtypes=all_types_and(torch.float16, torch.bfloat16),
|
||||
sample_inputs_func=sample_inputs_scatter_reduce,
|
||||
supports_out=False,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user