Skip to content

Commit

Permalink
add platt scale (#266)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiong-zhang authored Dec 1, 2023
1 parent d2708ec commit bde8676
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 0 deletions.
56 changes: 56 additions & 0 deletions pecos/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ def __init__(self, dirname, soname, forced_rebuild=False):
self.link_ann_hnsw_methods()
self.link_mmap_hashmap_methods()
self.link_mmap_valstore_methods()
self.link_calibrator_methods()

def link_xlinear_methods(self):
"""
Expand Down Expand Up @@ -1939,5 +1940,60 @@ def mmap_valstore_init(self, store_type):
raise NotImplementedError(f"store_type={store_type} is not implemented.")
return self.mmap_valstore_fn_dict[store_type]

def link_calibrator_methods(self):
"""
Specify C-lib's score calibration methods arguments and return types.
"""
corelib.fillprototype(
self.clib_float32.c_fit_platt_transform_f32,
None,
[c_uint64, POINTER(c_float), POINTER(c_float), POINTER(c_double)],
)
corelib.fillprototype(
self.clib_float32.c_fit_platt_transform_f64,
None,
[c_uint64, POINTER(c_double), POINTER(c_double), POINTER(c_double)],
)

def fit_platt_transform(self, logits, tgt_prob):
"""Python to C/C++ interface for platt transfrom fit.
Ref: https://www.csie.ntu.edu.tw/~cjlin/papers/plattprob.pdf
Args:
logits (ndarray): 1-d array of logit with length N.
tgt_prob (ndarray): 1-d array of target probability scores within [0, 1] with length N.
Returns:
A, B: coefficients for Platt's scale.
"""
assert isinstance(logits, np.ndarray)
assert isinstance(tgt_prob, np.ndarray)
assert len(logits) == len(tgt_prob)
assert logits.dtype == tgt_prob.dtype

if tgt_prob.min() < 0 or tgt_prob.max() > 1.0:
raise ValueError("Target probability out of bound!")

AB = np.array([0, 0], dtype=np.float64)

if tgt_prob.dtype == np.float32:
clib.clib_float32.c_fit_platt_transform_f32(
len(logits),
logits.ctypes.data_as(POINTER(c_float)),
tgt_prob.ctypes.data_as(POINTER(c_float)),
AB.ctypes.data_as(POINTER(c_double)),
)
elif tgt_prob.dtype == np.float64:
clib.clib_float32.c_fit_platt_transform_f64(
len(logits),
logits.ctypes.data_as(POINTER(c_double)),
tgt_prob.ctypes.data_as(POINTER(c_double)),
AB.ctypes.data_as(POINTER(c_double)),
)
else:
raise ValueError(f"Unsupported dtype: {tgt_prob.dtype}")

return AB[0], AB[1]


clib = corelib(os.path.join(os.path.dirname(os.path.abspath(pecos.__file__)), "core"), "libpecos")
14 changes: 14 additions & 0 deletions pecos/core/libpecos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -651,4 +651,18 @@ extern "C" {
static_cast<mmap_valstore_bytes *>(map_ptr)->batch_get(
n_sub_row, n_sub_col, sub_rows, sub_cols, trunc_val_len, ret, ret_lens, threads);
}

// ==== C Interface of Score Calibrator ====

#define C_FIT_PLATT_TRANSFORM(SUFFIX, VAL_TYPE) \
void c_fit_platt_transform ## SUFFIX( \
size_t num_samples, \
const VAL_TYPE* logits, \
const VAL_TYPE* tgt_probs, \
double* AB \
) { \
pecos::fit_platt_transform(num_samples, logits, tgt_probs, AB[0], AB[1]); \
}
C_FIT_PLATT_TRANSFORM(_f32, float32_t)
C_FIT_PLATT_TRANSFORM(_f64, float64_t)
}
112 changes: 112 additions & 0 deletions pecos/core/utils/newton.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,5 +272,117 @@ namespace pecos {
return cg_iter;
};
};


// Platt scale with given target curve.
// Reference Implementation:
// https://github.com/cjlin1/libsvm/blob/master/svm.cpp

template <typename value_type>
static void fit_platt_transform(size_t num_samples, const value_type *logits, const value_type *tgt_probs, double& A, double& B) {
// hyper parameters
int max_iter = 100; // Maximal number of iterations
double min_step = 1e-10; // Minimal step taken in line search
double sigma = 1e-12; // For numerically strict PD of Hessian
double eps = 1e-6;

int iter;

// Initial Point and Initial Fun Value
A = 0.0; B = 1.0;
double fval = 0.0;

// check for out of bound in tgt_probs
for (size_t i = 0; i < num_samples; i++) {
if (tgt_probs[i] > 1.0 || tgt_probs[i] < 0) {
throw std::runtime_error("fit_platt_transform: target probability out of bound\n");
}
}


for (size_t i = 0; i < num_samples; i++) {
double fApB = logits[i] * A + B;
if (fApB >= 0) {
fval += tgt_probs[i] * fApB + log(1 + exp(-fApB));
} else {
fval += (tgt_probs[i] - 1) * fApB + log(1 + exp(fApB));
}
}
for (iter = 0; iter < max_iter; iter++) {
// Update Gradient and Hessian (use H' = H + sigma I)
double h11 = sigma;
double h22 = sigma; // numerically ensures strict PD
double h21 = 0.0;
double g1 = 0.0;
double g2 = 0.0;

for (size_t i = 0; i < num_samples; i++) {
double fApB = logits[i] * A + B;
double p = 0, q = 0;
if (fApB >= 0) {
p = exp(-fApB) / (1.0 + exp(-fApB));
q = 1.0 / (1.0 + exp(-fApB));
} else {
p = 1.0 / (1.0 + exp(fApB));
q = exp(fApB) / (1.0 + exp(fApB));
}
double d1 = tgt_probs[i] - p;
double d2 = p * q;

h11 += d2 * logits[i] * logits[i];
h22 += d2;
h21 += logits[i] * d2;
g1 += logits[i] * d1;
g2 += d1;
}

// Stopping Criteria
if (fabs(g1) < eps && fabs(g2) < eps)
break;

// Finding Newton direction: -inv(H') * g
double det = h11 * h22 - h21 * h21;
double dA = -(h22 * g1 - h21 * g2) / det;
double dB = -(-h21 * g1 + h11 * g2) / det;
double gd = g1 * dA + g2 * dB;

// Line Search
double stepsize = 1.0;

while (stepsize >= min_step) {
double newA = A + stepsize * dA;
double newB = B + stepsize * dB;

// New function value
double newf = 0.0;
for (size_t i = 0; i < num_samples; i++) {
double fApB = logits[i] * newA + newB;
if (fApB >= 0) {
newf += tgt_probs[i] * fApB + log(1 + exp(-fApB));
} else {
newf += (tgt_probs[i] - 1) * fApB + log(1 + exp(fApB));
}
}
// Check sufficient decrease
if (newf < fval + 0.0001 * stepsize * gd)
{
A = newA;
B = newB;
fval = newf;
break;
} else {
stepsize = stepsize / 2.0;
}
}

if (stepsize < min_step) {
throw std::runtime_error("fit_platt_transform: Line search fails\n");
}
}

if (iter >= max_iter) {
throw std::runtime_error("fit_platt_transform: Reaching maximal iterations\n");
}
}
} // namespace pecos
#endif
20 changes: 20 additions & 0 deletions test/pecos/core/test_clib.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,23 @@ def test_sparse_inner_products():
assert true_vals == approx(
pred_vals, abs=1e-9
), f"true_vals != pred_vals, where X/Y are drm/dcm"


def test_platt_scale():
import numpy as np
from pecos.core import clib

A = 0.25
B = 3.14

orig = np.arange(-15, 15, 1, dtype=np.float32)
tgt = np.array([1.0 / (1 + np.exp(A * t + B)) for t in orig], dtype=np.float32)
At, Bt = clib.fit_platt_transform(orig, tgt)
assert B == approx(Bt, abs=1e-6), f"Platt_scale B error: {B} != {Bt}"
assert A == approx(At, abs=1e-6), f"Platt_scale A error: {A} != {At}"

orig = np.arange(-15, 15, 1, dtype=np.float64)
tgt = np.array([1.0 / (1 + np.exp(A * t + B)) for t in orig], dtype=np.float64)
At, Bt = clib.fit_platt_transform(orig, tgt)
assert B == approx(Bt, abs=1e-6), f"Platt_scale B error: {B} != {Bt}"
assert A == approx(At, abs=1e-6), f"Platt_scale A error: {A} != {At}"

0 comments on commit bde8676

Please sign in to comment.