use std::fs::File; use std::io::{BufReader, Error as IOError, ErrorKind}; use std::sync::Arc; use async_trait::async_trait; use futures::{stream, StreamExt}; use rustls_pemfile::{certs, pkcs8_private_keys}; use rustls_pki_types::{CertificateDer, PrivateKeyDer}; use tokio::net::TcpListener; use tokio_rustls::rustls::ServerConfig; use tokio_rustls::TlsAcceptor; use pgwire::api::auth::noop::NoopStartupHandler; use pgwire::api::copy::NoopCopyHandler; use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag}; use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type}; use pgwire::error::PgWireResult; use pgwire::tokio::process_socket; pub struct DummyProcessor; impl NoopStartupHandler for DummyProcessor {} #[async_trait] impl SimpleQueryHandler for DummyProcessor { async fn do_query<'a, C>( &self, _client: &mut C, query: &'a str, ) -> PgWireResult>> where C: ClientInfo + Unpin + Send + Sync, { println!("{:?}", query); if query.starts_with("SELECT") { let f1 = FieldInfo::new("id".into(), None, None, Type::INT4, FieldFormat::Text); let f2 = FieldInfo::new("name".into(), None, None, Type::VARCHAR, FieldFormat::Text); let schema = Arc::new(vec![f1, f2]); let data = vec![ (Some(0), Some("Tom")), (Some(1), Some("Jerry")), (Some(2), None), ]; let schema_ref = schema.clone(); let data_row_stream = stream::iter(data.into_iter()).map(move |r| { let mut encoder = DataRowEncoder::new(schema_ref.clone()); encoder.encode_field(&r.0)?; encoder.encode_field(&r.1)?; encoder.finish() }); Ok(vec![Response::Query(QueryResponse::new( schema, data_row_stream, ))]) } else { Ok(vec![Response::Execution(Tag::new("OK").with_rows(1))]) } } } fn setup_tls() -> Result { let cert = certs(&mut BufReader::new(File::open("examples/ssl/server.crt")?)) .collect::, IOError>>()?; let key = pkcs8_private_keys(&mut BufReader::new(File::open("examples/ssl/server.key")?)) .map(|key| key.map(PrivateKeyDer::from)) .collect::, IOError>>()? .remove(0); let mut config = ServerConfig::builder() .with_no_client_auth() .with_single_cert(cert, key) .map_err(|err| IOError::new(ErrorKind::InvalidInput, err))?; config.alpn_protocols = vec![b"postgresql".to_vec()]; Ok(TlsAcceptor::from(Arc::new(config))) } struct DummyProcessorFactory { handler: Arc, } impl PgWireServerHandlers for DummyProcessorFactory { type StartupHandler = DummyProcessor; type SimpleQueryHandler = DummyProcessor; type ExtendedQueryHandler = PlaceholderExtendedQueryHandler; type CopyHandler = NoopCopyHandler; type ErrorHandler = NoopErrorHandler; fn simple_query_handler(&self) -> Arc { self.handler.clone() } fn extended_query_handler(&self) -> Arc { Arc::new(PlaceholderExtendedQueryHandler) } fn startup_handler(&self) -> Arc { self.handler.clone() } fn copy_handler(&self) -> Arc { Arc::new(NoopCopyHandler) } fn error_handler(&self) -> Arc { Arc::new(NoopErrorHandler) } } #[tokio::main] pub async fn main() { let factory = Arc::new(DummyProcessorFactory { handler: Arc::new(DummyProcessor), }); let server_addr = "127.0.0.1:5433"; let tls_acceptor = Arc::new(setup_tls().unwrap()); let listener = TcpListener::bind(server_addr).await.unwrap(); println!("Listening to {}", server_addr); loop { let incoming_socket = listener.accept().await.unwrap(); let tls_acceptor_ref = tls_acceptor.clone(); let factory_ref = factory.clone(); tokio::spawn(async move { process_socket(incoming_socket.0, Some(tls_acceptor_ref), factory_ref).await }); } }