mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[submodule] CUTLASS upgrade to 4.2.0 and change cutlass to cutlass_cppgen (#163092)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163092 Approved by: https://github.com/drisspg, https://github.com/Skylion007
This commit is contained in:
parent
4b7aed89d8
commit
a81a2e54ed
|
|
@ -257,7 +257,7 @@ class TestCutlassBackend(TestCase):
|
||||||
if config.is_fbcode():
|
if config.is_fbcode():
|
||||||
import python_cutlass
|
import python_cutlass
|
||||||
else:
|
else:
|
||||||
import cutlass as python_cutlass # noqa: F401
|
import cutlass_cppgen as python_cutlass # noqa: F401
|
||||||
import cutlass_library # noqa: F401
|
import cutlass_library # noqa: F401
|
||||||
|
|
||||||
def test_cutlass_key(self):
|
def test_cutlass_key(self):
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ if try_import_cutlass():
|
||||||
if config.is_fbcode():
|
if config.is_fbcode():
|
||||||
import python_cutlass # type: ignore[import-untyped, import-not-found] # noqa: F401
|
import python_cutlass # type: ignore[import-untyped, import-not-found] # noqa: F401
|
||||||
else:
|
else:
|
||||||
import cutlass as python_cutlass # type: ignore[import-untyped, import-not-found] # noqa: F401
|
import cutlass_cppgen as python_cutlass # type: ignore[import-untyped, import-not-found] # noqa: F401
|
||||||
CutlassTensor = python_cutlass.backend.evt.ir.tensor.Tensor
|
CutlassTensor = python_cutlass.backend.evt.ir.tensor.Tensor
|
||||||
|
|
||||||
BIAS_CODE = """def example_epilogue(accum, C, aux, bias):
|
BIAS_CODE = """def example_epilogue(accum, C, aux, bias):
|
||||||
|
|
|
||||||
2
third_party/cutlass
vendored
2
third_party/cutlass
vendored
|
|
@ -1 +1 @@
|
||||||
Subproject commit e51efbfe18fe4f4cbb66ab814c55bf4aa0185491
|
Subproject commit 57e3cfb47a2d9e0d46eb6335c3dc411498efa198
|
||||||
|
|
@ -38,7 +38,7 @@ if try_import_cutlass():
|
||||||
if config.is_fbcode():
|
if config.is_fbcode():
|
||||||
import python_cutlass # type: ignore[import-untyped, import-not-found] # noqa: F401
|
import python_cutlass # type: ignore[import-untyped, import-not-found] # noqa: F401
|
||||||
else:
|
else:
|
||||||
import cutlass as python_cutlass # type: ignore[import-untyped, import-not-found] # noqa: F401
|
import cutlass_cppgen as python_cutlass # type: ignore[import-untyped, import-not-found] # noqa: F401
|
||||||
|
|
||||||
from torch._inductor.codegen.cuda import cuda_env
|
from torch._inductor.codegen.cuda import cuda_env
|
||||||
from torch._inductor.utils import IndentedBuffer
|
from torch._inductor.utils import IndentedBuffer
|
||||||
|
|
@ -174,7 +174,7 @@ non-contiguous layout, received stride: {stride} and shape: {shape}"
|
||||||
def is_nested_visitor_type(t: type) -> bool:
|
def is_nested_visitor_type(t: type) -> bool:
|
||||||
return ".".join([t.__module__, t.__qualname__]) in {
|
return ".".join([t.__module__, t.__qualname__]) in {
|
||||||
"python_cutlass.backend.c_types.visitor_factory.<locals>.VisitorType",
|
"python_cutlass.backend.c_types.visitor_factory.<locals>.VisitorType",
|
||||||
"cutlass.backend.c_types.visitor_factory.<locals>.VisitorType",
|
"cutlass_cppgen.backend.c_types.visitor_factory.<locals>.VisitorType",
|
||||||
}
|
}
|
||||||
|
|
||||||
buffer = IndentedBuffer()
|
buffer = IndentedBuffer()
|
||||||
|
|
@ -235,7 +235,7 @@ non-contiguous layout, received stride: {stride} and shape: {shape}"
|
||||||
# Once again, need to check for local class type for stride tuple
|
# Once again, need to check for local class type for stride tuple
|
||||||
if str(arg_ty) in {
|
if str(arg_ty) in {
|
||||||
"<class 'python_cutlass.backend.c_types.tuple_factory_.<locals>.TupleType'>",
|
"<class 'python_cutlass.backend.c_types.tuple_factory_.<locals>.TupleType'>",
|
||||||
"<class 'cutlass.backend.c_types.tuple_factory_.<locals>.TupleType'>",
|
"<class 'cutlass_cppgen.backend.c_types.tuple_factory_.<locals>.TupleType'>",
|
||||||
}:
|
}:
|
||||||
DEFAULT_STRIDE_LEN = 3
|
DEFAULT_STRIDE_LEN = 3
|
||||||
assert len(node.get_layout().stride) <= DEFAULT_STRIDE_LEN
|
assert len(node.get_layout().stride) <= DEFAULT_STRIDE_LEN
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ def move_cutlass_compiled_cache() -> None:
|
||||||
if config.is_fbcode():
|
if config.is_fbcode():
|
||||||
import python_cutlass # type: ignore[import-not-found]
|
import python_cutlass # type: ignore[import-not-found]
|
||||||
else:
|
else:
|
||||||
import cutlass as python_cutlass # type: ignore[import-not-found] # noqa: F401
|
import cutlass_cppgen as python_cutlass # type: ignore[import-not-found] # noqa: F401
|
||||||
|
|
||||||
# Check if the CACHE_FILE attribute exists in python_cutlass and if the file exists
|
# Check if the CACHE_FILE attribute exists in python_cutlass and if the file exists
|
||||||
if not hasattr(python_cutlass, "CACHE_FILE") or not os.path.exists(
|
if not hasattr(python_cutlass, "CACHE_FILE") or not os.path.exists(
|
||||||
|
|
@ -118,7 +118,7 @@ def try_import_cutlass() -> bool:
|
||||||
tmp_cutlass_full_path = os.path.abspath(os.path.join(cache_dir(), "torch_cutlass"))
|
tmp_cutlass_full_path = os.path.abspath(os.path.join(cache_dir(), "torch_cutlass"))
|
||||||
|
|
||||||
dst_link_library = path_join(tmp_cutlass_full_path, "cutlass_library")
|
dst_link_library = path_join(tmp_cutlass_full_path, "cutlass_library")
|
||||||
dst_link_cutlass = path_join(tmp_cutlass_full_path, "cutlass")
|
dst_link_cutlass = path_join(tmp_cutlass_full_path, "cutlass_cppgen")
|
||||||
dst_link_pycute = path_join(tmp_cutlass_full_path, "pycute")
|
dst_link_pycute = path_join(tmp_cutlass_full_path, "pycute")
|
||||||
|
|
||||||
# mock modules to import cutlass
|
# mock modules to import cutlass
|
||||||
|
|
@ -156,7 +156,7 @@ def try_import_cutlass() -> bool:
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import cutlass # noqa: F401, F811
|
import cutlass_cppgen # noqa: F401, F811
|
||||||
import cutlass_library.generator # noqa: F401
|
import cutlass_library.generator # noqa: F401
|
||||||
import cutlass_library.library # noqa: F401
|
import cutlass_library.library # noqa: F401
|
||||||
import cutlass_library.manifest # noqa: F401
|
import cutlass_library.manifest # noqa: F401
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user