diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index f982279915a..841d515f425 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -94,14 +94,6 @@ from .schema import ( # type: ignore[attr-defined] from .union import _Union -if has_triton(): - from triton.runtime.autotuner import Autotuner -else: - - class Autotuner: # type: ignore[no-redef] - pass - - __all__ = [ "serialize", "GraphModuleSerializer", @@ -684,6 +676,7 @@ class GraphModuleSerializer(metaclass=Final): is torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional ): assert has_triton(), "triton required to serialize triton kernels" + from triton.runtime.autotuner import Autotuner meta_val = node.meta["val"] assert isinstance(meta_val, dict)