Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/distributed computing #29

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions examples/distributed/distributed_gblup.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Authors
# Alexander Freudenberg, [email protected]

# Copyright (C) 2023 Alexander Freudenberg

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


using Base;
@timev "Loading libraries" begin
using Random;
using LinearAlgebra;
using CSV;
using ClusterManagers, Distributed;
using DelimitedFiles;
using DataFrames;
using BenchmarkTools;
using Statistics;
using Libdl;
end

# =====================
# Global definitions
# =====================

ROOT_DIR = string(@__DIR__) * "/../.."

MODULE_PATH = ROOT_DIR * "/src/bindings/Julia/miraculix.jl"
LIBRARY_PATH = ROOT_DIR * "/src/miraculix/miraculix.so"
DATA_DIR = ROOT_DIR * "/data"
LOG_DIR = DATA_DIR * "/logs"

# Control miraculix verbosity
ENV["PRINT_LEVEL"] = "1";

# Get thread number
@assert haskey(ENV, "OMP_NUM_THREADS") "OMP_NUM_THREADS not set"
OMP_NUM_THREADS = ENV["OMP_NUM_THREADS"];
BLAS.set_num_threads(parse(Int,OMP_NUM_THREADS))
println("OMP threads set to $OMP_NUM_THREADS")
include(MODULE_PATH)

# =====================
# Auxiliary functions
# =====================

function multiply_Zt(obj_ref::Vector{Ref{Ptr{Cvoid}}}, B::Matrix{Float64}, snps::Int, indiv::Int, snps_per_tile::Int, n_col::Int)
if snps % snps_per_tile != 0
error("SNPs must be a multiple of snps_per_tile")
end
indices = range(1, snps, step = snps_per_tile)
ZtB = zeros((snps, n_col))
for (i,j) in enumerate(indices)
ZtB[j : (j + snps_per_tile - 1),:] += miraculix.dgemm_compressed.dgemm_compressed_main(true, obj_ref[i], B, snps_per_tile, indiv)
end
return ZtB
end

function multiply_Z(obj_ref::Vector{Ref{Ptr{Cvoid}}}, B::Matrix{Float64}, snps::Int, indiv::Int, indiv_per_tile::Int, n_col::Int)
if indiv % indiv_per_tile != 0
error("Indiv must be a multiple of indiv_per_tile")
end
indices = range(1, indiv, step = indiv_per_tile)
ZB = zeros((indiv, n_col))
for (i, j) in enumerate(indices)
ZB[j : (j + indiv_per_tile - 1),:] += miraculix.dgemm_compressed.dgemm_compressed_main(false, obj_ref[i], B, snps, indiv_per_tile)
end
return ZB
end


# =====================
# Main
# =====================

miraculix.set_library_path(LIBRARY_PATH)
miraculix.load_shared_library()
miraculix.dgemm_compressed.set_options(use_gpu=true, verbose=1)

init_sym = dlsym(miraculix.LIBRARY_HANDLE[], :plink2compressed)
dgemm_compressed_sym = dlsym(miraculix.LIBRARY_HANDLE[], :dgemm_compressed)
free_sym = dlsym(miraculix.LIBRARY_HANDLE[], :free_compressed)


# Set hyper parameters
max_iter = 1_000 # Maximum number of iterations
print_iter = 1e2 # Frequency of convergence information
conv_crit = 1e-2 # Maximum norm of residual
n_devices = 1 # Number of devices
n_col = 5 # Number of columns of RHS

# We assume that genotype data has been generated by the R package MoBPS
data_file = DATA_DIR * "/xsmall.bed"

# Read-in data from PLINK binary format
@info "Reading in data from $data_file and transpose it"
@timev "Preprocessing" begin
# Read PLINK data and calculate allele frequencies
wtime = @elapsed plink, n_snps, n_indiv = miraculix.read_plink.read_bed(data_file, coding_twobit = false, calc_freq = false, check_for_missings = false)
freq = miraculix.read_plink.read_freq(DATA_DIR * "/xsmall.freq")

@debug "Time for reading: $wtime s."

if (length(ARGS) > 0) && (ARGS[1] == "test")
n_snps = 100
plink = plink[:,1:n_snps]
freq = freq[1:n_snps]
end

# Transpose matrix
wtime = @elapsed plink_transposed = miraculix.compressed_operations.transpose_genotype_matrix(plink, n_snps, n_indiv)
@debug "Time for transposing: $wtime s."

GC.gc()
end


obj_ref = Vector{Ref{Ptr{Cvoid}}}()
obj_ref_trans = Vector{Ref{Ptr{Cvoid}}}()
n_indiv_per_tile = Int(n_indiv / 2)
n_snps_per_tile = Int(n_snps / 2)

