use std::{
collections::HashMap,
fmt::{Debug, Formatter, Result as FmtResult},
net::{IpAddr, SocketAddr},
str::{self, Utf8Error},
};
use hyper::{body, Body, Request as HyperRequestInternal, Uri, Version};
use crate::{cookie::Cookie, HttpMethod};
pub type HyperRequest = HyperRequestInternal
;
pub struct Request {
pub(crate) socket_addr: SocketAddr,
pub(crate) body: Vec,
pub(crate) method: HttpMethod,
pub(crate) uri: Uri,
pub(crate) version: (u8, u8),
pub(crate) headers: HashMap>,
pub(crate) query: HashMap,
pub(crate) params: HashMap,
pub(crate) cookies: Vec,
pub(crate) hyper_request: HyperRequest,
}
impl Request {
pub async fn from_hyper(
socket_addr: SocketAddr,
req: HyperRequest,
) -> Self {
let (parts, hyper_body) = req.into_parts();
let mut headers = HashMap::>::new();
parts.headers.iter().for_each(|(key, value)| {
let key = key.to_string();
let value = value.to_str();
if value.is_err() {
return;
}
let value = value.unwrap().to_string();
if let Some(values) = headers.get_mut(&key) {
values.push(value);
} else {
headers.insert(key.to_string(), vec![value]);
}
});
let body = body::to_bytes(hyper_body).await.unwrap().to_vec();
Request {
socket_addr,
body: body.clone(),
method: HttpMethod::from(parts.method.clone()),
uri: parts.uri.clone(),
version: match parts.version {
Version::HTTP_09 => (0, 9),
Version::HTTP_10 => (1, 0),
Version::HTTP_11 => (1, 1),
Version::HTTP_2 => (2, 0),
Version::HTTP_3 => (3, 0),
_ => (0, 0),
},
headers: headers.clone(),
query: if let Some(query) = parts.uri.query() {
serde_urlencoded::from_str(query)
.unwrap_or_else(|_| HashMap::new())
} else {
HashMap::new()
},
params: HashMap::new(),
cookies: vec![],
hyper_request: HyperRequest::from_parts(parts, Body::from(body)),
}
}
pub fn get_body_bytes(&self) -> &[u8] {
&self.body
}
pub fn get_body(&self) -> Result {
Ok(str::from_utf8(&self.body)?.to_string())
}
pub fn get_method(&self) -> &HttpMethod {
&self.method
}
pub fn get_length(&self) -> u128 {
if let Some(length) = self.headers.get("Content-Length") {
if let Some(value) = length.get(0) {
if let Ok(value) = value.parse::() {
return value;
}
}
}
self.body.len() as u128
}
pub fn get_path(&self) -> String {
self.uri.path().to_string()
}
pub fn get_full_url(&self) -> String {
self.uri.to_string()
}
pub fn get_origin(&self) -> Option {
Some(format!(
"{}://{}",
self.uri.scheme_str()?,
self.uri.authority()?
))
}
pub fn get_query_string(&self) -> String {
self.uri.query().unwrap_or("").to_string()
}
pub fn get_host(&self) -> String {
self.uri
.host()
.map(String::from)
.unwrap_or_else(|| self.get_header("host").unwrap_or_default())
}
pub fn get_host_and_port(&self) -> String {
format!(
"{}{}",
self.uri.host().unwrap(),
if let Some(port) = self.uri.port_u16() {
format!(":{}", port)
} else {
String::new()
}
)
}
pub fn get_content_type(&self) -> String {
let c_type = self
.get_header("Content-Type")
.unwrap_or_else(|| "text/plain".to_string());
c_type.split(';').next().unwrap_or("").to_string()
}
pub fn get_charset(&self) -> Option {
let header = self.get_header("Content-Type")?;
let charset_index = header.find("charset=")?;
let data = &header[charset_index..];
Some(
data[(charset_index + 8)..data.find(';').unwrap_or(data.len())]
.to_string(),
)
}
pub fn get_protocol(&self) -> String {
// TODO support X-Forwarded-Proto
self.uri.scheme_str().unwrap_or("http").to_string()
}
pub fn is_secure(&self) -> bool {
self.get_protocol() == "https"
}
pub fn get_ip(&self) -> IpAddr {
self.socket_addr.ip()
}
pub fn is(&self, mimes: &[&str]) -> bool {
let given = self.get_content_type();
mimes.iter().any(|mime| mime == &given)
}
// TODO content negotiation
// See: https://koajs.com/#request content negotiation
pub fn get_version(&self) -> String {
format!("{}.{}", self.version.0, self.version.1)
}
pub fn get_version_major(&self) -> u8 {
self.version.0
}
pub fn get_version_minor(&self) -> u8 {
self.version.1
}
pub fn get_header(&self, field: &str) -> Option {
self.headers.iter().find_map(|(key, value)| {
if key.to_lowercase() == field.to_lowercase() {
Some(value.join("\n"))
} else {
None
}
})
}
pub fn get_headers(&self) -> &HashMap> {
&self.headers
}
pub fn set_header(&mut self, field: &str, value: &str) {
self.headers
.insert(field.to_string(), vec![value.to_string()]);
}
pub fn append_header(&mut self, key: String, value: String) {
if let Some(headers) = self.headers.get_mut(&key) {
headers.push(value);
} else {
self.headers.insert(key, vec![value]);
}
}
pub fn remove_header(&mut self, field: &str) {
self.headers.remove(field);
}
pub fn get_query(&self) -> &HashMap {
&self.query
}
pub fn get_params(&self) -> &HashMap {
&self.params
}
pub fn get_cookies(&self) -> &Vec {
&self.cookies
}
pub fn get_cookie(&self, name: &str) -> Option<&Cookie> {
self.cookies.iter().find(|cookie| cookie.key == name)
}
pub fn get_hyper_request(&self) -> &HyperRequest {
&self.hyper_request
}
pub fn get_hyper_request_mut(&mut self) -> &mut HyperRequest {
&mut self.hyper_request
}
}
#[cfg(debug_assertions)]
impl Debug for Request {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
f.debug_struct("Request")
.field("socket_addr", &self.socket_addr)
.field("body", &self.body)
.field("method", &self.method)
.field("uri", &self.uri)
.field("version", &self.version)
.field("headers", &self.headers)
.field("query", &self.query)
.field("params", &self.params)
.field("cookies", &self.cookies)
.finish()
}
}
#[cfg(not(debug_assertions))]
impl Debug for Request {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
write!(f, "[Request {} {}]", self.method, self.get_path())
}
}