From 0b4977a1e1280199fe7c171eed46714520d3f607 Mon Sep 17 00:00:00 2001 From: Jerry Ling Date: Fri, 17 Nov 2023 20:22:47 +0100 Subject: [PATCH] Set more sensible default threads count (#196) * Set more sensible default threads count * Don't care about pre 1.6 * Typo * add Nothing back into the union --- src/booster.jl | 3 ++- src/dmatrix.jl | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/booster.jl b/src/booster.jl index 32f94bb..e4df32c 100644 --- a/src/booster.jl +++ b/src/booster.jl @@ -108,6 +108,7 @@ function Booster(cache::AbstractVector{<:DMatrix}; model_file::AbstractString="", tree_method::Union{Nothing,AbstractString}=nothing, validate_parameters::Bool=true, + nthread=Threads.nthreads(), kw... ) o = Ref{BoosterHandle}() @@ -124,7 +125,7 @@ function Booster(cache::AbstractVector{<:DMatrix}; else (tree_method=tree_method,) end - setparams!(b; validate_parameters, tm..., kw...) + setparams!(b; validate_parameters, nthread, tm..., kw...) b end Booster(dm::DMatrix; kw...) = Booster([dm]; kw...) diff --git a/src/dmatrix.jl b/src/dmatrix.jl index f3962e8..7e1fda2 100644 --- a/src/dmatrix.jl +++ b/src/dmatrix.jl @@ -581,13 +581,13 @@ function _unsafe_dataiter_reset(ptr::Ptr) end function _dmatrix_caching_config_json(;cache_prefix::AbstractString, - nthreads::Union{Integer,Nothing}, + nthreads::Union{Integer, Nothing}, missing_value::Float32=NaN32, ) d = Dict("missing"=>"__NAN_STR__", "cache_prefix"=>cache_prefix, ) - isnothing(nthreads) || (d["nthreads"] = nthreads) + isnothing(nthreads) || (d["nthreads"] = string(nthreads)) # this is to strip out the special Float32 values to representations it'll accept nanstr = if isnan(missing_value) "NaN" @@ -603,7 +603,7 @@ end function DMatrix(itr::DataIterator; missing_value::Float32=NaN32, cache_prefix::AbstractString=joinpath(tempdir(),"xgb-cache"), - nthreads::Union{Integer,Nothing}=nothing, + nthreads::Union{Integer, Nothing}=Threads.nthreads(), kw... ) o = Ref{DMatrixHandle}()