use std::sync::{Arc, Mutex}; use async_trait::async_trait; use futures::stream; use futures::Stream; use pgwire::api::auth::md5pass::{hash_md5_password, Md5PasswordAuthStartupHandler}; use pgwire::api::auth::{AuthSource, DefaultServerParameterProvider, LoginInfo, Password}; use pgwire::api::copy::NoopCopyHandler; use pgwire::api::portal::{Format, Portal}; use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{ DataRowEncoder, DescribePortalResponse, DescribeStatementResponse, FieldInfo, QueryResponse, Response, Tag, }; use pgwire::api::stmt::{NoopQueryParser, StoredStatement}; use pgwire::api::PgWireHandlerFactory; use pgwire::api::{ClientInfo, Type}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use pgwire::messages::data::DataRow; use pgwire::tokio::process_socket; use rusqlite::Rows; use rusqlite::{types::ValueRef, Connection, Statement, ToSql}; use tokio::net::TcpListener; pub struct SqliteBackend { conn: Arc>, query_parser: Arc, } struct DummyAuthSource; #[async_trait] impl AuthSource for DummyAuthSource { async fn get_password(&self, login_info: &LoginInfo) -> PgWireResult { println!("login info: {:?}", login_info); let salt = vec![0, 0, 0, 0]; let password = "pencil"; let hash_password = hash_md5_password(login_info.user().as_ref().unwrap(), password, salt.as_ref()); Ok(Password::new(Some(salt), hash_password.as_bytes().to_vec())) } } #[async_trait] impl SimpleQueryHandler for SqliteBackend { async fn do_query<'a, C>( &self, _client: &mut C, query: &'a str, ) -> PgWireResult>> where C: ClientInfo + Unpin + Send + Sync, { let conn = self.conn.lock().unwrap(); if query.to_uppercase().starts_with("SELECT") { let mut stmt = conn .prepare(query) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; let header = Arc::new(row_desc_from_stmt(&stmt, &Format::UnifiedText)?); stmt.query(()) .map(|rows| { let s = encode_row_data(rows, header.clone()); vec![Response::Query(QueryResponse::new(header, s))] }) .map_err(|e| PgWireError::ApiError(Box::new(e))) } else { conn.execute(query, ()) .map(|affected_rows| { vec![Response::Execution( Tag::new("OK").with_rows(affected_rows).into(), )] }) .map_err(|e| PgWireError::ApiError(Box::new(e))) } } } fn name_to_type(name: &str) -> PgWireResult { dbg!(name); match name.to_uppercase().as_ref() { "INT" => Ok(Type::INT8), "VARCHAR" => Ok(Type::VARCHAR), "TEXT" => Ok(Type::TEXT), "BINARY" => Ok(Type::BYTEA), "FLOAT" => Ok(Type::FLOAT8), _ => Err(PgWireError::UserError(Box::new(ErrorInfo::new( "ERROR".to_owned(), "42846".to_owned(), format!("Unsupported data type: {name}"), )))), } } fn row_desc_from_stmt(stmt: &Statement, format: &Format) -> PgWireResult> { stmt.columns() .iter() .enumerate() .map(|(idx, col)| { let field_type = col .decl_type() .map(name_to_type) .unwrap_or(Ok(Type::UNKNOWN))?; Ok(FieldInfo::new( col.name().to_owned(), None, None, field_type, format.format_for(idx), )) }) .collect() } fn encode_row_data( mut rows: Rows, schema: Arc>, ) -> impl Stream> { let mut results = Vec::new(); let ncols = schema.len(); while let Ok(Some(row)) = rows.next() { let mut encoder = DataRowEncoder::new(schema.clone()); for idx in 0..ncols { let data = row.get_ref_unwrap::(idx); match data { ValueRef::Null => encoder.encode_field(&None::).unwrap(), ValueRef::Integer(i) => { encoder.encode_field(&i).unwrap(); } ValueRef::Real(f) => { encoder.encode_field(&f).unwrap(); } ValueRef::Text(t) => { encoder .encode_field(&String::from_utf8_lossy(t).as_ref()) .unwrap(); } ValueRef::Blob(b) => { encoder.encode_field(&b).unwrap(); } } } results.push(encoder.finish()); } stream::iter(results.into_iter()) } fn get_params(portal: &Portal) -> Vec> { let mut results = Vec::with_capacity(portal.parameter_len()); for i in 0..portal.parameter_len() { let param_type = portal.statement.parameter_types.get(i).unwrap(); // we only support a small amount of types for demo match param_type { &Type::BOOL => { let param = portal.parameter::(i, param_type).unwrap(); results.push(Box::new(param) as Box); } &Type::INT2 => { let param = portal.parameter::(i, param_type).unwrap(); results.push(Box::new(param) as Box); } &Type::INT4 => { let param = portal.parameter::(i, param_type).unwrap(); results.push(Box::new(param) as Box); } &Type::INT8 => { let param = portal.parameter::(i, param_type).unwrap(); results.push(Box::new(param) as Box); } &Type::TEXT | &Type::VARCHAR => { let param = portal.parameter::(i, param_type).unwrap(); results.push(Box::new(param) as Box); } &Type::FLOAT4 => { let param = portal.parameter::(i, param_type).unwrap(); results.push(Box::new(param) as Box); } &Type::FLOAT8 => { let param = portal.parameter::(i, param_type).unwrap(); results.push(Box::new(param) as Box); } _ => { unimplemented!("parameter type not supported") } } } results } #[async_trait] impl ExtendedQueryHandler for SqliteBackend { type Statement = String; type QueryParser = NoopQueryParser; fn query_parser(&self) -> Arc { self.query_parser.clone() } async fn do_query<'a, C>( &self, _client: &mut C, portal: &'a Portal, _max_rows: usize, ) -> PgWireResult> where C: ClientInfo + Unpin + Send + Sync, { let conn = self.conn.lock().unwrap(); let query = &portal.statement.statement; let mut stmt = conn .prepare_cached(query) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; let params = get_params(portal); let params_ref = params .iter() .map(|f| f.as_ref()) .collect::>(); if query.to_uppercase().starts_with("SELECT") { let header = Arc::new(row_desc_from_stmt(&stmt, &portal.result_column_format)?); stmt.query::<&[&dyn rusqlite::ToSql]>(params_ref.as_ref()) .map(|rows| { let s = encode_row_data(rows, header.clone()); Response::Query(QueryResponse::new(header, s)) }) .map_err(|e| PgWireError::ApiError(Box::new(e))) } else { stmt.execute::<&[&dyn rusqlite::ToSql]>(params_ref.as_ref()) .map(|affected_rows| { Response::Execution(Tag::new("OK").with_rows(affected_rows).into()) }) .map_err(|e| PgWireError::ApiError(Box::new(e))) } } async fn do_describe_statement( &self, _client: &mut C, stmt: &StoredStatement, ) -> PgWireResult where C: ClientInfo + Unpin + Send + Sync, { let conn = self.conn.lock().unwrap(); let param_types = stmt.parameter_types.clone(); let stmt = conn .prepare_cached(&stmt.statement) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; row_desc_from_stmt(&stmt, &Format::UnifiedBinary) .map(|fields| DescribeStatementResponse::new(param_types, fields)) } async fn do_describe_portal( &self, _client: &mut C, portal: &Portal, ) -> PgWireResult where C: ClientInfo + Unpin + Send + Sync, { let conn = self.conn.lock().unwrap(); let stmt = conn .prepare_cached(&portal.statement.statement) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; row_desc_from_stmt(&stmt, &portal.result_column_format) .map(|fields| DescribePortalResponse::new(fields)) } } impl SqliteBackend { fn new() -> SqliteBackend { SqliteBackend { conn: Arc::new(Mutex::new(Connection::open_in_memory().unwrap())), query_parser: Arc::new(NoopQueryParser::new()), } } } struct SqliteBackendFactory { handler: Arc, } impl PgWireHandlerFactory for SqliteBackendFactory { type StartupHandler = Md5PasswordAuthStartupHandler; type SimpleQueryHandler = SqliteBackend; type ExtendedQueryHandler = SqliteBackend; type CopyHandler = NoopCopyHandler; fn simple_query_handler(&self) -> Arc { self.handler.clone() } fn extended_query_handler(&self) -> Arc { self.handler.clone() } fn startup_handler(&self) -> Arc { let mut parameters = DefaultServerParameterProvider::default(); parameters.server_version = rusqlite::version().to_owned(); Arc::new(Md5PasswordAuthStartupHandler::new( Arc::new(DummyAuthSource), Arc::new(parameters), )) } fn copy_handler(&self) -> Arc { Arc::new(NoopCopyHandler) } } #[tokio::main] pub async fn main() { let factory = Arc::new(SqliteBackendFactory { handler: Arc::new(SqliteBackend::new()), }); let server_addr = "127.0.0.1:5432"; let listener = TcpListener::bind(server_addr).await.unwrap(); println!("Listening to {}", server_addr); loop { let incoming_socket = listener.accept().await.unwrap(); let factory_ref = factory.clone(); tokio::spawn(async move { process_socket(incoming_socket.0, None, factory_ref).await }); } }