FP8 Model with Per-Tensor Precision Recovery
- Source:
https://huggingface.co/lllyasviel/control_v11p_sd15_seg - Original File:
diffusion_pytorch_model.fp16.safetensors - FP8 Format:
E5M2 - FP8 File:
diffusion_pytorch_model.fp16-fp8-e5m2.safetensors - Recovery File:
diffusion_pytorch_model.fp16-recovery.safetensors
Recovery Rules Used
[
{
"key_pattern": "vae",
"dim": 4,
"method": "diff"
},
{
"key_pattern": "encoder",
"dim": 4,
"method": "diff"
},
{
"key_pattern": "decoder",
"dim": 4,
"method": "diff"
},
{
"key_pattern": "text",
"dim": 2,
"min_size": 10000,
"method": "lora",
"rank": 64
},
{
"key_pattern": "emb",
"dim": 2,
"min_size": 10000,
"method": "lora",
"rank": 64
},
{
"key_pattern": "attn",
"dim": 2,
"min_size": 10000,
"method": "lora",
"rank": 128
},
{
"key_pattern": "conv",
"dim": 4,
"method": "diff"
},
{
"key_pattern": "resnet",
"dim": 4,
"method": "diff"
},
{
"key_pattern": "all",
"method": "none"
}
]
Usage (Inference)
from safetensors.torch import load_file
import torch
# Load FP8 model
fp8_state = load_file("diffusion_pytorch_model.fp16-fp8-e5m2.safetensors")
# Load recovery weights if available
recovery_state = load_file("diffusion_pytorch_model.fp16-recovery.safetensors") if "diffusion_pytorch_model.fp16-recovery.safetensors" and os.path.exists("diffusion_pytorch_model.fp16-recovery.safetensors") else {}
# Reconstruct high-precision weights
reconstructed = {}
for key in fp8_state:
fp8_weight = fp8_state[key].to(torch.float32) # Convert to float32 for computation
# Apply LoRA recovery if available
lora_a_key = f"lora_A.{key}"
lora_b_key = f"lora_B.{key}"
if lora_a_key in recovery_state and lora_b_key in recovery_state:
A = recovery_state[lora_a_key].to(torch.float32)
B = recovery_state[lora_b_key].to(torch.float32)
# Reconstruct the low-rank approximation
lora_weight = B @ A
fp8_weight = fp8_weight + lora_weight
# Apply difference recovery if available
diff_key = f"diff.{key}"
if diff_key in recovery_state:
diff = recovery_state[diff_key].to(torch.float32)
fp8_weight = fp8_weight + diff
reconstructed[key] = fp8_weight
# Use reconstructed weights in your model
model.load_state_dict(reconstructed)
Note: For best results, use the same recovery configuration during inference as was used during extraction. Requires PyTorch โฅ 2.1 for FP8 support.
Statistics
- Total layers: 340
- Layers with recovery: 96
- LoRA recovery: 68
- Difference recovery: 28
- Downloads last month
- 6
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support