Skip to content

Commit

Permalink
Refactor the code further
Browse files Browse the repository at this point in the history
  • Loading branch information
scorebot committed Oct 5, 2024
1 parent 0226a55 commit 42d90e6
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 71 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ _PMML4S_ is really easy to use. Just do one or more of the following:
val row = Map("sepal_length" -> "5.1", "sepal_width" -> "3.5", "petal_length" -> "1.4", "petal_width" -> "0.2")

// You need to convert the data to the desired type defined by PMML, and keep the same order as defined in the input schema.
val values = inputSchema.map(x => Utils.toVal(row(x.name), x.dataType))
val values = inputSchema.map(x => Utils.toDataVal(row(x.name), x.dataType))

scala> val result = model.predict(Series.fromSeq(values))
result: org.pmml4s.data.Series = [Iris-setosa,1.0,1.0,0.0,0.0,1],[(predicted_class,string),(probability,double),(probability_Iris-setosa,double),(probability_Iris-versicolor,double),(probability_Iris-virginica,double),(node_id,string)]
Expand Down
118 changes: 67 additions & 51 deletions src/main/scala/org/pmml4s/common/predicates.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package org.pmml4s.common
import org.pmml4s.common.Operator.Operator
import org.pmml4s.data.Series
import org.pmml4s.metadata.Field
import org.pmml4s.util.Utils
import org.pmml4s.xml.ElemTags

