mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[AOTI][Tooling][7/n] Add debug printing support for JIT inductor codegen path as well (#135285)
Summary: 1. Add the debug printer call to a level lower for triton kernel python wrapper codegen path 2. Add `torch.save()` for jit inductor as well 3. This also fixes the issue introduced in D61949020 (at python wrapper code level for triton kernel not printing) Test Plan: ``` AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=1 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+graph, inductor, +schedule, output_code" buck2 run -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100 @//mode/opt fbcode//caffe2/test/inductor:test_aot_inductor -- -r test_addmm_abi_compatible_cuda ``` Differential Revision: D62272588 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135285 Approved by: https://github.com/chenyang78
This commit is contained in:
parent
fc88ba260f
commit
1f15973657
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
@ -136,8 +137,23 @@ class DebugPrinterManager:
|
||||||
# TODO: add non-abi compatible mode debug printing info
|
# TODO: add non-abi compatible mode debug printing info
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# currently, not cpp wrapper codegen mode not supported.
|
cwd = os.getcwd()
|
||||||
pass
|
saved_dir = cwd + "/tmp/jit_inductor/"
|
||||||
|
if not os.path.exists(saved_dir):
|
||||||
|
log.info(
|
||||||
|
"Creating directory to save inductor intermediate tensor values."
|
||||||
|
)
|
||||||
|
os.makedirs(saved_dir)
|
||||||
|
# Save the model to the directory
|
||||||
|
saved_path = saved_dir + f"{launch_prefix}_{kernel_name}_{arg}.pt"
|
||||||
|
log.info(
|
||||||
|
"Saved intermediate tensor %s for %s to %s",
|
||||||
|
arg,
|
||||||
|
kernel_name,
|
||||||
|
saved_path,
|
||||||
|
)
|
||||||
|
line = f"torch.save({arg}, '{saved_path}')"
|
||||||
|
V.graph.wrapper_code.writeline(line)
|
||||||
|
|
||||||
def codegen_intermediate_tensor_value_print(
|
def codegen_intermediate_tensor_value_print(
|
||||||
self,
|
self,
|
||||||
|
|
@ -171,5 +187,7 @@ class DebugPrinterManager:
|
||||||
# TODO: add non-abi compatible mode debug printing info
|
# TODO: add non-abi compatible mode debug printing info
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
line = f"print('{launch_prefix} - {kernel_name} - {arg}', {arg})"
|
line = (
|
||||||
|
f"print('inductor: {launch_prefix} - {kernel_name} - {arg}', {arg})"
|
||||||
|
)
|
||||||
V.graph.wrapper_code.writeline(line)
|
V.graph.wrapper_code.writeline(line)
|
||||||
|
|
|
||||||
|
|
@ -1679,9 +1679,15 @@ class WrapperCodeGen(CodeGen):
|
||||||
if grid_extra_kwargs:
|
if grid_extra_kwargs:
|
||||||
grid_str = f"{grid_str}, {grid_extra_kwargs}"
|
grid_str = f"{grid_str}, {grid_extra_kwargs}"
|
||||||
grid_str = f"{grid_fn}({grid_str})"
|
grid_str = f"{grid_fn}({grid_str})"
|
||||||
self.writeline(
|
# add debug printer code for triton kernel calls at (jit) inductor level
|
||||||
f"{kernel_name}.run({call_args_str}, grid={grid_str}, stream={stream_name})"
|
debug_printer_manager = V.graph.wrapper_code.debug_printer
|
||||||
|
debug_printer_manager.set_printer_args(
|
||||||
|
call_args, kernel_name, arg_types, None
|
||||||
)
|
)
|
||||||
|
with debug_printer_manager:
|
||||||
|
self.writeline(
|
||||||
|
f"{kernel_name}.run({call_args_str}, grid={grid_str}, stream={stream_name})"
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
config.triton.autotune_at_compile_time
|
config.triton.autotune_at_compile_time
|
||||||
and kernel_name not in self.kernel_autotune_names
|
and kernel_name not in self.kernel_autotune_names
|
||||||
|
|
|
||||||
|
|
@ -1025,7 +1025,7 @@ AOTI_TORCH_EXPORT void aoti_torch_save_tensor_handle(
|
||||||
fout.write(bytes.data(), bytes.size());
|
fout.write(bytes.data(), bytes.size());
|
||||||
fout.close();
|
fout.close();
|
||||||
|
|
||||||
std::cout << "aoti_torch_save_tensor_handle: Saved tensor to: "
|
std::cout << "aoti_torch_save_tensor_handle: Saved tensor to "
|
||||||
<< tensor_filepath_to_save << std::endl;
|
<< tensor_filepath_to_save << std::endl;
|
||||||
#endif // !defined(C10_MOBILE)
|
#endif // !defined(C10_MOBILE)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user