use crate::common::{create_model_info, setup_mock_provider_with_models}; use kernelx_core::{prelude::*, Error}; mod common; #[test] fn test_capability_checks() { let provider = setup_mock_provider_with_models(vec![create_model_info( "test-model", vec![Capability::Chat, Capability::Complete], )]) .unwrap(); // Check capabilities through the model's info from provider let model_info = provider .models() .iter() .find(|m| m.id == "test-model") .unwrap(); assert!(model_info.capabilities.contains(&Capability::Chat)); assert!(model_info.capabilities.contains(&Capability::Complete)); assert!(!model_info.capabilities.contains(&Capability::Vision)); } #[tokio::test] async fn test_capability_requirements() { let provider = OpenAI::builder() .api_key("test-key") .models(vec![ ModelInfo { id: "gpt-4".to_string(), capabilities: vec![Capability::Chat, Capability::Complete], }, ModelInfo { id: "gpt-4-vision".to_string(), capabilities: vec![Capability::Chat, Capability::Complete, Capability::Vision], }, ]) .build() .unwrap(); // Test that getting a Vision model for gpt-4 should return an error assert!(matches!( provider.get_model::("gpt-4"), Err(Error::UnsupportedCapability(_)) )); // These should work assert!(provider.get_model::("gpt-4").is_ok()); assert!(provider.get_model::("gpt-4-vision").is_ok()); } #[test] fn test_unsupported_capability() { let provider = setup_mock_provider_with_models(vec![create_model_info( "test-model", vec![Capability::Chat, Capability::Complete], )]) .unwrap(); assert!(matches!( provider.get_model::("test-model"), Err(Error::UnsupportedCapability(_)) )); } #[test] fn test_capability_combinations() { let provider = setup_mock_provider_with_models(vec![ create_model_info("basic-model", vec![Capability::Complete]), create_model_info("chat-model", vec![Capability::Chat, Capability::Complete]), create_model_info( "full-model", vec![ Capability::Chat, Capability::Complete, Capability::Structured, ], ), ]) .unwrap(); // Test different capability combinations assert!(provider.get_model::("basic-model").is_ok()); assert!(matches!( provider.get_model::("basic-model"), Err(Error::UnsupportedCapability(_)) )); // Basic LLM capabilities (Chat + Complete) assert!(provider.get_model::("chat-model").is_ok()); // Extended LLM capabilities (Chat + Complete + Structured) assert!(provider.get_model::("full-model").is_ok()); }