import PIL import math import torch import random import os import numpy as np import pandas as pd import gradio as gr import threading import time import zipfile import shutil import glob from pathlib import Path from torch.utils.data import DataLoader, Dataset, random_split import torchvision.transforms.v2 as transforms from PIL import Image, ImageDraw, ImageFont import imageio from tqdm import tqdm import tarfile import queue import hashlib import json # Set seeds for reproducibility torch.manual_seed(0) random.seed(0) np.random.seed(0) # Define constants IMG_SIZE = 512 BATCH_SIZE = 32 DEVICE = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") # Add these global variables after the imports and before the CSS definition # Global variables for LoRA training lora_status = "Ready" lora_is_processing = False # Global variables for generation control generation_should_stop = False classifier_should_stop = False # New flag for classifier training embedding_should_stop = False # New flag for embedding encoding lora_should_stop = False # New flag for LoRA training generation_queue = queue.Queue() is_processing = False # Add this to prevent multiple simultaneous processes # Create temporary directories for uploads temp_dir = Path("./temp_uploads") temp_dir.mkdir(exist_ok=True, parents=True) lora_temp_dir = Path("./temp_lora_uploads") lora_temp_dir.mkdir(exist_ok=True, parents=True) # Create a global queue for real-time updates result_queue = queue.Queue() displayed_results = [] # Keep track of all displayed results # Add these global variables at the top of your file total_images_to_process = 0 images_processed = 0 # Add these global variables after the existing ones displayed_results_class0_to_class1 = [] # Results for class 0 to class 1 displayed_results_class1_to_class0 = [] # Results for class 1 to class 0 # Add global variables for caching CACHE_DIR = Path("./cached_results") CACHE_DIR.mkdir(exist_ok=True, parents=True) # CSS for styling the interface css = """ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap'); body, * { font-family: 'Inter', sans-serif !important; letter-spacing: -0.01em; } .container { max-width: 1360px; margin: auto; padding-top: 2.5rem; padding-bottom: 2.5rem; } .header { text-align: center; margin-bottom: 3rem; padding-bottom: 2rem; border-bottom: 1px solid #f0f0f0; } .header h1 { font-size: 3rem; font-weight: 700; color: #222; letter-spacing: -0.03em; margin-bottom: 1rem; background: linear-gradient(90deg, #B39CD0 0%, #9D8AC7 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; display: inline-block; } .header p { font-size: 1.1rem; color: #333; max-width: 800px; margin: 0 auto; line-height: 1.6; } .subtitle { font-size: 0.95rem; color: #777; max-width: 800px; margin: 0.5rem auto 0; line-height: 1.5; } .contact-info { font-size: 0.8rem; color: #777; margin-top: 15px; padding-top: 10px; border-top: 1px dashed #e0e0e0; width: 80%; margin-left: auto; margin-right: auto; } .paper-info { background-color: #f8f9fa; border-radius: 12px; padding: 1.8rem; margin: 1.8rem 0; box-shadow: 0 6px 20px rgba(0,0,0,0.05); border-left: 4px solid #B39CD0; } .paper-info h3 { font-size: 1.5rem; font-weight: 600; color: #B39CD0; letter-spacing: -0.02em; margin-bottom: 1rem; } .paper-info p { font-size: 1.05em; line-height: 1.7; color: #333; } .section-header { font-size: 1.8rem; font-weight: 600; color: #B39CD0; margin: 2.5rem 0 1.5rem 0; padding-bottom: 0.8rem; border-bottom: 2px solid #ECF0F1; letter-spacing: -0.02em; } .footer { text-align: center; margin-top: 3rem; padding: 1.5rem; border-top: 1px solid #ECF0F1; color: #666; background-color: #f8f9fa; border-radius: 0 0 12px 12px; } .btn-primary { background-color: #B39CD0 !important; border-color: #B39CD0 !important; transition: all 0.3s ease; font-weight: 500 !important; letter-spacing: 0.02em !important; padding: 0.6rem 1.5rem !important; border-radius: 8px !important; } .btn-primary:hover { background-color: #9D8AC7 !important; border-color: #9D8AC7 !important; } /* Hide the output directory */ .hidden-element { display: none !important; } /* Additional CSS for better alignment */ .container { padding: 0 1.5rem; } .main-container { display: flex; flex-direction: column; gap: 1.5rem; } .results-container { margin-top: 0; padding-top: 0; } .full-width-header { margin-bottom: 2rem; padding-bottom: 1.5rem; border-bottom: 1px solid #f0f0f0; text-align: center; } .content-row { display: flex; gap: 2rem; } .sidebar { min-width: 250px; padding-right: 1.5rem; } .section-header { margin-top: 0; } .tabs-container { margin-top: 1rem; } .gallery-container { margin-top: 1rem; } /* Hide the output directory */ .hidden-element { display: none !important; } .gallery-item img { object-fit: contain !important; height: 200px !important; width: auto !important; } /* Force GIFs to restart when tab is selected */ .tabs-container .tabitem[style*="display: block"] .gallery-container img { animation: none; animation: reload-animation 0.1s; } @keyframes reload-animation { 0% { opacity: 0.99; } 100% { opacity: 1; } } """ # Add to your global variables current_cache_key = None is_using_default_params = False # Update the EXAMPLE_DATASETS to include direct dataset paths, embeddings, and classifiers EXAMPLE_DATASETS = [ { "name": "butterfly", "display_name": "Butterfly (Monarch vs Viceroy)", "description": "Dataset containing images of Monarch and Viceroy butterflies for counterfactual generation", "path": "/proj/vondrick/datasets/magnification/butterfly.tar.gz", "direct_dataset_path": "example_images/butterfly", "checkpoint_path": "/proj/vondrick2/mia/magnificationold/output/lora/butterfly/copper-forest-49/checkpoint-1800", "embeddings_path": "/proj/vondrick2/mia/diff-usion/results/clip_image_embeds/butterfly", "classifier_path": "/proj/vondrick2/mia/diff-usion/results/ensemble/butterfly", "class_names": ["class0", "class1"] }, { "name": "afhq", "display_name": "Cats vs. Dogs (AFHQ)", "description": "Dataset containing images of table lamps and floor lamps", "direct_dataset_path": "example_images/afhq", "checkpoint_path": None, "embeddings_path": "/proj/vondrick2/mia/diff-usion/results/clip_image_embeds/afhq", "classifier_path": "/proj/vondrick2/mia/diff-usion/results/ensemble/afhq", "class_names": ["class0", "class1"] }, { "name": "lamp", "display_name": "Lamps", "description": "Dataset containing images of table lamps and floor lamps", "path": "compressed_datasets/lampsfar.zip", "direct_dataset_path": "example_images/lamps", "checkpoint_path": "/proj/vondrick2/mia/diff-usion/lora_output_lampsfar/checkpoint-800", "embeddings_path": "/proj/vondrick2/mia/diff-usion/results/clip_image_embeds/lampsfar", "classifier_path": "/proj/vondrick2/mia/diff-usion/results/ensemble/lampsfar", "class_names": ["class0", "class1"] }, { "name": "couches", "display_name": "Couches", "description": "Dataset containing images of chairs and floor", "path": "compressed_datasets/couches.zip", "direct_dataset_path": "example_images/couches", "embeddings_path": "/proj/vondrick2/mia/diff-usion/results/clip_image_embeds/couches", "checkpoint_path": "/proj/vondrick2/mia/diff-usion/lora_output/couches/checkpoint-1000", "class_names": ["class0", "class1"] } ] # Function to get available example datasets def get_example_datasets(): """Get list of available example datasets""" return [dataset["name"] for dataset in EXAMPLE_DATASETS] # Function to get example dataset info def get_example_dataset_info(name): """Get information about an example dataset""" for dataset in EXAMPLE_DATASETS: if dataset["name"] == name: return dataset return None #Function to check if we're using default parameters def is_using_default_params(dataset_name, custom_tskip, num_images_per_class): """Check if we're using default parameters for the given dataset""" if dataset_name is None: return False if "butterfly" in dataset_name.lower(): return (custom_tskip == 70 or custom_tskip == "70") and num_images_per_class == 10 elif "lamp" in dataset_name.lower(): return (custom_tskip == 85 or custom_tskip == "85") and num_images_per_class == 10 elif "couch" in dataset_name.lower(): return (custom_tskip == 85 or custom_tskip == "85") and num_images_per_class == 10 return False # # Function to get the output directory - either cache or regular output # def get_output_directory(dataset_name, is_default_params, cache_key): # """Get the appropriate output directory based on parameters""" # if is_default_params: # # Use cache directory # cache_path = CACHE_DIR / cache_key # cache_path.mkdir(exist_ok=True, parents=True) # # Create dataset-specific directory # dataset_dir = cache_path / dataset_name.replace(" ", "_").lower() # dataset_dir.mkdir(exist_ok=True, parents=True) # # Create class-specific directories # class0_to_class1_dir = dataset_dir / "class0_to_class1" # class1_to_class0_dir = dataset_dir / "class1_to_class0" # class0_to_class1_dir.mkdir(exist_ok=True, parents=True) # class1_to_class0_dir.mkdir(exist_ok=True, parents=True) # # Create context directory # context_dir = dataset_dir / "context" # context_dir.mkdir(exist_ok=True, parents=True) # return dataset_dir, class0_to_class1_dir, class1_to_class0_dir, context_dir # else: # # Use regular output directory # output_dir = Path(f"./results/{dataset_name.replace(' ', '_').lower()}") # output_dir.mkdir(exist_ok=True, parents=True) # # Create gifs directory with class-specific subdirectories # gifs_dir = output_dir / "gifs" # gifs_dir.mkdir(exist_ok=True, parents=True) # class0_to_class1_dir = gifs_dir / "class0_to_class1" # class1_to_class0_dir = gifs_dir / "class1_to_class0" # class0_to_class1_dir.mkdir(exist_ok=True, parents=True) # class1_to_class0_dir.mkdir(exist_ok=True, parents=True) # # Create context directory # context_dir = output_dir / "context" # context_dir.mkdir(exist_ok=True, parents=True) # return output_dir, class0_to_class1_dir, class1_to_class0_dir, context_dir def has_prediction_flipped(orig_preds, new_preds): """Check if any prediction has flipped from one class to another.""" return ((orig_preds.preds > 0.5) != (new_preds.preds > 0.5)).any().item() # Function to extract uploaded zip or tar.gz file def extract_archive(archive_file, extract_dir): """Extract a zip or tar.gz file to the specified directory""" # Create a temporary directory for extraction temp_dir = Path(extract_dir) temp_dir.mkdir(parents=True, exist_ok=True) # Check file extension file_path = Path(archive_file) if file_path.suffix.lower() == '.zip': # Extract the zip file with zipfile.ZipFile(archive_file, 'r') as zip_ref: zip_ref.extractall(temp_dir) elif file_path.name.endswith('.tar.gz') or file_path.name.endswith('.tgz'): # Extract the tar.gz file with tarfile.open(archive_file, 'r:gz') as tar_ref: tar_ref.extractall(temp_dir) else: raise ValueError(f"Unsupported archive format: {file_path.suffix}. Please use .zip or .tar.gz") # Check if the extracted content has class0 and class1 folders # If not, try to find them in subdirectories class0_dir = temp_dir / "class0" class1_dir = temp_dir / "class1" if not (class0_dir.exists() and class1_dir.exists()): # Look for class0 and class1 in subdirectories for subdir in temp_dir.iterdir(): if subdir.is_dir(): if (subdir / "class0").exists() and (subdir / "class1").exists(): # Move the class directories to the temp_dir shutil.move(str(subdir / "class0"), str(class0_dir)) shutil.move(str(subdir / "class1"), str(class1_dir)) break # Verify that we have the required directories if not (class0_dir.exists() and class1_dir.exists()): raise ValueError("The uploaded archive must contain 'class0' and 'class1' directories or a subdirectory containing them") return str(temp_dir) # Function to handle cached results (placeholder implementation) def get_cached_result_info(name): """Get information about a cached result (placeholder)""" # This is a placeholder - in a real implementation, you'd store and retrieve cached results return None # Modify the TwoClassDataset class to accept num_samples_per_class as a parameter class TwoClassDataset(Dataset): def __init__(self, root_dir, transform=None, num_samples_per_class=None): self.root_dir = Path(root_dir) #import pdb; pdb.set_trace() self.transform = transform if 'kermany' in str(self.root_dir): #import pdb; pdb.set_trace() self.class0_dir = self.root_dir / "NORMAL" self.class1_dir = self.root_dir / "DRUSEN" elif 'kiki_bouba' in str(self.root_dir): self.class0_dir = self.root_dir / "kiki" self.class1_dir = self.root_dir / "bouba" elif 'afhq' in str(self.root_dir): self.class0_dir = self.root_dir / "dog" self.class1_dir = self.root_dir / "cat" else: self.class0_dir = self.root_dir / "class0" self.class1_dir = self.root_dir / "class1" # Get image paths #import pdb; pdb.set_trace() self.class0_images = list(self.class0_dir.glob("*.*")) self.class1_images = list(self.class1_dir.glob("*.*")) # Limit the number of samples per class if specified if num_samples_per_class is not None: self.class0_images = self.class0_images[:num_samples_per_class] self.class1_images = self.class1_images[:num_samples_per_class] # Create image list and labels self.images = self.class0_images + self.class1_images self.labels = [0] * len(self.class0_images) + [1] * len(self.class1_images) def __len__(self): return len(self.images) def __getitem__(self, idx): img_path = self.images[idx] image = Image.open(img_path).convert("RGB") label = self.labels[idx] if self.transform: image = self.transform(image) return image, label, str(img_path) def compute_lpips_similarity(images1, images2, reduction=None): """Compute LPIPS similarity between two batches of images""" # This is a placeholder - in a real implementation, you'd use a proper LPIPS model # For demo purposes, we'll just return a random similarity score batch_size = images1.shape[0] similarity = torch.rand(batch_size, device=images1.device) if reduction == "mean": return similarity.mean() return similarity def get_direction_sign(idx: int): if idx == 0: sign = -1 elif idx == 1: sign = 1 else: raise ValueError("Currently two direction are supported in this script") return sign def add_text_to_image(image, text): """Add text to an image at the top with a nicer design""" draw = ImageDraw.Draw(image) # Use a default font try: font = ImageFont.truetype("arial.ttf", 24) except: font = ImageFont.load_default() # Add a semi-transparent gradient background for better readability text_width, text_height = draw.textsize(text, font=font) if hasattr(draw, 'textsize') else (200, 30) # Create gradient background for i in range(40): alpha = int(180 - i * 4) # Fade from 180 to 20 alpha if alpha < 0: alpha = 0 draw.rectangle([(0, i), (image.width, i)], fill=(0, 0, 0, alpha)) # Draw text at the top of the image draw.text((15, 10), text, fill="white", font=font) return image def create_gif(img1, img2, output_path): """Create a GIF that alternates between two images with elegant labels""" # Create copies of the images to avoid modifying the originals img1_copy = img1.copy() img2_copy = img2.copy() # Add labels to the images draw1 = ImageDraw.Draw(img1_copy) draw2 = ImageDraw.Draw(img2_copy) try: # Use a larger font size for better visibility font = ImageFont.truetype("arial.ttf", 36) # Increased from 28 to 36 except: font = ImageFont.load_default() # Add a subtle shadow effect for better visibility padding = 15 # Original image - add text with shadow effect # First draw shadow/outline for offset in [(1,1), (-1,1), (1,-1), (-1,-1)]: draw1.text( (padding + offset[0], padding + offset[1]), "Original", fill=(0, 0, 0, 180), font=font ) # Then draw the main text draw1.text( (padding, padding), "Original", fill=(255, 255, 255, 230), font=font ) # Generated image - add text with shadow effect # First draw shadow/outline for offset in [(1,1), (-1,1), (1,-1), (-1,-1)]: draw2.text( (padding + offset[0], padding + offset[1]), "Generated", fill=(0, 0, 0, 180), font=font ) # Then draw the main text draw2.text( (padding, padding), "Generated", fill=(255, 255, 255, 230), font=font ) # Increase duration to 1 second per image (1000ms) imageio.mimsave(output_path, [img1_copy, img2_copy], duration=1000, loop=0) return output_path # Modify the update_progress_status function to be more informative def update_progress_status(): """Update the progress status for the counterfactual generation""" global images_processed, total_images_to_process, is_processing if not is_processing: if images_processed > 0: return f"Processing complete. Generated {images_processed} counterfactual images." return "Ready to process images." if total_images_to_process == 0: return "Preparing to process images..." percentage = (images_processed / total_images_to_process) * 100 return f"Progress: {images_processed}/{total_images_to_process} images processed ({percentage:.1f}%)" # Add function to cancel generation def cancel_generation(): """Cancel all ongoing processes""" global generation_should_stop, classifier_should_stop, embedding_should_stop, lora_should_stop # Set all stop flags generation_should_stop = True classifier_should_stop = True embedding_should_stop = True lora_should_stop = True return "All processes have been requested to stop. This may take a moment to complete." def save_results_to_cache(output_dir, cache_key): """Save generated results to cache directory""" cache_path = CACHE_DIR / cache_key cache_path.mkdir(exist_ok=True, parents=True) # Copy gifs directory output_gifs_dir = Path(output_dir) / "gifs" cache_gifs_dir = cache_path / "gifs" if output_gifs_dir.exists(): # Remove existing cache if it exists if cache_gifs_dir.exists(): shutil.rmtree(cache_gifs_dir) # Copy the new results, maintaining subdirectory structure shutil.copytree(output_gifs_dir, cache_gifs_dir) # Copy context images if they exist output_context_dir = Path(output_dir) / "context" cache_context_dir = cache_path / "context" if output_context_dir.exists(): if cache_context_dir.exists(): shutil.rmtree(cache_context_dir) shutil.copytree(output_context_dir, cache_context_dir) # Update the process_with_selected_dataset function to handle the new directory structure def process_with_selected_dataset(zip_file, output_dir, dataset_display_name, checkpoint_path=None, train_clf=True, is_direct_path=False, direct_path=None, embeddings_path=None, classifier_path=None, use_classifier_stopping=True, custom_tskip=85, manip_val=2): print(f"\nProcessing with dataset: {dataset_display_name}") # Find the selected dataset selected_dataset = None for dataset in EXAMPLE_DATASETS: if dataset["display_name"] == dataset_display_name: selected_dataset = dataset break if not selected_dataset: print("Error: No dataset selected") return "No dataset selected", [], [], [], "Error: No dataset selected", None, None # Generate cache key cache_key = get_cache_key( selected_dataset["name"], checkpoint_path, False, embeddings_path, classifier_path, use_classifier_stopping, custom_tskip, manip_val, ) print(f"Generated cache key: {cache_key}") # Check if cache exists cache_path = CACHE_DIR / cache_key dataset_dir = cache_path / "gifs" print(f"Looking for cache in: {cache_path}") print(f"Looking for gifs in: {dataset_dir}") print(f"Cache exists: {cache_path.exists()}") print(f"Gifs dir exists: {dataset_dir.exists()}") #import pdb; pdb.set_trace() if cache_path.exists() and dataset_dir.exists(): current_cache_key = cache_key print(f"Found cached results for key: {cache_key}") # Get paths to class-specific directories class0_to_class1_dir = dataset_dir / "class0_to_class1" class1_to_class0_dir = dataset_dir / "class1_to_class0" context_dir = cache_path/ "context" # Get all GIF paths class0_to_class1_gifs = list(class0_to_class1_dir.glob("*.gif")) if class0_to_class1_dir.exists() else [] class1_to_class0_gifs = list(class1_to_class0_dir.glob("*.gif")) if class1_to_class0_dir.exists() else [] # Sort the GIFs by filename for consistent ordering class0_to_class1_gifs.sort(key=lambda p: p.name) class1_to_class0_gifs.sort(key=lambda p: p.name) # Get context images class0_context = context_dir / "class0_sample.jpg" if (context_dir / "class0_sample.jpg").exists() else None class1_context = context_dir / "class1_sample.jpg" if (context_dir / "class1_sample.jpg").exists() else None # Convert paths to strings class0_to_class1_paths = [str(p) for p in class0_to_class1_gifs] class1_to_class0_paths = [str(p) for p in class1_to_class0_gifs] all_gifs = class0_to_class1_paths + class1_to_class0_paths # Update the global gallery variables global displayed_results, displayed_results_class0_to_class1, displayed_results_class1_to_class0 displayed_results = all_gifs displayed_results_class0_to_class1 = class0_to_class1_paths displayed_results_class1_to_class0 = class1_to_class0_paths status_message = f"Using cached results with t-skip={custom_tskip}, manip_scale={manip_val}" # Return cached results return ( "Using cached results for default parameters.", displayed_results, displayed_results_class0_to_class1, displayed_results_class1_to_class0, status_message, str(class0_context) if class0_context else None, str(class1_context) if class1_context else None ) else: print("No cached results found, processing dataset...") return "No cached results found, processing dataset...", [], [], [], "No cached results found, processing dataset...", None, None return # def process_and_clear(example_datasets_dropdown, checkpoint_path_state, # is_direct_path_state, direct_path_state, embeddings_path_state, # classifier_path_state, use_classifier_stopping, custom_tskip, # manip_val): # """Clear folders first, then process the dataset""" # # Clear folders first # clear_output_folders() # # Then process the dataset # return process_with_selected_dataset( # None, # input_zip (always None) # "./output", # output_dir (hardcoded) # example_datasets_dropdown, # checkpoint_path_state, # False, # train_clf (always False) # is_direct_path_state, # direct_path_state, # embeddings_path_state, # classifier_path_state, # use_classifier_stopping, # custom_tskip, # manip_val # ) def process_and_clear(example_datasets_dropdown, checkpoint_path_state, is_direct_path_state, direct_path_state, embeddings_path_state, classifier_path_state, use_classifier_stopping, custom_tskip, manip_val): """Clear galleries first, then process the dataset""" # Clear galleries but keep example images clear_output_folders() # Process the dataset result = process_with_selected_dataset( None, # input_zip (always None) "./output", # output_dir (hardcoded) example_datasets_dropdown, checkpoint_path_state, False, # train_clf (always False) is_direct_path_state, direct_path_state, embeddings_path_state, classifier_path_state, use_classifier_stopping, custom_tskip, manip_val ) # Return all outputs except example images return ( result[1], # gallery result[2], # gallery_class0_to_class1 result[3], # gallery_class1_to_class0 result[4], # progress_status # Don't update class1_context_image ) def update_example_images(dataset_display_name): """Update the example images based on the selected dataset""" print(f"\nUpdating example images for {dataset_display_name}") # Find the dataset info selected_dataset = None for dataset in EXAMPLE_DATASETS: print(f"Checking dataset: {dataset['display_name']}", dataset_display_name) if dataset["display_name"] == dataset_display_name: selected_dataset = dataset print(f"Selected dataset: {selected_dataset}") break class_names = selected_dataset.get("class_names", None) if selected_dataset: dataset_dir = selected_dataset.get("direct_dataset_path") print(f"Dataset directory: {dataset_dir}") if dataset_dir: # Debug: List all files in the directory print("Contents of directory:") for path in Path(dataset_dir).rglob("*"): print(f" {path}") # Try to find class0 and class1 images class0_path = Path(dataset_dir) / class_names[0] class1_path = Path(dataset_dir) / class_names[1] print(f"Looking in class0: {class0_path}") print(f"Looking in class1: {class1_path}") class0_img = next((str(p) for p in Path(dataset_dir).glob(f"{class_names[0]}/*.*")), None) class1_img = next((str(p) for p in Path(dataset_dir).glob(f"{class_names[1]}/*.*")), None) print(f"Found images:\nclass0={class0_img}\nclass1={class1_img}") return class0_img, class1_img print("No images found") return None, None # Add a state variable to store the direct dataset path direct_path_state = gr.State(None) # Map display names back to internal names (add this back) def get_name_from_display(display_name): for dataset in EXAMPLE_DATASETS: if dataset["display_name"] == display_name: return dataset["name"] return None # Modify the use_selected_dataset function def use_selected_dataset(display_name): name = get_name_from_display(display_name) if not name: print("No dataset name found") return None, None, False, None, None, None dataset_info = get_example_dataset_info(name) # Check if there's a direct dataset path available if dataset_info and "direct_dataset_path" in dataset_info and os.path.exists(dataset_info["direct_dataset_path"]): print(f"Using direct dataset path: {dataset_info['direct_dataset_path']}") # Return paths for direct dataset, checkpoint, embeddings, and classifiers return None, dataset_info["checkpoint_path"], True, dataset_info["direct_dataset_path"], \ dataset_info.get("embeddings_path"), dataset_info.get("classifier_path") elif dataset_info and os.path.exists(dataset_info["path"]): # Return the archive path and other paths return dataset_info["path"], dataset_info["checkpoint_path"], False, None, \ dataset_info.get("embeddings_path"), dataset_info.get("classifier_path") return None, None, False, None, None, None def reset_galleries(): """Reset all galleries when changing datasets or parameters""" global displayed_results, displayed_results_class0_to_class1, displayed_results_class1_to_class0 global current_cache_key # Also reset the cache key displayed_results = [] displayed_results_class0_to_class1 = [] displayed_results_class1_to_class0 = [] current_cache_key = None # Reset the cache key # Clear the result queue if it exists while not result_queue.empty(): result_queue.get() return [], [], [], "Galleries reset" def clear_output_folders(): """Delete the output/gifs and output/context folders and their contents""" import shutil from pathlib import Path # Folders to clear folders = ["gifs", "context"] for folder in folders: folder_path = Path("./output") / folder if folder_path.exists(): shutil.rmtree(folder_path) print(f"Deleted {folder_path}") def create_gradio_interface(): # Create temporary directories for uploads temp_dir = Path("./temp_uploads") temp_dir.mkdir(exist_ok=True, parents=True) clear_output_folders() lora_temp_dir = Path("./temp_lora_uploads") lora_temp_dir.mkdir(exist_ok=True, parents=True) # Get initial list of example datasets example_datasets = get_example_datasets() with gr.Blocks(css=css) as demo: # Add the header at the top level to span across all columns with gr.Row(elem_classes="full-width-header"): with gr.Column(): gr.HTML("""

