[MPS/Inductor] Add support for xlog1py. (#147709)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147709
Approved by: https://github.com/jansel
This commit is contained in:
Davide Italiano 2025-02-24 05:28:52 +00:00 committed by PyTorch MergeBot
parent baccadb2f1
commit 8b65dbad13
2 changed files with 11 additions and 0 deletions

View File

@ -119,6 +119,13 @@ class MPSBasicTests(TestCase):
torch.special.spherical_bessel_j0, (torch.rand(128, 128),), check_lowp=False
)
def test_pointwise_xlog1py(self):
self.common(
torch.special.xlog1py,
(torch.rand(128, 128), torch.rand(128, 128)),
check_lowp=False,
)
def test_broadcast(self):
self.common(torch.add, (torch.rand(32, 1024), torch.rand(1024)))

View File

@ -367,6 +367,10 @@ class MetalOverrides(OpOverrides):
def spherical_bessel_j0(x: CSEVariable) -> str:
return f"c10::metal::spherical_bessel_j0({x})"
@staticmethod
def xlog1py(x: CSEVariable) -> str:
return f"c10::metal::xlog1py({x})"
MetalOverrides._initialize_pointwise_overrides("mps")