mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Revert "[inductor] Remove usage of device_interface from _inductor.runtime (#124592)"
This reverts commit 5d45eb77f1.
Reverted https://github.com/pytorch/pytorch/pull/124592 on behalf of https://github.com/jeanschmidt due to breaking internal tests, check D56522594 ([comment](https://github.com/pytorch/pytorch/pull/124592#issuecomment-2076957668))
This commit is contained in:
parent
58806d6531
commit
f6ce94dca5
|
|
@ -14,7 +14,6 @@ from torch._dynamo.testing import rand_strided
|
|||
from torch._dynamo.utils import same
|
||||
from torch._inductor import config
|
||||
from torch._inductor.compile_fx import compile_fx_inner
|
||||
from torch._inductor.runtime.hints import DeviceProperties
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing import FileCheck
|
||||
|
|
@ -406,7 +405,7 @@ class CudaReproTests(TestCase):
|
|||
],
|
||||
meta={
|
||||
"signature": {0: "*fp32", 1: "*fp32", 2: "i32"},
|
||||
"device": DeviceProperties.create(torch.device("cuda")),
|
||||
"device": 0,
|
||||
"configs": [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=())],
|
||||
"constants": {},
|
||||
},
|
||||
|
|
|
|||
|
|
@ -45,12 +45,16 @@ from typing import (
|
|||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
|
||||
import torch
|
||||
from torch._dynamo.device_interface import get_registered_device_interfaces
|
||||
from torch._dynamo.device_interface import (
|
||||
get_interface_for_device,
|
||||
get_registered_device_interfaces,
|
||||
)
|
||||
from torch._dynamo.utils import counters, dynamo_timed
|
||||
from torch._inductor import config, exc, metrics
|
||||
from torch._inductor.codegen.cuda import cuda_env
|
||||
|
|
@ -66,6 +70,7 @@ from torch._subclasses.fake_tensor import (
|
|||
from torch.fx.experimental.symbolic_shapes import has_hint, hint_int, ShapeEnv
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.device_interface import DeviceInterface
|
||||
from torch._inductor.graph import GraphLowering
|
||||
from torch._inductor.ir import ChoiceCaller
|
||||
|
||||
|
|
@ -2766,9 +2771,14 @@ def _set_triton_ptxas_path() -> None:
|
|||
|
||||
def _worker_compile_triton(
|
||||
load_kernel: Callable[[], Any],
|
||||
cc: int,
|
||||
device: torch.device,
|
||||
device_interface: Type[DeviceInterface],
|
||||
):
|
||||
_set_triton_ptxas_path()
|
||||
load_kernel().precompile(warm_cache_only=True)
|
||||
device_interface.Worker.set_device(device.index)
|
||||
kernel = load_kernel()
|
||||
kernel.precompile(warm_cache_only_with_cc=cc)
|
||||
|
||||
|
||||
class CodeCacheFuture:
|
||||
|
|
@ -2931,13 +2941,17 @@ class AsyncCompile:
|
|||
|
||||
kernel = TritonCodeCache.load(kernel_name, source_code)
|
||||
if config.compile_threads > 1:
|
||||
return TritonFuture(
|
||||
kernel,
|
||||
self.process_pool().submit(
|
||||
_worker_compile_triton,
|
||||
kernel._reload_in_subproc,
|
||||
),
|
||||
device_interface = get_interface_for_device(device_str)
|
||||
device = torch.device(device_str, device_interface.current_device())
|
||||
cc = device_interface.get_compute_capability(device)
|
||||
future = self.process_pool().submit(
|
||||
_worker_compile_triton,
|
||||
kernel._reload_in_subproc,
|
||||
cc,
|
||||
device,
|
||||
device_interface,
|
||||
)
|
||||
return TritonFuture(kernel, future)
|
||||
else:
|
||||
kernel.precompile()
|
||||
return kernel
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ import torch.utils._pytree as pytree
|
|||
from torch._dynamo.utils import preserve_rng_state
|
||||
|
||||
from torch._inductor.metrics import is_metric_table_enabled, log_kernel_metadata
|
||||
from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties
|
||||
from torch._inductor.runtime.hints import AutotuneHint
|
||||
from torch._prims_common import is_integer_dtype
|
||||
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
|
@ -125,7 +125,7 @@ def gen_common_triton_imports():
|
|||
"""
|
||||
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
||||
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
||||
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
|
||||
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor
|
||||
"""
|
||||
)
|
||||
return imports.getvalue()
|
||||
|
|
@ -2833,7 +2833,8 @@ class TritonKernel(Kernel):
|
|||
)
|
||||
triton_meta = {
|
||||
"signature": triton_meta_signature,
|
||||
"device": DeviceProperties.create(V.graph.scheduler.current_device),
|
||||
"device": V.graph.scheduler.current_device.index,
|
||||
"device_type": V.graph.scheduler.current_device.type,
|
||||
"constants": {},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from typing import Dict, List, Tuple
|
|||
from sympy import Integer
|
||||
|
||||
from .. import metrics
|
||||
from ..runtime.hints import DeviceProperties
|
||||
from ..scheduler import SchedulerNode
|
||||
from ..utils import ceildiv, Placeholder
|
||||
from ..virtualized import V
|
||||
|
|
@ -158,7 +157,8 @@ class ForeachKernel(Kernel):
|
|||
_, _, signature = self.args.python_argdefs()
|
||||
triton_meta = {
|
||||
"signature": signature_to_meta(signature, size_dtype=size_dtype),
|
||||
"device": DeviceProperties.create(V.graph.scheduler.current_device),
|
||||
"device": V.graph.scheduler.current_device.index,
|
||||
"device_type": V.graph.scheduler.current_device.type,
|
||||
"constants": {},
|
||||
}
|
||||
triton_meta["configs"] = [config_of(signature)]
|
||||
|
|
|
|||
|
|
@ -35,7 +35,6 @@ from torch.utils._sympy.singleton_int import SingletonInt
|
|||
from .. import codecache, config, ir
|
||||
from ..ir import ReinterpretView
|
||||
from ..runtime import triton_heuristics
|
||||
from ..runtime.hints import DeviceProperties
|
||||
from ..utils import (
|
||||
cache_on_self,
|
||||
get_benchmark_name,
|
||||
|
|
@ -1107,7 +1106,8 @@ class WrapperCodeGen(CodeGen):
|
|||
size_dtype=index_dtype,
|
||||
indices=non_constant_indices,
|
||||
),
|
||||
"device": DeviceProperties.create(V.graph.scheduler.current_device),
|
||||
"device": V.graph.scheduler.current_device.index,
|
||||
"device_type": V.graph.scheduler.current_device.type,
|
||||
# Triton compiler includes equal_to_1 args into constants even
|
||||
# when they are not constexpr. otherwise there may be a segfault
|
||||
# during launching the Inductor-compiled Triton kernel.
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
import collections
|
||||
import typing
|
||||
from dataclasses import fields
|
||||
from enum import auto, Enum
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# NOTE: if these fail asserts submit a PR to increase them
|
||||
|
|
@ -91,39 +89,3 @@ class AutotuneHint(Enum):
|
|||
# which isn't valid python.
|
||||
# Enum.__str__ will just return "AutotuneHint.ELEMENTS_PER_WARP_32".
|
||||
__repr__ = Enum.__str__
|
||||
|
||||
|
||||
class DeviceProperties(typing.NamedTuple):
|
||||
"""Copy device properties into a data structure not requiring torch to be imported"""
|
||||
|
||||
type: str # type: ignore[assignment]
|
||||
index: int # type: ignore[assignment]
|
||||
cc: int
|
||||
major: Optional[int] = None
|
||||
regs_per_multiprocessor: Optional[int] = None
|
||||
max_threads_per_multi_processor: Optional[int] = None
|
||||
multi_processor_count: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def create(cls, device):
|
||||
import torch
|
||||
from torch._dynamo.device_interface import get_interface_for_device
|
||||
|
||||
device_type = device.type if torch.version.hip is None else "hip"
|
||||
device_interface = get_interface_for_device(device)
|
||||
if device_type == "cuda":
|
||||
props = device_interface.get_device_properties(device)
|
||||
return cls(
|
||||
type=device_type,
|
||||
index=device.index,
|
||||
cc=device_interface.get_compute_capability(device),
|
||||
major=props.major,
|
||||
regs_per_multiprocessor=props.regs_per_multiprocessor,
|
||||
max_threads_per_multi_processor=props.max_threads_per_multi_processor,
|
||||
multi_processor_count=props.multi_processor_count,
|
||||
)
|
||||
return cls(
|
||||
type=device_type,
|
||||
index=device.index,
|
||||
cc=device_interface.get_compute_capability(device),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -16,12 +16,12 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple
|
|||
|
||||
import torch
|
||||
|
||||
from torch._dynamo.device_interface import DeviceGuard, get_interface_for_device
|
||||
from .coordinate_descent_tuner import CoordescTuner
|
||||
|
||||
from .hints import (
|
||||
_NUM_THREADS_PER_WARP,
|
||||
AutotuneHint,
|
||||
DeviceProperties,
|
||||
HeuristicType,
|
||||
ReductionHint,
|
||||
TileHint,
|
||||
|
|
@ -144,12 +144,7 @@ class CachingAutotuner(KernelInterface):
|
|||
|
||||
assert len(configs) > 0, "Non-empty TritonConfig list required for compiling"
|
||||
self.fn = fn
|
||||
self.device_props: DeviceProperties = triton_meta["device"]
|
||||
self.triton_meta = {
|
||||
**triton_meta,
|
||||
"device": self.device_props.index,
|
||||
"device_type": self.device_props.type,
|
||||
}
|
||||
self.triton_meta = triton_meta
|
||||
self.inductor_meta = {} if inductor_meta is None else inductor_meta
|
||||
self.save_cache_hook = save_cache_hook
|
||||
self.mutated_arg_names = mutated_arg_names
|
||||
|
|
@ -157,6 +152,13 @@ class CachingAutotuner(KernelInterface):
|
|||
self.heuristic_type = heuristic_type
|
||||
self.custom_kernel = custom_kernel
|
||||
self.cuda_kernel_saved = False
|
||||
|
||||
# Align the default design that default as cuda
|
||||
self.device_type = (
|
||||
triton_meta["device_type"] if "device_type" in triton_meta else "cuda"
|
||||
)
|
||||
self.device_interface = get_interface_for_device(self.device_type)
|
||||
|
||||
if log.isEnabledFor(logging.DEBUG):
|
||||
log.debug(
|
||||
"CachingAutotuner gets %d configs for %s",
|
||||
|
|
@ -184,7 +186,7 @@ class CachingAutotuner(KernelInterface):
|
|||
)
|
||||
self.filename = filename
|
||||
|
||||
def precompile(self, warm_cache_only=False):
|
||||
def precompile(self, warm_cache_only_with_cc=None):
|
||||
with self.lock:
|
||||
if self.launchers:
|
||||
return
|
||||
|
|
@ -196,7 +198,7 @@ class CachingAutotuner(KernelInterface):
|
|||
for c in self.configs:
|
||||
try:
|
||||
compiled_binary, launcher = self._precompile_config(
|
||||
c, warm_cache_only
|
||||
c, warm_cache_only_with_cc
|
||||
)
|
||||
except OutOfResources as e:
|
||||
if len(self.configs) == 1:
|
||||
|
|
@ -217,19 +219,19 @@ class CachingAutotuner(KernelInterface):
|
|||
|
||||
seen_configs = set(self.configs)
|
||||
|
||||
device_prop = self.device_props
|
||||
device_prop = self.device_interface.Worker.get_device_properties(
|
||||
self.triton_meta["device"]
|
||||
)
|
||||
if (
|
||||
self.inductor_meta.get("dynamic_scale_rblock", True)
|
||||
and self.heuristic_type == HeuristicType.REDUCTION
|
||||
and self.size_hints is not None
|
||||
# Disable for AMDGPU/Intel as Triton is not ready to return n_regs for a compiled_binary.
|
||||
and device_prop.type == "cuda"
|
||||
and device_prop.major
|
||||
# Disable for AMDGPU as Triton is not ready to return n_regs for a compiled_binary.
|
||||
and not self.inductor_meta.get("is_hip")
|
||||
# Disable for Intel GPU as Triton is not ready to return n_regs for a compiled_binary.
|
||||
and self.device_type != "xpu"
|
||||
and device_prop.major >= 8
|
||||
):
|
||||
assert device_prop.regs_per_multiprocessor
|
||||
assert device_prop.max_threads_per_multi_processor
|
||||
assert device_prop.multi_processor_count
|
||||
for triton_config, compiled_binary in zip(
|
||||
self.configs, compiled_binaries
|
||||
):
|
||||
|
|
@ -290,21 +292,15 @@ class CachingAutotuner(KernelInterface):
|
|||
continue
|
||||
seen_configs.add(new_config)
|
||||
self.launchers.append(
|
||||
self._precompile_config(new_config, warm_cache_only)[1]
|
||||
self._precompile_config(new_config, warm_cache_only_with_cc)[1]
|
||||
)
|
||||
self.configs = None
|
||||
|
||||
def get_device_interface(self):
|
||||
# this code cannot run in compile workers, because it imports from torch
|
||||
from torch._dynamo.device_interface import get_interface_for_device
|
||||
|
||||
return get_interface_for_device(self.device_props.type.replace("hip", "cuda"))
|
||||
|
||||
def _precompile_config(self, cfg: Config, warm_cache_only: bool):
|
||||
def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: Optional[int]):
|
||||
"""Ahead of time compile a given autotuner config."""
|
||||
compile_meta = copy.deepcopy(self.triton_meta)
|
||||
for k, v in cfg.kwargs.items():
|
||||
if self.device_props.type != "hip":
|
||||
if torch.version.hip is not None:
|
||||
if k == "matrix_instr_nonkdim":
|
||||
compile_meta["matrix_instr_nonkdim"] = v
|
||||
continue
|
||||
|
|
@ -318,9 +314,22 @@ class CachingAutotuner(KernelInterface):
|
|||
"assert_indirect_indexing", True
|
||||
) and not self.inductor_meta.get("is_hip", False)
|
||||
|
||||
# device type will be "hip" rather than "cuda" here
|
||||
compile_meta["device_type"] = self.device_props.type
|
||||
compile_meta["cc"] = self.device_props.cc
|
||||
# Setting device_type="hip" required on ROCm to pass down to triton
|
||||
compile_meta["device_type"] = (
|
||||
self.device_type if torch.version.hip is None else "hip"
|
||||
)
|
||||
|
||||
if warm_cache_only_with_cc:
|
||||
cc = warm_cache_only_with_cc
|
||||
else:
|
||||
# Use device_type 'cuda' for both cuda and hip devices to retrieve
|
||||
# the compute capability.
|
||||
device_type = self.device_type if torch.version.hip is None else "cuda"
|
||||
device_id = compile_meta["device"]
|
||||
device = torch.device(device_type, device_id)
|
||||
cc = self.device_interface.get_compute_capability(device)
|
||||
|
||||
compile_meta["cc"] = cc
|
||||
|
||||
if ASTSource:
|
||||
compile_args = (
|
||||
|
|
@ -332,13 +341,13 @@ class CachingAutotuner(KernelInterface):
|
|||
),
|
||||
)
|
||||
|
||||
target = (compile_meta["device_type"], compile_meta["cc"])
|
||||
target = (compile_meta["device_type"], cc)
|
||||
options = {
|
||||
"num_warps": compile_meta["num_warps"],
|
||||
"num_stages": compile_meta["num_stages"],
|
||||
"debug": compile_meta["debug"],
|
||||
}
|
||||
if self.device_props.type != "hip":
|
||||
if torch.version.hip is not None:
|
||||
if "waves_per_eu" in compile_meta:
|
||||
options["waves_per_eu"] = compile_meta["waves_per_eu"]
|
||||
if "matrix_instr_nonkdim" in compile_meta:
|
||||
|
|
@ -353,21 +362,16 @@ class CachingAutotuner(KernelInterface):
|
|||
compile_args = (self.fn,)
|
||||
compile_kwargs = compile_meta
|
||||
|
||||
if warm_cache_only:
|
||||
if warm_cache_only_with_cc:
|
||||
return (
|
||||
triton.compile(*compile_args, **compile_kwargs),
|
||||
None,
|
||||
)
|
||||
|
||||
# importing from torch is safe now that precompile has returned
|
||||
from torch._dynamo.device_interface import DeviceGuard
|
||||
|
||||
device_interface = self.get_device_interface()
|
||||
|
||||
# load binary to the correct device
|
||||
with DeviceGuard(device_interface, compile_meta["device"]): # type: ignore[attr-defined]
|
||||
with DeviceGuard(self.device_interface, compile_meta["device"]): # type: ignore[attr-defined]
|
||||
# need to initialize context
|
||||
device_interface.synchronize(device_interface.current_device())
|
||||
self.device_interface.synchronize(self.device_interface.current_device())
|
||||
|
||||
try:
|
||||
binary = triton.compile(*compile_args, **compile_kwargs)
|
||||
|
|
@ -585,9 +589,8 @@ class CachingAutotuner(KernelInterface):
|
|||
)
|
||||
return float("inf")
|
||||
|
||||
device_interface = self.get_device_interface()
|
||||
stream = device_interface.get_raw_stream( # type: ignore[call-arg]
|
||||
device_interface.current_device()
|
||||
stream = self.device_interface.get_raw_stream( # type: ignore[call-arg]
|
||||
self.device_interface.current_device()
|
||||
)
|
||||
|
||||
def kernel_call():
|
||||
|
|
@ -694,7 +697,7 @@ class CachingAutotuner(KernelInterface):
|
|||
|
||||
from torch._inductor.codecache import CudaKernelParamCache
|
||||
|
||||
if self.device_props.type != "hip":
|
||||
if torch.version.hip is None:
|
||||
CudaKernelParamCache.set(key, params, launcher.bin.asm["cubin"])
|
||||
else:
|
||||
# There is some divergence between CUDA and ROCm here.
|
||||
|
|
@ -732,7 +735,7 @@ class CachingAutotuner(KernelInterface):
|
|||
|
||||
def benchmark_one_config(config):
|
||||
with self.lock:
|
||||
_, launcher = self._precompile_config(config, False)
|
||||
_, launcher = self._precompile_config(config, None)
|
||||
config2launcher[config] = launcher
|
||||
|
||||
out = self.bench(launcher, *cloned_args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -35,7 +35,6 @@ from .codegen.triton import (
|
|||
from .codegen.triton_utils import config_of, signature_to_meta
|
||||
from .exc import CUDACompileError
|
||||
from .ir import ChoiceCaller, PrimitiveInfoType
|
||||
from .runtime.hints import DeviceProperties
|
||||
from .runtime.runtime_utils import do_bench
|
||||
from .utils import get_dtype_size, Placeholder, sympy_dot, sympy_product, unique
|
||||
from .virtualized import V
|
||||
|
|
@ -148,7 +147,8 @@ class TritonTemplateKernel(TritonKernel):
|
|||
argdefs, _, signature = self.args.python_argdefs()
|
||||
triton_meta = {
|
||||
"signature": signature_to_meta(signature, size_dtype=self.index_dtype),
|
||||
"device": DeviceProperties.create(self.output_node.get_device()),
|
||||
"device": self.output_node.get_device().index,
|
||||
"device_type": self.output_node.get_device().type,
|
||||
"constants": {},
|
||||
}
|
||||
triton_meta["configs"] = [config_of(signature)]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user