#![allow(deprecated, unused_must_use)] extern crate futures; extern crate tokio_codec; extern crate tokio_udp; #[macro_use] extern crate tokio_io; extern crate bytes; extern crate env_logger; use std::io; use std::net::SocketAddr; use futures::{Future, Poll, Sink, Stream}; use bytes::{BufMut, BytesMut}; use tokio_codec::{Decoder, Encoder, LinesCodec}; use tokio_udp::{UdpFramed, UdpSocket}; macro_rules! t { ($e:expr) => { match $e { Ok(e) => e, Err(e) => panic!("{} failed with {:?}", stringify!($e), e), } }; } fn send_messages(send: S, recv: R) { let mut a = t!(UdpSocket::bind(&([127, 0, 0, 1], 0).into())); let mut b = t!(UdpSocket::bind(&([127, 0, 0, 1], 0).into())); let a_addr = t!(a.local_addr()); let b_addr = t!(b.local_addr()); { let send = SendMessage::new(a, send.clone(), b_addr, b"1234"); let recv = RecvMessage::new(b, recv.clone(), a_addr, b"1234"); let (sendt, received) = t!(send.join(recv).wait()); a = sendt; b = received; } { let send = SendMessage::new(a, send, b_addr, b""); let recv = RecvMessage::new(b, recv, a_addr, b""); t!(send.join(recv).wait()); } } #[test] fn send_to_and_recv_from() { send_messages(SendTo {}, RecvFrom {}); } #[test] fn send_and_recv() { send_messages(Send {}, Recv {}); } trait SendFn { fn send(&self, &mut UdpSocket, &[u8], &SocketAddr) -> Result; } #[derive(Debug, Clone)] struct SendTo {} impl SendFn for SendTo { fn send( &self, socket: &mut UdpSocket, buf: &[u8], addr: &SocketAddr, ) -> Result { socket.send_to(buf, addr) } } #[derive(Debug, Clone)] struct Send {} impl SendFn for Send { fn send( &self, socket: &mut UdpSocket, buf: &[u8], addr: &SocketAddr, ) -> Result { socket.connect(addr).expect("could not connect"); socket.send(buf) } } struct SendMessage { socket: Option, send: S, addr: SocketAddr, data: &'static [u8], } impl SendMessage { fn new(socket: UdpSocket, send: S, addr: SocketAddr, data: &'static [u8]) -> SendMessage { SendMessage { socket: Some(socket), send: send, addr: addr, data: data, } } } impl Future for SendMessage { type Item = UdpSocket; type Error = io::Error; fn poll(&mut self) -> Poll { let n = try_nb!(self .send .send(self.socket.as_mut().unwrap(), &self.data[..], &self.addr)); assert_eq!(n, self.data.len()); Ok(self.socket.take().unwrap().into()) } } trait RecvFn { fn recv(&self, &mut UdpSocket, &mut [u8], &SocketAddr) -> Result; } #[derive(Debug, Clone)] struct RecvFrom {} impl RecvFn for RecvFrom { fn recv( &self, socket: &mut UdpSocket, buf: &mut [u8], expected_addr: &SocketAddr, ) -> Result { socket.recv_from(buf).map(|(s, addr)| { assert_eq!(addr, *expected_addr); s }) } } #[derive(Debug, Clone)] struct Recv {} impl RecvFn for Recv { fn recv( &self, socket: &mut UdpSocket, buf: &mut [u8], _: &SocketAddr, ) -> Result { socket.recv(buf) } } struct RecvMessage { socket: Option, recv: R, expected_addr: SocketAddr, expected_data: &'static [u8], } impl RecvMessage { fn new( socket: UdpSocket, recv: R, expected_addr: SocketAddr, expected_data: &'static [u8], ) -> RecvMessage { RecvMessage { socket: Some(socket), recv: recv, expected_addr: expected_addr, expected_data: expected_data, } } } impl Future for RecvMessage { type Item = UdpSocket; type Error = io::Error; fn poll(&mut self) -> Poll { let mut buf = vec![0u8; 10 + self.expected_data.len() * 10]; let n = try_nb!(self.recv.recv( &mut self.socket.as_mut().unwrap(), &mut buf[..], &self.expected_addr )); assert_eq!(n, self.expected_data.len()); assert_eq!(&buf[..self.expected_data.len()], &self.expected_data[..]); Ok(self.socket.take().unwrap().into()) } } #[test] fn send_dgrams() { let mut a = t!(UdpSocket::bind(&t!("127.0.0.1:0".parse()))); let mut b = t!(UdpSocket::bind(&t!("127.0.0.1:0".parse()))); let mut buf = [0u8; 50]; let b_addr = t!(b.local_addr()); { let send = a.send_dgram(&b"4321"[..], &b_addr); let recv = b.recv_dgram(&mut buf[..]); let (sendt, received) = t!(send.join(recv).wait()); assert_eq!(received.2, 4); assert_eq!(&received.1[..4], b"4321"); a = sendt.0; b = received.0; } { let send = a.send_dgram(&b""[..], &b_addr); let recv = b.recv_dgram(&mut buf[..]); let received = t!(send.join(recv).wait()).1; assert_eq!(received.2, 0); } } pub struct ByteCodec; impl Decoder for ByteCodec { type Item = Vec; type Error = io::Error; fn decode(&mut self, buf: &mut BytesMut) -> Result>, io::Error> { let len = buf.len(); Ok(Some(buf.split_to(len).to_vec())) } } impl Encoder for ByteCodec { type Item = Vec; type Error = io::Error; fn encode(&mut self, data: Vec, buf: &mut BytesMut) -> Result<(), io::Error> { buf.reserve(data.len()); buf.put(data); Ok(()) } } #[test] fn send_framed_byte_codec() { drop(env_logger::try_init()); let mut a_soc = t!(UdpSocket::bind(&t!("127.0.0.1:0".parse()))); let mut b_soc = t!(UdpSocket::bind(&t!("127.0.0.1:0".parse()))); let a_addr = t!(a_soc.local_addr()); let b_addr = t!(b_soc.local_addr()); { let a = UdpFramed::new(a_soc, ByteCodec); let b = UdpFramed::new(b_soc, ByteCodec); let msg = b"4567".to_vec(); let send = a.send((msg.clone(), b_addr)); let recv = b.into_future().map_err(|e| e.0); let (sendt, received) = t!(send.join(recv).wait()); let (data, addr) = received.0.unwrap(); assert_eq!(msg, data); assert_eq!(a_addr, addr); a_soc = sendt.into_inner(); b_soc = received.1.into_inner(); } { let a = UdpFramed::new(a_soc, ByteCodec); let b = UdpFramed::new(b_soc, ByteCodec); let msg = b"".to_vec(); let send = a.send((msg.clone(), b_addr)); let recv = b.into_future().map_err(|e| e.0); let received = t!(send.join(recv).wait()).1; let (data, addr) = received.0.unwrap(); assert_eq!(msg, data); assert_eq!(a_addr, addr); } } #[test] fn send_framed_lines_codec() { drop(env_logger::try_init()); let a_soc = t!(UdpSocket::bind(&t!("127.0.0.1:0".parse()))); let b_soc = t!(UdpSocket::bind(&t!("127.0.0.1:0".parse()))); let a_addr = t!(a_soc.local_addr()); let b_addr = t!(b_soc.local_addr()); let a = UdpFramed::new(a_soc, ByteCodec); let b = UdpFramed::with_decode(b_soc, LinesCodec::new(), true); let msg = b"1\r\n2\r\n3\r\n".to_vec(); let send = a.send((msg.clone(), b_addr)); t!(send.wait()); let mut recv = Stream::wait(b).map(|e| e.unwrap()); assert_eq!(recv.next(), Some(("1".to_string(), a_addr))); assert_eq!(recv.next(), Some(("2".to_string(), a_addr))); assert_eq!(recv.next(), Some(("3".to_string(), a_addr))); } #[test] fn recv_framed_codec_errs() { drop(env_logger::try_init()); #[derive(Debug)] struct LinesCodecMaxLen { max_len: usize, codec: LinesCodec, } impl LinesCodecMaxLen { fn new(max_len: usize) -> Self { Self { max_len, codec: LinesCodec::new(), } } } impl Decoder for LinesCodecMaxLen { type Item = String; type Error = io::Error; fn decode(&mut self, buf: &mut BytesMut) -> Result, io::Error> { let opt_string = self.codec.decode_eof(buf)?; match opt_string { None => Ok(None), Some(string) => { if string.len() > self.max_len { Err(io::Error::new(io::ErrorKind::InvalidData, "Too big")) } else { Ok(Some(string)) } } } } } let a_soc = t!(UdpSocket::bind(&t!("127.0.0.1:0".parse()))); let b_soc = t!(UdpSocket::bind(&t!("127.0.0.1:0".parse()))); let a_addr = t!(a_soc.local_addr()); let b_addr = t!(b_soc.local_addr()); { let a = UdpFramed::new(a_soc, ByteCodec); let b = UdpFramed::new(b_soc, LinesCodecMaxLen::new(/*max_len*/ 1)); let msg = b"hello world".to_vec(); // hello world is too big let send = a.send((msg.clone(), b_addr)); let a = t!(send.wait()); let msg = b"1\r\n".to_vec(); // fits ok let send = a.send((msg.clone(), b_addr)); t!(send.wait()); let mut b = Stream::wait(b); let hello_world = b.next().unwrap(); assert!(hello_world.is_err()); // first one is too big let mut recv = b.map(|e| e.unwrap()); // and then we restore the state and continue receiving assert_eq!(recv.next(), Some(("1".to_string(), a_addr))); } } #[test] fn send_framed_lines_codec_with_non_terminating_frame() { drop(env_logger::try_init()); let a_soc = t!(UdpSocket::bind(&t!("127.0.0.1:0".parse()))); let b_soc = t!(UdpSocket::bind(&t!("127.0.0.1:0".parse()))); let a_addr = t!(a_soc.local_addr()); let b_addr = t!(b_soc.local_addr()); let a = UdpFramed::new(a_soc, ByteCodec); let b = UdpFramed::with_decode(b_soc, LinesCodec::new(), true); // This has no terminating delimiter thus we want to return the rest of the // frame and this tests that if decode fails, we try to decode_eof. let msg = b"1\r\n2\r\n3".to_vec(); let send = a.send((msg.clone(), b_addr)); t!(send.wait()); let mut recv = Stream::wait(b).map(|e| e.unwrap()); assert_eq!(recv.next(), Some(("1".to_string(), a_addr))); assert_eq!(recv.next(), Some(("2".to_string(), a_addr))); assert_eq!(recv.next(), Some(("3".to_string(), a_addr))); } #[test] fn recv_multi_framed_lines_codec_errs() { drop(env_logger::try_init()); #[derive(Debug)] struct LinesCodecMaxLen { max_len: usize, codec: LinesCodec, } impl LinesCodecMaxLen { fn new(max_len: usize) -> Self { Self { max_len, codec: LinesCodec::new(), } } } impl Decoder for LinesCodecMaxLen { type Item = String; type Error = io::Error; fn decode(&mut self, buf: &mut BytesMut) -> Result, io::Error> { return self.codec.decode(buf); } fn decode_eof(&mut self, buf: &mut BytesMut) -> Result, io::Error> { let opt_string = self.codec.decode_eof(buf)?; match opt_string { None => Ok(None), Some(string) => { if string.len() > self.max_len { Err(io::Error::new(io::ErrorKind::InvalidData, "Too big")) } else { Ok(Some(string)) } } } } } let a_soc = t!(UdpSocket::bind(&t!("127.0.0.1:0".parse()))); let b_soc = t!(UdpSocket::bind(&t!("127.0.0.1:0".parse()))); let a_addr = t!(a_soc.local_addr()); let b_addr = t!(b_soc.local_addr()); let a = UdpFramed::new(a_soc, ByteCodec); let b = UdpFramed::with_decode(b_soc, LinesCodecMaxLen::new(/*max_len*/ 1), true); let msg = b"hello world".to_vec(); // hello world is too big let send = a.send((msg.clone(), b_addr)); let a = t!(send.wait()); let msg = b"1\r\n2\r\n3\r\n".to_vec(); let send = a.send((msg.clone(), b_addr)); t!(send.wait()); let mut b = Stream::wait(b); let hello_world = b.next().unwrap(); assert!(hello_world.is_err()); // first one is too big let mut recv = b.map(|e| e.unwrap()); // and then we restore the state and continue receiving assert_eq!(recv.next(), Some(("1".to_string(), a_addr))); assert_eq!(recv.next(), Some(("2".to_string(), a_addr))); assert_eq!(recv.next(), Some(("3".to_string(), a_addr))); }