[composite compliance] put, take (#81094)

Reference: #69991

This PR makes `put` CompositeExplicit as it is implemented in terms of `put_` (for which we can't handle Composite Compliance at the implementation level).

Ref (put implementation)
478081c698/aten/src/ATen/native/TensorAdvancedIndexing.cpp (L619-L621)

Also, we update the `take` gradient formula to handle Tensor Subclass .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81094
Approved by: https://github.com/zou3519
This commit is contained in:
Kshiteej K 2022-07-25 15:05:13 +00:00 committed by PyTorch MergeBot
parent d30784be31
commit db0e121b46
7 changed files with 27 additions and 15 deletions

View File

@ -6471,6 +6471,8 @@
- func: put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor - func: put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor
variants: function, method variants: function, method
dispatch:
CompositeExplicitAutograd: put
- func: index_add.out(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) - func: index_add.out(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
structured: True structured: True

View File

@ -561,7 +561,6 @@ class TestOperators(TestCase):
xfail('nanquantile'), # checks q via a .item() call xfail('nanquantile'), # checks q via a .item() call
xfail('nn.functional.gaussian_nll_loss'), # checks var for if any value < 0 xfail('nn.functional.gaussian_nll_loss'), # checks var for if any value < 0
xfail('prod'), # calls nonzero xfail('prod'), # calls nonzero
xfail('put'),
xfail('quantile'), # checks q via a .item() call xfail('quantile'), # checks q via a .item() call
xfail('stft'), xfail('stft'),
xfail('view_as_complex'), xfail('view_as_complex'),
@ -651,7 +650,6 @@ class TestOperators(TestCase):
skip('nn.functional.max_unpool2d'), # fails everywhere except on mac skip('nn.functional.max_unpool2d'), # fails everywhere except on mac
skip('nn.functional.max_unpool3d'), # fails everywhere except on mac skip('nn.functional.max_unpool3d'), # fails everywhere except on mac
xfail('put'), # calls put_ during vmap with only vmaps over other, not self
xfail('nn.functional.prelu'), # Call Tensor.as_strided xfail('nn.functional.prelu'), # Call Tensor.as_strided
# erroring because running_mean and running_var aren't differentiable # erroring because running_mean and running_var aren't differentiable
@ -719,6 +717,7 @@ class TestOperators(TestCase):
xfail('linalg.cholesky_ex'), xfail('linalg.cholesky_ex'),
xfail('masked_scatter'), xfail('masked_scatter'),
xfail('index_fill'), xfail('index_fill'),
xfail('put'),
xfail('take'), xfail('take'),
xfail('linalg.eigvals'), xfail('linalg.eigvals'),
xfail('linalg.qr'), xfail('linalg.qr'),
@ -1283,7 +1282,6 @@ class TestOperators(TestCase):
@skipOps('TestOperators', 'test_vmap_autograd_grad', { @skipOps('TestOperators', 'test_vmap_autograd_grad', {
# call inplace functions # call inplace functions
xfail('linalg.householder_product'), # inplace xfail('linalg.householder_product'), # inplace
xfail('take'), # inplace
xfail('linalg.eig'), # all close? xfail('linalg.eig'), # all close?
# The size of tensor a (4) must match the size of tensor b (10) at non-singleton dimension 0 # The size of tensor a (4) must match the size of tensor b (10) at non-singleton dimension 0

View File

@ -1267,11 +1267,11 @@
self: prod_backward(grad, self.to(grad.scalar_type()), result, dim, keepdim) self: prod_backward(grad, self.to(grad.scalar_type()), result, dim, keepdim)
result: (prod_backward(at::ones({}, result.options()).expand_as(result), self_p.to(result.scalar_type()), result, dim, keepdim) * self_t.conj()).sum(dim, keepdim).conj() result: (prod_backward(at::ones({}, result.options()).expand_as(result), self_p.to(result.scalar_type()), result, dim, keepdim) * self_t.conj()).sum(dim, keepdim).conj()
- name: put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!) - name: put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor
self: "accumulate ? grad : grad.put(index, zeros_like(source), false)" self: "accumulate ? grad : grad.put(index, zeros_like(source), false)"
index: non_differentiable index: non_differentiable
source: grad.take(index).reshape_as(source) source: grad.take(index).reshape_as(source)
result: auto_linear # It is affine, but sure result: self_t.put(index, source_t, accumulate)
- name: linalg_qr(Tensor A, str mode='reduced') -> (Tensor Q, Tensor R) - name: linalg_qr(Tensor A, str mode='reduced') -> (Tensor Q, Tensor R)
A: linalg_qr_backward(grad_Q, grad_R, Q, R, mode) A: linalg_qr_backward(grad_Q, grad_R, Q, R, mode)
@ -1575,7 +1575,7 @@
result: auto_linear result: auto_linear
- name: take(Tensor self, Tensor index) -> Tensor - name: take(Tensor self, Tensor index) -> Tensor
self: zeros_like(self).put_(index, grad, true) self: take_backward(grad, self, index)
index: non_differentiable index: non_differentiable
result: auto_linear result: auto_linear

View File

@ -288,6 +288,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
"replication_pad2d", "replication_pad2d",
"replication_pad3d", "replication_pad3d",
"take", "take",
"put",
"put_", "put_",
"_to_copy", "_to_copy",
"replication_pad1d_backward", "replication_pad1d_backward",

View File

@ -6561,6 +6561,21 @@ std::tuple<Tensor, Tensor> index_reduce_backward(
return std::make_tuple(grad_self, grad_src); return std::make_tuple(grad_self, grad_src);
} }
Tensor take_backward(
const Tensor& grad,
const Tensor& self,
const Tensor& indices) {
Tensor grad_self = at::zeros_like(self);
// For Composite Compliance,
// if `grad` and `indices` are CCT but `self` is not
// then we use the out-of-place variant of `put`.
if (!isTensorSubclassLike(self) &&
areAnyTensorSubclassLike({grad, indices})) {
return grad_self.put(indices, grad, true);
}
return grad_self.put_(indices, grad, true);
}
} // namespace details } // namespace details
} // namespace generated } // namespace generated
} // namespace autograd } // namespace autograd

View File

@ -996,6 +996,11 @@ std::tuple<Tensor, Tensor> index_reduce_backward(
bool include_self, bool include_self,
const Tensor& result); const Tensor& result);
Tensor take_backward(
const Tensor& grad,
const Tensor& self,
const Tensor& indices);
} // namespace details } // namespace details
} // namespace generated } // namespace generated
} // namespace autograd } // namespace autograd

View File

@ -17044,12 +17044,6 @@ op_db: List[OpInfo] = [
supports_fwgrad_bwgrad=True, supports_fwgrad_bwgrad=True,
check_batched_forward_grad=False, check_batched_forward_grad=False,
check_batched_gradgrad=False, # vmap complains of the sizes check_batched_gradgrad=False, # vmap complains of the sizes
skips=(
# Problem, needs to be fixed
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'),
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
),
sample_inputs_func=sample_inputs_put), sample_inputs_func=sample_inputs_put),
OpInfo('take', OpInfo('take',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
@ -17057,9 +17051,6 @@ op_db: List[OpInfo] = [
supports_forward_ad=True, supports_forward_ad=True,
supports_fwgrad_bwgrad=True, supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_take, sample_inputs_func=sample_inputs_take,
skips=(
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
),
error_inputs_func=error_inputs_take), error_inputs_func=error_inputs_take),
OpInfo('scatter', OpInfo('scatter',
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),