Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import numpy as np | |
| import random | |
| import os | |
| from huggingface_hub import login, hf_hub_download | |
| import torch | |
| import spaces | |
| from diffusers import FluxPipeline | |
| import logging | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.FileHandler("app.log"), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| logger = logging.getLogger("moroccan-ghibli-flux-compare") | |
| # Constants | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 1024 | |
| DEFAULT_SEED = 42 | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Global pipelines | |
| pipeline_base = None | |
| pipeline_lora = None | |
| def load_base_pipeline(): | |
| api_key = os.getenv("HF_TOKEN") | |
| if not api_key: | |
| raise ValueError("HF_TOKEN environment variable not set.") | |
| login(token=api_key) | |
| logger.info("Loading base Flux model (no LoRA)") | |
| pipe = FluxPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-dev", | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| return pipe | |
| def load_lora_pipeline(revision: str): | |
| api_key = os.getenv("HF_TOKEN") | |
| if not api_key: | |
| raise ValueError("HF_TOKEN environment variable not set.") | |
| login(token=api_key) | |
| logger.info(f"Loading Flux model with LoRA revision: {revision}") | |
| pipe = FluxPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-dev", | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| # Extract step number from revision name (e.g., "step_1500" -> "1500") | |
| step_num = revision.split('_')[-1] | |
| if len(step_num) < 4: | |
| step_num = '0' + step_num | |
| filename = f"moroccan_ghibli_flux_lora_00000{step_num}.safetensors" | |
| logger.info(f"Downloading LoRA weights: {filename}") | |
| lora_path = hf_hub_download( | |
| repo_id="atlasia/moroccan-ghibli-flux-lora", | |
| filename=filename, | |
| revision=revision, | |
| token=api_key | |
| ) | |
| logger.info("Loading LoRA weights into the pipeline") | |
| pipe.load_lora_weights(lora_path) | |
| return pipe | |
| def init_pipelines(style_intensity: str = "low"): | |
| global pipeline_base, pipeline_lora | |
| revision_map = { | |
| "low": "step_750", | |
| "medium": "step_1500", | |
| "high": "step_2500" | |
| } | |
| revision = revision_map[style_intensity] | |
| if pipeline_base is None: | |
| pipeline_base = load_base_pipeline() | |
| logger.info("Base pipeline loaded") | |
| logger.info(f"Initializing LoRA pipeline with style: {style_intensity} ({revision})") | |
| pipeline_lora = load_lora_pipeline(revision) | |
| logger.info("LoRA pipeline loaded") | |
| # Initial pipelines | |
| init_pipelines("low") | |
| def update_lora_pipeline(style_intensity): | |
| logger.info(f"Updating LoRA pipeline to style: {style_intensity}") | |
| init_pipelines(style_intensity) | |
| def _run_pipeline(pipe: FluxPipeline, prompt: str, seed: int, width: int, height: int, guidance_scale: float): | |
| pipe.to(device) | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| max_sequence_length = 512 | |
| output = pipe( | |
| prompt=[prompt], | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=50, | |
| height=height, | |
| width=width, | |
| max_sequence_length=max_sequence_length, | |
| generator=generator, | |
| ) | |
| return output.images[0] | |
| def infer(prompt, seed, width, height, guidance_scale, progress=gr.Progress()): | |
| logger.info(f"Generating comparison for prompt: '{prompt[:50]}...'") | |
| logger.info(f"Parameters: seed={seed}, width={width}, height={height}, guidance={guidance_scale}") | |
| progress(0.1, desc="Preparing base model") | |
| base_image = _run_pipeline(pipeline_base, prompt, seed, width, height, guidance_scale) | |
| progress(0.55, desc="Preparing LoRA model") | |
| lora_image = _run_pipeline(pipeline_lora, prompt, seed, width, height, guidance_scale) | |
| progress(0.95, desc="Processing results") | |
| logger.info("Comparison generation completed successfully") | |
| # Return both images and the used seed | |
| return base_image, lora_image, seed | |
| def randomize_seed(): | |
| return random.randint(0, MAX_SEED) | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("# Flux: Base vs Moroccan-Ghibli LoRA") | |
| gr.Markdown("Generate side-by-side images with identical settings to compare the base model and the Moroccan Ghibli LoRA.") | |
| style_intensity = gr.Dropdown( | |
| label="LoRA Style Intensity", | |
| choices=["low", "medium", "high"], | |
| value="low", | |
| interactive=True | |
| ) | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Enter your prompt (e.g., 'Moroccan Ghibli studio style portrait')", | |
| ) | |
| with gr.Row(): | |
| seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=DEFAULT_SEED, interactive=True) | |
| randomize_button = gr.Button("Randomize Seed") | |
| with gr.Row(): | |
| width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024) | |
| height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024) | |
| guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=10.0, step=0.1, value=3.5) | |
| run_button = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=2): | |
| gr.Markdown("## Generated Images") | |
| with gr.Row(): | |
| base_image = gr.Image(label="Base Model", height=512) | |
| lora_image = gr.Image(label="Moroccan Ghibli LoRA", height=512) | |
| output_seed = gr.Number(label="Used Seed", precision=0) | |
| gr.Markdown("## Example Prompts") | |
| examples = [ | |
| ["Moroccan Ghibli studio style portrait of a character in a riad courtyard", DEFAULT_SEED, 1024, 1024, 3.5], | |
| ["Moroccan Ghibli studio style image of a bustling souk with flying carpets", DEFAULT_SEED, 1024, 1024, 3.5], | |
| ["Moroccan Ghibli studio style landscape of a desert oasis under starry skies", DEFAULT_SEED, 1024, 1024, 3.5], | |
| ["Moroccan Ghibli studio style depiction of a magical lantern festival", DEFAULT_SEED, 1024, 1024, 3.5], | |
| ["Moroccan Ghibli studio style portrait of a medina at sunset", DEFAULT_SEED, 1024, 1024, 3.5] | |
| ] | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[prompt, seed, width, height, guidance_scale], | |
| outputs=[base_image, lora_image, output_seed], | |
| fn=infer, | |
| cache_examples=True | |
| ) | |
| style_intensity.change( | |
| fn=update_lora_pipeline, | |
| inputs=[style_intensity], | |
| outputs=[] | |
| ) | |
| gr.on( | |
| triggers=[run_button.click, prompt.submit], | |
| fn=infer, | |
| inputs=[prompt, seed, width, height, guidance_scale], | |
| outputs=[base_image, lora_image, output_seed] | |
| ) | |
| randomize_button.click( | |
| fn=randomize_seed, | |
| inputs=[], | |
| outputs=[seed] | |
| ) | |
| if __name__ == "__main__": | |
| logger.info("Starting application") | |
| demo.launch() | |
| logger.info("Application closed") |