/** * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include #include #include #include #include #include #include using namespace faiss; typedef Index::idx_t idx_t; // dimension of the vectors to index int d = 32; // nb of training vectors size_t nt = 5000; // size of the database points per window step size_t nb = 1000; // nb of queries size_t nq = 200; int total_size = 40; int window_size = 10; std::vector make_data(size_t n) { std::vector database(n * d); std::mt19937 rng; std::uniform_real_distribution<> distrib; for (size_t i = 0; i < n * d; i++) { database[i] = distrib(rng); } return database; } std::unique_ptr make_trained_index(const char* index_type) { auto index = std::unique_ptr(index_factory(d, index_type)); auto xt = make_data(nt * d); index->train(nt, xt.data()); ParameterSpace().set_index_parameter(index.get(), "nprobe", 4); return index; } std::vector search_index(Index* index, const float* xq) { int k = 10; std::vector I(k * nq); std::vector D(k * nq); index->search(nq, xq, k, D.data(), I.data()); return I; } /************************************************************* * Test functions for a given index type *************************************************************/ // make a few slices of indexes that can be merged void make_index_slices( const Index* trained_index, std::vector>& sub_indexes) { for (int i = 0; i < total_size; i++) { sub_indexes.emplace_back(clone_index(trained_index)); printf("preparing sub-index # %d\n", i); Index* index = sub_indexes.back().get(); auto xb = make_data(nb * d); std::vector ids(nb); std::mt19937 rng; std::uniform_int_distribution<> distrib; for (int j = 0; j < nb; j++) { ids[j] = distrib(rng); } index->add_with_ids(nb, xb.data(), ids.data()); } } // build merged index explicitly at sliding window position i Index* make_merged_index( const Index* trained_index, const std::vector>& sub_indexes, int i) { Index* merged_index = clone_index(trained_index); for (int j = i - window_size + 1; j <= i; j++) { if (j < 0 || j >= total_size) continue; std::unique_ptr sub_index(clone_index(sub_indexes[j].get())); IndexIVF* ivf0 = ivflib::extract_index_ivf(merged_index); IndexIVF* ivf1 = ivflib::extract_index_ivf(sub_index.get()); ivf0->merge_from(*ivf1, 0); merged_index->ntotal = ivf0->ntotal; } return merged_index; } int test_sliding_window(const char* index_key) { std::unique_ptr trained_index = make_trained_index(index_key); // make the index slices std::vector> sub_indexes; make_index_slices(trained_index.get(), sub_indexes); // now slide over the windows std::unique_ptr index(clone_index(trained_index.get())); ivflib::SlidingIndexWindow window(index.get()); auto xq = make_data(nq * d); for (int i = 0; i < total_size + window_size; i++) { printf("doing step %d / %d\n", i, total_size + window_size); // update the index window.step( i < total_size ? sub_indexes[i].get() : nullptr, i >= window_size); printf(" current n_slice = %d\n", window.n_slice); auto new_res = search_index(index.get(), xq.data()); std::unique_ptr merged_index( make_merged_index(trained_index.get(), sub_indexes, i)); auto ref_res = search_index(merged_index.get(), xq.data()); EXPECT_EQ(ref_res.size(), new_res.size()); EXPECT_EQ(ref_res, new_res); } return 0; } int test_sliding_invlists(const char* index_key) { std::unique_ptr trained_index = make_trained_index(index_key); // make the index slices std::vector> sub_indexes; make_index_slices(trained_index.get(), sub_indexes); // now slide over the windows std::unique_ptr index(clone_index(trained_index.get())); IndexIVF* index_ivf = ivflib::extract_index_ivf(index.get()); auto xq = make_data(nq * d); for (int i = 0; i < total_size + window_size; i++) { printf("doing step %d / %d\n", i, total_size + window_size); // update the index std::vector ils; for (int j = i - window_size + 1; j <= i; j++) { if (j < 0 || j >= total_size) continue; ils.push_back( ivflib::extract_index_ivf(sub_indexes[j].get())->invlists); } if (ils.size() == 0) continue; ConcatenatedInvertedLists* ci = new ConcatenatedInvertedLists(ils.size(), ils.data()); // will be deleted by the index index_ivf->replace_invlists(ci, true); printf(" nb invlists = %zd\n", ils.size()); auto new_res = search_index(index.get(), xq.data()); std::unique_ptr merged_index( make_merged_index(trained_index.get(), sub_indexes, i)); auto ref_res = search_index(merged_index.get(), xq.data()); EXPECT_EQ(ref_res.size(), new_res.size()); size_t ndiff = 0; for (size_t j = 0; j < ref_res.size(); j++) { if (ref_res[j] != new_res[j]) ndiff++; } printf(" nb differences: %zd / %zd\n", ndiff, ref_res.size()); EXPECT_EQ(ref_res, new_res); } return 0; } /************************************************************* * Test entry points *************************************************************/ TEST(SlidingWindow, IVFFlat) { test_sliding_window("IVF32,Flat"); } TEST(SlidingWindow, PCAIVFFlat) { test_sliding_window("PCA24,IVF32,Flat"); } TEST(SlidingInvlists, IVFFlat) { test_sliding_invlists("IVF32,Flat"); } TEST(SlidingInvlists, PCAIVFFlat) { test_sliding_invlists("PCA24,IVF32,Flat"); }