mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add justknobs for static cuda launcher (#153400)
Summary: This diff adds a justknobs check for static cuda launcher. In particular, it supports a fractional rollout where each mast job/version can be consistently enrolled in the config on or off. It also adds a set_feature_use so we can track whether static cuda launcher is enabled on a given dynamo compile. Test Plan: Existing unit tests. The justknobs in question are set to be disabled right now, so this diff does not launch the feature yet. Differential Revision: D74599203 Pull Request resolved: https://github.com/pytorch/pytorch/pull/153400 Approved by: https://github.com/oulgen
This commit is contained in:
parent
20ba8fe7e6
commit
5ff2cb8587
|
|
@ -40,11 +40,18 @@ def bundle_triton_into_fx_graph_cache_default() -> Optional[bool]:
|
|||
|
||||
|
||||
def static_cuda_launcher_default() -> bool:
|
||||
result = get_tristate_env(
|
||||
"TORCHINDUCTOR_USE_STATIC_CUDA_LAUNCHER", True if not is_fbcode() else False
|
||||
)
|
||||
assert result is not None
|
||||
return result
|
||||
STATIC_CUDA_LAUNCHER_VERSION = 0
|
||||
|
||||
if "TORCHINDUCTOR_USE_STATIC_CUDA_LAUNCHER" in os.environ:
|
||||
return os.environ.get("TORCHINDUCTOR_USE_STATIC_CUDA_LAUNCHER") == "1"
|
||||
elif is_fbcode():
|
||||
version = torch._utils_internal.justknobs_getval_int(
|
||||
"pytorch/inductor:static_cuda_launcher_version"
|
||||
)
|
||||
return version <= STATIC_CUDA_LAUNCHER_VERSION
|
||||
else:
|
||||
# Default true in OSS
|
||||
return True
|
||||
|
||||
|
||||
def prologue_fusion_enabled() -> bool:
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from typing import (
|
|||
)
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import set_feature_use
|
||||
from torch._prims_common import compute_required_storage_length
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
|
@ -1288,6 +1289,9 @@ class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]):
|
|||
self.kernel.cubin_path = cubin_location
|
||||
|
||||
def make_launcher(self) -> LauncherType:
|
||||
# If at least one static make_launcher call occurs,
|
||||
# we're sure static cuda launcher was used for this compile
|
||||
set_feature_use("static_cuda_launcher", True)
|
||||
# Load the binary on the parent
|
||||
if not self.kernel.cubin_path:
|
||||
self.reload_cubin_path()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user