pascal-maker commited on
Commit
6cd7b7a
·
verified ·
1 Parent(s): 2fb54d3

updated app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -76
app.py CHANGED
@@ -41,36 +41,48 @@ import gradio as gr
41
  def check_and_install_sam2():
42
  """Check if SAM-2 is available and attempt installation if needed."""
43
  try:
44
- # Try importing SAM-2
45
  from sam2.build_sam import build_sam2
46
  from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
 
47
  return True, "SAM-2 already available"
48
- except ImportError:
49
- print("SAM-2 not found. Attempting to install...")
 
50
  try:
51
  # Clone SAM-2 repository
52
  if not os.path.exists("segment-anything-2"):
 
53
  subprocess.run([
54
  "git", "clone",
55
  "https://github.com/facebookresearch/segment-anything-2.git"
56
  ], check=True)
 
57
 
58
  # Install SAM-2
 
59
  original_dir = os.getcwd()
60
  os.chdir("segment-anything-2")
61
  subprocess.run([sys.executable, "-m", "pip", "install", "-e", "."], check=True)
62
  os.chdir(original_dir)
 
63
 
64
  # Add to Python path
65
- sys.path.insert(0, os.path.abspath("segment-anything-2"))
 
 
 
66
 
67
  # Try importing again
 
68
  from sam2.build_sam import build_sam2
69
  from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
 
70
  return True, "SAM-2 installed successfully"
71
 
72
  except Exception as e:
73
- print(f"Failed to install SAM-2: {e}")
 
74
  return False, f"SAM-2 installation failed: {e}"
75
 
76
  # Check SAM-2 availability
@@ -85,7 +97,6 @@ if SAM2_AVAILABLE:
85
  from sam2.build_sam import build_sam2
86
  from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
87
  from sam2.modeling.sam2_base import SAM2Base
88
- from sam2.utils.misc import get_device_index
89
  except ImportError as e:
90
  print(f"SAM-2 import error: {e}")
91
  SAM2_AVAILABLE = False
@@ -183,75 +194,48 @@ class MedicalVLMAgent:
183
  return self.processor.decode(trimmed, skip_special_tokens=True).strip()
184
 
185
  # =============================================================================
186
- # SAM-2 model + AutomaticMaskGenerator (conditional)
187
  # =============================================================================
188
- def download_sam2_checkpoint():
189
- """Download SAM-2 checkpoint if not present."""
190
- checkpoint_dir = "checkpoints"
191
- checkpoint_file = "sam2.1_hiera_large.pt"
192
- checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file)
193
-
194
- if not os.path.exists(checkpoint_path):
195
- os.makedirs(checkpoint_dir, exist_ok=True)
196
- print("Downloading SAM-2 checkpoint...")
197
- try:
198
- import urllib.request
199
- url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt"
200
- urllib.request.urlretrieve(url, checkpoint_path)
201
- print("SAM-2 checkpoint downloaded successfully")
202
- except Exception as e:
203
- print(f"Failed to download SAM-2 checkpoint: {e}")
204
- return None
205
-
206
- return checkpoint_path
207
 
208
  def initialize_sam2():
209
- """Initialize SAM-2 model and mask generator."""
210
- if not SAM2_AVAILABLE:
211
- return None, None
212
-
213
- try:
214
- # Download checkpoint if needed
215
- checkpoint_path = download_sam2_checkpoint()
216
- if checkpoint_path is None:
217
- return None, None
218
-
219
- # Config path (you may need to adjust this)
220
- config_path = "segment-anything-2/sam2/configs/sam2.1/sam2.1_hiera_l.yaml"
221
- if not os.path.exists(config_path):
222
- config_path = "configs/sam2.1/sam2.1_hiera_l.yaml"
223
-
224
- device = get_device()
225
- print(f"[SAM-2] building model on {device}")
226
-
227
- sam2_model = build_sam2(
228
- config_path,
229
- checkpoint_path,
230
- device=device,
231
- apply_postprocessing=False,
232
- )
233
 
234
- mask_gen = SAM2AutomaticMaskGenerator(
235
- model=sam2_model,
236
- points_per_side=32,
237
- pred_iou_thresh=0.86,
238
- stability_score_thresh=0.92,
239
- crop_n_layers=0,
240
- )
241
- return sam2_model, mask_gen
242
-
243
- except Exception as e:
244
- print(f"[SAM-2] Failed to initialize: {e}")
245
- return None, None
246
 
