Revert "[inductor] Remove usage of device_interface from _inductor.runtime (#124592)"

This reverts commit 5d45eb77f1.

Reverted https://github.com/pytorch/pytorch/pull/124592 on behalf of https://github.com/jeanschmidt due to breaking internal tests, check D56522594 ([comment](https://github.com/pytorch/pytorch/pull/124592#issuecomment-2076957668))
This commit is contained in:
PyTorch MergeBot 2024-04-25 11:28:21 +00:00
parent 58806d6531
commit f6ce94dca5
8 changed files with 79 additions and 100 deletions

View File

@ -14,7 +14,6 @@ from torch._dynamo.testing import rand_strided
from torch._dynamo.utils import same
from torch._inductor import config
from torch._inductor.compile_fx import compile_fx_inner
from torch._inductor.runtime.hints import DeviceProperties
from torch._inductor.utils import run_and_get_code
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing import FileCheck
@ -406,7 +405,7 @@ class CudaReproTests(TestCase):
],
meta={
"signature": {0: "*fp32", 1: "*fp32", 2: "i32"},
"device": DeviceProperties.create(torch.device("cuda")),
"device": 0,
"configs": [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=())],
"constants": {},
},

View File

@ -45,12 +45,16 @@ from typing import (
Optional,
Set,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
import torch
from torch._dynamo.device_interface import get_registered_device_interfaces
from torch._dynamo.device_interface import (
get_interface_for_device,
get_registered_device_interfaces,
)
from torch._dynamo.utils import counters, dynamo_timed
from torch._inductor import config, exc, metrics
from torch._inductor.codegen.cuda import cuda_env
@ -66,6 +70,7 @@ from torch._subclasses.fake_tensor import (
from torch.fx.experimental.symbolic_shapes import has_hint, hint_int, ShapeEnv
if TYPE_CHECKING:
from torch._dynamo.device_interface import DeviceInterface
from torch._inductor.graph import GraphLowering
from torch._inductor.ir import ChoiceCaller
@ -2766,9 +2771,14 @@ def _set_triton_ptxas_path() -> None:
def _worker_compile_triton(
load_kernel: Callable[[], Any],
cc: int,
device: torch.device,
device_interface: Type[DeviceInterface],
):
_set_triton_ptxas_path()
load_kernel().precompile(warm_cache_only=True)
device_interface.Worker.set_device(device.index)
kernel = load_kernel()
kernel.precompile(warm_cache_only_with_cc=cc)
class CodeCacheFuture:
@ -2931,13 +2941,17 @@ class AsyncCompile:
kernel = TritonCodeCache.load(kernel_name, source_code)
if config.compile_threads > 1:
return TritonFuture(
kernel,
self.process_pool().submit(
_worker_compile_triton,
kernel._reload_in_subproc,
),
device_interface = get_interface_for_device(device_str)
device = torch.device(device_str, device_interface.current_device())
cc = device_interface.get_compute_capability(device)
future = self.process_pool().submit(
_worker_compile_triton,
kernel._reload_in_subproc,
cc,
device,
device_interface,
)
return TritonFuture(kernel, future)
else:
kernel.precompile()
return kernel

View File

@ -34,7 +34,7 @@ import torch.utils._pytree as pytree
from torch._dynamo.utils import preserve_rng_state
from torch._inductor.metrics import is_metric_table_enabled, log_kernel_metadata
from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties
from torch._inductor.runtime.hints import AutotuneHint
from torch._prims_common import is_integer_dtype
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
from torch.utils._sympy.value_ranges import ValueRanges
@ -125,7 +125,7 @@ def gen_common_triton_imports():
"""
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor
"""
)
return imports.getvalue()
@ -2833,7 +2833,8 @@ class TritonKernel(Kernel):
)
triton_meta = {
"signature": triton_meta_signature,
"device": DeviceProperties.create(V.graph.scheduler.current_device),
"device": V.graph.scheduler.current_device.index,
"device_type": V.graph.scheduler.current_device.type,
"constants": {},
}

View File

@ -6,7 +6,6 @@ from typing import Dict, List, Tuple
from sympy import Integer
from .. import metrics
from ..runtime.hints import DeviceProperties
from ..scheduler import SchedulerNode
from ..utils import ceildiv, Placeholder
from ..virtualized import V
@ -158,7 +157,8 @@ class ForeachKernel(Kernel):
_, _, signature = self.args.python_argdefs()
triton_meta = {
"signature": signature_to_meta(signature, size_dtype=size_dtype),
"device": DeviceProperties.create(V.graph.scheduler.current_device),
"device": V.graph.scheduler.current_device.index,
"device_type": V.graph.scheduler.current_device.type,
"constants": {},
}
triton_meta["configs"] = [config_of(signature)]

View File

@ -35,7 +35,6 @@ from torch.utils._sympy.singleton_int import SingletonInt
from .. import codecache, config, ir
from ..ir import ReinterpretView
from ..runtime import triton_heuristics
from ..runtime.hints import DeviceProperties
from ..utils import (
cache_on_self,
get_benchmark_name,
@ -1107,7 +1106,8 @@ class WrapperCodeGen(CodeGen):
size_dtype=index_dtype,
indices=non_constant_indices,
),
"device": DeviceProperties.create(V.graph.scheduler.current_device),
"device": V.graph.scheduler.current_device.index,
"device_type": V.graph.scheduler.current_device.type,
# Triton compiler includes equal_to_1 args into constants even
# when they are not constexpr. otherwise there may be a segfault
# during launching the Inductor-compiled Triton kernel.

View File

@ -1,8 +1,6 @@
import collections
import typing
from dataclasses import fields
from enum import auto, Enum
from typing import Optional
# NOTE: if these fail asserts submit a PR to increase them
@ -91,39 +89,3 @@ class AutotuneHint(Enum):
# which isn't valid python.
# Enum.__str__ will just return "AutotuneHint.ELEMENTS_PER_WARP_32".
__repr__ = Enum.__str__
class DeviceProperties(typing.NamedTuple):
"""Copy device properties into a data structure not requiring torch to be imported"""
type: str # type: ignore[assignment]
index: int # type: ignore[assignment]
cc: int
major: Optional[int] = None
regs_per_multiprocessor: Optional[int] = None
max_threads_per_multi_processor: Optional[int] = None
multi_processor_count: Optional[int] = None
@classmethod
def create(cls, device):
import torch
from torch._dynamo.device_interface import get_interface_for_device
device_type = device.type if torch.version.hip is None else "hip"
device_interface = get_interface_for_device(device)
if device_type == "cuda":
props = device_interface.get_device_properties(device)
return cls(
type=device_type,
index=device.index,
cc=device_interface.get_compute_capability(device),
major=props.major,
regs_per_multiprocessor=props.regs_per_multiprocessor,
max_threads_per_multi_processor=props.max_threads_per_multi_processor,
multi_processor_count=props.multi_processor_count,
)
return cls(
type=device_type,
index=device.index,
cc=device_interface.get_compute_capability(device),
)

View File

@ -16,12 +16,12 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple
import torch
from torch._dynamo.device_interface import DeviceGuard, get_interface_for_device
from .coordinate_descent_tuner import CoordescTuner
from .hints import (
_NUM_THREADS_PER_WARP,
AutotuneHint,
DeviceProperties,
HeuristicType,
ReductionHint,
TileHint,
@ -144,12 +144,7 @@ class CachingAutotuner(KernelInterface):
assert len(configs) > 0, "Non-empty TritonConfig list required for compiling"
self.fn = fn
self.device_props: DeviceProperties = triton_meta["device"]
self.triton_meta = {
**triton_meta,
"device": self.device_props.index,
"device_type": self.device_props.type,
}
self.triton_meta = triton_meta
self.inductor_meta = {} if inductor_meta is None else inductor_meta
self.save_cache_hook = save_cache_hook
self.mutated_arg_names = mutated_arg_names
@ -157,6 +152,13 @@ class CachingAutotuner(KernelInterface):
self.heuristic_type = heuristic_type
self.custom_kernel = custom_kernel
self.cuda_kernel_saved = False
# Align the default design that default as cuda
self.device_type = (
triton_meta["device_type"] if "device_type" in triton_meta else "cuda"
)
self.device_interface = get_interface_for_device(self.device_type)
if log.isEnabledFor(logging.DEBUG):
log.debug(
"CachingAutotuner gets %d configs for %s",
@ -184,7 +186,7 @@ class CachingAutotuner(KernelInterface):
)
self.filename = filename
def precompile(self, warm_cache_only=False):
def precompile(self, warm_cache_only_with_cc=None):
with self.lock:
if self.launchers:
return
@ -196,7 +198,7 @@ class CachingAutotuner(KernelInterface):
for c in self.configs:
try:
compiled_binary, launcher = self._precompile_config(
c, warm_cache_only
c, warm_cache_only_with_cc
)
except OutOfResources as e:
if len(self.configs) == 1:
@ -217,19 +219,19 @@ class CachingAutotuner(KernelInterface):
seen_configs = set(self.configs)
device_prop = self.device_props
device_prop = self.device_interface.Worker.get_device_properties(
self.triton_meta["device"]
)
if (
self.inductor_meta.get("dynamic_scale_rblock", True)
and self.heuristic_type == HeuristicType.REDUCTION
and self.size_hints is not None
# Disable for AMDGPU/Intel as Triton is not ready to return n_regs for a compiled_binary.
and device_prop.type == "cuda"
and device_prop.major
# Disable for AMDGPU as Triton is not ready to return n_regs for a compiled_binary.
and not self.inductor_meta.get("is_hip")
# Disable for Intel GPU as Triton is not ready to return n_regs for a compiled_binary.
and self.device_type != "xpu"
and device_prop.major >= 8
):
assert device_prop.regs_per_multiprocessor
assert device_prop.max_threads_per_multi_processor
assert device_prop.multi_processor_count
for triton_config, compiled_binary in zip(
self.configs, compiled_binaries
):
@ -290,21 +292,15 @@ class CachingAutotuner(KernelInterface):
continue
seen_configs.add(new_config)
self.launchers.append(
self._precompile_config(new_config, warm_cache_only)[1]
self._precompile_config(new_config, warm_cache_only_with_cc)[1]
)
self.configs = None
def get_device_interface(self):
# this code cannot run in compile workers, because it imports from torch
from torch._dynamo.device_interface import get_interface_for_device
return get_interface_for_device(self.device_props.type.replace("hip", "cuda"))
def _precompile_config(self, cfg: Config, warm_cache_only: bool):
def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: Optional[int]):
"""Ahead of time compile a given autotuner config."""
compile_meta = copy.deepcopy(self.triton_meta)
for k, v in cfg.kwargs.items():
if self.device_props.type != "hip":
if torch.version.hip is not None:
if k == "matrix_instr_nonkdim":
compile_meta["matrix_instr_nonkdim"] = v
continue
@ -318,9 +314,22 @@ class CachingAutotuner(KernelInterface):
"assert_indirect_indexing", True
) and not self.inductor_meta.get("is_hip", False)
# device type will be "hip" rather than "cuda" here
compile_meta["device_type"] = self.device_props.type
compile_meta["cc"] = self.device_props.cc
# Setting device_type="hip" required on ROCm to pass down to triton
compile_meta["device_type"] = (
self.device_type if torch.version.hip is None else "hip"
)
if warm_cache_only_with_cc:
cc = warm_cache_only_with_cc
else:
# Use device_type 'cuda' for both cuda and hip devices to retrieve
# the compute capability.
device_type = self.device_type if torch.version.hip is None else "cuda"
device_id = compile_meta["device"]
device = torch.device(device_type, device_id)
cc = self.device_interface.get_compute_capability(device)
compile_meta["cc"] = cc
if ASTSource:
compile_args = (
@ -332,13 +341,13 @@ class CachingAutotuner(KernelInterface):
),
)
target = (compile_meta["device_type"], compile_meta["cc"])
target = (compile_meta["device_type"], cc)
options = {
"num_warps": compile_meta["num_warps"],
"num_stages": compile_meta["num_stages"],
"debug": compile_meta["debug"],
}
if self.device_props.type != "hip":
if torch.version.hip is not None:
if "waves_per_eu" in compile_meta:
options["waves_per_eu"] = compile_meta["waves_per_eu"]
if "matrix_instr_nonkdim" in compile_meta:
@ -353,21 +362,16 @@ class CachingAutotuner(KernelInterface):
compile_args = (self.fn,)
compile_kwargs = compile_meta
if warm_cache_only:
if warm_cache_only_with_cc:
return (
triton.compile(*compile_args, **compile_kwargs),
None,
)
# importing from torch is safe now that precompile has returned
from torch._dynamo.device_interface import DeviceGuard
device_interface = self.get_device_interface()
# load binary to the correct device
with DeviceGuard(device_interface, compile_meta["device"]): # type: ignore[attr-defined]
with DeviceGuard(self.device_interface, compile_meta["device"]): # type: ignore[attr-defined]
# need to initialize context
device_interface.synchronize(device_interface.current_device())
self.device_interface.synchronize(self.device_interface.current_device())
try:
binary = triton.compile(*compile_args, **compile_kwargs)
@ -585,9 +589,8 @@ class CachingAutotuner(KernelInterface):
)
return float("inf")
device_interface = self.get_device_interface()
stream = device_interface.get_raw_stream( # type: ignore[call-arg]
device_interface.current_device()
stream = self.device_interface.get_raw_stream( # type: ignore[call-arg]
self.device_interface.current_device()
)
def kernel_call():
@ -694,7 +697,7 @@ class CachingAutotuner(KernelInterface):
from torch._inductor.codecache import CudaKernelParamCache
if self.device_props.type != "hip":
if torch.version.hip is None:
CudaKernelParamCache.set(key, params, launcher.bin.asm["cubin"])
else:
# There is some divergence between CUDA and ROCm here.
@ -732,7 +735,7 @@ class CachingAutotuner(KernelInterface):
def benchmark_one_config(config):
with self.lock:
_, launcher = self._precompile_config(config, False)
_, launcher = self._precompile_config(config, None)
config2launcher[config] = launcher
out = self.bench(launcher, *cloned_args, **kwargs)

View File

@ -35,7 +35,6 @@ from .codegen.triton import (
from .codegen.triton_utils import config_of, signature_to_meta
from .exc import CUDACompileError
from .ir import ChoiceCaller, PrimitiveInfoType
from .runtime.hints import DeviceProperties
from .runtime.runtime_utils import do_bench
from .utils import get_dtype_size, Placeholder, sympy_dot, sympy_product, unique
from .virtualized import V
@ -148,7 +147,8 @@ class TritonTemplateKernel(TritonKernel):
argdefs, _, signature = self.args.python_argdefs()
triton_meta = {
"signature": signature_to_meta(signature, size_dtype=self.index_dtype),
"device": DeviceProperties.create(self.output_node.get_device()),
"device": self.output_node.get_device().index,
"device_type": self.output_node.get_device().type,
"constants": {},
}
triton_meta["configs"] = [config_of(signature)]