mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
fix searchsorted output type (#42933)
Summary: Fixes https://github.com/pytorch/pytorch/issues/41389 Make sure searchsorted that returns integer type does not make them require gradients. Pull Request resolved: https://github.com/pytorch/pytorch/pull/42933 Reviewed By: gchanan Differential Revision: D23109583 Pulled By: albanD fbshipit-source-id: 5af300b2f7f3c140d39fd7f7d87799f7b93a79c1
This commit is contained in:
parent
059aa34b12
commit
1f6d0985d7
|
|
@ -4434,6 +4434,20 @@ for shape in [(1,), ()]:
|
|||
self.assertFalse(out.dtype.is_floating_point)
|
||||
self.assertFalse(out.requires_grad)
|
||||
|
||||
out = inp.argmin()
|
||||
self.assertFalse(out.dtype.is_floating_point)
|
||||
self.assertFalse(out.requires_grad)
|
||||
|
||||
out = inp.argsort()
|
||||
self.assertFalse(out.dtype.is_floating_point)
|
||||
self.assertFalse(out.requires_grad)
|
||||
|
||||
val = torch.rand((), requires_grad=True)
|
||||
|
||||
out = torch.searchsorted(inp, val)
|
||||
self.assertFalse(out.dtype.is_floating_point)
|
||||
self.assertFalse(out.requires_grad)
|
||||
|
||||
|
||||
def index_variable(shape, max_indices):
|
||||
if not isinstance(shape, tuple):
|
||||
|
|
|
|||
|
|
@ -138,7 +138,7 @@ DONT_REQUIRE_DERIVATIVE = {
|
|||
# Quantize functions should not record gradients
|
||||
'quantize_per_tensor', 'quantize_per_channel',
|
||||
# Functions that return integers should not have output that require gradients
|
||||
'argmax', 'argmin', 'argsort',
|
||||
'argmax', 'argmin', 'argsort', 'searchsorted'
|
||||
}
|
||||
|
||||
# Some operators invalidate the grad_accumulator. Let's reset it.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user