mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Workaround for mtia double init issue in has_triton (#162974)
Summary: This change adds a new environment variable (`TORCHINDUCTOR_TRITON_DISABLE_DEVICE_DETECTION`) and configuration in `torch._inductor.config` which can be set to `"1"` to allow a user to disable triton's device detection logic in [torch/utils/_triton.py:has_triton()](c9e57d7e9f/torch/utils/_triton.py (L128)). This function is used at import scope in several places but the function has a side effect of initializing the mtia device if it is available which is causing some of our autotuning workflows to crash. Worth noting that when enabled this configuration disables all device detection not just mtia and this is because the logic in has_triton will initialize the mtia device as a side effect even when checking for a cuda or other device via the [get_interface_for_device()](c9e57d7e9f/torch/_dynamo/device_interface.py (L570)) function. I've tagged it `topic: not user facing` since I don't anticipate any outside of meta users making use of this, however this is my first PR here, so please indicate if it should be handled differently. Test Plan: This has been tested in the context of internal workflows. Differential Revision: D82347853 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162974 Approved by: https://github.com/xmfan
This commit is contained in:
parent
2c45628813
commit
b68a5115a4
|
|
@ -467,6 +467,11 @@ max_autotune_prune_choices_based_on_shared_mem = (
|
|||
== "1"
|
||||
)
|
||||
|
||||
# Disable triton from trying to initialize and detect devices on the host
|
||||
triton_disable_device_detection = (
|
||||
os.environ.get("TORCHINDUCTOR_TRITON_DISABLE_DEVICE_DETECTION", "0") == "1"
|
||||
)
|
||||
|
||||
# enable inductor graph partition to allow multiple inductor graphs for the same dynamo graph
|
||||
graph_partition: bool = (
|
||||
os.environ.get("TORCHINDUCTOR_GRAPH_PARTITION", "1" if not is_fbcode() else "0")
|
||||
|
|
|
|||
|
|
@ -144,6 +144,11 @@ def has_triton() -> bool:
|
|||
if not has_triton_package():
|
||||
return False
|
||||
|
||||
from torch._inductor.config import triton_disable_device_detection
|
||||
|
||||
if triton_disable_device_detection:
|
||||
return False
|
||||
|
||||
from torch._dynamo.device_interface import get_interface_for_device
|
||||
|
||||
def cuda_extra_check(device_interface: Any) -> bool:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user