mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[AOTI][refactor] Remove model_container_runner_cuda.cpp (#116113)
Differential Revision: [D52301272](https://our.internmc.facebook.com/intern/diff/D52301272) Pull Request resolved: https://github.com/pytorch/pytorch/pull/116113 Approved by: https://github.com/khabinov ghstack dependencies: #116047
This commit is contained in:
parent
f71d302c63
commit
2dce364634
|
|
@ -652,7 +652,6 @@ libtorch_cuda_core_sources = [
|
||||||
"torch/csrc/CudaIPCTypes.cpp",
|
"torch/csrc/CudaIPCTypes.cpp",
|
||||||
"torch/csrc/cuda/comm.cpp",
|
"torch/csrc/cuda/comm.cpp",
|
||||||
"torch/csrc/cuda/memory_snapshot.cpp",
|
"torch/csrc/cuda/memory_snapshot.cpp",
|
||||||
"torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp",
|
|
||||||
"torch/csrc/inductor/aoti_torch/shim_cuda.cpp",
|
"torch/csrc/inductor/aoti_torch/shim_cuda.cpp",
|
||||||
"torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp",
|
"torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp",
|
||||||
"torch/csrc/profiler/stubs/cuda.cpp",
|
"torch/csrc/profiler/stubs/cuda.cpp",
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,10 @@ class TORCH_API AOTIModelContainerRunnerCpu : public AOTIModelContainerRunner {
|
||||||
const std::string& model_so_path,
|
const std::string& model_so_path,
|
||||||
size_t num_models = 1)
|
size_t num_models = 1)
|
||||||
: AOTIModelContainerRunner(model_so_path, num_models, true, "") {}
|
: AOTIModelContainerRunner(model_so_path, num_models, true, "") {}
|
||||||
|
|
||||||
|
std::vector<at::Tensor> run(std::vector<at::Tensor>& inputs) {
|
||||||
|
return AOTIModelContainerRunner::run(inputs);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace torch::inductor
|
} // namespace torch::inductor
|
||||||
|
|
|
||||||
|
|
@ -1,16 +0,0 @@
|
||||||
#include <c10/cuda/CUDAStream.h>
|
|
||||||
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
|
|
||||||
|
|
||||||
namespace torch::inductor {
|
|
||||||
|
|
||||||
std::vector<at::Tensor> AOTIModelContainerRunnerCuda::run(
|
|
||||||
std::vector<at::Tensor>& inputs,
|
|
||||||
cudaStream_t cuda_stream_handle) {
|
|
||||||
if (cuda_stream_handle == nullptr) {
|
|
||||||
cuda_stream_handle = c10::cuda::getCurrentCUDAStream().stream();
|
|
||||||
}
|
|
||||||
return AOTIModelContainerRunner::run(
|
|
||||||
inputs, reinterpret_cast<AOTInductorStreamHandle>(cuda_stream_handle));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace torch::inductor
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <cuda_runtime_api.h>
|
#include <c10/cuda/CUDAStream.h>
|
||||||
#include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
|
#include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
|
||||||
|
|
||||||
namespace torch::inductor {
|
namespace torch::inductor {
|
||||||
|
|
@ -15,7 +15,13 @@ class TORCH_API AOTIModelContainerRunnerCuda : public AOTIModelContainerRunner {
|
||||||
|
|
||||||
std::vector<at::Tensor> run(
|
std::vector<at::Tensor> run(
|
||||||
std::vector<at::Tensor>& inputs,
|
std::vector<at::Tensor>& inputs,
|
||||||
cudaStream_t cuda_stream_handle = nullptr);
|
cudaStream_t cuda_stream_handle = nullptr) {
|
||||||
|
if (cuda_stream_handle == nullptr) {
|
||||||
|
cuda_stream_handle = c10::cuda::getCurrentCUDAStream().stream();
|
||||||
|
}
|
||||||
|
return AOTIModelContainerRunner::run(
|
||||||
|
inputs, reinterpret_cast<AOTInductorStreamHandle>(cuda_stream_handle));
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace torch::inductor
|
} // namespace torch::inductor
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user