use anyhow::Result; use connector_arrow::arrow::record_batch::RecordBatch; use glob::glob; use prqlc::sql::Dialect; use super::protocol::DbProtocol; static DATA_FILE_ROOT_KEYWORD: &str = "data_file_root"; pub(crate) trait DbTestRunner: Send { fn dialect(&self) -> Dialect; fn data_file_root(&self) -> &str; fn import_csv(&mut self, csv_path: &str, table_name: &str); fn modify_ddl(&self, sql: String) -> String; fn query(&mut self, prql: &str) -> Result { let prql = prql.replace(DATA_FILE_ROOT_KEYWORD, self.data_file_root()); let options = prqlc::Options::default().with_target(prqlc::Target::Sql(Some(self.dialect()))); let sql = prqlc::compile(&prql, &options)?; self.protocol().query(&sql) } fn protocol(&mut self) -> &mut dyn DbProtocol; fn setup(&mut self) { let schema = include_str!("../data/chinook/schema.sql"); let statements: Vec = schema .split(';') .map(str::trim) .filter(|s| !s.is_empty()) .map(|s| self.modify_ddl(s.to_string())) .collect(); for statement in statements { self.protocol().execute(&statement).unwrap(); } for file_path in glob("tests/integration/data/chinook/*.csv").unwrap() { let file_path = file_path.unwrap(); let stem = file_path.file_stem().unwrap().to_str().unwrap(); let path = format!("{}/{}.csv", self.data_file_root(), stem); self.import_csv(&path, stem); } } } pub(crate) struct DuckDbTestRunner { protocol: connector_arrow::duckdb::DuckDBConnection, data_file_root: String, } impl DuckDbTestRunner { pub(crate) fn new(data_file_root: String) -> Self { let conn = ::duckdb::Connection::open_in_memory().unwrap(); let conn_ar = connector_arrow::duckdb::DuckDBConnection::new(conn); Self { protocol: conn_ar, data_file_root, } } } impl DbTestRunner for DuckDbTestRunner { fn dialect(&self) -> Dialect { Dialect::DuckDb } fn protocol(&mut self) -> &mut dyn DbProtocol { &mut self.protocol } fn data_file_root(&self) -> &str { &self.data_file_root } fn import_csv(&mut self, csv_path: &str, table_name: &str) { self.protocol .execute(&format!( "COPY {table_name} FROM '{csv_path}' (AUTO_DETECT TRUE);" )) .unwrap(); } fn modify_ddl(&self, sql: String) -> String { sql.replace("FLOAT", "DOUBLE") } } pub(crate) struct SQLiteTestRunner { protocol: connector_arrow::sqlite::SQLiteConnection, data_file_root: String, } impl SQLiteTestRunner { pub(crate) fn new(data_file_root: String) -> Self { let conn = rusqlite::Connection::open_in_memory().unwrap(); let conn_ar = connector_arrow::sqlite::SQLiteConnection::new(conn); Self { protocol: conn_ar, data_file_root, } } } impl DbTestRunner for SQLiteTestRunner { fn dialect(&self) -> Dialect { Dialect::SQLite } fn protocol(&mut self) -> &mut dyn DbProtocol { &mut self.protocol } fn data_file_root(&self) -> &str { &self.data_file_root } fn import_csv(&mut self, csv_path: &str, table_name: &str) { let mut reader = csv::ReaderBuilder::new() .has_headers(true) .from_path(csv_path) .unwrap(); let headers = reader .headers() .unwrap() .iter() .map(|s| s.to_string()) .collect::>(); for result in reader.records() { let r = result.unwrap(); let q = format!( "INSERT INTO {table_name} ({}) VALUES ({})", headers.join(","), r.iter() .map(|s| if s.is_empty() { "null".to_string() } else { format!("\"{}\"", s.replace('"', "\"\"")) }) .collect::>() .join(",") ); self.protocol.execute(&q).unwrap(); } } fn modify_ddl(&self, sql: String) -> String { sql.replace("TIMESTAMP", "TEXT") // timestamps in chinook are stored as ISO8601 } } #[cfg(feature = "test-dbs-external")] pub(crate) use self::external::*; #[cfg(feature = "test-dbs-external")] pub(crate) mod external { use super::*; use std::fs; pub(crate) struct PostgresTestRunner { protocol: connector_arrow::postgres::PostgresConnection, data_file_root: String, } impl PostgresTestRunner { pub(crate) fn new(url: &str, data_file_root: String) -> Self { use connector_arrow::postgres::PostgresConnection; let client = ::postgres::Client::connect(url, ::postgres::NoTls).unwrap(); Self { protocol: PostgresConnection::new(client), data_file_root, } } } impl DbTestRunner for PostgresTestRunner { fn dialect(&self) -> Dialect { Dialect::Postgres } fn protocol(&mut self) -> &mut dyn DbProtocol { &mut self.protocol } fn data_file_root(&self) -> &str { &self.data_file_root } fn import_csv(&mut self, csv_path: &str, table_name: &str) { self.protocol .execute(&format!( "COPY {table_name} FROM '{csv_path}' DELIMITER ',' CSV HEADER;" )) .unwrap(); } fn modify_ddl(&self, sql: String) -> String { sql.replace("FLOAT", "DOUBLE PRECISION") } } pub(crate) struct MySqlTestRunner { protocol: connector_arrow::mysql::MySQLConnection<::mysql::Conn>, data_file_root: String, } impl MySqlTestRunner { pub(crate) fn new(url: &str, data_file_root: String) -> Self { let conn = ::mysql::Conn::new(url) .unwrap_or_else(|e| panic!("Failed to connect to {}:\n{}", url, e)); Self { protocol: connector_arrow::mysql::MySQLConnection::<::mysql::Conn>::new(conn), data_file_root, } } } impl DbTestRunner for MySqlTestRunner { fn dialect(&self) -> Dialect { Dialect::MySql } fn protocol(&mut self) -> &mut dyn DbProtocol { &mut self.protocol } fn data_file_root(&self) -> &str { &self.data_file_root } fn import_csv(&mut self, csv_path: &str, table_name: &str) { // MySQL-specific CSV import logic let csv_path_binding = std::path::PathBuf::from(csv_path); let local_csv_path = format!( "tests/integration/data/chinook/{}", csv_path_binding.file_name().unwrap().to_str().unwrap() ); let local_old_path = std::path::PathBuf::from(local_csv_path); let mut local_new_path = local_old_path.clone(); local_new_path.pop(); local_new_path.push(format!("{table_name}.my.csv").as_str()); let file_content = fs::read_to_string(local_old_path) .unwrap() .replace(",,", ",\\N,") .replace(",\n", ",\\N\n"); fs::write(&local_new_path, file_content).unwrap(); self.protocol.execute( &format!( "LOAD DATA INFILE '{}' INTO TABLE {table_name} FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY '\"' LINES TERMINATED BY '\n' IGNORE 1 ROWS;", &csv_path_binding.parent().unwrap().join(local_new_path.file_name().unwrap()).to_str().unwrap() ) ).unwrap(); fs::remove_file(&local_new_path).unwrap(); } fn modify_ddl(&self, sql: String) -> String { sql.replace("TIMESTAMP", "DATETIME") } } pub(crate) struct ClickHouseTestRunner { protocol: connector_arrow::mysql::MySQLConnection<::mysql::Conn>, data_file_root: String, } impl ClickHouseTestRunner { #[allow(dead_code)] pub(crate) fn new(url: &str, data_file_root: String) -> Self { Self { protocol: connector_arrow::mysql::MySQLConnection::<::mysql::Conn>::new( ::mysql::Conn::new(url) .unwrap_or_else(|e| panic!("Failed to connect to {}:\n{}", url, e)), ), data_file_root, } } } impl DbTestRunner for ClickHouseTestRunner { fn dialect(&self) -> Dialect { Dialect::ClickHouse } fn protocol(&mut self) -> &mut dyn DbProtocol { &mut self.protocol } fn data_file_root(&self) -> &str { &self.data_file_root } fn import_csv(&mut self, csv_path: &str, table_name: &str) { self.protocol .execute(&format!( "INSERT INTO {table_name} SELECT * FROM file('{csv_path}')" )) .unwrap(); } fn modify_ddl(&self, sql: String) -> String { use regex::Regex; let re = Regex::new(r"(?s)\)$").unwrap(); re.replace(&sql, r") ENGINE = Memory") .replace("TIMESTAMP", "DATETIME64") .replace("FLOAT", "DOUBLE") .replace("VARCHAR(255)", "Nullable(String)") .to_string() } } pub(crate) struct MsSqlTestRunner { protocol: connector_arrow::tiberius::TiberiusConnection< tokio_util::compat::Compat, >, data_file_root: String, } impl MsSqlTestRunner { pub(crate) fn new(data_file_root: String) -> Self { use std::sync::Arc; use tokio_util::compat::TokioAsyncWriteCompatExt; let mut config = tiberius::Config::new(); config.host("localhost"); config.port(1433); config.trust_cert(); config.authentication(tiberius::AuthMethod::sql_server("sa", "Wordpass123##")); let rt = Arc::new( tokio::runtime::Builder::new_current_thread() .enable_all() .build() .unwrap(), ); let client = rt .block_on(async { let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?; tcp.set_nodelay(true).unwrap(); tiberius::Client::connect(config, tcp.compat_write()).await }) .unwrap(); Self { protocol: connector_arrow::tiberius::TiberiusConnection::new(rt, client), data_file_root, } } } impl DbTestRunner for MsSqlTestRunner { fn dialect(&self) -> Dialect { Dialect::MsSql } fn protocol(&mut self) -> &mut dyn DbProtocol { &mut self.protocol } fn data_file_root(&self) -> &str { &self.data_file_root } fn import_csv(&mut self, csv_path: &str, table_name: &str) { self.protocol.execute(&format!( "BULK INSERT {table_name} FROM '{csv_path}' WITH (FIRSTROW = 2, FIELDTERMINATOR = ',', ROWTERMINATOR = '\n', TABLOCK, FORMAT = 'CSV', CODEPAGE = 'RAW');" )).unwrap(); } fn modify_ddl(&self, sql: String) -> String { sql.replace("TIMESTAMP", "DATETIME") .replace("FLOAT", "FLOAT(53)") .replace(" AS TEXT", " AS VARCHAR") } } pub(crate) struct GlareDbTestRunner { protocol: connector_arrow::postgres::PostgresConnection, data_file_root: String, } impl GlareDbTestRunner { pub(crate) fn new(url: &str, data_file_root: String) -> Self { use connector_arrow::postgres::PostgresConnection; let client = ::postgres::Client::connect(url, ::postgres::NoTls).unwrap(); Self { protocol: PostgresConnection::new(client), data_file_root, } } } impl DbTestRunner for GlareDbTestRunner { fn dialect(&self) -> Dialect { Dialect::GlareDb } fn protocol(&mut self) -> &mut dyn DbProtocol { &mut self.protocol } fn data_file_root(&self) -> &str { &self.data_file_root } fn import_csv(&mut self, csv_path: &str, table_name: &str) { self.protocol .execute(&format!( "INSERT INTO {table_name} SELECT * FROM '{csv_path}'" )) .unwrap(); } fn modify_ddl(&self, sql: String) -> String { sql.replace("FLOAT", "DOUBLE PRECISION") } } }