247
- # Initialize SAM-2 (conditional)
248
- _sam2_model, _mask_generator = None, None
249
- if SAM2_AVAILABLE:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  _sam2_model, _mask_generator = initialize_sam2()
251
- if _sam2_model is not None:
252
- print("[SAM-2] Successfully initialized!")
253
- else:
254
- print("[SAM-2] Initialization failed")
255
 
256
  def automatic_mask_overlay(image_np: np.ndarray) -> np.ndarray:
257
  """Generate masks and alpha-blend them on top of the original image."""
@@ -274,13 +258,9 @@ def automatic_mask_overlay(image_np: np.ndarray) -> np.ndarray:
274
  return overlay
275
 
276
  def tumor_segmentation_interface(image: Image.Image | None):
277
- """Tumor segmentation interface with proper error handling."""
278
  if image is None:
279
  return None, "Please upload an image."
280
 
281
- if not SAM2_AVAILABLE:
282
- return None, "SAM-2 is not available. Please check installation."
283
-
284
  if _mask_generator is None:
285
  return None, "SAM-2 not properly initialized. Check the console for errors."
286
 
@@ -338,16 +318,34 @@ def simple_segmentation_fallback(image: Image.Image | None):
338
  # CheXagent set-up
339
  # =============================================================================
340
  try:
 
341
  chex_name = "StanfordAIMI/CheXagent-2-3b"
 
342
  chex_tok = AutoTokenizer.from_pretrained(chex_name, trust_remote_code=True)
 
 
 
343
  chex_model = AutoModelForCausalLM.from_pretrained(
344
- chex_name, device_map="auto", trust_remote_code=True
 
 
 
345
  )
346
- chex_model = chex_model.half() if torch.cuda.is_available() else chex_model.float()
 
 
 
 
 
 
 
 
347
  chex_model.eval()
348
  CHEXAGENT_AVAILABLE = True
 
349
  except Exception as e:
350
- print(f"CheXagent not available: {e}")
 
351
  CHEXAGENT_AVAILABLE = False
352
  chex_tok, chex_model = None, None
353
 
 
41
  def check_and_install_sam2():
42
  """Check if SAM-2 is available and attempt installation if needed."""
43
  try:
44
+ print("[SAM-2 Debug] Attempting to import SAM-2 modules...")
45
  from sam2.build_sam import build_sam2
46
  from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
47
+ print("[SAM-2 Debug] Successfully imported SAM-2 modules")
48
  return True, "SAM-2 already available"
49
+ except ImportError as e:
50
+ print(f"[SAM-2 Debug] Import error: {str(e)}")
51
+ print("[SAM-2 Debug] Attempting to install SAM-2...")
52
  try:
53
  # Clone SAM-2 repository
54
  if not os.path.exists("segment-anything-2"):
55
+ print("[SAM-2 Debug] Cloning SAM-2 repository...")
56
  subprocess.run([
57
  "git", "clone",
58
  "https://github.com/facebookresearch/segment-anything-2.git"
59
  ], check=True)
60
+ print("[SAM-2 Debug] Repository cloned successfully")
61
 
62
  # Install SAM-2
63
+ print("[SAM-2 Debug] Installing SAM-2...")
64
  original_dir = os.getcwd()
65
  os.chdir("segment-anything-2")
66
  subprocess.run([sys.executable, "-m", "pip", "install", "-e", "."], check=True)
67
  os.chdir(original_dir)
68
+ print("[SAM-2 Debug] Installation completed")
69
 
70
  # Add to Python path
71
+ sam2_path = os.path.abspath("segment-anything-2")
72
+ if sam2_path not in sys.path:
73
+ sys.path.insert(0, sam2_path)
74
+ print(f"[SAM-2 Debug] Added {sam2_path} to Python path")
75
 
76
  # Try importing again
77
+ print("[SAM-2 Debug] Attempting to import SAM-2 modules again...")
78
  from sam2.build_sam import build_sam2
79
  from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
80
+ print("[SAM-2 Debug] Successfully imported SAM-2 modules after installation")
81
  return True, "SAM-2 installed successfully"
82
 
83
  except Exception as e:
