updated app.py
Browse files
app.py
CHANGED
|
@@ -41,36 +41,48 @@ import gradio as gr
|
|
| 41 |
def check_and_install_sam2():
|
| 42 |
"""Check if SAM-2 is available and attempt installation if needed."""
|
| 43 |
try:
|
| 44 |
-
|
| 45 |
from sam2.build_sam import build_sam2
|
| 46 |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
|
|
|
| 47 |
return True, "SAM-2 already available"
|
| 48 |
-
except ImportError:
|
| 49 |
-
print("SAM-2
|
|
|
|
| 50 |
try:
|
| 51 |
# Clone SAM-2 repository
|
| 52 |
if not os.path.exists("segment-anything-2"):
|
|
|
|
| 53 |
subprocess.run([
|
| 54 |
"git", "clone",
|
| 55 |
"https://github.com/facebookresearch/segment-anything-2.git"
|
| 56 |
], check=True)
|
|
|
|
| 57 |
|
| 58 |
# Install SAM-2
|
|
|
|
| 59 |
original_dir = os.getcwd()
|
| 60 |
os.chdir("segment-anything-2")
|
| 61 |
subprocess.run([sys.executable, "-m", "pip", "install", "-e", "."], check=True)
|
| 62 |
os.chdir(original_dir)
|
|
|
|
| 63 |
|
| 64 |
# Add to Python path
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
# Try importing again
|
|
|
|
| 68 |
from sam2.build_sam import build_sam2
|
| 69 |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
|
|
|
| 70 |
return True, "SAM-2 installed successfully"
|
| 71 |
|
| 72 |
except Exception as e:
|
| 73 |
-
print(f"
|
|
|
|
| 74 |
return False, f"SAM-2 installation failed: {e}"
|
| 75 |
|
| 76 |
# Check SAM-2 availability
|
|
@@ -85,7 +97,6 @@ if SAM2_AVAILABLE:
|
|
| 85 |
from sam2.build_sam import build_sam2
|
| 86 |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
| 87 |
from sam2.modeling.sam2_base import SAM2Base
|
| 88 |
-
from sam2.utils.misc import get_device_index
|
| 89 |
except ImportError as e:
|
| 90 |
print(f"SAM-2 import error: {e}")
|
| 91 |
SAM2_AVAILABLE = False
|
|
@@ -183,75 +194,48 @@ class MedicalVLMAgent:
|
|
| 183 |
return self.processor.decode(trimmed, skip_special_tokens=True).strip()
|
| 184 |
|
| 185 |
# =============================================================================
|
| 186 |
-
# SAM-2 model + AutomaticMaskGenerator (
|
| 187 |
# =============================================================================
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
if not os.path.exists(checkpoint_path):
|
| 195 |
-
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 196 |
-
print("Downloading SAM-2 checkpoint...")
|
| 197 |
-
try:
|
| 198 |
-
import urllib.request
|
| 199 |
-
url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt"
|
| 200 |
-
urllib.request.urlretrieve(url, checkpoint_path)
|
| 201 |
-
print("SAM-2 checkpoint downloaded successfully")
|
| 202 |
-
except Exception as e:
|
| 203 |
-
print(f"Failed to download SAM-2 checkpoint: {e}")
|
| 204 |
-
return None
|
| 205 |
-
|
| 206 |
-
return checkpoint_path
|
| 207 |
|
| 208 |
def initialize_sam2():
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
try:
|
| 214 |
-
# Download checkpoint if needed
|
| 215 |
-
checkpoint_path = download_sam2_checkpoint()
|
| 216 |
-
if checkpoint_path is None:
|
| 217 |
-
return None, None
|
| 218 |
-
|
| 219 |
-
# Config path (you may need to adjust this)
|
| 220 |
-
config_path = "segment-anything-2/sam2/configs/sam2.1/sam2.1_hiera_l.yaml"
|
| 221 |
-
if not os.path.exists(config_path):
|
| 222 |
-
config_path = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
| 223 |
-
|
| 224 |
-
device = get_device()
|
| 225 |
-
print(f"[SAM-2] building model on {device}")
|
| 226 |
-
|
| 227 |
-
sam2_model = build_sam2(
|
| 228 |
-
config_path,
|
| 229 |
-
checkpoint_path,
|
| 230 |
-
device=device,
|
| 231 |
-
apply_postprocessing=False,
|
| 232 |
-
)
|
| 233 |
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
points_per_side=32,
|
| 237 |
-
pred_iou_thresh=0.86,
|
| 238 |
-
stability_score_thresh=0.92,
|
| 239 |
-
crop_n_layers=0,
|
| 240 |
-
)
|
| 241 |
-
return sam2_model, mask_gen
|
| 242 |
-
|
| 243 |
-
except Exception as e:
|
| 244 |
-
print(f"[SAM-2] Failed to initialize: {e}")
|
| 245 |
-
return None, None
|
| 246 |
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
_sam2_model, _mask_generator = initialize_sam2()
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
|
| 256 |
def automatic_mask_overlay(image_np: np.ndarray) -> np.ndarray:
|
| 257 |
"""Generate masks and alpha-blend them on top of the original image."""
|
|
@@ -274,13 +258,9 @@ def automatic_mask_overlay(image_np: np.ndarray) -> np.ndarray:
|
|
| 274 |
return overlay
|
| 275 |
|
| 276 |
def tumor_segmentation_interface(image: Image.Image | None):
|
| 277 |
-
"""Tumor segmentation interface with proper error handling."""
|
| 278 |
if image is None:
|
| 279 |
return None, "Please upload an image."
|
| 280 |
|
| 281 |
-
if not SAM2_AVAILABLE:
|
| 282 |
-
return None, "SAM-2 is not available. Please check installation."
|
| 283 |
-
|
| 284 |
if _mask_generator is None:
|
| 285 |
return None, "SAM-2 not properly initialized. Check the console for errors."
|
| 286 |
|
|
@@ -338,16 +318,34 @@ def simple_segmentation_fallback(image: Image.Image | None):
|
|
| 338 |
# CheXagent set-up
|
| 339 |
# =============================================================================
|
| 340 |
try:
|
|
|
|
| 341 |
chex_name = "StanfordAIMI/CheXagent-2-3b"
|
|
|
|
| 342 |
chex_tok = AutoTokenizer.from_pretrained(chex_name, trust_remote_code=True)
|
|
|
|
|
|
|
|
|
|
| 343 |
chex_model = AutoModelForCausalLM.from_pretrained(
|
| 344 |
-
chex_name,
|
|
|
|
|
|
|
|
|
|
| 345 |
)
|
| 346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
chex_model.eval()
|
| 348 |
CHEXAGENT_AVAILABLE = True
|
|
|
|
| 349 |
except Exception as e:
|
| 350 |
-
print(f"CheXagent
|
|
|
|
| 351 |
CHEXAGENT_AVAILABLE = False
|
| 352 |
chex_tok, chex_model = None, None
|
| 353 |
|
|
|
|
| 41 |
def check_and_install_sam2():
|
| 42 |
"""Check if SAM-2 is available and attempt installation if needed."""
|
| 43 |
try:
|
| 44 |
+
print("[SAM-2 Debug] Attempting to import SAM-2 modules...")
|
| 45 |
from sam2.build_sam import build_sam2
|
| 46 |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
| 47 |
+
print("[SAM-2 Debug] Successfully imported SAM-2 modules")
|
| 48 |
return True, "SAM-2 already available"
|
| 49 |
+
except ImportError as e:
|
| 50 |
+
print(f"[SAM-2 Debug] Import error: {str(e)}")
|
| 51 |
+
print("[SAM-2 Debug] Attempting to install SAM-2...")
|
| 52 |
try:
|
| 53 |
# Clone SAM-2 repository
|
| 54 |
if not os.path.exists("segment-anything-2"):
|
| 55 |
+
print("[SAM-2 Debug] Cloning SAM-2 repository...")
|
| 56 |
subprocess.run([
|
| 57 |
"git", "clone",
|
| 58 |
"https://github.com/facebookresearch/segment-anything-2.git"
|
| 59 |
], check=True)
|
| 60 |
+
print("[SAM-2 Debug] Repository cloned successfully")
|
| 61 |
|
| 62 |
# Install SAM-2
|
| 63 |
+
print("[SAM-2 Debug] Installing SAM-2...")
|
| 64 |
original_dir = os.getcwd()
|
| 65 |
os.chdir("segment-anything-2")
|
| 66 |
subprocess.run([sys.executable, "-m", "pip", "install", "-e", "."], check=True)
|
| 67 |
os.chdir(original_dir)
|
| 68 |
+
print("[SAM-2 Debug] Installation completed")
|
| 69 |
|
| 70 |
# Add to Python path
|
| 71 |
+
sam2_path = os.path.abspath("segment-anything-2")
|
| 72 |
+
if sam2_path not in sys.path:
|
| 73 |
+
sys.path.insert(0, sam2_path)
|
| 74 |
+
print(f"[SAM-2 Debug] Added {sam2_path} to Python path")
|
| 75 |
|
| 76 |
# Try importing again
|
| 77 |
+
print("[SAM-2 Debug] Attempting to import SAM-2 modules again...")
|
| 78 |
from sam2.build_sam import build_sam2
|
| 79 |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
| 80 |
+
print("[SAM-2 Debug] Successfully imported SAM-2 modules after installation")
|
| 81 |
return True, "SAM-2 installed successfully"
|
| 82 |
|
| 83 |
except Exception as e:
|
| 84 |
+
print(f"[SAM-2 Debug] Installation failed: {str(e)}")
|
| 85 |
+
print(f"[SAM-2 Debug] Error type: {type(e).__name__}")
|
| 86 |
return False, f"SAM-2 installation failed: {e}"
|
| 87 |
|
| 88 |
# Check SAM-2 availability
|
|
|
|
| 97 |
from sam2.build_sam import build_sam2
|
| 98 |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
| 99 |
from sam2.modeling.sam2_base import SAM2Base
|
|
|
|
| 100 |
except ImportError as e:
|
| 101 |
print(f"SAM-2 import error: {e}")
|
| 102 |
SAM2_AVAILABLE = False
|
|
|
|
| 194 |
return self.processor.decode(trimmed, skip_special_tokens=True).strip()
|
| 195 |
|
| 196 |
# =============================================================================
|
| 197 |
+
# SAM-2 model + AutomaticMaskGenerator (final minimal version)
|
| 198 |
# =============================================================================
|
| 199 |
+
import os
|
| 200 |
+
import numpy as np
|
| 201 |
+
from PIL import Image, ImageDraw
|
| 202 |
+
from sam2.build_sam import build_sam2
|
| 203 |
+
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
def initialize_sam2():
|
| 206 |
+
# These two files are already in your repo
|
| 207 |
+
CKPT = "checkpoints/sam2.1_hiera_large.pt" # ≈2.7 GB
|
| 208 |
+
CFG = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
+
# One chdir so Hydra's search path starts inside sam2/sam2/
|
| 211 |
+
os.chdir("sam2/sam2")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
+
device = get_device()
|
| 214 |
+
print(f"[SAM-2] building model on {device}")
|
| 215 |
+
|
| 216 |
+
sam2_model = build_sam2(
|
| 217 |
+
CFG, # relative to sam2/sam2/
|
| 218 |
+
CKPT, # relative after chdir
|
| 219 |
+
device=device,
|
| 220 |
+
apply_postprocessing=False,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
mask_gen = SAM2AutomaticMaskGenerator(
|
| 224 |
+
model=sam2_model,
|
| 225 |
+
points_per_side=32,
|
| 226 |
+
pred_iou_thresh=0.86,
|
| 227 |
+
stability_score_thresh=0.92,
|
| 228 |
+
crop_n_layers=0,
|
| 229 |
+
)
|
| 230 |
+
return sam2_model, mask_gen
|
| 231 |
+
|
| 232 |
+
# ---------------------- build once ----------------------
|
| 233 |
+
try:
|
| 234 |
_sam2_model, _mask_generator = initialize_sam2()
|
| 235 |
+
print("[SAM-2] Successfully initialized!")
|
| 236 |
+
except Exception as e:
|
| 237 |
+
print(f"[SAM-2] Failed to initialize: {e}")
|
| 238 |
+
_sam2_model, _mask_generator = None, None
|
| 239 |
|
| 240 |
def automatic_mask_overlay(image_np: np.ndarray) -> np.ndarray:
|
| 241 |
"""Generate masks and alpha-blend them on top of the original image."""
|
|
|
|
| 258 |
return overlay
|
| 259 |
|
| 260 |
def tumor_segmentation_interface(image: Image.Image | None):
|
|
|
|
| 261 |
if image is None:
|
| 262 |
return None, "Please upload an image."
|
| 263 |
|
|
|
|
|
|
|
|
|
|
| 264 |
if _mask_generator is None:
|
| 265 |
return None, "SAM-2 not properly initialized. Check the console for errors."
|
| 266 |
|
|
|
|
| 318 |
# CheXagent set-up
|
| 319 |
# =============================================================================
|
| 320 |
try:
|
| 321 |
+
print("[CheXagent] Starting initialization...")
|
| 322 |
chex_name = "StanfordAIMI/CheXagent-2-3b"
|
| 323 |
+
print(f"[CheXagent] Loading tokenizer from {chex_name}")
|
| 324 |
chex_tok = AutoTokenizer.from_pretrained(chex_name, trust_remote_code=True)
|
| 325 |
+
print("[CheXagent] Tokenizer loaded successfully")
|
| 326 |
+
|
| 327 |
+
print("[CheXagent] Loading model...")
|
| 328 |
chex_model = AutoModelForCausalLM.from_pretrained(
|
| 329 |
+
chex_name,
|
| 330 |
+
device_map="auto",
|
| 331 |
+
trust_remote_code=True,
|
| 332 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
|
| 333 |
)
|
| 334 |
+
print("[CheXagent] Model loaded successfully")
|
| 335 |
+
|
| 336 |
+
if torch.cuda.is_available():
|
| 337 |
+
print("[CheXagent] Converting to half precision for GPU")
|
| 338 |
+
chex_model = chex_model.half()
|
| 339 |
+
else:
|
| 340 |
+
print("[CheXagent] Using full precision for CPU")
|
| 341 |
+
chex_model = chex_model.float()
|
| 342 |
+
|
| 343 |
chex_model.eval()
|
| 344 |
CHEXAGENT_AVAILABLE = True
|
| 345 |
+
print("[CheXagent] Initialization complete")
|
| 346 |
except Exception as e:
|
| 347 |
+
print(f"[CheXagent] Initialization failed: {str(e)}")
|
| 348 |
+
print(f"[CheXagent] Error type: {type(e).__name__}")
|
| 349 |
CHEXAGENT_AVAILABLE = False
|
| 350 |
chex_tok, chex_model = None, None
|
| 351 |
|