use chrono::Utc; use dragon_db::{ db_postgres::{Chunk, Document, VectorDB, VectorDBConfig}, embedding::{embed, embed_query, EmbeddingModelType}, server::models::{SearchMode, SearchQuery, SearchType, SimilarityMethod}, DB_PATH, DIMENSIONS, }; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::OnceCell; static TEST_DB: OnceCell> = OnceCell::const_new(); async fn get_test_db() -> Arc { TEST_DB .get_or_init(|| async { println!("Creating new test db"); let config = VectorDBConfig { path: format!("{}/testing", DB_PATH.to_string()), }; let vector_db = VectorDB::with_config(config).await.unwrap(); vector_db.reset().await.unwrap(); Arc::new(vector_db) }) .await .clone() } #[tokio::test] async fn postgres_test_insert_and_get() { let db = get_test_db().await; let collection_name = "test_collection"; db.create_collection(collection_name).await.unwrap(); let vector = vec![1.0; DIMENSIONS]; let document = Document { id: "test".to_string(), filename: "test.txt".to_string(), filetype: "text/plain".to_string(), filesize: 9, // "test text".len() date_created: Utc::now(), date_modified: None, description: Some("A test document".to_string()), description_embedding: None, metadata: None, text: Some("test text".to_string()), }; let chunk = Chunk { chunk_id: 0, doc_id: document.id.clone(), chunk_content: "test text".to_string(), chunk_embedding: vector.clone(), chunk_index: 0, }; let id = db .add(collection_name, vec![chunk], Some(document.clone())) .await .unwrap(); let retrieved_chunks = db.get_chunks(collection_name, &id).await.unwrap(); assert_eq!(retrieved_chunks[0].chunk_embedding, vector); let retrieved_document = db .get_document(collection_name, &id) .await .unwrap() .unwrap(); assert_eq!(retrieved_document.id, document.id); assert_eq!(retrieved_document.text, document.text); // Shutdown the db db.shutdown().await.unwrap(); } #[tokio::test] async fn postgres_test_delete() { let db = get_test_db().await; let collection_name = "test_collection"; db.create_collection(collection_name).await.unwrap(); let vector = vec![1.0; DIMENSIONS]; let document = Document { id: "test".to_string(), filename: "test.txt".to_string(), filetype: "text/plain".to_string(), filesize: 9, // "test text".len() date_created: Utc::now(), date_modified: None, description: Some("A test document".to_string()), description_embedding: None, metadata: None, text: Some("test text".to_string()), }; let chunk = Chunk { chunk_id: 0, doc_id: document.id.clone(), chunk_content: "test text".to_string(), chunk_embedding: vector, chunk_index: 0, }; let id = db .add(collection_name, vec![chunk], Some(document)) .await .unwrap(); db.delete(collection_name, vec![id.clone()]).await.unwrap(); assert!(db .get_chunks(collection_name, &id) .await .unwrap() .is_empty()); assert!(db .get_document(collection_name, &id) .await .unwrap() .is_none()); // Shutdown the db db.shutdown().await.unwrap(); } #[tokio::test] async fn postgres_test_semantic_search_chunks() { let db = get_test_db().await; let collection_name = "test_collection"; db.create_collection(collection_name).await.unwrap(); let vector1 = vec![1.0; DIMENSIONS]; let vector2 = vec![0.0; DIMENSIONS]; let vector3 = vec![0.5; DIMENSIONS]; let id1 = db .add( collection_name, vec![Chunk { chunk_id: 0, doc_id: "vector1".to_string(), chunk_content: "The quick brown fox jumps over the lazy dog".to_string(), chunk_embedding: vector1.clone(), chunk_index: 0, }], Some(Document { id: "vector1".to_string(), filename: "vector1.txt".to_string(), filetype: "text/plain".to_string(), filesize: 43, date_created: Utc::now(), date_modified: None, description: Some("Vector 1".to_string()), description_embedding: None, metadata: None, text: Some("The quick brown fox jumps over the lazy dog".to_string()), }), ) .await .unwrap(); let id2 = db .add( collection_name, vec![Chunk { chunk_id: 0, doc_id: "vector2".to_string(), chunk_content: "A fast auburn fox leaps above the sleepy canine".to_string(), chunk_embedding: vector2.clone(), chunk_index: 0, }], Some(Document { id: "vector2".to_string(), filename: "vector2.txt".to_string(), filetype: "text/plain".to_string(), filesize: 43, date_created: Utc::now(), date_modified: None, description: Some("Vector 2".to_string()), description_embedding: None, metadata: None, text: Some("A fast auburn fox leaps above the sleepy canine".to_string()), }), ) .await .unwrap(); let id3 = db .add( collection_name, vec![Chunk { chunk_id: 0, doc_id: "vector3".to_string(), chunk_content: "Python is a popular programming language".to_string(), chunk_embedding: vector3.clone(), chunk_index: 0, }], Some(Document { id: "vector3".to_string(), filename: "vector3.txt".to_string(), filetype: "text/plain".to_string(), filesize: 35, date_created: Utc::now(), date_modified: None, description: Some("Vector 3".to_string()), description_embedding: None, metadata: None, text: Some("Python is a popular programming language".to_string()), }), ) .await .unwrap(); // Assert that the vectors are added to the database let docs = db.list_documents(collection_name).await.unwrap(); assert_eq!(docs.len(), 3); assert_eq!(docs[0].id, id1); assert_eq!(docs[1].id, id2); assert_eq!(docs[2].id, id3); // Assert chunk embeddings are added to the database let chunks = db.list_chunks(collection_name).await.unwrap(); assert_eq!(chunks.len(), 3); assert_eq!(chunks[0].chunk_embedding, vector1); assert_eq!(chunks[1].chunk_embedding, vector2); assert_eq!(chunks[2].chunk_embedding, vector3); let query = vec![0.9; DIMENSIONS]; let results = db .semantic_search( collection_name, SearchQuery::VSS(query.clone()), 2, SearchType::Chunks, SearchMode::Vector, SimilarityMethod::Euclidean, Some(vec![id1.clone(), id3.clone()]), ) .await .unwrap(); assert_eq!(results.len(), 2); assert_eq!(results[0].2.id, id1); assert_eq!(results[1].2.id, id3); // Shutdown the db db.shutdown().await.unwrap(); } // // NOTE: WILL FAIL UNTIL DESCRIPTION EMBEDDING IS IMPLEMENTED // // #[tokio::test] // async fn postgres_test_semantic_search_docs() { // let db = get_test_db().await; // let collection_name = "test_collection"; // db.create_collection(collection_name).await.unwrap(); // let vector1 = vec![1.0; DIMENSIONS]; // let vector2 = vec![0.0; DIMENSIONS]; // let vector3 = vec![0.5; DIMENSIONS]; // let id1 = db // .add( // collection_name, // vec![Chunk { // chunk_id: 0, // doc_id: "vector1".to_string(), // chunk_content: "The quick brown fox jumps over the lazy dog".to_string(), // chunk_embedding: vector1.clone(), // chunk_index: 0, // }], // Some(Document { // id: "vector1".to_string(), // filename: "vector1.txt".to_string(), // filetype: "text/plain".to_string(), // filesize: 43, // date_created: Utc::now(), // date_modified: None, // description: Some("Vector 1".to_string()), // description_embedding: Some(vector1.clone()), // metadata: None, // text: Some("The quick brown fox jumps over the lazy dog".to_string()), // }), // ) // .await // .unwrap(); // let id2 = db // .add( // collection_name, // vec![Chunk { // chunk_id: 0, // doc_id: "vector2".to_string(), // chunk_content: "A fast auburn fox leaps above the sleepy canine".to_string(), // chunk_embedding: vector2.clone(), // chunk_index: 0, // }], // Some(Document { // id: "vector2".to_string(), // filename: "vector2.txt".to_string(), // filetype: "text/plain".to_string(), // filesize: 43, // date_created: Utc::now(), // date_modified: None, // description: Some("Vector 2".to_string()), // description_embedding: Some(vector2.clone()), // metadata: None, // text: Some("A fast auburn fox leaps above the sleepy canine".to_string()), // }), // ) // .await // .unwrap(); // let id3 = db // .add( // collection_name, // vec![Chunk { // chunk_id: 0, // doc_id: "vector3".to_string(), // chunk_content: "Python is a popular programming language".to_string(), // chunk_embedding: vector3.clone(), // chunk_index: 0, // }], // Some(Document { // id: "vector3".to_string(), // filename: "vector3.txt".to_string(), // filetype: "text/plain".to_string(), // filesize: 35, // date_created: Utc::now(), // date_modified: None, // description: Some("Vector 3".to_string()), // description_embedding: Some(vector3.clone()), // metadata: None, // text: Some("Python is a popular programming language".to_string()), // }), // ) // .await // .unwrap(); // // Assert that the vectors are added to the database // let docs = db.list_documents(collection_name).await.unwrap(); // assert_eq!(docs.len(), 3); // assert_eq!(docs[0].id, id1); // assert_eq!(docs[1].id, id2); // assert_eq!(docs[2].id, id3); // // Assert chunk embeddings are added to the database // let chunks = db.list_chunks(collection_name).await.unwrap(); // assert_eq!(chunks.len(), 3); // assert_eq!(chunks[0].chunk_embedding, vector1); // assert_eq!(chunks[1].chunk_embedding, vector2); // assert_eq!(chunks[2].chunk_embedding, vector3); // let query = vec![0.9; DIMENSIONS]; // let results = db.semantic_search_docs(collection_name, query, 2).await.unwrap(); // assert_eq!(results.len(), 2); // assert_eq!(results[0].1.id, id1); // assert_eq!(results[1].1.id, id3); // // Shutdown the db // db.shutdown().await.unwrap(); // } #[tokio::test] async fn postgres_test_list() { let db = get_test_db().await; let collection_name = "test_collection"; db.create_collection(collection_name).await.unwrap(); let vector1 = vec![1.0; DIMENSIONS]; let vector2 = vec![0.0; DIMENSIONS]; db.add( collection_name, vec![Chunk { chunk_id: 0, doc_id: "vector1".to_string(), chunk_content: "vector1".to_string(), chunk_embedding: vector1, chunk_index: 0, }], Some(Document { id: "vector1".to_string(), filename: "vector1.txt".to_string(), filetype: "text/plain".to_string(), filesize: 7, // "vector1".len() date_created: Utc::now(), date_modified: None, description: Some("Vector 1".to_string()), description_embedding: None, metadata: None, text: Some("vector1".to_string()), }), ) .await .unwrap(); db.add( collection_name, vec![Chunk { chunk_id: 0, doc_id: "vector2".to_string(), chunk_content: "vector2".to_string(), chunk_embedding: vector2, chunk_index: 0, }], Some(Document { id: "vector2".to_string(), filename: "vector2.txt".to_string(), filetype: "text/plain".to_string(), filesize: 7, // "vector2".len() date_created: Utc::now(), date_modified: None, description: Some("Vector 2".to_string()), description_embedding: None, metadata: None, text: Some("vector2".to_string()), }), ) .await .unwrap(); let list = db.list_documents(collection_name).await.unwrap(); assert_eq!(list.len(), 2); // Shutdown the db db.shutdown().await.unwrap(); } #[tokio::test] async fn postgres_test_update() { let db = get_test_db().await; let collection_name = "test_collection"; db.create_collection(collection_name).await.unwrap(); let vector = vec![1.0; DIMENSIONS]; let id = db .add( collection_name, vec![Chunk { chunk_id: 0, doc_id: "original".to_string(), chunk_content: "original".to_string(), chunk_embedding: vector, chunk_index: 0, }], Some(Document { id: "original".to_string(), filename: "original.txt".to_string(), filetype: "text/plain".to_string(), filesize: 8, // "original".len() date_created: Utc::now(), date_modified: None, description: Some("Original document".to_string()), description_embedding: None, metadata: None, text: Some("original".to_string()), }), ) .await .unwrap(); let new_vector = vec![2.0; DIMENSIONS]; let new_document = Document { id: id.clone(), filename: "updated.txt".to_string(), filetype: "text/plain".to_string(), filesize: 7, // "updated".len() date_created: Utc::now(), date_modified: Some(Utc::now()), description: Some("Updated document".to_string()), description_embedding: None, metadata: None, text: Some("updated".to_string()), }; db.update( collection_name, &id, Some(vec![Chunk { chunk_id: 0, doc_id: id.clone(), chunk_content: "updated".to_string(), chunk_embedding: new_vector.clone(), chunk_index: 0, }]), Some(new_document.clone()), ) .await .unwrap(); let updated_chunks = db.get_chunks(collection_name, &id).await.unwrap(); let updated_document = db.get_document(collection_name, &id).await.unwrap(); assert_eq!(updated_chunks[0].chunk_embedding, new_vector); assert_eq!(updated_document.unwrap().filename, new_document.filename); // Shutdown the db db.shutdown().await.unwrap(); } #[tokio::test] async fn postgres_test_batch_add() { let db = get_test_db().await; let collection_name = "test_collection"; db.create_collection(collection_name).await.unwrap(); let data = vec![ ( vec![Chunk { chunk_id: 0, doc_id: "vector1".to_string(), chunk_content: "vector1".to_string(), chunk_embedding: vec![1.0; DIMENSIONS], chunk_index: 0, }], Some(Document { id: "vector1".to_string(), filename: "vector1.txt".to_string(), filetype: "text/plain".to_string(), filesize: 7, // "vector1".len() date_created: Utc::now(), date_modified: None, description: Some("Vector 1".to_string()), description_embedding: None, metadata: None, text: Some("vector1".to_string()), }), ), ( vec![Chunk { chunk_id: 0, doc_id: "vector2".to_string(), chunk_content: "vector2".to_string(), chunk_embedding: vec![2.0; DIMENSIONS], chunk_index: 0, }], Some(Document { id: "vector2".to_string(), filename: "vector2.txt".to_string(), filetype: "text/plain".to_string(), filesize: 7, // "vector2".len() date_created: Utc::now(), date_modified: None, description: Some("Vector 2".to_string()), description_embedding: None, metadata: None, text: Some("vector2".to_string()), }), ), ]; let ids = db.batch_add(collection_name, data).await.unwrap(); assert_eq!(ids.len(), 2); let list = db.list_documents(collection_name).await.unwrap(); assert_eq!(list.len(), 2); // Verify chunk count is correct let chunks = db.get_chunks(collection_name, &ids[0]).await.unwrap(); assert_eq!(chunks.len(), 1); let chunks = db.get_chunks(collection_name, &ids[1]).await.unwrap(); assert_eq!(chunks.len(), 1); // Verify total chunk count let all_chunks = db.list_chunks(collection_name).await.unwrap(); assert_eq!(all_chunks.len(), 2); // Shutdown the db db.shutdown().await.unwrap(); } #[tokio::test] async fn postgres_test_batch_update() { let db = get_test_db().await; let collection_name = "test_collection"; db.create_collection(collection_name).await.unwrap(); let data = vec![ ( vec![Chunk { chunk_id: 0, doc_id: "vector1".to_string(), chunk_content: "vector1".to_string(), chunk_embedding: vec![1.0; DIMENSIONS], chunk_index: 0, }], Some(Document { id: "vector1".to_string(), filename: "vector1.txt".to_string(), filetype: "text/plain".to_string(), filesize: 7, // "vector1".len() date_created: Utc::now(), date_modified: None, description: Some("Vector 1".to_string()), description_embedding: None, metadata: None, text: Some("vector1".to_string()), }), ), ( vec![Chunk { chunk_id: 0, doc_id: "vector2".to_string(), chunk_content: "vector2".to_string(), chunk_embedding: vec![2.0; DIMENSIONS], chunk_index: 0, }], Some(Document { id: "vector2".to_string(), filename: "vector2.txt".to_string(), filetype: "text/plain".to_string(), filesize: 7, // "vector2".len() date_created: Utc::now(), date_modified: None, description: Some("Vector 2".to_string()), description_embedding: None, metadata: None, text: Some("vector2".to_string()), }), ), ]; let ids = db.batch_add(collection_name, data).await.unwrap(); let update_data = vec![ ( ids[0].clone(), vec![Chunk { chunk_id: 0, doc_id: ids[0].clone(), chunk_content: "updated vector1".to_string(), chunk_embedding: vec![3.0; DIMENSIONS], chunk_index: 0, }], Some(Document { id: ids[0].clone(), filename: "updated_vector1.txt".to_string(), filetype: "text/plain".to_string(), filesize: 15, // "updated vector1".len() date_created: Utc::now(), date_modified: Some(Utc::now()), description: Some("Updated Vector 1".to_string()), description_embedding: None, metadata: None, text: Some("updated vector1".to_string()), }), ), ( ids[1].clone(), vec![Chunk { chunk_id: 0, doc_id: ids[1].clone(), chunk_content: "updated vector2".to_string(), chunk_embedding: vec![4.0; DIMENSIONS], chunk_index: 0, }], Some(Document { id: ids[1].clone(), filename: "updated_vector2.txt".to_string(), filetype: "text/plain".to_string(), filesize: 15, // "updated vector2".len() date_created: Utc::now(), date_modified: Some(Utc::now()), description: Some("Updated Vector 2".to_string()), description_embedding: None, metadata: None, text: Some("updated vector2".to_string()), }), ), ]; db.batch_update(collection_name, update_data).await.unwrap(); let list = db.list_documents(collection_name).await.unwrap(); assert_eq!(list.len(), 2); assert_eq!(list[0].filename, "updated_vector1.txt"); assert_eq!(list[1].filename, "updated_vector2.txt"); let chunks = db.get_chunks(collection_name, &ids[0]).await.unwrap(); assert_eq!(chunks[0].chunk_embedding, vec![3.0; DIMENSIONS]); // Shutdown the db db.shutdown().await.unwrap(); } #[tokio::test] async fn postgres_test_execute_sql() { let db = get_test_db().await; let vector1 = vec![1.0; DIMENSIONS]; let vector2 = vec![0.0; DIMENSIONS]; let collection_name = "test_collection"; db.create_collection(collection_name).await.unwrap(); db.add( collection_name, vec![Chunk { chunk_id: 0, doc_id: "vector1".to_string(), chunk_content: "vector1".to_string(), chunk_embedding: vector1.clone(), chunk_index: 0, }], Some(Document { id: "vector1".to_string(), filename: "vector1.txt".to_string(), filetype: "text/plain".to_string(), filesize: 7, // "vector1".len() date_created: Utc::now(), date_modified: None, description: Some("Vector 1".to_string()), description_embedding: Some(vector1), metadata: Some(HashMap::from([( "user_id".to_string(), "test_user1".to_string(), )])), text: Some("vector1".to_string()), }), ) .await .unwrap(); db.add( collection_name, vec![Chunk { chunk_id: 0, doc_id: "vector2".to_string(), chunk_content: "vector2".to_string(), chunk_embedding: vector2.clone(), chunk_index: 0, }], Some(Document { id: "vector2".to_string(), filename: "vector2.txt".to_string(), filetype: "text/plain".to_string(), filesize: 7, // "vector2".len() date_created: Utc::now(), date_modified: None, description: Some("Vector 2".to_string()), description_embedding: Some(vector2), metadata: Some(HashMap::from([( "user_id".to_string(), "test_user2".to_string(), )])), text: Some("vector2".to_string()), }), ) .await .unwrap(); let sql_query = "SELECT documents.doc_id, chunks.chunk_content, documents.metadata->>'user_id' as user_id FROM test_collection_chunks chunks JOIN test_collection_docs documents ON chunks.doc_id = documents.doc_id ORDER BY documents.doc_id"; let results = db.execute_sql(sql_query).await.unwrap(); println!("Results: {:?}", results); assert_eq!(results.len(), 2); assert_eq!(results[0]["doc_id"], "vector1"); assert_eq!(results[0]["chunk_content"], "vector1"); assert_eq!(results[0]["user_id"], "test_user1"); assert_eq!(results[1]["doc_id"], "vector2"); assert_eq!(results[1]["chunk_content"], "vector2"); assert_eq!(results[1]["user_id"], "test_user2"); // Test a more complex query let complex_query = "SELECT metadata->>'user_id' as user_id, COUNT(*) as count FROM test_collection_docs GROUP BY metadata->>'user_id' ORDER BY metadata->>'user_id'"; let complex_results = db.execute_sql(complex_query).await.unwrap(); assert_eq!(complex_results.len(), 2); assert_eq!(complex_results[0]["user_id"], "test_user1"); assert_eq!(complex_results[0]["count"].as_i64().unwrap(), 1); assert_eq!(complex_results[1]["user_id"], "test_user2"); assert_eq!(complex_results[1]["count"].as_i64().unwrap(), 1); // Shutdown the db db.shutdown().await.unwrap(); } #[tokio::test] async fn test_hybrid_search() { let db = get_test_db().await; let collection_name = "test_collection"; db.create_collection(collection_name).await.unwrap(); // Add some test documents let vector1 = vec![0.1; DIMENSIONS]; let vector2 = vec![0.2; DIMENSIONS]; db.add( collection_name, vec![Chunk { chunk_id: 0, doc_id: "doc1".to_string(), chunk_content: "The quick brown fox jumps over the lazy dog".to_string(), chunk_embedding: vector1.clone(), chunk_index: 0, }], Some(Document { id: "doc1".to_string(), filename: "doc1.txt".to_string(), filetype: "text/plain".to_string(), filesize: 43, date_created: Utc::now(), date_modified: None, description: Some("Document 1".to_string()), description_embedding: Some(vector1), metadata: Some(HashMap::from([( "category".to_string(), "animals".to_string(), )])), text: Some("The quick brown fox jumps over the lazy dog".to_string()), }), ) .await .unwrap(); db.add( collection_name, vec![Chunk { chunk_id: 0, doc_id: "doc2".to_string(), chunk_content: "A quick brown fox jumps over the lazy cat".to_string(), chunk_embedding: vector2.clone(), chunk_index: 0, }], Some(Document { id: "doc2".to_string(), filename: "doc2.txt".to_string(), filetype: "text/plain".to_string(), filesize: 43, date_created: Utc::now(), date_modified: None, description: Some("Document 2".to_string()), description_embedding: Some(vector2), metadata: Some(HashMap::from([( "category".to_string(), "animals".to_string(), )])), text: Some("A quick brown fox jumps over the lazy cat".to_string()), }), ) .await .unwrap(); // Perform hybrid search let query = "quick fox lazy"; let query_vector = embed_query(query, &EmbeddingModelType::default()) .await .unwrap(); let results = db .hybrid_search( collection_name, query.to_string(), query_vector, 2, None, None, None, None, ) .await .unwrap(); // Check results assert_eq!(results.len(), 2); assert!( results[0].1.contains("quick") && results[0].1.contains("fox") && results[0].1.contains("lazy") ); // Shutdown the db db.shutdown().await.unwrap(); } #[tokio::test] async fn test_hybrid_search_with_document_ids() { let db = get_test_db().await; // Create a collection db.create_collection("test_collection").await.unwrap(); // Add multiple documents let docs = vec![ ("doc1", "The quick brown fox jumps over the lazy dog"), ("doc2", "A lazy cat sleeps all day long"), ("doc3", "Birds fly high in the blue sky"), ]; for (id, text) in docs { let embedding = embed(vec![text.to_string()], &EmbeddingModelType::default()) .await .unwrap(); db.add( "test_collection", vec![Chunk { chunk_id: 0, doc_id: id.to_string(), chunk_content: text.to_string(), chunk_embedding: embedding[0].clone(), chunk_index: 0, }], Some(Document { id: id.to_string(), filename: format!("{}.txt", id), filetype: "text/plain".to_string(), filesize: text.len(), date_created: Utc::now(), date_modified: None, description: Some(format!("Document {}", id)), description_embedding: Some(embedding[0].clone()), metadata: None, text: Some(text.to_string()), }), ) .await .unwrap(); } // Perform hybrid search with document_ids filter let query = "lazy animal"; let query_vector = embed_query(query, &EmbeddingModelType::default()) .await .unwrap(); let results = db .hybrid_search( "test_collection", query.to_string(), query_vector, 2, None, None, None, Some(vec!["doc2".to_string()]), ) .await .unwrap(); // Check results assert_eq!(results.len(), 1); assert_eq!(results[0].2.id, "doc2"); assert!(results[0].1.contains("lazy") && results[0].1.contains("cat")); // Delete the collection db.delete_collection("test_collection").await.unwrap(); }