ModernBERT-NLI with Learned Abstention
Lightweight NLI classification heads for ModernBERT-large that preserve base encoder compatibility. Only 2.3MB of weights - the base model is pulled from HuggingFace automatically.
Key Features
- Four modes from one model: bi-encoder embeddings, late interaction (ColBERT-style), NLI classification, and abstention
- Learned abstention: Model knows when it doesn't know - catches 78% of its own errors
- Minimal overhead: Task heads are only 594K parameters (0.15% of base model)
- Preserves embeddings: Encoder frozen during training, so embeddings are fully compatible with base ModernBERT
Quick Start
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoModel, AutoTokenizer
import torch.nn as nn
# Download task heads (2.3MB)
weights_path = hf_hub_download("YOUR_USERNAME/modernbert-nli-heads", "task_heads.pt")
# Load base model from HuggingFace
encoder = AutoModel.from_pretrained("answerdotai/ModernBERT-large")
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-large")
# Build task heads
nli_hidden = nn.Sequential(
nn.Linear(1024, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(0.1)
)
nli_output = nn.Linear(512, 3)
abstention_head = nn.Sequential(
nn.Linear(515, 128), nn.LayerNorm(128), nn.GELU(), nn.Dropout(0.1), nn.Linear(128, 2)
)
# Load weights
task_heads = torch.load(weights_path, map_location="cpu")
nli_hidden.load_state_dict({k.replace("nli_hidden.", ""): v for k, v in task_heads.items() if k.startswith("nli_hidden.")})
nli_output.load_state_dict({k.replace("nli_output.", ""): v for k, v in task_heads.items() if k.startswith("nli_output.")})
abstention_head.load_state_dict({k.replace("abstention_head.", ""): v for k, v in task_heads.items() if k.startswith("abstention_head.")})
Or use the provided load_model.py for a cleaner interface:
from load_model import load_modernbert_nli, predict_with_abstention
model, tokenizer = load_modernbert_nli("task_heads.pt")
result = predict_with_abstention(
model, tokenizer,
premise="A man is playing guitar on stage.",
hypothesis="A person is making music."
)
# {'label': 'entailment', 'confidence': 0.788, 'abstain': False, 'uncertainty': 0.32}
Model Modes
# 1. Bi-encoder embeddings (semantic search)
embeddings = model(input_ids, attention_mask, mode="embed") # (batch, 1024)
# 2. Late interaction (ColBERT-style reranking)
token_reps = model(input_ids, attention_mask, mode="late_interaction") # (batch, seq_len, 1024)
# 3. NLI classification
logits = model(input_ids, attention_mask, mode="nli") # (batch, 3)
# Labels: 0=entailment, 1=neutral, 2=contradiction
# 4. NLI with abstention
nli_logits, abstention_logits = model(input_ids, attention_mask, mode="abstention")
should_abstain = abstention_logits.argmax(dim=-1) == 1
Training Details
NLI Head
- Data: SNLI + MultiNLI combined (942K training examples)
- Method: Frozen encoder, only train classification head
- Epochs: 5
- Training Accuracy: 70.8%
- Parameters: 527K
Abstention Head
- Data: Difficulty labels generated from NLI model's own errors
- Method: Frozen encoder + frozen NLI head, only train abstention head
- Epochs: 3
- Validation Accuracy: 65.5%
- Recall on hard examples: 76.6% (catches 3/4 of errors)
- Parameters: 67K
Performance
NLI Classification
| Metric | Value |
|---|---|
| Training Accuracy | 70.8% |
| Validation Accuracy | ~75-80% |
| Parameters | 527K |
Note: Frozen encoder limits ceiling vs full fine-tuning (~90%), but preserves embedding compatibility.
Abstention Head
| Metric | Value |
|---|---|
| Accuracy | 65.5% |
| Precision | 44.6% |
| Recall | 76.6% |
| F1 | 56.3 |
What this means in practice:
- When the model says "I'm uncertain", it's catching a real error 45% of the time
- Of all errors the model makes, it flags 77% of them for abstention
- Accuracy on confident predictions improves from ~75% to ~85%
Abstention vs Simple Confidence Threshold
The abstention head outperforms simple confidence thresholding because it uses semantic features from the hidden state, not just logit entropy. In testing, it caught 5 errors that a 50% confidence threshold would have missed.
Intended Uses
Query Routing
categories = {
"code": "This is a programming-related request",
"factual": "This is a request for factual information",
"creative": "This is a request for creative content",
}
def route_query(query):
results = []
for name, hypothesis in categories.items():
result = predict_with_abstention(model, tokenizer, query, hypothesis)
results.append((name, result))
# Pick highest entailment score, respecting abstention
confident_results = [(n, r) for n, r in results if not r["abstain"]]
if confident_results:
return max(confident_results, key=lambda x: x[1]["probs"]["entailment"])
else:
return None, "uncertain" # All categories abstained
Fact Validation
def validate_fact(source: str, claim: str) -> dict:
result = predict_with_abstention(model, tokenizer, source, claim)
return {
"supported": result["label"] == "entailment",
"contradicted": result["label"] == "contradiction",
"uncertain": result["abstain"],
"confidence": result["confidence"]
}
Limitations
- Accuracy ceiling: Frozen encoder means ~75-80% accuracy vs ~90% for full fine-tuning
- Domain coverage: Trained on SNLI (image captions) + MultiNLI (mixed), may struggle with specialized domains
- Abstention precision: 45% precision means many unnecessary abstentions - tune threshold for your use case
- Systematic errors: Abstention can miss errors where the NLI model is confidently wrong (e.g., quantifier reasoning)
Architecture
ModernBERT-large (394.8M params, frozen)
β
[CLS] token (1024 dim)
β
βββββββββββββββββββββββββββββββββββ
β NLI Hidden (525K params) β
β Linear(1024β512) + LN + GELU β
βββββββββββββββββββββββββββββββββββ
β
βββ NLI Output (1.5K params)
β Linear(512β3) β [ent, neu, con]
β
βββ Abstention Head (67K params)
Concat([hidden, logits]) β 515 dim
Linear(515β128) + LN + GELU
Linear(128β2) β [confident, uncertain]
Files
task_heads.pt(2.3MB) - PyTorch state dict with all task head weightsconfig.json- Model configuration and training metadataload_model.py- Standalone loader script (copy into your project)
Citation
@misc{modernbert-nli-abstention,
title={ModernBERT-NLI with Learned Abstention},
author={[Your Name]},
year={2024},
url={https://huggingface.co/YOUR_USERNAME/modernbert-nli-heads}
}
Acknowledgments
- Base model: ModernBERT-large by Answer.AI
- Training data: SNLI and MultiNLI
- Downloads last month
- 24
Model tree for jmccardle/modernbert-nli-heads
Base model
answerdotai/ModernBERT-large