// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #ifdef _WINDOWS #include #else #include #endif #include "filter_utils.h" #include "utils.h" // WORKS FOR UPTO 2 BILLION POINTS (as we use INT INSTEAD OF UNSIGNED) #define PARTSIZE 10000000 #define ALIGNMENT 512 // custom types (for readability) typedef tsl::robin_set label_set; typedef std::string path; namespace po = boost::program_options; template T div_round_up(const T numerator, const T denominator) { return (numerator % denominator == 0) ? (numerator / denominator) : 1 + (numerator / denominator); } using pairIF = std::pair; struct cmpmaxstruct { bool operator()(const pairIF &l, const pairIF &r) { return l.second < r.second; }; }; using maxPQIFCS = std::priority_queue, cmpmaxstruct>; template T *aligned_malloc(const size_t n, const size_t alignment) { #ifdef _WINDOWS return (T *)_aligned_malloc(sizeof(T) * n, alignment); #else return static_cast(aligned_alloc(alignment, sizeof(T) * n)); #endif } inline bool custom_dist(const std::pair &a, const std::pair &b) { return a.second < b.second; } void compute_l2sq(float *const points_l2sq, const float *const matrix, const int64_t num_points, const uint64_t dim) { assert(points_l2sq != NULL); #pragma omp parallel for schedule(static, 65536) for (int64_t d = 0; d < num_points; ++d) points_l2sq[d] = cblas_sdot((int64_t)dim, matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1, matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1); } void distsq_to_points(const size_t dim, float *dist_matrix, // Col Major, cols are queries, rows are points size_t npoints, const float *const points, const float *const points_l2sq, // points in Col major size_t nqueries, const float *const queries, const float *const queries_l2sq, // queries in Col major float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0 { bool ones_vec_alloc = false; if (ones_vec == NULL) { ones_vec = new float[nqueries > npoints ? nqueries : npoints]; std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0); ones_vec_alloc = true; } cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-2.0, points, dim, queries, dim, (float)0.0, dist_matrix, npoints); cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, points_l2sq, npoints, ones_vec, nqueries, (float)1.0, dist_matrix, npoints); cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, ones_vec, npoints, queries_l2sq, nqueries, (float)1.0, dist_matrix, npoints); if (ones_vec_alloc) delete[] ones_vec; } void inner_prod_to_points(const size_t dim, float *dist_matrix, // Col Major, cols are queries, rows are points size_t npoints, const float *const points, size_t nqueries, const float *const queries, float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0 { bool ones_vec_alloc = false; if (ones_vec == NULL) { ones_vec = new float[nqueries > npoints ? nqueries : npoints]; std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0); ones_vec_alloc = true; } cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-1.0, points, dim, queries, dim, (float)0.0, dist_matrix, npoints); if (ones_vec_alloc) delete[] ones_vec; } void exact_knn(const size_t dim, const size_t k, size_t *const closest_points, // k * num_queries preallocated, col // major, queries columns float *const dist_closest_points, // k * num_queries // preallocated, Dist to // corresponding closes_points size_t npoints, float *points_in, // points in Col major size_t nqueries, float *queries_in, diskann::Metric metric = diskann::Metric::L2) // queries in Col major { float *points_l2sq = new float[npoints]; float *queries_l2sq = new float[nqueries]; compute_l2sq(points_l2sq, points_in, npoints, dim); compute_l2sq(queries_l2sq, queries_in, nqueries, dim); float *points = points_in; float *queries = queries_in; if (metric == diskann::Metric::COSINE) { // we convert cosine distance as // normalized L2 distnace points = new float[npoints * dim]; queries = new float[nqueries * dim]; #pragma omp parallel for schedule(static, 4096) for (int64_t i = 0; i < (int64_t)npoints; i++) { float norm = std::sqrt(points_l2sq[i]); if (norm == 0) { norm = std::numeric_limits::epsilon(); } for (uint32_t j = 0; j < dim; j++) { points[i * dim + j] = points_in[i * dim + j] / norm; } } #pragma omp parallel for schedule(static, 4096) for (int64_t i = 0; i < (int64_t)nqueries; i++) { float norm = std::sqrt(queries_l2sq[i]); if (norm == 0) { norm = std::numeric_limits::epsilon(); } for (uint32_t j = 0; j < dim; j++) { queries[i * dim + j] = queries_in[i * dim + j] / norm; } } // recalculate norms after normalizing, they should all be one. compute_l2sq(points_l2sq, points, npoints, dim); compute_l2sq(queries_l2sq, queries, nqueries, dim); } std::cout << "Going to compute " << k << " NNs for " << nqueries << " queries over " << npoints << " points in " << dim << " dimensions using"; if (metric == diskann::Metric::INNER_PRODUCT) std::cout << " MIPS "; else if (metric == diskann::Metric::COSINE) std::cout << " Cosine "; else std::cout << " L2 "; std::cout << "distance fn. " << std::endl; size_t q_batch_size = (1 << 9); float *dist_matrix = new float[(size_t)q_batch_size * (size_t)npoints]; for (size_t b = 0; b < div_round_up(nqueries, q_batch_size); ++b) { int64_t q_b = b * q_batch_size; int64_t q_e = ((b + 1) * q_batch_size > nqueries) ? nqueries : (b + 1) * q_batch_size; if (metric == diskann::Metric::L2 || metric == diskann::Metric::COSINE) { distsq_to_points(dim, dist_matrix, npoints, points, points_l2sq, q_e - q_b, queries + (ptrdiff_t)q_b * (ptrdiff_t)dim, queries_l2sq + q_b); } else { inner_prod_to_points(dim, dist_matrix, npoints, points, q_e - q_b, queries + (ptrdiff_t)q_b * (ptrdiff_t)dim); } std::cout << "Computed distances for queries: [" << q_b << "," << q_e << ")" << std::endl; #pragma omp parallel for schedule(dynamic, 16) for (long long q = q_b; q < q_e; q++) { maxPQIFCS point_dist; for (size_t p = 0; p < k; p++) point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]); for (size_t p = k; p < npoints; p++) { if (point_dist.top().second > dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]) point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]); if (point_dist.size() > k) point_dist.pop(); } for (ptrdiff_t l = 0; l < (ptrdiff_t)k; ++l) { closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().first; dist_closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().second; point_dist.pop(); } assert(std::is_sorted(dist_closest_points + (ptrdiff_t)q * (ptrdiff_t)k, dist_closest_points + (ptrdiff_t)(q + 1) * (ptrdiff_t)k)); } std::cout << "Computed exact k-NN for queries: [" << q_b << "," << q_e << ")" << std::endl; } delete[] dist_matrix; delete[] points_l2sq; delete[] queries_l2sq; if (metric == diskann::Metric::COSINE) { delete[] points; delete[] queries; } } template inline int get_num_parts(const char *filename) { std::ifstream reader; reader.exceptions(std::ios::failbit | std::ios::badbit); reader.open(filename, std::ios::binary); std::cout << "Reading bin file " << filename << " ...\n"; int npts_i32, ndims_i32; reader.read((char *)&npts_i32, sizeof(int)); reader.read((char *)&ndims_i32, sizeof(int)); std::cout << "#pts = " << npts_i32 << ", #dims = " << ndims_i32 << std::endl; reader.close(); uint32_t num_parts = (npts_i32 % PARTSIZE) == 0 ? npts_i32 / PARTSIZE : (uint32_t)std::floor(npts_i32 / PARTSIZE) + 1; std::cout << "Number of parts: " << num_parts << std::endl; return num_parts; } template inline void load_bin_as_float(const char *filename, float *&data, size_t &npts, size_t &ndims, int part_num) { std::ifstream reader; reader.exceptions(std::ios::failbit | std::ios::badbit); reader.open(filename, std::ios::binary); std::cout << "Reading bin file " << filename << " ...\n"; int npts_i32, ndims_i32; reader.read((char *)&npts_i32, sizeof(int)); reader.read((char *)&ndims_i32, sizeof(int)); uint64_t start_id = part_num * PARTSIZE; uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32); npts = end_id - start_id; ndims = (uint64_t)ndims_i32; std::cout << "#pts in part = " << npts << ", #dims = " << ndims << ", size = " << npts * ndims * sizeof(T) << "B" << std::endl; reader.seekg(start_id * ndims * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg); T *data_T = new T[npts * ndims]; reader.read((char *)data_T, sizeof(T) * npts * ndims); std::cout << "Finished reading part of the bin file." << std::endl; reader.close(); data = aligned_malloc(npts * ndims, ALIGNMENT); #pragma omp parallel for schedule(dynamic, 32768) for (int64_t i = 0; i < (int64_t)npts; i++) { for (int64_t j = 0; j < (int64_t)ndims; j++) { float cur_val_float = (float)data_T[i * ndims + j]; std::memcpy((char *)(data + i * ndims + j), (char *)&cur_val_float, sizeof(float)); } } delete[] data_T; std::cout << "Finished converting part data to float." << std::endl; } template inline void save_bin(const std::string filename, T *data, size_t npts, size_t ndims) { std::ofstream writer; writer.exceptions(std::ios::failbit | std::ios::badbit); writer.open(filename, std::ios::binary | std::ios::out); std::cout << "Writing bin: " << filename << "\n"; int npts_i32 = (int)npts, ndims_i32 = (int)ndims; writer.write((char *)&npts_i32, sizeof(int)); writer.write((char *)&ndims_i32, sizeof(int)); std::cout << "bin: #pts = " << npts << ", #dims = " << ndims << ", size = " << npts * ndims * sizeof(T) + 2 * sizeof(int) << "B" << std::endl; writer.write((char *)data, npts * ndims * sizeof(T)); writer.close(); std::cout << "Finished writing bin" << std::endl; } inline void save_groundtruth_as_one_file(const std::string filename, int32_t *data, float *distances, size_t npts, size_t ndims) { std::ofstream writer(filename, std::ios::binary | std::ios::out); int npts_i32 = (int)npts, ndims_i32 = (int)ndims; writer.write((char *)&npts_i32, sizeof(int)); writer.write((char *)&ndims_i32, sizeof(int)); std::cout << "Saving truthset in one file (npts, dim, npts*dim id-matrix, " "npts*dim dist-matrix) with npts = " << npts << ", dim = " << ndims << ", size = " << 2 * npts * ndims * sizeof(uint32_t) + 2 * sizeof(int) << "B" << std::endl; writer.write((char *)data, npts * ndims * sizeof(uint32_t)); writer.write((char *)distances, npts * ndims * sizeof(float)); writer.close(); std::cout << "Finished writing truthset" << std::endl; } template std::vector>> processUnfilteredParts(const std::string &base_file, size_t &nqueries, size_t &npoints, size_t &dim, size_t &k, float *query_data, const diskann::Metric &metric, std::vector &location_to_tag) { float *base_data = nullptr; int num_parts = get_num_parts(base_file.c_str()); std::vector>> res(nqueries); for (int p = 0; p < num_parts; p++) { size_t start_id = p * PARTSIZE; load_bin_as_float(base_file.c_str(), base_data, npoints, dim, p); size_t *closest_points_part = new size_t[nqueries * k]; float *dist_closest_points_part = new float[nqueries * k]; auto part_k = k < npoints ? k : npoints; exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints, base_data, nqueries, query_data, metric); for (size_t i = 0; i < nqueries; i++) { for (size_t j = 0; j < part_k; j++) { if (!location_to_tag.empty()) if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0) continue; res[i].push_back(std::make_pair((uint32_t)(closest_points_part[i * part_k + j] + start_id), dist_closest_points_part[i * part_k + j])); } } delete[] closest_points_part; delete[] dist_closest_points_part; diskann::aligned_free(base_data); } return res; }; template int aux_main(const std::string &base_file, const std::string &query_file, const std::string >_file, size_t k, const diskann::Metric &metric, const std::string &tags_file = std::string("")) { size_t npoints, nqueries, dim; float *query_data; load_bin_as_float(query_file.c_str(), query_data, nqueries, dim, 0); if (nqueries > PARTSIZE) std::cerr << "WARNING: #Queries provided (" << nqueries << ") is greater than " << PARTSIZE << ". Computing GT only for the first " << PARTSIZE << " queries." << std::endl; // load tags const bool tags_enabled = tags_file.empty() ? false : true; std::vector location_to_tag = diskann::loadTags(tags_file, base_file); int *closest_points = new int[nqueries * k]; float *dist_closest_points = new float[nqueries * k]; std::vector>> results = processUnfilteredParts(base_file, nqueries, npoints, dim, k, query_data, metric, location_to_tag); for (size_t i = 0; i < nqueries; i++) { std::vector> &cur_res = results[i]; std::sort(cur_res.begin(), cur_res.end(), custom_dist); size_t j = 0; for (auto iter : cur_res) { if (j == k) break; if (tags_enabled) { std::uint32_t index_with_tag = location_to_tag[iter.first]; closest_points[i * k + j] = (int32_t)index_with_tag; } else { closest_points[i * k + j] = (int32_t)iter.first; } if (metric == diskann::Metric::INNER_PRODUCT) dist_closest_points[i * k + j] = -iter.second; else dist_closest_points[i * k + j] = iter.second; ++j; } if (j < k) std::cout << "WARNING: found less than k GT entries for query " << i << std::endl; } save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, nqueries, k); delete[] closest_points; delete[] dist_closest_points; diskann::aligned_free(query_data); return 0; } void load_truthset(const std::string &bin_file, uint32_t *&ids, float *&dists, size_t &npts, size_t &dim) { size_t read_blk_size = 64 * 1024 * 1024; cached_ifstream reader(bin_file, read_blk_size); diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." << std::endl; size_t actual_file_size = reader.get_file_size(); int npts_i32, dim_i32; reader.read((char *)&npts_i32, sizeof(int)); reader.read((char *)&dim_i32, sizeof(int)); npts = (uint32_t)npts_i32; dim = (uint32_t)dim_i32; diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " << std::endl; int truthset_type = -1; // 1 means truthset has ids and distances, 2 means // only ids, -1 is error size_t expected_file_size_with_dists = 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); if (actual_file_size == expected_file_size_with_dists) truthset_type = 1; size_t expected_file_size_just_ids = npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); if (actual_file_size == expected_file_size_just_ids) truthset_type = 2; if (truthset_type == -1) { std::stringstream stream; stream << "Error. File size mismatch. File should have bin format, with " "npts followed by ngt followed by npts*ngt ids and optionally " "followed by npts*ngt distance values; actual size: " << actual_file_size << ", expected: " << expected_file_size_with_dists << " or " << expected_file_size_just_ids; diskann::cout << stream.str(); throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } ids = new uint32_t[npts * dim]; reader.read((char *)ids, npts * dim * sizeof(uint32_t)); if (truthset_type == 1) { dists = new float[npts * dim]; reader.read((char *)dists, npts * dim * sizeof(float)); } } int main(int argc, char **argv) { std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file; uint64_t K; try { po::options_description desc{"Arguments"}; desc.add_options()("help,h", "Print information on arguments"); desc.add_options()("data_type", po::value(&data_type)->required(), "data type "); desc.add_options()("dist_fn", po::value(&dist_fn)->required(), "distance function "); desc.add_options()("base_file", po::value(&base_file)->required(), "File containing the base vectors in binary format"); desc.add_options()("query_file", po::value(&query_file)->required(), "File containing the query vectors in binary format"); desc.add_options()("gt_file", po::value(>_file)->required(), "File name for the writing ground truth in binary " "format, please don' append .bin at end if " "no filter_label or filter_label_file is provided it " "will save the file with '.bin' at end." "else it will save the file as filename_label.bin"); desc.add_options()("K", po::value(&K)->required(), "Number of ground truth nearest neighbors to compute"); desc.add_options()("tags_file", po::value(&tags_file)->default_value(std::string()), "File containing the tags in binary format"); po::variables_map vm; po::store(po::parse_command_line(argc, argv, desc), vm); if (vm.count("help")) { std::cout << desc; return 0; } po::notify(vm); } catch (const std::exception &ex) { std::cerr << ex.what() << '\n'; return -1; } if (data_type != std::string("float") && data_type != std::string("int8") && data_type != std::string("uint8")) { std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl; return -1; } diskann::Metric metric; if (dist_fn == std::string("l2")) { metric = diskann::Metric::L2; } else if (dist_fn == std::string("mips")) { metric = diskann::Metric::INNER_PRODUCT; } else if (dist_fn == std::string("cosine")) { metric = diskann::Metric::COSINE; } else { std::cerr << "Unsupported distance function. Use l2/mips/cosine." << std::endl; return -1; } try { if (data_type == std::string("float")) aux_main(base_file, query_file, gt_file, K, metric, tags_file); if (data_type == std::string("int8")) aux_main(base_file, query_file, gt_file, K, metric, tags_file); if (data_type == std::string("uint8")) aux_main(base_file, query_file, gt_file, K, metric, tags_file); } catch (const std::exception &e) { std::cout << std::string(e.what()) << std::endl; diskann::cerr << "Compute GT failed." << std::endl; return -1; } }