mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Make tensordot resize output tensor's size if out= argument is specified & make it safely cast & copy output (#56286)
Summary: Fixes https://github.com/pytorch/pytorch/issues/56022. Fixes https://github.com/pytorch/pytorch/issues/56316 For `torch.tensordot`, 1. `tensordot`'s out variant now resizes the output tensor provided as the `out` argument if necessary. 2. Added a check to verify if the output tensor provided as the argument for `out` is on the same device as the input tensors. 3. Added a check to verify if the dtype of the result is castable to the dtype of the output tensor provided as an argument for `out`. 4. Because of (2) & (3), `tensordot`'s out variant now [safely casts & copies output](https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch). 5. `test_tensordot` in `test_linalg.py` had a bug - the output tensor wasn't being defined to be on the same device as the input tensors. It was fixed by simply using a `device` argument in its definition. 6. Added an `OpInfo` for `tensordot` and modified the `OpInfo` for `inner`. cc heitorschueroff mruberry Pull Request resolved: https://github.com/pytorch/pytorch/pull/56286 Reviewed By: ngimel Differential Revision: D27845980 Pulled By: mruberry fbshipit-source-id: 134ab163f05c31a6900dd65aefc745803019e037
This commit is contained in:
parent
0e106fce9c
commit
7513455c74
|
|
@ -1,4 +1,5 @@
|
|||
#include <ATen/ATen.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/native/xnnpack/Engine.h>
|
||||
#include <ATen/WrapDimUtilsMulti.h>
|
||||
|
|
@ -641,9 +642,26 @@ Tensor tensordot(const Tensor& input1, const Tensor& input2, IntArrayRef dims1,
|
|||
}
|
||||
|
||||
Tensor &tensordot_out(const Tensor& input1, const Tensor& input2, IntArrayRef dims1, IntArrayRef dims2, Tensor& result) {
|
||||
result.copy_(at::native::tensordot(input1, input2, dims1, dims2));
|
||||
Tensor result_tmp = at::native::tensordot(input1, input2, dims1, dims2);
|
||||
auto result_dtype = result_tmp.scalar_type();
|
||||
auto output_tensor_dtype = result.scalar_type();
|
||||
auto output_device = result.device();
|
||||
auto input_device = input1.device();
|
||||
// check if the input & output tensors are on the same device.
|
||||
TORCH_CHECK(
|
||||
output_device == input_device,
|
||||
"tensordot: Expected the output and input tensors to be on the "
|
||||
"same device, but got output on ", output_device, " and inputs on ",
|
||||
input_device);
|
||||
// check if the computed result has the same dtype as the out tensor
|
||||
// (because tensordot does not support type promotion)
|
||||
TORCH_CHECK(
|
||||
result_dtype == output_tensor_dtype, "tensordot",
|
||||
": Expected the output tensor to have dtype ", result_dtype,
|
||||
", but got an output tensor with dtype ", output_tensor_dtype);
|
||||
at::native::resize_output(result, result_tmp.sizes());
|
||||
result.copy_(result_tmp);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
}} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -7600,7 +7600,7 @@ else:
|
|||
axes=([1, 0], [0, 1])))
|
||||
self.assertEqual(c, cn)
|
||||
|
||||
cout = torch.zeros((5, 2))
|
||||
cout = torch.zeros((5, 2), device=device)
|
||||
torch.tensordot(a, b, dims=([1, 0], [0, 1]), out=cout).cpu()
|
||||
self.assertEqual(c, cout)
|
||||
|
||||
|
|
|
|||
|
|
@ -2632,6 +2632,20 @@ def sample_inputs_lerp(op_info, device, dtype, requires_grad):
|
|||
|
||||
return samples
|
||||
|
||||
def sample_inputs_tensordot(self, device, dtype, requires_grad, **kwargs):
|
||||
cases = (
|
||||
((2, 2, 2), (2, 2, 2), (2)),
|
||||
((2, 2, 1), (2, 1, 2), ([0, 1], [2, 0])),
|
||||
)
|
||||
samples = []
|
||||
for first_shape, second_shape, dims in cases:
|
||||
samples.append(SampleInput(make_tensor(first_shape, device, dtype,
|
||||
requires_grad=requires_grad),
|
||||
args=(make_tensor(second_shape, device, dtype,
|
||||
requires_grad=requires_grad),),
|
||||
kwargs=dict(dims=dims,)))
|
||||
return tuple(samples)
|
||||
|
||||
def sample_inputs_kron(op_info, device, dtype, requires_grad):
|
||||
test_cases = (
|
||||
((S, S), (M, L)),
|
||||
|
|
@ -4582,19 +4596,27 @@ op_db: List[OpInfo] = [
|
|||
supports_inplace_autograd=False,
|
||||
sample_inputs_func=sample_inputs_kron),
|
||||
OpInfo('inner',
|
||||
dtypes=floating_types(),
|
||||
dtypesIfCPU=all_types_and_complex_and(torch.float16, torch.bfloat16),
|
||||
# BFloat16 support on CUDA requires CUDA 11 and SM53
|
||||
dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128,
|
||||
*[torch.bfloat16] if CUDA11OrLater else []),
|
||||
dtypesIfROCM=floating_types_and(torch.half),
|
||||
supports_out=True,
|
||||
sample_inputs_func=sample_inputs_inner,
|
||||
dtypes=floating_and_complex_types_and(torch.half),
|
||||
dtypesIfCPU=all_types_and_complex_and(torch.half, torch.bfloat16),
|
||||
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []),
|
||||
dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16),
|
||||
sample_inputs_func=sample_inputs_inner),
|
||||
OpInfo('tensordot',
|
||||
dtypes=floating_and_complex_types_and(torch.half),
|
||||
dtypesIfCPU=all_types_and_complex_and(torch.half, torch.bfloat16),
|
||||
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []),
|
||||
dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16),
|
||||
safe_casts_outputs=True,
|
||||
sample_inputs_func=sample_inputs_tensordot,
|
||||
skips=(
|
||||
# Reference Issue: https://github.com/pytorch/pytorch/issues/56022
|
||||
# AssertionError: UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
|
||||
SkipInfo('TestCommon', 'test_out', dtypes=[torch.float32]),
|
||||
)),
|
||||
# Currently failing due to an INTERNAL_ASSERT_FAILED error.
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/56314
|
||||
SkipInfo("TestCommon", "test_variant_consistency_jit", dtypes=[torch.float32]),
|
||||
# Skip operator schema test because this is a functional and not an operator.
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/54574
|
||||
SkipInfo('TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
|
||||
)
|
||||
),
|
||||
OpInfo('logcumsumexp',
|
||||
dtypes=floating_types_and(),
|
||||
dtypesIfCUDA=floating_types_and(torch.half),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user