import PIL import math import torch import random import os import numpy as np import pandas as pd import gradio as gr import threading import time import zipfile import shutil import glob from pathlib import Path from torch.utils.data import DataLoader, Dataset, random_split import torchvision.transforms.v2 as transforms from PIL import Image, ImageDraw, ImageFont import imageio from tqdm import tqdm import tarfile import queue import hashlib import json # Set seeds for reproducibility torch.manual_seed(0) random.seed(0) np.random.seed(0) # Define constants IMG_SIZE = 512 BATCH_SIZE = 32 DEVICE = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") # Add these global variables after the imports and before the CSS definition # Global variables for LoRA training lora_status = "Ready" lora_is_processing = False # Global variables for generation control generation_should_stop = False classifier_should_stop = False # New flag for classifier training embedding_should_stop = False # New flag for embedding encoding lora_should_stop = False # New flag for LoRA training generation_queue = queue.Queue() is_processing = False # Add this to prevent multiple simultaneous processes # Create temporary directories for uploads temp_dir = Path("./temp_uploads") temp_dir.mkdir(exist_ok=True, parents=True) lora_temp_dir = Path("./temp_lora_uploads") lora_temp_dir.mkdir(exist_ok=True, parents=True) # Create a global queue for real-time updates result_queue = queue.Queue() displayed_results = [] # Keep track of all displayed results # Add these global variables at the top of your file total_images_to_process = 0 images_processed = 0 # Add these global variables after the existing ones displayed_results_class0_to_class1 = [] # Results for class 0 to class 1 displayed_results_class1_to_class0 = [] # Results for class 1 to class 0 # Add global variables for caching CACHE_DIR = Path("./cached_results") CACHE_DIR.mkdir(exist_ok=True, parents=True) # CSS for styling the interface css = """ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap'); body, * { font-family: 'Inter', sans-serif !important; letter-spacing: -0.01em; } .container { max-width: 1360px; margin: auto; padding-top: 2.5rem; padding-bottom: 2.5rem; } .header { text-align: center; margin-bottom: 3rem; padding-bottom: 2rem; border-bottom: 1px solid #f0f0f0; } .header h1 { font-size: 3rem; font-weight: 700; color: #222; letter-spacing: -0.03em; margin-bottom: 1rem; background: linear-gradient(90deg, #B39CD0 0%, #9D8AC7 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; display: inline-block; } .header p { font-size: 1.1rem; color: #333; max-width: 800px; margin: 0 auto; line-height: 1.6; } .subtitle { font-size: 0.95rem; color: #777; max-width: 800px; margin: 0.5rem auto 0; line-height: 1.5; } .contact-info { font-size: 0.8rem; color: #777; margin-top: 15px; padding-top: 10px; border-top: 1px dashed #e0e0e0; width: 80%; margin-left: auto; margin-right: auto; } .paper-info { background-color: #f8f9fa; border-radius: 12px; padding: 1.8rem; margin: 1.8rem 0; box-shadow: 0 6px 20px rgba(0,0,0,0.05); border-left: 4px solid #B39CD0; } .paper-info h3 { font-size: 1.5rem; font-weight: 600; color: #B39CD0; letter-spacing: -0.02em; margin-bottom: 1rem; } .paper-info p { font-size: 1.05em; line-height: 1.7; color: #333; } .section-header { font-size: 1.8rem; font-weight: 600; color: #B39CD0; margin: 2.5rem 0 1.5rem 0; padding-bottom: 0.8rem; border-bottom: 2px solid #ECF0F1; letter-spacing: -0.02em; } .footer { text-align: center; margin-top: 3rem; padding: 1.5rem; border-top: 1px solid #ECF0F1; color: #666; background-color: #f8f9fa; border-radius: 0 0 12px 12px; } .btn-primary { background-color: #B39CD0 !important; border-color: #B39CD0 !important; transition: all 0.3s ease; font-weight: 500 !important; letter-spacing: 0.02em !important; padding: 0.6rem 1.5rem !important; border-radius: 8px !important; } .btn-primary:hover { background-color: #9D8AC7 !important; border-color: #9D8AC7 !important; } /* Hide the output directory */ .hidden-element { display: none !important; } /* Additional CSS for better alignment */ .container { padding: 0 1.5rem; } .main-container { display: flex; flex-direction: column; gap: 1.5rem; } .results-container { margin-top: 0; padding-top: 0; } .full-width-header { margin-bottom: 2rem; padding-bottom: 1.5rem; border-bottom: 1px solid #f0f0f0; text-align: center; } .content-row { display: flex; gap: 2rem; } .sidebar { min-width: 250px; padding-right: 1.5rem; } .section-header { margin-top: 0; } .tabs-container { margin-top: 1rem; } .gallery-container { margin-top: 1rem; } /* Hide the output directory */ .hidden-element { display: none !important; } .gallery-item img { object-fit: contain !important; height: 200px !important; width: auto !important; } /* Force GIFs to restart when tab is selected */ .tabs-container .tabitem[style*="display: block"] .gallery-container img { animation: none; animation: reload-animation 0.1s; } @keyframes reload-animation { 0% { opacity: 0.99; } 100% { opacity: 1; } } """ # Add to your global variables current_cache_key = None is_using_default_params = False # Update the EXAMPLE_DATASETS to include direct dataset paths, embeddings, and classifiers EXAMPLE_DATASETS = [ { "name": "butterfly", "display_name": "Butterfly (Monarch vs Viceroy)", "description": "Dataset containing images of Monarch and Viceroy butterflies for counterfactual generation", "path": "/proj/vondrick/datasets/magnification/butterfly.tar.gz", "direct_dataset_path": "example_images/butterfly", "checkpoint_path": "/proj/vondrick2/mia/magnificationold/output/lora/butterfly/copper-forest-49/checkpoint-1800", "embeddings_path": "/proj/vondrick2/mia/diff-usion/results/clip_image_embeds/butterfly", "classifier_path": "/proj/vondrick2/mia/diff-usion/results/ensemble/butterfly", "class_names": ["class0", "class1"] }, { "name": "afhq", "display_name": "Cats vs. Dogs (AFHQ)", "description": "Dataset containing images of table lamps and floor lamps", "direct_dataset_path": "example_images/afhq", "checkpoint_path": None, "embeddings_path": "/proj/vondrick2/mia/diff-usion/results/clip_image_embeds/afhq", "classifier_path": "/proj/vondrick2/mia/diff-usion/results/ensemble/afhq", "class_names": ["class0", "class1"] }, { "name": "lamp", "display_name": "Lamps", "description": "Dataset containing images of table lamps and floor lamps", "path": "compressed_datasets/lampsfar.zip", "direct_dataset_path": "example_images/lamps", "checkpoint_path": "/proj/vondrick2/mia/diff-usion/lora_output_lampsfar/checkpoint-800", "embeddings_path": "/proj/vondrick2/mia/diff-usion/results/clip_image_embeds/lampsfar", "classifier_path": "/proj/vondrick2/mia/diff-usion/results/ensemble/lampsfar", "class_names": ["class0", "class1"] }, { "name": "couches", "display_name": "Couches", "description": "Dataset containing images of chairs and floor", "path": "compressed_datasets/couches.zip", "direct_dataset_path": "example_images/couches", "embeddings_path": "/proj/vondrick2/mia/diff-usion/results/clip_image_embeds/couches", "checkpoint_path": "/proj/vondrick2/mia/diff-usion/lora_output/couches/checkpoint-1000", "class_names": ["class0", "class1"] } ] # Function to get available example datasets def get_example_datasets(): """Get list of available example datasets""" return [dataset["name"] for dataset in EXAMPLE_DATASETS] # Function to get example dataset info def get_example_dataset_info(name): """Get information about an example dataset""" for dataset in EXAMPLE_DATASETS: if dataset["name"] == name: return dataset return None #Function to check if we're using default parameters def is_using_default_params(dataset_name, custom_tskip, num_images_per_class): """Check if we're using default parameters for the given dataset""" if dataset_name is None: return False if "butterfly" in dataset_name.lower(): return (custom_tskip == 70 or custom_tskip == "70") and num_images_per_class == 10 elif "lamp" in dataset_name.lower(): return (custom_tskip == 85 or custom_tskip == "85") and num_images_per_class == 10 elif "couch" in dataset_name.lower(): return (custom_tskip == 85 or custom_tskip == "85") and num_images_per_class == 10 return False # # Function to get the output directory - either cache or regular output # def get_output_directory(dataset_name, is_default_params, cache_key): # """Get the appropriate output directory based on parameters""" # if is_default_params: # # Use cache directory # cache_path = CACHE_DIR / cache_key # cache_path.mkdir(exist_ok=True, parents=True) # # Create dataset-specific directory # dataset_dir = cache_path / dataset_name.replace(" ", "_").lower() # dataset_dir.mkdir(exist_ok=True, parents=True) # # Create class-specific directories # class0_to_class1_dir = dataset_dir / "class0_to_class1" # class1_to_class0_dir = dataset_dir / "class1_to_class0" # class0_to_class1_dir.mkdir(exist_ok=True, parents=True) # class1_to_class0_dir.mkdir(exist_ok=True, parents=True) # # Create context directory # context_dir = dataset_dir / "context" # context_dir.mkdir(exist_ok=True, parents=True) # return dataset_dir, class0_to_class1_dir, class1_to_class0_dir, context_dir # else: # # Use regular output directory # output_dir = Path(f"./results/{dataset_name.replace(' ', '_').lower()}") # output_dir.mkdir(exist_ok=True, parents=True) # # Create gifs directory with class-specific subdirectories # gifs_dir = output_dir / "gifs" # gifs_dir.mkdir(exist_ok=True, parents=True) # class0_to_class1_dir = gifs_dir / "class0_to_class1" # class1_to_class0_dir = gifs_dir / "class1_to_class0" # class0_to_class1_dir.mkdir(exist_ok=True, parents=True) # class1_to_class0_dir.mkdir(exist_ok=True, parents=True) # # Create context directory # context_dir = output_dir / "context" # context_dir.mkdir(exist_ok=True, parents=True) # return output_dir, class0_to_class1_dir, class1_to_class0_dir, context_dir def has_prediction_flipped(orig_preds, new_preds): """Check if any prediction has flipped from one class to another.""" return ((orig_preds.preds > 0.5) != (new_preds.preds > 0.5)).any().item() # Function to extract uploaded zip or tar.gz file def extract_archive(archive_file, extract_dir): """Extract a zip or tar.gz file to the specified directory""" # Create a temporary directory for extraction temp_dir = Path(extract_dir) temp_dir.mkdir(parents=True, exist_ok=True) # Check file extension file_path = Path(archive_file) if file_path.suffix.lower() == '.zip': # Extract the zip file with zipfile.ZipFile(archive_file, 'r') as zip_ref: zip_ref.extractall(temp_dir) elif file_path.name.endswith('.tar.gz') or file_path.name.endswith('.tgz'): # Extract the tar.gz file with tarfile.open(archive_file, 'r:gz') as tar_ref: tar_ref.extractall(temp_dir) else: raise ValueError(f"Unsupported archive format: {file_path.suffix}. Please use .zip or .tar.gz") # Check if the extracted content has class0 and class1 folders # If not, try to find them in subdirectories class0_dir = temp_dir / "class0" class1_dir = temp_dir / "class1" if not (class0_dir.exists() and class1_dir.exists()): # Look for class0 and class1 in subdirectories for subdir in temp_dir.iterdir(): if subdir.is_dir(): if (subdir / "class0").exists() and (subdir / "class1").exists(): # Move the class directories to the temp_dir shutil.move(str(subdir / "class0"), str(class0_dir)) shutil.move(str(subdir / "class1"), str(class1_dir)) break # Verify that we have the required directories if not (class0_dir.exists() and class1_dir.exists()): raise ValueError("The uploaded archive must contain 'class0' and 'class1' directories or a subdirectory containing them") return str(temp_dir) # Function to handle cached results (placeholder implementation) def get_cached_result_info(name): """Get information about a cached result (placeholder)""" # This is a placeholder - in a real implementation, you'd store and retrieve cached results return None # Modify the TwoClassDataset class to accept num_samples_per_class as a parameter class TwoClassDataset(Dataset): def __init__(self, root_dir, transform=None, num_samples_per_class=None): self.root_dir = Path(root_dir) #import pdb; pdb.set_trace() self.transform = transform if 'kermany' in str(self.root_dir): #import pdb; pdb.set_trace() self.class0_dir = self.root_dir / "NORMAL" self.class1_dir = self.root_dir / "DRUSEN" elif 'kiki_bouba' in str(self.root_dir): self.class0_dir = self.root_dir / "kiki" self.class1_dir = self.root_dir / "bouba" elif 'afhq' in str(self.root_dir): self.class0_dir = self.root_dir / "dog" self.class1_dir = self.root_dir / "cat" else: self.class0_dir = self.root_dir / "class0" self.class1_dir = self.root_dir / "class1" # Get image paths #import pdb; pdb.set_trace() self.class0_images = list(self.class0_dir.glob("*.*")) self.class1_images = list(self.class1_dir.glob("*.*")) # Limit the number of samples per class if specified if num_samples_per_class is not None: self.class0_images = self.class0_images[:num_samples_per_class] self.class1_images = self.class1_images[:num_samples_per_class] # Create image list and labels self.images = self.class0_images + self.class1_images self.labels = [0] * len(self.class0_images) + [1] * len(self.class1_images) def __len__(self): return len(self.images) def __getitem__(self, idx): img_path = self.images[idx] image = Image.open(img_path).convert("RGB") label = self.labels[idx] if self.transform: image = self.transform(image) return image, label, str(img_path) def compute_lpips_similarity(images1, images2, reduction=None): """Compute LPIPS similarity between two batches of images""" # This is a placeholder - in a real implementation, you'd use a proper LPIPS model # For demo purposes, we'll just return a random similarity score batch_size = images1.shape[0] similarity = torch.rand(batch_size, device=images1.device) if reduction == "mean": return similarity.mean() return similarity def get_direction_sign(idx: int): if idx == 0: sign = -1 elif idx == 1: sign = 1 else: raise ValueError("Currently two direction are supported in this script") return sign def add_text_to_image(image, text): """Add text to an image at the top with a nicer design""" draw = ImageDraw.Draw(image) # Use a default font try: font = ImageFont.truetype("arial.ttf", 24) except: font = ImageFont.load_default() # Add a semi-transparent gradient background for better readability text_width, text_height = draw.textsize(text, font=font) if hasattr(draw, 'textsize') else (200, 30) # Create gradient background for i in range(40): alpha = int(180 - i * 4) # Fade from 180 to 20 alpha if alpha < 0: alpha = 0 draw.rectangle([(0, i), (image.width, i)], fill=(0, 0, 0, alpha)) # Draw text at the top of the image draw.text((15, 10), text, fill="white", font=font) return image def create_gif(img1, img2, output_path): """Create a GIF that alternates between two images with elegant labels""" # Create copies of the images to avoid modifying the originals img1_copy = img1.copy() img2_copy = img2.copy() # Add labels to the images draw1 = ImageDraw.Draw(img1_copy) draw2 = ImageDraw.Draw(img2_copy) try: # Use a larger font size for better visibility font = ImageFont.truetype("arial.ttf", 36) # Increased from 28 to 36 except: font = ImageFont.load_default() # Add a subtle shadow effect for better visibility padding = 15 # Original image - add text with shadow effect # First draw shadow/outline for offset in [(1,1), (-1,1), (1,-1), (-1,-1)]: draw1.text( (padding + offset[0], padding + offset[1]), "Original", fill=(0, 0, 0, 180), font=font ) # Then draw the main text draw1.text( (padding, padding), "Original", fill=(255, 255, 255, 230), font=font ) # Generated image - add text with shadow effect # First draw shadow/outline for offset in [(1,1), (-1,1), (1,-1), (-1,-1)]: draw2.text( (padding + offset[0], padding + offset[1]), "Generated", fill=(0, 0, 0, 180), font=font ) # Then draw the main text draw2.text( (padding, padding), "Generated", fill=(255, 255, 255, 230), font=font ) # Increase duration to 1 second per image (1000ms) imageio.mimsave(output_path, [img1_copy, img2_copy], duration=1000, loop=0) return output_path # Modify the update_progress_status function to be more informative def update_progress_status(): """Update the progress status for the counterfactual generation""" global images_processed, total_images_to_process, is_processing if not is_processing: if images_processed > 0: return f"Processing complete. Generated {images_processed} counterfactual images." return "Ready to process images." if total_images_to_process == 0: return "Preparing to process images..." percentage = (images_processed / total_images_to_process) * 100 return f"Progress: {images_processed}/{total_images_to_process} images processed ({percentage:.1f}%)" # Add function to cancel generation def cancel_generation(): """Cancel all ongoing processes""" global generation_should_stop, classifier_should_stop, embedding_should_stop, lora_should_stop # Set all stop flags generation_should_stop = True classifier_should_stop = True embedding_should_stop = True lora_should_stop = True return "All processes have been requested to stop. This may take a moment to complete." def save_results_to_cache(output_dir, cache_key): """Save generated results to cache directory""" cache_path = CACHE_DIR / cache_key cache_path.mkdir(exist_ok=True, parents=True) # Copy gifs directory output_gifs_dir = Path(output_dir) / "gifs" cache_gifs_dir = cache_path / "gifs" if output_gifs_dir.exists(): # Remove existing cache if it exists if cache_gifs_dir.exists(): shutil.rmtree(cache_gifs_dir) # Copy the new results, maintaining subdirectory structure shutil.copytree(output_gifs_dir, cache_gifs_dir) # Copy context images if they exist output_context_dir = Path(output_dir) / "context" cache_context_dir = cache_path / "context" if output_context_dir.exists(): if cache_context_dir.exists(): shutil.rmtree(cache_context_dir) shutil.copytree(output_context_dir, cache_context_dir) # Update the process_with_selected_dataset function to handle the new directory structure def process_with_selected_dataset(zip_file, output_dir, dataset_display_name, checkpoint_path=None, train_clf=True, is_direct_path=False, direct_path=None, embeddings_path=None, classifier_path=None, use_classifier_stopping=True, custom_tskip=85, manip_val=2): print(f"\nProcessing with dataset: {dataset_display_name}") # Find the selected dataset selected_dataset = None for dataset in EXAMPLE_DATASETS: if dataset["display_name"] == dataset_display_name: selected_dataset = dataset break if not selected_dataset: print("Error: No dataset selected") return "No dataset selected", [], [], [], "Error: No dataset selected", None, None # Generate cache key cache_key = get_cache_key( selected_dataset["name"], checkpoint_path, False, embeddings_path, classifier_path, use_classifier_stopping, custom_tskip, manip_val, ) print(f"Generated cache key: {cache_key}") # Check if cache exists cache_path = CACHE_DIR / cache_key dataset_dir = cache_path / "gifs" print(f"Looking for cache in: {cache_path}") print(f"Looking for gifs in: {dataset_dir}") print(f"Cache exists: {cache_path.exists()}") print(f"Gifs dir exists: {dataset_dir.exists()}") #import pdb; pdb.set_trace() if cache_path.exists() and dataset_dir.exists(): current_cache_key = cache_key print(f"Found cached results for key: {cache_key}") # Get paths to class-specific directories class0_to_class1_dir = dataset_dir / "class0_to_class1" class1_to_class0_dir = dataset_dir / "class1_to_class0" context_dir = cache_path/ "context" # Get all GIF paths class0_to_class1_gifs = list(class0_to_class1_dir.glob("*.gif")) if class0_to_class1_dir.exists() else [] class1_to_class0_gifs = list(class1_to_class0_dir.glob("*.gif")) if class1_to_class0_dir.exists() else [] # Sort the GIFs by filename for consistent ordering class0_to_class1_gifs.sort(key=lambda p: p.name) class1_to_class0_gifs.sort(key=lambda p: p.name) # Get context images class0_context = context_dir / "class0_sample.jpg" if (context_dir / "class0_sample.jpg").exists() else None class1_context = context_dir / "class1_sample.jpg" if (context_dir / "class1_sample.jpg").exists() else None # Convert paths to strings class0_to_class1_paths = [str(p) for p in class0_to_class1_gifs] class1_to_class0_paths = [str(p) for p in class1_to_class0_gifs] all_gifs = class0_to_class1_paths + class1_to_class0_paths # Update the global gallery variables global displayed_results, displayed_results_class0_to_class1, displayed_results_class1_to_class0 displayed_results = all_gifs displayed_results_class0_to_class1 = class0_to_class1_paths displayed_results_class1_to_class0 = class1_to_class0_paths status_message = f"Using cached results with t-skip={custom_tskip}, manip_scale={manip_val}" # Return cached results return ( "Using cached results for default parameters.", displayed_results, displayed_results_class0_to_class1, displayed_results_class1_to_class0, status_message, str(class0_context) if class0_context else None, str(class1_context) if class1_context else None ) else: print("No cached results found, processing dataset...") return "No cached results found, processing dataset...", [], [], [], "No cached results found, processing dataset...", None, None return # def process_and_clear(example_datasets_dropdown, checkpoint_path_state, # is_direct_path_state, direct_path_state, embeddings_path_state, # classifier_path_state, use_classifier_stopping, custom_tskip, # manip_val): # """Clear folders first, then process the dataset""" # # Clear folders first # clear_output_folders() # # Then process the dataset # return process_with_selected_dataset( # None, # input_zip (always None) # "./output", # output_dir (hardcoded) # example_datasets_dropdown, # checkpoint_path_state, # False, # train_clf (always False) # is_direct_path_state, # direct_path_state, # embeddings_path_state, # classifier_path_state, # use_classifier_stopping, # custom_tskip, # manip_val # ) def process_and_clear(example_datasets_dropdown, checkpoint_path_state, is_direct_path_state, direct_path_state, embeddings_path_state, classifier_path_state, use_classifier_stopping, custom_tskip, manip_val): """Clear galleries first, then process the dataset""" # Clear galleries but keep example images clear_output_folders() # Process the dataset result = process_with_selected_dataset( None, # input_zip (always None) "./output", # output_dir (hardcoded) example_datasets_dropdown, checkpoint_path_state, False, # train_clf (always False) is_direct_path_state, direct_path_state, embeddings_path_state, classifier_path_state, use_classifier_stopping, custom_tskip, manip_val ) # Return all outputs except example images return ( result[1], # gallery result[2], # gallery_class0_to_class1 result[3], # gallery_class1_to_class0 result[4], # progress_status # Don't update class1_context_image ) def update_example_images(dataset_display_name): """Update the example images based on the selected dataset""" print(f"\nUpdating example images for {dataset_display_name}") # Find the dataset info selected_dataset = None for dataset in EXAMPLE_DATASETS: print(f"Checking dataset: {dataset['display_name']}", dataset_display_name) if dataset["display_name"] == dataset_display_name: selected_dataset = dataset print(f"Selected dataset: {selected_dataset}") break class_names = selected_dataset.get("class_names", None) if selected_dataset: dataset_dir = selected_dataset.get("direct_dataset_path") print(f"Dataset directory: {dataset_dir}") if dataset_dir: # Debug: List all files in the directory print("Contents of directory:") for path in Path(dataset_dir).rglob("*"): print(f" {path}") # Try to find class0 and class1 images class0_path = Path(dataset_dir) / class_names[0] class1_path = Path(dataset_dir) / class_names[1] print(f"Looking in class0: {class0_path}") print(f"Looking in class1: {class1_path}") class0_img = next((str(p) for p in Path(dataset_dir).glob(f"{class_names[0]}/*.*")), None) class1_img = next((str(p) for p in Path(dataset_dir).glob(f"{class_names[1]}/*.*")), None) print(f"Found images:\nclass0={class0_img}\nclass1={class1_img}") return class0_img, class1_img print("No images found") return None, None # Add a state variable to store the direct dataset path direct_path_state = gr.State(None) # Map display names back to internal names (add this back) def get_name_from_display(display_name): for dataset in EXAMPLE_DATASETS: if dataset["display_name"] == display_name: return dataset["name"] return None # Modify the use_selected_dataset function def use_selected_dataset(display_name): name = get_name_from_display(display_name) if not name: print("No dataset name found") return None, None, False, None, None, None dataset_info = get_example_dataset_info(name) # Check if there's a direct dataset path available if dataset_info and "direct_dataset_path" in dataset_info and os.path.exists(dataset_info["direct_dataset_path"]): print(f"Using direct dataset path: {dataset_info['direct_dataset_path']}") # Return paths for direct dataset, checkpoint, embeddings, and classifiers return None, dataset_info["checkpoint_path"], True, dataset_info["direct_dataset_path"], \ dataset_info.get("embeddings_path"), dataset_info.get("classifier_path") elif dataset_info and os.path.exists(dataset_info["path"]): # Return the archive path and other paths return dataset_info["path"], dataset_info["checkpoint_path"], False, None, \ dataset_info.get("embeddings_path"), dataset_info.get("classifier_path") return None, None, False, None, None, None def reset_galleries(): """Reset all galleries when changing datasets or parameters""" global displayed_results, displayed_results_class0_to_class1, displayed_results_class1_to_class0 global current_cache_key # Also reset the cache key displayed_results = [] displayed_results_class0_to_class1 = [] displayed_results_class1_to_class0 = [] current_cache_key = None # Reset the cache key # Clear the result queue if it exists while not result_queue.empty(): result_queue.get() return [], [], [], "Galleries reset" def clear_output_folders(): """Delete the output/gifs and output/context folders and their contents""" import shutil from pathlib import Path # Folders to clear folders = ["gifs", "context"] for folder in folders: folder_path = Path("./output") / folder if folder_path.exists(): shutil.rmtree(folder_path) print(f"Deleted {folder_path}") def create_gradio_interface(): # Create temporary directories for uploads temp_dir = Path("./temp_uploads") temp_dir.mkdir(exist_ok=True, parents=True) clear_output_folders() lora_temp_dir = Path("./temp_lora_uploads") lora_temp_dir.mkdir(exist_ok=True, parents=True) # Get initial list of example datasets example_datasets = get_example_datasets() with gr.Blocks(css=css) as demo: # Add the header at the top level to span across all columns with gr.Row(elem_classes="full-width-header"): with gr.Column(): gr.HTML("""
Generate fine-grained edits to images using another class of images as guidance.
For any questions/comments/issues with this demo, please email mia.chiquier@cs.columbia.edu.🤖
Text-based AI image editing can be tricky, as language often fails to capture precise visual ideas, and users may not always know what they want. Our image-guided editing method learns transformations directly from the differences between two image groups, removing the need for detailed verbal descriptions. Designed for scientific applications, it highlights subtle differences in visually similar image categories. It also applies to nicely to marketing, adapting new products into scenes by managing small interior design details. Choose between four example datasets, then adjust the tskip (higher = less edit) and manipulation scalar (higher = more edit) to explore the editing effects. A Gradio demo in our GitHub code release lets users upload datasets and try the method (GPU required).
""") # Counterfactual Generation Section gr.HTML('Dataset Format: Upload a zip file containing two folders named 'class0' and 'class1', # each containing images of the respective class.
# """) # with gr.Row(): # input_zip = gr.File( # label="Upload Custom Dataset (ZIP or TAR.GZ file)", # file_types=[".zip", ".tar.gz", ".tgz"], # type="filepath" # ) # # Hide the output directory by using elem_classes # output_dir = gr.Textbox( # label="Output Directory", # value="./output", # elem_classes="hidden-element" # ) # with gr.Row(): # gr.HTML('Dataset Format: Upload a zip file containing two folders named 'class0' and 'class1', # each containing images of the respective class for training the LoRA model.
# """) # with gr.Row(): # lora_output_dir = gr.Textbox( # label="LoRA Output Directory", # value="./lora_output" # ) # gr.HTML(""" #Default LoRA Training Parameters:
#