[AOTI][Tooling][7/n] Add debug printing support for JIT inductor codegen path as well (#135285)

Summary:
1.  Add the debug printer call to a level lower for triton kernel python wrapper codegen path
2. Add `torch.save()` for jit inductor as well
3. This also fixes the issue introduced in D61949020 (at python wrapper code level for triton kernel not printing)

Test Plan:
```
AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=1  TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+graph, inductor, +schedule, output_code" buck2 run -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100 @//mode/opt fbcode//caffe2/test/inductor:test_aot_inductor -- -r test_addmm_abi_compatible_cuda
```

Differential Revision: D62272588

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135285
Approved by: https://github.com/chenyang78
This commit is contained in:
Rachel Guo 2024-09-10 19:24:58 +00:00 committed by PyTorch MergeBot
parent fc88ba260f
commit 1f15973657
3 changed files with 30 additions and 6 deletions

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import functools
import logging
import os
from enum import Enum
from typing import List, Optional
@ -136,8 +137,23 @@ class DebugPrinterManager:
# TODO: add non-abi compatible mode debug printing info
pass
else:
# currently, not cpp wrapper codegen mode not supported.
pass
cwd = os.getcwd()
saved_dir = cwd + "/tmp/jit_inductor/"
if not os.path.exists(saved_dir):
log.info(
"Creating directory to save inductor intermediate tensor values."
)
os.makedirs(saved_dir)
# Save the model to the directory
saved_path = saved_dir + f"{launch_prefix}_{kernel_name}_{arg}.pt"
log.info(
"Saved intermediate tensor %s for %s to %s",
arg,
kernel_name,
saved_path,
)
line = f"torch.save({arg}, '{saved_path}')"
V.graph.wrapper_code.writeline(line)
def codegen_intermediate_tensor_value_print(
self,
@ -171,5 +187,7 @@ class DebugPrinterManager:
# TODO: add non-abi compatible mode debug printing info
pass
else:
line = f"print('{launch_prefix} - {kernel_name} - {arg}', {arg})"
line = (
f"print('inductor: {launch_prefix} - {kernel_name} - {arg}', {arg})"
)
V.graph.wrapper_code.writeline(line)

View File

@ -1679,9 +1679,15 @@ class WrapperCodeGen(CodeGen):
if grid_extra_kwargs:
grid_str = f"{grid_str}, {grid_extra_kwargs}"
grid_str = f"{grid_fn}({grid_str})"
self.writeline(
f"{kernel_name}.run({call_args_str}, grid={grid_str}, stream={stream_name})"
# add debug printer code for triton kernel calls at (jit) inductor level
debug_printer_manager = V.graph.wrapper_code.debug_printer
debug_printer_manager.set_printer_args(
call_args, kernel_name, arg_types, None
)
with debug_printer_manager:
self.writeline(
f"{kernel_name}.run({call_args_str}, grid={grid_str}, stream={stream_name})"
)
if (
config.triton.autotune_at_compile_time
and kernel_name not in self.kernel_autotune_names

View File

@ -1025,7 +1025,7 @@ AOTI_TORCH_EXPORT void aoti_torch_save_tensor_handle(
fout.write(bytes.data(), bytes.size());
fout.close();
std::cout << "aoti_torch_save_tensor_handle: Saved tensor to: "
std::cout << "aoti_torch_save_tensor_handle: Saved tensor to "
<< tensor_filepath_to_save << std::endl;
#endif // !defined(C10_MOBILE)
}