Skip to content

Commit

Permalink
Refactored the evaluation of NormContinuous elements
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Oct 13, 2024
1 parent 212333a commit 05c19fc
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import java.util.List;

import com.google.common.base.Function;
import org.dmg.pmml.LinearNorm;
import org.dmg.pmml.NormContinuous;
import org.dmg.pmml.OutlierTreatmentMethod;
Expand Down Expand Up @@ -54,79 +55,55 @@ public Number normalize(NormContinuous normContinuous, Number value){
public <V extends Number> Value<V> normalize(NormContinuous normContinuous, Value<V> value){
List<LinearNorm> linearNorms = ensureLinearNorms(normContinuous);

LinearNorm start = linearNorms.get(0);
LinearNorm end = linearNorms.get(linearNorms.size() - 1);
LinearNorm start;
LinearNorm end;

Number startOrig = start.requireOrig();
Number endOrig = end.requireOrig();

if(value.compareTo(startOrig) < 0 || value.compareTo(endOrig) > 0){
int index = search(linearNorms, LinearNorm::requireOrig, value);
if(index < 0 || index == (linearNorms.size() - 1)){
OutlierTreatmentMethod outlierTreatmentMethod = normContinuous.getOutliers();

switch(outlierTreatmentMethod){
case AS_IS:
// "Extrapolate from the first interval"
if(value.compareTo(startOrig) < 0){
if(index < 0){
start = linearNorms.get(0);
end = linearNorms.get(1);

endOrig = end.requireOrig();
} else

// "Extrapolate from the last interval"
{
start = linearNorms.get(linearNorms.size() - 2);

startOrig = start.requireOrig();
end = linearNorms.get(linearNorms.size() - 1);
}
break;
case AS_MISSING_VALUES:
// "Map to a missing value"
return null;
case AS_EXTREME_VALUES:
// "Map to the value of the first interval"
if(value.compareTo(startOrig) < 0){
Number startNorm = start.requireNorm();
if(index < 0){
start = linearNorms.get(0);

return value.reset(startNorm);
return value.reset(start.requireNorm());
} else

// "Map to the value of the last interval"
{
Number endNorm = end.requireNorm();
end = linearNorms.get(linearNorms.size() - 1);

return value.reset(endNorm);
return value.reset(end.requireNorm());
}
default:
throw new UnsupportedAttributeException(normContinuous, outlierTreatmentMethod);
}
} else

{
for(int i = 1, max = (linearNorms.size() - 1); i < max; i++){
LinearNorm linearNorm = linearNorms.get(i);

Number orig = linearNorm.requireOrig();

if(value.compareTo(orig) >= 0){
start = linearNorm;

startOrig = orig;
} else

if(value.compareTo(orig) <= 0){
end = linearNorm;

endOrig = orig;

break;
}
}
start = linearNorms.get(index);
end = linearNorms.get(index + 1);
}

Number startNorm = start.requireNorm();
Number endNorm = end.requireNorm();

return value.normalize(startOrig, startNorm, endOrig, endNorm);
return value.normalize(start.requireOrig(), start.requireNorm(), end.requireOrig(), end.requireNorm());
}

static
Expand All @@ -142,36 +119,59 @@ public Number denormalize(NormContinuous normContinuous, Number value){
public <V extends Number> Value<V> denormalize(NormContinuous normContinuous, Value<V> value){
List<LinearNorm> linearNorms = ensureLinearNorms(normContinuous);

LinearNorm start = linearNorms.get(0);
LinearNorm end = linearNorms.get(linearNorms.size() - 1);
LinearNorm start;
LinearNorm end;

int index = search(linearNorms, LinearNorm::requireNorm, value);
if(index < 0 || index == (linearNorms.size() - 1)){
throw new NotImplementedException();
} else

{
start = linearNorms.get(index);
end = linearNorms.get(index + 1);
}

Number startNorm = start.requireNorm();
Number endNorm = end.requireNorm();
return value.denormalize(start.requireOrig(), start.requireNorm(), end.requireOrig(), end.requireNorm());
}

static
<V extends Number> int search(List<LinearNorm> linearNorms, Function<LinearNorm, Number> thresholdFunction, Value<V> value){

for(int i = 1, max = (linearNorms.size() - 1); i < max; i++){
for(int i = 0, max = linearNorms.size(); i < max; i++){
LinearNorm linearNorm = linearNorms.get(i);

Number norm = linearNorm.requireNorm();
Number threshold = thresholdFunction.apply(linearNorm);

if(value.compareTo(norm) >= 0){
start = linearNorm;
if(value.compareTo(threshold) >= 0){

startNorm = norm;
} else
if(i < (max - 1)){
LinearNorm nextLinearNorm = linearNorms.get(i + 1);

if(value.compareTo(norm) <= 0){
end = linearNorm;
Number nextThreshold = thresholdFunction.apply(nextLinearNorm);

endNorm = norm;
// Assume a closed-closed range, rather than a closed-open range.
// If the value matches some threshold value exactly,
// then it does not matter which bin (ie. this or the next) is used for interpolation.
if(value.compareTo(nextThreshold) <= 0){
return i;
}

continue;
} else

break;
// The last element
{
return i;
}
} else

{
return -1;
}
}

Number startOrig = start.requireOrig();
Number endOrig = end.requireOrig();

return value.denormalize(startOrig, startNorm, endOrig, endNorm);
throw new IllegalArgumentException();
}

static
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,61 +18,149 @@
*/
package org.jpmml.evaluator;

import java.util.ArrayList;
import java.util.List;

import org.dmg.pmml.LinearNorm;
import org.dmg.pmml.NormContinuous;
import org.dmg.pmml.OutlierTreatmentMethod;
import org.junit.Test;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.fail;

public class NormalizationUtilTest implements Deltas {

@Test
public void normalize(){
NormContinuous normContinuous = createNormContinuous();

assertEquals(BEGIN[1], (double)NormalizationUtil.normalize(normContinuous, BEGIN[0]), DOUBLE_EXACT);
assertEquals(interpolate(1.212d, BEGIN, MIDPOINT), (double)NormalizationUtil.normalize(normContinuous, 1.212d), DOUBLE_EXACT);
assertEquals(MIDPOINT[1], (double)NormalizationUtil.normalize(normContinuous, MIDPOINT[0]), DOUBLE_EXACT);
assertEquals(interpolate(6.5d, MIDPOINT, END), (double)NormalizationUtil.normalize(normContinuous, 6.5d), DOUBLE_EXACT);
assertEquals(END[1], (double)NormalizationUtil.normalize(normContinuous, END[0]), DOUBLE_EXACT);
assertEquals(BEGIN[1], normalize(normContinuous, BEGIN[0]), DOUBLE_EXACT);
assertEquals(interpolate(1.212d, BEGIN, MIDPOINT), normalize(normContinuous, 1.212d), DOUBLE_EXACT);
assertEquals(MIDPOINT[1], normalize(normContinuous, MIDPOINT[0]), DOUBLE_EXACT);
assertEquals(interpolate(6.5d, MIDPOINT, END), normalize(normContinuous, 6.5d), DOUBLE_EXACT);
assertEquals(END[1], normalize(normContinuous, END[0]), DOUBLE_EXACT);
}

@Test
public void normalizeOutliers(){
NormContinuous normContinuous = createNormContinuous();

assertEquals(interpolate(-1d, BEGIN, MIDPOINT), (double)NormalizationUtil.normalize(normContinuous, -1d), DOUBLE_EXACT);
assertEquals(interpolate(12.2d, MIDPOINT, END), (double)NormalizationUtil.normalize(normContinuous, 12.2d), DOUBLE_EXACT);
assertEquals(interpolate(-1d, BEGIN, MIDPOINT), normalize(normContinuous, -1d), DOUBLE_EXACT);
assertEquals(interpolate(12.2d, MIDPOINT, END), normalize(normContinuous, 12.2d), DOUBLE_EXACT);

normContinuous.setOutliers(OutlierTreatmentMethod.AS_MISSING_VALUES);

assertNull(NormalizationUtil.normalize(normContinuous, -1d));
assertNull(NormalizationUtil.normalize(normContinuous, 12.2d));
assertNull(normalize(normContinuous, -1d));
assertNull(normalize(normContinuous, 12.2d));

normContinuous.setOutliers(OutlierTreatmentMethod.AS_EXTREME_VALUES);

assertEquals(BEGIN[1], (double)NormalizationUtil.normalize(normContinuous, -1d), DOUBLE_EXACT);
assertEquals(END[1], (double)NormalizationUtil.normalize(normContinuous, 12.2d), DOUBLE_EXACT);
assertEquals(BEGIN[1], normalize(normContinuous, -1d), DOUBLE_EXACT);
assertEquals(END[1], normalize(normContinuous, 12.2d), DOUBLE_EXACT);
}

@Test
public void denormalize(){
NormContinuous normContinuous = createNormContinuous();

assertEquals(BEGIN[0], (double)NormalizationUtil.denormalize(normContinuous, BEGIN[1]), DOUBLE_EXACT);
assertEquals(0.3d, (double)NormalizationUtil.denormalize(normContinuous, interpolate(0.3d, BEGIN, MIDPOINT)), DOUBLE_EXACT);
assertEquals(MIDPOINT[0], (double)NormalizationUtil.denormalize(normContinuous, MIDPOINT[1]), DOUBLE_EXACT);
assertEquals(7.123d, (double)NormalizationUtil.denormalize(normContinuous, interpolate(7.123d, MIDPOINT, END)), DOUBLE_EXACT);
assertEquals(END[0], (double)NormalizationUtil.denormalize(normContinuous, END[1]), DOUBLE_EXACT);
try {
denormalize(normContinuous, -0.5d);

fail();
} catch(NotImplementedException nie){
// Ignored
}

assertEquals(BEGIN[0], denormalize(normContinuous, BEGIN[1]), DOUBLE_EXACT);
assertEquals(0.3d, denormalize(normContinuous, interpolate(0.3d, BEGIN, MIDPOINT)), DOUBLE_EXACT);
assertEquals(MIDPOINT[0], denormalize(normContinuous, MIDPOINT[1]), DOUBLE_EXACT);
assertEquals(7.123d, denormalize(normContinuous, interpolate(7.123d, MIDPOINT, END)), DOUBLE_EXACT);
assertEquals(END[0], denormalize(normContinuous, END[1]), DOUBLE_EXACT);

try {
denormalize(normContinuous, 1.5d);

fail();
} catch(NotImplementedException nie){
// Ignored
}
}

@Test
public void standardize(){
double mu = 1.5;
double stdev = Math.sqrt(2d);

NormContinuous normContinuous = new NormContinuous("x", null)
.setOutliers(OutlierTreatmentMethod.AS_IS)
.addLinearNorms(
new LinearNorm(0d, -(mu / stdev)),
new LinearNorm(mu, 0d)
);

assertEquals(zScore(-2d, mu, stdev), normalize(normContinuous, -2d), DOUBLE_EXACT);
assertEquals(zScore(-1d, mu, stdev), normalize(normContinuous, -1d), DOUBLE_EXACT);
assertEquals(zScore(0d, mu, stdev), normalize(normContinuous, 0d), DOUBLE_EXACT);
assertEquals(zScore(1d, mu, stdev), normalize(normContinuous, 1d), DOUBLE_EXACT);
assertEquals(zScore(2d, mu, stdev), normalize(normContinuous, 2d), DOUBLE_EXACT);

assertEquals(1d, denormalize(normContinuous, zScore(1d, mu, stdev)), DOUBLE_EXACT);
}

@Test
public void search(){
List<LinearNorm> linearNorms = new ArrayList<>();

linearNorms.add(new LinearNorm(0d, null));
linearNorms.add(new LinearNorm(1d, null));

assertEquals(-1, search(linearNorms, -1d));
assertEquals(0, search(linearNorms, 0d));
assertEquals(0, search(linearNorms, 1d));
assertEquals(1, search(linearNorms, 2d));

linearNorms.add(new LinearNorm(2d, null));

assertEquals(-1, search(linearNorms, -1d));
assertEquals(0, search(linearNorms, 1d));
assertEquals(1, search(linearNorms, 2d));
assertEquals(2, search(linearNorms, 3d));

linearNorms.add(new LinearNorm(3d, null));

assertEquals(-1, search(linearNorms,-1d));
assertEquals(1, search(linearNorms, 2d));
assertEquals(2, search(linearNorms, 3d));
assertEquals(3, search(linearNorms, 4d));
}

static
private Double normalize(NormContinuous normContinuous, double value){
return (Double)NormalizationUtil.normalize(normContinuous, value);
}

static
private Double denormalize(NormContinuous normContinuous, double value){
return (Double)NormalizationUtil.denormalize(normContinuous, value);
}

static
private int search(List<LinearNorm> linearNorms, double value){
return NormalizationUtil.search(linearNorms, LinearNorm::requireOrig, new DoubleValue(value));
}

static
private double interpolate(double x, double[] begin, double[] end){
return begin[1] + (x - begin[0]) / (end[0] - begin[0]) * (end[1] - begin[1]);
}

static
private double zScore(double x, double mu, double stdev){
return (x - mu) / stdev;
}

static
private NormContinuous createNormContinuous(){
NormContinuous result = new NormContinuous("x", null)
Expand Down

0 comments on commit 05c19fc

Please sign in to comment.