mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
typing tvm.py (#160369)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160369 Approved by: https://github.com/Skylion007 ghstack dependencies: #160362, #160363, #160364, #160365, #160366, #160367, #160368
This commit is contained in:
parent
39ca0ce0c8
commit
4d5f92aa39
|
|
@ -1,5 +1,3 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
This module provides TVM backend integration for TorchDynamo.
|
||||
|
||||
|
|
@ -29,9 +27,10 @@ import os
|
|||
import sys
|
||||
import tempfile
|
||||
from types import MappingProxyType
|
||||
from typing import Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
|
||||
from .common import device_from_inputs, fake_tensor_unsupported
|
||||
from .registry import register_backend
|
||||
|
|
@ -41,15 +40,16 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
@register_backend
|
||||
@fake_tensor_unsupported
|
||||
@fake_tensor_unsupported # type: ignore[arg-type]
|
||||
def tvm(
|
||||
gm,
|
||||
example_inputs,
|
||||
gm: fx.GraphModule,
|
||||
example_inputs: list[torch.Tensor],
|
||||
*,
|
||||
options: Optional[MappingProxyType] = MappingProxyType(
|
||||
{"scheduler": None, "trials": 20000, "opt_level": 3}
|
||||
),
|
||||
):
|
||||
options: Optional[MappingProxyType[str, Any]] = None,
|
||||
) -> Callable[..., Any]:
|
||||
if options is None:
|
||||
options = MappingProxyType({"scheduler": None, "trials": 20000, "opt_level": 3})
|
||||
assert options is not None
|
||||
import tvm # type: ignore[import]
|
||||
from tvm import relay # type: ignore[import]
|
||||
from tvm.contrib import graph_executor # type: ignore[import]
|
||||
|
|
@ -147,7 +147,7 @@ def tvm(
|
|||
)
|
||||
m = graph_executor.GraphModule(lib["default"](dev))
|
||||
|
||||
def to_torch_tensor(nd_tensor):
|
||||
def to_torch_tensor(nd_tensor: tvm.nd.array) -> torch.Tensor:
|
||||
"""A helper function to transfer a NDArray to torch.tensor."""
|
||||
if nd_tensor.dtype == "bool":
|
||||
# DLPack does not support boolean so it can't be handled by
|
||||
|
|
@ -156,7 +156,7 @@ def tvm(
|
|||
return torch.from_numpy(nd_tensor.numpy())
|
||||
return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack())
|
||||
|
||||
def to_tvm_tensor(torch_tensor):
|
||||
def to_tvm_tensor(torch_tensor: torch.Tensor) -> tvm.nd.array:
|
||||
"""A helper function to transfer a torch.tensor to NDArray."""
|
||||
if torch_tensor.dtype == torch.bool:
|
||||
# same reason as above, fallback to numpy conversion which
|
||||
|
|
@ -164,7 +164,7 @@ def tvm(
|
|||
return tvm.nd.array(torch_tensor.cpu().numpy())
|
||||
return tvm.nd.from_dlpack(torch_tensor)
|
||||
|
||||
def exec_tvm(*i_args):
|
||||
def exec_tvm(*i_args: torch.Tensor) -> list[torch.Tensor]:
|
||||
args = [a.contiguous() for a in i_args]
|
||||
shape_info, _ = m.get_input_info()
|
||||
active_inputs = {name for name, _ in shape_info.items()}
|
||||
|
|
@ -193,7 +193,7 @@ tvm_meta_schedule = functools.partial(tvm, scheduler="meta_schedule")
|
|||
tvm_auto_scheduler = functools.partial(tvm, scheduler="auto_scheduler")
|
||||
|
||||
|
||||
def has_tvm():
|
||||
def has_tvm() -> bool:
|
||||
try:
|
||||
importlib.import_module("tvm")
|
||||
return True
|
||||
|
|
@ -202,7 +202,7 @@ def has_tvm():
|
|||
|
||||
|
||||
@functools.cache
|
||||
def llvm_target():
|
||||
def llvm_target() -> str:
|
||||
if sys.platform == "linux":
|
||||
cpuinfo = open("/proc/cpuinfo").read()
|
||||
if "avx512" in cpuinfo:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user