mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Cutlass] Import cutlass python API for EVT (#150344)
This imports the pieces of the cutlass python API that are needed for python EVT tracing. It builds on existing importing for cutlass_library. Once EVT tracing has been added to cutlass_library (should be later this year) this can be removed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150344 Approved by: https://github.com/henrylhtsang, https://github.com/eellison
This commit is contained in:
parent
91923f0ee1
commit
d77e0cddfe
|
|
@ -135,6 +135,15 @@ class TestCutlassBackend(TestCase):
|
|||
), "Cutlass Kernels should have been filtered, GEMM size is too small"
|
||||
torch.testing.assert_close(Y_compiled, Y)
|
||||
|
||||
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
def test_import_cutlass(self):
|
||||
from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass
|
||||
|
||||
self.assertTrue(try_import_cutlass())
|
||||
|
||||
import cutlass # noqa: F401
|
||||
import cutlass_library # noqa: F401
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
def test_cutlass_backend_subproc_mm(self):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,6 @@
|
|||
import torch
|
||||
|
||||
|
||||
__version__ = torch.version.cuda
|
||||
|
||||
from .cuda import * # noqa: F403
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
# mypy: disable-error-code="no-untyped-def"
|
||||
# flake8: noqa
|
||||
class CUdeviceptr:
|
||||
pass
|
||||
|
||||
|
||||
class CUstream:
|
||||
def __init__(self, v):
|
||||
pass
|
||||
|
||||
|
||||
class CUresult:
|
||||
pass
|
||||
|
||||
|
||||
class nvrtc:
|
||||
pass
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
# mypy: disable-error-code="var-annotated"
|
||||
Dot = None
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
# typing: ignore
|
||||
# flake8: noqa
|
||||
from .special import *
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
# mypy: disable-error-code="var-annotated"
|
||||
erf = None
|
||||
|
|
@ -78,31 +78,74 @@ def try_import_cutlass() -> bool:
|
|||
# This is a temporary hack to avoid CUTLASS module naming conflicts.
|
||||
# TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues.
|
||||
|
||||
cutlass_py_full_path = os.path.abspath(
|
||||
os.path.join(config.cuda.cutlass_dir, "python/cutlass_library")
|
||||
)
|
||||
tmp_cutlass_py_full_path = os.path.abspath(
|
||||
os.path.join(cache_dir(), "torch_cutlass_library")
|
||||
)
|
||||
dst_link = os.path.join(tmp_cutlass_py_full_path, "cutlass_library")
|
||||
# TODO(mlazos): epilogue visitor tree currently lives in python/cutlass,
|
||||
# but will be moved to python/cutlass_library in the future
|
||||
def path_join(path0, path1):
|
||||
return os.path.abspath(os.path.join(path0, path1))
|
||||
|
||||
if os.path.isdir(cutlass_py_full_path):
|
||||
if tmp_cutlass_py_full_path not in sys.path:
|
||||
if os.path.exists(dst_link):
|
||||
assert os.path.islink(dst_link), (
|
||||
f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
|
||||
# contains both cutlass and cutlass_library
|
||||
# we need cutlass for eVT
|
||||
cutlass_python_path = path_join(config.cuda.cutlass_dir, "python")
|
||||
torch_root = os.path.abspath(os.path.dirname(torch.__file__))
|
||||
mock_src_path = os.path.join(
|
||||
torch_root,
|
||||
"_inductor",
|
||||
"codegen",
|
||||
"cuda",
|
||||
"cutlass_lib_extensions",
|
||||
"cutlass_mock_imports",
|
||||
)
|
||||
|
||||
cutlass_library_src_path = path_join(cutlass_python_path, "cutlass_library")
|
||||
cutlass_src_path = path_join(cutlass_python_path, "cutlass")
|
||||
pycute_src_path = path_join(cutlass_python_path, "pycute")
|
||||
|
||||
tmp_cutlass_full_path = os.path.abspath(os.path.join(cache_dir(), "torch_cutlass"))
|
||||
|
||||
dst_link_library = path_join(tmp_cutlass_full_path, "cutlass_library")
|
||||
dst_link_cutlass = path_join(tmp_cutlass_full_path, "cutlass")
|
||||
dst_link_pycute = path_join(tmp_cutlass_full_path, "pycute")
|
||||
|
||||
# mock modules to import cutlass
|
||||
mock_modules = ["cuda", "scipy", "pydot"]
|
||||
|
||||
if os.path.isdir(cutlass_python_path):
|
||||
if tmp_cutlass_full_path not in sys.path:
|
||||
|
||||
def link_and_append(dst_link, src_path, parent_dir):
|
||||
if os.path.exists(dst_link):
|
||||
assert os.path.islink(dst_link), (
|
||||
f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
|
||||
)
|
||||
assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
|
||||
src_path,
|
||||
), f"Symlink at {dst_link} does not point to {src_path}"
|
||||
else:
|
||||
os.makedirs(parent_dir, exist_ok=True)
|
||||
os.symlink(src_path, dst_link)
|
||||
|
||||
if parent_dir not in sys.path:
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
link_and_append(
|
||||
dst_link_library, cutlass_library_src_path, tmp_cutlass_full_path
|
||||
)
|
||||
link_and_append(dst_link_cutlass, cutlass_src_path, tmp_cutlass_full_path)
|
||||
link_and_append(dst_link_pycute, pycute_src_path, tmp_cutlass_full_path)
|
||||
|
||||
for module in mock_modules:
|
||||
link_and_append(
|
||||
path_join(tmp_cutlass_full_path, module), # dst_link
|
||||
path_join(mock_src_path, module), # src_path
|
||||
tmp_cutlass_full_path, # parent
|
||||
)
|
||||
assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
|
||||
cutlass_py_full_path
|
||||
), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}"
|
||||
else:
|
||||
os.makedirs(tmp_cutlass_py_full_path, exist_ok=True)
|
||||
os.symlink(cutlass_py_full_path, dst_link)
|
||||
sys.path.append(tmp_cutlass_py_full_path)
|
||||
|
||||
try:
|
||||
import cutlass # noqa: F401
|
||||
import cutlass_library.generator # noqa: F401
|
||||
import cutlass_library.library # noqa: F401
|
||||
import cutlass_library.manifest # noqa: F401
|
||||
import pycute # type: ignore[import-not-found] # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError as e:
|
||||
|
|
@ -113,7 +156,7 @@ def try_import_cutlass() -> bool:
|
|||
else:
|
||||
log.debug(
|
||||
"Failed to import CUTLASS packages: CUTLASS repo does not exist: %s",
|
||||
cutlass_py_full_path,
|
||||
cutlass_python_path,
|
||||
)
|
||||
return False
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user