// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "common_includes.h" #include #include "index.h" #include "disk_utils.h" #include "math_utils.h" #include "memory_mapper.h" #include "partition.h" #include "pq_flash_index.h" #include "timer.h" #include "percentile_stats.h" #include "program_options_utils.hpp" #ifndef _WINDOWS #include #include #include #include "linux_aligned_file_reader.h" #else #ifdef USE_BING_INFRA #include "bing_aligned_file_reader.h" #else #include "windows_aligned_file_reader.h" #endif #endif #define WARMUP false namespace po = boost::program_options; void print_stats(std::string category, std::vector percentiles, std::vector results) { diskann::cout << std::setw(20) << category << ": " << std::flush; for (uint32_t s = 0; s < percentiles.size(); s++) { diskann::cout << std::setw(8) << percentiles[s] << "%"; } diskann::cout << std::endl; diskann::cout << std::setw(22) << " " << std::flush; for (uint32_t s = 0; s < percentiles.size(); s++) { diskann::cout << std::setw(9) << results[s]; } diskann::cout << std::endl; } template int search_disk_index(diskann::Metric &metric, const std::string &index_path_prefix, const std::string &result_output_prefix, const std::string &query_file, std::string >_file, const uint32_t num_threads, const uint32_t recall_at, const uint32_t beamwidth, const uint32_t num_nodes_to_cache, const uint32_t search_io_limit, const std::vector &Lvec, const float fail_if_recall_below, const std::vector &query_filters, const bool use_reorder_data = false) { diskann::cout << "Search parameters: #threads: " << num_threads << ", "; if (beamwidth <= 0) diskann::cout << "beamwidth to be optimized for each L value" << std::flush; else diskann::cout << " beamwidth: " << beamwidth << std::flush; if (search_io_limit == std::numeric_limits::max()) diskann::cout << "." << std::endl; else diskann::cout << ", io_limit: " << search_io_limit << "." << std::endl; std::string warmup_query_file = index_path_prefix + "_sample_data.bin"; // load query bin T *query = nullptr; uint32_t *gt_ids = nullptr; float *gt_dists = nullptr; size_t query_num, query_dim, query_aligned_dim, gt_num, gt_dim; diskann::load_aligned_bin(query_file, query, query_num, query_dim, query_aligned_dim); bool filtered_search = false; if (!query_filters.empty()) { filtered_search = true; if (query_filters.size() != 1 && query_filters.size() != query_num) { std::cout << "Error. Mismatch in number of queries and size of query " "filters file" << std::endl; return -1; // To return -1 or some other error handling? } } bool calc_recall_flag = false; if (gt_file != std::string("null") && gt_file != std::string("NULL") && file_exists(gt_file)) { diskann::load_truthset(gt_file, gt_ids, gt_dists, gt_num, gt_dim); if (gt_num != query_num) { diskann::cout << "Error. Mismatch in number of queries and ground truth data" << std::endl; } calc_recall_flag = true; } std::shared_ptr reader = nullptr; #ifdef _WINDOWS #ifndef USE_BING_INFRA reader.reset(new WindowsAlignedFileReader()); #else reader.reset(new diskann::BingAlignedFileReader()); #endif #else reader.reset(new LinuxAlignedFileReader()); #endif std::unique_ptr> _pFlashIndex( new diskann::PQFlashIndex(reader, metric)); int res = _pFlashIndex->load(num_threads, index_path_prefix.c_str()); if (res != 0) { return res; } // cache bfs levels std::vector node_list; diskann::cout << "Caching " << num_nodes_to_cache << " BFS nodes around medoid(s)" << std::endl; //_pFlashIndex->cache_bfs_levels(num_nodes_to_cache, node_list); if (num_nodes_to_cache > 0) _pFlashIndex->generate_cache_list_from_sample_queries(warmup_query_file, 15, 6, num_nodes_to_cache, num_threads, node_list); _pFlashIndex->load_cache_list(node_list); node_list.clear(); node_list.shrink_to_fit(); omp_set_num_threads(num_threads); uint64_t warmup_L = 20; uint64_t warmup_num = 0, warmup_dim = 0, warmup_aligned_dim = 0; T *warmup = nullptr; if (WARMUP) { if (file_exists(warmup_query_file)) { diskann::load_aligned_bin(warmup_query_file, warmup, warmup_num, warmup_dim, warmup_aligned_dim); } else { warmup_num = (std::min)((uint32_t)150000, (uint32_t)15000 * num_threads); warmup_dim = query_dim; warmup_aligned_dim = query_aligned_dim; diskann::alloc_aligned(((void **)&warmup), warmup_num * warmup_aligned_dim * sizeof(T), 8 * sizeof(T)); std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T)); std::random_device rd; std::mt19937 gen(rd()); std::uniform_int_distribution<> dis(-128, 127); for (uint32_t i = 0; i < warmup_num; i++) { for (uint32_t d = 0; d < warmup_dim; d++) { warmup[i * warmup_aligned_dim + d] = (T)dis(gen); } } } diskann::cout << "Warming up index... " << std::flush; std::vector warmup_result_ids_64(warmup_num, 0); std::vector warmup_result_dists(warmup_num, 0); #pragma omp parallel for schedule(dynamic, 1) for (int64_t i = 0; i < (int64_t)warmup_num; i++) { _pFlashIndex->cached_beam_search(warmup + (i * warmup_aligned_dim), 1, warmup_L, warmup_result_ids_64.data() + (i * 1), warmup_result_dists.data() + (i * 1), 4); } diskann::cout << "..done" << std::endl; } diskann::cout.setf(std::ios_base::fixed, std::ios_base::floatfield); diskann::cout.precision(2); std::string recall_string = "Recall@" + std::to_string(recall_at); diskann::cout << std::setw(6) << "L" << std::setw(12) << "Beamwidth" << std::setw(16) << "QPS" << std::setw(16) << "Mean Latency" << std::setw(16) << "99.9 Latency" << std::setw(16) << "Mean IOs" << std::setw(16) << "CPU (s)"; if (calc_recall_flag) { diskann::cout << std::setw(16) << recall_string << std::endl; } else diskann::cout << std::endl; diskann::cout << "===============================================================" "=======================================================" << std::endl; std::vector> query_result_ids(Lvec.size()); std::vector> query_result_dists(Lvec.size()); uint32_t optimized_beamwidth = 2; double best_recall = 0.0; for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++) { uint32_t L = Lvec[test_id]; if (L < recall_at) { diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl; continue; } if (beamwidth <= 0) { diskann::cout << "Tuning beamwidth.." << std::endl; optimized_beamwidth = optimize_beamwidth(_pFlashIndex, warmup, warmup_num, warmup_aligned_dim, L, optimized_beamwidth); } else optimized_beamwidth = beamwidth; query_result_ids[test_id].resize(recall_at * query_num); query_result_dists[test_id].resize(recall_at * query_num); auto stats = new diskann::QueryStats[query_num]; std::vector query_result_ids_64(recall_at * query_num); auto s = std::chrono::high_resolution_clock::now(); #pragma omp parallel for schedule(dynamic, 1) for (int64_t i = 0; i < (int64_t)query_num; i++) { if (!filtered_search) { _pFlashIndex->cached_beam_search(query + (i * query_aligned_dim), recall_at, L, query_result_ids_64.data() + (i * recall_at), query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, use_reorder_data, stats + i); } else { LabelT label_for_search; if (query_filters.size() == 1) { // one label for all queries label_for_search = _pFlashIndex->get_converted_label(query_filters[0]); } else { // one label for each query label_for_search = _pFlashIndex->get_converted_label(query_filters[i]); } _pFlashIndex->cached_beam_search( query + (i * query_aligned_dim), recall_at, L, query_result_ids_64.data() + (i * recall_at), query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, label_for_search, use_reorder_data, stats + i); } } auto e = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = e - s; double qps = (1.0 * query_num) / (1.0 * diff.count()); diskann::convert_types(query_result_ids_64.data(), query_result_ids[test_id].data(), query_num, recall_at); auto mean_latency = diskann::get_mean_stats( stats, query_num, [](const diskann::QueryStats &stats) { return stats.total_us; }); auto latency_999 = diskann::get_percentile_stats( stats, query_num, 0.999, [](const diskann::QueryStats &stats) { return stats.total_us; }); auto mean_ios = diskann::get_mean_stats(stats, query_num, [](const diskann::QueryStats &stats) { return stats.n_ios; }); auto mean_cpuus = diskann::get_mean_stats(stats, query_num, [](const diskann::QueryStats &stats) { return stats.cpu_us; }); double recall = 0; if (calc_recall_flag) { recall = diskann::calculate_recall((uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim, query_result_ids[test_id].data(), recall_at, recall_at); best_recall = std::max(recall, best_recall); } diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth << std::setw(16) << qps << std::setw(16) << mean_latency << std::setw(16) << latency_999 << std::setw(16) << mean_ios << std::setw(16) << mean_cpuus; if (calc_recall_flag) { diskann::cout << std::setw(16) << recall << std::endl; } else diskann::cout << std::endl; delete[] stats; } diskann::cout << "Done searching. Now saving results " << std::endl; uint64_t test_id = 0; for (auto L : Lvec) { if (L < recall_at) continue; std::string cur_result_path = result_output_prefix + "_" + std::to_string(L) + "_idx_uint32.bin"; diskann::save_bin(cur_result_path, query_result_ids[test_id].data(), query_num, recall_at); cur_result_path = result_output_prefix + "_" + std::to_string(L) + "_dists_float.bin"; diskann::save_bin(cur_result_path, query_result_dists[test_id++].data(), query_num, recall_at); } diskann::aligned_free(query); if (warmup != nullptr) diskann::aligned_free(warmup); return best_recall >= fail_if_recall_below ? 0 : -1; } int main(int argc, char **argv) { std::string data_type, dist_fn, index_path_prefix, result_path_prefix, query_file, gt_file, filter_label, label_type, query_filters_file; uint32_t num_threads, K, W, num_nodes_to_cache, search_io_limit; std::vector Lvec; bool use_reorder_data = false; float fail_if_recall_below = 0.0f; po::options_description desc{ program_options_utils::make_program_description("search_disk_index", "Searches on-disk DiskANN indexes")}; try { desc.add_options()("help,h", "Print information on arguments"); // Required parameters po::options_description required_configs("Required"); required_configs.add_options()("data_type", po::value(&data_type)->required(), program_options_utils::DATA_TYPE_DESCRIPTION); required_configs.add_options()("dist_fn", po::value(&dist_fn)->required(), program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); required_configs.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); required_configs.add_options()("result_path", po::value(&result_path_prefix)->required(), program_options_utils::RESULT_PATH_DESCRIPTION); required_configs.add_options()("query_file", po::value(&query_file)->required(), program_options_utils::QUERY_FILE_DESCRIPTION); required_configs.add_options()("recall_at,K", po::value(&K)->required(), program_options_utils::NUMBER_OF_RESULTS_DESCRIPTION); required_configs.add_options()("search_list,L", po::value>(&Lvec)->multitoken()->required(), program_options_utils::SEARCH_LIST_DESCRIPTION); // Optional parameters po::options_description optional_configs("Optional"); optional_configs.add_options()("gt_file", po::value(>_file)->default_value(std::string("null")), program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION); optional_configs.add_options()("beamwidth,W", po::value(&W)->default_value(2), program_options_utils::BEAMWIDTH); optional_configs.add_options()("num_nodes_to_cache", po::value(&num_nodes_to_cache)->default_value(0), program_options_utils::NUMBER_OF_NODES_TO_CACHE); optional_configs.add_options()( "search_io_limit", po::value(&search_io_limit)->default_value(std::numeric_limits::max()), "Max #IOs for search. Default value: uint32::max()"); optional_configs.add_options()("num_threads,T", po::value(&num_threads)->default_value(omp_get_num_procs()), program_options_utils::NUMBER_THREADS_DESCRIPTION); optional_configs.add_options()("use_reorder_data", po::bool_switch()->default_value(false), "Include full precision data in the index. Use only in " "conjuction with compressed data on SSD. Default value: false"); optional_configs.add_options()("filter_label", po::value(&filter_label)->default_value(std::string("")), program_options_utils::FILTER_LABEL_DESCRIPTION); optional_configs.add_options()("query_filters_file", po::value(&query_filters_file)->default_value(std::string("")), program_options_utils::FILTERS_FILE_DESCRIPTION); optional_configs.add_options()("label_type", po::value(&label_type)->default_value("uint"), program_options_utils::LABEL_TYPE_DESCRIPTION); optional_configs.add_options()("fail_if_recall_below", po::value(&fail_if_recall_below)->default_value(0.0f), program_options_utils::FAIL_IF_RECALL_BELOW); // Merge required and optional parameters desc.add(required_configs).add(optional_configs); po::variables_map vm; po::store(po::parse_command_line(argc, argv, desc), vm); if (vm.count("help")) { std::cout << desc; return 0; } po::notify(vm); if (vm["use_reorder_data"].as()) use_reorder_data = true; } catch (const std::exception &ex) { std::cerr << ex.what() << '\n'; return -1; } diskann::Metric metric; if (dist_fn == std::string("mips")) { metric = diskann::Metric::INNER_PRODUCT; } else if (dist_fn == std::string("l2")) { metric = diskann::Metric::L2; } else if (dist_fn == std::string("cosine")) { metric = diskann::Metric::COSINE; } else { std::cout << "Unsupported distance function. Currently only L2/ Inner " "Product/Cosine are supported." << std::endl; return -1; } if ((data_type != std::string("float")) && (metric == diskann::Metric::INNER_PRODUCT)) { std::cout << "Currently support only floating point data for Inner Product." << std::endl; return -1; } if (use_reorder_data && data_type != std::string("float")) { std::cout << "Error: Reorder data for reordering currently only " "supported for float data type." << std::endl; return -1; } if (filter_label != "" && query_filters_file != "") { std::cerr << "Only one of filter_label and query_filters_file should be provided" << std::endl; return -1; } std::vector query_filters; if (filter_label != "") { query_filters.push_back(filter_label); } else if (query_filters_file != "") { query_filters = read_file_to_vector_of_strings(query_filters_file); } try { if (!query_filters.empty() && label_type == "ushort") { if (data_type == std::string("float")) return search_disk_index( metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data); else if (data_type == std::string("int8")) return search_disk_index( metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data); else if (data_type == std::string("uint8")) return search_disk_index( metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data); else { std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl; return -1; } } else { if (data_type == std::string("float")) return search_disk_index(metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data); else if (data_type == std::string("int8")) return search_disk_index(metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data); else if (data_type == std::string("uint8")) return search_disk_index(metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data); else { std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl; return -1; } } } catch (const std::exception &e) { std::cout << std::string(e.what()) << std::endl; diskann::cerr << "Index search failed." << std::endl; return -1; } }