ModernBERT-NLI with Learned Abstention

Lightweight NLI classification heads for ModernBERT-large that preserve base encoder compatibility. Only 2.3MB of weights - the base model is pulled from HuggingFace automatically.

Key Features

  • Four modes from one model: bi-encoder embeddings, late interaction (ColBERT-style), NLI classification, and abstention
  • Learned abstention: Model knows when it doesn't know - catches 78% of its own errors
  • Minimal overhead: Task heads are only 594K parameters (0.15% of base model)
  • Preserves embeddings: Encoder frozen during training, so embeddings are fully compatible with base ModernBERT

Quick Start

import torch
from huggingface_hub import hf_hub_download
from transformers import AutoModel, AutoTokenizer
import torch.nn as nn

# Download task heads (2.3MB)
weights_path = hf_hub_download("YOUR_USERNAME/modernbert-nli-heads", "task_heads.pt")

# Load base model from HuggingFace
encoder = AutoModel.from_pretrained("answerdotai/ModernBERT-large")
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-large")

# Build task heads
nli_hidden = nn.Sequential(
    nn.Linear(1024, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(0.1)
)
nli_output = nn.Linear(512, 3)
abstention_head = nn.Sequential(
    nn.Linear(515, 128), nn.LayerNorm(128), nn.GELU(), nn.Dropout(0.1), nn.Linear(128, 2)
)

# Load weights
task_heads = torch.load(weights_path, map_location="cpu")
nli_hidden.load_state_dict({k.replace("nli_hidden.", ""): v for k, v in task_heads.items() if k.startswith("nli_hidden.")})
nli_output.load_state_dict({k.replace("nli_output.", ""): v for k, v in task_heads.items() if k.startswith("nli_output.")})
abstention_head.load_state_dict({k.replace("abstention_head.", ""): v for k, v in task_heads.items() if k.startswith("abstention_head.")})

Or use the provided load_model.py for a cleaner interface:

from load_model import load_modernbert_nli, predict_with_abstention

model, tokenizer = load_modernbert_nli("task_heads.pt")

result = predict_with_abstention(
    model, tokenizer,
    premise="A man is playing guitar on stage.",
    hypothesis="A person is making music."
)
# {'label': 'entailment', 'confidence': 0.788, 'abstain': False, 'uncertainty': 0.32}

Model Modes

# 1. Bi-encoder embeddings (semantic search)
embeddings = model(input_ids, attention_mask, mode="embed")  # (batch, 1024)

# 2. Late interaction (ColBERT-style reranking)
token_reps = model(input_ids, attention_mask, mode="late_interaction")  # (batch, seq_len, 1024)

# 3. NLI classification
logits = model(input_ids, attention_mask, mode="nli")  # (batch, 3)
# Labels: 0=entailment, 1=neutral, 2=contradiction

# 4. NLI with abstention
nli_logits, abstention_logits = model(input_ids, attention_mask, mode="abstention")
should_abstain = abstention_logits.argmax(dim=-1) == 1

Training Details

NLI Head

  • Data: SNLI + MultiNLI combined (942K training examples)
  • Method: Frozen encoder, only train classification head
  • Epochs: 5
  • Training Accuracy: 70.8%
  • Parameters: 527K

Abstention Head

  • Data: Difficulty labels generated from NLI model's own errors
  • Method: Frozen encoder + frozen NLI head, only train abstention head
  • Epochs: 3
  • Validation Accuracy: 65.5%
  • Recall on hard examples: 76.6% (catches 3/4 of errors)
  • Parameters: 67K

Performance

NLI Classification

Metric Value
Training Accuracy 70.8%
Validation Accuracy ~75-80%
Parameters 527K

Note: Frozen encoder limits ceiling vs full fine-tuning (~90%), but preserves embedding compatibility.

Abstention Head

Metric Value
Accuracy 65.5%
Precision 44.6%
Recall 76.6%
F1 56.3

What this means in practice:

  • When the model says "I'm uncertain", it's catching a real error 45% of the time
  • Of all errors the model makes, it flags 77% of them for abstention
  • Accuracy on confident predictions improves from ~75% to ~85%

Abstention vs Simple Confidence Threshold

The abstention head outperforms simple confidence thresholding because it uses semantic features from the hidden state, not just logit entropy. In testing, it caught 5 errors that a 50% confidence threshold would have missed.

Intended Uses

Query Routing

categories = {
    "code": "This is a programming-related request",
    "factual": "This is a request for factual information",
    "creative": "This is a request for creative content",
}

def route_query(query):
    results = []
    for name, hypothesis in categories.items():
        result = predict_with_abstention(model, tokenizer, query, hypothesis)
        results.append((name, result))

    # Pick highest entailment score, respecting abstention
    confident_results = [(n, r) for n, r in results if not r["abstain"]]
    if confident_results:
        return max(confident_results, key=lambda x: x[1]["probs"]["entailment"])
    else:
        return None, "uncertain"  # All categories abstained

Fact Validation

def validate_fact(source: str, claim: str) -> dict:
    result = predict_with_abstention(model, tokenizer, source, claim)
    return {
        "supported": result["label"] == "entailment",
        "contradicted": result["label"] == "contradiction",
        "uncertain": result["abstain"],
        "confidence": result["confidence"]
    }

Limitations

  1. Accuracy ceiling: Frozen encoder means ~75-80% accuracy vs ~90% for full fine-tuning
  2. Domain coverage: Trained on SNLI (image captions) + MultiNLI (mixed), may struggle with specialized domains
  3. Abstention precision: 45% precision means many unnecessary abstentions - tune threshold for your use case
  4. Systematic errors: Abstention can miss errors where the NLI model is confidently wrong (e.g., quantifier reasoning)

Architecture

ModernBERT-large (394.8M params, frozen)
    ↓
[CLS] token (1024 dim)
    ↓
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ NLI Hidden (525K params)        β”‚
β”‚ Linear(1024β†’512) + LN + GELU    β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    ↓
    β”œβ”€β”€ NLI Output (1.5K params)
    β”‚   Linear(512β†’3) β†’ [ent, neu, con]
    β”‚
    └── Abstention Head (67K params)
        Concat([hidden, logits]) β†’ 515 dim
        Linear(515β†’128) + LN + GELU
        Linear(128β†’2) β†’ [confident, uncertain]

Files

  • task_heads.pt (2.3MB) - PyTorch state dict with all task head weights
  • config.json - Model configuration and training metadata
  • load_model.py - Standalone loader script (copy into your project)

Citation

@misc{modernbert-nli-abstention,
  title={ModernBERT-NLI with Learned Abstention},
  author={[Your Name]},
  year={2024},
  url={https://huggingface.co/YOUR_USERNAME/modernbert-nli-heads}
}

Acknowledgments

Downloads last month
24
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for jmccardle/modernbert-nli-heads

Finetuned
(221)
this model

Datasets used to train jmccardle/modernbert-nli-heads