Prisma-VL-8B / INTROSPECTIVE_ARCHITECTURE.md
ehartford's picture
Upload folder using huggingface_hub
5154f51 verified
# Introspective Prisma-VL-8B Architecture
## Overview
Prisma-VL-8B includes a introspective feedback mechanism that provides fine-grained self-monitoring uncertainty awareness to the model's predictions.
## Core Innovation
The model now tracks its own prediction uncertainty and uses this as a feedback signal for subsequent predictions. This creates a temporal awareness loop:
```
Token t-1: "What's next?" β†’ Prediction + Uncertainty measurement
Token t: [Previous uncertainty signal] + "What's next?" β†’ Better calibrated prediction
```
## Architecture Changes
### 1. Uncertainty Embeddings (PrismaVLModel)
Added to `PrismaVLModel.__init__()`:
```python
# 65,536-level uncertainty embedding table
self.n_bits = 16 # 16-bit quantization
self.n_uncertainty_levels = 65536 # 2^16
# Learned embeddings: one vector per uncertainty level
self.uncertainty_embeddings = nn.Embedding(65536, hidden_dim)
# Cache for uncertainty codes from previous step
self.prev_uncertainty_code = None # [batch_size, seq_len] with values [0-65535]
```
**Parameter cost**: 65,536 Γ— 4096 = 268,435,456 parameters (3.35% overhead)
### 2. Uncertainty Injection (PrismaVLModel.forward)
During forward pass, after creating input embeddings:
```python
# Look up uncertainty embeddings from previous step
uncertainty_embeds = self.uncertainty_embeddings(prev_uncertainty_code)
# Shift right: position i gets uncertainty from position i-1
uncertainty_shifted = pad(uncertainty_embeds[:, :-1, :], (0,0,1,0))
# Inject into input
inputs_embeds = inputs_embeds + uncertainty_shifted
```
Now the model sees: **[Token embedding] + [How uncertain was I last time?]**
### 3. Uncertainty Computation (PrismaVLForConditionalGeneration.forward)
After computing logits, during training:
```python
# Compute entropy (uncertainty) of predictions
probs = logits.softmax(-1)
entropy = -(probs * log(probs)).sum(-1)
# Normalize to [0, 1]
entropy_norm = entropy / log(vocab_size)
# Quantize to 16 bits (0-65535)
uncertainty_code = (entropy_norm * 65535).long()
# Store for next step
self.model.prev_uncertainty_code = uncertainty_code
```
## How It Works (Step by Step)
### Inference/Generation:
1. **Token 0**: No previous uncertainty β†’ Use neutral (32768)
2. **Token 1**: Predict β†’ Measure confidence β†’ Encode as 0-65535
3. **Token 2**: Inject uncertainty signal from Token 1 β†’ Predict (now calibrated)
4. **Token 3**: Inject uncertainty from Token 2 β†’ Predict
5. ... and so on
### Training:
Model learns the uncertainty embeddings through backpropagation:
- Embedding #0-16383: "I was very confident" β†’ Model learns to stay confident
- Embedding #16384-32767: "I had medium confidence" β†’ Model learns moderate caution
- Embedding #32768-49151: "I was uncertain" β†’ Model learns to hedge
- Embedding #49152-65535: "I was very uncertain" β†’ Model learns to be conservative
## Key Properties
### 1. Moderate Overhead
- **Parameters**: 268M additional (3.35% of 8B base)
- **Memory**: 2 bytes per token (uncertainty code)
- **Compute**: Negligible (one embedding lookup per token)
### 2. Temporal Awareness
- Model builds a "confidence history" across generation
- Can detect when it's going into unfamiliar territory
- Can recover calibration after uncertain predictions
### 3. Self-Calibration
- No external signals needed
- Model learns its own uncertainty language
- Improves through standard supervised training
### 4. Architecture-Agnostic
- Works with any transformer-based model
- Doesn't modify attention, FFN, or other core components
- Clean separation: uncertainty mechanism vs. base model
## Usage
### Standard Inference
```python
from modeling import PrismaVLForConditionalGeneration
from transformers import AutoProcessor
# Load model (introspective mechanism is built-in)
model = PrismaVLForConditionalGeneration.from_pretrained(
".",
trust_remote_code=True,
dtype=torch.bfloat16,
device_map="auto"
)
processor = AutoProcessor.from_pretrained(".", trust_remote_code=True)
# Use normally - uncertainty tracking happens automatically
messages = [{"role": "user", "content": [{"type": "image", "image": img}, {"type": "text", "text": prompt}]}]
inputs = processor.apply_chat_template(messages, ...)
outputs = model.generate(**inputs)
```
### Training
```python
# Train normally - uncertainty mechanism learns automatically
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
for batch in dataloader:
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
# The uncertainty embeddings will learn to represent
# "how to adjust predictions based on previous confidence"
```
### Resetting Uncertainty (Between Sequences)
```python
# Reset uncertainty cache between independent generations
model.model.reset_uncertainty()
# Generate
outputs = model.generate(...)
```
## What Gets Learned
The 65,536 uncertainty embedding vectors learn to encode:
1. **Confidence Continuation**:
- "Last token was confident" β†’ Maintain confidence (if appropriate)
2. **Uncertainty Propagation**:
- "Last token was uncertain" β†’ Be more conservative
3. **Domain Shifts**:
- Sequence of low uncertainty β†’ sudden high uncertainty β†’ Domain boundary detected
4. **Recovery Patterns**:
- High uncertainty β†’ Gradual return to confidence β†’ Model finding its footing
## Benefits
1. **Better Calibration**: Model knows when it doesn't know
2. **Hallucination Awareness**: Uncertain predictions less likely to compound
3. **Adaptive Confidence**: Can adjust based on recent performance
4. **Interpretability**: Uncertainty codes provide insight into model state
5. **No Inference Cost**: Only active during training (for computing new uncertainties)
## Implementation Details
### Files Modified
- `modeling.py`:
- `PrismaVLModel.__init__()`: Add uncertainty embeddings
- `PrismaVLModel.forward()`: Inject uncertainty signal
- `PrismaVLForConditionalGeneration.forward()`: Compute uncertainty
- Added `reset_uncertainty()` method
### Initialization
- Uncertainty embeddings initialized with `std = config.text_config.initializer_range` (typically 0.02)
- Start neutral: first token uses code 128 (middle of range)
### Compatibility
- Fully backward compatible: model can load existing checkpoints
- New uncertainty embeddings initialize randomly (will be trained)
- No changes to base model weights or architecture
## Comparison to Original Llama 3.2 Example
### Similarities:
- Entropy-based uncertainty measurement
- Temporal feedback loop
- Embedding-based uncertainty representation
### Differences:
- **Quantization**: 16-bit (65,536 levels) vs. 8-bit (256 levels)
- **Resolution**: Fine-grained uncertainty vs. coarse-grained
- **Overhead**: 3.35% parameter overhead vs. ~0.04%
- **Applied to**: Vision-language model (Prisma-VL) vs. pure language model (Llama)
- **Integration**: Built into core architecture vs. wrapper class
- **Scope**: Uncertainty only for text generation (not vision encoding)
## Future Enhancements
Potential extensions:
1. **Multi-resolution Uncertainty**: Track uncertainty at token, word, and sentence levels
2. **Uncertainty-aware Generation**: Sample less when uncertain (lower temperature)
3. **Visual Uncertainty**: Extend mechanism to vision encoder
4. **Cross-modal Uncertainty**: Track alignment confidence between vision and text
5. **Explicit Uncertainty Tokens**: Add special tokens to express uncertainty in output
## Citation
Inspired by temporal feedback loop patterns, enhanced with 16-bit high-resolution quantization for fine-grained uncertainty representation.
---
**Model**: Prisma-VL-8B
**Date**: 2025
**Architecture**: Integrated 16-bit temporal uncertainty feedback mechanism
**Parameter Overhead**: 268M (3.35%)
**Memory Overhead**: 2 bytes/token