Skip to content

Commit

Permalink
Update cuml (#45)
Browse files Browse the repository at this point in the history
* chore: Update Dockerfile and add dependencies in /testdata

* chore: Update Dockerfile to remove unused dependencies

* update bindings
  • Loading branch information
getumen authored Jul 31, 2024
1 parent 1f111bb commit bb6f56c
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 32 deletions.
4 changes: 2 additions & 2 deletions go/rawcuml4go/device_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ func UseBinningMemoryResource(
}, nil
}

func UseArenaMemoryResource() (
func UseArenaMemoryResource(arena_size uint64) (
*MemoryResource,
error,
) {
var pointer C.DeviceMemoryResource
ret := C.UseArenaMemoryResource(&pointer)
ret := C.UseArenaMemoryResource(&pointer, arena_size)
if ret != 0 {
return nil, ErrGetDeviceMemoryResource
}
Expand Down
14 changes: 1 addition & 13 deletions go/rawcuml4go/fil.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ var (
type FILModel struct {
deviceResource *DeviceResource
pointer C.FILModelHandle
numClass int
}

// NewFILModel
Expand Down Expand Up @@ -65,16 +64,9 @@ func NewFILModel(
return nil, ErrFILModelLoad
}

var numClass uint64
ret = C.FILGetNumClasses(handle, (*C.ulong)(&numClass))
if ret != 0 {
return nil, ErrFILModelLoad
}

return &FILModel{
deviceResource: deviceResource,
pointer: handle,
numClass: int(numClass),
}, nil

}
Expand All @@ -90,7 +82,7 @@ func (m *FILModel) Predict(
if preds == nil {
var predsLen int
if outputClassProbability {
predsLen = numRow * m.numClass
predsLen = numRow * 2
} else {
predsLen = numRow
}
Expand Down Expand Up @@ -121,7 +113,3 @@ func (m *FILModel) Close() error {
}
return nil
}

func (m *FILModel) NumClass() int {
return m.numClass
}
6 changes: 4 additions & 2 deletions rust/src/fil.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ use crate::{
sys::{
bindings::FILModelHandle,
device_resource::DeviceResource,
fil::{fil_free_model, fil_get_num_class, fil_load_model, fil_predict},
fil::{fil_free_model, fil_load_model, fil_predict},
},
};

const FIL_SUPPORTED_CLASS_NUM: usize = 2;

pub enum ModelType {
// XGBoost xgboost model (binary model file)
XGBoost = 0,
Expand Down Expand Up @@ -92,7 +94,7 @@ impl Model {
output_class_probabilities: bool,
) -> Result<Vec<f32>, CumlError> {
let mut preds = if output_class_probabilities {
let num_class = fil_get_num_class(self.model)?;
let num_class = FIL_SUPPORTED_CLASS_NUM;
vec![0f32; num_row * num_class]
} else {
vec![0f32; num_row]
Expand Down
14 changes: 1 addition & 13 deletions rust/src/sys/fil.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use anyhow::{anyhow, Context};
use crate::errors::CumlError;

use super::{
bindings::{FILFreeModel, FILGetNumClasses, FILLoadModel, FILModelHandle, FILPredict},
bindings::{FILFreeModel, FILLoadModel, FILModelHandle, FILPredict},
device_resource::DeviceResource,
};

Expand Down Expand Up @@ -78,15 +78,3 @@ pub fn fil_predict(

Ok(())
}

pub fn fil_get_num_class(model: FILModelHandle) -> Result<usize, CumlError> {
let mut out = 0usize;

let result = unsafe { FILGetNumClasses(model, &mut out) };

if result != 0 {
Err(anyhow!("fail to get num class"))?
}

Ok(out)
}
4 changes: 2 additions & 2 deletions rust/src/sys/memory_resource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ impl MemoryResource {
})
}

pub fn use_arena_memoryg_resource() -> Result<Self, CumlError> {
pub fn use_arena_memoryg_resource(arena_size: usize) -> Result<Self, CumlError> {
let mut resource: DeviceMemoryResource = null_mut();
let ret = unsafe { UseArenaMemoryResource(&mut resource) };
let ret = unsafe { UseArenaMemoryResource(&mut resource, arena_size) };
if ret != 0 {
Err(anyhow!("fail to use arena memory resource"))?
}
Expand Down

0 comments on commit bb6f56c

Please sign in to comment.