| # Introspective Prisma-VL-8B Architecture | |
| ## Overview | |
| Prisma-VL-8B includes a introspective feedback mechanism that provides fine-grained self-monitoring uncertainty awareness to the model's predictions. | |
| ## Core Innovation | |
| The model now tracks its own prediction uncertainty and uses this as a feedback signal for subsequent predictions. This creates a temporal awareness loop: | |
| ``` | |
| Token t-1: "What's next?" β Prediction + Uncertainty measurement | |
| Token t: [Previous uncertainty signal] + "What's next?" β Better calibrated prediction | |
| ``` | |
| ## Architecture Changes | |
| ### 1. Uncertainty Embeddings (PrismaVLModel) | |
| Added to `PrismaVLModel.__init__()`: | |
| ```python | |
| # 65,536-level uncertainty embedding table | |
| self.n_bits = 16 # 16-bit quantization | |
| self.n_uncertainty_levels = 65536 # 2^16 | |
| # Learned embeddings: one vector per uncertainty level | |
| self.uncertainty_embeddings = nn.Embedding(65536, hidden_dim) | |
| # Cache for uncertainty codes from previous step | |
| self.prev_uncertainty_code = None # [batch_size, seq_len] with values [0-65535] | |
| ``` | |
| **Parameter cost**: 65,536 Γ 4096 = 268,435,456 parameters (3.35% overhead) | |
| ### 2. Uncertainty Injection (PrismaVLModel.forward) | |
| During forward pass, after creating input embeddings: | |
| ```python | |
| # Look up uncertainty embeddings from previous step | |
| uncertainty_embeds = self.uncertainty_embeddings(prev_uncertainty_code) | |
| # Shift right: position i gets uncertainty from position i-1 | |
| uncertainty_shifted = pad(uncertainty_embeds[:, :-1, :], (0,0,1,0)) | |
| # Inject into input | |
| inputs_embeds = inputs_embeds + uncertainty_shifted | |
| ``` | |
| Now the model sees: **[Token embedding] + [How uncertain was I last time?]** | |
| ### 3. Uncertainty Computation (PrismaVLForConditionalGeneration.forward) | |
| After computing logits, during training: | |
| ```python | |
| # Compute entropy (uncertainty) of predictions | |
| probs = logits.softmax(-1) | |
| entropy = -(probs * log(probs)).sum(-1) | |
| # Normalize to [0, 1] | |
| entropy_norm = entropy / log(vocab_size) | |
| # Quantize to 16 bits (0-65535) | |
| uncertainty_code = (entropy_norm * 65535).long() | |
| # Store for next step | |
| self.model.prev_uncertainty_code = uncertainty_code | |
| ``` | |
| ## How It Works (Step by Step) | |
| ### Inference/Generation: | |
| 1. **Token 0**: No previous uncertainty β Use neutral (32768) | |
| 2. **Token 1**: Predict β Measure confidence β Encode as 0-65535 | |
| 3. **Token 2**: Inject uncertainty signal from Token 1 β Predict (now calibrated) | |
| 4. **Token 3**: Inject uncertainty from Token 2 β Predict | |
| 5. ... and so on | |
| ### Training: | |
| Model learns the uncertainty embeddings through backpropagation: | |
| - Embedding #0-16383: "I was very confident" β Model learns to stay confident | |
| - Embedding #16384-32767: "I had medium confidence" β Model learns moderate caution | |
| - Embedding #32768-49151: "I was uncertain" β Model learns to hedge | |
| - Embedding #49152-65535: "I was very uncertain" β Model learns to be conservative | |
| ## Key Properties | |
| ### 1. Moderate Overhead | |
| - **Parameters**: 268M additional (3.35% of 8B base) | |
| - **Memory**: 2 bytes per token (uncertainty code) | |
| - **Compute**: Negligible (one embedding lookup per token) | |
| ### 2. Temporal Awareness | |
| - Model builds a "confidence history" across generation | |
| - Can detect when it's going into unfamiliar territory | |
| - Can recover calibration after uncertain predictions | |
| ### 3. Self-Calibration | |
| - No external signals needed | |
| - Model learns its own uncertainty language | |
| - Improves through standard supervised training | |
| ### 4. Architecture-Agnostic | |
| - Works with any transformer-based model | |
| - Doesn't modify attention, FFN, or other core components | |
| - Clean separation: uncertainty mechanism vs. base model | |
| ## Usage | |
| ### Standard Inference | |
| ```python | |
| from modeling import PrismaVLForConditionalGeneration | |
| from transformers import AutoProcessor | |
| # Load model (introspective mechanism is built-in) | |
| model = PrismaVLForConditionalGeneration.from_pretrained( | |
| ".", | |
| trust_remote_code=True, | |
| dtype=torch.bfloat16, | |
| device_map="auto" | |
| ) | |
| processor = AutoProcessor.from_pretrained(".", trust_remote_code=True) | |
| # Use normally - uncertainty tracking happens automatically | |
| messages = [{"role": "user", "content": [{"type": "image", "image": img}, {"type": "text", "text": prompt}]}] | |
| inputs = processor.apply_chat_template(messages, ...) | |
| outputs = model.generate(**inputs) | |
| ``` | |
| ### Training | |
| ```python | |
| # Train normally - uncertainty mechanism learns automatically | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5) | |
| for batch in dataloader: | |
| outputs = model(**batch) | |
| loss = outputs.loss | |
| loss.backward() | |
| optimizer.step() | |
| # The uncertainty embeddings will learn to represent | |
| # "how to adjust predictions based on previous confidence" | |
| ``` | |
| ### Resetting Uncertainty (Between Sequences) | |
| ```python | |
| # Reset uncertainty cache between independent generations | |
| model.model.reset_uncertainty() | |
| # Generate | |
| outputs = model.generate(...) | |
| ``` | |
| ## What Gets Learned | |
| The 65,536 uncertainty embedding vectors learn to encode: | |
| 1. **Confidence Continuation**: | |
| - "Last token was confident" β Maintain confidence (if appropriate) | |
| 2. **Uncertainty Propagation**: | |
| - "Last token was uncertain" β Be more conservative | |
| 3. **Domain Shifts**: | |
| - Sequence of low uncertainty β sudden high uncertainty β Domain boundary detected | |
| 4. **Recovery Patterns**: | |
| - High uncertainty β Gradual return to confidence β Model finding its footing | |
| ## Benefits | |
| 1. **Better Calibration**: Model knows when it doesn't know | |
| 2. **Hallucination Awareness**: Uncertain predictions less likely to compound | |
| 3. **Adaptive Confidence**: Can adjust based on recent performance | |
| 4. **Interpretability**: Uncertainty codes provide insight into model state | |
| 5. **No Inference Cost**: Only active during training (for computing new uncertainties) | |
| ## Implementation Details | |
| ### Files Modified | |
| - `modeling.py`: | |
| - `PrismaVLModel.__init__()`: Add uncertainty embeddings | |
| - `PrismaVLModel.forward()`: Inject uncertainty signal | |
| - `PrismaVLForConditionalGeneration.forward()`: Compute uncertainty | |
| - Added `reset_uncertainty()` method | |
| ### Initialization | |
| - Uncertainty embeddings initialized with `std = config.text_config.initializer_range` (typically 0.02) | |
| - Start neutral: first token uses code 128 (middle of range) | |
| ### Compatibility | |
| - Fully backward compatible: model can load existing checkpoints | |
| - New uncertainty embeddings initialize randomly (will be trained) | |
| - No changes to base model weights or architecture | |
| ## Comparison to Original Llama 3.2 Example | |
| ### Similarities: | |
| - Entropy-based uncertainty measurement | |
| - Temporal feedback loop | |
| - Embedding-based uncertainty representation | |
| ### Differences: | |
| - **Quantization**: 16-bit (65,536 levels) vs. 8-bit (256 levels) | |
| - **Resolution**: Fine-grained uncertainty vs. coarse-grained | |
| - **Overhead**: 3.35% parameter overhead vs. ~0.04% | |
| - **Applied to**: Vision-language model (Prisma-VL) vs. pure language model (Llama) | |
| - **Integration**: Built into core architecture vs. wrapper class | |
| - **Scope**: Uncertainty only for text generation (not vision encoding) | |
| ## Future Enhancements | |
| Potential extensions: | |
| 1. **Multi-resolution Uncertainty**: Track uncertainty at token, word, and sentence levels | |
| 2. **Uncertainty-aware Generation**: Sample less when uncertain (lower temperature) | |
| 3. **Visual Uncertainty**: Extend mechanism to vision encoder | |
| 4. **Cross-modal Uncertainty**: Track alignment confidence between vision and text | |
| 5. **Explicit Uncertainty Tokens**: Add special tokens to express uncertainty in output | |
| ## Citation | |
| Inspired by temporal feedback loop patterns, enhanced with 16-bit high-resolution quantization for fine-grained uncertainty representation. | |
| --- | |
| **Model**: Prisma-VL-8B | |
| **Date**: 2025 | |
| **Architecture**: Integrated 16-bit temporal uncertainty feedback mechanism | |
| **Parameter Overhead**: 268M (3.35%) | |
| **Memory Overhead**: 2 bytes/token | |