#[allow(unused_imports)]
pub(crate) use anyhow::{anyhow, bail, Error, Result};
use llm_models::local_model::{gguf::preset::LlmPreset, LocalLlmModel};
use llm_prompt::{apply_chat_template, LlmPrompt};
use std::collections::HashMap;
#[test]
fn test_chat() -> crate::Result<()> {
let model = LocalLlmModel::default();
let prompt = LlmPrompt::new_chat_template_prompt(
&model.chat_template.chat_template,
&model.chat_template.bos_token,
&model.chat_template.eos_token,
model.chat_template.unk_token.as_deref(),
model.chat_template.base_generation_prefix.as_deref(),
model.model_base.tokenizer.clone(),
);
prompt
.add_user_message()?
.set_content("test user content 1");
prompt
.add_assistant_message()?
.set_content("test assistant content");
prompt
.add_user_message()?
.set_content("test user content 2");
let test_chat = prompt.get_built_prompt_string()?;
println!("{prompt}",);
assert_eq!(
test_chat,
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\ntest user content 1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\ntest assistant content<|eot_id|><|start_header_id|>user<|end_header_id|>\n\ntest user content 2<|eot_id|>"
);
let token_count = prompt.get_total_prompt_tokens()?;
let prompt_as_tokens = prompt.get_built_prompt_as_tokens()?;
assert_eq!(54, token_count);
assert_eq!(token_count, prompt_as_tokens.len() as u64);
prompt.set_generation_prefix("Generating 12345:");
let test_chat = prompt.get_built_prompt_string()?;
println!("{prompt}");
assert_eq!(
test_chat,
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\ntest user content 1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\ntest assistant content<|eot_id|><|start_header_id|>user<|end_header_id|>\n\ntest user content 2<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nGenerating 12345:"
);
let token_count = prompt.get_total_prompt_tokens()?;
let prompt_as_tokens = prompt.get_built_prompt_as_tokens()?;
assert_eq!(63, token_count);
assert_eq!(token_count, prompt_as_tokens.len() as u64);
Ok(())
}
const USER_PROMPT_1: &str = "tell me a joke";
const ASSISTANT_PROMPT_1: &str = "the clouds";
const USER_PROMPT_2: &str = "funny";
const ASSISTANT_PROMPT_2: &str = "beepboop";
const USER_PROMPT_3: &str = "robot?";
#[test]
fn test_chat_templates() -> crate::Result<()> {
let expected_outputs = [
// mistralai/Mistral-7B-Instruct-v0.3
"[INST] tell me a joke [/INST]the clouds[INST] funny [/INST]beepboop[INST] robot? [/INST]",
// phi/Phi-3-mini-4k-instruct
"<|user|>\ntell me a joke<|end|>\n<|assistant|>\nthe clouds<|end|>\n<|user|>\nfunny<|end|>\n<|assistant|>\nbeepboop<|end|>\n<|user|>\nrobot?<|end|>\n<|assistant|>\n",
];
let messages: Vec> = vec![
HashMap::from([
("role".to_string(), "user".to_string()),
("content".to_string(), USER_PROMPT_1.to_string()),
]),
HashMap::from([
("role".to_string(), "assistant".to_string()),
("content".to_string(), ASSISTANT_PROMPT_1.to_string()),
]),
HashMap::from([
("role".to_string(), "user".to_string()),
("content".to_string(), USER_PROMPT_2.to_string()),
]),
HashMap::from([
("role".to_string(), "assistant".to_string()),
("content".to_string(), ASSISTANT_PROMPT_2.to_string()),
]),
HashMap::from([
("role".to_string(), "user".to_string()),
("content".to_string(), USER_PROMPT_3.to_string()),
]),
];
let templates = vec![
LlmPreset::Mistral7bInstructV0_3.load()?.chat_template,
LlmPreset::Phi3Mini4kInstruct.load()?.chat_template,
];
for (i, chat_template) in templates.iter().enumerate() {
let res = apply_chat_template(
&messages,
&chat_template.chat_template,
&chat_template.bos_token,
&chat_template.eos_token,
chat_template.unk_token.as_deref(),
);
assert_eq!(res, expected_outputs[i]);
}
Ok(())
}