use std::net::SocketAddr; use http_body_util::Empty; use hyper::body::Bytes; use rand::Rng; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use websocket_sans_io::{ FrameInfo, Opcode, WebsocketFrameDecoder, WebsocketFrameEncoder, WebsocketFrameEvent, }; #[path="../src/tokiort.rs"] mod tokiort; #[tokio::main(flavor = "current_thread")] async fn main() -> Result<(), Box> { let addr: SocketAddr = "127.0.0.1:1234".parse()?; let s = tokio::net::TcpStream::connect(addr).await?; let s = tokiort::TokioIo::new(s); let b = hyper::client::conn::http1::Builder::new(); let (mut sr, conn) = b.handshake::<_, Empty>(s).await?; tokio::spawn(async move { match conn.with_upgrades().await { Ok(()) => (), Err(e) => { eprintln!("Error from connection: {e}"); } } }); let rq = hyper::Request::builder() .uri("/") .header("Connection", "Upgrade") .header("Upgrade", "websocket") .header("Sec-WebSocket-Version", "13") .header("Sec-WebSocket-Key", "wmnj1sVQ7pHv1bVR/wraDw==") .body(Empty::new())?; let resp = sr.send_request(rq).await?; let upg = hyper::upgrade::on(resp).await?; let Ok(parts) = upg.downcast::>() else { return Err("Failed to downcast".into()); }; let mut s = parts.io.inner(); let debt = parts.read_buf; let mut buf = Vec::::with_capacity(debt.len().max(4096)); buf.extend_from_slice(&debt[..]); let mut frame_decoder = WebsocketFrameDecoder::new(); let mut frame_encoder = WebsocketFrameEncoder::new(); let mut bufptr = 0; let mut error = false; println!("Connected to a WebSocket"); loop { let bufslice = &mut buf[bufptr..]; let ret = frame_decoder.add_data(bufslice).unwrap(); bufptr += ret.consumed_bytes; if let Some(ref ev) = ret.event { match ev { WebsocketFrameEvent::Start{frame_info: mut fi, ..} => { if !fi.is_reasonable() { println!("Unreasonable frame"); error = true; break; } if fi.mask.is_some() { println!("Masked frame while expected unmasked one"); error = true; break; } println!( "Frame {:?} with payload length {}{}", fi.opcode, fi.payload_length, if fi.fin { "" } else { " (non-final)" } ); if fi.opcode == Opcode::Ping { fi.opcode = Opcode::Pong; } if fi.opcode == Opcode::ConnectionClose { break; } fi.mask = Some(rand::thread_rng().gen()); let header = frame_encoder.start_frame(&fi); s.write_all(&header[..]).await?; } WebsocketFrameEvent::PayloadChunk{original_opcode: _} => { let payload_slice = &mut bufslice[0..ret.consumed_bytes]; frame_encoder.transform_frame_payload(payload_slice); s.write_all(payload_slice).await?; } WebsocketFrameEvent::End{..} => (), } } if ret.consumed_bytes == 0 && ret.event.is_none() { bufptr = 0; buf.resize(buf.capacity(), 0); let n = s.read(&mut buf[..]).await?; if n == 0 { println!("EOF"); error = true; break; } buf.resize(n, 0); } } let header = frame_encoder.start_frame(&FrameInfo { opcode: Opcode::ConnectionClose, payload_length: 2, mask: Some(rand::thread_rng().gen()), fin: true, reserved: 0, }); s.write_all(&header[..]).await?; let mut last_buf : [u8; 2] = if error { 1002u16 } else { 1000u16 }.to_be_bytes(); frame_encoder.transform_frame_payload(&mut last_buf[..]); s.write_all(&last_buf[..]).await?; println!("Finished"); Ok(()) }