mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
[composite compliance] put, take (#81094)
Reference: #69991
This PR makes `put` CompositeExplicit as it is implemented in terms of `put_` (for which we can't handle Composite Compliance at the implementation level).
Ref (put implementation)
478081c698/aten/src/ATen/native/TensorAdvancedIndexing.cpp (L619-L621)
Also, we update the `take` gradient formula to handle Tensor Subclass .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81094
Approved by: https://github.com/zou3519
This commit is contained in:
parent
d30784be31
commit
db0e121b46
|
|
@ -6471,6 +6471,8 @@
|
||||||
|
|
||||||
- func: put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor
|
- func: put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor
|
||||||
variants: function, method
|
variants: function, method
|
||||||
|
dispatch:
|
||||||
|
CompositeExplicitAutograd: put
|
||||||
|
|
||||||
- func: index_add.out(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
|
- func: index_add.out(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
|
||||||
structured: True
|
structured: True
|
||||||
|
|
|
||||||
|
|
@ -561,7 +561,6 @@ class TestOperators(TestCase):
|
||||||
xfail('nanquantile'), # checks q via a .item() call
|
xfail('nanquantile'), # checks q via a .item() call
|
||||||
xfail('nn.functional.gaussian_nll_loss'), # checks var for if any value < 0
|
xfail('nn.functional.gaussian_nll_loss'), # checks var for if any value < 0
|
||||||
xfail('prod'), # calls nonzero
|
xfail('prod'), # calls nonzero
|
||||||
xfail('put'),
|
|
||||||
xfail('quantile'), # checks q via a .item() call
|
xfail('quantile'), # checks q via a .item() call
|
||||||
xfail('stft'),
|
xfail('stft'),
|
||||||
xfail('view_as_complex'),
|
xfail('view_as_complex'),
|
||||||
|
|
@ -651,7 +650,6 @@ class TestOperators(TestCase):
|
||||||
skip('nn.functional.max_unpool2d'), # fails everywhere except on mac
|
skip('nn.functional.max_unpool2d'), # fails everywhere except on mac
|
||||||
skip('nn.functional.max_unpool3d'), # fails everywhere except on mac
|
skip('nn.functional.max_unpool3d'), # fails everywhere except on mac
|
||||||
|
|
||||||
xfail('put'), # calls put_ during vmap with only vmaps over other, not self
|
|
||||||
xfail('nn.functional.prelu'), # Call Tensor.as_strided
|
xfail('nn.functional.prelu'), # Call Tensor.as_strided
|
||||||
|
|
||||||
# erroring because running_mean and running_var aren't differentiable
|
# erroring because running_mean and running_var aren't differentiable
|
||||||
|
|
@ -719,6 +717,7 @@ class TestOperators(TestCase):
|
||||||
xfail('linalg.cholesky_ex'),
|
xfail('linalg.cholesky_ex'),
|
||||||
xfail('masked_scatter'),
|
xfail('masked_scatter'),
|
||||||
xfail('index_fill'),
|
xfail('index_fill'),
|
||||||
|
xfail('put'),
|
||||||
xfail('take'),
|
xfail('take'),
|
||||||
xfail('linalg.eigvals'),
|
xfail('linalg.eigvals'),
|
||||||
xfail('linalg.qr'),
|
xfail('linalg.qr'),
|
||||||
|
|
@ -1283,7 +1282,6 @@ class TestOperators(TestCase):
|
||||||
@skipOps('TestOperators', 'test_vmap_autograd_grad', {
|
@skipOps('TestOperators', 'test_vmap_autograd_grad', {
|
||||||
# call inplace functions
|
# call inplace functions
|
||||||
xfail('linalg.householder_product'), # inplace
|
xfail('linalg.householder_product'), # inplace
|
||||||
xfail('take'), # inplace
|
|
||||||
|
|
||||||
xfail('linalg.eig'), # all close?
|
xfail('linalg.eig'), # all close?
|
||||||
# The size of tensor a (4) must match the size of tensor b (10) at non-singleton dimension 0
|
# The size of tensor a (4) must match the size of tensor b (10) at non-singleton dimension 0
|
||||||
|
|
|
||||||
|
|
@ -1267,11 +1267,11 @@
|
||||||
self: prod_backward(grad, self.to(grad.scalar_type()), result, dim, keepdim)
|
self: prod_backward(grad, self.to(grad.scalar_type()), result, dim, keepdim)
|
||||||
result: (prod_backward(at::ones({}, result.options()).expand_as(result), self_p.to(result.scalar_type()), result, dim, keepdim) * self_t.conj()).sum(dim, keepdim).conj()
|
result: (prod_backward(at::ones({}, result.options()).expand_as(result), self_p.to(result.scalar_type()), result, dim, keepdim) * self_t.conj()).sum(dim, keepdim).conj()
|
||||||
|
|
||||||
- name: put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!)
|
- name: put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor
|
||||||
self: "accumulate ? grad : grad.put(index, zeros_like(source), false)"
|
self: "accumulate ? grad : grad.put(index, zeros_like(source), false)"
|
||||||
index: non_differentiable
|
index: non_differentiable
|
||||||
source: grad.take(index).reshape_as(source)
|
source: grad.take(index).reshape_as(source)
|
||||||
result: auto_linear # It is affine, but sure
|
result: self_t.put(index, source_t, accumulate)
|
||||||
|
|
||||||
- name: linalg_qr(Tensor A, str mode='reduced') -> (Tensor Q, Tensor R)
|
- name: linalg_qr(Tensor A, str mode='reduced') -> (Tensor Q, Tensor R)
|
||||||
A: linalg_qr_backward(grad_Q, grad_R, Q, R, mode)
|
A: linalg_qr_backward(grad_Q, grad_R, Q, R, mode)
|
||||||
|
|
@ -1575,7 +1575,7 @@
|
||||||
result: auto_linear
|
result: auto_linear
|
||||||
|
|
||||||
- name: take(Tensor self, Tensor index) -> Tensor
|
- name: take(Tensor self, Tensor index) -> Tensor
|
||||||
self: zeros_like(self).put_(index, grad, true)
|
self: take_backward(grad, self, index)
|
||||||
index: non_differentiable
|
index: non_differentiable
|
||||||
result: auto_linear
|
result: auto_linear
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -288,6 +288,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
|
||||||
"replication_pad2d",
|
"replication_pad2d",
|
||||||
"replication_pad3d",
|
"replication_pad3d",
|
||||||
"take",
|
"take",
|
||||||
|
"put",
|
||||||
"put_",
|
"put_",
|
||||||
"_to_copy",
|
"_to_copy",
|
||||||
"replication_pad1d_backward",
|
"replication_pad1d_backward",
|
||||||
|
|
|
||||||
|
|
@ -6561,6 +6561,21 @@ std::tuple<Tensor, Tensor> index_reduce_backward(
|
||||||
return std::make_tuple(grad_self, grad_src);
|
return std::make_tuple(grad_self, grad_src);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor take_backward(
|
||||||
|
const Tensor& grad,
|
||||||
|
const Tensor& self,
|
||||||
|
const Tensor& indices) {
|
||||||
|
Tensor grad_self = at::zeros_like(self);
|
||||||
|
// For Composite Compliance,
|
||||||
|
// if `grad` and `indices` are CCT but `self` is not
|
||||||
|
// then we use the out-of-place variant of `put`.
|
||||||
|
if (!isTensorSubclassLike(self) &&
|
||||||
|
areAnyTensorSubclassLike({grad, indices})) {
|
||||||
|
return grad_self.put(indices, grad, true);
|
||||||
|
}
|
||||||
|
return grad_self.put_(indices, grad, true);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace details
|
} // namespace details
|
||||||
} // namespace generated
|
} // namespace generated
|
||||||
} // namespace autograd
|
} // namespace autograd
|
||||||
|
|
|
||||||
|
|
@ -996,6 +996,11 @@ std::tuple<Tensor, Tensor> index_reduce_backward(
|
||||||
bool include_self,
|
bool include_self,
|
||||||
const Tensor& result);
|
const Tensor& result);
|
||||||
|
|
||||||
|
Tensor take_backward(
|
||||||
|
const Tensor& grad,
|
||||||
|
const Tensor& self,
|
||||||
|
const Tensor& indices);
|
||||||
|
|
||||||
} // namespace details
|
} // namespace details
|
||||||
} // namespace generated
|
} // namespace generated
|
||||||
} // namespace autograd
|
} // namespace autograd
|
||||||
|
|
|
||||||
|
|
@ -17044,12 +17044,6 @@ op_db: List[OpInfo] = [
|
||||||
supports_fwgrad_bwgrad=True,
|
supports_fwgrad_bwgrad=True,
|
||||||
check_batched_forward_grad=False,
|
check_batched_forward_grad=False,
|
||||||
check_batched_gradgrad=False, # vmap complains of the sizes
|
check_batched_gradgrad=False, # vmap complains of the sizes
|
||||||
skips=(
|
|
||||||
# Problem, needs to be fixed
|
|
||||||
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'),
|
|
||||||
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
|
|
||||||
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
|
|
||||||
),
|
|
||||||
sample_inputs_func=sample_inputs_put),
|
sample_inputs_func=sample_inputs_put),
|
||||||
OpInfo('take',
|
OpInfo('take',
|
||||||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
|
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
|
||||||
|
|
@ -17057,9 +17051,6 @@ op_db: List[OpInfo] = [
|
||||||
supports_forward_ad=True,
|
supports_forward_ad=True,
|
||||||
supports_fwgrad_bwgrad=True,
|
supports_fwgrad_bwgrad=True,
|
||||||
sample_inputs_func=sample_inputs_take,
|
sample_inputs_func=sample_inputs_take,
|
||||||
skips=(
|
|
||||||
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
|
|
||||||
),
|
|
||||||
error_inputs_func=error_inputs_take),
|
error_inputs_func=error_inputs_take),
|
||||||
OpInfo('scatter',
|
OpInfo('scatter',
|
||||||
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
|
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user