use std::fmt::Debug; use async_tungstenite::tungstenite::{protocol::CloseFrame, Error as WsError}; use rsa::errors::Error as RsaError; use serde_json::error::Error as JsonError; use sqlx_core::Error as SqlxError; use thiserror::Error as ThisError; /// Enum representing protocol implementation errors. #[derive(Debug, ThisError)] pub enum ExaProtocolError { #[error("expected {0} parameter sets; found a mismatch of length {1}")] ParameterLengthMismatch(usize, usize), #[error("invalid response from database, expecting {0}")] UnexpectedResponse(&'static str), #[error("transaction already open")] TransactionAlreadyOpen, #[error("no response data received")] MissingResponseData, #[error("no message received")] MissingMessage, #[error("type mismatch: expected SQL type `{0}` but was provided `{1}`")] DatatypeMismatch(String, String), #[error("server closed connection due to: {0}")] WebsocketClosed(String), #[error("feature 'compression' must be enabled to use compression")] CompressionDisabled, } impl<'a> From<Option<CloseFrame<'a>>> for ExaProtocolError { fn from(value: Option<CloseFrame<'a>>) -> Self { let msg = value.map_or("unknown reason".to_owned(), |c| c.to_string()); Self::WebsocketClosed(msg) } } impl From<ExaProtocolError> for SqlxError { fn from(value: ExaProtocolError) -> Self { Self::Protocol(value.to_string()) } } /// Helper trait used for converting errors from various underlying libraries to `SQLx`. pub(crate) trait ExaResultExt<T> { fn to_sqlx_err(self) -> Result<T, SqlxError>; } impl<T> ExaResultExt<T> for Result<T, WsError> { fn to_sqlx_err(self) -> Result<T, SqlxError> { let e = match self { Ok(v) => return Ok(v), Err(e) => e, }; let e = match e { WsError::ConnectionClosed => SqlxError::Protocol(WsError::ConnectionClosed.to_string()), WsError::AlreadyClosed => SqlxError::Protocol(WsError::AlreadyClosed.to_string()), WsError::Io(e) => SqlxError::Io(e), WsError::Tls(e) => SqlxError::Tls(e.into()), WsError::Capacity(e) => SqlxError::Protocol(e.to_string()), WsError::Protocol(e) => SqlxError::Protocol(e.to_string()), WsError::WriteBufferFull(e) => SqlxError::Protocol(e.to_string()), WsError::Utf8 => SqlxError::Protocol(WsError::Utf8.to_string()), WsError::Url(e) => SqlxError::Configuration(e.into()), WsError::Http(r) => SqlxError::Protocol(format!("HTTP error: {}", r.status())), WsError::HttpFormat(e) => SqlxError::Protocol(e.to_string()), WsError::AttackAttempt => SqlxError::Tls(WsError::AttackAttempt.into()), }; Err(e) } } impl<T> ExaResultExt<T> for Result<T, RsaError> { fn to_sqlx_err(self) -> Result<T, SqlxError> { self.map_err(|e| SqlxError::Protocol(e.to_string())) } } impl<T> ExaResultExt<T> for Result<T, JsonError> { fn to_sqlx_err(self) -> Result<T, SqlxError> { self.map_err(|e| SqlxError::Protocol(e.to_string())) } }