obj_reference = Ref{Ptr{Cvoid}}(C_NULL)
ccall(init_sym, Cvoid, (Ptr{UInt8}, Ptr{UInt8}, Cint, Cint, Ptr{Float64}, Cint, Ptr{Ptr{Cvoid}}), plink, plink_transposed, Int32(n_snps), Int32(n_indiv), freq, Int32(n_col), obj_reference)


for i in range(0, 1)
push!(obj_ref, Ref{Ptr{Cvoid}}(C_NULL))
push!(obj_ref_trans, Ref{Ptr{Cvoid}}(C_NULL))

ENV["CUDA_DEVICE"] = i % n_devices
ccall(init_sym, Cvoid, (Ptr{UInt8}, Ptr{UInt8}, Cint, Cint, Ptr{Float64}, Cint, Ptr{Ptr{Cvoid}}), pointer(plink, i * n_snps_per_tile * size(plink,1) + 1 ), C_NULL, Int32(n_snps_per_tile), Int32(n_indiv), pointer(freq,i * n_snps_per_tile + 1), Int32(n_col), obj_ref[i + 1])

ENV["CUDA_DEVICE"] = (i + 1) % n_devices
ccall(init_sym, Cvoid, (Ptr{UInt8}, Ptr{UInt8}, Cint, Cint, Ptr{Float64}, Cint, Ptr{Ptr{Cvoid}}), C_NULL, pointer(plink_transposed, (i * n_indiv_per_tile * size(plink_transposed, 1) + 1)), Int32(n_snps), Int32(n_indiv_per_tile), freq, Int32(n_col), obj_ref_trans[i + 1])
end

B = ones(Float64, n_indiv, n_col) # RHS of equation system

V = multiply_Zt(obj_ref, B, n_snps, n_indiv, n_snps_per_tile, n_col)
V_ref = miraculix.dgemm_compressed.dgemm_compressed_main(true, obj_reference, B, n_snps, n_indiv)
@assert isapprox(V, V_ref)

W = multiply_Z(obj_ref_trans, V, n_snps, n_indiv, n_indiv_per_tile, n_col)
W_ref = miraculix.dgemm_compressed.dgemm_compressed_main(false, obj_reference, V_ref, n_snps, n_indiv)

for ref in obj_ref
ccall(free_sym, Cvoid, (Ptr{Ptr{Cvoid}},), ref)
end
for ref in obj_ref_trans
ccall(free_sym, Cvoid, (Ptr{Ptr{Cvoid}},), ref)
end
ccall(free_sym, Cvoid, (Ptr{Ptr{Cvoid}},), obj_reference)
2 changes: 1 addition & 1 deletion examples/gblup/calculate_gblup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ include(MODULE_PATH)
# =====================
function multiply_ld(obj_ref::Ref{Ptr{Cvoid}}, snps::Int, indiv::Int, B::Matrix{Float64})
# Calculate SNPs times Matrix
Z_vec = miraculix.dgemm_compressed.dgemm_compressed_main(false, obj_ref, B, n_snps, n_indiv)
Z_vec = miraculix.dgemm_compressed.dgemm_compressed_main(false, obj_ref, B, snps, indiv)
# Calculate Individuals times Matrix
ZtZ_vec = miraculix.dgemm_compressed.dgemm_compressed_main(true, obj_ref, Z_vec, snps, indiv)
return ZtZ_vec
Expand Down
6 changes: 5 additions & 1 deletion examples/iterative_solver/grm_solve_cg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ function GRM_vec(obj_ref::Ref{Ptr{Cvoid}}, B::Matrix{Float64}, snps::Int, indiv:
return Gv
end

# =====================
# Main
# =====================

# Set library path and load library
miraculix.set_library_path(LIBRARY_PATH)
miraculix.load_shared_library()
Expand All @@ -94,7 +98,7 @@ genotype_data, calc_freq, n_snps, n_indiv = miraculix.read_plink.read_bed(DATA_F
genotype_data_transposed = miraculix.compressed_operations.transpose_genotype_matrix(genotype_data, n_snps, n_indiv)

# Initialize storage object
obj_ref = miraculix.dgemm_compressed.init_compressed(genotype_data,genotype_data_transposed, n_snps, n_indiv, freq, 1)
obj_ref = miraculix.dgemm_compressed.init_compressed(genotype_data, genotype_data_transposed, n_snps, n_indiv, freq, 1)

# Set hyper parameters
max_iter = 1_000 # Maximum number of iterations
Expand Down
Loading