extern crate chrono; extern crate msql_srv; extern crate mysql; extern crate mysql_common as myc; extern crate nom; use msql_srv::AuthenticationContext; use msql_srv::{ Column, ErrorKind, InitWriter, MysqlIntermediary, MysqlShim, ParamParser, QueryResultWriter, StatementMetaWriter, }; use mysql::prelude::*; use mysql::MySqlError; use mysql::OptsBuilder; use mysql::SslOpts; #[cfg(all(feature = "tls", unix))] use openssl::{ asn1::Asn1Time, bn::{BigNum, MsbOption}, error::ErrorStack, hash::MessageDigest, pkcs12::Pkcs12, pkey::{PKey, Private}, rsa::Rsa, x509::{extension::SubjectKeyIdentifier, X509}, }; #[cfg(all(feature = "tls", unix))] use rcgen::generate_simple_self_signed; #[cfg(all(feature = "tls", unix))] use rustls::{ pki_types::{CertificateDer, PrivateKeyDer}, RootCertStore, ServerConfig, }; use std::error::Error; use std::io; use std::io::Read; use std::io::Write; use std::net; use std::sync::Arc; use std::sync::Mutex; use std::thread; use std::time::Duration; struct TestingShim { columns: Vec, params: Vec, on_q: Q, on_p: P, on_e: E, on_i: I, after_auth: A, #[cfg(feature = "tls")] server_tls: Option>, client_tls: Option, #[cfg(all(feature = "tls", unix))] client_cert_pkcs12_file: Option>, } impl MysqlShim for TestingShim where Q: FnMut(&str, QueryResultWriter) -> io::Result<()>, P: FnMut(&str) -> u32, E: FnMut(u32, Vec, QueryResultWriter) -> io::Result<()>, I: FnMut(&str, InitWriter) -> io::Result<()>, A: FnMut(&AuthenticationContext) -> io::Result<()>, { type Error = io::Error; fn on_prepare( &mut self, query: &str, info: StatementMetaWriter, ) -> io::Result<()> { let id = (self.on_p)(query); info.reply(id, &self.params, &self.columns) } fn on_execute( &mut self, id: u32, params: ParamParser, results: QueryResultWriter, ) -> io::Result<()> { (self.on_e)(id, params.into_iter().collect(), results) } fn on_close(&mut self, _: u32) {} fn on_init(&mut self, schema: &str, writer: InitWriter) -> io::Result<()> { (self.on_i)(schema, writer) } fn on_query( &mut self, query: &str, results: QueryResultWriter, ) -> io::Result<()> { (self.on_q)(query, results) } #[cfg(feature = "tls")] fn tls_config(&self) -> Option> { self.server_tls.as_ref().map(std::sync::Arc::clone) } fn after_authentication(&mut self, auth_context: &AuthenticationContext) -> io::Result<()> { (self.after_auth)(auth_context) } } impl TestingShim where Q: 'static + Send + FnMut(&str, QueryResultWriter) -> io::Result<()>, P: 'static + Send + FnMut(&str) -> u32, E: 'static + Send + FnMut(u32, Vec, QueryResultWriter) -> io::Result<()>, I: 'static + Send + FnMut(&str, InitWriter) -> io::Result<()>, T: 'static + Send + FnMut(&AuthenticationContext) -> io::Result<()>, { fn new(on_q: Q, on_p: P, on_e: E, on_i: I, after_auth: T) -> Self { TestingShim { columns: Vec::new(), params: Vec::new(), on_q, on_p, on_e, on_i, after_auth, #[cfg(feature = "tls")] server_tls: None, client_tls: None, #[cfg(all(feature = "tls", unix))] client_cert_pkcs12_file: None, } } fn with_params(mut self, p: Vec) -> Self { self.params = p; self } fn with_columns(mut self, c: Vec) -> Self { self.columns = c; self } #[cfg(all(feature = "tls", unix))] fn with_tls(mut self, client: bool, server: bool, use_client_certs: bool) -> Self { use std::fs::File; use mysql::ClientIdentity; use rustls::server::WebPkiClientVerifier; let mut client_cert_der = None; if use_client_certs { let (client_cert, client_pkey) = mk_client_cert().unwrap(); client_cert_der = Some(CertificateDer::from(client_cert.to_der().unwrap())); // Set up client cert der12 file. let client_cert_pkcs12_file = Arc::new(tempfile::NamedTempFile::new().unwrap()); self.client_cert_pkcs12_file = Some(Arc::clone(&client_cert_pkcs12_file)); let pkcs12 = Pkcs12::builder() .name("friendly_name") .cert(&client_cert) .pkey(&client_pkey) .build2("password") .unwrap(); let der = pkcs12.to_der().unwrap(); let mut f = File::create(&*client_cert_pkcs12_file).unwrap(); f.write(&der).unwrap(); f.flush().unwrap(); } if server { let cert = generate_simple_self_signed(vec!["localhost".to_string()]).unwrap(); let builder = ServerConfig::builder(); let builder = if let Some(client_cert_der) = client_cert_der { let mut client_auth_roots = RootCertStore::empty(); client_auth_roots.add(client_cert_der).unwrap(); let client_auth = WebPkiClientVerifier::builder(client_auth_roots.into()) .build() .unwrap(); builder.with_client_cert_verifier(client_auth) } else { builder.with_no_client_auth() } .with_single_cert( vec![CertificateDer::from(cert.serialize_der().unwrap())], PrivateKeyDer::Pkcs8(cert.get_key_pair().serialize_der().into()), ) .unwrap(); self.server_tls = Some(std::sync::Arc::new(builder)); } if client { self.client_tls = Some( SslOpts::default() .with_danger_accept_invalid_certs(true) .with_client_identity(self.client_cert_pkcs12_file.as_ref().map(|x| { ClientIdentity::new(x.path().to_owned()).with_password("password") })), ); } self } fn test(self, c: C) where C: FnOnce(&mut mysql::Conn), { self.test_with_result(c).unwrap() } fn test_with_result(self, c: C) -> Result<(), Box> where C: FnOnce(&mut mysql::Conn) -> (), { let client_tls = self.client_tls.clone(); let listener = net::TcpListener::bind("127.0.0.1:0").unwrap(); let port = listener.local_addr().unwrap().port(); let jh = thread::spawn(move || { let (s, _) = listener.accept().unwrap(); MysqlIntermediary::run_on_tcp(self, s) }); let opts = OptsBuilder::default() .ip_or_hostname(Some("localhost")) .user(Some("username")) .tcp_port(port) .ssl_opts(client_tls); let mut db = mysql::Conn::new(opts)?; c(&mut db); drop(db); jh.join().unwrap().unwrap(); Ok(()) } } #[cfg(all(feature = "tls", unix))] fn mk_client_cert() -> Result<(X509, PKey), ErrorStack> { let key_pair = PKey::from_rsa(Rsa::generate(2048)?)?; let mut cert_builder = X509::builder()?; cert_builder.set_version(2)?; let serial_number = { let mut serial = BigNum::new()?; serial.rand(159, MsbOption::MAYBE_ZERO, false)?; serial.to_asn1_integer()? }; cert_builder.set_serial_number(&serial_number)?; cert_builder.set_pubkey(&key_pair)?; let not_before = Asn1Time::days_from_now(0)?; cert_builder.set_not_before(¬_before)?; let not_after = Asn1Time::days_from_now(365)?; cert_builder.set_not_after(¬_after)?; let subject_key_identifier = SubjectKeyIdentifier::new().build(&cert_builder.x509v3_context(None, None))?; cert_builder.append_extension(subject_key_identifier)?; cert_builder.sign(&key_pair, MessageDigest::sha256())?; let cert = cert_builder.build(); Ok((cert, key_pair)) } #[test] fn it_connects() { let username = Arc::new(Mutex::new(None)); let username1 = Arc::clone(&username); TestingShim::new( |_, _| unreachable!(), |_| unreachable!(), |_, _, _| unreachable!(), |_, _| unreachable!(), move |a| { let mut ac = username1.lock().unwrap(); assert_eq!(*ac, None); *ac = a.username.clone(); Ok(()) }, ) .test(|_| {}); let username = username.lock().unwrap(); assert_eq!(*username, Some(b"username".to_vec())); } #[cfg(all(feature = "tls", unix))] fn tls_test_common( enable_client_tls: bool, enable_server_tls: bool, use_client_certs: bool, ) -> Result<(Option>, Option>>), Box> { let auth_context = Arc::new(Mutex::new((None, None))); let auth_context1 = Arc::clone(&auth_context); TestingShim::new( |_, _| unreachable!(), |_| unreachable!(), |_, _, _| unreachable!(), |_, _| unreachable!(), move |a| { let mut ac = auth_context1.lock().unwrap(); assert_eq!(*ac, (None, None)); *ac = ( a.username.clone(), a.tls_client_certs .map(|x| x.iter().map(|c| c.clone().into_owned()).collect()), ); Ok(()) }, ) .with_tls(enable_client_tls, enable_server_tls, use_client_certs) .test_with_result(|_| {})?; Ok(Arc::try_unwrap(auth_context).unwrap().into_inner().unwrap()) } #[test] #[cfg(all(feature = "tls", unix))] fn it_connects_tls_server_only() { // Client can connect ok without SSL when SSL is enabled on the server. let (username, certs) = tls_test_common(false, true, false).unwrap(); assert_eq!(username, Some(b"username".to_vec())); assert_eq!(certs, None); } #[test] #[cfg(all(feature = "tls", unix))] fn it_connects_tls_both_no_client_certs() { // SSL connection when ssl enabled on server and used by client, client not passing certs to the server. let (username, certs) = tls_test_common(true, true, false).unwrap(); assert_eq!(username, Some(b"username".to_vec())); assert_eq!(certs, None); } #[test] #[cfg(all(feature = "tls", unix))] fn it_connects_tls_both_with_client_certs() { // SSL connection when ssl enabled on server and used by client, with the client passing certs to the server. let (username, certs) = tls_test_common(true, true, true).unwrap(); assert_eq!(username, Some(b"username".to_vec())); assert!(!certs.expect("expected client certs").is_empty()); } #[test] #[cfg(all(feature = "tls", unix))] fn it_connects_tls_both_with_delayed_server_read() { // This test is to ensure correctly handle the case when we read both the pre-TLS data as well // as (at least part of) the TLS handshake into our the buffer. When that happens, we need to // ensure we correctly pass that TLS part of the data to rustls so that is can handle the TLS // handshake properly. use std::{marker::PhantomData, sync::Arc}; struct MyShim { ph: PhantomData, } impl MysqlShim for MyShim { type Error = io::Error; fn on_prepare( &mut self, _: &str, _: StatementMetaWriter<'_, RW>, ) -> Result<(), Self::Error> { unreachable!() } fn on_execute( &mut self, _: u32, _: ParamParser<'_>, _: QueryResultWriter<'_, RW>, ) -> Result<(), Self::Error> { unreachable!() } fn on_close(&mut self, _: u32) { unreachable!() } fn on_query(&mut self, _: &str, _: QueryResultWriter<'_, RW>) -> Result<(), Self::Error> { unreachable!() } fn tls_config(&self) -> Option> { let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()]).unwrap(); Some(std::sync::Arc::new( ServerConfig::builder() .with_no_client_auth() .with_single_cert( vec![CertificateDer::from(cert.serialize_der().unwrap())], PrivateKeyDer::Pkcs8(cert.get_key_pair().serialize_der().into()), ) .unwrap(), )) } } let shim = MyShim { ph: PhantomData::default(), }; let listener = net::TcpListener::bind("127.0.0.1:0").unwrap(); let port = listener.local_addr().unwrap().port(); let jh = thread::spawn(move || { let (s, _) = listener.accept().unwrap(); let s = DelayedReadRW { s, read_delay: Duration::from_millis(200), }; MysqlIntermediary::run_on(shim, s) }); let db = mysql::Conn::new( OptsBuilder::default() .ip_or_hostname(Some("localhost")) .tcp_port(port) .ssl_opts(Some( SslOpts::default().with_danger_accept_invalid_certs(true), )), ) .unwrap(); drop(db); jh.join().unwrap().unwrap(); } struct DelayedReadRW { s: RW, read_delay: Duration, } impl Read for DelayedReadRW { fn read(&mut self, buf: &mut [u8]) -> io::Result { thread::sleep(self.read_delay); self.s.read(buf) } } impl Write for DelayedReadRW { fn write(&mut self, buf: &[u8]) -> io::Result { self.s.write(buf) } fn flush(&mut self) -> io::Result<()> { self.s.flush() } } #[test] #[cfg(all(feature = "tls", unix))] fn it_does_not_connect_tls_client_only() { // Client requesting tls fails as expected when server does not support it. let e = tls_test_common(true, false, false).expect_err("client should not have connected"); assert!( matches!( e.downcast_ref::(), Some(mysql::Error::DriverError( mysql::DriverError::TlsNotSupported )) ), "unexpected error {:?}", e ); } #[test] fn it_fails_correctly_on_after_auth_error() { let e = TestingShim::new( |_, _| unreachable!(), |_| unreachable!(), |_, _, _| unreachable!(), |_, _| unreachable!(), move |_| Err(io::Error::new(io::ErrorKind::Other, "")), ) .test_with_result(|_| {}) .expect_err("client should not have connected"); let expected = MySqlError { state: "28000".to_owned(), message: "client authentication failed".to_owned(), code: 1045, }; assert!( matches!( e.downcast_ref::(), Some(mysql::Error::MySqlError(m)) if m == &expected, ), "unexpected error {:?}", e ); } #[test] fn it_inits_ok() { TestingShim::new( |_, _| unreachable!(), |_| unreachable!(), |_, _, _| unreachable!(), |schema, writer| { assert_eq!(schema, "test"); writer.ok() }, |_| Ok(()), ) .test(|db| assert_eq!(true, db.select_db("test"))); } #[test] fn it_inits_error() { TestingShim::new( |_, _| unreachable!(), |_| unreachable!(), |_, _, _| unreachable!(), |schema, writer| { assert_eq!(schema, "test"); writer.error( ErrorKind::ER_BAD_DB_ERROR, format!("Database {} not found", schema).as_bytes(), ) }, |_| Ok(()), ) .test(|db| assert_eq!(false, db.select_db("test"))); } #[test] fn it_inits_on_use_query_ok() { TestingShim::new( |_, _| unreachable!(), |_| unreachable!(), |_, _, _| unreachable!(), |schema, writer| { assert_eq!(schema, "test"); writer.ok() }, |_| Ok(()), ) .test(|db| match db.query_drop("USE `test`;") { Ok(_) => assert!(true), Err(_) => assert!(false), }); } #[test] fn it_pings() { TestingShim::new( |_, _| unreachable!(), |_| unreachable!(), |_, _, _| unreachable!(), |_, _| unreachable!(), |_| Ok(()), ) .test(|db| assert_eq!(db.ping(), true)) } #[test] fn empty_response() { TestingShim::new( |_, w| w.completed(0, 0), |_| unreachable!(), |_, _, _| unreachable!(), |_, _| unreachable!(), |_| Ok(()), ) .test(|db| { assert_eq!(db.query_iter("SELECT a, b FROM foo").unwrap().count(), 0); }) } #[test] fn no_rows() { let cols = [Column { table: String::new(), column: "a".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_SHORT, colflags: myc::constants::ColumnFlags::empty(), }]; TestingShim::new( move |_, w| w.start(&cols[..])?.finish(), |_| unreachable!(), |_, _, _| unreachable!(), |_, _| unreachable!(), |_| Ok(()), ) .test(|db| { assert_eq!(db.query_iter("SELECT a, b FROM foo").unwrap().count(), 0); }) } #[test] fn no_columns() { TestingShim::new( move |_, w| w.start(&[])?.finish(), |_| unreachable!(), |_, _, _| unreachable!(), |_, _| unreachable!(), |_| Ok(()), ) .test(|db| { assert_eq!(db.query_iter("SELECT a, b FROM foo").unwrap().count(), 0); }) } #[test] fn no_columns_but_rows() { TestingShim::new( move |_, w| w.start(&[])?.write_col(42).map(|_| ()), |_| unreachable!(), |_, _, _| unreachable!(), |_, _| unreachable!(), |_| Ok(()), ) .test(|db| { assert_eq!(db.query_iter("SELECT a, b FROM foo").unwrap().count(), 0); }) } #[test] fn error_response() { let err = (ErrorKind::ER_NO, "clearly not"); TestingShim::new( move |_, w| w.error(err.0, err.1.as_bytes()), |_| unreachable!(), |_, _, _| unreachable!(), |_, _| unreachable!(), |_| Ok(()), ) .test(|db| { if let mysql::Error::MySqlError(e) = db.query_iter("SELECT a, b FROM foo").unwrap_err() { assert_eq!( e, mysql::error::MySqlError { state: String::from_utf8(err.0.sqlstate().to_vec()).unwrap(), message: err.1.to_owned(), code: err.0 as u16, } ); } else { unreachable!(); } }) } #[test] fn error_in_result_set_response() { let err = (ErrorKind::ER_NO, "clearly not"); TestingShim::new( move |_, w| { let cols = &[Column { table: String::new(), column: "a".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_SHORT, colflags: myc::constants::ColumnFlags::empty(), }]; let mut w = w.start(cols)?; w.write_col(1024)?; w.finish_error(err.0, &err.1.as_bytes()) }, |_| unreachable!(), |_, _, _| unreachable!(), |_, _| unreachable!(), |_| Ok(()), ) .test(|db| { let mut result = db.query_iter("SELECT a FROM foo").unwrap(); let row1 = result.next().unwrap().unwrap().get::(0).unwrap(); assert_eq!(row1, 1024); if let mysql::Error::MySqlError(e) = result.by_ref().next().unwrap().unwrap_err() { assert_eq!( e, mysql::error::MySqlError { state: String::from_utf8(err.0.sqlstate().to_vec()).unwrap(), message: err.1.to_owned(), code: err.0 as u16, } ); } else { unreachable!() } }) } #[test] fn empty_on_drop() { let cols = [Column { table: String::new(), column: "a".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_SHORT, colflags: myc::constants::ColumnFlags::empty(), }]; TestingShim::new( move |_, w| w.start(&cols[..]).map(|_| ()), |_| unreachable!(), |_, _, _| unreachable!(), |_, _| unreachable!(), |_| Ok(()), ) .test(|db| { assert_eq!(db.query_iter("SELECT a, b FROM foo").unwrap().count(), 0); }) } #[test] fn it_queries_nulls() { TestingShim::new( |_, w| { let cols = &[Column { table: String::new(), column: "a".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_SHORT, colflags: myc::constants::ColumnFlags::empty(), }]; let mut w = w.start(cols)?; w.write_col(None::)?; w.finish() }, |_| unreachable!(), |_, _, _| unreachable!(), |_, _| unreachable!(), |_| Ok(()), ) .test(|db| { let row = db .query_iter("SELECT a, b FROM foo") .unwrap() .next() .unwrap() .unwrap(); assert_eq!(row.as_ref(0), Some(&mysql::Value::NULL)); }) } #[test] fn it_queries() { TestingShim::new( |_, w| { let cols = &[Column { table: String::new(), column: "a".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_SHORT, colflags: myc::constants::ColumnFlags::empty(), }]; let mut w = w.start(cols)?; w.write_col(1024i16)?; w.finish() }, |_| unreachable!(), |_, _, _| unreachable!(), |_, _| unreachable!(), |_| Ok(()), ) .test(|db| { let row = db .query_iter("SELECT a, b FROM foo") .unwrap() .next() .unwrap() .unwrap(); assert_eq!(row.get::(0), Some(1024)); }) } #[test] fn multi_result() { TestingShim::new( |_, w| { let cols = &[Column { table: String::new(), column: "a".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_SHORT, colflags: myc::constants::ColumnFlags::empty(), }]; let mut row = w.start(cols)?; row.write_col(1024i16)?; let w = row.finish_one()?; let mut row = w.start(cols)?; row.write_col(1025i16)?; row.finish() }, |_| unreachable!(), |_, _, _| unreachable!(), |_, _| unreachable!(), |_| Ok(()), ) .test(|db| { let mut result = db .query_iter("SELECT a FROM foo; SELECT a FROM foo") .unwrap(); let mut set = result.iter().unwrap(); let row1: Vec<_> = set .by_ref() .filter_map(|row| row.unwrap().get::(0)) .collect(); assert_eq!(row1, vec![1024]); drop(set); let mut set = result.iter().unwrap(); let row2: Vec<_> = set .by_ref() .filter_map(|row| row.unwrap().get::(0)) .collect(); assert_eq!(row2, vec![1025]); drop(set); assert!(result.iter().is_none()); }) } #[test] fn it_queries_many_rows() { TestingShim::new( |_, w| { let cols = &[ Column { table: String::new(), column: "a".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_SHORT, colflags: myc::constants::ColumnFlags::empty(), }, Column { table: String::new(), column: "b".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_SHORT, colflags: myc::constants::ColumnFlags::empty(), }, ]; let mut w = w.start(cols)?; w.write_col(1024i16)?; w.write_col(1025i16)?; w.end_row()?; w.write_row(&[1024i16, 1025i16])?; w.finish() }, |_| unreachable!(), |_, _, _| unreachable!(), |_, _| unreachable!(), |_| Ok(()), ) .test(|db| { let mut rows = 0; for row in db.query_iter("SELECT a, b FROM foo").unwrap() { let row = row.unwrap(); assert_eq!(row.get::(0), Some(1024)); assert_eq!(row.get::(1), Some(1025)); rows += 1; } assert_eq!(rows, 2); }) } #[test] fn it_prepares() { let cols = vec![Column { table: String::new(), column: "a".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_SHORT, colflags: myc::constants::ColumnFlags::empty(), }]; let cols2 = cols.clone(); let params = vec![Column { table: String::new(), column: "c".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_SHORT, colflags: myc::constants::ColumnFlags::empty(), }]; TestingShim::new( |_, _| unreachable!(), |q| { assert_eq!(q, "SELECT a FROM b WHERE c = ?"); 41 }, move |stmt, params, w| { assert_eq!(stmt, 41); assert_eq!(params.len(), 1); // rust-mysql sends all numbers as LONGLONG assert_eq!( params[0].coltype, myc::constants::ColumnType::MYSQL_TYPE_LONGLONG ); assert_eq!(Into::::into(params[0].value), 42i8); let mut w = w.start(&cols)?; w.write_col(1024i16)?; w.finish() }, |_, _| unreachable!(), |_| Ok(()), ) .with_params(params) .with_columns(cols2) .test(|db| { let row = db .exec_iter("SELECT a FROM b WHERE c = ?", (42i16,)) .unwrap() .next() .unwrap() .unwrap(); assert_eq!(row.get::(0), Some(1024i16)); }) } #[test] fn insert_exec() { let params = vec![ Column { table: String::new(), column: "username".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_VARCHAR, colflags: myc::constants::ColumnFlags::empty(), }, Column { table: String::new(), column: "email".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_VARCHAR, colflags: myc::constants::ColumnFlags::empty(), }, Column { table: String::new(), column: "pw".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_VARCHAR, colflags: myc::constants::ColumnFlags::empty(), }, Column { table: String::new(), column: "created".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_DATETIME, colflags: myc::constants::ColumnFlags::empty(), }, Column { table: String::new(), column: "session".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_VARCHAR, colflags: myc::constants::ColumnFlags::empty(), }, Column { table: String::new(), column: "rss".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_VARCHAR, colflags: myc::constants::ColumnFlags::empty(), }, Column { table: String::new(), column: "mail".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_VARCHAR, colflags: myc::constants::ColumnFlags::empty(), }, ]; TestingShim::new( |_, _| unreachable!(), |_| 1, move |_, params, w| { assert_eq!(params.len(), 7); assert_eq!( params[0].coltype, myc::constants::ColumnType::MYSQL_TYPE_VAR_STRING ); assert_eq!( params[1].coltype, myc::constants::ColumnType::MYSQL_TYPE_VAR_STRING ); assert_eq!( params[2].coltype, myc::constants::ColumnType::MYSQL_TYPE_VAR_STRING ); assert_eq!( params[3].coltype, myc::constants::ColumnType::MYSQL_TYPE_DATETIME ); assert_eq!( params[4].coltype, myc::constants::ColumnType::MYSQL_TYPE_VAR_STRING ); assert_eq!( params[5].coltype, myc::constants::ColumnType::MYSQL_TYPE_VAR_STRING ); assert_eq!( params[6].coltype, myc::constants::ColumnType::MYSQL_TYPE_VAR_STRING ); assert_eq!(Into::<&str>::into(params[0].value), "user199"); assert_eq!(Into::<&str>::into(params[1].value), "user199@example.com"); assert_eq!( Into::<&str>::into(params[2].value), "$2a$10$Tq3wrGeC0xtgzuxqOlc3v.07VTUvxvwI70kuoVihoO2cE5qj7ooka" ); assert_eq!( Into::::into(params[3].value), chrono::NaiveDate::from_ymd_opt(2018, 4, 6) .unwrap() .and_hms_opt(13, 0, 56) .unwrap() ); assert_eq!(Into::<&str>::into(params[4].value), "token199"); assert_eq!(Into::<&str>::into(params[5].value), "rsstoken199"); assert_eq!(Into::<&str>::into(params[6].value), "mtok199"); w.completed(42, 1) }, |_, _| unreachable!(), |_| Ok(()), ) .with_params(params) .test(|db| { let res = db .exec_iter( "INSERT INTO `users` \ (`username`, `email`, `password_digest`, `created_at`, \ `session_token`, `rss_token`, `mailing_list_token`) \ VALUES (?, ?, ?, ?, ?, ?, ?)", ( "user199", "user199@example.com", "$2a$10$Tq3wrGeC0xtgzuxqOlc3v.07VTUvxvwI70kuoVihoO2cE5qj7ooka", mysql::Value::Date(2018, 4, 6, 13, 0, 56, 0), "token199", "rsstoken199", "mtok199", ), ) .unwrap(); assert_eq!(res.affected_rows(), 42); assert_eq!(res.last_insert_id(), Some(1)); }) } #[test] fn send_long() { let cols = vec![Column { table: String::new(), column: "a".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_SHORT, colflags: myc::constants::ColumnFlags::empty(), }]; let cols2 = cols.clone(); let params = vec![Column { table: String::new(), column: "c".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_BLOB, colflags: myc::constants::ColumnFlags::empty(), }]; TestingShim::new( |_, _| unreachable!(), |q| { assert_eq!(q, "SELECT a FROM b WHERE c = ?"); 41 }, move |stmt, params, w| { assert_eq!(stmt, 41); assert_eq!(params.len(), 1); // rust-mysql sends all strings as VAR_STRING assert_eq!( params[0].coltype, myc::constants::ColumnType::MYSQL_TYPE_VAR_STRING ); assert_eq!(Into::<&[u8]>::into(params[0].value), b"Hello world"); let mut w = w.start(&cols)?; w.write_col(1024i16)?; w.finish() }, |_, _| unreachable!(), |_| Ok(()), ) .with_params(params) .with_columns(cols2) .test(|db| { let row = db .exec_iter("SELECT a FROM b WHERE c = ?", (b"Hello world",)) .unwrap() .next() .unwrap() .unwrap(); assert_eq!(row.get::(0), Some(1024i16)); }) } #[test] fn it_prepares_many() { let cols = vec![ Column { table: String::new(), column: "a".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_SHORT, colflags: myc::constants::ColumnFlags::empty(), }, Column { table: String::new(), column: "b".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_SHORT, colflags: myc::constants::ColumnFlags::empty(), }, ]; let cols2 = cols.clone(); TestingShim::new( |_, _| unreachable!(), |q| { assert_eq!(q, "SELECT a, b FROM x"); 41 }, move |stmt, params, w| { assert_eq!(stmt, 41); assert_eq!(params.len(), 0); let mut w = w.start(&cols)?; w.write_col(1024i16)?; w.write_col(1025i16)?; w.end_row()?; w.write_row(&[1024i16, 1025i16])?; w.finish() }, |_, _| unreachable!(), |_| Ok(()), ) .with_params(Vec::new()) .with_columns(cols2) .test(|db| { let mut rows = 0; for row in db.exec_iter("SELECT a, b FROM x", ()).unwrap() { let row = row.unwrap(); assert_eq!(row.get::(0), Some(1024)); assert_eq!(row.get::(1), Some(1025)); rows += 1; } assert_eq!(rows, 2); }) } #[test] fn prepared_empty() { let cols = vec![Column { table: String::new(), column: "a".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_SHORT, colflags: myc::constants::ColumnFlags::empty(), }]; let cols2 = cols; let params = vec![Column { table: String::new(), column: "c".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_SHORT, colflags: myc::constants::ColumnFlags::empty(), }]; TestingShim::new( |_, _| unreachable!(), |_| 0, move |_, params, w| { assert!(!params.is_empty()); w.completed(0, 0) }, |_, _| unreachable!(), |_| Ok(()), ) .with_params(params) .with_columns(cols2) .test(|db| { assert_eq!( db.exec_iter("SELECT a FROM b WHERE c = ?", (42i16,)) .unwrap() .count(), 0 ); }) } #[test] fn prepared_no_params() { let cols = vec![Column { table: String::new(), column: "a".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_SHORT, colflags: myc::constants::ColumnFlags::empty(), }]; let cols2 = cols.clone(); let params = vec![]; TestingShim::new( |_, _| unreachable!(), |_| 0, move |_, params, w| { assert!(params.is_empty()); let mut w = w.start(&cols)?; w.write_col(1024i16)?; w.finish() }, |_, _| unreachable!(), |_| Ok(()), ) .with_params(params) .with_columns(cols2) .test(|db| { let row = db.exec_iter("foo", ()).unwrap().next().unwrap().unwrap(); assert_eq!(row.get::(0), Some(1024i16)); }) } #[test] fn prepared_nulls() { let cols = vec![ Column { table: String::new(), column: "a".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_SHORT, colflags: myc::constants::ColumnFlags::empty(), }, Column { table: String::new(), column: "b".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_SHORT, colflags: myc::constants::ColumnFlags::empty(), }, ]; let cols2 = cols.clone(); let params = vec![ Column { table: String::new(), column: "c".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_SHORT, colflags: myc::constants::ColumnFlags::empty(), }, Column { table: String::new(), column: "d".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_SHORT, colflags: myc::constants::ColumnFlags::empty(), }, ]; TestingShim::new( |_, _| unreachable!(), |_| 0, move |_, params, w| { assert_eq!(params.len(), 2); assert!(params[0].value.is_null()); assert!(!params[1].value.is_null()); assert_eq!( params[0].coltype, myc::constants::ColumnType::MYSQL_TYPE_NULL ); // rust-mysql sends all numbers as LONGLONG :'( assert_eq!( params[1].coltype, myc::constants::ColumnType::MYSQL_TYPE_LONGLONG ); assert_eq!(Into::::into(params[1].value), 42i8); let mut w = w.start(&cols)?; w.write_row(vec![None::, Some(42)])?; w.finish() }, |_, _| unreachable!(), |_| Ok(()), ) .with_params(params) .with_columns(cols2) .test(|db| { let row = db .exec_iter( "SELECT a, b FROM x WHERE c = ? AND d = ?", (mysql::Value::NULL, 42), ) .unwrap() .next() .unwrap() .unwrap(); assert_eq!(row.as_ref(0), Some(&mysql::Value::NULL)); assert_eq!(row.get::(1), Some(42)); }) } #[test] fn prepared_no_rows() { let cols = vec![Column { table: String::new(), column: "a".to_owned(), coltype: myc::constants::ColumnType::MYSQL_TYPE_SHORT, colflags: myc::constants::ColumnFlags::empty(), }]; let cols2 = cols.clone(); TestingShim::new( |_, _| unreachable!(), |_| 0, move |_, _, w| w.start(&cols[..])?.finish(), |_, _| unreachable!(), |_| Ok(()), ) .with_columns(cols2) .test(|db| { assert_eq!(db.exec_iter("SELECT a, b FROM foo", ()).unwrap().count(), 0); }) } #[test] fn prepared_no_cols_but_rows() { TestingShim::new( |_, _| unreachable!(), |_| 0, move |_, _, w| w.start(&[])?.write_col(42).map(|_| ()), |_, _| unreachable!(), |_| Ok(()), ) .test(|db| { assert_eq!(db.exec_iter("SELECT a, b FROM foo", ()).unwrap().count(), 0); }) } #[test] fn prepared_no_cols() { TestingShim::new( |_, _| unreachable!(), |_| 0, move |_, _, w| w.start(&[])?.finish(), |_, _| unreachable!(), |_| Ok(()), ) .test(|db| { assert_eq!(db.exec_iter("SELECT a, b FROM foo", ()).unwrap().count(), 0); }) } #[test] fn really_long_query() { let long = "CREATE TABLE `stories` (`id` int unsigned NOT NULL AUTO_INCREMENT PRIMARY KEY, `always_null` int, `created_at` datetime, `user_id` int unsigned, `url` varchar(250) DEFAULT '', `title` varchar(150) DEFAULT '' NOT NULL, `description` mediumtext, `short_id` varchar(6) DEFAULT '' NOT NULL, `is_expired` tinyint(1) DEFAULT 0 NOT NULL, `is_moderated` tinyint(1) DEFAULT 0 NOT NULL, `markeddown_description` mediumtext, `story_cache` mediumtext, `merged_story_id` int, `unavailable_at` datetime, `twitter_id` varchar(20), `user_is_author` tinyint(1) DEFAULT 0, INDEX `index_stories_on_created_at` (`created_at`), fulltext INDEX `index_stories_on_description` (`description`), INDEX `is_idxes` (`is_expired`, `is_moderated`), INDEX `index_stories_on_is_expired` (`is_expired`), INDEX `index_stories_on_is_moderated` (`is_moderated`), INDEX `index_stories_on_merged_story_id` (`merged_story_id`), UNIQUE INDEX `unique_short_id` (`short_id`), fulltext INDEX `index_stories_on_story_cache` (`story_cache`), fulltext INDEX `index_stories_on_title` (`title`), INDEX `index_stories_on_twitter_id` (`twitter_id`), INDEX `url` (`url`(191)), INDEX `index_stories_on_user_id` (`user_id`)) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;"; TestingShim::new( move |q, w| { assert_eq!(q, long); w.start(&[])?.finish() }, |_| 0, |_, _, _| unreachable!(), |_, _| unreachable!(), |_| Ok(()), ) .test(move |db| { db.query_iter(long).unwrap(); }) }