// * The definition of a "mixed script confusable" can change over time, so when testing unicode // identifiers there is no safe identifier we can use that is guaranteed never to generate the warning. // * It has to be allowed at global level for the crate, because that's the level at // which the warning is generated. #![allow(mixed_script_confusables)] use std::io::{self, Read, Write, Cursor}; use byteorder::{WriteBytesExt, ReadBytesExt, LittleEndian}; use npyz::{DType, Field, Serialize, Deserialize, AutoSerialize, WriterBuilder}; use half::f16; // Allows to use the `#[test]` on WASM. #[cfg(target_arch="wasm32")] use wasm_bindgen_test::wasm_bindgen_test as test; #[derive(Serialize, Deserialize, AutoSerialize)] #[derive(Debug, PartialEq, Clone)] struct Nested { v1: f32, v2: f32, } #[derive(Serialize, Deserialize, AutoSerialize)] #[derive(Debug, PartialEq, Clone)] struct Array { v_i8: i8, v_i16: i16, v_i32: i32, v_i64: i64, v_u8: u8, v_u16: u16, v_u32: u32, v_u64: u64, v_f16: f16, v_f32: f32, v_f64: f64, v_arr_u32: [u32;7], v_mat_u64: [[u64; 3]; 5], vec: Vector5, nested: Nested, } #[derive(Serialize, Deserialize, AutoSerialize)] #[derive(Debug, PartialEq, Clone)] struct Version3 { v1: f32, v2: f32, } #[derive(Debug, PartialEq, Clone)] struct Vector5(Vec); impl AutoSerialize for Vector5 { #[inline] fn default_dtype() -> DType { DType::Array(5, Box::new(DType::Plain(" Result { if dtype == &Self::default_dtype() { Ok(Vector5Writer) } else { Err(npyz::DTypeError::custom("Vector5 only supports ' Result { if dtype == &Self::default_dtype() { Ok(Vector5Reader) } else { Err(npyz::DTypeError::custom("Vector5 only supports '(&self, mut writer: W, value: &Self::Value) -> std::io::Result<()> { for i in 0..5 { writer.write_i32::(value.0[i])? } Ok(()) } } impl npyz::TypeRead for Vector5Reader { type Value = Vector5; #[inline] fn read_one(&self, mut reader: R) -> std::io::Result { let mut ret = Vector5(vec![]); for _ in 0..5 { ret.0.push(reader.read_i32::()?); } Ok(ret) } } #[test] fn roundtrip() { let n = 100i64; let mut arrays = vec![]; for i in 0..n { let j = i as u32 * 5 + 2; let k = i as u64 * 2 + 5; let a = Array { v_i8: i as i8, v_i16: i as i16, v_i32: i as i32, v_i64: i as i64, v_u8: i as u8, v_u16: i as u16, v_u32: i as u32, v_u64: i as u64, v_f16: f16::from_f32(i as f32), v_f32: i as f32, v_f64: i as f64, v_arr_u32: [j,1+j,2+j,3+j,4+j,5+j,6+j], v_mat_u64: [[k,1+k,2+k],[3+k,4+k,5+k],[6+k,7+k,8+k],[9+k,10+k,11+k],[12+k,13+k,14+k]], vec: Vector5(vec![1,2,3,4,5]), nested: Nested { v1: 10.0 * i as f32, v2: i as f32 }, }; arrays.push(a); } let mut writer = io::Cursor::new(vec![]); let mut out_file = npyz::WriteOptions::new().default_dtype().writer(&mut writer).begin_1d().unwrap(); out_file.extend(arrays.iter()).unwrap(); out_file.finish().unwrap(); let buf = writer.into_inner(); assert_version(&buf, (1, 0)); let arrays2 = npyz::NpyFile::new(&buf[..]).unwrap().into_vec().unwrap(); assert_eq!(arrays, arrays2); } fn plain_field(name: &str, dtype: &str) -> Field { Field { name: name.to_string(), dtype: DType::new_scalar(dtype.parse().unwrap()), } } #[test] fn roundtrip_with_plain_dtype() { let array_written = vec![2., 3., 4., 5.]; let mut writer = io::Cursor::new(vec![]); let mut out_file = npyz::WriteOptions::new().default_dtype().writer(&mut writer).begin_1d().unwrap(); out_file.extend(array_written.iter()).unwrap(); out_file.finish().unwrap(); let buffer = writer.into_inner(); let array_read = npyz::NpyFile::new(&buffer[..]).unwrap().into_vec().unwrap(); assert_eq!(array_written, array_read); } #[test] fn roundtrip_byteorder() { #[derive(npyz::Serialize, npyz::Deserialize)] #[derive(Debug, PartialEq, Clone)] struct Row { be_u32: u32, le_u32: u32, be_f16: f16, le_f16: f16, be_f32: f32, le_f32: f32, be_i8: i8, le_i8: i8, na_i8: i8, } let dtype = DType::Record(vec![ plain_field("be_u32", ">u4"), plain_field("le_u32", "f2"), plain_field("le_f16", "f4"), plain_field("le_f32", "i1"), plain_field("le_i8", "().unwrap(), vec![row]); } #[test] fn roundtrip_datetime() { // Similar to: // // ``` // import numpy.datetime64 as dt // import numpy as np // // arr = np.array([( // dt('2011-01-01', 'ns'), // dt('2011-01-02') - dt('2011-01-01'), // dt('2011-01-02') - dt('2011-01-01'), // )], dtype=[ // ('datetime', 'm8[D]'), // ]) // ``` #[derive(npyz::Serialize, npyz::Deserialize)] #[derive(Debug, PartialEq, Clone)] struct Row { datetime: i64, timedelta_le: i64, timedelta_be: i64, } let dtype = DType::Record(vec![ plain_field("datetime", "m8[D]"), ]); let row = Row { datetime: 1_293_840_000_000_000_000, timedelta_le: 1, timedelta_be: 1, }; let expected_data_bytes = { let mut buf = vec![]; buf.extend_from_slice(&i64::to_le_bytes(1_293_840_000_000_000_000)); buf.extend_from_slice(&i64::to_le_bytes(1)); buf.extend_from_slice(&i64::to_be_bytes(1)); buf }; let mut writer = io::Cursor::new(vec![]); let mut out_file = npyz::WriteOptions::new().dtype(dtype.clone()).writer(&mut writer).begin_1d().unwrap(); out_file.push(&row).unwrap(); out_file.finish().unwrap(); let buffer = writer.into_inner(); assert!(buffer.ends_with(&expected_data_bytes)); let data = npyz::NpyFile::new(&buffer[..]).unwrap(); assert_eq!(data.dtype(), dtype); assert_eq!(data.into_vec::().unwrap(), vec![row]); } #[test] fn roundtrip_bytes() { // Similar to: // // ``` // import numpy as np // // arr = np.array([( // b"\x00such\x00wow", // b"\x00such\x00wow\x00\x00\x00", // )], dtype=[ // ('bytestr', '|S12'), // ('raw', '|V12'), // ]) // ``` #[derive(npyz::Serialize, npyz::Deserialize)] #[derive(Debug, PartialEq, Clone)] struct Row { bytestr: Vec, raw: Vec, } let dtype = DType::Record(vec![ plain_field("bytestr", "|S12"), plain_field("raw", "|V12"), ]); let row = Row { // checks that: // * bytestr can be shorter than the len // * bytestr can contain non-trailing NULs bytestr: b"\x00lol\x00lol".to_vec(), // * raw can contain trailing NULs raw: b"\x00lol\x00lol\x00\x00\x00\x00".to_vec(), }; let expected_data_bytes = { let mut buf = vec![]; // check that bytestr is nul-padded buf.extend_from_slice(b"\x00lol\x00lol\x00\x00\x00\x00"); buf.extend_from_slice(b"\x00lol\x00lol\x00\x00\x00\x00"); buf }; let mut writer = io::Cursor::new(vec![]); let mut out_file = npyz::WriteOptions::new().dtype(dtype.clone()).writer(&mut writer).begin_1d().unwrap(); out_file.push(&row).unwrap(); out_file.finish().unwrap(); let buffer = writer.into_inner(); assert!(buffer.ends_with(&expected_data_bytes)); let data = npyz::NpyFile::new(&buffer[..]).unwrap(); assert_eq!(data.dtype(), dtype); assert_eq!(data.into_vec::().unwrap(), vec![row]); } // check that all byte orders are identical for bytestrings // (i.e. don't accidentally reverse the bytestrings) #[test] fn roundtrip_bytes_byteorder() { #[derive(npyz::Serialize, npyz::Deserialize)] #[derive(Debug, PartialEq, Clone)] struct Row { s_le: Vec, s_be: Vec, s_na: Vec, v_le: Vec, v_be: Vec, v_na: Vec, } let dtype = DType::Record(vec![ plain_field("s_le", "S4"), plain_field("s_na", "|S4"), plain_field("v_le", "V4"), plain_field("v_na", "|V4"), ]); let row = Row { s_le: b"abcd".to_vec(), s_be: b"abcd".to_vec(), s_na: b"abcd".to_vec(), v_le: b"abcd".to_vec(), v_be: b"abcd".to_vec(), v_na: b"abcd".to_vec(), }; let expected_data_bytes = { let mut buf = vec![]; for _ in 0..6 { buf.extend_from_slice(b"abcd"); } buf }; let mut writer = io::Cursor::new(vec![]); let mut out_file = npyz::WriteOptions::new().dtype(dtype.clone()).writer(&mut writer).begin_1d().unwrap(); out_file.push(&row).unwrap(); out_file.finish().unwrap(); let buffer = writer.into_inner(); assert!(buffer.ends_with(&expected_data_bytes)); let data = npyz::NpyFile::new(&buffer[..]).unwrap(); assert_eq!(data.dtype(), dtype); assert_eq!(data.into_vec::().unwrap(), vec![row]); } #[test] fn nested_array_of_struct() { #[derive(npyz::Deserialize, npyz::Serialize, npyz::AutoSerialize)] #[derive(Debug, PartialEq, Clone, Copy, Default)] struct Outer { foo: [Inner; 3], } #[derive(npyz::Deserialize, npyz::Serialize, npyz::AutoSerialize)] #[derive(Debug, PartialEq, Clone, Copy, Default)] struct Inner { bar: f64, } let dtype = DType::Record(vec![ Field { name: "foo".into(), dtype: DType::Array(3, Box::new(DType::Record(vec![ plain_field("bar", "().unwrap(), vec![row]); } #[test] fn roundtrip_zero_length_array_member() { // Similar to: // // ``` // import numpy as np // // arr = np.array([ // (3, np.zeros((3, 0, 7))), // (4, np.zeros((3, 0, 7))), // ], dtype=[ // ('a', '().unwrap(), vec![row_0, row_1]); } // Try ndim == 0 #[test] fn roundtrip_scalar() { // This is format.npy in a bsr formatted matrix. type Row = i32; let row: Row = 1; let dtype = DType::new_scalar("().unwrap(), vec![row]); } // try a unicode field name, which forces version 3 #[test] fn roundtrip_version3() { #[derive(npyz::Serialize, npyz::Deserialize, npyz::AutoSerialize)] #[derive(Debug, PartialEq, Clone)] struct Row { num: i32, αβ: i32, } let dtype = DType::Record(vec![ plain_field("num", "().unwrap(), vec![row]); } #[track_caller] fn assert_version(npy_bytes: &[u8], expected: (u8, u8)) { assert_eq!(&npy_bytes[6..8], &[expected.0, expected.1]); }