DIFFusion Demo

Generate fine-grained edits to images using another class of images as guidance.

For any questions/comments/issues with this demo, please email mia.chiquier@cs.columbia.edu.🤖

""") # Main content row with sidebar, config column and results column with gr.Row(elem_classes="content-row"): # Sidebar for example datasets with gr.Column(scale=1, elem_classes="sidebar"): gr.HTML('
Example Datasets
') # Create a dropdown for example datasets example_datasets_dropdown = gr.Dropdown( choices=[dataset["display_name"] for dataset in EXAMPLE_DATASETS], value=next((dataset["display_name"] for dataset in EXAMPLE_DATASETS if "lamp" in dataset["display_name"].lower()), None), # Set lamp as default label="Example Datasets", info="Select a pre-loaded dataset to use" ) # Add dataset descriptions directly in the dropdown info dataset_descriptions = {dataset["display_name"]: dataset.get("description", "") for dataset in EXAMPLE_DATASETS} # Add some spacing gr.HTML("
") # Add a hidden state for the dataset description (we'll still update it but not display it) dataset_description = gr.Textbox(visible=False) # Main content area with gr.Column(scale=2, elem_classes="main-container"): # Paper info and configuration with gr.Column(): with gr.Column(elem_classes="paper-info"): gr.HTML("""

