[NestedTensor] Support binary pointwise ops with >2 inputs (if inputs are non-tensors) (#119419)

It should usually be safe to run pointwise binary ops with >2 inputs. e.g. threshold_backward(tensor, tensor, scalar): we just operate on the values of the nested tensors, and pass in the other args as-is.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119419
Approved by: https://github.com/soulitzer
This commit is contained in:
David Berard 2024-02-07 15:04:36 -08:00 committed by PyTorch MergeBot
parent cd9a1934fb
commit 278a0e1600
2 changed files with 16 additions and 1 deletions

View File

@ -3349,6 +3349,21 @@ class TestNestedTensorSubclass(TestCase):
t = torch.rand(t_size, requires_grad=True, device=device, dtype=torch.float64) t = torch.rand(t_size, requires_grad=True, device=device, dtype=torch.float64)
gradcheck(grad_test_func, inputs=(t, *ts), check_batched_grad=False) gradcheck(grad_test_func, inputs=(t, *ts), check_batched_grad=False)
def test_threshold_backward(self, device):
ts1 = self._get_list_for_jagged_tensor(((2, 3, 4), 16), device=device, requires_grad=False)
ts2 = self._get_list_for_jagged_tensor(((2, 3, 4), 16), device=device, requires_grad=False)
nt1, offsets = jagged_from_list(ts1, None)
nt2, offsets = jagged_from_list(ts2, offsets)
buf1 = buffer_from_jagged(nt1).detach().clone()
buf2 = buffer_from_jagged(nt2).detach().clone()
res_nt = torch.ops.aten.threshold_backward(nt1, nt2, 0.0)
res_dense = torch.ops.aten.threshold_backward(buf1, buf2, 0.0)
self.assertEqual(res_dense, buffer_from_jagged(res_nt))
@parametrize("keepdim", [False, True]) @parametrize("keepdim", [False, True])
def test_sum_int_DimList(self, device, keepdim): def test_sum_int_DimList(self, device, keepdim):
# (B, j0, 3, 4) # (B, j0, 3, 4)

View File

@ -194,7 +194,7 @@ def lookup_jagged(func, *args, **kwargs) -> Optional[Callable]:
check_schema("self: jt_all, ...", func, *args, **kwargs) check_schema("self: jt_all, ...", func, *args, **kwargs)
return functools.partial(jagged_unary_pointwise, func) return functools.partial(jagged_unary_pointwise, func)
elif num_tensor_args == 2: elif num_tensor_args == 2:
check_schema("lhs: any, rhs: any", func, *args, **kwargs) check_schema("lhs: any, rhs: any, ...", func, *args, **kwargs)
return functools.partial(jagged_binary_pointwise, func) return functools.partial(jagged_binary_pointwise, func)
return None return None