#![allow(unused)] use std::collections::HashMap; use std::fmt::{Display, Formatter}; use std::io; use std::net::{Ipv4Addr, Ipv6Addr}; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Arc, Mutex as StdMutex}; use async_trait::async_trait; use rand_core::OsRng; use tokio::sync::{mpsc, Mutex}; use wiretun::noise::crypto::LocalStaticSecret; use wiretun::*; pub struct TestKit {} impl TestKit { #[inline(always)] pub fn gen_local_secret() -> LocalStaticSecret { let pri = x25519_dalek::StaticSecret::random_from_rng(OsRng).to_bytes(); LocalStaticSecret::new(pri) } } #[derive(Clone)] pub struct StubTun { inbound_sent: Arc, inbound_recording: Arc>>>, outbound_sent: Arc, outbound_recording: Arc>>>, outbound_tx: mpsc::Sender>, outbound_rx: Arc>>>, inbound_tx: mpsc::Sender>, inbound_rx: Arc>>>, } impl StubTun { pub fn new() -> Self { let (inbound_tx, inbound_rx) = mpsc::channel(256); let (outbound_tx, outbound_rx) = mpsc::channel(256); Self { inbound_sent: Arc::new(AtomicU64::new(0)), inbound_recording: Arc::new(StdMutex::new(vec![])), outbound_sent: Arc::new(AtomicU64::new(0)), outbound_recording: Arc::new(StdMutex::new(vec![])), outbound_tx, outbound_rx: Arc::new(Mutex::new(outbound_rx)), inbound_tx, inbound_rx: Arc::new(Mutex::new(inbound_rx)), } } #[inline(always)] pub fn inbound_sent(&self) -> u64 { self.outbound_sent.load(Ordering::Relaxed) } #[inline(always)] pub fn outbound_sent(&self) -> u64 { self.outbound_sent.load(Ordering::Relaxed) } #[inline(always)] pub fn inbound_recording(&self) -> Vec> { self.inbound_recording.lock().unwrap().clone() } #[inline(always)] pub fn outbound_recording(&self) -> Vec> { self.outbound_recording.lock().unwrap().clone() } pub async fn send_outbound(&self, data: &[u8]) { self.outbound_sent.fetch_add(1, Ordering::Relaxed); self.outbound_recording.lock().unwrap().push(data.to_vec()); self.outbound_tx.send(data.to_vec()).await.unwrap(); } pub async fn fetch_inbound(&self) -> Vec { let mut rx = self.inbound_rx.lock().await; rx.recv().await.unwrap() } } impl Default for StubTun { fn default() -> Self { Self::new() } } #[async_trait] impl Tun for StubTun { fn name(&self) -> &str { "stub" } fn mtu(&self) -> Result { Ok(1500) } fn set_mtu(&self, _mtu: u16) -> Result<(), TunError> { Ok(()) } async fn recv(&self) -> Result, TunError> { let mut rx = self.outbound_rx.lock().await; let data = rx.recv().await.unwrap(); Ok(data) } async fn send(&self, buf: &[u8]) -> Result<(), TunError> { self.inbound_sent.fetch_add(1, Ordering::Relaxed); self.inbound_recording.lock().unwrap().push(buf.to_vec()); self.inbound_tx.send(buf.to_vec()).await.unwrap(); Ok(()) } } type TransportPacket = (Endpoint, Vec); #[derive(Clone)] pub struct StubTransport { ipv4: Ipv4Addr, ipv6: Ipv6Addr, port: u16, inbound_sent: Arc, inbound_recording: Arc>>, outbound_sent: Arc, outbound_recording: Arc>>, outbound_tx: mpsc::Sender, outbound_rx: Arc>>, inbound_tx: mpsc::Sender, inbound_rx: Arc>>, } impl StubTransport { #[inline(always)] pub fn inbound_sent(&self) -> u64 { self.inbound_sent.load(Ordering::Relaxed) } #[inline(always)] pub fn outbound_sent(&self) -> u64 { self.outbound_sent.load(Ordering::Relaxed) } #[inline(always)] pub fn inbound_recording(&self) -> Vec<(Endpoint, Vec)> { self.inbound_recording.lock().unwrap().clone() } #[inline(always)] pub fn outbound_recording(&self) -> Vec<(Endpoint, Vec)> { self.outbound_recording.lock().unwrap().clone() } pub async fn send_inbound(&self, data: &[u8], endpoint: &Endpoint) { self.inbound_sent.fetch_add(1, Ordering::Relaxed); self.inbound_recording .lock() .unwrap() .push((endpoint.clone(), data.to_vec())); self.inbound_tx .send((endpoint.clone(), data.to_vec())) .await .unwrap(); } pub async fn fetch_outbound(&self) -> (Endpoint, Vec) { let mut rx = self.outbound_rx.lock().await; rx.recv().await.unwrap() } } impl Display for StubTransport { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "StubTransport") } } #[async_trait] impl Transport for StubTransport { async fn bind(ipv4: Ipv4Addr, ipv6: Ipv6Addr, port: u16) -> Result { let (inbound_tx, inbound_rx) = mpsc::channel(256); let (outbound_tx, outbound_rx) = mpsc::channel(256); Ok(Self { ipv4, ipv6, port, inbound_sent: Arc::new(AtomicU64::new(0)), inbound_recording: Arc::new(StdMutex::new(vec![])), outbound_sent: Arc::new(AtomicU64::new(0)), outbound_recording: Arc::new(StdMutex::new(vec![])), outbound_tx, outbound_rx: Arc::new(Mutex::new(outbound_rx)), inbound_tx, inbound_rx: Arc::new(Mutex::new(inbound_rx)), }) } fn ipv4(&self) -> Ipv4Addr { self.ipv4 } fn ipv6(&self) -> Ipv6Addr { self.ipv6 } fn port(&self) -> u16 { self.port } async fn send_to(&self, data: &[u8], endpoint: &Endpoint) -> Result<(), io::Error> { self.outbound_sent.fetch_add(1, Ordering::Relaxed); self.outbound_recording .lock() .unwrap() .push((endpoint.clone(), data.to_vec())); self.outbound_tx .send((endpoint.clone(), data.to_vec())) .await .unwrap(); Ok(()) } async fn recv_from(&mut self) -> Result<(Endpoint, Vec), io::Error> { let rv = self.inbound_rx.lock().await.recv().await.unwrap(); Ok(rv) } }