Skip to content

Commit

Permalink
optionally build bellman-cuda from source if no build is found
Browse files Browse the repository at this point in the history
  • Loading branch information
robik75 committed Aug 26, 2024
1 parent 9091a68 commit 4e63115
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 13 deletions.
2 changes: 2 additions & 0 deletions crates/gpu-ffi/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ num_cpus = "1"
crossbeam = "0.8"

[build-dependencies]
era_cudart_sys.workspace = true
bindgen = "0.59.1"
cmake = "0.1"

[dev-dependencies]
rand = "0.4"
Expand Down
39 changes: 26 additions & 13 deletions crates/gpu-ffi/build.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
extern crate bindgen;
use era_cudart_sys::get_cuda_lib_path;
use std::env::var;
use std::path::Path;
use std::{env, path::PathBuf};

// build.rs

fn main() {
Expand Down Expand Up @@ -53,16 +55,27 @@ fn generate_bindings(bellman_cuda_path: &str) {
}

fn link_multiexp_library(bellman_cuda_path: &str) {
let kind = "static";
let name = "bellman-cuda";

println!("cargo:rustc-link-lib=dylib=stdc++");
println!("cargo:rustc-link-search=native=/usr/local/cuda/lib64");
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=static=cudadevrt");
println!(
"cargo:rustc-link-search=native={}/build/src",
bellman_cuda_path
);
println!("cargo:rustc-link-lib={}={}", kind, name);
let bellman_cuda_lib_path = if Path::new(bellman_cuda_path).join("build").exists() {
Path::new(bellman_cuda_path)
.join("build")
.join("src")
.to_str()
.unwrap()
.to_string()
} else {
let cudaarchs = var("CUDAARCHS").unwrap_or("native".to_string());
let dst = cmake::Config::new(bellman_cuda_path)
.profile("Release")
.define("CMAKE_CUDA_ARCHITECTURES", cudaarchs)
.build();
dst.to_str().unwrap().to_string()
};
println!("cargo:rustc-link-search=native={bellman_cuda_lib_path}");
println!("cargo:rustc-link-lib=static=bellman-cuda");
let cuda_lib_path = get_cuda_lib_path().unwrap();
let cuda_lib_path_str = cuda_lib_path.to_str().unwrap();
println!("cargo:rustc-link-search=native={cuda_lib_path_str}");
println!("cargo:rustc-link-lib=cudart");
#[cfg(target_os = "linux")]
println!("cargo:rustc-link-lib=stdc++");
}

0 comments on commit 4e63115

Please sign in to comment.