// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include #include #include #include #include #include #include #include "filter_utils.h" #include #ifndef _WINDOWS #include #endif #include "index.h" #include "memory_mapper.h" #include "parameters.h" #include "utils.h" #include "program_options_utils.hpp" namespace po = boost::program_options; typedef std::tuple>, uint64_t> stitch_indices_return_values; /* * Inline function to display progress bar. */ inline void print_progress(double percentage) { int val = (int)(percentage * 100); int lpad = (int)(percentage * PBWIDTH); int rpad = PBWIDTH - lpad; printf("\r%3d%% [%.*s%*s]", val, lpad, PBSTR, rpad, ""); fflush(stdout); } /* * Inline function to generate a random integer in a range. */ inline size_t random(size_t range_from, size_t range_to) { std::random_device rand_dev; std::mt19937 generator(rand_dev()); std::uniform_int_distribution distr(range_from, range_to); return distr(generator); } /* * function to handle command line parsing. * * Arguments are merely the inputs from the command line. */ void handle_args(int argc, char **argv, std::string &data_type, path &input_data_path, path &final_index_path_prefix, path &label_data_path, std::string &universal_label, uint32_t &num_threads, uint32_t &R, uint32_t &L, uint32_t &stitched_R, float &alpha) { po::options_description desc{ program_options_utils::make_program_description("build_stitched_index", "Build a stitched DiskANN index.")}; try { desc.add_options()("help,h", "Print information on arguments"); // Required parameters po::options_description required_configs("Required"); required_configs.add_options()("data_type", po::value(&data_type)->required(), program_options_utils::DATA_TYPE_DESCRIPTION); required_configs.add_options()("index_path_prefix", po::value(&final_index_path_prefix)->required(), program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); required_configs.add_options()("data_path", po::value(&input_data_path)->required(), program_options_utils::INPUT_DATA_PATH); // Optional parameters po::options_description optional_configs("Optional"); optional_configs.add_options()("num_threads,T", po::value(&num_threads)->default_value(omp_get_num_procs()), program_options_utils::NUMBER_THREADS_DESCRIPTION); optional_configs.add_options()("max_degree,R", po::value(&R)->default_value(64), program_options_utils::MAX_BUILD_DEGREE); optional_configs.add_options()("Lbuild,L", po::value(&L)->default_value(100), program_options_utils::GRAPH_BUILD_COMPLEXITY); optional_configs.add_options()("alpha", po::value(&alpha)->default_value(1.2f), program_options_utils::GRAPH_BUILD_ALPHA); optional_configs.add_options()("label_file", po::value(&label_data_path)->default_value(""), program_options_utils::LABEL_FILE); optional_configs.add_options()("universal_label", po::value(&universal_label)->default_value(""), program_options_utils::UNIVERSAL_LABEL); optional_configs.add_options()("stitched_R", po::value(&stitched_R)->default_value(100), "Degree to prune final graph down to"); // Merge required and optional parameters desc.add(required_configs).add(optional_configs); po::variables_map vm; po::store(po::parse_command_line(argc, argv, desc), vm); if (vm.count("help")) { std::cout << desc; exit(0); } po::notify(vm); } catch (const std::exception &ex) { std::cerr << ex.what() << '\n'; throw; } } /* * Custom index save to write the in-memory index to disk. * Also writes required files for diskANN API - * 1. labels_to_medoids * 2. universal_label * 3. data (redundant for static indices) * 4. labels (redundant for static indices) */ void save_full_index(path final_index_path_prefix, path input_data_path, uint64_t final_index_size, std::vector> stitched_graph, tsl::robin_map entry_points, std::string universal_label, path label_data_path) { // aux. file 1 auto saving_index_timer = std::chrono::high_resolution_clock::now(); std::ifstream original_label_data_stream; original_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit); original_label_data_stream.open(label_data_path, std::ios::binary); std::ofstream new_label_data_stream; new_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit); new_label_data_stream.open(final_index_path_prefix + "_labels.txt", std::ios::binary); new_label_data_stream << original_label_data_stream.rdbuf(); original_label_data_stream.close(); new_label_data_stream.close(); // aux. file 2 std::ifstream original_input_data_stream; original_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit); original_input_data_stream.open(input_data_path, std::ios::binary); std::ofstream new_input_data_stream; new_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit); new_input_data_stream.open(final_index_path_prefix + ".data", std::ios::binary); new_input_data_stream << original_input_data_stream.rdbuf(); original_input_data_stream.close(); new_input_data_stream.close(); // aux. file 3 std::ofstream labels_to_medoids_writer; labels_to_medoids_writer.exceptions(std::ios::badbit | std::ios::failbit); labels_to_medoids_writer.open(final_index_path_prefix + "_labels_to_medoids.txt"); for (auto iter : entry_points) labels_to_medoids_writer << iter.first << ", " << iter.second << std::endl; labels_to_medoids_writer.close(); // aux. file 4 (only if we're using a universal label) if (universal_label != "") { std::ofstream universal_label_writer; universal_label_writer.exceptions(std::ios::badbit | std::ios::failbit); universal_label_writer.open(final_index_path_prefix + "_universal_label.txt"); universal_label_writer << universal_label << std::endl; universal_label_writer.close(); } // main index uint64_t index_num_frozen_points = 0, index_num_edges = 0; uint32_t index_max_observed_degree = 0, index_entry_point = 0; const size_t METADATA = 2 * sizeof(uint64_t) + 2 * sizeof(uint32_t); for (auto &point_neighbors : stitched_graph) { index_max_observed_degree = std::max(index_max_observed_degree, (uint32_t)point_neighbors.size()); } std::ofstream stitched_graph_writer; stitched_graph_writer.exceptions(std::ios::badbit | std::ios::failbit); stitched_graph_writer.open(final_index_path_prefix, std::ios_base::binary); stitched_graph_writer.write((char *)&final_index_size, sizeof(uint64_t)); stitched_graph_writer.write((char *)&index_max_observed_degree, sizeof(uint32_t)); stitched_graph_writer.write((char *)&index_entry_point, sizeof(uint32_t)); stitched_graph_writer.write((char *)&index_num_frozen_points, sizeof(uint64_t)); size_t bytes_written = METADATA; for (uint32_t node_point = 0; node_point < stitched_graph.size(); node_point++) { uint32_t current_node_num_neighbors = (uint32_t)stitched_graph[node_point].size(); std::vector current_node_neighbors = stitched_graph[node_point]; stitched_graph_writer.write((char *)¤t_node_num_neighbors, sizeof(uint32_t)); bytes_written += sizeof(uint32_t); for (const auto ¤t_node_neighbor : current_node_neighbors) { stitched_graph_writer.write((char *)¤t_node_neighbor, sizeof(uint32_t)); bytes_written += sizeof(uint32_t); } index_num_edges += current_node_num_neighbors; } if (bytes_written != final_index_size) { std::cerr << "Error: written bytes does not match allocated space" << std::endl; throw; } stitched_graph_writer.close(); std::chrono::duration saving_index_time = std::chrono::high_resolution_clock::now() - saving_index_timer; std::cout << "Stitched graph written in " << saving_index_time.count() << " seconds" << std::endl; std::cout << "Stitched graph average degree: " << ((float)index_num_edges) / ((float)(stitched_graph.size())) << std::endl; std::cout << "Stitched graph max degree: " << index_max_observed_degree << std::endl << std::endl; } /* * Unions the per-label graph indices together via the following policy: * - any two nodes can only have at most one edge between them - * * Returns the "stitched" graph and its expected file size. */ template stitch_indices_return_values stitch_label_indices( path final_index_path_prefix, uint32_t total_number_of_points, label_set all_labels, tsl::robin_map labels_to_number_of_points, tsl::robin_map &label_entry_points, tsl::robin_map> label_id_to_orig_id_map) { size_t final_index_size = 0; std::vector> stitched_graph(total_number_of_points); auto stitching_index_timer = std::chrono::high_resolution_clock::now(); for (const auto &lbl : all_labels) { path curr_label_index_path(final_index_path_prefix + "_" + lbl); std::vector> curr_label_index; uint64_t curr_label_index_size; uint32_t curr_label_entry_point; std::tie(curr_label_index, curr_label_index_size) = diskann::load_label_index(curr_label_index_path, labels_to_number_of_points[lbl]); curr_label_entry_point = (uint32_t)random(0, curr_label_index.size()); label_entry_points[lbl] = label_id_to_orig_id_map[lbl][curr_label_entry_point]; for (uint32_t node_point = 0; node_point < curr_label_index.size(); node_point++) { uint32_t original_point_id = label_id_to_orig_id_map[lbl][node_point]; for (auto &node_neighbor : curr_label_index[node_point]) { uint32_t original_neighbor_id = label_id_to_orig_id_map[lbl][node_neighbor]; std::vector curr_point_neighbors = stitched_graph[original_point_id]; if (std::find(curr_point_neighbors.begin(), curr_point_neighbors.end(), original_neighbor_id) == curr_point_neighbors.end()) { stitched_graph[original_point_id].push_back(original_neighbor_id); final_index_size += sizeof(uint32_t); } } } } const size_t METADATA = 2 * sizeof(uint64_t) + 2 * sizeof(uint32_t); final_index_size += (total_number_of_points * sizeof(uint32_t) + METADATA); std::chrono::duration stitching_index_time = std::chrono::high_resolution_clock::now() - stitching_index_timer; std::cout << "stitched graph generated in memory in " << stitching_index_time.count() << " seconds" << std::endl; return std::make_tuple(stitched_graph, final_index_size); } /* * Applies the prune_neighbors function from src/index.cpp to * every node in the stitched graph. * * This is an optional step, hence the saving of both the full * and pruned graph. */ template void prune_and_save(path final_index_path_prefix, path full_index_path_prefix, path input_data_path, std::vector> stitched_graph, uint32_t stitched_R, tsl::robin_map label_entry_points, std::string universal_label, path label_data_path, uint32_t num_threads) { size_t dimension, number_of_label_points; auto diskann_cout_buffer = diskann::cout.rdbuf(nullptr); auto std_cout_buffer = std::cout.rdbuf(nullptr); auto pruning_index_timer = std::chrono::high_resolution_clock::now(); diskann::get_bin_metadata(input_data_path, number_of_label_points, dimension); diskann::Index index(diskann::Metric::L2, dimension, number_of_label_points, false, false); // not searching this index, set search_l to 0 index.load(full_index_path_prefix.c_str(), num_threads, 1); std::cout << "parsing labels" << std::endl; index.prune_all_neighbors(stitched_R, 750, 1.2); index.save((final_index_path_prefix).c_str()); diskann::cout.rdbuf(diskann_cout_buffer); std::cout.rdbuf(std_cout_buffer); std::chrono::duration pruning_index_time = std::chrono::high_resolution_clock::now() - pruning_index_timer; std::cout << "pruning performed in " << pruning_index_time.count() << " seconds\n" << std::endl; } /* * Delete all temporary artifacts. * In the process of creating the stitched index, some temporary artifacts are * created: * 1. the separate bin files for each labels' points * 2. the separate diskANN indices built for each label * 3. the '.data' file created while generating the indices */ void clean_up_artifacts(path input_data_path, path final_index_path_prefix, label_set all_labels) { for (const auto &lbl : all_labels) { path curr_label_input_data_path(input_data_path + "_" + lbl); path curr_label_index_path(final_index_path_prefix + "_" + lbl); path curr_label_index_path_data(curr_label_index_path + ".data"); if (std::remove(curr_label_index_path.c_str()) != 0) throw; if (std::remove(curr_label_input_data_path.c_str()) != 0) throw; if (std::remove(curr_label_index_path_data.c_str()) != 0) throw; } } int main(int argc, char **argv) { // 1. handle cmdline inputs std::string data_type; path input_data_path, final_index_path_prefix, label_data_path; std::string universal_label; uint32_t num_threads, R, L, stitched_R; float alpha; auto index_timer = std::chrono::high_resolution_clock::now(); handle_args(argc, argv, data_type, input_data_path, final_index_path_prefix, label_data_path, universal_label, num_threads, R, L, stitched_R, alpha); path labels_file_to_use = final_index_path_prefix + "_label_formatted.txt"; path labels_map_file = final_index_path_prefix + "_labels_map.txt"; convert_labels_string_to_int(label_data_path, labels_file_to_use, labels_map_file, universal_label); // 2. parse label file and create necessary data structures std::vector point_ids_to_labels; tsl::robin_map labels_to_number_of_points; label_set all_labels; std::tie(point_ids_to_labels, labels_to_number_of_points, all_labels) = diskann::parse_label_file(labels_file_to_use, universal_label); // 3. for each label, make a separate data file tsl::robin_map> label_id_to_orig_id_map; uint32_t total_number_of_points = (uint32_t)point_ids_to_labels.size(); #ifndef _WINDOWS if (data_type == "uint8") label_id_to_orig_id_map = diskann::generate_label_specific_vector_files( input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); else if (data_type == "int8") label_id_to_orig_id_map = diskann::generate_label_specific_vector_files( input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); else if (data_type == "float") label_id_to_orig_id_map = diskann::generate_label_specific_vector_files( input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); else throw; #else if (data_type == "uint8") label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat( input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); else if (data_type == "int8") label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat( input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); else if (data_type == "float") label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat( input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); else throw; #endif // 4. for each created data file, create a vanilla diskANN index if (data_type == "uint8") diskann::generate_label_indices(input_data_path, final_index_path_prefix, all_labels, R, L, alpha, num_threads); else if (data_type == "int8") diskann::generate_label_indices(input_data_path, final_index_path_prefix, all_labels, R, L, alpha, num_threads); else if (data_type == "float") diskann::generate_label_indices(input_data_path, final_index_path_prefix, all_labels, R, L, alpha, num_threads); else throw; // 5. "stitch" the indices together std::vector> stitched_graph; tsl::robin_map label_entry_points; uint64_t stitched_graph_size; if (data_type == "uint8") std::tie(stitched_graph, stitched_graph_size) = stitch_label_indices(final_index_path_prefix, total_number_of_points, all_labels, labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map); else if (data_type == "int8") std::tie(stitched_graph, stitched_graph_size) = stitch_label_indices(final_index_path_prefix, total_number_of_points, all_labels, labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map); else if (data_type == "float") std::tie(stitched_graph, stitched_graph_size) = stitch_label_indices(final_index_path_prefix, total_number_of_points, all_labels, labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map); else throw; path full_index_path_prefix = final_index_path_prefix + "_full"; // 5a. save the stitched graph to disk save_full_index(full_index_path_prefix, input_data_path, stitched_graph_size, stitched_graph, label_entry_points, universal_label, labels_file_to_use); // 6. run a prune on the stitched index, and save to disk if (data_type == "uint8") prune_and_save(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph, stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads); else if (data_type == "int8") prune_and_save(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph, stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads); else if (data_type == "float") prune_and_save(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph, stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads); else throw; std::chrono::duration index_time = std::chrono::high_resolution_clock::now() - index_timer; std::cout << "pruned/stitched graph generated in " << index_time.count() << " seconds" << std::endl; clean_up_artifacts(input_data_path, final_index_path_prefix, all_labels); }