mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add forward AD formulas for {adaptive_,fractional_,}max_pool{2,3}d_{backward,} (#69884)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69884 Also fixes: https://github.com/pytorch/pytorch/issues/69322, https://github.com/pytorch/pytorch/issues/69325 Test Plan: Imported from OSS Reviewed By: bdhirsh Differential Revision: D33093039 Pulled By: soulitzer fbshipit-source-id: b9a522a00f4e9e85974888de5058de07280f8f66
This commit is contained in:
parent
6925576e88
commit
3116d87024
|
|
@ -11941,18 +11941,21 @@ class TestNNInit(TestCase):
|
|||
self.assertEqual(F.max_unpool1d(output, indices, 2), F.max_unpool1d(output, indices, 2, stride=2))
|
||||
|
||||
# Test list / tuple passed as argument to max_unpool1d
|
||||
input = torch.randn([1, 1, 5])
|
||||
input = torch.randn([1, 1, 5], requires_grad=True)
|
||||
output, indices = F.max_pool1d(input, 2, stride=2, return_indices=True)
|
||||
self.assertEqual(F.max_unpool1d(output, indices, 2, stride=2, output_size=input.shape),
|
||||
F.max_unpool1d(output, indices, 2, stride=2, output_size=input.size()))
|
||||
gradcheck(F.max_unpool1d, (output, indices, 2), check_forward_ad=True)
|
||||
|
||||
# Test 2D
|
||||
output, indices = F.max_pool2d(torch.randn([1, 1, 4, 4]), 2, stride=2, return_indices=True)
|
||||
output, indices = F.max_pool2d(torch.randn([1, 1, 4, 4], requires_grad=True), 2, stride=2, return_indices=True)
|
||||
self.assertEqual(F.max_unpool2d(output, indices, 2), F.max_unpool2d(output, indices, 2, stride=2))
|
||||
gradcheck(F.max_unpool2d, (output, indices, 2), check_forward_ad=True)
|
||||
|
||||
# Test 3D
|
||||
output, indices = F.max_pool3d(torch.randn([4, 4, 4, 4, 4]), 2, stride=2, return_indices=True)
|
||||
output, indices = F.max_pool3d(torch.randn([4, 4, 4, 4, 4], requires_grad=True), 2, stride=2, return_indices=True)
|
||||
self.assertEqual(F.max_unpool3d(output, indices, 2), F.max_unpool3d(output, indices, 2, stride=2))
|
||||
gradcheck(F.max_unpool3d, (output, indices, 2), check_forward_ad=True)
|
||||
|
||||
def test_dirac_properties(self):
|
||||
for dims in [3, 4, 5]:
|
||||
|
|
|
|||
|
|
@ -1962,9 +1962,13 @@
|
|||
|
||||
- name: adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)
|
||||
self: adaptive_max_pool2d_backward(grad, self, result1)
|
||||
result0: gather(self_t.flatten(-2), -1, result1.flatten(-2)).view_as(result1)
|
||||
output_differentiability: [True, False]
|
||||
|
||||
- name: adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor)
|
||||
self: adaptive_max_pool3d_backward(grad, self, result1)
|
||||
result0: gather(self_t.flatten(-3), -1, result1.flatten(-3)).view_as(result1)
|
||||
output_differentiability: [True, False]
|
||||
|
||||
- name: avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor
|
||||
self: avg_pool2d_backward(grad, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
|
||||
|
|
@ -1976,25 +1980,33 @@
|
|||
|
||||
- name: fractional_max_pool2d(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples) -> (Tensor, Tensor)
|
||||
self: fractional_max_pool2d_backward(grad, self, kernel_size, output_size, result1)
|
||||
result0: gather(self_t.flatten(-2), -1, result1.flatten(-2)).view_as(result1)
|
||||
output_differentiability: [True, False]
|
||||
|
||||
- name: fractional_max_pool3d(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples) -> (Tensor, Tensor)
|
||||
self: fractional_max_pool3d_backward(grad, self, kernel_size, output_size, result1)
|
||||
result0: gather(self_t.flatten(-3), -1, result1.flatten(-3)).view_as(result1)
|
||||
output_differentiability: [True, False]
|
||||
|
||||
- name: max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
|
||||
self: max_pool2d_with_indices_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode, result1)
|
||||
result0: gather(self_t.flatten(-2), -1, result1.flatten(-2)).view_as(result1)
|
||||
output_differentiability: [True, False]
|
||||
|
||||
- name: max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
|
||||
self: max_pool3d_with_indices_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode, result1)
|
||||
result0: gather(self_t.flatten(-3), -1, result1.flatten(-3)).view_as(result1)
|
||||
output_differentiability: [True, False]
|
||||
|
||||
- name: max_unpool2d(Tensor self, Tensor indices, int[2] output_size) -> Tensor
|
||||
self: max_unpool2d_backward(grad, self, indices, output_size)
|
||||
indices: non_differentiable
|
||||
result: auto_linear
|
||||
|
||||
- name: max_unpool3d(Tensor self, Tensor indices, int[3] output_size, int[3] stride, int[3] padding) -> Tensor
|
||||
self: max_unpool3d_backward(grad, self, indices, output_size, stride, padding)
|
||||
indices: non_differentiable
|
||||
result: auto_linear
|
||||
|
||||
- name: convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor
|
||||
input, weight, bias: "grad.defined() ? convolution_backward(grad, input, weight, bias->sizes(), stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
|
@ -2086,10 +2098,12 @@
|
|||
- name: adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor
|
||||
grad_output: max_pool_double_backward(grad, indices, 2)
|
||||
self: zeros_like(self)
|
||||
result: auto_linear
|
||||
|
||||
- name: adaptive_max_pool3d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor
|
||||
grad_output: max_pool_double_backward(grad, indices, 3)
|
||||
self: zeros_like(self)
|
||||
result: auto_linear
|
||||
|
||||
- name: avg_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor
|
||||
grad_output: avg_pool2d(grad, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
|
||||
|
|
@ -2108,10 +2122,12 @@
|
|||
- name: fractional_max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices) -> Tensor
|
||||
grad_output: max_pool_double_backward(grad, indices, 2)
|
||||
self: zeros_like(self)
|
||||
result: auto_linear
|
||||
|
||||
- name: fractional_max_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices) -> Tensor
|
||||
grad_output: max_pool_double_backward(grad, indices, 3)
|
||||
self: zeros_like(self)
|
||||
result: auto_linear
|
||||
|
||||
- name: glu_backward(Tensor grad_output, Tensor self, int dim) -> Tensor
|
||||
grad_output: glu_double_backward_grad_output(grad, self, dim)
|
||||
|
|
@ -2148,11 +2164,13 @@
|
|||
grad_output: max_pool_double_backward(grad, indices, 2)
|
||||
self: zeros_like(self)
|
||||
indices: non_differentiable
|
||||
result: auto_linear
|
||||
|
||||
- name: max_pool3d_with_indices_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices) -> Tensor
|
||||
grad_output: max_pool_double_backward(grad, indices, 3)
|
||||
self: zeros_like(self)
|
||||
indices: non_differentiable
|
||||
result: auto_linear
|
||||
|
||||
- name: max_unpool2d_backward(Tensor grad_output, Tensor self, Tensor indices, int[2] output_size) -> Tensor
|
||||
grad_output: max_unpool2d(grad, indices, output_size)
|
||||
|
|
|
|||
|
|
@ -10775,6 +10775,10 @@ op_db: List[OpInfo] = [
|
|||
dtypes=floating_types(),
|
||||
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
# got: Batching rule not implemented for aten::flatten.using_ints
|
||||
check_batched_forward_grad=False,
|
||||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
|
||||
sample_inputs_func=sample_inputs_adaptive_max_pool1d),
|
||||
OpInfo('nn.functional.adaptive_max_pool2d',
|
||||
|
|
@ -10792,6 +10796,10 @@ op_db: List[OpInfo] = [
|
|||
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
|
||||
),
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
# got: Batching rule not implemented for aten::flatten.using_ints
|
||||
check_batched_forward_grad=False,
|
||||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
|
||||
sample_inputs_func=sample_inputs_adaptive_max_pool2d),
|
||||
OpInfo('nn.functional.adaptive_max_pool3d',
|
||||
|
|
@ -10811,6 +10819,10 @@ op_db: List[OpInfo] = [
|
|||
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
|
||||
),
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
# got: Batching rule not implemented for aten::flatten.using_ints
|
||||
check_batched_forward_grad=False,
|
||||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
|
||||
sample_inputs_func=sample_inputs_adaptive_max_pool3d),
|
||||
OpInfo('nn.functional.avg_pool1d',
|
||||
|
|
@ -11201,49 +11213,54 @@ op_db: List[OpInfo] = [
|
|||
OpInfo('nn.functional.fractional_max_pool2d',
|
||||
supports_autograd=True,
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
op=lambda input, *args, **kwargs:
|
||||
wrapper_set_seed(torch.nn.functional.fractional_max_pool2d, input, *args, **kwargs),
|
||||
# vmap does not support random operations
|
||||
check_batched_forward_grad=False,
|
||||
dtypes=floating_types(),
|
||||
dtypesIfCUDA=floating_types_and(torch.float16),
|
||||
test_neg_view=False,
|
||||
sample_inputs_func=sample_inputs_fractional_max_pool2d,
|
||||
decorators=[
|
||||
# FIXME: both derivatives are implemented incorrectly
|
||||
# https://github.com/pytorch/pytorch/issues/69322
|
||||
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_grad'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_gradgrad'),
|
||||
# FIXME: produces incorrect output on non-contiguous inputs
|
||||
# https://github.com/pytorch/pytorch/issues/69325
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'),
|
||||
decorators=(
|
||||
# FIXME: AssertionError: False is not true : Tensors failed to compare as equal!
|
||||
DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
|
||||
# RuntimeError: input->type()->kind() == TypeKind::OptionalType
|
||||
# INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
|
||||
], ),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'))),
|
||||
OpInfo('nn.functional.fractional_max_pool3d',
|
||||
supports_autograd=True,
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
op=lambda input, *args, **kwargs:
|
||||
wrapper_set_seed(torch.nn.functional.fractional_max_pool3d, input, *args, **kwargs),
|
||||
# vmap does not support random operations
|
||||
check_batched_forward_grad=False,
|
||||
dtypes=floating_types(),
|
||||
dtypesIfCUDA=floating_types_and(torch.float16),
|
||||
test_neg_view=False,
|
||||
sample_inputs_func=sample_inputs_fractional_max_pool3d,
|
||||
decorators=[
|
||||
decorators=(
|
||||
# FIXME: both derivatives are implemented incorrectly
|
||||
# https://github.com/pytorch/pytorch/issues/69322
|
||||
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_grad'),
|
||||
# RuntimeError: cannot reshape tensor of 0 elements into shape [0, 1, -1] because the
|
||||
# unspecified dimension size -1 can be any value and is ambiguous
|
||||
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_gradgrad'),
|
||||
# FIXME: produces incorrect output on non-contiguous inputs
|
||||
# https://github.com/pytorch/pytorch/issues/69325
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'),
|
||||
# FIXME: AssertionError: False is not true : Tensors failed to compare as equal!
|
||||
DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
|
||||
# RuntimeError: input->type()->kind() == TypeKind::OptionalType
|
||||
# INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
|
||||
], ),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),)),
|
||||
OpInfo('nn.functional.max_pool1d',
|
||||
aten_name='max_pool1d',
|
||||
supports_autograd=True,
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
# got: Batching rule not implemented for aten::flatten.using_ints
|
||||
check_batched_forward_grad=False,
|
||||
# TODO: add shape checks
|
||||
assert_jit_shape_analysis=False,
|
||||
dtypes=floating_types(),
|
||||
|
|
@ -11259,6 +11276,10 @@ op_db: List[OpInfo] = [
|
|||
# Vmap is not happy with non-contiguous (channels_last) inputs
|
||||
check_batched_gradgrad=False,
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
# got: Batching rule not implemented for aten::flatten.using_ints
|
||||
check_batched_forward_grad=False,
|
||||
assert_jit_shape_analysis=True,
|
||||
dtypes=floating_types(),
|
||||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
|
|
@ -11267,6 +11288,10 @@ op_db: List[OpInfo] = [
|
|||
aten_name='max_pool3d',
|
||||
supports_autograd=True,
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
# got: Batching rule not implemented for aten::flatten.using_ints
|
||||
check_batched_forward_grad=False,
|
||||
# TODO: add shape checks
|
||||
assert_jit_shape_analysis=False,
|
||||
dtypes=floating_types(),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user