src/server/ssh/mod.rs
Ref: Size: 2.3 KiB
pub mod auth;
pub mod session;
use std::net::SocketAddr;
use std::path::Path;
use std::sync::Arc;
use async_trait::async_trait;
use russh_keys::key::KeyPair;
use russh::server::Server as _;
use tracing::{error, info};
use session::{SshHandler, SshServerConfig};
/// Load an ed25519 host key from disk, or generate a new one if not found.
pub fn load_or_generate_host_key(key_path: &Path) -> Result<KeyPair, Box<dyn std::error::Error>> {
if key_path.exists() {
info!("Loading SSH host key from {:?}", key_path);
let key = russh_keys::load_secret_key(key_path, None)?;
Ok(key)
} else {
info!("Generating new SSH host key at {:?}", key_path);
let key = KeyPair::generate_ed25519();
// Ensure parent directory exists
if let Some(parent) = key_path.parent() {
std::fs::create_dir_all(parent)?;
}
let file = std::fs::File::create(key_path)?;
// Restrict host key file permissions to owner-only
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
std::fs::set_permissions(key_path, std::fs::Permissions::from_mode(0o600))?;
}
russh_keys::encode_pkcs8_pem(&key, file)?;
Ok(key)
}
}
/// The russh Server implementation that spawns new SshHandler per connection.
struct CollabSshServer {
config: Arc<SshServerConfig>,
}
#[async_trait]
impl russh::server::Server for CollabSshServer {
type Handler = SshHandler;
fn new_client(&mut self, peer_addr: Option<SocketAddr>) -> SshHandler {
info!("New SSH connection from {:?}", peer_addr);
SshHandler::new(self.config.clone())
}
}
/// Start the SSH server on the given bind address.
pub async fn serve(
bind_addr: SocketAddr,
host_key: KeyPair,
ssh_config: SshServerConfig,
) -> Result<(), std::io::Error> {
let russh_config = russh::server::Config {
keys: vec![host_key],
methods: russh::MethodSet::PUBLICKEY,
..Default::default()
};
let mut server = CollabSshServer {
config: Arc::new(ssh_config),
};
info!("SSH server listening on {}", bind_addr);
match server
.run_on_address(Arc::new(russh_config), bind_addr)
.await
{
Ok(()) => Ok(()),
Err(e) => {
error!("SSH server error: {}", e);
Err(e)
}
}
}