mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[CI] Fix GPUTests.test_scheduler_vertical_fusion1 (#151166)
By enabling the test_operators on MPS device Pull Request resolved: https://github.com/pytorch/pytorch/pull/151166 Approved by: https://github.com/dcci
This commit is contained in:
parent
9699cc3eb9
commit
d289d1177c
|
|
@ -223,6 +223,7 @@ for test_name in [
|
||||||
"test_rsqrt",
|
"test_rsqrt",
|
||||||
"test_scalar_cpu_tensor_arg",
|
"test_scalar_cpu_tensor_arg",
|
||||||
"test_scalar_output",
|
"test_scalar_output",
|
||||||
|
"test_scheduler_vertical_fusion1",
|
||||||
"test_setitem_with_int_parameter",
|
"test_setitem_with_int_parameter",
|
||||||
"test_signbit",
|
"test_signbit",
|
||||||
"test_silu",
|
"test_silu",
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ if not torch._running_with_deploy():
|
||||||
)
|
)
|
||||||
|
|
||||||
_test_lib_impl = torch.library.Library("_inductor_test", "IMPL")
|
_test_lib_impl = torch.library.Library("_inductor_test", "IMPL")
|
||||||
for dispatch_key in ("CPU", "CUDA", "Meta"):
|
for dispatch_key in ("CPU", "CUDA", "MPS", "Meta"):
|
||||||
_test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key)
|
_test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key)
|
||||||
|
|
||||||
class Realize(Function):
|
class Realize(Function):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user