// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. mod common; use std::{pin::Pin, sync::Arc}; use crate::common::fixture::TestFixture; use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray, TimestampNanosecondArray}; use arrow_flight::{ decode::FlightRecordBatchStream, encode::FlightDataEncoderBuilder, flight_service_server::{FlightService, FlightServiceServer}, sql::{ server::{FlightSqlService, PeekableFlightDataStream}, ActionCreatePreparedStatementRequest, ActionCreatePreparedStatementResult, Any, CommandGetCatalogs, CommandGetDbSchemas, CommandGetTableTypes, CommandGetTables, CommandPreparedStatementQuery, CommandStatementQuery, DoPutPreparedStatementResult, ProstMessageExt, SqlInfo, }, utils::batches_to_flight_data, Action, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse, IpcMessage, SchemaAsIpc, Ticket, }; use arrow_ipc::writer::IpcWriteOptions; use arrow_schema::{ArrowError, DataType, Field, Schema, TimeUnit}; use assert_cmd::Command; use bytes::Bytes; use futures::{Stream, TryStreamExt}; use prost::Message; use tonic::{Request, Response, Status, Streaming}; const QUERY: &str = "SELECT * FROM table;"; #[tokio::test] async fn test_simple() { let test_server = FlightSqlServiceImpl::default(); let fixture = TestFixture::new(test_server.service()).await; let addr = fixture.addr; let stdout = tokio::task::spawn_blocking(move || { Command::cargo_bin("flight_sql_client") .unwrap() .env_clear() .env("RUST_BACKTRACE", "1") .env("RUST_LOG", "warn") .arg("--host") .arg(addr.ip().to_string()) .arg("--port") .arg(addr.port().to_string()) .arg("statement-query") .arg(QUERY) .assert() .success() .get_output() .stdout .clone() }) .await .unwrap(); fixture.shutdown_and_wait().await; assert_eq!( std::str::from_utf8(&stdout).unwrap().trim(), "+--------------+-----------+---------------------------+-----------------------------+\ \n| field_string | field_int | field_timestamp_nano_notz | field_timestamp_nano_berlin |\ \n+--------------+-----------+---------------------------+-----------------------------+\ \n| Hello | 42 | | |\ \n| lovely | | 1970-01-01T00:00:00 | 1970-01-01T01:00:00+01:00 |\ \n| FlightSQL! | 1337 | 2024-10-30T11:36:57 | 2024-10-30T12:36:57+01:00 |\ \n+--------------+-----------+---------------------------+-----------------------------+", ); } #[tokio::test] async fn test_get_catalogs() { let test_server = FlightSqlServiceImpl::default(); let fixture = TestFixture::new(test_server.service()).await; let addr = fixture.addr; let stdout = tokio::task::spawn_blocking(move || { Command::cargo_bin("flight_sql_client") .unwrap() .env_clear() .env("RUST_BACKTRACE", "1") .env("RUST_LOG", "warn") .arg("--host") .arg(addr.ip().to_string()) .arg("--port") .arg(addr.port().to_string()) .arg("catalogs") .assert() .success() .get_output() .stdout .clone() }) .await .unwrap(); fixture.shutdown_and_wait().await; assert_eq!( std::str::from_utf8(&stdout).unwrap().trim(), "+--------------+\ \n| catalog_name |\ \n+--------------+\ \n| catalog_a |\ \n| catalog_b |\ \n+--------------+", ); } #[tokio::test] async fn test_get_db_schemas() { let test_server = FlightSqlServiceImpl::default(); let fixture = TestFixture::new(test_server.service()).await; let addr = fixture.addr; let stdout = tokio::task::spawn_blocking(move || { Command::cargo_bin("flight_sql_client") .unwrap() .env_clear() .env("RUST_BACKTRACE", "1") .env("RUST_LOG", "warn") .arg("--host") .arg(addr.ip().to_string()) .arg("--port") .arg(addr.port().to_string()) .arg("db-schemas") .arg("catalog_a") .assert() .success() .get_output() .stdout .clone() }) .await .unwrap(); fixture.shutdown_and_wait().await; assert_eq!( std::str::from_utf8(&stdout).unwrap().trim(), "+--------------+----------------+\ \n| catalog_name | db_schema_name |\ \n+--------------+----------------+\ \n| catalog_a | schema_1 |\ \n| catalog_a | schema_2 |\ \n+--------------+----------------+", ); } #[tokio::test] async fn test_get_tables() { let test_server = FlightSqlServiceImpl::default(); let fixture = TestFixture::new(test_server.service()).await; let addr = fixture.addr; let stdout = tokio::task::spawn_blocking(move || { Command::cargo_bin("flight_sql_client") .unwrap() .env_clear() .env("RUST_BACKTRACE", "1") .env("RUST_LOG", "warn") .arg("--host") .arg(addr.ip().to_string()) .arg("--port") .arg(addr.port().to_string()) .arg("tables") .arg("catalog_a") .assert() .success() .get_output() .stdout .clone() }) .await .unwrap(); fixture.shutdown_and_wait().await; assert_eq!( std::str::from_utf8(&stdout).unwrap().trim(), "+--------------+----------------+------------+------------+\ \n| catalog_name | db_schema_name | table_name | table_type |\ \n+--------------+----------------+------------+------------+\ \n| catalog_a | schema_1 | table_1 | TABLE |\ \n| catalog_a | schema_2 | table_2 | VIEW |\ \n+--------------+----------------+------------+------------+", ); } #[tokio::test] async fn test_get_tables_db_filter() { let test_server = FlightSqlServiceImpl::default(); let fixture = TestFixture::new(test_server.service()).await; let addr = fixture.addr; let stdout = tokio::task::spawn_blocking(move || { Command::cargo_bin("flight_sql_client") .unwrap() .env_clear() .env("RUST_BACKTRACE", "1") .env("RUST_LOG", "warn") .arg("--host") .arg(addr.ip().to_string()) .arg("--port") .arg(addr.port().to_string()) .arg("tables") .arg("catalog_a") .arg("--db-schema-filter") .arg("schema_2") .assert() .success() .get_output() .stdout .clone() }) .await .unwrap(); fixture.shutdown_and_wait().await; assert_eq!( std::str::from_utf8(&stdout).unwrap().trim(), "+--------------+----------------+------------+------------+\ \n| catalog_name | db_schema_name | table_name | table_type |\ \n+--------------+----------------+------------+------------+\ \n| catalog_a | schema_2 | table_2 | VIEW |\ \n+--------------+----------------+------------+------------+", ); } #[tokio::test] async fn test_get_tables_types() { let test_server = FlightSqlServiceImpl::default(); let fixture = TestFixture::new(test_server.service()).await; let addr = fixture.addr; let stdout = tokio::task::spawn_blocking(move || { Command::cargo_bin("flight_sql_client") .unwrap() .env_clear() .env("RUST_BACKTRACE", "1") .env("RUST_LOG", "warn") .arg("--host") .arg(addr.ip().to_string()) .arg("--port") .arg(addr.port().to_string()) .arg("table-types") .assert() .success() .get_output() .stdout .clone() }) .await .unwrap(); fixture.shutdown_and_wait().await; assert_eq!( std::str::from_utf8(&stdout).unwrap().trim(), "+--------------+\ \n| table_type |\ \n+--------------+\ \n| SYSTEM_TABLE |\ \n| TABLE |\ \n| VIEW |\ \n+--------------+", ); } const PREPARED_QUERY: &str = "SELECT * FROM table WHERE field = $1"; const PREPARED_STATEMENT_HANDLE: &str = "prepared_statement_handle"; const UPDATED_PREPARED_STATEMENT_HANDLE: &str = "updated_prepared_statement_handle"; async fn test_do_put_prepared_statement(test_server: FlightSqlServiceImpl) { let fixture = TestFixture::new(test_server.service()).await; let addr = fixture.addr; let stdout = tokio::task::spawn_blocking(move || { Command::cargo_bin("flight_sql_client") .unwrap() .env_clear() .env("RUST_BACKTRACE", "1") .env("RUST_LOG", "warn") .arg("--host") .arg(addr.ip().to_string()) .arg("--port") .arg(addr.port().to_string()) .arg("prepared-statement-query") .arg(PREPARED_QUERY) .args(["-p", "$1=string"]) .args(["-p", "$2=64"]) .assert() .success() .get_output() .stdout .clone() }) .await .unwrap(); fixture.shutdown_and_wait().await; assert_eq!( std::str::from_utf8(&stdout).unwrap().trim(), "+--------------+-----------+---------------------------+-----------------------------+\ \n| field_string | field_int | field_timestamp_nano_notz | field_timestamp_nano_berlin |\ \n+--------------+-----------+---------------------------+-----------------------------+\ \n| Hello | 42 | | |\ \n| lovely | | 1970-01-01T00:00:00 | 1970-01-01T01:00:00+01:00 |\ \n| FlightSQL! | 1337 | 2024-10-30T11:36:57 | 2024-10-30T12:36:57+01:00 |\ \n+--------------+-----------+---------------------------+-----------------------------+", ); } #[tokio::test] pub async fn test_do_put_prepared_statement_stateless() { test_do_put_prepared_statement(FlightSqlServiceImpl { stateless_prepared_statements: true, }) .await } #[tokio::test] pub async fn test_do_put_prepared_statement_stateful() { test_do_put_prepared_statement(FlightSqlServiceImpl { stateless_prepared_statements: false, }) .await } #[derive(Clone)] pub struct FlightSqlServiceImpl { /// Whether to emulate stateless (true) or stateful (false) behavior for /// prepared statements. stateful servers will not return an updated /// handle after executing `DoPut(CommandPreparedStatementQuery)` stateless_prepared_statements: bool, } impl Default for FlightSqlServiceImpl { fn default() -> Self { Self { stateless_prepared_statements: true, } } } impl FlightSqlServiceImpl { /// Return an [`FlightServiceServer`] that can be used with a /// [`Server`](tonic::transport::Server) pub fn service(&self) -> FlightServiceServer { // wrap up tonic goop FlightServiceServer::new(self.clone()) } fn schema() -> Schema { Schema::new(vec![ Field::new("field_string", DataType::Utf8, false), Field::new("field_int", DataType::Int64, true), Field::new( "field_timestamp_nano_notz", DataType::Timestamp(TimeUnit::Nanosecond, None), true, ), Field::new( "field_timestamp_nano_berlin", DataType::Timestamp(TimeUnit::Nanosecond, Some(Arc::from("Europe/Berlin"))), true, ), ]) } fn fake_result() -> Result { let schema = Self::schema(); let string_array = StringArray::from(vec!["Hello", "lovely", "FlightSQL!"]); let int_array = Int64Array::from(vec![Some(42), None, Some(1337)]); let timestamp_array = TimestampNanosecondArray::from(vec![None, Some(0), Some(1730288217000000000)]); let timestamp_ts_array = timestamp_array .clone() .with_timezone(Arc::from("Europe/Berlin")); let cols = vec![ Arc::new(string_array) as ArrayRef, Arc::new(int_array) as ArrayRef, Arc::new(timestamp_array) as ArrayRef, Arc::new(timestamp_ts_array) as ArrayRef, ]; RecordBatch::try_new(Arc::new(schema), cols) } fn create_fake_prepared_stmt() -> Result { let handle = PREPARED_STATEMENT_HANDLE.to_string(); let schema = Self::schema(); let parameter_schema = Schema::new(vec![ Field::new("$1", DataType::Utf8, false), Field::new("$2", DataType::Int64, true), ]); Ok(ActionCreatePreparedStatementResult { prepared_statement_handle: handle.into(), dataset_schema: serialize_schema(&schema)?, parameter_schema: serialize_schema(¶meter_schema)?, }) } fn fake_flight_info(&self) -> Result { let batch = Self::fake_result()?; Ok(FlightInfo::new() .try_with_schema(batch.schema_ref()) .expect("encoding schema") .with_endpoint( FlightEndpoint::new().with_ticket(Ticket::new( FetchResults { handle: String::from("part_1"), } .as_any() .encode_to_vec(), )), ) .with_endpoint( FlightEndpoint::new().with_ticket(Ticket::new( FetchResults { handle: String::from("part_2"), } .as_any() .encode_to_vec(), )), ) .with_total_records(batch.num_rows() as i64) .with_total_bytes(batch.get_array_memory_size() as i64) .with_ordered(false)) } } fn serialize_schema(schema: &Schema) -> Result { Ok(IpcMessage::try_from(SchemaAsIpc::new(schema, &IpcWriteOptions::default()))?.0) } #[tonic::async_trait] impl FlightSqlService for FlightSqlServiceImpl { type FlightService = FlightSqlServiceImpl; async fn do_handshake( &self, _request: Request>, ) -> Result< Response> + Send>>>, Status, > { Err(Status::unimplemented("do_handshake not implemented")) } async fn do_get_fallback( &self, _request: Request, message: Any, ) -> Result::DoGetStream>, Status> { let part = message.unpack::().unwrap().unwrap().handle; let batch = Self::fake_result().unwrap(); let batch = match part.as_str() { "part_1" => batch.slice(0, 2), "part_2" => batch.slice(2, 1), ticket => panic!("Invalid ticket: {ticket:?}"), }; let schema = batch.schema_ref(); let batches = vec![batch.clone()]; let flight_data = batches_to_flight_data(schema, batches) .unwrap() .into_iter() .map(Ok); let stream: Pin> + Send>> = Box::pin(futures::stream::iter(flight_data)); let resp = Response::new(stream); Ok(resp) } async fn get_flight_info_catalogs( &self, query: CommandGetCatalogs, request: Request, ) -> Result, Status> { let flight_descriptor = request.into_inner(); let ticket = Ticket { ticket: query.as_any().encode_to_vec().into(), }; let endpoint = FlightEndpoint::new().with_ticket(ticket); let flight_info = FlightInfo::new() .try_with_schema(&query.into_builder().schema()) .unwrap() .with_endpoint(endpoint) .with_descriptor(flight_descriptor); Ok(Response::new(flight_info)) } async fn get_flight_info_schemas( &self, query: CommandGetDbSchemas, request: Request, ) -> Result, Status> { let flight_descriptor = request.into_inner(); let ticket = Ticket { ticket: query.as_any().encode_to_vec().into(), }; let endpoint = FlightEndpoint::new().with_ticket(ticket); let flight_info = FlightInfo::new() .try_with_schema(&query.into_builder().schema()) .unwrap() .with_endpoint(endpoint) .with_descriptor(flight_descriptor); Ok(Response::new(flight_info)) } async fn get_flight_info_tables( &self, query: CommandGetTables, request: Request, ) -> Result, Status> { let flight_descriptor = request.into_inner(); let ticket = Ticket { ticket: query.as_any().encode_to_vec().into(), }; let endpoint = FlightEndpoint::new().with_ticket(ticket); let flight_info = FlightInfo::new() .try_with_schema(&query.into_builder().schema()) .unwrap() .with_endpoint(endpoint) .with_descriptor(flight_descriptor); Ok(Response::new(flight_info)) } async fn get_flight_info_table_types( &self, query: CommandGetTableTypes, request: Request, ) -> Result, Status> { let flight_descriptor = request.into_inner(); let ticket = Ticket { ticket: query.as_any().encode_to_vec().into(), }; let endpoint = FlightEndpoint::new().with_ticket(ticket); let flight_info = FlightInfo::new() .try_with_schema(&query.into_builder().schema()) .unwrap() .with_endpoint(endpoint) .with_descriptor(flight_descriptor); Ok(Response::new(flight_info)) } async fn get_flight_info_statement( &self, query: CommandStatementQuery, _request: Request, ) -> Result, Status> { assert_eq!(query.query, QUERY); let resp = Response::new(self.fake_flight_info().unwrap()); Ok(resp) } async fn get_flight_info_prepared_statement( &self, cmd: CommandPreparedStatementQuery, _request: Request, ) -> Result, Status> { if self.stateless_prepared_statements { assert_eq!( cmd.prepared_statement_handle, UPDATED_PREPARED_STATEMENT_HANDLE.as_bytes() ); } else { assert_eq!( cmd.prepared_statement_handle, PREPARED_STATEMENT_HANDLE.as_bytes() ); } let resp = Response::new(self.fake_flight_info().unwrap()); Ok(resp) } async fn do_get_catalogs( &self, query: CommandGetCatalogs, _request: Request, ) -> Result::DoGetStream>, Status> { let mut builder = query.into_builder(); for catalog_name in ["catalog_a", "catalog_b"] { builder.append(catalog_name); } let schema = builder.schema(); let batch = builder.build(); let stream = FlightDataEncoderBuilder::new() .with_schema(schema) .build(futures::stream::once(async { batch })) .map_err(Status::from); Ok(Response::new(Box::pin(stream))) } async fn do_get_schemas( &self, query: CommandGetDbSchemas, _request: Request, ) -> Result::DoGetStream>, Status> { let mut builder = query.into_builder(); for (catalog_name, schema_name) in [ ("catalog_a", "schema_1"), ("catalog_a", "schema_2"), ("catalog_b", "schema_3"), ] { builder.append(catalog_name, schema_name); } let schema = builder.schema(); let batch = builder.build(); let stream = FlightDataEncoderBuilder::new() .with_schema(schema) .build(futures::stream::once(async { batch })) .map_err(Status::from); Ok(Response::new(Box::pin(stream))) } async fn do_get_tables( &self, query: CommandGetTables, _request: Request, ) -> Result::DoGetStream>, Status> { let mut builder = query.into_builder(); for (catalog_name, schema_name, table_name, table_type, schema) in [ ( "catalog_a", "schema_1", "table_1", "TABLE", Arc::new(Schema::empty()), ), ( "catalog_a", "schema_2", "table_2", "VIEW", Arc::new(Schema::empty()), ), ( "catalog_b", "schema_3", "table_3", "TABLE", Arc::new(Schema::empty()), ), ] { builder .append(catalog_name, schema_name, table_name, table_type, &schema) .unwrap(); } let schema = builder.schema(); let batch = builder.build(); let stream = FlightDataEncoderBuilder::new() .with_schema(schema) .build(futures::stream::once(async { batch })) .map_err(Status::from); Ok(Response::new(Box::pin(stream))) } async fn do_get_table_types( &self, query: CommandGetTableTypes, _request: Request, ) -> Result::DoGetStream>, Status> { let mut builder = query.into_builder(); for table_type in ["TABLE", "VIEW", "SYSTEM_TABLE"] { builder.append(table_type); } let schema = builder.schema(); let batch = builder.build(); let stream = FlightDataEncoderBuilder::new() .with_schema(schema) .build(futures::stream::once(async { batch })) .map_err(Status::from); Ok(Response::new(Box::pin(stream))) } async fn do_put_prepared_statement_query( &self, _query: CommandPreparedStatementQuery, request: Request, ) -> Result { // just make sure decoding the parameters works let parameters = FlightRecordBatchStream::new_from_flight_data( request.into_inner().map_err(|e| e.into()), ) .try_collect::>() .await?; for (left, right) in parameters[0].schema().flattened_fields().iter().zip(vec![ Field::new("$1", DataType::Utf8, false), Field::new("$2", DataType::Int64, true), ]) { if left.name() != right.name() || left.data_type() != right.data_type() { return Err(Status::invalid_argument(format!( "Parameters did not match parameter schema\ngot {}", parameters[0].schema(), ))); } } let handle = if self.stateless_prepared_statements { UPDATED_PREPARED_STATEMENT_HANDLE.to_string().into() } else { PREPARED_STATEMENT_HANDLE.to_string().into() }; let result = DoPutPreparedStatementResult { prepared_statement_handle: Some(handle), }; Ok(result) } async fn do_action_create_prepared_statement( &self, _query: ActionCreatePreparedStatementRequest, _request: Request, ) -> Result { Self::create_fake_prepared_stmt() .map_err(|e| Status::internal(format!("Unable to serialize schema: {e}"))) } async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {} } #[derive(Clone, PartialEq, ::prost::Message)] pub struct FetchResults { #[prost(string, tag = "1")] pub handle: ::prost::alloc::string::String, } impl ProstMessageExt for FetchResults { fn type_url() -> &'static str { "type.googleapis.com/arrow.flight.protocol.sql.FetchResults" } fn as_any(&self) -> Any { Any { type_url: FetchResults::type_url().to_string(), value: ::prost::Message::encode_to_vec(self).into(), } } }