diff --git a/src/dmatrix.jl b/src/dmatrix.jl index bd5ec79..7e1fda2 100644 --- a/src/dmatrix.jl +++ b/src/dmatrix.jl @@ -581,13 +581,13 @@ function _unsafe_dataiter_reset(ptr::Ptr) end function _dmatrix_caching_config_json(;cache_prefix::AbstractString, - nthreads::Integer, + nthreads::Union{Integer, Nothing}, missing_value::Float32=NaN32, ) d = Dict("missing"=>"__NAN_STR__", "cache_prefix"=>cache_prefix, ) - d["nthreads"] = string(nthreads) + isnothing(nthreads) || (d["nthreads"] = string(nthreads)) # this is to strip out the special Float32 values to representations it'll accept nanstr = if isnan(missing_value) "NaN" @@ -603,7 +603,7 @@ end function DMatrix(itr::DataIterator; missing_value::Float32=NaN32, cache_prefix::AbstractString=joinpath(tempdir(),"xgb-cache"), - nthreads::Integer=Threads.nthreads(), + nthreads::Union{Integer, Nothing}=Threads.nthreads(), kw... ) o = Ref{DMatrixHandle}()