Skip to content

Commit

Permalink
Fixup wrt. cuda+none
Browse files Browse the repository at this point in the history
  • Loading branch information
stemann committed Dec 1, 2024
1 parent fa02b9e commit 7bdd1cf
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions T/Torch/build_tarballs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,8 @@ if [[ $target == x86_64-apple-darwin* # Fails to compile: /workspace/srcdir/pyto
cmake_extra_args+="-DUSE_MKLDNN=OFF "
fi
if [[ $bb_full_target == *cuda* ]]; then
cuda_version=`echo $bb_full_target | sed -E -e 's/.*cuda\+([0-9]+\.[0-9]+).*/\1/'`
cuda_version_major=`echo $cuda_version | cut -d . -f 1`
cuda_version_minor=`echo $cuda_version | cut -d . -f 2`
cuda_version=${bb_full_target##*-cuda+}
if [[ $bb_full_target == *cuda* ]] && [[ $cuda_version != none ]]; then
export CUDA_PATH="$prefix/cuda"
mkdir $WORKSPACE/tmpdir
export TMPDIR=$WORKSPACE/tmpdir
Expand All @@ -122,6 +120,8 @@ if [[ $bb_full_target == *cuda* ]]; then
-DUSE_MAGMA=ON \
-DCUDA_TOOLKIT_ROOT_DIR=$CUDA_PATH \
-DCUB_INCLUDE_DIR=$WORKSPACE/srcdir/pytorch/third_party/cub "
cuda_version_major=`echo $cuda_version | cut -d . -f 1`
cuda_version_minor=`echo $cuda_version | cut -d . -f 2`
micromamba install -y magma-cuda${cuda_version_major}${cuda_version_minor} -c pytorch
git submodule update --init \
third_party/cub \
Expand Down Expand Up @@ -177,10 +177,10 @@ configure() {
$cmake_extra_args \
..
}
if [[ $bb_full_target != *cuda* ]]; then
configure
else
if [[ $bb_full_target == *cuda* ]] && [[ $cuda_version != none ]]; then
configure || configure
else
configure
fi
cmake --build . -- -j $nproc
make install
Expand Down Expand Up @@ -263,7 +263,7 @@ builds = []
for platform in platforms
should_build_platform(platform) || continue
additional_deps = BinaryBuilder.AbstractDependency[]
if haskey(platform, "cuda")
if haskey(platform, "cuda") && platform["cuda"] != "none"
if platform["cuda"] == "11.3"
additional_deps = BinaryBuilder.AbstractDependency[
BuildDependency(PackageSpec("CUDA_full_jll", v"11.3.1")),
Expand Down

0 comments on commit 7bdd1cf

Please sign in to comment.