mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
avoiding adding some functions to the public python API before 1.11 release (#72543)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72543 Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D34085724 Pulled By: bdhirsh fbshipit-source-id: 941d5a90a6fa5328268d623e0e2b01577e4132ca
This commit is contained in:
parent
963027c7f2
commit
6676a0c79a
|
|
@ -6061,7 +6061,7 @@
|
||||||
- func: scatter_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor
|
- func: scatter_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor
|
||||||
variants: function, method
|
variants: function, method
|
||||||
|
|
||||||
- func: scatter_reduce.two(Tensor self, int dim, Tensor index, str reduce, *, int? output_size=None) -> Tensor
|
- func: _scatter_reduce.two(Tensor self, int dim, Tensor index, str reduce, *, int? output_size=None) -> Tensor
|
||||||
variants: function, method
|
variants: function, method
|
||||||
dispatch:
|
dispatch:
|
||||||
CPU: scatter_reduce_two_cpu
|
CPU: scatter_reduce_two_cpu
|
||||||
|
|
|
||||||
|
|
@ -106,6 +106,7 @@ ALLOW_LIST = [
|
||||||
("aten::_scatter_reduce", datetime.date(2022, 1, 31)),
|
("aten::_scatter_reduce", datetime.date(2022, 1, 31)),
|
||||||
("aten::native_multi_head_self_attention", datetime.date(9999, 1, 1)),
|
("aten::native_multi_head_self_attention", datetime.date(9999, 1, 1)),
|
||||||
("aten::_native_multi_head_self_attention", datetime.date(9999, 1, 1)),
|
("aten::_native_multi_head_self_attention", datetime.date(9999, 1, 1)),
|
||||||
|
("aten::scatter_reduce.two", datetime.date(2022, 3, 15)),
|
||||||
]
|
]
|
||||||
|
|
||||||
ALLOW_LIST_COMPILED = [
|
ALLOW_LIST_COMPILED = [
|
||||||
|
|
|
||||||
|
|
@ -5773,7 +5773,7 @@ class TestTorch(TestCase):
|
||||||
|
|
||||||
for reduce in reduces:
|
for reduce in reduces:
|
||||||
for dim in range(len(shape)):
|
for dim in range(len(shape)):
|
||||||
output = input.scatter_reduce(dim, index, reduce, output_size=output_size)
|
output = input._scatter_reduce(dim, index, reduce, output_size=output_size)
|
||||||
|
|
||||||
# Check that output is of the correct size
|
# Check that output is of the correct size
|
||||||
output_shape = copy.copy(shape)
|
output_shape = copy.copy(shape)
|
||||||
|
|
@ -5807,16 +5807,16 @@ class TestTorch(TestCase):
|
||||||
self.assertTrue(torch.allclose(output, expected))
|
self.assertTrue(torch.allclose(output, expected))
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "Expected `dim` to be in range -3 to 2"):
|
with self.assertRaisesRegex(RuntimeError, "Expected `dim` to be in range -3 to 2"):
|
||||||
torch.scatter_reduce(input, 4, index, "sum")
|
torch._scatter_reduce(input, 4, index, "sum")
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "Shape mismatch"):
|
with self.assertRaisesRegex(RuntimeError, "Shape mismatch"):
|
||||||
index2 = torch.randint(0, output_size, (10, ), dtype=torch.long, device=device)
|
index2 = torch.randint(0, output_size, (10, ), dtype=torch.long, device=device)
|
||||||
torch.scatter_reduce(input, 0, index2, "sum")
|
torch._scatter_reduce(input, 0, index2, "sum")
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "Expected `index` values to be in range 0 to 2"):
|
with self.assertRaisesRegex(RuntimeError, "Expected `index` values to be in range 0 to 2"):
|
||||||
input2 = torch.randn(10, dtype=dtype, device=device)
|
input2 = torch.randn(10, dtype=dtype, device=device)
|
||||||
index2 = torch.tensor([0, 1, 0, 1, 2, 3, 3, 4, 4, 3])
|
index2 = torch.tensor([0, 1, 0, 1, 2, 3, 3, 4, 4, 3])
|
||||||
torch.scatter_reduce(input2, 0, index2, "sum", output_size=2)
|
torch._scatter_reduce(input2, 0, index2, "sum", output_size=2)
|
||||||
|
|
||||||
def test_structseq_repr(self):
|
def test_structseq_repr(self):
|
||||||
a = torch.arange(250).reshape(5, 5, 10)
|
a = torch.arange(250).reshape(5, 5, 10)
|
||||||
|
|
|
||||||
|
|
@ -2595,6 +2595,6 @@
|
||||||
- name: _efficientzerotensor(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
- name: _efficientzerotensor(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||||
output_differentiability: [False]
|
output_differentiability: [False]
|
||||||
|
|
||||||
- name: scatter_reduce.two(Tensor self, int dim, Tensor index, str reduce, *, int? output_size=None) -> Tensor
|
- name: _scatter_reduce.two(Tensor self, int dim, Tensor index, str reduce, *, int? output_size=None) -> Tensor
|
||||||
self: scatter_reduce_backward(grad, self, dim, index, reduce, result)
|
self: scatter_reduce_backward(grad, self, dim, index, reduce, result)
|
||||||
index: non_differentiable
|
index: non_differentiable
|
||||||
|
|
|
||||||
|
|
@ -4,9 +4,6 @@ import warnings
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..overrides import (
|
|
||||||
has_torch_function_variadic,
|
|
||||||
handle_torch_function)
|
|
||||||
|
|
||||||
# These no_grad_* functions are necessary as wrappers around the parts of these
|
# These no_grad_* functions are necessary as wrappers around the parts of these
|
||||||
# functions that use `with torch.no_grad()`. The JIT doesn't support context
|
# functions that use `with torch.no_grad()`. The JIT doesn't support context
|
||||||
|
|
@ -135,8 +132,8 @@ def uniform_(tensor: Tensor, a: float = 0., b: float = 1.) -> Tensor:
|
||||||
>>> w = torch.empty(3, 5)
|
>>> w = torch.empty(3, 5)
|
||||||
>>> nn.init.uniform_(w)
|
>>> nn.init.uniform_(w)
|
||||||
"""
|
"""
|
||||||
if has_torch_function_variadic(tensor):
|
if torch.overrides.has_torch_function_variadic(tensor):
|
||||||
return handle_torch_function(uniform_, (tensor,), tensor=tensor, a=a, b=b)
|
return torch.overrides.handle_torch_function(uniform_, (tensor,), tensor=tensor, a=a, b=b)
|
||||||
return _no_grad_uniform_(tensor, a, b)
|
return _no_grad_uniform_(tensor, a, b)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -153,8 +150,8 @@ def normal_(tensor: Tensor, mean: float = 0., std: float = 1.) -> Tensor:
|
||||||
>>> w = torch.empty(3, 5)
|
>>> w = torch.empty(3, 5)
|
||||||
>>> nn.init.normal_(w)
|
>>> nn.init.normal_(w)
|
||||||
"""
|
"""
|
||||||
if has_torch_function_variadic(tensor):
|
if torch.overrides.has_torch_function_variadic(tensor):
|
||||||
return handle_torch_function(normal_, (tensor,), tensor=tensor, mean=mean, std=std)
|
return torch.overrides.handle_torch_function(normal_, (tensor,), tensor=tensor, mean=mean, std=std)
|
||||||
return _no_grad_normal_(tensor, mean, std)
|
return _no_grad_normal_(tensor, mean, std)
|
||||||
|
|
||||||
def trunc_normal_(tensor: Tensor, mean: float = 0., std: float = 1., a: float = -2., b: float = 2.) -> Tensor:
|
def trunc_normal_(tensor: Tensor, mean: float = 0., std: float = 1., a: float = -2., b: float = 2.) -> Tensor:
|
||||||
|
|
@ -190,8 +187,8 @@ def constant_(tensor: Tensor, val: float) -> Tensor:
|
||||||
>>> w = torch.empty(3, 5)
|
>>> w = torch.empty(3, 5)
|
||||||
>>> nn.init.constant_(w, 0.3)
|
>>> nn.init.constant_(w, 0.3)
|
||||||
"""
|
"""
|
||||||
if has_torch_function_variadic(tensor):
|
if torch.overrides.has_torch_function_variadic(tensor):
|
||||||
return handle_torch_function(constant_, (tensor,), tensor=tensor, val=val)
|
return torch.overrides.handle_torch_function(constant_, (tensor,), tensor=tensor, val=val)
|
||||||
return _no_grad_fill_(tensor, val)
|
return _no_grad_fill_(tensor, val)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -393,8 +390,14 @@ def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||||
>>> w = torch.empty(3, 5)
|
>>> w = torch.empty(3, 5)
|
||||||
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
|
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
|
||||||
"""
|
"""
|
||||||
if has_torch_function_variadic(tensor):
|
if torch.overrides.has_torch_function_variadic(tensor):
|
||||||
return handle_torch_function(kaiming_uniform_, (tensor,), tensor=tensor, a=a, mode=mode, nonlinearity=nonlinearity)
|
return torch.overrides.handle_torch_function(
|
||||||
|
kaiming_uniform_,
|
||||||
|
(tensor,),
|
||||||
|
tensor=tensor,
|
||||||
|
a=a,
|
||||||
|
mode=mode,
|
||||||
|
nonlinearity=nonlinearity)
|
||||||
|
|
||||||
if 0 in tensor.shape:
|
if 0 in tensor.shape:
|
||||||
warnings.warn("Initializing zero-element tensors is a no-op")
|
warnings.warn("Initializing zero-element tensors is a no-op")
|
||||||
|
|
|
||||||
|
|
@ -897,7 +897,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.saddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
|
torch.saddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
|
||||||
torch.scatter: lambda input, dim, index, src: -1,
|
torch.scatter: lambda input, dim, index, src: -1,
|
||||||
torch.scatter_add: lambda input, dim, index, src: -1,
|
torch.scatter_add: lambda input, dim, index, src: -1,
|
||||||
torch.scatter_reduce: lambda input, dim, index, reduce, output_size=None: -1,
|
torch._scatter_reduce: lambda input, dim, index, reduce, output_size=None: -1,
|
||||||
torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1,
|
torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1,
|
||||||
torch.segment_reduce: lambda data, reduce="max", lengths=None, indices=None, axis=0, unsafe=False: -1,
|
torch.segment_reduce: lambda data, reduce="max", lengths=None, indices=None, axis=0, unsafe=False: -1,
|
||||||
torch.select: lambda input, dim, index: -1,
|
torch.select: lambda input, dim, index: -1,
|
||||||
|
|
|
||||||
|
|
@ -15517,7 +15517,7 @@ op_db: List[OpInfo] = [
|
||||||
supports_fwgrad_bwgrad=True,
|
supports_fwgrad_bwgrad=True,
|
||||||
),
|
),
|
||||||
OpInfo(
|
OpInfo(
|
||||||
'scatter_reduce',
|
'_scatter_reduce',
|
||||||
dtypes=all_types_and(torch.float16, torch.bfloat16),
|
dtypes=all_types_and(torch.float16, torch.bfloat16),
|
||||||
sample_inputs_func=sample_inputs_scatter_reduce,
|
sample_inputs_func=sample_inputs_scatter_reduce,
|
||||||
supports_out=False,
|
supports_out=False,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user