mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[AOTI][refactor] Remove model_container_runner_cuda.cpp (#116113)
Differential Revision: [D52301272](https://our.internmc.facebook.com/intern/diff/D52301272) Pull Request resolved: https://github.com/pytorch/pytorch/pull/116113 Approved by: https://github.com/khabinov ghstack dependencies: #116047
This commit is contained in:
parent
f71d302c63
commit
2dce364634
|
|
@ -652,7 +652,6 @@ libtorch_cuda_core_sources = [
|
|||
"torch/csrc/CudaIPCTypes.cpp",
|
||||
"torch/csrc/cuda/comm.cpp",
|
||||
"torch/csrc/cuda/memory_snapshot.cpp",
|
||||
"torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp",
|
||||
"torch/csrc/inductor/aoti_torch/shim_cuda.cpp",
|
||||
"torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp",
|
||||
"torch/csrc/profiler/stubs/cuda.cpp",
|
||||
|
|
|
|||
|
|
@ -9,6 +9,10 @@ class TORCH_API AOTIModelContainerRunnerCpu : public AOTIModelContainerRunner {
|
|||
const std::string& model_so_path,
|
||||
size_t num_models = 1)
|
||||
: AOTIModelContainerRunner(model_so_path, num_models, true, "") {}
|
||||
|
||||
std::vector<at::Tensor> run(std::vector<at::Tensor>& inputs) {
|
||||
return AOTIModelContainerRunner::run(inputs);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace torch::inductor
|
||||
|
|
|
|||
|
|
@ -1,16 +0,0 @@
|
|||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
|
||||
|
||||
namespace torch::inductor {
|
||||
|
||||
std::vector<at::Tensor> AOTIModelContainerRunnerCuda::run(
|
||||
std::vector<at::Tensor>& inputs,
|
||||
cudaStream_t cuda_stream_handle) {
|
||||
if (cuda_stream_handle == nullptr) {
|
||||
cuda_stream_handle = c10::cuda::getCurrentCUDAStream().stream();
|
||||
}
|
||||
return AOTIModelContainerRunner::run(
|
||||
inputs, reinterpret_cast<AOTInductorStreamHandle>(cuda_stream_handle));
|
||||
}
|
||||
|
||||
} // namespace torch::inductor
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
|
||||
|
||||
namespace torch::inductor {
|
||||
|
|
@ -15,7 +15,13 @@ class TORCH_API AOTIModelContainerRunnerCuda : public AOTIModelContainerRunner {
|
|||
|
||||
std::vector<at::Tensor> run(
|
||||
std::vector<at::Tensor>& inputs,
|
||||
cudaStream_t cuda_stream_handle = nullptr);
|
||||
cudaStream_t cuda_stream_handle = nullptr) {
|
||||
if (cuda_stream_handle == nullptr) {
|
||||
cuda_stream_handle = c10::cuda::getCurrentCUDAStream().stream();
|
||||
}
|
||||
return AOTIModelContainerRunner::run(
|
||||
inputs, reinterpret_cast<AOTInductorStreamHandle>(cuda_stream_handle));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace torch::inductor
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user