use async_icmp::message::decode::DecodedIcmpMsg; use async_icmp::message::{IcmpV4MsgType, IcmpV6MsgType}; use async_icmp::{message::IcmpEchoRequest, socket::IcmpSocket, IpVersion}; use clap::Parser as _; use itertools::Itertools; use log::{debug, error, info, warn}; use rand::Rng; use std::sync::Arc; use std::{collections, net, time}; use tokio::select; use tokio::sync::mpsc; use tokio::task; use winnow::{binary, combinator, Parser}; #[tokio::main] async fn main() -> anyhow::Result<()> { env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("INFO")) .format_timestamp_millis() .init(); let cli = Cli::parse(); let ip_version = (&cli.dest).into(); let socket = Arc::new(IcmpSocket::new(ip_version)?); let mut rng = rand::thread_rng(); let id = rng.gen(); let mut data = vec![0; 32]; rng.fill(&mut data[..]); let (recv_tx, mut recv_rx) = mpsc::channel(cli.count.into()); let recv_task = spawn_recv_task(id, cli.count, data.clone(), socket.clone(), recv_tx); let (send_tx, mut send_rx) = mpsc::channel(cli.count.into()); let send_task = spawn_send_task( id, cli.count, cli.between_pings.into(), &data, socket, cli.dest, send_tx, ); // get sent tstamp but no reply tstamp yet let mut send_unpaired_timestamps = collections::HashMap::::new(); // got reply tstamp but no sent tstamp yet let mut recv_unpaired_timestamps = collections::HashMap::::new(); // seqs for which we haven't yet seen a sent ts let mut pending_send_seqs = (0_u16..cli.count).collect::>(); // send tstamps with not yet timed out pending replies. // Like `send_unpaired_timestamps` but filtered out entries that are too old. let mut waitable_seqs = collections::HashMap::::new(); // durations for matching send/recv pairs let mut durations = collections::HashMap::new(); loop { // keep the set from growing without bound during packet loss: // retain only send tstamps that haven't expired let oldest_non_expired_timestamp = time::Instant::now() .checked_sub(cli.timeout.into()) .expect("Any reasonable timeout will still fit in Instant"); waitable_seqs.retain(|_seq, tstamp| *tstamp >= oldest_non_expired_timestamp); // Wait up to `timeout` past the latest send tstamp let recv_timeout_at = waitable_seqs .iter() .max_by_key(|(_seq, ts)| **ts) .and_then(|(_seq, ts)| ts.checked_add(cli.timeout.into())) .unwrap_or( // If there's no response to wait for, use effectively infinite timeout time::Instant::now() .checked_add(time::Duration::from_secs(3600 * 24 * 365)) .expect("1yr in the future should fit in Instant"), ); let should_wait_for_send = !pending_send_seqs.is_empty() && !send_rx.is_closed(); let should_wait_for_recv = !waitable_seqs.is_empty() && !recv_rx.is_closed(); let opt_dur = select! { send = send_rx.recv(), if should_wait_for_send => { match send { None => { // channel just closed error!("Send task died early"); continue; } Some((seq, tstamp)) => { pending_send_seqs.remove(&seq); // task scheduling might lead to finding the send after the recv if let Some(recv_ts) = recv_unpaired_timestamps.remove(&seq) { Some((seq, recv_ts.saturating_duration_since(tstamp))) } else { send_unpaired_timestamps.insert(seq, tstamp); waitable_seqs.insert(seq, tstamp); None } } } } recv = tokio::time::timeout_at(recv_timeout_at.into(), recv_rx.recv()), if should_wait_for_recv => { match recv { Ok(opt) => match opt { None => { // channel just closed error!("Recv task died early"); continue; } Some((seq, tstamp)) => { if let Some(send_ts) = send_unpaired_timestamps.remove(&seq) { waitable_seqs.remove(&seq); Some((seq, tstamp.duration_since(send_ts))) } else { recv_unpaired_timestamps.insert(seq, tstamp); None } } } Err(_e) => { // we timed out while waiting for recv. // let waitable filtering apply continue; } } } else => { // no send seq #'s left, and nothing to wait for: we're done break; } }; if let Some((seq, dur)) = opt_dur { info!("Ping seq={seq} response received in {dur:?}"); durations.insert(seq, dur); } } drop(send_rx); drop(recv_rx); let _ = send_task.await?; let _ = recv_task.await?; info!( "{} pings total, {} responses received ({}%{}), min {:?}, mean {:?}, max {:?}", cli.count, durations.len(), durations.len() as f64 / cli.count as f64 * 100.0_f64, if send_unpaired_timestamps.is_empty() { "".to_string() } else { format!( ", missing seqs {:?}", send_unpaired_timestamps .into_iter() .map(|(seq, _ts)| seq) .sorted() .collect_vec() ) }, durations .iter() .map(|(_seq, dur)| dur) .min() .unwrap_or(&time::Duration::ZERO), durations .iter() .map(|(_seq, dur)| dur) .sum::() / durations.len().try_into().expect("# pings will fit in u32"), durations .iter() .map(|(_seq, dur)| dur) .max() .unwrap_or(&time::Duration::ZERO), ); Ok(()) } fn spawn_send_task( id: u16, count: u16, between_pings: time::Duration, data: &[u8], socket: Arc, dest: net::IpAddr, tx: mpsc::Sender<(u16, time::Instant)>, ) -> task::JoinHandle<()> { let mut echo = IcmpEchoRequest::new(id, 0, &data); tokio::spawn(async move { let start = time::Instant::now(); for seq in 0..count { tokio::time::sleep_until((start + between_pings * u32::from(seq)).into()).await; echo.set_seq(seq); if let Err(e) = socket.send_to(dest, &mut echo).await { warn!("Could not send ping seq={seq}: {}", e); } info!("Ping seq={} sent", seq); let tstamp = time::Instant::now(); if let Err(e) = tx.try_send((seq, tstamp)) { match e { mpsc::error::TrySendError::Full(_) => { error!( "Dropped send timestamp -- consumer dead? seq={} time={:?}", seq, tstamp ); } mpsc::error::TrySendError::Closed(_) => { // upstream must have shut down return; } } } } }) } fn spawn_recv_task( id: u16, count: u16, data: Vec, socket: Arc, tx: mpsc::Sender<(u16, time::Instant)>, ) -> task::JoinHandle<()> { tokio::spawn(async move { let mut seqs_to_recv = (0..count).collect::>(); // we're only sending 32 byte data so this allows some room for echo replies // with extra data appended, as seems to happen from some hosts let mut buf = vec![0_u8; 128]; let echo_reply_msg_type = match socket.ip_version() { IpVersion::V4 => IcmpV4MsgType::EchoReply as u8, IpVersion::V6 => IcmpV6MsgType::EchoReply as u8, }; loop { let res = select! { _ = tx.closed() => { debug!("Recv task cancelled"); return; } res = socket.recv(&mut buf[..]) => { res } }; match res { Ok((bytes, _)) => { let tstamp = time::Instant::now(); let res = DecodedIcmpMsg::decode(bytes); match res { Ok(message) => { if message.msg_type() != echo_reply_msg_type || message.msg_code() != 0 { debug!( "Skipping message with type={}, code={}", message.msg_type(), message.msg_code() ); continue; } debug!( "Got ICMP echo reply message with type={}, code={}", message.msg_type(), message.msg_code() ); match parse_echo_reply(message.body()) { Ok((incoming_id, incoming_seq, incoming_data)) => { debug!( "Parsed echo reply: id={}, seq={}, data=0x{}", incoming_id, incoming_seq, hex::encode_upper(incoming_data) ); if incoming_id != id { debug!("Unexpected id: {incoming_id}, expected {id}"); continue; } if !seqs_to_recv.remove(&incoming_seq) { debug!("Unexpected (duplicate?) seq: {incoming_seq}"); continue; } if data != incoming_data { debug!( "Unexpected data: {}", hex::encode_upper(incoming_data) ); continue; } debug!("Received seq={incoming_seq}"); if let Err(e) = tx.try_send((incoming_seq, tstamp)) { match e { mpsc::error::TrySendError::Full(_) => { error!("Dropped recv timestamp -- consumer dead? seq={} time={:?}", incoming_seq, tstamp); } mpsc::error::TrySendError::Closed(_) => { return; } } } if seqs_to_recv.is_empty() { debug!("All seq numbers received"); return; } } Err(e) => { warn!("Echo reply could not be parsed: {}", e); } } } Err(_) => { warn!("Packet could not be decoded: {}", hex::encode_upper(bytes)) } } } Err(e) => warn!("Recv error: {}", e), } } }) } /// Parse into (`ident`, `seq`, `data`) fn parse_echo_reply( icmp_reply_body: &[u8], ) -> Result< (u16, u16, &[u8]), winnow::error::ParseError, winnow::error::ContextError>, > { (binary::be_u16, binary::be_u16, combinator::rest).parse(winnow::Located::new(icmp_reply_body)) } #[derive(clap::Parser)] struct Cli { /// The number of times to ping #[arg(short, long, default_value_t = 3)] count: u16, /// How long to wait between sending pings #[arg(long, default_value_t = time::Duration::from_millis(250).into())] between_pings: humantime::Duration, /// Wait at least this long for a response #[arg(long, default_value_t = time::Duration::from_secs(2).into())] timeout: humantime::Duration, /// Ip address to ping dest: net::IpAddr, }