// RPC Client compatible with https://github.com/janestreet/async_rpc_kernel // This can be used with the example server (in OCaml) available here: // https://github.com/janestreet/async/blob/v0.14/async_rpc/example/rpc_server.ml // // RPC magic number 4_411_474 use anyhow::Result; use binprot::macros::{BinProtRead, BinProtWrite}; use binprot::{BinProtRead, BinProtSize, BinProtWrite}; use std::collections::BTreeMap; use std::io::{Read, Write}; use std::net::{TcpListener, TcpStream}; #[derive(BinProtRead, BinProtWrite, Debug, Clone, PartialEq)] struct Handshake(Vec); #[derive(BinProtRead, BinProtWrite, Clone, PartialEq)] enum Sexp { Atom(String), List(Vec), } // Dummy formatter, escaping is not handled properly. impl std::fmt::Debug for Sexp { fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { match self { Sexp::Atom(atom) => { if atom.contains(|c: char| !c.is_alphanumeric()) { fmt.write_str("\"")?; for c in atom.escape_default() { std::fmt::Write::write_char(fmt, c)?; } fmt.write_str("\"")?; } else { fmt.write_str(&atom)?; } Ok(()) } Sexp::List(list) => { fmt.write_str("(")?; for (index, sexp) in list.iter().enumerate() { if index > 0 { fmt.write_str(" ")?; } sexp.fmt(fmt)?; } fmt.write_str(")")?; Ok(()) } } } } #[derive(BinProtRead, BinProtWrite, Debug, Clone, PartialEq)] struct Query { rpc_tag: String, version: i64, id: i64, data: binprot::WithLen, } #[derive(BinProtRead, BinProtWrite, Debug, Clone, PartialEq)] #[polymorphic_variant] enum Version { Version(i64), } #[derive(BinProtRead, BinProtWrite, Debug, Clone, PartialEq)] enum RpcError { BinIoExn(Sexp), ConnectionClosed, WriteError(Sexp), UncaughtExn(Sexp), UnimplementedRpc((String, Version)), UnknownQueryId(String), } #[derive(BinProtRead, BinProtWrite, Debug, Clone, PartialEq)] enum RpcResult { Ok(binprot::WithLen), Error(RpcError), } #[derive(BinProtRead, BinProtWrite, Debug, Clone, PartialEq)] struct Response { id: i64, data: RpcResult, } #[derive(BinProtRead, BinProtWrite, Debug, Clone, PartialEq)] enum Message { Heartbeat, Query(Query), Response(Response), } fn read_bin_prot(stream: &mut TcpStream, buffer: &mut Vec) -> Result { let mut recv_bytes = [0u8; 8]; stream.read_exact(&mut recv_bytes)?; let recv_len = i64::from_le_bytes(recv_bytes); buffer.resize(recv_len as usize, 0u8); stream.read_exact(buffer)?; let mut slice = buffer.as_slice(); let data = T::binprot_read(&mut slice)?; Ok(data) } fn write_bin_prot(stream: &mut TcpStream, v: &T) -> Result<()> { let len = v.binprot_size() as i64; stream.write_all(&len.to_le_bytes())?; v.binprot_write(stream)?; Ok(()) } trait JRpc { const NAME: &'static str; const VERSION: i64; type Q; type R; } struct RpcGetUniqueId; impl JRpc for RpcGetUniqueId { const NAME: &'static str = "get-unique-id"; const VERSION: i64 = 0; type Q = (); type R = i64; } struct RpcGetUniqueIdTypo; impl JRpc for RpcGetUniqueIdTypo { const NAME: &'static str = "get-unique-id2"; const VERSION: i64 = 0; type Q = (); type R = i64; } struct RpcSetIdCounter; impl JRpc for RpcSetIdCounter { const NAME: &'static str = "set-id-counter"; const VERSION: i64 = 1; type Q = i64; type R = (); } struct RpcClient { stream: std::net::TcpStream, buffer: Vec, id: i64, } impl RpcClient { fn connect(address: &str) -> Result { let mut stream = TcpStream::connect(address)?; let mut buffer = vec![0u8; 256]; println!("Successfully connected to {}", address); let handshake: Handshake = read_bin_prot(&mut stream, &mut buffer)?; println!("Received {:?}", handshake); write_bin_prot(&mut stream, &handshake)?; Ok(RpcClient { stream, buffer, id: 0 }) } fn dispatch(&mut self, query: T::Q) -> Result> where T::Q: BinProtWrite, T::R: BinProtRead, { self.id = self.id + 1; let query = Query { rpc_tag: T::NAME.to_owned(), version: T::VERSION, id: self.id, data: binprot::WithLen(query), }; write_bin_prot(&mut self.stream, &Message::Query::(query))?; loop { let received: Message<(), T::R> = read_bin_prot(&mut self.stream, &mut self.buffer).unwrap(); match received { Message::Heartbeat => (), Message::Response(r) => return Ok(r), Message::Query(_) => (), } } } } trait JRpcImpl { type Q; // Query type R; // Response type E; // Error fn rpc_impl(&mut self, q: Self::Q) -> std::result::Result; } trait ErasedJRpcImpl { fn erased_rpc_impl(&mut self, stream: &mut TcpStream, id: i64) -> Result<()>; } //impl ErasedJRpcImpl for dyn JRpcImpl //where // Q: BinProtRead, // R: BinProtWrite, // E: std::error::Error, impl ErasedJRpcImpl for T where T: JRpcImpl, T::Q: BinProtRead, T::R: BinProtWrite, T::E: std::error::Error, { fn erased_rpc_impl(&mut self, stream: &mut TcpStream, id: i64) -> Result<()> { let query = T::Q::binprot_read(stream)?; let rpc_result = match self.rpc_impl(query) { Ok(response) => RpcResult::Ok(binprot::WithLen(response)), Err(error) => { let sexp = Sexp::Atom(error.to_string()); RpcResult::Error(RpcError::UncaughtExn(sexp)) } }; let response = Response { id, data: rpc_result }; write_bin_prot(stream, &Message::Response::<(), T::R>(response))?; Ok(()) } } #[allow(dead_code)] struct RpcServer { listener: TcpListener, buffer: Vec, id: i64, rpc_impls: BTreeMap>, } struct GetUniqueIdImpl(i64); impl JRpcImpl for GetUniqueIdImpl { type Q = (); type R = i64; type E = std::convert::Infallible; fn rpc_impl(&mut self, _q: Self::Q) -> std::result::Result { let result = self.0; self.0 += 1; Ok(result) } } // It is not easy to use [Query] on the server side as we do not // know which rpcs will be triggered. So instead we use this type // that only parses up to the length of the payload. // This only works because the payload appears last in the // serialized representation #[derive(BinProtRead, BinProtWrite, Debug, Clone, PartialEq)] struct ServerQuery { rpc_tag: String, version: i64, id: i64, data: binprot::Nat0, } #[derive(BinProtRead, BinProtWrite, Debug, Clone, PartialEq)] enum ServerMessage { Heartbeat, Query(ServerQuery), Response(R), } impl RpcServer { fn bind(address: &str) -> Result { let listener = TcpListener::bind(address)?; let buffer = vec![0u8; 256]; println!("Successfully bound to {}", address); let mut rpc_impls: BTreeMap> = BTreeMap::new(); let get_unique_id_impl: Box = Box::new(GetUniqueIdImpl(0)); rpc_impls.insert("get-unique-id".to_string(), get_unique_id_impl); Ok(RpcServer { listener, buffer, id: 0, rpc_impls }) } fn run(&mut self) -> Result<()> { for stream in self.listener.incoming() { let mut stream = stream?; println!("Got connection {:?}.", stream); write_bin_prot(&mut stream, &Handshake(vec![4411474, 1]))?; let handshake: Handshake = read_bin_prot(&mut stream, &mut self.buffer)?; println!("Received handshake {:?}", handshake); let mut recv_bytes = [0u8; 8]; loop { // We don't know the type of rpcs that will be received so the // following parses the incoming messages in a "manual" way. stream.read_exact(&mut recv_bytes)?; let _recv_len = i64::from_le_bytes(recv_bytes); let query = ServerMessage::<()>::binprot_read(&mut stream)?; println!("Received rpc query {:?}", query); match query { ServerMessage::Heartbeat => {} ServerMessage::Query(query) => match self.rpc_impls.get_mut(&query.rpc_tag) { None => { let err = RpcError::UnimplementedRpc(( query.rpc_tag, Version::Version(query.version), )); let message = ServerMessage::Response(Response::<()> { id: query.id, data: RpcResult::Error(err), }); self.buffer.resize(query.data.0 as usize, 0u8); stream.read_exact(&mut self.buffer)?; write_bin_prot(&mut stream, &message)? } Some(r) => r.erased_rpc_impl(&mut stream, query.id)?, }, ServerMessage::Response(()) => unimplemented!(), }; } } Ok(()) } } fn main() -> Result<()> { let arg = std::env::args().skip(1).next(); match arg.as_deref() { Some("client") => { let mut client = RpcClient::connect("localhost:8080")?; let response = client.dispatch::(())?; println!(">> {:?}", response); let response = client.dispatch::(())?; println!(">> {:?}", response); let response = client.dispatch::(42)?; println!(">> {:?}", response); let response = client.dispatch::(())?; println!(">> {:?}", response); let response = client.dispatch::(0)?; println!(">> {:?}", response); let response = client.dispatch::(())?; println!(">> {:?}", response); } Some("server") => { let mut server = RpcServer::bind("localhost:8080")?; server.run()? } Some(_) => { panic!("unexpected argument, try client or server") } None => { panic!("missing argument") } } Ok(()) }