/** * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include #include #include #include #include #include #include using namespace faiss; namespace { typedef Index::idx_t idx_t; // dimension of the vectors to index int d = 32; // nb of training vectors size_t nt = 5000; // size of the database points per window step size_t nb = 1000; // nb of queries size_t nq = 200; std::mt19937 rng; std::vector make_data(size_t n) { std::vector database(n * d); std::uniform_real_distribution<> distrib; for (size_t i = 0; i < n * d; i++) { database[i] = distrib(rng); } return database; } std::unique_ptr make_trained_index(const char* index_type) { auto index = std::unique_ptr(index_factory(d, index_type)); auto xt = make_data(nt * d); index->train(nt, xt.data()); ParameterSpace().set_index_parameter(index.get(), "nprobe", 4); return index; } std::vector search_index(Index* index, const float* xq) { int k = 10; std::vector I(k * nq); std::vector D(k * nq); index->search(nq, xq, k, D.data(), I.data()); return I; } /************************************************************* * Test functions for a given index type *************************************************************/ struct EncapsulateInvertedLists : InvertedLists { const InvertedLists* il; EncapsulateInvertedLists(const InvertedLists* il) : InvertedLists(il->nlist, il->code_size), il(il) {} static void* memdup(const void* m, size_t size) { if (size == 0) return nullptr; return memcpy(malloc(size), m, size); } size_t list_size(size_t list_no) const override { return il->list_size(list_no); } const uint8_t* get_codes(size_t list_no) const override { return (uint8_t*)memdup( il->get_codes(list_no), list_size(list_no) * code_size); } const idx_t* get_ids(size_t list_no) const override { return (idx_t*)memdup( il->get_ids(list_no), list_size(list_no) * sizeof(idx_t)); } void release_codes(size_t, const uint8_t* codes) const override { free((void*)codes); } void release_ids(size_t, const idx_t* ids) const override { free((void*)ids); } const uint8_t* get_single_code(size_t list_no, size_t offset) const override { return (uint8_t*)memdup( il->get_single_code(list_no, offset), code_size); } size_t add_entries(size_t, size_t, const idx_t*, const uint8_t*) override { assert(!"not implemented"); return 0; } void update_entries(size_t, size_t, size_t, const idx_t*, const uint8_t*) override { assert(!"not implemented"); } void resize(size_t, size_t) override { assert(!"not implemented"); } ~EncapsulateInvertedLists() override {} }; int test_dealloc_invlists(const char* index_key) { std::unique_ptr index = make_trained_index(index_key); IndexIVF* index_ivf = ivflib::extract_index_ivf(index.get()); auto xb = make_data(nb * d); index->add(nb, xb.data()); auto xq = make_data(nq * d); auto ref_res = search_index(index.get(), xq.data()); EncapsulateInvertedLists eil(index_ivf->invlists); index_ivf->own_invlists = false; index_ivf->replace_invlists(&eil, false); // TEST: this could crash or leak mem auto new_res = search_index(index.get(), xq.data()); // delete explicitly delete eil.il; // just to make sure EXPECT_EQ(ref_res, new_res); return 0; } } // anonymous namespace /************************************************************* * Test entry points *************************************************************/ TEST(TestIvlistDealloc, IVFFlat) { test_dealloc_invlists("IVF32,Flat"); } TEST(TestIvlistDealloc, IVFSQ) { test_dealloc_invlists("IVF32,SQ8"); } TEST(TestIvlistDealloc, IVFPQ) { test_dealloc_invlists("IVF32,PQ4np"); }