use dragon_db::embedding::EmbeddingModelType; use dragon_db::{db_postgres::VectorDB, db_postgres::VectorDBConfig}; use dragon_db::{server::models::*, server::routes::*, DB_PATH, DIMENSIONS}; use rocket::async_test; use rocket::fairing::AdHoc; use rocket::http::ContentType; use rocket::http::Status; use rocket::local::asynchronous::Client; use rocket::routes; use std::collections::HashMap; use std::fs; use tempfile::TempDir; use tokio::sync::Mutex; const USE_TEMP_DIR: bool = false; async fn create_test_rocket() -> Client { let temp_dir = TempDir::new().unwrap(); let tmp_path = temp_dir.path().to_owned(); let path = if USE_TEMP_DIR { tmp_path.to_str().unwrap().to_string() } else { format!("{}/testing", DB_PATH.to_string()) }; let config = VectorDBConfig { path }; let database = VectorDB::with_config(config).await.unwrap(); database.reset().await.unwrap(); let rock = rocket::build() .manage(Mutex::new(database)) .manage(temp_dir) .mount( "/", routes![ add, add_file, search, get_texts, get_docs, delete_texts, count_items, clear_collection, get_info, batch_add_file, execute_sql, create_collection, delete_collection, update_text ], ) .attach(AdHoc::on_shutdown("Cleanup TempDir", |_| { Box::pin(async move { let _ = fs::remove_dir_all(tmp_path); }) })); Client::tracked(rock) .await .expect("Invalid rocket instance") } async fn create_test_collection(client: &Client, collection_name: &str) { let create_request = CreateCollectionRequest { collection_name: collection_name.to_string(), dimensions: DIMENSIONS, }; let response = client .post("/create_collection") .json(&create_request) .dispatch() .await; assert_eq!(response.status(), Status::Ok); } #[async_test] async fn test_add_text() { let client = create_test_rocket().await; create_test_collection(&client, "test_collection").await; let add_request = AddRequest { collection_name: "test_collection".to_string(), texts: vec!["Test text".to_string()], vectors: None, metadata: Some(vec![HashMap::from([( "key".to_string(), "value".to_string(), )])]), model: Some(EmbeddingModelType::default()), document_id: None, create_new_doc: Some(true), }; let response = client.post("/add").json(&add_request).dispatch().await; let status = response.status(); println!("Status: {:?}", status); let body = response .into_string() .await .expect("Failed to get response body"); println!("Response body: {}", body); assert_eq!(status, Status::Ok); let body: AddResponse = serde_json::from_str(&body).expect("Failed to parse JSON"); assert!(!body.id.is_empty()); } #[async_test] async fn test_file_txt() { let client = create_test_rocket().await; create_test_collection(&client, "test_collection").await; // Create a multipart form with the example file text let content = "This is a test file content"; let ct = "multipart/form-data; boundary=X-BOUNDARY" .parse::() .unwrap(); let body = &[ "--X-BOUNDARY", r#"Content-Disposition: form-data; name="file"; filename="test.txt""#, "Content-Type: text/plain", "", content, "--X-BOUNDARY", r#"Content-Disposition: form-data; name="model""#, "", "OpenAI", "--X-BOUNDARY", r#"Content-Disposition: form-data; name="collection_name""#, "", "test_collection", "--X-BOUNDARY--", "", ] .join("\r\n"); // Send the request let response = client .post("/add_file") .header(ct) .body(body) .dispatch() .await; let status = response.status(); println!("Status: {:?}", status); assert_eq!(status, Status::Ok); let body = response .into_string() .await .expect("Failed to get response body"); println!("Response body: {}", body); let body: AddResponse = serde_json::from_str(&body).expect("Failed to parse JSON"); assert!(!body.id.is_empty()); // Get the text using the returned ID let get_request = GetTextsRequest { collection_name: "test_collection".to_string(), ids: Some(vec![body.id]), offset: None, limit: None, }; let get_response = client.get("/get_texts").json(&get_request).dispatch().await; assert_eq!(get_response.status(), Status::Ok); let retrieved_content = get_response.into_string().await.unwrap(); let parsed: GetTextsResponse = serde_json::from_str(&retrieved_content).unwrap(); assert_eq!(parsed.results.len(), 1); } #[async_test] async fn test_delete_texts() { let client = create_test_rocket().await; create_test_collection(&client, "test_collection").await; // Add a text to delete let text_data = AddRequest { collection_name: "test_collection".to_string(), texts: vec!["Text to delete".to_string()], vectors: None, metadata: None, model: Some(EmbeddingModelType::default()), document_id: None, create_new_doc: Some(true), }; let add_response = client .post("/add") .header(ContentType::JSON) .json(&text_data) .dispatch() .await; assert_eq!(add_response.status(), Status::Ok); let add_result: AddResponse = add_response.into_json().await.unwrap(); // Delete the text let delete_request = DeleteRequest { collection_name: "test_collection".to_string(), ids: vec![add_result.id.clone()], }; let delete_response = client .delete("/delete") .json(&delete_request) .dispatch() .await; assert_eq!(delete_response.status(), Status::Ok); // Verify the text is deleted let get_request = GetTextsRequest { collection_name: "test_collection".to_string(), ids: Some(vec![add_result.id]), offset: None, limit: None, }; let get_response = client.get("/get_texts").json(&get_request).dispatch().await; assert_eq!(get_response.status(), Status::Ok); let texts: GetTextsResponse = get_response.into_json().await.unwrap(); assert_eq!(texts.results.len(), 0); } #[async_test] async fn test_update_text() { let client = create_test_rocket().await; create_test_collection(&client, "test_collection").await; // Add a text to update let original_text_data = AddRequest { collection_name: "test_collection".to_string(), texts: vec!["Original text".to_string()], vectors: None, metadata: None, model: Some(EmbeddingModelType::default()), document_id: None, create_new_doc: Some(true), }; let add_response = client .post("/add") .header(ContentType::JSON) .json(&original_text_data) .dispatch() .await; assert_eq!(add_response.status(), Status::Ok); let add_result: AddResponse = add_response.into_json().await.unwrap(); // Update the text let updated_text_data = UpdateRequest { collection_name: "test_collection".to_string(), id: add_result.id.clone(), text: Some("Updated text".to_string()), metadata: None, model: Some(EmbeddingModelType::default()), }; let update_response = client .put("/update") .header(ContentType::JSON) .json(&updated_text_data) .dispatch() .await; assert_eq!(update_response.status(), Status::Ok); assert_eq!( update_response .into_json::() .await .unwrap() .success, true ); // Verify the text is updated let get_request = GetTextsRequest { collection_name: "test_collection".to_string(), ids: Some(vec![add_result.id]), offset: None, limit: None, }; let get_response = client.get("/get_texts").json(&get_request).dispatch().await; assert_eq!(get_response.status(), Status::Ok); let texts: GetTextsResponse = get_response.into_json().await.unwrap(); assert_eq!(texts.results.len(), 1); assert_eq!(texts.results[0].chunk_content, "Updated text".to_string()); } #[async_test] async fn test_clear_collection() { let client = create_test_rocket().await; create_test_collection(&client, "test_collection").await; // Add some texts for i in 0..3 { let text_data = AddRequest { collection_name: "test_collection".to_string(), texts: vec![format!("Text {}", i)], vectors: None, metadata: None, model: Some(EmbeddingModelType::default()), document_id: None, create_new_doc: Some(true), }; let add_response = client .post("/add") .header(ContentType::JSON) .json(&text_data) .dispatch() .await; assert_eq!(add_response.status(), Status::Ok); } // Count the 3 items let count_request = CountItemsRequest { collection_name: "test_collection".to_string(), }; let count_response = client.get("/count").json(&count_request).dispatch().await; assert_eq!(count_response.status(), Status::Ok); let count: CountResponse = count_response.into_json().await.unwrap(); assert_eq!(count.count, 3); // Clear collection let clear_request = ClearCollectionRequest { collection_name: "test_collection".to_string(), }; let clear_response = client .post("/clear_collection") .json(&clear_request) .dispatch() .await; assert_eq!(clear_response.status(), Status::Ok); // Verify collection is empty let count_response = client.get("/count").json(&count_request).dispatch().await; assert_eq!(count_response.status(), Status::Ok); let count: CountResponse = count_response.into_json().await.unwrap(); assert_eq!(count.count, 0); } #[async_test] async fn test_get_info() { let client = create_test_rocket().await; create_test_collection(&client, "test_collection").await; let info_request = GetInfoRequest { collection_name: "test_collection".to_string(), }; let response = client.post("/info").json(&info_request).dispatch().await; assert_eq!(response.status(), Status::Ok); let info: GetInfoResponse = response.into_json().await.unwrap(); // Assert on specific fields in the info response assert_eq!(info.table_name, "test_collection"); assert_eq!(info.dimension, DIMENSIONS); } #[async_test] async fn test_batch_add() { let client = create_test_rocket().await; create_test_collection(&client, "test_collection").await; let batch_request = AddRequest { collection_name: "test_collection".to_string(), texts: vec!["Test text 1".to_string(), "Test text 2".to_string()], vectors: None, metadata: Some(vec![ HashMap::from([("key1".to_string(), "value1".to_string())]), HashMap::from([("key2".to_string(), "value2".to_string())]), ]), model: Some(EmbeddingModelType::default()), document_id: None, create_new_doc: Some(true), }; let response = client.post("/add").json(&batch_request).dispatch().await; assert_eq!(response.status(), Status::Ok); let body: AddResponse = response.into_json().await.expect("Failed to parse JSON"); assert!(!body.id.is_empty()); // Verify that two documents were added let count_request = CountItemsRequest { collection_name: "test_collection".to_string(), }; let count_response = client.get("/count").json(&count_request).dispatch().await; assert_eq!(count_response.status(), Status::Ok); let count: CountResponse = count_response.into_json().await.unwrap(); assert_eq!(count.count, 2); } #[async_test] async fn test_batch_add_file() { let client = create_test_rocket().await; create_test_collection(&client, "test_collection").await; let boundary = "X-BOUNDARY"; let content_type = format!("multipart/form-data; boundary={}", boundary); let body = vec![ format!("--{}", boundary), r#"Content-Disposition: form-data; name="files"; filename="test1.txt""#.to_string(), "Content-Type: text/plain".to_string(), "".to_string(), "This is test file 1".to_string(), format!("--{}", boundary), r#"Content-Disposition: form-data; name="files"; filename="test2.txt""#.to_string(), "Content-Type: text/plain".to_string(), "".to_string(), "This is test file 2".to_string(), format!("--{}", boundary), r#"Content-Disposition: form-data; name="metadata""#.to_string(), "".to_string(), r#"{"key": "value"}"#.to_string(), format!("--{}", boundary), r#"Content-Disposition: form-data; name="model""#.to_string(), "".to_string(), "OpenAI".to_string(), format!("--{}", boundary), r#"Content-Disposition: form-data; name="collection_name""#.to_string(), "".to_string(), "test_collection".to_string(), format!("--{}--", boundary), "".to_string(), ] .join("\r\n"); let response = client .post("/batch_add_file") .header(ContentType::parse_flexible(&content_type).unwrap()) .body(body) .dispatch() .await; assert_eq!(response.status(), Status::Ok); let body: BatchAddFileResponse = response.into_json().await.expect("Failed to parse JSON"); assert_eq!(body.ids.len(), 2); } #[async_test] async fn test_execute_sql() { let client = create_test_rocket().await; create_test_collection(&client, "test_collection").await; // Add some test data let add_request = AddRequest { collection_name: "test_collection".to_string(), texts: vec!["Test text".to_string()], vectors: None, metadata: Some(vec![HashMap::from([( "key".to_string(), "value".to_string(), )])]), model: Some(EmbeddingModelType::default()), document_id: None, create_new_doc: Some(true), }; client.post("/add").json(&add_request).dispatch().await; // Execute a custom SQL query let sql_query = "SELECT doc_id, text FROM test_collection_docs"; let execute_sql_request = ExecuteSqlRequest { sql_query: sql_query.to_string(), }; let response = client .post("/execute_sql") .json(&execute_sql_request) .dispatch() .await; assert_eq!(response.status(), Status::Ok); let body = response.into_json::().await.unwrap(); println!("Body: {:?}", body); assert_eq!(body.results.len(), 1); assert!(body.results[0].get("doc_id").is_some()); assert_eq!( body.results[0].get("text").unwrap().as_str().unwrap(), "Test text" ); } #[async_test] async fn test_vector_search() { let client = create_test_rocket().await; create_test_collection(&client, "test_collection").await; let model = EmbeddingModelType::default(); // Add two documents let add_request1 = AddRequest { collection_name: "test_collection".to_string(), texts: vec!["The quick brown fox jumps over the lazy dog".to_string()], vectors: None, metadata: Some(vec![HashMap::from([( "category".to_string(), "animals".to_string(), )])]), model: Some(model.clone()), document_id: None, create_new_doc: Some(true), }; let add_request2 = AddRequest { collection_name: "test_collection".to_string(), texts: vec!["A quick brown fox jumps over the lazy cat".to_string()], vectors: None, metadata: Some(vec![HashMap::from([( "category".to_string(), "animals".to_string(), )])]), model: Some(model.clone()), document_id: None, create_new_doc: Some(true), }; let response1 = client.post("/add").json(&add_request1).dispatch().await; let response2 = client.post("/add").json(&add_request2).dispatch().await; assert_eq!(response1.status(), Status::Ok); assert_eq!(response2.status(), Status::Ok); // Verify count of chunks let count_request = CountItemsRequest { collection_name: "test_collection".to_string(), }; let count_response = client.get("/count").json(&count_request).dispatch().await; assert_eq!(count_response.status(), Status::Ok); let count: CountResponse = count_response.into_json().await.unwrap(); assert_eq!(count.count, 2); // Perform chunk search let chunk_search_request = SearchRequest { collection_name: "test_collection".to_string(), query: "quick fox lazy".to_string(), top_k: Some(2), model: Some(model.clone()), model_url: None, search_type: Some(SearchType::Chunks), similarity_method: Some(SimilarityMethod::Cosine), search_mode: Some(SearchMode::Vector), document_ids: None, }; let chunk_search_response = client .post("/search") .json(&chunk_search_request) .dispatch() .await; assert_eq!(chunk_search_response.status(), Status::Ok); let chunk_search_body = chunk_search_response .into_json::() .await .unwrap(); assert_eq!(chunk_search_body.results.len(), 2); // Check that the chunk results contain the expected texts let chunk_result_texts: Vec<&str> = chunk_search_body .results .iter() .map(|r| r.document.text.as_ref().unwrap().as_str()) .collect(); assert!(chunk_result_texts.contains(&"The quick brown fox jumps over the lazy dog")); assert!(chunk_result_texts.contains(&"A quick brown fox jumps over the lazy cat")); // Check that the chunk distances are reasonable (between 0 and 1 for cosine similarity) for result in &chunk_search_body.results { assert!(result.distance >= 0.0 && result.distance <= 1.0); } // Check that the chunk results are sorted by distance (ascending order for cosine similarity) assert!(chunk_search_body.results[0].distance <= chunk_search_body.results[1].distance); // Verify count of documents let count_request = CountItemsRequest { collection_name: "test_collection".to_string(), }; let count_response = client.get("/count").json(&count_request).dispatch().await; assert_eq!(count_response.status(), Status::Ok); let count: CountResponse = count_response.into_json().await.unwrap(); assert_eq!(count.count, 2); // // NOTE: Won't be able to perform document search until we have a description embedding model // // // Perform document search // let doc_search_request = SearchRequest { // text: "quick fox lazy".to_string(), // top_k: Some(2), // model: Some(model.clone()), // search_type: SearchType::Documents, // similarity_method: Some(SimilarityMethod::Cosine), // }; // let doc_search_response = client // .post("/search") // .json(&doc_search_request) // .dispatch() // .await; // assert_eq!(doc_search_response.status(), Status::Ok); // let doc_search_body = doc_search_response // .into_json::() // .await // .unwrap(); // assert_eq!(doc_search_body.results.len(), 2); // // Check that the document results contain the expected texts // let doc_result_texts: Vec<&str> = doc_search_body // .results // .iter() // .map(|r| r.document.text.as_ref().unwrap().as_str()) // .collect(); // assert!(doc_result_texts.contains(&"The quick brown fox jumps over the lazy dog")); // assert!(doc_result_texts.contains(&"A quick brown fox jumps over the lazy cat")); // // Check that the document distances are reasonable (between 0 and 1 for cosine similarity) // for result in &doc_search_body.results { // assert!(result.distance >= 0.0 && result.distance <= 1.0); // } // // Check that the document results are sorted by distance (ascending order for cosine similarity) // assert!(doc_search_body.results[0].distance <= doc_search_body.results[1].distance); } #[async_test] async fn test_fulltext_search() { let client = create_test_rocket().await; create_test_collection(&client, "test_collection").await; let model = EmbeddingModelType::default(); // Add two documents let add_request1 = AddRequest { collection_name: "test_collection".to_string(), texts: vec!["The quick brown fox jumps over the lazy dog".to_string()], vectors: None, metadata: Some(vec![HashMap::from([( "category".to_string(), "animals".to_string(), )])]), model: Some(model.clone()), document_id: None, create_new_doc: Some(true), }; let add_request2 = AddRequest { collection_name: "test_collection".to_string(), texts: vec!["A quick brown fox jumps over the lazy cat".to_string()], vectors: None, metadata: Some(vec![HashMap::from([( "category".to_string(), "animals".to_string(), )])]), model: Some(model.clone()), document_id: None, create_new_doc: Some(true), }; let response1 = client.post("/add").json(&add_request1).dispatch().await; let response2 = client.post("/add").json(&add_request2).dispatch().await; assert_eq!(response1.status(), Status::Ok); assert_eq!(response2.status(), Status::Ok); // Count the added documents and chunks let count_request = CountItemsRequest { collection_name: "test_collection".to_string(), }; let count_response = client.get("/count").json(&count_request).dispatch().await; assert_eq!(count_response.status(), Status::Ok); let count: CountResponse = count_response.into_json().await.unwrap(); assert_eq!(count.count, 2); // Print and list documents let get_docs_request = GetDocsRequest { collection_name: "test_collection".to_string(), ids: None, limit: None, offset: None, }; let list_docs_response = client .get("/get_docs") .json(&get_docs_request) .dispatch() .await; assert_eq!(list_docs_response.status(), Status::Ok); let documents: GetDocsResponse = list_docs_response.into_json().await.unwrap(); println!("Documents:"); for doc in &documents.results { println!(" ID: {}, Filename: {}", doc.id, doc.filename); } // Print and list chunks let get_chunks_request = GetTextsRequest { collection_name: "test_collection".to_string(), ids: None, limit: None, offset: None, }; let list_chunks_response = client .get("/get_texts") .json(&get_chunks_request) .dispatch() .await; assert_eq!(list_chunks_response.status(), Status::Ok); let chunks: GetTextsResponse = list_chunks_response.into_json().await.unwrap(); println!("Chunks:"); for chunk in &chunks.results { println!( " Doc ID: {}, Content: {}", chunk.doc_id, chunk.chunk_content ); } // Perform fulltext search let fulltext_search_request = SearchRequest { collection_name: "test_collection".to_string(), query: "quick fox lazy".to_string(), model: Some(model.clone()), top_k: Some(2), model_url: None, search_type: Some(SearchType::Documents), similarity_method: None, search_mode: Some(SearchMode::FullText), document_ids: None, }; let search_response = client .post("/search") .json(&fulltext_search_request) .dispatch() .await; assert_eq!(search_response.status(), Status::Ok); let search_body = search_response.into_json::().await.unwrap(); assert_eq!(search_body.results.len(), 2); // Check that the results contain the expected texts let result_texts: Vec<&str> = search_body .results .iter() .map(|r| r.document.text.as_ref().unwrap().as_str()) .collect(); assert!(result_texts.contains(&"The quick brown fox jumps over the lazy dog")); assert!(result_texts.contains(&"A quick brown fox jumps over the lazy cat")); // Check that the results have scores (for fulltext search) for result in &search_body.results { assert!(result.distance >= 0.0 && result.distance <= 1.0); } // Check that the results are sorted by score (descending order for fulltext search) assert!(search_body.results[0].distance <= search_body.results[1].distance); } #[async_test] async fn test_hybrid_search() { let client = create_test_rocket().await; create_test_collection(&client, "test_collection").await; let model = EmbeddingModelType::default(); // Add two documents let add_request1 = AddRequest { collection_name: "test_collection".to_string(), texts: vec!["The quick brown fox jumps over the lazy dog".to_string()], vectors: None, metadata: Some(vec![HashMap::from([( "category".to_string(), "animals".to_string(), )])]), model: Some(model.clone()), document_id: None, create_new_doc: Some(true), }; let add_request2 = AddRequest { collection_name: "test_collection".to_string(), texts: vec!["A quick brown fox jumps over the lazy cat".to_string()], vectors: None, metadata: Some(vec![HashMap::from([( "category".to_string(), "animals".to_string(), )])]), model: Some(model.clone()), document_id: None, create_new_doc: Some(true), }; let response1 = client.post("/add").json(&add_request1).dispatch().await; let response2 = client.post("/add").json(&add_request2).dispatch().await; assert_eq!(response1.status(), Status::Ok); assert_eq!(response2.status(), Status::Ok); // Perform hybrid search let hybrid_search_request = SearchRequest { collection_name: "test_collection".to_string(), query: "quick fox lazy".to_string(), model: Some(model.clone()), top_k: Some(2), model_url: None, search_type: Some(SearchType::Chunks), similarity_method: Some(SimilarityMethod::Cosine), search_mode: Some(SearchMode::Hybrid), document_ids: None, }; let search_response = client .post("/search") .json(&hybrid_search_request) .dispatch() .await; assert_eq!(search_response.status(), Status::Ok); let search_body = search_response.into_json::().await.unwrap(); // Check results assert_eq!(search_body.results.len(), 2); // Check if both results contain the search terms for result in &search_body.results { assert!( result.content.contains("quick") && result.content.contains("fox") && result.content.contains("lazy") ); } // Ensure results are in the correct order (most relevant first) assert!(search_body.results[0].distance >= search_body.results[1].distance); } #[async_test] async fn test_create_and_delete_collection() { let client = create_test_rocket().await; // Test create collection let create_request = CreateCollectionRequest { collection_name: "test_collection".to_string(), dimensions: DIMENSIONS, }; let create_response = client .post("/create_collection") .json(&create_request) .dispatch() .await; assert_eq!(create_response.status(), Status::Ok); let create_body = create_response .into_json::() .await .unwrap(); assert!(create_body.success); // Test delete collection let delete_request = DeleteCollectionRequest { collection_name: "test_collection".to_string(), }; let delete_response = client .delete("/delete_collection") .json(&delete_request) .dispatch() .await; assert_eq!(delete_response.status(), Status::Ok); let delete_body = delete_response .into_json::() .await .unwrap(); assert!(delete_body.success); }