[AOTI][refactor] Remove model_container_runner_cuda.cpp (#116113)

Differential Revision: [D52301272](https://our.internmc.facebook.com/intern/diff/D52301272)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116113
Approved by: https://github.com/khabinov
ghstack dependencies: #116047
This commit is contained in:
Bin Bao 2023-12-20 08:36:12 -08:00 committed by PyTorch MergeBot
parent f71d302c63
commit 2dce364634
4 changed files with 12 additions and 19 deletions

View File

@ -652,7 +652,6 @@ libtorch_cuda_core_sources = [
"torch/csrc/CudaIPCTypes.cpp",
"torch/csrc/cuda/comm.cpp",
"torch/csrc/cuda/memory_snapshot.cpp",
"torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp",
"torch/csrc/inductor/aoti_torch/shim_cuda.cpp",
"torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp",
"torch/csrc/profiler/stubs/cuda.cpp",

View File

@ -9,6 +9,10 @@ class TORCH_API AOTIModelContainerRunnerCpu : public AOTIModelContainerRunner {
const std::string& model_so_path,
size_t num_models = 1)
: AOTIModelContainerRunner(model_so_path, num_models, true, "") {}
std::vector<at::Tensor> run(std::vector<at::Tensor>& inputs) {
return AOTIModelContainerRunner::run(inputs);
}
};
} // namespace torch::inductor

View File

@ -1,16 +0,0 @@
#include <c10/cuda/CUDAStream.h>
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
namespace torch::inductor {
std::vector<at::Tensor> AOTIModelContainerRunnerCuda::run(
std::vector<at::Tensor>& inputs,
cudaStream_t cuda_stream_handle) {
if (cuda_stream_handle == nullptr) {
cuda_stream_handle = c10::cuda::getCurrentCUDAStream().stream();
}
return AOTIModelContainerRunner::run(
inputs, reinterpret_cast<AOTInductorStreamHandle>(cuda_stream_handle));
}
} // namespace torch::inductor

View File

@ -1,6 +1,6 @@
#pragma once
#include <cuda_runtime_api.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
namespace torch::inductor {
@ -15,7 +15,13 @@ class TORCH_API AOTIModelContainerRunnerCuda : public AOTIModelContainerRunner {
std::vector<at::Tensor> run(
std::vector<at::Tensor>& inputs,
cudaStream_t cuda_stream_handle = nullptr);
cudaStream_t cuda_stream_handle = nullptr) {
if (cuda_stream_handle == nullptr) {
cuda_stream_handle = c10::cuda::getCurrentCUDAStream().stream();
}
return AOTIModelContainerRunner::run(
inputs, reinterpret_cast<AOTInductorStreamHandle>(cuda_stream_handle));
}
};
} // namespace torch::inductor