Spaces:
Running
Running
File size: 2,286 Bytes
9c4b1c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
"""
Download model weights for Hugging Face Spaces deployment.
This script downloads model weights on first run if they're not present.
"""
import os
import urllib.request
import ssl
# Bypass SSL verification for downloads
try:
_create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
pass
else:
ssl._create_default_https_context = _create_unverified_https_context
def download_file(url, dest_path):
"""Download a file from URL to destination path."""
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
if os.path.exists(dest_path):
print(f"β {dest_path} already exists")
return
print(f"Downloading {os.path.basename(dest_path)}...")
try:
urllib.request.urlretrieve(url, dest_path)
print(f"β Downloaded {dest_path}")
except Exception as e:
print(f"β Failed to download {dest_path}: {e}")
# Model weights URLs (update these with actual URLs)
WEIGHTS_URLS = {
"R50_TF": "https://drive.google.com/uc?export=download&id=YOUR_GOOGLE_DRIVE_ID", # Replace
"R50_nodown": "https://drive.google.com/uc?export=download&id=YOUR_GOOGLE_DRIVE_ID", # Replace
"CLIP-D": "https://drive.google.com/uc?export=download&id=YOUR_GOOGLE_DRIVE_ID", # Replace
"P2G": "https://drive.google.com/uc?export=download&id=YOUR_GOOGLE_DRIVE_ID", # Replace
"NPR": "https://drive.google.com/uc?export=download&id=YOUR_GOOGLE_DRIVE_ID", # Replace
}
def download_all_weights():
"""Download all model weights if not present."""
print("Checking model weights...")
for model_name, url in WEIGHTS_URLS.items():
dest_path = f"detectors/{model_name}/checkpoint/pretrained/weights/best.pt"
# Skip if URL not configured
if "YOUR_GOOGLE_DRIVE_ID" in url:
print(f"β Skipping {model_name}: URL not configured")
continue
download_file(url, dest_path)
# Download P2G classes.pkl
classes_url = "https://github.com/laitifranz/Prompt2Guard/raw/main/src/utils/classes.pkl"
classes_path = "detectors/P2G/src/utils/classes.pkl"
download_file(classes_url, classes_path)
print("\nWeight check complete!")
if __name__ == "__main__":
download_all_weights()
|