| import gradio as gr |
| from datasets import load_dataset |
| from qdrant_client import QdrantClient, models |
| from sentence_transformers import SentenceTransformer |
| import torch |
| import os |
| import shutil |
| import PyPDF2 |
| from docx import Document |
| import pandas as pd |
|
|
| |
| QDRANT_PATH = "./qdrant_db" |
| COLLECTION_NAME = "my_text_collection" |
| MODEL_NAME = 'KaLM-Embedding/KaLM-embedding-multilingual-mini-instruct-v2.5' |
|
|
| |
| device = "cpu" |
| model = SentenceTransformer(MODEL_NAME, device=device) |
|
|
| |
| qdrant_client = QdrantClient(path=QDRANT_PATH) |
|
|
| |
| collection_exists = False |
| try: |
| collection_info = qdrant_client.get_collection(collection_name=COLLECTION_NAME) |
| print("Collection already exists.") |
| collection_exists = True |
| except Exception as e: |
| print(f"Collection not found: {e}, creating a new one...") |
| collection_exists = False |
|
|
| |
| if not collection_exists: |
| |
| dataset = load_dataset("ag_news", split="test") |
| |
| df = dataset.to_pandas() |
| data = df['text'].tolist()[:1000] |
|
|
| |
| |
| vector_size = model.get_sentence_embedding_dimension() or 768 |
| qdrant_client.create_collection( |
| collection_name=COLLECTION_NAME, |
| vectors_config=models.VectorParams(size=vector_size, distance=models.Distance.COSINE), |
| ) |
|
|
| |
| print("Generating and indexing embeddings...") |
| embeddings = model.encode(data) |
| |
| |
| points = [] |
| for i, (text, embedding) in enumerate(zip(data, embeddings)): |
| point = models.PointStruct( |
| id=i, |
| vector=embedding.tolist(), |
| payload={"document": text} |
| ) |
| points.append(point) |
| |
| |
| qdrant_client.upsert( |
| collection_name=COLLECTION_NAME, |
| points=points |
| ) |
| print("Embeddings indexed successfully.") |
|
|
|
|
| |
| def search_in_qdrant(query): |
| if not query: |
| return "Please enter a search query." |
|
|
| |
| query_embedding = model.encode([query])[0].tolist() |
| |
| hits = qdrant_client.search( |
| collection_name=COLLECTION_NAME, |
| query_vector=query_embedding, |
| limit=5, |
| ) |
|
|
| results_text = "" |
| if not hits: |
| return "No results found." |
|
|
| for hit in hits: |
| |
| if hit.payload and 'document' in hit.payload: |
| results_text += f"**Score:** {hit.score:.4f}\n" |
| results_text += f"**Text:** {hit.payload['document']}\n\n" |
| else: |
| results_text += f"**Score:** {hit.score:.4f}\n" |
| results_text += f"**Text:** [No document content available]\n\n" |
|
|
| return results_text |
|
|
| |
| def extract_text_from_file(file_path): |
| """Extract text from various file types""" |
| file_extension = file_path.lower().split('.')[-1] |
| |
| if file_extension == 'txt': |
| with open(file_path, 'r', encoding='utf-8') as f: |
| return f.read() |
| elif file_extension == 'pdf': |
| text = "" |
| with open(file_path, 'rb') as f: |
| pdf_reader = PyPDF2.PdfReader(f) |
| for page in pdf_reader.pages: |
| text += page.extract_text() + "\n" |
| return text |
| elif file_extension in ['docx', 'doc']: |
| doc = Document(file_path) |
| text = "" |
| for paragraph in doc.paragraphs: |
| text += paragraph.text + "\n" |
| return text |
| elif file_extension in ['csv', 'xlsx', 'xls']: |
| if file_extension == 'csv': |
| df = pd.read_csv(file_path) |
| else: |
| df = pd.read_excel(file_path) |
| |
| return df.to_string() |
| else: |
| |
| try: |
| with open(file_path, 'r', encoding='utf-8') as f: |
| return f.read() |
| except UnicodeDecodeError: |
| |
| try: |
| with open(file_path, 'r', encoding='latin-1') as f: |
| return f.read() |
| except: |
| return "Could not read file: unsupported format or encoding issue" |
|
|
| def upload_to_qdrant(text_content, file_upload=None): |
| if not text_content and not file_upload: |
| return "Please provide text content or upload a file." |
| |
| documents_to_add = [] |
| |
| |
| if text_content: |
| documents_to_add.append(text_content) |
| |
| |
| if file_upload: |
| try: |
| content = extract_text_from_file(file_upload.name) |
| documents_to_add.append(content) |
| except Exception as e: |
| return f"Error reading file: {str(e)}" |
| |
| if not documents_to_add: |
| return "No content to upload." |
| |
| |
| |
| max_id = 0 |
| try: |
| collection_info = qdrant_client.get_collection(collection_name=COLLECTION_NAME) |
| if hasattr(collection_info, 'points_count') and collection_info.points_count is not None: |
| current_count = collection_info.points_count |
| max_id = current_count |
| except: |
| max_id = 0 |
|
|
| |
| embeddings = model.encode(documents_to_add) |
| |
| |
| points = [] |
| for i, (doc, embedding) in enumerate(zip(documents_to_add, embeddings)): |
| point_id = max_id + i + 1 |
| point = models.PointStruct( |
| id=point_id, |
| vector=embedding.tolist(), |
| payload={"document": doc} |
| ) |
| points.append(point) |
| |
| |
| qdrant_client.upsert( |
| collection_name=COLLECTION_NAME, |
| points=points |
| ) |
| |
| return f"Successfully added {len(documents_to_add)} document(s) to the collection." |
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown("# Semantic Search with Qdrant and Gradio") |
| gr.Markdown("Enter a query to search for similar news articles from the AG News dataset.") |
|
|
| with gr.Tab("Search"): |
| with gr.Row(): |
| search_input = gr.Textbox(label="Search Query", placeholder="e.g., 'Latest news on space exploration'") |
| search_button = gr.Button("Search") |
| search_output = gr.Markdown() |
| search_button.click(search_in_qdrant, inputs=search_input, outputs=search_output) |
|
|
| with gr.Tab("Upload"): |
| with gr.Row(): |
| text_input = gr.Textbox(label="Text Content", placeholder="Enter text to add to the collection", lines=5) |
| with gr.Row(): |
| file_input = gr.File(label="Or Upload a File", file_types=['.txt', '.pdf', '.docx', '.csv', '.xlsx', '.xls', '.md']) |
| upload_button = gr.Button("Upload to Collection") |
| upload_output = gr.Textbox(label="Upload Status", interactive=False) |
| upload_button.click(upload_to_qdrant, inputs=[text_input, file_input], outputs=upload_output) |
|
|
| if __name__ == "__main__": |
| demo.launch() |