#![feature(portable_simd)] #![feature(array_chunks)] // extern crate blas_src; use std::simd::num::SimdInt; use std::simd::{f32x16, f32x32, f32x8, i16x16, i16x8, i32x16, i32x4, i32x8}; use std::simd::{f32x4, num::SimdFloat, StdFloat}; use std::time::Instant; // use ndarray::linalg::Dot; // use ndarray::{ArcArray, CowArray}; // pub fn f32_dot_prod_ndarray(a: &[f32], b: &[f32]) -> f32 { // let ar = CowArray::from(a); // let br = CowArray::from(b); // ar.dot(&br) // } // pub fn i32_dot_prod_ndarray(a: &[i32], b: &[i32]) -> i32 { // let ar = CowArray::from(a); // let br = CowArray::from(b); // ar.dot(&br) // } pub fn f32_dot_prod_test() { let N = 2048; let K = 256128; let a: Vec = (0..N).map(|i| ((i as f32) / 100.0).sin()).collect(); let b: Vec = (0..N).map(|i| ((i as f32) / 100.0).cos()).collect(); let mut result = 0.0; let mut before = Instant::now(); for _ in 0..K { result = f32_dot_prod_simd_4(&a, &b); } println!("f32x4 time: {:.2?}", before.elapsed()); println!("result: {}", result); before = Instant::now(); for _ in 0..K { result = f32_dot_prod_simd_8(&a, &b); } println!("f32x8 time: {:.2?}", before.elapsed()); println!("result: {}", result); before = Instant::now(); for _ in 0..K { result = f32_dot_prod_simd_16(&a, &b); } println!("f32x16 time: {:.2?}", before.elapsed()); println!("result: {}", result); before = Instant::now(); for _ in 0..K { result = f32_dot_prod_simd_32(&a, &b); } println!("f32x32 time: {:.2?}", before.elapsed()); println!("result: {}", result); before = Instant::now(); for _ in 0..K { result = f32_dot_prod(&a, &b); } println!("f32 plain time: {:.2?}", before.elapsed()); println!("result: {}", result); // before = Instant::now(); // for _ in 0..K { // result = f32_dot_prod_ndarray(&a, &b); // } // println!("f32 ndarray time: {:.2?}", before.elapsed()); // println!("result: {}", result); } pub fn i16_dot_prod_test() { let N = 2048; let K = 256128; let a: Vec = (0..N).collect(); let b: Vec = (0..N).collect(); let mut result: i32 = 0; let mut before = Instant::now(); for _ in 0..K { result = i16_dot_prod_simd_8(&a, &b); } println!("i16x8 time: {:.2?}", before.elapsed()); println!("result: {}", result); before = Instant::now(); for _ in 0..K { result = i16_dot_prod_simd_16(&a, &b); } println!("i16x16 time: {:.2?}", before.elapsed()); println!("result: {}", result); before = Instant::now(); for _ in 0..K { result = i16_dot_prod(&a, &b); } println!("i16 plain time: {:.2?}", before.elapsed()); println!("result: {}", result); } pub fn i32_dot_prod_test() { let N = 2048; let K = 256128; let a: Vec = (0..N).collect(); let b: Vec = (0..N).collect(); let mut result = 0; let mut before = Instant::now(); for _ in 0..K { result = i32_dot_prod_simd_4(&a, &b); } println!("i32x4 time: {:.2?}", before.elapsed()); println!("result: {}", result); before = Instant::now(); for _ in 0..K { result = i32_dot_prod_simd_8(&a, &b); } println!("i32x8 time: {:.2?}", before.elapsed()); println!("result: {}", result); before = Instant::now(); for _ in 0..K { result = i32_dot_prod_simd_16(&a, &b); } println!("i32x16 time: {:.2?}", before.elapsed()); println!("result: {}", result); before = Instant::now(); for _ in 0..K { result = i32_dot_prod(&a, &b); } println!("i32 plain time: {:.2?}", before.elapsed()); println!("result: {}", result); // before = Instant::now(); // for _ in 0..K { // result = i32_dot_prod_ndarray(&a, &b); // } // println!("i32 ndarray time: {:.2?}", before.elapsed()); // println!("result: {}", result); } fn f32_dot_prod_simd_4(a: &[f32], b: &[f32]) -> f32 { a.array_chunks::<4>() .map(|&a| f32x4::from_array(a)) .zip(b.array_chunks::<4>().map(|&b| f32x4::from_array(b))) .fold(f32x4::splat(0.), |acc, (a, b)| a.mul_add(b, acc)) .reduce_sum() } fn i32_dot_prod_simd_4(a: &[i32], b: &[i32]) -> i32 { a.array_chunks::<4>() .map(|&a| i32x4::from_array(a)) .zip(b.array_chunks::<4>().map(|&b| i32x4::from_array(b))) .fold(i32x4::splat(0), |acc, (a, b)| a * b + acc) .reduce_sum() } fn f32_dot_prod_simd_8(a: &[f32], b: &[f32]) -> f32 { a.array_chunks::<8>() .map(|&a| f32x8::from_array(a)) .zip(b.array_chunks::<8>().map(|&b| f32x8::from_array(b))) .fold(f32x8::splat(0.), |acc, (a, b)| a.mul_add(b, acc)) .reduce_sum() } fn i32_dot_prod_simd_8(a: &[i32], b: &[i32]) -> i32 { a.array_chunks::<8>() .map(|&a| i32x8::from_array(a)) .zip(b.array_chunks::<8>().map(|&b| i32x8::from_array(b))) .fold(i32x8::splat(0), |acc, (a, b)| a * b + acc) .reduce_sum() } fn i16_dot_prod_simd_8(a: &[i16], b: &[i16]) -> i32 { a.array_chunks::<8>() .map(|&a| i16x8::from_array(a)) .zip(b.array_chunks::<8>().map(|&b| i16x8::from_array(b))) .fold(i16x8::splat(0), |acc, (a, b)| a * b + acc) .reduce_sum() as i32 } fn f32_dot_prod_simd_16(a: &[f32], b: &[f32]) -> f32 { a.array_chunks::<16>() .map(|&a| f32x16::from_array(a)) .zip(b.array_chunks::<16>().map(|&b| f32x16::from_array(b))) .fold(f32x16::splat(0.), |acc, (a, b)| a.mul_add(b, acc)) .reduce_sum() } fn i32_dot_prod_simd_16(a: &[i32], b: &[i32]) -> i32 { a.array_chunks::<16>() .map(|&a| i32x16::from_array(a)) .zip(b.array_chunks::<16>().map(|&b| i32x16::from_array(b))) .fold(i32x16::splat(0), |acc, (a, b)| a * b + acc) .reduce_sum() } fn i16_dot_prod_simd_16(a: &[i16], b: &[i16]) -> i32 { a.array_chunks::<16>() .map(|&a| i16x16::from_array(a)) .zip(b.array_chunks::<16>().map(|&b| i16x16::from_array(b))) .fold(i16x16::splat(0), |acc, (a, b)| a * b + acc) .reduce_sum() as i32 } fn f32_dot_prod_simd_32(a: &[f32], b: &[f32]) -> f32 { a.array_chunks::<32>() .map(|&a| f32x32::from_array(a)) .zip(b.array_chunks::<32>().map(|&b| f32x32::from_array(b))) .fold(f32x32::splat(0.), |acc, (a, b)| a.mul_add(b, acc)) .reduce_sum() } fn f32_dot_prod(a: &[f32], b: &[f32]) -> f32 { assert_eq!(a.len(), b.len()); let mut product = 0.0; for i in 0..a.len() { product += a[i] * b[i]; } product } fn i32_dot_prod(a: &[i32], b: &[i32]) -> i32 { assert_eq!(a.len(), b.len()); let mut product = 0; for i in 0..a.len() { product += a[i] * b[i]; } product } fn i16_dot_prod(a: &[i16], b: &[i16]) -> i32 { assert_eq!(a.len(), b.len()); let mut product = 0; for i in 0..a.len() { product += a[i] * b[i]; } product as i32 } fn main() { f32_dot_prod_test(); i32_dot_prod_test(); i16_dot_prod_test(); }