File size: 7,819 Bytes
5154f51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
# Introspective Prisma-VL-8B Architecture

## Overview

Prisma-VL-8B includes a introspective feedback mechanism that provides fine-grained self-monitoring uncertainty awareness to the model's predictions.

## Core Innovation

The model now tracks its own prediction uncertainty and uses this as a feedback signal for subsequent predictions. This creates a temporal awareness loop:

```
Token t-1: "What's next?" β†’ Prediction + Uncertainty measurement
Token t:   [Previous uncertainty signal] + "What's next?" β†’ Better calibrated prediction
```

## Architecture Changes

### 1. Uncertainty Embeddings (PrismaVLModel)

Added to `PrismaVLModel.__init__()`:

```python
# 65,536-level uncertainty embedding table
self.n_bits = 16  # 16-bit quantization
self.n_uncertainty_levels = 65536  # 2^16

# Learned embeddings: one vector per uncertainty level
self.uncertainty_embeddings = nn.Embedding(65536, hidden_dim)

# Cache for uncertainty codes from previous step
self.prev_uncertainty_code = None  # [batch_size, seq_len] with values [0-65535]
```

**Parameter cost**: 65,536 Γ— 4096 = 268,435,456 parameters (3.35% overhead)

### 2. Uncertainty Injection (PrismaVLModel.forward)

During forward pass, after creating input embeddings:

```python
# Look up uncertainty embeddings from previous step
uncertainty_embeds = self.uncertainty_embeddings(prev_uncertainty_code)

# Shift right: position i gets uncertainty from position i-1
uncertainty_shifted = pad(uncertainty_embeds[:, :-1, :], (0,0,1,0))

# Inject into input
inputs_embeds = inputs_embeds + uncertainty_shifted
```

Now the model sees: **[Token embedding] + [How uncertain was I last time?]**

### 3. Uncertainty Computation (PrismaVLForConditionalGeneration.forward)

After computing logits, during training:

```python
# Compute entropy (uncertainty) of predictions
probs = logits.softmax(-1)
entropy = -(probs * log(probs)).sum(-1)

# Normalize to [0, 1]
entropy_norm = entropy / log(vocab_size)

# Quantize to 16 bits (0-65535)
uncertainty_code = (entropy_norm * 65535).long()

# Store for next step
self.model.prev_uncertainty_code = uncertainty_code
```

## How It Works (Step by Step)

### Inference/Generation:

1. **Token 0**: No previous uncertainty β†’ Use neutral (32768)
2. **Token 1**: Predict β†’ Measure confidence β†’ Encode as 0-65535
3. **Token 2**: Inject uncertainty signal from Token 1 β†’ Predict (now calibrated)
4. **Token 3**: Inject uncertainty from Token 2 β†’ Predict
5. ... and so on

### Training:

Model learns the uncertainty embeddings through backpropagation:
- Embedding #0-16383: "I was very confident" β†’ Model learns to stay confident
- Embedding #16384-32767: "I had medium confidence" β†’ Model learns moderate caution
- Embedding #32768-49151: "I was uncertain" β†’ Model learns to hedge
- Embedding #49152-65535: "I was very uncertain" β†’ Model learns to be conservative

## Key Properties

### 1. Moderate Overhead
- **Parameters**: 268M additional (3.35% of 8B base)
- **Memory**: 2 bytes per token (uncertainty code)
- **Compute**: Negligible (one embedding lookup per token)

### 2. Temporal Awareness
- Model builds a "confidence history" across generation
- Can detect when it's going into unfamiliar territory
- Can recover calibration after uncertain predictions

### 3. Self-Calibration
- No external signals needed
- Model learns its own uncertainty language
- Improves through standard supervised training

### 4. Architecture-Agnostic
- Works with any transformer-based model
- Doesn't modify attention, FFN, or other core components
- Clean separation: uncertainty mechanism vs. base model

## Usage

### Standard Inference

