avoiding adding some functions to the public python API before 1.11 release (#72543)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72543

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D34085724

Pulled By: bdhirsh

fbshipit-source-id: 941d5a90a6fa5328268d623e0e2b01577e4132ca
This commit is contained in:
Brian Hirsh 2022-02-14 11:39:05 -08:00 committed by Facebook GitHub Bot
parent 963027c7f2
commit 6676a0c79a
7 changed files with 23 additions and 19 deletions

View File

@ -6061,7 +6061,7 @@
- func: scatter_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor
variants: function, method
- func: scatter_reduce.two(Tensor self, int dim, Tensor index, str reduce, *, int? output_size=None) -> Tensor
- func: _scatter_reduce.two(Tensor self, int dim, Tensor index, str reduce, *, int? output_size=None) -> Tensor
variants: function, method
dispatch:
CPU: scatter_reduce_two_cpu

View File

@ -106,6 +106,7 @@ ALLOW_LIST = [
("aten::_scatter_reduce", datetime.date(2022, 1, 31)),
("aten::native_multi_head_self_attention", datetime.date(9999, 1, 1)),
("aten::_native_multi_head_self_attention", datetime.date(9999, 1, 1)),
("aten::scatter_reduce.two", datetime.date(2022, 3, 15)),
]
ALLOW_LIST_COMPILED = [

View File

@ -5773,7 +5773,7 @@ class TestTorch(TestCase):
for reduce in reduces:
for dim in range(len(shape)):
output = input.scatter_reduce(dim, index, reduce, output_size=output_size)
output = input._scatter_reduce(dim, index, reduce, output_size=output_size)
# Check that output is of the correct size
output_shape = copy.copy(shape)
@ -5807,16 +5807,16 @@ class TestTorch(TestCase):
self.assertTrue(torch.allclose(output, expected))
with self.assertRaisesRegex(RuntimeError, "Expected `dim` to be in range -3 to 2"):
torch.scatter_reduce(input, 4, index, "sum")
torch._scatter_reduce(input, 4, index, "sum")
with self.assertRaisesRegex(RuntimeError, "Shape mismatch"):
index2 = torch.randint(0, output_size, (10, ), dtype=torch.long, device=device)
torch.scatter_reduce(input, 0, index2, "sum")
torch._scatter_reduce(input, 0, index2, "sum")
with self.assertRaisesRegex(RuntimeError, "Expected `index` values to be in range 0 to 2"):
input2 = torch.randn(10, dtype=dtype, device=device)
index2 = torch.tensor([0, 1, 0, 1, 2, 3, 3, 4, 4, 3])
torch.scatter_reduce(input2, 0, index2, "sum", output_size=2)
torch._scatter_reduce(input2, 0, index2, "sum", output_size=2)
def test_structseq_repr(self):
a = torch.arange(250).reshape(5, 5, 10)

View File

@ -2595,6 +2595,6 @@
- name: _efficientzerotensor(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
output_differentiability: [False]
- name: scatter_reduce.two(Tensor self, int dim, Tensor index, str reduce, *, int? output_size=None) -> Tensor
- name: _scatter_reduce.two(Tensor self, int dim, Tensor index, str reduce, *, int? output_size=None) -> Tensor
self: scatter_reduce_backward(grad, self, dim, index, reduce, result)
index: non_differentiable

View File

@ -4,9 +4,6 @@ import warnings
from torch import Tensor
import torch
from ..overrides import (
has_torch_function_variadic,
handle_torch_function)
# These no_grad_* functions are necessary as wrappers around the parts of these
# functions that use `with torch.no_grad()`. The JIT doesn't support context
@ -135,8 +132,8 @@ def uniform_(tensor: Tensor, a: float = 0., b: float = 1.) -> Tensor:
>>> w = torch.empty(3, 5)
>>> nn.init.uniform_(w)
"""
if has_torch_function_variadic(tensor):
return handle_torch_function(uniform_, (tensor,), tensor=tensor, a=a, b=b)
if torch.overrides.has_torch_function_variadic(tensor):
return torch.overrides.handle_torch_function(uniform_, (tensor,), tensor=tensor, a=a, b=b)
return _no_grad_uniform_(tensor, a, b)
@ -153,8 +150,8 @@ def normal_(tensor: Tensor, mean: float = 0., std: float = 1.) -> Tensor:
>>> w = torch.empty(3, 5)
>>> nn.init.normal_(w)
"""
if has_torch_function_variadic(tensor):
return handle_torch_function(normal_, (tensor,), tensor=tensor, mean=mean, std=std)
if torch.overrides.has_torch_function_variadic(tensor):
return torch.overrides.handle_torch_function(normal_, (tensor,), tensor=tensor, mean=mean, std=std)
return _no_grad_normal_(tensor, mean, std)
def trunc_normal_(tensor: Tensor, mean: float = 0., std: float = 1., a: float = -2., b: float = 2.) -> Tensor:
@ -190,8 +187,8 @@ def constant_(tensor: Tensor, val: float) -> Tensor:
>>> w = torch.empty(3, 5)
>>> nn.init.constant_(w, 0.3)
"""
if has_torch_function_variadic(tensor):
return handle_torch_function(constant_, (tensor,), tensor=tensor, val=val)
if torch.overrides.has_torch_function_variadic(tensor):
return torch.overrides.handle_torch_function(constant_, (tensor,), tensor=tensor, val=val)
return _no_grad_fill_(tensor, val)
@ -393,8 +390,14 @@ def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
>>> w = torch.empty(3, 5)
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
"""
if has_torch_function_variadic(tensor):
return handle_torch_function(kaiming_uniform_, (tensor,), tensor=tensor, a=a, mode=mode, nonlinearity=nonlinearity)
if torch.overrides.has_torch_function_variadic(tensor):
return torch.overrides.handle_torch_function(
kaiming_uniform_,
(tensor,),
tensor=tensor,
a=a,
mode=mode,
nonlinearity=nonlinearity)
if 0 in tensor.shape:
warnings.warn("Initializing zero-element tensors is a no-op")

View File

@ -897,7 +897,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.saddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
torch.scatter: lambda input, dim, index, src: -1,
torch.scatter_add: lambda input, dim, index, src: -1,
torch.scatter_reduce: lambda input, dim, index, reduce, output_size=None: -1,
torch._scatter_reduce: lambda input, dim, index, reduce, output_size=None: -1,
torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1,
torch.segment_reduce: lambda data, reduce="max", lengths=None, indices=None, axis=0, unsafe=False: -1,
torch.select: lambda input, dim, index: -1,

View File

@ -15517,7 +15517,7 @@ op_db: List[OpInfo] = [
supports_fwgrad_bwgrad=True,
),
OpInfo(
'scatter_reduce',
'_scatter_reduce',
dtypes=all_types_and(torch.float16, torch.bfloat16),
sample_inputs_func=sample_inputs_scatter_reduce,
supports_out=False,