[Cutlass] Import cutlass python API for EVT (#150344)

This imports the pieces of the cutlass python API that are needed for python EVT tracing. It builds on existing importing for cutlass_library. Once EVT tracing has been added to cutlass_library (should be later this year) this can be removed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150344
Approved by: https://github.com/henrylhtsang, https://github.com/eellison
This commit is contained in:
Michael Lazos 2025-04-14 15:01:51 -07:00 committed by PyTorch MergeBot
parent 91923f0ee1
commit d77e0cddfe
8 changed files with 102 additions and 20 deletions

View File

@ -135,6 +135,15 @@ class TestCutlassBackend(TestCase):
), "Cutlass Kernels should have been filtered, GEMM size is too small"
torch.testing.assert_close(Y_compiled, Y)
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_import_cutlass(self):
from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass
self.assertTrue(try_import_cutlass())
import cutlass # noqa: F401
import cutlass_library # noqa: F401
@unittest.skipIf(not SM90OrLater, "need sm_90")
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_cutlass_backend_subproc_mm(self):

View File

@ -0,0 +1,6 @@
import torch
__version__ = torch.version.cuda
from .cuda import * # noqa: F403

View File

@ -0,0 +1,17 @@
# mypy: disable-error-code="no-untyped-def"
# flake8: noqa
class CUdeviceptr:
pass
class CUstream:
def __init__(self, v):
pass
class CUresult:
pass
class nvrtc:
pass

View File

@ -0,0 +1,2 @@
# mypy: disable-error-code="var-annotated"
Dot = None

View File

@ -0,0 +1,3 @@
# typing: ignore
# flake8: noqa
from .special import *

View File

@ -0,0 +1,2 @@
# mypy: disable-error-code="var-annotated"
erf = None

View File

@ -78,31 +78,74 @@ def try_import_cutlass() -> bool:
# This is a temporary hack to avoid CUTLASS module naming conflicts.
# TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues.
cutlass_py_full_path = os.path.abspath(
os.path.join(config.cuda.cutlass_dir, "python/cutlass_library")
)
tmp_cutlass_py_full_path = os.path.abspath(
os.path.join(cache_dir(), "torch_cutlass_library")
)
dst_link = os.path.join(tmp_cutlass_py_full_path, "cutlass_library")
# TODO(mlazos): epilogue visitor tree currently lives in python/cutlass,
# but will be moved to python/cutlass_library in the future
def path_join(path0, path1):
return os.path.abspath(os.path.join(path0, path1))
if os.path.isdir(cutlass_py_full_path):
if tmp_cutlass_py_full_path not in sys.path:
if os.path.exists(dst_link):
assert os.path.islink(dst_link), (
f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
# contains both cutlass and cutlass_library
# we need cutlass for eVT
cutlass_python_path = path_join(config.cuda.cutlass_dir, "python")
torch_root = os.path.abspath(os.path.dirname(torch.__file__))
mock_src_path = os.path.join(
torch_root,
"_inductor",
"codegen",
"cuda",
"cutlass_lib_extensions",
"cutlass_mock_imports",
)
cutlass_library_src_path = path_join(cutlass_python_path, "cutlass_library")
cutlass_src_path = path_join(cutlass_python_path, "cutlass")
pycute_src_path = path_join(cutlass_python_path, "pycute")
tmp_cutlass_full_path = os.path.abspath(os.path.join(cache_dir(), "torch_cutlass"))
dst_link_library = path_join(tmp_cutlass_full_path, "cutlass_library")
dst_link_cutlass = path_join(tmp_cutlass_full_path, "cutlass")
dst_link_pycute = path_join(tmp_cutlass_full_path, "pycute")
# mock modules to import cutlass
mock_modules = ["cuda", "scipy", "pydot"]
if os.path.isdir(cutlass_python_path):
if tmp_cutlass_full_path not in sys.path:
def link_and_append(dst_link, src_path, parent_dir):
if os.path.exists(dst_link):
assert os.path.islink(dst_link), (
f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
)
assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
src_path,
), f"Symlink at {dst_link} does not point to {src_path}"
else:
os.makedirs(parent_dir, exist_ok=True)
os.symlink(src_path, dst_link)
if parent_dir not in sys.path:
sys.path.append(parent_dir)
link_and_append(
dst_link_library, cutlass_library_src_path, tmp_cutlass_full_path
)
link_and_append(dst_link_cutlass, cutlass_src_path, tmp_cutlass_full_path)
link_and_append(dst_link_pycute, pycute_src_path, tmp_cutlass_full_path)
for module in mock_modules:
link_and_append(
path_join(tmp_cutlass_full_path, module), # dst_link
path_join(mock_src_path, module), # src_path
tmp_cutlass_full_path, # parent
)
assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
cutlass_py_full_path
), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}"
else:
os.makedirs(tmp_cutlass_py_full_path, exist_ok=True)
os.symlink(cutlass_py_full_path, dst_link)
sys.path.append(tmp_cutlass_py_full_path)
try:
import cutlass # noqa: F401
import cutlass_library.generator # noqa: F401
import cutlass_library.library # noqa: F401
import cutlass_library.manifest # noqa: F401
import pycute # type: ignore[import-not-found] # noqa: F401
return True
except ImportError as e:
@ -113,7 +156,7 @@ def try_import_cutlass() -> bool:
else:
log.debug(
"Failed to import CUTLASS packages: CUTLASS repo does not exist: %s",
cutlass_py_full_path,
cutlass_python_path,
)
return False