forked from secretflow/spu
-
Notifications
You must be signed in to change notification settings - Fork 0
/
fxp_approx.cc
738 lines (600 loc) · 24.6 KB
/
fxp_approx.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
// Copyright 2021 Ant Group Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "libspu/kernel/hal/fxp_approx.h"
#include <algorithm>
#include <array>
#include <cmath>
#include <future>
#include "libspu/core/trace.h"
#include "libspu/kernel/hal/constants.h"
#include "libspu/kernel/hal/fxp_base.h"
#include "libspu/kernel/hal/fxp_cleartext.h"
#include "libspu/kernel/hal/ring.h"
#include "libspu/kernel/hal/shape_ops.h"
namespace spu::kernel::hal {
namespace detail {
// Pade approximation fo x belongs to [0.5, 1]:
//
// p2524(x) = -0.205466671951 * 10
// + x * -0.88626599391 * 10
// + x^2 * 0.610585199015 * 10
// + x^3 * 0.481147460989 * 10
// q2524(x) = 0.353553425277
// + x * 0.454517087629 * 10
// + x^2 * 0.642784209029 * 10
// + x^3 * 0.1 *10
// log2(x) = p2524(x) / q2524(x)
//
Value log2_pade_normalized(SPUContext* ctx, const Value& x) {
const auto x2 = f_square(ctx, x);
const auto x3 = f_mul(ctx, x2, x);
const auto p0 = constant(ctx, -0.205466671951F * 10, x.dtype(), x.shape());
const auto p1 = constant(ctx, -0.88626599391F * 10, x.dtype(), x.shape());
const auto p2 = constant(ctx, 0.610585199015F * 10, x.dtype(), x.shape());
const auto p3 = constant(ctx, 0.481147460989F * 10, x.dtype(), x.shape());
const auto q0 = constant(ctx, 0.353553425277F, x.dtype(), x.shape());
const auto q1 = constant(ctx, 0.454517087629F * 10, x.dtype(), x.shape());
const auto q2 = constant(ctx, 0.642784209029F * 10, x.dtype(), x.shape());
const auto q3 = constant(ctx, 0.1F * 10, x.dtype(), x.shape());
auto p2524 = _mul(ctx, x, p1);
p2524 = _add(ctx, p2524, _mul(ctx, x2, p2));
p2524 = _add(ctx, p2524, _mul(ctx, x3, p3));
p2524 = _add(ctx, _trunc(ctx, p2524), p0).setDtype(x.dtype());
auto q2524 = _mul(ctx, x, q1);
q2524 = _add(ctx, q2524, _mul(ctx, x2, q2));
q2524 = _add(ctx, q2524, _mul(ctx, x3, q3));
q2524 = _add(ctx, _trunc(ctx, q2524), q0).setDtype(x.dtype());
return detail::div_goldschmidt(ctx, p2524, q2524);
}
// Refer to
// Chapter 5 Exponentiation and Logarithms
// Benchmarking Privacy Preserving Scientific Operations
// https://www.esat.kuleuven.be/cosic/publications/article-3013.pdf
Value log2_pade(SPUContext* ctx, const Value& x) {
SPU_TRACE_HAL_DISP(ctx, x);
const size_t bit_width = SizeOf(ctx->config().field()) * 8;
auto k = _popcount(ctx, _prefix_or(ctx, x), bit_width);
const size_t num_fxp_bits = ctx->getFxpBits();
// let x = x_norm * factor, where x in [0.5, 1.0)
auto msb = detail::highestOneBit(ctx, x);
auto factor = _bitrev(ctx, msb, 0, 2 * num_fxp_bits).setDtype(x.dtype());
detail::hintNumberOfBits(factor, 2 * num_fxp_bits);
auto norm = f_mul(ctx, x, factor);
// log2(x) = log2(x_norm * factor)
// = log2(x_norm) + log2(factor)
// = log2(x_norm) + (k-fxp_bits)
return _add(
ctx, log2_pade_normalized(ctx, norm),
_lshift(ctx, _sub(ctx, k, _constant(ctx, num_fxp_bits, x.shape())),
num_fxp_bits))
.setDtype(x.dtype());
}
// See P11, A.2.4 Logarithm and Exponent,
// https://lvdmaaten.github.io/publications/papers/crypten.pdf
// https://github.com/facebookresearch/CrypTen/blob/master/crypten/common/functions/approximations.py#L55-L104
// Approximates the natural logarithm using 8th order modified
// Householder iterations. This approximation is accurate within 2% relative
// error on [0.0001, 250].
Value log_householder(SPUContext* ctx, const Value& x) {
Value term_1 = f_div(ctx, x, constant(ctx, 120.0, x.dtype(), x.shape()));
Value term_2 = f_mul(
ctx,
f_exp(ctx,
f_negate(ctx, f_add(ctx,
f_mul(ctx, x,
constant(ctx, 2.0, x.dtype(), x.shape())),
constant(ctx, 1.0, x.dtype(), x.shape())))),
constant(ctx, 20.0, x.dtype(), x.shape()));
Value y = f_add(ctx, f_sub(ctx, term_1, term_2),
constant(ctx, 3.0, x.dtype(), x.shape()));
const size_t fxp_log_orders = ctx->config().fxp_log_orders();
SPU_ENFORCE(fxp_log_orders != 0, "fxp_log_orders should not be {}",
fxp_log_orders);
std::vector<float> coeffs;
for (size_t i = 0; i < fxp_log_orders; i++) {
coeffs.emplace_back(1.0 / (1.0 + i));
}
const size_t num_iters = ctx->config().fxp_log_iters();
SPU_ENFORCE(num_iters != 0, "fxp_log_iters should not be {}", num_iters);
for (size_t i = 0; i < num_iters; i++) {
Value h = f_sub(ctx, constant(ctx, 1.0, x.dtype(), x.shape()),
f_mul(ctx, x, f_exp(ctx, f_negate(ctx, y))));
y = f_sub(ctx, y, detail::polynomial(ctx, h, coeffs));
}
return y;
}
// see https://lvdmaaten.github.io/publications/papers/crypten.pdf
// exp(x) = (1 + x / n) ^ n, when n is infinite large.
Value exp_taylor(SPUContext* ctx, const Value& x) {
const size_t fxp_exp_iters = ctx->config().fxp_exp_iters();
SPU_ENFORCE(fxp_exp_iters != 0, "fxp_exp_iters should not be {}",
fxp_exp_iters);
Value res = f_add(ctx, _trunc(ctx, x, fxp_exp_iters).setDtype(x.dtype()),
constant(ctx, 1.0F, x.dtype(), x.shape()));
for (size_t i = 0; i < fxp_exp_iters; i++) {
res = f_square(ctx, res);
}
return res;
}
namespace {
// Pade approximation of exp2(x), x is in [0, 1].
// p1015(x) = 0.100000007744302 * 10
// + x * 0.693147180426163
// + x^2 * 0.240226510710170
// + x^3 * 0.555040686204663 / 10
// + x^4 * 0.961834122588046 / 100
// + x^5 * 0.133273035928143 / 100
Value exp2_pade_normalized(SPUContext* ctx, const Value& x) {
auto x2 = f_mul(ctx, x, x);
auto x3 = f_mul(ctx, x, x2);
auto x4 = f_mul(ctx, x, x3);
auto x5 = f_mul(ctx, x, x4);
const auto p0 = constant(ctx, 0.100000007744302F * 10, x.dtype(), x.shape());
const auto p1 = constant(ctx, 0.693147180426163F, x.dtype(), x.shape());
const auto p2 = constant(ctx, 0.240226510710170F, x.dtype(), x.shape());
const auto p3 = constant(ctx, 0.555040686204663F / 10, x.dtype(), x.shape());
const auto p4 = constant(ctx, 0.961834122588046F / 100, x.dtype(), x.shape());
const auto p5 = constant(ctx, 0.133273035928143F / 100, x.dtype(), x.shape());
auto res = _mul(ctx, x, p1);
res = _add(ctx, res, _mul(ctx, x2, p2));
res = _add(ctx, res, _mul(ctx, x3, p3));
res = _add(ctx, res, _mul(ctx, x4, p4));
res = _add(ctx, res, _mul(ctx, x5, p5));
return _add(ctx, _trunc(ctx, res), p0).setDtype(x.dtype());
}
} // namespace
// Refer to
// Chapter 5 Exponentiation and Logarithms
// Benchmarking Privacy Preserving Scientific Operations
// https://www.esat.kuleuven.be/cosic/publications/article-3013.pdf
// NOTE(junfeng): The valid integer bits of x is 5. Otherwise, the output is
// incorrect.
Value exp2_pade(SPUContext* ctx, const Value& x) {
const size_t fbits = ctx->getFxpBits();
const auto k1 = _constant(ctx, 1U, x.shape());
// TODO(junfeng): Make int_bits configurable.
const size_t int_bits = 5;
const size_t bit_width = SizeOf(ctx->getField()) * 8;
const auto x_bshare = _prefer_b(ctx, x);
const auto x_msb = _rshift(ctx, x_bshare, bit_width - 1);
auto x_integer = _rshift(ctx, x_bshare, fbits);
auto x_fraction =
_sub(ctx, x, _lshift(ctx, x_integer, fbits)).setDtype(x.dtype());
auto ret = exp2_pade_normalized(ctx, x_fraction);
for (size_t idx = 0; idx < int_bits; idx++) {
auto a = _and(ctx, _rshift(ctx, x_integer, idx), k1);
detail::hintNumberOfBits(a, 1);
a = _prefer_a(ctx, a);
const auto K = 1U << std::min(1UL << idx, bit_width - 2);
ret = _mul(ctx, ret,
_add(ctx, _mul(ctx, a, _constant(ctx, K, x.shape())),
_sub(ctx, k1, a)))
.setDtype(ret.dtype());
}
// If we could ensure the integer bits of x is 5.
// we have x, -x, -x_hat. x_hat is 2's complement of -x.
// Then,
// x + (x_hat) = 32
// (x_hat) - 32 = -x
// exp2(x_hat) / exp2(32) = exp(-x)
// so exp(-x) = exp2(x_hat) / exp2(32)
auto ret_reciprocal =
_trunc(ctx, ret, std::pow(2, int_bits)).setDtype(ret.dtype());
// ret + msb * (reciprocal - ret)
return f_add(
ctx, ret,
_mul(ctx, x_msb, f_sub(ctx, ret_reciprocal, ret)).setDtype(ret.dtype()));
}
Value exp_pade(SPUContext* ctx, const Value& x) {
return f_exp2(ctx, f_mul(ctx, x,
constant(ctx, std::log2(std::exp(1.0F)), x.dtype(),
x.shape())));
}
// Refer to
// https://www.wolframalpha.com/input?i=Pade+approximation+tanh%28x%29+order+5%2C5.
// tanh(x) = (x + x^3 / 9.0 + x^5 /945.0) /
// (1 + 4 * x^2 / 9.0 + x^4 / 63.0)
Value tanh_pade(SPUContext* ctx, const Value& x) {
const auto x_2 = f_square(ctx, x);
const auto x_4 = f_square(ctx, x_2);
// Idea here...
// transform formula into
// x * (1 + x^2 / 9 + x^4 / 945) / (1 + 4 * x^2 / 9 + x^4 / 63)
// = x * (945 + 105 * x^2 + x^4) / (945 + 420 * x^2 + 15 * x^4)
// This can save some truncations
const auto c_945 = constant(ctx, 945.0F, x.dtype(), x.shape());
const auto c_105 = constant(ctx, 105, DT_I32, x.shape());
const auto c_420 = constant(ctx, 420, DT_I32, x.shape());
const auto c_15 = constant(ctx, 15, DT_I32, x.shape());
const auto x_2_m_105 = _mul(ctx, x_2, c_105).setDtype(x_2.dtype());
const auto x_2_m_420 = _mul(ctx, x_2, c_420).setDtype(x_2.dtype());
const auto x_4_m_15 = _mul(ctx, x_4, c_15).setDtype(x_4.dtype());
const auto nominator =
f_mul(ctx, x, f_add(ctx, c_945, f_add(ctx, x_2_m_105, x_4)));
const auto denominator = f_add(ctx, c_945, f_add(ctx, x_2_m_420, x_4_m_15));
return f_div(ctx, nominator, denominator);
}
// Reference:
// https://github.com/facebookresearch/CrypTen/blob/6ef151101668591bcfb2bbf7e7ebd39ab6db0413/crypten/common/functions/approximations.py#L365
Value compute_chebyshev_polynomials(SPUContext* ctx, const Value& x,
int64_t terms) {
// Ref:
// https://en.wikipedia.org/wiki/Chebyshev_polynomials#Recurrence_definition
// Chebyshev Polynomials of the first kind are defined as
//.. math::
// P_0(x) = 1
// P_1(x) = x
// P_{n+1}(x) = 2xP_{n}(x) - P_{n-1}(x)
std::vector<Value> poly = {x};
// y = 4*x^2 - 2
auto four = constant(ctx, 4, DT_I32, x.shape());
auto two = constant(ctx, 2.0F, x.dtype(), x.shape());
auto y =
f_sub(ctx, _mul(ctx, four, f_square(ctx, x)).setDtype(x.dtype()), two);
// z = y - 1
auto one = constant(ctx, 1.0F, x.dtype(), x.shape());
auto z = f_sub(ctx, y, one);
poly.emplace_back(f_mul(ctx, x, z));
for (int64_t idx = 2; idx < terms; ++idx) {
// next_polynomial = y * polynomials[k - 1] - polynomials[k - 2]
auto next = f_sub(ctx, f_mul(ctx, y, poly[idx - 1]), poly[idx - 2]);
poly.emplace_back(std::move(next));
}
return concatenate(ctx, poly, 0);
}
Value tanh_chebyshev(SPUContext* ctx, const Value& x) {
// Cheb coeff, deg = 17, domain = [-5,5]
static const std::array<float, 9> kCoeffs = {
1.2514045938932097, -0.3655987797163166, 0.17253141478140663,
-0.08943445792774211, 0.047703017901250824, -0.025830290571688078,
0.014338801903468182, -0.008541730970059077, 0.0061230685785789475};
auto coeff_value = constant(ctx, kCoeffs, x.dtype(),
{1, static_cast<int64_t>(kCoeffs.size())});
auto normalized_x = reshape(ctx, x, {1, x.numel()});
normalized_x =
_clamp(ctx, normalized_x,
constant(ctx, -5.0F, normalized_x.dtype(), normalized_x.shape()),
constant(ctx, 5.0F, normalized_x.dtype(), normalized_x.shape()))
.setDtype(x.dtype());
normalized_x = f_mul(
ctx, constant(ctx, 0.2F, x.dtype(), normalized_x.shape()), normalized_x);
auto poly = compute_chebyshev_polynomials(ctx, normalized_x, kCoeffs.size());
auto ret = f_mmul(ctx, coeff_value, poly);
return reshape(ctx, ret, x.shape());
}
Value sin_chebyshev(SPUContext* ctx, const Value& x) {
// Cheb coeff, deg = 9, domain = [-1.25*pi, 1.25*pi]
// use larger domain for accurate output on boundary
static const std::array<float, 5> kCoeffs = {
-0.07570787578233389, -0.8532364056408055, 0.2474789050491474,
-0.02719844932262742, 0.0016750058127101841};
auto coeff_value = constant(ctx, kCoeffs, x.dtype(),
{1, static_cast<int64_t>(kCoeffs.size())});
// Normalize input to[-pi, pi]
// theta - TWO_PI * Math.floor((theta + Math.PI) / TWO_PI)
auto pi = constant(ctx, M_PI, x.dtype(), x.shape());
auto two_pi = constant(ctx, 2 * M_PI, x.dtype(), x.shape());
auto two_pi_inv = constant(ctx, 1 / (2 * M_PI), x.dtype(), x.shape());
auto normalized = f_mul(ctx, f_add(ctx, x, pi), two_pi_inv);
normalized = f_mul(ctx, f_floor(ctx, normalized), two_pi);
normalized = f_sub(ctx, x, normalized);
normalized = reshape(ctx, normalized, {1, normalized.numel()});
// rescale the original x
normalized = f_mul(ctx,
constant(ctx, 0.25464790894703254F, normalized.dtype(),
normalized.shape()),
normalized);
auto poly = compute_chebyshev_polynomials(ctx, normalized, kCoeffs.size());
auto ret = f_mmul(ctx, coeff_value, poly);
return reshape(ctx, ret, x.shape());
}
Value cos_chebyshev(SPUContext* ctx, const Value& x) {
auto half_pi = constant(ctx, M_PI / 2, x.dtype(), x.shape());
// cos(x) = sin(pi/2 - x)
return sin_chebyshev(ctx, f_sub(ctx, half_pi, x));
}
} // namespace detail
Value f_exp(SPUContext* ctx, const Value& x) {
SPU_TRACE_HAL_LEAF(ctx, x);
SPU_ENFORCE(x.isFxp());
if (x.isPublic()) {
return f_exp_p(ctx, x);
}
switch (ctx->config().fxp_exp_mode()) {
case RuntimeConfig::EXP_DEFAULT:
case RuntimeConfig::EXP_TAYLOR:
return detail::exp_taylor(ctx, x);
case RuntimeConfig::EXP_PADE: {
// The valid input for exp_pade is [-kInputLimit, kInputLimit].
// TODO(junfeng): should merge clamp into exp_pade to save msb ops.
const float kInputLimit = 32 / std::log2(std::exp(1));
const auto clamped_x =
_clamp(ctx, x, constant(ctx, -kInputLimit, x.dtype(), x.shape()),
constant(ctx, kInputLimit, x.dtype(), x.shape()))
.setDtype(x.dtype());
return detail::exp_pade(ctx, clamped_x);
}
default:
SPU_THROW("unexpected exp approximation method {}",
ctx->config().fxp_exp_mode());
}
}
Value f_log(SPUContext* ctx, const Value& x) {
SPU_TRACE_HAL_LEAF(ctx, x);
SPU_ENFORCE(x.isFxp());
if (x.isPublic()) {
return f_log_p(ctx, x);
}
switch (ctx->config().fxp_log_mode()) {
case RuntimeConfig::LOG_DEFAULT:
case RuntimeConfig::LOG_PADE:
return f_mul(ctx, constant(ctx, std::log(2.0F), x.dtype(), x.shape()),
f_log2(ctx, x));
case RuntimeConfig::LOG_NEWTON:
return detail::log_householder(ctx, x);
default:
SPU_THROW("unexpected log approximation method {}",
ctx->config().fxp_log_mode());
}
}
Value f_log1p(SPUContext* ctx, const Value& x) {
SPU_TRACE_HAL_LEAF(ctx, x);
SPU_ENFORCE(x.isFxp());
return f_log(ctx, f_add(ctx, constant(ctx, 1.0F, x.dtype(), x.shape()), x));
}
Value f_log2(SPUContext* ctx, const Value& x) {
SPU_TRACE_HAL_LEAF(ctx, x);
SPU_ENFORCE(x.isFxp());
return detail::log2_pade(ctx, x).setDtype(x.dtype());
}
Value f_exp2(SPUContext* ctx, const Value& x) {
SPU_TRACE_HAL_LEAF(ctx, x);
return detail::exp2_pade(ctx, x);
}
Value f_tanh(SPUContext* ctx, const Value& x) {
SPU_TRACE_HAL_LEAF(ctx, x);
#ifndef TANH_USE_PADE
return detail::tanh_chebyshev(ctx, x);
#elif
// For tanh inputs beyond [-3, 3], result is infinitely close to -1, 1
// pade approximation has a relative ok result between [-3, 3], so clamp
// inputs to this range.
auto normalized_x = _clamp(ctx, x, constant(ctx, -3.F, x.dtype(), x.shape()),
constant(ctx, 3.F, x.dtype(), x.shape()))
.setDtype(x.dtype());
return detail::tanh_pade(ctx, normalized_x);
#endif
}
static Value rsqrt_init_guess(SPUContext* ctx, const Value& x, const Value& z) {
SPU_TRACE_HAL_LEAF(ctx, x, z);
const size_t f = ctx->getFxpBits();
// let u in [0.25, 0.5)
auto z_rev = _bitrev(ctx, z, 0, 2 * f);
detail::hintNumberOfBits(z_rev, 2 * f);
auto u = _trunc(ctx, _mul(ctx, x, z_rev)).setDtype(x.dtype());
// let rsqrt(u) = 26.02942339 * u^4 - 49.86605845 * u^3 + 38.4714796 * u^2
// - 15.47994394 * u + 4.14285016
spu::Value r;
if (!ctx->config().enable_lower_accuracy_rsqrt()) {
auto coeffs = {-15.47994394F, 38.4714796F, -49.86605845F, 26.02942339F};
r = f_add(ctx, detail::polynomial(ctx, u, coeffs),
constant(ctx, 4.14285016F, x.dtype(), x.shape()));
} else {
auto coeffs = {-5.9417F, 4.7979F};
r = f_add(ctx, detail::polynomial(ctx, u, coeffs),
constant(ctx, 3.1855F, x.dtype(), x.shape()));
}
return r;
}
static Value rsqrt_comp(SPUContext* ctx, const Value& x, const Value& z) {
SPU_TRACE_HAL_LEAF(ctx, x, z);
const size_t k = SizeOf(ctx->getField()) * 8;
const size_t f = ctx->getFxpBits();
// let a = 2^((e+f)/2), that is a[i] = 1 for i = (e+f)/2 else 0
// let b = lsb(e+f)
Value a;
Value b;
{
auto z_sep = _bitdeintl(ctx, z);
auto lo_mask =
_constant(ctx, (static_cast<uint128_t>(1) << (k / 2)) - 1, x.shape());
auto z_even = _and(ctx, z_sep, lo_mask);
auto z_odd = _and(ctx, _rshift(ctx, z_sep, k / 2), lo_mask);
// a[i] = z[2*i] ^ z[2*i+1]
a = _xor(ctx, z_odd, z_even);
// b ^= z[2*i]
b = _bit_parity(ctx, z_even, k / 2);
detail::hintNumberOfBits(b, 1);
}
auto a_rev = _bitrev(ctx, a, 0, (f / 2) * 2);
detail::hintNumberOfBits(a_rev, (f / 2) * 2);
// do compensation
// Note:
// https://arxiv.org/pdf/2107.00501.pdf
// - the magic number c0 & c1 seems to be wrong.
// - the LSB algorithm is correct and used in this implementation.
//
// The following constant is deduced exactly from:
// https://dl.acm.org/doi/10.1145/3411501.3419427
Value c0;
Value c1;
if (f % 2 == 1) {
c0 = _constant(ctx, 1 << ((f + 3) / 2), x.shape()); // f+e even
c1 = _constant(ctx, (1 << (f / 2 + 1)) * std::sqrt(2),
x.shape()); // f+e odd
} else {
c0 = _constant(ctx, (1 << (f / 2)) * std::sqrt(2),
x.shape()); // f+e even
c1 = _constant(ctx, 1 << (f / 2), x.shape()); // f+e odd
}
// let comp = 2^(-(e-1)/2) = mux(b, c1, c0) * a_rev
return _mul(ctx, _mux(ctx, b, c0, c1), a_rev);
}
static Value rsqrt_np2(SPUContext* ctx, const Value& x) {
SPU_TRACE_HAL_LEAF(ctx, x);
// let e = NP2(x), z = 2^(e+f)
return _lshift(ctx, detail::highestOneBit(ctx, x), 1);
}
// Reference:
// 1. https://dl.acm.org/doi/10.1145/3411501.3419427
// Main idea:
// 1. convert x to u * 2^(e + 1) while u belongs to [0.25, 0.5).
// 2. get a nice approximation for u part.
// 3. get the compensation for 2^(e + 1) part.
// 4. multiple two parts and get the result.
Value f_rsqrt(SPUContext* ctx, const Value& x) {
SPU_TRACE_HAL_LEAF(ctx, x);
// let e = NP2(x) , z = 2^(e+f)
auto z = rsqrt_np2(ctx, x);
// TODO: we should avoid fork context in hal layer, it will make global
// scheduling harder and also make profiling harder.
if (ctx->config().experimental_enable_intra_op_par()) {
auto sub_ctx = ctx->fork();
auto r = std::async(rsqrt_init_guess,
dynamic_cast<SPUContext*>(sub_ctx.get()), x, z);
auto comp = rsqrt_comp(ctx, x, z);
return _trunc(ctx, _mul(ctx, r.get(), comp)).setDtype(x.dtype());
} else {
auto r = rsqrt_init_guess(ctx, x, z);
auto comp = rsqrt_comp(ctx, x, z);
return _trunc(ctx, _mul(ctx, r, comp)).setDtype(x.dtype());
}
}
// Reference:
// 1. https://eprint.iacr.org/2012/405.pdf, section 6.1
// 2.
// https://github.com/tf-encrypted/tf-encrypted/blob/3b0f14d26e900caf12a92a9ea2284ccd4d58e683/tf_encrypted/protocol/aby3/fp.py#L35-L52
// Goldschmidt iteration, needs an initial approximation of sqrt_inv(x).
// In the end, g is an approximation of sqrt(x) while h is an approximation of
// 1 / (2 * sqrt(x)).
Value f_sqrt(SPUContext* ctx, const Value& x) {
SPU_TRACE_HAL_LEAF(ctx, x);
const auto c0 = constant(ctx, 0.5F, x.dtype(), x.shape());
const auto c1 = constant(ctx, 1.5F, x.dtype(), x.shape());
Value y0 = f_rsqrt(ctx, x);
Value g = f_mul(ctx, x, y0);
Value h = f_mul(ctx, y0, c0);
// iterations of 1 is enough.
const int iterations = 1;
for (int i = 0; i < iterations; i++) {
const auto r = f_sub(ctx, c1, f_mul(ctx, g, h));
g = f_mul(ctx, g, r);
h = f_mul(ctx, h, r);
}
return g;
}
namespace {
Value sigmoid_real(SPUContext* ctx, const Value& x) {
// f(x) = 1/(1+exp(-x))
const auto c1 = constant(ctx, 1.0F, x.dtype(), x.shape());
return f_reciprocal(ctx, f_add(ctx, c1, f_exp(ctx, f_negate(ctx, x))));
}
Value sigmoid_mm1(SPUContext* ctx, const Value& x) {
// SigmoidMM1: f(x) = 0.5 + 0.125 * x
const auto c1 = constant(ctx, 0.5F, x.dtype(), x.shape());
const auto c2 = constant(ctx, 0.125F, x.dtype(), x.shape());
return f_add(ctx, c1, f_mul(ctx, c2, x));
}
Value sigmoid_seg3(SPUContext* ctx, const Value& x) {
// f(x) = 0.5 + 0.125x if -4 <= x <= 4
// 1 if x > 4
// 0 if -4 > x
// Rounds = Gt + Mux*2 = 4 + Log(K)
auto upper = constant(ctx, 1.0F, x.dtype(), x.shape());
auto lower = constant(ctx, 0.0F, x.dtype(), x.shape());
auto middle = sigmoid_mm1(ctx, x);
auto upper_bound = constant(ctx, 4.0F, x.dtype(), x.shape());
auto lower_bound = constant(ctx, -4.0F, x.dtype(), x.shape());
auto ret = _mux(ctx, f_less(ctx, upper_bound, x), upper, middle);
return _mux(ctx, f_less(ctx, x, lower_bound), lower, ret).setDtype(x.dtype());
}
} // namespace
Value f_sigmoid(SPUContext* ctx, const Value& x) {
SPU_TRACE_HAL_DISP(ctx, x);
SPU_ENFORCE(x.isFxp());
switch (ctx->config().sigmoid_mode()) {
case RuntimeConfig::SIGMOID_DEFAULT:
case RuntimeConfig::SIGMOID_MM1: {
return sigmoid_mm1(ctx, x);
}
case RuntimeConfig::SIGMOID_SEG3: {
return sigmoid_seg3(ctx, x);
}
case RuntimeConfig::SIGMOID_REAL: {
return sigmoid_real(ctx, x);
}
default: {
SPU_THROW("Should not hit");
}
}
}
Value f_sine(SPUContext* ctx, const Value& x) {
SPU_TRACE_HAL_DISP(ctx, x);
SPU_ENFORCE(x.isFxp());
if (x.isPublic()) {
return f_sine_p(ctx, x);
}
return detail::sin_chebyshev(ctx, x);
}
Value f_cosine(SPUContext* ctx, const Value& x) {
SPU_TRACE_HAL_DISP(ctx, x);
SPU_ENFORCE(x.isFxp());
if (x.isPublic()) {
return f_cosine_p(ctx, x);
}
return detail::cos_chebyshev(ctx, x);
}
namespace {
Value EvaluatePolynomial(SPUContext* ctx, const Value& x,
absl::Span<const float> coefficients) {
auto poly = constant(ctx, coefficients[0], x.dtype(), x.shape());
for (size_t i = 1; i < coefficients.size(); ++i) {
auto c = constant(ctx, coefficients[i], x.dtype(), x.shape());
poly = f_mul(ctx, poly, x);
poly = f_add(ctx, poly, c);
}
return poly;
}
Value ErfImpl(SPUContext* ctx, const Value& x) {
static std::array<float, 5> kErfCoefficient{0.078108, 0.000972, 0.230389,
0.278393, 1.0};
auto one = constant(ctx, 1.0, x.dtype(), x.shape());
auto z = EvaluatePolynomial(ctx, x, kErfCoefficient);
z = f_square(ctx, z);
z = f_square(ctx, z);
z = detail::reciprocal_goldschmidt_positive(ctx, z);
return f_sub(ctx, one, z);
}
} // namespace
// Ref:
// Handbook of Mathematical Functions: with Formulas, Graphs, and Mathematical
// Tables, equation 7.1.27, maximum absolute error <= 5e-4
Value f_erf(SPUContext* ctx, const Value& x) {
if (x.isPublic()) {
return f_erf_p(ctx, x);
}
auto zero = constant(ctx, 0.0, x.dtype(), x.shape());
auto pred = f_less(ctx, x, zero);
auto abs_x = f_abs(ctx, x);
auto three = constant(ctx, 3.0, x.dtype(), x.shape());
auto cond = f_less(ctx, abs_x, three);
auto erf = ErfImpl(ctx, abs_x);
// we do this truncation because:
// 1. for large abs_x, reciprocal may overflow
// 2. error is sufficiently small (< 2.2e-5)
erf = _mux(ctx, cond, erf, constant(ctx, 1.0F, x.dtype(), x.shape()))
.setDtype(x.dtype());
return _mux(ctx, pred, f_negate(ctx, erf), erf).setDtype(x.dtype());
}
} // namespace spu::kernel::hal