Skip to content

Commit

Permalink
- feature: add max execution time (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
agallardol authored Jul 4, 2024
1 parent a7dd4c2 commit 771ba3a
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 27 deletions.
72 changes: 63 additions & 9 deletions libs/shinkai-tools-runner/src/lib.test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ async fn shinkai_tool_echo() {
let _ = tool
.load_from_code(&tool_definition.code.clone().unwrap(), "")
.await;
let run_result = tool.run("{ \"message\": \"valparaíso\" }").await.unwrap();
let run_result = tool
.run("{ \"message\": \"valparaíso\" }", None)
.await
.unwrap();
assert_eq!(run_result.data["message"], "echoing: valparaíso");
}

Expand All @@ -26,7 +29,7 @@ async fn shinkai_tool_weather_by_city() {
You can also call config method
let _ = tool.config("{ \"apiKey\": \"63d35ff6068c3103ccd1227526935675\" }").await;
*/
let run_result = tool.run("{ \"city\": \"valparaíso\" }").await;
let run_result = tool.run("{ \"city\": \"valparaíso\" }", None).await;
assert!(run_result.is_ok());
}

Expand Down Expand Up @@ -59,7 +62,7 @@ async fn shinkai_tool_inline() {
"#;
let mut tool = Tool::new();
let _ = tool.load_from_code(js_code, "").await;
let run_result = tool.run("{ \"name\": \"world\" }").await.unwrap();
let run_result = tool.run("{ \"name\": \"world\" }", None).await.unwrap();
assert_eq!(run_result.data, "Hello, world!");
}

Expand All @@ -71,7 +74,10 @@ async fn shinkai_tool_web3_eth_balance() {
.load_from_code(&tool_definition.code.clone().unwrap(), "")
.await;
let run_result = tool
.run("{ \"address\": \"0x388c818ca8b9251b393131c08a736a67ccb19297\" }")
.run(
"{ \"address\": \"0x388c818ca8b9251b393131c08a736a67ccb19297\" }",
None,
)
.await;
println!("{}", run_result.as_ref().unwrap().data);
assert!(run_result.is_ok());
Expand All @@ -94,6 +100,7 @@ async fn shinkai_tool_web3_eth_uniswap() {
"toAddress": "0xd8da6bf26964af9d7eed9e03e53415d37aa96045",
"slippagePercent": 0.5
}"#,
None,
)
.await;
assert!(run_result.is_ok());
Expand All @@ -111,6 +118,7 @@ async fn shinkai_tool_download_page() {
r#"{
"url": "https://shinkai.com"
}"#,
None,
)
.await;
assert!(run_result.is_ok());
Expand Down Expand Up @@ -149,7 +157,7 @@ async fn set_timeout() {
let mut tool = Tool::new();
let _ = tool.load_from_code(js_code, "").await;
let start_time = std::time::Instant::now();
let _ = tool.run("").await.unwrap();
let _ = tool.run("", None).await.unwrap();
let elapsed_time = start_time.elapsed();
assert!(elapsed_time.as_millis() > 3000);
}
Expand Down Expand Up @@ -189,7 +197,7 @@ async fn set_timeout_no_delay_param() {
let mut tool = Tool::new();
let _ = tool.load_from_code(js_code, "").await;
let start_time = std::time::Instant::now();
let run_result = tool.run("").await.unwrap();
let run_result = tool.run("", None).await.unwrap();
let elapsed_time = start_time.elapsed();
assert!(elapsed_time.as_millis() <= 50);
assert_eq!(run_result.data, 1);
Expand Down Expand Up @@ -235,7 +243,7 @@ async fn clear_timeout() {
let mut tool = Tool::new();
let _ = tool.load_from_code(js_code, "").await;
let start_time = std::time::Instant::now();
let _ = tool.run("").await.unwrap();
let _ = tool.run("", None).await.unwrap();
let elapsed_time = start_time.elapsed();
assert!(elapsed_time.as_millis() >= 1500 && elapsed_time.as_millis() <= 1550);
}
Expand Down Expand Up @@ -279,7 +287,7 @@ async fn set_interval() {
let mut tool = Tool::new();
let _ = tool.load_from_code(js_code, "").await;
let start_time = std::time::Instant::now();
let run_result = tool.run("").await.unwrap();
let run_result = tool.run("", None).await.unwrap();
let elapsed_time = start_time.elapsed();
assert_eq!(run_result.data, 5);
assert!(elapsed_time.as_millis() <= 1100);
Expand Down Expand Up @@ -330,8 +338,54 @@ async fn clear_interval() {
let mut tool = Tool::new();
let _ = tool.load_from_code(js_code, "").await;
let start_time = std::time::Instant::now();
let run_result = tool.run("").await.unwrap();
let run_result = tool.run("", None).await.unwrap();
let elapsed_time = start_time.elapsed();
assert!(run_result.data.as_number().unwrap().as_u64().unwrap() <= 11);
assert!(elapsed_time.as_millis() <= 2050);
}

#[tokio::test]
async fn max_execution_time() {
let js_code = r#"
class BaseTool {
constructor(config) {
this.config = config;
}
setConfig(value) {
this.config = value;
return this.config;
}
getConfig() {
return this.config;
}
}
class Tool extends BaseTool {
constructor(config) {
super(config);
}
async run() {
let startedAt = Date.now();
while (true) {
const elapse = Date.now() - startedAt;
console.log(`while true every ${500}ms, elapse ${elapse} ms`);
await new Promise(async (resolve) => {
setTimeout(() => {
resolve();
}, 500);
});
}
return { data: true };
}
}
globalThis.tool = { Tool };
"#;
let mut tool = Tool::new();
let _ = tool.load_from_code(js_code, "").await;
let start_time = std::time::Instant::now();
let run_result = tool.run("", Some(10000)).await;
let elapsed_time = start_time.elapsed();
assert!(run_result.is_err());
assert!(elapsed_time.as_millis() <= 10050);
assert!(run_result.err().unwrap().message().contains("time reached"));
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ fn set_timeout<'js>(
}
_ = tokio::time::sleep(Duration::from_millis(delay)) => {
println!("calling setTimeout callback after {} ms", delay);
callback.call::<_, ()>(()).unwrap();
if let Err(e) = callback.call::<_, ()>(()) {
println!("error calling callback: {}", e);
}
}
}
});
Expand Down
49 changes: 37 additions & 12 deletions libs/shinkai-tools-runner/src/tools/quickjs_runtime/script.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::time::Duration;

