// Copyright (c) 2022 Espresso Systems (espressosys.com)
// This file is part of the tide-disco library.
// You should have received a copy of the MIT License
// along with the tide-disco library. If not, see .
use crate::method::Method;
use serde::{Deserialize, Serialize};
use snafu::{OptionExt, Snafu};
use std::any::type_name;
use std::collections::HashMap;
use std::fmt::Display;
use strum_macros::EnumString;
use tagged_base64::TaggedBase64;
use tide::http::{self, content::Accept, mime::Mime, Headers};
use vbs::{version::StaticVersionType, BinarySerializer, Serializer};
#[derive(Clone, Debug, Snafu, Deserialize, Serialize)]
pub enum RequestError {
#[snafu(display("missing required parameter: {}", name))]
MissingParam { name: String },
#[snafu(display(
"incorrect parameter type: {} cannot be converted to {}",
actual,
expected
))]
IncorrectParamType {
actual: RequestParamType,
expected: RequestParamType,
},
#[snafu(display("value {} is too large for type {}", value, expected))]
IntegerOverflow { value: u128, expected: String },
#[snafu(display("Unable to deserialize from JSON"))]
Json,
#[snafu(display("Unable to deserialize from binary"))]
Binary,
#[snafu(display("Unable to deserialise from tagged base 64: {}", reason))]
TaggedBase64 { reason: String },
#[snafu(display("Content type not specified or type not supported"))]
UnsupportedContentType,
#[snafu(display("HTTP protocol error: {}", reason))]
Http { reason: String },
#[snafu(display("error parsing {} parameter: {}", param_type, reason))]
InvalidParam { param_type: String, reason: String },
#[snafu(display("unexpected tag in TaggedBase64: {} (expected {})", actual, expected))]
TagMismatch { actual: String, expected: String },
}
/// Parameters passed to a route handler.
///
/// These parameters describe the incoming request and the current server state.
#[derive(Clone, Debug)]
pub struct RequestParams {
req: http::Request,
post_data: Vec,
params: HashMap,
}
impl RequestParams {
pub(crate) async fn new(
mut req: tide::Request,
formal_params: &[RequestParam],
) -> Result {
Ok(Self {
post_data: req.body_bytes().await.unwrap(),
params: formal_params
.iter()
.filter_map(|param| match RequestParamValue::new(&req, param) {
Ok(None) => None,
Ok(Some(value)) => Some(Ok((param.name.clone(), value))),
Err(err) => Some(Err(err)),
})
.collect::>()?,
req: req.into(),
})
}
/// The [Method] used to dispatch the request.
pub fn method(&self) -> Method {
self.req.method().into()
}
/// The headers of the incoming request.
pub fn headers(&self) -> &Headers {
self.req.as_ref()
}
/// The [Accept] header of this request.
///
/// The media type proposals in the resulting header are sorted in order of decreasing weight.
///
/// If no [Accept] header was explicitly set, defaults to the wildcard `Accept: *`.
///
/// # Error
///
/// Returns [RequestError::Http] if the [Accept] header is malformed.
pub fn accept(&self) -> Result {
Self::accept_from_headers(self.headers())
}
pub(crate) fn accept_from_headers(
headers: impl AsRef,
) -> Result {
match Accept::from_headers(headers).map_err(|err| RequestError::Http {
reason: err.to_string(),
})? {
Some(mut accept) => {
accept.sort();
Ok(accept)
}
None => {
let mut accept = Accept::new();
accept.set_wildcard(true);
Ok(accept)
}
}
}
/// Get the remote address for this request.
///
/// This is determined in the following priority:
/// 1. `Forwarded` header `for` key
/// 2. The first `X-Forwarded-For` header
/// 3. Peer address of the transport
pub fn remote(&self) -> Option<&str> {
self.req.remote()
}
/// Get the value of a named parameter.
///
/// The name of the parameter can be given by any type that implements [Display]. Of course, the
/// simplest option is to use [str] or [String], as in
///
/// ```
/// # use tide_disco::*;
/// # fn ex(req: &RequestParams) {
/// req.param("foo")
/// # ;}
/// ```
///
/// However, you have the option of defining a statically typed enum representing the possible
/// parameters of a given route and using enum variants as parameter names. Among other
/// benefits, this allows you to change the client-facing parameter names just by tweaking the
/// [Display] implementation of your enum, without changing other code.
///
/// ```
/// use std::fmt::{self, Display, Formatter};
///
/// enum RouteParams {
/// Param1,
/// Param2,
/// }
///
/// impl Display for RouteParams {
/// fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
/// let name = match self {
/// Self::Param1 => "param1",
/// Self::Param2 => "param2",
/// };
/// write!(f, "{}", name)
/// }
/// }
///
/// # use tide_disco::*;
/// # fn ex(req: &RequestParams) {
/// req.param(&RouteParams::Param1)
/// # ;}
/// ```
///
/// You can also use [strum_macros] to automatically derive the [Display] implementation, so you
/// only have to specify the client-facing names of each parameter:
///
/// ```
/// #[derive(strum_macros::Display)]
/// enum RouteParams {
/// #[strum(serialize = "param1")]
/// Param1,
/// #[strum(serialize = "param2")]
/// Param2,
/// }
///
/// # use tide_disco::*;
/// # fn ex(req: &RequestParams) {
/// req.param(&RouteParams::Param1)
/// # ;}
/// ```
///
/// # Errors
///
/// Returns [RequestError::MissingParam] if a parameter called `name` was not provided with the
/// request.
///
/// It is recommended to implement `From` for the error type for your API, so that
/// you can use `?` with this function in a route handler. If your error type implements
/// [Error](crate::Error), you can easily use the [catch_all](crate::Error::catch_all)
/// constructor to do this:
///
/// ```
/// use serde::{Deserialize, Serialize};
/// use snafu::Snafu;
/// use tide_disco::{Error, RequestError, RequestParams, StatusCode};
///
/// type ApiState = ();
///
/// #[derive(Debug, Snafu, Deserialize, Serialize)]
/// struct ApiError {
/// status: StatusCode,
/// msg: String,
/// }
///
/// impl Error for ApiError {
/// fn catch_all(status: StatusCode, msg: String) -> Self {
/// Self { status, msg }
/// }
///
/// fn status(&self) -> StatusCode {
/// self.status
/// }
/// }
///
/// impl From for ApiError {
/// fn from(err: RequestError) -> Self {
/// Self::catch_all(StatusCode::BAD_REQUEST, err.to_string())
/// }
/// }
///
/// async fn my_route_handler(req: RequestParams, _state: &ApiState) -> Result<(), ApiError> {
/// let param = req.param("my_param")?;
/// Ok(())
/// }
/// ```
pub fn param(&self, name: &Name) -> Result<&RequestParamValue, RequestError>
where
Name: ?Sized + Display,
{
self.opt_param(name).context(MissingParamSnafu {
name: name.to_string(),
})
}
/// Get the value of a named optional parameter.
///
/// Like [param](Self::param), but returns [None] instead of [Err] if the parametre is missing.
pub fn opt_param(&self, name: &Name) -> Option<&RequestParamValue>
where
Name: ?Sized + Display,
{
self.params.get(&name.to_string())
}
/// Get the value of a named parameter and convert it to an integer.
///
/// Like [param](Self::param), but returns [Err] if the parameter value cannot be converted to
/// an integer of the desired size.
pub fn integer_param(&self, name: &Name) -> Result
where
Name: ?Sized + Display,
T: TryFrom,
{
self.opt_integer_param(name)?.context(MissingParamSnafu {
name: name.to_string(),
})
}
/// Get the value of a named optional parameter and convert it to an integer.
///
/// Like [opt_param](Self::opt_param), but returns [Err] if the parameter value cannot be
/// converted to an integer of the desired size.
pub fn opt_integer_param(&self, name: &Name) -> Result