diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index e6af8d088b1..355ae719eb1 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -3349,6 +3349,21 @@ class TestNestedTensorSubclass(TestCase): t = torch.rand(t_size, requires_grad=True, device=device, dtype=torch.float64) gradcheck(grad_test_func, inputs=(t, *ts), check_batched_grad=False) + def test_threshold_backward(self, device): + ts1 = self._get_list_for_jagged_tensor(((2, 3, 4), 16), device=device, requires_grad=False) + ts2 = self._get_list_for_jagged_tensor(((2, 3, 4), 16), device=device, requires_grad=False) + + nt1, offsets = jagged_from_list(ts1, None) + nt2, offsets = jagged_from_list(ts2, offsets) + buf1 = buffer_from_jagged(nt1).detach().clone() + buf2 = buffer_from_jagged(nt2).detach().clone() + + res_nt = torch.ops.aten.threshold_backward(nt1, nt2, 0.0) + res_dense = torch.ops.aten.threshold_backward(buf1, buf2, 0.0) + + self.assertEqual(res_dense, buffer_from_jagged(res_nt)) + + @parametrize("keepdim", [False, True]) def test_sum_int_DimList(self, device, keepdim): # (B, j0, 3, 4) diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index c0e060d7e87..5a5accafc98 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -194,7 +194,7 @@ def lookup_jagged(func, *args, **kwargs) -> Optional[Callable]: check_schema("self: jt_all, ...", func, *args, **kwargs) return functools.partial(jagged_unary_pointwise, func) elif num_tensor_args == 2: - check_schema("lhs: any, rhs: any", func, *args, **kwargs) + check_schema("lhs: any, rhs: any, ...", func, *args, **kwargs) return functools.partial(jagged_binary_pointwise, func) return None