-
Notifications
You must be signed in to change notification settings - Fork 0
/
streamlit.py
113 lines (95 loc) · 4.13 KB
/
streamlit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import streamlit as st
import requests
import time
import os
import json
# API base URL
API_BASE_URL = "http://localhost:8000"
# Streamlit app
st.title("RAG Chatbot")
# Session state initialization
if 'chats' not in st.session_state:
st.session_state.chats = {}
if 'current_chat_id' not in st.session_state:
st.session_state.current_chat_id = None
if 'uploaded_docs' not in st.session_state:
st.session_state.uploaded_docs = {}
if 'chat_started' not in st.session_state:
st.session_state.chat_started = False
# Sidebar for chat list
with st.sidebar:
st.header("Chats")
for chat_id in st.session_state.chats:
if st.button(f"Chat {chat_id}"):
st.session_state.current_chat_id = chat_id
st.session_state.chat_started = True
# File uploader
uploaded_file = st.file_uploader("Upload a document", type=['pdf', 'txt', 'doc'])
if uploaded_file is not None:
# Process the uploaded file
# Save the uploaded file temporarily
with open(uploaded_file.name, "wb") as f:
f.write(uploaded_file.getbuffer())
response = requests.post(f"{API_BASE_URL}/api/documents/process", json={"file_path": uploaded_file.name})
# Delete the temporary file
os.remove(uploaded_file.name)
if response.status_code == 200:
asset_id = response.json()['asset_id']
st.session_state.uploaded_docs[uploaded_file.name] = asset_id
st.success(f"Document processed. Asset ID: {asset_id}")
else:
st.error("Error processing document")
# Document selector
if st.session_state.uploaded_docs:
selected_doc = st.selectbox("Select a document", list(st.session_state.uploaded_docs.keys()))
if st.button("Start New Chat"):
asset_id = st.session_state.uploaded_docs[selected_doc]
response = requests.post(f"{API_BASE_URL}/api/chat/start", json={'asset_id': asset_id})
if response.status_code == 200:
chat_id = response.json()['chat_id']
st.session_state.chats[chat_id] = []
st.session_state.current_chat_id = chat_id
st.session_state.chat_started = True
st.success(f"New chat started with ID: {chat_id}")
else:
st.error("Error starting new chat")
# Main chat area
if st.session_state.current_chat_id:
st.header(f"Chat {st.session_state.current_chat_id}")
# Display chat history
for message in st.session_state.chats[st.session_state.current_chat_id]:
with st.chat_message(message["role"]):
st.write(message["content"])
# Chat input
user_input = st.chat_input("Type your message here")
if user_input:
# Add user message to chat history
st.session_state.chats[st.session_state.current_chat_id].append({"role": "user", "content": user_input})
with st.chat_message("user"):
st.write(user_input)
# Send message to API and stream response
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
for response in requests.post(
f"{API_BASE_URL}/api/chat/message",
json={'chat_id': st.session_state.current_chat_id, 'message': user_input},
stream=True
).iter_content(chunk_size=1024,decode_unicode=True):
if response:
print(response)
full_response += response
message_placeholder.markdown(full_response + "▌")
message_placeholder.markdown(full_response)
# Add assistant response to chat history
st.session_state.chats[st.session_state.current_chat_id].append({"role": "assistant", "content": full_response})
# Warning about data loss on reload
st.warning("Warning: All chat data will be lost if you reload the page.")
# JavaScript to show an alert when the page is about to be reloaded
st.markdown("""
<script>
window.onbeforeunload = function() {
return "Are you sure you want to reload? All chat data will be lost.";
}
</script>
""", unsafe_allow_html=True)