84
+ print(f"[SAM-2 Debug] Installation failed: {str(e)}")
85
+ print(f"[SAM-2 Debug] Error type: {type(e).__name__}")
86
  return False, f"SAM-2 installation failed: {e}"
87
 
88
  # Check SAM-2 availability
 
97
  from sam2.build_sam import build_sam2
98
  from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
99
  from sam2.modeling.sam2_base import SAM2Base
 
100
  except ImportError as e:
101
  print(f"SAM-2 import error: {e}")
102
  SAM2_AVAILABLE = False
 
194
  return self.processor.decode(trimmed, skip_special_tokens=True).strip()
195
 
196
  # =============================================================================
197
+ # SAM-2 model + AutomaticMaskGenerator (final minimal version)
198
  # =============================================================================
199
+ import os
200
+ import numpy as np
201
+ from PIL import Image, ImageDraw
202
+ from sam2.build_sam import build_sam2
203
+ from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
  def initialize_sam2():
206
+ # These two files are already in your repo
207
+ CKPT = "checkpoints/sam2.1_hiera_large.pt" # ≈2.7 GB
208
+ CFG = "configs/sam2.1/sam2.1_hiera_l.yaml"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
+ # One chdir so Hydra's search path starts inside sam2/sam2/
211
+ os.chdir("sam2/sam2")
 
 
 
 
 
 
 
 
 
 
212
 
213
+ device = get_device()
214
+ print(f"[SAM-2] building model on {device}")
215
+
216
+ sam2_model = build_sam2(
217
+ CFG, # relative to sam2/sam2/
218
+ CKPT, # relative after chdir
219
+ device=device,
220
+ apply_postprocessing=False,
221
+ )
222
+
223
+ mask_gen = SAM2AutomaticMaskGenerator(
224
+ model=sam2_model,
225
+ points_per_side=32,
226
+ pred_iou_thresh=0.86,
227
+ stability_score_thresh=0.92,
228
+ crop_n_layers=0,
229
+ )
230
+ return sam2_model, mask_gen
231
+
232
+ # ---------------------- build once ----------------------
233
+ try:
234
  _sam2_model, _mask_generator = initialize_sam2()
235
+ print("[SAM-2] Successfully initialized!")
236
+ except Exception as e:
237
+ print(f"[SAM-2] Failed to initialize: {e}")
238
+ _sam2_model, _mask_generator = None, None
239
 
240
  def automatic_mask_overlay(image_np: np.ndarray) -> np.ndarray:
241
  """Generate masks and alpha-blend them on top of the original image."""
 
258
  return overlay
259
 
260
  def tumor_segmentation_interface(image: Image.Image | None):
 
261
  if image is None:
262
  return None, "Please upload an image."
263
 
 
 
 
264
  if _mask_generator is None:
265
  return None, "SAM-2 not properly initialized. Check the console for errors."
266
 
 
318
  # CheXagent set-up
319
  # =============================================================================
320
  try:
321
+ print("[CheXagent] Starting initialization...")
322
  chex_name = "StanfordAIMI/CheXagent-2-3b"
323
+ print(f"[CheXagent] Loading tokenizer from {chex_name}")
324
  chex_tok = AutoTokenizer.from_pretrained(chex_name, trust_remote_code=True)
325
+ print("[CheXagent] Tokenizer loaded successfully")
326
+
327
+ print("[CheXagent] Loading model...")
328
  chex_model = AutoModelForCausalLM.from_pretrained(
329
+ chex_name,
330
+ device_map="auto",
331
+ trust_remote_code=True,
332
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
333
  )
334
+ print("[CheXagent] Model loaded successfully")
335
+
336
+ if torch.cuda.is_available():
337
+ print("[CheXagent] Converting to half precision for GPU")
338
+ chex_model = chex_model.half()
339
+ else:
340
+ print("[CheXagent] Using full precision for CPU")
341
+ chex_model = chex_model.float()
342
+
343
  chex_model.eval()
344
  CHEXAGENT_AVAILABLE = True
345
+ print("[CheXagent] Initialization complete")
346
  except Exception as e:
347
+ print(f"[CheXagent] Initialization failed: {str(e)}")
348
+ print(f"[CheXagent] Error type: {type(e).__name__}")
349
  CHEXAGENT_AVAILABLE = False
350
  chex_tok, chex_model = None, None
351