object Predication extends Enumeration {
Expand Down Expand Up @@ -99,11 +98,11 @@ class SimplePredicate(
val missing = (v != v)
operator match {
case `lessOrEqual` => if (missing) UNKNOWN else if (v <= value) TRUE else FALSE
case `equal` => if (missing) UNKNOWN else if (v == value) TRUE else FALSE
case `notEqual` => if (missing) UNKNOWN else if (v != value) TRUE else FALSE
case `lessThan` => if (missing) UNKNOWN else if (v < value) TRUE else FALSE
case `greaterThan` => if (missing) UNKNOWN else if (v > value) TRUE else FALSE
case `greaterOrEqual` => if (missing) UNKNOWN else if (v >= value) TRUE else FALSE
case `equal` => if (missing) UNKNOWN else if (v == value) TRUE else FALSE
case `notEqual` => if (missing) UNKNOWN else if (v != value) TRUE else FALSE
case `isMissing` => if (missing) TRUE else FALSE
case `isNotMissing` => if (!missing) TRUE else FALSE
}
Expand All @@ -121,56 +120,72 @@ class CompoundPredicate(

import CompoundPredicate.BooleanOperator._

def eval(input: Series): Predication = booleanOperator match {
case `or` => {
var hasUnknown = false
for (child <- children) {
val r = child.eval(input)
if (r == TRUE)
return TRUE
else if (r == UNKNOWN)
hasUnknown = true
def eval(input: Series): Predication = {
val len = children.length
var i = 0
booleanOperator match {
case `or` => {
var hasUnknown = false
while (i < len) {
val child = children(i)
val r = child.eval(input)
if (r == TRUE)
return TRUE
else if (r == UNKNOWN)
hasUnknown = true

i += 1
}

if (hasUnknown) UNKNOWN else FALSE
}

if (hasUnknown) UNKNOWN else FALSE
}
case `and` => {
var hasUnknown = false
for (child <- children) {
val r = child.eval(input)
if (r == FALSE)
return FALSE
else if (r == UNKNOWN)
hasUnknown = true
case `and` => {
var hasUnknown = false
while (i < len) {
val child = children(i)
val r = child.eval(input)
if (r == FALSE)
return FALSE
else if (r == UNKNOWN)
hasUnknown = true

i += 1
}

if (hasUnknown) UNKNOWN else TRUE
}

if (hasUnknown) UNKNOWN else TRUE
}
case `xor` => {
var count = 0
for (child <- children) {
val r = child.eval(input)
if (r == UNKNOWN)
return UNKNOWN
else if (r == TRUE)
count += 1
case `xor` => {
var count = 0
while (9 < len) {
val child = children(i)
val r = child.eval(input)
if (r == UNKNOWN)
return UNKNOWN
else if (r == TRUE)
count += 1

i += 1
}

if (count % 2 == 1) TRUE else FALSE
}

if (count % 2 == 1) TRUE else FALSE
}
case `surrogate` => {
var isSurrogate = false
for (child <- children) {
val r = child.eval(input)
if (r != UNKNOWN)
return if (r == TRUE) {
if (isSurrogate) SURROGATE else TRUE
} else r
else
isSurrogate = true
case `surrogate` => {
var isSurrogate = false
while (i < len) {
val child = children(i)
val r = child.eval(input)
if (r != UNKNOWN)
return if (r == TRUE) {
if (isSurrogate) SURROGATE else TRUE
} else r
else
isSurrogate = true

i += 1
}

UNKNOWN
}

UNKNOWN
}
}
}
Expand All @@ -188,9 +203,10 @@ class SimpleSetPredicate(

def eval(input: Series): Predication = {
val v = field.encode(input)
val missing = (v != v)
booleanOperator match {
case `isIn` => if (Utils.isMissing(v)) UNKNOWN else if (values.contains(v)) TRUE else FALSE
case `isNotIn` => if (Utils.isMissing(v)) UNKNOWN else if (!values.contains(v)) TRUE else FALSE
case `isIn` => if (missing) UNKNOWN else if (values.contains(v)) TRUE else FALSE
case `isNotIn` => if (missing) UNKNOWN else if (!values.contains(v)) TRUE else FALSE
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/org/pmml4s/model/MiningModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class MiningModel(
x._2.feature match {
case ResultFeature.predictedValue => outputs.predictedValue = x._1
case ResultFeature.probability => x._2.value.foreach(y => {
probabilities += (y -> x._1.asInstanceOf[Double])
probabilities += (y -> x._1.toDouble)
})
}
})
Expand Down
30 changes: 20 additions & 10 deletions src/main/scala/org/pmml4s/model/TreeModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.pmml4s.common._
import org.pmml4s.data.{DataVal, Series}
import org.pmml4s.metadata.{MiningSchema, Output, OutputField, Targets}
import org.pmml4s.transformations.LocalTransformations
import org.pmml4s.util.Utils

import scala.collection.mutable.ArrayBuffer
import scala.collection.{immutable, mutable}
Expand Down Expand Up @@ -91,20 +90,20 @@ class TreeModel(
while (i < len && !hit) {
val c = children(i)
c.eval(series) match {
case Predication.TRUE => {
case Predication.TRUE => {
r = Predication.TRUE
child = c
hit = true
}
case Predication.SURROGATE => {
case Predication.FALSE =>
case Predication.SURROGATE => {
r = Predication.SURROGATE
child = c
hit = true
}
case Predication.UNKNOWN => {
case Predication.UNKNOWN => {
unknown = true
}
case _ =>
}
i += 1
}
Expand Down Expand Up @@ -133,17 +132,24 @@ class TreeModel(
val total = selected.recordCount.getOrElse(Double.NaN)
val candidates = selected.children.filter { x => x.eval(series) == UNKNOWN }
var max = 0.0
for (cls <- classes) {
var i = 0
while (i < numClasses) {
val cls = classes(i)
var conf = 0.0
for (cand <- candidates) {
conf += cand.getConfidence(cls) * cand.recordCount.getOrElse(0.0) / total
var j = 0
while (j < candidates.length) {
val candi = candidates(j)
conf += candi.getConfidence(cls) * candi.recordCount.getOrElse(0.0) / total
j += 1
}

if (conf > max) {
max = conf
outputs.predictedValue = cls
outputs.confidence = conf
}

i += 1
}

done = true
Expand Down Expand Up @@ -213,7 +219,7 @@ class TreeModel(
result(series, outputs)
}

/** The sub-classes can override this method to provide classes of target inside model. */
/** The subclasses can override this method to provide classes of target inside model. */
override def inferClasses: Array[DataVal] = {
firstLeaf.scoreDistributions.classes
}
Expand Down Expand Up @@ -289,8 +295,12 @@ class TreeModel(
i += 1
}

for (child <- candidates)
i = 0
while (i < candidates.length) {
val child = candidates(i)
traverseLeaves(child, series, leaves)
i += 1
}
}
}

Expand Down
18 changes: 14 additions & 4 deletions src/main/scala/org/pmml4s/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,17 @@ object Utils {
// Support such float number, for example "1.0", which is converted into double firstly,
// then converted to integer again.
case _: NumberFormatException => {
s.toDouble.toLong
val d = StringUtils.asDouble(s)
if (d != d) {
null
} else {
d.toLong
}
}
case e: Throwable => throw e
case e: Throwable => null
}
}
case _: NumericType => s.toDouble
case _: NumericType => StringUtils.asDouble(s)
case BooleanType => s.toBoolean
case _ => s
}
Expand All @@ -127,7 +132,12 @@ object Utils {
// Support such float number, for example "1.0", which is converted into double firstly,
// then converted to integer again.
case _: NumberFormatException => {
LongVal(s.toDouble.toLong)
val d = StringUtils.asDouble(s)
if (d != d) {
DataVal.NULL
} else {
LongVal(d.toLong)
}
}
case e: Throwable => DataVal.NULL
}
Expand Down
8 changes: 4 additions & 4 deletions src/main/scala/org/pmml4s/xml/Builder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -502,13 +502,13 @@ trait Builder[T <: Model] extends TransformationsBuilder {
}

def makeMatrix(reader: XMLEventReader, attrs: XmlAttrs): Matrix = {
val kind = attrs.get(AttrTags.KIND).map(MatrixKind.withName(_)).getOrElse(MatrixKind.any)
val kind = attrs.get(AttrTags.KIND).map(MatrixKind.withName).getOrElse(MatrixKind.any)
val nbRows = attrs.getInt(AttrTags.NB_ROWS)
val nbCols = attrs.getInt(AttrTags.NB_COLS)
val diagDefault = attrs.getDouble(AttrTags.DIAG_DEFAULT)
val offDiagDefault = attrs.getDouble(AttrTags.OFF_DIAG_DEFAULT)
val arrays = mutable.ArrayBuilder.make[Array[Double]]
nbRows.foreach(arrays.sizeHint(_))
nbRows.foreach(arrays.sizeHint)
val matCells = mutable.ArrayBuilder.make[MatCell]

traverseElems(reader, ElemTags.MATRIX, {
Expand All @@ -517,7 +517,7 @@ trait Builder[T <: Model] extends TransformationsBuilder {
override def build(reader: XMLEventReader, attrs: XmlAttrs): MatCell = {
val row = attrs.int(AttrTags.ROW)
val col = attrs.int(AttrTags.COL)
val value = extractText(reader, ElemTags.MAT_CELL).toDouble
val value = StringUtils.asDouble(extractText(reader, ElemTags.MAT_CELL))

new MatCell(row, col, value)
}
Expand Down Expand Up @@ -625,7 +625,7 @@ trait Builder[T <: Model] extends TransformationsBuilder {
override def build(reader: XMLEventReader, attrs: XmlAttrs): ComparisonMeasure = {
val kind = ComparisonMeasureKind.withName(attrs(AttrTags.KIND))
val compareFunction =
attrs.get(AttrTags.COMPARE_FUNCTION).map(CompareFunction.withName(_)).getOrElse(CompareFunction.absDiff)
attrs.get(AttrTags.COMPARE_FUNCTION).map(CompareFunction.withName).getOrElse(CompareFunction.absDiff)
val minimum = attrs.getDouble(AttrTags.MINIMUM)
val maximum = attrs.getDouble(AttrTags.MAXIMUM)
var distance: Distance = null
Expand Down

0 comments on commit 42d90e6

Please sign in to comment.