[MPS] Type-promote tensor-iterator common dtype (#160334)

Otherwise, `torch.add(FloatTensor, IntTensor, alpha=2)` and `torch.add(FloatTensor, IntTensor, alpha=2)` were dispatched to different kernels

Fixes https://github.com/pytorch/pytorch/issues/160208
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160334
Approved by: https://github.com/Skylion007, https://github.com/dcci
This commit is contained in:
Nikita Shulga 2025-08-11 09:57:30 -07:00 committed by PyTorch MergeBot
parent d0e2240f68
commit d25c4f954d
2 changed files with 3 additions and 0 deletions

View File

@ -53,6 +53,7 @@ void binary_op_kernel(const std::string func_name,
.add_input(input)
.add_input(other)
.check_all_same_dtype(false)
.promote_inputs_to_common_dtype(true)
.build();
lib.exec_binary_kernel(iter, func_name, alpha);

View File

@ -7736,6 +7736,8 @@ class TestMPS(TestCaseMPS):
y = torch.arange(32, device='mps', dtype=torch.int32)
self.assertEqual(torch.add(x, y, alpha=2).cpu(), torch.add(x.cpu(), y.cpu(), alpha=2))
self.assertEqual(torch.add(x, 3, alpha=2).cpu(), torch.add(x.cpu(), 3, alpha=2))
# Regression test for https://github.com/pytorch/pytorch/issues/160208
self.assertEqual(torch.add(y, x, alpha=2).cpu(), torch.add(y.cpu(), x.cpu(), alpha=2))
# Test add
def test_add_scalars(self):