use std::sync::{Arc, Mutex}; use async_trait::async_trait; use duckdb::arrow::datatypes::DataType; use duckdb::Rows; use duckdb::{params, types::ValueRef, Connection, Statement, ToSql}; use futures::stream; use futures::Stream; use pgwire::api::auth::md5pass::{hash_md5_password, Md5PasswordAuthStartupHandler}; use pgwire::api::auth::{AuthSource, DefaultServerParameterProvider, LoginInfo, Password}; use pgwire::api::copy::NoopCopyHandler; use pgwire::api::portal::{Format, Portal}; use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{ DataRowEncoder, DescribePortalResponse, DescribeStatementResponse, FieldInfo, QueryResponse, Response, Tag, }; use pgwire::api::stmt::{NoopQueryParser, StoredStatement}; use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use pgwire::messages::data::DataRow; use pgwire::tokio::process_socket; use tokio::net::TcpListener; pub struct DuckDBBackend { conn: Arc>, query_parser: Arc, } struct DummyAuthSource; #[async_trait] impl AuthSource for DummyAuthSource { async fn get_password(&self, login_info: &LoginInfo) -> PgWireResult { println!("login info: {:?}", login_info); let salt = vec![0, 0, 0, 0]; let password = "pencil"; let hash_password = hash_md5_password(login_info.user().as_ref().unwrap(), password, salt.as_ref()); Ok(Password::new(Some(salt), hash_password.as_bytes().to_vec())) } } #[async_trait] impl SimpleQueryHandler for DuckDBBackend { async fn do_query<'a, C>( &self, _client: &mut C, query: &'a str, ) -> PgWireResult>> where C: ClientInfo + Unpin + Send + Sync, { let conn = self.conn.lock().unwrap(); if query.to_uppercase().starts_with("SELECT") { let mut stmt = conn .prepare(query) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; let header = Arc::new(row_desc_from_stmt(&stmt, &Format::UnifiedText)?); stmt.query(params![]) .map(|rows| { let s = encode_row_data(rows, header.clone()); vec![Response::Query(QueryResponse::new(header, s))] }) .map_err(|e| PgWireError::ApiError(Box::new(e))) } else { conn.execute(query, params![]) .map(|affected_rows| { vec![Response::Execution( Tag::new("OK").with_rows(affected_rows).into(), )] }) .map_err(|e| PgWireError::ApiError(Box::new(e))) } } } fn into_pg_type(df_type: &DataType) -> PgWireResult { Ok(match df_type { DataType::Null => Type::UNKNOWN, DataType::Boolean => Type::BOOL, DataType::Int8 | DataType::UInt8 => Type::CHAR, DataType::Int16 | DataType::UInt16 => Type::INT2, DataType::Int32 | DataType::UInt32 => Type::INT4, DataType::Int64 | DataType::UInt64 => Type::INT8, DataType::Timestamp(_, _) => Type::TIMESTAMP, DataType::Time32(_) | DataType::Time64(_) => Type::TIME, DataType::Date32 | DataType::Date64 => Type::DATE, DataType::Binary => Type::BYTEA, DataType::Float32 => Type::FLOAT4, DataType::Float64 => Type::FLOAT8, DataType::Utf8 => Type::VARCHAR, DataType::List(field) => match field.data_type() { DataType::Boolean => Type::BOOL_ARRAY, DataType::Int8 | DataType::UInt8 => Type::CHAR_ARRAY, DataType::Int16 | DataType::UInt16 => Type::INT2_ARRAY, DataType::Int32 | DataType::UInt32 => Type::INT4_ARRAY, DataType::Int64 | DataType::UInt64 => Type::INT8_ARRAY, DataType::Timestamp(_, _) => Type::TIMESTAMP_ARRAY, DataType::Time32(_) | DataType::Time64(_) => Type::TIME_ARRAY, DataType::Date32 | DataType::Date64 => Type::DATE_ARRAY, DataType::Binary => Type::BYTEA_ARRAY, DataType::Float32 => Type::FLOAT4_ARRAY, DataType::Float64 => Type::FLOAT8_ARRAY, DataType::Utf8 => Type::VARCHAR_ARRAY, list_type => { return Err(PgWireError::UserError(Box::new(ErrorInfo::new( "ERROR".to_owned(), "XX000".to_owned(), format!("Unsupported List Datatype {list_type}"), )))); } }, _ => { return Err(PgWireError::UserError(Box::new(ErrorInfo::new( "ERROR".to_owned(), "XX000".to_owned(), format!("Unsupported Datatype {df_type}"), )))); } }) } fn row_desc_from_stmt(stmt: &Statement, format: &Format) -> PgWireResult> { let columns = stmt.column_count(); (0..columns) .map(|idx| { let datatype = stmt.column_type(idx); let name = stmt.column_name(idx).unwrap(); Ok(FieldInfo::new( name.clone(), None, None, into_pg_type(&datatype).unwrap(), format.format_for(idx), )) }) .collect() } fn encode_row_data( mut rows: Rows<'_>, schema: Arc>, ) -> impl Stream> { let mut results = Vec::new(); let ncols = schema.len(); while let Ok(Some(row)) = rows.next() { let mut encoder = DataRowEncoder::new(schema.clone()); for idx in 0..ncols { let data = row.get_ref_unwrap::(idx); match data { ValueRef::Null => encoder.encode_field(&None::).unwrap(), ValueRef::TinyInt(i) => { encoder.encode_field(&i).unwrap(); } ValueRef::SmallInt(i) => { encoder.encode_field(&i).unwrap(); } ValueRef::Int(i) => { encoder.encode_field(&i).unwrap(); } ValueRef::BigInt(i) => { encoder.encode_field(&i).unwrap(); } ValueRef::Float(f) => { encoder.encode_field(&f).unwrap(); } ValueRef::Double(f) => { encoder.encode_field(&f).unwrap(); } ValueRef::Text(t) => { encoder .encode_field(&String::from_utf8_lossy(t).as_ref()) .unwrap(); } ValueRef::Blob(b) => { encoder.encode_field(&b).unwrap(); } _ => { unimplemented!("More types to be supported.") } } } results.push(encoder.finish()); } stream::iter(results.into_iter()) } fn get_params(portal: &Portal) -> Vec> { let mut results = Vec::with_capacity(portal.parameter_len()); for i in 0..portal.parameter_len() { let param_type = portal.statement.parameter_types.get(i).unwrap(); // we only support a small amount of types for demo match param_type { &Type::BOOL => { let param = portal.parameter::(i, param_type).unwrap(); results.push(Box::new(param) as Box); } &Type::INT2 => { let param = portal.parameter::(i, param_type).unwrap(); results.push(Box::new(param) as Box); } &Type::INT4 => { let param = portal.parameter::(i, param_type).unwrap(); results.push(Box::new(param) as Box); } &Type::INT8 => { let param = portal.parameter::(i, param_type).unwrap(); results.push(Box::new(param) as Box); } &Type::TEXT | &Type::VARCHAR => { let param = portal.parameter::(i, param_type).unwrap(); results.push(Box::new(param) as Box); } &Type::FLOAT4 => { let param = portal.parameter::(i, param_type).unwrap(); results.push(Box::new(param) as Box); } &Type::FLOAT8 => { let param = portal.parameter::(i, param_type).unwrap(); results.push(Box::new(param) as Box); } _ => { unimplemented!("parameter type not supported") } } } results } #[async_trait] impl ExtendedQueryHandler for DuckDBBackend { type Statement = String; type QueryParser = NoopQueryParser; fn query_parser(&self) -> Arc { self.query_parser.clone() } async fn do_query<'a, C>( &self, _client: &mut C, portal: &'a Portal, _max_rows: usize, ) -> PgWireResult> where C: ClientInfo + Unpin + Send + Sync, { let conn = self.conn.lock().unwrap(); let query = &portal.statement.statement; let mut stmt = conn .prepare_cached(query) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; let params = get_params(portal); let params_ref = params .iter() .map(|f| f.as_ref()) .collect::>(); if query.to_uppercase().starts_with("SELECT") { let header = Arc::new(row_desc_from_stmt(&stmt, &portal.result_column_format)?); stmt.query::<&[&dyn duckdb::ToSql]>(params_ref.as_ref()) .map(|rows| { let s = encode_row_data(rows, header.clone()); Response::Query(QueryResponse::new(header, s)) }) .map_err(|e| PgWireError::ApiError(Box::new(e))) } else { stmt.execute::<&[&dyn duckdb::ToSql]>(params_ref.as_ref()) .map(|affected_rows| { Response::Execution(Tag::new("OK").with_rows(affected_rows).into()) }) .map_err(|e| PgWireError::ApiError(Box::new(e))) } } async fn do_describe_statement( &self, _client: &mut C, stmt: &StoredStatement, ) -> PgWireResult where C: ClientInfo + Unpin + Send + Sync, { let conn = self.conn.lock().unwrap(); let param_types = stmt.parameter_types.clone(); let stmt = conn .prepare_cached(&stmt.statement) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; row_desc_from_stmt(&stmt, &Format::UnifiedBinary) .map(|fields| DescribeStatementResponse::new(param_types, fields)) } async fn do_describe_portal( &self, _client: &mut C, portal: &Portal, ) -> PgWireResult where C: ClientInfo + Unpin + Send + Sync, { let conn = self.conn.lock().unwrap(); let stmt = conn .prepare_cached(&portal.statement.statement) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; row_desc_from_stmt(&stmt, &portal.result_column_format) .map(|fields| DescribePortalResponse::new(fields)) } } impl DuckDBBackend { fn new() -> DuckDBBackend { DuckDBBackend { conn: Arc::new(Mutex::new(Connection::open_in_memory().unwrap())), query_parser: Arc::new(NoopQueryParser::new()), } } } struct DuckDBBackendFactory { handler: Arc, } impl PgWireServerHandlers for DuckDBBackendFactory { type StartupHandler = Md5PasswordAuthStartupHandler; type SimpleQueryHandler = DuckDBBackend; type ExtendedQueryHandler = DuckDBBackend; type CopyHandler = NoopCopyHandler; type ErrorHandler = NoopErrorHandler; fn simple_query_handler(&self) -> Arc { self.handler.clone() } fn extended_query_handler(&self) -> Arc { self.handler.clone() } fn startup_handler(&self) -> Arc { Arc::new(Md5PasswordAuthStartupHandler::new( Arc::new(DummyAuthSource), Arc::new(DefaultServerParameterProvider::default()), )) } fn copy_handler(&self) -> Arc { Arc::new(NoopCopyHandler) } fn error_handler(&self) -> Arc { Arc::new(NoopErrorHandler) } } #[tokio::main] pub async fn main() { let factory = Arc::new(DuckDBBackendFactory { handler: Arc::new(DuckDBBackend::new()), }); let server_addr = "127.0.0.1:5432"; let listener = TcpListener::bind(server_addr).await.unwrap(); println!( "Listening to {}, use password `pencil` to connect", server_addr ); loop { let incoming_socket = listener.accept().await.unwrap(); let factory_ref = factory.clone(); tokio::spawn(async move { process_socket(incoming_socket.0, None, factory_ref).await }); } }