use axum::{ extract::{ ws::{Message as WsMessage, WebSocket, WebSocketUpgrade}, Json, Query, State, }, response::{ sse::{Event as SseEvent, Sse}, Html, IntoResponse, }, routing::{get, post}, Router, }; use futures::stream::Stream; use serde::{Deserialize, Serialize}; use std::convert::Infallible; use std::net::SocketAddr; use tagged_channels::TaggedChannels; #[derive(Clone, Eq, Hash, PartialEq)] enum ChannelTag { UserId(i32), IsAdmin, } #[derive(Deserialize, Serialize)] #[serde(tag = "_type")] enum EventMessage { User(UserMessage), Admin(SimpleMessage), Broadcast(SimpleMessage), } #[derive(Deserialize, Serialize)] struct UserMessage { user_id: i32, message: String, } #[derive(Deserialize, Serialize)] struct SimpleMessage { message: String, } #[derive(Deserialize)] struct ConnectionParams { user_id: Option, is_admin: bool, } #[tokio::main] async fn main() { let channels = TaggedChannels::new(); let app = Router::new() .route("/", get(index)) .route("/send", post(send)) .route("/sse", get(sse_ui)) .route("/ws", get(ws_ui)) .route("/sse-events", get(events)) .route("/ws-events", get(ws_events)) .with_state(channels); let addr = SocketAddr::from(([0, 0, 0, 0], 3000)); axum::Server::bind(&addr) .serve(app.into_make_service_with_connect_info::()) .await .unwrap(); } async fn index() -> Html { let page = [("WebSocket", "/ws"), ("SSE", "/sse")] .iter() .map(|(name, url)| format!(r#"
  • {name} example
  • "#)) .collect(); Html(page) } async fn sse_ui() -> Html { Html(include_str!("ui.html").replace("{{example}}", "sse")) } async fn ws_ui() -> Html { Html(include_str!("ui.html").replace("{{example}}", "ws")) } async fn send( State(channels): State>, Json(message): Json, ) { use EventMessage::*; match message { User(data) => { let tag = ChannelTag::UserId(data.user_id); channels.send_by_tag(&tag, User(data)).await } Admin(data) => { let tag = ChannelTag::IsAdmin; channels.send_by_tag(&tag, Admin(data)).await } Broadcast(data) => channels.broadcast(Broadcast(data)).await, } } /// Handler for browser to receive SSE events async fn events( Query(params): Query, State(mut channels): State>, ) -> Sse>> { let stream = async_stream::stream! { let mut rx = channels.create_channel(params.as_tags()); while let Some(msg) = rx.recv().await { let Ok(json) = serde_json::to_string(&msg) else { continue }; yield Ok(SseEvent::default().data(json)); } }; Sse::new(stream) } async fn ws_events( ws: WebSocketUpgrade, Query(params): Query, State(channels): State>, ) -> impl IntoResponse { ws.on_upgrade(move |socket| handle_socket(socket, channels, params.as_tags())) } async fn handle_socket( mut socket: WebSocket, mut channels: TaggedChannels, tags: Vec, ) { let mut rx = channels.create_channel(tags); while let Some(msg) = rx.recv().await { let Ok(json) = serde_json::to_string(&msg) else { continue }; if socket.send(WsMessage::Text(json)).await.is_err() { break; } } } impl ConnectionParams { fn as_tags(&self) -> Vec { let mut tags = Vec::new(); if let Some(id) = self.user_id { tags.push(ChannelTag::UserId(id)); } if self.is_admin { tags.push(ChannelTag::IsAdmin); } tags } }