PHarder commited on
Commit
8a814b1
·
1 Parent(s): 5ead2c4
app.py DELETED
@@ -1,278 +0,0 @@
1
- import gradio as gr
2
- from PIL import Image, ImageDraw, ImageEnhance
3
- import torch
4
- import kornia
5
- from transformers import AutoImageProcessor, AutoModel
6
- import torchvision.transforms as T
7
- import numpy as np
8
- import cv2
9
- import random
10
- import glob
11
- from skimage.metrics import structural_similarity as ssim # Hinzugefügt für SSIM
12
-
13
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
- print(f"Using device: {device}")
15
-
16
- # Lade LightGlue-Modell (genau wie Jupyter)
17
- processor = AutoImageProcessor.from_pretrained("ETH-CVG/lightglue_superpoint")
18
- model = AutoModel.from_pretrained("ETH-CVG/lightglue_superpoint")
19
- model.to(device)
20
-
21
- def draw_keypoints_and_matches(img_pil, keypoints, color=(0, 255, 0), radius=3):
22
- """Plots key points as colored circles on PIL image."""
23
- img_draw = ImageDraw.Draw(img_pil)
24
- kpts_np = keypoints.cpu().numpy()
25
- for kp in kpts_np:
26
- x, y = int(kp[0]), int(kp[1])
27
- img_draw.ellipse([x-radius, y-radius, x+radius, y+radius], outline=color, fill=color)
28
- return img_pil
29
-
30
- def visualize_matches(images, feature_mapping):
31
- """Visualization """
32
- if len(feature_mapping) == 0:
33
- return images[0]
34
-
35
- matches_viz = processor.visualize_keypoint_matching(images, feature_mapping)
36
- return matches_viz[0]
37
-
38
- def calculate_ssim(original_img, stitched_img, gray_scale=False):
39
- """
40
- Calculates the SSIM between the original PIL image and the stitched PIL image.
41
- """
42
- orig_np = np.array(original_img)
43
-
44
- stitch_np = stitched_img.cpu().detach().numpy()
45
- if stitch_np.ndim == 3:
46
- stitch_np = np.transpose(stitch_np, (1, 2, 0))
47
- # ensure imgs to have the same size
48
- stitch_np = cv2.resize(stitch_np, (orig_np.shape[1], orig_np.shape[0]))
49
-
50
- print(orig_np.shape, stitch_np.shape)
51
-
52
- # ensure the data types and ranges match
53
- orig_np = orig_np.astype(np.float32) / 255.0 if orig_np.max() > 1.0 else orig_np.astype(np.float32)
54
- stitch_np = stitch_np.astype(np.float32) / 255.0 if stitch_np.max() > 1.0 else stitch_np.astype(np.float32)
55
-
56
- # conversion to gray scale
57
- orig_np = np.dot(orig_np[..., :3], [0.2989, 0.5870, 0.1140]) if gray_scale else orig_np
58
- stitch_np = np.dot(stitch_np[..., :3], [0.2989, 0.5870, 0.1140]) if gray_scale else stitch_np
59
-
60
- score, diff = ssim(orig_np, stitch_np, full=True, data_range=1.0, channel_axis=None if gray_scale else 2)
61
-
62
- return score, diff
63
-
64
- def stitch_images(img0_pil, img1_pil, output, device=device):
65
- to_tensor = T.ToTensor()
66
- image0 = to_tensor(img0_pil).to(device)
67
- image1 = to_tensor(img1_pil).to(device)
68
-
69
- pts0 = output["keypoints0"].float()
70
- pts1 = output["keypoints1"].float()
71
-
72
- print(f"DEBUG: {pts0.shape[0]} Matches found!")
73
-
74
- if pts0.shape[0] < 4:
75
- print("Insufficient matches for homography.")
76
- return None
77
-
78
- p0_np = pts0.detach().cpu().numpy()
79
- p1_np = pts1.detach().cpu().numpy()
80
-
81
- H_np, mask = cv2.findHomography(
82
- p1_np,
83
- p0_np,
84
- method=cv2.USAC_MAGSAC,
85
- ransacReprojThreshold=5.0,
86
- confidence=0.999,
87
- maxIters=1000
88
- )
89
-
90
- if H_np is None:
91
- print("Homography estimation failed.")
92
- return None
93
-
94
- H = torch.from_numpy(H_np).to(device).float()
95
-
96
- c, h0, w0 = image0.shape
97
- _, h1, w1 = image1.shape
98
-
99
- corners1 = torch.tensor([[0., 0.], [float(w1), 0.], [float(w1), float(h1)], [0., float(h1)]], device=device)
100
- corners1_homo = torch.cat([corners1, torch.ones((4, 1), device=device)], dim=1).T
101
- warped_homo = H @ corners1_homo
102
- warped_corners1 = (warped_homo[:2] / warped_homo[2]).T
103
-
104
- all_coords = torch.cat([
105
- warped_corners1,
106
- torch.tensor([[0., 0.], [float(w0), 0.], [float(w0), float(h0)], [0., float(h0)]], device=device)
107
- ], dim=0)
108
-
109
- min_xy = all_coords.min(dim=0).values
110
- max_xy = all_coords.max(dim=0).values
111
-
112
- translation = torch.eye(3, device=device)
113
- translation[0, 2] = -min_xy[0]
114
- translation[1, 2] = -min_xy[1]
115
-
116
- H_final = translation @ H
117
- out_size = (int(max_xy[1] - min_xy[1]), int(max_xy[0] - min_xy[0]))
118
-
119
- warped0 = kornia.geometry.transform.warp_perspective(
120
- image0.unsqueeze(0), translation.unsqueeze(0), dsize=out_size, align_corners=True
121
- ).squeeze(0)
122
-
123
- warped1 = kornia.geometry.transform.warp_perspective(
124
- image1.unsqueeze(0), H_final.unsqueeze(0), dsize=out_size, align_corners=True
125
- ).squeeze(0)
126
-
127
- mask0 = (warped0.abs().sum(dim=0, keepdim=True) > 1e-5).float()
128
- mask1 = (warped1.abs().sum(dim=0, keepdim=True) > 1e-5).float()
129
-
130
- stitched = (warped0 + warped1) / (mask0 + mask1 + 1e-8)
131
-
132
- return stitched
133
-
134
- def feature_detection_mapping(images):
135
- inputs = processor(images, return_tensors="pt").to(device)
136
- with torch.no_grad():
137
- outputs = model(**inputs)
138
- image_sizes = [[(image.height, image.width) for image in images]]
139
- outputs = processor.post_process_keypoint_matching(outputs, image_sizes, threshold=0.2)
140
- return outputs
141
-
142
- def apply_geometric_jitter(image, rotation_limit=10.0, translation_limit=5, perspective_limit=0.02):
143
- width, height = image.size
144
-
145
- angle = random.uniform(-rotation_limit, rotation_limit)
146
- tx = random.uniform(-translation_limit, translation_limit)
147
- ty = random.uniform(-translation_limit, translation_limit)
148
-
149
- img = image.rotate(angle, resample=Image.BILINEAR, translate=(tx, ty))
150
-
151
- coeffs = [
152
- 1 + random.uniform(-perspective_limit, perspective_limit), 0, 0,
153
- 0, 1 + random.uniform(-perspective_limit, perspective_limit), 0,
154
- random.uniform(-0.0001, 0.0001), random.uniform(-0.0001, 0.0001)
155
- ]
156
-
157
- return img.transform((width, height), Image.PERSPECTIVE, coeffs, Image.BILINEAR)
158
-
159
- def apply_brightness_jitter(image, jitter_range=(0.7, 1.3)):
160
- enhancer = ImageEnhance.Brightness(image)
161
- factor = random.uniform(*jitter_range)
162
- return enhancer.enhance(factor)
163
-
164
- def remove_alpha(img_rgba, bg_color=(0, 0, 0)):
165
- background = Image.new("RGB", img_rgba.size, bg_color)
166
- background.paste(img_rgba, mask=img_rgba.split()[3])
167
- return background
168
-
169
- def split_image_variable_diagonal(image_path, min_overlap_pct=0.1):
170
- img = Image.open(image_path).convert("RGBA")
171
- w, h = img.size
172
- margin = 50
173
-
174
- top_x = random.randint(margin, w - margin)
175
-
176
- max_slant = w
177
- min_overlap = int(w * min_overlap_pct)
178
- bottom_x = random.randint(max(margin, top_x - max_slant),
179
- min(w - margin, top_x + max_slant))
180
-
181
- mask_left = Image.new("L", (w, h), 0)
182
- draw_l = ImageDraw.Draw(mask_left)
183
- draw_l.polygon([(0, 0), (top_x + min_overlap, 0), (bottom_x + min_overlap, h), (0, h)], fill=255)
184
-
185
- mask_right = Image.new("L", (w, h), 0)
186
- draw_r = ImageDraw.Draw(mask_right)
187
- draw_r.polygon([(top_x - min_overlap, 0), (w, 0), (w, h), (bottom_x - min_overlap, h)], fill=255)
188
-
189
- left_img = img.copy()
190
- left_img.putalpha(mask_left)
191
-
192
- right_img = img.copy()
193
- right_img.putalpha(mask_right)
194
-
195
- cropped_left = remove_alpha(left_img)
196
- cropped_right = remove_alpha(right_img)
197
-
198
- return cropped_left, cropped_right, img
199
-
200
- def process_images(files, rot_limit, trans_limit, persp_limit, bright_factor, overlap_pct):
201
- if files is None or len(files) == 0:
202
- return None, None, ""
203
-
204
- imgs = [Image.open(f.name) for f in files if True]
205
- if len(imgs) == 0:
206
- return None, None, ""
207
-
208
- def jitter_image(img):
209
- img = apply_geometric_jitter(img, rotation_limit=rot_limit, translation_limit=trans_limit, perspective_limit=persp_limit)
210
- img = apply_brightness_jitter(img, jitter_range=(max(0.1, bright_factor-0.3), min(2.0, bright_factor+0.3)))
211
- return img
212
-
213
- if len(imgs) == 1:
214
- img = imgs[0]
215
- original = img.copy()
216
- f_path = files[0].name
217
-
218
- left, right, _ = split_image_variable_diagonal(f_path, min_overlap_pct=overlap_pct)
219
- left = jitter_image(left)
220
-
221
- images_to_stitch = [left, right]
222
- feature_mapping = feature_detection_mapping(images_to_stitch)
223
- matches_viz = visualize_matches(images_to_stitch, feature_mapping)
224
-
225
- stitched = stitch_images(left, right, feature_mapping[0], device=device)
226
- if stitched is not None:
227
- result_img = Image.fromarray((stitched.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
228
- ssim_score, _ = calculate_ssim(original, stitched)
229
- num_matches = (feature_mapping[0].get("matches0", torch.tensor([])) > -1).sum().item()
230
- return result_img, matches_viz, f"SSIM: {ssim_score:.4f}"
231
- else:
232
- jittered_original = jitter_image(img)
233
- return jittered_original, matches_viz, "Stitching failed"
234
-
235
- elif len(imgs) == 2:
236
- img1, img2 = imgs[0], imgs[1]
237
- feature_mapping = feature_detection_mapping([img1, img2])
238
- matches_viz = visualize_matches([img1, img2], feature_mapping)
239
-
240
- stitched = stitch_images(img1, img2, feature_mapping[0], device=device)
241
- if stitched is not None:
242
- result_img = Image.fromarray((stitched.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
243
- return result_img, matches_viz, ""
244
- else:
245
- return img1, matches_viz, "Stitching failed"
246
-
247
- return None, None, "Only support 1 or 2 images"
248
-
249
- # Gradio Interface
250
- with gr.Blocks() as demo:
251
- gr.Markdown("## Automated Panorama Stitching")
252
- gr.Markdown("**1 Picture:** Split + Jitter + Stitch | **2 Pictures:** Direct Stitch")
253
-
254
- with gr.Row():
255
- with gr.Column(scale=1):
256
- input_files = gr.File(label="Pictures (1 or 2)", file_types=["image"], file_count="multiple")
257
-
258
- with gr.Group("Jitter and Split"):
259
- rotation_slider = gr.Slider(0, 20, value=5, step=0.5, label="Rotation (°)")
260
- trans_slider = gr.Slider(0, 20, value=3, step=0.5, label="Translation (px)")
261
- persp_slider = gr.Slider(0, 0.1, value=0.02, step=0.005, label="Perspective")
262
- bright_slider = gr.Slider(0.5, 1.5, value=1.0, step=0.05, label="Brightness")
263
- overlap_slider = gr.Slider(0.05, 0.3, value=0.15, step=0.01, label="Overlap (%)")
264
-
265
- generate_btn = gr.Button("Stitch", variant="primary", size="lg")
266
-
267
- with gr.Column(scale=2):
268
- stitched_gallery = gr.Image(label="Result")
269
- matches_gallery = gr.Image(label="Keypoints & Matches")
270
- stats_md = gr.Markdown(label="Stats")
271
-
272
- generate_btn.click(
273
- fn=process_images,
274
- inputs=[input_files, rotation_slider, trans_slider, persp_slider, bright_slider, overlap_slider],
275
- outputs=[stitched_gallery, matches_gallery, stats_md]
276
- )
277
-
278
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ mit ```python app.py``` lokal starten
2
+
3
+ oder hier https://huggingface.co/spaces/PHarder/Automated_Panorama_Stitching testen
app/app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
4
+ from image_augmentation import jitter_image, split_image_diagonal
5
+ from cv_utils import Stitcher, calculate_ssim
6
+
7
+ from PIL import Image
8
+ import numpy as np
9
+ import gradio as gr
10
+
11
+ stitcher = Stitcher()
12
+
13
+ def process_images(files, rot_limit, trans_limit_X, trans_limit_Y, persp_limit, bright_factor, overlap_pct):
14
+ if files is None or len(files) == 0:
15
+ return None, None, ""
16
+
17
+ imgs = [Image.open(f.name) for f in files if True]
18
+ if len(imgs) == 0:
19
+ return None, None, ""
20
+
21
+ if len(imgs) == 1:
22
+ img = imgs[0]
23
+ original = img.copy()
24
+ f_path = files[0].name
25
+
26
+ left, right, _ = split_image_diagonal(f_path, min_overlap_pct=overlap_pct)
27
+ left = jitter_image(
28
+ left,
29
+ angle=rot_limit,
30
+ tx=trans_limit_X,
31
+ ty=trans_limit_Y,
32
+ perspective_coeffs=persp_limit,
33
+ brightness_factor=bright_factor
34
+ )
35
+
36
+ images_to_stitch = [left, right]
37
+ stitched, feature_mapping, num_matches = stitcher.stitch(left, right)
38
+ matches_viz = stitcher.visualize_matches(images_to_stitch, feature_mapping)
39
+
40
+ if stitched is not None:
41
+ result_img = Image.fromarray((stitched.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
42
+ ssim_score, _ = calculate_ssim(original, stitched)
43
+ return result_img, matches_viz, f"SSIM: {ssim_score:.4f}"
44
+ else:
45
+ return img, matches_viz, "Stitching failed"
46
+
47
+ elif len(imgs) == 2:
48
+ img1, img2 = imgs[0], imgs[1]
49
+ stitched, feature_mapping, num_matches = stitcher.stitch(img1, img2)
50
+ matches_viz = stitcher.visualize_matches([img1, img2], feature_mapping)
51
+
52
+ if stitched is not None:
53
+ result_img = Image.fromarray((stitched.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
54
+ return result_img, matches_viz, ""
55
+ else:
56
+ return img1, matches_viz, "Stitching failed"
57
+
58
+ return None, None, "Only support 1 or 2 images"
59
+
60
+ # Gradio Interface
61
+ with gr.Blocks() as demo:
62
+ gr.Markdown("## Automated Panorama Stitching")
63
+ gr.Markdown("**1 Picture:** Split + Jitter + Stitch | **2 Pictures:** Direct Stitch")
64
+
65
+ with gr.Row():
66
+ with gr.Column(scale=1):
67
+ input_files = gr.File(label="Pictures (1 or 2)", file_types=["image"], file_count="multiple")
68
+
69
+ with gr.Group("Jitter and Split"):
70
+ rotation_slider = gr.Slider(0, 20, value=5, step=0.5, label="Rotation (°)")
71
+ trans_slider_X = gr.Slider(0, 20, value=3, step=0.5, label="Translation X (px)")
72
+ trans_slider_Y = gr.Slider(0, 20, value=3, step=0.5, label="Translation Y (px)")
73
+ persp_slider = gr.Slider(0, 0.1, value=0.02, step=0.005, label="Perspective")
74
+ bright_slider = gr.Slider(0.5, 1.5, value=1.0, step=0.05, label="Brightness")
75
+ overlap_slider = gr.Slider(0.05, 0.3, value=0.15, step=0.01, label="Overlap (%)")
76
+
77
+ generate_btn = gr.Button("Stitch", variant="primary", size="lg")
78
+
79
+ with gr.Column(scale=2):
80
+ stitched_gallery = gr.Image(label="Result")
81
+ matches_gallery = gr.Image(label="Keypoints & Matches")
82
+ stats_md = gr.Markdown(label="Stats")
83
+
84
+ generate_btn.click(
85
+ fn=process_images,
86
+ inputs=[input_files, rotation_slider, trans_slider_X, trans_slider_Y, persp_slider, bright_slider, overlap_slider],
87
+ outputs=[stitched_gallery, matches_gallery, stats_md]
88
+ )
89
+
90
+ demo.launch(debug=True)
app/requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ kornia
4
+ transformers
5
+ opencv-python
6
+ scikit-image
7
+ pillow
8
+ gradio
9
+ numpy
10
+ matplotlib
cv_utils/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoImageProcessor, AutoModel
3
+ from .stitching import feature_detection_mapping, stitch_images
4
+ from .metrics import calculate_ssim
5
+
6
+ class Stitcher:
7
+ def __init__(self, model_name="ETH-CVG/lightglue_superpoint"):
8
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
+ print(f"Using device: {self.device}")
10
+ self.processor = AutoImageProcessor.from_pretrained(model_name)
11
+ self.model = AutoModel.from_pretrained(model_name)
12
+ self.model.to(self.device)
13
+
14
+ def stitch(self, img0_pil, img1_pil):
15
+ feature_mapping = feature_detection_mapping([img0_pil, img1_pil], self.processor, self.model, self.device)
16
+ stitched_image, num_matches = stitch_images(img0_pil, img1_pil, feature_mapping[0], self.device)
17
+ return stitched_image, feature_mapping, num_matches
18
+
19
+ def visualize_matches(self, images, feature_mapping):
20
+ """Visualization """
21
+ if len(feature_mapping) == 0:
22
+ return images[0]
23
+
24
+ matches_viz = self.processor.visualize_keypoint_matching(images, feature_mapping)
25
+ return matches_viz[0]
cv_utils/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (2.21 kB). View file
 
cv_utils/__pycache__/metrics.cpython-313.pyc ADDED
Binary file (1.97 kB). View file
 
cv_utils/__pycache__/stitching.cpython-313.pyc ADDED
Binary file (4.86 kB). View file
 
cv_utils/metrics.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ from skimage.metrics import structural_similarity as ssim
4
+
5
+ def calculate_ssim(original_img, stitched_img, gray_scale=False):
6
+ """
7
+ Calculates the SSIM between the original PIL image and the stitched PIL image.
8
+ """
9
+ orig_np = np.array(original_img)
10
+
11
+ stitch_np = stitched_img.cpu().detach().numpy()
12
+ if stitch_np.ndim == 3:
13
+ stitch_np = np.transpose(stitch_np, (1, 2, 0))
14
+ # ensure imgs to have the same size
15
+ stitch_np = cv2.resize(stitch_np, (orig_np.shape[1], orig_np.shape[0]))
16
+
17
+ # ensure the data types and ranges match
18
+ orig_np = orig_np.astype(np.float32) / 255.0 if orig_np.max() > 1.0 else orig_np.astype(np.float32)
19
+ stitch_np = stitch_np.astype(np.float32) / 255.0 if stitch_np.max() > 1.0 else stitch_np.astype(np.float32)
20
+
21
+ # conversion to gray scale
22
+ orig_np = np.dot(orig_np[..., :3], [0.2989, 0.5870, 0.1140]) if gray_scale else orig_np
23
+ stitch_np = np.dot(stitch_np[..., :3], [0.2989, 0.5870, 0.1140]) if gray_scale else stitch_np
24
+
25
+ score, diff = ssim(orig_np, stitch_np, full=True, data_range=1.0, channel_axis=None if gray_scale else 2)
26
+
27
+ return score, diff
cv_utils/stitching.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import kornia
3
+ import cv2
4
+ import torchvision.transforms as T
5
+
6
+ def feature_detection_mapping(images, processor, model, device):
7
+ inputs = processor(images, return_tensors="pt").to(device)
8
+ with torch.no_grad():
9
+ outputs = model(**inputs)
10
+ image_sizes = [[(image.height, image.width) for image in images]]
11
+ outputs = processor.post_process_keypoint_matching(outputs, image_sizes, threshold=0.2)
12
+ return outputs
13
+
14
+ def stitch_images(img0_pil, img1_pil, output, device):
15
+ to_tensor = T.ToTensor()
16
+ image0 = to_tensor(img0_pil).to(device)
17
+ image1 = to_tensor(img1_pil).to(device)
18
+
19
+ pts0 = output["keypoints0"].float()
20
+ pts1 = output["keypoints1"].float()
21
+
22
+ num_matches = pts0.shape[0]
23
+
24
+ if num_matches < 4:
25
+ return None, num_matches
26
+
27
+ p0_np = pts0.detach().cpu().numpy()
28
+ p1_np = pts1.detach().cpu().numpy()
29
+
30
+ H_np, mask = cv2.findHomography(
31
+ p1_np,
32
+ p0_np,
33
+ method=cv2.USAC_MAGSAC,
34
+ ransacReprojThreshold=5.0,
35
+ confidence=0.999,
36
+ maxIters=100000
37
+ )
38
+
39
+ if H_np is None:
40
+ return None, num_matches
41
+
42
+ H = torch.from_numpy(H_np).to(device).float()
43
+
44
+ _, h0, w0 = image0.shape
45
+ _, h1, w1 = image1.shape
46
+
47
+ corners1 = torch.tensor([[0., 0.], [float(w1), 0.], [float(w1), float(h1)], [0., float(h1)]], device=device)
48
+ corners1_homo = torch.cat([corners1, torch.ones((4, 1), device=device)], dim=1).T
49
+ warped_homo = H @ corners1_homo
50
+ warped_corners1 = (warped_homo[:2] / warped_homo[2]).T
51
+
52
+ all_coords = torch.cat([
53
+ warped_corners1,
54
+ torch.tensor([[0., 0.], [float(w0), 0.], [float(w0), float(h0)], [0., float(h0)]], device=device)
55
+ ], dim=0)
56
+
57
+ min_xy = all_coords.min(dim=0).values
58
+ max_xy = all_coords.max(dim=0).values
59
+
60
+ translation = torch.eye(3, device=device)
61
+ translation[0, 2] = -min_xy[0]
62
+ translation[1, 2] = -min_xy[1]
63
+
64
+ H_final = translation @ H
65
+ out_size = (int(max_xy[1] - min_xy[1]), int(max_xy[0] - min_xy[0]))
66
+
67
+ warped0 = kornia.geometry.transform.warp_perspective(
68
+ image0.unsqueeze(0), translation.unsqueeze(0), dsize=out_size, align_corners=True
69
+ ).squeeze(0)
70
+
71
+ warped1 = kornia.geometry.transform.warp_perspective(
72
+ image1.unsqueeze(0), H_final.unsqueeze(0), dsize=out_size, align_corners=True
73
+ ).squeeze(0)
74
+
75
+ mask0 = (warped0.abs().sum(dim=0, keepdim=True) > 1e-5).float()
76
+ mask1 = (warped1.abs().sum(dim=0, keepdim=True) > 1e-5).float()
77
+
78
+ stitched = (warped0 + warped1) / (mask0 + mask1 + 1e-8)
79
+
80
+ return stitched, num_matches
image_augmentation/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image Augmentation Package
3
+ This package provides functions for image jittering and augmentation.
4
+ """
5
+ from .jitter import (
6
+ apply_geometric_jitter,
7
+ apply_brightness_jitter_range,
8
+ apply_geometric_transform,
9
+ apply_brightness_jitter,
10
+ jitter_image_random,
11
+ jitter_image,
12
+ remove_alpha,
13
+ split_image_diagonal_random,
14
+ split_image_diagonal
15
+ )
16
+
17
+ __all__ = [
18
+ "apply_geometric_jitter",
19
+ "apply_brightness_jitter_range",
20
+ "apply_geometric_transform",
21
+ "apply_brightness_jitter",
22
+ "jitter_image_random",
23
+ "jitter_image",
24
+ "remove_alpha",
25
+ "split_image_diagonal_random",
26
+ "split_image_diagonal"
27
+ ]
image_augmentation/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (651 Bytes). View file
 
image_augmentation/__pycache__/jitter.cpython-313.pyc ADDED
Binary file (7.12 kB). View file
 
image_augmentation/jitter.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageDraw, ImageEnhance
2
+ import random
3
+
4
+
5
+ # Range depended functions
6
+ def apply_geometric_jitter(image, rotation_limit=10.0, translation_limit=5, perspective_limit=0.02):
7
+ """
8
+ Applies geometric jitter to an image, including rotation, translation, and perspective shifts using random values within limits.
9
+ """
10
+ width, height = image.size
11
+
12
+ angle = random.uniform(-rotation_limit, rotation_limit)
13
+ tx = random.uniform(-translation_limit, translation_limit)
14
+ ty = random.uniform(-translation_limit, translation_limit)
15
+
16
+ img = image.rotate(angle, resample=Image.BILINEAR, translate=(tx, ty))
17
+
18
+ coeffs = [
19
+ 1 + random.uniform(-perspective_limit, perspective_limit), 0, 0,
20
+ 0, 1 + random.uniform(-perspective_limit, perspective_limit), 0,
21
+ random.uniform(-0.0001, 0.0001), random.uniform(-0.0001, 0.0001)
22
+ ]
23
+
24
+ return img.transform((width, height), Image.PERSPECTIVE, coeffs, Image.BILINEAR)
25
+
26
+ def apply_brightness_jitter_range(image, jitter_range=(0.7, 1.3)):
27
+ enhancer = ImageEnhance.Brightness(image)
28
+ factor = random.uniform(*jitter_range)
29
+ return enhancer.enhance(factor)
30
+
31
+ def jitter_image_random(img, rot_limit=5, trans_limit=3, persp_limit=0.02, bright_factor=1.0, bright_range_delta=0.3):
32
+ """
33
+ Applies a combination of geometric and brightness jitter to an image.
34
+ """
35
+ img = apply_geometric_jitter(img, rotation_limit=rot_limit, translation_limit=trans_limit, perspective_limit=persp_limit)
36
+
37
+ brightness_min = max(0.1, bright_factor - bright_range_delta)
38
+ brightness_max = min(2.0, bright_factor + bright_range_delta)
39
+
40
+ img = apply_brightness_jitter_range(img, jitter_range=(brightness_min, brightness_max))
41
+ return img
42
+
43
+ # Specific augmentations
44
+ def apply_geometric_transform(image, angle, tx, ty, perspective_coeffs):
45
+ width, height = image.size
46
+ img = image.rotate(angle, resample=Image.BILINEAR, translate=(tx, ty))
47
+ coeffs = [
48
+ 1 + perspective_coeffs, 0, 0,
49
+ 0, 1 + perspective_coeffs, 0,
50
+ 0, 0
51
+ ]
52
+ return img.transform((width, height), Image.PERSPECTIVE, coeffs, Image.BILINEAR)
53
+
54
+ def apply_brightness_jitter(image, factor):
55
+ enhancer = ImageEnhance.Brightness(image)
56
+ return enhancer.enhance(factor)
57
+
58
+ def jitter_image(img, angle, tx, ty, perspective_coeffs, brightness_factor):
59
+ """
60
+ Applies a specific combination of geometric and brightness jitter to an image.
61
+ """
62
+ img = apply_geometric_transform(img, angle, tx, ty, perspective_coeffs)
63
+ img = apply_brightness_jitter(img, brightness_factor)
64
+ return img
65
+
66
+ # Split functions
67
+ def split_image_diagonal_random(image_path, min_overlap_pct=0.1):
68
+ """
69
+ Splits an image diagonally into two parts, including a certain overlap.
70
+ """
71
+ img = Image.open(image_path).convert("RGBA")
72
+ w, h = img.size
73
+ margin = 50
74
+
75
+ top_x = random.randint(margin, w - margin)
76
+
77
+ max_slant = w
78
+ min_overlap = int(w * min_overlap_pct)
79
+ bottom_x = random.randint(max(margin, top_x - max_slant),
80
+ min(w - margin, top_x + max_slant))
81
+
82
+ mask_left = Image.new("L", (w, h), 0)
83
+ draw_l = ImageDraw.Draw(mask_left)
84
+ draw_l.polygon([(0, 0), (top_x + min_overlap, 0), (bottom_x + min_overlap, h), (0, h)], fill=255)
85
+
86
+ mask_right = Image.new("L", (w, h), 0)
87
+ draw_r = ImageDraw.Draw(mask_right)
88
+ draw_r.polygon([(top_x - min_overlap, 0), (w, 0), (w, h), (bottom_x - min_overlap, h)], fill=255)
89
+
90
+ left_img = img.copy()
91
+ left_img.putalpha(mask_left)
92
+
93
+ right_img = img.copy()
94
+ right_img.putalpha(mask_right)
95
+
96
+ cropped_left = remove_alpha(left_img)
97
+ cropped_right = remove_alpha(right_img)
98
+
99
+ return cropped_left, cropped_right, img
100
+
101
+ def split_image_diagonal(image_path, min_overlap_pct):
102
+ """
103
+ Splits an image diagonally into two parts.
104
+ """
105
+ img = Image.open(image_path).convert("RGBA")
106
+ w, h = img.size
107
+ margin = 50
108
+
109
+ top_x = w // 2
110
+ bottom_x = w // 2
111
+
112
+ min_overlap = int(w * min_overlap_pct)
113
+
114
+ mask_left = Image.new("L", (w, h), 0)
115
+ draw_l = ImageDraw.Draw(mask_left)
116
+ draw_l.polygon([(0, 0), (top_x + min_overlap, 0), (bottom_x + min_overlap, h), (0, h)], fill=255)
117
+
118
+ mask_right = Image.new("L", (w, h), 0)
119
+ draw_r = ImageDraw.Draw(mask_right)
120
+ draw_r.polygon([(top_x - min_overlap, 0), (w, 0), (w, h), (bottom_x - min_overlap, h)], fill=255)
121
+
122
+ left_img = img.copy()
123
+ left_img.putalpha(mask_left)
124
+
125
+ right_img = img.copy()
126
+ right_img.putalpha(mask_right)
127
+
128
+ cropped_left = remove_alpha(left_img)
129
+ cropped_right = remove_alpha(right_img)
130
+
131
+ return cropped_left, cropped_right, img
132
+
133
+
134
+ def remove_alpha(img_rgba, bg_color=(0, 0, 0)):
135
+ """
136
+ Removes the alpha channel from an RGBA image and replaces it with a solid background color.
137
+ """
138
+ background = Image.new("RGB", img_rgba.size, bg_color)
139
+ background.paste(img_rgba, mask=img_rgba.split()[3])
140
+ return background