Workaround for mtia double init issue in has_triton (#162974)

Summary:
This change adds a new environment variable (`TORCHINDUCTOR_TRITON_DISABLE_DEVICE_DETECTION`) and configuration in `torch._inductor.config` which can be set to `"1"` to allow a user to disable triton's device detection logic in [torch/utils/_triton.py:has_triton()](c9e57d7e9f/torch/utils/_triton.py (L128)). This function is used at import scope in several places but the function has a side effect of initializing the mtia device if it is available which is causing some of our autotuning workflows to crash.

Worth noting that when enabled this configuration disables all device detection not just mtia and this is because the logic in has_triton will initialize the mtia device as a side effect even when checking for a cuda or other device via the [get_interface_for_device()](c9e57d7e9f/torch/_dynamo/device_interface.py (L570)) function.

I've tagged it `topic: not user facing` since I don't anticipate any outside of meta users making use of this, however this is my first PR here, so please indicate if it should be handled differently.

Test Plan: This has been tested in the context of internal workflows.

Differential Revision: D82347853

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162974
Approved by: https://github.com/xmfan
This commit is contained in:
Scott Rostrup 2025-09-16 04:46:07 +00:00 committed by PyTorch MergeBot
parent 2c45628813
commit b68a5115a4
2 changed files with 10 additions and 0 deletions

View File

@ -467,6 +467,11 @@ max_autotune_prune_choices_based_on_shared_mem = (
== "1"
)
# Disable triton from trying to initialize and detect devices on the host
triton_disable_device_detection = (
os.environ.get("TORCHINDUCTOR_TRITON_DISABLE_DEVICE_DETECTION", "0") == "1"
)
# enable inductor graph partition to allow multiple inductor graphs for the same dynamo graph
graph_partition: bool = (
os.environ.get("TORCHINDUCTOR_GRAPH_PARTITION", "1" if not is_fbcode() else "0")

View File

@ -144,6 +144,11 @@ def has_triton() -> bool:
if not has_triton_package():
return False
from torch._inductor.config import triton_disable_device_detection
if triton_disable_device_detection:
return False
from torch._dynamo.device_interface import get_interface_for_device
def cuda_extra_check(device_interface: Any) -> bool: