-
Notifications
You must be signed in to change notification settings - Fork 101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for multiple data types #130
Conversation
This reverts commit 79eb179.
Codecov Report
@@ Coverage Diff @@
## master #130 +/- ##
============================================
- Coverage 47.82% 41.47% -6.36%
============================================
Files 81 67 -14
Lines 5733 5155 -578
Branches 37 0 -37
============================================
- Hits 2742 2138 -604
- Misses 2964 3017 +53
+ Partials 27 0 -27
Continue to review full report at Codecov.
|
Note: code coverage is broken. |
How to use the Value value;
switch (value.Type()) {
case treelite::ADT::ValueImpl::ValueKind::kInt32:
int32_t int32val = treelite::ADT::get<treelite::ADT::Int32Value>(value);
// ...
break;
case treelite::ADT::ValueImpl::ValueKind::kFloat32:
float float32val = treelite::ADT::get<treelite::ADT::Float32Value>(value);
// ...
break;
case treelite::ADT::ValueImpl::ValueKind::kFloat64:
double float64val = treelite::ADT::get<treelite::ADT::Float64Value>(value);
// ...
break;
default:
LOG(FATAL) << "Unknown value variant";
} |
Need to also change generated lib to expose function interface with a different type. |
Now there are at least 4 different possible signatures for the prediction function: int predict(union Entry* data, int pred_margin);
double predict(union Entry* data, int pred_margin);
size_t predict(union Entry* data, int pred_margin, float* result);
size_t predict(union Entry* data, int pred_margin, double* result); This is quickly getting out of control. We should introduce type-erased interface to reduce complexity. |
Indeed. I strongly recommended an 8-bit datatype in here too. For the sake of reducing model size by as much as 75% in compiled C code, I always run my C source through a post processor to convert all floats to uint8_t. I was about to create a pull request that could intelligently determine the smallest data type that could successfully separate the branches of a tree, but this pull request has the potential. |
Hi, Any plans to release this version anytime soon? |
Addresses #95 and #111. Follow-up to #198, #199, #201 Trying again, since #130 failed. This time, I made the Model class to be polymorphic. This way, the amount of pointer indirection is minimized. Summary: Model is an opaque container that wraps the polymorphic handle ModelImpl<ThresholdType, LeafOutputType>. The handle in turn stores the list of trees Tree<ThresholdType, LeafOutputType>. To unbox the Model container and obtain ModelImpl<ThresholdType, LeafOutputType>, use Model::Dispatch(<lambda expression>). Also, upgrade to C++14 to access the generic lambda feature, which proved to be very useful in the dispatching logic for the polymorphic Model class. * Turn the Model and Tree classes into template classes * Revise the string templates so that correct data types are used in the generated C code * Rewrite the model builder class * Revise the zero-copy serializer * Create an abstract matrix class that supports multiple data types (float32, float64 for now). * Move the DMatrix class to the runtime. * Extend the DMatrix class so that it can hold float32 and float64. * Redesign the C runtime API using the DMatrix class. * Ensure accuracy of scikit-learn models. To achieve the best results, use float32 for the input matrix and float64 for the split thresholds and leaf outputs. * Revise the JVM runtime.
Hi there, may I know if this is merged? Just to say I'm still having the wrong predictions using the latest master #222. Thanks! |
@simon19891101 Take a look at #203 |
Addresses #95 and #111.
@trivialfis Your RTTI implementation in XGBoost JSON served as a great inspiration for this work. Thank you.
cc @canonizer @teju85