Make tensordot resize output tensor's size if out= argument is specified & make it safely cast & copy output (#56286)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/56022.
Fixes https://github.com/pytorch/pytorch/issues/56316

For `torch.tensordot`,
1. `tensordot`'s out variant now resizes the output tensor provided as the `out` argument if necessary.
2. Added a check to verify if the output tensor provided as the argument for `out` is on the same device as the input tensors.
3. Added a check to verify if the dtype of the result is castable to the dtype of the output tensor provided as an argument for `out`.
4. Because of (2) & (3), `tensordot`'s out variant now [safely casts & copies output](https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch).
5. `test_tensordot` in `test_linalg.py` had a bug - the output tensor wasn't being defined to be on the same device as the input tensors. It was fixed by simply using a `device` argument in its definition.
6. Added an `OpInfo` for `tensordot` and modified the `OpInfo` for `inner`.

cc heitorschueroff mruberry

Pull Request resolved: https://github.com/pytorch/pytorch/pull/56286

Reviewed By: ngimel

Differential Revision: D27845980

Pulled By: mruberry

fbshipit-source-id: 134ab163f05c31a6900dd65aefc745803019e037
This commit is contained in:
Winston Smith 2021-04-19 04:19:06 -07:00 committed by Facebook GitHub Bot
parent 0e106fce9c
commit 7513455c74
3 changed files with 55 additions and 15 deletions

View File

@ -1,4 +1,5 @@
#include <ATen/ATen.h>
#include <ATen/native/Resize.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/xnnpack/Engine.h>
#include <ATen/WrapDimUtilsMulti.h>
@ -641,9 +642,26 @@ Tensor tensordot(const Tensor& input1, const Tensor& input2, IntArrayRef dims1,
}
Tensor &tensordot_out(const Tensor& input1, const Tensor& input2, IntArrayRef dims1, IntArrayRef dims2, Tensor& result) {
result.copy_(at::native::tensordot(input1, input2, dims1, dims2));
Tensor result_tmp = at::native::tensordot(input1, input2, dims1, dims2);
auto result_dtype = result_tmp.scalar_type();
auto output_tensor_dtype = result.scalar_type();
auto output_device = result.device();
auto input_device = input1.device();
// check if the input & output tensors are on the same device.
TORCH_CHECK(
output_device == input_device,
"tensordot: Expected the output and input tensors to be on the "
"same device, but got output on ", output_device, " and inputs on ",
input_device);
// check if the computed result has the same dtype as the out tensor
// (because tensordot does not support type promotion)
TORCH_CHECK(
result_dtype == output_tensor_dtype, "tensordot",
": Expected the output tensor to have dtype ", result_dtype,
", but got an output tensor with dtype ", output_tensor_dtype);
at::native::resize_output(result, result_tmp.sizes());
result.copy_(result_tmp);
return result;
}
}} // namespace at::native

View File

@ -7600,7 +7600,7 @@ else:
axes=([1, 0], [0, 1])))
self.assertEqual(c, cn)
cout = torch.zeros((5, 2))
cout = torch.zeros((5, 2), device=device)
torch.tensordot(a, b, dims=([1, 0], [0, 1]), out=cout).cpu()
self.assertEqual(c, cout)

View File

@ -2632,6 +2632,20 @@ def sample_inputs_lerp(op_info, device, dtype, requires_grad):
return samples
def sample_inputs_tensordot(self, device, dtype, requires_grad, **kwargs):
cases = (
((2, 2, 2), (2, 2, 2), (2)),
((2, 2, 1), (2, 1, 2), ([0, 1], [2, 0])),
)
samples = []
for first_shape, second_shape, dims in cases:
samples.append(SampleInput(make_tensor(first_shape, device, dtype,
requires_grad=requires_grad),
args=(make_tensor(second_shape, device, dtype,
requires_grad=requires_grad),),
kwargs=dict(dims=dims,)))
return tuple(samples)
def sample_inputs_kron(op_info, device, dtype, requires_grad):
test_cases = (
((S, S), (M, L)),
@ -4582,19 +4596,27 @@ op_db: List[OpInfo] = [
supports_inplace_autograd=False,
sample_inputs_func=sample_inputs_kron),
OpInfo('inner',
dtypes=floating_types(),
dtypesIfCPU=all_types_and_complex_and(torch.float16, torch.bfloat16),
# BFloat16 support on CUDA requires CUDA 11 and SM53
dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128,
*[torch.bfloat16] if CUDA11OrLater else []),
dtypesIfROCM=floating_types_and(torch.half),
supports_out=True,
sample_inputs_func=sample_inputs_inner,
dtypes=floating_and_complex_types_and(torch.half),
dtypesIfCPU=all_types_and_complex_and(torch.half, torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []),
dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_inner),
OpInfo('tensordot',
dtypes=floating_and_complex_types_and(torch.half),
dtypesIfCPU=all_types_and_complex_and(torch.half, torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []),
dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16),
safe_casts_outputs=True,
sample_inputs_func=sample_inputs_tensordot,
skips=(
# Reference Issue: https://github.com/pytorch/pytorch/issues/56022
# AssertionError: UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
SkipInfo('TestCommon', 'test_out', dtypes=[torch.float32]),
)),
# Currently failing due to an INTERNAL_ASSERT_FAILED error.
# Reference: https://github.com/pytorch/pytorch/issues/56314
SkipInfo("TestCommon", "test_variant_consistency_jit", dtypes=[torch.float32]),
# Skip operator schema test because this is a functional and not an operator.
# Reference: https://github.com/pytorch/pytorch/issues/54574
SkipInfo('TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
)
),
OpInfo('logcumsumexp',
dtypes=floating_types_and(),
dtypesIfCUDA=floating_types_and(torch.half),