Weight-only quant which makes use of Marlin/Machete kernels on vLLM. (Useful for Ampere and Hopper.) To increase throughput in recent vLLM, this can also be run in W4A8 by setting VLLM_MARLIN_INPUT_DTYPE=int8.

Benchmarks

Original BF16:

|        Tasks        |Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|---------------------|-------|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k                |      3|flexible-extract|     5|exact_match|↑  |0.100|±  |0.0190|
|                     |       |strict-match    |     5|exact_match|↑  |0.072|±  |0.0164|
|kormedmcqa           |      2|none            |      |exact_match|↑  |0.755|±  |0.0134|
| - kormedmcqa_dentist|      2|none            |     5|exact_match|↑  |0.704|±  |0.0289|
| - kormedmcqa_doctor |      2|none            |     5|exact_match|↑  |0.644|±  |0.0303|
| - kormedmcqa_nurse  |      2|none            |     5|exact_match|↑  |0.844|±  |0.0230|
| - kormedmcqa_pharm  |      2|none            |     5|exact_match|↑  |0.828|±  |0.0239|
|medmcqa              |Yaml   |none            |     5|acc        |↑  |0.676|±  |0.0297|
|                     |       |none            |     5|acc_norm   |↑  |0.676|±  |0.0297|

|  Groups  |Version|Filter|n-shot|  Metric   |   |Value|   |Stderr|
|----------|------:|------|------|-----------|---|----:|---|-----:|
|kormedmcqa|      2|none  |      |exact_match|↑  |0.755|±  |0.0134|

This W4A16:

|        Tasks        |Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|---------------------|-------|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k                |      3|flexible-extract|     5|exact_match|↑  |0.088|±  |0.0180|
|                     |       |strict-match    |     5|exact_match|↑  |0.056|±  |0.0146|
|kormedmcqa           |      2|none            |      |exact_match|↑  |0.733|±  |0.0137|
| - kormedmcqa_dentist|      2|none            |     5|exact_match|↑  |0.672|±  |0.0298|
| - kormedmcqa_doctor |      2|none            |     5|exact_match|↑  |0.620|±  |0.0308|
| - kormedmcqa_nurse  |      2|none            |     5|exact_match|↑  |0.828|±  |0.0239|
| - kormedmcqa_pharm  |      2|none            |     5|exact_match|↑  |0.812|±  |0.0248|
|medmcqa              |Yaml   |none            |     5|acc        |↑  |0.672|±  |0.0298|
|                     |       |none            |     5|acc_norm   |↑  |0.672|±  |0.0298|

|  Groups  |Version|Filter|n-shot|  Metric   |   |Value|   |Stderr|
|----------|------:|------|------|-----------|---|----:|---|-----:|
|kormedmcqa|      2|none  |      |exact_match|↑  |0.733|±  |0.0137|

Reproduction

import torch
import random
import base64
from io import BytesIO
from datasets import load_dataset, Dataset, VerificationMode
from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration
from qwen_vl_utils import process_vision_info

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.utils import dispatch_for_generation

# ==========================================
# 1. Configuration
# ==========================================
MODEL_ID = "Qwen/Qwen3-VL-30B-A3B-Thinking"
DATASET_TEXT_ID = "neuralmagic/calibration"
DATASET_IMG_ID = "lmms-lab/flickr30k"

# Total samples for calibration
TOTAL_SAMPLES = 512
# 80% Text, 20% Image
NUM_TEXT_SAMPLES = int(TOTAL_SAMPLES * 0.8)
NUM_IMG_SAMPLES = TOTAL_SAMPLES - NUM_TEXT_SAMPLES
MAX_SEQUENCE_LENGTH = 1024

# ==========================================
# 2. Model Loading
# ==========================================
print(f"Loading {MODEL_ID}...")
model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
    MODEL_ID, 
    dtype="auto", 
    device_map="auto", 
    trust_remote_code=True
)
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)

# ==========================================
# 3. Dataset Processing (Hybrid 80/20)
# ==========================================

def process_text_sample(example):
    """Preprocesses a text-only sample from neuralmagic/calibration."""
    messages = []
    for message in example["messages"]:
        messages.append(
            {
                "role": message["role"],
                "content": [{"type": "text", "text": message["content"]}],
            }
        )

    text_inputs = processor.apply_chat_template(
        messages,
        return_tensors="pt",
        padding=False,
        truncation=True,
        max_length=MAX_SEQUENCE_LENGTH,
        tokenize=True,
        add_special_tokens=False,
        return_dict=True,
        add_generation_prompt=False,
    )
    
    return {
        "input_ids": text_inputs.input_ids[0],
        "attention_mask": text_inputs.attention_mask[0],
        # Explicitly set visual keys to None for text samples
        "pixel_values": None,
        "image_grid_thw": None,
        "video_grid_thw": None
    }

