Skip to content

Commit

Permalink
fix http_get issue with latest ollama
Browse files Browse the repository at this point in the history
  • Loading branch information
aprxi committed Aug 3, 2024
1 parent 57d985d commit 1f0d89b
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 102 deletions.
9 changes: 8 additions & 1 deletion lumni/src/apps/api/error.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::fmt;

use rusqlite::Error as SqliteError;
use tokio::task::JoinError;

// export the http client error via api::error
pub use crate::http::client::HttpClientError;
Expand Down Expand Up @@ -175,6 +176,12 @@ impl From<anyhow::Error> for ApplicationError {
}
}

impl From<JoinError> for ApplicationError {
fn from(error: JoinError) -> Self {
ApplicationError::Runtime(format!("Task join error: {}", error))
}
}

impl From<&str> for LumniError {
fn from(error: &str) -> Self {
LumniError::Any(error.to_owned())
Expand All @@ -185,4 +192,4 @@ impl From<std::string::String> for LumniError {
fn from(error: std::string::String) -> Self {
LumniError::Any(error.to_owned())
}
}
}
1 change: 1 addition & 0 deletions lumni/src/apps/builtin/llm/prompt/src/server/ollama/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ impl ServerTrait for Ollama {
payload,
)
.await;

if let Ok(response) = response {
// check if model is available by validating the response format
// at this moment we not yet need the response itself
Expand Down
59 changes: 34 additions & 25 deletions lumni/src/apps/builtin/llm/prompt/src/server/send.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,25 @@ pub async fn http_get_with_response(
"application/json".to_string(),
)]);
let (tx, mut rx) = mpsc::channel(1);

// Spawn a task to handle the HTTP request
let request_task = tokio::spawn(async move {
http_client.get(&url, Some(&header), None, Some(tx), None).await
});

let mut response_bytes = BytesMut::new();

// Receive chunks from the channel
while let Some(response) = rx.recv().await {
response_bytes.extend_from_slice(&response);
}

let result = http_client
.get(&url, Some(&header), None, Some(tx), None)
.await;
// Wait for the request task to complete
let result = request_task.await?;

// Handle the result
match result {
Ok(_) => {
let mut response_bytes = BytesMut::new();
while let Some(response) = rx.recv().await {
response_bytes.extend_from_slice(&response);
}
drop(rx); // drop the receiver to close the channel
Ok(response_bytes.freeze())
}
Ok(_) => Ok(response_bytes.freeze()),
Err(e) => Err(e.into()),
}
}
Expand All @@ -81,28 +86,32 @@ pub async fn http_post_with_response(
)]);
let (tx, mut rx) = mpsc::channel(1);
let payload_bytes = Bytes::from(payload.into_bytes());

let result = http_client
.post(

// Spawn a task to handle the HTTP request
let request_task = tokio::spawn(async move {
http_client.post(
&url,
Some(&headers),
None,
Some(&payload_bytes),
Some(tx),
None,
)
.await;
).await
});

let mut response_bytes = BytesMut::new();

// Receive chunks from the channel
while let Some(response) = rx.recv().await {
response_bytes.extend_from_slice(&response);
}

// Handle the result of the HTTP POST request
// Wait for the request task to complete
let result = request_task.await?;

// Handle the result
match result {
Ok(_) => {
let mut response_bytes = BytesMut::new();
while let Some(response) = rx.recv().await {
response_bytes.extend_from_slice(&response);
}
drop(rx); // drop the receiver to close the channel
Ok(response_bytes.freeze())
}
Ok(_) => Ok(response_bytes.freeze()),
Err(e) => Err(e.into()),
}
}
155 changes: 79 additions & 76 deletions lumni/src/http/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::time::Duration;

