AbstractPhil commited on
Commit
5cbc675
Β·
verified Β·
1 Parent(s): 0088abf

Create fractal_mix.py

Browse files
Files changed (1) hide show
  1. fractal_mix.py +109 -0
fractal_mix.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
2
+ # MIXING AUGMENTATIONS
3
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
4
+
5
+ def alphamix_data(x, y, alpha_range=(0.3, 0.7), spatial_ratio=0.25):
6
+ """
7
+ Standard AlphaMix: Single spatially localized transparent overlay.
8
+ """
9
+ batch_size = x.size(0)
10
+ index = torch.randperm(batch_size, device=x.device)
11
+
12
+ y_a, y_b = y, y[index]
13
+
14
+ # Sample alpha from Beta distribution
15
+ alpha_min, alpha_max = alpha_range
16
+ beta_sample = torch.distributions.Beta(2.0, 2.0).sample().item()
17
+ alpha = alpha_min + (alpha_max - alpha_min) * beta_sample
18
+
19
+ # Compute overlay region
20
+ _, _, H, W = x.shape
21
+ overlay_ratio = torch.sqrt(torch.tensor(spatial_ratio)).item()
22
+ overlay_h = int(H * overlay_ratio)
23
+ overlay_w = int(W * overlay_ratio)
24
+
25
+ top = torch.randint(0, H - overlay_h + 1, (1,), device=x.device).item()
26
+ left = torch.randint(0, W - overlay_w + 1, (1,), device=x.device).item()
27
+
28
+ # Blend
29
+ composited_x = x.clone()
30
+ overlay_region = alpha * x[:, :, top:top+overlay_h, left:left+overlay_w]
31
+ background_region = (1 - alpha) * x[index, :, top:top+overlay_h, left:left+overlay_w]
32
+ composited_x[:, :, top:top+overlay_h, left:left+overlay_w] = overlay_region + background_region
33
+
34
+ return composited_x, y_a, y_b, alpha
35
+
36
+
37
+ def alphamix_fractal(
38
+ x: torch.Tensor,
39
+ y: torch.Tensor,
40
+ alpha_range=(0.3, 0.7),
41
+ steps_range=(1, 3),
42
+ triad_scales=(1/3, 1/9, 1/27),
43
+ beta_shape=(2.0, 2.0),
44
+ seed: int | None = None,
45
+ ):
46
+ """
47
+ Fractal AlphaMix: Triadic multi-patch overlays aligned to Cantor geometry.
48
+ Pure torch, GPU-compatible.
49
+ """
50
+ if seed is not None:
51
+ torch.manual_seed(seed)
52
+
53
+ B, C, H, W = x.shape
54
+ device = x.device
55
+
56
+ # Permutation for mixing
57
+ idx = torch.randperm(B, device=device)
58
+ y_a, y_b = y, y[idx]
59
+
60
+ x_mix = x.clone()
61
+ total_area = H * W
62
+
63
+ # Beta distribution for transparency sampling
64
+ k1, k2 = beta_shape
65
+ beta_dist = torch.distributions.Beta(k1, k2)
66
+ alpha_min, alpha_max = alpha_range
67
+
68
+ # Storage for effective alpha calculation
69
+ alpha_elems = []
70
+ area_weights = []
71
+
72
+ # Sample number of patches (same for all images in batch)
73
+ steps = torch.randint(steps_range[0], steps_range[1] + 1, (1,), device=device).item()
74
+
75
+ for _ in range(steps):
76
+ # Choose triadic scale
77
+ scale_idx = torch.randint(0, len(triad_scales), (1,), device=device).item()
78
+ scale = triad_scales[scale_idx]
79
+
80
+ # Compute patch dimensions (triadic area)
81
+ patch_area = max(1, int(total_area * scale))
82
+ side = int(torch.sqrt(torch.tensor(patch_area, dtype=torch.float32)).item())
83
+ h = max(1, min(H, side))
84
+ w = max(1, min(W, side))
85
+
86
+ # Random position
87
+ top = torch.randint(0, H - h + 1, (1,), device=device).item()
88
+ left = torch.randint(0, W - w + 1, (1,), device=device).item()
89
+
90
+ # Sample transparency from Beta distribution
91
+ alpha_raw = beta_dist.sample().item()
92
+ alpha = alpha_min + (alpha_max - alpha_min) * alpha_raw
93
+
94
+ # Track for effective alpha
95
+ alpha_elems.append(alpha)
96
+ area_weights.append(h * w)
97
+
98
+ # Blend patches
99
+ fg = alpha * x[:, :, top:top + h, left:left + w]
100
+ bg = (1 - alpha) * x[idx, :, top:top + h, left:left + w]
101
+ x_mix[:, :, top:top + h, left:left + w] = fg + bg
102
+
103
+ # Compute area-weighted effective alpha
104
+ alpha_t = torch.tensor(alpha_elems, dtype=torch.float32, device=device)
105
+ area_t = torch.tensor(area_weights, dtype=torch.float32, device=device)
106
+ alpha_eff = (alpha_t * area_t).sum() / (area_t.sum() + 1e-12)
107
+ alpha_eff = alpha_eff.item()
108
+
109
+ return x_mix, y_a, y_b, alpha_eff