use std::str::FromStr; use serde::{Serialize, Deserialize}; #[derive(Debug, Copy, Clone, Default, PartialEq)] pub enum Model { {% for model in models %} {% if model.codename == default_model_codename %} #[default] {% endif %} {{ model.enumname }}, {% endfor %} } impl FromStr for Model { type Err = String; fn from_str(s: &str) -> Result { match s { {% for model in models %} "{{ model.codename }}" => Ok(Self::{{ model.enumname }}), {% endfor %} _ => Err(format!("{} is not a valid model", s)), } } } impl ToString for Model { fn to_string(&self) -> String { match self { {% for model in models %} Self::{{ model.enumname }} { .. } => String::from("{{ model.codename }}"), {% endfor %} } } } impl Serialize for Model { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { serializer.serialize_str(&self.to_string()) } } impl<'de> Deserialize<'de> for Model { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { let s = String::deserialize(deserializer)?; Self::from_str(&s).map_err(serde::de::Error::custom) } } impl Model { pub fn all() -> Vec { vec![ {% for model in models %} Self::{{ model.enumname }}, {% endfor %} ] } pub fn cost(&self, prompt_tokens: usize, completion_tokens: usize) -> f64 { let (prompt_cost, completion_cost) = match self { {% for model in models %} Self::{{ model.enumname }} => ({{ model.prompt_cost }}, {{ model.completion_cost }}), {% endfor %} }; (prompt_tokens as f64).mul_add( prompt_cost / 1000000.0, (completion_tokens as f64) * (completion_cost / 1000000.0), ) } pub const fn context_size(&self) -> usize { match self { {% for model in models %} Self::{{ model.enumname }} => {{ model.context_size }}, {% endfor %} } } }