Create fractal_mix.py
Browse files- fractal_mix.py +109 -0
fractal_mix.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 2 |
+
# MIXING AUGMENTATIONS
|
| 3 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 4 |
+
|
| 5 |
+
def alphamix_data(x, y, alpha_range=(0.3, 0.7), spatial_ratio=0.25):
|
| 6 |
+
"""
|
| 7 |
+
Standard AlphaMix: Single spatially localized transparent overlay.
|
| 8 |
+
"""
|
| 9 |
+
batch_size = x.size(0)
|
| 10 |
+
index = torch.randperm(batch_size, device=x.device)
|
| 11 |
+
|
| 12 |
+
y_a, y_b = y, y[index]
|
| 13 |
+
|
| 14 |
+
# Sample alpha from Beta distribution
|
| 15 |
+
alpha_min, alpha_max = alpha_range
|
| 16 |
+
beta_sample = torch.distributions.Beta(2.0, 2.0).sample().item()
|
| 17 |
+
alpha = alpha_min + (alpha_max - alpha_min) * beta_sample
|
| 18 |
+
|
| 19 |
+
# Compute overlay region
|
| 20 |
+
_, _, H, W = x.shape
|
| 21 |
+
overlay_ratio = torch.sqrt(torch.tensor(spatial_ratio)).item()
|
| 22 |
+
overlay_h = int(H * overlay_ratio)
|
| 23 |
+
overlay_w = int(W * overlay_ratio)
|
| 24 |
+
|
| 25 |
+
top = torch.randint(0, H - overlay_h + 1, (1,), device=x.device).item()
|
| 26 |
+
left = torch.randint(0, W - overlay_w + 1, (1,), device=x.device).item()
|
| 27 |
+
|
| 28 |
+
# Blend
|
| 29 |
+
composited_x = x.clone()
|
| 30 |
+
overlay_region = alpha * x[:, :, top:top+overlay_h, left:left+overlay_w]
|
| 31 |
+
background_region = (1 - alpha) * x[index, :, top:top+overlay_h, left:left+overlay_w]
|
| 32 |
+
composited_x[:, :, top:top+overlay_h, left:left+overlay_w] = overlay_region + background_region
|
| 33 |
+
|
| 34 |
+
return composited_x, y_a, y_b, alpha
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def alphamix_fractal(
|
| 38 |
+
x: torch.Tensor,
|
| 39 |
+
y: torch.Tensor,
|
| 40 |
+
alpha_range=(0.3, 0.7),
|
| 41 |
+
steps_range=(1, 3),
|
| 42 |
+
triad_scales=(1/3, 1/9, 1/27),
|
| 43 |
+
beta_shape=(2.0, 2.0),
|
| 44 |
+
seed: int | None = None,
|
| 45 |
+
):
|
| 46 |
+
"""
|
| 47 |
+
Fractal AlphaMix: Triadic multi-patch overlays aligned to Cantor geometry.
|
| 48 |
+
Pure torch, GPU-compatible.
|
| 49 |
+
"""
|
| 50 |
+
if seed is not None:
|
| 51 |
+
torch.manual_seed(seed)
|
| 52 |
+
|
| 53 |
+
B, C, H, W = x.shape
|
| 54 |
+
device = x.device
|
| 55 |
+
|
| 56 |
+
# Permutation for mixing
|
| 57 |
+
idx = torch.randperm(B, device=device)
|
| 58 |
+
y_a, y_b = y, y[idx]
|
| 59 |
+
|
| 60 |
+
x_mix = x.clone()
|
| 61 |
+
total_area = H * W
|
| 62 |
+
|
| 63 |
+
# Beta distribution for transparency sampling
|
| 64 |
+
k1, k2 = beta_shape
|
| 65 |
+
beta_dist = torch.distributions.Beta(k1, k2)
|
| 66 |
+
alpha_min, alpha_max = alpha_range
|
| 67 |
+
|
| 68 |
+
# Storage for effective alpha calculation
|
| 69 |
+
alpha_elems = []
|
| 70 |
+
area_weights = []
|
| 71 |
+
|
| 72 |
+
# Sample number of patches (same for all images in batch)
|
| 73 |
+
steps = torch.randint(steps_range[0], steps_range[1] + 1, (1,), device=device).item()
|
| 74 |
+
|
| 75 |
+
for _ in range(steps):
|
| 76 |
+
# Choose triadic scale
|
| 77 |
+
scale_idx = torch.randint(0, len(triad_scales), (1,), device=device).item()
|
| 78 |
+
scale = triad_scales[scale_idx]
|
| 79 |
+
|
| 80 |
+
# Compute patch dimensions (triadic area)
|
| 81 |
+
patch_area = max(1, int(total_area * scale))
|
| 82 |
+
side = int(torch.sqrt(torch.tensor(patch_area, dtype=torch.float32)).item())
|
| 83 |
+
h = max(1, min(H, side))
|
| 84 |
+
w = max(1, min(W, side))
|
| 85 |
+
|
| 86 |
+
# Random position
|
| 87 |
+
top = torch.randint(0, H - h + 1, (1,), device=device).item()
|
| 88 |
+
left = torch.randint(0, W - w + 1, (1,), device=device).item()
|
| 89 |
+
|
| 90 |
+
# Sample transparency from Beta distribution
|
| 91 |
+
alpha_raw = beta_dist.sample().item()
|
| 92 |
+
alpha = alpha_min + (alpha_max - alpha_min) * alpha_raw
|
| 93 |
+
|
| 94 |
+
# Track for effective alpha
|
| 95 |
+
alpha_elems.append(alpha)
|
| 96 |
+
area_weights.append(h * w)
|
| 97 |
+
|
| 98 |
+
# Blend patches
|
| 99 |
+
fg = alpha * x[:, :, top:top + h, left:left + w]
|
| 100 |
+
bg = (1 - alpha) * x[idx, :, top:top + h, left:left + w]
|
| 101 |
+
x_mix[:, :, top:top + h, left:left + w] = fg + bg
|
| 102 |
+
|
| 103 |
+
# Compute area-weighted effective alpha
|
| 104 |
+
alpha_t = torch.tensor(alpha_elems, dtype=torch.float32, device=device)
|
| 105 |
+
area_t = torch.tensor(area_weights, dtype=torch.float32, device=device)
|
| 106 |
+
alpha_eff = (alpha_t * area_t).sum() / (area_t.sum() + 1e-12)
|
| 107 |
+
alpha_eff = alpha_eff.item()
|
| 108 |
+
|
| 109 |
+
return x_mix, y_a, y_b, alpha_eff
|