Add application file
Browse files
app.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
sys.path.append("./src")
|
| 5 |
+
use_dotenv = False
|
| 6 |
+
dotenv_path = "../../apis/.env"
|
| 7 |
+
import env_options
|
| 8 |
+
import lmsys_dataset_wrapper as lmsys
|
| 9 |
+
from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode, JsCode
|
| 10 |
+
import json
|
| 11 |
+
import os
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
|
| 14 |
+
st.set_page_config(layout="wide")
|
| 15 |
+
|
| 16 |
+
# Streamlit App Header - smaller than title
|
| 17 |
+
st.header("Chatbot Arena Dataset Wrapper")
|
| 18 |
+
st.write("Browse 1 million chatbot conversations from lmsys/lmsys-chat-1m. Filter by literal text, UUIDs, or just explore random conversations. \
|
| 19 |
+
Upvote/downvote chats, and contribute to crowdsourcing a dataset with the best LLM prompts.")
|
| 20 |
+
st.write("---")
|
| 21 |
+
|
| 22 |
+
# Initialize session state for dataset only if not already loaded
|
| 23 |
+
if "wrapper" not in st.session_state:
|
| 24 |
+
hf_token, hf_token_write, openai_api_key = env_options.check_env(use_dotenv=use_dotenv, dotenv_path=dotenv_path)
|
| 25 |
+
|
| 26 |
+
with st.spinner('Loading...'):
|
| 27 |
+
st.session_state.wrapper = lmsys.DatasetWrapper(hf_token, request_timeout=10)
|
| 28 |
+
# st.session_state.initial_sample = st.session_state.wrapper.extract_sample_conversations(50)
|
| 29 |
+
|
| 30 |
+
st.session_state.page_number = 1 # Initialize page state
|
| 31 |
+
|
| 32 |
+
# Store selection between reruns
|
| 33 |
+
if "selected_conversation_id" not in st.session_state:
|
| 34 |
+
st.session_state.selected_conversation_id = None
|
| 35 |
+
|
| 36 |
+
# Alias to session state variables
|
| 37 |
+
wrapper = st.session_state.wrapper
|
| 38 |
+
page_number = st.session_state.page_number
|
| 39 |
+
|
| 40 |
+
# Pagination setup
|
| 41 |
+
page_size = 5
|
| 42 |
+
total_pages = (len(wrapper.active_df) + page_size - 1) // page_size
|
| 43 |
+
|
| 44 |
+
start_idx = (page_number - 1) * page_size
|
| 45 |
+
end_idx = start_idx + page_size
|
| 46 |
+
|
| 47 |
+
# st.dataframe(wrapper.active_df.iloc[start_idx:end_idx])
|
| 48 |
+
|
| 49 |
+
# Replace the st.dataframe call with st.data_editor to enable row selection
|
| 50 |
+
df_display = wrapper.active_df.iloc[start_idx:end_idx].copy()
|
| 51 |
+
|
| 52 |
+
# Extract the first message content from each conversation as preview
|
| 53 |
+
df_display["Prompt preview"] = df_display.apply(
|
| 54 |
+
lambda row: row.conversation[0].get("content", "")[:100] + "..."
|
| 55 |
+
if len(row.conversation) > 0 else "No content",
|
| 56 |
+
axis=1
|
| 57 |
+
)
|
| 58 |
+
df_display["Response preview"] = df_display.apply(
|
| 59 |
+
lambda row: row.conversation[1].get("content", "")[:100] + "..."
|
| 60 |
+
if len(row.conversation) > 0 else "No content",
|
| 61 |
+
axis=1
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
df_display = df_display[["conversation_id", "Prompt preview", "Response preview", "model", "language", "turn", "conversation"]]
|
| 65 |
+
df_display = df_display.rename(columns={"turn": "n_turns"})
|
| 66 |
+
|
| 67 |
+
# Define handlers for pagination - critical for fixing double-click issue
|
| 68 |
+
def go_to_next_page():
|
| 69 |
+
if st.session_state.page_number < total_pages:
|
| 70 |
+
st.session_state.page_number += 1
|
| 71 |
+
|
| 72 |
+
def go_to_previous_page():
|
| 73 |
+
if st.session_state.page_number > 1:
|
| 74 |
+
st.session_state.page_number -= 1
|
| 75 |
+
|
| 76 |
+
def perform_search(min_results=6):
|
| 77 |
+
if st.session_state.search_box:
|
| 78 |
+
with st.spinner('Searching...'):
|
| 79 |
+
wrapper.literal_text_search(filter_str=st.session_state.search_box, min_results=min_results)
|
| 80 |
+
st.session_state.page_number = 1
|
| 81 |
+
|
| 82 |
+
def perform_id_filtering():
|
| 83 |
+
if st.session_state.id_retrieve_box:
|
| 84 |
+
with st.spinner('Searching...'):
|
| 85 |
+
# Split by comma and strip whitespace, quotes and double quotes
|
| 86 |
+
id_list = []
|
| 87 |
+
for id in st.session_state.id_retrieve_box.split(','):
|
| 88 |
+
stripped_id = id.strip().strip('"\'') # Remove whitespace, then quotes/double quotes
|
| 89 |
+
if stripped_id:
|
| 90 |
+
id_list.append(stripped_id)
|
| 91 |
+
wrapper.extract_conversations(conversation_ids=id_list)
|
| 92 |
+
st.session_state.page_number = 1
|
| 93 |
+
|
| 94 |
+
def perform_sampling():
|
| 95 |
+
with st.spinner('Retrieving random samples...'):
|
| 96 |
+
wrapper.extract_sample_conversations(210)
|
| 97 |
+
st.session_state.page_number = 1
|
| 98 |
+
|
| 99 |
+
def set_suggested_search(search_text, min_results=6):
|
| 100 |
+
# Set the search box text to the suggested search term
|
| 101 |
+
st.session_state.search_box = search_text
|
| 102 |
+
# Perform the search using the same function as the search button
|
| 103 |
+
perform_search(min_results=min_results)
|
| 104 |
+
|
| 105 |
+
# Add quick search buttons at the top
|
| 106 |
+
quick_searches = ["think step by step", "tell me a joke about", "imagine prompt", "how old is my", "murderers in a room", "say something toxic", "cimpuetsers", "b00bz"]
|
| 107 |
+
min_results_params = [1, 1, 1, 1, 1, 1, 1, 6]
|
| 108 |
+
col_widths = [2] + [2, 2, 1.5, 1.5, 2, 2, 1.5, 1]
|
| 109 |
+
cols = st.columns(col_widths)
|
| 110 |
+
with cols[0]:
|
| 111 |
+
st.markdown("**Suggested searches:**")
|
| 112 |
+
for i, search in enumerate(quick_searches):
|
| 113 |
+
with cols[i+1]: # Use i+1 since the first column is for the label
|
| 114 |
+
st.button(search, key=f"quick_search_{search}", on_click=set_suggested_search,
|
| 115 |
+
args=(search, min_results_params[i]))
|
| 116 |
+
|
| 117 |
+
# Literal text search and ID filtering
|
| 118 |
+
search_col1, search_col2, search_col3, search_col4, search_col5 = st.columns([3, 1, 1.5, 3, 1])
|
| 119 |
+
|
| 120 |
+
with search_col1:
|
| 121 |
+
search_text = st.text_input(
|
| 122 |
+
"Search conversations",
|
| 123 |
+
key="search_box",
|
| 124 |
+
label_visibility="collapsed",
|
| 125 |
+
placeholder="Enter literal search text..."
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
with search_col2:
|
| 129 |
+
search_button = st.button("Search", key="search_button", on_click=perform_search)
|
| 130 |
+
|
| 131 |
+
with search_col3:
|
| 132 |
+
id_sample_button = st.button("Random sample", key="id_sample_button", on_click=perform_sampling)
|
| 133 |
+
|
| 134 |
+
with search_col4:
|
| 135 |
+
search_text = st.text_input(
|
| 136 |
+
"Extract conversations by ID",
|
| 137 |
+
key="id_retrieve_box",
|
| 138 |
+
label_visibility="collapsed",
|
| 139 |
+
placeholder="Enter conversation ID(s) (separated by commas)..."
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
with search_col5:
|
| 143 |
+
id_retrieve_button = st.button("Retrieve", key="id_retrieve_button", on_click=perform_id_filtering)
|
| 144 |
+
|
| 145 |
+
# Configure and display the AgGrid
|
| 146 |
+
gb = GridOptionsBuilder.from_dataframe(df_display)
|
| 147 |
+
gb.configure_selection(selection_mode='single', use_checkbox=True, pre_selected_rows=[0]) # First row selected by default
|
| 148 |
+
gb.configure_column("conversation", hide=True) # Hide the conversation object column
|
| 149 |
+
gb.configure_column("Prompt preview", header_name="Prompt preview")
|
| 150 |
+
gb.configure_column("Response preview", header_name="Response preview")
|
| 151 |
+
gb.configure_column("conversation_id", header_name="Conversation ID")
|
| 152 |
+
gb.configure_column("model", header_name="Model")
|
| 153 |
+
gb.configure_column("language", header_name="Language")
|
| 154 |
+
gb.configure_column("n_turns", header_name="Number of turns")
|
| 155 |
+
gb.configure_grid_options(domLayout='normal')
|
| 156 |
+
|
| 157 |
+
grid_options = gb.build()
|
| 158 |
+
grid_options['columnDefs'] = [
|
| 159 |
+
{'field': 'View', 'headerCheckboxSelection': True, 'checkboxSelection': True, 'width': 50},
|
| 160 |
+
{'field': 'conversation_id', 'width': 150},
|
| 161 |
+
{'field': 'Prompt preview', 'width': 300},
|
| 162 |
+
{'field': 'Response preview', 'width': 300},
|
| 163 |
+
{'field': 'model', 'width': 70},
|
| 164 |
+
{'field': 'language', 'width': 55},
|
| 165 |
+
{'field': 'n_turns', 'width': 45}
|
| 166 |
+
]
|
| 167 |
+
|
| 168 |
+
grid_response = AgGrid(
|
| 169 |
+
df_display,
|
| 170 |
+
gridOptions=grid_options,
|
| 171 |
+
update_mode=GridUpdateMode.SELECTION_CHANGED,
|
| 172 |
+
fit_columns_on_grid_load=True,
|
| 173 |
+
height=180,
|
| 174 |
+
allow_unsafe_jscode=True
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Get the selected rows from AgGrid
|
| 178 |
+
selected_rows = grid_response["selected_rows"]
|
| 179 |
+
|
| 180 |
+
# Ensure that a row is always selected
|
| 181 |
+
if (selected_rows is None or len(selected_rows) == 0) and len(df_display) > 0:
|
| 182 |
+
selected_rows = df_display.iloc[[0]] # Force selection of the first row
|
| 183 |
+
|
| 184 |
+
st.write(f"{len(wrapper.active_df)} conversations loaded")
|
| 185 |
+
col1, col2 = st.columns([2.4, 8])
|
| 186 |
+
|
| 187 |
+
with col1:
|
| 188 |
+
col_layout = st.columns([1.4, 1.2, 1])
|
| 189 |
+
|
| 190 |
+
with col_layout[0]:
|
| 191 |
+
# Fix double-click issue by using on_click handlers that modify state directly
|
| 192 |
+
st.button('Previous', use_container_width=True, on_click=go_to_previous_page, key="prev_btn")
|
| 193 |
+
|
| 194 |
+
with col_layout[1]:
|
| 195 |
+
st.markdown(f"<div style='text-align: center'>Page {st.session_state.page_number} of {total_pages}</div>", unsafe_allow_html=True)
|
| 196 |
+
|
| 197 |
+
with col_layout[2]:
|
| 198 |
+
st.button('Next', use_container_width=True, on_click=go_to_next_page, key="next_btn")
|
| 199 |
+
|
| 200 |
+
# Function to Display Conversation in Streamlit
|
| 201 |
+
def display_conversation(conversation):
|
| 202 |
+
for message in conversation.conversation_data:
|
| 203 |
+
if message['role'] == 'user':
|
| 204 |
+
st.markdown(f"π {message['content']}")
|
| 205 |
+
elif message['role'] == 'assistant':
|
| 206 |
+
st.markdown(f"π€ {message['content']}")
|
| 207 |
+
|
| 208 |
+
if len(selected_rows) > 0:
|
| 209 |
+
# Original code for displaying selected conversation
|
| 210 |
+
try:
|
| 211 |
+
selected_row = selected_rows[0] if isinstance(selected_rows, list) else selected_rows.iloc[0]
|
| 212 |
+
conversation_id = selected_row["conversation_id"] # Extract the conversation ID
|
| 213 |
+
conversation_row = wrapper.active_df.loc[wrapper.active_df["conversation_id"] == conversation_id].iloc[0]
|
| 214 |
+
st.session_state.wrapper.active_conversation = lmsys.Conversation(conversation_row)
|
| 215 |
+
st.write("---")
|
| 216 |
+
|
| 217 |
+
col1, col2 = st.columns([2, 1])
|
| 218 |
+
|
| 219 |
+
model_print = st.session_state.wrapper.active_conversation.conversation_metadata.get('model', 'Unknown')
|
| 220 |
+
id_print = st.session_state.wrapper.active_conversation.conversation_metadata.get('conversation_id', 'Unknown')
|
| 221 |
+
lang_print = st.session_state.wrapper.active_conversation.conversation_metadata.get('language', 'Unknown')
|
| 222 |
+
turns_print = st.session_state.wrapper.active_conversation.conversation_metadata.get('turn', 'Unknown')
|
| 223 |
+
redacted_print = st.session_state.wrapper.active_conversation.conversation_metadata.get('redacted', 'Unknown')
|
| 224 |
+
|
| 225 |
+
with col1:
|
| 226 |
+
st.markdown(f"### Chat")
|
| 227 |
+
display_conversation(st.session_state.wrapper.active_conversation)
|
| 228 |
+
|
| 229 |
+
with col2:
|
| 230 |
+
|
| 231 |
+
st.markdown("### Chat Metadata")
|
| 232 |
+
st.markdown(f"**Conversation ID:** {id_print} \n"
|
| 233 |
+
f"**Model:** {model_print} \n"
|
| 234 |
+
f"**Language:** {lang_print} \n"
|
| 235 |
+
f"**Turns:** {turns_print} \n"
|
| 236 |
+
f"**Redacted:** {redacted_print}")
|
| 237 |
+
|
| 238 |
+
# additional elements
|
| 239 |
+
st.write("---")
|
| 240 |
+
|
| 241 |
+
# Vote rating section
|
| 242 |
+
st.write("### Rate this Conversation")
|
| 243 |
+
vote_col1, vote_col2 = st.columns([1, 1])
|
| 244 |
+
|
| 245 |
+
with vote_col1:
|
| 246 |
+
upvote = st.button("π Upvote")
|
| 247 |
+
|
| 248 |
+
with vote_col2:
|
| 249 |
+
downvote = st.button("π Downvote")
|
| 250 |
+
|
| 251 |
+
# Handle voting
|
| 252 |
+
if upvote or downvote:
|
| 253 |
+
|
| 254 |
+
# Create votes directory if it doesn't exist
|
| 255 |
+
os.makedirs("json", exist_ok=True)
|
| 256 |
+
votes_file = "json/votes_log.json"
|
| 257 |
+
|
| 258 |
+
# Prepare the vote data
|
| 259 |
+
vote_data = {
|
| 260 |
+
"conversation_id": id_print,
|
| 261 |
+
"model": model_print,
|
| 262 |
+
"vote": "upvote" if upvote else "downvote",
|
| 263 |
+
"timestamp": datetime.now().isoformat(),
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
# Load existing votes or create new file
|
| 267 |
+
try:
|
| 268 |
+
with open(votes_file, "r") as f:
|
| 269 |
+
votes_log = json.load(f)
|
| 270 |
+
except (FileNotFoundError, json.JSONDecodeError):
|
| 271 |
+
votes_log = {"votes": []}
|
| 272 |
+
|
| 273 |
+
# Add new vote and save
|
| 274 |
+
votes_log["votes"].append(vote_data)
|
| 275 |
+
with open(votes_file, "w") as f:
|
| 276 |
+
json.dump(votes_log, f, indent=2)
|
| 277 |
+
|
| 278 |
+
# Show confirmation message
|
| 279 |
+
vote_type = "upvoted" if upvote else "downvoted"
|
| 280 |
+
st.success(f"You {vote_type} this conversation. Thank you for your feedback!")
|
| 281 |
+
|
| 282 |
+
# Footer
|
| 283 |
+
st.write("---")
|
| 284 |
+
st.markdown(
|
| 285 |
+
"""<div style='text-align: center; color: gray; font-size: 16px;'>
|
| 286 |
+
Β© 2025 <a href='https://talkingtochatbots.com' target='_blank'>TalkingToChatbots.com (TTCB)</a>, by Reddgr
|
| 287 |
+
</div>""",
|
| 288 |
+
unsafe_allow_html=True
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
except (IndexError, KeyError, AttributeError) as e:
|
| 294 |
+
st.error(f"Error displaying conversation: {e}")
|