From 678afbe714456bf6aa65c5fd6ac10a327f54b5bb Mon Sep 17 00:00:00 2001 From: Moelf Date: Wed, 25 Oct 2023 17:57:49 +0200 Subject: [PATCH 1/4] Set more sensible default threads count --- src/booster.jl | 3 ++- src/dmatrix.jl | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/booster.jl b/src/booster.jl index bebb054..9d0d234 100644 --- a/src/booster.jl +++ b/src/booster.jl @@ -103,6 +103,7 @@ function Booster(cache::AbstractVector{<:DMatrix}; model_file::AbstractString="", tree_method::Union{Nothing,AbstractString}=nothing, validate_parameters::Bool=true, + thread=Threads.nthreads(), kw... ) o = Ref{BoosterHandle}() @@ -119,7 +120,7 @@ function Booster(cache::AbstractVector{<:DMatrix}; else (tree_method=tree_method,) end - setparams!(b; validate_parameters, tm..., kw...) + setparams!(b; validate_parameters, thread=thread, tm..., kw...) b end Booster(dm::DMatrix; kw...) = Booster([dm]; kw...) diff --git a/src/dmatrix.jl b/src/dmatrix.jl index f3962e8..bd5ec79 100644 --- a/src/dmatrix.jl +++ b/src/dmatrix.jl @@ -581,13 +581,13 @@ function _unsafe_dataiter_reset(ptr::Ptr) end function _dmatrix_caching_config_json(;cache_prefix::AbstractString, - nthreads::Union{Integer,Nothing}, + nthreads::Integer, missing_value::Float32=NaN32, ) d = Dict("missing"=>"__NAN_STR__", "cache_prefix"=>cache_prefix, ) - isnothing(nthreads) || (d["nthreads"] = nthreads) + d["nthreads"] = string(nthreads) # this is to strip out the special Float32 values to representations it'll accept nanstr = if isnan(missing_value) "NaN" @@ -603,7 +603,7 @@ end function DMatrix(itr::DataIterator; missing_value::Float32=NaN32, cache_prefix::AbstractString=joinpath(tempdir(),"xgb-cache"), - nthreads::Union{Integer,Nothing}=nothing, + nthreads::Integer=Threads.nthreads(), kw... ) o = Ref{DMatrixHandle}() From 5d4c1fc26926c5a2c7f9419e9ad1a6162d265a6b Mon Sep 17 00:00:00 2001 From: Moelf Date: Wed, 25 Oct 2023 17:59:10 +0200 Subject: [PATCH 2/4] Don't care about pre 1.6 --- src/booster.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/booster.jl b/src/booster.jl index 9d0d234..0f163d0 100644 --- a/src/booster.jl +++ b/src/booster.jl @@ -120,7 +120,7 @@ function Booster(cache::AbstractVector{<:DMatrix}; else (tree_method=tree_method,) end - setparams!(b; validate_parameters, thread=thread, tm..., kw...) + setparams!(b; validate_parameters, thread, tm..., kw...) b end Booster(dm::DMatrix; kw...) = Booster([dm]; kw...) From 8c0e36998e2e512c08e0413ba11734b3ff33cac3 Mon Sep 17 00:00:00 2001 From: Moelf Date: Wed, 25 Oct 2023 18:00:18 +0200 Subject: [PATCH 3/4] Typo --- src/booster.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/booster.jl b/src/booster.jl index 0f163d0..4f7ff0c 100644 --- a/src/booster.jl +++ b/src/booster.jl @@ -103,7 +103,7 @@ function Booster(cache::AbstractVector{<:DMatrix}; model_file::AbstractString="", tree_method::Union{Nothing,AbstractString}=nothing, validate_parameters::Bool=true, - thread=Threads.nthreads(), + nthread=Threads.nthreads(), kw... ) o = Ref{BoosterHandle}() @@ -120,7 +120,7 @@ function Booster(cache::AbstractVector{<:DMatrix}; else (tree_method=tree_method,) end - setparams!(b; validate_parameters, thread, tm..., kw...) + setparams!(b; validate_parameters, nthread, tm..., kw...) b end Booster(dm::DMatrix; kw...) = Booster([dm]; kw...) From c54cfcc48e0e556511d3243f219c28004c224fc1 Mon Sep 17 00:00:00 2001 From: Moelf Date: Sun, 5 Nov 2023 23:35:42 +0100 Subject: [PATCH 4/4] add Nothing back into the union --- src/dmatrix.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dmatrix.jl b/src/dmatrix.jl index bd5ec79..7e1fda2 100644 --- a/src/dmatrix.jl +++ b/src/dmatrix.jl @@ -581,13 +581,13 @@ function _unsafe_dataiter_reset(ptr::Ptr) end function _dmatrix_caching_config_json(;cache_prefix::AbstractString, - nthreads::Integer, + nthreads::Union{Integer, Nothing}, missing_value::Float32=NaN32, ) d = Dict("missing"=>"__NAN_STR__", "cache_prefix"=>cache_prefix, ) - d["nthreads"] = string(nthreads) + isnothing(nthreads) || (d["nthreads"] = string(nthreads)) # this is to strip out the special Float32 values to representations it'll accept nanstr = if isnan(missing_value) "NaN" @@ -603,7 +603,7 @@ end function DMatrix(itr::DataIterator; missing_value::Float32=NaN32, cache_prefix::AbstractString=joinpath(tempdir(),"xgb-cache"), - nthreads::Integer=Threads.nthreads(), + nthreads::Union{Integer, Nothing}=Threads.nthreads(), kw... ) o = Ref{DMatrixHandle}()