use std::net::IpAddr; use std::sync::Arc; use async_trait::async_trait; use json::JsonValue; use log::{debug, warn}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; use crate::async_tls::AsyncTlsConnector; use crate::line_reader::LineReader; use crate::resolver::Resolver; use crate::util::allocate_vec; pub struct DohResolver { tls_connector: Arc> } impl DohResolver { pub fn new(tls_connector: Arc>) -> Self { Self { tls_connector } } } #[async_trait] impl Resolver for DohResolver { async fn resolve_host(&self, host: &str) -> std::io::Result> { lookup_host(host, &self.tls_connector).await } } async fn lookup_host( host: &str, tls_connector: &Arc>, ) -> std::io::Result> { let stream = TcpStream::connect("8.8.8.8:443").await?; let mut stream = tls_connector.connect("dns.google", stream).await?; let mut request = String::with_capacity(4096); request.push_str("GET /resolve?name="); request.push_str(host); request.push_str(" HTTP/1.1\r\n"); request.push_str("Host: dns.google\r\n"); request.push_str("User-Agent: curl/7.68.0\r\n"); request.push_str("Accept: application/dns-json\r\n"); request.push_str("Accept-Encoding: identity\r\n"); request.push_str("Connection: close\r\n\r\n"); debug!("request: {}", &request); stream.write_all(&request.into_bytes()).await?; let mut line_reader = LineReader::new(); let line = line_reader.read_line(&mut stream).await?; if !line.starts_with("HTTP/1.1 200 OK") { loop { let line = line_reader.read_line(&mut stream).await?; debug!("FAILURE LINE: {}", line); if line.is_empty() { break; } } if let Ok(s) = std::str::from_utf8(line_reader.unparsed_data()) { debug!("MSGUP: {}", s); } let mut buf = allocate_vec(40960); let len = stream.read(&mut buf).await?; debug!("READ {}", len); if let Ok(s) = std::str::from_utf8(&buf[0..len]) { debug!("MSG: {}", s); } return Err(std::io::Error::new( std::io::ErrorKind::Other, format!("DoH request failed: {}", "blurghaaa"), )); } let mut content_length: Option = None; loop { let line = line_reader.read_line(&mut stream).await?; if line.is_empty() { break; } if line.to_ascii_lowercase().starts_with("content-length: ") { let len = line[16..].parse::().map_err(|e| { std::io::Error::new( std::io::ErrorKind::InvalidData, format!("failed to parse content length: {}", e), ) })?; if content_length.is_some() { return Err(std::io::Error::new( std::io::ErrorKind::Other, format!( "Got content-length header but it was already received: {}", line ), )); } content_length = Some(len); } } let json_val = match content_length { Some(len) => { if len > 255_000 { return Err(std::io::Error::new( std::io::ErrorKind::Other, format!("JSON content length is too big ({})", len), )); } let mut json_bytes = allocate_vec(len); stream.read_exact(&mut json_bytes).await?; // TODO: we could set connection: keep-alive and reuse the stream. let json_str = String::from_utf8(json_bytes).map_err(|e| { std::io::Error::new( std::io::ErrorKind::InvalidData, format!("failed to parse dns JSON: {}", e), ) })?; JsonValue::from(json_str) } None => { panic!("TODO: handle transfer-encoding chunked"); } }; let answer = &json_val["Answer"]; if !answer.is_array() { return Err(std::io::Error::new( std::io::ErrorKind::Other, "failed to read json answer", )); } let mut results = vec![]; for item in answer.members() { // 1 is A, 28 is AAAA match item["type"].as_str() { Some("1") | Some("28") => { // A or AAAA // ref: https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-4 let data = match item["data"].as_str() { Some(s) => s, None => { warn!("Answer item did not have data field"); continue; } }; let ip_addr = data.parse::().map_err(|e| { std::io::Error::new( std::io::ErrorKind::InvalidData, format!("failed to parse ip address: {}", e), ) })?; results.push(ip_addr); } Some(_) => { continue; } None => { warn!("Answer item did not have type field"); continue; } } } if results.is_empty() { return Err(std::io::Error::new( std::io::ErrorKind::Other, format!("Failed to resolve {}", host) )); } debug!("Successfully resolved {}: {:?}", host, &results); Ok(results) }