[Testing] Enable test_mutations_loop_fusion_mps (#151872)

By testing it against float32 rather than double dtype

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151872
Approved by: https://github.com/Skylion007, https://github.com/dcci, https://github.com/jansel
ghstack dependencies: #151869, #151871
This commit is contained in:
Nikita Shulga 2025-04-22 14:11:27 -07:00 committed by PyTorch MergeBot
parent 2f851ac8f8
commit c0b70f94e2

View File

@ -12114,16 +12114,16 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
self.assertTrue(called) self.assertTrue(called)
@skip_if_gpu_halide # cuda error @skip_if_gpu_halide # cuda error
@xfail_if_mps # float64 is not MPS type
def test_mutations_loop_fusion(self): def test_mutations_loop_fusion(self):
def fn(tensor, index, source): def fn(tensor, index, source):
out = tensor.index_add(0, index, source, alpha=2.0) / 2 out = tensor.index_add(0, index, source, alpha=2.0) / 2
return out return out
device = "cpu" device = "cpu"
tensor = torch.rand((1,), dtype=torch.double, device=device) dtype = torch.double if self.device != "mps" else torch.float32
tensor = torch.rand((1,), dtype=dtype, device=device)
index = torch.tensor([0], dtype=torch.long, device=device) index = torch.tensor([0], dtype=torch.long, device=device)
source = torch.rand((1,), dtype=torch.double, device=device) source = torch.rand((1,), dtype=dtype, device=device)
self.common( self.common(
fn, fn,
( (