22 std::vector<lbcrypto::Plaintext> encodeVector(
const std::vector<std::vector<std::int64_t>> &vec);
23 std::vector<lbcrypto::Ciphertext<lbcrypto::DCRTPoly>> encryptVector(
const std::vector<lbcrypto::Plaintext> &encoded_vec);
24 std::vector<lbcrypto::Ciphertext<lbcrypto::DCRTPoly>> eltwiseadd(
const std::vector<lbcrypto::Plaintext> &A,
25 const std::vector<lbcrypto::Ciphertext<lbcrypto::DCRTPoly>> &B);
26 std::vector<lbcrypto::Plaintext> decryptResult(
const std::vector<lbcrypto::Ciphertext<lbcrypto::DCRTPoly>> &encrypted_result);
27 std::vector<std::vector<int64_t>> decodeResult(
const std::vector<lbcrypto::Plaintext> &encoded_result);
30 class PalisadeBFVContext
33 PalisadeBFVContext(
int poly_modulus_degree)
35 std::size_t plaintext_modulus = 65537;
36 lbcrypto::SecurityLevel sec_level = lbcrypto::HEStd_128_classic;
38 std::size_t num_coeff_moduli = 2;
39 std::size_t max_depth = 2;
40 std::size_t coeff_moduli_bits = 40;
42 lbcrypto::CryptoContext<lbcrypto::DCRTPoly> crypto_context =
43 lbcrypto::CryptoContextFactory<lbcrypto::DCRTPoly>::genCryptoContextBFVrns(
44 plaintext_modulus, sec_level, sigma, 0, num_coeff_moduli,
45 0, OPTIMIZED, max_depth, 0, coeff_moduli_bits, poly_modulus_degree);
46 crypto_context->Enable(ENCRYPTION);
47 crypto_context->Enable(SHE);
49 lbcrypto::LPKeyPair<lbcrypto::DCRTPoly> local_key = crypto_context->KeyGen();
50 lbcrypto::LPKeyPair<lbcrypto::DCRTPoly> *p_key =
new lbcrypto::LPKeyPair<lbcrypto::DCRTPoly>(local_key.publicKey, local_key.secretKey);
51 local_key = lbcrypto::LPKeyPair<lbcrypto::DCRTPoly>();
52 m_keys = std::unique_ptr<lbcrypto::LPKeyPair<lbcrypto::DCRTPoly>>(p_key);
54 m_p_palisade_context = std::make_shared<lbcrypto::CryptoContext<lbcrypto::DCRTPoly>>(crypto_context);
55 m_slot_count = poly_modulus_degree;
58 auto publicKey()
const {
return m_keys->publicKey; }
59 std::size_t getSlotCount()
const {
return m_slot_count; }
60 lbcrypto::CryptoContext<lbcrypto::DCRTPoly> &context() {
return *m_p_palisade_context; }
61 void decrypt(
const lbcrypto::Ciphertext<lbcrypto::DCRTPoly> &cipher, lbcrypto::Plaintext &plain)
63 context()->Decrypt(m_keys->secretKey, cipher, &plain);
66 lbcrypto::Plaintext
decrypt(
const lbcrypto::Ciphertext<lbcrypto::DCRTPoly> &cipher)
68 lbcrypto::Plaintext retval;
74 std::shared_ptr<lbcrypto::CryptoContext<lbcrypto::DCRTPoly>> m_p_palisade_context;
75 std::unique_ptr<lbcrypto::LPKeyPair<lbcrypto::DCRTPoly>> m_keys;
76 std::size_t m_slot_count;
79 std::size_t m_vector_size;
80 std::shared_ptr<PalisadeBFVContext> m_p_context;
82 PalisadeBFVContext &context() {
return *m_p_context; }
88 throw std::invalid_argument(
"vector_size");
89 m_vector_size = vector_size;
90 m_p_context = std::make_shared<PalisadeBFVContext>(8192);
93 std::vector<lbcrypto::Plaintext> Workload::encodeVector(
const std::vector<std::vector<std::int64_t>> &vec)
95 std::vector<lbcrypto::Plaintext> retval(vec.size());
97 for (std::size_t i = 0; i < vec.size(); ++i)
99 assert(vec[i].size() <= context().getSlotCount());
100 retval[i] = context().context()->MakePackedPlaintext(vec[i]);
105 std::vector<lbcrypto::Ciphertext<lbcrypto::DCRTPoly>> Workload::encryptVector(
const std::vector<lbcrypto::Plaintext> &encoded_vec)
107 std::vector<lbcrypto::Ciphertext<lbcrypto::DCRTPoly>> retval(encoded_vec.size());
108 for (std::size_t i = 0; i < encoded_vec.size(); i++)
109 retval[i] = context().context()->Encrypt(context().publicKey(), encoded_vec[i]);
113 std::vector<lbcrypto::Ciphertext<lbcrypto::DCRTPoly>> Workload::eltwiseadd(
const std::vector<lbcrypto::Plaintext> &A,
114 const std::vector<lbcrypto::Ciphertext<lbcrypto::DCRTPoly>> &B)
116 std::vector<lbcrypto::Ciphertext<lbcrypto::DCRTPoly>> retval(A.size() * B.size());
123 for (std::size_t A_i = 0; A_i < A.size(); ++A_i)
124 for (std::size_t B_i = 0; B_i < B.size(); ++B_i)
126 lbcrypto::Ciphertext<lbcrypto::DCRTPoly> &retval_item = retval[A_i * B.size() + B_i];
127 retval_item = context().context()->EvalAdd(B[B_i], A[A_i]);
133 std::vector<lbcrypto::Plaintext> Workload::decryptResult(
const std::vector<lbcrypto::Ciphertext<lbcrypto::DCRTPoly>> &encrypted_result)
135 std::vector<lbcrypto::Plaintext> retval(encrypted_result.size());
136 for (std::size_t i = 0; i < encrypted_result.size(); i++)
137 retval[i] = context().decrypt(encrypted_result[i]);
141 std::vector<std::vector<int64_t>> Workload::decodeResult(
const std::vector<lbcrypto::Plaintext> &encoded_result)
143 std::vector<std::vector<int64_t>> retval(encoded_result.size());
144 for (std::size_t i = 0; i < encoded_result.size(); ++i)
146 retval[i] = encoded_result[i]->GetPackedValue();
147 retval[i].resize(m_vector_size);
157 static const std::size_t VectorSize = 400;
158 static const std::size_t DimensionA = 2;
159 static const std::size_t DimensionB = 5;
162 std::vector<std::vector<std::int64_t>> A(DimensionA, std::vector<std::int64_t>(VectorSize));
163 std::vector<std::vector<std::int64_t>> B(DimensionB, std::vector<std::int64_t>(VectorSize));
166 std::random_device rd;
167 std::mt19937 gen(rd());
168 std::uniform_int_distribution<> distrib(-100, 100);
170 for (std::size_t i = 0; i < A.size(); ++i)
171 for (std::size_t j = 0; j < A[i].size(); ++j)
172 A[i][j] = distrib(gen);
173 for (std::size_t i = 0; i < B.size(); ++i)
174 for (std::size_t j = 0; j < B[i].size(); ++j)
175 B[i][j] = distrib(gen);
181 std::shared_ptr<Workload> p_workload =
182 std::make_shared<Workload>(VectorSize);
186 std::vector<std::vector<std::int64_t>> result =
187 workload.decodeResult(
188 workload.decryptResult(
189 workload.eltwiseadd(workload.encodeVector(A),
190 workload.encryptVector(workload.encodeVector(B)))));
198 std::vector<std::vector<std::int64_t>> exp_out(DimensionA * DimensionB, std::vector<std::int64_t>(VectorSize));
199 assert(exp_out.size() == result.size());
202 std::size_t result_i = 0;
203 for (std::size_t A_i = 0; A_i < A.size(); ++A_i)
204 for (std::size_t B_i = 0; B_i < B.size(); ++B_i)
206 for (std::size_t i = 0; i < VectorSize; ++i)
207 exp_out[result_i][i] = A[A_i][i] + B[B_i][i];
211 if (result == exp_out)
212 std::cout <<
"OK" << std::endl;
214 std::cout <<
"Fail" << std::endl;
ErrorCode decrypt(Handle h_benchmark, Handle h_ciphertext, Handle *h_plaintext)
Decrypts a cipher text into corresponding plain text.
Workload
Defines all possible workloads.
int main(int argc, char **argv)