mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS/inductor] Add support for hermite_polynomial_h. (#150664)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150664 Approved by: https://github.com/malfet
This commit is contained in:
parent
09c4da9325
commit
295b7e21eb
|
|
@ -132,6 +132,7 @@ class MPSBasicTests(TestCase):
|
||||||
"chebyshev_polynomial_u",
|
"chebyshev_polynomial_u",
|
||||||
"chebyshev_polynomial_v",
|
"chebyshev_polynomial_v",
|
||||||
"chebyshev_polynomial_w",
|
"chebyshev_polynomial_w",
|
||||||
|
"hermite_polynomial_h",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_pointwise_binary_op(self, op_name):
|
def test_pointwise_binary_op(self, op_name):
|
||||||
|
|
|
||||||
|
|
@ -449,6 +449,10 @@ class MetalOverrides(OpOverrides):
|
||||||
def chebyshev_polynomial_w(x: CSEVariable, n: CSEVariable) -> str:
|
def chebyshev_polynomial_w(x: CSEVariable, n: CSEVariable) -> str:
|
||||||
return f"c10::metal::chebyshev_polynomial_w_forward({x}, {n})"
|
return f"c10::metal::chebyshev_polynomial_w_forward({x}, {n})"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def hermite_polynomial_h(x: CSEVariable, n: CSEVariable) -> str:
|
||||||
|
return f"c10::metal::hermite_polynomial_h_forward({x}, {n})"
|
||||||
|
|
||||||
|
|
||||||
MetalOverrides._initialize_pointwise_overrides("mps")
|
MetalOverrides._initialize_pointwise_overrides("mps")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user