//! A chat server that broadcasts a message to all connections. //! //! This is a simple line-based server which accepts WebSocket connections, //! reads lines from those connections, and broadcasts the lines to all other //! connected clients. //! //! You can test this out by running: //! //! cargo run --example server 127.0.0.1:12345 //! //! And then in another window run: //! //! cargo run --example client ws://127.0.0.1:12345/socket //! //! You can run the second command in multiple windows and then chat between the //! two, seeing the messages from the other client as they're received. For all //! connected clients they'll all join the same room and see everyone else's //! messages. use std::{ collections::HashMap, convert::Infallible, env, net::SocketAddr, sync::{Arc, Mutex}, }; use hyper::{ header::{ HeaderValue, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE, }, server::conn::AddrStream, service::{make_service_fn, service_fn}, upgrade::Upgraded, Body, Method, Request, Response, Server, StatusCode, Version, }; use futures_channel::mpsc::{unbounded, UnboundedSender}; use futures_util::{future, pin_mut, stream::TryStreamExt, StreamExt}; use async_tungstenite::{tokio::TokioAdapter, WebSocketStream}; use tungstenite::{ handshake::derive_accept_key, protocol::{Message, Role}, }; type Tx = UnboundedSender; type PeerMap = Arc>>; async fn handle_connection( peer_map: PeerMap, ws_stream: WebSocketStream>, addr: SocketAddr, ) { println!("WebSocket connection established: {}", addr); // Insert the write part of this peer to the peer map. let (tx, rx) = unbounded(); peer_map.lock().unwrap().insert(addr, tx); let (outgoing, incoming) = ws_stream.split(); let broadcast_incoming = incoming.try_for_each(|msg| { println!( "Received a message from {}: {}", addr, msg.to_text().unwrap() ); let peers = peer_map.lock().unwrap(); // We want to broadcast the message to everyone except ourselves. let broadcast_recipients = peers .iter() .filter(|(peer_addr, _)| peer_addr != &&addr) .map(|(_, ws_sink)| ws_sink); for recp in broadcast_recipients { recp.unbounded_send(msg.clone()).unwrap(); } future::ok(()) }); let receive_from_others = rx.map(Ok).forward(outgoing); pin_mut!(broadcast_incoming, receive_from_others); future::select(broadcast_incoming, receive_from_others).await; println!("{} disconnected", &addr); peer_map.lock().unwrap().remove(&addr); } async fn handle_request( peer_map: PeerMap, mut req: Request, addr: SocketAddr, ) -> Result, Infallible> { println!("Received a new, potentially ws handshake"); println!("The request's path is: {}", req.uri().path()); println!("The request's headers are:"); for (ref header, _value) in req.headers() { println!("* {}", header); } let upgrade = HeaderValue::from_static("Upgrade"); let websocket = HeaderValue::from_static("websocket"); let headers = req.headers(); let key = headers.get(SEC_WEBSOCKET_KEY); let derived = key.map(|k| derive_accept_key(k.as_bytes())); if req.method() != Method::GET || req.version() < Version::HTTP_11 || !headers .get(CONNECTION) .and_then(|h| h.to_str().ok()) .map(|h| { h.split(|c| c == ' ' || c == ',') .any(|p| p.eq_ignore_ascii_case(upgrade.to_str().unwrap())) }) .unwrap_or(false) || !headers .get(UPGRADE) .and_then(|h| h.to_str().ok()) .map(|h| h.eq_ignore_ascii_case("websocket")) .unwrap_or(false) || !headers .get(SEC_WEBSOCKET_VERSION) .map(|h| h == "13") .unwrap_or(false) || key.is_none() || req.uri() != "/socket" { return Ok(Response::new(Body::from("Hello World!"))); } let ver = req.version(); tokio::task::spawn(async move { match hyper::upgrade::on(&mut req).await { Ok(upgraded) => { handle_connection( peer_map, WebSocketStream::from_raw_socket( TokioAdapter::new(upgraded), Role::Server, None, ) .await, addr, ) .await; } Err(e) => println!("upgrade error: {}", e), } }); let mut res = Response::new(Body::empty()); *res.status_mut() = StatusCode::SWITCHING_PROTOCOLS; *res.version_mut() = ver; res.headers_mut().append(CONNECTION, upgrade); res.headers_mut().append(UPGRADE, websocket); res.headers_mut() .append(SEC_WEBSOCKET_ACCEPT, derived.unwrap().parse().unwrap()); // Let's add an additional header to our response to the client. res.headers_mut() .append("MyCustomHeader", ":)".parse().unwrap()); res.headers_mut() .append("SOME_TUNGSTENITE_HEADER", "header_value".parse().unwrap()); Ok(res) } #[tokio::main] async fn main() -> Result<(), hyper::Error> { let state = PeerMap::new(Mutex::new(HashMap::new())); let addr = env::args() .nth(1) .unwrap_or_else(|| "127.0.0.1:8080".to_string()) .parse() .unwrap(); let make_svc = make_service_fn(move |conn: &AddrStream| { let remote_addr = conn.remote_addr(); let state = state.clone(); let service = service_fn(move |req| handle_request(state.clone(), req, remote_addr)); async { Ok::<_, Infallible>(service) } }); let server = Server::bind(&addr).serve(make_svc); server.await?; Ok::<_, hyper::Error>(()) }