mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Testing] Enable test_mutations_loop_fusion_mps (#151872)
By testing it against float32 rather than double dtype Pull Request resolved: https://github.com/pytorch/pytorch/pull/151872 Approved by: https://github.com/Skylion007, https://github.com/dcci, https://github.com/jansel ghstack dependencies: #151869, #151871
This commit is contained in:
parent
2f851ac8f8
commit
c0b70f94e2
|
|
@ -12114,16 +12114,16 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||||
self.assertTrue(called)
|
self.assertTrue(called)
|
||||||
|
|
||||||
@skip_if_gpu_halide # cuda error
|
@skip_if_gpu_halide # cuda error
|
||||||
@xfail_if_mps # float64 is not MPS type
|
|
||||||
def test_mutations_loop_fusion(self):
|
def test_mutations_loop_fusion(self):
|
||||||
def fn(tensor, index, source):
|
def fn(tensor, index, source):
|
||||||
out = tensor.index_add(0, index, source, alpha=2.0) / 2
|
out = tensor.index_add(0, index, source, alpha=2.0) / 2
|
||||||
return out
|
return out
|
||||||
|
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
tensor = torch.rand((1,), dtype=torch.double, device=device)
|
dtype = torch.double if self.device != "mps" else torch.float32
|
||||||
|
tensor = torch.rand((1,), dtype=dtype, device=device)
|
||||||
index = torch.tensor([0], dtype=torch.long, device=device)
|
index = torch.tensor([0], dtype=torch.long, device=device)
|
||||||
source = torch.rand((1,), dtype=torch.double, device=device)
|
source = torch.rand((1,), dtype=dtype, device=device)
|
||||||
self.common(
|
self.common(
|
||||||
fn,
|
fn,
|
||||||
(
|
(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user