use kernelx_core::{models, prelude::*}; use kernelx_macros::{lmp, lmp_schema, provider}; use schemars::{ schema::{RootSchema, SchemaObject}, JsonSchema, }; use serde::{Deserialize, Serialize}; mod mock { use super::*; use async_trait::async_trait; use serde_json::Value; #[derive(Clone, Debug, Default)] pub struct Provider { models: Vec, config: ModelConfig, } impl Provider { pub fn builder() -> ProviderBuilder { ProviderBuilder::new() } } impl kernelx_core::providers::Provider for Provider { fn models(&self) -> &[ModelInfo] { &self.models } fn from_builder(mut builder: ProviderBuilder) -> Result { Ok(Self { models: builder.take_models().unwrap_or_else(|| { vec![ModelInfo { id: "test-model".to_string(), capabilities: vec![ Capability::Chat, Capability::Complete, Capability::Structured, ], }] }), config: ModelConfig::default(), }) } } impl HasCapability for Provider { fn model_id(&self) -> &str { "mock" } fn config(&self) -> &ModelConfig { &self.config } } #[async_trait] impl Chat for Provider { async fn chat_impl( &self, _model: &str, _messages: Vec, _config: &ModelConfig, ) -> Result { Ok("mock response".to_string()) } } #[async_trait] impl Complete for Provider { async fn complete_impl( &self, _model: &str, _prompt: &str, _config: &ModelConfig, ) -> Result { Ok("mock response".to_string()) } } #[async_trait] impl Structured for Provider { async fn structured_impl( &self, _model: &str, _prompt: &str, _schema: &Value, _config: &ModelConfig, ) -> Result { Ok(serde_json::json!({ "message": "mock response" })) } } } // Re-export Provider for use in tests #[test] fn test_lmp_schema_derive() { #[lmp_schema] struct TestStruct { field1: String, field2: i32, nested: Option>, } let schema: RootSchema = schemars::schema_for!(TestStruct); let SchemaObject { object, .. } = schema.schema; if let Some(obj_validation) = object { assert!(obj_validation.properties.contains_key("field1")); assert!(obj_validation.properties.contains_key("field2")); assert!(obj_validation.properties.contains_key("nested")); } else { panic!("Schema should have object validation"); } let test = TestStruct { field1: "test".to_string(), field2: 42, nested: Some(vec!["nested".to_string()]), }; let json = serde_json::to_string(&test).unwrap(); let deserialized: TestStruct = serde_json::from_str(&json).unwrap(); assert_eq!(deserialized.field1, "test"); assert_eq!(deserialized.field2, 42); assert_eq!(deserialized.nested, Some(vec!["nested".to_string()])); } #[test] fn test_provider_configuration() { provider! { mock::Provider, api_base: "http://localhost:8080", api_key: "test-key", models: models![ "test-model" => [Capability::Complete, Capability::Chat, Capability::Structured], "vision-model" => [Capability::Vision], ] } let provider = get_provider(); let models = provider.models(); assert_eq!(models.len(), 2); // Test model capabilities let test_model = models.iter().find(|m| m.id == "test-model").unwrap(); assert!(test_model.capabilities.contains(&Capability::Complete)); assert!(test_model.capabilities.contains(&Capability::Chat)); assert!(test_model.capabilities.contains(&Capability::Structured)); let vision_model = models.iter().find(|m| m.id == "vision-model").unwrap(); assert!(vision_model.capabilities.contains(&Capability::Vision)); } #[tokio::test] async fn test_lmp_variants() { provider! { mock::Provider, api_base: "http://localhost:8080", api_key: "test-key", models: models!["test-model" => [Capability::Complete, Capability::Chat, Capability::Structured]] } #[lmp(model = "test-model")] async fn basic(input: &str) -> Result { format!("Echo: {}", input) } #[derive(Debug, Serialize, Deserialize, JsonSchema)] struct TestResponse { message: String, } #[lmp(model = "test-model", response_format = TestResponse)] async fn structured(input: &str) -> Result { format!("Response: {}", input) } let result = basic("test").await; assert!(result.is_ok()); assert_eq!(result.unwrap(), "mock response"); let result = structured("test").await; assert!(result.is_ok()); let response = result.unwrap(); assert_eq!(response.message, "mock response"); } #[tokio::test] async fn test_thread_safety() { provider! { mock::Provider, api_base: "http://localhost:8080", api_key: "test-key", models: models!["test-model" => [Capability::Complete]] } let handles: Vec<_> = (0..10) .map(|_| { tokio::spawn(async { let provider = get_provider(); assert_eq!(provider.models().len(), 1); }) }) .collect(); for handle in handles { handle.await.unwrap(); } }