FocalCodec-Demo / app.py
MihaiPopa-1's picture
Update app.py
01e25a6 verified
raw
history blame
18.4 kB
import torch
import torchaudio
import gradio as gr
import os
import tempfile
import numpy as np
import struct
# Define the model ID for the 0.16 kbps codec config
MODEL_CONFIG = "lucadellalib/focalcodec_12_5hz"
# Load the model globally using torch.hub
codec = None
try:
print("Loading FocalCodec model...")
codec = torch.hub.load(
repo_or_dir="lucadellalib/focalcodec",
model="focalcodec",
config=MODEL_CONFIG,
force_reload=False,
trust_repo=True
)
codec.eval()
for param in codec.parameters():
param.requires_grad = False
if torch.cuda.is_available():
codec = codec.cuda()
print("Model loaded successfully on GPU!")
else:
print("Model loaded successfully on CPU!")
except Exception as e:
print(f"ERROR loading model via torch.hub: {e}")
print("\nTrying alternative installation method...")
try:
import subprocess
subprocess.check_call(["pip", "install", "focalcodec@git+https://github.com/lucadellalib/focalcodec.git@main"])
import focalcodec
codec = focalcodec.FocalCodec.from_pretrained(MODEL_CONFIG)
codec.eval()
for param in codec.parameters():
param.requires_grad = False
if torch.cuda.is_available():
codec = codec.cuda()
print("Model loaded via pip installation!")
except Exception as e2:
print(f"ERROR with alternative method: {e2}")
codec = None
def save_compressed_codes_optimal(toks, codes, fc_file_path, codec):
"""Save codes with optimal bit packing to achieve true 160 bps"""
codes_cpu = codes.cpu().numpy()
toks_cpu = toks.cpu().numpy()
print(f"\n=== Optimal Compression ===")
print(f"Codes shape: {codes.shape}")
print(f"Codes dtype: {codes.dtype}")
# Determine actual bits needed based on token range
max_token = int(toks_cpu.max())
if max_token <= 1:
bits_needed = 1
elif max_token <= 3:
bits_needed = 2
elif max_token <= 7:
bits_needed = 3
elif max_token <= 15:
bits_needed = 4
elif max_token <= 31:
bits_needed = 5
elif max_token <= 63:
bits_needed = 6
elif max_token <= 127:
bits_needed = 7
elif max_token <= 255:
bits_needed = 8
elif max_token <= 511:
bits_needed = 9
elif max_token <= 1023:
bits_needed = 10
elif max_token <= 2047:
bits_needed = 11
elif max_token <= 4095:
bits_needed = 12
elif max_token <= 8191:
bits_needed = 13
elif max_token <= 16383:
bits_needed = 14
elif max_token <= 32767:
bits_needed = 15
else:
bits_needed = 16
print(f"Token range: 0 to {max_token}")
print(f"Bits needed per token: {bits_needed}")
# If codes are already binary (batch, time, bits), use them directly
if len(codes.shape) == 3 and codes.dtype in [torch.bool, torch.uint8]:
print(f"Using binary codes directly: {codes.shape[2]} bits per token")
# Pack the binary codes
codes_flat = codes_cpu.flatten()
packed_bits = np.packbits(codes_flat)
bits_per_token = codes.shape[2]
num_tokens = codes.shape[1]
else:
# Pack tokens manually using exact bit width
print(f"Packing tokens with {bits_needed} bits each")
toks_flat = toks_cpu.flatten().astype(np.uint32)
num_tokens = len(toks_flat)
# Convert to binary string and pack
total_bits = num_tokens * bits_needed
# Create bit array
bit_array = []
for tok in toks_flat:
# Convert to binary with exact bit width
bits = format(int(tok), f'0{bits_needed}b')
bit_array.extend([int(b) for b in bits])
# Pad to byte boundary
while len(bit_array) % 8 != 0:
bit_array.append(0)
# Pack into bytes
packed_bits = np.packbits(np.array(bit_array, dtype=np.uint8))
bits_per_token = bits_needed
# Write to file
with open(fc_file_path, 'wb') as f:
# Magic number
f.write(b'FC01')
# Metadata
f.write(struct.pack('<I', toks.shape[0])) # batch size
f.write(struct.pack('<I', num_tokens)) # number of tokens
f.write(struct.pack('<B', bits_per_token)) # bits per token
# Packed data
f.write(packed_bits.tobytes())
file_size = os.path.getsize(fc_file_path)
header_size = 4 + 4 + 4 + 1 # magic + 2 ints + 1 byte
data_size = file_size - header_size
print(f"File size: {file_size} bytes (header: {header_size}B, data: {data_size}B)")
print(f"===========================\n")
return file_size, bits_per_token, data_size
def load_compressed_codes_optimal(fc_file_path):
"""Load optimally packed codes"""
with open(fc_file_path, 'rb') as f:
# Verify magic
magic = f.read(4)
if magic != b'FC01':
raise ValueError("Invalid .fc file!")
# Read metadata
batch_size = struct.unpack('<I', f.read(4))[0]
num_tokens = struct.unpack('<I', f.read(4))[0]
bits_per_token = struct.unpack('<B', f.read(1))[0]
# Read packed data
packed_data = np.frombuffer(f.read(), dtype=np.uint8)
print(f"\n=== Loading Optimal Codes ===")
print(f"Batch: {batch_size}, Tokens: {num_tokens}, Bits/token: {bits_per_token}")
# Unpack bits
unpacked_bits = np.unpackbits(packed_data)
# Extract exact number of bits needed
total_bits = num_tokens * bits_per_token
token_bits = unpacked_bits[:total_bits]
# Reconstruct tokens
tokens = []
for i in range(num_tokens):
start = i * bits_per_token
end = start + bits_per_token
token_bits_slice = token_bits[start:end]
# Convert binary to integer
token_value = 0
for bit in token_bits_slice:
token_value = (token_value << 1) | bit
tokens.append(token_value)
tokens_array = np.array(tokens, dtype=np.int64).reshape(batch_size, -1)
tokens_tensor = torch.from_numpy(tokens_array)
print(f"Loaded tokens: {tokens_tensor.shape}")
print(f"==============================\n")
return tokens_tensor
def encode_decode_focal(audio_input):
"""
Processes input audio through the 160 bps FocalCodec, saves the tokens,
and returns both the decoded WAV and the path to the FC file for download.
"""
if codec is None:
return None, None, "โŒ ERROR: Model failed to load. Check console for details."
if audio_input is None:
return None, None, "โŒ Please provide audio input."
try:
sr, wav_numpy = audio_input
print(f"\n{'='*50}")
print(f"Processing new audio...")
print(f"Input audio: sample_rate={sr}, shape={wav_numpy.shape}")
# Handle stereo to mono conversion
if len(wav_numpy.shape) > 1:
if wav_numpy.shape[1] == 2:
wav_numpy = wav_numpy.mean(axis=1)
print("Converted stereo to mono")
elif wav_numpy.shape[0] == 2:
wav_numpy = wav_numpy.mean(axis=0)
print("Converted stereo to mono (channels first)")
# Ensure float32 and normalize
wav_numpy = wav_numpy.astype(np.float32)
if wav_numpy.max() > 1.0 or wav_numpy.min() < -1.0:
wav_numpy = wav_numpy / 32768.0
# Convert to torch tensor
sig = torch.from_numpy(wav_numpy).unsqueeze(0)
# Resample to 16kHz
if sr != codec.sample_rate_input:
print(f"Resampling from {sr}Hz to {codec.sample_rate_input}Hz...")
resampler = torchaudio.transforms.Resample(
orig_freq=sr,
new_freq=codec.sample_rate_input
)
sig = resampler(sig)
print(f"Signal shape: {sig.shape}")
if torch.cuda.is_available():
sig = sig.cuda()
# --- Encode and Decode ---
with torch.no_grad():
print("\n--- Encoding ---")
toks = codec.sig_to_toks(sig)
duration_sec = sig.shape[-1] / codec.sample_rate_input
token_rate = toks.shape[1] / duration_sec
print(f"Tokens shape: {toks.shape}")
print(f"Token range: {toks.min().item()} to {toks.max().item()}")
print(f"Duration: {duration_sec:.2f}s")
print(f"Token rate: {token_rate:.2f} tokens/sec")
# Get binary codes
codes = codec.toks_to_codes(toks)
print(f"Codes shape: {codes.shape}")
print(f"Codes dtype: {codes.dtype}")
if len(codes.shape) == 3:
print(f"Bits per token (from codes): {codes.shape[2]}")
print("\n--- Decoding ---")
rec_sig = codec.toks_to_sig(toks)
print(f"Reconstructed signal shape: {rec_sig.shape}")
# --- Save with optimal bit packing ---
temp_dir = tempfile.mkdtemp()
fc_file_path = os.path.join(temp_dir, "compressed_tokens.fc")
file_size, bits_per_token, data_size = save_compressed_codes_optimal(
toks, codes, fc_file_path, codec
)
# Calculate bitrates
total_bitrate = (file_size * 8) / duration_sec
data_bitrate = (data_size * 8) / duration_sec
theoretical_bitrate = token_rate * bits_per_token
print(f"--- Results ---")
print(f"Total bitrate: {total_bitrate:.1f} bps (with header)")
print(f"Data bitrate: {data_bitrate:.1f} bps (data only)")
print(f"Theoretical: {theoretical_bitrate:.1f} bps")
print(f"Target: 160 bps")
print(f"Efficiency: {(160/data_bitrate)*100:.1f}% of target")
print(f"{'='*50}\n")
# Prepare output
decoded_wav_output = rec_sig.cpu().numpy().squeeze()
if len(decoded_wav_output.shape) == 0:
decoded_wav_output = decoded_wav_output.reshape(1)
status_msg = f"โœ… {duration_sec:.1f}s | {file_size}B | {data_bitrate:.0f} bps | {bits_per_token} bits/tok | target: 160 bps"
return (codec.sample_rate_output, decoded_wav_output), fc_file_path, status_msg
except Exception as e:
error_msg = f"โŒ Error: {str(e)}"
print(error_msg)
import traceback
traceback.print_exc()
return None, None, error_msg
def decode_from_fc_file(fc_file):
"""Decode audio from uploaded .fc file"""
if codec is None:
return None, "โŒ Model not loaded"
if fc_file is None:
return None, "โŒ Please upload a .fc file"
try:
print(f"\n{'='*50}")
print(f"Decoding from file: {fc_file.name}")
# Load tokens
toks = load_compressed_codes_optimal(fc_file.name)
if torch.cuda.is_available():
toks = toks.cuda()
# Decode to audio
with torch.no_grad():
print("Decoding tokens to audio...")
rec_sig = codec.toks_to_sig(toks)
print(f"Reconstructed signal shape: {rec_sig.shape}")
decoded_wav = rec_sig.cpu().numpy().squeeze()
# Calculate stats
duration_sec = decoded_wav.shape[0] / codec.sample_rate_output
file_size = os.path.getsize(fc_file.name)
header_size = 4 + 4 + 4 + 1
data_size = file_size - header_size
bitrate = (data_size * 8) / duration_sec
print(f"Duration: {duration_sec:.2f}s")
print(f"Bitrate: {bitrate:.1f} bps")
print(f"{'='*50}\n")
status = f"โœ… Decoded! {duration_sec:.1f}s | {bitrate:.0f} bps"
return (codec.sample_rate_output, decoded_wav), status
except Exception as e:
import traceback
traceback.print_exc()
return None, f"โŒ Error: {str(e)}"
# --- Gradio Interface ---
with gr.Blocks(title="FocalCodec 160 bps", theme=gr.themes.Soft()) as iface:
gr.Markdown("# ๐ŸŽ™๏ธ FocalCodec at 160 bps")
gr.Markdown(f"**Neural speech codec at insanely low bitrate!** Using `{MODEL_CONFIG}`")
gr.Markdown("โš ๏ธ **Optimized for speech only** - not suitable for music | ๐Ÿ”ฅ **1600x compression ratio!**")
with gr.Tab("๐ŸŽค Encode Audio"):
gr.Markdown("### Compress audio to ~160 bps with optimal bit packing")
with gr.Row():
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="numpy",
label="Input Audio (any format/sample rate)"
)
with gr.Column():
audio_output = gr.Audio(
type="numpy",
label="๐Ÿ”Š Decoded Output (16kHz)"
)
file_output = gr.File(
label="๐Ÿ’พ Download Compressed .fc File"
)
status_output = gr.Textbox(label="๐Ÿ“Š Status", lines=2)
encode_btn = gr.Button("๐Ÿ”„ Encode & Decode", variant="primary", size="lg")
encode_btn.click(
fn=encode_decode_focal,
inputs=[audio_input],
outputs=[audio_output, file_output, status_output]
)
gr.Markdown("### How it works:")
gr.Markdown("- โœ… Automatically resamples to 16kHz")
gr.Markdown("- โœ… Converts stereo to mono")
gr.Markdown("- โœ… Encodes to discrete tokens (~12.5 tokens/sec)")
gr.Markdown("- โœ… Packs tokens using only needed bits (no waste!)")
gr.Markdown("- โœ… Decodes tokens back to audio")
gr.Markdown("- ๐Ÿ“ˆ Check console for detailed bitrate analysis!")
with gr.Tab("๐Ÿ“‚ Decode from .fc File"):
gr.Markdown("### Decode previously compressed audio")
with gr.Row():
fc_input = gr.File(
label="Upload .fc File",
file_types=[".fc"]
)
with gr.Column():
decoded_output = gr.Audio(
type="numpy",
label="๐Ÿ”Š Decoded Audio"
)
decode_status = gr.Textbox(label="๐Ÿ“Š Status", lines=2)
decode_btn = gr.Button("๐Ÿ”Š Decode Audio", variant="primary", size="lg")
decode_btn.click(
fn=decode_from_fc_file,
inputs=[fc_input],
outputs=[decoded_output, decode_status]
)
gr.Markdown("### Note:")
gr.Markdown("Upload a .fc file created by this tool to decode it back to audio.")
with gr.Tab("โ„น๏ธ About"):
gr.Markdown("""
## FocalCodec - Ultra Low Bitrate Neural Audio Codec
### ๐ŸŽฏ Compression Ratios:
| Format | Bitrate | 1-Hour File Size | Compression |
|--------|---------|------------------|-------------|
| **Uncompressed PCM** (16kHz mono) | 256 kbps | ~115 MB | 1x |
| **MP3** (standard) | 128 kbps | ~57 MB | 2x |
| **Opus** (voice optimized) | 16 kbps | ~7.2 MB | 16x |
| **FocalCodec** | **0.16 kbps** | **~72 KB** | **1600x** ๐Ÿ”ฅ |
### ๐Ÿ’ก Use Cases:
- ๐Ÿ“ž **Ultra-low bandwidth voice calls** (satellite, deep space)
- ๐Ÿค– **AI-generated podcasts** (NotebookLM-style apps)
- ๐ŸŒ **Low-bandwidth regions** (2G networks)
- ๐Ÿ“ป **Emergency communications** (disaster relief)
- ๐ŸŽ“ **Educational content distribution** (offline learning)
- ๐Ÿ’พ **Voice memo storage** (years of recordings in MB)
### โš–๏ธ Trade-offs:
**Pros:**
- โœ… Insanely efficient compression (1600x!)
- โœ… Speech remains highly intelligible
- โœ… Works on any sample rate (auto-resamples)
- โœ… Tiny storage/bandwidth requirements
**Cons:**
- โŒ Voice characteristics may change
- โŒ Emotional nuances can be lost
- โŒ Occasional pronunciation artifacts
- โŒ Not suitable for music or non-speech audio
### ๐Ÿ”ง Technical Details:
- **Model:** `lucadellalib/focalcodec_12_5hz`
- **Sample Rate:** 16 kHz
- **Token Rate:** ~12.5 tokens/second
- **Bits per Token:** 13 bits (auto-detected, optimally packed)
- **Target Bitrate:** 160 bps (12.5 ร— 13 = 162.5 bps)
- **File Format:** Custom binary format with metadata header
### ๐Ÿงฎ How We Achieve 160 bps:
Traditional approach would waste bits:
```
Token (0-8191) โ†’ int16 (16 bits) โ†’ 16 ร— 12.5 = 200 bps โŒ
Wasting 3 bits per token!
```
Our optimal approach:
```
Token (0-8191) โ†’ 13 bits exactly โ†’ 13 ร— 12.5 = 162.5 bps โœ…
Zero waste!
```
### ๐Ÿ”ฌ Debug Information:
Check the **console/terminal** for detailed encoding information:
- Actual token rate and range
- Bits per token (detected automatically)
- Expected vs actual bitrate
- File size breakdown (header vs data)
- Compression efficiency
### ๐Ÿ“š Example Use Case - AI Podcast Library:
Imagine storing **1000 hours** of AI-generated podcasts:
- **Uncompressed:** 115 GB
- **MP3:** 57 GB
- **Opus:** 7.2 GB
- **FocalCodec:** **72 MB** ๐Ÿคฏ
You could fit an entire podcast library on a USB flash drive!
---
### ๐Ÿ”— Links:
- [FocalCodec GitHub](https://github.com/lucadellalib/focalcodec)
- [Research Paper](https://arxiv.org/abs/2410.03608)
### ๐Ÿ—๏ธ Built with:
- PyTorch + TorchAudio
- Gradio
- FocalCodec (Luca Della Libera et al.)
""")
if __name__ == "__main__":
print("\n" + "="*50)
print("๐ŸŽ™๏ธ FocalCodec 160 bps Demo")
print("="*50 + "\n")
iface.launch()