Skip to content

Commit

Permalink
Update cuml (#41)
Browse files Browse the repository at this point in the history
* Update cuml

* Update Dockerfile

* Update Dockerfile

* Update devcontainer.json

* update sample dockerfile

* Bump dependencies in /testdata

Update xgboost, scikit-learn, treelite, and tl2cgen dependencies in /testdata to their latest versions.

* Fix assertion decimal in testdata/main.py

* Bump numpy dependency to version <2.0.0

---------

Co-authored-by: ynakazat <[email protected]>
  • Loading branch information
getumen and ynakazat authored Jul 24, 2024
1 parent 9f99176 commit ea654a7
Show file tree
Hide file tree
Showing 10 changed files with 260 additions and 265 deletions.
8 changes: 3 additions & 5 deletions .devcontainer/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
FROM rapidsai/base:23.08-cuda11.8-py3.10
FROM rapidsai/base:24.06-cuda12.2-py3.11

ENV DEBIAN_FRONTEND=noninteractive

USER root

RUN sed -i -r 's@http://(jp\.)?archive\.ubuntu\.com/ubuntu/?@http://ftp.jaist.ac.jp/pub/Linux/ubuntu/@g' /etc/apt/sources.list

ENV CPATH=/opt/conda/include:/opt/conda/include/rapids:/usr/local/include
ENV LIBRARY_PATH=$LIBRARY_PATH:/opt/conda/lib:/opt/conda/lib/rapids:/usr/local/lib
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/lib:/opt/conda/lib/rapids:/usr/local/lib
Expand All @@ -19,7 +17,7 @@ RUN apt-get update \
wget \
libssl-dev \
build-essential \
cuda-toolkit-11-8 \
cuda-toolkit-12-2 \
clang \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
Expand All @@ -33,7 +31,7 @@ RUN wget https://github.com/Kitware/CMake/releases/download/v3.27.3/cmake-3.27.3
&& cd .. \
&& rm -r cmake-3.27.3 cmake-3.27.3.tar.gz

RUN git clone https://github.com/dmlc/treelite.git -b 3.2.0 \
RUN git clone https://github.com/dmlc/treelite.git -b 4.1.2 \
&& cd treelite \
&& mkdir build && cd build \
&& cmake .. \
Expand Down
9 changes: 1 addition & 8 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@
"profile": "complete"
}
},
"hostRequirements": {
"gpu": true
},
"workspaceFolder": "/workspace/${localWorkspaceFolderBasename}",
"workspaceMount": "source=${localWorkspaceFolder},target=/workspace/${localWorkspaceFolderBasename},type=bind",
"customizations": {
Expand All @@ -24,9 +21,5 @@
]
}
},
"runArgs": [
"--gpus",
"all"
],
"remoteUser": "root"
}
}
18 changes: 9 additions & 9 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
FROM nvidia/cuda:12.2.2-devel-ubuntu22.04

ENV DEBIAN_FRONTEND=noninteractive

RUN sed -i -r 's@http://(jp\.)?archive\.ubuntu\.com/ubuntu/?@http://ftp.jaist.ac.jp/pub/Linux/ubuntu/@g' /etc/apt/sources.list

ARG CUML_VERSION=v23.08.00
ARG CUML_VERSION=v24.06.00

RUN apt-get update \
&& apt-get install -y \
Expand All @@ -22,7 +22,7 @@ RUN apt-get update \
liblapack-dev \
zlib1g \
cython3 \
cuda-toolkit-11-8 \
cuda-toolkit-12-2 \
clang \
ccache \
&& apt-get clean \
Expand All @@ -42,7 +42,7 @@ RUN wget https://github.com/Kitware/CMake/releases/download/v3.27.3/cmake-3.27.3
&& cd .. \
&& rm -r cmake-3.27.3 cmake-3.27.3.tar.gz

RUN git clone https://github.com/dmlc/treelite.git -b 3.2.0 \
RUN git clone https://github.com/dmlc/treelite.git -b 4.1.2 \
&& cd treelite \
&& mkdir build && cd build \
&& cmake .. \
Expand All @@ -68,12 +68,12 @@ RUN git clone https://github.com/NVIDIA/nccl.git \
&& rm -rf nccl

