<#@ template cleanws="true" #> use std::borrow::Cow; use std::marker::PhantomData; use std::mem; use num_traits::{FromPrimitive, ToPrimitive}; use tsproto::packets::{Direction, Flags, InCommand, OutCommand, OutPacket, PacketType}; /* Resulting code: Static arguments which depend on 'cmd go into the rental struct, others into the normal struct. Example for: sendmessage targetmode=1 target=2 static_arg_s=a\|b msg=bc|msg=a\|c */ pub trait InMessageTrait<'a> { fn new(cmd: &'a InCommand) -> Result where Self: Sized; } #[derive(Debug)] pub struct InMessage { cmd: InCommand, msg: InMessages<'static>, } #[derive(Debug)] pub enum InMessages<'a> { <# for msg_group in &self.0.msg_group { for msg in &msg_group.msg { #> <#= msg.name #>(In<#= msg.name #><'a>), <# } } #> Other, } impl InMessage { pub fn new(cmd: InCommand) -> Result { let mut res = Self { cmd, msg: InMessages::Other }; { // Parse message let msg: InMessages = loop { match res.cmd.data().name { <# let mut msgs: Vec<_> = self.0.msg_group.iter() .flat_map(|g| g.msg.iter()) .filter(|m| m.notify.is_some()) .collect(); msgs.sort_unstable_by_key(|m| m.notify.as_ref().map(|s| s.as_str()).unwrap()); for (notify, group) in &msgs.iter().group_by(|m| m.notify.as_ref().map(|s| s.as_str()).unwrap()) { #> "<#= notify #>" => {<# let group: Vec<_> = group.collect(); let (msg, group) = group.split_last().unwrap(); for msg in group { #> if let Ok(r) = In<#= msg.name #>::new(&res.cmd) { break InMessages::<#= msg.name #>(r); }<# } #> match In<#= msg.name #>::new(&res.cmd) { Ok(msg) => break InMessages::<#= msg.name #>(msg), Err(e) => return Err((res.cmd, e)), }} <# } #> s => { let s = s.to_string(); return Err((res.cmd, ParseError::UnknownCommand(s))); } }}; res.msg = unsafe { mem::transmute(msg) }; } Ok(res) } #[inline] pub fn command(&self) -> &InCommand { &self.cmd } #[inline] pub fn msg(&self) -> &InMessages { &self.msg } #[inline] pub fn into_command(self) -> InCommand { self.cmd } } // Statics would look like this inside a struct: // pub targetmode: TextMessageTargetMode, // pub static_arg_s: &'a str, <# for msg_group in &self.0.msg_group { for msg in &msg_group.msg { #> #[derive(Debug)] pub struct In<#= msg.name #><'a> { <# if msg_group.default.response { #> pub return_code: Option<&'a str>, <# } #> list: Vec<<#= msg.name #>Part<'a>>, } #[derive(Debug)] pub struct <#= msg.name #>Part<'a> { <# for a in &msg.attributes { let field = self.0.get_field(a); if field.get_rust_name() != field.ts { #> /// `<#= field.ts #>` in TeamSpeak. <# } #> pub <#= field.get_rust_name() #>: <#= field.get_rust_type(a, true).replace("&", "&'a ").replace("UidRef", "UidRef<'a>") #>, <# } /// Use the lifetime and make this struct non-exhaustive. // TODO But how do we create the parts then? #> pub phantom: PhantomData<&'a ()>, } impl<'a> InMessageTrait<'a> for In<#= msg.name #><'a> { fn new(cmd: &'a InCommand) -> Result { let data = cmd.data(); if data.name != "<#= msg.notify.as_ref().map(|s| s.as_str()).unwrap_or("") #>" { return Err(ParseError::WrongCommand(data.name.to_string())); } if <#= if msg_group.default.np { "!" } else { "" } #>cmd.newprotocol() { return Err(ParseError::WrongNewprotocol(cmd.newprotocol())); } if cmd.packet_type() != PacketType::Command<#= if msg_group.default.low { "Low" } else { "" } #> { return Err(ParseError::WrongPacketType(cmd.packet_type())); } <# if !msg_group.default.s2c { #> if cmd.direction() == Direction::S2C { return Err(ParseError::WrongDirection(cmd.direction())); } <# } #> <# if !msg_group.default.c2s { #> if cmd.direction() == Direction::C2S { return Err(ParseError::WrongDirection(cmd.direction())); } <# } #> <# /* // Statics let it = cmd.iter(); let arg = it.statics.get("targetmode") .ok_or_else(|| format_err!("Static argument targetmode not found in < #= msg.name # >"))?; let targetmode = TextMessageTargetMode::from_u32(arg.parse()?) .ok_or_else(|| format_err!("Invalid targetmode {} found in < #= msg.name # >", arg))?; let arg = it.statics.get("target"); let target; if let Some(arg) = arg { target = Some(arg.parse()?); } else { target = None; } let static_arg_s = *it.statics.get("static_arg_s") .ok_or_else(|| format_err!("Static argument static_arg_s not found in < #= msg.name # >"))?; */ #> <# if msg_group.default.response { #> // Get return code let return_code = cmd.data().static_args.iter() .find(|(k, _)| *k == "return_code") .map(|(_, v)| v.as_ref()); <# } #> // List arguments let mut list = Vec::new(); for <#= if msg.attributes.is_empty() { "_" } else { "ccmd" } #> in cmd.iter() { list.push(<#= msg.name #>Part { <# for a in &msg.attributes { let field = self.0.get_field(a); #> <#= field.get_rust_name() #>: { <# if !a.ends_with('?') { /* is not optional */ #> let val = ccmd.0.get("<#= field.ts #>") .ok_or(ParseError::ParameterNotFound { arg: "<#= field.pretty #>", name: "<#= msg.name #>", })?; <#= generate_deserializer(field) #> }, <# } else { #> if let Some(val) = ccmd.0.get("<#= field.ts #>") { Some({ <#= generate_deserializer(field) #> }) } else { None } }, <# } #> <# } #> phantom: PhantomData, }); } // TODO Still missing: Warn if there are more arguments than we parsed Ok(In<#= msg.name #> {<# if msg_group.default.response { #> return_code,<# } #> list }) } } impl<'a> In<#= msg.name #><'a> { #[inline] pub fn iter(&self) -> InMessageIterator<<#= msg.name #>Part> { self.into_iter() } } impl<'a> IntoIterator for &'a In<#= msg.name #><'a> { type Item = &'a <#= msg.name #>Part<'a>; type IntoIter = InMessageIterator<'a, <#= msg.name #>Part<'a>>; #[inline] fn into_iter(self) -> Self::IntoIter { InMessageIterator(&self.list, 0) } } <# } } #> /// The iterator is guaranteed to output at least one part. pub struct InMessageIterator<'a, T>(&'a [T], usize); impl<'a, T> Iterator for InMessageIterator<'a, T> { type Item = &'a T; fn next(&mut self) -> Option { let i = self.1; self.1 += 1; self.0.get(i) } } <# for msg_group in &self.0.msg_group { let defs = &msg_group.default; for msg in &msg_group.msg { #> pub struct Out<#= msg.name #>Message; impl Out<#= msg.name #>Message { pub fn new<'a, I: IteratorPart<'a>>>(list: I<#= if msg_group.default.response { ", return_code: Option<&str>" } else { "" } #>) -> OutPacket { let mut packet = OutPacket::new_with_dir(Direction::<#= if msg_group.default.s2c { "S2C" } else { "C2S" } #>, Flags::<#= if msg_group.default.np { "NEWPROTOCOL" } else { "empty()" } #>, PacketType::Command<#= if msg_group.default.low { "Low" } else { "" } #>); <# if msg_group.default.response { #> let static_args = return_code.iter().map(|s| ("return_code", Cow::Borrowed(*s))); <# } else { #> let static_args = std::iter::empty(); <# } if msg.attributes.is_empty() { #> let list_args = list.map(|_| { std::iter::empty() }); <# } else { #> let list_args = list.map(|p| { let mut res = Vec::new(); <# for a in &msg.attributes { let field = self.0.get_field(a); let val = format!("p.{}", field.get_rust_name()); if a.ends_with('?') { #> if let Some(val) = <#= val #> { res.push(("<#= field.ts #>", <#= generate_serializer(field, "val") #>)); } <# } else { #> res.push(("<#= field.ts #>", <#= generate_serializer(field, &val) #>)); <# } #> <# } #> res.into_iter() }); <# } #> OutCommand::new_into::<&'static str, Cow, &'static str, Cow, _, _, _>( "<#= msg.notify.as_ref().map(|s| s.as_str()).unwrap_or("") #>", static_args, list_args, packet.data_mut()); packet } } <# } } #> <# fn generate_deserializer(field: &Field) -> String { let rust_type = field.get_rust_type("", true); if rust_type.starts_with("Vec<") { vector_value_deserializer(field) } else { single_value_deserializer(field, &rust_type) } } fn single_value_deserializer(field: &Field, rust_type: &str) -> String { let res = match rust_type { "i8" | "u8" | "i16" | "u16" | "i32" | "u32" | "i64" | "u64" => format!("val.parse().map_err(|e| ParseError::ParseInt {{ arg: \"{}\", value: val.to_string(), error: e, }})?", field.pretty), "f32" | "f64" => format!("val.parse().map_err(|e| ParseError::ParseFloat {{ arg: \"{}\", value: val.to_string(), error: e, }})?", field.pretty), "bool" => format!("match *val {{ \"0\" => false, \"1\" => true, _ => Err(ParseError::ParseBool {{ arg: \"{}\", value: val.to_string(), }})? }}", field.pretty), "UidRef" => "UidRef(val)".into(), "&str" => "val".into(), "IconHash" => format!("IconHash(if val.starts_with('-') {{ val.parse::().map(|i| i as u32) }} else {{ val.parse::().map(|i| i as u32) }}.map_err(|e| ParseError::ParseInt {{ arg: \"{}\", value: val.to_string(), error: e, }})?)", field.pretty), "ClientId" | "ClientDbId" | "ChannelId" | "ServerGroupId" | "ChannelGroupId" => format!("{}(val.parse().map_err(|e| ParseError::ParseInt {{ arg: \"{}\", value: val.to_string(), error: e, }})?)", rust_type, field.pretty), "TextMessageTargetMode" | "HostMessageMode" | "HostBannerMode" | "LicenseType" | "LogLevel" | "Codec" | "CodecEncryptionMode" | "Reason" | "ClientType" | "GroupNamingMode" | "GroupType" | "Permission" | "PermissionType" | "TokenType" | "PluginTargetMode" | "Error" => format!("{}::from_u32(val.parse().map_err(|e| ParseError::ParseInt {{ arg: \"{}\", value: val.to_string(), error: e, }})?).ok_or(ParseError::InvalidValue {{ arg: \"{1}\", value: val.to_string(), }})?", rust_type, field.pretty), "Duration" => if field.type_s == "DurationSeconds" { format!("let val = val.parse::().map_err(|e| ParseError::ParseInt {{ arg: \"{}\", value: val.to_string(), error: e, }})?; if let Some(_) = val.checked_mul(1000) {{ Duration::seconds(val) }} else {{ Err(ParseError::InvalidValue {{ arg: \"{0}\", value: val.to_string(), }})? }}", field.pretty) } else if field.type_s == "DurationMilliseconds" { format!("Duration::milliseconds(val.parse::().map_err(|e| ParseError::ParseInt {{ arg: \"{}\", value: val.to_string(), error: e, }})?)", field.pretty) } else { panic!("Unknown original time type {} found.", field.type_s); }, "DateTime" => format!("DateTime::::from_utc(NaiveDateTime::from_timestamp_opt(val.parse().map_err(|e| ParseError::ParseInt {{ arg: \"{}\", value: val.to_string(), error: e, }})?, 0).ok_or(ParseError::InvalidValue {{ arg: \"{0}\", value: val.to_string(), }})?, Utc)", field.pretty), _ => panic!("Unknown type '{}'", rust_type), }; if res.contains('\n') { indent(&res, 2) } else { res } } fn vector_value_deserializer(field: &Field) -> String { let rust_type = field.get_rust_type("", true); let inner_type = &rust_type[4..rust_type.len()-1]; String::from(format!("val.split(',').map(|val| Ok({})).collect::, ParseError>>()?", single_value_deserializer(field, inner_type), inner_type)) } fn generate_serializer(field: &Field, name: &str) -> String { let rust_type = field.get_rust_type("", true); if rust_type.starts_with("Vec<") { let inner_type = &rust_type[4..rust_type.len()-1]; vector_value_serializer(field, inner_type, name) } else { single_value_serializer(field, &rust_type, name) } } fn single_value_serializer(field: &Field, rust_type: &str, name: &str) -> String { match rust_type { "i8" | "u8" | "i16" | "u16" | "i32" | "u32" | "i64" | "u64" | "f32" | "f64" => format!("Cow::Owned({}.to_string())", name), "bool" => format!("Cow::Borrowed(if {} {{ \"1\" }} else {{ \"0\" }})", name), "&str" => format!("Cow::Borrowed({})", name), "UidRef" => format!("Cow::Borrowed({}.0)", name), "ClientId" | "ClientDbId" | "ChannelId" | "ServerGroupId" | "ChannelGroupId" | "IconHash" => format!("Cow::Owned({}.0.to_string())", name), "TextMessageTargetMode" | "HostMessageMode" | "HostBannerMode" | "LicenseType" | "LogLevel" | "Codec" | "CodecEncryptionMode" | "Reason" | "ClientType" | "GroupNamingMode" | "GroupType" | "Permission" | "PermissionType" | "TokenType" | "PluginTargetMode" | "Error" => format!("Cow::Owned({}.to_u32().unwrap().to_string())", name), "Duration" => if field.type_s == "DurationSeconds" { format!("Cow::Owned({}.num_seconds().to_string())", name) } else if field.type_s == "DurationMilliseconds" { format!("Cow::Owned({}.num_milliseconds().to_string())", name) } else { panic!("Unknown original time type {} found.", field.type_s); }, "DateTime" => format!("Cow::Owned({}.timestamp().to_string())", name), _ => panic!("Unknown type '{}'", rust_type), } } fn vector_value_serializer(field: &Field, inner_type: &str, name: &str) -> String { format!("{{ let mut s = String::new(); for val in {} {{ if !s.is_empty() {{ s += \",\" }} let t: Cow = {}; s += t.as_ref(); }} Cow::Owned(s) }}", name, single_value_serializer(field, inner_type, "val")) } #>