use std::fs::File; use std::io::Write; use std::sync::Arc; use std::{fs::create_dir_all, path::PathBuf}; use clap::Parser; use tokio::io::BufStream; use tokio_rustls::rustls::{self, Certificate, PrivateKey}; use tokio_rustls::TlsAcceptor; use tracing_subscriber::util::SubscriberInitExt; use ws_tool::codec::AsyncStringCodec; use ws_tool::{codec::default_handshake_handler, ServerBuilder}; /// websocket client connect to binance futures websocket #[derive(Parser)] struct Args { /// server host #[arg(long, default_value = "127.0.0.1")] host: String, /// server port #[arg(short, long, default_value = "9000")] port: u16, /// relative path from workspace dir for certs #[arg(short, long, default_value = "certs")] cert: PathBuf, /// enable ssl #[arg(short, long)] ssl: bool, /// level #[arg(short, long, default_value = "info")] level: tracing::Level, } #[tokio::main] async fn main() -> Result<(), ()> { let args = Args::parse(); tracing_subscriber::fmt::fmt() .with_max_level(args.level) .with_file(true) .with_line_number(true) .finish() .try_init() .expect("failed to init log"); if args.ssl { let cert = rcgen::generate_simple_self_signed(vec![args.host.clone()]) .expect("unable to generate certs"); let mut cert_dir = std::env::current_dir().expect("failed to get current work dir"); cert_dir.push(args.cert); if !cert_dir.exists() { create_dir_all(&cert_dir).expect("failed to create cert dir"); } let mut cert_file_path = cert_dir.clone(); cert_file_path.push("certs.pem"); let mut cert_file = File::create(&cert_file_path).expect("failed to create cert file"); let cert_content = cert.serialize_pem().unwrap(); cert_file .write_all(cert_content.as_bytes()) .expect("fail to write cert file"); cert_file.sync_all().unwrap(); tracing::info!("cert file saved at {}", cert_file_path.display()); let mut key_file_path = cert_dir.clone(); key_file_path.push("key.pem"); let mut key_file = File::create(&key_file_path).expect("failed to create key file"); let key_content = cert.serialize_private_key_pem(); key_file .write_all(key_content.as_bytes()) .expect("fail to write key file"); key_file.sync_all().unwrap(); tracing::info!("key file saved at {}", key_file_path.display()); let cert_file = File::open(cert_file_path).unwrap(); let mut reader = std::io::BufReader::new(cert_file); let certs = rustls_pemfile::certs(&mut reader) .unwrap() .into_iter() .map(Certificate) .collect(); let key_file = File::open(key_file_path).unwrap(); let mut reader = std::io::BufReader::new(key_file); let mut keys = rustls_pemfile::pkcs8_private_keys(&mut reader).unwrap(); let config = rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() .with_single_cert( // vec![Certificate(cert_content.into())], // PrivateKey(key_content.into()), certs, PrivateKey(keys.remove(0)), ) .expect("failed to init ssl context"); let accepter = TlsAcceptor::from(Arc::new(config)); tracing::info!("binding on {}:{}", args.host, args.port); let listener = tokio::net::TcpListener::bind(format!("{}:{}", args.host, args.port)) .await .unwrap(); loop { let (stream, addr) = listener.accept().await.unwrap(); let stream = match accepter.accept(stream).await { Ok(stream) => stream, Err(e) => { tracing::error!("{e:?}"); continue; } }; let stream = BufStream::with_capacity(0, 0, stream); tracing::info!("got connect from {:?}", addr); let (mut read, mut write) = ServerBuilder::async_accept( stream, default_handshake_handler, // AsyncWsStringCodec::factory, AsyncStringCodec::factory, ) .await .unwrap() .split(); while let Ok(msg) = read.receive().await { write.send((msg.code, msg.data)).await.unwrap(); } } } else { tracing::info!("binding on {}:{}", args.host, args.port); let listener = tokio::net::TcpListener::bind(format!("{}:{}", args.host, args.port)) .await .unwrap(); loop { let (stream, addr) = listener.accept().await.unwrap(); tracing::info!("got connect from {:?}", addr); let (mut read, mut write) = ServerBuilder::async_accept( stream, default_handshake_handler, // AsyncWsStringCodec::factory, AsyncStringCodec::factory, ) .await .unwrap() .split(); loop { match read.receive().await { Ok(msg) => write.send(msg).await.unwrap(), Err(e) => { dbg!(e); break; } } } } } }