Skip to content

Commit

Permalink
Support input image from path or data buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
sh1r0 committed Aug 16, 2016
1 parent 06db3bf commit 080c238
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 23 deletions.
57 changes: 48 additions & 9 deletions android/caffe_jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

#include <cblas.h>

#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>

#include "caffe/caffe.hpp"
#include "caffe_mobile.hpp"

Expand All @@ -30,6 +34,29 @@ string jstring2string(JNIEnv *env, jstring jstr) {
return str;
}

/**
* NOTE: byte[] buf = str.getBytes("US-ASCII")
*/
string bytes2string(JNIEnv *env, jbyteArray buf) {
jbyte *ptr = env->GetByteArrayElements(buf, 0);
string s((char *)ptr, env->GetArrayLength(buf));
env->ReleaseByteArrayElements(buf, ptr, 0);
return s;
}

cv::Mat imgbuf2mat(JNIEnv *env, jbyteArray buf, int width, int height) {
jbyte *ptr = env->GetByteArrayElements(buf, 0);
cv::Mat img(height + height / 2, width, CV_8UC1, (unsigned char *)ptr);
cv::cvtColor(img, img, CV_YUV2RGBA_NV21);
env->ReleaseByteArrayElements(buf, ptr, 0);
return img;
}

cv::Mat getImage(JNIEnv *env, jbyteArray buf, int width, int height) {
return (width == 0 && height == 0) ? cv::imread(bytes2string(env, buf), -1)
: imgbuf2mat(env, buf, width, height);
}

