typing tvm.py (#160369)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160369
Approved by: https://github.com/Skylion007
ghstack dependencies: #160362, #160363, #160364, #160365, #160366, #160367, #160368
This commit is contained in:
Lucas Kabela 2025-08-14 16:03:35 -07:00 committed by PyTorch MergeBot
parent 39ca0ce0c8
commit 4d5f92aa39

View File

@ -1,5 +1,3 @@
# mypy: ignore-errors
"""
This module provides TVM backend integration for TorchDynamo.
@ -29,9 +27,10 @@ import os
import sys
import tempfile
from types import MappingProxyType
from typing import Optional
from typing import Any, Callable, Optional
import torch
from torch import fx
from .common import device_from_inputs, fake_tensor_unsupported
from .registry import register_backend
@ -41,15 +40,16 @@ log = logging.getLogger(__name__)
@register_backend
@fake_tensor_unsupported
@fake_tensor_unsupported # type: ignore[arg-type]
def tvm(
gm,
example_inputs,
gm: fx.GraphModule,
example_inputs: list[torch.Tensor],
*,
options: Optional[MappingProxyType] = MappingProxyType(
{"scheduler": None, "trials": 20000, "opt_level": 3}
),
):
options: Optional[MappingProxyType[str, Any]] = None,
) -> Callable[..., Any]:
if options is None:
options = MappingProxyType({"scheduler": None, "trials": 20000, "opt_level": 3})
assert options is not None
import tvm # type: ignore[import]
from tvm import relay # type: ignore[import]
from tvm.contrib import graph_executor # type: ignore[import]
@ -147,7 +147,7 @@ def tvm(
)
m = graph_executor.GraphModule(lib["default"](dev))
def to_torch_tensor(nd_tensor):
def to_torch_tensor(nd_tensor: tvm.nd.array) -> torch.Tensor:
"""A helper function to transfer a NDArray to torch.tensor."""
if nd_tensor.dtype == "bool":
# DLPack does not support boolean so it can't be handled by
@ -156,7 +156,7 @@ def tvm(
return torch.from_numpy(nd_tensor.numpy())
return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack())
def to_tvm_tensor(torch_tensor):
def to_tvm_tensor(torch_tensor: torch.Tensor) -> tvm.nd.array:
"""A helper function to transfer a torch.tensor to NDArray."""
if torch_tensor.dtype == torch.bool:
# same reason as above, fallback to numpy conversion which
@ -164,7 +164,7 @@ def tvm(
return tvm.nd.array(torch_tensor.cpu().numpy())
return tvm.nd.from_dlpack(torch_tensor)
def exec_tvm(*i_args):
def exec_tvm(*i_args: torch.Tensor) -> list[torch.Tensor]:
args = [a.contiguous() for a in i_args]
shape_info, _ = m.get_input_info()
active_inputs = {name for name, _ in shape_info.items()}
@ -193,7 +193,7 @@ tvm_meta_schedule = functools.partial(tvm, scheduler="meta_schedule")
tvm_auto_scheduler = functools.partial(tvm, scheduler="auto_scheduler")
def has_tvm():
def has_tvm() -> bool:
try:
importlib.import_module("tvm")
return True
@ -202,7 +202,7 @@ def has_tvm():
@functools.cache
def llvm_target():
def llvm_target() -> str:
if sys.platform == "linux":
cpuinfo = open("/proc/cpuinfo").read()
if "avx512" in cpuinfo: