[AOTI] Add device guard when launching autotune kernels (#158034)

Summary: Fix https://github.com/pytorch/pytorch/issues/157737. When launching Triton kernels in the autotune block, we need to consider the fact that the model may not always be on device 0. The reason this was not caught on CI is because test_on_gpu_device1 requires multi_gpu and was not run on a multi_gpu instance. Added test_on_gpu_device1 and other similar multi_gpu tests back.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158034
Approved by: https://github.com/eqy, https://github.com/yushangdi
This commit is contained in:
Bin Bao 2025-07-10 08:37:46 -07:00 committed by PyTorch MergeBot
parent 7d4228dbfd
commit 326e751d07
2 changed files with 10 additions and 5 deletions

View File

@ -382,9 +382,10 @@ test_einops() {
test_inductor_distributed() {
# Smuggle a few multi-gpu tests here so that we don't have to request another large node
echo "Testing multi_gpu tests in test_torchinductor"
python test/run_test.py -i inductor/test_torchinductor.py -k test_multi_gpu --verbose
python test/run_test.py -i inductor/test_aot_inductor.py -k test_non_default_cuda_device --verbose
python test/run_test.py -i inductor/test_aot_inductor.py -k test_replicate_on_devices --verbose
python test/run_test.py -i inductor/test_aot_inductor.py -k test_on_gpu_device1 --verbose
python test/run_test.py -i inductor/test_aot_inductor.py -k test_non_default_gpu_device --verbose
python test/run_test.py -i inductor/test_aot_inductor.py -k test_load_package_multiple_gpus --verbose
python test/run_test.py -i distributed/test_c10d_functional_native.py --verbose
python test/run_test.py -i distributed/tensor/test_dtensor_compile.py --verbose
python test/run_test.py -i distributed/tensor/parallel/test_micro_pipeline_tp.py --verbose

View File

@ -1247,9 +1247,6 @@ class PythonWrapperCodegen(CodeGen):
f"with {V.graph.device_ops.device_guard(device_idx)}:"
)
self.kernel_autotune_calls.do_indent()
self.kernel_autotune_calls.writeline(
V.graph.device_ops.set_device(device_idx)
)
if is_codegen_graph_partition_subgraph(self):
# Need get_raw_stream for subgraph
self.write_get_raw_stream_header()
@ -2679,9 +2676,16 @@ class PythonWrapperCodegen(CodeGen):
arg_str = self.generate_example_arg_value(arg, arg_type, raw_arg)
all_args.append(arg_str if key is None else f"{key}={arg_str}")
# Make sure kernel launch under a device guard because models don't always run on device 0
self.kernel_autotune_calls.writeline(
f"with {V.graph.device_ops.device_guard(device.index)}:"
)
self.kernel_autotune_calls.do_indent()
self.kernel_autotune_calls.writeline(
f"{kernel_name}.run({', '.join(all_args)}, stream={stream_name})"
)
self.kernel_autotune_calls.do_unindent()
self.kernel_autotune_calls.writeline(
DelayReplaceLine("<del_call>", get_autotune_deletion_call, "<del_call>")
)