use callm::templates::{MessageRole, TemplateImpl, TemplateJinja as Template};
// Mistral-7B-Instruct-v0.3
const JINJA_TEMPLATE: &str = r#"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"#;
const BOS_TOKEN: &str = r#""#;
const EOS_TOKEN: &str = r#""#;
#[test]
fn single_user_message() {
let msgs = vec![(MessageRole::User, "User message 1".to_string())];
let mut template = Template::new(JINJA_TEMPLATE);
template.set_bos_token(Some(BOS_TOKEN.to_string()));
template.set_eos_token(Some(EOS_TOKEN.to_string()));
assert_eq!(
template.apply(msgs.as_slice()).unwrap(),
r#"[INST] User message 1 [/INST]"#
);
}
#[test]
fn two_messages() {
let msgs = vec![
(MessageRole::User, "User message 1".to_string()),
(MessageRole::Assistant, "Assistant message 1".to_string()),
];
let mut template = Template::new(JINJA_TEMPLATE);
template.set_bos_token(Some(BOS_TOKEN.to_string()));
template.set_eos_token(Some(EOS_TOKEN.to_string()));
assert_eq!(
template.apply(msgs.as_slice()).unwrap(),
r#"[INST] User message 1 [/INST]Assistant message 1"#
);
}
#[test]
fn three_messages() {
let msgs = vec![
(MessageRole::User, "User message 1".to_string()),
(MessageRole::Assistant, "Assistant message 1".to_string()),
(MessageRole::User, "User message 2".to_string()),
];
let mut template = Template::new(JINJA_TEMPLATE);
template.set_bos_token(Some(BOS_TOKEN.to_string()));
template.set_eos_token(Some(EOS_TOKEN.to_string()));
assert_eq!(
template.apply(msgs.as_slice()).unwrap(),
r#"[INST] User message 1 [/INST]Assistant message 1[INST] User message 2 [/INST]"#
);
}
#[test]
#[should_panic]
fn with_system_message() {
let msgs = vec![
(MessageRole::System, "System message".to_string()),
(MessageRole::User, "User message 1".to_string()),
];
let mut template = Template::new(JINJA_TEMPLATE);
template.set_bos_token(Some(BOS_TOKEN.to_string()));
template.set_eos_token(Some(EOS_TOKEN.to_string()));
assert_eq!(
template.apply(msgs.as_slice()).unwrap(),
r##"<|begin_of_text|><|start_header_id|>system<|end_header_id|>
System message<|eot_id|><|start_header_id|>user<|end_header_id|>
User message 1<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"##
);
}