Spaces:
Runtime error
Runtime error
| import spaces | |
| import os | |
| os.system("pip uninstall torchvision -y") | |
| os.system("pip install torchvision --force-reinstall --no-cache-dir") | |
| import torch | |
| from diffusers import AutoPipelineForText2Image | |
| import gradio as gr | |
| from PIL import Image | |
| pipe = AutoPipelineForText2Image.from_pretrained( | |
| "ostris/Flex.2-preview", | |
| custom_pipeline="pipeline.py", | |
| torch_dtype=torch.bfloat16, | |
| ).to("cuda") | |
| def generate_image( | |
| prompt: str, | |
| inpaint_img: Image.Image, | |
| inpaint_mask: Image.Image, | |
| control_img: Image.Image, | |
| height: int, | |
| width: int, | |
| guidance_scale: float, | |
| num_inference_steps: int, | |
| seed: int, | |
| control_strength: float, | |
| control_stop: float, | |
| ): | |
| gen = torch.Generator(device="cuda").manual_seed(seed) | |
| inp_img = inpaint_img.convert("RGB") | |
| inp_mask = inpaint_mask.convert("RGB") | |
| ctrl_img = control_img.convert("RGB") | |
| result = pipe( | |
| prompt=prompt, | |
| inpaint_image=inp_img, | |
| inpaint_mask=inp_mask, | |
| control_image=ctrl_img, | |
| control_strength=control_strength, | |
| control_stop=control_stop, | |
| height=height, | |
| width=width, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| generator=gen, | |
| ) | |
| return result.images[0] | |
| with gr.Blocks(title="Flex.2-preview Image Generator") as demo: | |
| gr.Markdown("# Flex.2-preview Text→Image Generator") | |
| prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt...", lines=2) | |
| with gr.Row(): | |
| inpaint_img = gr.Image(type="pil", label="Inpaint Image") | |
| inpaint_mask = gr.Image(type="pil", label="Inpaint Mask") | |
| control_img = gr.Image(type="pil", label="Control Image") | |
| with gr.Row(): | |
| height = gr.Slider(64, 2048, value=512, step=64, label="Height") | |
| width = gr.Slider(64, 2048, value=512, step=64, label="Width") | |
| with gr.Row(): | |
| guidance_scale = gr.Slider(0.0, 20.0, value=3.5, step=0.1, label="Guidance Scale") | |
| num_inference_steps = gr.Slider(1, 100, value=50, step=1, label="Inference Steps") | |
| seed = gr.Number(value=42, precision=0, label="Random Seed") | |
| control_strength = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="Control Strength") | |
| control_stop = gr.Slider(0.0, 1.0, value=1.0, step=0.01, label="Control Stop") | |
| generate_btn = gr.Button("Generate") | |
| output = gr.Image(type="pil", label="Generated Image") | |
| generate_btn.click( | |
| fn=generate_image, | |
| inputs=[ | |
| prompt, | |
| inpaint_img, | |
| inpaint_mask, | |
| control_img, | |
| height, | |
| width, | |
| guidance_scale, | |
| num_inference_steps, | |
| seed, | |
| control_strength, | |
| control_stop, | |
| ], | |
| outputs=[output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |