Skip to content

Commit

Permalink
added kernel to calculate block RMSE 2D
Browse files Browse the repository at this point in the history
  • Loading branch information
ammendes committed Sep 12, 2024
1 parent 4249d83 commit 640da93
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 41 deletions.
93 changes: 52 additions & 41 deletions src/main/java/BlockRedundancy2D_.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,17 @@ public class BlockRedundancy2D_ implements PlugIn {
static private CLContext context;

static private CLProgram programGetPatchMeans, programGetPatchCosineSim, programGetPatchDiffStd, programGetPatchPearson,
programGetPatchHu, programGetPatchSsim, programGetRelevanceMap;
programGetPatchHu, programGetPatchSsim, programGetPatchRmse, programGetRelevanceMap;

static private CLKernel kernelGetPatchMeans, kernelGetPatchDiffStd, kernelGetPatchPearson,
kernelGetPatchHu, kernelGetPatchSsim, kernelGetRelevanceMap, kernelGetPatchCosineSim;
kernelGetPatchHu, kernelGetPatchSsim, kernelGetRelevanceMap, kernelGetPatchCosineSim, kernelGetPatchRmse;

static private CLPlatform clPlatformMaxFlop;

static private CLCommandQueue queue;

private CLBuffer<FloatBuffer> clRefPixels, clLocalMeans, clLocalStds, clPatchPixels, clCosineSimMap, clDiffStdMap, clPearsonMap,
clHuMap, clSsimMap, clRelevanceMap, clGaussianWindow;
clHuMap, clSsimMap, clRelevanceMap, clRmseMap;

@Override
public void run(String s) {
Expand Down Expand Up @@ -610,62 +610,73 @@ public void run(String s) {
programGetPatchSsim.release();
}

if (metric == metrics[3]) { // SSIM
showStatus("Calculating SSIM...");
if (metric == metrics[3]) { // NRMSE (inverted)
showStatus("Calculating NRMSE...");

// Build OpenCL program
String programStringGetPatchSsim = getResourceAsString(BlockRedundancy2D_.class, "kernelGetPatchSsim2D.cl");
programStringGetPatchSsim = replaceFirst(programStringGetPatchSsim, "$WIDTH$", "" + w);
programStringGetPatchSsim = replaceFirst(programStringGetPatchSsim, "$HEIGHT$", "" + h);
programStringGetPatchSsim = replaceFirst(programStringGetPatchSsim, "$PATCH_SIZE$", "" + patchSize);
programStringGetPatchSsim = replaceFirst(programStringGetPatchSsim, "$BW$", "" + bW);
programStringGetPatchSsim = replaceFirst(programStringGetPatchSsim, "$BH$", "" + bH);
programStringGetPatchSsim = replaceFirst(programStringGetPatchSsim, "$BRW$", "" + bRW);
programStringGetPatchSsim = replaceFirst(programStringGetPatchSsim, "$BRH$", "" + bRH);
programStringGetPatchSsim = replaceFirst(programStringGetPatchSsim, "$PATCH_MEAN$", "" + patchMeanFloat);
programStringGetPatchSsim = replaceFirst(programStringGetPatchSsim, "$PATCH_STD$", "" + patchStdDev);
programStringGetPatchSsim = replaceFirst(programStringGetPatchSsim, "$EPSILON$", "" + EPSILON);
programGetPatchSsim = context.createProgram(programStringGetPatchSsim).build();
String programStringGetPatchRmse = getResourceAsString(BlockRedundancy2D_.class, "kernelGetPatchRmse2D.cl");
programStringGetPatchRmse = replaceFirst(programStringGetPatchRmse, "$WIDTH$", "" + w);
programStringGetPatchRmse = replaceFirst(programStringGetPatchRmse, "$HEIGHT$", "" + h);
programStringGetPatchRmse = replaceFirst(programStringGetPatchRmse, "$PATCH_SIZE$", "" + patchSize);
programStringGetPatchRmse = replaceFirst(programStringGetPatchRmse, "$BW$", "" + bW);
programStringGetPatchRmse = replaceFirst(programStringGetPatchRmse, "$BH$", "" + bH);
programStringGetPatchRmse = replaceFirst(programStringGetPatchRmse, "$BRW$", "" + bRW);
programStringGetPatchRmse = replaceFirst(programStringGetPatchRmse, "$BRH$", "" + bRH);
programStringGetPatchRmse = replaceFirst(programStringGetPatchRmse, "$PATCH_MEAN$", "" + patchMeanFloat);
programStringGetPatchRmse = replaceFirst(programStringGetPatchRmse, "$EPSILON$", "" + EPSILON);
programGetPatchRmse = context.createProgram(programStringGetPatchRmse).build();
//System.out.println(programGetPatchSsim.getBuildLog()); // Print program build log to check for errors

// Fill OpenCL buffers
clPatchPixels = context.createFloatBuffer(patchSize, READ_ONLY);
fillBufferWithFloatArray(clPatchPixels, patchPixelsFloat);

clSsimMap = context.createFloatBuffer(wh, READ_WRITE);
fillBufferWithFloatArray(clSsimMap, repetitionMap);
clRmseMap = context.createFloatBuffer(wh, READ_WRITE);
fillBufferWithFloatArray(clRmseMap, repetitionMap);

// Create kernel and set args
kernelGetPatchSsim = programGetPatchSsim.createCLKernel("kernelGetPatchSsim2D");
kernelGetPatchRmse = programGetPatchRmse.createCLKernel("kernelGetPatchRmse2D");

argn = 0;
kernelGetPatchSsim.setArg(argn++, clPatchPixels);
kernelGetPatchSsim.setArg(argn++, clRefPixels);
kernelGetPatchSsim.setArg(argn++, clLocalMeans);
kernelGetPatchSsim.setArg(argn++, clLocalStds);
kernelGetPatchSsim.setArg(argn++, clSsimMap);
kernelGetPatchRmse.setArg(argn++, clPatchPixels);
kernelGetPatchRmse.setArg(argn++, clRefPixels);
kernelGetPatchRmse.setArg(argn++, clLocalMeans);
kernelGetPatchRmse.setArg(argn++, clRmseMap);

// Calculate SSIM
// Calculate RMSE
queue.putWriteBuffer(clPatchPixels, true);
queue.putWriteBuffer(clSsimMap, true);
queue.put2DRangeKernel(kernelGetPatchSsim, 0, 0, w, h, 0, 0);
queue.putWriteBuffer(clRmseMap, true);
queue.put2DRangeKernel(kernelGetPatchRmse, 0, 0, w, h, 0, 0);
queue.finish();

// Read SSIM back from the GPU
queue.putReadBuffer(clSsimMap, true);
// Read RMSE back from the GPU
queue.putReadBuffer(clRmseMap, true);
for (int y = 0; y < h; y++) {
for (int x = 0; x < w; x++) {
repetitionMap[y * w + x] = clSsimMap.getBuffer().get(y * w + x);
repetitionMap[y * w + x] = clRmseMap.getBuffer().get(y * w + x);
queue.finish();
}
}
queue.finish();

// Release GPU resources
kernelGetPatchSsim.release();
kernelGetPatchRmse.release();
clPatchPixels.release();
clSsimMap.release();
programGetPatchSsim.release();
clRmseMap.release();
programGetPatchRmse.release();

// Invert RMSE
for (int y = bRH; y < h-bRH; y++) {
for (int x = bRW; x < w-bRW; x++) {
float rmse = repetitionMap[y*w+x];

if(rmse == 0.0f){ // Special case where RMSE is 0, 1/rmse would be undefined but we want perfect similarity
repetitionMap[y*w+x] = 1.0f;
}else{
repetitionMap[y*w+x] = 1.0f / rmse;
}
}
}
}


Expand Down Expand Up @@ -751,18 +762,18 @@ public void run(String s) {
// ----------------------------------------------------------------------- //

// Find min and max within the relevance mask
float pearsonMin = Float.MAX_VALUE;
float pearsonMax = -Float.MAX_VALUE;
float repetitionMin = Float.MAX_VALUE;
float repetitionMax = -Float.MAX_VALUE;

for (int j = bRH; j < h - bRH; j++) {
for (int i = bRW; i < w - bRW; i++) {
if (relevanceMap[j * w + i] > noiseMeanVar * filterConstant) {
float pixelValue = repetitionMap[j * w + i];
if (pixelValue > pearsonMax) {
pearsonMax = pixelValue;
if (pixelValue > repetitionMax) {
repetitionMax = pixelValue;
}
if (pixelValue < pearsonMin) {
pearsonMin = pixelValue;
if (pixelValue < repetitionMin) {
repetitionMin = pixelValue;
}
}
}
Expand All @@ -772,7 +783,7 @@ public void run(String s) {
for (int j = bRH; j < h - bRH; j++) {
for (int i = bRW; i < w - bRW; i++) {
if (relevanceMap[j * w + i] > noiseMeanVar * filterConstant) {
repetitionMap[j * w + i] = (repetitionMap[j * w + i] - pearsonMin) / (pearsonMax - pearsonMin + EPSILON);
repetitionMap[j * w + i] = (repetitionMap[j * w + i] - repetitionMin) / (repetitionMax - repetitionMin + EPSILON);
}
}
}
Expand Down
74 changes: 74 additions & 0 deletions src/main/resources/kernelGetPatchRmse2D.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
//#pragma OPENCL EXTENSION cl_khr_fp64 : enable
#define w $WIDTH$
#define h $HEIGHT$
#define patch_size $PATCH_SIZE$
#define bW $BW$
#define bH $BH$
#define bRW $BRW$
#define bRH $BRH$
#define ref_mean $PATCH_MEAN$
#define EPSILON $EPSILON$

kernel void kernelGetPatchRmse2D(
global float* patch_pixels,
global float* ref_pixels,
global float* local_means,
global float* rmse_map
){

int gx = get_global_id(0);
int gy = get_global_id(1);

// Bound check (avoids borders dynamically based on patch dimensions)
if(gx<bRW || gx>=w-bRW || gy<bRH || gy>=h-bRH){
return;
}


// --------------------------------------------- //
// ---- Get mean-subtracted reference block ---- //
// --------------------------------------------- //

__local float ref_patch[patch_size]; // Make a local copy to avoid slower reads from global memory

for(int i=0; i<patch_size; i++){
ref_patch[i] = patch_pixels[i]; // Block is mean-subtracted in the host Java class
}


// ------------------------------------- //
// ---- Get comparison patch pixels ---- //
// ------------------------------------- //

float comp_patch[patch_size] = {0.0f};
int index = 0;
for(int j=gy-bRH; j<=gy+bRH; j++){
for(int i=gx-bRW; i<=gx+bRW; i++){
float dx = (float)(i-gx);
float dy = (float)(j-gy);
if(((dx*dx)/(float)(bRW*bRW))+((dy*dy)/(float)(bRH*bRH)) <= 1.0f){
comp_patch[index] = ref_pixels[j*w+i];
index++;
}
}
}


// Mean-subtract comparison patch
float comp_mean = local_means[gy*w+gx];
for(int i=0; i<patch_size; i++){
comp_patch[i] = comp_patch[i] - comp_mean;
}


// ------------------------- //
// ---- Calculate NRMSE ---- //
// ------------------------- //

float mse = 0.0f;
for(int i=0; i<patch_size; i++){
mse += (ref_patch[i]-comp_patch[i])*(ref_patch[i]-comp_patch[i]);
}
mse /= (float) patch_size;
rmse_map[gy*w+gx] = sqrt(mse);
}

0 comments on commit 640da93

Please sign in to comment.