use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; use async_trait::async_trait; use log::{error, info, LevelFilter}; use rand_core::OsRng; use russh::server::{Auth, Msg, Server as _, Session}; use russh::{Channel, ChannelId}; use russh_sftp::protocol::{File, FileAttributes, Handle, Name, Status, StatusCode, Version}; use tokio::sync::Mutex; #[derive(Clone)] struct Server; impl russh::server::Server for Server { type Handler = SshSession; fn new_client(&mut self, _: Option) -> Self::Handler { SshSession::default() } } struct SshSession { clients: Arc>>>, } impl Default for SshSession { fn default() -> Self { Self { clients: Arc::new(Mutex::new(HashMap::new())), } } } impl SshSession { pub async fn get_channel(&mut self, channel_id: ChannelId) -> Channel { let mut clients = self.clients.lock().await; clients.remove(&channel_id).unwrap() } } #[async_trait] impl russh::server::Handler for SshSession { type Error = anyhow::Error; async fn auth_password(&mut self, user: &str, password: &str) -> Result { info!("credentials: {}, {}", user, password); Ok(Auth::Accept) } async fn auth_publickey( &mut self, user: &str, public_key: &russh_keys::ssh_key::PublicKey, ) -> Result { info!("credentials: {}, {:?}", user, public_key); Ok(Auth::Accept) } async fn channel_open_session( &mut self, channel: Channel, _session: &mut Session, ) -> Result { { let mut clients = self.clients.lock().await; clients.insert(channel.id(), channel); } Ok(true) } async fn channel_eof( &mut self, channel: ChannelId, session: &mut Session, ) -> Result<(), Self::Error> { // After a client has sent an EOF, indicating that they don't want // to send more data in this session, the channel can be closed. session.close(channel)?; Ok(()) } async fn subsystem_request( &mut self, channel_id: ChannelId, name: &str, session: &mut Session, ) -> Result<(), Self::Error> { info!("subsystem: {}", name); if name == "sftp" { let channel = self.get_channel(channel_id).await; let sftp = SftpSession::default(); session.channel_success(channel_id)?; russh_sftp::server::run(channel.into_stream(), sftp).await; } else { session.channel_failure(channel_id)?; } Ok(()) } } #[derive(Default)] struct SftpSession { version: Option, root_dir_read_done: bool, } #[async_trait] impl russh_sftp::server::Handler for SftpSession { type Error = StatusCode; fn unimplemented(&self) -> Self::Error { StatusCode::OpUnsupported } async fn init( &mut self, version: u32, extensions: HashMap, ) -> Result { if self.version.is_some() { error!("duplicate SSH_FXP_VERSION packet"); return Err(StatusCode::ConnectionLost); } self.version = Some(version); info!("version: {:?}, extensions: {:?}", self.version, extensions); Ok(Version::new()) } async fn close(&mut self, id: u32, _handle: String) -> Result { Ok(Status { id, status_code: StatusCode::Ok, error_message: "Ok".to_string(), language_tag: "en-US".to_string(), }) } async fn opendir(&mut self, id: u32, path: String) -> Result { info!("opendir: {}", path); self.root_dir_read_done = false; Ok(Handle { id, handle: path }) } async fn readdir(&mut self, id: u32, handle: String) -> Result { info!("readdir handle: {}", handle); if handle == "/" && !self.root_dir_read_done { self.root_dir_read_done = true; return Ok(Name { id, files: vec![ File::new("foo", FileAttributes::default()), File::new("bar", FileAttributes::default()), ], }); } // If all files have been sent to the client, respond with an EOF Err(StatusCode::Eof) } async fn realpath(&mut self, id: u32, path: String) -> Result { info!("realpath: {}", path); Ok(Name { id, files: vec![File::dummy("/")], }) } } #[tokio::main] async fn main() { env_logger::builder() .filter_level(LevelFilter::Debug) .init(); let config = russh::server::Config { auth_rejection_time: Duration::from_secs(3), auth_rejection_time_initial: Some(Duration::from_secs(0)), keys: vec![ russh_keys::PrivateKey::random(&mut OsRng, ssh_key::Algorithm::Ed25519).unwrap(), ], ..Default::default() }; let mut server = Server; server .run_on_address( Arc::new(config), ( "0.0.0.0", std::env::var("PORT") .unwrap_or("22".to_string()) .parse() .unwrap(), ), ) .await .unwrap(); }