JNIEXPORT void JNICALL
Java_com_sh1r0_caffe_1android_1lib_CaffeMobile_setNumThreads(JNIEnv *env,
jobject thiz,
Expand Down Expand Up @@ -71,12 +98,16 @@ JNIEXPORT void JNICALL Java_com_sh1r0_caffe_1android_1lib_CaffeMobile_setScale(
caffe_mobile->SetScale(scale);
}

/**
* NOTE: when width == 0 && height == 0, buf is a byte array
* (str.getBytes("US-ASCII")) which contains the img path
*/
JNIEXPORT jfloatArray JNICALL
Java_com_sh1r0_caffe_1android_1lib_CaffeMobile_getConfidenceScore(
JNIEnv *env, jobject thiz, jstring imgPath) {
JNIEnv *env, jobject thiz, jbyteArray buf, jint width, jint height) {
CaffeMobile *caffe_mobile = CaffeMobile::Get();
vector<float> conf_score =
caffe_mobile->GetConfidenceScore(jstring2string(env, imgPath));
caffe_mobile->GetConfidenceScore(getImage(env, buf, width, height));

jfloatArray result;
result = env->NewFloatArray(conf_score.size());
Expand All @@ -88,14 +119,17 @@ Java_com_sh1r0_caffe_1android_1lib_CaffeMobile_getConfidenceScore(
return result;
}

/**
* NOTE: when width == 0 && height == 0, buf is a byte array
* (str.getBytes("US-ASCII")) which contains the img path
*/
JNIEXPORT jintArray JNICALL
Java_com_sh1r0_caffe_1android_1lib_CaffeMobile_predictImage(JNIEnv *env,
jobject thiz,
jstring imgPath,
jint k) {
Java_com_sh1r0_caffe_1android_1lib_CaffeMobile_predictImage(
JNIEnv *env, jobject thiz, jbyteArray buf, jint width, jint height,
jint k) {
CaffeMobile *caffe_mobile = CaffeMobile::Get();
vector<int> top_k =
caffe_mobile->PredictTopK(jstring2string(env, imgPath), k);
caffe_mobile->PredictTopK(getImage(env, buf, width, height), k);

jintArray result;
result = env->NewIntArray(k);
Expand All @@ -107,12 +141,17 @@ Java_com_sh1r0_caffe_1android_1lib_CaffeMobile_predictImage(JNIEnv *env,
return result;
}

/**
* NOTE: when width == 0 && height == 0, buf is a byte array
* (str.getBytes("US-ASCII")) which contains the img path
*/
JNIEXPORT jobjectArray JNICALL
Java_com_sh1r0_caffe_1android_1lib_CaffeMobile_extractFeatures(
JNIEnv *env, jobject thiz, jstring imgPath, jstring blobNames) {
JNIEnv *env, jobject thiz, jbyteArray buf, jint width, jint height,
jstring blobNames) {
CaffeMobile *caffe_mobile = CaffeMobile::Get();
vector<vector<float>> features = caffe_mobile->ExtractFeatures(
jstring2string(env, imgPath), jstring2string(env, blobNames));
getImage(env, buf, width, height), jstring2string(env, blobNames));

jobjectArray array2D =
env->NewObjectArray(features.size(), env->FindClass("[F"), NULL);
Expand Down
19 changes: 9 additions & 10 deletions android/caffe_mobile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,8 @@ void CaffeMobile::WrapInputLayer(std::vector<cv::Mat> *input_channels) {
}
}

vector<float> CaffeMobile::Forward(const string &filename) {
cv::Mat img = cv::imread(filename, -1);
CHECK(!img.empty()) << "Unable to decode image " << filename;
vector<float> CaffeMobile::Forward(const cv::Mat &img) {
CHECK(!img.empty()) << "img should not be empty";

Blob<float> *input_layer = net_->input_blobs()[0];
input_layer->Reshape(1, num_channels_, input_geometry_.height,
Expand All @@ -218,20 +217,20 @@ vector<float> CaffeMobile::Forward(const string &filename) {
return vector<float>(begin, end);
}

vector<float> CaffeMobile::GetConfidenceScore(const string &img_path) {
return Forward(img_path);
vector<float> CaffeMobile::GetConfidenceScore(const cv::Mat &img) {
return Forward(img);
}

vector<int> CaffeMobile::PredictTopK(const string &img_path, int k) {
const vector<float> probs = Forward(img_path);
vector<int> CaffeMobile::PredictTopK(const cv::Mat &img, int k) {
const vector<float> probs = Forward(img);
k = std::min<int>(std::max(k, 1), probs.size());
return argmax(probs, k);
}

vector<vector<float>>
CaffeMobile::ExtractFeatures(const string &img_path,
CaffeMobile::ExtractFeatures(const cv::Mat &img,
const string &str_blob_names) {
Forward(img_path);
Forward(img);

vector<std::string> blob_names;
boost::split(blob_names, str_blob_names, boost::is_any_of(","));
Expand Down Expand Up @@ -266,7 +265,7 @@ int main(int argc, char const *argv[]) {
CaffeMobile *caffe_mobile =
CaffeMobile::Get(string(argv[1]), string(argv[2]));
caffe_mobile->SetMean(string(argv[3]));
vector<int> top_3 = caffe_mobile->PredictTopK(string(argv[4]), 3);
vector<int> top_3 = caffe_mobile->PredictTopK(cv::imread(string(argv[4]), -1), 3);
for (auto i : top_3) {
std::cout << i << std::endl;
}
Expand Down
8 changes: 4 additions & 4 deletions android/caffe_mobile.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ class CaffeMobile {

void SetScale(const float scale);

vector<float> GetConfidenceScore(const string &img_path);
vector<float> GetConfidenceScore(const cv::Mat &img);

vector<int> PredictTopK(const string &img_path, int k);
vector<int> PredictTopK(const cv::Mat &img, int k);

vector<vector<float>> ExtractFeatures(const string &img_path,
vector<vector<float>> ExtractFeatures(const cv::Mat &img,
const string &str_blob_names);

private:
Expand All @@ -42,7 +42,7 @@ class CaffeMobile {

void WrapInputLayer(std::vector<cv::Mat> *input_channels);

vector<float> Forward(const string &filename);
vector<float> Forward(const cv::Mat &img);

shared_ptr<Net<float>> net_;
cv::Size input_geometry_;
Expand Down

0 comments on commit 080c238

Please sign in to comment.