刘鑫 commited on
Commit
1700cda
·
1 Parent(s): 1488551

set zero gpu inference

Browse files
Files changed (1) hide show
  1. app.py +71 -57
app.py CHANGED
@@ -7,6 +7,7 @@ from typing import Optional, Tuple
7
  from pathlib import Path
8
  import tempfile
9
  import soundfile as sf
 
10
 
11
 
12
  def setup_cache_env():
@@ -45,31 +46,54 @@ if os.environ.get("HF_REPO_ID", "").strip() == "":
45
  # Global model cache for ZeroGPU
46
  _asr_model = None
47
  _voxcpm_model = None
48
- _default_local_model_dir = "./models/VoxCPM1.5"
 
 
 
49
 
50
 
51
  def predownload_models():
52
  """
53
  Pre-download models at startup (runs in main process, not GPU worker).
54
- This ensures models are cached before GPU functions are called.
55
  """
56
  print("=" * 50)
57
- print("Pre-downloading models to cache...")
58
- print(f"HF_HOME={os.environ.get('HF_HOME')}")
59
  print("=" * 50)
60
 
61
- # Pre-download ASR model (SenseVoice) from HuggingFace
62
- try:
63
- from huggingface_hub import snapshot_download
64
- asr_model_id = "FunAudioLLM/SenseVoiceSmall"
65
- print(f"Pre-downloading ASR model: {asr_model_id}")
66
- asr_local_path = snapshot_download(
67
- asr_model_id,
68
- cache_dir=os.environ.get("HF_HOME"),
69
- )
70
- print(f"ASR model downloaded to: {asr_local_path}")
71
- except Exception as e:
72
- print(f"Warning: Failed to pre-download ASR model: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  print("=" * 50)
75
  print("Model pre-download complete!")
@@ -80,49 +104,24 @@ def predownload_models():
80
  predownload_models()
81
 
82
 
83
- def _resolve_model_dir() -> str:
84
- """
85
- Resolve model directory:
86
- 1) Use local checkpoint directory if exists
87
- 2) If HF_REPO_ID env is set, download into models/{repo}
88
- 3) Fallback to 'models'
89
- """
90
- if os.path.isdir(_default_local_model_dir):
91
- return _default_local_model_dir
92
-
93
- repo_id = os.environ.get("HF_REPO_ID", "").strip()
94
- if len(repo_id) > 0:
95
- target_dir = os.path.join("models", repo_id.replace("/", "__"))
96
- if not os.path.isdir(target_dir):
97
- try:
98
- from huggingface_hub import snapshot_download
99
- os.makedirs(target_dir, exist_ok=True)
100
- print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...")
101
- snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
102
- except Exception as e:
103
- print(f"Warning: HF download failed: {e}. Falling back to 'models'.")
104
- return "models"
105
- return target_dir
106
- return "models"
107
-
108
-
109
  def get_asr_model():
110
- """Lazy load ASR model from HuggingFace."""
111
  global _asr_model
112
  if _asr_model is None:
113
- setup_cache_env()
114
-
115
  from funasr import AutoModel
 
116
  print("Loading ASR model...")
117
- print(f" HF_HOME={os.environ.get('HF_HOME')}")
 
118
  _asr_model = AutoModel(
119
- model="FunAudioLLM/SenseVoiceSmall", # HuggingFace model ID
120
- hub="hf", # Use HuggingFace Hub
121
  disable_update=True,
122
  log_level='INFO',
123
  device="cuda:0",
124
  )
125
- print("ASR model loaded.")
 
 
126
  return _asr_model
127
 
128
 
@@ -130,19 +129,19 @@ def get_voxcpm_model():
130
  """Lazy load VoxCPM model (without denoiser)."""
131
  global _voxcpm_model
132
  if _voxcpm_model is None:
133
- setup_cache_env()
134
-
135
  import voxcpm
 
136
  print("Loading VoxCPM model...")
137
- model_dir = _resolve_model_dir()
138
- print(f"Using model dir: {model_dir}")
139
-
140
  _voxcpm_model = voxcpm.VoxCPM(
141
- voxcpm_model_path=model_dir,
142
  optimize=False,
143
  enable_denoiser=False, # Disable denoiser to avoid ZipEnhancer download
144
  )
145
- print("VoxCPM model loaded.")
 
 
146
  return _voxcpm_model
147
 
148
 
@@ -151,9 +150,16 @@ def prompt_wav_recognition(prompt_wav: Optional[str]) -> str:
151
  """Use ASR to recognize prompt audio text."""
152
  if prompt_wav is None or not prompt_wav.strip():
153
  return ""
 
 
154
  asr_model = get_asr_model()
 
155
  res = asr_model.generate(input=prompt_wav, language="auto", use_itn=True)
 
156
  text = res[0]["text"].split('|>')[-1]
 
 
 
157
  return text
158
 
159
 
@@ -187,7 +193,10 @@ def generate_tts_audio_gpu(
187
  prompt_wav_path = f.name
188
 
189
  try:
190
- print(f"Generating audio for text: '{text[:60]}...'")
 
 
 
191
  wav = voxcpm_model.generate(
192
  text=text,
193
  prompt_text=prompt_text,
@@ -197,6 +206,11 @@ def generate_tts_audio_gpu(
197
  normalize=do_normalize,
198
  denoise=False, # Denoiser disabled
199
  )
 
 
 
 
 
200
  return (voxcpm_model.tts_model.sample_rate, wav)
201
  finally:
202
  # Cleanup temp file
 
7
  from pathlib import Path
8
  import tempfile
9
  import soundfile as sf
10
+ import time
11
 
12
 
13
  def setup_cache_env():
 
46
  # Global model cache for ZeroGPU
47
  _asr_model = None
48
  _voxcpm_model = None
49
+
50
+ # Fixed local paths for models (to avoid repeated downloads in GPU workers)
51
+ ASR_LOCAL_DIR = "./models/SenseVoiceSmall"
52
+ VOXCPM_LOCAL_DIR = "./models/VoxCPM1.5"
53
 
54
 
55
  def predownload_models():
56
  """
57
  Pre-download models at startup (runs in main process, not GPU worker).
58
+ Download to fixed local directories so GPU workers can reuse them.
59
  """
60
  print("=" * 50)
61
+ print("Pre-downloading models to local directories...")
 
62
  print("=" * 50)
63
 
64
+ # Pre-download ASR model (SenseVoice) to fixed local directory
65
+ if not os.path.isdir(ASR_LOCAL_DIR) or not os.path.exists(os.path.join(ASR_LOCAL_DIR, "model.pt")):
66
+ try:
67
+ from huggingface_hub import snapshot_download
68
+ asr_model_id = "FunAudioLLM/SenseVoiceSmall"
69
+ print(f"Pre-downloading ASR model: {asr_model_id} -> {ASR_LOCAL_DIR}")
70
+ os.makedirs(ASR_LOCAL_DIR, exist_ok=True)
71
+ snapshot_download(
72
+ repo_id=asr_model_id,
73
+ local_dir=ASR_LOCAL_DIR,
74
+ )
75
+ print(f"ASR model downloaded to: {ASR_LOCAL_DIR}")
76
+ except Exception as e:
77
+ print(f"Warning: Failed to pre-download ASR model: {e}")
78
+ else:
79
+ print(f"ASR model already exists at: {ASR_LOCAL_DIR}")
80
+
81
+ # Pre-download VoxCPM model to fixed local directory
82
+ if not os.path.isdir(VOXCPM_LOCAL_DIR) or not os.path.exists(os.path.join(VOXCPM_LOCAL_DIR, "model.safetensors")):
83
+ try:
84
+ from huggingface_hub import snapshot_download
85
+ voxcpm_model_id = os.environ.get("HF_REPO_ID", "openbmb/VoxCPM1.5")
86
+ print(f"Pre-downloading VoxCPM model: {voxcpm_model_id} -> {VOXCPM_LOCAL_DIR}")
87
+ os.makedirs(VOXCPM_LOCAL_DIR, exist_ok=True)
88
+ snapshot_download(
89
+ repo_id=voxcpm_model_id,
90
+ local_dir=VOXCPM_LOCAL_DIR,
91
+ )
92
+ print(f"VoxCPM model downloaded to: {VOXCPM_LOCAL_DIR}")
93
+ except Exception as e:
94
+ print(f"Warning: Failed to pre-download VoxCPM model: {e}")
95
+ else:
96
+ print(f"VoxCPM model already exists at: {VOXCPM_LOCAL_DIR}")
97
 
98
  print("=" * 50)
99
  print("Model pre-download complete!")
 
104
  predownload_models()
105
 
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  def get_asr_model():
108
+ """Lazy load ASR model from local directory."""
109
  global _asr_model
110
  if _asr_model is None:
 
 
111
  from funasr import AutoModel
112
+ print("=" * 50)
113
  print("Loading ASR model...")
114
+ print(f" Using local path: {ASR_LOCAL_DIR}")
115
+ start_time = time.time()
116
  _asr_model = AutoModel(
117
+ model=ASR_LOCAL_DIR, # Use local directory path
 
118
  disable_update=True,
119
  log_level='INFO',
120
  device="cuda:0",
121
  )
122
+ load_time = time.time() - start_time
123
+ print(f"ASR model loaded. (耗时: {load_time:.2f}s)")
124
+ print("=" * 50)
125
  return _asr_model
126
 
127
 
 
129
  """Lazy load VoxCPM model (without denoiser)."""
130
  global _voxcpm_model
131
  if _voxcpm_model is None:
 
 
132
  import voxcpm
133
+ print("=" * 50)
134
  print("Loading VoxCPM model...")
135
+ print(f" Using local path: {VOXCPM_LOCAL_DIR}")
136
+ start_time = time.time()
 
137
  _voxcpm_model = voxcpm.VoxCPM(
138
+ voxcpm_model_path=VOXCPM_LOCAL_DIR,
139
  optimize=False,
140
  enable_denoiser=False, # Disable denoiser to avoid ZipEnhancer download
141
  )
142
+ load_time = time.time() - start_time
143
+ print(f"VoxCPM model loaded. (耗时: {load_time:.2f}s)")
144
+ print("=" * 50)
145
  return _voxcpm_model
146
 
147
 
 
150
  """Use ASR to recognize prompt audio text."""
151
  if prompt_wav is None or not prompt_wav.strip():
152
  return ""
153
+ print("=" * 50)
154
+ print("[ASR] 开始语音识别...")
155
  asr_model = get_asr_model()
156
+ start_time = time.time()
157
  res = asr_model.generate(input=prompt_wav, language="auto", use_itn=True)
158
+ inference_time = time.time() - start_time
159
  text = res[0]["text"].split('|>')[-1]
160
+ print(f"[ASR] 识别结果: {text}")
161
+ print(f"[ASR] 推理耗时: {inference_time:.2f}s")
162
+ print("=" * 50)
163
  return text
164
 
165
 
 
193
  prompt_wav_path = f.name
194
 
195
  try:
196
+ print("=" * 50)
197
+ print("[TTS] 开始语音合成...")
198
+ print(f"[TTS] 目标文本: {text}")
199
+ start_time = time.time()
200
  wav = voxcpm_model.generate(
201
  text=text,
202
  prompt_text=prompt_text,
 
206
  normalize=do_normalize,
207
  denoise=False, # Denoiser disabled
208
  )
209
+ inference_time = time.time() - start_time
210
+ audio_duration = len(wav) / voxcpm_model.tts_model.sample_rate
211
+ rtf = inference_time / audio_duration if audio_duration > 0 else 0
212
+ print(f"[TTS] 推理耗时: {inference_time:.2f}s | 音频时长: {audio_duration:.2f}s | RTF: {rtf:.3f}")
213
+ print("=" * 50)
214
  return (voxcpm_model.tts_model.sample_rate, wav)
215
  finally:
216
  # Cleanup temp file