sayakpaul's picture
sayakpaul HF Staff
up
bb10560
raw
history blame
1.04 kB
from datetime import datetime
import gradio as gr
import spaces
import torch
from diffusers import FluxPipeline
from aoti import aoti_load
# --- Model Loading ---
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/Flux.1-Dev", torch_dtype=torch.bfloat16
).to(device)
pipeline.transformer.fuse_qkv_projections()
aoti_load_(pipeline.transformer, "sayakpaul/flux-dev-aot", "flux-dev-aot.pt2")
@spaces.GPU
def generate_image(prompt: str, progress=gr.Progress(track_tqdm=True)):
generator = torch.Generator(device='cuda').manual_seed(42)
t0 = datetime.now()
output = pipeline(
prompt=prompt,
num_inference_steps=28,
generator=generator,
)
return [(output.images[0], f'{(datetime.now() - t0).total_seconds():.2f}s')]
gr.Interface(
fn=generate_image,
inputs=gr.Text(label="Prompt"),
outputs=gr.Gallery(),
examples=["A cat playing with a ball of yarn"],
cache_examples=False,
).launch()