diff --git a/torch/csrc/jit/passes/cuda_graph_fuser.h b/torch/csrc/jit/passes/cuda_graph_fuser.h index 104a437104a..4bdf83e2b91 100644 --- a/torch/csrc/jit/passes/cuda_graph_fuser.h +++ b/torch/csrc/jit/passes/cuda_graph_fuser.h @@ -12,11 +12,11 @@ namespace jit { struct C10_EXPORT RegisterCudaFuseGraph : public PassManager { static bool registerPass(bool enabled) { - TORCH_CHECK( - at::globalContext().hasCUDA() && !at::globalContext().hasHIP(), - "Running CUDA fuser is only supported on CUDA builds."); bool old_flag = PassManager::isRegistered(); if (enabled) { + TORCH_CHECK( + at::globalContext().hasCUDA() && !at::globalContext().hasHIP(), + "Running CUDA fuser is only supported on CUDA builds."); PassManager::registerPass(fuser::cuda::fuseGraph); } else { PassManager::clearPass();