File size: 7,819 Bytes
5154f51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 |
# Introspective Prisma-VL-8B Architecture
## Overview
Prisma-VL-8B includes a introspective feedback mechanism that provides fine-grained self-monitoring uncertainty awareness to the model's predictions.
## Core Innovation
The model now tracks its own prediction uncertainty and uses this as a feedback signal for subsequent predictions. This creates a temporal awareness loop:
```
Token t-1: "What's next?" β Prediction + Uncertainty measurement
Token t: [Previous uncertainty signal] + "What's next?" β Better calibrated prediction
```
## Architecture Changes
### 1. Uncertainty Embeddings (PrismaVLModel)
Added to `PrismaVLModel.__init__()`:
```python
# 65,536-level uncertainty embedding table
self.n_bits = 16 # 16-bit quantization
self.n_uncertainty_levels = 65536 # 2^16
# Learned embeddings: one vector per uncertainty level
self.uncertainty_embeddings = nn.Embedding(65536, hidden_dim)
# Cache for uncertainty codes from previous step
self.prev_uncertainty_code = None # [batch_size, seq_len] with values [0-65535]
```
**Parameter cost**: 65,536 Γ 4096 = 268,435,456 parameters (3.35% overhead)
### 2. Uncertainty Injection (PrismaVLModel.forward)
During forward pass, after creating input embeddings:
```python
# Look up uncertainty embeddings from previous step
uncertainty_embeds = self.uncertainty_embeddings(prev_uncertainty_code)
# Shift right: position i gets uncertainty from position i-1
uncertainty_shifted = pad(uncertainty_embeds[:, :-1, :], (0,0,1,0))
# Inject into input
inputs_embeds = inputs_embeds + uncertainty_shifted
```
Now the model sees: **[Token embedding] + [How uncertain was I last time?]**
### 3. Uncertainty Computation (PrismaVLForConditionalGeneration.forward)
After computing logits, during training:
```python
# Compute entropy (uncertainty) of predictions
probs = logits.softmax(-1)
entropy = -(probs * log(probs)).sum(-1)
# Normalize to [0, 1]
entropy_norm = entropy / log(vocab_size)
# Quantize to 16 bits (0-65535)
uncertainty_code = (entropy_norm * 65535).long()
# Store for next step
self.model.prev_uncertainty_code = uncertainty_code
```
## How It Works (Step by Step)
### Inference/Generation:
1. **Token 0**: No previous uncertainty β Use neutral (32768)
2. **Token 1**: Predict β Measure confidence β Encode as 0-65535
3. **Token 2**: Inject uncertainty signal from Token 1 β Predict (now calibrated)
4. **Token 3**: Inject uncertainty from Token 2 β Predict
5. ... and so on
### Training:
Model learns the uncertainty embeddings through backpropagation:
- Embedding #0-16383: "I was very confident" β Model learns to stay confident
- Embedding #16384-32767: "I had medium confidence" β Model learns moderate caution
- Embedding #32768-49151: "I was uncertain" β Model learns to hedge
- Embedding #49152-65535: "I was very uncertain" β Model learns to be conservative
## Key Properties
### 1. Moderate Overhead
- **Parameters**: 268M additional (3.35% of 8B base)
- **Memory**: 2 bytes per token (uncertainty code)
- **Compute**: Negligible (one embedding lookup per token)
### 2. Temporal Awareness
- Model builds a "confidence history" across generation
- Can detect when it's going into unfamiliar territory
- Can recover calibration after uncertain predictions
### 3. Self-Calibration
- No external signals needed
- Model learns its own uncertainty language
- Improves through standard supervised training
### 4. Architecture-Agnostic
- Works with any transformer-based model
- Doesn't modify attention, FFN, or other core components
- Clean separation: uncertainty mechanism vs. base model
## Usage
### Standard Inference
```python
from modeling import PrismaVLForConditionalGeneration
from transformers import AutoProcessor
# Load model (introspective mechanism is built-in)
model = PrismaVLForConditionalGeneration.from_pretrained(
".",
trust_remote_code=True,
dtype=torch.bfloat16,
device_map="auto"
)
processor = AutoProcessor.from_pretrained(".", trust_remote_code=True)
# Use normally - uncertainty tracking happens automatically
messages = [{"role": "user", "content": [{"type": "image", "image": img}, {"type": "text", "text": prompt}]}]
inputs = processor.apply_chat_template(messages, ...)
outputs = model.generate(**inputs)
```
### Training
```python
# Train normally - uncertainty mechanism learns automatically
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
for batch in dataloader:
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
# The uncertainty embeddings will learn to represent
# "how to adjust predictions based on previous confidence"
```
### Resetting Uncertainty (Between Sequences)
```python
# Reset uncertainty cache between independent generations
model.model.reset_uncertainty()
# Generate
outputs = model.generate(...)
```
## What Gets Learned
The 65,536 uncertainty embedding vectors learn to encode:
1. **Confidence Continuation**:
- "Last token was confident" β Maintain confidence (if appropriate)
2. **Uncertainty Propagation**:
- "Last token was uncertain" β Be more conservative
3. **Domain Shifts**:
- Sequence of low uncertainty β sudden high uncertainty β Domain boundary detected
4. **Recovery Patterns**:
- High uncertainty β Gradual return to confidence β Model finding its footing
## Benefits
1. **Better Calibration**: Model knows when it doesn't know
2. **Hallucination Awareness**: Uncertain predictions less likely to compound
3. **Adaptive Confidence**: Can adjust based on recent performance
4. **Interpretability**: Uncertainty codes provide insight into model state
5. **No Inference Cost**: Only active during training (for computing new uncertainties)
## Implementation Details
### Files Modified
- `modeling.py`:
- `PrismaVLModel.__init__()`: Add uncertainty embeddings
- `PrismaVLModel.forward()`: Inject uncertainty signal
- `PrismaVLForConditionalGeneration.forward()`: Compute uncertainty
- Added `reset_uncertainty()` method
### Initialization
- Uncertainty embeddings initialized with `std = config.text_config.initializer_range` (typically 0.02)
- Start neutral: first token uses code 128 (middle of range)
### Compatibility
- Fully backward compatible: model can load existing checkpoints
- New uncertainty embeddings initialize randomly (will be trained)
- No changes to base model weights or architecture
## Comparison to Original Llama 3.2 Example
### Similarities:
- Entropy-based uncertainty measurement
- Temporal feedback loop
- Embedding-based uncertainty representation
### Differences:
- **Quantization**: 16-bit (65,536 levels) vs. 8-bit (256 levels)
- **Resolution**: Fine-grained uncertainty vs. coarse-grained
- **Overhead**: 3.35% parameter overhead vs. ~0.04%
- **Applied to**: Vision-language model (Prisma-VL) vs. pure language model (Llama)
- **Integration**: Built into core architecture vs. wrapper class
- **Scope**: Uncertainty only for text generation (not vision encoding)
## Future Enhancements
Potential extensions:
1. **Multi-resolution Uncertainty**: Track uncertainty at token, word, and sentence levels
2. **Uncertainty-aware Generation**: Sample less when uncertain (lower temperature)
3. **Visual Uncertainty**: Extend mechanism to vision encoder
4. **Cross-modal Uncertainty**: Track alignment confidence between vision and text
5. **Explicit Uncertainty Tokens**: Add special tokens to express uncertainty in output
## Citation
Inspired by temporal feedback loop patterns, enhanced with 16-bit high-resolution quantization for fine-grained uncertainty representation.
---
**Model**: Prisma-VL-8B
**Date**: 2025
**Architecture**: Integrated 16-bit temporal uncertainty feedback mechanism
**Parameter Overhead**: 268M (3.35%)
**Memory Overhead**: 2 bytes/token
|