#![deny(warnings)] extern crate hyper; use std::net::{TcpStream, SocketAddr}; use std::io::{self, Read, Write}; use std::sync::mpsc; use std::time::Duration; use hyper::{Next, Encoder, Decoder}; use hyper::net::{HttpListener, HttpStream}; use hyper::server::{Server, Handler, Request, Response}; struct Serve { listening: Option, msg_rx: mpsc::Receiver, reply_tx: mpsc::Sender, } impl Serve { fn addrs(&self) -> &[SocketAddr] { self.listening.as_ref().unwrap().addrs() } fn addr(&self) -> &SocketAddr { let addrs = self.addrs(); assert!(addrs.len() == 1); &addrs[0] } /* fn head(&self) -> Request { unimplemented!() } */ fn body(&self) -> Vec { let mut buf = vec![]; while let Ok(Msg::Chunk(msg)) = self.msg_rx.try_recv() { buf.extend(&msg); } buf } fn reply(&self) -> ReplyBuilder { ReplyBuilder { tx: &self.reply_tx } } } struct ReplyBuilder<'a> { tx: &'a mpsc::Sender, } impl<'a> ReplyBuilder<'a> { fn status(self, status: hyper::StatusCode) -> Self { self.tx.send(Reply::Status(status)).unwrap(); self } fn header(self, header: H) -> Self { let mut headers = hyper::Headers::new(); headers.set(header); self.tx.send(Reply::Headers(headers)).unwrap(); self } fn body>(self, body: T) { self.tx.send(Reply::Body(body.as_ref().into())).unwrap(); } } impl Drop for Serve { fn drop(&mut self) { self.listening.take().unwrap().close(); } } struct TestHandler { tx: mpsc::Sender, reply: Vec, peeked: Option>, timeout: Option, } enum Reply { Status(hyper::StatusCode), Headers(hyper::Headers), Body(Vec), } enum Msg { //Head(Request), Chunk(Vec), } impl TestHandler { fn next(&self, next: Next) -> Next { if let Some(dur) = self.timeout { next.timeout(dur) } else { next } } } impl Handler for TestHandler { fn on_request(&mut self, _req: Request) -> Next { //self.tx.send(Msg::Head(req)).unwrap(); self.next(Next::read()) } fn on_request_readable(&mut self, decoder: &mut Decoder) -> Next { let mut vec = vec![0; 1024]; match decoder.read(&mut vec) { Ok(0) => { self.next(Next::write()) } Ok(n) => { vec.truncate(n); self.tx.send(Msg::Chunk(vec)).unwrap(); self.next(Next::read()) } Err(e) => match e.kind() { io::ErrorKind::WouldBlock => self.next(Next::read()), _ => panic!("test error: {}", e) } } } fn on_response(&mut self, res: &mut Response) -> Next { for reply in self.reply.drain(..) { match reply { Reply::Status(s) => { res.set_status(s); }, Reply::Headers(headers) => { use std::iter::Extend; res.headers_mut().extend(headers.iter()); }, Reply::Body(body) => { self.peeked = Some(body); }, } } if self.peeked.is_some() { self.next(Next::write()) } else { self.next(Next::end()) } } fn on_response_writable(&mut self, encoder: &mut Encoder) -> Next { match self.peeked { Some(ref body) => { encoder.write(body).unwrap(); self.next(Next::end()) }, None => self.next(Next::end()) } } } fn serve() -> Serve { serve_with_timeout(None) } fn serve_with_timeout(dur: Option) -> Serve { serve_n_with_timeout(1, dur) } fn serve_n(n: u32) -> Serve { serve_n_with_timeout(n, None) } fn serve_n_with_timeout(n: u32, dur: Option) -> Serve { use std::thread; let (msg_tx, msg_rx) = mpsc::channel(); let (reply_tx, reply_rx) = mpsc::channel(); let addr = "127.0.0.1:0".parse().unwrap(); let listeners = (0..n).map(|_| HttpListener::bind(&addr).unwrap()); let (listening, server) = Server::new(listeners) .handle(move |_| { let mut replies = Vec::new(); while let Ok(reply) = reply_rx.try_recv() { replies.push(reply); } TestHandler { tx: msg_tx.clone(), timeout: dur, reply: replies, peeked: None, } }).unwrap(); let thread_name = format!("test-server-{}: {:?}", listening, dur); thread::Builder::new().name(thread_name).spawn(move || { server.run(); }).unwrap(); Serve { listening: Some(listening), msg_rx: msg_rx, reply_tx: reply_tx, } } #[test] fn server_get_should_ignore_body() { let server = serve(); let mut req = TcpStream::connect(server.addr()).unwrap(); req.write_all(b"\ GET / HTTP/1.1\r\n\ Host: example.domain\r\n\ \r\n\ I shouldn't be read.\r\n\ ").unwrap(); req.read(&mut [0; 256]).unwrap(); assert_eq!(server.body(), b""); } #[test] fn server_get_with_body() { let server = serve(); let mut req = TcpStream::connect(server.addr()).unwrap(); req.write_all(b"\ GET / HTTP/1.1\r\n\ Host: example.domain\r\n\ Content-Length: 19\r\n\ \r\n\ I'm a good request.\r\n\ ").unwrap(); req.read(&mut [0; 256]).unwrap(); // note: doesnt include trailing \r\n, cause Content-Length wasn't 21 assert_eq!(server.body(), b"I'm a good request."); } #[test] fn server_get_fixed_response() { let foo_bar = b"foo bar baz"; let server = serve(); server.reply() .status(hyper::Ok) .header(hyper::header::ContentLength(foo_bar.len() as u64)) .body(foo_bar); let mut req = TcpStream::connect(server.addr()).unwrap(); req.write_all(b"\ GET / HTTP/1.1\r\n\ Host: example.domain\r\n\ Connection: close\r\n \r\n\ ").unwrap(); let mut body = String::new(); req.read_to_string(&mut body).unwrap(); let n = body.find("\r\n\r\n").unwrap() + 4; assert_eq!(&body[n..], "foo bar baz"); } #[test] fn server_get_chunked_response() { let foo_bar = b"foo bar baz"; let server = serve(); server.reply() .status(hyper::Ok) .header(hyper::header::TransferEncoding::chunked()) .body(foo_bar); let mut req = TcpStream::connect(server.addr()).unwrap(); req.write_all(b"\ GET / HTTP/1.1\r\n\ Host: example.domain\r\n\ Connection: close\r\n \r\n\ ").unwrap(); let mut body = String::new(); req.read_to_string(&mut body).unwrap(); let n = body.find("\r\n\r\n").unwrap() + 4; assert_eq!(&body[n..], "B\r\nfoo bar baz\r\n0\r\n\r\n"); } #[test] fn server_post_with_chunked_body() { let server = serve(); let mut req = TcpStream::connect(server.addr()).unwrap(); req.write_all(b"\ POST / HTTP/1.1\r\n\ Host: example.domain\r\n\ Transfer-Encoding: chunked\r\n\ \r\n\ 1\r\n\ q\r\n\ 2\r\n\ we\r\n\ 2\r\n\ rt\r\n\ 0\r\n\ \r\n ").unwrap(); req.read(&mut [0; 256]).unwrap(); assert_eq!(server.body(), b"qwert"); } /* #[test] fn server_empty_response() { let server = serve(); server.reply() .status(hyper::Ok); let mut req = TcpStream::connect(server.addr()).unwrap(); req.write_all(b"\ GET / HTTP/1.1\r\n\ Host: example.domain\r\n\ Connection: close\r\n \r\n\ ").unwrap(); let mut response = String::new(); req.read_to_string(&mut response).unwrap(); assert_eq!(response, "foo"); assert!(!response.contains("Transfer-Encoding: chunked\r\n")); let mut lines = response.lines(); assert_eq!(lines.next(), Some("HTTP/1.1 200 OK")); let mut lines = lines.skip_while(|line| !line.is_empty()); assert_eq!(lines.next(), Some("")); assert_eq!(lines.next(), None); } */ #[test] fn server_empty_response_chunked() { let server = serve(); server.reply() .status(hyper::Ok) .body(""); let mut req = TcpStream::connect(server.addr()).unwrap(); req.write_all(b"\ GET / HTTP/1.1\r\n\ Host: example.domain\r\n\ Connection: close\r\n \r\n\ ").unwrap(); let mut response = String::new(); req.read_to_string(&mut response).unwrap(); assert!(response.contains("Transfer-Encoding: chunked\r\n")); let mut lines = response.lines(); assert_eq!(lines.next(), Some("HTTP/1.1 200 OK")); let mut lines = lines.skip_while(|line| !line.is_empty()); assert_eq!(lines.next(), Some("")); // 0\r\n\r\n assert_eq!(lines.next(), Some("0")); assert_eq!(lines.next(), Some("")); assert_eq!(lines.next(), None); } #[test] fn server_empty_response_chunked_without_calling_write() { let server = serve(); server.reply() .status(hyper::Ok) .header(hyper::header::TransferEncoding::chunked()); let mut req = TcpStream::connect(server.addr()).unwrap(); req.write_all(b"\ GET / HTTP/1.1\r\n\ Host: example.domain\r\n\ Connection: close\r\n \r\n\ ").unwrap(); let mut response = String::new(); req.read_to_string(&mut response).unwrap(); assert!(response.contains("Transfer-Encoding: chunked\r\n")); let mut lines = response.lines(); assert_eq!(lines.next(), Some("HTTP/1.1 200 OK")); let mut lines = lines.skip_while(|line| !line.is_empty()); assert_eq!(lines.next(), Some("")); // 0\r\n\r\n assert_eq!(lines.next(), Some("0")); assert_eq!(lines.next(), Some("")); assert_eq!(lines.next(), None); } #[test] fn server_keep_alive() { extern crate env_logger; env_logger::init().unwrap(); let foo_bar = b"foo bar baz"; let server = serve(); server.reply() .status(hyper::Ok) .header(hyper::header::ContentLength(foo_bar.len() as u64)) .body(foo_bar); let mut req = TcpStream::connect(server.addr()).unwrap(); req.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); req.set_write_timeout(Some(Duration::from_secs(5))).unwrap(); req.write_all(b"\ GET / HTTP/1.1\r\n\ Host: example.domain\r\n\ Connection: keep-alive\r\n\ \r\n\ ").expect("writing 1"); let mut buf = [0; 1024 * 8]; loop { let n = req.read(&mut buf[..]).expect("reading 1"); if n < buf.len() { if &buf[n - foo_bar.len()..n] == foo_bar { break; } else { println!("{:?}", ::std::str::from_utf8(&buf[..n])); } } } // try again! let quux = b"zar quux"; server.reply() .status(hyper::Ok) .header(hyper::header::ContentLength(quux.len() as u64)) .body(quux); req.write_all(b"\ GET /quux HTTP/1.1\r\n\ Host: example.domain\r\n\ Connection: close\r\n\ \r\n\ ").expect("writing 2"); let mut buf = [0; 1024 * 8]; loop { let n = req.read(&mut buf[..]).expect("reading 2"); assert!(n > 0, "n = {}", n); if n < buf.len() { if &buf[n - quux.len()..n] == quux { break; } } } } #[test] fn server_get_with_body_three_listeners() { let server = serve_n(3); let addrs = server.addrs(); assert_eq!(addrs.len(), 3); for (i, addr) in addrs.iter().enumerate() { let mut req = TcpStream::connect(addr).unwrap(); write!(req, "\ GET / HTTP/1.1\r\n\ Host: example.domain\r\n\ Content-Length: 17\r\n\ \r\n\ I'm sending to {}.\r\n\ ", i).unwrap(); req.read(&mut [0; 256]).unwrap(); // note: doesnt include trailing \r\n, cause Content-Length wasn't 19 let comparison = format!("I'm sending to {}.", i).into_bytes(); assert_eq!(server.body(), comparison); } }