[MPS/inductor] Add support for hermite_polynomial_h. (#150664)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150664
Approved by: https://github.com/malfet
This commit is contained in:
Davide Italiano 2025-04-04 13:14:52 +00:00 committed by PyTorch MergeBot
parent 09c4da9325
commit 295b7e21eb
2 changed files with 5 additions and 0 deletions

View File

@ -132,6 +132,7 @@ class MPSBasicTests(TestCase):
"chebyshev_polynomial_u", "chebyshev_polynomial_u",
"chebyshev_polynomial_v", "chebyshev_polynomial_v",
"chebyshev_polynomial_w", "chebyshev_polynomial_w",
"hermite_polynomial_h",
], ],
) )
def test_pointwise_binary_op(self, op_name): def test_pointwise_binary_op(self, op_name):

View File

@ -449,6 +449,10 @@ class MetalOverrides(OpOverrides):
def chebyshev_polynomial_w(x: CSEVariable, n: CSEVariable) -> str: def chebyshev_polynomial_w(x: CSEVariable, n: CSEVariable) -> str:
return f"c10::metal::chebyshev_polynomial_w_forward({x}, {n})" return f"c10::metal::chebyshev_polynomial_w_forward({x}, {n})"
@staticmethod
def hermite_polynomial_h(x: CSEVariable, n: CSEVariable) -> str:
return f"c10::metal::hermite_polynomial_h_forward({x}, {n})"
MetalOverrides._initialize_pointwise_overrides("mps") MetalOverrides._initialize_pointwise_overrides("mps")