test new
Browse files- app.py +0 -278
- app/README.md +3 -0
- app/app.py +90 -0
- app/requirements.txt +10 -0
- cv_utils/__init__.py +25 -0
- cv_utils/__pycache__/__init__.cpython-313.pyc +0 -0
- cv_utils/__pycache__/metrics.cpython-313.pyc +0 -0
- cv_utils/__pycache__/stitching.cpython-313.pyc +0 -0
- cv_utils/metrics.py +27 -0
- cv_utils/stitching.py +80 -0
- image_augmentation/__init__.py +27 -0
- image_augmentation/__pycache__/__init__.cpython-313.pyc +0 -0
- image_augmentation/__pycache__/jitter.cpython-313.pyc +0 -0
- image_augmentation/jitter.py +140 -0
app.py
DELETED
|
@@ -1,278 +0,0 @@
|
|
| 1 |
-
import gradio as gr
|
| 2 |
-
from PIL import Image, ImageDraw, ImageEnhance
|
| 3 |
-
import torch
|
| 4 |
-
import kornia
|
| 5 |
-
from transformers import AutoImageProcessor, AutoModel
|
| 6 |
-
import torchvision.transforms as T
|
| 7 |
-
import numpy as np
|
| 8 |
-
import cv2
|
| 9 |
-
import random
|
| 10 |
-
import glob
|
| 11 |
-
from skimage.metrics import structural_similarity as ssim # Hinzugefügt für SSIM
|
| 12 |
-
|
| 13 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 14 |
-
print(f"Using device: {device}")
|
| 15 |
-
|
| 16 |
-
# Lade LightGlue-Modell (genau wie Jupyter)
|
| 17 |
-
processor = AutoImageProcessor.from_pretrained("ETH-CVG/lightglue_superpoint")
|
| 18 |
-
model = AutoModel.from_pretrained("ETH-CVG/lightglue_superpoint")
|
| 19 |
-
model.to(device)
|
| 20 |
-
|
| 21 |
-
def draw_keypoints_and_matches(img_pil, keypoints, color=(0, 255, 0), radius=3):
|
| 22 |
-
"""Plots key points as colored circles on PIL image."""
|
| 23 |
-
img_draw = ImageDraw.Draw(img_pil)
|
| 24 |
-
kpts_np = keypoints.cpu().numpy()
|
| 25 |
-
for kp in kpts_np:
|
| 26 |
-
x, y = int(kp[0]), int(kp[1])
|
| 27 |
-
img_draw.ellipse([x-radius, y-radius, x+radius, y+radius], outline=color, fill=color)
|
| 28 |
-
return img_pil
|
| 29 |
-
|
| 30 |
-
def visualize_matches(images, feature_mapping):
|
| 31 |
-
"""Visualization """
|
| 32 |
-
if len(feature_mapping) == 0:
|
| 33 |
-
return images[0]
|
| 34 |
-
|
| 35 |
-
matches_viz = processor.visualize_keypoint_matching(images, feature_mapping)
|
| 36 |
-
return matches_viz[0]
|
| 37 |
-
|
| 38 |
-
def calculate_ssim(original_img, stitched_img, gray_scale=False):
|
| 39 |
-
"""
|
| 40 |
-
Calculates the SSIM between the original PIL image and the stitched PIL image.
|
| 41 |
-
"""
|
| 42 |
-
orig_np = np.array(original_img)
|
| 43 |
-
|
| 44 |
-
stitch_np = stitched_img.cpu().detach().numpy()
|
| 45 |
-
if stitch_np.ndim == 3:
|
| 46 |
-
stitch_np = np.transpose(stitch_np, (1, 2, 0))
|
| 47 |
-
# ensure imgs to have the same size
|
| 48 |
-
stitch_np = cv2.resize(stitch_np, (orig_np.shape[1], orig_np.shape[0]))
|
| 49 |
-
|
| 50 |
-
print(orig_np.shape, stitch_np.shape)
|
| 51 |
-
|
| 52 |
-
# ensure the data types and ranges match
|
| 53 |
-
orig_np = orig_np.astype(np.float32) / 255.0 if orig_np.max() > 1.0 else orig_np.astype(np.float32)
|
| 54 |
-
stitch_np = stitch_np.astype(np.float32) / 255.0 if stitch_np.max() > 1.0 else stitch_np.astype(np.float32)
|
| 55 |
-
|
| 56 |
-
# conversion to gray scale
|
| 57 |
-
orig_np = np.dot(orig_np[..., :3], [0.2989, 0.5870, 0.1140]) if gray_scale else orig_np
|
| 58 |
-
stitch_np = np.dot(stitch_np[..., :3], [0.2989, 0.5870, 0.1140]) if gray_scale else stitch_np
|
| 59 |
-
|
| 60 |
-
score, diff = ssim(orig_np, stitch_np, full=True, data_range=1.0, channel_axis=None if gray_scale else 2)
|
| 61 |
-
|
| 62 |
-
return score, diff
|
| 63 |
-
|
| 64 |
-
def stitch_images(img0_pil, img1_pil, output, device=device):
|
| 65 |
-
to_tensor = T.ToTensor()
|
| 66 |
-
image0 = to_tensor(img0_pil).to(device)
|
| 67 |
-
image1 = to_tensor(img1_pil).to(device)
|
| 68 |
-
|
| 69 |
-
pts0 = output["keypoints0"].float()
|
| 70 |
-
pts1 = output["keypoints1"].float()
|
| 71 |
-
|
| 72 |
-
print(f"DEBUG: {pts0.shape[0]} Matches found!")
|
| 73 |
-
|
| 74 |
-
if pts0.shape[0] < 4:
|
| 75 |
-
print("Insufficient matches for homography.")
|
| 76 |
-
return None
|
| 77 |
-
|
| 78 |
-
p0_np = pts0.detach().cpu().numpy()
|
| 79 |
-
p1_np = pts1.detach().cpu().numpy()
|
| 80 |
-
|
| 81 |
-
H_np, mask = cv2.findHomography(
|
| 82 |
-
p1_np,
|
| 83 |
-
p0_np,
|
| 84 |
-
method=cv2.USAC_MAGSAC,
|
| 85 |
-
ransacReprojThreshold=5.0,
|
| 86 |
-
confidence=0.999,
|
| 87 |
-
maxIters=1000
|
| 88 |
-
)
|
| 89 |
-
|
| 90 |
-
if H_np is None:
|
| 91 |
-
print("Homography estimation failed.")
|
| 92 |
-
return None
|
| 93 |
-
|
| 94 |
-
H = torch.from_numpy(H_np).to(device).float()
|
| 95 |
-
|
| 96 |
-
c, h0, w0 = image0.shape
|
| 97 |
-
_, h1, w1 = image1.shape
|
| 98 |
-
|
| 99 |
-
corners1 = torch.tensor([[0., 0.], [float(w1), 0.], [float(w1), float(h1)], [0., float(h1)]], device=device)
|
| 100 |
-
corners1_homo = torch.cat([corners1, torch.ones((4, 1), device=device)], dim=1).T
|
| 101 |
-
warped_homo = H @ corners1_homo
|
| 102 |
-
warped_corners1 = (warped_homo[:2] / warped_homo[2]).T
|
| 103 |
-
|
| 104 |
-
all_coords = torch.cat([
|
| 105 |
-
warped_corners1,
|
| 106 |
-
torch.tensor([[0., 0.], [float(w0), 0.], [float(w0), float(h0)], [0., float(h0)]], device=device)
|
| 107 |
-
], dim=0)
|
| 108 |
-
|
| 109 |
-
min_xy = all_coords.min(dim=0).values
|
| 110 |
-
max_xy = all_coords.max(dim=0).values
|
| 111 |
-
|
| 112 |
-
translation = torch.eye(3, device=device)
|
| 113 |
-
translation[0, 2] = -min_xy[0]
|
| 114 |
-
translation[1, 2] = -min_xy[1]
|
| 115 |
-
|
| 116 |
-
H_final = translation @ H
|
| 117 |
-
out_size = (int(max_xy[1] - min_xy[1]), int(max_xy[0] - min_xy[0]))
|
| 118 |
-
|
| 119 |
-
warped0 = kornia.geometry.transform.warp_perspective(
|
| 120 |
-
image0.unsqueeze(0), translation.unsqueeze(0), dsize=out_size, align_corners=True
|
| 121 |
-
).squeeze(0)
|
| 122 |
-
|
| 123 |
-
warped1 = kornia.geometry.transform.warp_perspective(
|
| 124 |
-
image1.unsqueeze(0), H_final.unsqueeze(0), dsize=out_size, align_corners=True
|
| 125 |
-
).squeeze(0)
|
| 126 |
-
|
| 127 |
-
mask0 = (warped0.abs().sum(dim=0, keepdim=True) > 1e-5).float()
|
| 128 |
-
mask1 = (warped1.abs().sum(dim=0, keepdim=True) > 1e-5).float()
|
| 129 |
-
|
| 130 |
-
stitched = (warped0 + warped1) / (mask0 + mask1 + 1e-8)
|
| 131 |
-
|
| 132 |
-
return stitched
|
| 133 |
-
|
| 134 |
-
def feature_detection_mapping(images):
|
| 135 |
-
inputs = processor(images, return_tensors="pt").to(device)
|
| 136 |
-
with torch.no_grad():
|
| 137 |
-
outputs = model(**inputs)
|
| 138 |
-
image_sizes = [[(image.height, image.width) for image in images]]
|
| 139 |
-
outputs = processor.post_process_keypoint_matching(outputs, image_sizes, threshold=0.2)
|
| 140 |
-
return outputs
|
| 141 |
-
|
| 142 |
-
def apply_geometric_jitter(image, rotation_limit=10.0, translation_limit=5, perspective_limit=0.02):
|
| 143 |
-
width, height = image.size
|
| 144 |
-
|
| 145 |
-
angle = random.uniform(-rotation_limit, rotation_limit)
|
| 146 |
-
tx = random.uniform(-translation_limit, translation_limit)
|
| 147 |
-
ty = random.uniform(-translation_limit, translation_limit)
|
| 148 |
-
|
| 149 |
-
img = image.rotate(angle, resample=Image.BILINEAR, translate=(tx, ty))
|
| 150 |
-
|
| 151 |
-
coeffs = [
|
| 152 |
-
1 + random.uniform(-perspective_limit, perspective_limit), 0, 0,
|
| 153 |
-
0, 1 + random.uniform(-perspective_limit, perspective_limit), 0,
|
| 154 |
-
random.uniform(-0.0001, 0.0001), random.uniform(-0.0001, 0.0001)
|
| 155 |
-
]
|
| 156 |
-
|
| 157 |
-
return img.transform((width, height), Image.PERSPECTIVE, coeffs, Image.BILINEAR)
|
| 158 |
-
|
| 159 |
-
def apply_brightness_jitter(image, jitter_range=(0.7, 1.3)):
|
| 160 |
-
enhancer = ImageEnhance.Brightness(image)
|
| 161 |
-
factor = random.uniform(*jitter_range)
|
| 162 |
-
return enhancer.enhance(factor)
|
| 163 |
-
|
| 164 |
-
def remove_alpha(img_rgba, bg_color=(0, 0, 0)):
|
| 165 |
-
background = Image.new("RGB", img_rgba.size, bg_color)
|
| 166 |
-
background.paste(img_rgba, mask=img_rgba.split()[3])
|
| 167 |
-
return background
|
| 168 |
-
|
| 169 |
-
def split_image_variable_diagonal(image_path, min_overlap_pct=0.1):
|
| 170 |
-
img = Image.open(image_path).convert("RGBA")
|
| 171 |
-
w, h = img.size
|
| 172 |
-
margin = 50
|
| 173 |
-
|
| 174 |
-
top_x = random.randint(margin, w - margin)
|
| 175 |
-
|
| 176 |
-
max_slant = w
|
| 177 |
-
min_overlap = int(w * min_overlap_pct)
|
| 178 |
-
bottom_x = random.randint(max(margin, top_x - max_slant),
|
| 179 |
-
min(w - margin, top_x + max_slant))
|
| 180 |
-
|
| 181 |
-
mask_left = Image.new("L", (w, h), 0)
|
| 182 |
-
draw_l = ImageDraw.Draw(mask_left)
|
| 183 |
-
draw_l.polygon([(0, 0), (top_x + min_overlap, 0), (bottom_x + min_overlap, h), (0, h)], fill=255)
|
| 184 |
-
|
| 185 |
-
mask_right = Image.new("L", (w, h), 0)
|
| 186 |
-
draw_r = ImageDraw.Draw(mask_right)
|
| 187 |
-
draw_r.polygon([(top_x - min_overlap, 0), (w, 0), (w, h), (bottom_x - min_overlap, h)], fill=255)
|
| 188 |
-
|
| 189 |
-
left_img = img.copy()
|
| 190 |
-
left_img.putalpha(mask_left)
|
| 191 |
-
|
| 192 |
-
right_img = img.copy()
|
| 193 |
-
right_img.putalpha(mask_right)
|
| 194 |
-
|
| 195 |
-
cropped_left = remove_alpha(left_img)
|
| 196 |
-
cropped_right = remove_alpha(right_img)
|
| 197 |
-
|
| 198 |
-
return cropped_left, cropped_right, img
|
| 199 |
-
|
| 200 |
-
def process_images(files, rot_limit, trans_limit, persp_limit, bright_factor, overlap_pct):
|
| 201 |
-
if files is None or len(files) == 0:
|
| 202 |
-
return None, None, ""
|
| 203 |
-
|
| 204 |
-
imgs = [Image.open(f.name) for f in files if True]
|
| 205 |
-
if len(imgs) == 0:
|
| 206 |
-
return None, None, ""
|
| 207 |
-
|
| 208 |
-
def jitter_image(img):
|
| 209 |
-
img = apply_geometric_jitter(img, rotation_limit=rot_limit, translation_limit=trans_limit, perspective_limit=persp_limit)
|
| 210 |
-
img = apply_brightness_jitter(img, jitter_range=(max(0.1, bright_factor-0.3), min(2.0, bright_factor+0.3)))
|
| 211 |
-
return img
|
| 212 |
-
|
| 213 |
-
if len(imgs) == 1:
|
| 214 |
-
img = imgs[0]
|
| 215 |
-
original = img.copy()
|
| 216 |
-
f_path = files[0].name
|
| 217 |
-
|
| 218 |
-
left, right, _ = split_image_variable_diagonal(f_path, min_overlap_pct=overlap_pct)
|
| 219 |
-
left = jitter_image(left)
|
| 220 |
-
|
| 221 |
-
images_to_stitch = [left, right]
|
| 222 |
-
feature_mapping = feature_detection_mapping(images_to_stitch)
|
| 223 |
-
matches_viz = visualize_matches(images_to_stitch, feature_mapping)
|
| 224 |
-
|
| 225 |
-
stitched = stitch_images(left, right, feature_mapping[0], device=device)
|
| 226 |
-
if stitched is not None:
|
| 227 |
-
result_img = Image.fromarray((stitched.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
|
| 228 |
-
ssim_score, _ = calculate_ssim(original, stitched)
|
| 229 |
-
num_matches = (feature_mapping[0].get("matches0", torch.tensor([])) > -1).sum().item()
|
| 230 |
-
return result_img, matches_viz, f"SSIM: {ssim_score:.4f}"
|
| 231 |
-
else:
|
| 232 |
-
jittered_original = jitter_image(img)
|
| 233 |
-
return jittered_original, matches_viz, "Stitching failed"
|
| 234 |
-
|
| 235 |
-
elif len(imgs) == 2:
|
| 236 |
-
img1, img2 = imgs[0], imgs[1]
|
| 237 |
-
feature_mapping = feature_detection_mapping([img1, img2])
|
| 238 |
-
matches_viz = visualize_matches([img1, img2], feature_mapping)
|
| 239 |
-
|
| 240 |
-
stitched = stitch_images(img1, img2, feature_mapping[0], device=device)
|
| 241 |
-
if stitched is not None:
|
| 242 |
-
result_img = Image.fromarray((stitched.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
|
| 243 |
-
return result_img, matches_viz, ""
|
| 244 |
-
else:
|
| 245 |
-
return img1, matches_viz, "Stitching failed"
|
| 246 |
-
|
| 247 |
-
return None, None, "Only support 1 or 2 images"
|
| 248 |
-
|
| 249 |
-
# Gradio Interface
|
| 250 |
-
with gr.Blocks() as demo:
|
| 251 |
-
gr.Markdown("## Automated Panorama Stitching")
|
| 252 |
-
gr.Markdown("**1 Picture:** Split + Jitter + Stitch | **2 Pictures:** Direct Stitch")
|
| 253 |
-
|
| 254 |
-
with gr.Row():
|
| 255 |
-
with gr.Column(scale=1):
|
| 256 |
-
input_files = gr.File(label="Pictures (1 or 2)", file_types=["image"], file_count="multiple")
|
| 257 |
-
|
| 258 |
-
with gr.Group("Jitter and Split"):
|
| 259 |
-
rotation_slider = gr.Slider(0, 20, value=5, step=0.5, label="Rotation (°)")
|
| 260 |
-
trans_slider = gr.Slider(0, 20, value=3, step=0.5, label="Translation (px)")
|
| 261 |
-
persp_slider = gr.Slider(0, 0.1, value=0.02, step=0.005, label="Perspective")
|
| 262 |
-
bright_slider = gr.Slider(0.5, 1.5, value=1.0, step=0.05, label="Brightness")
|
| 263 |
-
overlap_slider = gr.Slider(0.05, 0.3, value=0.15, step=0.01, label="Overlap (%)")
|
| 264 |
-
|
| 265 |
-
generate_btn = gr.Button("Stitch", variant="primary", size="lg")
|
| 266 |
-
|
| 267 |
-
with gr.Column(scale=2):
|
| 268 |
-
stitched_gallery = gr.Image(label="Result")
|
| 269 |
-
matches_gallery = gr.Image(label="Keypoints & Matches")
|
| 270 |
-
stats_md = gr.Markdown(label="Stats")
|
| 271 |
-
|
| 272 |
-
generate_btn.click(
|
| 273 |
-
fn=process_images,
|
| 274 |
-
inputs=[input_files, rotation_slider, trans_slider, persp_slider, bright_slider, overlap_slider],
|
| 275 |
-
outputs=[stitched_gallery, matches_gallery, stats_md]
|
| 276 |
-
)
|
| 277 |
-
|
| 278 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/README.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
mit ```python app.py``` lokal starten
|
| 2 |
+
|
| 3 |
+
oder hier https://huggingface.co/spaces/PHarder/Automated_Panorama_Stitching testen
|
app/app.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 4 |
+
from image_augmentation import jitter_image, split_image_diagonal
|
| 5 |
+
from cv_utils import Stitcher, calculate_ssim
|
| 6 |
+
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import numpy as np
|
| 9 |
+
import gradio as gr
|
| 10 |
+
|
| 11 |
+
stitcher = Stitcher()
|
| 12 |
+
|
| 13 |
+
def process_images(files, rot_limit, trans_limit_X, trans_limit_Y, persp_limit, bright_factor, overlap_pct):
|
| 14 |
+
if files is None or len(files) == 0:
|
| 15 |
+
return None, None, ""
|
| 16 |
+
|
| 17 |
+
imgs = [Image.open(f.name) for f in files if True]
|
| 18 |
+
if len(imgs) == 0:
|
| 19 |
+
return None, None, ""
|
| 20 |
+
|
| 21 |
+
if len(imgs) == 1:
|
| 22 |
+
img = imgs[0]
|
| 23 |
+
original = img.copy()
|
| 24 |
+
f_path = files[0].name
|
| 25 |
+
|
| 26 |
+
left, right, _ = split_image_diagonal(f_path, min_overlap_pct=overlap_pct)
|
| 27 |
+
left = jitter_image(
|
| 28 |
+
left,
|
| 29 |
+
angle=rot_limit,
|
| 30 |
+
tx=trans_limit_X,
|
| 31 |
+
ty=trans_limit_Y,
|
| 32 |
+
perspective_coeffs=persp_limit,
|
| 33 |
+
brightness_factor=bright_factor
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
images_to_stitch = [left, right]
|
| 37 |
+
stitched, feature_mapping, num_matches = stitcher.stitch(left, right)
|
| 38 |
+
matches_viz = stitcher.visualize_matches(images_to_stitch, feature_mapping)
|
| 39 |
+
|
| 40 |
+
if stitched is not None:
|
| 41 |
+
result_img = Image.fromarray((stitched.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
|
| 42 |
+
ssim_score, _ = calculate_ssim(original, stitched)
|
| 43 |
+
return result_img, matches_viz, f"SSIM: {ssim_score:.4f}"
|
| 44 |
+
else:
|
| 45 |
+
return img, matches_viz, "Stitching failed"
|
| 46 |
+
|
| 47 |
+
elif len(imgs) == 2:
|
| 48 |
+
img1, img2 = imgs[0], imgs[1]
|
| 49 |
+
stitched, feature_mapping, num_matches = stitcher.stitch(img1, img2)
|
| 50 |
+
matches_viz = stitcher.visualize_matches([img1, img2], feature_mapping)
|
| 51 |
+
|
| 52 |
+
if stitched is not None:
|
| 53 |
+
result_img = Image.fromarray((stitched.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
|
| 54 |
+
return result_img, matches_viz, ""
|
| 55 |
+
else:
|
| 56 |
+
return img1, matches_viz, "Stitching failed"
|
| 57 |
+
|
| 58 |
+
return None, None, "Only support 1 or 2 images"
|
| 59 |
+
|
| 60 |
+
# Gradio Interface
|
| 61 |
+
with gr.Blocks() as demo:
|
| 62 |
+
gr.Markdown("## Automated Panorama Stitching")
|
| 63 |
+
gr.Markdown("**1 Picture:** Split + Jitter + Stitch | **2 Pictures:** Direct Stitch")
|
| 64 |
+
|
| 65 |
+
with gr.Row():
|
| 66 |
+
with gr.Column(scale=1):
|
| 67 |
+
input_files = gr.File(label="Pictures (1 or 2)", file_types=["image"], file_count="multiple")
|
| 68 |
+
|
| 69 |
+
with gr.Group("Jitter and Split"):
|
| 70 |
+
rotation_slider = gr.Slider(0, 20, value=5, step=0.5, label="Rotation (°)")
|
| 71 |
+
trans_slider_X = gr.Slider(0, 20, value=3, step=0.5, label="Translation X (px)")
|
| 72 |
+
trans_slider_Y = gr.Slider(0, 20, value=3, step=0.5, label="Translation Y (px)")
|
| 73 |
+
persp_slider = gr.Slider(0, 0.1, value=0.02, step=0.005, label="Perspective")
|
| 74 |
+
bright_slider = gr.Slider(0.5, 1.5, value=1.0, step=0.05, label="Brightness")
|
| 75 |
+
overlap_slider = gr.Slider(0.05, 0.3, value=0.15, step=0.01, label="Overlap (%)")
|
| 76 |
+
|
| 77 |
+
generate_btn = gr.Button("Stitch", variant="primary", size="lg")
|
| 78 |
+
|
| 79 |
+
with gr.Column(scale=2):
|
| 80 |
+
stitched_gallery = gr.Image(label="Result")
|
| 81 |
+
matches_gallery = gr.Image(label="Keypoints & Matches")
|
| 82 |
+
stats_md = gr.Markdown(label="Stats")
|
| 83 |
+
|
| 84 |
+
generate_btn.click(
|
| 85 |
+
fn=process_images,
|
| 86 |
+
inputs=[input_files, rotation_slider, trans_slider_X, trans_slider_Y, persp_slider, bright_slider, overlap_slider],
|
| 87 |
+
outputs=[stitched_gallery, matches_gallery, stats_md]
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
demo.launch(debug=True)
|
app/requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
kornia
|
| 4 |
+
transformers
|
| 5 |
+
opencv-python
|
| 6 |
+
scikit-image
|
| 7 |
+
pillow
|
| 8 |
+
gradio
|
| 9 |
+
numpy
|
| 10 |
+
matplotlib
|
cv_utils/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoImageProcessor, AutoModel
|
| 3 |
+
from .stitching import feature_detection_mapping, stitch_images
|
| 4 |
+
from .metrics import calculate_ssim
|
| 5 |
+
|
| 6 |
+
class Stitcher:
|
| 7 |
+
def __init__(self, model_name="ETH-CVG/lightglue_superpoint"):
|
| 8 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 9 |
+
print(f"Using device: {self.device}")
|
| 10 |
+
self.processor = AutoImageProcessor.from_pretrained(model_name)
|
| 11 |
+
self.model = AutoModel.from_pretrained(model_name)
|
| 12 |
+
self.model.to(self.device)
|
| 13 |
+
|
| 14 |
+
def stitch(self, img0_pil, img1_pil):
|
| 15 |
+
feature_mapping = feature_detection_mapping([img0_pil, img1_pil], self.processor, self.model, self.device)
|
| 16 |
+
stitched_image, num_matches = stitch_images(img0_pil, img1_pil, feature_mapping[0], self.device)
|
| 17 |
+
return stitched_image, feature_mapping, num_matches
|
| 18 |
+
|
| 19 |
+
def visualize_matches(self, images, feature_mapping):
|
| 20 |
+
"""Visualization """
|
| 21 |
+
if len(feature_mapping) == 0:
|
| 22 |
+
return images[0]
|
| 23 |
+
|
| 24 |
+
matches_viz = self.processor.visualize_keypoint_matching(images, feature_mapping)
|
| 25 |
+
return matches_viz[0]
|
cv_utils/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (2.21 kB). View file
|
|
|
cv_utils/__pycache__/metrics.cpython-313.pyc
ADDED
|
Binary file (1.97 kB). View file
|
|
|
cv_utils/__pycache__/stitching.cpython-313.pyc
ADDED
|
Binary file (4.86 kB). View file
|
|
|
cv_utils/metrics.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import cv2
|
| 3 |
+
from skimage.metrics import structural_similarity as ssim
|
| 4 |
+
|
| 5 |
+
def calculate_ssim(original_img, stitched_img, gray_scale=False):
|
| 6 |
+
"""
|
| 7 |
+
Calculates the SSIM between the original PIL image and the stitched PIL image.
|
| 8 |
+
"""
|
| 9 |
+
orig_np = np.array(original_img)
|
| 10 |
+
|
| 11 |
+
stitch_np = stitched_img.cpu().detach().numpy()
|
| 12 |
+
if stitch_np.ndim == 3:
|
| 13 |
+
stitch_np = np.transpose(stitch_np, (1, 2, 0))
|
| 14 |
+
# ensure imgs to have the same size
|
| 15 |
+
stitch_np = cv2.resize(stitch_np, (orig_np.shape[1], orig_np.shape[0]))
|
| 16 |
+
|
| 17 |
+
# ensure the data types and ranges match
|
| 18 |
+
orig_np = orig_np.astype(np.float32) / 255.0 if orig_np.max() > 1.0 else orig_np.astype(np.float32)
|
| 19 |
+
stitch_np = stitch_np.astype(np.float32) / 255.0 if stitch_np.max() > 1.0 else stitch_np.astype(np.float32)
|
| 20 |
+
|
| 21 |
+
# conversion to gray scale
|
| 22 |
+
orig_np = np.dot(orig_np[..., :3], [0.2989, 0.5870, 0.1140]) if gray_scale else orig_np
|
| 23 |
+
stitch_np = np.dot(stitch_np[..., :3], [0.2989, 0.5870, 0.1140]) if gray_scale else stitch_np
|
| 24 |
+
|
| 25 |
+
score, diff = ssim(orig_np, stitch_np, full=True, data_range=1.0, channel_axis=None if gray_scale else 2)
|
| 26 |
+
|
| 27 |
+
return score, diff
|
cv_utils/stitching.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import kornia
|
| 3 |
+
import cv2
|
| 4 |
+
import torchvision.transforms as T
|
| 5 |
+
|
| 6 |
+
def feature_detection_mapping(images, processor, model, device):
|
| 7 |
+
inputs = processor(images, return_tensors="pt").to(device)
|
| 8 |
+
with torch.no_grad():
|
| 9 |
+
outputs = model(**inputs)
|
| 10 |
+
image_sizes = [[(image.height, image.width) for image in images]]
|
| 11 |
+
outputs = processor.post_process_keypoint_matching(outputs, image_sizes, threshold=0.2)
|
| 12 |
+
return outputs
|
| 13 |
+
|
| 14 |
+
def stitch_images(img0_pil, img1_pil, output, device):
|
| 15 |
+
to_tensor = T.ToTensor()
|
| 16 |
+
image0 = to_tensor(img0_pil).to(device)
|
| 17 |
+
image1 = to_tensor(img1_pil).to(device)
|
| 18 |
+
|
| 19 |
+
pts0 = output["keypoints0"].float()
|
| 20 |
+
pts1 = output["keypoints1"].float()
|
| 21 |
+
|
| 22 |
+
num_matches = pts0.shape[0]
|
| 23 |
+
|
| 24 |
+
if num_matches < 4:
|
| 25 |
+
return None, num_matches
|
| 26 |
+
|
| 27 |
+
p0_np = pts0.detach().cpu().numpy()
|
| 28 |
+
p1_np = pts1.detach().cpu().numpy()
|
| 29 |
+
|
| 30 |
+
H_np, mask = cv2.findHomography(
|
| 31 |
+
p1_np,
|
| 32 |
+
p0_np,
|
| 33 |
+
method=cv2.USAC_MAGSAC,
|
| 34 |
+
ransacReprojThreshold=5.0,
|
| 35 |
+
confidence=0.999,
|
| 36 |
+
maxIters=100000
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
if H_np is None:
|
| 40 |
+
return None, num_matches
|
| 41 |
+
|
| 42 |
+
H = torch.from_numpy(H_np).to(device).float()
|
| 43 |
+
|
| 44 |
+
_, h0, w0 = image0.shape
|
| 45 |
+
_, h1, w1 = image1.shape
|
| 46 |
+
|
| 47 |
+
corners1 = torch.tensor([[0., 0.], [float(w1), 0.], [float(w1), float(h1)], [0., float(h1)]], device=device)
|
| 48 |
+
corners1_homo = torch.cat([corners1, torch.ones((4, 1), device=device)], dim=1).T
|
| 49 |
+
warped_homo = H @ corners1_homo
|
| 50 |
+
warped_corners1 = (warped_homo[:2] / warped_homo[2]).T
|
| 51 |
+
|
| 52 |
+
all_coords = torch.cat([
|
| 53 |
+
warped_corners1,
|
| 54 |
+
torch.tensor([[0., 0.], [float(w0), 0.], [float(w0), float(h0)], [0., float(h0)]], device=device)
|
| 55 |
+
], dim=0)
|
| 56 |
+
|
| 57 |
+
min_xy = all_coords.min(dim=0).values
|
| 58 |
+
max_xy = all_coords.max(dim=0).values
|
| 59 |
+
|
| 60 |
+
translation = torch.eye(3, device=device)
|
| 61 |
+
translation[0, 2] = -min_xy[0]
|
| 62 |
+
translation[1, 2] = -min_xy[1]
|
| 63 |
+
|
| 64 |
+
H_final = translation @ H
|
| 65 |
+
out_size = (int(max_xy[1] - min_xy[1]), int(max_xy[0] - min_xy[0]))
|
| 66 |
+
|
| 67 |
+
warped0 = kornia.geometry.transform.warp_perspective(
|
| 68 |
+
image0.unsqueeze(0), translation.unsqueeze(0), dsize=out_size, align_corners=True
|
| 69 |
+
).squeeze(0)
|
| 70 |
+
|
| 71 |
+
warped1 = kornia.geometry.transform.warp_perspective(
|
| 72 |
+
image1.unsqueeze(0), H_final.unsqueeze(0), dsize=out_size, align_corners=True
|
| 73 |
+
).squeeze(0)
|
| 74 |
+
|
| 75 |
+
mask0 = (warped0.abs().sum(dim=0, keepdim=True) > 1e-5).float()
|
| 76 |
+
mask1 = (warped1.abs().sum(dim=0, keepdim=True) > 1e-5).float()
|
| 77 |
+
|
| 78 |
+
stitched = (warped0 + warped1) / (mask0 + mask1 + 1e-8)
|
| 79 |
+
|
| 80 |
+
return stitched, num_matches
|
image_augmentation/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Image Augmentation Package
|
| 3 |
+
This package provides functions for image jittering and augmentation.
|
| 4 |
+
"""
|
| 5 |
+
from .jitter import (
|
| 6 |
+
apply_geometric_jitter,
|
| 7 |
+
apply_brightness_jitter_range,
|
| 8 |
+
apply_geometric_transform,
|
| 9 |
+
apply_brightness_jitter,
|
| 10 |
+
jitter_image_random,
|
| 11 |
+
jitter_image,
|
| 12 |
+
remove_alpha,
|
| 13 |
+
split_image_diagonal_random,
|
| 14 |
+
split_image_diagonal
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
"apply_geometric_jitter",
|
| 19 |
+
"apply_brightness_jitter_range",
|
| 20 |
+
"apply_geometric_transform",
|
| 21 |
+
"apply_brightness_jitter",
|
| 22 |
+
"jitter_image_random",
|
| 23 |
+
"jitter_image",
|
| 24 |
+
"remove_alpha",
|
| 25 |
+
"split_image_diagonal_random",
|
| 26 |
+
"split_image_diagonal"
|
| 27 |
+
]
|
image_augmentation/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (651 Bytes). View file
|
|
|
image_augmentation/__pycache__/jitter.cpython-313.pyc
ADDED
|
Binary file (7.12 kB). View file
|
|
|
image_augmentation/jitter.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image, ImageDraw, ImageEnhance
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# Range depended functions
|
| 6 |
+
def apply_geometric_jitter(image, rotation_limit=10.0, translation_limit=5, perspective_limit=0.02):
|
| 7 |
+
"""
|
| 8 |
+
Applies geometric jitter to an image, including rotation, translation, and perspective shifts using random values within limits.
|
| 9 |
+
"""
|
| 10 |
+
width, height = image.size
|
| 11 |
+
|
| 12 |
+
angle = random.uniform(-rotation_limit, rotation_limit)
|
| 13 |
+
tx = random.uniform(-translation_limit, translation_limit)
|
| 14 |
+
ty = random.uniform(-translation_limit, translation_limit)
|
| 15 |
+
|
| 16 |
+
img = image.rotate(angle, resample=Image.BILINEAR, translate=(tx, ty))
|
| 17 |
+
|
| 18 |
+
coeffs = [
|
| 19 |
+
1 + random.uniform(-perspective_limit, perspective_limit), 0, 0,
|
| 20 |
+
0, 1 + random.uniform(-perspective_limit, perspective_limit), 0,
|
| 21 |
+
random.uniform(-0.0001, 0.0001), random.uniform(-0.0001, 0.0001)
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
return img.transform((width, height), Image.PERSPECTIVE, coeffs, Image.BILINEAR)
|
| 25 |
+
|
| 26 |
+
def apply_brightness_jitter_range(image, jitter_range=(0.7, 1.3)):
|
| 27 |
+
enhancer = ImageEnhance.Brightness(image)
|
| 28 |
+
factor = random.uniform(*jitter_range)
|
| 29 |
+
return enhancer.enhance(factor)
|
| 30 |
+
|
| 31 |
+
def jitter_image_random(img, rot_limit=5, trans_limit=3, persp_limit=0.02, bright_factor=1.0, bright_range_delta=0.3):
|
| 32 |
+
"""
|
| 33 |
+
Applies a combination of geometric and brightness jitter to an image.
|
| 34 |
+
"""
|
| 35 |
+
img = apply_geometric_jitter(img, rotation_limit=rot_limit, translation_limit=trans_limit, perspective_limit=persp_limit)
|
| 36 |
+
|
| 37 |
+
brightness_min = max(0.1, bright_factor - bright_range_delta)
|
| 38 |
+
brightness_max = min(2.0, bright_factor + bright_range_delta)
|
| 39 |
+
|
| 40 |
+
img = apply_brightness_jitter_range(img, jitter_range=(brightness_min, brightness_max))
|
| 41 |
+
return img
|
| 42 |
+
|
| 43 |
+
# Specific augmentations
|
| 44 |
+
def apply_geometric_transform(image, angle, tx, ty, perspective_coeffs):
|
| 45 |
+
width, height = image.size
|
| 46 |
+
img = image.rotate(angle, resample=Image.BILINEAR, translate=(tx, ty))
|
| 47 |
+
coeffs = [
|
| 48 |
+
1 + perspective_coeffs, 0, 0,
|
| 49 |
+
0, 1 + perspective_coeffs, 0,
|
| 50 |
+
0, 0
|
| 51 |
+
]
|
| 52 |
+
return img.transform((width, height), Image.PERSPECTIVE, coeffs, Image.BILINEAR)
|
| 53 |
+
|
| 54 |
+
def apply_brightness_jitter(image, factor):
|
| 55 |
+
enhancer = ImageEnhance.Brightness(image)
|
| 56 |
+
return enhancer.enhance(factor)
|
| 57 |
+
|
| 58 |
+
def jitter_image(img, angle, tx, ty, perspective_coeffs, brightness_factor):
|
| 59 |
+
"""
|
| 60 |
+
Applies a specific combination of geometric and brightness jitter to an image.
|
| 61 |
+
"""
|
| 62 |
+
img = apply_geometric_transform(img, angle, tx, ty, perspective_coeffs)
|
| 63 |
+
img = apply_brightness_jitter(img, brightness_factor)
|
| 64 |
+
return img
|
| 65 |
+
|
| 66 |
+
# Split functions
|
| 67 |
+
def split_image_diagonal_random(image_path, min_overlap_pct=0.1):
|
| 68 |
+
"""
|
| 69 |
+
Splits an image diagonally into two parts, including a certain overlap.
|
| 70 |
+
"""
|
| 71 |
+
img = Image.open(image_path).convert("RGBA")
|
| 72 |
+
w, h = img.size
|
| 73 |
+
margin = 50
|
| 74 |
+
|
| 75 |
+
top_x = random.randint(margin, w - margin)
|
| 76 |
+
|
| 77 |
+
max_slant = w
|
| 78 |
+
min_overlap = int(w * min_overlap_pct)
|
| 79 |
+
bottom_x = random.randint(max(margin, top_x - max_slant),
|
| 80 |
+
min(w - margin, top_x + max_slant))
|
| 81 |
+
|
| 82 |
+
mask_left = Image.new("L", (w, h), 0)
|
| 83 |
+
draw_l = ImageDraw.Draw(mask_left)
|
| 84 |
+
draw_l.polygon([(0, 0), (top_x + min_overlap, 0), (bottom_x + min_overlap, h), (0, h)], fill=255)
|
| 85 |
+
|
| 86 |
+
mask_right = Image.new("L", (w, h), 0)
|
| 87 |
+
draw_r = ImageDraw.Draw(mask_right)
|
| 88 |
+
draw_r.polygon([(top_x - min_overlap, 0), (w, 0), (w, h), (bottom_x - min_overlap, h)], fill=255)
|
| 89 |
+
|
| 90 |
+
left_img = img.copy()
|
| 91 |
+
left_img.putalpha(mask_left)
|
| 92 |
+
|
| 93 |
+
right_img = img.copy()
|
| 94 |
+
right_img.putalpha(mask_right)
|
| 95 |
+
|
| 96 |
+
cropped_left = remove_alpha(left_img)
|
| 97 |
+
cropped_right = remove_alpha(right_img)
|
| 98 |
+
|
| 99 |
+
return cropped_left, cropped_right, img
|
| 100 |
+
|
| 101 |
+
def split_image_diagonal(image_path, min_overlap_pct):
|
| 102 |
+
"""
|
| 103 |
+
Splits an image diagonally into two parts.
|
| 104 |
+
"""
|
| 105 |
+
img = Image.open(image_path).convert("RGBA")
|
| 106 |
+
w, h = img.size
|
| 107 |
+
margin = 50
|
| 108 |
+
|
| 109 |
+
top_x = w // 2
|
| 110 |
+
bottom_x = w // 2
|
| 111 |
+
|
| 112 |
+
min_overlap = int(w * min_overlap_pct)
|
| 113 |
+
|
| 114 |
+
mask_left = Image.new("L", (w, h), 0)
|
| 115 |
+
draw_l = ImageDraw.Draw(mask_left)
|
| 116 |
+
draw_l.polygon([(0, 0), (top_x + min_overlap, 0), (bottom_x + min_overlap, h), (0, h)], fill=255)
|
| 117 |
+
|
| 118 |
+
mask_right = Image.new("L", (w, h), 0)
|
| 119 |
+
draw_r = ImageDraw.Draw(mask_right)
|
| 120 |
+
draw_r.polygon([(top_x - min_overlap, 0), (w, 0), (w, h), (bottom_x - min_overlap, h)], fill=255)
|
| 121 |
+
|
| 122 |
+
left_img = img.copy()
|
| 123 |
+
left_img.putalpha(mask_left)
|
| 124 |
+
|
| 125 |
+
right_img = img.copy()
|
| 126 |
+
right_img.putalpha(mask_right)
|
| 127 |
+
|
| 128 |
+
cropped_left = remove_alpha(left_img)
|
| 129 |
+
cropped_right = remove_alpha(right_img)
|
| 130 |
+
|
| 131 |
+
return cropped_left, cropped_right, img
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def remove_alpha(img_rgba, bg_color=(0, 0, 0)):
|
| 135 |
+
"""
|
| 136 |
+
Removes the alpha channel from an RGBA image and replaces it with a solid background color.
|
| 137 |
+
"""
|
| 138 |
+
background = Image.new("RGB", img_rgba.size, bg_color)
|
| 139 |
+
background.paste(img_rgba, mask=img_rgba.split()[3])
|
| 140 |
+
return background
|