use std::sync::Arc; use async_trait::async_trait; use datafusion::{ arrow::datatypes::SchemaRef, catalog::SchemaProvider, error::{DataFusionError, Result}, execution::context::{SessionContext, SessionState}, physical_plan::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream}, sql::unparser::dialect::{DefaultDialect, Dialect}, }; use futures::TryStreamExt; use datafusion_federation::sql::SQLExecutor; pub fn overwrite_default_schema( state: &SessionState, schema: Arc, ) -> Result<()> { let options = &state.config().options().catalog; let catalog = state .catalog_list() .catalog(options.default_catalog.as_str()) .unwrap(); catalog.register_schema(options.default_schema.as_str(), schema)?; Ok(()) } pub struct MockSqliteExecutor { session: Arc, } impl MockSqliteExecutor { pub fn new(session: Arc) -> Self { Self { session } } } #[async_trait] impl SQLExecutor for MockSqliteExecutor { fn name(&self) -> &str { "mock_sqlite_executor" } fn compute_context(&self) -> Option { // Don't return None here - it will cause incorrect federation with other providers of the // same name that also have a compute_context of None. // Instead return a random string that will never match any other provider's context. Some("sqlite_exec".to_string()) } fn execute(&self, sql: &str, schema: SchemaRef) -> Result { // Execute it using the remote datafusion session context let future_stream = _execute(self.session.clone(), sql.to_string()); let stream = futures::stream::once(future_stream).try_flatten(); Ok(Box::pin(RecordBatchStreamAdapter::new( schema.clone(), stream, ))) } async fn table_names(&self) -> Result> { Err(DataFusionError::NotImplemented( "table inference not implemented".to_string(), )) } async fn get_table_schema(&self, table_name: &str) -> Result { let sql = format!("select * from {table_name} limit 1"); let df = self.session.sql(&sql).await?; let schema = df.schema().as_arrow().clone(); Ok(Arc::new(schema)) } fn dialect(&self) -> Arc { Arc::new(DefaultDialect {}) } } #[allow(dead_code)] pub struct MockPostgresExecutor { session: Arc, } #[allow(dead_code)] impl MockPostgresExecutor { pub fn new(session: Arc) -> Self { Self { session } } } #[async_trait] impl SQLExecutor for MockPostgresExecutor { fn name(&self) -> &str { "mock_postgres_executor" } fn compute_context(&self) -> Option { // Don't return None here - it will cause incorrect federation with other providers of the // same name that also have a compute_context of None. // Instead return a random string that will never match any other provider's context. Some("postgres_exec".to_string()) } fn execute(&self, sql: &str, schema: SchemaRef) -> Result { // Execute it using the remote datafusion session context let future_stream = _execute(self.session.clone(), sql.to_string()); let stream = futures::stream::once(future_stream).try_flatten(); Ok(Box::pin(RecordBatchStreamAdapter::new( schema.clone(), stream, ))) } async fn table_names(&self) -> Result> { Err(DataFusionError::NotImplemented( "table inference not implemented".to_string(), )) } async fn get_table_schema(&self, table_name: &str) -> Result { let sql = format!("select * from {table_name} limit 1"); let df = self.session.sql(&sql).await?; let schema = df.schema().as_arrow().clone(); Ok(Arc::new(schema)) } fn dialect(&self) -> Arc { Arc::new(DefaultDialect {}) } } async fn _execute(ctx: Arc, sql: String) -> Result { ctx.sql(&sql).await?.execute_stream().await }