forked from uav4geo/OpenPointClass
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgbm.hpp
60 lines (47 loc) · 1.36 KB
/
gbm.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#ifndef GBM_H
#define GBM_H
#include <LightGBM/config.h>
#include <LightGBM/dataset_loader.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/prediction_early_stop.h>
#include "vendor/json/json.hpp"
#include "features.hpp"
#include "labels.hpp"
#include "constants.hpp"
#include "point_io.hpp"
using json = nlohmann::json;
namespace gbm {
typedef LightGBM::Boosting Boosting;
Boosting *train(const std::vector<std::string> &filenames,
double *startResolution,
int numScales,
int numTrees,
int treeDepth,
double radius,
int maxSamples,
const std::vector<int> &classes
);
struct BoosterParams {
double resolution;
double radius;
int numScales;
};
Boosting *loadBooster(const std::string &modelFilename);
void saveBooster(Boosting *booster, const std::string &modelFilename);
BoosterParams extractBoosterParams(Boosting *booster);
void classify(PointSet &pointSet,
Boosting *booster,
const std::vector<Feature *> &features,
const std::vector<Label> &labels,
Regularization regularization = Regularization::None,
double regRadius = 2.5,
bool useColors = false,
bool unclassifiedOnly = false,
bool evaluate = false,
const std::vector<int> &skip = {},
const std::string &statsFile = "");
}
#endif