diff --git a/test/inductor/test_mps_basic.py b/test/inductor/test_mps_basic.py index 187ca3e22d5..74b827be4dc 100644 --- a/test/inductor/test_mps_basic.py +++ b/test/inductor/test_mps_basic.py @@ -99,6 +99,8 @@ class MPSBasicTests(TestCase): "modified_bessel_i1", "modified_bessel_k0", "modified_bessel_k1", + "scaled_modified_bessel_k0", + "scaled_modified_bessel_k1", "entr", ] diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index 50de653fa52..dc438b0f9e7 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -418,6 +418,14 @@ class MetalOverrides(OpOverrides): def modified_bessel_k1(x: CSEVariable) -> str: return f"c10::metal::modified_bessel_k1_forward({x})" + @staticmethod + def scaled_modified_bessel_k0(x: CSEVariable) -> str: + return f"c10::metal::scaled_modified_bessel_k0_forward({x})" + + @staticmethod + def scaled_modified_bessel_k1(x: CSEVariable) -> str: + return f"c10::metal::scaled_modified_bessel_k1_forward({x})" + MetalOverrides._initialize_pointwise_overrides("mps")