Upload app.py
Browse files
app.py
CHANGED
|
@@ -40,6 +40,19 @@ pipe = DAIPipeline(
|
|
| 40 |
t_start=0,
|
| 41 |
).to(device)
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
@spaces.GPU
|
| 44 |
def process_image(input_image, resolution_choice):
|
| 45 |
# 将 Gradio 输入转换为 PIL 图像
|
|
@@ -47,6 +60,7 @@ def process_image(input_image, resolution_choice):
|
|
| 47 |
|
| 48 |
# 根据用户选择设置处理分辨率
|
| 49 |
if resolution_choice == "768":
|
|
|
|
| 50 |
processing_resolution = None
|
| 51 |
else:
|
| 52 |
processing_resolution = 0 # 使用原始分辨率
|
|
@@ -64,7 +78,7 @@ def process_image(input_image, resolution_choice):
|
|
| 64 |
processed_frame = (processed_frame[0] * 255).astype(np.uint8)
|
| 65 |
processed_frame = Image.fromarray(processed_frame)
|
| 66 |
|
| 67 |
-
return processed_frame
|
| 68 |
|
| 69 |
# 创建 Gradio 界面
|
| 70 |
def create_gradio_interface():
|
|
@@ -108,7 +122,7 @@ def create_gradio_interface():
|
|
| 108 |
submit_btn.click(
|
| 109 |
fn=process_image,
|
| 110 |
inputs=[input_image, resolution_choice], # 输入组件列表
|
| 111 |
-
outputs=output_image,
|
| 112 |
)
|
| 113 |
|
| 114 |
return demo
|
|
|
|
| 40 |
t_start=0,
|
| 41 |
).to(device)
|
| 42 |
|
| 43 |
+
def resize_image(image, max_size):
|
| 44 |
+
"""Resize the image so that the maximum side is max_size."""
|
| 45 |
+
width, height = image.size
|
| 46 |
+
if max(width, height) > max_size:
|
| 47 |
+
if width > height:
|
| 48 |
+
new_width = max_size
|
| 49 |
+
new_height = int(height * (max_size / width))
|
| 50 |
+
else:
|
| 51 |
+
new_height = max_size
|
| 52 |
+
new_width = int(width * (max_size / height))
|
| 53 |
+
image = image.resize((new_width, new_height), Image.LANCZOS)
|
| 54 |
+
return image
|
| 55 |
+
|
| 56 |
@spaces.GPU
|
| 57 |
def process_image(input_image, resolution_choice):
|
| 58 |
# 将 Gradio 输入转换为 PIL 图像
|
|
|
|
| 60 |
|
| 61 |
# 根据用户选择设置处理分辨率
|
| 62 |
if resolution_choice == "768":
|
| 63 |
+
input_image = resize_image(input_image, 768)
|
| 64 |
processing_resolution = None
|
| 65 |
else:
|
| 66 |
processing_resolution = 0 # 使用原始分辨率
|
|
|
|
| 78 |
processed_frame = (processed_frame[0] * 255).astype(np.uint8)
|
| 79 |
processed_frame = Image.fromarray(processed_frame)
|
| 80 |
|
| 81 |
+
return input_image, processed_frame # 返回调整后的输入图片和处理后的图片
|
| 82 |
|
| 83 |
# 创建 Gradio 界面
|
| 84 |
def create_gradio_interface():
|
|
|
|
| 122 |
submit_btn.click(
|
| 123 |
fn=process_image,
|
| 124 |
inputs=[input_image, resolution_choice], # 输入组件列表
|
| 125 |
+
outputs=[input_image, output_image], # 输出组件列表
|
| 126 |
)
|
| 127 |
|
| 128 |
return demo
|