Skip to content

Commit

Permalink
[AOTI][refactor] Remove model_container_runner_cuda.cpp (pytorch#116113)
Browse files Browse the repository at this point in the history
Differential Revision: [D52301272](https://our.internmc.facebook.com/intern/diff/D52301272)
Pull Request resolved: pytorch#116113
Approved by: https://github.com/khabinov
ghstack dependencies: pytorch#116047
  • Loading branch information
desertfire authored and pytorchmergebot committed Dec 21, 2023
1 parent f71d302 commit 2dce364
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 19 deletions.
1 change: 0 additions & 1 deletion build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,6 @@ libtorch_cuda_core_sources = [
"torch/csrc/CudaIPCTypes.cpp",
"torch/csrc/cuda/comm.cpp",
"torch/csrc/cuda/memory_snapshot.cpp",
"torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp",
"torch/csrc/inductor/aoti_torch/shim_cuda.cpp",
"torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp",
"torch/csrc/profiler/stubs/cuda.cpp",
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ class TORCH_API AOTIModelContainerRunnerCpu : public AOTIModelContainerRunner {
const std::string& model_so_path,
size_t num_models = 1)
: AOTIModelContainerRunner(model_so_path, num_models, true, "") {}

std::vector<at::Tensor> run(std::vector<at::Tensor>& inputs) {
return AOTIModelContainerRunner::run(inputs);
}
};

} // namespace torch::inductor
16 changes: 0 additions & 16 deletions torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp

This file was deleted.

10 changes: 8 additions & 2 deletions torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include <cuda_runtime_api.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/csrc/inductor/aoti_runner/model_container_runner.h>

namespace torch::inductor {
Expand All @@ -15,7 +15,13 @@ class TORCH_API AOTIModelContainerRunnerCuda : public AOTIModelContainerRunner {

std::vector<at::Tensor> run(
std::vector<at::Tensor>& inputs,
cudaStream_t cuda_stream_handle = nullptr);
cudaStream_t cuda_stream_handle = nullptr) {
if (cuda_stream_handle == nullptr) {
cuda_stream_handle = c10::cuda::getCurrentCUDAStream().stream();
}
return AOTIModelContainerRunner::run(
inputs, reinterpret_cast<AOTInductorStreamHandle>(cuda_stream_handle));
}
};

} // namespace torch::inductor

0 comments on commit 2dce364

Please sign in to comment.