From 2dce364634220d2d0db97653ab3f3d024df2df03 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Wed, 20 Dec 2023 08:36:12 -0800 Subject: [PATCH] [AOTI][refactor] Remove model_container_runner_cuda.cpp (#116113) Differential Revision: [D52301272](https://our.internmc.facebook.com/intern/diff/D52301272) Pull Request resolved: https://github.com/pytorch/pytorch/pull/116113 Approved by: https://github.com/khabinov ghstack dependencies: #116047 --- build_variables.bzl | 1 - .../aoti_runner/model_container_runner_cpu.h | 4 ++++ .../aoti_runner/model_container_runner_cuda.cpp | 16 ---------------- .../aoti_runner/model_container_runner_cuda.h | 10 ++++++++-- 4 files changed, 12 insertions(+), 19 deletions(-) delete mode 100644 torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp diff --git a/build_variables.bzl b/build_variables.bzl index def6eac76e54f..711ca08253339 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -652,7 +652,6 @@ libtorch_cuda_core_sources = [ "torch/csrc/CudaIPCTypes.cpp", "torch/csrc/cuda/comm.cpp", "torch/csrc/cuda/memory_snapshot.cpp", - "torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp", "torch/csrc/inductor/aoti_torch/shim_cuda.cpp", "torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp", "torch/csrc/profiler/stubs/cuda.cpp", diff --git a/torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h b/torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h index a8de82a954093..360e9f1404fcf 100644 --- a/torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h +++ b/torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h @@ -9,6 +9,10 @@ class TORCH_API AOTIModelContainerRunnerCpu : public AOTIModelContainerRunner { const std::string& model_so_path, size_t num_models = 1) : AOTIModelContainerRunner(model_so_path, num_models, true, "") {} + + std::vector run(std::vector& inputs) { + return AOTIModelContainerRunner::run(inputs); + } }; } // namespace torch::inductor diff --git a/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp b/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp deleted file mode 100644 index 18eb0555028fb..0000000000000 --- a/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp +++ /dev/null @@ -1,16 +0,0 @@ -#include -#include - -namespace torch::inductor { - -std::vector AOTIModelContainerRunnerCuda::run( - std::vector& inputs, - cudaStream_t cuda_stream_handle) { - if (cuda_stream_handle == nullptr) { - cuda_stream_handle = c10::cuda::getCurrentCUDAStream().stream(); - } - return AOTIModelContainerRunner::run( - inputs, reinterpret_cast(cuda_stream_handle)); -} - -} // namespace torch::inductor diff --git a/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h b/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h index 5832320c3c923..2e73cb8784747 100644 --- a/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h +++ b/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include namespace torch::inductor { @@ -15,7 +15,13 @@ class TORCH_API AOTIModelContainerRunnerCuda : public AOTIModelContainerRunner { std::vector run( std::vector& inputs, - cudaStream_t cuda_stream_handle = nullptr); + cudaStream_t cuda_stream_handle = nullptr) { + if (cuda_stream_handle == nullptr) { + cuda_stream_handle = c10::cuda::getCurrentCUDAStream().stream(); + } + return AOTIModelContainerRunner::run( + inputs, reinterpret_cast(cuda_stream_handle)); + } }; } // namespace torch::inductor