diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 531de605c5d..b6a9801792b 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -12114,16 +12114,16 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar self.assertTrue(called) @skip_if_gpu_halide # cuda error - @xfail_if_mps # float64 is not MPS type def test_mutations_loop_fusion(self): def fn(tensor, index, source): out = tensor.index_add(0, index, source, alpha=2.0) / 2 return out device = "cpu" - tensor = torch.rand((1,), dtype=torch.double, device=device) + dtype = torch.double if self.device != "mps" else torch.float32 + tensor = torch.rand((1,), dtype=dtype, device=device) index = torch.tensor([0], dtype=torch.long, device=device) - source = torch.rand((1,), dtype=torch.double, device=device) + source = torch.rand((1,), dtype=dtype, device=device) self.common( fn, (