DIFFusion Demo

Text-based AI image editing can be tricky, as language often fails to capture precise visual ideas, and users may not always know what they want. Our image-guided editing method learns transformations directly from the differences between two image groups, removing the need for detailed verbal descriptions. Designed for scientific applications, it highlights subtle differences in visually similar image categories. It also applies to nicely to marketing, adapting new products into scenes by managing small interior design details. Choose between four example datasets, then adjust the tskip (higher = less edit) and manipulation scalar (higher = more edit) to explore the editing effects. A Gradio demo in our GitHub code release lets users upload datasets and try the method (GPU required).

""") # Counterfactual Generation Section gr.HTML('
Counterfactual Generation
') # with gr.Column(elem_classes="upload-info"): # gr.HTML(""" #

Dataset Format: Upload a zip file containing two folders named 'class0' and 'class1', # each containing images of the respective class.

# """) # with gr.Row(): # input_zip = gr.File( # label="Upload Custom Dataset (ZIP or TAR.GZ file)", # file_types=[".zip", ".tar.gz", ".tgz"], # type="filepath" # ) # # Hide the output directory by using elem_classes # output_dir = gr.Textbox( # label="Output Directory", # value="./output", # elem_classes="hidden-element" # ) # with gr.Row(): # gr.HTML('
LoRA Training
') # with gr.Column(elem_classes="upload-info"): # gr.HTML(""" #

