mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix lerp weight type promotion (#141117)
Fixes #140601
Enable `promote_inputs_to_common_dtype` when tensors not same dtype when invoke `lerp` function.
For `lerp_Tensor`
- Check whether same `dtype` of tensors, enable promote if not
- Remove type check assert
For `lerp_Scalar`
- Seems already enable `promote_inputs_to_common_dtype` by default, just remove the type check. Make sure promote behavior consistent with `lerp_Tensor`
`lerp_Scalar` get TensorIteratorConfig from here
c37185c76a/aten/src/ATen/TensorIterator.cpp (L979-L985)
**Test Result**
Test case in issue passed
```python
>>> import torch
>>>
>>> x = torch.ones(2, 2, dtype=torch.float64)
>>> w = torch.ones(2, 2, dtype=torch.float64)
>>> s = torch.tensor(2.2)
>>> x.lerp_(w, s)
tensor([[1., 1.],
[1., 1.]], dtype=torch.float64)
>>> x = torch.ones(2, 2, dtype=torch.float16)
>>> w = torch.ones(2, 2, dtype=torch.float16)
>>> s = torch.tensor(2.2)
>>> x.lerp_(w, s)
tensor([[1., 1.],
[1., 1.]], dtype=torch.float16)
```
```bash
$ pytest test/test_binary_ufuncs.py -k 'test_lerp_tensor_type_promotion or test_lerp_scalar_type_promotion'
```

```bash
$ lintrunner
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141117
Approved by: https://github.com/janeyx99
Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
This commit is contained in:
parent
b2c89bc115
commit
54e2f4b201
|
|
@ -16,10 +16,16 @@ TORCH_META_FUNC(lerp_Tensor)(
|
|||
const Tensor& self, const Tensor& end, const Tensor& weight) {
|
||||
TORCH_CHECK(self.dtype() == end.dtype(), "expected dtype ", self.dtype(),
|
||||
" for `end` but got dtype ", end.dtype());
|
||||
bool promote_weight = weight.dim() == 0;
|
||||
if (!promote_weight) {
|
||||
TORCH_CHECK(self.dtype() == weight.dtype(), "expected dtype ", self.dtype(),
|
||||
" for `weight` but got dtype ", weight.dtype());
|
||||
}
|
||||
build(at::TensorIteratorConfig()
|
||||
.allow_cpu_scalars(true)
|
||||
.promote_inputs_to_common_dtype(promote_weight)
|
||||
.enforce_safe_casting_to_output(promote_weight)
|
||||
.cast_common_dtype_to_outputs(promote_weight)
|
||||
.add_output(maybe_get_output())
|
||||
.add_const_input(self)
|
||||
.add_const_input(end)
|
||||
|
|
|
|||
|
|
@ -3519,6 +3519,24 @@ class TestBinaryUfuncs(TestCase):
|
|||
expected = torch.lerp(xref, yref, wref).to(dtype)
|
||||
self.assertEqual(actual, expected, atol=0.0, rtol=0.0)
|
||||
|
||||
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
|
||||
def test_lerp_weight_scalar_tensor_promotion(self, device, dtype):
|
||||
start = make_tensor((5, 5), dtype=dtype, device=device, low=1, high=100)
|
||||
end = make_tensor((5, 5), dtype=dtype, device=device, low=1, high=100)
|
||||
weight = torch.rand((), dtype=torch.float, device=device)
|
||||
|
||||
actual = torch.lerp(start, end, weight)
|
||||
expected = start + weight.to(dtype) * (end - start)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@dtypes(torch.double, torch.cfloat, torch.cdouble)
|
||||
def test_lerp_weight_tensor_promotion_error(self, device, dtype):
|
||||
start = make_tensor((5, 5), dtype=dtype, device=device, low=1, high=100)
|
||||
end = make_tensor((5, 5), dtype=dtype, device=device, low=1, high=100)
|
||||
weight = torch.rand((5, 5), dtype=torch.float, device=device)
|
||||
with self.assertRaisesRegex(RuntimeError, "expected dtype"):
|
||||
torch.lerp(start, end, weight)
|
||||
|
||||
def _test_logaddexp(self, device, dtype, base2):
|
||||
if base2:
|
||||
ref_func = np.logaddexp2
|
||||
|
|
|
|||
|
|
@ -6972,6 +6972,7 @@ def lerp(start, end, weight):
|
|||
)
|
||||
args = [start, end]
|
||||
if isinstance(weight, TensorLike):
|
||||
if weight.ndim != 0:
|
||||
torch._check(
|
||||
start.dtype == weight.dtype,
|
||||
lambda: f"expected dtype {start.dtype} for `weight`, but got dtype {weight.dtype}",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user