forked from aminnj/NanoTools
-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
25 changed files
with
378 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#ifndef MLWRAPPER_H | ||
#define MLWRAPPER_H | ||
|
||
#include <unordered_map> | ||
#include <vector> | ||
|
||
class MLWrapper{ | ||
public: | ||
|
||
MLWrapper(){}; | ||
virtual ~MLWrapper(){}; | ||
|
||
virtual bool build(std::string fname, std::vector<std::string> const& varnames, float missing_entry_val) = 0; | ||
|
||
}; | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
#include <iostream> | ||
#include <limits> | ||
#include <ostream> | ||
//#include <xgboost/c_api.h> | ||
//#include "/cvmfs/cms.cern.ch/slc7_amd64_gcc900/external/py3-xgboost/0.90-ghbfee2/lib/python3.8/site-packages/xgboost/include/xgboost/c_api.h" | ||
#include "XGBoostInterface.hpp" | ||
|
||
XGBoostInterface::XGBoostInterface() : MLWrapper(), booster(nullptr), defval(0) {} | ||
|
||
XGBoostInterface::~XGBoostInterface() | ||
{ | ||
SAFE_XGBOOST(XGBoosterFree(*booster)); | ||
delete booster; | ||
} | ||
|
||
bool XGBoostInterface::build(std::string fname, std::vector<std::string> const &varnames, float missing_entry_val) | ||
{ | ||
|
||
if (booster) | ||
{ | ||
std::cerr << "XGBoostInterface::build: The booster is already built." << endl; | ||
return false; | ||
} | ||
if (fname == "") | ||
{ | ||
std::cerr << "XGBoostInterface::build: The file name is an empty string. This function should be called to load models from a file." << endl; | ||
assert(0); | ||
} | ||
|
||
defval = missing_entry_val; | ||
variable_names = varnames; | ||
|
||
booster = new BoosterHandle; | ||
SAFE_XGBOOST(XGBoosterCreate(nullptr, 0, booster)); | ||
|
||
// std::cout << "XGBoostInterface::build: A new xgboost is created. Loading the model in " << fname << "..." << endl; | ||
|
||
SAFE_XGBOOST(XGBoosterLoadModel(*booster, fname.data())); | ||
|
||
return true; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
#ifndef XGBOOSTINTERFACE_H | ||
#define XGBOOSTINTERFACE_H | ||
|
||
#include <xgboost/c_api.h> | ||
//#include "${XGBOOST_PATH}/include/xgboost/c_api.h" | ||
//#include "/cvmfs/cms.cern.ch/slc7_amd64_gcc900/external/py3-xgboost/0.90-ghbfee2/lib/python3.8/site-packages/xgboost/include/xgboost/c_api.h" | ||
#include "MLWrapper.h" | ||
|
||
class XGBoostInterface : public MLWrapper | ||
{ | ||
protected: | ||
BoosterHandle *booster; | ||
float defval; | ||
std::vector<std::string> variable_names; | ||
|
||
public: | ||
XGBoostInterface(); | ||
virtual ~XGBoostInterface(); | ||
|
||
bool build(std::string fname, std::vector<std::string> const &varnames, float missing_entry_val); | ||
|
||
std::vector<std::string> const &getVariableNames() const { return variable_names; } | ||
|
||
BoosterHandle *const &getBooster() const { return booster; } | ||
|
||
template <typename T> bool eval(std::unordered_map<std::string, float> const &vars, std::vector<T> &res); | ||
template <typename T> bool eval(std::unordered_map<std::string, float> const &vars, T &res); | ||
}; | ||
|
||
#endif | ||
|
Oops, something went wrong.