mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[complex] conv_transpose2d (#81805)
Reference: https://github.com/pytorch/pytorch/issues/71108 Fixes : #86414 Pull Request resolved: https://github.com/pytorch/pytorch/pull/81805 Approved by: https://github.com/anjali411
This commit is contained in:
parent
232fbd90ff
commit
528dd05108
|
|
@ -1066,8 +1066,14 @@ at::Tensor conv_transpose2d(
|
|||
Tensor input;
|
||||
bool is_batched;
|
||||
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 2, "conv_transpose2d");
|
||||
auto output = at::convolution(
|
||||
Tensor output;
|
||||
if (at::isComplexType(input_.scalar_type())) {
|
||||
output = complex_convolution(
|
||||
input, weight, bias, stride, padding, dilation, true, output_padding, groups);
|
||||
} else {
|
||||
output = at::convolution(
|
||||
input, weight, bias, stride, padding, dilation, true, output_padding, groups);
|
||||
}
|
||||
return is_batched ? output : output.squeeze(0);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3069,8 +3069,12 @@ def conv_transpose_ref(input, weight, bias, stride=1, padding=0,
|
|||
|
||||
assert fn is not None
|
||||
|
||||
grad_fn_map = {torch.nn.functional.conv_transpose1d: torch.nn.grad.conv1d_input}
|
||||
batched_dim_map = {torch.nn.functional.conv_transpose1d: 3}
|
||||
grad_fn_map = {torch.nn.functional.conv_transpose1d: torch.nn.grad.conv1d_input,
|
||||
torch.nn.functional.conv_transpose2d: torch.nn.grad.conv2d_input,
|
||||
torch.nn.functional.conv_transpose3d: torch.nn.grad.conv3d_input}
|
||||
batched_dim_map = {torch.nn.functional.conv_transpose1d: 3,
|
||||
torch.nn.functional.conv_transpose2d: 4,
|
||||
torch.nn.functional.conv_transpose3d: 5}
|
||||
|
||||
# Input for `ref` is ndarray.
|
||||
input, weight = torch.from_numpy(input), torch.from_numpy(weight)
|
||||
|
|
@ -3080,7 +3084,10 @@ def conv_transpose_ref(input, weight, bias, stride=1, padding=0,
|
|||
input = input.unsqueeze(0)
|
||||
|
||||
if bias is not None:
|
||||
bias = torch.from_numpy(bias).unsqueeze(1)
|
||||
bias = torch.from_numpy(bias)
|
||||
unsqueeze_dims = input.ndim - 2
|
||||
for _ in range(unsqueeze_dims):
|
||||
bias = bias.unsqueeze(1)
|
||||
|
||||
grad_output = input
|
||||
# Get the input shape for grad_fn.
|
||||
|
|
@ -3147,8 +3154,7 @@ def sample_inputs_conv_transpose2d(op_info, device, dtype, requires_grad, **kwar
|
|||
((1, 1, 4, 3), (1, 2, 3, 4), None,
|
||||
{'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1}),
|
||||
((2, 8, 4, 4), (8, 1, 3, 3), None, {'groups': 4}),
|
||||
((1, 4, 5, 5), (4, 8, 3, 3), None,
|
||||
{})
|
||||
((1, 4, 5, 5), (4, 8, 3, 3), None, {})
|
||||
)
|
||||
|
||||
for input_shape, weight, bias, kwargs in cases:
|
||||
|
|
@ -10635,9 +10641,12 @@ op_db: List[OpInfo] = [
|
|||
OpInfo('nn.functional.conv_transpose2d',
|
||||
aten_name='conv_transpose2d',
|
||||
aliases=('conv_transpose2d',),
|
||||
dtypes=floating_types_and(torch.int64),
|
||||
dtypesIfCUDA=floating_types_and(torch.float16,
|
||||
*[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
|
||||
# `ref` for this function is backward of
|
||||
# corresponding `conv*d`
|
||||
ref=partial(conv_transpose_ref, fn=torch.nn.functional.conv_transpose2d),
|
||||
dtypes=floating_and_complex_types_and(torch.int64),
|
||||
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf,
|
||||
*[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
|
||||
sample_inputs_func=sample_inputs_conv_transpose2d,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
|
|
@ -10646,11 +10655,32 @@ op_db: List[OpInfo] = [
|
|||
decorators=[
|
||||
DecorateInfo(
|
||||
toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06), }),
|
||||
'TestCommon', 'test_variant_consistency_eager', device_type='cuda')],
|
||||
'TestCommon', 'test_variant_consistency_eager', device_type='cuda'),
|
||||
DecorateInfo(
|
||||
toleranceOverride({torch.float32: tol(atol=2e-05, rtol=5e-05), }),
|
||||
'TestCommon', 'test_noncontiguous_samples', device_type='cuda'),
|
||||
DecorateInfo(
|
||||
toleranceOverride({torch.complex32: tol(atol=5e-2, rtol=5e-2)}),
|
||||
"TestCudaFuserOpInfo", "test_nvfuser_correctness"),
|
||||
DecorateInfo(
|
||||
toleranceOverride({torch.chalf: tol(atol=5e-2, rtol=5e-2), }),
|
||||
'TestCommon', 'test_complex_half_reference_testing')],
|
||||
skips=(
|
||||
# RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at
|
||||
# "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
|
||||
# RuntimeError: UNSUPPORTED DTYPE: complex
|
||||
DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness',
|
||||
dtypes=(torch.complex64, torch.complex128)),
|
||||
# RuntimeError: "slow_conv2d_cpu_grad_input" not implemented for 'Long'
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref',
|
||||
dtypes=(torch.int64,)),
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/86356
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref',
|
||||
dtypes=(torch.double, torch.cdouble)),
|
||||
# AssertionError: None mismatch: torch.complex64 is not None
|
||||
DecorateInfo(unittest.expectedFailure, 'TestDtypeCustomRules', 'test_custom_rules',
|
||||
dtypes=(torch.complex64, torch.complex128)),
|
||||
),
|
||||
supports_out=False,),
|
||||
OpInfo('nn.functional.conv_transpose3d',
|
||||
|
|
|
|||
|
|
@ -1143,6 +1143,7 @@ module_db: List[ModuleInfo] = [
|
|||
module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=False, transposed=True),
|
||||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
|
||||
module_memformat_affects_out=True,
|
||||
dtypes=floating_and_complex_types_and(torch.chalf),
|
||||
skips=(
|
||||
# channels_last support on cuda requires cudnn >= 7603
|
||||
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
|
||||
|
|
@ -1153,7 +1154,25 @@ module_db: List[ModuleInfo] = [
|
|||
# See https://github.com/pytorch/pytorch/issues/80247
|
||||
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cpu'),
|
||||
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
|
||||
dtypes=[torch.float64]),
|
||||
dtypes=[torch.float64, torch.complex128]),
|
||||
# These fail only on ROCm
|
||||
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
|
||||
dtypes=[torch.complex32, torch.complex64], active_if=TEST_WITH_ROCM),
|
||||
# Not implmented for chalf on CPU
|
||||
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_forward',
|
||||
dtypes=(torch.chalf,), device_type='cpu'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_memory_format',
|
||||
dtypes=(torch.chalf,), device_type='cpu'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestModule',
|
||||
'test_if_train_and_eval_modes_differ', dtypes=(torch.chalf,), device_type='cpu'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_non_contiguous_tensors',
|
||||
dtypes=(torch.chalf,), device_type='cpu'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity',
|
||||
dtypes=(torch.chalf,), device_type='cuda'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_multiple_device_transfer',
|
||||
dtypes=(torch.chalf,), device_type='cuda'),
|
||||
# Ref: https://github.com/pytorch/pytorch/issues/73502
|
||||
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_pickle', dtypes=(torch.chalf,)),
|
||||
),
|
||||
decorators=(
|
||||
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user