diff --git a/test/inductor/test_mps_basic.py b/test/inductor/test_mps_basic.py index c19a51ad3dc..66f86f0cc4a 100644 --- a/test/inductor/test_mps_basic.py +++ b/test/inductor/test_mps_basic.py @@ -119,6 +119,13 @@ class MPSBasicTests(TestCase): torch.special.spherical_bessel_j0, (torch.rand(128, 128),), check_lowp=False ) + def test_pointwise_xlog1py(self): + self.common( + torch.special.xlog1py, + (torch.rand(128, 128), torch.rand(128, 128)), + check_lowp=False, + ) + def test_broadcast(self): self.common(torch.add, (torch.rand(32, 1024), torch.rand(1024))) diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index 8af344c35ea..61a33f91989 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -367,6 +367,10 @@ class MetalOverrides(OpOverrides): def spherical_bessel_j0(x: CSEVariable) -> str: return f"c10::metal::spherical_bessel_j0({x})" + @staticmethod + def xlog1py(x: CSEVariable) -> str: + return f"c10::metal::xlog1py({x})" + MetalOverrides._initialize_pointwise_overrides("mps")