def process_image_sample(example):
    """Preprocesses an image sample from flickr30k."""
    # Convert PIL image to base64 for the chat template
    buffered = BytesIO()
    example["image"].save(buffered, format="PNG")
    encoded_image = base64.b64encode(buffered.getvalue())
    encoded_image_text = encoded_image.decode("utf-8")
    base64_qwen = f"data:image;base64,{encoded_image_text}"

    # Create a generic prompt for the image
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": base64_qwen},
                {"type": "text", "text": "Describe this image in detail."},
            ],
        }
    ]
    
    # Process text formatting
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    
    # Process visual info (extracts pixel values and grid info)
    image_inputs, video_inputs = process_vision_info(messages)

    # Tokenize and create tensors
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=False,
        max_length=MAX_SEQUENCE_LENGTH,
        truncation=True,
    )

    return {
        "input_ids": torch.tensor(inputs["input_ids"][0]),
        "attention_mask": torch.tensor(inputs["attention_mask"][0]),
        # Qwen3 VL visual outputs
        "pixel_values": torch.tensor(inputs["pixel_values"]), 
        "image_grid_thw": torch.tensor(inputs["image_grid_thw"]),
        "video_grid_thw": None # Assuming no video in flickr30k
    }

print(f"Preparing datasets: {NUM_TEXT_SAMPLES} Text / {NUM_IMG_SAMPLES} Images...")

# Load and process Text Dataset
ds_text = load_dataset(DATASET_TEXT_ID, name="LLM", split=f"train[:{NUM_TEXT_SAMPLES}]")
processed_data = []

for sample in ds_text:
    processed_data.append(process_text_sample(sample))

# Load and process Image Dataset
ds_img = load_dataset(DATASET_IMG_ID, data_files="data/test-00001-of-00009.parquet", split=f"train[:{NUM_IMG_SAMPLES}]", verification_mode=VerificationMode.NO_CHECKS)
for sample in ds_img:
    processed_data.append(process_image_sample(sample))

random.shuffle(processed_data)
combined_ds = Dataset.from_list(processed_data)

# ==========================================
# 4. Data Collator
# ==========================================
def hybrid_data_collator(batch):
    """
    Handles batches that might be text-only OR multimodal.
    Removes None values (which represent missing visual data in text samples).
    """
    assert len(batch) == 1, "Batch size must be 1 for oneshot calibration"
    sample = batch[0]
    
    batch_out = {}
    
    # Handle Input IDs and Mask (Always present)
    batch_out["input_ids"] = torch.tensor(sample["input_ids"]).unsqueeze(0)
    batch_out["attention_mask"] = torch.tensor(sample["attention_mask"]).unsqueeze(0)
    
    # Handle Visuals (Only present if not None)
    if sample.get("pixel_values") is not None:
        # Convert list back to tensor if dataset conversion made them lists
        batch_out["pixel_values"] = torch.tensor(sample["pixel_values"])
        
        # Qwen3 usually requires bfloat16 for pixel values
        batch_out["pixel_values"] = batch_out["pixel_values"].to(dtype=torch.bfloat16)

    if sample.get("image_grid_thw") is not None:
        batch_out["image_grid_thw"] = torch.tensor(sample["image_grid_thw"])
        
    if sample.get("video_grid_thw") is not None:
        batch_out["video_grid_thw"] = torch.tensor(sample["video_grid_thw"])

    return batch_out

# ==========================================
# 5. Quantization 
# ==========================================
recipe = QuantizationModifier(
    targets="Linear",
    scheme="W4A16",
    ignore=[
        "re:.*embed_tokens", # Don't mess with token embedding space
#        "re:.*self_attn.*",
        "re:.*lm_head",
        "re:visual.*",
        "re:model.visual.*",
        "re:.*mlp.gate$", # Gate is crucial for MoE
    ],
)

print("Starting One-Shot Calibration...")
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W4A16"
oneshot(
    model=model,
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=TOTAL_SAMPLES,
    dataset=combined_ds,
    data_collator=hybrid_data_collator,
    moe_calibrate_all_experts=True,
    output_dir=SAVE_DIR
)
Downloads last month
49
Safetensors
Model size
5B params
Tensor type
I64
·
I32
·
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for shivak/Qwen3-VL-30B-A3B-Thinking-W4A16

Quantized
(42)
this model