forked from secretflow/spu
-
Notifications
You must be signed in to change notification settings - Fork 0
/
permute.cc
99 lines (85 loc) · 2.95 KB
/
permute.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
// Copyright 2023 Ant Group Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except x compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to x writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "libspu/mpc/utils/permute.h"
#include <algorithm>
#include <random>
#include "libspu/core/ndarray_ref.h"
#include "libspu/core/type_util.h"
namespace spu::mpc {
PermVector ring2pv(const NdArrayRef& x) {
SPU_ENFORCE(x.eltype().isa<Ring2k>(), "must be ring2k_type, got={}",
x.eltype());
const auto field = x.eltype().as<Ring2k>()->field();
PermVector pv(x.numel());
DISPATCH_ALL_FIELDS(field, "_", [&]() {
NdArrayView<ring2k_t> _x(x);
pforeach(0, x.numel(), [&](int64_t idx) { pv[idx] = int64_t(_x[idx]); });
});
return pv;
}
NdArrayRef applyInvPerm(const NdArrayRef& x, absl::Span<const int64_t> pv) {
SPU_ENFORCE_EQ(x.shape().ndim(), 1U, "x should be 1-d tensor");
NdArrayRef y(x.eltype(), x.shape());
const auto field = x.eltype().as<Ring2k>()->field();
DISPATCH_ALL_FIELDS(field, kPermModule, [&]() {
NdArrayView<ring2k_t> _x(x);
NdArrayView<ring2k_t> _y(y);
for (int64_t i = 0; i < y.numel(); i++) {
_y[pv[i]] = _x[i];
}
});
return y;
}
NdArrayRef applyPerm(const NdArrayRef& x, absl::Span<const int64_t> pv) {
SPU_ENFORCE_EQ(x.shape().ndim(), 1U, "x should be 1-d tensor");
NdArrayRef y(x.eltype(), x.shape());
const auto field = x.eltype().as<Ring2k>()->field();
DISPATCH_ALL_FIELDS(field, kPermModule, [&]() {
NdArrayView<ring2k_t> _x(x);
NdArrayView<ring2k_t> _y(y);
for (int64_t i = 0; i < y.numel(); i++) {
_y[i] = _x[pv[i]];
}
});
return y;
}
PermVector genRandomPerm(size_t size, uint64_t seed) {
PermVector perm(size);
std::iota(perm.begin(), perm.end(), 0);
// TODO: change PRNG to CSPRNG
std::mt19937 rng(seed);
std::shuffle(perm.begin(), perm.end(), rng);
return perm;
}
PermVector genInversePerm(absl::Span<const int64_t> pv) {
PermVector ret(pv.size());
for (size_t i = 0; i < pv.size(); ++i) {
ret[pv[i]] = i;
}
return ret;
}
PermVector genPermBySort(const NdArrayRef& x) {
SPU_ENFORCE_EQ(x.shape().ndim(), 1U, "x should be 1-d tensor");
PermVector perm(x.shape()[0]);
std::iota(perm.begin(), perm.end(), 0);
const auto field = x.eltype().as<Ring2k>()->field();
DISPATCH_ALL_FIELDS(field, kPermModule, [&]() {
using T = std::make_signed_t<ring2k_t>;
NdArrayView<T> _x(x);
auto cmp = [&_x](int64_t a, int64_t b) { return _x[a] < _x[b]; };
std::stable_sort(perm.begin(), perm.end(), cmp);
});
return perm;
}
} // namespace spu::mpc