use anyhow::{anyhow, Error as AnyhowError, Result};
use bytes::{Bytes, BytesMut};
use futures::future::pending;
use tokio::time::timeout;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Empty, Full};
use hyper::header::{HeaderName, HeaderValue};
Expand Down Expand Up @@ -45,7 +45,7 @@ impl HttpClientResponse {
#[derive(Debug, Clone)]
pub enum HttpClientError {
ConnectionError(String),
TimeoutError,
Timeout,
HttpError(u16, String), // Status code, status text
Utf8Error(String),
RequestCancelled,
Expand All @@ -58,7 +58,7 @@ impl fmt::Display for HttpClientError {
HttpClientError::ConnectionError(e) => {
write!(f, "ConnectionError: {}", e)
}
HttpClientError::TimeoutError => write!(f, "TimeoutError"),
HttpClientError::Timeout => write!(f, "Timeout"),
HttpClientError::HttpError(code, message) => {
write!(f, "HTTPError: {} {}", code, message)
}
Expand Down Expand Up @@ -175,86 +175,89 @@ impl HttpClient {
let request = req_builder
.body(request_body)
.expect("Failed to build the request");
// Send the request and await the response, handling timeout as needed
let mut response =
self.client.request(request).await.map_err(|_| {
HttpClientError::ConnectionError(url.to_string())
})?;

if !response.status().is_success() {
let canonical_reason = response
.status()
.canonical_reason()
.unwrap_or("")
.to_string();
if let Some(error_handler) = &self.error_handler {
// Custom error handling
let http_client_response = HttpClientResponse {
body: None,
status_code: response.status().as_u16(),
headers: response.headers().clone(),
};
return Err(error_handler
.handle_error(http_client_response, canonical_reason));
}
return Err(HttpClientError::HttpError(
response.status().as_u16(),
canonical_reason,
));
}

let status_code = response.status().as_u16();
let headers = response.headers().clone();
let body;

if let Some(tx) = &tx {
body = None;
loop {
let frame_future = response.frame();
tokio::select! {
next = frame_future => {
match next {
Some(Ok(frame)) => {
if let Ok(chunk) = frame.into_data() {
if let Err(e) = tx.send(chunk).await {
return Err(HttpClientError::Other(e.to_string()));
}
match timeout(self.timeout, self.client.request(request)).await {
Ok(result) => {
let mut response = result.map_err(|_| {
HttpClientError::ConnectionError(url.to_string())
})?;

if !response.status().is_success() {
let canonical_reason = response
.status()
.canonical_reason()
.unwrap_or("")
.to_string();
if let Some(error_handler) = &self.error_handler {
// Custom error handling
let http_client_response = HttpClientResponse {
body: None,
status_code: response.status().as_u16(),
headers: response.headers().clone(),
};
return Err(error_handler
.handle_error(http_client_response, canonical_reason));
}
return Err(HttpClientError::HttpError(
response.status().as_u16(),
canonical_reason,
));
}

let status_code = response.status().as_u16();
let headers = response.headers().clone();
let body;

if let Some(tx) = &tx {
body = None;
loop {
let frame_future = response.frame();
tokio::select! {
next = frame_future => {
match next {
Some(Ok(frame)) => {
if let Ok(chunk) = frame.into_data() {
if let Err(e) = tx.send(chunk).await {
return Err(HttpClientError::Other(e.to_string()));
}
}
},
Some(Err(e)) => return Err(HttpClientError::Other(e.to_string())),
None => break, // End of the stream
}
},
// Check if the request has been cancelled
_ = async {
if let Some(rx) = &mut cancel_rx {
rx.await.ok();
} else {
std::future::pending::<()>().await;
}
} => {
drop(response); // Optionally drop the response to close the connection
return Err(HttpClientError::RequestCancelled);
},
Some(Err(e)) => return Err(HttpClientError::Other(e.to_string())),
None => break, // End of the stream
}
},
// Check if the request has been cancelled
_ = async {
if let Some(rx) = &mut cancel_rx {
rx.await.ok();
} else {
pending::<()>().await;
}
} else {
let mut body_bytes = BytesMut::new();
while let Some(next) = response.frame().await {
let frame = next.map_err(|e| anyhow!(e))?;
if let Some(chunk) = frame.data_ref() {
body_bytes.extend_from_slice(chunk);
}
} => {
drop(response); // Optionally drop the response to close the connection
return Err(HttpClientError::RequestCancelled);
},
}
}
} else {
let mut body_bytes = BytesMut::new();
while let Some(next) = response.frame().await {
// get headers for debugging
let frame = next.map_err(|e| anyhow!(e))?;
if let Some(chunk) = frame.data_ref() {
body_bytes.extend_from_slice(chunk);
}
body = Some(body_bytes.into());
}
}
body = Some(body_bytes.into());
}

Ok(HttpClientResponse {
body,
status_code,
headers,
})
Ok(HttpClientResponse {
body,
status_code,
headers,
})
},
Err(_) => Err(HttpClientError::Timeout),
}
}

pub async fn get(
Expand Down

0 comments on commit 1f0d89b

Please sign in to comment.