forked from Cell-veto/postlhc
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmarsaglia.hpp
113 lines (98 loc) · 2.95 KB
/
marsaglia.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
// (c) 2015-2016 Sebastian Kapfer <[email protected]>, FAU Erlangen
#ifndef MTW_DISCR_DISTR_HPP_INCLUDED
#define MTW_DISCR_DISTR_HPP_INCLUDED
#include "tools.hpp"
template <typename VALUE>
struct MtwDiscreteSampler
{
private:
typedef VALUE value_t;
typedef unsigned long index_t;
std::vector <double> weights;
double weights_total;
std::vector <value_t> values;
static const unsigned NUM_LEVELS = 5;
static const index_t P_SCALE = 1ull << 30;
std::vector <index_t> table[NUM_LEVELS];
index_t tbase[NUM_LEVELS+1];
index_t sample (index_t u) const
{
for (unsigned n = 0; n != NUM_LEVELS; ++n)
{
if (u < tbase[n+1])
{
u -= tbase[n];
u >>= 6 * (NUM_LEVELS-1-n);
return table[n].at (u);
}
}
std::cerr << "Broken lookup table in MtwDiscreteSampler" << ABORT;
return 0; // squelch warning
}
void init_table ()
{
// convert to scaled representation
std::vector <index_t> scaled_ws;
for (double w : weights)
// + .5 would be fair rounding
// + 1. to over-allocate a bit (compensates for rounding)
scaled_ws.push_back (P_SCALE / weights_total * w + 1.);
// set up lookup tables
index_t entry_weight = P_SCALE;
index_t total_alloc = 0;
for (unsigned n = 0; n != NUM_LEVELS; ++n)
{
table[n].clear ();
tbase[n] = total_alloc;
entry_weight /= 64;
for (index_t i = 0; i != weights.size (); ++i)
{
while (scaled_ws[i] >= entry_weight)
{
scaled_ws[i] -= entry_weight;
table[n].push_back (i);
total_alloc += entry_weight;
}
}
}
tbase[NUM_LEVELS] = total_alloc;
std::cerr << "mtw_init alloci " << total_alloc << ' ' << P_SCALE
<< "\nmtw_init rawrate " << weights_total << '\n';
if (total_alloc < P_SCALE)
std::cerr << "Initializing MtwDiscreteSampler: rounding error" << ABORT;
}
public:
void clear ()
{
weights.clear ();
values.clear ();
}
void add (double weight, const value_t &value)
{
if (weight == 0)
return;
weights.push_back (weight);
values.push_back (value);
}
void finish ()
{
// normalize weights
weights_total = 0.;
for (double w : weights)
weights_total += w;
if (weights.size () == 0)
return;
init_table ();
}
double total_weights () const
{
return weights_total;
}
const double &random_sample (value_t *out, RandomContext *random) const
{
index_t i = sample (random->uint (P_SCALE));
*out = values.at (i);
return weights.at (i);
}
};
#endif /* MTW_DISCR_DISTR_HPP_INCLUDED */