Dataset Format: Upload a zip file containing two folders named 'class0' and 'class1', # each containing images of the respective class for training the LoRA model.

# """) # with gr.Row(): # lora_output_dir = gr.Textbox( # label="LoRA Output Directory", # value="./lora_output" # ) # gr.HTML(""" #
#

Default LoRA Training Parameters:

# #
# """) # train_lora_btn = gr.Button("Train LoRA Model", elem_classes="btn-primary") # lora_status_box = gr.Textbox(label="LoRA Training Status", value="Ready to train LoRA model") # train_clf = gr.Checkbox(label="Train New Classifiers", value=False) with gr.Row(): use_classifier_stopping = gr.State(False)# custom_tskip = gr.Dropdown( choices=[55, 60, 65, 70, 75, 80, 85, 90, 95], value=85, # default value label="Custom T-Skip Value", info="Select a t-skip value", visible=True ) # Add a text box for number of images per class with gr.Row(): manip_val = gr.Dropdown( choices=[1.0, 1.5, 2.0], value=2.0, # default value label="Manip scale", info="Select a manip scale", visible=True ) # with gr.Row(): process_btn = gr.Button("Generate Counterfactuals", elem_classes="btn-primary") cancel_btn = gr.Button("Cancel Generation", elem_classes="btn-primary") # Status for the main column #status = gr.Textbox(label="Status", value="Ready to generate counterfactuals") # Results column with gr.Column(scale=2, elem_classes="results-container"): # Class Examples section header - MOVED HERE gr.HTML('
Class Examples
') # Class example images - MOVED HERE with gr.Row(): class0_context_image = gr.Image(label="Class 0 Example", type="filepath", height=256) class1_context_image = gr.Image(label="Class 1 Example", type="filepath", height=256) # Results section header gr.HTML('
Results
') default_dataset = next((dataset["display_name"] for dataset in EXAMPLE_DATASETS if "lamps" in dataset["display_name"].lower()), None) if default_dataset: # Initial load of example images class0_img, class1_img = update_example_images(default_dataset) if class0_img and class1_img: class0_context_image.value = class0_img # Directly set the value class1_context_image.value = class1_img print(f"Class 0 image: {class0_context_image.value}") print(f"Class 1 image: {class1_context_image.value}") # Add tabs for different direction signs - make "All Results" the default tab with gr.Tabs(elem_classes="tabs-container") as result_tabs: with gr.TabItem("All Results"): gallery = gr.Gallery( label="Generated Images", show_label=False, elem_id="gallery_all", columns=4, # Show 4 images per row rows=None, # Let it adjust rows automatically height="auto", allow_preview=True, preview=False, object_fit="contain" ) with gr.TabItem("Class 0 → Class 1"): gallery_class0_to_class1 = gr.Gallery( label="Class 0 to Class 1", show_label=False, elem_id="gallery_0to1", columns=4, # Show 4 images per row rows=None, # Let it adjust rows automatically height="auto", allow_preview=True, preview=True, object_fit="contain" ) with gr.TabItem("Class 1 → Class 0"): gallery_class1_to_class0 = gr.Gallery( label="Class 1 to Class 0", show_label=False, elem_id="gallery_1to0", columns=4, # Show 4 images per row rows=None, # Let it adjust rows automatically height="auto", allow_preview=True, preview=True, object_fit="contain" ) # with gr.TabItem("All Results"): # gallery = gr.Gallery( # columns=[3], # rows=[3], # height="auto", # allow_preview=True, # Make sure this is enabled # preview=True, # Try setting this explicitly # object_fit="contain" # Try different fit modes # ) # with gr.TabItem("Class 0 → Class 1"): # gallery_class0_to_class1 = gr.Gallery( # columns=[3], # rows=[3], # height="auto", # allow_preview=True, # Make sure this is enabled # preview=True, # Try setting this explicitly # object_fit="contain" # Try different fit modes # ) # with gr.TabItem("Class 1 → Class 0"): # gallery_class1_to_class0 = gr.Gallery( # columns=[3], # rows=[3], # height="auto", # allow_preview=True, # Make sure this is enabled # preview=True, # Try setting this explicitly # object_fit="contain" # Try different fit modes # ) # Add a progress status box in the results column progress_status = gr.Textbox( label="Progress", value="Ready to process", interactive=False ) # Define state variables inside the function #set the default to these to be those for the lamp dataset default_dataset = next((dataset for dataset in EXAMPLE_DATASETS if "lamp" in dataset["display_name"].lower()), None) if default_dataset: checkpoint_path_state = gr.State(default_dataset["checkpoint_path"]) is_direct_path_state = gr.State(False) direct_path_state = gr.State(None) embeddings_path_state = gr.State(default_dataset["embeddings_path"]) classifier_path_state = gr.State(default_dataset["classifier_path"]) process_btn.click( fn=process_and_clear, inputs=[ example_datasets_dropdown, checkpoint_path_state, is_direct_path_state, direct_path_state, embeddings_path_state, classifier_path_state, use_classifier_stopping, custom_tskip, manip_val ], outputs=[ gallery, # Make sure these variables are all defined gallery_class0_to_class1, # and not None gallery_class1_to_class0, progress_status ] # Removed 'status' since it wasn't defined ) # Set up the cancel button click handler cancel_btn.click( fn=cancel_generation, inputs=None, outputs=None ) num_images_per_class = gr.State(10) example_datasets_dropdown.change( fn=reset_galleries, # Reset galleries but not example images inputs=None, outputs=[gallery, gallery_class0_to_class1, gallery_class1_to_class0, progress_status] ).then( # Update dataset info fn=update_dataset_info, inputs=example_datasets_dropdown, outputs=[dataset_description, checkpoint_path_state, is_direct_path_state, direct_path_state, embeddings_path_state, classifier_path_state, custom_tskip] ).then( # Set custom t-skip fn=set_custom_tskip_for_dataset, inputs=example_datasets_dropdown, outputs=custom_tskip ).then( # Change cache key fn=change_cache_key, inputs=[example_datasets_dropdown, num_images_per_class, use_classifier_stopping, custom_tskip], outputs=None ).then( # Update example images fn=update_example_images, inputs=example_datasets_dropdown, outputs=[class0_context_image, class1_context_image] ).then( # Automatically generate counterfactuals when dataset changes fn=process_and_clear, inputs=[ example_datasets_dropdown, checkpoint_path_state, is_direct_path_state, direct_path_state, embeddings_path_state, classifier_path_state, use_classifier_stopping, custom_tskip, manip_val ], outputs=[gallery, gallery_class0_to_class1, gallery_class1_to_class0, progress_status] ) # Load initial example images and generate counterfactuals for default dataset (Lamps) demo.load( fn=update_example_images, inputs=example_datasets_dropdown, outputs=[class0_context_image, class1_context_image] ).then( # Initial counterfactual generation fn=process_and_clear, inputs=[ example_datasets_dropdown, checkpoint_path_state, is_direct_path_state, direct_path_state, embeddings_path_state, classifier_path_state, use_classifier_stopping, custom_tskip, manip_val ], outputs=[gallery, gallery_class0_to_class1, gallery_class1_to_class0, progress_status] ) # example_datasets_dropdown.change( # fn=reset_galleries, # Reset first # inputs=None, # outputs=[gallery, gallery_class0_to_class1, gallery_class1_to_class0, progress_status] # ).then( # Update dataset info # fn=update_dataset_info, # inputs=example_datasets_dropdown, # outputs=[dataset_description, checkpoint_path_state, is_direct_path_state, direct_path_state, # embeddings_path_state, classifier_path_state, custom_tskip_state] # ).then( # Set custom t-skip # fn=set_custom_tskip_for_dataset, # inputs=example_datasets_dropdown, # outputs=custom_tskip # ).then( # Change cache key # fn=change_cache_key, # inputs=[example_datasets_dropdown, manip_val, use_classifier_stopping, custom_tskip], # outputs=None # ).then( # Update example images # fn=lambda display_name: update_example_images(display_name), # inputs=example_datasets_dropdown, # outputs=[class0_context_image, class1_context_image] # ) # process_btn.click( # fn=process_and_clear, # inputs=[ # example_datasets_dropdown, checkpoint_path_state, # is_direct_path_state, direct_path_state, embeddings_path_state, # classifier_path_state, use_classifier_stopping, custom_tskip, # manip_val # ], # outputs=[status, gallery, gallery_class0_to_class1, gallery_class1_to_class0, # progress_status, class0_context_image, class1_context_image] # ) # # Set up the click event for LoRA training # train_lora_btn.click( # fn=start_lora_training, # inputs=[input_zip, lora_output_dir], # outputs=[lora_status_box] # ) # # Set up periodic status checking for LoRA training # demo.load( # fn=check_lora_status, # inputs=None, # outputs=lora_status_box, # every=5 # Check every 5 seconds # ) # Add a periodic refresh for the galleries # Add a periodic refresh for the galleries # Add this event handler: # example_datasets_dropdown.change( # fn=reset_galleries, # inputs=None, # outputs=[gallery, gallery_class0_to_class1, gallery_class1_to_class0, progress_status] # ) return demo def update_dataset_info(dataset_display_name): """Update dataset description and paths when dropdown changes""" # Find the selected dataset selected_dataset = None for dataset in EXAMPLE_DATASETS: if dataset["display_name"] == dataset_display_name: selected_dataset = dataset break if not selected_dataset: return "No dataset selected", None, False, None, None, None, None # Get dataset description description = selected_dataset.get("description", "No description available") # Get paths checkpoint_path = selected_dataset.get("checkpoint_path", None) direct_path = selected_dataset.get("direct_dataset_path", None) is_direct_path = direct_path is not None embeddings_path = selected_dataset.get("embeddings_path", None) classifier_path = selected_dataset.get("classifier_path", None) # Set default custom_tskip based on dataset custom_tskip = None if "butterfly" in dataset_display_name.lower(): custom_tskip = 70 # Set to 70 for butterfly elif "lamp" in dataset_display_name.lower(): custom_tskip = 85 # Set to 85 for lamp print(f"Setting custom_tskip to {custom_tskip} for dataset {dataset_display_name}") return description, checkpoint_path, is_direct_path, direct_path, embeddings_path, classifier_path, custom_tskip # Function to generate a cache key based on parameters def get_cache_key(dataset_name, checkpoint_path, train_clf, embeddings_path, classifier_path, use_classifier_stopping, custom_tskip, manip_val): """Generate a unique cache key based on the processing parameters""" # Create a dictionary of parameters params = { "dataset_name": dataset_name, "checkpoint_path": str(checkpoint_path), "train_clf": train_clf, "embeddings_path": str(embeddings_path), "classifier_path": str(classifier_path), "use_classifier_stopping": use_classifier_stopping, "custom_tskip": custom_tskip, "manip_val": float(manip_val) } print(f"Params: {params}") # Convert to JSON string and hash params_str = json.dumps(params, sort_keys=True) return hashlib.md5(params_str.encode()).hexdigest() def change_cache_key(dataset_name, manip_val, use_classifier_stopping, custom_tskip): """Change the cache key based on the selected dataset""" global current_cache_key # Find the selected dataset from EXAMPLE_DATASETS selected_dataset = None for dataset in EXAMPLE_DATASETS: if dataset["display_name"] == dataset_name: selected_dataset = dataset break if not selected_dataset: print(f"No dataset found for name: {dataset_name}") return # Get all parameters from the selected dataset checkpoint_path = selected_dataset.get("checkpoint_path", None) embeddings_path = selected_dataset.get("embeddings_path", None) classifier_path = selected_dataset.get("classifier_path", None) # Generate and set the cache key current_cache_key = get_cache_key( selected_dataset["name"], # Use internal name instead of display name checkpoint_path, False, # train_clf is always False embeddings_path, classifier_path, use_classifier_stopping, custom_tskip, manip_val ) # Function to check if cached results exist def check_cache(cache_key): """Check if cached results exist for the given key""" cache_path = CACHE_DIR / cache_key return cache_path.exists() and (cache_path / "gifs").exists() # Add this function to create context images for each class def create_context_image(image_paths, output_path, title, preferred_index=0): """Create a context image showing samples from a class Args: image_paths: List of paths to images in the class output_path: Where to save the context image title: Title for the image preferred_index: Index of the preferred image to use (default: 0) """ if not image_paths: # Create a blank image if no samples are available img = Image.new('RGB', (512, 512), color=(240, 240, 240)) draw = ImageDraw.Draw(img) try: font = ImageFont.truetype("arial.ttf", 32) except: font = ImageFont.load_default() draw.text((256, 256), "No samples available", fill=(80, 80, 80), font=font, anchor="mm") img.save(output_path) return # Use the preferred index if available, otherwise use the first image img_index = min(preferred_index, len(image_paths) - 1) img = Image.open(image_paths[img_index]).convert("RGB") img = img.resize((512, 512), Image.LANCZOS) # Add title draw = ImageDraw.Draw(img) try: font = ImageFont.truetype("arial.ttf", 32) except: font = ImageFont.load_default() # Draw a semi-transparent background for the title draw.rectangle([(0, 0), (img.width, 50)], fill=(0, 0, 0, 180)) # Save the context image img.save(output_path) # Fix the update_custom_tskip function def update_custom_tskip(tskip_value): """Update the custom_tskip input field with the value from the state""" print(f"Updating custom_tskip input with value: {tskip_value}") if tskip_value is None: return "" return str(tskip_value) # Convert to string for the text input # Add this function to directly set the custom_tskip based on dataset name def set_custom_tskip_for_dataset(dataset_name): """Set the custom_tskip value based on the selected dataset""" if dataset_name is None: return 85 if "butterfly" in dataset_name.lower(): return 70 elif "lamp" in dataset_name.lower(): return 85 else: return 85 if __name__ == "__main__": # Uncomment this line to save current results to cache #save_current_results_to_cache() demo = create_gradio_interface() demo.launch() # Add these functions at the top of the file, after the imports and global variables # but before any other function definitions #