Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
XiangpengHao committed Dec 27, 2024
1 parent 09d4720 commit 6d58bb6
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 42 deletions.
10 changes: 9 additions & 1 deletion src/file_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use object_store::{ObjectStore, PutPayload};
use opendal::{services::Http, services::S3, Operator};
use web_sys::{js_sys, Url};

use crate::{ParquetTable, INMEMORY_STORE};
use crate::{ParquetTable, INMEMORY_STORE, SESSION_CTX};

const S3_ENDPOINT_KEY: &str = "s3_endpoint";
const S3_ACCESS_KEY_ID_KEY: &str = "s3_access_key_id";
Expand Down Expand Up @@ -36,10 +36,18 @@ async fn update_file(
parquet_table: ParquetTable,
parquet_table_setter: WriteSignal<Option<ParquetTable>>,
) {
let ctx = SESSION_CTX.as_ref();
let object_store = &*INMEMORY_STORE;
let path = Path::parse(&parquet_table.table_name).unwrap();
let payload = PutPayload::from_bytes(parquet_table.bytes.clone());
object_store.put(&path, payload).await.unwrap();
ctx.register_parquet(
&parquet_table.table_name,
&format!("mem:///{}", parquet_table.table_name),
Default::default(),
)
.await
.unwrap();
parquet_table_setter.set(Some(parquet_table));
}

