mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[MPS/inductor] Add support for modified_scaled_bessel_k{0,1} (#149794)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149794 Approved by: https://github.com/malfet
This commit is contained in:
parent
6bbe8dbd63
commit
2b848ab192
|
|
@ -99,6 +99,8 @@ class MPSBasicTests(TestCase):
|
||||||
"modified_bessel_i1",
|
"modified_bessel_i1",
|
||||||
"modified_bessel_k0",
|
"modified_bessel_k0",
|
||||||
"modified_bessel_k1",
|
"modified_bessel_k1",
|
||||||
|
"scaled_modified_bessel_k0",
|
||||||
|
"scaled_modified_bessel_k1",
|
||||||
"entr",
|
"entr",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -418,6 +418,14 @@ class MetalOverrides(OpOverrides):
|
||||||
def modified_bessel_k1(x: CSEVariable) -> str:
|
def modified_bessel_k1(x: CSEVariable) -> str:
|
||||||
return f"c10::metal::modified_bessel_k1_forward({x})"
|
return f"c10::metal::modified_bessel_k1_forward({x})"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def scaled_modified_bessel_k0(x: CSEVariable) -> str:
|
||||||
|
return f"c10::metal::scaled_modified_bessel_k0_forward({x})"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def scaled_modified_bessel_k1(x: CSEVariable) -> str:
|
||||||
|
return f"c10::metal::scaled_modified_bessel_k1_forward({x})"
|
||||||
|
|
||||||
|
|
||||||
MetalOverrides._initialize_pointwise_overrides("mps")
|
MetalOverrides._initialize_pointwise_overrides("mps")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user