// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include #include #include // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include #include #include #include #include #include using namespace web; using namespace web::http; using namespace web::http::client; using namespace diskann; namespace po = boost::program_options; template void query_loop(const std::string &ip_addr_port, const std::string &query_file, const unsigned nq, const unsigned Ls, const unsigned k_value) { web::http::client::http_client client(U(ip_addr_port)); T *data; size_t npts = 1, ndims = 128, rounded_dim = 128; diskann::load_aligned_bin(query_file, data, npts, ndims, rounded_dim); for (unsigned i = 0; i < nq; ++i) { T *vec = data + i * rounded_dim; web::http::http_request http_query(methods::POST); web::json::value queryJson = web::json::value::object(); queryJson[QUERY_ID_KEY] = i; queryJson[K_KEY] = k_value; queryJson[L_KEY] = Ls; for (size_t i = 0; i < ndims; ++i) { queryJson[VECTOR_KEY][i] = web::json::value::number(vec[i]); } http_query.set_body(queryJson); client.request(http_query) .then([](web::http::http_response response) -> pplx::task { if (response.status_code() == status_codes::OK) { return response.extract_string(); } std::cerr << "Query failed" << std::endl; return pplx::task_from_result(utility::string_t()); }) .then([](pplx::task previousTask) { try { std::cout << previousTask.get() << std::endl; } catch (http_exception const &e) { std::wcout << e.what() << std::endl; } }) .wait(); } } int main(int argc, char *argv[]) { std::string data_type, query_file, address; uint32_t num_queries; uint32_t l_search, k_value; po::options_description desc{"Arguments"}; try { desc.add_options()("help,h", "Print information on arguments"); desc.add_options()("data_type", po::value(&data_type)->required(), "data type "); desc.add_options()("address", po::value(&address)->required(), "Web server address"); desc.add_options()("query_file", po::value(&query_file)->required(), "File containing the queries to search"); desc.add_options()("num_queries,Q", po::value(&num_queries)->required(), "Number of queries to search"); desc.add_options()("l_search", po::value(&l_search)->required(), "Value of L"); desc.add_options()("k_value,K", po::value(&k_value)->default_value(10), "Value of K (default 10)"); po::variables_map vm; po::store(po::parse_command_line(argc, argv, desc), vm); if (vm.count("help")) { std::cout << desc; return 0; } po::notify(vm); } catch (const std::exception &ex) { std::cerr << ex.what() << std::endl; return -1; } if (data_type == std::string("float")) { query_loop(address, query_file, num_queries, l_search, k_value); } else if (data_type == std::string("int8")) { query_loop(address, query_file, num_queries, l_search, k_value); } else if (data_type == std::string("uint8")) { query_loop(address, query_file, num_queries, l_search, k_value); } else { std::cerr << "Unsupported type " << argv[2] << std::endl; return -1; } return 0; }