[complex] conv_transpose2d (#81805)

Reference: https://github.com/pytorch/pytorch/issues/71108

Fixes : #86414
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81805
Approved by: https://github.com/anjali411
This commit is contained in:
kshitij12345 2022-10-19 09:12:27 +00:00 committed by PyTorch MergeBot
parent 232fbd90ff
commit 528dd05108
3 changed files with 66 additions and 11 deletions

View File

@ -1066,8 +1066,14 @@ at::Tensor conv_transpose2d(
Tensor input;
bool is_batched;
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 2, "conv_transpose2d");
auto output = at::convolution(
Tensor output;
if (at::isComplexType(input_.scalar_type())) {
output = complex_convolution(
input, weight, bias, stride, padding, dilation, true, output_padding, groups);
} else {
output = at::convolution(
input, weight, bias, stride, padding, dilation, true, output_padding, groups);
}
return is_batched ? output : output.squeeze(0);
}

View File

@ -3069,8 +3069,12 @@ def conv_transpose_ref(input, weight, bias, stride=1, padding=0,
assert fn is not None
grad_fn_map = {torch.nn.functional.conv_transpose1d: torch.nn.grad.conv1d_input}
batched_dim_map = {torch.nn.functional.conv_transpose1d: 3}
grad_fn_map = {torch.nn.functional.conv_transpose1d: torch.nn.grad.conv1d_input,
torch.nn.functional.conv_transpose2d: torch.nn.grad.conv2d_input,
torch.nn.functional.conv_transpose3d: torch.nn.grad.conv3d_input}
batched_dim_map = {torch.nn.functional.conv_transpose1d: 3,
torch.nn.functional.conv_transpose2d: 4,
torch.nn.functional.conv_transpose3d: 5}
# Input for `ref` is ndarray.
input, weight = torch.from_numpy(input), torch.from_numpy(weight)
@ -3080,7 +3084,10 @@ def conv_transpose_ref(input, weight, bias, stride=1, padding=0,
input = input.unsqueeze(0)
if bias is not None:
bias = torch.from_numpy(bias).unsqueeze(1)
bias = torch.from_numpy(bias)
unsqueeze_dims = input.ndim - 2
for _ in range(unsqueeze_dims):
bias = bias.unsqueeze(1)
grad_output = input
# Get the input shape for grad_fn.
@ -3147,8 +3154,7 @@ def sample_inputs_conv_transpose2d(op_info, device, dtype, requires_grad, **kwar
((1, 1, 4, 3), (1, 2, 3, 4), None,
{'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1}),
((2, 8, 4, 4), (8, 1, 3, 3), None, {'groups': 4}),
((1, 4, 5, 5), (4, 8, 3, 3), None,
{})
((1, 4, 5, 5), (4, 8, 3, 3), None, {})
)
for input_shape, weight, bias, kwargs in cases:
@ -10635,9 +10641,12 @@ op_db: List[OpInfo] = [
OpInfo('nn.functional.conv_transpose2d',
aten_name='conv_transpose2d',
aliases=('conv_transpose2d',),
dtypes=floating_types_and(torch.int64),
dtypesIfCUDA=floating_types_and(torch.float16,
*[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
# `ref` for this function is backward of
# corresponding `conv*d`
ref=partial(conv_transpose_ref, fn=torch.nn.functional.conv_transpose2d),
dtypes=floating_and_complex_types_and(torch.int64),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf,
*[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
sample_inputs_func=sample_inputs_conv_transpose2d,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
@ -10646,11 +10655,32 @@ op_db: List[OpInfo] = [
decorators=[
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06), }),
'TestCommon', 'test_variant_consistency_eager', device_type='cuda')],
'TestCommon', 'test_variant_consistency_eager', device_type='cuda'),
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=2e-05, rtol=5e-05), }),
'TestCommon', 'test_noncontiguous_samples', device_type='cuda'),
DecorateInfo(
toleranceOverride({torch.complex32: tol(atol=5e-2, rtol=5e-2)}),
"TestCudaFuserOpInfo", "test_nvfuser_correctness"),
DecorateInfo(
toleranceOverride({torch.chalf: tol(atol=5e-2, rtol=5e-2), }),
'TestCommon', 'test_complex_half_reference_testing')],
skips=(
# RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at
# "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
# RuntimeError: UNSUPPORTED DTYPE: complex
DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness',
dtypes=(torch.complex64, torch.complex128)),
# RuntimeError: "slow_conv2d_cpu_grad_input" not implemented for 'Long'
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref',
dtypes=(torch.int64,)),
# Reference: https://github.com/pytorch/pytorch/issues/86356
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref',
dtypes=(torch.double, torch.cdouble)),
# AssertionError: None mismatch: torch.complex64 is not None
DecorateInfo(unittest.expectedFailure, 'TestDtypeCustomRules', 'test_custom_rules',
dtypes=(torch.complex64, torch.complex128)),
),
supports_out=False,),
OpInfo('nn.functional.conv_transpose3d',

View File

@ -1143,6 +1143,7 @@ module_db: List[ModuleInfo] = [
module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=False, transposed=True),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
module_memformat_affects_out=True,
dtypes=floating_and_complex_types_and(torch.chalf),
skips=(
# channels_last support on cuda requires cudnn >= 7603
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
@ -1153,7 +1154,25 @@ module_db: List[ModuleInfo] = [
# See https://github.com/pytorch/pytorch/issues/80247
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cpu'),
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
dtypes=[torch.float64]),
dtypes=[torch.float64, torch.complex128]),
# These fail only on ROCm
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
dtypes=[torch.complex32, torch.complex64], active_if=TEST_WITH_ROCM),
# Not implmented for chalf on CPU
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_forward',
dtypes=(torch.chalf,), device_type='cpu'),
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_memory_format',
dtypes=(torch.chalf,), device_type='cpu'),
DecorateInfo(unittest.expectedFailure, 'TestModule',
'test_if_train_and_eval_modes_differ', dtypes=(torch.chalf,), device_type='cpu'),
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_non_contiguous_tensors',
dtypes=(torch.chalf,), device_type='cpu'),
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity',
dtypes=(torch.chalf,), device_type='cuda'),
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_multiple_device_transfer',
dtypes=(torch.chalf,), device_type='cuda'),
# Ref: https://github.com/pytorch/pytorch/issues/73502
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_pickle', dtypes=(torch.chalf,)),
),
decorators=(
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),