[mps/inductor] Add support for round() (#144731)

With this change, inductor/test_view_on_aliased passes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144731
Approved by: https://github.com/malfet
This commit is contained in:
Davide Italiano 2025-01-14 05:56:13 +00:00 committed by PyTorch MergeBot
parent 17e05cde0c
commit 35b46a75f1
2 changed files with 5 additions and 0 deletions

View File

@ -64,6 +64,7 @@ class MPSBasicTests(TestCase):
test_slice_scatter4 = CommonTemplate.test_slice_scatter4
test_tanh = CommonTemplate.test_tanh
test_view_as_complex = CommonTemplate.test_view_as_complex
test_view_on_aliased = CommonTemplate.test_view_on_aliased
test_views3 = CommonTemplate.test_views3
test_views6 = CommonTemplate.test_views6
test_views7 = CommonTemplate.test_views7

View File

@ -239,6 +239,10 @@ class MetalOverrides(OpOverrides):
def ceil(x: CSEVariable) -> str:
return f"metal::ceil({x})"
@staticmethod
def round(x: CSEVariable) -> str:
return f"metal::round({x})"
class MetalKernel(SIMDKernel):
overrides = MetalOverrides # type: ignore[assignment]