fix searchsorted output type (#42933)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/41389
Make sure searchsorted that returns integer type does not make them require gradients.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/42933

Reviewed By: gchanan

Differential Revision: D23109583

Pulled By: albanD

fbshipit-source-id: 5af300b2f7f3c140d39fd7f7d87799f7b93a79c1
This commit is contained in:
albanD 2020-08-14 12:33:00 -07:00 committed by Facebook GitHub Bot
parent 059aa34b12
commit 1f6d0985d7
2 changed files with 15 additions and 1 deletions

View File

@ -4434,6 +4434,20 @@ for shape in [(1,), ()]:
self.assertFalse(out.dtype.is_floating_point)
self.assertFalse(out.requires_grad)
out = inp.argmin()
self.assertFalse(out.dtype.is_floating_point)
self.assertFalse(out.requires_grad)
out = inp.argsort()
self.assertFalse(out.dtype.is_floating_point)
self.assertFalse(out.requires_grad)
val = torch.rand((), requires_grad=True)
out = torch.searchsorted(inp, val)
self.assertFalse(out.dtype.is_floating_point)
self.assertFalse(out.requires_grad)
def index_variable(shape, max_indices):
if not isinstance(shape, tuple):

View File

@ -138,7 +138,7 @@ DONT_REQUIRE_DERIVATIVE = {
# Quantize functions should not record gradients
'quantize_per_tensor', 'quantize_per_channel',
# Functions that return integers should not have output that require gradients
'argmax', 'argmin', 'argsort',
'argmax', 'argmin', 'argsort', 'searchsorted'
}
# Some operators invalidate the grad_accumulator. Let's reset it.