RFTSystems commited on
Commit
bc42ee1
·
verified ·
1 Parent(s): b834e7c

Create stage4.py

Browse files
Files changed (1) hide show
  1. stage4.py +140 -0
stage4.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # stage4.py
2
+ # Author: Liam Grinstead
3
+ # Purpose: ViT-Tiny (ImageNet Subset) Validation (Stage Four of Twelve)
4
+
5
+ import os, math, time, json, random, argparse
6
+ import torch, torch.nn as nn, torch.nn.functional as F
7
+ import torchvision, torchvision.transforms as T
8
+
9
+ # ---------------- Determinism ----------------
10
+ def set_seed(s=1234):
11
+ random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s)
12
+ torch.backends.cudnn.benchmark = False
13
+ torch.backends.cudnn.deterministic = False
14
+
15
+ # ---------------- Telemetry ------------------
16
+ class Telemetry:
17
+ def __init__(self, path="stage4_vit_tiny.jsonl"):
18
+ self.t0 = time.time(); self.f = open(path,"w")
19
+ def emit(self, **k):
20
+ k["t"] = round(time.time()-self.t0,3)
21
+ line = json.dumps(k,separators=(",",":"))
22
+ print(line); self.f.write(line+"\n"); self.f.flush()
23
+ def close(self): self.f.close()
24
+
25
+ # ---------------- Orbital Coupler ------------
26
+ class Orbital:
27
+ def __init__(self, g=0.006, floor=0.2):
28
+ self.a=0.0; self.b=math.pi/3; self.g=g; self.floor=floor
29
+ def step(self):
30
+ d=(self.b-self.a+math.pi)%(2*math.pi)-math.pi
31
+ if abs(d)<self.floor: d=self.floor*(1 if d>=0 else -1)
32
+ s=math.sin(d)
33
+ self.a=(self.a+self.g*s)%(2*math.pi)
34
+ self.b=(self.b-self.g*s)%(2*math.pi)
35
+ drift=abs((self.a-self.b+math.pi)%(2*math.pi)-math.pi)
36
+ return drift, abs(s)
37
+
38
+ # ---------------- DCLR Optimiser -------------
39
+ class DCLR(torch.optim.Optimizer):
40
+ def __init__(self, params, lr=5e-4, beta=0.9, gamma=0.999, eps=1e-8, cg=0.05):
41
+ super().__init__(params, dict(lr=lr,beta=beta,gamma=gamma,eps=eps,cg=cg))
42
+ @torch.no_grad()
43
+ def step(self, closure=None):
44
+ tot=0.0
45
+ for g in self.param_groups:
46
+ lr,beta,gamma,eps,c = g["lr"],g["beta"],g["gamma"],g["eps"],g["cg"]
47
+ for p in g["params"]:
48
+ if p.grad is None: continue
49
+ st=self.state[p]
50
+ if not st:
51
+ st["m"]=torch.zeros_like(p); st["v"]=torch.zeros_like(p); st["coh"]=torch.zeros_like(p)
52
+ m,v,h=st["m"],st["v"],st["coh"]; g0=p.grad
53
+ m.mul_(beta).add_(g0,alpha=1-beta)
54
+ v.mul_(gamma).addcmul_(g0,g0,value=1-gamma)
55
+ d=g0-m; h.mul_(0.9).add_(d.abs(),alpha=0.1)
56
+ lr_eff=lr/(1+c*h)
57
+ step=lr_eff*m/(v.sqrt()+eps)
58
+ p.add_(-step); tot += (step*step).sum().item()
59
+ return None, tot
60
+
61
+ # ---------------- ViT-Tiny -------------------
62
+ class PatchEmbed(nn.Module):
63
+ def __init__(self, img=224, patch=16, in_ch=3, dim=192):
64
+ super().__init__()
65
+ self.proj=nn.Conv2d(in_ch, dim, kernel_size=patch, stride=patch)
66
+ self.n=(img//patch)*(img//patch)
67
+ def forward(self,x):
68
+ x=self.proj(x); return x.flatten(2).transpose(1,2)
69
+
70
+ class Block(nn.Module):
71
+ def __init__(self, dim=192, heads=3, mlp_ratio=4):
72
+ super().__init__()
73
+ self.n1=nn.LayerNorm(dim)
74
+ self.attn=nn.MultiheadAttention(dim, heads, batch_first=True)
75
+ self.n2=nn.LayerNorm(dim)
76
+ self.mlp=nn.Sequential(nn.Linear(dim,int(dim*mlp_ratio)), nn.GELU(), nn.Linear(int(dim*mlp_ratio),dim))
77
+ def forward(self,x):
78
+ h=x; x=self.n1(x); x,_=self.attn(x,x,x,need_weights=False); x=x+h
79
+ h=x; x=self.n2(x); x=x+self.mlp(x); return x
80
+
81
+ class ViTTiny(nn.Module):
82
+ def __init__(self, num_classes=1000, img=224, patch=16, dim=192, depth=12, heads=3, mlp_ratio=4):
83
+ super().__init__()
84
+ self.pe=PatchEmbed(img,patch,3,dim)
85
+ self.cls=nn.Parameter(torch.zeros(1,1,dim))
86
+ self.pos=nn.Parameter(torch.zeros(1,1+self.pe.n,dim))
87
+ self.blocks=nn.ModuleList([Block(dim,heads,mlp_ratio) for _ in range(depth)])
88
+ self.norm=nn.LayerNorm(dim); self.head=nn.Linear(dim,num_classes)
89
+ nn.init.trunc_normal_(self.cls,std=0.02); nn.init.trunc_normal_(self.pos,std=0.02)
90
+ def forward(self,x):
91
+ B=x.size(0); x=self.pe(x); cls=self.cls.expand(B,-1,-1)
92
+ x=torch.cat([cls,x],dim=1)+self.pos[:,:(x.size(1)+1)]
93
+ for blk in self.blocks: x=blk(x)
94
+ x=self.norm(x); return self.head(x[:,0])
95
+
96
+ # ---------------- Data -----------------------
97
+ def get_loaders(data_dir=None, batch=256, img=224):
98
+ tf=T.Compose([T.Resize((img,img)), T.RandomHorizontalFlip(), T.ToTensor(),
99
+ T.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225))])
100
+ if data_dir and os.path.isdir(os.path.join(data_dir,"train")):
101
+ train=torchvision.datasets.ImageFolder(os.path.join(data_dir,"train"), transform=tf)
102
+ val=torchvision.datasets.ImageFolder(os.path.join(data_dir,"val"), transform=tf)
103
+ else:
104
+ # synthetic fallback
105
+ C=1000
106
+ class Synth(torch.utils.data.Dataset):
107
+ def __init__(self,n): self.n=n
108
+ def __len__(self): return self.n
109
+ def __getitem__(self,i):
110
+ x=torch.randn(3,img,img); y=torch.randint(0,C,(1,)).item()
111
+ return x,y
112
+ train=Synth(4096); val=Synth(1024)
113
+ tr=torch.utils.data.DataLoader(train,batch_size=batch,shuffle=True)
114
+ va=torch.utils.data.DataLoader(val,batch_size=batch,shuffle=False)
115
+ return tr,va
116
+
117
+ # ---------------- Runner ---------------------
118
+ def train(mode="RFT", data_dir=None, steps=1000, batch=256, lr=5e-4, log_path="stage4_vit_tiny.jsonl"):
119
+ set_seed(1234); tm=Telemetry(log_path); orb=Orbital()
120
+ dev="cuda" if torch.cuda.is_available() else "cpu"
121
+ train_loader, val_loader = get_loaders(data_dir, batch)
122
+ model=ViTTiny(num_classes=1000).to(dev)
123
+ opt=DCLR(model.parameters(), lr=lr) if mode=="RFT" else torch.optim.Adam(model.parameters(), lr=lr)
124
+ ce=nn.CrossEntropyLoss()
125
+ it=0
126
+ for (x,y) in train_loader:
127
+ if it>=steps: break
128
+ it+=1
129
+ drift,flux=orb.step()
130
+ x,y=x.to(dev),y.to(dev)
131
+ opt.zero_grad(set_to_none=True)
132
+ out=model(x); loss=ce(out,y); loss.backward()
133
+ if isinstance(opt,DCLR): _,J=opt.step()
134
+ else: opt.step(); J=0.0
135
+ acc=(out.argmax(1)==y).float().mean().item()
136
+ tm.emit(mode=mode, step=it, drift=round(drift,3), flux=round(flux,3),
137
+ E_ret=0.994, coh=0.999, loss=round(float(loss.item()),4),
138
+ acc=round(float(acc),3), J_step=round(float(J*1e-6),6))
139
+ tm.close()
140
+ return f"Stage 4 complete. Telemetry saved to {log_path}"