Meeex2's picture
Update app.py
dda662d verified
import gradio as gr
import numpy as np
import random
import os
from huggingface_hub import login, hf_hub_download
import torch
import spaces
from diffusers import FluxPipeline
import logging
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("app.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger("moroccan-ghibli-flux-compare")
# Constants
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
DEFAULT_SEED = 42
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Global pipelines
pipeline_base = None
pipeline_lora = None
def load_base_pipeline():
api_key = os.getenv("HF_TOKEN")
if not api_key:
raise ValueError("HF_TOKEN environment variable not set.")
login(token=api_key)
logger.info("Loading base Flux model (no LoRA)")
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
)
return pipe
def load_lora_pipeline(revision: str):
api_key = os.getenv("HF_TOKEN")
if not api_key:
raise ValueError("HF_TOKEN environment variable not set.")
login(token=api_key)
logger.info(f"Loading Flux model with LoRA revision: {revision}")
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
)
# Extract step number from revision name (e.g., "step_1500" -> "1500")
step_num = revision.split('_')[-1]
if len(step_num) < 4:
step_num = '0' + step_num
filename = f"moroccan_ghibli_flux_lora_00000{step_num}.safetensors"
logger.info(f"Downloading LoRA weights: {filename}")
lora_path = hf_hub_download(
repo_id="atlasia/moroccan-ghibli-flux-lora",
filename=filename,
revision=revision,
token=api_key
)
logger.info("Loading LoRA weights into the pipeline")
pipe.load_lora_weights(lora_path)
return pipe
def init_pipelines(style_intensity: str = "low"):
global pipeline_base, pipeline_lora
revision_map = {
"low": "step_750",
"medium": "step_1500",
"high": "step_2500"
}
revision = revision_map[style_intensity]
if pipeline_base is None:
pipeline_base = load_base_pipeline()
logger.info("Base pipeline loaded")
logger.info(f"Initializing LoRA pipeline with style: {style_intensity} ({revision})")
pipeline_lora = load_lora_pipeline(revision)
logger.info("LoRA pipeline loaded")
# Initial pipelines
init_pipelines("low")
def update_lora_pipeline(style_intensity):
logger.info(f"Updating LoRA pipeline to style: {style_intensity}")
init_pipelines(style_intensity)
def _run_pipeline(pipe: FluxPipeline, prompt: str, seed: int, width: int, height: int, guidance_scale: float):
pipe.to(device)
generator = torch.Generator(device=device).manual_seed(seed)
max_sequence_length = 512
output = pipe(
prompt=[prompt],
guidance_scale=guidance_scale,
num_inference_steps=50,
height=height,
width=width,
max_sequence_length=max_sequence_length,
generator=generator,
)
return output.images[0]
@spaces.GPU(duration=120)
def infer(prompt, seed, width, height, guidance_scale, progress=gr.Progress()):
logger.info(f"Generating comparison for prompt: '{prompt[:50]}...'")
logger.info(f"Parameters: seed={seed}, width={width}, height={height}, guidance={guidance_scale}")
progress(0.1, desc="Preparing base model")
base_image = _run_pipeline(pipeline_base, prompt, seed, width, height, guidance_scale)
progress(0.55, desc="Preparing LoRA model")
lora_image = _run_pipeline(pipeline_lora, prompt, seed, width, height, guidance_scale)
progress(0.95, desc="Processing results")
logger.info("Comparison generation completed successfully")
# Return both images and the used seed
return base_image, lora_image, seed
def randomize_seed():
return random.randint(0, MAX_SEED)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("# Flux: Base vs Moroccan-Ghibli LoRA")
gr.Markdown("Generate side-by-side images with identical settings to compare the base model and the Moroccan Ghibli LoRA.")
style_intensity = gr.Dropdown(
label="LoRA Style Intensity",
choices=["low", "medium", "high"],
value="low",
interactive=True
)
prompt = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt (e.g., 'Moroccan Ghibli studio style portrait')",
)
with gr.Row():
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=DEFAULT_SEED, interactive=True)
randomize_button = gr.Button("Randomize Seed")
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=10.0, step=0.1, value=3.5)
run_button = gr.Button("Generate", variant="primary")
with gr.Column(scale=2):
gr.Markdown("## Generated Images")
with gr.Row():
base_image = gr.Image(label="Base Model", height=512)
lora_image = gr.Image(label="Moroccan Ghibli LoRA", height=512)
output_seed = gr.Number(label="Used Seed", precision=0)
gr.Markdown("## Example Prompts")
examples = [
["Moroccan Ghibli studio style portrait of a character in a riad courtyard", DEFAULT_SEED, 1024, 1024, 3.5],
["Moroccan Ghibli studio style image of a bustling souk with flying carpets", DEFAULT_SEED, 1024, 1024, 3.5],
["Moroccan Ghibli studio style landscape of a desert oasis under starry skies", DEFAULT_SEED, 1024, 1024, 3.5],
["Moroccan Ghibli studio style depiction of a magical lantern festival", DEFAULT_SEED, 1024, 1024, 3.5],
["Moroccan Ghibli studio style portrait of a medina at sunset", DEFAULT_SEED, 1024, 1024, 3.5]
]
gr.Examples(
examples=examples,
inputs=[prompt, seed, width, height, guidance_scale],
outputs=[base_image, lora_image, output_seed],
fn=infer,
cache_examples=True
)
style_intensity.change(
fn=update_lora_pipeline,
inputs=[style_intensity],
outputs=[]
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[prompt, seed, width, height, guidance_scale],
outputs=[base_image, lora_image, output_seed]
)
randomize_button.click(
fn=randomize_seed,
inputs=[],
outputs=[seed]
)
if __name__ == "__main__":
logger.info("Starting application")
demo.launch()
logger.info("Application closed")