mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS/Inductor] Add support for xlog1py. (#147709)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147709 Approved by: https://github.com/jansel
This commit is contained in:
parent
baccadb2f1
commit
8b65dbad13
|
|
@ -119,6 +119,13 @@ class MPSBasicTests(TestCase):
|
|||
torch.special.spherical_bessel_j0, (torch.rand(128, 128),), check_lowp=False
|
||||
)
|
||||
|
||||
def test_pointwise_xlog1py(self):
|
||||
self.common(
|
||||
torch.special.xlog1py,
|
||||
(torch.rand(128, 128), torch.rand(128, 128)),
|
||||
check_lowp=False,
|
||||
)
|
||||
|
||||
def test_broadcast(self):
|
||||
self.common(torch.add, (torch.rand(32, 1024), torch.rand(1024)))
|
||||
|
||||
|
|
|
|||
|
|
@ -367,6 +367,10 @@ class MetalOverrides(OpOverrides):
|
|||
def spherical_bessel_j0(x: CSEVariable) -> str:
|
||||
return f"c10::metal::spherical_bessel_j0({x})"
|
||||
|
||||
@staticmethod
|
||||
def xlog1py(x: CSEVariable) -> str:
|
||||
return f"c10::metal::xlog1py({x})"
|
||||
|
||||
|
||||
MetalOverrides._initialize_pointwise_overrides("mps")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user