Skip to content

Commit

Permalink
Merge pull request #53 from Intel-HLS/gp_fpga_support
Browse files Browse the repository at this point in the history
Add FPGA accelerated PairHMM
  • Loading branch information
George Powley authored Apr 23, 2017
2 parents dd638ad + 9d23edc commit e24dcb5
Show file tree
Hide file tree
Showing 12 changed files with 225 additions and 52 deletions.
2 changes: 1 addition & 1 deletion .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 12 additions & 4 deletions src/main/java/com/intel/gkl/pairhmm/IntelPairHmm.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package com.intel.gkl.pairhmm;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import com.intel.gkl.IntelGKLUtils;
import com.intel.gkl.NativeLibraryLoader;
import org.broadinstitute.gatk.nativebindings.pairhmm.HaplotypeDataHolder;
Expand All @@ -13,8 +16,10 @@
* Provides a native PairHMM implementation accelerated for the Intel Architecture.
*/
public class IntelPairHmm implements PairHMMNativeBinding {
private final static Logger logger = LogManager.getLogger(IntelPairHmm.class);
private static final String NATIVE_LIBRARY_NAME = "gkl_pairhmm";
private String nativeLibraryName = "gkl_pairhmm";
boolean useFpga = false;

void setNativeLibraryName(String nativeLibraryName) {
this.nativeLibraryName = nativeLibraryName;
Expand Down Expand Up @@ -46,13 +51,15 @@ public synchronized boolean load(File tempDir) {
* @param args the args used to configure native PairHMM
*/
public void initialize(PairHMMNativeArguments args) {
if(args == null)
{
if (args == null) {
args = new PairHMMNativeArguments();
args.useDoublePrecision = false;
args.maxNumberOfThreads = 1;
}
initNative(ReadDataHolder.class, HaplotypeDataHolder.class, args.useDoublePrecision, args.maxNumberOfThreads);
if (args.useDoublePrecision && useFpga) {
logger.warn("FPGA PairHMM does not support double precision floating-point. Using AVX PairHMM");
}
initNative(ReadDataHolder.class, HaplotypeDataHolder.class, args.useDoublePrecision, args.maxNumberOfThreads, useFpga);
}

/**
Expand Down Expand Up @@ -82,7 +89,8 @@ public void done() {
private native static void initNative(Class<?> readDataHolderClass,
Class<?> haplotypeDataHolderClass,
boolean doublePrecision,
int maxThreads);
int maxThreads,
boolean useFpga);

private native void computeLikelihoodsNative(Object[] readDataArray,
Object[] haplotypeDataArray,
Expand Down
13 changes: 13 additions & 0 deletions src/main/java/com/intel/gkl/pairhmm/IntelPairHmmFpga.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.intel.gkl.pairhmm;

/**
* Provides a native PairHMM implementation accelerated using Intel FPGAs
*/
public final class IntelPairHmmFpga extends IntelPairHmm {
private static final String NATIVE_LIBRARY_NAME = "gkl_pairhmm_fpga";

public IntelPairHmmFpga() {
setNativeLibraryName(NATIVE_LIBRARY_NAME);
useFpga = true;
}
}
24 changes: 20 additions & 4 deletions src/main/native/pairhmm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,21 +1,37 @@
#---------------------------------------------------------------------
# pairhmm_shacc (stub version)
#---------------------------------------------------------------------
set(TARGET gkl_pairhmm_shacc)
add_library(${TARGET} SHARED shacc_pairhmm_stub.cc)
install(TARGETS ${TARGET} DESTINATION ${CMAKE_BINARY_DIR})

#---------------------------------------------------------------------
# pairhmm
#---------------------------------------------------------------------
set(TARGET gkl_pairhmm)

add_library(${TARGET} SHARED IntelPairHmm.cc)
if (APPLE)
target_link_libraries(${TARGET} gkl_pairhmm_shacc)
endif()
install(TARGETS ${TARGET} DESTINATION ${CMAKE_BINARY_DIR})

#---------------------------------------------------------------------
# pairhmm_omp
#---------------------------------------------------------------------
set(TARGET gkl_pairhmm_omp)

find_package(OpenMP)

if(OPENMP_FOUND)
set(TARGET gkl_pairhmm_omp)
add_library(${TARGET} SHARED IntelPairHmm.cc)
set_target_properties(${TARGET} PROPERTIES COMPILE_OPTIONS ${OpenMP_CXX_FLAGS})
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS ${OpenMP_CXX_FLAGS})
target_link_libraries(${TARGET} ${OpenMP_CXX_FLAGS})
install(TARGETS ${TARGET} DESTINATION ${CMAKE_BINARY_DIR})
endif()

#---------------------------------------------------------------------
# pairhmm_fpga
#---------------------------------------------------------------------
set(TARGET gkl_pairhmm_fpga)
add_library(${TARGET} SHARED IntelPairHmm.cc)
target_link_libraries(${TARGET} gkl_pairhmm_shacc)
install(TARGETS ${TARGET} DESTINATION ${CMAKE_BINARY_DIR})
28 changes: 21 additions & 7 deletions src/main/native/pairhmm/IntelPairHmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
#include "IntelPairHmm.h"
#include "pairhmm_common.h"
#include "avx-pairhmm.h"
#include "shacc_pairhmm.h"
#include "JavaData.h"
#include "debug.h"

bool g_use_double;
int g_max_threads;
bool g_use_fpga;

Context<float> ctxf;
Context<double> ctxd;
Context<float> g_ctxf;
Context<double> g_ctxd;

float (*g_compute_full_prob_float)(testcase *tc);
double (*g_compute_full_prob_double)(testcase *tc);
Expand All @@ -24,7 +27,7 @@ double (*g_compute_full_prob_double)(testcase *tc);
*/
JNIEXPORT void JNICALL Java_com_intel_gkl_pairhmm_IntelPairHmm_initNative
(JNIEnv* env, jclass cls, jclass readDataHolder, jclass haplotypeDataHolder,
jboolean use_double, jint max_threads)
jboolean use_double, jint max_threads, jboolean use_fpga)
{
DBG("Enter");

Expand All @@ -37,6 +40,8 @@ JNIEXPORT void JNICALL Java_com_intel_gkl_pairhmm_IntelPairHmm_initNative
g_max_threads = std::min((int)max_threads, omp_get_max_threads());
#endif

g_use_fpga = use_fpga;

// enable FTZ
_MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);

Expand All @@ -49,7 +54,6 @@ JNIEXPORT void JNICALL Java_com_intel_gkl_pairhmm_IntelPairHmm_initNative
DBG("Exit");
}


/*
* Class: com_intel_gkl_pairhmm_IntelPairHmm
* Method: computeLikelihoodsNative
Expand All @@ -69,18 +73,28 @@ JNIEXPORT void JNICALL Java_com_intel_gkl_pairhmm_IntelPairHmm_computeLikelihood

//==================================================================
// calcutate pairHMM
shacc_pairhmm::Batch batch;
bool batch_valid = false;
if (g_use_fpga && !g_use_double) {
batch = javaData.getBatch();
batch_valid = shacc_pairhmm::calculate(batch);
}

#ifdef _OPENMP
#pragma omp parallel for schedule(dynamic, 1) num_threads(g_max_threads)
#endif
for (int i = 0; i < testcases.size(); i++) {
double result_final = 0;

float result_float = g_use_double ? 0.0f : g_compute_full_prob_float(&testcases[i]);
float result_float = g_use_double ? 0.0f :
batch_valid ? batch.results[i] : g_compute_full_prob_float(&testcases[i]);

if (result_float < MIN_ACCEPTED) {
double result_double = g_compute_full_prob_double(&testcases[i]);
result_final = log10(result_double) - ctxd.LOG10_INITIAL_CONSTANT;
result_final = log10(result_double) - g_ctxd.LOG10_INITIAL_CONSTANT;
}
else {
result_final = (double)(log10f(result_float) - ctxf.LOG10_INITIAL_CONSTANT);
result_final = (double)(log10f(result_float) - g_ctxf.LOG10_INITIAL_CONSTANT);
}

javaResults[i] = result_final;
Expand Down
4 changes: 2 additions & 2 deletions src/main/native/pairhmm/IntelPairHmm.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

101 changes: 75 additions & 26 deletions src/main/native/pairhmm/JavaData.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@

#include <vector>
#include "pairhmm_common.h"
#include "shacc_pairhmm.h"
#include "debug.h"

class JavaData {
public:

// cache field ids
static void init(JNIEnv *env, jclass readDataHolder, jclass haplotypeDataHolder) {
m_readBasesFid = getFieldId(env, readDataHolder, "readBases", "[B");
m_readQualsFid = getFieldId(env, readDataHolder, "readQuals", "[B");
Expand All @@ -15,24 +19,29 @@ class JavaData {
m_haplotypeBasesFid = getFieldId(env, haplotypeDataHolder, "haplotypeBases", "[B");
}

// create array of testcases
std::vector<testcase> getData(JNIEnv *env, jobjectArray& readDataArray, jobjectArray& haplotypeDataArray) {
int numReads = env->GetArrayLength(readDataArray);
int numHaplotypes = env->GetArrayLength(haplotypeDataArray);

testcases.resize(numReads * numHaplotypes);
m_batch.num_reads = numReads;
m_batch.num_haps = numHaplotypes;

std::vector<char*> haplotypes;
std::vector<int> haplotypeLengths;

std::vector<char*> haplotypes(numHaplotypes);
std::vector<int> haplotypeLengths(numHaplotypes);
long total_hap_length = 0;
long total_read_length = 0;

// get haplotypes
for (int i = 0; i < numHaplotypes; i++) {
int length = 0;
haplotypes[i] = getCharArray(env, haplotypeDataArray, i, m_haplotypeBasesFid, length);
haplotypeLengths[i] = length;
haplotypes.push_back(getCharArray(env, haplotypeDataArray, i, m_haplotypeBasesFid, length));
haplotypeLengths.push_back(length);
total_hap_length += length;
}

// get reads and create testcases
int i = 0;
for (int r = 0; r < numReads; r++) {
int length = 0;
char* reads = getCharArray(env, readDataArray, r, m_readBasesFid, length);
Expand All @@ -41,33 +50,69 @@ class JavaData {
char* delGops = getCharArray(env, readDataArray, r, m_deletionGopFid, length);
char* gapConts = getCharArray(env, readDataArray, r, m_overallGcpFid, length);
char* readQuals = getCharArray(env, readDataArray, r, m_readQualsFid, length);
total_read_length += length;

for (int h = 0; h < numHaplotypes; h++) {
testcases[i].hap = haplotypes[h];
testcases[i].haplen = haplotypeLengths[h];
testcases[i].rs = reads;
testcases[i].rslen = readLength;
testcases[i].i = insGops;
testcases[i].d = delGops;
testcases[i].c = gapConts;
testcases[i].q = readQuals;
i++;
testcase tc;
tc.hap = haplotypes[h];
tc.haplen = haplotypeLengths[h];
tc.rs = reads;
tc.rslen = readLength;
tc.i = insGops;
tc.d = delGops;
tc.c = gapConts;
tc.q = readQuals;
m_testcases.push_back(tc);
}
}

return testcases;
m_total_cells = 3 * total_read_length * total_hap_length;

return m_testcases;
}

double* getOutputArray(JNIEnv *env, jdoubleArray array) {
return getDoubleArray(env, array);
}

// create shacc_pairhmm::batch from array of testcases
shacc_pairhmm::Batch getBatch() {
int num_testcases = m_batch.num_reads * m_batch.num_haps;

// get reads
for (int i = 0; i < num_testcases; i += m_batch.num_haps) {
shacc_pairhmm::Read read;
read.bases = m_testcases[i].rs;
read.length = m_testcases[i].rslen;
read.i = m_testcases[i].i;
read.d = m_testcases[i].d;
read.c = m_testcases[i].c;
read.q = m_testcases[i].q;
m_reads.push_back(read);
}
m_batch.reads = m_reads.data();

// get haplotypes
for (int i = 0; i < m_batch.num_haps; i++) {
shacc_pairhmm::Haplotype hap;
DBG("hap #%d len = %d", i, m_testcases[i].haplen);
hap.bases = m_testcases[i].hap;
hap.length = m_testcases[i].haplen;
m_haps.push_back(hap);
}
m_batch.haps = m_haps.data();

m_batch.num_cells = m_total_cells;

return m_batch;
}

void releaseData(JNIEnv *env) {
for (int i = 0; i < byteArrays.size(); i++) {
env->ReleaseByteArrayElements(byteArrays[i].first, byteArrays[i].second, 0);
for (int i = 0; i < m_byteArrays.size(); i++) {
env->ReleaseByteArrayElements(m_byteArrays[i].first, m_byteArrays[i].second, 0);
}
for (int i = 0; i < doubleArrays.size(); i++) {
env->ReleaseDoubleArrayElements(doubleArrays[i].first, doubleArrays[i].second, 0);
for (int i = 0; i < m_doubleArrays.size(); i++) {
env->ReleaseDoubleArrayElements(m_doubleArrays[i].first, m_doubleArrays[i].second, 0);
}
}

Expand All @@ -88,7 +133,7 @@ class JavaData {
env->ThrowNew(env->FindClass("java/lang/OutOfMemoryError"), "Unable to access jbyteArray");
}
length = env->GetArrayLength(byteArray);
byteArrays.push_back(std::make_pair(byteArray, primArray));
m_byteArrays.push_back(std::make_pair(byteArray, primArray));
return (char*)primArray;
}

Expand All @@ -97,13 +142,17 @@ class JavaData {
if (primArray == NULL) {
env->ThrowNew(env->FindClass("java/lang/OutOfMemoryError"), "Unable to access jdoubleArray");
}
doubleArrays.push_back(std::make_pair(array, primArray));
m_doubleArrays.push_back(std::make_pair(array, primArray));
return (double*)primArray;
}

std::vector<testcase> testcases;
std::vector<std::pair<jbyteArray, jbyte*> > byteArrays;
std::vector<std::pair<jdoubleArray, jdouble*> > doubleArrays;
shacc_pairhmm::Batch m_batch;
std::vector<shacc_pairhmm::Read> m_reads;
std::vector<shacc_pairhmm::Haplotype> m_haps;
std::vector<testcase> m_testcases;
std::vector<std::pair<jbyteArray, jbyte*> > m_byteArrays;
std::vector<std::pair<jdoubleArray, jdouble*> > m_doubleArrays;
long m_total_cells;

static jfieldID m_readBasesFid;
static jfieldID m_readQualsFid;
Expand Down
18 changes: 18 additions & 0 deletions src/main/native/pairhmm/debug.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef DEBUG_H
#define DEBUG_H

#include <stdio.h>

#ifdef DEBUG
# define DBG(M, ...) fprintf(stderr, "[DEBUG] (%s:%d) : " M "\n", __FUNCTION__, __LINE__, ##__VA_ARGS__)
# define INFO(M, ...) fprintf(stderr, "[INFO] (%s:%d) : " M "\n", __FUNCTION__, __LINE__, ##__VA_ARGS__)
# define WARN(M, ...) fprintf(stderr, "[WARNING] (%s:%d) : " M "\n", __FUNCTION__, __LINE__, ##__VA_ARGS__)
# define ERROR(M, ...) fprintf(stderr, "[ERROR] (%s:%d) :" M "\n", __FUNCTION__, __LINE__, ##__VA_ARGS__)
#else
# define DBG(M, ...)
# define INFO(M, ...) fprintf(stderr, "[INFO] " M "\n", ##__VA_ARGS__)
# define WARN(M, ...) fprintf(stderr, "[WARNING] " M "\n", ##__VA_ARGS__)
# define ERROR(M, ...) fprintf(stderr, "[ERROR] " M "\n", ##__VA_ARGS__)
#endif

#endif
Loading

0 comments on commit e24dcb5

Please sign in to comment.