kmohan27 commited on
Commit
5cd5882
Β·
verified Β·
1 Parent(s): 8f5adf7

Upload 2 files

Browse files
Files changed (2) hide show
  1. Test_SemanticSearch.py +149 -0
  2. requirements.txt +15 -3
Test_SemanticSearch.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from pathlib import Path
4
+ import streamlit as st
5
+ import pickle
6
+ import plotly.express as px
7
+ import plotly.graph_objects as go
8
+ import gdown
9
+ import torch
10
+ from sentence_transformers import SentenceTransformer, util
11
+
12
+ # -------------------------------
13
+ # Setup
14
+ # -------------------------------
15
+ st.set_page_config(page_title="Alibaba Semantic Search", layout="wide")
16
+
17
+ MODEL_DIR = Path("models")
18
+ MODEL_DIR.mkdir(exist_ok=True)
19
+
20
+ embeddings_path = MODEL_DIR / 'desc_embeddings_Alibaba_20251016.npy'
21
+ umap_embeddings_path = MODEL_DIR / 'descs_umap_2d_AB_20251016.npy'
22
+ data_file_path = MODEL_DIR / 'full_df_minus_nan_descs.csv'
23
+ umap_model_path = MODEL_DIR / 'umap_2d_AB_written.pkl'
24
+ pca_model_path = MODEL_DIR / 'pca_AB_written.pkl'
25
+
26
+ emb_ID = '1QQ_QfFTSzTLNkp6Sr4jux_ZTJjMhSyah'
27
+ umap_emb_ID = '1a5t5iWOAVgUmYXzrWXctATkDyx9rRF4F'
28
+ data_ID = '1tzM67Lg3R-rAvRtol0VGHx6zGW_tdx60'
29
+ umap_mod_ID = '1x8PK1Gn72YYBZ4po-0guZMUBtL8oSn1i'
30
+ pca_mod_ID = '1jIxBBAZOy8OAzGxBCG4jy7244Wb_TjP9'
31
+
32
+ paths = [embeddings_path, umap_embeddings_path, data_file_path, umap_model_path, pca_model_path]
33
+ ids = [emb_ID, umap_emb_ID, data_ID, umap_mod_ID, pca_mod_ID]
34
+ assets_links = [f"https://drive.google.com/uc?id={x}" for x in ids]
35
+
36
+ # -------------------------------
37
+ # Download + Load Data
38
+ # -------------------------------
39
+ def load_assets():
40
+ st.info("Downloading assets from Google Drive (only if missing)...")
41
+ for url, path in zip(assets_links, paths):
42
+ if not path.exists():
43
+ gdown.download(url, str(path), quiet=False)
44
+ st.success("Assets ready βœ…")
45
+
46
+ embeddings = np.load(embeddings_path)
47
+ umap_2d = np.load(umap_embeddings_path)
48
+ docs = pd.read_csv(data_file_path)
49
+ with open(umap_model_path, "rb") as f:
50
+ umap_model = pickle.load(f)
51
+ with open(pca_model_path, "rb") as f:
52
+ pca_model = pickle.load(f)
53
+ return embeddings, umap_2d, docs, umap_model, pca_model
54
+
55
+ embeddings, umap_2d, docs, umap_model, pca_model = load_assets()
56
+
57
+ # -------------------------------
58
+ # Load SentenceTransformer (cached)
59
+ # -------------------------------
60
+ @st.cache_resource
61
+ def load_text_encoder():
62
+ return SentenceTransformer('Alibaba-NLP/gte-multilingual-base', trust_remote_code=True)
63
+
64
+ model = load_text_encoder()
65
+
66
+ # -------------------------------
67
+ # UI
68
+ # -------------------------------
69
+ st.title("πŸ” Semantic Search β€” Alibaba Embeddings")
70
+ st.markdown("Enter a query to highlight semantically similar documents on the 2D UMAP plot.")
71
+
72
+ query = st.text_input("Enter search query:")
73
+ top_k = st.slider("Number of matches to highlight", min_value=10, max_value=2500, value=100)
74
+
75
+ similarity_measure = st.radio(
76
+ "Similarity measure",
77
+ ["Cosine", "Euclidean", "Manhattan (L1)"],
78
+ horizontal=True
79
+ )
80
+
81
+ # -------------------------------
82
+ # Search logic
83
+ # -------------------------------
84
+ if query:
85
+ with st.spinner("Encoding and searching..."):
86
+ query_embedding = model.encode(query, convert_to_tensor=True)
87
+ query_numpy = query_embedding.cpu().numpy().reshape(1, -1)
88
+ query_pca = pca_model.transform(query_numpy)
89
+ query_umap = umap_model.transform(query_pca)
90
+
91
+ if similarity_measure == "Cosine":
92
+ scores = util.cos_sim(query_embedding, embeddings)[0]
93
+ elif similarity_measure == "Euclidean":
94
+ scores = -torch.cdist(query_embedding, embeddings, p=2)[0]
95
+ elif similarity_measure == "Manhattan (L1)":
96
+ scores = -torch.cdist(query_embedding, embeddings, p=1)[0]
97
+
98
+ top_results = scores.argsort(descending=True)
99
+ highlight_indices = top_results[:top_k].cpu().numpy()
100
+
101
+ documents = docs.title_narrative
102
+ reporting_org = docs.reporting_org_name
103
+ funding = docs.Funding
104
+
105
+ labels = ["Match" if i in highlight_indices else "Other" for i in range(len(documents))]
106
+
107
+ df = pd.DataFrame({
108
+ "UMAP_1": umap_2d[:, 0],
109
+ "UMAP_2": umap_2d[:, 1],
110
+ "Label": labels,
111
+ "Text": documents,
112
+ "Reporting Org": reporting_org,
113
+ "Funding": funding,
114
+ })
115
+
116
+ df["Title"] = df["Text"].str.slice(0, 100) + "..."
117
+ df["Index"] = df.index
118
+
119
+ color_discrete_map = {"Match": "red", "Other": "lightgray"}
120
+
121
+ fig = px.scatter(
122
+ df,
123
+ x="UMAP_1",
124
+ y="UMAP_2",
125
+ color="Label",
126
+ color_discrete_map=color_discrete_map,
127
+ hover_data={"Text": False, "Title": True, "Index": True, "Reporting Org": True, "Funding":True, "UMAP_1": False, "UMAP_2": False},
128
+ opacity=0.7,
129
+ title=f"Top {top_k} semantic matches for: '{query}' ({similarity_measure})",
130
+ width=900,
131
+ height=700
132
+ )
133
+
134
+ fig.add_trace(go.Scatter(
135
+ x=[query_umap[0][0]], y=[query_umap[0][1]],
136
+ mode='markers+text',
137
+ marker=dict(size=10, color='blue', symbol='x'),
138
+ name='Query',
139
+ text=['Query'], textposition='top center'
140
+ ))
141
+
142
+ fig.update_traces(marker=dict(size=4))
143
+ st.plotly_chart(fig, use_container_width=True)
144
+
145
+ st.subheader("Top 10 matched documents")
146
+ for rank, idx in enumerate(highlight_indices[:10], start=1):
147
+ st.markdown(f"{rank}. {documents.iloc[idx]}")
148
+ else:
149
+ st.info("Enter a search query to begin.")
requirements.txt CHANGED
@@ -1,3 +1,15 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core app
2
+ streamlit==1.39.0
3
+ gdown==5.2.0
4
+ pandas==2.2.3
5
+ numpy==2.1.2
6
+ numba==0.61.2
7
+ joblib==1.4.2
8
+ plotly==5.24.1
9
+ torch==2.6.0
10
+ sentence-transformers==3.2.1
11
+ umap-learn==0.5.6
12
+ scikit-learn==1.5.2
13
+
14
+ # Optional (used internally by SentenceTransformer models)
15
+ transformers==4.45.2