diff --git a/heu/library/algorithms/dj/public_key.h b/heu/library/algorithms/dj/public_key.h index 95cfba34..b9872b41 100644 --- a/heu/library/algorithms/dj/public_key.h +++ b/heu/library/algorithms/dj/public_key.h @@ -54,7 +54,7 @@ class PublicKey : public HeObject { private: MPInt n_, hs_, pmod_, cmod_, bound_; - uint32_t s_; + uint32_t s_ = 0; // Updated by Ant Group struct LUT { std::unique_ptr m_space; // m-space for mod n^(s+1) diff --git a/heu/library/algorithms/dj/secret_key.h b/heu/library/algorithms/dj/secret_key.h index 178fe0ac..e02b419e 100644 --- a/heu/library/algorithms/dj/secret_key.h +++ b/heu/library/algorithms/dj/secret_key.h @@ -41,7 +41,7 @@ class SecretKey : public HeObject { MPInt2 n_; // (p, q) MPInt lambda_, mu_; // λ, μ MPInt pmod_; // n^s - uint32_t s_; + uint32_t s_ = 0; // Updated by Ant Group MPInt pp_; // p^s * (p^(-s) mod q^s), used for CRT MPInt2 inv_pq_; // ( q^(-1) mod p^s, p^(-1) mod q^s ) diff --git a/heu/library/algorithms/paillier_clustar_fpga/fpga_engine/driver/driver.cc b/heu/library/algorithms/paillier_clustar_fpga/fpga_engine/driver/driver.cc index 40765228..2527c80b 100644 --- a/heu/library/algorithms/paillier_clustar_fpga/fpga_engine/driver/driver.cc +++ b/heu/library/algorithms/paillier_clustar_fpga/fpga_engine/driver/driver.cc @@ -173,7 +173,7 @@ uint32_t get_error_num(uint32_t error_id) { */ int open_dev(int *userfd, int *h2cfd, int *c2hfd, uint8_t dev_num) { char *xdma_num = NULL; - int num = 255; + uint8_t num = 255; // Updated by Ant Group const char *xdma_dev = "XDMA_DEV"; xdma_num = getenv(xdma_dev); @@ -183,7 +183,7 @@ int open_dev(int *userfd, int *h2cfd, int *c2hfd, uint8_t dev_num) { int h2c_fd; int c2h_fd; - int fpga_id = (num <= dev_num ? num : dev_num); + uint8_t fpga_id = (num <= dev_num ? num : dev_num); // Updated by Ant Group fpga_id = fpga_id < MAX_NUM_OF_DEV ? fpga_id : 0; user_fd = open(DEV_USER[fpga_id], O_RDWR | O_SYNC); if (user_fd == -1) { diff --git a/heu/library/algorithms/paillier_clustar_fpga/fpga_engine/driver/multi_driver.cc b/heu/library/algorithms/paillier_clustar_fpga/fpga_engine/driver/multi_driver.cc index c54ebe9b..3d8be5e4 100644 --- a/heu/library/algorithms/paillier_clustar_fpga/fpga_engine/driver/multi_driver.cc +++ b/heu/library/algorithms/paillier_clustar_fpga/fpga_engine/driver/multi_driver.cc @@ -75,6 +75,9 @@ int fpga_fedai_operator_accl_split_task(fpga_config *cfg, char *para, sizeof(fpga_config)); // copy original cfg into cfg_last card_num = fpga_dev_number_get(); // number of available cards + if (card_num == 0) { + return -1; // Updated by Ant Group + } data1_length = getdatalength(cfg->data1_bitlen); // get the real bitlength of data1 @@ -136,8 +139,9 @@ int fpga_fedai_operator_accl_split_task(fpga_config *cfg, char *para, for (i = 0; i < card_num; i++) { fut_vec[i].get(); if (th_para[i].error != 0) { + int error = th_para[i].error; free(th_para); - return th_para[i].error; + return error; } } @@ -228,10 +232,11 @@ int fpga_fedai_operator_accl_split_task(fpga_config *cfg, char *para, for (i = 0; i < card_num; i++) { fut_vec[i].get(); if (th_para[i].error != 0) { + int error = th_para[i].error; free(cfg_last); free(cfg_first); free(th_para); - return th_para[i].error; + return error; } } @@ -275,6 +280,9 @@ int fpga_fedai_operator_accl_split_task(fpga_config *cfg, char *para, memcpy(&row_num, para, 32 / 8); total_row = cfg->batch_size / row_num; // Calculate total number of rows in this matmul. + if (total_row == 0) { + return -1; // Updated by Ant Group + } // card_num = 8; //debugcxd if (total_row >= card_num) { row_per_card = @@ -336,8 +344,9 @@ int fpga_fedai_operator_accl_split_task(fpga_config *cfg, char *para, for (i = 0; i < card_num; i++) { fut_vec[i].get(); if (th_para[i].error != 0) { + int error = th_para[i].error; free(th_para); - return th_para[i].error; + return error; } } @@ -422,11 +431,12 @@ int fpga_fedai_operator_accl_split_task(fpga_config *cfg, char *para, for (i = 0; i < used_card_num; i++) { fut_vec[i].get(); if (th_para[i].error != 0) { + int error = th_para[i].error; free(para_last); free(para_first); free(result_tmp); free(th_para); - return th_para[i].error; + return error; } } @@ -539,10 +549,11 @@ int fpga_fedai_operator_accl_split_task(fpga_config *cfg, char *para, for (i = 0; i < card_num; i++) { fut_vec[i].get(); if (th_para[i].error != 0) { + int error = th_para[i].error; free(cfg_last); free(cfg_first); free(th_para); - return th_para[i].error; + return error; } } diff --git a/heu/library/phe/encoding/batch_encoder.h b/heu/library/phe/encoding/batch_encoder.h index 53279bd1..226ce728 100644 --- a/heu/library/phe/encoding/batch_encoder.h +++ b/heu/library/phe/encoding/batch_encoder.h @@ -134,9 +134,9 @@ class BatchEncoder : public algorithms::HeObject { return pt; } - SchemaType schema_; - int64_t scale_; - size_t padding_bits_; + SchemaType schema_{}; + int64_t scale_ = 0; + size_t padding_bits_ = 0; }; } // namespace heu::lib::phe diff --git a/heu/library/phe/encoding/plain_encoder.h b/heu/library/phe/encoding/plain_encoder.h index 67938935..e782439a 100644 --- a/heu/library/phe/encoding/plain_encoder.h +++ b/heu/library/phe/encoding/plain_encoder.h @@ -81,8 +81,8 @@ class PlainEncoder : public algorithms::HeObject { private: explicit PlainEncoder(yacl::ByteContainerView in) { Deserialize(in); } - SchemaType schema_; - int64_t scale_; + SchemaType schema_{}; + int64_t scale_ = 0; }; } // namespace heu::lib::phe diff --git a/heu/library/spi/README.md b/heu/library/spi/README.md new file mode 100644 index 00000000..f434de53 --- /dev/null +++ b/heu/library/spi/README.md @@ -0,0 +1,115 @@ +# SPI 简介 + +## 什么是 SPI + +SPI(Service Provider Interface) 是 HEU 面向下层软硬件设置的一层功能扩展接口。通过 SPI,上层应用可以与下层的具体实现解耦。 + +下图展示了 SPI 在局部架构中的位置 +``` +┌─────────────────────────────────┐ +│ HE Based APPs │ +├─────────────────────────────────┤ +│ SPI │ +└─────────────────────────────────┘ +┌───────────┐ ┌───────────┐ +│ │ │ │ +│ LIB 1 │ │ LIB 2 │ ...... +│ │ │ │ +└───────────┘ └───────────┘ +``` + +### 名词解释 + +* SPI:Service Provider Interface,服务提供接口 +* LIB:功能包。SPI 中每一种具体实现称为一个 Lib,同一个 SPI 下挂的 Lib 功能一般类似,当上层 App 发起调用时,SPI 可自动选择最合适的 Lib + +## SPI 的使用场景 + +SPI 是围绕 HEU 高性能、可扩展的目标打造的一套技术框架,初期只在 HEU 中有应用,目前随着第三代 SPI 架构的推出,SPI 设计基本基本成熟,并在隐语其他模块中也有应用。未来,构建在 SPI 之上的隐语将会更加开放、灵活。 + +概率来说,SPI 适用以下这些场景: + +1. 硬件加速场景。硬件加速器可以封装成一个 Lib 接入 SPI,一经接入立即对上层所有算法适配 +2. 集成第三方库的场景。出于某些原因隐语需要依赖第三方库,但接入的同时又需要保持较低的耦合性,可以使用 SPI 方案 +3. 开源协议污染场景。对于一些带有 GPL 等传染性协议的三方库,使用 SPI 运行时加载库文件的能力可以有效屏蔽协议不兼容问题(规划中的能力) +4. 存在多个同类型库的场景。例如 HEU 实现了多种 PHE 算法,SPI 可以更好的组织、管理这些算法 + + +使用 SPI 的优势有: + +1. 解耦功能实现者和调用者,SPI 为底层 Lib 抽象出一层统一的接口,上层 APP 可以无缝在 Lib 之间切换 +2. 降低 Lib 接入难度。对于传统架构,硬件加速器开发者或者 Lib 开发者想要把产品接入隐语非常困难,需要有从上到下的理解,使用 SPI 模式后底层开发者只需要对接 SPI,无需关注上层应用 +3. 更好性能。SPI 支持变量堆上、栈上传递、自动处理标量/向量化调用,尽可能避免接口层的性能损耗,最大化发挥 Lib 性能 +4. 更快编译。Lib 代码修改后,上层依赖代码无需重编译,仅需重新 Link,在一定程度上提升研发效率 + + + +## HE SPI + +HE SPI 是基于第三代 SPI 技术专门为 HE 方向设计的一套接口,其特点是同时支持 PHE、Wordwise-FHE/LHE,Bitwise-FHE/LHE,支持 Multi-level 接入。 + +所谓 Multi-level SPI,是指 HEU 提供了两个接入层: + +* HE SPI,HE 算法整体接入层,如果第三方软/硬件实现了完整的 HE 算法,适合在此层接入 +* Polynomial+NTT SPI,一个较为底层的接入口,开放了多项式环计算和 NTT 转换的加速器接入口 + +``` + ┌──────────────────────────────────────────────────────────────┐ + │ HE Based APPs │ + ├──────────────────────────────────────────────────────────────┤ + │ HE SPI (PHE + LHE + FHE) │ + ├──────────────────────────────┬─┬──────────────┬──────────────┤ + │ HEU │ │ Third-party │ Third-party │ + │ Built-in │ │Software based│Hardware based│ + │ RNS Based HE Algorithms │ │ Libs │ Libs │ + ├──────────────────────────────┤ └──────────────┴──────────────┘ + │ Polynomial SPI (uint64) │ + └──────────────────────────────┘ + ┌─────────────┐ + │ NTT SPI │ + └─────────────┘ +``` + +FHE 算法非常复杂,如果直接接入 HE SPI 有难度,则可以考虑 Polynomial+NTT SPI,后者只要求实现基本的多项式环运算和 NTT 运算即可接入 HEU,难度降低很多,并且亦可对整体起到不错的加速效果。 + + +## SPI 的工作方式 + +SPI 并不只是一层接口,上文提到的每一种 SPI 其实都是一个“模块”,每一个 SPI 模块主要由以下几部分组成: + +1. SPI interface for user:SPI 对用户侧的接口,也就是用户看到的接口 +2. Multi-level sketches:对用户侧接口的多级预实现,对于一些较为简单,功能固定的接口 Sketch 可提供一个默认实现,这样每个 Lib 就不需要重复实现,简化 Lib 接入负担 +3. SPI interface for lib:Lib 侧的接口,这一层接口不固定,不同 Sketch 对用户侧接口的实现方式不一样,因此 Lib 侧看到的接口也不一样,取决于 Lib 从哪一个 Sketch 继承。Lib 接入时只需选择一个最合适的 Sketch,实现该 Sketch 要求的接口即可。 +4. SPI Factory:SPI 工厂用于创建 Lib 实例。Lib 运行需要的初始化参数、配置均由 SPI Factory注入 + +``` + S P I M o d u l e + ┌────────────────────────────────────────┐ + │ SPI interface for user │ + ├────────────────────────────────────────┤ +┌────────────┐ │ │ +│ │ │ Composed of multi-level sketches │ +│ S P I │ │ │ +│ Factory │ ├────────────────────────────────────────┤ +│ │ │ SPI interface for lib │ +└─────┬──────┘ └────────────────────────────────────────┘ + │Create ┌───────────┐ ┌───────────┐ + │Instance │ │ │ │ + └──────────►│ LIB 1 │ │ LIB 2 │ ...... + │ │ │ │ + └───────────┘ └───────────┘ +``` + +SPI 的代码组织和使用示意: + +```c++ +class HeSpi; // SPI 用户侧接口 + +class HeSpiVectorSketch : public HeSpi; // Sketch: 对用户接口的预实现 + +class HeGpuLib : public HeSpiVectorSketch; // 第三方 Lib,从 Sketch 继承 + +// 用户创建 Lib 实例 +std::unique_ptr hekit = HeSpiFactory::Instance()->Create(/* FHE Args...*/); +heKit->XXXX(); // 使用 +``` diff --git a/heu/library/spi/he/base.h b/heu/library/spi/he/base.h new file mode 100644 index 00000000..0b5fe15d --- /dev/null +++ b/heu/library/spi/he/base.h @@ -0,0 +1,36 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yacl/utils/spi/argument/arg_set.h" +#include "yacl/utils/spi/item.h" + +namespace heu::lib::spi { + +/* + * Item 本质上在 C++ 之上构建了一套无类型系统,类似于 + * Python,任何类型都可以转换成 Item, 反之 Item 也可以变成任何实际类型。 + * + * 一些缩写约定: + * PT => Plaintext + * CT => Ciphertext + * PTs => Plaintext Array + * CTs => Ciphertext Array + */ +using Item = yacl::Item; + +using SpiArgs = yacl::SpiArgs; + +} // namespace heu::lib::spi diff --git a/heu/library/spi/he/binary_evaluator.h b/heu/library/spi/he/binary_evaluator.h new file mode 100644 index 00000000..8fceba65 --- /dev/null +++ b/heu/library/spi/he/binary_evaluator.h @@ -0,0 +1,56 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "heu/library/spi/he/base.h" + +namespace heu::lib::spi { + +// Short int operations based on TFHE +// 基于 Bitwise FHE 封装得到的高级 integer 操作 +class BinaryEvaluator { + public: + ~BinaryEvaluator() = default; + + // Bitwise operations. + virtual Item ShiftL(const Item &x, uint32_t bits) const = 0; + virtual void ShiftLInplace(Item *x, uint32_t bits) const = 0; + + virtual Item ShiftR(const Item &x, uint32_t bits) const = 0; + virtual void ShiftRInplace(Item *x, uint32_t bits) const = 0; + + virtual Item RotateL(const Item &x, uint32_t bits) const = 0; + virtual void RotateLInplace(Item *x, uint32_t bits) const = 0; + + virtual Item RotateR(const Item &x, uint32_t bits) const = 0; + virtual void RotateRInplace(Item *x, uint32_t bits) const = 0; + + // Comparisons. + virtual Item IsEqual(const Item &x, const Item &y) const = 0; + virtual Item IsNotEqual(const Item &x, const Item &y) const = 0; + virtual Item IsGreaterThan(const Item &x, const Item &y) const = 0; + virtual Item IsGreaterEqual(const Item &x, const Item &y) const = 0; + virtual Item IsLower(const Item &x, const Item &y) const = 0; + virtual Item IsLowerEqual(const Item &x, const Item &y) const = 0; + + virtual Item Min(const Item &x) const = 0; + virtual Item Min(const Item &x, const Item &y) const = 0; + virtual Item Max(const Item &x) const = 0; + virtual Item Max(const Item &x, const Item &y) const = 0; + + // other api: type cast support, programmable bootstrapping support +}; + +} // namespace heu::lib::spi diff --git a/heu/library/spi/he/decryptor.h b/heu/library/spi/he/decryptor.h new file mode 100644 index 00000000..987fe28e --- /dev/null +++ b/heu/library/spi/he/decryptor.h @@ -0,0 +1,29 @@ +// Copyright 2022 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "heu/library/spi/he/base.h" + +namespace heu::lib::spi { + +class Decryptor { + public: + // CT -> PT + // CTs -> PTs + virtual void Decrypt(const Item& ct, Item* out) const = 0; + virtual Item Decrypt(const Item& ct) const = 0; +}; + +} // namespace heu::lib::spi diff --git a/heu/library/spi/he/encryptor.h b/heu/library/spi/he/encryptor.h new file mode 100644 index 00000000..02f29c6a --- /dev/null +++ b/heu/library/spi/he/encryptor.h @@ -0,0 +1,29 @@ +// Copyright 2022 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "heu/library/spi/he/base.h" + +namespace heu::lib::spi { + +class Encryptor { + public: + // message is encoded plaintext or plaintext array + // For all HE schema, plaintext is a custom type defined by underlying lib + // For 1bit-boolean-FHE, plaintext can be bool or custom type + virtual Item Encrypt(const Item &message) const = 0; +}; + +} // namespace heu::lib::spi diff --git a/heu/library/spi/he/gate_evaluator.h b/heu/library/spi/he/gate_evaluator.h new file mode 100644 index 00000000..2661dc26 --- /dev/null +++ b/heu/library/spi/he/gate_evaluator.h @@ -0,0 +1,86 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "heu/library/spi/he/base.h" + +namespace heu::lib::spi { + +// For single bit operations +class GateEvaluator { + public: + virtual ~GateEvaluator() = default; + + // CT = -CT + // CTs = -CTS + virtual Item Not(const Item &x) const = 0; + virtual void NotInplace(Item *x) const = 0; + + // PT = PT & PT + // CT = PT & CT + // CT = CT & PT + // CT = CT & CT + // PTs = PTs & PT [Broadcast] + // CTs = PTs & CT [Broadcast] + // CTs = CTs & PT [Broadcast] + // CTs = CTs & CT [Broadcast] + // PTs = PT & PTs [Broadcast] + // CTs = PT & CTs [Broadcast] + // CTs = CT & PTs [Broadcast] + // CTs = CT & CTs [Broadcast] + // PTs = PTs & PTs + // CTs = PTs & CTs + // CTs = CTs & PTs + // CTs = CTs & CTs + // If Item is plaintext, then the real type must be bool or vector/span + virtual Item And(const Item &x, const Item &y) const = 0; + // CT &= PT + // CT &= CT + // CTs &= PT [Broadcast] + // CTs &= CT [Broadcast] + // CTs &= PTs + // CTs &= CTs + virtual void AndInplace(Item *x, const Item &y) const = 0; + // AND gate with bootstrapping + virtual Item BootAnd(const Item &x, const Item &y) const = 0; + virtual void BootAndInplace(Item *x, const Item &y) const = 0; + + virtual Item Or(const Item &x, const Item &y) const = 0; + virtual void OrInplace(Item *x, const Item &y) const = 0; + virtual Item BootOr(const Item &x, const Item &y) const = 0; + virtual void BootOrInplace(Item *x, const Item &y) const = 0; + + virtual Item Xor(const Item &x, const Item &y) const = 0; + virtual void XorInplace(Item *x, const Item &y) const = 0; + virtual Item BootXor(const Item &x, const Item &y) const = 0; + virtual void BootXorInplace(Item *x, const Item &y) const = 0; + + virtual Item Nand(const Item &x, const Item &y) const = 0; + virtual void NandInplace(Item *x, const Item &y) const = 0; + virtual Item BootNand(const Item &x, const Item &y) const = 0; + virtual void BootNandInplace(Item *x, const Item &y) const = 0; + + virtual Item Nor(const Item &x, const Item &y) const = 0; + virtual void NorInplace(Item *x, const Item &y) const = 0; + virtual Item BootNor(const Item &x, const Item &y) const = 0; + virtual void BootNorInplace(Item *x, const Item &y) const = 0; + + virtual Item Xnor(const Item &x, const Item &y) const = 0; + virtual void XnorInplace(Item *x, const Item &y) const = 0; + virtual Item BootXnor(const Item &x, const Item &y) const = 0; + virtual void BootXnorInplace(Item *x, const Item &y) const = 0; +}; + +} // namespace heu::lib::spi diff --git a/heu/library/spi/he/kit.cc b/heu/library/spi/he/kit.cc new file mode 100644 index 00000000..a6b11c02 --- /dev/null +++ b/heu/library/spi/he/kit.cc @@ -0,0 +1,59 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "heu/library/spi/he/kit.h" + +namespace heu::lib::spi { + +std::shared_ptr HeKit::GetEncryptor() const { + YACL_ENFORCE( + encryptor_, + "Encryptor is not enabled according to your initialization params"); + return encryptor_; +} + +std::shared_ptr HeKit::GetDecryptor() const { + YACL_ENFORCE( + decryptor_, + "Decryptor is not enabled according to your initialization params"); + return decryptor_; +} + +std::shared_ptr HeKit::GetWordEvaluator() const { + YACL_ENFORCE( + word_evaluator_, + "Word evaluator is not enabled according to your initialization params"); + return word_evaluator_; +} + +std::shared_ptr HeKit::GetGateEvaluator() const { + YACL_ENFORCE( + gate_evaluator_, + "Gate evaluator is not enabled according to your initialization params"); + return gate_evaluator_; +} + +std::shared_ptr HeKit::GetBinaryEvaluator() const { + YACL_ENFORCE(binary_evaluator_, + "Binary evaluator is not enabled according to your " + "initialization params"); + return binary_evaluator_; +} + +HeFactory& HeFactory::Instance() { + static HeFactory factory; + return factory; +} + +} // namespace heu::lib::spi diff --git a/heu/library/spi/he/kit.h b/heu/library/spi/he/kit.h new file mode 100644 index 00000000..1ead9639 --- /dev/null +++ b/heu/library/spi/he/kit.h @@ -0,0 +1,117 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "yacl/base/buffer.h" +#include "yacl/base/byte_container_view.h" +#include "yacl/utils/spi/spi_factory.h" + +#include "heu/library/spi/he/binary_evaluator.h" +#include "heu/library/spi/he/decryptor.h" +#include "heu/library/spi/he/encryptor.h" +#include "heu/library/spi/he/gate_evaluator.h" +#include "heu/library/spi/he/word_evaluator.h" + +namespace heu::lib::spi { + +enum class HeKeyType { + PublicKey, + SecretKey, + RelinKeys, + GaloisKeys, + BootstrappingKey, +}; + +class HeKit { + public: + virtual ~HeKit() = default; + + //=== Meta query ===// + + virtual std::string GetLibraryName() const = 0; + virtual std::string GetSchemaName() const = 0; + + virtual Item GetPublicKey() = 0; // equal to GetKey(HeKeyType::PublicKey); + virtual Item GetSecretKey() = 0; + virtual Item GetKey(HeKeyType key_type) = 0; + + //=== Get Operators ===// + + virtual std::shared_ptr GetEncryptor() const; + virtual std::shared_ptr GetDecryptor() const; + virtual std::shared_ptr GetWordEvaluator() const; + virtual std::shared_ptr GetGateEvaluator() const; + virtual std::shared_ptr GetBinaryEvaluator() const; + + //=== Get Encoders ===// + + // TODO + + /*====================================// + * I/O for HE Objects + * + * 以下所有函数的入参出参均支持如下形式: + * 1. Plaintext + * 2. Plaintext array + * 3. Ciphertext + * 4. Ciphertext array + * 5. All kinds of keys + *====================================*/ + + // Make a deep copy of obj. + virtual Item Clone(const Item& obj) const = 0; + + // To human-readable string + virtual std::string ToString(const Item& x) const = 0; + + // The format of object(s) is based on initial params + virtual yacl::Buffer Serialize(const Item& x) const = 0; + // serialize field element(s) to already allocated buffer. + // if buf is nullptr, then calc serialize size only + // @return: the actual size of serialized buffer + virtual size_t Serialize(const Item& x, uint8_t* buf, + size_t buf_len) const = 0; + + virtual Item Deserialize(yacl::ByteContainerView buffer) const = 0; + + protected: + // Generate all needed keys according to args. + // Or recover keys from previous serialized buffer + virtual void SetupContext(const SpiArgs& args) = 0; + + std::shared_ptr encryptor_; + std::shared_ptr decryptor_; + std::shared_ptr word_evaluator_; + std::shared_ptr gate_evaluator_; + std::shared_ptr binary_evaluator_; +}; + +class HeFactory final : public yacl::SpiFactoryBase { + public: + static HeFactory& Instance(); +}; + +/* + * The sign of creator/checker: + * > std::unique_ptr Create(const std::string &schema, const SpiArgs &); + * > bool Check(const std::string &schema, const SpiArgs &args); + */ +#define REGISTER_HE_LIBRARY(lib_name, performance, checker, creator) \ + REGISTER_SPI_LIBRARY_HELPER(HeFactory, lib_name, performance, checker, \ + creator) + +} // namespace heu::lib::spi diff --git a/heu/library/spi/he/sketches/README.md b/heu/library/spi/he/sketches/README.md new file mode 100644 index 00000000..9086bc41 --- /dev/null +++ b/heu/library/spi/he/sketches/README.md @@ -0,0 +1,21 @@ +# Sketch 说明 + +`sketches/` 目录存放 HE SPI 接口的预实现,根据预实现的方式不同 Sketch 分为两大类: + +- Scalar Sketch:为支持标量调用的 Lib 实现一些通用的功能 + - 特点是:一次调用只传入一个参数,但是调用频率非常高,一般为多线程并发调用 + - 适合基于 CPU 实现的 Lib + +- Vector Sketch:为支持向量化调用的 Lib 实现一些通用的功能 + - 特点:批量式调用,一次函数调用即可处理一批数据 + - 适合基于加速硬件实现的 Lib,例如 GPU 等,可以最大化发挥硬件的并发性能 + +对于 Lib 开发者来说,您可以选择从 Sketch 继承子类来实现您的功能,比直接从 SPI 接口继承开发会简单很多 + +Sketch 实现以下功能: + +- 将函数参数 Item 转换成 Lib 自定义的类型 +- 将上层的调用模式转换成 Lib 支持的调用模式 + - Scalar Sketch:无论上层为标量/向量调用,一律转换成标量调用模式 + - Vector Sketch:无论上层为标量/向量调用,一律转换成向量调用模式 +- 对于一些简单的功能,Sketch 提供默认实现 diff --git a/heu/library/spi/he/sketches/scalar/word_evaluator.h b/heu/library/spi/he/sketches/scalar/word_evaluator.h new file mode 100644 index 00000000..06c0b685 --- /dev/null +++ b/heu/library/spi/he/sketches/scalar/word_evaluator.h @@ -0,0 +1,126 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +// ================================================================ // +// <<< Sketch 接口与 SPI 接口基本类似 >>> // +// <<< 此处仅以 WordEvaluator 为例展示接口 >>> // +// <<< 其它 Encryptor/Evaluator/Decryptor 接口变化同理,此处不再展开 >>> // +// ================================================================ // + +namespace heu::lib::spi { + +template +class WordEvaluatorScalarSketch { + public: + virtual ~WordEvaluatorScalarSketch() = default; + + //=== Arithmetic Operations ===// + + // PT = -PT + // CT = -CT + virtual PlaintextT Negate(const PlaintextT& a) const = 0; + virtual void NegateInplace(PlaintextT* a) const = 0; + virtual CiphertextT Negate(const CiphertextT& a) const = 0; + virtual void NegateInplace(CiphertextT* a) const = 0; + + // PT = PT + PT + // CT = PT + CT + // CT = CT + PT + // CT = CT + CT + virtual PlaintextT Add(const PlaintextT& a, const PlaintextT& b) const = 0; + virtual CiphertextT Add(const PlaintextT& a, const CiphertextT& b) const = 0; + virtual CiphertextT Add(const CiphertextT& a, const PlaintextT& b) const = 0; + virtual CiphertextT Add(const CiphertextT& a, const CiphertextT& b) const = 0; + // CT += PT + // CT += CT + virtual void AddInplace(CiphertextT* a, const PlaintextT& b) const = 0; + virtual void AddInplace(CiphertextT* a, const CiphertextT& b) const = 0; + + // PT = PT + PT + // CT = PT + CT + // CT = CT + PT + // CT = CT + CT + // CT += PT + // CT += CT + virtual PlaintextT Sub(const PlaintextT& a, const PlaintextT& b) const = 0; + virtual CiphertextT Sub(const PlaintextT& a, const CiphertextT& b) const = 0; + virtual CiphertextT Sub(const CiphertextT& a, const PlaintextT& b) const = 0; + virtual CiphertextT Sub(const CiphertextT& a, const CiphertextT& b) const = 0; + virtual void SubInplace(CiphertextT* a, const PlaintextT& b) const = 0; + virtual void SubInplace(CiphertextT* a, const CiphertextT& b) const = 0; + + // PT = PT * PT [AHE/FHE] + // CT = PT * CT [AHE/FHE] + // CT = CT * PT [AHE/FHE] + // CT = CT * CT [FHE] + virtual PlaintextT Mul(const PlaintextT& a, const PlaintextT& b) const = 0; + virtual CiphertextT Mul(const PlaintextT& a, const CiphertextT& b) const = 0; + virtual CiphertextT Mul(const CiphertextT& a, const PlaintextT& b) const = 0; + virtual CiphertextT Mul(const CiphertextT& a, const CiphertextT& b) const = 0; + virtual void MulInplace(CiphertextT* a, const PlaintextT& b) const = 0; + virtual void MulInplace(CiphertextT* a, const CiphertextT& b) const = 0; + + virtual PlaintextT Square(const PlaintextT& a) const = 0; + virtual CiphertextT Square(const CiphertextT& a) const = 0; + virtual void SquareInplace(PlaintextT* a) const = 0; + virtual void SquareInplace(CiphertextT* a) const = 0; + + virtual PlaintextT Pow(const PlaintextT& a, uint64_t exponent) const = 0; + virtual CiphertextT Pow(const CiphertextT& a, uint64_t exponent) const = 0; + virtual void PowInplace(PlaintextT* a, uint64_t exponent) const = 0; + virtual void PowInplace(CiphertextT* a, uint64_t exponent) const = 0; + + //=== Ciphertext maintains ===// + + // CT -> CT + // The result is same with ct += Enc(0) + virtual void Randomize(CiphertextT* ct) const = 0; + + virtual CiphertextT Relinearize(const CiphertextT& a) const = 0; + virtual void RelinearizeInplace(CiphertextT* a) const = 0; + + // Given a ciphertext with modulo q_1...q_k, this function switches the + // modulus down to q_1...q_{k-1} + virtual CiphertextT ModSwitch(const CiphertextT& a) const = 0; + virtual void ModSwitchInplace(CiphertextT* a) const = 0; + + // Given a ciphertext with modulo q_1...q_k, this function switches the + // modulus down to q_1...q_{k-1}, and scales the message down accordingly + virtual CiphertextT Rescale(const CiphertextT& a) const = 0; + virtual void RescaleInplace(CiphertextT* a) const = 0; + + //=== Galois automorphism ===// + + // BFV/BGV only + virtual CiphertextT SwapRows(const CiphertextT& a) const = 0; + virtual void SwapRowsInplace(CiphertextT* a) const = 0; + + // CKKS only, for complex number + virtual CiphertextT Conjugate(const CiphertextT& a) const = 0; + virtual void ConjugateInplace(CiphertextT* a) const = 0; + + // BFV/BGV batching mode: + // The size of matrix is 2-by-(N/2), so move each row cyclically to the left + // (steps > 0) or to the right (steps < 0) + // CKKS batching mode: + // rotates the encrypted plaintext vector cyclically to the left (steps > 0) + // or to the right (steps < 0). + // All schemas: require abs(steps) < N/2 + virtual CiphertextT Rotate(const CiphertextT& a, int steps) const = 0; + virtual void RotateInplace(CiphertextT* a, int steps) const = 0; +}; + +} // namespace heu::lib::spi diff --git a/heu/library/spi/he/sketches/vector/word_evaluator.h b/heu/library/spi/he/sketches/vector/word_evaluator.h new file mode 100644 index 00000000..94f7ca9d --- /dev/null +++ b/heu/library/spi/he/sketches/vector/word_evaluator.h @@ -0,0 +1,175 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "absl/types/span.h" + +// ================================================================ // +// <<< Sketch 接口与 SPI 接口基本类似 >>> // +// <<< 此处仅以 WordEvaluator 为例展示接口 >>> // +// <<< 其它 Encryptor/Evaluator/Decryptor 接口变化同理,此处不再展开 >>> // +// ================================================================ // + +namespace heu::lib::spi { + +template +class WordEvaluatorVectorSketch { + public: + virtual ~WordEvaluatorVectorSketch() = default; + + //=== Arithmetic Operations ===// + + // PT = -PT + // CT = -CT + virtual std::vector Negate( + const absl::Span& a) const = 0; + virtual void NegateInplace(absl::Span a) const = 0; + virtual std::vector Negate( + const absl::Span& a) const = 0; + virtual void NegateInplace(absl::Span a) const = 0; + + // PT = PT + PT + // CT = PT + CT + // CT = CT + PT + // CT = CT + CT + virtual std::vector Add( + const absl::Span& a, + const absl::Span& b) const = 0; + virtual std::vector Add( + const absl::Span& a, + const absl::Span& b) const = 0; + virtual std::vector Add( + const absl::Span& a, + const absl::Span& b) const = 0; + virtual std::vector Add( + const absl::Span& a, + const absl::Span& b) const = 0; + // CT += PT + // CT += CT + virtual void AddInplace(absl::Span a, + const absl::Span& b) const = 0; + virtual void AddInplace(absl::Span a, + const absl::Span& b) const = 0; + + // PT = PT + PT + // CT = PT + CT + // CT = CT + PT + // CT = CT + CT + // CT += PT + // CT += CT + virtual std::vector Sub( + const absl::Span& a, + const absl::Span& b) const = 0; + virtual std::vector Sub( + const absl::Span& a, + const absl::Span& b) const = 0; + virtual std::vector Sub( + const absl::Span& a, + const absl::Span& b) const = 0; + virtual std::vector Sub( + const absl::Span& a, + const absl::Span& b) const = 0; + virtual void SubInplace(absl::Span a, + const absl::Span& b) const = 0; + virtual void SubInplace(absl::Span a, + const absl::Span& b) const = 0; + + // PT = PT * PT [AHE/FHE] + // CT = PT * CT [AHE/FHE] + // CT = CT * PT [AHE/FHE] + // CT = CT * CT [FHE] + virtual std::vector Mul( + const absl::Span& a, + const absl::Span& b) const = 0; + virtual std::vector Mul( + const absl::Span& a, + const absl::Span& b) const = 0; + virtual std::vector Mul( + const absl::Span& a, + const absl::Span& b) const = 0; + virtual std::vector Mul( + const absl::Span& a, + const absl::Span& b) const = 0; + virtual void MulInplace(absl::Span a, + const absl::Span& b) const = 0; + virtual void MulInplace(absl::Span a, + const absl::Span& b) const = 0; + + virtual std::vector Square( + const absl::Span& a) const = 0; + virtual std::vector Square( + const absl::Span& a) const = 0; + virtual void SquareInplace(absl::Span a) const = 0; + virtual void SquareInplace(absl::Span a) const = 0; + + virtual std::vector Pow(const absl::Span& a, + uint64_t exponent) const = 0; + virtual std::vector Pow(const absl::Span& a, + uint64_t exponent) const = 0; + virtual void PowInplace(absl::Span a, + uint64_t exponent) const = 0; + virtual void PowInplace(absl::Span a, + uint64_t exponent) const = 0; + + //=== Ciphertext maintains ===// + + // CT -> CT + // The result is same with ct += Enc(0) + virtual void Randomize(absl::Span ct) const = 0; + + virtual std::vector Relinearize( + const absl::Span& a) const = 0; + virtual void RelinearizeInplace(absl::Span a) const = 0; + + // Given a ciphertext with modulo q_1...q_k, this function switches the + // modulus down to q_1...q_{k-1} + virtual std::vector ModSwitch( + const absl::Span& a) const = 0; + virtual void ModSwitchInplace(absl::Span a) const = 0; + + // Given a ciphertext with modulo q_1...q_k, this function switches the + // modulus down to q_1...q_{k-1}, and scales the message down accordingly + virtual std::vector Rescale( + const absl::Span& a) const = 0; + virtual void RescaleInplace(absl::Span a) const = 0; + + //=== Galois automorphism ===// + + // BFV/BGV only + virtual std::vector SwapRows( + const absl::Span& a) const = 0; + virtual void SwapRowsInplace(absl::Span a) const = 0; + + // CKKS only, for complex number + virtual std::vector Conjugate( + const absl::Span& a) const = 0; + virtual void ConjugateInplace(absl::Span a) const = 0; + + // BFV/BGV batching mode: + // The size of matrix is 2-by-(N/2), so move each row cyclically to the left + // (steps > 0) or to the right (steps < 0) + // CKKS batching mode: + // rotates the encrypted plaintext vector cyclically to the left (steps > 0) + // or to the right (steps < 0). + // All schemas: require abs(steps) < N/2 + virtual std::vector Rotate( + const absl::Span& a, int steps) const = 0; + virtual void RotateInplace(absl::Span a, int steps) const = 0; +}; + +} // namespace heu::lib::spi diff --git a/heu/library/spi/he/word_evaluator.h b/heu/library/spi/he/word_evaluator.h new file mode 100644 index 00000000..f5e93f53 --- /dev/null +++ b/heu/library/spi/he/word_evaluator.h @@ -0,0 +1,129 @@ +// Copyright 2022 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "heu/library/spi/he/base.h" + +namespace heu::lib::spi { + +class WordEvaluator { + public: + virtual ~WordEvaluator() = default; + + //=== Arithmetic Operations ===// + + // PT = -PT + // CT = -CT + // PTs = -PTs + // CTs = -CTs + virtual Item Negate(const Item& a) const = 0; + virtual void NegateInplace(Item* a) const = 0; + + // PT = PT + PT + // CT = PT + CT + // CT = CT + PT + // CT = CT + CT + // PTs = PTs + PT [Broadcast] + // CTs = PTs + CT [Broadcast] + // CTs = CTs + PT [Broadcast] + // CTs = CTs + CT [Broadcast] + // PTs = PT + PTs [Broadcast] + // CTs = PT + CTs [Broadcast] + // CTs = CT + PTs [Broadcast] + // CTs = CT + CTs [Broadcast] + // PTs = PTs + PTs + // CTs = PTs + CTs + // CTs = CTs + PTs + // CTs = CTs + CTs + virtual Item Add(const Item& a, const Item& b) const = 0; + // CT += PT + // CT += CT + // CTs += PT [Broadcast] + // CTs += CT [Broadcast] + // CTs += PTs + // CTs += CTs + virtual void AddInplace(Item* a, const Item& b) const = 0; + + // 参数可能的组合类型与 Add 相同 + virtual Item Sub(const Item& a, const Item& b) const = 0; + virtual void SubInplace(Item* a, const Item& b) const = 0; + + // PT = PT * PT [AHE/FHE] + // CT = PT * CT [AHE/FHE] + // CT = CT * PT [AHE/FHE] + // CT = CT * CT [FHE] + // PTs = PTs * PT [Broadcast] [AHE/FHE] + // CTs = PTs * CT [Broadcast] [AHE/FHE] + // CTs = CTs * PT [Broadcast] [AHE/FHE] + // CTs = CTs * CT [Broadcast] [FHE] + // PTs = PT * PTs [Broadcast] [AHE/FHE] + // CTs = PT * CTs [Broadcast] [AHE/FHE] + // CTs = CT * PTs [Broadcast] [AHE/FHE] + // CTs = CT * CTs [Broadcast] [FHE] + // PTs = PTs * PTs [AHE/FHE] + // CTs = PTs * CTs [AHE/FHE] + // CTs = CTs * PTs [AHE/FHE] + // CTs = CTs * CTs [FHE] + virtual Item Mul(const Item& a, const Item& b) const = 0; + virtual void MulInplace(Item* a, const Item& b) const = 0; + + virtual Item Square(const Item& a) const = 0; + virtual void SquareInplace(Item* a) const = 0; + + virtual Item Pow(const Item& a, uint64_t exponent) const = 0; + virtual void PowInplace(Item* a, uint64_t exponent) const = 0; + + //=== Ciphertext maintains ===// + + // CT -> CT + // CTs -> CTs + // The result is same with ct += Enc(0) + virtual void Randomize(Item* ct) const = 0; + + virtual Item Relinearize(const Item& a) const = 0; + virtual void RelinearizeInplace(Item* a) const = 0; + + // Given a ciphertext with modulo q_1...q_k, this function switches the + // modulus down to q_1...q_{k-1} + virtual Item ModSwitch(const Item& a) const = 0; + virtual void ModSwitchInplace(Item* a) const = 0; + + // Given a ciphertext with modulo q_1...q_k, this function switches the + // modulus down to q_1...q_{k-1}, and scales the message down accordingly + virtual Item Rescale(const Item& a) const = 0; + virtual void RescaleInplace(Item* a) const = 0; + + //=== Galois automorphism ===// + + // BFV/BGV only + virtual Item SwapRows(const Item& a) const = 0; + virtual void SwapRowsInplace(Item* a) const = 0; + + // CKKS only, for complex number + virtual Item Conjugate(const Item& a) const = 0; + virtual void ConjugateInplace(Item* a) const = 0; + + // BFV/BGV batching mode: + // The size of matrix is 2-by-(N/2), so move each row cyclically to the left + // (steps > 0) or to the right (steps < 0) + // CKKS batching mode: + // rotates the encrypted plaintext vector cyclically to the left (steps > 0) + // or to the right (steps < 0). + // All schemas: require abs(steps) < N/2 + virtual Item Rotate(const Item& a, int steps) const = 0; + virtual void RotateInplace(Item* a, int steps) const = 0; +}; + +} // namespace heu::lib::spi diff --git a/heu/library/spi/poly/ntt_op.h b/heu/library/spi/poly/ntt_op.h new file mode 100644 index 00000000..87428c77 --- /dev/null +++ b/heu/library/spi/poly/ntt_op.h @@ -0,0 +1,35 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "heu/library/spi/poly/poly_def.h" + +namespace heu::lib::spi { + +// Performs nega-cyclic forward and inverse number-theoretic transform (NTT) +// nega-cyclic means polynomial is mod by (X^N + 1) +class NttOperator { + public: + virtual ~NttOperator() = default; + + //=== (Batched) NTT Operations ===// + virtual Polys Forward(const Polys &Polys_in) const = 0; + virtual void ForwardInplace(Polys *Polys) const = 0; + + virtual Polys Inverse(const Polys &Polys_in) const = 0; + virtual void InverseInplace(Polys *Polys) const = 0; +}; + +} // namespace heu::lib::spi diff --git a/heu/library/spi/poly/poly_def.h b/heu/library/spi/poly/poly_def.h new file mode 100644 index 00000000..6be6b53f --- /dev/null +++ b/heu/library/spi/poly/poly_def.h @@ -0,0 +1,40 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace heu::lib::spi { + +class Poly : public std::vector { + public: + using std::vector::vector; +}; + +// Polys can store a batch of independent Polys, or sub-polynomials decomposed +// by RNS. +class Polys { + public: + // add public functions here + private: + std::vector polys_; +}; + +using Moduli = std::vector; + +using RnsPoly = Polys; + +} // namespace heu::lib::spi diff --git a/heu/library/spi/poly/poly_op.h b/heu/library/spi/poly/poly_op.h new file mode 100644 index 00000000..bfb0e98a --- /dev/null +++ b/heu/library/spi/poly/poly_op.h @@ -0,0 +1,88 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "heu/library/spi/poly/poly_def.h" + +namespace heu::lib::spi { + +class ElementWisePolyOperator { + public: + virtual ~ElementWisePolyOperator() = default; + + //=== Element-wise polynomial operations ===// + + // just mod the coefficients + /// \param coeff_modulus coefficient modulus of each polynomial + virtual Polys Mod(const Polys &in, const Moduli &coeff_modulus) const = 0; + virtual void ModInplace(Polys *polys, const Moduli &coeff_modulus) const = 0; + + // just mod the coefficients + virtual Polys NegateMod(const Polys &in, + const Moduli &coeff_modulus) const = 0; + virtual void NegateModInplace(Polys *polys, + const Moduli &coeff_modulus) const = 0; + + /// Add two batches of polynomials + /// \param in1 first batch of polynomials + /// \param in2 second batch of polynomials + /// \param coeff_modulus coefficient modulus of each polynomial + virtual Polys AddMod(const Polys &in1, const Polys &in2, + const Moduli &coeff_modulus) const = 0; + virtual void AddModInplace(Polys *polys_1, const Polys &polys_2, + const Moduli &coeff_modulus) const = 0; + + // add scalar to coefficients with broadcast + virtual Polys AddMod(const Polys &polys_in, + const std::vector &scalar_in, + const Moduli &coeff_modulus) const = 0; + virtual void AddModInplace(Polys *polys, + const std::vector &scalar_in, + const Moduli &coeff_modulus) const = 0; + + virtual Polys SubMod(const Polys &in1, const Polys &in2, + const Moduli &coeff_modulus) const = 0; + virtual void SubModInplace(Polys *polys_1, const Polys &polys_2, + const Moduli &coeff_modulus) const = 0; + + virtual Polys SubMod(const Polys &polys_in, + const std::vector &scalar_in, + const Moduli &coeff_modulus) const = 0; + virtual void SubModInplace(Polys *polys, + const std::vector &scalar_in, + const Moduli &coeff_modulus) const = 0; + + virtual Polys MulMod(const Polys &in1, const Polys &in2, + const Moduli &coeff_modulus) const = 0; + virtual void MulModInplace(Polys *polys_1, const Polys &polys_2, + const Moduli &coeff_modulus) const = 0; + + virtual Polys MulMod(const Polys &polys_in, + const std::vector &scalar_in, + const Moduli &coeff_modulus) const = 0; + virtual void MulModInplace(Polys *polys, + const std::vector &scalar_in, + const Moduli &coeff_modulus) const = 0; + + // move coeff[0] to coeff[offset], + // if offset > 0 do negacyclic shift, otherwise do cyclic shift + virtual Polys Shift(const Polys &in, int64_t offset) const = 0; + virtual void ShiftInplace(Polys *polys, int64_t offset) const = 0; +}; + +} // namespace heu::lib::spi