diff --git a/go/rawcuml4go/device_memory.go b/go/rawcuml4go/device_memory.go index a5ae13a..c0ee014 100644 --- a/go/rawcuml4go/device_memory.go +++ b/go/rawcuml4go/device_memory.go @@ -65,12 +65,12 @@ func UseBinningMemoryResource( }, nil } -func UseArenaMemoryResource() ( +func UseArenaMemoryResource(arena_size uint64) ( *MemoryResource, error, ) { var pointer C.DeviceMemoryResource - ret := C.UseArenaMemoryResource(&pointer) + ret := C.UseArenaMemoryResource(&pointer, arena_size) if ret != 0 { return nil, ErrGetDeviceMemoryResource } diff --git a/go/rawcuml4go/fil.go b/go/rawcuml4go/fil.go index a58a498..bee8487 100644 --- a/go/rawcuml4go/fil.go +++ b/go/rawcuml4go/fil.go @@ -19,7 +19,6 @@ var ( type FILModel struct { deviceResource *DeviceResource pointer C.FILModelHandle - numClass int } // NewFILModel @@ -65,16 +64,9 @@ func NewFILModel( return nil, ErrFILModelLoad } - var numClass uint64 - ret = C.FILGetNumClasses(handle, (*C.ulong)(&numClass)) - if ret != 0 { - return nil, ErrFILModelLoad - } - return &FILModel{ deviceResource: deviceResource, pointer: handle, - numClass: int(numClass), }, nil } @@ -90,7 +82,7 @@ func (m *FILModel) Predict( if preds == nil { var predsLen int if outputClassProbability { - predsLen = numRow * m.numClass + predsLen = numRow * 2 } else { predsLen = numRow } @@ -121,7 +113,3 @@ func (m *FILModel) Close() error { } return nil } - -func (m *FILModel) NumClass() int { - return m.numClass -} diff --git a/rust/src/fil.rs b/rust/src/fil.rs index 4f890bd..9bd3944 100644 --- a/rust/src/fil.rs +++ b/rust/src/fil.rs @@ -5,10 +5,12 @@ use crate::{ sys::{ bindings::FILModelHandle, device_resource::DeviceResource, - fil::{fil_free_model, fil_get_num_class, fil_load_model, fil_predict}, + fil::{fil_free_model, fil_load_model, fil_predict}, }, }; +const FIL_SUPPORTED_CLASS_NUM: usize = 2; + pub enum ModelType { // XGBoost xgboost model (binary model file) XGBoost = 0, @@ -92,7 +94,7 @@ impl Model { output_class_probabilities: bool, ) -> Result, CumlError> { let mut preds = if output_class_probabilities { - let num_class = fil_get_num_class(self.model)?; + let num_class = FIL_SUPPORTED_CLASS_NUM; vec![0f32; num_row * num_class] } else { vec![0f32; num_row] diff --git a/rust/src/sys/fil.rs b/rust/src/sys/fil.rs index ef81209..72d78dc 100644 --- a/rust/src/sys/fil.rs +++ b/rust/src/sys/fil.rs @@ -5,7 +5,7 @@ use anyhow::{anyhow, Context}; use crate::errors::CumlError; use super::{ - bindings::{FILFreeModel, FILGetNumClasses, FILLoadModel, FILModelHandle, FILPredict}, + bindings::{FILFreeModel, FILLoadModel, FILModelHandle, FILPredict}, device_resource::DeviceResource, }; @@ -78,15 +78,3 @@ pub fn fil_predict( Ok(()) } - -pub fn fil_get_num_class(model: FILModelHandle) -> Result { - let mut out = 0usize; - - let result = unsafe { FILGetNumClasses(model, &mut out) }; - - if result != 0 { - Err(anyhow!("fail to get num class"))? - } - - Ok(out) -} diff --git a/rust/src/sys/memory_resource.rs b/rust/src/sys/memory_resource.rs index cba885d..df88f2c 100644 --- a/rust/src/sys/memory_resource.rs +++ b/rust/src/sys/memory_resource.rs @@ -47,9 +47,9 @@ impl MemoryResource { }) } - pub fn use_arena_memoryg_resource() -> Result { + pub fn use_arena_memoryg_resource(arena_size: usize) -> Result { let mut resource: DeviceMemoryResource = null_mut(); - let ret = unsafe { UseArenaMemoryResource(&mut resource) }; + let ret = unsafe { UseArenaMemoryResource(&mut resource, arena_size) }; if ret != 0 { Err(anyhow!("fail to use arena memory resource"))? }