Spaces:
Sleeping
Sleeping
| import json | |
| import logging | |
| import os | |
| import urllib.parse | |
| from typing import Any | |
| import gradio as gr | |
| import requests | |
| from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
| logger = logging.getLogger(__name__) | |
| example = HuggingfaceHubSearch().example_value() | |
| HEADER_CONTENT = ( | |
| "# 🤗 Dataset DuckDB Query Chatbot\n\n" | |
| "This is a basic text to SQL tool that allows you to query datasets on Hugging Face Hub. " | |
| "It's a fork of " | |
| "[davidberenstein1957/text-to-sql-hub-datasets](https://huggingface.co/spaces/davidberenstein1957/text-to-sql-hub-datasets) " | |
| "that adds chat capability and table name generation." | |
| ) | |
| ABOUT_CONTENT = """ | |
| This space uses [LLama 3.1 70B](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct). | |
| via [together.ai](https://together.ai) | |
| Also, it uses the | |
| [dataset-server API](https://redocly.github.io/redoc/?url=https://datasets-server.huggingface.co/openapi.json#operation/isValidDataset). | |
| Query history is saved and given to the chat model so you can chat to refine your query as you go. | |
| When the DuckDB modal is presented, you may need to click on the name of the | |
| config/split at the base of the modal to get the table loaded for DuckDB's use. | |
| Search for and select a dataset to begin. | |
| """ | |
| SYSTEM_PROMPT_TEMPLATE = ( | |
| "You are a SQL query expert assistant that returns a DuckDB SQL queries " | |
| "based on the user's natural language query and dataset features. " | |
| "You might need to use DuckDB functions for lists and aggregations, " | |
| "given the features. Only return the SQL query, no other text. The " | |
| "user may ask you to make various adjustments to the query. Every " | |
| "time your response should only include the refined SQL query and " | |
| "nothing else.\n\n" | |
| "The table being queried is named: {table_name}.\n\n" | |
| "# Features\n" | |
| "{features}" | |
| ) | |
| def get_iframe(hub_repo_id, sql_query=None): | |
| if not hub_repo_id: | |
| raise ValueError("Hub repo id is required") | |
| if sql_query: | |
| sql_query = urllib.parse.quote(sql_query) | |
| url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer?sql_console=true&sql={sql_query}" | |
| else: | |
| url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer" | |
| iframe = f""" | |
| <iframe | |
| src="{url}" | |
| frameborder="0" | |
| width="100%" | |
| height="800px" | |
| ></iframe> | |
| """ | |
| return iframe | |
| def get_table_info(hub_repo_id): | |
| url: str = f"https://datasets-server.huggingface.co/info?dataset={hub_repo_id}" | |
| response = requests.get(url) | |
| try: | |
| data = response.json() | |
| data = data.get("dataset_info") | |
| return json.dumps(data) | |
| except Exception as e: | |
| gr.Error(f"Error getting column info: {e}") | |
| def get_table_name( | |
| config: str | None, | |
| split: str | None, | |
| config_choices: list[str], | |
| split_choices: list[str], | |
| ): | |
| if len(config_choices) > 0 and config is None: | |
| config = config_choices[0] | |
| if len(split_choices) > 0 and split is None: | |
| split = split_choices[0] | |
| if len(config_choices) > 1 and len(split_choices) > 1: | |
| base_name = f"{config}_{split}" | |
| elif len(config_choices) >= 1 and len(split_choices) <= 1: | |
| base_name = config | |
| else: | |
| base_name = split | |
| def replace_char(c): | |
| if c.isalnum(): | |
| return c | |
| if c in ["-", "_", "/"]: | |
| return "_" | |
| return "" | |
| table_name = "".join(replace_char(c) for c in base_name) | |
| if table_name[0].isdigit(): | |
| table_name = f"_{table_name}" | |
| return table_name.lower() | |
| def get_system_prompt( | |
| card_data: dict[str, Any], | |
| config: str | None, | |
| split: str | None, | |
| ): | |
| config_choices = get_config_choices(card_data) | |
| split_choices = get_split_choices(card_data) | |
| table_name = get_table_name(config, split, config_choices, split_choices) | |
| features = card_data[config]["features"] | |
| return SYSTEM_PROMPT_TEMPLATE.format( | |
| table_name=table_name, | |
| features=features, | |
| ) | |
| def get_config_choices(card_data: dict[str, Any]) -> list[str]: | |
| return list(card_data.keys()) | |
| def get_split_choices(card_data: dict[str, Any]) -> list[str]: | |
| splits = set() | |
| for config in card_data.values(): | |
| splits.update(config.get("splits", {}).keys()) | |
| return list(splits) | |
| def query_dataset(hub_repo_id, card_data, query, config, split, history): | |
| if card_data is None or len(card_data) == 0: | |
| if hub_repo_id: | |
| iframe = get_iframe(hub_repo_id) | |
| else: | |
| iframe = "<p>No dataset selected.</p>" | |
| return "", iframe, [], "" | |
| card_data = json.loads(card_data) | |
| system_prompt = get_system_prompt(card_data, config, split) | |
| messages = [{"role": "system", "content": system_prompt}] | |
| for turn in history: | |
| user, assistant = turn | |
| messages.append( | |
| { | |
| "role": "user", | |
| "content": user, | |
| } | |
| ) | |
| messages.append( | |
| { | |
| "role": "assistant", | |
| "content": assistant, | |
| } | |
| ) | |
| messages.append( | |
| { | |
| "role": "user", | |
| "content": query, | |
| } | |
| ) | |
| api_key = os.environ["API_KEY_TOGETHER_AI"].strip() | |
| response = requests.post( | |
| "https://api.together.xyz/v1/chat/completions", | |
| json=dict( | |
| model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", | |
| messages=messages, | |
| max_tokens=1000, | |
| ), | |
| headers={"Authorization": f"Bearer {api_key}"}, | |
| ) | |
| if response.status_code != 200: | |
| logger.warning(response.text) | |
| try: | |
| response.raise_for_status() | |
| except Exception as e: | |
| gr.Error(f"Could not query LLM for suggestion: {e}") | |
| response_dict = response.json() | |
| duck_query = response_dict["choices"][0]["message"]["content"] | |
| duck_query = _sanitize_duck_query(duck_query) | |
| history.append((query, duck_query)) | |
| return duck_query, get_iframe(hub_repo_id, duck_query), history, "" | |
| def _sanitize_duck_query(duck_query: str) -> str: | |
| # Sometimes the LLM wraps the query like this: | |
| # ```sql | |
| # select * from x; | |
| # ``` | |
| # This removes that wrapping if present. | |
| if "```" not in duck_query: | |
| return duck_query | |
| start_idx = duck_query.index("```") + len("```") | |
| end_idx = duck_query.rindex("```") | |
| duck_query = duck_query[start_idx:end_idx] | |
| if duck_query.startswith("sql\n"): | |
| duck_query = duck_query.replace("sql\n", "", 1) | |
| return duck_query | |
| with gr.Blocks() as demo: | |
| gr.Markdown(HEADER_CONTENT) | |
| with gr.Accordion("About/Help", open=False): | |
| gr.Markdown(ABOUT_CONTENT) | |
| with gr.Row(): | |
| search_in = HuggingfaceHubSearch( | |
| label="Search Hugging Face Hub", | |
| placeholder="Search for models on Huggingface", | |
| search_type="dataset", | |
| sumbit_on_select=True, | |
| ) | |
| with gr.Row(): | |
| show_btn = gr.Button("Show Dataset") | |
| with gr.Row(): | |
| sql_out = gr.Code( | |
| label="DuckDB SQL Query", | |
| interactive=True, | |
| language="sql", | |
| lines=1, | |
| visible=False, | |
| ) | |
| with gr.Row(): | |
| card_data = gr.Code(label="Card data", language="json", visible=False) | |
| def show_config_split_choices(data): | |
| try: | |
| data = json.loads(data.strip()) | |
| config_choices = get_config_choices(data) | |
| split_choices = get_split_choices(data) | |
| except Exception: | |
| config_choices = [] | |
| split_choices = [] | |
| initial_config = config_choices[0] if len(config_choices) > 0 else None | |
| initial_split = split_choices[0] if len(split_choices) > 0 else None | |
| with gr.Row(): | |
| with gr.Column(): | |
| config_selection = gr.Dropdown( | |
| label="Config Name", choices=config_choices, value=initial_config | |
| ) | |
| with gr.Column(): | |
| split_selection = gr.Dropdown( | |
| label="Split Name", choices=split_choices, value=initial_split | |
| ) | |
| with gr.Accordion("Query Suggestion History.", open=False) as accordion: | |
| chatbot = gr.Chatbot(height=200, layout="bubble") | |
| with gr.Row(): | |
| query = gr.Textbox( | |
| label="Query Description", | |
| placeholder="Enter a natural language query to generate SQL", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| query_btn = gr.Button("Get Suggested Query") | |
| with gr.Column(): | |
| clear = gr.ClearButton([query, chatbot], value="Reset Query History") | |
| with gr.Row(): | |
| search_out = gr.HTML(label="Search Results") | |
| gr.on( | |
| [show_btn.click, search_in.submit], | |
| fn=get_iframe, | |
| inputs=[search_in], | |
| outputs=[search_out], | |
| ).then( | |
| fn=get_table_info, | |
| inputs=[search_in], | |
| outputs=[card_data], | |
| ) | |
| gr.on( | |
| [query_btn.click, query.submit], | |
| fn=query_dataset, | |
| inputs=[ | |
| search_in, | |
| card_data, | |
| query, | |
| config_selection, | |
| split_selection, | |
| chatbot, | |
| ], | |
| outputs=[sql_out, search_out, chatbot, query], | |
| ) | |
| gr.on([query_btn.click], fn=lambda: gr.update(open=True), outputs=[accordion]) | |
| if __name__ == "__main__": | |
| demo.launch() | |