use futures::*; use lazy_static::lazy_static; use odbc_futures::*; use odbc_futures_derive::Odbc; use odbc_sys::*; use std::sync::{Arc, RwLock}; lazy_static! { pub static ref ODBC_ENVIRONMENT: Arc> = { let mut env = SqlEnvironment::new().expect("failed to create ODBC environment"); env.set_version(OdbcVersion::SQL_OV_ODBC3_80) .expect("failed to set ODBC environment version"); let pooling = if cfg!(windows) { env.set_connection_pooling(SqlConnectionPooling::DriverAware) } else { env.set_connection_pooling(SqlConnectionPooling::OnePerDriver) }; pooling.expect("failed to setup connection pooling"); //let path = std::env::current_dir() // .expect("failed to get current working directory") // .join("common.log"); //let file = path.to_str().expect("failed to get source file path"); //SqlConnection::set_trace_file(file).expect("failed to setup ODBC trace logging file"); //SqlConnection::set_trace(true).expect("failed to setup ODBC trace logging"); Arc::new(RwLock::new(env)) }; } #[test] fn databases() { #[derive(Debug, Clone, Default, Odbc)] struct Database { name: String, } for (_, stmt) in [RDBMS::MSSQL, RDBMS::Postgres, RDBMS::MariaDB, RDBMS::MySQL] .iter() .filter_map(RDBMS::statement) { let (results, _stmt) = stmt .exec_direct_collect_async::( "select distinct catalog_name as \"name\" from information_schema.schemata", ) .wait() .unwrap(); assert_eq!(1, results.len()); println!("{:?}", results); } } #[test] fn nested() { #[derive(Debug, Clone, Default, Odbc)] struct Outer { a: f32, #[odbc_nested] inner: Inner, } #[derive(Debug, Clone, Default, Odbc)] struct Inner { b: u32 } for (_, stmt) in [RDBMS::MSSQL, RDBMS::Postgres, RDBMS::MariaDB, RDBMS::MySQL] .iter() .filter_map(RDBMS::statement) { let (results, _stmt) = stmt .exec_direct_collect_async::( "select 2.0 as a, 42 as b;", ) .wait() .unwrap(); assert_eq!(1, results.len()); println!("{:?}", results); } } #[test] fn get_type_info() { for (_, stmt) in RDBMS::statements() { let (type_info, _stmt) = stmt.get_type_info_collect_async(None).wait().unwrap(); assert!(!type_info.is_empty()); println!("##########################################"); println!("{:#?}", type_info); println!("##########################################"); } } #[test] fn bind_long() { const LONG_DATA_LENGTH: usize = 12000; use odbc_futures_derive::Odbc; #[derive(Debug, Clone, Default, Odbc)] struct LongData { string: String, binary: Vec, //c_string: String, } let string = "0".repeat(LONG_DATA_LENGTH); for (_rdbms, mut stmt) in [RDBMS::MSSQL, RDBMS::Postgres, RDBMS::MariaDB, RDBMS::MySQL] .iter() .filter_map(RDBMS::statement) { let binary = vec![1u8; LONG_DATA_LENGTH]; let c_string = "1".repeat(LONG_DATA_LENGTH);; stmt.bind_parameter::<_, Vec>(1, Some(string.clone()), None) .unwrap(); stmt.bind_parameter(2, Some(binary), None).unwrap(); stmt.bind_parameter::<_, ()>(3, Some(c_string.clone()), None) .unwrap(); let (mut results, _stmt): (Vec, _) = stmt .exec_direct_collect_async::( r#" select ? as "string" , ? as "binary" , ? as "c_string" "#, ) .wait() .unwrap(); let result = results.pop().unwrap(); assert_eq!(LONG_DATA_LENGTH, result.string.len()); assert!(result.string.chars().all(|c| c == '0')); assert_eq!(LONG_DATA_LENGTH, result.binary.len()); assert!(result.binary.iter().all(|c| *c == 1u8)); //assert_eq!(LONG_DATA_LENGTH, result.c_string.len()); //assert!(result.c_string.chars().all(|c| c == '1')); } } #[test] fn long_data() { //use std::ffi::CString; use odbc_futures_derive::Odbc; #[derive(Debug, Clone, Default, Odbc)] struct LongData { index: u32, string: String, binary: Vec, //c_string: CString, } const LONG_DATA_LENGTH: u32 = 12000; for (rdbms, mut stmt) in [RDBMS::MSSQL, RDBMS::Postgres, RDBMS::MariaDB, RDBMS::MySQL] .iter() .filter_map(RDBMS::statement) { let stmt_text = match rdbms { RDBMS::MSSQL => r#" select x.val as "index" , cast(replicate(cast(cast(x.val as tinyint) as varbinary(max)), x.val * ?) as varbinary(max)) as "binary" , cast(replicate(cast(x.val as nvarchar(max)), x.val * ?) as nvarchar(max)) as "string" , cast(replicate(cast(x.val as varchar(max)), x.val * ?) as varchar(max)) as "c_string" from ( values (1), (2), (3) ) as x(val); "#, RDBMS::Postgres => r#" select repeat(x.val::text, x.val * ?) as "string" , repeat(x.val::text, x.val * ?) as "c_string" , repeat(chr(x.val), x.val * ?)::bytea as "binary" , x.val as "index" from generate_series(1, 3) as x(val); "#, RDBMS::MariaDB | RDBMS::MySQL => { stmt.exec_direct("set session sql_mode = 'ANSI';").unwrap(); r#" with x as ( select 1 as "val" union all select 2 as "val" union all select 3 as "val" ) select repeat(cast(x.val as nchar), x.val * ?) as "string" , repeat(cast(x.val as char), x.val * ?) as "c_string" , cast(repeat(char(x.val), x.val * ?) as binary) as "binary" , x.val as "index" from x; "# }, _ => continue }; for n in 1..=3 { stmt.bind_parameter(n, Some(LONG_DATA_LENGTH), None) .unwrap(); } let (results, _stmt): (Vec, _) = stmt .exec_direct_collect_async::(&stmt_text) .wait() .unwrap(); assert_eq!(3, results.len()); for (row_number, row) in results.iter().enumerate() { assert_eq!(row_number + 1, row.index as usize); let expected_len = LONG_DATA_LENGTH as usize * (row_number + 1); assert_eq!(expected_len, row.binary.len()); assert_eq!(expected_len, row.string.len()); //assert_eq!(expected_len, row.c_string.as_bytes().len()); let binary = row.index as u8; assert!(row.binary.iter().all(|x| *x == binary)); //let digit = std::char::from_digit(row.index, 10).unwrap(); //let c_string = row.c_string.as_str(); //for s in [row.string.as_str(), c_string].into_iter() { // assert!(s.chars().all(|c| c == digit)); //} } } } #[derive(Debug, Copy, Clone)] pub enum RDBMS { MSSQL, Postgres, MariaDB, SQLite3, MySQL, } impl RDBMS { pub fn statement(&self) -> Option<(RDBMS, SqlStatement)> { if let Some(conn) = self.connection() { let stmt = SqlStatement::new(&conn) .expect(&format!("failed to create statement for {:?}", self)); Some((*self, stmt)) } else { None } } pub fn statements() -> impl Iterator { use self::RDBMS::*; [MSSQL, Postgres, MariaDB, SQLite3, MySQL] .iter() .filter_map(Self::statement) } fn connection_string(&self) -> Option { use self::RDBMS::*; let result = match self { MSSQL => std::env::var("ODBC_CONNECTION_STRING_MSSQL"), Postgres => std::env::var("ODBC_CONNECTION_STRING_POSTGRES"), MariaDB => std::env::var("ODBC_CONNECTION_STRING_MARIADB"), SQLite3 => std::env::var("ODBC_CONNECTION_STRING_SQLITE3"), MySQL => std::env::var("ODBC_CONNECTION_STRING_MYSQL"), }; if let Ok(string) = result { Some(string) } else { None } } fn connection(&self) -> Option>> { let connection_string = self.connection_string()?; let mut connection = SqlConnection::new(&ODBC_ENVIRONMENT) .expect(&format!("failed to create connection for {:?}", self)); connection .driver_connect(&connection_string) .expect(&format!("failed to connect to {:?}", self)); Some(Arc::new(RwLock::new(connection))) } }