use std::collections::HashSet; use std::error::Error; use std::io::Cursor; use std::net::Ipv4Addr; use std::thread; use std::time::{Duration, Instant}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use clap::Clap; use etherparse::{ InternetSlice, PacketBuilder, PacketBuilderStep, SlicedPacket, TransportSlice, UdpHeader, }; use xdpsock::{ socket::{BindFlags, SocketConfig, SocketConfigBuilder, XdpFlags}, umem::{UmemConfig, UmemConfigBuilder}, xsk::{Xsk2, MAX_PACKET_SIZE}, }; #[derive(Clap, Debug, Clone)] enum Mode { Tx, Rx, } /// Send or Receive UDP packets on the specified interface #[derive(Debug, Clone, Clap)] #[clap(version = "1.0", author = "Collins Huff")] struct Opts { /// interface name #[clap(short, long)] dev: String, /// source MAC address #[clap(long)] src_mac: String, /// destination MAC address #[clap(long)] dest_mac: String, /// source IP address #[clap(long)] src_ip: String, /// source port #[clap(long)] src_port: u16, /// destination IP address #[clap(long)] dest_ip: String, /// destination port #[clap(long)] dest_port: u16, /// A level of verbosity, and can be used multiple times #[clap(short, long, parse(from_occurrences))] verbose: i32, /// Transmit or Receive mode #[clap(subcommand)] mode: Mode, /// Number of packets to transmit or wait to receive #[clap(short, long)] n_pkts: u64, #[clap(long)] n_threads: Option, } fn main() { env_logger::init(); let opts: Opts = Opts::parse(); let umem_config = UmemConfigBuilder::new() .frame_count(8192) .comp_queue_size(4096) .fill_queue_size(4096) .build() .unwrap(); let socket_config = SocketConfigBuilder::new() .tx_queue_size(4096) .rx_queue_size(4096) .bind_flags(BindFlags::XDP_COPY) .xdp_flags(XdpFlags::XDP_FLAGS_SKB_MODE) .build() .unwrap(); let n_tx_frames = umem_config.frame_count() / 2; let dev_ifname = opts.dev.clone(); let mut xsk = Xsk2::new( &dev_ifname, 0, umem_config, socket_config, n_tx_frames as usize, ); match opts.mode { Mode::Tx => spawn_tx(xsk, opts), Mode::Rx => spawn_rx(xsk, opts), } } fn spawn_tx(mut xsk: Xsk2, opts: Opts) { let n_send_threads = match opts.n_threads { Some(n) => n, None => 1, }; eprintln!("sending {} pkts", opts.n_pkts); let src_mac = parse_mac(&opts.src_mac).expect("failed to parse src mac addr"); let dest_mac = parse_mac(&opts.dest_mac).expect("failed to parse dest mac addr"); let filter = Filter::new(&opts.src_ip, opts.src_port, &opts.dest_ip, opts.dest_port).unwrap(); let tx_send = xsk.tx_sender().unwrap(); let mut send_handles = vec![]; let pkts_per_thread = opts.n_pkts / n_send_threads; for i in 0..n_send_threads { let n_start = i * pkts_per_thread; let n_end = n_start + pkts_per_thread; eprintln!("thread {} sending nums {} to {}", i, n_start, n_end); let filter = filter.clone(); let tx_send = tx_send.clone(); let send_handle = thread::spawn(move || { for n in n_start..n_end { let pkt_builder = PacketBuilder::ethernet2(src_mac, dest_mac) .ipv4(filter.src_ip, filter.dest_ip, 20) .udp(filter.src_port, filter.dest_port); let pkt_with_payload = generate_pkt(pkt_builder, n); let mut packet: [u8; MAX_PACKET_SIZE] = [0; MAX_PACKET_SIZE]; let l = std::cmp::min(MAX_PACKET_SIZE, pkt_with_payload.len()); let packet_slice = &mut packet[..l]; packet_slice.copy_from_slice(&pkt_with_payload[..l]); tx_send .send((packet, pkt_with_payload.len())) .expect("failed to put packet on tx queue"); } }); send_handles.push(send_handle); } drop(tx_send); for handle in send_handles.into_iter() { handle.join().expect("failed to join tx handle"); } let tx_stats = xsk.shutdown_tx().expect("failed to shutdown tx"); let rx_stats = xsk.shutdown_rx().expect("failed to shut down rx"); eprintln!("tx_stats = {:?}", tx_stats); eprintln!("tx duration = {:?}", tx_stats.duration()); eprintln!("tx pps = {:?}", tx_stats.pps()); eprintln!("rx_stats = {:?}", rx_stats); } fn generate_pkt(pkt_builder: PacketBuilderStep, n: u64) -> Vec { let mut payload = vec![]; payload.write_u64::(n).unwrap(); //get some memory to store the result let mut result = Vec::::with_capacity(pkt_builder.size(payload.len())); //serialize pkt_builder .write(&mut result, &payload) .expect("failed to build packet"); result } #[derive(Debug, Clone)] struct Filter { src_ip: [u8; 4], src_port: u16, dest_ip: [u8; 4], dest_port: u16, } impl Filter { fn new( src_ip: &str, src_port: u16, dest_ip: &str, dest_port: u16, ) -> Result> { let src_ipv4: Ipv4Addr = src_ip.parse()?; let dest_ipv4: Ipv4Addr = dest_ip.parse()?; Ok(Self { src_ip: src_ipv4.octets(), src_port, dest_ip: dest_ipv4.octets(), dest_port, }) } } fn spawn_rx(mut xsk: Xsk2, opts: Opts) { let rx_recv = xsk.rx_receiver().unwrap(); let filter = Filter::new(&opts.src_ip, opts.src_port, &opts.dest_ip, opts.dest_port).unwrap(); let recv_handle = thread::spawn(move || { let mut recvd_nums: HashSet = HashSet::new(); for (pkt, len) in rx_recv.iter() { match SlicedPacket::from_ethernet(&pkt[..len]) { Ok(pkt) => { if filter_pkt(&pkt, &filter) { let mut rdr = Cursor::new(&pkt.payload[0..8]); let n = rdr.read_u64::().unwrap(); recvd_nums.insert(n); } } Err(e) => log::warn!("failed to parse packet {:?}", e), } } recvd_nums }); thread::sleep(Duration::from_secs(30)); let rx_stats = xsk.shutdown_rx().expect("failed to shut down rx"); eprintln!("rx_stats = {:?}", rx_stats); eprintln!("rx duration = {:?}", rx_stats.duration()); eprintln!("rx pps = {:?}", rx_stats.pps()); let tx_stats = xsk.shutdown_tx().expect("failed to shut down tx"); eprintln!("tx_stats = {:?}", tx_stats); let recvd_nums = recv_handle.join().expect("failed to join recv handle"); let expected_recvd_nums: Vec = (0..opts.n_pkts).into_iter().collect(); let mut n_missing = 0; for n in expected_recvd_nums.iter() { if !recvd_nums.contains(n) { //eprintln!("missing {}", n); n_missing += 1; } } eprintln!("missing {} packets", n_missing); } fn filter_pkt(parsed_pkt: &SlicedPacket, filter: &Filter) -> bool { let mut ip_match = false; let mut transport_match = false; if let Some(ref ip) = parsed_pkt.ip { if let InternetSlice::Ipv4(ipv4) = ip { ip_match = (ipv4.source() == filter.src_ip) && (ipv4.destination() == filter.dest_ip); } } if let Some(ref transport) = parsed_pkt.transport { if let TransportSlice::Udp(udp) = transport { transport_match = (udp.source_port() == filter.src_port) && (udp.destination_port() == filter.dest_port); } } ip_match && transport_match } fn parse_mac(mac: &str) -> Result<[u8; 6], Box> { let mut mac_bytes: [u8; 6] = [0; 6]; let parts: Vec<&str> = mac.split(':').into_iter().collect(); if parts.len() != 6 { Err("wrong len".into()) } else { for (i, part) in parts.iter().enumerate() { let mac_byte = u8::from_str_radix(part, 16)?; mac_bytes[i] = mac_byte; } Ok(mac_bytes) } }