Fix lerp weight type promotion (#141117)

Fixes #140601

Enable `promote_inputs_to_common_dtype` when tensors not same dtype when invoke `lerp` function.

For `lerp_Tensor`
- Check whether same `dtype` of tensors, enable promote if not
- Remove type check assert

For `lerp_Scalar`
- Seems already enable `promote_inputs_to_common_dtype` by default, just remove the type check. Make sure promote behavior consistent with `lerp_Tensor`

`lerp_Scalar` get TensorIteratorConfig from here
c37185c76a/aten/src/ATen/TensorIterator.cpp (L979-L985)

**Test Result**
Test case in issue passed

```python
>>> import torch
>>>
>>> x = torch.ones(2, 2, dtype=torch.float64)
>>> w = torch.ones(2, 2, dtype=torch.float64)
>>> s = torch.tensor(2.2)
>>> x.lerp_(w, s)
tensor([[1., 1.],
        [1., 1.]], dtype=torch.float64)

>>> x = torch.ones(2, 2, dtype=torch.float16)
>>> w = torch.ones(2, 2, dtype=torch.float16)
>>> s = torch.tensor(2.2)
>>> x.lerp_(w, s)
tensor([[1., 1.],
        [1., 1.]], dtype=torch.float16)

```

```bash
$ pytest test/test_binary_ufuncs.py -k 'test_lerp_tensor_type_promotion or test_lerp_scalar_type_promotion'
```
![image](https://github.com/user-attachments/assets/288a5294-a9ee-47f3-bbf7-d4ff986f3ba8)

```bash
$ lintrunner
```
![image](https://github.com/user-attachments/assets/d469836f-5c49-4d89-a2fd-379cad4db3af)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141117
Approved by: https://github.com/janeyx99

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
This commit is contained in:
zeshengzong 2025-01-24 01:18:18 +00:00 committed by PyTorch MergeBot
parent b2c89bc115
commit 54e2f4b201
3 changed files with 31 additions and 6 deletions

View File

@ -16,10 +16,16 @@ TORCH_META_FUNC(lerp_Tensor)(
const Tensor& self, const Tensor& end, const Tensor& weight) { const Tensor& self, const Tensor& end, const Tensor& weight) {
TORCH_CHECK(self.dtype() == end.dtype(), "expected dtype ", self.dtype(), TORCH_CHECK(self.dtype() == end.dtype(), "expected dtype ", self.dtype(),
" for `end` but got dtype ", end.dtype()); " for `end` but got dtype ", end.dtype());
TORCH_CHECK(self.dtype() == weight.dtype(), "expected dtype ", self.dtype(), bool promote_weight = weight.dim() == 0;
" for `weight` but got dtype ", weight.dtype()); if (!promote_weight) {
TORCH_CHECK(self.dtype() == weight.dtype(), "expected dtype ", self.dtype(),
" for `weight` but got dtype ", weight.dtype());
}
build(at::TensorIteratorConfig() build(at::TensorIteratorConfig()
.allow_cpu_scalars(true) .allow_cpu_scalars(true)
.promote_inputs_to_common_dtype(promote_weight)
.enforce_safe_casting_to_output(promote_weight)
.cast_common_dtype_to_outputs(promote_weight)
.add_output(maybe_get_output()) .add_output(maybe_get_output())
.add_const_input(self) .add_const_input(self)
.add_const_input(end) .add_const_input(end)

View File

@ -3519,6 +3519,24 @@ class TestBinaryUfuncs(TestCase):
expected = torch.lerp(xref, yref, wref).to(dtype) expected = torch.lerp(xref, yref, wref).to(dtype)
self.assertEqual(actual, expected, atol=0.0, rtol=0.0) self.assertEqual(actual, expected, atol=0.0, rtol=0.0)
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
def test_lerp_weight_scalar_tensor_promotion(self, device, dtype):
start = make_tensor((5, 5), dtype=dtype, device=device, low=1, high=100)
end = make_tensor((5, 5), dtype=dtype, device=device, low=1, high=100)
weight = torch.rand((), dtype=torch.float, device=device)
actual = torch.lerp(start, end, weight)
expected = start + weight.to(dtype) * (end - start)
self.assertEqual(expected, actual)
@dtypes(torch.double, torch.cfloat, torch.cdouble)
def test_lerp_weight_tensor_promotion_error(self, device, dtype):
start = make_tensor((5, 5), dtype=dtype, device=device, low=1, high=100)
end = make_tensor((5, 5), dtype=dtype, device=device, low=1, high=100)
weight = torch.rand((5, 5), dtype=torch.float, device=device)
with self.assertRaisesRegex(RuntimeError, "expected dtype"):
torch.lerp(start, end, weight)
def _test_logaddexp(self, device, dtype, base2): def _test_logaddexp(self, device, dtype, base2):
if base2: if base2:
ref_func = np.logaddexp2 ref_func = np.logaddexp2

View File

@ -6972,10 +6972,11 @@ def lerp(start, end, weight):
) )
args = [start, end] args = [start, end]
if isinstance(weight, TensorLike): if isinstance(weight, TensorLike):
torch._check( if weight.ndim != 0:
start.dtype == weight.dtype, torch._check(
lambda: f"expected dtype {start.dtype} for `weight`, but got dtype {weight.dtype}", start.dtype == weight.dtype,
) lambda: f"expected dtype {start.dtype} for `weight`, but got dtype {weight.dtype}",
)
args.append(weight) args.append(weight)
return elementwise_meta( return elementwise_meta(
*args, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT *args, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT