[MPS/inductor] Add support for modified_scaled_bessel_k{0,1} (#149794)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149794
Approved by: https://github.com/malfet
This commit is contained in:
Davide Italiano 2025-03-22 15:41:40 +00:00 committed by PyTorch MergeBot
parent 6bbe8dbd63
commit 2b848ab192
2 changed files with 10 additions and 0 deletions

View File

@ -99,6 +99,8 @@ class MPSBasicTests(TestCase):
"modified_bessel_i1",
"modified_bessel_k0",
"modified_bessel_k1",
"scaled_modified_bessel_k0",
"scaled_modified_bessel_k1",
"entr",
]

View File

@ -418,6 +418,14 @@ class MetalOverrides(OpOverrides):
def modified_bessel_k1(x: CSEVariable) -> str:
return f"c10::metal::modified_bessel_k1_forward({x})"
@staticmethod
def scaled_modified_bessel_k0(x: CSEVariable) -> str:
return f"c10::metal::scaled_modified_bessel_k0_forward({x})"
@staticmethod
def scaled_modified_bessel_k1(x: CSEVariable) -> str:
return f"c10::metal::scaled_modified_bessel_k1_forward({x})"
MetalOverrides._initialize_pointwise_overrides("mps")