diff-usion / app.py
miachiquier's picture
Update app.py
836d983 verified
import PIL
import math
import torch
import random
import os
import numpy as np
import pandas as pd
import gradio as gr
import threading
import time
import zipfile
import shutil
import glob
from pathlib import Path
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision.transforms.v2 as transforms
from PIL import Image, ImageDraw, ImageFont
import imageio
from tqdm import tqdm
import tarfile
import queue
import hashlib
import json
# Set seeds for reproducibility
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)
# Define constants
IMG_SIZE = 512
BATCH_SIZE = 32
DEVICE = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
# Add these global variables after the imports and before the CSS definition
# Global variables for LoRA training
lora_status = "Ready"
lora_is_processing = False
# Global variables for generation control
generation_should_stop = False
classifier_should_stop = False # New flag for classifier training
embedding_should_stop = False # New flag for embedding encoding
lora_should_stop = False # New flag for LoRA training
generation_queue = queue.Queue()
is_processing = False # Add this to prevent multiple simultaneous processes
# Create temporary directories for uploads
temp_dir = Path("./temp_uploads")
temp_dir.mkdir(exist_ok=True, parents=True)
lora_temp_dir = Path("./temp_lora_uploads")
lora_temp_dir.mkdir(exist_ok=True, parents=True)
# Create a global queue for real-time updates
result_queue = queue.Queue()
displayed_results = [] # Keep track of all displayed results
# Add these global variables at the top of your file
total_images_to_process = 0
images_processed = 0
# Add these global variables after the existing ones
displayed_results_class0_to_class1 = [] # Results for class 0 to class 1
displayed_results_class1_to_class0 = [] # Results for class 1 to class 0
# Add global variables for caching
CACHE_DIR = Path("./cached_results")
CACHE_DIR.mkdir(exist_ok=True, parents=True)
# CSS for styling the interface
css = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
body, * {
font-family: 'Inter', sans-serif !important;
letter-spacing: -0.01em;
}
.container {
max-width: 1360px;
margin: auto;
padding-top: 2.5rem;
padding-bottom: 2.5rem;
}
.header {
text-align: center;
margin-bottom: 3rem;
padding-bottom: 2rem;
border-bottom: 1px solid #f0f0f0;
}
.header h1 {
font-size: 3rem;
font-weight: 700;
color: #222;
letter-spacing: -0.03em;
margin-bottom: 1rem;
background: linear-gradient(90deg, #B39CD0 0%, #9D8AC7 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
display: inline-block;
}
.header p {
font-size: 1.1rem;
color: #333;
max-width: 800px;
margin: 0 auto;
line-height: 1.6;
}
.subtitle {
font-size: 0.95rem;
color: #777;
max-width: 800px;
margin: 0.5rem auto 0;
line-height: 1.5;
}
.contact-info {
font-size: 0.8rem;
color: #777;
margin-top: 15px;
padding-top: 10px;
border-top: 1px dashed #e0e0e0;
width: 80%;
margin-left: auto;
margin-right: auto;
}
.paper-info {
background-color: #f8f9fa;
border-radius: 12px;
padding: 1.8rem;
margin: 1.8rem 0;
box-shadow: 0 6px 20px rgba(0,0,0,0.05);
border-left: 4px solid #B39CD0;
}
.paper-info h3 {
font-size: 1.5rem;
font-weight: 600;
color: #B39CD0;
letter-spacing: -0.02em;
margin-bottom: 1rem;
}
.paper-info p {
font-size: 1.05em;
line-height: 1.7;
color: #333;
}
.section-header {
font-size: 1.8rem;
font-weight: 600;
color: #B39CD0;
margin: 2.5rem 0 1.5rem 0;
padding-bottom: 0.8rem;
border-bottom: 2px solid #ECF0F1;
letter-spacing: -0.02em;
}
.footer {
text-align: center;
margin-top: 3rem;
padding: 1.5rem;
border-top: 1px solid #ECF0F1;
color: #666;
background-color: #f8f9fa;
border-radius: 0 0 12px 12px;
}
.btn-primary {
background-color: #B39CD0 !important;
border-color: #B39CD0 !important;
transition: all 0.3s ease;
font-weight: 500 !important;
letter-spacing: 0.02em !important;
padding: 0.6rem 1.5rem !important;
border-radius: 8px !important;
}
.btn-primary:hover {
background-color: #9D8AC7 !important;
border-color: #9D8AC7 !important;
}
/* Hide the output directory */
.hidden-element {
display: none !important;
}
/* Additional CSS for better alignment */
.container {
padding: 0 1.5rem;
}
.main-container {
display: flex;
flex-direction: column;
gap: 1.5rem;
}
.results-container {
margin-top: 0;
padding-top: 0;
}
.full-width-header {
margin-bottom: 2rem;
padding-bottom: 1.5rem;
border-bottom: 1px solid #f0f0f0;
text-align: center;
}
.content-row {
display: flex;
gap: 2rem;
}
.sidebar {
min-width: 250px;
padding-right: 1.5rem;
}
.section-header {
margin-top: 0;
}
.tabs-container {
margin-top: 1rem;
}
.gallery-container {
margin-top: 1rem;
}
/* Hide the output directory */
.hidden-element {
display: none !important;
}
.gallery-item img {
object-fit: contain !important;
height: 200px !important;
width: auto !important;
}
/* Force GIFs to restart when tab is selected */
.tabs-container .tabitem[style*="display: block"] .gallery-container img {
animation: none;
animation: reload-animation 0.1s;
}
@keyframes reload-animation {
0% { opacity: 0.99; }
100% { opacity: 1; }
}
"""
# Add to your global variables
current_cache_key = None
is_using_default_params = False
# Update the EXAMPLE_DATASETS to include direct dataset paths, embeddings, and classifiers
EXAMPLE_DATASETS = [
{
"name": "butterfly",
"display_name": "Butterfly (Monarch vs Viceroy)",
"description": "Dataset containing images of Monarch and Viceroy butterflies for counterfactual generation",
"path": "/proj/vondrick/datasets/magnification/butterfly.tar.gz",
"direct_dataset_path": "example_images/butterfly",
"checkpoint_path": "/proj/vondrick2/mia/magnificationold/output/lora/butterfly/copper-forest-49/checkpoint-1800",
"embeddings_path": "/proj/vondrick2/mia/diff-usion/results/clip_image_embeds/butterfly",
"classifier_path": "/proj/vondrick2/mia/diff-usion/results/ensemble/butterfly",
"class_names": ["class0", "class1"]
},
{
"name": "afhq",
"display_name": "Cats vs. Dogs (AFHQ)",
"description": "Dataset containing images of table lamps and floor lamps",
"direct_dataset_path": "example_images/afhq",
"checkpoint_path": None,
"embeddings_path": "/proj/vondrick2/mia/diff-usion/results/clip_image_embeds/afhq",
"classifier_path": "/proj/vondrick2/mia/diff-usion/results/ensemble/afhq",
"class_names": ["class0", "class1"]
},
{
"name": "lamp",
"display_name": "Lamps",
"description": "Dataset containing images of table lamps and floor lamps",
"path": "compressed_datasets/lampsfar.zip",
"direct_dataset_path": "example_images/lamps",
"checkpoint_path": "/proj/vondrick2/mia/diff-usion/lora_output_lampsfar/checkpoint-800",
"embeddings_path": "/proj/vondrick2/mia/diff-usion/results/clip_image_embeds/lampsfar",
"classifier_path": "/proj/vondrick2/mia/diff-usion/results/ensemble/lampsfar",
"class_names": ["class0", "class1"]
},
{
"name": "couches",
"display_name": "Couches",
"description": "Dataset containing images of chairs and floor",
"path": "compressed_datasets/couches.zip",
"direct_dataset_path": "example_images/couches",
"embeddings_path": "/proj/vondrick2/mia/diff-usion/results/clip_image_embeds/couches",
"checkpoint_path": "/proj/vondrick2/mia/diff-usion/lora_output/couches/checkpoint-1000",
"class_names": ["class0", "class1"]
}
]
# Function to get available example datasets
def get_example_datasets():
"""Get list of available example datasets"""
return [dataset["name"] for dataset in EXAMPLE_DATASETS]
# Function to get example dataset info
def get_example_dataset_info(name):
"""Get information about an example dataset"""
for dataset in EXAMPLE_DATASETS:
if dataset["name"] == name:
return dataset
return None
#Function to check if we're using default parameters
def is_using_default_params(dataset_name, custom_tskip, num_images_per_class):
"""Check if we're using default parameters for the given dataset"""
if dataset_name is None:
return False
if "butterfly" in dataset_name.lower():
return (custom_tskip == 70 or custom_tskip == "70") and num_images_per_class == 10
elif "lamp" in dataset_name.lower():
return (custom_tskip == 85 or custom_tskip == "85") and num_images_per_class == 10
elif "couch" in dataset_name.lower():
return (custom_tskip == 85 or custom_tskip == "85") and num_images_per_class == 10
return False
# # Function to get the output directory - either cache or regular output
# def get_output_directory(dataset_name, is_default_params, cache_key):
# """Get the appropriate output directory based on parameters"""
# if is_default_params:
# # Use cache directory
# cache_path = CACHE_DIR / cache_key
# cache_path.mkdir(exist_ok=True, parents=True)
# # Create dataset-specific directory
# dataset_dir = cache_path / dataset_name.replace(" ", "_").lower()
# dataset_dir.mkdir(exist_ok=True, parents=True)
# # Create class-specific directories
# class0_to_class1_dir = dataset_dir / "class0_to_class1"
# class1_to_class0_dir = dataset_dir / "class1_to_class0"
# class0_to_class1_dir.mkdir(exist_ok=True, parents=True)
# class1_to_class0_dir.mkdir(exist_ok=True, parents=True)
# # Create context directory
# context_dir = dataset_dir / "context"
# context_dir.mkdir(exist_ok=True, parents=True)
# return dataset_dir, class0_to_class1_dir, class1_to_class0_dir, context_dir
# else:
# # Use regular output directory
# output_dir = Path(f"./results/{dataset_name.replace(' ', '_').lower()}")
# output_dir.mkdir(exist_ok=True, parents=True)
# # Create gifs directory with class-specific subdirectories
# gifs_dir = output_dir / "gifs"
# gifs_dir.mkdir(exist_ok=True, parents=True)
# class0_to_class1_dir = gifs_dir / "class0_to_class1"
# class1_to_class0_dir = gifs_dir / "class1_to_class0"
# class0_to_class1_dir.mkdir(exist_ok=True, parents=True)
# class1_to_class0_dir.mkdir(exist_ok=True, parents=True)
# # Create context directory
# context_dir = output_dir / "context"
# context_dir.mkdir(exist_ok=True, parents=True)
# return output_dir, class0_to_class1_dir, class1_to_class0_dir, context_dir
def has_prediction_flipped(orig_preds, new_preds):
"""Check if any prediction has flipped from one class to another."""
return ((orig_preds.preds > 0.5) != (new_preds.preds > 0.5)).any().item()
# Function to extract uploaded zip or tar.gz file
def extract_archive(archive_file, extract_dir):
"""Extract a zip or tar.gz file to the specified directory"""
# Create a temporary directory for extraction
temp_dir = Path(extract_dir)
temp_dir.mkdir(parents=True, exist_ok=True)
# Check file extension
file_path = Path(archive_file)
if file_path.suffix.lower() == '.zip':
# Extract the zip file
with zipfile.ZipFile(archive_file, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
elif file_path.name.endswith('.tar.gz') or file_path.name.endswith('.tgz'):
# Extract the tar.gz file
with tarfile.open(archive_file, 'r:gz') as tar_ref:
tar_ref.extractall(temp_dir)
else:
raise ValueError(f"Unsupported archive format: {file_path.suffix}. Please use .zip or .tar.gz")
# Check if the extracted content has class0 and class1 folders
# If not, try to find them in subdirectories
class0_dir = temp_dir / "class0"
class1_dir = temp_dir / "class1"
if not (class0_dir.exists() and class1_dir.exists()):
# Look for class0 and class1 in subdirectories
for subdir in temp_dir.iterdir():
if subdir.is_dir():
if (subdir / "class0").exists() and (subdir / "class1").exists():
# Move the class directories to the temp_dir
shutil.move(str(subdir / "class0"), str(class0_dir))
shutil.move(str(subdir / "class1"), str(class1_dir))
break
# Verify that we have the required directories
if not (class0_dir.exists() and class1_dir.exists()):
raise ValueError("The uploaded archive must contain 'class0' and 'class1' directories or a subdirectory containing them")
return str(temp_dir)
# Function to handle cached results (placeholder implementation)
def get_cached_result_info(name):
"""Get information about a cached result (placeholder)"""
# This is a placeholder - in a real implementation, you'd store and retrieve cached results
return None
# Modify the TwoClassDataset class to accept num_samples_per_class as a parameter
class TwoClassDataset(Dataset):
def __init__(self, root_dir, transform=None, num_samples_per_class=None):
self.root_dir = Path(root_dir)
#import pdb; pdb.set_trace()
self.transform = transform
if 'kermany' in str(self.root_dir):
#import pdb; pdb.set_trace()
self.class0_dir = self.root_dir / "NORMAL"
self.class1_dir = self.root_dir / "DRUSEN"
elif 'kiki_bouba' in str(self.root_dir):
self.class0_dir = self.root_dir / "kiki"
self.class1_dir = self.root_dir / "bouba"
elif 'afhq' in str(self.root_dir):
self.class0_dir = self.root_dir / "dog"
self.class1_dir = self.root_dir / "cat"
else:
self.class0_dir = self.root_dir / "class0"
self.class1_dir = self.root_dir / "class1"
# Get image paths
#import pdb; pdb.set_trace()
self.class0_images = list(self.class0_dir.glob("*.*"))
self.class1_images = list(self.class1_dir.glob("*.*"))
# Limit the number of samples per class if specified
if num_samples_per_class is not None:
self.class0_images = self.class0_images[:num_samples_per_class]
self.class1_images = self.class1_images[:num_samples_per_class]
# Create image list and labels
self.images = self.class0_images + self.class1_images
self.labels = [0] * len(self.class0_images) + [1] * len(self.class1_images)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
image = Image.open(img_path).convert("RGB")
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image, label, str(img_path)
def compute_lpips_similarity(images1, images2, reduction=None):
"""Compute LPIPS similarity between two batches of images"""
# This is a placeholder - in a real implementation, you'd use a proper LPIPS model
# For demo purposes, we'll just return a random similarity score
batch_size = images1.shape[0]
similarity = torch.rand(batch_size, device=images1.device)
if reduction == "mean":
return similarity.mean()
return similarity
def get_direction_sign(idx: int):
if idx == 0:
sign = -1
elif idx == 1:
sign = 1
else:
raise ValueError("Currently two direction are supported in this script")
return sign
def add_text_to_image(image, text):
"""Add text to an image at the top with a nicer design"""
draw = ImageDraw.Draw(image)
# Use a default font
try:
font = ImageFont.truetype("arial.ttf", 24)
except:
font = ImageFont.load_default()
# Add a semi-transparent gradient background for better readability
text_width, text_height = draw.textsize(text, font=font) if hasattr(draw, 'textsize') else (200, 30)
# Create gradient background
for i in range(40):
alpha = int(180 - i * 4) # Fade from 180 to 20 alpha
if alpha < 0:
alpha = 0
draw.rectangle([(0, i), (image.width, i)], fill=(0, 0, 0, alpha))
# Draw text at the top of the image
draw.text((15, 10), text, fill="white", font=font)
return image
def create_gif(img1, img2, output_path):
"""Create a GIF that alternates between two images with elegant labels"""
# Create copies of the images to avoid modifying the originals
img1_copy = img1.copy()
img2_copy = img2.copy()
# Add labels to the images
draw1 = ImageDraw.Draw(img1_copy)
draw2 = ImageDraw.Draw(img2_copy)
try:
# Use a larger font size for better visibility
font = ImageFont.truetype("arial.ttf", 36) # Increased from 28 to 36
except:
font = ImageFont.load_default()
# Add a subtle shadow effect for better visibility
padding = 15
# Original image - add text with shadow effect
# First draw shadow/outline
for offset in [(1,1), (-1,1), (1,-1), (-1,-1)]:
draw1.text(
(padding + offset[0], padding + offset[1]),
"Original",
fill=(0, 0, 0, 180),
font=font
)
# Then draw the main text
draw1.text(
(padding, padding),
"Original",
fill=(255, 255, 255, 230),
font=font
)
# Generated image - add text with shadow effect
# First draw shadow/outline
for offset in [(1,1), (-1,1), (1,-1), (-1,-1)]:
draw2.text(
(padding + offset[0], padding + offset[1]),
"Generated",
fill=(0, 0, 0, 180),
font=font
)
# Then draw the main text
draw2.text(
(padding, padding),
"Generated",
fill=(255, 255, 255, 230),
font=font
)
# Increase duration to 1 second per image (1000ms)
imageio.mimsave(output_path, [img1_copy, img2_copy], duration=1000, loop=0)
return output_path
# Modify the update_progress_status function to be more informative
def update_progress_status():
"""Update the progress status for the counterfactual generation"""
global images_processed, total_images_to_process, is_processing
if not is_processing:
if images_processed > 0:
return f"Processing complete. Generated {images_processed} counterfactual images."
return "Ready to process images."
if total_images_to_process == 0:
return "Preparing to process images..."
percentage = (images_processed / total_images_to_process) * 100
return f"Progress: {images_processed}/{total_images_to_process} images processed ({percentage:.1f}%)"
# Add function to cancel generation
def cancel_generation():
"""Cancel all ongoing processes"""
global generation_should_stop, classifier_should_stop, embedding_should_stop, lora_should_stop
# Set all stop flags
generation_should_stop = True
classifier_should_stop = True
embedding_should_stop = True
lora_should_stop = True
return "All processes have been requested to stop. This may take a moment to complete."
def save_results_to_cache(output_dir, cache_key):
"""Save generated results to cache directory"""
cache_path = CACHE_DIR / cache_key
cache_path.mkdir(exist_ok=True, parents=True)
# Copy gifs directory
output_gifs_dir = Path(output_dir) / "gifs"
cache_gifs_dir = cache_path / "gifs"
if output_gifs_dir.exists():
# Remove existing cache if it exists
if cache_gifs_dir.exists():
shutil.rmtree(cache_gifs_dir)
# Copy the new results, maintaining subdirectory structure
shutil.copytree(output_gifs_dir, cache_gifs_dir)
# Copy context images if they exist
output_context_dir = Path(output_dir) / "context"
cache_context_dir = cache_path / "context"
if output_context_dir.exists():
if cache_context_dir.exists():
shutil.rmtree(cache_context_dir)
shutil.copytree(output_context_dir, cache_context_dir)
# Update the process_with_selected_dataset function to handle the new directory structure
def process_with_selected_dataset(zip_file, output_dir, dataset_display_name, checkpoint_path=None, train_clf=True,
is_direct_path=False, direct_path=None, embeddings_path=None,
classifier_path=None, use_classifier_stopping=True, custom_tskip=85,
manip_val=2):
print(f"\nProcessing with dataset: {dataset_display_name}")
# Find the selected dataset
selected_dataset = None
for dataset in EXAMPLE_DATASETS:
if dataset["display_name"] == dataset_display_name:
selected_dataset = dataset
break
if not selected_dataset:
print("Error: No dataset selected")
return "No dataset selected", [], [], [], "Error: No dataset selected", None, None
# Generate cache key
cache_key = get_cache_key(
selected_dataset["name"], checkpoint_path, False, embeddings_path,
classifier_path, use_classifier_stopping, custom_tskip,
manip_val,
)
print(f"Generated cache key: {cache_key}")
# Check if cache exists
cache_path = CACHE_DIR / cache_key
dataset_dir = cache_path / "gifs"
print(f"Looking for cache in: {cache_path}")
print(f"Looking for gifs in: {dataset_dir}")
print(f"Cache exists: {cache_path.exists()}")
print(f"Gifs dir exists: {dataset_dir.exists()}")
#import pdb; pdb.set_trace()
if cache_path.exists() and dataset_dir.exists():
current_cache_key = cache_key
print(f"Found cached results for key: {cache_key}")
# Get paths to class-specific directories
class0_to_class1_dir = dataset_dir / "class0_to_class1"
class1_to_class0_dir = dataset_dir / "class1_to_class0"
context_dir = cache_path/ "context"
# Get all GIF paths
class0_to_class1_gifs = list(class0_to_class1_dir.glob("*.gif")) if class0_to_class1_dir.exists() else []
class1_to_class0_gifs = list(class1_to_class0_dir.glob("*.gif")) if class1_to_class0_dir.exists() else []
# Sort the GIFs by filename for consistent ordering
class0_to_class1_gifs.sort(key=lambda p: p.name)
class1_to_class0_gifs.sort(key=lambda p: p.name)
# Get context images
class0_context = context_dir / "class0_sample.jpg" if (context_dir / "class0_sample.jpg").exists() else None
class1_context = context_dir / "class1_sample.jpg" if (context_dir / "class1_sample.jpg").exists() else None
# Convert paths to strings
class0_to_class1_paths = [str(p) for p in class0_to_class1_gifs]
class1_to_class0_paths = [str(p) for p in class1_to_class0_gifs]
all_gifs = class0_to_class1_paths + class1_to_class0_paths
# Update the global gallery variables
global displayed_results, displayed_results_class0_to_class1, displayed_results_class1_to_class0
displayed_results = all_gifs
displayed_results_class0_to_class1 = class0_to_class1_paths
displayed_results_class1_to_class0 = class1_to_class0_paths
status_message = f"Using cached results with t-skip={custom_tskip}, manip_scale={manip_val}"
# Return cached results
return (
"Using cached results for default parameters.",
displayed_results,
displayed_results_class0_to_class1,
displayed_results_class1_to_class0,
status_message,
str(class0_context) if class0_context else None,
str(class1_context) if class1_context else None
)
else:
print("No cached results found, processing dataset...")
return "No cached results found, processing dataset...", [], [], [], "No cached results found, processing dataset...", None, None
return
# def process_and_clear(example_datasets_dropdown, checkpoint_path_state,
# is_direct_path_state, direct_path_state, embeddings_path_state,
# classifier_path_state, use_classifier_stopping, custom_tskip,
# manip_val):
# """Clear folders first, then process the dataset"""
# # Clear folders first
# clear_output_folders()
# # Then process the dataset
# return process_with_selected_dataset(
# None, # input_zip (always None)
# "./output", # output_dir (hardcoded)
# example_datasets_dropdown,
# checkpoint_path_state,
# False, # train_clf (always False)
# is_direct_path_state,
# direct_path_state,
# embeddings_path_state,
# classifier_path_state,
# use_classifier_stopping,
# custom_tskip,
# manip_val
# )
def process_and_clear(example_datasets_dropdown, checkpoint_path_state,
is_direct_path_state, direct_path_state, embeddings_path_state,
classifier_path_state, use_classifier_stopping, custom_tskip,
manip_val):
"""Clear galleries first, then process the dataset"""
# Clear galleries but keep example images
clear_output_folders()
# Process the dataset
result = process_with_selected_dataset(
None, # input_zip (always None)
"./output", # output_dir (hardcoded)
example_datasets_dropdown,
checkpoint_path_state,
False, # train_clf (always False)
is_direct_path_state,
direct_path_state,
embeddings_path_state,
classifier_path_state,
use_classifier_stopping,
custom_tskip,
manip_val
)
# Return all outputs except example images
return (
result[1], # gallery
result[2], # gallery_class0_to_class1
result[3], # gallery_class1_to_class0
result[4], # progress_status # Don't update class1_context_image
)
def update_example_images(dataset_display_name):
"""Update the example images based on the selected dataset"""
print(f"\nUpdating example images for {dataset_display_name}")
# Find the dataset info
selected_dataset = None
for dataset in EXAMPLE_DATASETS:
print(f"Checking dataset: {dataset['display_name']}", dataset_display_name)
if dataset["display_name"] == dataset_display_name:
selected_dataset = dataset
print(f"Selected dataset: {selected_dataset}")
break
class_names = selected_dataset.get("class_names", None)
if selected_dataset:
dataset_dir = selected_dataset.get("direct_dataset_path")
print(f"Dataset directory: {dataset_dir}")
if dataset_dir:
# Debug: List all files in the directory
print("Contents of directory:")
for path in Path(dataset_dir).rglob("*"):
print(f" {path}")
# Try to find class0 and class1 images
class0_path = Path(dataset_dir) / class_names[0]
class1_path = Path(dataset_dir) / class_names[1]
print(f"Looking in class0: {class0_path}")
print(f"Looking in class1: {class1_path}")
class0_img = next((str(p) for p in Path(dataset_dir).glob(f"{class_names[0]}/*.*")), None)
class1_img = next((str(p) for p in Path(dataset_dir).glob(f"{class_names[1]}/*.*")), None)
print(f"Found images:\nclass0={class0_img}\nclass1={class1_img}")
return class0_img, class1_img
print("No images found")
return None, None
# Add a state variable to store the direct dataset path
direct_path_state = gr.State(None)
# Map display names back to internal names (add this back)
def get_name_from_display(display_name):
for dataset in EXAMPLE_DATASETS:
if dataset["display_name"] == display_name:
return dataset["name"]
return None
# Modify the use_selected_dataset function
def use_selected_dataset(display_name):
name = get_name_from_display(display_name)
if not name:
print("No dataset name found")
return None, None, False, None, None, None
dataset_info = get_example_dataset_info(name)
# Check if there's a direct dataset path available
if dataset_info and "direct_dataset_path" in dataset_info and os.path.exists(dataset_info["direct_dataset_path"]):
print(f"Using direct dataset path: {dataset_info['direct_dataset_path']}")
# Return paths for direct dataset, checkpoint, embeddings, and classifiers
return None, dataset_info["checkpoint_path"], True, dataset_info["direct_dataset_path"], \
dataset_info.get("embeddings_path"), dataset_info.get("classifier_path")
elif dataset_info and os.path.exists(dataset_info["path"]):
# Return the archive path and other paths
return dataset_info["path"], dataset_info["checkpoint_path"], False, None, \
dataset_info.get("embeddings_path"), dataset_info.get("classifier_path")
return None, None, False, None, None, None
def reset_galleries():
"""Reset all galleries when changing datasets or parameters"""
global displayed_results, displayed_results_class0_to_class1, displayed_results_class1_to_class0
global current_cache_key # Also reset the cache key
displayed_results = []
displayed_results_class0_to_class1 = []
displayed_results_class1_to_class0 = []
current_cache_key = None # Reset the cache key
# Clear the result queue if it exists
while not result_queue.empty():
result_queue.get()
return [], [], [], "Galleries reset"
def clear_output_folders():
"""Delete the output/gifs and output/context folders and their contents"""
import shutil
from pathlib import Path
# Folders to clear
folders = ["gifs", "context"]
for folder in folders:
folder_path = Path("./output") / folder
if folder_path.exists():
shutil.rmtree(folder_path)
print(f"Deleted {folder_path}")
def create_gradio_interface():
# Create temporary directories for uploads
temp_dir = Path("./temp_uploads")
temp_dir.mkdir(exist_ok=True, parents=True)
clear_output_folders()
lora_temp_dir = Path("./temp_lora_uploads")
lora_temp_dir.mkdir(exist_ok=True, parents=True)
# Get initial list of example datasets
example_datasets = get_example_datasets()
with gr.Blocks(css=css) as demo:
# Add the header at the top level to span across all columns
with gr.Row(elem_classes="full-width-header"):
with gr.Column():
gr.HTML("""
<div class="header">
<h1>DIFFusion Demo</h1>
<p class="subtitle">Generate fine-grained edits to images using another class of images as guidance.</p>
<p class="contact-info">For any questions/comments/issues with this demo, please email [email protected].🤖</p>
</div>
""")
# Main content row with sidebar, config column and results column
with gr.Row(elem_classes="content-row"):
# Sidebar for example datasets
with gr.Column(scale=1, elem_classes="sidebar"):
gr.HTML('<div class="section-header">Example Datasets</div>')
# Create a dropdown for example datasets
example_datasets_dropdown = gr.Dropdown(
choices=[dataset["display_name"] for dataset in EXAMPLE_DATASETS],
value=next((dataset["display_name"] for dataset in EXAMPLE_DATASETS if "lamp" in dataset["display_name"].lower()), None), # Set lamp as default
label="Example Datasets",
info="Select a pre-loaded dataset to use"
)
# Add dataset descriptions directly in the dropdown info
dataset_descriptions = {dataset["display_name"]: dataset.get("description", "") for dataset in EXAMPLE_DATASETS}
# Add some spacing
gr.HTML("<div style='height: 20px;'></div>")
# Add a hidden state for the dataset description (we'll still update it but not display it)
dataset_description = gr.Textbox(visible=False)
# Main content area
with gr.Column(scale=2, elem_classes="main-container"):
# Paper info and configuration
with gr.Column():
with gr.Column(elem_classes="paper-info"):
gr.HTML("""
<h3>DIFFusion Demo</h3>
<p>Text-based AI image editing can be tricky, as language often fails to capture precise visual ideas, and users may not always know what they want. Our image-guided editing method learns transformations directly from the differences between two image groups, removing the need for detailed verbal descriptions. Designed for scientific applications, it highlights subtle differences in visually similar image categories. It also applies to nicely to marketing, adapting new products into scenes by managing small interior design details. Choose between four example datasets, then adjust the tskip (higher = less edit) and manipulation scalar (higher = more edit) to explore the editing effects. A Gradio demo in our GitHub code release lets users upload datasets and try the method (GPU required).</p>
""")
# Counterfactual Generation Section
gr.HTML('<div class="section-header">Counterfactual Generation</div>')
# with gr.Column(elem_classes="upload-info"):
# gr.HTML("""
# <p><strong>Dataset Format:</strong> Upload a zip file containing two folders named 'class0' and 'class1',
# each containing images of the respective class.</p>
# """)
# with gr.Row():
# input_zip = gr.File(
# label="Upload Custom Dataset (ZIP or TAR.GZ file)",
# file_types=[".zip", ".tar.gz", ".tgz"],
# type="filepath"
# )
# # Hide the output directory by using elem_classes
# output_dir = gr.Textbox(
# label="Output Directory",
# value="./output",
# elem_classes="hidden-element"
# )
# with gr.Row():
# gr.HTML('<div class="section-header">LoRA Training</div>')
# with gr.Column(elem_classes="upload-info"):
# gr.HTML("""
# <p><strong>Dataset Format:</strong> Upload a zip file containing two folders named 'class0' and 'class1',
# each containing images of the respective class for training the LoRA model.</p>
# """)
# with gr.Row():
# lora_output_dir = gr.Textbox(
# label="LoRA Output Directory",
# value="./lora_output"
# )
# gr.HTML("""
# <div class="parameter-box">
# <p>Default LoRA Training Parameters:</p>
# <ul>
# <li>Epochs: 5</li>
# <li>Learning Rate: 1e-4</li>
# <li>Batch Size: 32</li>
# <li>LoRA Rank: 4</li>
# <li>LoRA Alpha: 32</li>
# <li>Max Training Steps: 1000</li>
# </ul>
# </div>
# """)
# train_lora_btn = gr.Button("Train LoRA Model", elem_classes="btn-primary")
# lora_status_box = gr.Textbox(label="LoRA Training Status", value="Ready to train LoRA model")
# train_clf = gr.Checkbox(label="Train New Classifiers", value=False)
with gr.Row():
use_classifier_stopping = gr.State(False)#
custom_tskip = gr.Dropdown(
choices=[55, 60, 65, 70, 75, 80, 85, 90, 95],
value=85, # default value
label="Custom T-Skip Value",
info="Select a t-skip value",
visible=True
)
# Add a text box for number of images per class
with gr.Row():
manip_val = gr.Dropdown(
choices=[1.0, 1.5, 2.0],
value=2.0, # default value
label="Manip scale",
info="Select a manip scale",
visible=True
)
#
with gr.Row():
process_btn = gr.Button("Generate Counterfactuals", elem_classes="btn-primary")
cancel_btn = gr.Button("Cancel Generation", elem_classes="btn-primary")
# Status for the main column
#status = gr.Textbox(label="Status", value="Ready to generate counterfactuals")
# Results column
with gr.Column(scale=2, elem_classes="results-container"):
# Class Examples section header - MOVED HERE
gr.HTML('<div class="section-header">Class Examples</div>')
# Class example images - MOVED HERE
with gr.Row():
class0_context_image = gr.Image(label="Class 0 Example", type="filepath", height=256)
class1_context_image = gr.Image(label="Class 1 Example", type="filepath", height=256)
# Results section header
gr.HTML('<div class="section-header">Results</div>')
default_dataset = next((dataset["display_name"] for dataset in EXAMPLE_DATASETS if "lamps" in dataset["display_name"].lower()), None)
if default_dataset:
# Initial load of example images
class0_img, class1_img = update_example_images(default_dataset)
if class0_img and class1_img:
class0_context_image.value = class0_img # Directly set the value
class1_context_image.value = class1_img
print(f"Class 0 image: {class0_context_image.value}")
print(f"Class 1 image: {class1_context_image.value}")
# Add tabs for different direction signs - make "All Results" the default tab
with gr.Tabs(elem_classes="tabs-container") as result_tabs:
with gr.TabItem("All Results"):
gallery = gr.Gallery(
label="Generated Images",
show_label=False,
elem_id="gallery_all",
columns=4, # Show 4 images per row
rows=None, # Let it adjust rows automatically
height="auto",
allow_preview=True,
preview=False,
object_fit="contain"
)
with gr.TabItem("Class 0 → Class 1"):
gallery_class0_to_class1 = gr.Gallery(
label="Class 0 to Class 1",
show_label=False,
elem_id="gallery_0to1",
columns=4, # Show 4 images per row
rows=None, # Let it adjust rows automatically
height="auto",
allow_preview=True,
preview=True,
object_fit="contain"
)
with gr.TabItem("Class 1 → Class 0"):
gallery_class1_to_class0 = gr.Gallery(
label="Class 1 to Class 0",
show_label=False,
elem_id="gallery_1to0",
columns=4, # Show 4 images per row
rows=None, # Let it adjust rows automatically
height="auto",
allow_preview=True,
preview=True,
object_fit="contain"
)
# with gr.TabItem("All Results"):
# gallery = gr.Gallery(
# columns=[3],
# rows=[3],
# height="auto",
# allow_preview=True, # Make sure this is enabled
# preview=True, # Try setting this explicitly
# object_fit="contain" # Try different fit modes
# )
# with gr.TabItem("Class 0 → Class 1"):
# gallery_class0_to_class1 = gr.Gallery(
# columns=[3],
# rows=[3],
# height="auto",
# allow_preview=True, # Make sure this is enabled
# preview=True, # Try setting this explicitly
# object_fit="contain" # Try different fit modes
# )
# with gr.TabItem("Class 1 → Class 0"):
# gallery_class1_to_class0 = gr.Gallery(
# columns=[3],
# rows=[3],
# height="auto",
# allow_preview=True, # Make sure this is enabled
# preview=True, # Try setting this explicitly
# object_fit="contain" # Try different fit modes
# )
# Add a progress status box in the results column
progress_status = gr.Textbox(
label="Progress",
value="Ready to process",
interactive=False
)
# Define state variables inside the function
#set the default to these to be those for the lamp dataset
default_dataset = next((dataset for dataset in EXAMPLE_DATASETS if "lamp" in dataset["display_name"].lower()), None)
if default_dataset:
checkpoint_path_state = gr.State(default_dataset["checkpoint_path"])
is_direct_path_state = gr.State(False)
direct_path_state = gr.State(None)
embeddings_path_state = gr.State(default_dataset["embeddings_path"])
classifier_path_state = gr.State(default_dataset["classifier_path"])
process_btn.click(
fn=process_and_clear,
inputs=[
example_datasets_dropdown, checkpoint_path_state,
is_direct_path_state, direct_path_state, embeddings_path_state,
classifier_path_state, use_classifier_stopping, custom_tskip,
manip_val
],
outputs=[
gallery, # Make sure these variables are all defined
gallery_class0_to_class1, # and not None
gallery_class1_to_class0,
progress_status
] # Removed 'status' since it wasn't defined
)
# Set up the cancel button click handler
cancel_btn.click(
fn=cancel_generation,
inputs=None,
outputs=None
)
num_images_per_class = gr.State(10)
example_datasets_dropdown.change(
fn=reset_galleries, # Reset galleries but not example images
inputs=None,
outputs=[gallery, gallery_class0_to_class1, gallery_class1_to_class0, progress_status]
).then( # Update dataset info
fn=update_dataset_info,
inputs=example_datasets_dropdown,
outputs=[dataset_description, checkpoint_path_state, is_direct_path_state, direct_path_state,
embeddings_path_state, classifier_path_state, custom_tskip]
).then( # Set custom t-skip
fn=set_custom_tskip_for_dataset,
inputs=example_datasets_dropdown,
outputs=custom_tskip
).then( # Change cache key
fn=change_cache_key,
inputs=[example_datasets_dropdown, num_images_per_class, use_classifier_stopping, custom_tskip],
outputs=None
).then( # Update example images
fn=update_example_images,
inputs=example_datasets_dropdown,
outputs=[class0_context_image, class1_context_image]
).then( # Automatically generate counterfactuals when dataset changes
fn=process_and_clear,
inputs=[
example_datasets_dropdown, checkpoint_path_state,
is_direct_path_state, direct_path_state, embeddings_path_state,
classifier_path_state, use_classifier_stopping, custom_tskip,
manip_val
],
outputs=[gallery, gallery_class0_to_class1, gallery_class1_to_class0,
progress_status]
)
# Load initial example images and generate counterfactuals for default dataset (Lamps)
demo.load(
fn=update_example_images,
inputs=example_datasets_dropdown,
outputs=[class0_context_image, class1_context_image]
).then( # Initial counterfactual generation
fn=process_and_clear,
inputs=[
example_datasets_dropdown, checkpoint_path_state,
is_direct_path_state, direct_path_state, embeddings_path_state,
classifier_path_state, use_classifier_stopping, custom_tskip,
manip_val
],
outputs=[gallery, gallery_class0_to_class1, gallery_class1_to_class0,
progress_status]
)
# example_datasets_dropdown.change(
# fn=reset_galleries, # Reset first
# inputs=None,
# outputs=[gallery, gallery_class0_to_class1, gallery_class1_to_class0, progress_status]
# ).then( # Update dataset info
# fn=update_dataset_info,
# inputs=example_datasets_dropdown,
# outputs=[dataset_description, checkpoint_path_state, is_direct_path_state, direct_path_state,
# embeddings_path_state, classifier_path_state, custom_tskip_state]
# ).then( # Set custom t-skip
# fn=set_custom_tskip_for_dataset,
# inputs=example_datasets_dropdown,
# outputs=custom_tskip
# ).then( # Change cache key
# fn=change_cache_key,
# inputs=[example_datasets_dropdown, manip_val, use_classifier_stopping, custom_tskip],
# outputs=None
# ).then( # Update example images
# fn=lambda display_name: update_example_images(display_name),
# inputs=example_datasets_dropdown,
# outputs=[class0_context_image, class1_context_image]
# )
# process_btn.click(
# fn=process_and_clear,
# inputs=[
# example_datasets_dropdown, checkpoint_path_state,
# is_direct_path_state, direct_path_state, embeddings_path_state,
# classifier_path_state, use_classifier_stopping, custom_tskip,
# manip_val
# ],
# outputs=[status, gallery, gallery_class0_to_class1, gallery_class1_to_class0,
# progress_status, class0_context_image, class1_context_image]
# )
# # Set up the click event for LoRA training
# train_lora_btn.click(
# fn=start_lora_training,
# inputs=[input_zip, lora_output_dir],
# outputs=[lora_status_box]
# )
# # Set up periodic status checking for LoRA training
# demo.load(
# fn=check_lora_status,
# inputs=None,
# outputs=lora_status_box,
# every=5 # Check every 5 seconds
# )
# Add a periodic refresh for the galleries
# Add a periodic refresh for the galleries
# Add this event handler:
# example_datasets_dropdown.change(
# fn=reset_galleries,
# inputs=None,
# outputs=[gallery, gallery_class0_to_class1, gallery_class1_to_class0, progress_status]
# )
return demo
def update_dataset_info(dataset_display_name):
"""Update dataset description and paths when dropdown changes"""
# Find the selected dataset
selected_dataset = None
for dataset in EXAMPLE_DATASETS:
if dataset["display_name"] == dataset_display_name:
selected_dataset = dataset
break
if not selected_dataset:
return "No dataset selected", None, False, None, None, None, None
# Get dataset description
description = selected_dataset.get("description", "No description available")
# Get paths
checkpoint_path = selected_dataset.get("checkpoint_path", None)
direct_path = selected_dataset.get("direct_dataset_path", None)
is_direct_path = direct_path is not None
embeddings_path = selected_dataset.get("embeddings_path", None)
classifier_path = selected_dataset.get("classifier_path", None)
# Set default custom_tskip based on dataset
custom_tskip = None
if "butterfly" in dataset_display_name.lower():
custom_tskip = 70 # Set to 70 for butterfly
elif "lamp" in dataset_display_name.lower():
custom_tskip = 85 # Set to 85 for lamp
print(f"Setting custom_tskip to {custom_tskip} for dataset {dataset_display_name}")
return description, checkpoint_path, is_direct_path, direct_path, embeddings_path, classifier_path, custom_tskip
# Function to generate a cache key based on parameters
def get_cache_key(dataset_name, checkpoint_path, train_clf, embeddings_path,
classifier_path, use_classifier_stopping, custom_tskip, manip_val):
"""Generate a unique cache key based on the processing parameters"""
# Create a dictionary of parameters
params = {
"dataset_name": dataset_name,
"checkpoint_path": str(checkpoint_path),
"train_clf": train_clf,
"embeddings_path": str(embeddings_path),
"classifier_path": str(classifier_path),
"use_classifier_stopping": use_classifier_stopping,
"custom_tskip": custom_tskip,
"manip_val": float(manip_val)
}
print(f"Params: {params}")
# Convert to JSON string and hash
params_str = json.dumps(params, sort_keys=True)
return hashlib.md5(params_str.encode()).hexdigest()
def change_cache_key(dataset_name, manip_val, use_classifier_stopping, custom_tskip):
"""Change the cache key based on the selected dataset"""
global current_cache_key
# Find the selected dataset from EXAMPLE_DATASETS
selected_dataset = None
for dataset in EXAMPLE_DATASETS:
if dataset["display_name"] == dataset_name:
selected_dataset = dataset
break
if not selected_dataset:
print(f"No dataset found for name: {dataset_name}")
return
# Get all parameters from the selected dataset
checkpoint_path = selected_dataset.get("checkpoint_path", None)
embeddings_path = selected_dataset.get("embeddings_path", None)
classifier_path = selected_dataset.get("classifier_path", None)
# Generate and set the cache key
current_cache_key = get_cache_key(
selected_dataset["name"], # Use internal name instead of display name
checkpoint_path,
False, # train_clf is always False
embeddings_path,
classifier_path,
use_classifier_stopping,
custom_tskip,
manip_val
)
# Function to check if cached results exist
def check_cache(cache_key):
"""Check if cached results exist for the given key"""
cache_path = CACHE_DIR / cache_key
return cache_path.exists() and (cache_path / "gifs").exists()
# Add this function to create context images for each class
def create_context_image(image_paths, output_path, title, preferred_index=0):
"""Create a context image showing samples from a class
Args:
image_paths: List of paths to images in the class
output_path: Where to save the context image
title: Title for the image
preferred_index: Index of the preferred image to use (default: 0)
"""
if not image_paths:
# Create a blank image if no samples are available
img = Image.new('RGB', (512, 512), color=(240, 240, 240))
draw = ImageDraw.Draw(img)
try:
font = ImageFont.truetype("arial.ttf", 32)
except:
font = ImageFont.load_default()
draw.text((256, 256), "No samples available", fill=(80, 80, 80), font=font, anchor="mm")
img.save(output_path)
return
# Use the preferred index if available, otherwise use the first image
img_index = min(preferred_index, len(image_paths) - 1)
img = Image.open(image_paths[img_index]).convert("RGB")
img = img.resize((512, 512), Image.LANCZOS)
# Add title
draw = ImageDraw.Draw(img)
try:
font = ImageFont.truetype("arial.ttf", 32)
except:
font = ImageFont.load_default()
# Draw a semi-transparent background for the title
draw.rectangle([(0, 0), (img.width, 50)], fill=(0, 0, 0, 180))
# Save the context image
img.save(output_path)
# Fix the update_custom_tskip function
def update_custom_tskip(tskip_value):
"""Update the custom_tskip input field with the value from the state"""
print(f"Updating custom_tskip input with value: {tskip_value}")
if tskip_value is None:
return ""
return str(tskip_value) # Convert to string for the text input
# Add this function to directly set the custom_tskip based on dataset name
def set_custom_tskip_for_dataset(dataset_name):
"""Set the custom_tskip value based on the selected dataset"""
if dataset_name is None:
return 85
if "butterfly" in dataset_name.lower():
return 70
elif "lamp" in dataset_name.lower():
return 85
else:
return 85
if __name__ == "__main__":
# Uncomment this line to save current results to cache
#save_current_results_to_cache()
demo = create_gradio_interface()
demo.launch()
# Add these functions at the top of the file, after the imports and global variables
# but before any other function definitions
#