// Copyright 2024-present the vsag project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "vsag/vsag.h" TEST_CASE("index params", "[ft][factory]") { int dim = 16; int max_elements = 1000; int max_degree = 16; int ef_construction = 100; int ef_search = 100; nlohmann::json hnsw_parameters{ {"max_degree", max_degree}, {"ef_construction", ef_construction}, {"ef_search", ef_search}, }; nlohmann::json index_parameters{ {"dtype", "float32"}, {"metric_type", "l2"}, {"dim", dim}, {"hnsw", hnsw_parameters}}; std::shared_ptr hnsw; auto index = vsag::Factory::CreateIndex("hnsw", index_parameters.dump()); REQUIRE(index.has_value()); hnsw = index.value(); // Generate random data std::mt19937 rng; rng.seed(47); std::uniform_real_distribution<> distrib_real; int64_t* ids = new int64_t[max_elements]; float* data = new float[dim * max_elements]; for (int i = 0; i < max_elements; i++) { ids[i] = i; } for (int i = 0; i < dim * max_elements; i++) { data[i] = distrib_real(rng); } auto dataset = vsag::Dataset::Make(); dataset->Dim(dim)->NumElements(max_elements)->Ids(ids)->Float32Vectors(data); hnsw->Build(dataset); // Query the elements for themselves and measure recall 1@1 float correct = 0; for (int i = 0; i < max_elements; i++) { auto query = vsag::Dataset::Make(); query->NumElements(1)->Dim(dim)->Float32Vectors(data + i * dim)->Owner(false); nlohmann::json parameters{ {"hnsw", {{"ef_search", ef_search}}}, }; int64_t k = 10; if (auto result = hnsw->KnnSearch(query, k, parameters.dump()); result.has_value()) { if (result.value()->GetIds()[0] == i) { correct++; } } else if (result.error().type == vsag::ErrorType::INTERNAL_ERROR) { std::cerr << "failed to search on index: internalError" << std::endl; } } float recall = correct / max_elements; REQUIRE(recall == 1); }