/** * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include #include #include #include #include #include #include using namespace faiss; namespace { typedef Index::idx_t idx_t; // dimension of the vectors to index int d = 32; // size of the database we plan to index size_t nb = 1000; // nb of queries size_t nq = 200; std::mt19937 rng; std::vector make_data(size_t n) { std::vector database(n * d); std::uniform_real_distribution<> distrib; for (size_t i = 0; i < n * d; i++) { database[i] = distrib(rng); } return database; } std::unique_ptr make_index( const char* index_type, MetricType metric, const std::vector& x) { std::unique_ptr index(index_factory(d, index_type, metric)); index->train(nb, x.data()); index->add(nb, x.data()); return index; } std::vector search_index(Index* index, const float* xq) { int k = 10; std::vector I(k * nq); std::vector D(k * nq); index->search(nq, xq, k, D.data(), I.data()); return I; } std::vector search_index_with_params( Index* index, const float* xq, IVFSearchParameters* params) { int k = 10; std::vector I(k * nq); std::vector D(k * nq); ivflib::search_with_parameters( index, nq, xq, k, D.data(), I.data(), params); return I; } /************************************************************* * Test functions for a given index type *************************************************************/ int test_params_override(const char* index_key, MetricType metric) { std::vector xb = make_data(nb); // database vectors auto index = make_index(index_key, metric, xb); // index->train(nb, xb.data()); // index->add(nb, xb.data()); std::vector xq = make_data(nq); ParameterSpace ps; ps.set_index_parameter(index.get(), "nprobe", 2); auto res2ref = search_index(index.get(), xq.data()); ps.set_index_parameter(index.get(), "nprobe", 9); auto res9ref = search_index(index.get(), xq.data()); ps.set_index_parameter(index.get(), "nprobe", 1); IVFSearchParameters params; params.max_codes = 0; params.nprobe = 2; auto res2new = search_index_with_params(index.get(), xq.data(), ¶ms); params.nprobe = 9; auto res9new = search_index_with_params(index.get(), xq.data(), ¶ms); if (res2ref != res2new) return 2; if (res9ref != res9new) return 9; return 0; } } // namespace /************************************************************* * Test entry points *************************************************************/ TEST(TPO, IVFFlat) { int err1 = test_params_override("IVF32,Flat", METRIC_L2); EXPECT_EQ(err1, 0); int err2 = test_params_override("IVF32,Flat", METRIC_INNER_PRODUCT); EXPECT_EQ(err2, 0); } TEST(TPO, IVFPQ) { int err1 = test_params_override("IVF32,PQ8np", METRIC_L2); EXPECT_EQ(err1, 0); int err2 = test_params_override("IVF32,PQ8np", METRIC_INNER_PRODUCT); EXPECT_EQ(err2, 0); } TEST(TPO, IVFSQ) { int err1 = test_params_override("IVF32,SQ8", METRIC_L2); EXPECT_EQ(err1, 0); int err2 = test_params_override("IVF32,SQ8", METRIC_INNER_PRODUCT); EXPECT_EQ(err2, 0); } TEST(TPO, IVFFlatPP) { int err1 = test_params_override("PCA16,IVF32,SQ8", METRIC_L2); EXPECT_EQ(err1, 0); int err2 = test_params_override("PCA16,IVF32,SQ8", METRIC_INNER_PRODUCT); EXPECT_EQ(err2, 0); } /************************************************************* * Same for binary indexes *************************************************************/ std::vector make_data_binary(size_t n) { std::vector database(n * d / 8); std::uniform_int_distribution<> distrib; for (size_t i = 0; i < n * d / 8; i++) { database[i] = distrib(rng); } return database; } std::unique_ptr make_index( const char* index_type, const std::vector& x) { auto index = std::unique_ptr( dynamic_cast(index_binary_factory(d, index_type))); index->train(nb, x.data()); index->add(nb, x.data()); return index; } std::vector search_index(IndexBinaryIVF* index, const uint8_t* xq) { int k = 10; std::vector I(k * nq); std::vector D(k * nq); index->search(nq, xq, k, D.data(), I.data()); return I; } std::vector search_index_with_params( IndexBinaryIVF* index, const uint8_t* xq, IVFSearchParameters* params) { int k = 10; std::vector I(k * nq); std::vector D(k * nq); std::vector Iq(params->nprobe * nq); std::vector Dq(params->nprobe * nq); index->quantizer->search(nq, xq, params->nprobe, Dq.data(), Iq.data()); index->search_preassigned( nq, xq, k, Iq.data(), Dq.data(), D.data(), I.data(), false, params); return I; } int test_params_override_binary(const char* index_key) { std::vector xb = make_data_binary(nb); // database vectors auto index = make_index(index_key, xb); index->train(nb, xb.data()); index->add(nb, xb.data()); std::vector xq = make_data_binary(nq); index->nprobe = 2; auto res2ref = search_index(index.get(), xq.data()); index->nprobe = 9; auto res9ref = search_index(index.get(), xq.data()); index->nprobe = 1; IVFSearchParameters params; params.max_codes = 0; params.nprobe = 2; auto res2new = search_index_with_params(index.get(), xq.data(), ¶ms); params.nprobe = 9; auto res9new = search_index_with_params(index.get(), xq.data(), ¶ms); if (res2ref != res2new) return 2; if (res9ref != res9new) return 9; return 0; } TEST(TPOB, IVF) { int err1 = test_params_override_binary("BIVF32"); EXPECT_EQ(err1, 0); }