Skip to content

Commit

Permalink
wip - support secure strings encrypted with ssh-key
Browse files Browse the repository at this point in the history
  • Loading branch information
aprxi committed Aug 9, 2024
1 parent 48147d6 commit 1cc7adc
Show file tree
Hide file tree
Showing 7 changed files with 444 additions and 177 deletions.
6 changes: 3 additions & 3 deletions lumni/src/apps/api/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub enum ApplicationError {
NotFound(String),
ServerConfigurationError(String),
HttpClientError(HttpClientError),
IoError(std::io::Error),
IOError(std::io::Error),
DatabaseError(String),
NotImplemented(String),
NotReady(String),
Expand Down Expand Up @@ -132,7 +132,7 @@ impl fmt::Display for ApplicationError {
ApplicationError::HttpClientError(e) => {
write!(f, "HttpClientError: {}", e)
}
ApplicationError::IoError(e) => write!(f, "IoError: {}", e),
ApplicationError::IOError(e) => write!(f, "IoError: {}", e),
ApplicationError::DatabaseError(s) => {
write!(f, "DatabaseError: {}", s)
}
Expand Down Expand Up @@ -171,7 +171,7 @@ impl From<HttpClientError> for ApplicationError {

impl From<std::io::Error> for ApplicationError {
fn from(error: std::io::Error) -> Self {
ApplicationError::IoError(error)
ApplicationError::IOError(error)
}
}

Expand Down
12 changes: 10 additions & 2 deletions lumni/src/apps/builtin/llm/prompt/src/app.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::io::{self, Write};
use std::path::PathBuf;
use std::sync::Arc;

use clap::{ArgMatches, Command};
Expand All @@ -19,7 +20,9 @@ use tokio::signal;
use tokio::sync::Mutex;
use tokio::time::{timeout, Duration};

use super::chat::db::{ConversationDatabase, ModelServerName};
use super::chat::db::{
ConversationDatabase, EncryptionHandler, ModelServerName,
};
use super::chat::{
prompt_app, App, AssistantManager, ChatEvent, NewConversation,
PromptInstruction, ThreadedChatSession,
Expand Down Expand Up @@ -165,7 +168,12 @@ pub async fn run_cli(
let config_dir =
env.get_config_dir().expect("Config directory not defined");
let sqlite_file = config_dir.join("chat.db");
let db_conn = Arc::new(ConversationDatabase::new(&sqlite_file)?);

let encryption_handler =
EncryptionHandler::new_from_path(None)?.map(Arc::new);

let db_conn =
Arc::new(ConversationDatabase::new(&sqlite_file, encryption_handler)?);

let mut profile_handler = db_conn.get_profile_handler(None);
if let Some(ref matches) = matches {
Expand Down
196 changes: 194 additions & 2 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/encryption/mod.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
use std::fs;
use std::path::PathBuf;

use base64::engine::general_purpose;
use base64::Engine as _;
use base64::{Engine, Engine as _};
use lumni::api::error::{ApplicationError, EncryptionError};
use ring::aead;
use ring::rand::{SecureRandom, SystemRandom};
use rsa::pkcs1::DecodeRsaPrivateKey;
use rsa::pkcs1v15::Pkcs1v15Encrypt;
use rsa::pkcs8::{
DecodePrivateKey, DecodePublicKey, EncodePrivateKey, EncodePublicKey,
LineEnding,
};
use rsa::{Pkcs1v15Encrypt, RsaPrivateKey, RsaPublicKey};
use rsa::{BigUint, RsaPrivateKey, RsaPublicKey};

use crate::external as lumni;

#[derive(Debug)]
pub struct EncryptionHandler {
public_key: RsaPublicKey,
private_key: RsaPrivateKey,
Expand All @@ -31,6 +37,192 @@ impl EncryptionHandler {
})
}

pub fn new_from_path(
private_key_path: Option<&PathBuf>,
) -> Result<Option<Self>, ApplicationError> {
match private_key_path {
Some(path) => {
if !path.exists() {
return Err(ApplicationError::NotFound(format!(
"Private key file not found: {:?}",
path
)));
}

let private_key = Self::parse_private_key(
path.to_str().ok_or_else(|| {
ApplicationError::InvalidInput(
"Invalid path".to_string(),
)
})?,
)?;
let public_key = RsaPublicKey::from(&private_key);
let public_key_pem = public_key
.to_public_key_pem(LineEnding::LF)
.map_err(|e| EncryptionError::Other(Box::new(e)))?;

let private_key_pem = private_key
.to_pkcs8_pem(LineEnding::LF)
.map_err(|e| EncryptionError::Pkcs8Error(e))?;

Ok(Some(Self::new(&public_key_pem, &private_key_pem)?))
}
None => {
if let Some(home_dir) = dirs::home_dir() {
let default_path = home_dir.join(".ssh").join("id_rsa");
if default_path.exists() {
Self::new_from_path(Some(&default_path))
} else {
Ok(None)
}
} else {
Ok(None)
}
}
}
}

fn parse_private_key(
key_path: &str,
) -> Result<RsaPrivateKey, ApplicationError> {
let key_data = fs::read_to_string(key_path)
.map_err(|e| ApplicationError::IOError(e))?;

// Try parsing as OpenSSH format
if key_data.starts_with("-----BEGIN OPENSSH PRIVATE KEY-----") {
return Self::parse_openssh_private_key(&key_data);
}

// Try parsing as PKCS#8 PEM
if let Ok(key) = RsaPrivateKey::from_pkcs8_pem(&key_data) {
return Ok(key);
}

// Try parsing as PKCS#1 PEM
if let Ok(key) = RsaPrivateKey::from_pkcs1_pem(&key_data) {
return Ok(key);
}

// If all parsing attempts fail, return an error
Err(ApplicationError::InvalidInput(
"Unable to parse private key: unsupported format".to_string(),
))
}

fn parse_openssh_private_key(
key_data: &str,
) -> Result<RsaPrivateKey, ApplicationError> {
let lines: Vec<&str> = key_data.lines().collect();

if !lines[0].starts_with("-----BEGIN OPENSSH PRIVATE KEY-----") {
return Err(ApplicationError::InvalidInput(
"Not an OpenSSH private key".to_string(),
));
}

let base64_data = lines[1..lines.len() - 1].join("");
let decoded =
general_purpose::STANDARD.decode(base64_data).map_err(|e| {
ApplicationError::from(EncryptionError::Base64Error(e))
})?;

// OpenSSH magic header
if &decoded[0..15] != b"openssh-key-v1\0" {
return Err(ApplicationError::InvalidInput(
"Invalid OpenSSH key format".to_string(),
));
}

let mut index = 15;

// Skip ciphername, kdfname, kdfoptions
for _ in 0..3 {
let len = u32::from_be_bytes([
decoded[index],
decoded[index + 1],
decoded[index + 2],
decoded[index + 3],
]) as usize;
index += 4 + len;
}

// Number of keys (should be 1)
index += 4;

// Public key length
let pubkey_len = u32::from_be_bytes([
decoded[index],
decoded[index + 1],
decoded[index + 2],
decoded[index + 3],
]) as usize;
index += 4;

// Skip public key
index += pubkey_len;

// Private key length
let privkey_len = u32::from_be_bytes([
decoded[index],
decoded[index + 1],
decoded[index + 2],
decoded[index + 3],
]) as usize;
index += 4;

// Skip checkints
index += 8;

// Key type length
let key_type_len = u32::from_be_bytes([
decoded[index],
decoded[index + 1],
decoded[index + 2],
decoded[index + 3],
]) as usize;
index += 4;

// Skip key type
index += key_type_len;

// Extract n
let n_len = u32::from_be_bytes([
decoded[index],
decoded[index + 1],
decoded[index + 2],
decoded[index + 3],
]) as usize;
index += 4;
let n = BigUint::from_bytes_be(&decoded[index..index + n_len]);
index += n_len;

// Extract e
let e_len = u32::from_be_bytes([
decoded[index],
decoded[index + 1],
decoded[index + 2],
decoded[index + 3],
]) as usize;
index += 4;
let e = BigUint::from_bytes_be(&decoded[index..index + e_len]);
index += e_len;

// Extract d
let d_len = u32::from_be_bytes([
decoded[index],
decoded[index + 1],
decoded[index + 2],
decoded[index + 3],
]) as usize;
index += 4;
let d = BigUint::from_bytes_be(&decoded[index..index + d_len]);

// We're ignoring iqmp, p, and q for simplicity, but a full implementation should use these

RsaPrivateKey::from_components(n, e, d, vec![])
.map_err(|e| ApplicationError::from(EncryptionError::RsaError(e)))
}

pub fn get_ssh_private_key(&self) -> Result<Vec<u8>, EncryptionError> {
// Convert the RSA private key to PKCS#8 PEM format with LF line endings
let pem = self
Expand Down
1 change: 1 addition & 0 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ mod store;
mod user_profiles;

pub use conversations::ConversationDbHandler;
pub use encryption::EncryptionHandler;
pub use lumni::Timestamp;
pub use model::{ModelIdentifier, ModelSpec};
use serde::{Deserialize, Serialize};
Expand Down
9 changes: 6 additions & 3 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use tokio::sync::Mutex as TokioMutex;

use super::connector::{DatabaseConnector, DatabaseOperationError};
use super::conversations::ConversationDbHandler;
use super::encryption::EncryptionHandler;
use super::encryption::{self, EncryptionHandler};
use super::user_profiles::UserProfileDbHandler;
use super::{
Conversation, ConversationId, ConversationStatus, Message, MessageId,
Expand All @@ -23,7 +23,10 @@ pub struct ConversationDatabase {
}

impl ConversationDatabase {
pub fn new(sqlite_file: &PathBuf) -> Result<Self, DatabaseOperationError> {
pub fn new(
sqlite_file: &PathBuf,
encryption_handler: Option<Arc<EncryptionHandler>>,
) -> Result<Self, DatabaseOperationError> {
PROMPT_SQLITE_FILEPATH
.set(sqlite_file.clone())
.map_err(|_| {
Expand All @@ -35,7 +38,7 @@ impl ConversationDatabase {
})?;
Ok(Self {
db: Arc::new(TokioMutex::new(DatabaseConnector::new(sqlite_file)?)),
encryption_handler: None,
encryption_handler,
})
}

Expand Down
Loading

0 comments on commit 1cc7adc

Please sign in to comment.