// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "seal/context.h" #include "seal/galoiskeys.h" #include "seal/keygenerator.h" #include "seal/modulus.h" #include "seal/util/polyarithsmallmod.h" #include "seal/util/uintcore.h" #include #include "gtest/gtest.h" using namespace seal; using namespace seal::util; using namespace std; namespace sealtest { TEST(GaloisKeysTest, GaloisKeysSaveLoad) { auto galoiskey_save_load = [](scheme_type scheme) { stringstream stream; { EncryptionParameters parms(scheme); parms.set_poly_modulus_degree(64); parms.set_plain_modulus(65537); parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 60 })); SEALContext context(parms, false, sec_level_type::none); KeyGenerator keygen(context); GaloisKeys keys; GaloisKeys test_keys; keys.save(stream); test_keys.unsafe_load(context, stream); ASSERT_EQ(keys.data().size(), test_keys.data().size()); ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); ASSERT_EQ(0ULL, keys.data().size()); keygen.create_galois_keys(keys); keys.save(stream); test_keys.load(context, stream); ASSERT_EQ(keys.data().size(), test_keys.data().size()); ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); for (size_t j = 0; j < test_keys.data().size(); j++) { for (size_t i = 0; i < test_keys.data()[j].size(); i++) { ASSERT_EQ(keys.data()[j][i].data().size(), test_keys.data()[j][i].data().size()); ASSERT_EQ( keys.data()[j][i].data().dyn_array().size(), test_keys.data()[j][i].data().dyn_array().size()); ASSERT_TRUE(is_equal_uint( keys.data()[j][i].data().data(), test_keys.data()[j][i].data().data(), keys.data()[j][i].data().dyn_array().size())); } } ASSERT_EQ(64ULL, keys.data().size()); } { EncryptionParameters parms(scheme); parms.set_poly_modulus_degree(256); parms.set_plain_modulus(65537); parms.set_coeff_modulus(CoeffModulus::Create(256, { 60, 50 })); SEALContext context(parms, false, sec_level_type::none); KeyGenerator keygen(context); GaloisKeys keys; GaloisKeys test_keys; keys.save(stream); test_keys.unsafe_load(context, stream); ASSERT_EQ(keys.data().size(), test_keys.data().size()); ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); ASSERT_EQ(0ULL, keys.data().size()); keygen.create_galois_keys(keys); keys.save(stream); test_keys.load(context, stream); ASSERT_EQ(keys.data().size(), test_keys.data().size()); ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); for (size_t j = 0; j < test_keys.data().size(); j++) { for (size_t i = 0; i < test_keys.data()[j].size(); i++) { ASSERT_EQ(keys.data()[j][i].data().size(), test_keys.data()[j][i].data().size()); ASSERT_EQ( keys.data()[j][i].data().dyn_array().size(), test_keys.data()[j][i].data().dyn_array().size()); ASSERT_TRUE(is_equal_uint( keys.data()[j][i].data().data(), test_keys.data()[j][i].data().data(), keys.data()[j][i].data().dyn_array().size())); } } ASSERT_EQ(256ULL, keys.data().size()); } }; galoiskey_save_load(scheme_type::bfv); galoiskey_save_load(scheme_type::bgv); } TEST(GaloisKeysTest, GaloisKeysSeededSaveLoad) { auto galoiskey_seeded_save_load = [](scheme_type scheme) { // Returns true if a, b contains the same error. auto compare_kswitchkeys = [](const KSwitchKeys &a, const KSwitchKeys &b, const SecretKey &sk, const SEALContext &context) { auto compare_error = [](const Ciphertext &a_ct, const Ciphertext &b_ct, const SecretKey &sk1, const SEALContext &ctx) { auto get_error = [](const Ciphertext &encrypted, const SecretKey &sk2, const SEALContext &ctx2) { auto pool = MemoryManager::GetPool(); auto &ctx2_data = *ctx2.get_context_data(encrypted.parms_id()); auto &parms = ctx2_data.parms(); auto &coeff_modulus = parms.coeff_modulus(); size_t coeff_count = parms.poly_modulus_degree(); size_t coeff_modulus_size = coeff_modulus.size(); size_t rns_poly_uint64_count = util::mul_safe(coeff_count, coeff_modulus_size); DynArray error; error.resize(rns_poly_uint64_count); auto destination = error.begin(); auto copy_operand1(util::allocate_uint(coeff_count, pool)); for (size_t i = 0; i < coeff_modulus_size; i++) { // Initialize pointers for multiplication const uint64_t *encrypted_ptr = encrypted.data(1) + (i * coeff_count); const uint64_t *secret_key_ptr = sk2.data().data() + (i * coeff_count); uint64_t *destination_ptr = destination + (i * coeff_count); util::set_zero_uint(coeff_count, destination_ptr); util::set_uint(encrypted_ptr, coeff_count, copy_operand1.get()); // compute c_{j+1} * s^{j+1} util::dyadic_product_coeffmod( copy_operand1.get(), secret_key_ptr, coeff_count, coeff_modulus[i], copy_operand1.get()); // add c_{j+1} * s^{j+1} to destination util::add_poly_coeffmod( destination_ptr, copy_operand1.get(), coeff_count, coeff_modulus[i], destination_ptr); // add c_0 into destination util::add_poly_coeffmod( destination_ptr, encrypted.data() + (i * coeff_count), coeff_count, coeff_modulus[i], destination_ptr); } return error; }; auto error_a = get_error(a_ct, sk1, ctx); auto error_b = get_error(b_ct, sk1, ctx); ASSERT_EQ(error_a.size(), error_b.size()); ASSERT_TRUE(is_equal_uint(error_a.cbegin(), error_b.cbegin(), error_a.size())); }; ASSERT_EQ(a.size(), b.size()); auto iter_a = a.data().begin(); auto iter_b = b.data().begin(); for (; iter_a != a.data().end(); iter_a++, iter_b++) { ASSERT_EQ(iter_a->size(), iter_b->size()); auto pk_a = iter_a->begin(); auto pk_b = iter_b->begin(); for (; pk_a != iter_a->end(); pk_a++, pk_b++) { compare_error(pk_a->data(), pk_b->data(), sk, context); } } }; stringstream stream; { EncryptionParameters parms(scheme); parms.set_poly_modulus_degree(8); parms.set_plain_modulus(65537); parms.set_coeff_modulus(CoeffModulus::Create(8, { 60, 60 })); prng_seed_type seed; for (auto &i : seed) { i = random_uint64(); } auto rng = make_shared(Blake2xbPRNGFactory(seed)); parms.set_random_generator(rng); SEALContext context(parms, false, sec_level_type::none); KeyGenerator keygen(context); SecretKey secret_key = keygen.secret_key(); keygen.create_galois_keys().save(stream); GaloisKeys test_keys; test_keys.load(context, stream); GaloisKeys keys; keygen.create_galois_keys(keys); compare_kswitchkeys(keys, test_keys, secret_key, context); } { EncryptionParameters parms(scheme); parms.set_poly_modulus_degree(256); parms.set_plain_modulus(65537); parms.set_coeff_modulus(CoeffModulus::Create(256, { 60, 50 })); prng_seed_type seed; for (auto &i : seed) { i = random_uint64(); } auto rng = make_shared(Blake2xbPRNGFactory(seed)); parms.set_random_generator(rng); SEALContext context(parms, false, sec_level_type::none); KeyGenerator keygen(context); SecretKey secret_key = keygen.secret_key(); keygen.create_galois_keys().save(stream); GaloisKeys test_keys; test_keys.load(context, stream); GaloisKeys keys; keygen.create_galois_keys(keys); compare_kswitchkeys(keys, test_keys, secret_key, context); } }; galoiskey_seeded_save_load(scheme_type::bfv); galoiskey_seeded_save_load(scheme_type::bgv); } } // namespace sealtest