mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[mps/inductor] Add support for round() (#144731)
With this change, inductor/test_view_on_aliased passes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144731 Approved by: https://github.com/malfet
This commit is contained in:
parent
17e05cde0c
commit
35b46a75f1
|
|
@ -64,6 +64,7 @@ class MPSBasicTests(TestCase):
|
|||
test_slice_scatter4 = CommonTemplate.test_slice_scatter4
|
||||
test_tanh = CommonTemplate.test_tanh
|
||||
test_view_as_complex = CommonTemplate.test_view_as_complex
|
||||
test_view_on_aliased = CommonTemplate.test_view_on_aliased
|
||||
test_views3 = CommonTemplate.test_views3
|
||||
test_views6 = CommonTemplate.test_views6
|
||||
test_views7 = CommonTemplate.test_views7
|
||||
|
|
|
|||
|
|
@ -239,6 +239,10 @@ class MetalOverrides(OpOverrides):
|
|||
def ceil(x: CSEVariable) -> str:
|
||||
return f"metal::ceil({x})"
|
||||
|
||||
@staticmethod
|
||||
def round(x: CSEVariable) -> str:
|
||||
return f"metal::round({x})"
|
||||
|
||||
|
||||
class MetalKernel(SIMDKernel):
|
||||
overrides = MetalOverrides # type: ignore[assignment]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user