mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[NestedTensor] Support binary pointwise ops with >2 inputs (if inputs are non-tensors) (#119419)
It should usually be safe to run pointwise binary ops with >2 inputs. e.g. threshold_backward(tensor, tensor, scalar): we just operate on the values of the nested tensors, and pass in the other args as-is. Pull Request resolved: https://github.com/pytorch/pytorch/pull/119419 Approved by: https://github.com/soulitzer
This commit is contained in:
parent
cd9a1934fb
commit
278a0e1600
|
|
@ -3349,6 +3349,21 @@ class TestNestedTensorSubclass(TestCase):
|
|||
t = torch.rand(t_size, requires_grad=True, device=device, dtype=torch.float64)
|
||||
gradcheck(grad_test_func, inputs=(t, *ts), check_batched_grad=False)
|
||||
|
||||
def test_threshold_backward(self, device):
|
||||
ts1 = self._get_list_for_jagged_tensor(((2, 3, 4), 16), device=device, requires_grad=False)
|
||||
ts2 = self._get_list_for_jagged_tensor(((2, 3, 4), 16), device=device, requires_grad=False)
|
||||
|
||||
nt1, offsets = jagged_from_list(ts1, None)
|
||||
nt2, offsets = jagged_from_list(ts2, offsets)
|
||||
buf1 = buffer_from_jagged(nt1).detach().clone()
|
||||
buf2 = buffer_from_jagged(nt2).detach().clone()
|
||||
|
||||
res_nt = torch.ops.aten.threshold_backward(nt1, nt2, 0.0)
|
||||
res_dense = torch.ops.aten.threshold_backward(buf1, buf2, 0.0)
|
||||
|
||||
self.assertEqual(res_dense, buffer_from_jagged(res_nt))
|
||||
|
||||
|
||||
@parametrize("keepdim", [False, True])
|
||||
def test_sum_int_DimList(self, device, keepdim):
|
||||
# (B, j0, 3, 4)
|
||||
|
|
|
|||
|
|
@ -194,7 +194,7 @@ def lookup_jagged(func, *args, **kwargs) -> Optional[Callable]:
|
|||
check_schema("self: jt_all, ...", func, *args, **kwargs)
|
||||
return functools.partial(jagged_unary_pointwise, func)
|
||||
elif num_tensor_args == 2:
|
||||
check_schema("lhs: any, rhs: any", func, *args, **kwargs)
|
||||
check_schema("lhs: any, rhs: any, ...", func, *args, **kwargs)
|
||||
return functools.partial(jagged_binary_pointwise, func)
|
||||
|
||||
return None
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user