```python
from modeling import PrismaVLForConditionalGeneration
from transformers import AutoProcessor

# Load model (introspective mechanism is built-in)
model = PrismaVLForConditionalGeneration.from_pretrained(
    ".",
    trust_remote_code=True,
    dtype=torch.bfloat16,
    device_map="auto"
)

processor = AutoProcessor.from_pretrained(".", trust_remote_code=True)

# Use normally - uncertainty tracking happens automatically
messages = [{"role": "user", "content": [{"type": "image", "image": img}, {"type": "text", "text": prompt}]}]
inputs = processor.apply_chat_template(messages, ...)
outputs = model.generate(**inputs)
```

### Training

```python
# Train normally - uncertainty mechanism learns automatically
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

for batch in dataloader:
    outputs = model(**batch)
    loss = outputs.loss
    loss.backward()
    optimizer.step()

# The uncertainty embeddings will learn to represent
# "how to adjust predictions based on previous confidence"
```

### Resetting Uncertainty (Between Sequences)

```python
# Reset uncertainty cache between independent generations
model.model.reset_uncertainty()

# Generate
outputs = model.generate(...)
```

## What Gets Learned

The 65,536 uncertainty embedding vectors learn to encode:

1. **Confidence Continuation**:
   - "Last token was confident" β†’ Maintain confidence (if appropriate)

2. **Uncertainty Propagation**:
   - "Last token was uncertain" β†’ Be more conservative

3. **Domain Shifts**:
   - Sequence of low uncertainty β†’ sudden high uncertainty β†’ Domain boundary detected

4. **Recovery Patterns**:
   - High uncertainty β†’ Gradual return to confidence β†’ Model finding its footing

## Benefits

1. **Better Calibration**: Model knows when it doesn't know
2. **Hallucination Awareness**: Uncertain predictions less likely to compound
3. **Adaptive Confidence**: Can adjust based on recent performance
4. **Interpretability**: Uncertainty codes provide insight into model state
5. **No Inference Cost**: Only active during training (for computing new uncertainties)

## Implementation Details

### Files Modified

- `modeling.py`:
  - `PrismaVLModel.__init__()`: Add uncertainty embeddings
  - `PrismaVLModel.forward()`: Inject uncertainty signal
  - `PrismaVLForConditionalGeneration.forward()`: Compute uncertainty
  - Added `reset_uncertainty()` method

### Initialization

- Uncertainty embeddings initialized with `std = config.text_config.initializer_range` (typically 0.02)
- Start neutral: first token uses code 128 (middle of range)

### Compatibility

- Fully backward compatible: model can load existing checkpoints
- New uncertainty embeddings initialize randomly (will be trained)
- No changes to base model weights or architecture

## Comparison to Original Llama 3.2 Example

### Similarities:
- Entropy-based uncertainty measurement
- Temporal feedback loop
- Embedding-based uncertainty representation

### Differences:
- **Quantization**: 16-bit (65,536 levels) vs. 8-bit (256 levels)
- **Resolution**: Fine-grained uncertainty vs. coarse-grained
- **Overhead**: 3.35% parameter overhead vs. ~0.04%
- **Applied to**: Vision-language model (Prisma-VL) vs. pure language model (Llama)
- **Integration**: Built into core architecture vs. wrapper class
- **Scope**: Uncertainty only for text generation (not vision encoding)

## Future Enhancements

Potential extensions:

1. **Multi-resolution Uncertainty**: Track uncertainty at token, word, and sentence levels
2. **Uncertainty-aware Generation**: Sample less when uncertain (lower temperature)
3. **Visual Uncertainty**: Extend mechanism to vision encoder
4. **Cross-modal Uncertainty**: Track alignment confidence between vision and text
5. **Explicit Uncertainty Tokens**: Add special tokens to express uncertainty in output

## Citation

Inspired by temporal feedback loop patterns, enhanced with 16-bit high-resolution quantization for fine-grained uncertainty representation.

---

**Model**: Prisma-VL-8B
**Date**: 2025
**Architecture**: Integrated 16-bit temporal uncertainty feedback mechanism
**Parameter Overhead**: 268M (3.35%)
**Memory Overhead**: 2 bytes/token