use std::fmt::Debug; use std::sync::Arc; use async_trait::async_trait; use futures::{stream, Sink, SinkExt}; use tokio::net::TcpListener; use pgwire::api::auth::noop::NoopStartupHandler; use pgwire::api::copy::NoopCopyHandler; use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag}; use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type}; use pgwire::error::ErrorInfo; use pgwire::error::{PgWireError, PgWireResult}; use pgwire::messages::response::NoticeResponse; use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage}; use pgwire::tokio::process_socket; pub struct DummyProcessor; #[async_trait] impl NoopStartupHandler for DummyProcessor { async fn post_startup( &self, client: &mut C, _message: PgWireFrontendMessage, ) -> PgWireResult<()> where C: ClientInfo + Sink + Unpin + Send, C::Error: Debug, PgWireError: From<>::Error>, { println!("Connected: {}", client.socket_addr()); client .send(PgWireBackendMessage::NoticeResponse(NoticeResponse::from( ErrorInfo::new( "NOTICE".to_owned(), "01000".to_owned(), "Supported queries in this example:\n- BEGIN;\n- ROLLBACK;\n- COMMIT;\n- SELECT 1;" .to_string(), ), ))) .await?; Ok(()) } } #[async_trait] impl SimpleQueryHandler for DummyProcessor { async fn do_query<'a, C>( &self, _client: &mut C, query: &'a str, ) -> PgWireResult>> where C: ClientInfo + Sink + Unpin + Send + Sync, C::Error: Debug, PgWireError: From<>::Error>, { let resp = match query { "BEGIN;" => Response::TransactionStart(Tag::new("BEGIN")), "ROLLBACK;" => Response::TransactionEnd(Tag::new("ROLLBACK")), "COMMIT;" => Response::TransactionEnd(Tag::new("COMMIT")), "SELECT 1;" => { let f1 = FieldInfo::new("SELECT 1".into(), None, None, Type::INT4, FieldFormat::Text); let schema = Arc::new(vec![f1]); let schema_ref = schema.clone(); let row = { let mut encoder = DataRowEncoder::new(schema_ref.clone()); encoder.encode_field(&Some(1))?; encoder.finish() }; let data_row_stream = stream::iter(vec![row]); Response::Query(QueryResponse::new(schema, data_row_stream)) } _ => Response::Error(Box::new(ErrorInfo::new( "FATAL".to_string(), "38003".to_string(), "Unsupported statement.".to_string(), ))), }; Ok(vec![resp]) } } struct DummyProcessorFactory { handler: Arc, } impl PgWireServerHandlers for DummyProcessorFactory { type StartupHandler = DummyProcessor; type SimpleQueryHandler = DummyProcessor; type ExtendedQueryHandler = PlaceholderExtendedQueryHandler; type CopyHandler = NoopCopyHandler; type ErrorHandler = NoopErrorHandler; fn simple_query_handler(&self) -> Arc { self.handler.clone() } fn extended_query_handler(&self) -> Arc { Arc::new(PlaceholderExtendedQueryHandler) } fn startup_handler(&self) -> Arc { self.handler.clone() } fn copy_handler(&self) -> Arc { Arc::new(NoopCopyHandler) } fn error_handler(&self) -> Arc { Arc::new(NoopErrorHandler) } } #[tokio::main] pub async fn main() { let factory = Arc::new(DummyProcessorFactory { handler: Arc::new(DummyProcessor), }); let server_addr = "127.0.0.1:5432"; let listener = TcpListener::bind(server_addr).await.unwrap(); println!("Listening to {}", server_addr); loop { let incoming_socket = listener.accept().await.unwrap(); let factory_ref = factory.clone(); tokio::spawn(async move { process_socket(incoming_socket.0, None, factory_ref).await }); } }