From 9d23edc0ed101d198f8f9ca9d292262a4be7e6e5 Mon Sep 17 00:00:00 2001 From: George Powley Date: Fri, 14 Apr 2017 14:05:16 -0400 Subject: [PATCH] add pairhmm fpga support --- .idea/misc.xml | 2 +- .../com/intel/gkl/pairhmm/IntelPairHmm.java | 16 ++- .../intel/gkl/pairhmm/IntelPairHmmFpga.java | 13 +++ src/main/native/pairhmm/CMakeLists.txt | 24 ++++- src/main/native/pairhmm/IntelPairHmm.cc | 28 +++-- src/main/native/pairhmm/IntelPairHmm.h | 4 +- src/main/native/pairhmm/JavaData.h | 101 +++++++++++++----- src/main/native/pairhmm/debug.h | 18 ++++ src/main/native/pairhmm/pairhmm_common.h | 9 +- src/main/native/pairhmm/shacc_pairhmm.h | 38 +++++++ src/main/native/pairhmm/shacc_pairhmm_stub.cc | 12 +++ .../intel/gkl/pairhmm/PairHmmUnitTest.java | 12 ++- 12 files changed, 225 insertions(+), 52 deletions(-) create mode 100644 src/main/java/com/intel/gkl/pairhmm/IntelPairHmmFpga.java create mode 100644 src/main/native/pairhmm/debug.h create mode 100644 src/main/native/pairhmm/shacc_pairhmm.h create mode 100644 src/main/native/pairhmm/shacc_pairhmm_stub.cc diff --git a/.idea/misc.xml b/.idea/misc.xml index 2139cc15..857dea7e 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -14,7 +14,7 @@ - + diff --git a/src/main/java/com/intel/gkl/pairhmm/IntelPairHmm.java b/src/main/java/com/intel/gkl/pairhmm/IntelPairHmm.java index 8cd3984c..34657929 100644 --- a/src/main/java/com/intel/gkl/pairhmm/IntelPairHmm.java +++ b/src/main/java/com/intel/gkl/pairhmm/IntelPairHmm.java @@ -1,5 +1,8 @@ package com.intel.gkl.pairhmm; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + import com.intel.gkl.IntelGKLUtils; import com.intel.gkl.NativeLibraryLoader; import org.broadinstitute.gatk.nativebindings.pairhmm.HaplotypeDataHolder; @@ -13,8 +16,10 @@ * Provides a native PairHMM implementation accelerated for the Intel Architecture. */ public class IntelPairHmm implements PairHMMNativeBinding { + private final static Logger logger = LogManager.getLogger(IntelPairHmm.class); private static final String NATIVE_LIBRARY_NAME = "gkl_pairhmm"; private String nativeLibraryName = "gkl_pairhmm"; + boolean useFpga = false; void setNativeLibraryName(String nativeLibraryName) { this.nativeLibraryName = nativeLibraryName; @@ -46,13 +51,15 @@ public synchronized boolean load(File tempDir) { * @param args the args used to configure native PairHMM */ public void initialize(PairHMMNativeArguments args) { - if(args == null) - { + if (args == null) { args = new PairHMMNativeArguments(); args.useDoublePrecision = false; args.maxNumberOfThreads = 1; } - initNative(ReadDataHolder.class, HaplotypeDataHolder.class, args.useDoublePrecision, args.maxNumberOfThreads); + if (args.useDoublePrecision && useFpga) { + logger.warn("FPGA PairHMM does not support double precision floating-point. Using AVX PairHMM"); + } + initNative(ReadDataHolder.class, HaplotypeDataHolder.class, args.useDoublePrecision, args.maxNumberOfThreads, useFpga); } /** @@ -82,7 +89,8 @@ public void done() { private native static void initNative(Class readDataHolderClass, Class haplotypeDataHolderClass, boolean doublePrecision, - int maxThreads); + int maxThreads, + boolean useFpga); private native void computeLikelihoodsNative(Object[] readDataArray, Object[] haplotypeDataArray, diff --git a/src/main/java/com/intel/gkl/pairhmm/IntelPairHmmFpga.java b/src/main/java/com/intel/gkl/pairhmm/IntelPairHmmFpga.java new file mode 100644 index 00000000..861fc94f --- /dev/null +++ b/src/main/java/com/intel/gkl/pairhmm/IntelPairHmmFpga.java @@ -0,0 +1,13 @@ +package com.intel.gkl.pairhmm; + +/** + * Provides a native PairHMM implementation accelerated using Intel FPGAs + */ +public final class IntelPairHmmFpga extends IntelPairHmm { + private static final String NATIVE_LIBRARY_NAME = "gkl_pairhmm_fpga"; + + public IntelPairHmmFpga() { + setNativeLibraryName(NATIVE_LIBRARY_NAME); + useFpga = true; + } +} diff --git a/src/main/native/pairhmm/CMakeLists.txt b/src/main/native/pairhmm/CMakeLists.txt index 6b12bbdf..0c611397 100644 --- a/src/main/native/pairhmm/CMakeLists.txt +++ b/src/main/native/pairhmm/CMakeLists.txt @@ -1,21 +1,37 @@ +#--------------------------------------------------------------------- +# pairhmm_shacc (stub version) +#--------------------------------------------------------------------- +set(TARGET gkl_pairhmm_shacc) +add_library(${TARGET} SHARED shacc_pairhmm_stub.cc) +install(TARGETS ${TARGET} DESTINATION ${CMAKE_BINARY_DIR}) + #--------------------------------------------------------------------- # pairhmm #--------------------------------------------------------------------- set(TARGET gkl_pairhmm) - add_library(${TARGET} SHARED IntelPairHmm.cc) +if (APPLE) + target_link_libraries(${TARGET} gkl_pairhmm_shacc) +endif() install(TARGETS ${TARGET} DESTINATION ${CMAKE_BINARY_DIR}) #--------------------------------------------------------------------- # pairhmm_omp #--------------------------------------------------------------------- -set(TARGET gkl_pairhmm_omp) - find_package(OpenMP) if(OPENMP_FOUND) + set(TARGET gkl_pairhmm_omp) add_library(${TARGET} SHARED IntelPairHmm.cc) set_target_properties(${TARGET} PROPERTIES COMPILE_OPTIONS ${OpenMP_CXX_FLAGS}) - set_target_properties(${TARGET} PROPERTIES LINK_FLAGS ${OpenMP_CXX_FLAGS}) + target_link_libraries(${TARGET} ${OpenMP_CXX_FLAGS}) install(TARGETS ${TARGET} DESTINATION ${CMAKE_BINARY_DIR}) endif() + +#--------------------------------------------------------------------- +# pairhmm_fpga +#--------------------------------------------------------------------- +set(TARGET gkl_pairhmm_fpga) +add_library(${TARGET} SHARED IntelPairHmm.cc) +target_link_libraries(${TARGET} gkl_pairhmm_shacc) +install(TARGETS ${TARGET} DESTINATION ${CMAKE_BINARY_DIR}) diff --git a/src/main/native/pairhmm/IntelPairHmm.cc b/src/main/native/pairhmm/IntelPairHmm.cc index d0835095..26a506ad 100644 --- a/src/main/native/pairhmm/IntelPairHmm.cc +++ b/src/main/native/pairhmm/IntelPairHmm.cc @@ -6,13 +6,16 @@ #include "IntelPairHmm.h" #include "pairhmm_common.h" #include "avx-pairhmm.h" +#include "shacc_pairhmm.h" #include "JavaData.h" +#include "debug.h" bool g_use_double; int g_max_threads; +bool g_use_fpga; -Context ctxf; -Context ctxd; +Context g_ctxf; +Context g_ctxd; float (*g_compute_full_prob_float)(testcase *tc); double (*g_compute_full_prob_double)(testcase *tc); @@ -24,7 +27,7 @@ double (*g_compute_full_prob_double)(testcase *tc); */ JNIEXPORT void JNICALL Java_com_intel_gkl_pairhmm_IntelPairHmm_initNative (JNIEnv* env, jclass cls, jclass readDataHolder, jclass haplotypeDataHolder, - jboolean use_double, jint max_threads) + jboolean use_double, jint max_threads, jboolean use_fpga) { DBG("Enter"); @@ -37,6 +40,8 @@ JNIEXPORT void JNICALL Java_com_intel_gkl_pairhmm_IntelPairHmm_initNative g_max_threads = std::min((int)max_threads, omp_get_max_threads()); #endif + g_use_fpga = use_fpga; + // enable FTZ _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON); @@ -49,7 +54,6 @@ JNIEXPORT void JNICALL Java_com_intel_gkl_pairhmm_IntelPairHmm_initNative DBG("Exit"); } - /* * Class: com_intel_gkl_pairhmm_IntelPairHmm * Method: computeLikelihoodsNative @@ -69,18 +73,28 @@ JNIEXPORT void JNICALL Java_com_intel_gkl_pairhmm_IntelPairHmm_computeLikelihood //================================================================== // calcutate pairHMM + shacc_pairhmm::Batch batch; + bool batch_valid = false; + if (g_use_fpga && !g_use_double) { + batch = javaData.getBatch(); + batch_valid = shacc_pairhmm::calculate(batch); + } + +#ifdef _OPENMP #pragma omp parallel for schedule(dynamic, 1) num_threads(g_max_threads) +#endif for (int i = 0; i < testcases.size(); i++) { double result_final = 0; - float result_float = g_use_double ? 0.0f : g_compute_full_prob_float(&testcases[i]); + float result_float = g_use_double ? 0.0f : + batch_valid ? batch.results[i] : g_compute_full_prob_float(&testcases[i]); if (result_float < MIN_ACCEPTED) { double result_double = g_compute_full_prob_double(&testcases[i]); - result_final = log10(result_double) - ctxd.LOG10_INITIAL_CONSTANT; + result_final = log10(result_double) - g_ctxd.LOG10_INITIAL_CONSTANT; } else { - result_final = (double)(log10f(result_float) - ctxf.LOG10_INITIAL_CONSTANT); + result_final = (double)(log10f(result_float) - g_ctxf.LOG10_INITIAL_CONSTANT); } javaResults[i] = result_final; diff --git a/src/main/native/pairhmm/IntelPairHmm.h b/src/main/native/pairhmm/IntelPairHmm.h index c18b8150..ff802c1d 100644 --- a/src/main/native/pairhmm/IntelPairHmm.h +++ b/src/main/native/pairhmm/IntelPairHmm.h @@ -10,10 +10,10 @@ extern "C" { /* * Class: com_intel_gkl_pairhmm_IntelPairHmm * Method: initNative - * Signature: (Ljava/lang/Class;Ljava/lang/Class;ZI)V + * Signature: (Ljava/lang/Class;Ljava/lang/Class;ZIZ)V */ JNIEXPORT void JNICALL Java_com_intel_gkl_pairhmm_IntelPairHmm_initNative - (JNIEnv *, jclass, jclass, jclass, jboolean, jint); +(JNIEnv *, jclass, jclass, jclass, jboolean, jint, jboolean); /* * Class: com_intel_gkl_pairhmm_IntelPairHmm diff --git a/src/main/native/pairhmm/JavaData.h b/src/main/native/pairhmm/JavaData.h index 19fd9ba6..a5b71586 100644 --- a/src/main/native/pairhmm/JavaData.h +++ b/src/main/native/pairhmm/JavaData.h @@ -3,9 +3,13 @@ #include #include "pairhmm_common.h" +#include "shacc_pairhmm.h" +#include "debug.h" class JavaData { public: + + // cache field ids static void init(JNIEnv *env, jclass readDataHolder, jclass haplotypeDataHolder) { m_readBasesFid = getFieldId(env, readDataHolder, "readBases", "[B"); m_readQualsFid = getFieldId(env, readDataHolder, "readQuals", "[B"); @@ -15,24 +19,29 @@ class JavaData { m_haplotypeBasesFid = getFieldId(env, haplotypeDataHolder, "haplotypeBases", "[B"); } + // create array of testcases std::vector getData(JNIEnv *env, jobjectArray& readDataArray, jobjectArray& haplotypeDataArray) { int numReads = env->GetArrayLength(readDataArray); int numHaplotypes = env->GetArrayLength(haplotypeDataArray); - testcases.resize(numReads * numHaplotypes); + m_batch.num_reads = numReads; + m_batch.num_haps = numHaplotypes; + + std::vector haplotypes; + std::vector haplotypeLengths; - std::vector haplotypes(numHaplotypes); - std::vector haplotypeLengths(numHaplotypes); + long total_hap_length = 0; + long total_read_length = 0; // get haplotypes for (int i = 0; i < numHaplotypes; i++) { int length = 0; - haplotypes[i] = getCharArray(env, haplotypeDataArray, i, m_haplotypeBasesFid, length); - haplotypeLengths[i] = length; + haplotypes.push_back(getCharArray(env, haplotypeDataArray, i, m_haplotypeBasesFid, length)); + haplotypeLengths.push_back(length); + total_hap_length += length; } - + // get reads and create testcases - int i = 0; for (int r = 0; r < numReads; r++) { int length = 0; char* reads = getCharArray(env, readDataArray, r, m_readBasesFid, length); @@ -41,33 +50,69 @@ class JavaData { char* delGops = getCharArray(env, readDataArray, r, m_deletionGopFid, length); char* gapConts = getCharArray(env, readDataArray, r, m_overallGcpFid, length); char* readQuals = getCharArray(env, readDataArray, r, m_readQualsFid, length); + total_read_length += length; for (int h = 0; h < numHaplotypes; h++) { - testcases[i].hap = haplotypes[h]; - testcases[i].haplen = haplotypeLengths[h]; - testcases[i].rs = reads; - testcases[i].rslen = readLength; - testcases[i].i = insGops; - testcases[i].d = delGops; - testcases[i].c = gapConts; - testcases[i].q = readQuals; - i++; + testcase tc; + tc.hap = haplotypes[h]; + tc.haplen = haplotypeLengths[h]; + tc.rs = reads; + tc.rslen = readLength; + tc.i = insGops; + tc.d = delGops; + tc.c = gapConts; + tc.q = readQuals; + m_testcases.push_back(tc); } } - return testcases; + m_total_cells = 3 * total_read_length * total_hap_length; + + return m_testcases; } double* getOutputArray(JNIEnv *env, jdoubleArray array) { return getDoubleArray(env, array); } + // create shacc_pairhmm::batch from array of testcases + shacc_pairhmm::Batch getBatch() { + int num_testcases = m_batch.num_reads * m_batch.num_haps; + + // get reads + for (int i = 0; i < num_testcases; i += m_batch.num_haps) { + shacc_pairhmm::Read read; + read.bases = m_testcases[i].rs; + read.length = m_testcases[i].rslen; + read.i = m_testcases[i].i; + read.d = m_testcases[i].d; + read.c = m_testcases[i].c; + read.q = m_testcases[i].q; + m_reads.push_back(read); + } + m_batch.reads = m_reads.data(); + + // get haplotypes + for (int i = 0; i < m_batch.num_haps; i++) { + shacc_pairhmm::Haplotype hap; + DBG("hap #%d len = %d", i, m_testcases[i].haplen); + hap.bases = m_testcases[i].hap; + hap.length = m_testcases[i].haplen; + m_haps.push_back(hap); + } + m_batch.haps = m_haps.data(); + + m_batch.num_cells = m_total_cells; + + return m_batch; + } + void releaseData(JNIEnv *env) { - for (int i = 0; i < byteArrays.size(); i++) { - env->ReleaseByteArrayElements(byteArrays[i].first, byteArrays[i].second, 0); + for (int i = 0; i < m_byteArrays.size(); i++) { + env->ReleaseByteArrayElements(m_byteArrays[i].first, m_byteArrays[i].second, 0); } - for (int i = 0; i < doubleArrays.size(); i++) { - env->ReleaseDoubleArrayElements(doubleArrays[i].first, doubleArrays[i].second, 0); + for (int i = 0; i < m_doubleArrays.size(); i++) { + env->ReleaseDoubleArrayElements(m_doubleArrays[i].first, m_doubleArrays[i].second, 0); } } @@ -88,7 +133,7 @@ class JavaData { env->ThrowNew(env->FindClass("java/lang/OutOfMemoryError"), "Unable to access jbyteArray"); } length = env->GetArrayLength(byteArray); - byteArrays.push_back(std::make_pair(byteArray, primArray)); + m_byteArrays.push_back(std::make_pair(byteArray, primArray)); return (char*)primArray; } @@ -97,13 +142,17 @@ class JavaData { if (primArray == NULL) { env->ThrowNew(env->FindClass("java/lang/OutOfMemoryError"), "Unable to access jdoubleArray"); } - doubleArrays.push_back(std::make_pair(array, primArray)); + m_doubleArrays.push_back(std::make_pair(array, primArray)); return (double*)primArray; } - std::vector testcases; - std::vector > byteArrays; - std::vector > doubleArrays; + shacc_pairhmm::Batch m_batch; + std::vector m_reads; + std::vector m_haps; + std::vector m_testcases; + std::vector > m_byteArrays; + std::vector > m_doubleArrays; + long m_total_cells; static jfieldID m_readBasesFid; static jfieldID m_readQualsFid; diff --git a/src/main/native/pairhmm/debug.h b/src/main/native/pairhmm/debug.h new file mode 100644 index 00000000..59b5061a --- /dev/null +++ b/src/main/native/pairhmm/debug.h @@ -0,0 +1,18 @@ +#ifndef DEBUG_H +#define DEBUG_H + +#include + +#ifdef DEBUG +# define DBG(M, ...) fprintf(stderr, "[DEBUG] (%s:%d) : " M "\n", __FUNCTION__, __LINE__, ##__VA_ARGS__) +# define INFO(M, ...) fprintf(stderr, "[INFO] (%s:%d) : " M "\n", __FUNCTION__, __LINE__, ##__VA_ARGS__) +# define WARN(M, ...) fprintf(stderr, "[WARNING] (%s:%d) : " M "\n", __FUNCTION__, __LINE__, ##__VA_ARGS__) +# define ERROR(M, ...) fprintf(stderr, "[ERROR] (%s:%d) :" M "\n", __FUNCTION__, __LINE__, ##__VA_ARGS__) +#else +# define DBG(M, ...) +# define INFO(M, ...) fprintf(stderr, "[INFO] " M "\n", ##__VA_ARGS__) +# define WARN(M, ...) fprintf(stderr, "[WARNING] " M "\n", ##__VA_ARGS__) +# define ERROR(M, ...) fprintf(stderr, "[ERROR] " M "\n", ##__VA_ARGS__) +#endif + +#endif diff --git a/src/main/native/pairhmm/pairhmm_common.h b/src/main/native/pairhmm/pairhmm_common.h index aee663c5..6e28ce45 100644 --- a/src/main/native/pairhmm/pairhmm_common.h +++ b/src/main/native/pairhmm/pairhmm_common.h @@ -10,9 +10,6 @@ #include #include -#define DBG(M, ...) -//#define DBG(M, ...) fprintf(stdout, "[DEBUG] (%s %s:%d) " M "\n", __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__) - #define CAT(X,Y) X##Y #define CONCAT(X,Y) CAT(X,Y) @@ -22,10 +19,8 @@ typedef struct { int rslen, haplen; - char *q, *i, *d, *c; - char *hap, *rs; - int *ihap; - int *irs; + const char *q, *i, *d, *c; + const char *hap, *rs; } testcase; class ConvertChar { diff --git a/src/main/native/pairhmm/shacc_pairhmm.h b/src/main/native/pairhmm/shacc_pairhmm.h new file mode 100644 index 00000000..2f6e0b01 --- /dev/null +++ b/src/main/native/pairhmm/shacc_pairhmm.h @@ -0,0 +1,38 @@ +#ifndef SHACC_PAIRHMM_H +#define SHACC_PAIRHMM_H + +#ifdef __APPLE__ +#define WEAK __attribute__((weak_import)) +#else +#define WEAK __attribute__((weak)) +#endif + +namespace shacc_pairhmm { + + struct Read { + int length; + const char* bases; + const char* q; + const char* i; + const char* d; + const char* c; + }; + + struct Haplotype { + int length; + const char* bases; + }; + + struct Batch { + int num_reads; + int num_haps; + long num_cells; + Read* reads; + Haplotype* haps; + float* results; + }; + + extern WEAK bool calculate(Batch& batch); +} + +#endif diff --git a/src/main/native/pairhmm/shacc_pairhmm_stub.cc b/src/main/native/pairhmm/shacc_pairhmm_stub.cc new file mode 100644 index 00000000..2581abed --- /dev/null +++ b/src/main/native/pairhmm/shacc_pairhmm_stub.cc @@ -0,0 +1,12 @@ +#include "shacc_pairhmm.h" +#include "debug.h" + +namespace shacc_pairhmm { + +bool calculate(Batch& batch) { + WARN("Using stub version of shacc::calculate()"); + // return false so batch will be computed on the CPU + return false; +} + +} diff --git a/src/test/java/com/intel/gkl/pairhmm/PairHmmUnitTest.java b/src/test/java/com/intel/gkl/pairhmm/PairHmmUnitTest.java index e2e67717..bf87b31c 100644 --- a/src/test/java/com/intel/gkl/pairhmm/PairHmmUnitTest.java +++ b/src/test/java/com/intel/gkl/pairhmm/PairHmmUnitTest.java @@ -1,6 +1,7 @@ package com.intel.gkl.pairhmm; import com.intel.gkl.IntelGKLUtils; +import com.intel.gkl.NativeLibraryLoader; import com.intel.gkl.pairhmm.IntelPairHmm; import org.broadinstitute.gatk.nativebindings.pairhmm.HaplotypeDataHolder; import org.broadinstitute.gatk.nativebindings.pairhmm.PairHMMNativeArguments; @@ -54,6 +55,15 @@ public void simpleTest() { Assert.assertEquals(likelihoodArray[0], expectedResult, 1e-5, "Likelihood not in expected range."); } + @Test(enabled = true) + public void fpgaTest() { + final boolean shaccIsLoaded = NativeLibraryLoader.load(null, "gkl_pairhmm_shacc"); + Assert.assertTrue(shaccIsLoaded); + + final boolean isloaded = new IntelPairHmmFpga().load(null); + Assert.assertTrue(isloaded); + } + @Test(enabled = true) public void omp_Test() { final boolean isSupported = new IntelPairHmmOMP().load(null); @@ -261,4 +271,4 @@ static byte[] normalize(byte[] scores, int min) { } return scores; } -} \ No newline at end of file +}