|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Patient Matching Pipeline - Gradio Web Interface |
|
|
|
|
|
This interface allows users to: |
|
|
1. Configure models (embedder, trial_checker, boilerplate_checker) |
|
|
2. Upload patient database OR load pre-embedded patients |
|
|
3. Enter set of clinical criteria (trial eligibility criteria) |
|
|
4. Get ranked patient recommendations with eligibility predictions |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import torch |
|
|
import os |
|
|
import json |
|
|
import pickle |
|
|
import html |
|
|
from typing import List, Tuple |
|
|
from pathlib import Path |
|
|
import pyarrow.parquet as pq |
|
|
|
|
|
|
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForSequenceClassification, |
|
|
) |
|
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
|
|
|
try: |
|
|
from vllm import LLM, SamplingParams |
|
|
HAS_VLLM = True |
|
|
except ImportError: |
|
|
HAS_VLLM = False |
|
|
print("○ vLLM not installed - will use HuggingFace transformers for LLM inference") |
|
|
|
|
|
|
|
|
try: |
|
|
import config |
|
|
HAS_CONFIG = True |
|
|
print("✓ Found config.py - will auto-load models on startup") |
|
|
except ImportError: |
|
|
HAS_CONFIG = False |
|
|
print("○ No config.py found - using manual model loading") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AppState: |
|
|
def __init__(self): |
|
|
self.embedder_model = None |
|
|
self.embedder_tokenizer = None |
|
|
self.trial_checker_model = None |
|
|
self.trial_checker_tokenizer = None |
|
|
self.boilerplate_checker_model = None |
|
|
self.boilerplate_checker_tokenizer = None |
|
|
self.llm_model = None |
|
|
self.llm_tokenizer = None |
|
|
|
|
|
self.patient_df = None |
|
|
self.patient_embeddings = None |
|
|
self.patient_preview_df = None |
|
|
|
|
|
|
|
|
self.last_results_df = None |
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
self.auto_load_status = { |
|
|
"embedder": "", |
|
|
"trial_checker": "", |
|
|
"boilerplate_checker": "", |
|
|
"llm": "", |
|
|
"patients": "" |
|
|
} |
|
|
|
|
|
def reset_patients(self): |
|
|
self.patient_df = None |
|
|
self.patient_embeddings = None |
|
|
self.patient_preview_df = None |
|
|
|
|
|
state = AppState() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MAX_EMBEDDER_SEQ_LEN = 2500 |
|
|
MAX_TRIAL_CHECKER_LENGTH = 4096 |
|
|
MAX_BOILERPLATE_CHECKER_LENGTH = 3192 |
|
|
CLASSIFIER_BATCH_SIZE = 32 |
|
|
|
|
|
|
|
|
DEFAULT_CLINICAL_SPACE_TEMPLATE = """Age range allowed: |
|
|
Sex allowed: |
|
|
Cancer type allowed: |
|
|
Histology allowed: |
|
|
Cancer burden allowed: |
|
|
Prior treatment required: |
|
|
Prior treatment excluded: |
|
|
Biomarkers required: |
|
|
Biomarkers excluded: """ |
|
|
|
|
|
DEFAULT_BOILERPLATE_TEMPLATE = """History of pneumonitis: |
|
|
Heart failure or cardiac dysfunction: |
|
|
Renal dysfunction: |
|
|
Liver dysfunction: |
|
|
Uncontrolled brain metastases: |
|
|
HIV or hepatitis infection: |
|
|
Poor performance status (ECOG >= 2): |
|
|
Other relevant exclusions: """ |
|
|
|
|
|
|
|
|
REASONING_MARKER = "assistantfinal" |
|
|
|
|
|
DEEPER_SCREEN_TRIAL_PROMPT = ( |
|
|
"You are a brilliant oncologist with encyclopedic knowledge about cancer and its treatment. " |
|
|
"Your job is to evaluate whether a given clinical trial is a reasonable consideration for a patient, " |
|
|
"given a clinical trial summary and a patient summary.\n\n" |
|
|
"Here is a summary of the clinical trial:\n{trial_space}\n" |
|
|
"Here is a summary of the patient:\n{patient_summary}\n" |
|
|
"Base your judgment on whether the patient generally fits the age requirements if any, sex requirements if any, cancer type(s), cancer burden, prior treatment(s), " |
|
|
"and biomarker criteria specified for the trial.\n" |
|
|
"You do not have to determine if the patient is actually eligible; instead please just evaluate whether it is reasonable " |
|
|
"for the trial to be considered further by the patient's oncologist.\n" |
|
|
"Biomarker criteria have to be considered carefully. Some trials have biomarker requirements that are not assessed until " |
|
|
"formal trial screening. A trial may therefore sometimes be a reasonable consideration for a patient even if a required " |
|
|
"biomarker is not known to be present in the patient.\n" |
|
|
"However, if a required biomarker is known to be absent, or can be assumed to be absent based on other information, the trial " |
|
|
"is not a reasonable consideration. For example, if a trial for lung cancer requires an EGFR mutation, documentation that there " |
|
|
"is no EGFR mutation indicates the trial is not a reasonable consideration. Similarly, documentation of a KRAS mutation in the " |
|
|
"patient indicates the trial is not a reasonable consideration, since, as you know, KRAS and EGFR driver mutations in lung cancer " |
|
|
"are mutually exclusive.\n" |
|
|
"Many trials describe required washout periods for prior treatments for eligibility. For example, the eligibility criteria might state " |
|
|
"that patients may not have received radiation or chemotherapy in the last 14 days or 30 days. It is CRITICAL that you IGNORE these " |
|
|
"eligibility criteria when considering prior treatment requirements. Assume that patients could wait for the washout period to enroll. " |
|
|
"Also CRITICAL: Ignore your knowledge of today's current date. Pretend that you are evaluating the patient's evaluation based on the " |
|
|
"most recent information available in their summary, at the time of that most recently available information. " |
|
|
"Do not provide ethical judgments or comment on resource constraints with respect whether the trial is a reasonable clinical " |
|
|
"consideration; just evaluate whether it is, given the available information.\n" |
|
|
'Reason step by step, then answer the question "Is this trial a reasonable consideration for this patient?" with a one-word ' |
|
|
'"Yes!" or "No!" answer.\n' |
|
|
"Make sure to include the exclamation point in your final one-word answer." |
|
|
) |
|
|
|
|
|
DEEPER_SCREEN_BOILERPLATE_PROMPT = ( |
|
|
"You are a brilliant oncologist with encyclopedic knowledge about cancer and its treatment.\n" |
|
|
"Your job is to evaluate whether a patient has any underlying medical conditions that would exclude him or her from a specific clinical trial.\n\n" |
|
|
"Here is an extract of the patient's history:\n{patient_boilerplate}\n" |
|
|
"Here are the exclusion criteria for the trial:\n{trial_boilerplate}\n" |
|
|
"Note that the extract was generated by prompting an LLM to determine whether the patient meets specific common exclusion criteria, " |
|
|
"such as uncontrolled brain metastases, lack of measurable disease, congestive heart failure, pneumonitis, renal dysfunction, " |
|
|
"liver dysfunction, and HIV or hepatitis infection, and to present evidence for whether the patient met the criterion.\n" |
|
|
"You should therefore not assume that mention of such condition means the patient has the condition; it may represent the LLM reasoning " |
|
|
"about whether the patient has the condition.\n" |
|
|
"Based on the extract, you should determine whether the patient clearly meets one of the exclusion criteria for this specific trial.\n" |
|
|
"Do not evaluate exclusion criteria other than those listed for this trial.\n" |
|
|
"Reason through one exclusion criterion at a time. Generate a numbered list of the criteria as you go. For each one, decide whether the patient clearly " |
|
|
"meets the exclusion criteron. If it is not completely clear that the patient meets the exclusion criterion, give the patient the benefit of the doubt, " |
|
|
"and err on the side of deciding the patient is not excluded. A description in the patient extract that a condition is mild, low-grade, or resolved is even " |
|
|
"more of a reason not to exclude the patient based on that condition.\n" |
|
|
'Once you have evaluated all exclusion criteria, answer the question "Is this patient clearly excluded from this trial?" with a one-word "Yes!" or "No!" answer, ' |
|
|
"based on whether the patient clearly met any of the individual exclusion criteria. It is critical that your final word be either \"Yes!\" or \"No!\", verbatim, and case-sensitive.\n" |
|
|
"Make sure to include the exclamation point in your final one-word answer.\n" |
|
|
"No introductory text or concluding text after that final answer." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def truncate_text(text: str, tokenizer, max_tokens: int = 1500) -> str: |
|
|
"""Truncate text to a maximum number of tokens.""" |
|
|
return tokenizer.decode( |
|
|
tokenizer.encode(text, add_special_tokens=True, truncation=True, max_length=max_tokens), |
|
|
skip_special_tokens=True |
|
|
) |
|
|
|
|
|
|
|
|
def format_probability_visual(val, is_exclusion=False): |
|
|
"""Format probabilities with visual indicators.""" |
|
|
try: |
|
|
val_float = float(val) |
|
|
except: |
|
|
return val |
|
|
|
|
|
if not is_exclusion: |
|
|
|
|
|
if val_float >= 0.8: |
|
|
return f"🟢 **{val_float:.2f}**" |
|
|
elif val_float >= 0.5: |
|
|
return f"🟡 {val_float:.2f}" |
|
|
else: |
|
|
return f"🔴 {val_float:.2f}" |
|
|
else: |
|
|
|
|
|
if val_float >= 0.5: |
|
|
return f"🔴 **{val_float:.2f}**" |
|
|
elif val_float >= 0.2: |
|
|
return f"🟡 {val_float:.2f}" |
|
|
else: |
|
|
return f"🟢 {val_float:.2f}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def auto_load_models_from_config(): |
|
|
"""Auto-load models specified in config.py""" |
|
|
if not HAS_CONFIG: |
|
|
return |
|
|
|
|
|
print("\n" + "="*70) |
|
|
print("AUTO-LOADING MODELS FROM CONFIG") |
|
|
print("="*70) |
|
|
|
|
|
|
|
|
if config.MODEL_CONFIG.get("embedder"): |
|
|
print(f"\n[1/4] Loading embedder: {config.MODEL_CONFIG['embedder']}") |
|
|
status, _, _ = load_embedder_model(config.MODEL_CONFIG["embedder"]) |
|
|
state.auto_load_status["embedder"] = status |
|
|
print(status) |
|
|
|
|
|
|
|
|
if config.MODEL_CONFIG.get("trial_checker"): |
|
|
print(f"\n[2/4] Loading trial checker: {config.MODEL_CONFIG['trial_checker']}") |
|
|
status, _ = load_trial_checker(config.MODEL_CONFIG["trial_checker"]) |
|
|
state.auto_load_status["trial_checker"] = status |
|
|
print(status) |
|
|
|
|
|
|
|
|
if config.MODEL_CONFIG.get("boilerplate_checker"): |
|
|
print(f"\n[3/4] Loading boilerplate checker: {config.MODEL_CONFIG['boilerplate_checker']}") |
|
|
status, _ = load_boilerplate_checker(config.MODEL_CONFIG["boilerplate_checker"]) |
|
|
state.auto_load_status["boilerplate_checker"] = status |
|
|
print(status) |
|
|
|
|
|
|
|
|
if config.MODEL_CONFIG.get("llm"): |
|
|
print(f"\n[4/4] Loading LLM: {config.MODEL_CONFIG['llm']}") |
|
|
status, _ = load_llm_model(config.MODEL_CONFIG["llm"]) |
|
|
state.auto_load_status["llm"] = status |
|
|
print(status) |
|
|
|
|
|
print("\n" + "="*70) |
|
|
print("MODEL AUTO-LOADING COMPLETE") |
|
|
print("="*70 + "\n") |
|
|
|
|
|
|
|
|
def auto_load_patients_from_config(): |
|
|
"""Auto-load patient database from config.py - prefers pre-embedded over fresh embedding.""" |
|
|
if not HAS_CONFIG: |
|
|
return |
|
|
|
|
|
|
|
|
if hasattr(config, 'PREEMBEDDED_PATIENTS') and config.PREEMBEDDED_PATIENTS: |
|
|
preembed_path = config.PREEMBEDDED_PATIENTS |
|
|
|
|
|
|
|
|
if preembed_path.startswith("http://") or preembed_path.startswith("https://"): |
|
|
print("\n" + "="*70) |
|
|
print(f"AUTO-LOADING PRE-EMBEDDED PATIENTS (URL): {preembed_path}") |
|
|
print("="*70) |
|
|
|
|
|
status, preview = load_preembedded_patients(preembed_path) |
|
|
state.auto_load_status["patients"] = status |
|
|
state.patient_preview_df = preview |
|
|
|
|
|
print("="*70) |
|
|
print("PRE-EMBEDDED PATIENTS AUTO-LOADING COMPLETE") |
|
|
print("="*70 + "\n") |
|
|
return |
|
|
|
|
|
|
|
|
parquet_path = preembed_path if preembed_path.endswith('.parquet') else f"{preembed_path}.parquet" |
|
|
old_format_data = f"{preembed_path}_data.pkl" |
|
|
|
|
|
if os.path.exists(parquet_path): |
|
|
|
|
|
print("\n" + "="*70) |
|
|
print(f"AUTO-LOADING PRE-EMBEDDED PATIENTS (parquet): {parquet_path}") |
|
|
print("="*70) |
|
|
|
|
|
status, preview = load_preembedded_patients(parquet_path) |
|
|
state.auto_load_status["patients"] = status |
|
|
state.patient_preview_df = preview |
|
|
|
|
|
print("="*70) |
|
|
print("PRE-EMBEDDED PATIENTS AUTO-LOADING COMPLETE") |
|
|
print("="*70 + "\n") |
|
|
return |
|
|
elif os.path.exists(old_format_data): |
|
|
|
|
|
print("\n" + "="*70) |
|
|
print(f"AUTO-LOADING PRE-EMBEDDED PATIENTS (legacy): {preembed_path}") |
|
|
print("="*70) |
|
|
|
|
|
status, preview = load_preembedded_patients(preembed_path) |
|
|
state.auto_load_status["patients"] = status |
|
|
state.patient_preview_df = preview |
|
|
|
|
|
print("="*70) |
|
|
print("PRE-EMBEDDED PATIENTS AUTO-LOADING COMPLETE") |
|
|
print("="*70 + "\n") |
|
|
return |
|
|
else: |
|
|
print(f"✗ Pre-embedded patient files not found: {preembed_path}") |
|
|
state.auto_load_status["patients"] = f"✗ Pre-embedded files not found: {preembed_path}" |
|
|
return |
|
|
|
|
|
|
|
|
if not hasattr(config, 'DEFAULT_PATIENT_DB') or not config.DEFAULT_PATIENT_DB: |
|
|
print("○ No patient database specified in config") |
|
|
return |
|
|
|
|
|
if not os.path.exists(config.DEFAULT_PATIENT_DB): |
|
|
print(f"✗ Default patient database not found: {config.DEFAULT_PATIENT_DB}") |
|
|
state.auto_load_status["patients"] = f"✗ Patient database file not found: {config.DEFAULT_PATIENT_DB}" |
|
|
return |
|
|
|
|
|
if state.embedder_model is None: |
|
|
print("○ Embedder not loaded yet - skipping patient database auto-load") |
|
|
state.auto_load_status["patients"] = "○ Waiting for embedder model to be loaded..." |
|
|
return |
|
|
|
|
|
print("\n" + "="*70) |
|
|
print(f"AUTO-LOADING PATIENT DATABASE: {config.DEFAULT_PATIENT_DB}") |
|
|
print("="*70) |
|
|
|
|
|
class FilePath: |
|
|
def __init__(self, path): |
|
|
self.name = path |
|
|
|
|
|
status, preview = load_and_embed_patients(FilePath(config.DEFAULT_PATIENT_DB), show_progress=True) |
|
|
state.auto_load_status["patients"] = status |
|
|
state.patient_preview_df = preview |
|
|
|
|
|
print("="*70) |
|
|
print("PATIENT DATABASE AUTO-LOADING COMPLETE") |
|
|
print("="*70 + "\n") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_embedder_model(model_path: str) -> Tuple[str, str, str]: |
|
|
"""Load sentence transformer embedder model.""" |
|
|
try: |
|
|
will_need_reembed = state.patient_df is not None and len(state.patient_df) > 0 |
|
|
|
|
|
if will_need_reembed: |
|
|
warning_msg = f"\n⚠️ Warning: {len(state.patient_df)} patients are currently loaded. They will need to be re-embedded with the new model." |
|
|
else: |
|
|
warning_msg = "" |
|
|
|
|
|
state.embedder_model = SentenceTransformer(model_path, device=state.device, trust_remote_code=True) |
|
|
state.embedder_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
|
|
|
|
|
|
|
try: |
|
|
state.embedder_model.prompts['query'] = ( |
|
|
"Instruct: Given a cancer patient summary, retrieve clinical trial options " |
|
|
"that are reasonable for that patient; or, given a clinical trial option, " |
|
|
"retrieve cancer patients who are reasonable candidates for that trial." |
|
|
) |
|
|
except: |
|
|
pass |
|
|
|
|
|
try: |
|
|
state.embedder_model.max_seq_length = MAX_EMBEDDER_SEQ_LEN |
|
|
except: |
|
|
pass |
|
|
|
|
|
success_msg = f"✓ Embedder model loaded from {model_path}{warning_msg}" |
|
|
|
|
|
if will_need_reembed: |
|
|
state.patient_embeddings = None |
|
|
success_msg += "\n→ Patient embeddings cleared. Please reload patient database to re-embed." |
|
|
|
|
|
return success_msg, "", warning_msg |
|
|
except Exception as e: |
|
|
return f"✗ Error loading embedder model: {str(e)}", str(e), "" |
|
|
|
|
|
|
|
|
def load_trial_checker(model_path: str) -> Tuple[str, str]: |
|
|
"""Load ModernBERT trial checker.""" |
|
|
try: |
|
|
state.trial_checker_tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
state.trial_checker_model = AutoModelForSequenceClassification.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=torch.float16 if state.device == "cuda" else torch.float32 |
|
|
).to(state.device) |
|
|
state.trial_checker_model.eval() |
|
|
return f"✓ Trial checker loaded from {model_path}", "" |
|
|
except Exception as e: |
|
|
return f"✗ Error loading trial checker: {str(e)}", str(e) |
|
|
|
|
|
|
|
|
def load_boilerplate_checker(model_path: str) -> Tuple[str, str]: |
|
|
"""Load ModernBERT boilerplate checker.""" |
|
|
try: |
|
|
state.boilerplate_checker_tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
state.boilerplate_checker_model = AutoModelForSequenceClassification.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=torch.float16 if state.device == "cuda" else torch.float32 |
|
|
).to(state.device) |
|
|
state.boilerplate_checker_model.eval() |
|
|
return f"✓ Boilerplate checker loaded from {model_path}", "" |
|
|
except Exception as e: |
|
|
return f"✗ Error loading boilerplate checker: {str(e)}", str(e) |
|
|
|
|
|
|
|
|
def load_llm_model(model_path: str) -> Tuple[str, str]: |
|
|
"""Load LLM for deep screen reasoning.""" |
|
|
try: |
|
|
if HAS_VLLM: |
|
|
|
|
|
gpu_count = torch.cuda.device_count() |
|
|
tp_size = min(gpu_count, 4) if gpu_count > 1 else 1 |
|
|
|
|
|
state.llm_model = LLM( |
|
|
model=model_path, |
|
|
tensor_parallel_size=tp_size, |
|
|
gpu_memory_utilization=0.60, |
|
|
max_model_len=15000, |
|
|
) |
|
|
state.llm_tokenizer = state.llm_model.get_tokenizer() |
|
|
return f"✓ LLM loaded from {model_path} (vLLM, tp={tp_size})", "" |
|
|
else: |
|
|
|
|
|
from transformers import AutoModelForCausalLM |
|
|
state.llm_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
|
state.llm_model = AutoModelForCausalLM.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=torch.float16 if state.device == "cuda" else torch.float32, |
|
|
device_map="auto", |
|
|
trust_remote_code=True |
|
|
) |
|
|
return f"✓ LLM loaded from {model_path} (HuggingFace)", "" |
|
|
except Exception as e: |
|
|
return f"✗ Error loading LLM: {str(e)}", str(e) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_preembedded_patients(preembedded_path: str) -> Tuple[str, pd.DataFrame]: |
|
|
"""Load pre-embedded patient database from disk. |
|
|
|
|
|
Supports two formats: |
|
|
1. New format: Single parquet file with patient_embedding column |
|
|
- Path should end with .parquet |
|
|
- Embeddings stored as lists in patient_embedding column |
|
|
- Metadata stored in parquet file metadata |
|
|
|
|
|
2. Legacy format: Separate pkl/npy/json files |
|
|
- Path is a prefix (e.g., "patient_embeddings") |
|
|
- Creates patient_embeddings_data.pkl, _vectors.npy, _metadata.json |
|
|
""" |
|
|
try: |
|
|
|
|
|
is_parquet = preembedded_path.endswith('.parquet') or os.path.exists(f"{preembedded_path}.parquet") if not preembedded_path.endswith('.parquet') else True |
|
|
|
|
|
if is_parquet: |
|
|
return _load_preembedded_parquet(preembedded_path) |
|
|
else: |
|
|
return _load_preembedded_legacy(preembedded_path) |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return f"✗ Error loading pre-embedded patients: {str(e)}", None |
|
|
|
|
|
|
|
|
def _load_preembedded_parquet(parquet_path: str) -> Tuple[str, pd.DataFrame]: |
|
|
"""Load pre-embedded patients from new single parquet format.""" |
|
|
is_url = parquet_path.startswith("http://") or parquet_path.startswith("https://") |
|
|
|
|
|
|
|
|
if not is_url and not parquet_path.endswith('.parquet'): |
|
|
parquet_path = f"{parquet_path}.parquet" |
|
|
|
|
|
if not is_url and not os.path.exists(parquet_path): |
|
|
return f"✗ Pre-embedded parquet file not found: {parquet_path}", None |
|
|
|
|
|
print(f"\n{'='*70}") |
|
|
print(f"LOADING PRE-EMBEDDED PATIENTS (Parquet Format)") |
|
|
print(f"{'='*70}") |
|
|
print(f"Loading from: {parquet_path}") |
|
|
|
|
|
try: |
|
|
|
|
|
if is_url: |
|
|
df = pd.read_parquet(parquet_path) |
|
|
|
|
|
|
|
|
print(f"Metadata: (Skipped for URL)") |
|
|
else: |
|
|
|
|
|
parquet_file = pq.read_table(parquet_path) |
|
|
|
|
|
|
|
|
if parquet_file.schema.metadata and b'patient_embedding_metadata' in parquet_file.schema.metadata: |
|
|
metadata = json.loads(parquet_file.schema.metadata[b'patient_embedding_metadata'].decode('utf-8')) |
|
|
print(f"Metadata:") |
|
|
print(f" Created: {metadata.get('created_at', 'unknown')}") |
|
|
print(f" Embedder: {metadata.get('embedder_model', 'unknown')}") |
|
|
print(f" Patients: {metadata.get('num_patients', 'unknown')}") |
|
|
print(f" Embedding dim: {metadata.get('embedding_dim', 'unknown')}") |
|
|
|
|
|
|
|
|
df = parquet_file.to_pandas() |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"✗ Failed to read parquet file from {parquet_path}: {str(e)}" |
|
|
print(error_msg) |
|
|
return error_msg, None |
|
|
|
|
|
print(f"✓ Loaded {len(df)} patients") |
|
|
print(f" Columns: {', '.join(df.columns.tolist())}") |
|
|
|
|
|
|
|
|
if 'patient_embedding' not in df.columns: |
|
|
return f"✗ Parquet file missing 'patient_embedding' column: {parquet_path}", None |
|
|
|
|
|
if 'patient_id' not in df.columns: |
|
|
return f"✗ Parquet file missing 'patient_id' column: {parquet_path}", None |
|
|
|
|
|
if 'patient_summary' not in df.columns: |
|
|
return f"✗ Parquet file missing 'patient_summary' column: {parquet_path}", None |
|
|
|
|
|
|
|
|
if 'patient_boilerplate' in df.columns: |
|
|
non_empty_bp = (df['patient_boilerplate'].astype(str).str.strip().str.len() > 0).sum() |
|
|
print(f" ✓ patient_boilerplate column: {non_empty_bp}/{len(df)} patients have boilerplate text") |
|
|
else: |
|
|
print(f" ⚠ No patient_boilerplate column found") |
|
|
df['patient_boilerplate'] = '' |
|
|
|
|
|
|
|
|
print(f"Converting embeddings to numpy array...") |
|
|
embeddings = np.array(df['patient_embedding'].tolist(), dtype=np.float32) |
|
|
print(f"✓ Loaded embeddings: {embeddings.shape}") |
|
|
|
|
|
|
|
|
df_without_embeddings = df.drop(columns=['patient_embedding']) |
|
|
|
|
|
state.patient_df = df_without_embeddings |
|
|
state.patient_embeddings = embeddings |
|
|
|
|
|
print(f"{'='*70}") |
|
|
print(f"PRE-EMBEDDED PATIENTS LOADED SUCCESSFULLY") |
|
|
print(f"{'='*70}\n") |
|
|
|
|
|
preview = df_without_embeddings[['patient_id', 'patient_summary']].head(10) |
|
|
return f"✓ Loaded {len(df)} pre-embedded patients from {os.path.basename(parquet_path)}", preview |
|
|
|
|
|
|
|
|
def _load_preembedded_legacy(preembedded_prefix: str) -> Tuple[str, pd.DataFrame]: |
|
|
"""Load pre-embedded patients from legacy format (pkl + npy + json files).""" |
|
|
data_file = f"{preembedded_prefix}_data.pkl" |
|
|
vectors_file = f"{preembedded_prefix}_vectors.npy" |
|
|
metadata_file = f"{preembedded_prefix}_metadata.json" |
|
|
|
|
|
if not os.path.exists(data_file): |
|
|
return f"✗ Pre-embedded data file not found: {data_file}", None |
|
|
if not os.path.exists(vectors_file): |
|
|
return f"✗ Pre-embedded vectors file not found: {vectors_file}", None |
|
|
|
|
|
print(f"\n{'='*70}") |
|
|
print(f"LOADING PRE-EMBEDDED PATIENTS (Legacy Format)") |
|
|
print(f"{'='*70}") |
|
|
print(f"Loading from: {preembedded_prefix}_*") |
|
|
|
|
|
if os.path.exists(metadata_file): |
|
|
with open(metadata_file, 'r') as f: |
|
|
metadata = json.load(f) |
|
|
print(f"Metadata:") |
|
|
print(f" Created: {metadata.get('created_at', 'unknown')}") |
|
|
print(f" Embedder: {metadata.get('embedder_model', 'unknown')}") |
|
|
print(f" Patients: {metadata.get('num_patients', 'unknown')}") |
|
|
print(f" Embedding dim: {metadata.get('embedding_dim', 'unknown')}") |
|
|
|
|
|
print(f"Loading patient dataframe...") |
|
|
with open(data_file, 'rb') as f: |
|
|
df = pickle.load(f) |
|
|
print(f"✓ Loaded {len(df)} patients") |
|
|
print(f" Columns: {', '.join(df.columns.tolist())}") |
|
|
|
|
|
|
|
|
if 'patient_boilerplate' in df.columns: |
|
|
non_empty_bp = (df['patient_boilerplate'].astype(str).str.strip().str.len() > 0).sum() |
|
|
print(f" ✓ patient_boilerplate column: {non_empty_bp}/{len(df)} patients have boilerplate text") |
|
|
else: |
|
|
print(f" ⚠ No patient_boilerplate column found") |
|
|
df['patient_boilerplate'] = '' |
|
|
|
|
|
print(f"Loading embeddings...") |
|
|
embeddings = np.load(vectors_file) |
|
|
print(f"✓ Loaded embeddings: {embeddings.shape}") |
|
|
|
|
|
if len(df) != embeddings.shape[0]: |
|
|
return ( |
|
|
f"✗ Mismatch: {len(df)} patients but {embeddings.shape[0]} embeddings", |
|
|
None |
|
|
) |
|
|
|
|
|
state.patient_df = df |
|
|
state.patient_embeddings = embeddings |
|
|
|
|
|
print(f"{'='*70}") |
|
|
print(f"PRE-EMBEDDED PATIENTS LOADED SUCCESSFULLY") |
|
|
print(f"{'='*70}\n") |
|
|
|
|
|
preview = df[['patient_id', 'patient_summary']].head(10) |
|
|
return f"✓ Loaded {len(df)} pre-embedded patients from {preembedded_prefix}_*", preview |
|
|
|
|
|
|
|
|
def load_and_embed_patients(file, show_progress: bool = False) -> Tuple[str, pd.DataFrame]: |
|
|
"""Load patient database and embed summaries.""" |
|
|
try: |
|
|
if state.embedder_model is None: |
|
|
return "✗ Please load the embedder model first!", None |
|
|
|
|
|
|
|
|
if file.name.endswith('.parquet'): |
|
|
df = pd.read_parquet(file.name) |
|
|
elif file.name.endswith('.csv'): |
|
|
df = pd.read_csv(file.name) |
|
|
elif file.name.endswith(('.xlsx', '.xls')): |
|
|
df = pd.read_excel(file.name) |
|
|
else: |
|
|
return "✗ Unsupported format. Use Parquet, CSV, or Excel.", None |
|
|
|
|
|
|
|
|
required_cols = ['patient_id', 'patient_summary'] |
|
|
missing = [col for col in required_cols if col not in df.columns] |
|
|
if missing: |
|
|
return f"✗ Missing columns: {', '.join(missing)}", None |
|
|
|
|
|
|
|
|
df = df[~df['patient_summary'].isnull()].copy() |
|
|
df = df[df['patient_summary'].astype(str).str.strip().str.len() > 0].copy() |
|
|
|
|
|
if 'patient_boilerplate' not in df.columns: |
|
|
df['patient_boilerplate'] = '' |
|
|
else: |
|
|
df['patient_boilerplate'] = df['patient_boilerplate'].fillna('') |
|
|
|
|
|
|
|
|
df['patient_summary_trunc'] = df['patient_summary'].apply( |
|
|
lambda x: truncate_text(str(x), state.embedder_tokenizer, max_tokens=1500) |
|
|
) |
|
|
|
|
|
prefix = ( |
|
|
"Instruct: Given a cancer patient summary, retrieve clinical trial options " |
|
|
"that are reasonable for that patient; or, given a clinical trial option, " |
|
|
"retrieve cancer patients who are reasonable candidates for that trial. " |
|
|
) |
|
|
texts_to_embed = [prefix + txt for txt in df['patient_summary_trunc'].tolist()] |
|
|
|
|
|
if not show_progress: |
|
|
gr.Info(f"Embedding {len(df)} patient summaries...") |
|
|
else: |
|
|
print(f"Embedding {len(df)} patient summaries...") |
|
|
|
|
|
with torch.no_grad(): |
|
|
embeddings = state.embedder_model.encode( |
|
|
texts_to_embed, |
|
|
batch_size=64, |
|
|
convert_to_tensor=True, |
|
|
normalize_embeddings=True, |
|
|
show_progress_bar=show_progress, |
|
|
prompt='query' |
|
|
) |
|
|
|
|
|
state.patient_df = df |
|
|
state.patient_embeddings = embeddings.cpu().numpy() |
|
|
|
|
|
preview = df[['patient_id', 'patient_summary']].head(10) |
|
|
|
|
|
success_msg = f"✓ Loaded and embedded {len(df)} patients" |
|
|
if show_progress: |
|
|
print(success_msg) |
|
|
|
|
|
return success_msg, preview |
|
|
|
|
|
except Exception as e: |
|
|
return f"✗ Error processing patients: {str(e)}", None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def match_patients( |
|
|
clinical_space: str, |
|
|
boilerplate_criteria: str, |
|
|
top_k_check: int = 1000, |
|
|
eligibility_threshold: float = 0.5 |
|
|
) -> Tuple[pd.DataFrame, str]: |
|
|
"""Match clinical query to patients and run eligibility checks.""" |
|
|
try: |
|
|
if state.embedder_model is None: |
|
|
raise ValueError("Embedder model not loaded") |
|
|
if state.patient_embeddings is None: |
|
|
raise ValueError("Patient database not loaded") |
|
|
if state.trial_checker_model is None: |
|
|
raise ValueError("Trial checker model not loaded") |
|
|
if state.boilerplate_checker_model is None: |
|
|
raise ValueError("Boilerplate checker model not loaded") |
|
|
|
|
|
if not clinical_space or not clinical_space.strip(): |
|
|
raise ValueError("Please enter clinical criteria") |
|
|
|
|
|
|
|
|
prefix = ( |
|
|
"Instruct: Given a cancer patient summary, retrieve clinical trial options " |
|
|
"that are reasonable for that patient; or, given a clinical trial option, " |
|
|
"retrieve cancer patients who are reasonable candidates for that trial. " |
|
|
) |
|
|
|
|
|
query_text = truncate_text(clinical_space, state.embedder_tokenizer, max_tokens=MAX_EMBEDDER_SEQ_LEN) |
|
|
query_text_with_prefix = prefix + query_text |
|
|
|
|
|
gr.Info("Ranking all patients by similarity...") |
|
|
|
|
|
with torch.no_grad(): |
|
|
query_emb = state.embedder_model.encode( |
|
|
[query_text_with_prefix], |
|
|
convert_to_tensor=True, |
|
|
normalize_embeddings=True, |
|
|
prompt='query' |
|
|
) |
|
|
|
|
|
|
|
|
query_emb_np = query_emb.cpu().numpy() |
|
|
similarities = np.dot(state.patient_embeddings, query_emb_np.T).squeeze() |
|
|
|
|
|
|
|
|
sorted_indices = np.argsort(similarities)[::-1] |
|
|
|
|
|
|
|
|
all_patients_ranked = state.patient_df.iloc[sorted_indices].copy() |
|
|
all_patients_ranked['similarity_score'] = similarities[sorted_indices] |
|
|
|
|
|
|
|
|
top_k_check = min(top_k_check, len(all_patients_ranked)) |
|
|
patients_to_check = all_patients_ranked.head(top_k_check).copy() |
|
|
|
|
|
gr.Info(f"Running eligibility checks on top {len(patients_to_check)} patients...") |
|
|
|
|
|
|
|
|
trial_check_inputs = [ |
|
|
f"{clinical_space}\nNow here is the patient summary:{row['patient_summary']}" |
|
|
for _, row in patients_to_check.iterrows() |
|
|
] |
|
|
|
|
|
trial_probs_list = [] |
|
|
for i in range(0, len(trial_check_inputs), CLASSIFIER_BATCH_SIZE): |
|
|
batch_inputs = trial_check_inputs[i:i + CLASSIFIER_BATCH_SIZE] |
|
|
|
|
|
batch_encodings = state.trial_checker_tokenizer( |
|
|
batch_inputs, |
|
|
truncation=True, |
|
|
max_length=MAX_TRIAL_CHECKER_LENGTH, |
|
|
padding=True, |
|
|
return_tensors='pt' |
|
|
).to(state.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
batch_outputs = state.trial_checker_model(**batch_encodings) |
|
|
batch_probs = torch.softmax(batch_outputs.logits, dim=1)[:, 1].cpu().numpy() |
|
|
trial_probs_list.append(batch_probs) |
|
|
|
|
|
trial_probs = np.concatenate(trial_probs_list) |
|
|
patients_to_check['eligibility_probability'] = trial_probs |
|
|
|
|
|
|
|
|
|
|
|
def get_boilerplate_text(row): |
|
|
bp = row.get('patient_boilerplate', '') |
|
|
if bp and isinstance(bp, str) and bp.strip(): |
|
|
return bp |
|
|
return row['patient_summary'] |
|
|
|
|
|
boilerplate_check_inputs = [ |
|
|
f"Patient history: {get_boilerplate_text(row)}\nTrial exclusions:{boilerplate_criteria}" |
|
|
for _, row in patients_to_check.iterrows() |
|
|
] |
|
|
|
|
|
boilerplate_probs_list = [] |
|
|
for i in range(0, len(boilerplate_check_inputs), CLASSIFIER_BATCH_SIZE): |
|
|
batch_inputs = boilerplate_check_inputs[i:i + CLASSIFIER_BATCH_SIZE] |
|
|
|
|
|
batch_encodings = state.boilerplate_checker_tokenizer( |
|
|
batch_inputs, |
|
|
truncation=True, |
|
|
max_length=MAX_BOILERPLATE_CHECKER_LENGTH, |
|
|
padding=True, |
|
|
return_tensors='pt' |
|
|
).to(state.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
batch_outputs = state.boilerplate_checker_model(**batch_encodings) |
|
|
batch_probs = torch.softmax(batch_outputs.logits, dim=1)[:, 1].cpu().numpy() |
|
|
boilerplate_probs_list.append(batch_probs) |
|
|
|
|
|
boilerplate_probs = np.concatenate(boilerplate_probs_list) |
|
|
patients_to_check['exclusion_probability'] = boilerplate_probs |
|
|
|
|
|
|
|
|
patients_to_check = patients_to_check.sort_values('eligibility_probability', ascending=False) |
|
|
|
|
|
|
|
|
state.last_results_df = patients_to_check.copy() |
|
|
|
|
|
|
|
|
num_eligible = (patients_to_check['eligibility_probability'] >= eligibility_threshold).sum() |
|
|
num_no_exclusion = (patients_to_check['exclusion_probability'] < 0.5).sum() |
|
|
num_both = ((patients_to_check['eligibility_probability'] >= eligibility_threshold) & |
|
|
(patients_to_check['exclusion_probability'] < 0.5)).sum() |
|
|
|
|
|
bottom_line = f""" |
|
|
### 📊 Summary: Patients Meeting Your Criteria |
|
|
| Metric | Count | |
|
|
|--------|-------| |
|
|
| Total patients in database | **{len(state.patient_df)}** | |
|
|
| Top patients checked with classifiers | **{len(patients_to_check)}** | |
|
|
| Meeting eligibility criteria (≥{eligibility_threshold}) | **{num_eligible}** | |
|
|
| Without boilerplate exclusions (<0.5) | **{num_no_exclusion}** | |
|
|
| **Meeting BOTH criteria** | **{num_both}** | |
|
|
""" |
|
|
|
|
|
|
|
|
patients_to_check['eligibility_display'] = patients_to_check['eligibility_probability'].apply( |
|
|
lambda x: format_probability_visual(x, is_exclusion=False) |
|
|
) |
|
|
patients_to_check['exclusion_display'] = patients_to_check['exclusion_probability'].apply( |
|
|
lambda x: format_probability_visual(x, is_exclusion=True) |
|
|
) |
|
|
patients_to_check['similarity_display'] = patients_to_check['similarity_score'].apply( |
|
|
lambda x: f"{x:.3f}" |
|
|
) |
|
|
|
|
|
|
|
|
patients_to_check['summary_preview'] = patients_to_check['patient_summary'].apply( |
|
|
lambda x: str(x)[:300] + "..." if len(str(x)) > 300 else str(x) |
|
|
) |
|
|
|
|
|
|
|
|
display_cols = [ |
|
|
'patient_id', |
|
|
'eligibility_display', |
|
|
'exclusion_display', |
|
|
'similarity_display', |
|
|
'summary_preview' |
|
|
] |
|
|
|
|
|
result_df = patients_to_check[display_cols].reset_index(drop=True) |
|
|
result_df.columns = [ |
|
|
'Patient ID', |
|
|
'Eligibility', |
|
|
'Exclusion', |
|
|
'Similarity', |
|
|
'Summary Preview' |
|
|
] |
|
|
|
|
|
return result_df, bottom_line |
|
|
|
|
|
except Exception as e: |
|
|
gr.Error(f"Error matching patients: {str(e)}") |
|
|
return pd.DataFrame(), f"**Error:** {str(e)}" |
|
|
|
|
|
|
|
|
def get_patient_details(df: pd.DataFrame, evt: gr.SelectData) -> str: |
|
|
"""Get full patient details when user clicks on a row.""" |
|
|
try: |
|
|
if df is None or len(df) == 0: |
|
|
return "No patient selected" |
|
|
|
|
|
row_idx = evt.index[0] |
|
|
patient_id = df.iloc[row_idx]['Patient ID'] |
|
|
|
|
|
|
|
|
if state.last_results_df is None: |
|
|
return "No results available" |
|
|
|
|
|
matching_rows = state.last_results_df[ |
|
|
state.last_results_df['patient_id'] == patient_id |
|
|
] |
|
|
|
|
|
if len(matching_rows) == 0: |
|
|
return f"Error: Could not find patient {patient_id}" |
|
|
|
|
|
patient_row = matching_rows.iloc[0] |
|
|
|
|
|
|
|
|
raw_boilerplate = patient_row.get('patient_boilerplate', '') |
|
|
has_separate_boilerplate = raw_boilerplate and isinstance(raw_boilerplate, str) and raw_boilerplate.strip() |
|
|
|
|
|
if has_separate_boilerplate: |
|
|
boilerplate_text = raw_boilerplate |
|
|
else: |
|
|
boilerplate_text = "(No separate boilerplate column - patient summary was used for boilerplate exclusion check)" |
|
|
|
|
|
|
|
|
summary_escaped = html.escape(str(patient_row['patient_summary'])) |
|
|
boilerplate_escaped = html.escape(str(boilerplate_text)) |
|
|
|
|
|
details = f""" |
|
|
# Patient Details: {patient_id} |
|
|
|
|
|
--- |
|
|
|
|
|
## Scores |
|
|
- **Eligibility Probability:** {patient_row['eligibility_probability']:.3f} |
|
|
- **Exclusion Probability:** {patient_row['exclusion_probability']:.3f} |
|
|
- **Similarity Score:** {patient_row['similarity_score']:.3f} |
|
|
|
|
|
--- |
|
|
|
|
|
## Full Patient Summary |
|
|
<pre style="white-space: pre-wrap; word-wrap: break-word; background-color: #1a1a1a; color: #ffffff; padding: 10px; border-radius: 5px; font-family: monospace; font-size: 0.9em;">{summary_escaped}</pre> |
|
|
|
|
|
--- |
|
|
|
|
|
## Boilerplate Exclusion Check Input |
|
|
<pre style="white-space: pre-wrap; word-wrap: break-word; background-color: #1a1a1a; color: #ffffff; padding: 10px; border-radius: 5px; font-family: monospace; font-size: 0.9em;">{boilerplate_escaped}</pre> |
|
|
""" |
|
|
return details |
|
|
|
|
|
except Exception as e: |
|
|
return f"Error retrieving patient details: {str(e)}" |
|
|
|
|
|
|
|
|
def request_identified_patients(): |
|
|
"""Placeholder for requesting identified patient list.""" |
|
|
if state.last_results_df is None or len(state.last_results_df) == 0: |
|
|
gr.Warning("No results to request - run a search first") |
|
|
return |
|
|
|
|
|
|
|
|
gr.Info("Request functionality not yet implemented") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def toggle_reasoning_display(raw_text: str, show: bool) -> str: |
|
|
"""Helper to toggle reasoning display.""" |
|
|
if not raw_text: |
|
|
return "" |
|
|
if show: |
|
|
return raw_text.strip() |
|
|
|
|
|
if REASONING_MARKER in raw_text: |
|
|
return raw_text.split(REASONING_MARKER, 1)[-1].strip() |
|
|
|
|
|
return raw_text.strip() |
|
|
|
|
|
|
|
|
def _run_llm_inference(messages: List[dict], temperature: float = 0.0, top_p: float = 1.0) -> str: |
|
|
"""Helper to run inference with loaded LLM.""" |
|
|
if hasattr(state.llm_model, 'generate') and hasattr(state.llm_model, 'get_tokenizer'): |
|
|
|
|
|
prompt = state.llm_tokenizer.apply_chat_template( |
|
|
conversation=messages, |
|
|
add_generation_prompt=True, |
|
|
tokenize=False |
|
|
) |
|
|
|
|
|
response = state.llm_model.generate( |
|
|
[prompt], |
|
|
SamplingParams( |
|
|
temperature=temperature, |
|
|
top_p=top_p if temperature > 0 else 1.0, |
|
|
max_tokens=7500, |
|
|
repetition_penalty=1.2 |
|
|
) |
|
|
) |
|
|
return response[0].outputs[0].text |
|
|
else: |
|
|
|
|
|
input_ids = state.llm_tokenizer.apply_chat_template( |
|
|
conversation=messages, |
|
|
add_generation_prompt=True, |
|
|
return_tensors="pt" |
|
|
).to(state.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = state.llm_model.generate( |
|
|
input_ids, |
|
|
max_new_tokens=7500, |
|
|
temperature=temperature, |
|
|
top_p=top_p if temperature > 0 else 1.0, |
|
|
do_sample=(temperature > 0), |
|
|
repetition_penalty=1.2 |
|
|
) |
|
|
return state.llm_tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
def run_deeper_screen( |
|
|
patient_data: dict, |
|
|
clinical_space: str, |
|
|
boilerplate_criteria: str, |
|
|
show_reasoning: bool, |
|
|
progress=gr.Progress(track_tqdm=True) |
|
|
) -> Tuple[str, str, dict]: |
|
|
"""Run deeper screen using LLM for patient-criteria and boilerplate checks.""" |
|
|
if not patient_data: |
|
|
return "Please select a patient first.", "Please select a patient first.", {} |
|
|
|
|
|
if state.llm_model is None: |
|
|
return "Please load LLM model first.", "Please load LLM model first.", {} |
|
|
|
|
|
patient_summary = patient_data.get("patient_summary", "") |
|
|
patient_boilerplate = patient_data.get("patient_boilerplate", "") |
|
|
|
|
|
if not patient_summary: |
|
|
return "Missing patient summary.", "Missing patient summary.", {} |
|
|
|
|
|
if not clinical_space: |
|
|
return "Missing clinical criteria.", "Missing clinical criteria.", {} |
|
|
|
|
|
try: |
|
|
progress(0, desc="Starting deeper screen...") |
|
|
gr.Info("Starting deeper screen analysis... This may take 1-2 minutes.") |
|
|
|
|
|
|
|
|
trial_msg = [ |
|
|
{"role": "system", "content": "Reasoning: high."}, |
|
|
{"role": "user", "content": DEEPER_SCREEN_TRIAL_PROMPT.format( |
|
|
trial_space=clinical_space, |
|
|
patient_summary=patient_summary |
|
|
)} |
|
|
] |
|
|
|
|
|
print("Running Deeper Screen: Checking Criteria Match...") |
|
|
progress(0.2, desc="Checking Criteria Match...") |
|
|
gr.Info("Analyzing eligibility criteria match...") |
|
|
|
|
|
|
|
|
trial_raw = _run_llm_inference(trial_msg, temperature=0.5, top_p=0.9) |
|
|
|
|
|
|
|
|
bp_msg = [ |
|
|
{"role": "system", "content": "Reasoning: high."}, |
|
|
{"role": "user", "content": DEEPER_SCREEN_BOILERPLATE_PROMPT.format( |
|
|
trial_boilerplate=boilerplate_criteria, |
|
|
patient_boilerplate=patient_boilerplate if patient_boilerplate else patient_summary |
|
|
)} |
|
|
] |
|
|
|
|
|
print("Running Deeper Screen: Checking Boilerplate...") |
|
|
progress(0.6, desc="Checking Boilerplate Exclusions...") |
|
|
gr.Info("Analyzing boilerplate exclusion criteria...") |
|
|
|
|
|
|
|
|
bp_raw = _run_llm_inference(bp_msg, temperature=1.0, top_p=0.9) |
|
|
|
|
|
progress(1.0, desc="Deeper screen complete!") |
|
|
gr.Info("Deeper screen analysis complete!") |
|
|
|
|
|
|
|
|
raw_outputs = { |
|
|
"trial_raw": trial_raw, |
|
|
"bp_raw": bp_raw |
|
|
} |
|
|
|
|
|
|
|
|
trial_display = toggle_reasoning_display(trial_raw, show_reasoning) |
|
|
bp_display = toggle_reasoning_display(bp_raw, show_reasoning) |
|
|
|
|
|
return trial_display, bp_display, raw_outputs |
|
|
|
|
|
except Exception as e: |
|
|
err = f"Error in deeper screen: {str(e)}" |
|
|
gr.Error(err) |
|
|
return err, err, {} |
|
|
|
|
|
|
|
|
def update_deeper_screen_display(raw_outputs: dict, show_reasoning: bool) -> Tuple[str, str]: |
|
|
"""Toggle reasoning display for existing deeper screen results.""" |
|
|
if not raw_outputs: |
|
|
return "", "" |
|
|
|
|
|
trial_display = toggle_reasoning_display(raw_outputs.get("trial_raw", ""), show_reasoning) |
|
|
bp_display = toggle_reasoning_display(raw_outputs.get("bp_raw", ""), show_reasoning) |
|
|
|
|
|
return trial_display, bp_display |
|
|
|
|
|
|
|
|
def get_patient_data_for_deep_screen(df: pd.DataFrame, evt: gr.SelectData) -> Tuple[str, dict]: |
|
|
"""Get patient data when user clicks on a row - returns both display and data for deep screen.""" |
|
|
try: |
|
|
if df is None or len(df) == 0: |
|
|
return "No patient selected", {} |
|
|
|
|
|
row_idx = evt.index[0] |
|
|
patient_id = df.iloc[row_idx]['Patient ID'] |
|
|
|
|
|
|
|
|
if state.last_results_df is None: |
|
|
return "No results available", {} |
|
|
|
|
|
matching_rows = state.last_results_df[ |
|
|
state.last_results_df['patient_id'] == patient_id |
|
|
] |
|
|
|
|
|
if len(matching_rows) == 0: |
|
|
return f"Error: Could not find patient {patient_id}", {} |
|
|
|
|
|
patient_row = matching_rows.iloc[0] |
|
|
|
|
|
|
|
|
raw_boilerplate = patient_row.get('patient_boilerplate', '') |
|
|
has_separate_boilerplate = raw_boilerplate and isinstance(raw_boilerplate, str) and raw_boilerplate.strip() |
|
|
|
|
|
if has_separate_boilerplate: |
|
|
boilerplate_text = raw_boilerplate |
|
|
else: |
|
|
boilerplate_text = "(No separate boilerplate column - patient summary was used for boilerplate exclusion check)" |
|
|
|
|
|
|
|
|
summary_escaped = html.escape(str(patient_row['patient_summary'])) |
|
|
boilerplate_escaped = html.escape(str(boilerplate_text)) |
|
|
|
|
|
details = f""" |
|
|
# Patient Details: {patient_id} |
|
|
|
|
|
--- |
|
|
|
|
|
## Scores |
|
|
- **Eligibility Probability:** {patient_row['eligibility_probability']:.3f} |
|
|
- **Exclusion Probability:** {patient_row['exclusion_probability']:.3f} |
|
|
- **Similarity Score:** {patient_row['similarity_score']:.3f} |
|
|
|
|
|
--- |
|
|
|
|
|
## Full Patient Summary |
|
|
<pre style="white-space: pre-wrap; word-wrap: break-word; background-color: #1a1a1a; color: #ffffff; padding: 10px; border-radius: 5px; font-family: monospace; font-size: 0.9em;">{summary_escaped}</pre> |
|
|
|
|
|
--- |
|
|
|
|
|
## Boilerplate Exclusion Check Input |
|
|
<pre style="white-space: pre-wrap; word-wrap: break-word; background-color: #1a1a1a; color: #ffffff; padding: 10px; border-radius: 5px; font-family: monospace; font-size: 0.9em;">{boilerplate_escaped}</pre> |
|
|
""" |
|
|
|
|
|
|
|
|
patient_data = { |
|
|
"patient_id": patient_id, |
|
|
"patient_summary": str(patient_row['patient_summary']), |
|
|
"patient_boilerplate": str(raw_boilerplate) if has_separate_boilerplate else str(patient_row['patient_summary']) |
|
|
} |
|
|
|
|
|
return details, patient_data |
|
|
|
|
|
except Exception as e: |
|
|
return f"Error retrieving patient details: {str(e)}", {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_interface(): |
|
|
|
|
|
theme = gr.themes.Soft( |
|
|
primary_hue="teal", |
|
|
secondary_hue="slate", |
|
|
).set( |
|
|
body_background_fill="*neutral_50", |
|
|
block_background_fill="white", |
|
|
block_border_width="1px", |
|
|
block_label_background_fill="*primary_50", |
|
|
) |
|
|
|
|
|
custom_css = """ |
|
|
.gradio-container { font-family: 'Inter', Arial, sans-serif !important; } |
|
|
.model-status { min-height: 80px !important; font-size: 0.9em; } |
|
|
.status-box { background: #f9fafb; border: 1px solid #e5e7eb; border-radius: 8px; padding: 10px; } |
|
|
h1 { color: #0d9488; } |
|
|
""" |
|
|
|
|
|
|
|
|
clinical_space_template = getattr(config, 'CLINICAL_SPACE_TEMPLATE', DEFAULT_CLINICAL_SPACE_TEMPLATE) if HAS_CONFIG else DEFAULT_CLINICAL_SPACE_TEMPLATE |
|
|
boilerplate_template = getattr(config, 'BOILERPLATE_TEMPLATE', DEFAULT_BOILERPLATE_TEMPLATE) if HAS_CONFIG else DEFAULT_BOILERPLATE_TEMPLATE |
|
|
|
|
|
with gr.Blocks(title="Patient Search Prototype", theme=theme, css=custom_css) as demo: |
|
|
|
|
|
with gr.Row(variant="panel"): |
|
|
with gr.Column(scale=4): |
|
|
gr.Markdown(""" |
|
|
# 🔬 Patient Search Prototype |
|
|
**Find patients matching clinical criteria. Designed for clinical trial matching.** |
|
|
""") |
|
|
with gr.Column(scale=1): |
|
|
pass |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
with gr.Tab("1️⃣ Search"): |
|
|
gr.Markdown(""" |
|
|
### Define Your Search Criteria |
|
|
Enter the clinical criteria to search for matching patients. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
clinical_space_input = gr.Textbox( |
|
|
label="Clinical Criteria", |
|
|
placeholder="Enter eligibility criteria...", |
|
|
value=clinical_space_template, |
|
|
lines=12, |
|
|
info="Define age, sex, cancer type, histology, treatments, biomarkers, etc." |
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
boilerplate_input = gr.Textbox( |
|
|
label="Boilerplate Exclusion Criteria", |
|
|
placeholder="Enter boilerplate exclusions...", |
|
|
value=boilerplate_template, |
|
|
lines=12, |
|
|
info="Common exclusions like organ dysfunction, infections, etc." |
|
|
) |
|
|
|
|
|
gr.Markdown("---") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
match_btn = gr.Button("🔍 Find Matching Patients", variant="primary", size="lg") |
|
|
with gr.Column(scale=3): |
|
|
with gr.Accordion("Search Settings", open=False): |
|
|
top_k_check_slider = gr.Slider( |
|
|
minimum=5, maximum=10000, value=500, step=50, |
|
|
label="Patients to Check with Classifiers", |
|
|
info="Number of top-ranked patients to run through eligibility/boilerplate models (larger queries take more time)" |
|
|
) |
|
|
eligibility_threshold_slider = gr.Slider( |
|
|
minimum=0.0, maximum=1.0, value=0.5, step=0.05, |
|
|
label="Eligibility Threshold", |
|
|
info="Threshold for counting patients as 'eligible'" |
|
|
) |
|
|
|
|
|
gr.Markdown("### 📊 Results") |
|
|
|
|
|
|
|
|
bottom_line_output = gr.Markdown( |
|
|
value="*Run a search to see results*" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=7): |
|
|
results_df = gr.Dataframe( |
|
|
label="Matched Patients", |
|
|
interactive=False, |
|
|
wrap=True, |
|
|
datatype=["str", "markdown", "markdown", "str", "str"], |
|
|
column_widths=["12%", "12%", "12%", "10%", "54%"] |
|
|
) |
|
|
|
|
|
with gr.Column(scale=5): |
|
|
patient_details = gr.Markdown( |
|
|
label="Patient Details", |
|
|
value="<div style='text-align: center; padding: 50px; color: #666;'>👈 Click on a patient row to see full details here</div>" |
|
|
) |
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown("### Deep Screen (LLM)") |
|
|
|
|
|
with gr.Row(): |
|
|
deeper_screen_btn = gr.Button("Run Deep Screen", variant="secondary") |
|
|
|
|
|
show_reasoning_chk = gr.Checkbox(label="Show Chain of Thought", value=False) |
|
|
|
|
|
with gr.Accordion("Criteria Match Reasoning", open=True): |
|
|
trial_reasoning_output = gr.Textbox(show_label=False, lines=8, placeholder="Run deep screen to see reasoning...") |
|
|
|
|
|
with gr.Accordion("Boilerplate Check Reasoning", open=True): |
|
|
bp_reasoning_output = gr.Textbox(show_label=False, lines=8, placeholder="Run deep screen to see reasoning...") |
|
|
|
|
|
|
|
|
patient_data_state = gr.State({}) |
|
|
deeper_screen_raw_state = gr.State({}) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
request_btn = gr.Button("📋 Request Identified Patient List", variant="secondary") |
|
|
|
|
|
|
|
|
match_btn.click( |
|
|
fn=match_patients, |
|
|
inputs=[clinical_space_input, boilerplate_input, top_k_check_slider, eligibility_threshold_slider], |
|
|
outputs=[results_df, bottom_line_output] |
|
|
) |
|
|
|
|
|
results_df.select( |
|
|
fn=get_patient_data_for_deep_screen, |
|
|
inputs=[results_df], |
|
|
outputs=[patient_details, patient_data_state] |
|
|
) |
|
|
|
|
|
deeper_screen_btn.click( |
|
|
fn=run_deeper_screen, |
|
|
inputs=[patient_data_state, clinical_space_input, boilerplate_input, show_reasoning_chk], |
|
|
outputs=[trial_reasoning_output, bp_reasoning_output, deeper_screen_raw_state] |
|
|
) |
|
|
|
|
|
show_reasoning_chk.change( |
|
|
fn=update_deeper_screen_display, |
|
|
inputs=[deeper_screen_raw_state, show_reasoning_chk], |
|
|
outputs=[trial_reasoning_output, bp_reasoning_output] |
|
|
) |
|
|
|
|
|
request_btn.click( |
|
|
fn=request_identified_patients, |
|
|
inputs=[], |
|
|
outputs=[] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("2️⃣ Patient Database"): |
|
|
gr.Markdown("### 📊 Patient Database Management") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("#### Load Pre-embedded Patients (Fast)") |
|
|
preembed_prefix = gr.Textbox( |
|
|
label="Pre-embedded Prefix", |
|
|
placeholder="patient_embeddings", |
|
|
value=getattr(config, 'PREEMBEDDED_PATIENTS', '') or "" if HAS_CONFIG else "" |
|
|
) |
|
|
preembed_btn = gr.Button("Load Pre-embedded", variant="secondary") |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("#### Upload & Embed New Database") |
|
|
patient_file = gr.File( |
|
|
label="Upload Patient Database (Parquet/CSV/Excel)", |
|
|
file_types=[".parquet", ".csv", ".xlsx", ".xls"] |
|
|
) |
|
|
patient_upload_btn = gr.Button("Process & Embed", variant="secondary") |
|
|
|
|
|
patient_status = gr.Textbox( |
|
|
label="Status", |
|
|
interactive=False, |
|
|
value=state.auto_load_status.get("patients", "No patients loaded") |
|
|
) |
|
|
|
|
|
patient_preview = gr.Dataframe( |
|
|
label="Patient Preview (first 10)", |
|
|
value=state.patient_preview_df, |
|
|
wrap=True |
|
|
) |
|
|
|
|
|
preembed_btn.click( |
|
|
fn=load_preembedded_patients, |
|
|
inputs=[preembed_prefix], |
|
|
outputs=[patient_status, patient_preview] |
|
|
) |
|
|
|
|
|
patient_upload_btn.click( |
|
|
fn=load_and_embed_patients, |
|
|
inputs=[patient_file], |
|
|
outputs=[patient_status, patient_preview] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("3️⃣ Model Configuration"): |
|
|
gr.Markdown("### 🧠 Model Management") |
|
|
|
|
|
status_msg = """ |
|
|
**Config file detected** - Models will auto-load on startup. |
|
|
""" if HAS_CONFIG else """ |
|
|
**No config file found** - Please load models manually below. |
|
|
""" |
|
|
gr.Info(status_msg) |
|
|
|
|
|
with gr.Group(): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
embedder_input = gr.Textbox( |
|
|
label="Embedder Model", |
|
|
placeholder="Qwen/Qwen3-Embedding-0.6B", |
|
|
value=config.MODEL_CONFIG.get("embedder", "") if HAS_CONFIG else "" |
|
|
) |
|
|
embedder_btn = gr.Button("Load Embedder") |
|
|
embedder_status = gr.Textbox( |
|
|
label="Status", |
|
|
interactive=False, |
|
|
value=state.auto_load_status.get("embedder", ""), |
|
|
elem_classes=["model-status"] |
|
|
) |
|
|
embedder_warning = gr.Textbox(visible=False) |
|
|
|
|
|
with gr.Column(): |
|
|
trial_checker_input = gr.Textbox( |
|
|
label="Trial Checker Model", |
|
|
placeholder="answerdotai/ModernBERT-large", |
|
|
value=config.MODEL_CONFIG.get("trial_checker", "") if HAS_CONFIG else "" |
|
|
) |
|
|
trial_checker_btn = gr.Button("Load Trial Checker") |
|
|
trial_checker_status = gr.Textbox( |
|
|
label="Status", |
|
|
interactive=False, |
|
|
value=state.auto_load_status.get("trial_checker", ""), |
|
|
elem_classes=["model-status"] |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
boilerplate_checker_input = gr.Textbox( |
|
|
label="Boilerplate Checker Model", |
|
|
placeholder="answerdotai/ModernBERT-large", |
|
|
value=config.MODEL_CONFIG.get("boilerplate_checker", "") if HAS_CONFIG else "" |
|
|
) |
|
|
boilerplate_checker_btn = gr.Button("Load Boilerplate Checker") |
|
|
boilerplate_checker_status = gr.Textbox( |
|
|
label="Status", |
|
|
interactive=False, |
|
|
value=state.auto_load_status.get("boilerplate_checker", ""), |
|
|
elem_classes=["model-status"] |
|
|
) |
|
|
with gr.Column(scale=1): |
|
|
llm_input = gr.Textbox( |
|
|
label="LLM Model (Deep Screen)", |
|
|
placeholder="ksg-dfci/OncoReasoning-3B-1225", |
|
|
value=config.MODEL_CONFIG.get("llm", "") if HAS_CONFIG else "" |
|
|
) |
|
|
llm_btn = gr.Button("Load LLM") |
|
|
llm_status = gr.Textbox( |
|
|
label="Status", |
|
|
interactive=False, |
|
|
value=state.auto_load_status.get("llm", ""), |
|
|
elem_classes=["model-status"] |
|
|
) |
|
|
|
|
|
|
|
|
embedder_btn.click( |
|
|
fn=load_embedder_model, |
|
|
inputs=[embedder_input], |
|
|
outputs=[embedder_status, gr.Textbox(visible=False), embedder_warning] |
|
|
) |
|
|
trial_checker_btn.click( |
|
|
fn=load_trial_checker, |
|
|
inputs=[trial_checker_input], |
|
|
outputs=[trial_checker_status, gr.Textbox(visible=False)] |
|
|
) |
|
|
boilerplate_checker_btn.click( |
|
|
fn=load_boilerplate_checker, |
|
|
inputs=[boilerplate_checker_input], |
|
|
outputs=[boilerplate_checker_status, gr.Textbox(visible=False)] |
|
|
) |
|
|
llm_btn.click( |
|
|
fn=load_llm_model, |
|
|
inputs=[llm_input], |
|
|
outputs=[llm_status, gr.Textbox(visible=False)] |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print(f"Device: {state.device}") |
|
|
print(f"GPU Available: {torch.cuda.is_available()}") |
|
|
if torch.cuda.is_available(): |
|
|
print(f"GPU Count: {torch.cuda.device_count()}") |
|
|
|
|
|
|
|
|
if HAS_CONFIG: |
|
|
auto_load_models_from_config() |
|
|
|
|
|
|
|
|
if state.embedder_model is not None or (hasattr(config, 'PREEMBEDDED_PATIENTS') and config.PREEMBEDDED_PATIENTS): |
|
|
auto_load_patients_from_config() |
|
|
|
|
|
demo = create_interface() |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
ssr_mode=False, |
|
|
|
|
|
share=False |
|
|
) |
|
|
|