Spaces:
Running
on
Zero
Running
on
Zero
刘鑫
commited on
Commit
·
1700cda
1
Parent(s):
1488551
set zero gpu inference
Browse files
app.py
CHANGED
|
@@ -7,6 +7,7 @@ from typing import Optional, Tuple
|
|
| 7 |
from pathlib import Path
|
| 8 |
import tempfile
|
| 9 |
import soundfile as sf
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
def setup_cache_env():
|
|
@@ -45,31 +46,54 @@ if os.environ.get("HF_REPO_ID", "").strip() == "":
|
|
| 45 |
# Global model cache for ZeroGPU
|
| 46 |
_asr_model = None
|
| 47 |
_voxcpm_model = None
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
|
| 51 |
def predownload_models():
|
| 52 |
"""
|
| 53 |
Pre-download models at startup (runs in main process, not GPU worker).
|
| 54 |
-
|
| 55 |
"""
|
| 56 |
print("=" * 50)
|
| 57 |
-
print("Pre-downloading models to
|
| 58 |
-
print(f"HF_HOME={os.environ.get('HF_HOME')}")
|
| 59 |
print("=" * 50)
|
| 60 |
|
| 61 |
-
# Pre-download ASR model (SenseVoice)
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
print("=" * 50)
|
| 75 |
print("Model pre-download complete!")
|
|
@@ -80,49 +104,24 @@ def predownload_models():
|
|
| 80 |
predownload_models()
|
| 81 |
|
| 82 |
|
| 83 |
-
def _resolve_model_dir() -> str:
|
| 84 |
-
"""
|
| 85 |
-
Resolve model directory:
|
| 86 |
-
1) Use local checkpoint directory if exists
|
| 87 |
-
2) If HF_REPO_ID env is set, download into models/{repo}
|
| 88 |
-
3) Fallback to 'models'
|
| 89 |
-
"""
|
| 90 |
-
if os.path.isdir(_default_local_model_dir):
|
| 91 |
-
return _default_local_model_dir
|
| 92 |
-
|
| 93 |
-
repo_id = os.environ.get("HF_REPO_ID", "").strip()
|
| 94 |
-
if len(repo_id) > 0:
|
| 95 |
-
target_dir = os.path.join("models", repo_id.replace("/", "__"))
|
| 96 |
-
if not os.path.isdir(target_dir):
|
| 97 |
-
try:
|
| 98 |
-
from huggingface_hub import snapshot_download
|
| 99 |
-
os.makedirs(target_dir, exist_ok=True)
|
| 100 |
-
print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...")
|
| 101 |
-
snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
|
| 102 |
-
except Exception as e:
|
| 103 |
-
print(f"Warning: HF download failed: {e}. Falling back to 'models'.")
|
| 104 |
-
return "models"
|
| 105 |
-
return target_dir
|
| 106 |
-
return "models"
|
| 107 |
-
|
| 108 |
-
|
| 109 |
def get_asr_model():
|
| 110 |
-
"""Lazy load ASR model from
|
| 111 |
global _asr_model
|
| 112 |
if _asr_model is None:
|
| 113 |
-
setup_cache_env()
|
| 114 |
-
|
| 115 |
from funasr import AutoModel
|
|
|
|
| 116 |
print("Loading ASR model...")
|
| 117 |
-
print(f"
|
|
|
|
| 118 |
_asr_model = AutoModel(
|
| 119 |
-
model=
|
| 120 |
-
hub="hf", # Use HuggingFace Hub
|
| 121 |
disable_update=True,
|
| 122 |
log_level='INFO',
|
| 123 |
device="cuda:0",
|
| 124 |
)
|
| 125 |
-
|
|
|
|
|
|
|
| 126 |
return _asr_model
|
| 127 |
|
| 128 |
|
|
@@ -130,19 +129,19 @@ def get_voxcpm_model():
|
|
| 130 |
"""Lazy load VoxCPM model (without denoiser)."""
|
| 131 |
global _voxcpm_model
|
| 132 |
if _voxcpm_model is None:
|
| 133 |
-
setup_cache_env()
|
| 134 |
-
|
| 135 |
import voxcpm
|
|
|
|
| 136 |
print("Loading VoxCPM model...")
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
_voxcpm_model = voxcpm.VoxCPM(
|
| 141 |
-
voxcpm_model_path=
|
| 142 |
optimize=False,
|
| 143 |
enable_denoiser=False, # Disable denoiser to avoid ZipEnhancer download
|
| 144 |
)
|
| 145 |
-
|
|
|
|
|
|
|
| 146 |
return _voxcpm_model
|
| 147 |
|
| 148 |
|
|
@@ -151,9 +150,16 @@ def prompt_wav_recognition(prompt_wav: Optional[str]) -> str:
|
|
| 151 |
"""Use ASR to recognize prompt audio text."""
|
| 152 |
if prompt_wav is None or not prompt_wav.strip():
|
| 153 |
return ""
|
|
|
|
|
|
|
| 154 |
asr_model = get_asr_model()
|
|
|
|
| 155 |
res = asr_model.generate(input=prompt_wav, language="auto", use_itn=True)
|
|
|
|
| 156 |
text = res[0]["text"].split('|>')[-1]
|
|
|
|
|
|
|
|
|
|
| 157 |
return text
|
| 158 |
|
| 159 |
|
|
@@ -187,7 +193,10 @@ def generate_tts_audio_gpu(
|
|
| 187 |
prompt_wav_path = f.name
|
| 188 |
|
| 189 |
try:
|
| 190 |
-
print(
|
|
|
|
|
|
|
|
|
|
| 191 |
wav = voxcpm_model.generate(
|
| 192 |
text=text,
|
| 193 |
prompt_text=prompt_text,
|
|
@@ -197,6 +206,11 @@ def generate_tts_audio_gpu(
|
|
| 197 |
normalize=do_normalize,
|
| 198 |
denoise=False, # Denoiser disabled
|
| 199 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
return (voxcpm_model.tts_model.sample_rate, wav)
|
| 201 |
finally:
|
| 202 |
# Cleanup temp file
|
|
|
|
| 7 |
from pathlib import Path
|
| 8 |
import tempfile
|
| 9 |
import soundfile as sf
|
| 10 |
+
import time
|
| 11 |
|
| 12 |
|
| 13 |
def setup_cache_env():
|
|
|
|
| 46 |
# Global model cache for ZeroGPU
|
| 47 |
_asr_model = None
|
| 48 |
_voxcpm_model = None
|
| 49 |
+
|
| 50 |
+
# Fixed local paths for models (to avoid repeated downloads in GPU workers)
|
| 51 |
+
ASR_LOCAL_DIR = "./models/SenseVoiceSmall"
|
| 52 |
+
VOXCPM_LOCAL_DIR = "./models/VoxCPM1.5"
|
| 53 |
|
| 54 |
|
| 55 |
def predownload_models():
|
| 56 |
"""
|
| 57 |
Pre-download models at startup (runs in main process, not GPU worker).
|
| 58 |
+
Download to fixed local directories so GPU workers can reuse them.
|
| 59 |
"""
|
| 60 |
print("=" * 50)
|
| 61 |
+
print("Pre-downloading models to local directories...")
|
|
|
|
| 62 |
print("=" * 50)
|
| 63 |
|
| 64 |
+
# Pre-download ASR model (SenseVoice) to fixed local directory
|
| 65 |
+
if not os.path.isdir(ASR_LOCAL_DIR) or not os.path.exists(os.path.join(ASR_LOCAL_DIR, "model.pt")):
|
| 66 |
+
try:
|
| 67 |
+
from huggingface_hub import snapshot_download
|
| 68 |
+
asr_model_id = "FunAudioLLM/SenseVoiceSmall"
|
| 69 |
+
print(f"Pre-downloading ASR model: {asr_model_id} -> {ASR_LOCAL_DIR}")
|
| 70 |
+
os.makedirs(ASR_LOCAL_DIR, exist_ok=True)
|
| 71 |
+
snapshot_download(
|
| 72 |
+
repo_id=asr_model_id,
|
| 73 |
+
local_dir=ASR_LOCAL_DIR,
|
| 74 |
+
)
|
| 75 |
+
print(f"ASR model downloaded to: {ASR_LOCAL_DIR}")
|
| 76 |
+
except Exception as e:
|
| 77 |
+
print(f"Warning: Failed to pre-download ASR model: {e}")
|
| 78 |
+
else:
|
| 79 |
+
print(f"ASR model already exists at: {ASR_LOCAL_DIR}")
|
| 80 |
+
|
| 81 |
+
# Pre-download VoxCPM model to fixed local directory
|
| 82 |
+
if not os.path.isdir(VOXCPM_LOCAL_DIR) or not os.path.exists(os.path.join(VOXCPM_LOCAL_DIR, "model.safetensors")):
|
| 83 |
+
try:
|
| 84 |
+
from huggingface_hub import snapshot_download
|
| 85 |
+
voxcpm_model_id = os.environ.get("HF_REPO_ID", "openbmb/VoxCPM1.5")
|
| 86 |
+
print(f"Pre-downloading VoxCPM model: {voxcpm_model_id} -> {VOXCPM_LOCAL_DIR}")
|
| 87 |
+
os.makedirs(VOXCPM_LOCAL_DIR, exist_ok=True)
|
| 88 |
+
snapshot_download(
|
| 89 |
+
repo_id=voxcpm_model_id,
|
| 90 |
+
local_dir=VOXCPM_LOCAL_DIR,
|
| 91 |
+
)
|
| 92 |
+
print(f"VoxCPM model downloaded to: {VOXCPM_LOCAL_DIR}")
|
| 93 |
+
except Exception as e:
|
| 94 |
+
print(f"Warning: Failed to pre-download VoxCPM model: {e}")
|
| 95 |
+
else:
|
| 96 |
+
print(f"VoxCPM model already exists at: {VOXCPM_LOCAL_DIR}")
|
| 97 |
|
| 98 |
print("=" * 50)
|
| 99 |
print("Model pre-download complete!")
|
|
|
|
| 104 |
predownload_models()
|
| 105 |
|
| 106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
def get_asr_model():
|
| 108 |
+
"""Lazy load ASR model from local directory."""
|
| 109 |
global _asr_model
|
| 110 |
if _asr_model is None:
|
|
|
|
|
|
|
| 111 |
from funasr import AutoModel
|
| 112 |
+
print("=" * 50)
|
| 113 |
print("Loading ASR model...")
|
| 114 |
+
print(f" Using local path: {ASR_LOCAL_DIR}")
|
| 115 |
+
start_time = time.time()
|
| 116 |
_asr_model = AutoModel(
|
| 117 |
+
model=ASR_LOCAL_DIR, # Use local directory path
|
|
|
|
| 118 |
disable_update=True,
|
| 119 |
log_level='INFO',
|
| 120 |
device="cuda:0",
|
| 121 |
)
|
| 122 |
+
load_time = time.time() - start_time
|
| 123 |
+
print(f"ASR model loaded. (耗时: {load_time:.2f}s)")
|
| 124 |
+
print("=" * 50)
|
| 125 |
return _asr_model
|
| 126 |
|
| 127 |
|
|
|
|
| 129 |
"""Lazy load VoxCPM model (without denoiser)."""
|
| 130 |
global _voxcpm_model
|
| 131 |
if _voxcpm_model is None:
|
|
|
|
|
|
|
| 132 |
import voxcpm
|
| 133 |
+
print("=" * 50)
|
| 134 |
print("Loading VoxCPM model...")
|
| 135 |
+
print(f" Using local path: {VOXCPM_LOCAL_DIR}")
|
| 136 |
+
start_time = time.time()
|
|
|
|
| 137 |
_voxcpm_model = voxcpm.VoxCPM(
|
| 138 |
+
voxcpm_model_path=VOXCPM_LOCAL_DIR,
|
| 139 |
optimize=False,
|
| 140 |
enable_denoiser=False, # Disable denoiser to avoid ZipEnhancer download
|
| 141 |
)
|
| 142 |
+
load_time = time.time() - start_time
|
| 143 |
+
print(f"VoxCPM model loaded. (耗时: {load_time:.2f}s)")
|
| 144 |
+
print("=" * 50)
|
| 145 |
return _voxcpm_model
|
| 146 |
|
| 147 |
|
|
|
|
| 150 |
"""Use ASR to recognize prompt audio text."""
|
| 151 |
if prompt_wav is None or not prompt_wav.strip():
|
| 152 |
return ""
|
| 153 |
+
print("=" * 50)
|
| 154 |
+
print("[ASR] 开始语音识别...")
|
| 155 |
asr_model = get_asr_model()
|
| 156 |
+
start_time = time.time()
|
| 157 |
res = asr_model.generate(input=prompt_wav, language="auto", use_itn=True)
|
| 158 |
+
inference_time = time.time() - start_time
|
| 159 |
text = res[0]["text"].split('|>')[-1]
|
| 160 |
+
print(f"[ASR] 识别结果: {text}")
|
| 161 |
+
print(f"[ASR] 推理耗时: {inference_time:.2f}s")
|
| 162 |
+
print("=" * 50)
|
| 163 |
return text
|
| 164 |
|
| 165 |
|
|
|
|
| 193 |
prompt_wav_path = f.name
|
| 194 |
|
| 195 |
try:
|
| 196 |
+
print("=" * 50)
|
| 197 |
+
print("[TTS] 开始语音合成...")
|
| 198 |
+
print(f"[TTS] 目标文本: {text}")
|
| 199 |
+
start_time = time.time()
|
| 200 |
wav = voxcpm_model.generate(
|
| 201 |
text=text,
|
| 202 |
prompt_text=prompt_text,
|
|
|
|
| 206 |
normalize=do_normalize,
|
| 207 |
denoise=False, # Denoiser disabled
|
| 208 |
)
|
| 209 |
+
inference_time = time.time() - start_time
|
| 210 |
+
audio_duration = len(wav) / voxcpm_model.tts_model.sample_rate
|
| 211 |
+
rtf = inference_time / audio_duration if audio_duration > 0 else 0
|
| 212 |
+
print(f"[TTS] 推理耗时: {inference_time:.2f}s | 音频时长: {audio_duration:.2f}s | RTF: {rtf:.3f}")
|
| 213 |
+
print("=" * 50)
|
| 214 |
return (voxcpm_model.tts_model.sample_rate, wav)
|
| 215 |
finally:
|
| 216 |
# Cleanup temp file
|