-
Notifications
You must be signed in to change notification settings - Fork 4
/
p3a_search.hpp
321 lines (305 loc) · 11.8 KB
/
p3a_search.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
#pragma once
#include "p3a_macros.hpp"
#include "p3a_functions.hpp"
namespace p3a {
template <class T = void>
class less {
public:
[[nodiscard]] P3A_HOST_DEVICE P3A_ALWAYS_INLINE inline constexpr
auto operator()(T const& lhs, T const& rhs ) const
{
return lhs < rhs;
}
};
template <class ForwardIt, class T, class Compare>
[[nodiscard]] P3A_HOST_DEVICE P3A_ALWAYS_INLINE inline
ForwardIt upper_bound(
ForwardIt first,
ForwardIt last,
T const& value,
Compare comp)
{
auto count = last - first;
while (count > 0) {
auto it = first;
auto const step = count / 2;
it += step;
if (!comp(value, *it)) {
first = ++it;
count -= step + 1;
} else {
count = step;
}
}
return first;
}
template<class ForwardIt, class T, class Compare>
[[nodiscard]] P3A_HOST_DEVICE P3A_ALWAYS_INLINE inline
ForwardIt lower_bound(ForwardIt first, ForwardIt last, const T& value, Compare comp)
{
auto count = last - first;
while (count > 0) {
auto it = first;
auto const step = count / 2;
it += step;
if (comp(*it, value)) {
first = ++it;
count -= step + 1;
} else {
count = step;
}
}
return first;
}
template <class ForwardIt, class T>
[[nodiscard]] P3A_HOST_DEVICE P3A_ALWAYS_INLINE inline
ForwardIt upper_bound(
ForwardIt first,
ForwardIt last,
T const& value)
{
return p3a::upper_bound(first, last, value, p3a::less<T>());
}
template <class ForwardIt, class T>
[[nodiscard]] P3A_HOST_DEVICE P3A_ALWAYS_INLINE inline
ForwardIt lower_bound(
ForwardIt first,
ForwardIt last,
T const& value)
{
return p3a::lower_bound(first, last, value, p3a::less<T>());
}
enum class search_errc : int {
success,
desired_value_below_minimum,
desired_value_above_maximum,
exceeded_maximum_iterations
};
/* This code acts as the inverse of a real differentiable function
* in a specified subset of the domain.
* It is given the ability to compute the function and its first derivative.
* It then uses a combination of methods to search the given space until it finds
* a domain value for which the range value is close enough to the desired range value.
* The primary method is Newton's method.
* In cases where Newton's method on the real function would not converge,
* we fall back to using bisection.
*
* The function only needs to be continuous and differentiable, it does not need
* to be monotonic in the given subset of the domain.
*
* Execution speed via minimizing actual function evaluations is a
* primary design goal of this code.
* This is the reason for using Newton's method in the common case
* to do fewer evaluations than something like bisection would.
* This is also the reason for having a separate "function state"
* that the value and derivative are computed from.
* This is because the derivative calculation can often use information
* from the value calculation, so "function state" is a mechanism
* for users to implement that optimization.
*/
template <
class DomainValue,
class RangeValue,
class Tolerance,
class StateFromDomainValue,
class RangeValueFromState,
class DerivativeValueFromState>
[[nodiscard]] P3A_HOST_DEVICE inline
search_errc invert_differentiable_function(
StateFromDomainValue const& state_from_domain_value,
RangeValueFromState const& range_value_from_state,
DerivativeValueFromState const& derivative_value_from_state,
RangeValue const& desired_range_value,
Tolerance const& tolerance,
DomainValue minimum_domain_value,
DomainValue maximum_domain_value,
DomainValue& domain_value)
{
int constexpr maximum_iterations = 100;
auto const state_at_maximum_domain_value = state_from_domain_value(maximum_domain_value);
auto range_value_at_maximum_domain_value = range_value_from_state(state_at_maximum_domain_value);
domain_value = minimum_domain_value;
auto state_at_domain_value = state_from_domain_value(domain_value);
auto range_value_at_minimum_domain_value = range_value_from_state(state_at_domain_value);
auto range_value = range_value_at_minimum_domain_value;
auto derivative_value = derivative_value_from_state(state_at_domain_value);
for (int iteration = 0; iteration < maximum_iterations; ++iteration) {
if (are_close(range_value, desired_range_value, tolerance)) return search_errc::success;
auto const next_domain_value_newton =
domain_value - (range_value - desired_range_value) / derivative_value;
auto const next_domain_value_bisection =
minimum_domain_value + (maximum_domain_value - minimum_domain_value) / 2;
auto const newton_will_not_converge =
(derivative_value == decltype(derivative_value)(0)) ||
(next_domain_value_newton > maximum_domain_value) ||
(next_domain_value_newton < minimum_domain_value);
domain_value =
condition(
newton_will_not_converge,
next_domain_value_bisection,
next_domain_value_newton);
state_at_domain_value = state_from_domain_value(domain_value);
range_value = range_value_from_state(state_at_domain_value);
derivative_value = derivative_value_from_state(state_at_domain_value);
auto const is_new_minimum =
// this is a logical XOR operation, designed to flip the logic if the function
// is decreasing rather than increasing
(!(range_value < desired_range_value)) !=
(!(range_value_at_maximum_domain_value < range_value_at_minimum_domain_value));
minimum_domain_value = condition(is_new_minimum, domain_value, minimum_domain_value);
maximum_domain_value = condition(is_new_minimum, maximum_domain_value, domain_value);
range_value_at_minimum_domain_value = condition(is_new_minimum,
range_value, range_value_at_minimum_domain_value);
range_value_at_maximum_domain_value = condition(is_new_minimum,
range_value_at_maximum_domain_value, range_value);
}
return search_errc::exceeded_maximum_iterations;
}
/* given a set of tabulated values of a continuous real function,
* this code finds an interval such that the tabulated range values
* on either side of that interval bound the desired range value.
*
* It uses binary search (bisection in index space) to do this.
*/
template <
class Index,
class RangeValueFromPoint,
class RangeValue>
[[nodiscard]] P3A_HOST_DEVICE inline
search_errc find_tabulated_interval(
Index const& number_of_points,
RangeValueFromPoint const& range_value_from_point,
RangeValue const& desired_range_value,
Index& interval)
{
auto minimum_point = Index(0);
auto maximum_point = number_of_points - 1;
auto range_value_at_minimum_point = range_value_from_point(minimum_point);
auto range_value_at_maximum_point = range_value_from_point(maximum_point);
auto const minimum_range_value = min(range_value_at_minimum_point, range_value_at_maximum_point);
auto const maximum_range_value = max(range_value_at_minimum_point, range_value_at_maximum_point);
if (desired_range_value < minimum_range_value) return search_errc::desired_value_below_minimum;
if (desired_range_value > maximum_range_value) return search_errc::desired_value_above_maximum;
int constexpr maximum_iterations = 100;
for (int iteration = 0; iteration < maximum_iterations; ++iteration) {
if ((maximum_point - minimum_point) <= Index(1)) {
interval = minimum_point;
return search_errc::success;
}
auto const point = minimum_point + (maximum_point - minimum_point) / 2;
auto const range_value = range_value_from_point(point);
auto const is_new_minimum =
(!(range_value < desired_range_value)) !=
(!(range_value_at_maximum_point < range_value_at_minimum_point));
minimum_point = condition(is_new_minimum, point, minimum_point);
maximum_point = condition(is_new_minimum, maximum_point, point);
range_value_at_minimum_point = condition(is_new_minimum,
range_value, range_value_at_minimum_point);
range_value_at_maximum_point = condition(is_new_minimum,
range_value_at_maximum_point, range_value);
}
return search_errc::exceeded_maximum_iterations;
}
template <
class Index,
class RangeValueFromPoint,
class DomainValueFromPoint,
class RangeValue,
class DomainValue,
class Tolerance,
class StateFunctorFromInterval,
class RangeValueFromState,
class DerivativeValueFromState>
[[nodiscard]] P3A_HOST_DEVICE inline
search_errc invert_piecewise_differentiable_function(
Index const& number_of_points,
RangeValueFromPoint const& range_value_from_point,
DomainValueFromPoint const& domain_value_from_point,
StateFunctorFromInterval const& state_functor_from_interval,
RangeValueFromState const& range_value_from_state,
DerivativeValueFromState const& derivative_value_from_state,
RangeValue const& desired_range_value,
Tolerance const& tolerance,
Index& interval,
DomainValue& domain_value)
{
auto result = find_tabulated_interval(
number_of_points,
range_value_from_point,
desired_range_value,
interval);
if (result != search_errc::success) return result;
auto const state_from_domain_value = state_functor_from_interval(interval);
result = invert_differentiable_function(
state_from_domain_value,
range_value_from_state,
derivative_value_from_state,
desired_range_value,
tolerance,
domain_value_from_point(interval),
domain_value_from_point(interval + Index(1)),
domain_value);
return result;
}
template <class Iterator>
class iterator_as_functor {
Iterator m_iterator;
public:
using difference_type = typename std::iterator_traits<Iterator>::difference_type;
using reference = typename std::iterator_traits<Iterator>::reference;
P3A_HOST_DEVICE P3A_ALWAYS_INLINE inline constexpr
iterator_as_functor(Iterator iterator_arg)
:m_iterator(iterator_arg)
{}
P3A_HOST_DEVICE P3A_ALWAYS_INLINE inline constexpr
reference operator()(difference_type i) const
{
return m_iterator[i];
}
};
template <
class DomainValue,
class RangeValue,
class Tolerance,
class Function>
[[nodiscard]] P3A_HOST_DEVICE inline
search_errc invert_function(
Function const& function,
RangeValue const& desired_range_value,
DomainValue& domain_value,
DomainValue minimum_domain_value,
DomainValue maximum_domain_value,
Tolerance const& tolerance,
int maximum_iterations)
{
auto range_value_at_maximum_domain_value = function(maximum_domain_value);
domain_value = minimum_domain_value;
if (desired_range_value > range_value_at_maximum_domain_value) {
return search_errc::desired_value_above_maximum;
}
auto range_value_at_minimum_domain_value = function(minimum_domain_value);
if (desired_range_value < range_value_at_minimum_domain_value) {
return search_errc::desired_value_below_minimum;
}
auto range_value = range_value_at_minimum_domain_value;
for (int iteration = 0; iteration < maximum_iterations; ++iteration) {
if (are_close(range_value, desired_range_value, tolerance)) return search_errc::success;
auto const next_domain_value_bisection =
minimum_domain_value + (maximum_domain_value - minimum_domain_value) / 2;
domain_value = next_domain_value_bisection;
range_value = function(domain_value);
auto const is_new_minimum =
// this is a logical XOR operation, designed to flip the logic if the function
// is decreasing rather than increasing
(!(range_value < desired_range_value)) !=
(!(range_value_at_maximum_domain_value < range_value_at_minimum_domain_value));
minimum_domain_value = condition(is_new_minimum, domain_value, minimum_domain_value);
maximum_domain_value = condition(is_new_minimum, maximum_domain_value, domain_value);
range_value_at_minimum_domain_value = condition(is_new_minimum,
range_value, range_value_at_minimum_domain_value);
range_value_at_maximum_domain_value = condition(is_new_minimum,
range_value_at_maximum_domain_value, range_value);
}
return search_errc::exceeded_maximum_iterations;
}
}