use std::fmt::Debug; use std::sync::Arc; use async_trait::async_trait; use futures::{Sink, SinkExt}; use tokio::net::TcpListener; use pgwire::api::auth::noop::NoopStartupHandler; use pgwire::api::copy::CopyHandler; use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{CopyResponse, Response}; use pgwire::api::{ClientInfo, PgWireConnectionState, PgWireHandlerFactory}; use pgwire::error::ErrorInfo; use pgwire::error::{PgWireError, PgWireResult}; use pgwire::messages::copy::{CopyData, CopyDone, CopyFail}; use pgwire::messages::response::NoticeResponse; use pgwire::messages::PgWireBackendMessage; use pgwire::tokio::process_socket; pub struct DummyProcessor; impl NoopStartupHandler for DummyProcessor {} #[async_trait] impl SimpleQueryHandler for DummyProcessor { async fn do_query<'a, C>( &self, client: &mut C, query: &'a str, ) -> PgWireResult>> where C: ClientInfo + Sink + Unpin + Send + Sync, C::Error: Debug, PgWireError: From<>::Error>, { client .send(PgWireBackendMessage::NoticeResponse(NoticeResponse::from( ErrorInfo::new( "NOTICE".to_owned(), "01000".to_owned(), format!("Query received {}", query), ), ))) .await?; Ok(vec![Response::CopyIn(CopyResponse::new(0, 1, vec![0]))]) } } #[async_trait] impl CopyHandler for DummyProcessor { async fn on_copy_data(&self, client: &mut C, copy_data: CopyData) -> PgWireResult<()> where C: ClientInfo + Sink + Unpin + Send + Sync, C::Error: Debug, PgWireError: From<>::Error>, { use PgWireConnectionState::*; // This is set by the `on_query` implementations while handling a // `CopyIn`/`CopyOut`/`CopyBoth` response. assert!(matches!(client.state(), CopyInProgress(_))); println!("receiving data: {:?}", copy_data); Ok(()) } async fn on_copy_done(&self, client: &mut C, _done: CopyDone) -> PgWireResult<()> where C: ClientInfo + Sink + Unpin + Send + Sync, C::Error: Debug, PgWireError: From<>::Error>, { use PgWireConnectionState::*; // This is set by the `on_query` implementations while handling a // `CopyIn`/`CopyOut`/`CopyBoth` response. assert!(matches!(client.state(), CopyInProgress(_))); println!("copy done"); Ok(()) } async fn on_copy_fail(&self, client: &mut C, fail: CopyFail) -> PgWireError where C: ClientInfo + Sink + Unpin + Send + Sync, C::Error: Debug, PgWireError: From<>::Error>, { use PgWireConnectionState::*; // This is set by the `on_query` implementations while handling a // `CopyIn`/`CopyOut`/`CopyBoth` response. assert!(matches!(client.state(), CopyInProgress(_))); println!("copy failed: {:?}", fail); PgWireError::UserError(Box::new(ErrorInfo::new( "ERROR".to_owned(), "XX000".to_owned(), format!("COPY IN mode terminated by the user: {}", fail.message), ))) } } struct DummyProcessorFactory { handler: Arc, } impl PgWireHandlerFactory for DummyProcessorFactory { type StartupHandler = DummyProcessor; type SimpleQueryHandler = DummyProcessor; type ExtendedQueryHandler = PlaceholderExtendedQueryHandler; type CopyHandler = DummyProcessor; fn simple_query_handler(&self) -> Arc { self.handler.clone() } fn extended_query_handler(&self) -> Arc { Arc::new(PlaceholderExtendedQueryHandler) } fn startup_handler(&self) -> Arc { self.handler.clone() } fn copy_handler(&self) -> Arc { self.handler.clone() } } #[tokio::main] pub async fn main() { let factory = Arc::new(DummyProcessorFactory { handler: Arc::new(DummyProcessor), }); let server_addr = "127.0.0.1:5432"; let listener = TcpListener::bind(server_addr).await.unwrap(); println!("Listening to {}", server_addr); loop { let incoming_socket = listener.accept().await.unwrap(); let factory_ref = factory.clone(); tokio::spawn(async move { process_socket(incoming_socket.0, None, factory_ref).await }); } }