Expand Down
51 changes: 33 additions & 18 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
mod schema;
use datafusion::physical_plan::ExecutionPlan;
use file_reader::{get_stored_value, FileReader};
use datafusion::{
datasource::MemTable,
execution::object_store::ObjectStoreUrl,
physical_plan::ExecutionPlan,
prelude::{SessionConfig, SessionContext},
};
use file_reader::FileReader;
use leptos_router::{
components::Router,
hooks::{query_signal, use_query_map},
Expand Down Expand Up @@ -32,11 +37,21 @@ mod query_input;
use query_input::{execute_query_inner, QueryInput};

mod settings;
use settings::{Settings, ANTHROPIC_API_KEY};
use settings::Settings;

pub(crate) static INMEMORY_STORE: LazyLock<Arc<InMemory>> =
LazyLock::new(|| Arc::new(InMemory::new()));

pub(crate) static SESSION_CTX: LazyLock<Arc<SessionContext>> = LazyLock::new(|| {
let mut config = SessionConfig::new();
config.options_mut().sql_parser.dialect = "PostgreSQL".to_string();
let ctx = Arc::new(SessionContext::new_with_config(config));
let object_store_url = ObjectStoreUrl::parse("mem://").unwrap();
let object_store = INMEMORY_STORE.clone();
ctx.register_object_store(object_store_url.as_ref(), object_store);
ctx
});

#[derive(Debug, Clone, PartialEq)]
pub(crate) struct ParquetReader {
parquet_table: ParquetTable,
Expand Down Expand Up @@ -190,10 +205,9 @@ impl std::fmt::Display for ParquetInfo {
}

async fn execute_query_async(
query: String,
table_name: String,
query: &str,
) -> Result<(Vec<arrow::array::RecordBatch>, Arc<dyn ExecutionPlan>), String> {
let (results, physical_plan) = execute_query_inner(&table_name, &query)
let (results, physical_plan) = execute_query_inner(query)
.await
.map_err(|e| format!("Failed to execute query: {}", e))?;

Expand Down Expand Up @@ -269,15 +283,7 @@ fn App() -> impl IntoView {
let Some(parquet_reader) = parquet_reader.get() else {
return;
};
let api_key = get_stored_value(ANTHROPIC_API_KEY, "");
let sql = match query_input::user_input_to_sql(
&user_input,
&parquet_reader.info().schema,
parquet_reader.table_name(),
&api_key,
)
.await
{
let sql = match query_input::user_input_to_sql(&user_input, &parquet_reader).await {
Ok(response) => response,
Err(e) => {
set_error_message.set(Some(e));
Expand All @@ -301,13 +307,12 @@ fn App() -> impl IntoView {
return;
}

if let Some(parquet_table) = bytes_opt {
if let Some(_parquet_table) = bytes_opt {
let query = query.clone();
let export_to = export_to.clone();
let table_name = parquet_table.table_name;

leptos::task::spawn_local(async move {
match execute_query_async(query.clone(), table_name).await {
match execute_query_async(&query).await {
Ok((results, physical_plan)) => {
if let Some(export_to) = export_to {
if export_to == "csv" {
Expand All @@ -316,8 +321,18 @@ fn App() -> impl IntoView {
export_to_parquet_inner(&results);
}
}

set_query_results.update(|r| {
let id = r.len();
if let Some(first_batch) = results.first() {
let schema = first_batch.schema();
let mem_table =
MemTable::try_new(schema, vec![results.clone()]).unwrap();
SESSION_CTX
.as_ref()
.register_table(format!("view_{}", id), Arc::new(mem_table))
.unwrap();
}
r.push(QueryResult::new(
id,
query,
Expand Down
33 changes: 10 additions & 23 deletions src/query_input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ use arrow_array::RecordBatch;
use arrow_schema::SchemaRef;
use datafusion::{
error::DataFusionError,
execution::object_store::ObjectStoreUrl,
physical_plan::{collect, ExecutionPlan},
prelude::{ParquetReadOptions, SessionConfig},
};
use leptos::{logging, prelude::*};
use leptos::{
Expand All @@ -17,27 +15,15 @@ use serde_json::json;
use wasm_bindgen_futures::JsFuture;
use web_sys::{js_sys, Headers, Request, RequestInit, RequestMode, Response};

use crate::INMEMORY_STORE;
use crate::{
settings::{get_stored_value, ANTHROPIC_API_KEY},
ParquetReader, SESSION_CTX,
};

pub(crate) async fn execute_query_inner(
table_name: &str,
query: &str,
) -> Result<(Vec<RecordBatch>, Arc<dyn ExecutionPlan>), DataFusionError> {
let mut config = SessionConfig::new();
config.options_mut().sql_parser.dialect = "PostgreSQL".to_string();

let ctx = datafusion::prelude::SessionContext::new_with_config(config);

let object_store_url = ObjectStoreUrl::parse("mem://").unwrap();
let object_store = INMEMORY_STORE.clone();
ctx.register_object_store(object_store_url.as_ref(), object_store);
ctx.register_parquet(
table_name,
&format!("mem:///{}", table_name),
ParquetReadOptions::default(),
)
.await?;

let ctx = SESSION_CTX.as_ref();
let plan = ctx.sql(query).await?;

let (state, plan) = plan.into_parts();
Expand Down Expand Up @@ -137,9 +123,7 @@ pub fn QueryInput(

pub(crate) async fn user_input_to_sql(
input: &str,
schema: &SchemaRef,
file_name: &str,
api_key: &str,
parquet_reader: &ParquetReader,
) -> Result<String, String> {
// if the input seems to be a SQL query, return it as is
if input.starts_with("select") || input.starts_with("SELECT") {
Expand All @@ -148,6 +132,9 @@ pub(crate) async fn user_input_to_sql(

// otherwise, treat it as some natural language

let schema = &parquet_reader.info().schema;
let file_name = parquet_reader.table_name();
let api_key = get_stored_value(ANTHROPIC_API_KEY, "");
let schema_str = schema_to_brief_str(schema);
logging::log!("Processing user input: {}", input);

Expand All @@ -157,7 +144,7 @@ pub(crate) async fn user_input_to_sql(
);
logging::log!("{}", prompt);

let sql = match generate_sql_via_claude(&prompt, api_key).await {
let sql = match generate_sql_via_claude(&prompt, &api_key).await {
Ok(response) => response,
Err(e) => {
logging::log!("{}", e);
Expand Down
6 changes: 6 additions & 0 deletions src/query_results.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,12 @@ pub fn QueryResultView(result: QueryResult) -> impl IntoView {
<div class="relative">
<div class="absolute top-0 right-0 z-10">
<div class="flex items-center gap-1 rounded-md">
<div class="text-sm text-gray-500 font-mono relative group">
<span class="absolute bottom-full left-1/2 transform -translate-x-1/2 px-2 py-1 bg-gray-800 text-white text-xs rounded opacity-0 group-hover:opacity-100 whitespace-nowrap pointer-events-none">
{format!("SELECT * FROM view_{}", result.id())}
</span>
{format!("view_{}", result.id())}
</div>
<button
class="p-2 text-gray-500 hover:text-gray-700 relative group"
aria-label="Export to CSV"
Expand Down

0 comments on commit 6d58bb6

Please sign in to comment.