use std::{num::NonZeroU32, time::Duration}; use wgpu_sort::{utils::{download_buffer, guess_workgroup_size}, GPUSorter, SortBuffers}; struct SortStuff{ device:wgpu::Device, queue:wgpu::Queue, query_set:wgpu::QuerySet, query_buffer:wgpu::Buffer, } async fn setup()-> SortStuff{ let instance = wgpu::Instance::new(wgpu::InstanceDescriptor::default()); let adapter = wgpu::util::initialize_adapter_from_env_or_default(&instance, None) .await .unwrap(); let (device, queue) = adapter .request_device( &wgpu::DeviceDescriptor { required_features: wgpu::Features::TIMESTAMP_QUERY, required_limits: wgpu::Limits{ max_buffer_size:1<<30, max_storage_buffer_binding_size:1<<30, ..Default::default() }, label: None, }, None, ) .await .unwrap(); let capacity = 2; let query_set = device.create_query_set(&wgpu::QuerySetDescriptor { label: Some("time stamp query set"), ty: wgpu::QueryType::Timestamp, count: capacity, }); let query_buffer = device.create_buffer(&wgpu::BufferDescriptor { label: Some("query set buffer"), size: capacity as u64 * std::mem::size_of::() as u64, usage: wgpu::BufferUsages::QUERY_RESOLVE | wgpu::BufferUsages::COPY_SRC, mapped_at_creation: false, }); return SortStuff{device,queue,query_set,query_buffer} } async fn sort(context:&SortStuff,sorter:&GPUSorter,buffers:&SortBuffers,n:u32,iters:u32) -> Duration { let mut encoder = context.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None, }); encoder.write_timestamp(&context.query_set, 0); for _ in 0..iters{ sorter.sort(&mut encoder,&context.queue,buffers,Some(n)); } encoder.write_timestamp(&context.query_set, 1); encoder.resolve_query_set( &context.query_set, 0..2, &context.query_buffer, 0, ); let idx = context.queue.submit([encoder.finish()]); context.device.poll(wgpu::Maintain::WaitForSubmissionIndex(idx)); let timestamps : Vec = pollster::block_on(download_buffer(&context.query_buffer, &context.device, &context.queue, ..)); let diff_ticks = timestamps[1] - timestamps[0]; let period = context.queue.get_timestamp_period(); let diff_time = Duration::from_nanos((diff_ticks as f32 * period / iters as f32) as u64); return diff_time; } #[pollster::main] async fn main() { let context = setup().await; let subgroup_size = guess_workgroup_size(&context.device, &context.queue).await.expect("could not find a valid subgroup size"); let sorter = GPUSorter::new(&context.device, subgroup_size); for n in [10_000,100_000,1_000_000,8_000_000,20_000_000]{ let buffers = sorter.create_sort_buffers(&context.device, NonZeroU32::new(n).unwrap()); let d = sort(&context,&sorter, &buffers,n,10000).await; println!("{n}: {d:?}"); } }