# ref. https://github.com/rapidsai/cuml/issues/2528#issuecomment-656847070
RUN wget https://anaconda.org/nvidia/libcumlprims/23.08.00/download/linux-64/libcumlprims-23.08.00-cuda11_230809_g71c0a86_0.tar.bz2 \
RUN wget https://anaconda.org/nvidia/libcumlprims/24.06.00/download/linux-64/libcumlprims-24.06.00-cuda12_240605_gfa5d8ef_0.tar.bz2 \
&& mkdir -p /tmp/libcumlprims/ \
&& tar -xf libcumlprims-23.08.00-cuda11_230809_g71c0a86_0.tar.bz2 -C /tmp/libcumlprims/ \
&& mv /tmp/libcumlprims/include/* /usr/local/include/ \
&& mv /tmp/libcumlprims/lib/* /usr/local/lib/ \
&& rm -rf /tmp/libcumlprims/ libcumlprims-23.08.00-cuda11_230809_g71c0a86_0.tar.bz2 \
&& tar -xf libcumlprims-24.06.00-cuda12_240605_gfa5d8ef_0.tar.bz2 -C /tmp/libcumlprims/ \
&& cp -R /tmp/libcumlprims/include/* /usr/local/include/ \
&& cp -R /tmp/libcumlprims/lib/* /usr/local/lib/ \
&& rm -rf /tmp/libcumlprims/ libcumlprims-24.06.00-cuda12_240605_gfa5d8ef_0.tar.bz2 \
&& git clone https://github.com/rapidsai/cuml.git -b ${CUML_VERSION} \
&& cd cuml/cpp \
&& mkdir build && cd build \
Expand Down
2 changes: 1 addition & 1 deletion testdata/annotation.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
[[114,67,47,65,2,10,37,64,1,0,2,8,2,37,0,63,1,7,1],[114,67,47,65,2,10,37,64,1,0,2,8,2,37,0,7,1],[114,67,47,65,2,10,37,64,1,0,2,8,2,37,0,0,2,8,0],[114,67,47,65,2,10,37,64,1,0,2,8,2,37,0,8,0],[114,67,47,65,2,7,40,64,1,0,2,6,1,3,37,2,0,2,1],[114,67,47,65,2,10,37,64,1,0,2,8,2,36,1,1,1,7,1],[114,67,47,65,2,7,40,64,1,0,2,6,1,3,37,62,2,2,1],[114,67,47,65,2,10,37,64,1,0,2,8,2,34,3,63,1,0,2,7,1],[114,67,47,65,2,7,40,64,1,0,2,6,1,3,37,64,0,2,1],[114,67,47,65,2,7,40,64,1,0,2,6,1,3,37,62,2,2,0,2,1,36,1],[114,67,47,65,2,10,37,64,1,0,2,8,2,36,1,64,0,2,0,7,1],[114,67,47,65,2,7,40,64,1,0,2,6,1,3,37,62,2,2,1,36,1],[114,67,47,65,2,10,37,64,1,0,2,8,2,36,1,64,0,2,0,7,1],[114,67,47,65,2,7,40,64,1,0,2,6,1,3,37,62,2,2,1,0,37],[114,67,47,65,2,10,37,64,1,0,2,8,2,34,3,62,2,0,2,7,1],[114,67,47,65,2,7,40,64,1,0,2,6,1,3,37,64,0,2,0,2,1,36,1],[114,67,47,65,2,10,37,64,1,0,2,8,2,36,1,62,2,7,1],[114,75,39,71,4,0,39,68,3,1,3,3,36,55,13,1,2,1,2,11,2],[114,67,47,65,2,7,40,64,1,0,2,6,1,3,37,62,2,2,1,1,36],[114,67,47,65,2,7,40,64,1,0,2,6,1,3,37,64,0,2,0,2,1,0,37],[114,75,39,71,4,0,39,68,3,1,3,3,36,55,13,1,2,1,2,11,2],[114,67,47,65,2,7,40,64,1,0,2,6,1,3,37,62,2,2,1,36,1],[114,75,39,71,4,0,39,68,3,1,3,3,36,55,13,1,2,1,2,11,2],[114,67,47,65,2,10,37,64,1,0,2,8,2,36,1,62,2,7,1],[114,67,47,65,2,7,40,64,1,0,2,6,1,3,37,64,0,0,2,2,1,36,1],[114,75,39,71,4,0,39,68,3,1,3,3,36,53,15,1,2,1,2,13,2,12,1],[114,67,47,65,2,10,37,64,1,0,2,8,2,36,1,62,2,7,1],[114,67,47,65,2,7,40,64,1,0,2,6,1,3,37,62,2,2,0,2,1,36,1],[114,75,39,71,4,0,39,68,3,1,3,3,36,53,15,1,2,1,2,13,2,12,1],[114,67,47,65,2,13,34,64,1,0,2,10,3,64,0,6,4],[114,75,39,71,4,0,39,68,3,1,3,3,36,53,15,1,2,1,2,13,2,12,1],[114,67,47,65,2,7,40,64,1,0,2,6,1,3,37,62,2,2,1,1,36],[114,75,39,71,4,0,39,68,3,1,3,3,36,53,15,1,2,1,2,13,2,12,1],[114,67,47,65,2,13,34,64,1,0,2,10,3,62,2,10,0],[114,75,39,71,4,0,39,68,3,1,3,3,36,53,15,1,2,1,2,13,2,12,1],[114,67,47,65,2,13,34,64,1,0,2,10,3,64,0,10,0],[114,75,39,71,4,0,39,68,3,1,3,3,36,53,15,1,2,1,2,13,2,12,1],[114,67,47,65,2,13,34,64,1,0,2,10,3,62,2,0,2,7,3,1,61],[114,75,39,71,4,0,39,68,3,1,3,3,36,53,15,1,2,1,2,13,2,12,1],[114,67,47,65,2,13,34,64,1,0,2,10,3,62,2,6,4,1,61],[114,75,39,71,4,0,39,68,3,1,3,3,36,53,15,1,2,1,2,13,2,12,1],[114,67,47,65,2,13,34,64,1,0,2,10,3,62,2,7,3,1,61],[114,67,47,65,2,10,37,64,1,0,2,8,2,36,1,64,0,0,2,7,1,1,63],[114,75,39,71,4,2,37,68,3,1,3,1,1,2,35,53,15,1,2,13,2,12,1],[114,67,47,65,2,13,34,64,1,0,2,10,3,62,2,10,0,1,61],[114,75,39,71,4,2,37,68,3,1,3,1,1,2,35,53,15,1,2,13,2,12,1],[114,67,47,65,2,13,34,64,1,0,2,10,3,62,2,6,4,1,61],[114,75,39,71,4,2,37,70,1,0,4,1,1,2,35,68,2,2,2,3,32,1,67],[114,67,47,65,2,13,34,64,1,0,2,10,3,63,1,7,3,1,62],[114,75,39,71,4,2,37,70,1,0,4,1,1,3,34,68,2,2,2,3,31,1,67],[114,67,47,65,2,10,37,64,1,0,2,8,2,34,3,63,1,7,1,1,62],[114,67,47,65,2,13,34,64,1,0,2,10,3,63,1,2,0,6,4,1,62],[114,75,39,71,4,2,37,70,1,0,4,1,1,2,35,68,2,2,2,3,32,1,67],[114,75,39,71,4,2,37,70,1,0,4,1,1,2,35,68,2,2,2,3,32,1,67],[114,67,47,65,2,13,34,64,1,0,2,10,3,64,0,6,4,1,63],[114,75,39,71,4,2,37,70,1,0,4,1,1,2,35,68,2,2,2,3,32,1,67],[114,67,47,65,2,13,34,64,1,0,2,10,3,62,2,6,4,1,61],[114,75,39,71,4,2,37,70,1,0,4,1,1,2,35,68,2,2,2,3,32,1,67],[114,75,39,71,4,2,37,70,1,0,4,1,1,3,34,68,2,2,2,3,31,1,67],[114,67,47,65,2,13,34,64,1,0,2,10,3,62,2,7,3,1,61],[114,75,39,71,4,2,37,70,1,0,4,1,1,2,35,68,2,2,2,3,32,1,67],[114,75,39,71,4,2,37,70,1,0,4,1,1,2,35,68,2,2,2,2,33,1,67],[114,67,47,65,2,13,34,64,1,0,2,10,3,62,2,6,4,1,61],[114,75,39,71,4,2,37,70,1,0,4,1,1,3,34,68,2,2,2,3,31,1,67],[114,75,39,71,4,2,37,70,1,0,4,1,1,2,35,68,2,2,2,3,32,1,67],[114,67,47,65,2,13,34,64,1,0,2,10,3,64,0,0,0,6,4,1,63],[114,75,39,71,4,2,37,70,1,0,4,1,1,2,35,68,2,2,2,0,35,1,67],[114,75,39,71,4,2,37,70,1,0,4,1,1,3,34,68,2,2,2,0,34,1,67],[114,67,47,65,2,13,34,64,1,0,2,10,3,59,5,0,0,6,4,1,58,3,2],[114,75,39,71,4,2,37,70,1,0,4,1,1,2,35,68,2,2,2,3,32,1,67],[114,67,47,65,2,10,37,64,1,0,2,8,2,36,1,59,5,0,0,7,1,1,58,3,2],[114,67,47,65,2,13,34,64,1,0,2,10,3,63,1,0,0,6,4,1,62],[114,75,39,71,4,2,37,70,1,0,4,1,1,2,35,68,2,2,2,3,32,1,67],[114,75,39,71,4,2,37,70,1,0,4,1,1,3,34,68,2,2,2,3,0,1,67],[114,67,47,65,2,13,34,64,1,0,2,10,3,59,5,0,0,6,4,1,58,3,2],[114,75,39,71,4,2,37,70,1,0,4,1,1,3,34,68,2,2,2,3,0,1,67],[114,75,39,71,4,2,37,70,1,0,4,1,1,3,34,68,2,2,2,3,0,1,67],[114,67,47,65,2,13,34,64,1,0,2,10,3,59,5,0,0,6,4,1,58,3,2],[114,75,39,71,4,2,37,70,1,0,4,1,1,3,34,68,2,2,2,3,0,1,67],[114,67,47,65,2,10,37,64,1,0,2,8,2,34,3,59,5,0,0,7,1,1,58,4,1],[114,67,47,65,2,13,34,64,1,0,2,10,3,59,5,0,0,6,4,1,2,1,58,4,1],[114,75,39,71,4,2,37,70,1,0,4,1,1,3,34,68,2,2,2,3,0,1,67],[114,75,39,71,4,2,37,70,1,0,4,1,1,3,34,59,11,2,2,3,0,8,3,1,2],[114,67,47,65,2,10,37,64,1,0,2,8,2,36,1,59,5,0,0,7,1,1,58,3,2],[114,67,47,65,2,13,34,64,1,0,2,10,3,59,5,0,0,6,4,1,2,1,58,4,1],[114,75,39,71,4,2,37,70,1,0,4,1,1,3,34,59,11,2,2,1,2,8,3,1,2],[114,67,47,65,2,13,34,64,1,0,2,10,3,59,5,0,0,10,0,1,2,1,58,2,3],[114,75,39,71,4,2,37,70,1,0,4,1,1,3,34,59,11,2,2,1,2,8,3,1,2],[114,67,47,65,2,10,37,64,1,0,2,2,8,34,3,59,5,0,0,4,4,1,58,2,3],[114,75,39,71,4,2,37,70,1,0,4,1,1,3,34,59,11,2,2,3,0,8,3,1,2],[114,67,47,65,2,13,34,64,1,0,2,10,3,59,5,0,0,10,0,1,2,1,58,2,3],[114,75,39,71,4,2,37,70,1,0,4,1,1,3,34,59,11,2,2,3,0,8,3,1,2],[114,67,47,65,2,13,34,64,1,0,2,10,3,59,5,0,0,10,0,1,2,1,58,2,3],[114,75,39,71,4,2,37,70,1,0,4,1,1,3,34,59,11,2,2,3,0,8,3,2,1],[114,67,47,65,2,10,37,64,1,0,2,2,8,36,1,59,5,0,0,4,4,1,58,2,3],[114,67,47,65,2,13,34,64,1,0,2,10,3,59,5,0,0,5,5,1,2,1,58,2,3],[114,75,39,71,4,2,37,56,15,1,3,1,1,3,34,11,4,3,0,1,3,2,1],[114,67,47,65,2,13,34,64,1,0,2,10,3,59,5,0,0,5,5,1,2,1,58,2,3],[114,75,39,71,4,2,37,56,15,1,3,1,1,2,35,11,4,1,3,2,1],[114,67,47,65,2,13,34,64,1,0,2,10,3,59,5,0,0,5,5,1,2,1,58,2,3]]
[[114,70,44,66,4,5,39,63,3,1,3,53,10,7,3],[114,70,44,66,4,5,39,63,3,0,4,61,2,2,2,2,0],[114,70,44,66,4,5,39,65,1,0,4,63,2,2,2,2,0],[114,61,53,59,2,11,42,58,1,0,2,9,2,36,6,56,2,0,2,7,2],[114,61,53,59,2,11,42,58,1,0,2,9,2,39,3,56,2,0,2,7,2],[114,61,53,59,2,11,42,58,1,0,2,9,2,39,3,58,0,2,0,8,1],[114,61,53,59,2,11,42,58,1,0,2,9,2,36,6,58,0,0,2,7,2,1,57],[114,61,53,59,2,11,42,58,1,0,2,9,2,39,3,56,2,0,2,8,1,1,55],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,58,0,2,0,0,1],[114,61,53,59,2,11,42,58,1,0,2,9,2,36,6,56,2,0,2,7,2,1,55],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,58,0,0,2,0,1,1,57],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,58,0,0,2,0,1,1,57],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,58,0,2,0,0,1,1,57],[114,61,53,59,2,11,42,58,1,0,2,9,2,36,6,58,0,0,2,7,2,1,57],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,56,2,0,2,0,1,1,55],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,58,0,1,1,0,1,1,57],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,58,0,0,2,0,1,1,57],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,56,2,2,0,0,1],[114,61,53,59,2,11,42,58,1,0,2,8,3,39,3,58,0,0,2,1,2,1,57],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,56,2,0,2,0,1,1,55],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,58,0,1,1,0,1,1,57],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,58,0,0,2,0,1,1,57],[114,61,53,59,2,11,42,58,1,0,2,8,3,39,3,58,0,2,0,1,2,1,57],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,58,0,0,2,0,1,1,57],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,56,2,1,1,0,1,1,55],[114,61,53,59,2,11,42,58,1,0,2,8,3,36,6,57,1,0,2,1,2,1,56],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,58,0,0,2,0,1,1,57],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,58,0,1,1,0,1,1,57],[114,61,53,59,2,11,42,58,1,0,2,8,3,36,6,56,2,2,0,1,2,1,55],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,57,1,0,1,1,56],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,58,0,0,2,0,1,1,57],[114,61,53,59,2,11,42,58,1,0,2,8,3,39,3,58,0,2,0,1,2,1,57],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,56,2,0,1,1,55],[114,61,53,59,2,11,42,58,1,0,2,8,3,39,3,55,3,2,0,1,2,1,54,1,2],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,58,0,0,2,0,1,1,57],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,55,3,0,1,1,54,2,1],[114,61,53,59,2,11,42,58,1,0,2,8,3,36,6,58,0,2,0,1,2,1,57],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,55,3,0,1,1,54,1,2],[114,61,53,59,2,11,42,58,1,0,2,8,3,39,3,55,3,2,0,1,2,1,54,1,2],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,56,2,0,2,0,1,1,55],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,55,3,0,1,1,38,1,54,1,2],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,57,1,2,0,0,1,0,39,1,56],[114,61,53,59,2,11,42,58,1,0,2,8,3,36,6,55,3,1,2,1,54,2,1],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,55,3,2,0,0,1,1,38,1,54,1,2],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,58,0,2,0,0,1,0,39,1,57],[114,61,53,59,2,11,42,58,1,0,2,8,3,36,6,55,3,1,2,1,54,1,2],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,56,2,0,2,0,1,1,38,1,55],[114,61,53,59,2,19,34,58,1,0,2,12,7,55,3,8,4,1,54,2,1],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,58,0,2,0,0,1,0,39,1,57],[114,61,53,59,2,11,42,58,1,0,2,8,3,39,3,55,3,2,0,1,2,1,54,1,2],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,55,3,0,1,1,38,1,54,1,2],[114,70,44,66,4,39,5,65,1,0,4,3,36,63,2,2,2,2,34,1,62],[114,61,53,59,2,19,34,58,1,0,2,12,7,55,3,8,4,1,54,1,2],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,55,3,0,2,0,1,1,38,1,54,2,1],[114,70,44,66,4,39,5,65,1,0,4,3,36,57,8,2,2,2,34,5,3,2,1],[114,61,53,59,2,19,34,58,1,0,2,12,7,55,3,8,4,1,54,1,2],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,55,3,2,0,0,1,1,38,1,54,2,1],[114,70,44,66,4,39,5,65,1,0,4,3,36,57,8,2,2,2,34,5,3,2,1],[114,61,53,59,2,19,34,58,1,0,2,12,7,55,3,9,3,1,54,1,2],[114,61,53,59,2,19,34,58,1,0,2,12,7,55,3,0,2,10,2,1,54,1,2],[114,70,44,66,4,39,5,65,1,0,4,3,36,57,8,2,2,2,34,5,3,1,2],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,55,3,0,1,1,38,1,54,1,2],[114,61,53,59,2,19,34,58,1,0,2,12,7,55,3,9,3,1,54,1,2],[114,70,44,66,4,39,5,65,1,0,4,3,36,57,8,2,2,2,34,5,3,2,1],[114,61,53,59,2,19,34,58,1,0,2,12,7,55,3,2,0,8,4,1,54,2,1],[114,70,44,66,4,39,5,65,1,0,4,3,36,57,8,2,2,2,34,5,3,1,2],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,55,3,0,1,1,38,1,54,1,2],[114,61,53,59,2,19,34,58,1,0,2,12,7,55,3,10,2,1,54,2,1],[114,70,44,66,4,39,5,65,1,0,4,3,36,57,8,2,2,2,34,5,3,2,1],[114,61,53,59,2,19,34,58,1,0,2,12,7,55,3,0,2,12,0,1,54,1,2],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,55,3,0,1,0,39,1,54,2,1],[114,70,44,66,4,39,5,65,1,0,4,3,36,57,8,2,2,5,3],[114,61,53,59,2,19,34,58,1,0,2,12,7,55,3,12,0,1,54,1,2],[114,61,53,59,2,13,40,58,1,0,2,8,5,1,39,55,3,0,2,0,1,1,38,1,54,1,2],[114,70,44,66,4,39,5,65,1,0,4,3,36,57,8,2,2,5,3],[114,61,53,59,2,19,34,58,1,0,2,12,7,55,3,12,0,1,54,1,2],[114,70,44,66,4,6,38,65,1,0,4,4,2,57,8,2,2,5,3],[114,61,53,59,2,19,34,58,1,0,2,12,7,55,3,9,3,1,54,1,2],[114,70,44,66,4,6,38,65,1,0,4,4,2,57,8,2,2,5,3],[114,61,53,59,2,19,34,58,1,0,2,12,7,55,3,9,3,1,54,2,1],[114,70,44,66,4,6,38,65,1,0,4,4,2,57,8,2,2,5,3],[114,61,53,59,2,19,34,58,1,0,2,12,7,55,3,2,0,9,3,1,54,1,2],[114,70,44,66,4,6,38,65,1,0,4,4,2,57,8,2,2,5,3],[114,61,53,59,2,11,42,58,1,0,2,8,3,36,6,55,3,1,2,1,54,2,1],[114,61,53,59,2,19,34,58,1,0,2,12,7,55,3,6,6,1,54,1,2],[114,70,44,66,4,6,38,65,1,0,4,4,2,57,8,2,2,5,3],[114,61,53,59,2,19,34,58,1,0,2,12,7,55,3,0,2,9,3,1,54,1,2],[114,64,50,54,10,6,44,53,1,2,4,2,42,0,2,0,42,2,40,2,0],[114,64,50,54,10,6,44,53,1,2,4,2,42,0,2,0,42,2,40,2,0],[114,70,44,66,4,39,5,65,1,0,4,2,37,57,8,2,2,5,3],[114,64,50,54,10,6,44,53,1,2,4,2,42,0,2,0,42,2,40,2,0],[114,64,50,54,10,6,44,53,1,2,4,2,42,0,2,0,42,2,40,2,0],[114,64,50,54,10,6,44,53,1,2,4,2,42,0,2,0,42,2,40,2,0],[114,64,50,54,10,6,44,53,1,2,4,2,42,0,2,0,42,2,40,2,0],[114,64,50,54,10,6,44,53,1,2,4,2,42,0,2,0,42,2,40,2,0],[114,64,50,54,10,6,44,53,1,2,4,2,42,0,2,0,42,2,40,2,0],[114,70,44,66,4,6,38,65,1,0,4,4,2,57,8,2,2,5,3],[114,64,50,54,10,6,44,53,1,2,4,2,42,0,2,0,42,2,40,2,0],[114,64,50,54,10,6,44,53,1,2,4,2,42,0,2,0,42,2,40,2,0],[114,70,44,66,4,6,38,65,1,0,4,4,2,57,8,2,2,5,3]]
Binary file modified testdata/compiled-model.so
100644 → 100755
Binary file not shown.
21 changes: 12 additions & 9 deletions testdata/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import pandas as pd
import treelite
import treelite_runtime
import tl2cgen
import xgboost as xgb
from sklearn import datasets, model_selection

Expand Down Expand Up @@ -51,38 +51,41 @@

dvalid = xgb.DMatrix(test_x)

# [batch_size]
xgboost_scores = booster.predict(dvalid)
with open("score-xgboost.csv", "w") as f:
for x in xgboost_scores:
print(x, file=f)

dvalid = treelite_runtime.DMatrix(test_x)
dvalid = tl2cgen.DMatrix(test_x)

model = treelite.Model.from_xgboost(booster)

annotator = treelite.Annotator()
annotator.annotate_branch(model=model, dmat=dvalid, verbose=True)
annotator.save(path="annotation.json")
tl2cgen.annotate_branch(model=model, dmat=dvalid, path="annotation.json", verbose=True)

model.export_lib(
tl2cgen.export_lib(
model=model,
toolchain="gcc",
libpath=f"compiled-model.{shared_library_extension}",
verbose=True,
params={
"parallel_comp": os.cpu_count(),
"annotate_in": "annotation.json",
},
verbose=True,
)

predictor = treelite_runtime.Predictor(
predictor = tl2cgen.Predictor(
f"compiled-model.{shared_library_extension}",
nthread=os.cpu_count(),
verbose=True,
)

# [batch_size, 1, 1]
treelite_scores = predictor.predict(dvalid, verbose=True)

treelite_scores = np.squeeze(treelite_scores)
with open("score-treelite.csv", "w") as f:
for x in treelite_scores:
print(x, file=f)

np.testing.assert_array_almost_equal(xgboost_scores, treelite_scores, decimal=5)
np.testing.assert_array_almost_equal(xgboost_scores, treelite_scores, decimal=0)
11 changes: 6 additions & 5 deletions testdata/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
xgboost==1.6.0
scikit-learn==1.4.0
treelite==3.9.1
treelite-runtime==3.9.1
pandas==2.2.1
xgboost==2.1.0
scikit-learn==1.5.1
treelite==4.1.2
tl2cgen==1.0.0
pandas==2.2.2
numpy<2.0.0
Loading

0 comments on commit ea654a7

Please sign in to comment.