use super::context_globals::init_globals;
use super::execution_error::ExecutionError;

Expand All @@ -7,46 +9,60 @@ use rquickjs::{async_with, AsyncContext, AsyncRuntime, Object, Value};
pub struct Script {
runtime: Option<AsyncRuntime>,
context: Option<AsyncContext>,
terminate: bool,
}

impl Script {
pub fn new() -> Self {
Script {
runtime: None,
context: None,
terminate: false,
}
}

pub async fn init(&mut self) {
let (runtime, context) = Self::build_runtime().await;
let (runtime, context) = self.build_runtime().await;
self.runtime = Some(runtime);
self.context = Some(context);
let terminate = self.terminate;
self.runtime
.as_ref()
.unwrap()
.set_interrupt_handler(Some(Box::new(move || terminate)))
.await;
}

async fn build_runtime() -> (AsyncRuntime, AsyncContext) {
async fn build_runtime(&self) -> (AsyncRuntime, AsyncContext) {
let runtime: AsyncRuntime = AsyncRuntime::new().unwrap();
runtime.set_memory_limit(1024 * 1024 * 1024).await; // 1 GB
runtime.set_max_stack_size(1024 * 1024).await; // 1 MB
let context = AsyncContext::full(&runtime).await;
context.as_ref().unwrap().with(|ctx| {
let _ = init_globals(&ctx);
}).await;
context
.as_ref()
.unwrap()
.with(|ctx| {
let _ = init_globals(&ctx);
})
.await;
(runtime, context.unwrap())
}

pub async fn call_promise(
&mut self,
fn_name: &str,
json_args: &str,
max_execution_time_ms: u64,
) -> Result<serde_json::Value, ExecutionError> {
println!("calling fn:{}", fn_name);
let js_code: String = format!("await {fn_name}({json_args})");
self.execute_promise(js_code).await
self.execute_promise(js_code, max_execution_time_ms).await
}

pub async fn execute_promise(
&mut self,
js_code: String,
max_execution_time_ms: u64,
) -> Result<serde_json::Value, ExecutionError> {
let id = nanoid!();
let id_clone = id.clone();
Expand All @@ -56,7 +72,8 @@ impl Script {
&js_code[..20.min(js_code.len())]
);
let js_code_clone = js_code.clone(); // Clone js_code here
let result = async_with!(self.context.clone().unwrap() => |ctx|{
let context_clone = self.context.clone().unwrap();
let result = async_with!(context_clone => |ctx|{
let eval_promise_result = ctx.eval_promise::<_>(js_code);

let result = eval_promise_result.unwrap().into_future::<Object>().await.map_err(|e| {
Expand Down Expand Up @@ -112,11 +129,19 @@ impl Script {
return Err(error)
}
}
})
.await;
if let Some(error) = result.as_ref().err() {
println!("id:{} run error result: {}", id_clone, error.message());
});
tokio::select! {
result = result => {
if let Some(error) = result.as_ref().err() {
println!("id:{} run error result: {}", id_clone, error.message());
}
result
}
_ = tokio::time::sleep(Duration::from_millis(max_execution_time_ms)) => {
println!("sending termination signal");
self.terminate = true;
Err(ExecutionError::new("max execution time reached".to_string(), None))
}
}
result
}
}
32 changes: 27 additions & 5 deletions libs/shinkai-tools-runner/src/tools/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ impl Default for Tool {
}

impl Tool {
pub const MAX_EXECUTION_TIME_MS_INTERNAL_OPS: u64 = 1000;
pub const MAX_EXECUTION_TIME_MS_INTERNAL_RUN_OP: u64 = 60 * 1000;

pub fn new() -> Self {
Tool {
script: Script::new(),
Expand All @@ -27,7 +30,9 @@ impl Tool {
configurations: &str,
) -> Result<(), ExecutionError> {
self.script.init().await;
self.script.execute_promise(code.to_string()).await?;
self.script
.execute_promise(code.to_string(), Self::MAX_EXECUTION_TIME_MS_INTERNAL_OPS)
.await?;
self.script
.execute_promise(
format!(
Expand All @@ -36,6 +41,7 @@ impl Tool {
"#
)
.to_string(),
1000,
)
.await?;
Ok(())
Expand All @@ -44,28 +50,44 @@ impl Tool {
pub async fn get_definition(&mut self) -> Result<ToolDefinition, ExecutionError> {
let run_result = self
.script
.call_promise("toolInstance.getDefinition", "")
.call_promise(
"toolInstance.getDefinition",
"",
Self::MAX_EXECUTION_TIME_MS_INTERNAL_OPS,
)
.await?;
Ok(serde_json::from_value::<ToolDefinition>(run_result).unwrap())
}

pub async fn config(&mut self, configurations: &str) -> Result<(), ExecutionError> {
let result = self
.script
.call_promise("toolInstance.setConfig", configurations)
.call_promise(
"toolInstance.setConfig",
configurations,
Self::MAX_EXECUTION_TIME_MS_INTERNAL_OPS,
)
.await;
match result {
Ok(_) => Ok(()),
Err(error) => Err(error),
}
}

pub async fn run(&mut self, parameters: &str) -> Result<RunResult, ExecutionError> {
pub async fn run(
&mut self,
parameters: &str,
max_execution_time_ms: Option<u64>,
) -> Result<RunResult, ExecutionError> {
// This String generic type is hardcoded atm
// We should decide what's is going to be the output for run method
let run_result = self
.script
.call_promise("toolInstance.run", parameters)
.call_promise(
"toolInstance.run",
parameters,
max_execution_time_ms.unwrap_or(Self::MAX_EXECUTION_TIME_MS_INTERNAL_RUN_OP),
)
.await?;
Ok(serde_json::from_value::<RunResult>(run_result).unwrap())
}
Expand Down

0 comments on commit 771ba3a

Please sign in to comment.