mod support; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; use std::time::Duration; use tokio::time; use support::*; use wiretun::noise::protocol::{HandshakeInitiation, TransportData}; use wiretun::*; #[tokio::test] async fn test_noop_when_no_endpoint() { let secret = TestKit::gen_local_secret(); let tun = StubTun::new(); let transport = StubTransport::bind(Ipv4Addr::UNSPECIFIED, Ipv6Addr::UNSPECIFIED, 0) .await .unwrap(); let cfg = DeviceConfig::default() .private_key(secret.private_key().to_bytes()) .peer( PeerConfig::default().public_key(TestKit::gen_local_secret().public_key().to_bytes()), ); let device = Device::with_transport(tun.clone(), transport.clone(), cfg) .await .unwrap(); let _ctrl = device.control(); time::sleep(Duration::from_secs(30)).await; assert_eq!(transport.inbound_sent(), 0); assert_eq!(transport.outbound_sent(), 0); assert_eq!(tun.inbound_sent(), 0); assert_eq!(tun.outbound_sent(), 0); } #[tokio::test] async fn test_keep_initiation_when_no_response() { let secret = TestKit::gen_local_secret(); let tun = StubTun::new(); let transport = StubTransport::bind(Ipv4Addr::UNSPECIFIED, Ipv6Addr::UNSPECIFIED, 0) .await .unwrap(); let peer_pub = TestKit::gen_local_secret().public_key().to_bytes(); let peer_endpoint = "10.0.0.1:80".parse().unwrap(); let cfg = DeviceConfig::default() .private_key(secret.private_key().to_bytes()) .peer( PeerConfig::default() .public_key(peer_pub) .endpoint(peer_endpoint), ); let device = Device::with_transport(tun.clone(), transport.clone(), cfg) .await .unwrap(); let _ctrl = device.control(); time::sleep(Duration::from_secs(30)).await; assert_eq!(transport.inbound_sent(), 0); assert!(transport.outbound_sent() > 0); assert_eq!(tun.inbound_sent(), 0); assert_eq!(tun.outbound_sent(), 0); for _ in 0..transport.outbound_sent() { let (endpoint, data) = transport.fetch_outbound().await; assert_eq!(endpoint.dst(), peer_endpoint); let ret = HandshakeInitiation::try_from(data.as_slice()); assert!(ret.is_ok()); } } #[tokio::test] async fn test_complete_handshake() { use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; tracing_subscriber::registry() .with(tracing_subscriber::EnvFilter::new( std::env::var("RUST_LOG").unwrap_or_else(|_| "debug".into()), )) .with(tracing_subscriber::fmt::layer()) .init(); let secret1 = TestKit::gen_local_secret(); let endpoint1 = "10.10.0.1:6789".parse::().unwrap(); let endpoint2 = "10.10.0.2:1245".parse::().unwrap(); let secret2 = TestKit::gen_local_secret(); let (_device1, tun1, transport1) = { let tun = StubTun::new(); let transport = StubTransport::bind(Ipv4Addr::UNSPECIFIED, Ipv6Addr::UNSPECIFIED, 0) .await .unwrap(); let cfg = DeviceConfig::default() .private_key(secret1.private_key().to_bytes()) .peer( PeerConfig::default() .public_key(secret2.public_key().to_bytes()) .allowed_ip(endpoint2.ip()) .endpoint(endpoint2), ); let device = Device::with_transport(tun.clone(), transport.clone(), cfg) .await .unwrap(); (device, tun, transport) }; let (_device2, tun2, transport2) = { let tun = StubTun::new(); let transport = StubTransport::bind(Ipv4Addr::UNSPECIFIED, Ipv6Addr::UNSPECIFIED, 0) .await .unwrap(); let cfg = DeviceConfig::default() .private_key(secret2.private_key().to_bytes()) .peer( PeerConfig::default() .public_key(secret1.public_key().to_bytes()) .allowed_ip(endpoint1.ip()), ); let device = Device::with_transport(tun.clone(), transport.clone(), cfg) .await .unwrap(); (device, tun, transport) }; { let (t1, t2) = (transport1.clone(), transport2.clone()); tokio::spawn(async move { loop { let (endpoint, data) = t1.fetch_outbound().await; assert_eq!(endpoint.dst(), endpoint2); let endpoint = Endpoint::new(t2.clone(), endpoint1); t2.send_inbound(&data, &endpoint).await; } }); let (t1, t2) = (transport1.clone(), transport2.clone()); tokio::spawn(async move { loop { let (endpoint, data) = t2.fetch_outbound().await; assert_eq!(endpoint.dst(), endpoint1); let endpoint = Endpoint::new(t1.clone(), endpoint2); t1.send_inbound(&data, &endpoint).await; } }); } time::sleep(Duration::from_secs(30)).await; assert_eq!(tun1.inbound_sent(), 0); assert_eq!(tun1.outbound_sent(), 0); assert_eq!(tun2.inbound_sent(), 0); assert_eq!(tun2.outbound_sent(), 0); assert!(transport1.inbound_sent() > 0); assert!(transport1.outbound_sent() > 0); assert!(transport2.inbound_sent() > 0); assert!(transport2.outbound_sent() > 0); let (mut d1_completed, mut d2_completed) = (false, false); for (_, data) in transport1.outbound_recording() { if TransportData::try_from(data.as_slice()).is_ok() { d1_completed = true; break; } } for (_, data) in transport2.outbound_recording() { if TransportData::try_from(data.as_slice()).is_ok() { d2_completed = true; break; } } assert!(d1_completed); assert!(!d2_completed); }