inoryQwQ
commited on
Commit
·
ca02ffa
1
Parent(s):
58d5563
Update README, Update python API
Browse files- README.md +271 -3
- python/main.py +54 -0
- python/requirements.txt +4 -2
- python/test_wer.py +272 -0
- python/whisper.py +209 -225
- python/whisper_cli.py +45 -0
- python/whisper_onnx.py +0 -239
- python/whisper_svr.py +104 -0
- python/whisper_tokenizer.py +395 -0
README.md
CHANGED
|
@@ -5,7 +5,199 @@ pipeline_tag: automatic-speech-recognition
|
|
| 5 |
|
| 6 |
# Whisper
|
| 7 |
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
### 服务端
|
| 11 |
|
|
@@ -16,10 +208,86 @@ cd cpp
|
|
| 16 |
|
| 17 |
### 客户端
|
| 18 |
|
| 19 |
-
curl
|
| 20 |
```
|
| 21 |
ffmpeg -i demo.wav -f f32le -c:a pcm_f32le - 2>/dev/null | \
|
| 22 |
curl -X POST 10.126.33.192:8080/asr \
|
| 23 |
-H "Content-Type: application/octet-stream" \
|
| 24 |
--data-binary @-
|
| 25 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
# Whisper
|
| 7 |
|
| 8 |
+
OpenAI Whisper on Axera
|
| 9 |
+
|
| 10 |
+
- 目前支持 C++ 和 Python 两种语言
|
| 11 |
+
- 预编译模型下载
|
| 12 |
+
- [Huggingface](https://huggingface.co/AXERA-TECH/Whisper)
|
| 13 |
+
|
| 14 |
+
- 如需自行转换请参考[模型转换](https://github.com/ml-inory/whisper.axera/blob/main/model_convert/README.md)
|
| 15 |
+
|
| 16 |
+
## 支持平台
|
| 17 |
+
|
| 18 |
+
- [x] AX650N
|
| 19 |
+
- [x] AX630C
|
| 20 |
+
|
| 21 |
+
## 模型转换
|
| 22 |
+
|
| 23 |
+
[模型转换](https://github.com/ml-inory/whisper.axera/blob/main/model_convert/README.md)
|
| 24 |
+
|
| 25 |
+
## 上板部署
|
| 26 |
+
|
| 27 |
+
- 基于 AX650N、AX630C 的设备已预装 Ubuntu22.04
|
| 28 |
+
- 链接互联网,确保设备能正常执行 `apt install`, `pip install` 等指令
|
| 29 |
+
- 已验证设备:
|
| 30 |
+
- [爱芯派Pro(AX650N)](https://wiki.sipeed.com/hardware/zh/maixIV/m4ndock/m4ndock.html)
|
| 31 |
+
- [M.2 Accelerator card(AX650N)](https://axcl-docs.readthedocs.io/zh-cn/latest/doc_guide_hardware.html)
|
| 32 |
+
- [爱芯派2(AX630C)](https://axera-pi-2-docs-cn.readthedocs.io/zh-cn/latest/index.html)
|
| 33 |
+
- [Module-LLM(AX630C)](https://docs.m5stack.com/zh_CN/module/Module-LLM)
|
| 34 |
+
- [LLM630 Compute Kit(AX630C)](https://docs.m5stack.com/zh_CN/core/LLM630%20Compute%20Kit)
|
| 35 |
+
- 支持编程语言:
|
| 36 |
+
- [Python](#Python)
|
| 37 |
+
- [C++](#CPP)
|
| 38 |
+
|
| 39 |
+
<h3 id="Python">Python</h3>
|
| 40 |
+
|
| 41 |
+
#### Requirements
|
| 42 |
+
|
| 43 |
+
推荐在板上安装Miniconda管理虚拟环境,安装方法如下:
|
| 44 |
+
```
|
| 45 |
+
mkdir -p ~/miniconda3
|
| 46 |
+
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh -O ~/miniconda3/miniconda.sh
|
| 47 |
+
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
|
| 48 |
+
rm ~/miniconda3/miniconda.sh
|
| 49 |
+
|
| 50 |
+
source ~/miniconda3/bin/activate
|
| 51 |
+
|
| 52 |
+
conda init --all
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
安装Whisper依赖
|
| 56 |
+
```
|
| 57 |
+
cd python
|
| 58 |
+
|
| 59 |
+
conda create -n whisper python=3.12
|
| 60 |
+
conda activate whisper
|
| 61 |
+
pip3 install -r requirements.txt
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
#### 安装pyaxenigne
|
| 65 |
+
|
| 66 |
+
参考 https://github.com/AXERA-TECH/pyaxengine 安装 NPU Python API
|
| 67 |
+
|
| 68 |
+
在0.1.3rc2上测试通过,可通过
|
| 69 |
+
```
|
| 70 |
+
pip install https://github.com/AXERA-TECH/pyaxengine/releases/download/0.1.3.rc2/axengine-0.1.3-py3-none-any.whl
|
| 71 |
+
```
|
| 72 |
+
安装,或把版本号更改为你想使用的版本
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
#### 运行
|
| 76 |
+
|
| 77 |
+
登陆开发板后
|
| 78 |
+
|
| 79 |
+
输入命令
|
| 80 |
+
|
| 81 |
+
```
|
| 82 |
+
cd python
|
| 83 |
+
conda activate whisper
|
| 84 |
+
python3 main.py --model_type small --model_path ../models-ax650 --wav ../demo.wav --language zh
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
输出结果
|
| 88 |
+
|
| 89 |
+
```
|
| 90 |
+
root@ax650:/mnt/qtang/whisper.axera/python# python3 main.py --wav ../demo.wav --model_type small --model_path ../models/ --language zh
|
| 91 |
+
[INFO] Available providers: ['AxEngineExecutionProvider']
|
| 92 |
+
wav: ../demo.wav
|
| 93 |
+
model_type: small
|
| 94 |
+
model_path: ../models/
|
| 95 |
+
language: zh
|
| 96 |
+
[INFO] Using provider: AxEngineExecutionProvider
|
| 97 |
+
[INFO] Chip type: ChipType.MC50
|
| 98 |
+
[INFO] VNPU type: VNPUType.DISABLED
|
| 99 |
+
[INFO] Engine version: 2.10.1s
|
| 100 |
+
[INFO] Model type: 2 (triple core)
|
| 101 |
+
[INFO] Compiler version: 3.2-patch1 117f5fd4
|
| 102 |
+
[INFO] Using provider: AxEngineExecutionProvider
|
| 103 |
+
[INFO] Model type: 2 (triple core)
|
| 104 |
+
[INFO] Compiler version: 3.2-patch1 117f5fd4
|
| 105 |
+
[INFO] Using provider: AxEngineExecutionProvider
|
| 106 |
+
[INFO] Model type: 2 (triple core)
|
| 107 |
+
[INFO] Compiler version: 3.2-patch1 117f5fd4
|
| 108 |
+
Load models take 2322.563409805298ms
|
| 109 |
+
Preprocess wav take 6971.68493270874ms
|
| 110 |
+
Run encoder take 211.52877807617188ms
|
| 111 |
+
Run decoder_main take 79.00094985961914ms
|
| 112 |
+
First token: 17556
|
| 113 |
+
Run decoder_loop take 101.91774368286133ms
|
| 114 |
+
Iter 0 Token: 20844
|
| 115 |
+
Run decoder_loop take 60.30416488647461ms
|
| 116 |
+
Iter 1 Token: 7781
|
| 117 |
+
Run decoder_loop take 60.22000312805176ms
|
| 118 |
+
Iter 2 Token: 20204
|
| 119 |
+
Run decoder_loop take 60.23716926574707ms
|
| 120 |
+
Iter 3 Token: 28455
|
| 121 |
+
Run decoder_loop take 60.214996337890625ms
|
| 122 |
+
Iter 4 Token: 31962
|
| 123 |
+
Run decoder_loop take 60.17565727233887ms
|
| 124 |
+
Iter 5 Token: 6336
|
| 125 |
+
Run decoder_loop take 60.94002723693848ms
|
| 126 |
+
Iter 6 Token: 254
|
| 127 |
+
Run decoder_loop take 60.71639060974121ms
|
| 128 |
+
Iter 7 Token: 2930
|
| 129 |
+
Run decoder_loop take 60.225725173950195ms
|
| 130 |
+
Iter 8 Token: 236
|
| 131 |
+
Run decoder_loop take 60.167789459228516ms
|
| 132 |
+
Iter 9 Token: 36135
|
| 133 |
+
Run decoder_loop take 60.29987335205078ms
|
| 134 |
+
Iter 10 Token: 15868
|
| 135 |
+
Run decoder_loop take 61.163902282714844ms
|
| 136 |
+
Iter 11 Token: 252
|
| 137 |
+
Run decoder_loop take 60.273170471191406ms
|
| 138 |
+
Iter 12 Token: 1546
|
| 139 |
+
Run decoder_loop take 60.23144721984863ms
|
| 140 |
+
Iter 13 Token: 46514
|
| 141 |
+
Run decoder_loop take 60.31966209411621ms
|
| 142 |
+
Iter 14 Token: 50257
|
| 143 |
+
Result: 甚至出现交易几乎停滞的情况
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
运行参数说明:
|
| 147 |
+
| 参数名称 | 说明 | 默认值 |
|
| 148 |
+
| --- | --- | --- |
|
| 149 |
+
| --wav | 输入音频文件 | |
|
| 150 |
+
| --model_type/-t | 模型类型, tiny/base/small | |
|
| 151 |
+
| --model_path/-p | 模型所在目录 | ../models |
|
| 152 |
+
| --language/-l | 识别语言 | zh |
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
<h3 id="CPP">CPP</h3>
|
| 156 |
+
|
| 157 |
+
#### 运行
|
| 158 |
+
|
| 159 |
+
在 AX650N 设备上执行
|
| 160 |
+
|
| 161 |
+
```
|
| 162 |
+
cd cpp
|
| 163 |
+
./whisper -w ../demo.wav
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
或
|
| 167 |
+
|
| 168 |
+
```
|
| 169 |
+
cd cpp
|
| 170 |
+
./whisper --model_type small --model_path ../models -w ../demo.wav
|
| 171 |
+
```
|
| 172 |
+
|
| 173 |
+
输出结果
|
| 174 |
+
|
| 175 |
+
```
|
| 176 |
+
root@ax650:/mnt/qtang/whisper.axera/cpp# ./install/whisper --wav ../demo.wav --model_type small --model_path ../models/ --language zh
|
| 177 |
+
wav_file: ../demo.wav
|
| 178 |
+
model_path: ../models/
|
| 179 |
+
model_type: small
|
| 180 |
+
language: zh
|
| 181 |
+
Encoder run take 188.30 ms
|
| 182 |
+
First token: 17556 take 81.88ms
|
| 183 |
+
Next Token: 20844 take 29.64ms
|
| 184 |
+
Next Token: 7781 take 29.70ms
|
| 185 |
+
Next Token: 20204 take 29.64ms
|
| 186 |
+
Next Token: 28455 take 29.65ms
|
| 187 |
+
Next Token: 31962 take 29.61ms
|
| 188 |
+
Next Token: 6336 take 29.67ms
|
| 189 |
+
Next Token: 254 take 29.63ms
|
| 190 |
+
Next Token: 2930 take 29.61ms
|
| 191 |
+
Next Token: 236 take 29.56ms
|
| 192 |
+
Next Token: 36135 take 29.64ms
|
| 193 |
+
Next Token: 15868 take 29.71ms
|
| 194 |
+
Next Token: 252 take 29.51ms
|
| 195 |
+
Next Token: 1546 take 29.63ms
|
| 196 |
+
Next Token: 46514 take 29.51ms
|
| 197 |
+
Next Token: 50257 take 29.69ms
|
| 198 |
+
All take 801.13 ms
|
| 199 |
+
Result: 甚至出现交易几乎停滞的情况
|
| 200 |
+
```
|
| 201 |
|
| 202 |
### 服务端
|
| 203 |
|
|
|
|
| 208 |
|
| 209 |
### 客户端
|
| 210 |
|
| 211 |
+
curl命令行测试(请自行替换IP和端口):
|
| 212 |
```
|
| 213 |
ffmpeg -i demo.wav -f f32le -c:a pcm_f32le - 2>/dev/null | \
|
| 214 |
curl -X POST 10.126.33.192:8080/asr \
|
| 215 |
-H "Content-Type: application/octet-stream" \
|
| 216 |
--data-binary @-
|
| 217 |
+
```
|
| 218 |
+
|
| 219 |
+
## 模型性能
|
| 220 |
+
|
| 221 |
+
### Latency
|
| 222 |
+
|
| 223 |
+
RTF: Real-Time Factor
|
| 224 |
+
|
| 225 |
+
CPP:
|
| 226 |
+
|
| 227 |
+
| Models | AX650N | AX630C |
|
| 228 |
+
| ------------- | ------ | ------ |
|
| 229 |
+
| Whisper-Tiny | 0.08 | |
|
| 230 |
+
| Whisper-Base | 0.11 | 0.35 |
|
| 231 |
+
| Whisper-Small | 0.24 | |
|
| 232 |
+
| Whisper-Turbo | 0.48 | |
|
| 233 |
+
|
| 234 |
+
Python:
|
| 235 |
+
|
| 236 |
+
| Models | AX650N | AX630C |
|
| 237 |
+
| ------------- | ------ | ------ |
|
| 238 |
+
| Whisper-Tiny | 0.12 | |
|
| 239 |
+
| Whisper-Base | 0.16 | 0.35 |
|
| 240 |
+
| Whisper-Small | 0.50 | |
|
| 241 |
+
| Whisper-Turbo | 0.60 | |
|
| 242 |
+
|
| 243 |
+
### Word Error Rate(Test on AIShell dataset)
|
| 244 |
+
|
| 245 |
+
| Models | AX650N | AX630C |
|
| 246 |
+
| ------------- | ------ | ------ |
|
| 247 |
+
| Whisper-Tiny | 0.24 | |
|
| 248 |
+
| Whisper-Base | 0.18 | |
|
| 249 |
+
| Whisper-Small | 0.11 | |
|
| 250 |
+
| Whisper-Turbo | 0.06 | |
|
| 251 |
+
|
| 252 |
+
若要复现测试结果,请按照以下步骤:
|
| 253 |
+
|
| 254 |
+
解压数据集:
|
| 255 |
+
```
|
| 256 |
+
unzip datasets.zip
|
| 257 |
+
```
|
| 258 |
+
|
| 259 |
+
运行测试脚本:
|
| 260 |
+
```
|
| 261 |
+
cd python
|
| 262 |
+
conda activate whisper
|
| 263 |
+
python test_wer.py -d aishell --gt_path ../datasets/ground_truth.txt --model_type tiny
|
| 264 |
+
|
| 265 |
+
```
|
| 266 |
+
|
| 267 |
+
### MEM Usage
|
| 268 |
+
|
| 269 |
+
* CMM Stands for Physical memory used by Axera modules like VDEC(Video decoder), VENC(Video encoder), NPU, etc.
|
| 270 |
+
|
| 271 |
+
Python:
|
| 272 |
+
|
| 273 |
+
| Models | CMM(MB)| OS(MB) |
|
| 274 |
+
| ------------- | ------ | ------ |
|
| 275 |
+
| Whisper-Tiny | 332 | 512 |
|
| 276 |
+
| Whisper-Base | 533 | 644 |
|
| 277 |
+
| Whisper-Small | 1106 | 906 |
|
| 278 |
+
| Whisper-Turbo | 2065 | 2084 |
|
| 279 |
+
|
| 280 |
+
C++:
|
| 281 |
+
|
| 282 |
+
| Models | CMM(MB)| OS(MB) |
|
| 283 |
+
| ------------- | ------ | ------ |
|
| 284 |
+
| Whisper-Tiny | 332 | 31 |
|
| 285 |
+
| Whisper-Base | 533 | 54 |
|
| 286 |
+
| Whisper-Small | 1106 | 146 |
|
| 287 |
+
| Whisper-Turbo | 2065 | 86 |
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
## 技术讨论
|
| 291 |
+
|
| 292 |
+
- Github issues
|
| 293 |
+
- QQ 群: 139953715
|
python/main.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
from whisper import Whisper
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_args():
|
| 8 |
+
parser = argparse.ArgumentParser(
|
| 9 |
+
prog="whisper",
|
| 10 |
+
description="Run Whisper on input audio file"
|
| 11 |
+
)
|
| 12 |
+
parser.add_argument("--wav", "-w", type=str, required=True, help="Input audio file")
|
| 13 |
+
parser.add_argument("--model_type", "-t", type=str, choices=["tiny", "base", "small", "large", "large-v3", "turbo"], required=True, help="model type, only support tiny, base and small currently")
|
| 14 |
+
parser.add_argument("--model_path", "-p", type=str, required=False, default="../models/models-ax650", help="model path for *.axmodel, tokens.txt, positional_embedding.bin")
|
| 15 |
+
parser.add_argument("--language", "-l", type=str, required=False, default="zh", help="Target language, support en, zh, ja, and others. See languages.py for more options.")
|
| 16 |
+
parser.add_argument("--task", type=str, required=False, choices=["translate", "transcribe"], default="transcribe")
|
| 17 |
+
parser.add_argument("--print_rtf", action="store_true", help="Print Real-Time Factor")
|
| 18 |
+
return parser.parse_args()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def print_args(args):
|
| 22 |
+
print(f"wav: {args.wav}")
|
| 23 |
+
print(f"model_type: {args.model_type}")
|
| 24 |
+
print(f"model_path: {args.model_path}")
|
| 25 |
+
print(f"language: {args.language}")
|
| 26 |
+
print(f"task: {args.task}")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def main():
|
| 30 |
+
args = get_args()
|
| 31 |
+
print_args(args)
|
| 32 |
+
|
| 33 |
+
# Check wav existence
|
| 34 |
+
wav_path = args.wav
|
| 35 |
+
assert os.path.exists(wav_path), f"{wav_path} NOT exist"
|
| 36 |
+
|
| 37 |
+
model = Whisper(args.model_type, args.model_path, args.language, args.task)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
print("\n预测结果:")
|
| 42 |
+
start = time.time()
|
| 43 |
+
print(model.run(wav_path))
|
| 44 |
+
end = time.time()
|
| 45 |
+
|
| 46 |
+
if args.print_rtf:
|
| 47 |
+
import librosa
|
| 48 |
+
samples, sr = librosa.load(wav_path, sr=16000)
|
| 49 |
+
duration = len(samples) / sr
|
| 50 |
+
process_time = end - start
|
| 51 |
+
print(f"RTF: {process_time / duration}")
|
| 52 |
+
|
| 53 |
+
if __name__ == "__main__":
|
| 54 |
+
main()
|
python/requirements.txt
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
numpy==1.26.4
|
| 2 |
soundfile
|
| 3 |
-
librosa
|
| 4 |
-
zhconv
|
|
|
|
|
|
|
|
|
| 1 |
numpy==1.26.4
|
| 2 |
soundfile
|
| 3 |
+
librosa==0.9.1
|
| 4 |
+
zhconv
|
| 5 |
+
jiwer
|
| 6 |
+
tiktoken
|
python/test_wer.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import logging
|
| 4 |
+
import re
|
| 5 |
+
from whisper import Whisper
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def setup_logging():
|
| 9 |
+
"""配置日志系统,同时输出到控制台和文件"""
|
| 10 |
+
# 获取脚本所在目录
|
| 11 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 12 |
+
log_file = os.path.join(script_dir, "test_wer.log")
|
| 13 |
+
|
| 14 |
+
# 配置日志格式
|
| 15 |
+
log_format = '%(asctime)s - %(levelname)s - %(message)s'
|
| 16 |
+
date_format = '%Y-%m-%d %H:%M:%S'
|
| 17 |
+
|
| 18 |
+
# 创建logger
|
| 19 |
+
logger = logging.getLogger()
|
| 20 |
+
logger.setLevel(logging.INFO)
|
| 21 |
+
|
| 22 |
+
# 清除现有的handler
|
| 23 |
+
for handler in logger.handlers[:]:
|
| 24 |
+
logger.removeHandler(handler)
|
| 25 |
+
|
| 26 |
+
# 创建文件handler
|
| 27 |
+
file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8')
|
| 28 |
+
file_handler.setLevel(logging.INFO)
|
| 29 |
+
file_formatter = logging.Formatter(log_format, date_format)
|
| 30 |
+
file_handler.setFormatter(file_formatter)
|
| 31 |
+
|
| 32 |
+
# 创建控制台handler
|
| 33 |
+
console_handler = logging.StreamHandler()
|
| 34 |
+
console_handler.setLevel(logging.INFO)
|
| 35 |
+
console_formatter = logging.Formatter(log_format, date_format)
|
| 36 |
+
console_handler.setFormatter(console_formatter)
|
| 37 |
+
|
| 38 |
+
# 添加handler到logger
|
| 39 |
+
logger.addHandler(file_handler)
|
| 40 |
+
logger.addHandler(console_handler)
|
| 41 |
+
|
| 42 |
+
return logger
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class AIShellDataset:
|
| 46 |
+
def __init__(self, gt_path: str):
|
| 47 |
+
"""
|
| 48 |
+
初始化数据集
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
json_path: voice.json文件的路径
|
| 52 |
+
"""
|
| 53 |
+
self.gt_path = gt_path
|
| 54 |
+
self.dataset_dir = os.path.dirname(gt_path)
|
| 55 |
+
self.voice_dir = os.path.join(self.dataset_dir, "aishell_S0764")
|
| 56 |
+
|
| 57 |
+
# 检查必要文件和文件夹是否存在
|
| 58 |
+
assert os.path.exists(gt_path), f"gt文件不存在: {gt_path}"
|
| 59 |
+
assert os.path.exists(self.voice_dir), f"aishell_S0764文件夹不存在: {self.voice_dir}"
|
| 60 |
+
|
| 61 |
+
# 加载数据
|
| 62 |
+
self.data = []
|
| 63 |
+
with open(gt_path, 'r', encoding='utf-8') as f:
|
| 64 |
+
for line in f:
|
| 65 |
+
line = line.strip()
|
| 66 |
+
audio_path, gt = line.split(" ")
|
| 67 |
+
audio_path = os.path.join(self.voice_dir, audio_path + ".wav")
|
| 68 |
+
self.data.append({"audio_path": audio_path, "gt": gt})
|
| 69 |
+
|
| 70 |
+
# 使用logging而不是print
|
| 71 |
+
logger = logging.getLogger()
|
| 72 |
+
logger.info(f"加载了 {len(self.data)} 条数据")
|
| 73 |
+
|
| 74 |
+
def __iter__(self):
|
| 75 |
+
"""返回迭代器"""
|
| 76 |
+
self.index = 0
|
| 77 |
+
return self
|
| 78 |
+
|
| 79 |
+
def __next__(self):
|
| 80 |
+
"""返回下一个数据项"""
|
| 81 |
+
if self.index >= len(self.data):
|
| 82 |
+
raise StopIteration
|
| 83 |
+
|
| 84 |
+
item = self.data[self.index]
|
| 85 |
+
audio_path = item["audio_path"]
|
| 86 |
+
ground_truth = item["gt"]
|
| 87 |
+
|
| 88 |
+
self.index += 1
|
| 89 |
+
return audio_path, ground_truth
|
| 90 |
+
|
| 91 |
+
def __len__(self):
|
| 92 |
+
"""返回数据集大小"""
|
| 93 |
+
return len(self.data)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class CommonVoiceDataset:
|
| 97 |
+
"""Common Voice数据集解析器"""
|
| 98 |
+
|
| 99 |
+
def __init__(self, tsv_path: str):
|
| 100 |
+
"""
|
| 101 |
+
初始化数据集
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
json_path: voice.json文件的路径
|
| 105 |
+
"""
|
| 106 |
+
self.tsv_path = tsv_path
|
| 107 |
+
self.dataset_dir = os.path.dirname(tsv_path)
|
| 108 |
+
self.voice_dir = os.path.join(self.dataset_dir, "clips")
|
| 109 |
+
|
| 110 |
+
# 检查必要文件和文件夹是否存在
|
| 111 |
+
assert os.path.exists(tsv_path), f"{tsv_path}文件不存在: {tsv_path}"
|
| 112 |
+
assert os.path.exists(self.voice_dir), f"voice文件夹不存在: {self.voice_dir}"
|
| 113 |
+
|
| 114 |
+
# 加载JSON数据
|
| 115 |
+
self.data = []
|
| 116 |
+
with open(tsv_path, 'r', encoding='utf-8') as f:
|
| 117 |
+
f.readline()
|
| 118 |
+
for line in f:
|
| 119 |
+
line = line.strip()
|
| 120 |
+
splits = line.split("\t")
|
| 121 |
+
audio_path = splits[1]
|
| 122 |
+
gt = splits[2]
|
| 123 |
+
audio_path = os.path.join(self.voice_dir, audio_path)
|
| 124 |
+
self.data.append({"audio_path": audio_path, "gt": gt})
|
| 125 |
+
|
| 126 |
+
# 使用logging而不是print
|
| 127 |
+
logger = logging.getLogger()
|
| 128 |
+
logger.info(f"加载了 {len(self.data)} 条数据")
|
| 129 |
+
|
| 130 |
+
def __iter__(self):
|
| 131 |
+
"""返回迭代器"""
|
| 132 |
+
self.index = 0
|
| 133 |
+
return self
|
| 134 |
+
|
| 135 |
+
def __next__(self):
|
| 136 |
+
"""返回下一个数据项"""
|
| 137 |
+
if self.index >= len(self.data):
|
| 138 |
+
raise StopIteration
|
| 139 |
+
|
| 140 |
+
item = self.data[self.index]
|
| 141 |
+
audio_path = item["audio_path"]
|
| 142 |
+
ground_truth = item["gt"]
|
| 143 |
+
|
| 144 |
+
self.index += 1
|
| 145 |
+
return audio_path, ground_truth
|
| 146 |
+
|
| 147 |
+
def __len__(self):
|
| 148 |
+
"""返回数据集大小"""
|
| 149 |
+
return len(self.data)
|
| 150 |
+
|
| 151 |
+
def get_args():
|
| 152 |
+
parser = argparse.ArgumentParser(
|
| 153 |
+
prog="whisper",
|
| 154 |
+
description="Test WER on dataset"
|
| 155 |
+
)
|
| 156 |
+
parser.add_argument("--dataset", "-d", type=str, required=True, choices=["aishell", "common_voice"], help="Test dataset")
|
| 157 |
+
parser.add_argument("--gt_path", "-g", type=str, required=True, help="Test dataset ground truth file")
|
| 158 |
+
parser.add_argument("--max_num", type=int, default=-1, required=False, help="Maximum test data num")
|
| 159 |
+
parser.add_argument("--model_type", "-t", type=str, choices=["tiny", "base", "small", "large", "large-v3", "turbo"], required=True, help="model type, only support tiny, base and small currently")
|
| 160 |
+
parser.add_argument("--model_path", "-p", type=str, required=False, default="../models/models-ax650", help="model path for *.axmodel, tokens.txt, positional_embedding.bin")
|
| 161 |
+
parser.add_argument("--language", "-l", type=str, required=False, default="zh", help="Target language, support en, zh, ja, and others. See languages.py for more options.")
|
| 162 |
+
return parser.parse_args()
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def print_args(args):
|
| 166 |
+
logger = logging.getLogger()
|
| 167 |
+
logger.info(f"dataset: {args.dataset}")
|
| 168 |
+
logger.info(f"gt_path: {args.gt_path}")
|
| 169 |
+
logger.info(f"max_num: {args.max_num}")
|
| 170 |
+
logger.info(f"model_type: {args.model_type}")
|
| 171 |
+
logger.info(f"model_path: {args.model_path}")
|
| 172 |
+
logger.info(f"language: {args.language}")
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def min_distance(word1: str, word2: str) -> int:
|
| 176 |
+
|
| 177 |
+
row = len(word1) + 1
|
| 178 |
+
column = len(word2) + 1
|
| 179 |
+
|
| 180 |
+
cache = [ [0]*column for i in range(row) ]
|
| 181 |
+
|
| 182 |
+
for i in range(row):
|
| 183 |
+
for j in range(column):
|
| 184 |
+
|
| 185 |
+
if i ==0 and j ==0:
|
| 186 |
+
cache[i][j] = 0
|
| 187 |
+
elif i == 0 and j!=0:
|
| 188 |
+
cache[i][j] = j
|
| 189 |
+
elif j == 0 and i!=0:
|
| 190 |
+
cache[i][j] = i
|
| 191 |
+
else:
|
| 192 |
+
if word1[i-1] == word2[j-1]:
|
| 193 |
+
cache[i][j] = cache[i-1][j-1]
|
| 194 |
+
else:
|
| 195 |
+
replace = cache[i-1][j-1] + 1
|
| 196 |
+
insert = cache[i][j-1] + 1
|
| 197 |
+
remove = cache[i-1][j] + 1
|
| 198 |
+
|
| 199 |
+
cache[i][j] = min(replace, insert, remove)
|
| 200 |
+
|
| 201 |
+
return cache[row-1][column-1]
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def remove_punctuation(text):
|
| 205 |
+
# 定义正则表达式模式,匹配所有标点符号
|
| 206 |
+
# 这个模式包括常见的标点符号和中文标点
|
| 207 |
+
pattern = r'[^\w\s]|_'
|
| 208 |
+
|
| 209 |
+
# 使用sub方法将所有匹配的标点符号替换为空字符串
|
| 210 |
+
cleaned_text = re.sub(pattern, '', text)
|
| 211 |
+
|
| 212 |
+
return cleaned_text
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def main():
|
| 216 |
+
# 设置日志系统
|
| 217 |
+
logger = setup_logging()
|
| 218 |
+
|
| 219 |
+
args = get_args()
|
| 220 |
+
print_args(args)
|
| 221 |
+
|
| 222 |
+
dataset_type = args.dataset.lower()
|
| 223 |
+
if dataset_type == "aishell":
|
| 224 |
+
dataset = AIShellDataset(args.gt_path)
|
| 225 |
+
elif dataset_type == "common_voice":
|
| 226 |
+
dataset = CommonVoiceDataset(args.gt_path)
|
| 227 |
+
else:
|
| 228 |
+
raise ValueError(f"Unknown dataset type {dataset_type}")
|
| 229 |
+
|
| 230 |
+
max_num = args.max_num
|
| 231 |
+
|
| 232 |
+
# Load model
|
| 233 |
+
model = Whisper(args.model_type, args.model_path, args.language, "transcribe")
|
| 234 |
+
|
| 235 |
+
# Iterate over dataset
|
| 236 |
+
references = []
|
| 237 |
+
hyp = []
|
| 238 |
+
all_character_error_num = 0
|
| 239 |
+
all_character_num = 0
|
| 240 |
+
wer_file = open("wer.txt", "w")
|
| 241 |
+
max_data_num = max_num if max_num > 0 else len(dataset)
|
| 242 |
+
for n, (audio_path, reference) in enumerate(dataset):
|
| 243 |
+
hypothesis = model.run(audio_path)
|
| 244 |
+
|
| 245 |
+
hypothesis = remove_punctuation(hypothesis)
|
| 246 |
+
reference = remove_punctuation(reference)
|
| 247 |
+
|
| 248 |
+
character_error_num = min_distance(reference, hypothesis)
|
| 249 |
+
character_num = len(reference)
|
| 250 |
+
character_error_rate = character_error_num / character_num * 100
|
| 251 |
+
|
| 252 |
+
all_character_error_num += character_error_num
|
| 253 |
+
all_character_num += character_num
|
| 254 |
+
|
| 255 |
+
hyp.append(hypothesis)
|
| 256 |
+
references.append(reference)
|
| 257 |
+
|
| 258 |
+
line_content = f"({n+1}/{max_data_num}) {os.path.basename(audio_path)} gt: {reference} predict: {hypothesis} WER: {character_error_rate}%"
|
| 259 |
+
wer_file.write(line_content + "\n")
|
| 260 |
+
logger.info(line_content)
|
| 261 |
+
|
| 262 |
+
if n + 1 >= max_data_num:
|
| 263 |
+
break
|
| 264 |
+
|
| 265 |
+
total_character_error_rate = all_character_error_num / all_character_num * 100
|
| 266 |
+
|
| 267 |
+
logger.info(f"Total WER: {total_character_error_rate}%")
|
| 268 |
+
wer_file.write(f"Total WER: {total_character_error_rate}%")
|
| 269 |
+
wer_file.close()
|
| 270 |
+
|
| 271 |
+
if __name__ == "__main__":
|
| 272 |
+
main()
|
python/whisper.py
CHANGED
|
@@ -1,240 +1,224 @@
|
|
| 1 |
-
import argparse
|
| 2 |
import axengine as axe
|
| 3 |
import numpy as np
|
| 4 |
import librosa
|
| 5 |
import os
|
| 6 |
-
from typing import
|
| 7 |
-
|
| 8 |
-
import
|
|
|
|
| 9 |
import zhconv
|
| 10 |
-
import time
|
| 11 |
-
from languages import WHISPER_LANGUAGES
|
| 12 |
|
| 13 |
|
| 14 |
-
|
| 15 |
-
WHISPER_SAMPLE_RATE = 16000
|
| 16 |
-
WHISPER_N_FFT = 480
|
| 17 |
-
WHISPER_HOP_LENGTH = 160
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
if mel.shape[1] < target:
|
| 121 |
-
mel = np.concatenate((mel, np.zeros((n_mels, target - mel.shape[1]), dtype=np.float32)), axis=-1)
|
| 122 |
-
|
| 123 |
-
return mel
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
def supress_tokens(logits, is_initial):
|
| 127 |
-
if is_initial:
|
| 128 |
-
logits[WHISPER_EOT] = NEG_INF
|
| 129 |
-
logits[WHISPER_BLANK] = NEG_INF
|
| 130 |
-
|
| 131 |
-
logits[WHISPER_NO_TIMESTAMPS] = NEG_INF
|
| 132 |
-
logits[WHISPER_SOT] = NEG_INF
|
| 133 |
-
logits[WHISPER_NO_SPEECH] = NEG_INF
|
| 134 |
-
logits[WHISPER_TRANSLATE] = NEG_INF
|
| 135 |
-
return logits
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
def choose_language(lang):
|
| 139 |
-
if lang not in WHISPER_LANGUAGES.keys():
|
| 140 |
-
raise Exception(f"Unknown language: {lang}. Check languages.py for correct options.")
|
| 141 |
-
SOT_SEQUENCE[1] = WHISPER_SOT + 1 + tuple(WHISPER_LANGUAGES.keys()).index(lang)
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
def main():
|
| 145 |
-
args = get_args()
|
| 146 |
-
print_args(args)
|
| 147 |
-
|
| 148 |
-
# Check wav existence
|
| 149 |
-
wav_path = args.wav
|
| 150 |
-
assert os.path.exists(wav_path), f"{wav_path} NOT exist"
|
| 151 |
-
|
| 152 |
-
# Choose language
|
| 153 |
-
choose_language(args.language)
|
| 154 |
-
|
| 155 |
-
# Load models and other stuff
|
| 156 |
-
start = time.time()
|
| 157 |
-
encoder, decoder_main, decoder_loop, pe, token_table = load_models(args.model_path, args.model_type)
|
| 158 |
-
print(f"Load models take {(time.time() - start) * 1000}ms")
|
| 159 |
-
WHISPER_N_TEXT_STATE = WHISPER_N_TEXT_STATE_MAP[args.model_type]
|
| 160 |
-
|
| 161 |
-
# Preprocess
|
| 162 |
-
start = time.time()
|
| 163 |
-
mel = compute_feature(wav_path, n_mels=WHISPER_N_MELS)
|
| 164 |
-
print(f"Preprocess wav take {(time.time() - start) * 1000}ms")
|
| 165 |
-
# mel.tofile("mel.bin")
|
| 166 |
-
|
| 167 |
-
# Run encoder
|
| 168 |
-
start = time.time()
|
| 169 |
-
x = encoder.run(None, input_feed={"mel": mel[None, ...]})
|
| 170 |
-
n_layer_cross_k, n_layer_cross_v = x
|
| 171 |
-
print(f"Run encoder take {(time.time() - start) * 1000}ms")
|
| 172 |
-
|
| 173 |
-
# n_layer_cross_k.tofile("n_layer_cross_k.bin")
|
| 174 |
-
# n_layer_cross_v.tofile("n_layer_cross_v.bin")
|
| 175 |
-
|
| 176 |
-
# Run decoder_main
|
| 177 |
-
start = time.time()
|
| 178 |
-
x = decoder_main.run(None, input_feed={
|
| 179 |
-
"tokens": SOT_SEQUENCE[None, ...],
|
| 180 |
-
"n_layer_cross_k": n_layer_cross_k,
|
| 181 |
-
"n_layer_cross_v": n_layer_cross_v
|
| 182 |
-
})
|
| 183 |
-
logits, n_layer_self_k_cache, n_layer_self_v_cache = x
|
| 184 |
-
print(f"Run decoder_main take {(time.time() - start) * 1000}ms")
|
| 185 |
-
|
| 186 |
-
# Decode token
|
| 187 |
-
logits = logits[0, -1, :]
|
| 188 |
-
logits = supress_tokens(logits, is_initial=True)
|
| 189 |
-
# logits.tofile("logits.bin")
|
| 190 |
-
max_token_id = np.argmax(logits)
|
| 191 |
-
output_tokens = []
|
| 192 |
-
print(f"First token: {max_token_id}")
|
| 193 |
-
|
| 194 |
-
# Position embedding offset
|
| 195 |
-
offset = SOT_SEQUENCE.shape[0]
|
| 196 |
-
|
| 197 |
-
# Autoregressively run decoder until token meets EOT
|
| 198 |
-
for i in range(WHISPER_N_TEXT_CTX - SOT_SEQUENCE.shape[0]):
|
| 199 |
-
if max_token_id == WHISPER_EOT:
|
| 200 |
-
break
|
| 201 |
-
|
| 202 |
-
output_tokens.append(max_token_id)
|
| 203 |
-
|
| 204 |
-
mask = np.zeros((WHISPER_N_TEXT_CTX,), dtype=np.float32)
|
| 205 |
-
mask[: WHISPER_N_TEXT_CTX - offset - 1] = NEG_INF
|
| 206 |
-
|
| 207 |
-
# Run decoder_loop
|
| 208 |
-
start = time.time()
|
| 209 |
-
x = decoder_loop.run(None, input_feed={
|
| 210 |
-
"tokens": np.array([[output_tokens[-1]]], dtype=np.int32),
|
| 211 |
-
"in_n_layer_self_k_cache": n_layer_self_k_cache,
|
| 212 |
-
"in_n_layer_self_v_cache": n_layer_self_v_cache,
|
| 213 |
"n_layer_cross_k": n_layer_cross_k,
|
| 214 |
-
"n_layer_cross_v": n_layer_cross_v
|
| 215 |
-
"positional_embedding": pe[offset * WHISPER_N_TEXT_STATE : (offset + 1) * WHISPER_N_TEXT_STATE][None, ...],
|
| 216 |
-
"mask": mask
|
| 217 |
})
|
| 218 |
logits, n_layer_self_k_cache, n_layer_self_v_cache = x
|
| 219 |
-
print(f"Run decoder_loop take {(time.time() - start) * 1000}ms")
|
| 220 |
|
| 221 |
# Decode token
|
| 222 |
-
|
| 223 |
-
logits = supress_tokens(logits
|
|
|
|
| 224 |
max_token_id = np.argmax(logits)
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import axengine as axe
|
| 2 |
import numpy as np
|
| 3 |
import librosa
|
| 4 |
import os
|
| 5 |
+
from typing import Union
|
| 6 |
+
from whisper_tokenizer import *
|
| 7 |
+
import json
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
import zhconv
|
|
|
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
+
NEG_INF = float("-inf")
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
+
@dataclass
|
| 15 |
+
class WhisperConfig:
|
| 16 |
+
n_mels : int = 0
|
| 17 |
+
sample_rate : int = 0
|
| 18 |
+
n_fft : int = 0
|
| 19 |
+
hop_length : int = 0
|
| 20 |
+
|
| 21 |
+
sot : int = 0
|
| 22 |
+
eot : int = 0
|
| 23 |
+
blank_id : int = 0
|
| 24 |
+
no_timestamps : int = 0
|
| 25 |
+
no_speech : int = 0
|
| 26 |
+
translate : int = 0
|
| 27 |
+
transcribe : int = 0
|
| 28 |
+
n_vocab : int = 0
|
| 29 |
+
n_text_ctx : int = 0
|
| 30 |
+
n_text_state : int = 0
|
| 31 |
+
|
| 32 |
+
sot_sequence : np.ndarray = field(default_factory=lambda: np.array([0,0,0,0], dtype=np.int32))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class Whisper:
|
| 36 |
+
def __init__(self, model_type: str, model_path: str, language: str, task: str):
|
| 37 |
+
assert task in ["translate", "transcribe"]
|
| 38 |
+
|
| 39 |
+
self.language = language
|
| 40 |
+
self.task = task
|
| 41 |
+
self.encoder, self.decoder_main, self.decoder_loop, self.pe, self.tokenizer, model_config = \
|
| 42 |
+
self.load_model(model_type, model_path, language, task)
|
| 43 |
+
self.config = self.load_config(model_config)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def load_model(self, model_type, model_path, language, task):
|
| 47 |
+
encoder_path = f"{model_type}/{model_type}-encoder.axmodel"
|
| 48 |
+
decoder_main_path = f"{model_type}/{model_type}-decoder-main.axmodel"
|
| 49 |
+
decoder_loop_path = f"{model_type}/{model_type}-decoder-loop.axmodel"
|
| 50 |
+
pe_path = f"{model_type}/{model_type}-positional_embedding.bin"
|
| 51 |
+
model_config_file = f"{model_type}/{model_type}_config.json"
|
| 52 |
+
|
| 53 |
+
required_files = [os.path.join(model_path, i) for i in (encoder_path, decoder_main_path, decoder_loop_path, pe_path, model_config_file)]
|
| 54 |
+
# Check file existence
|
| 55 |
+
for i, file_path in enumerate(required_files):
|
| 56 |
+
assert os.path.exists(file_path), f"{file_path} NOT exist"
|
| 57 |
+
|
| 58 |
+
# Load encoder
|
| 59 |
+
encoder = axe.InferenceSession(required_files[0], providers=['AxEngineExecutionProvider'])
|
| 60 |
+
# Load decoder main
|
| 61 |
+
decoder_main = axe.InferenceSession(required_files[1], providers=['AxEngineExecutionProvider'])
|
| 62 |
+
# Load decoder loop
|
| 63 |
+
decoder_loop = axe.InferenceSession(required_files[2], providers=['AxEngineExecutionProvider'])
|
| 64 |
+
# Load position embedding
|
| 65 |
+
pe = np.fromfile(required_files[3], dtype=np.float32)
|
| 66 |
+
# Load tokens
|
| 67 |
+
model_config = json.load(open(required_files[4], "r"))
|
| 68 |
+
model_config["all_language_tokens"] = [int(i) for i in model_config["all_language_tokens"].split(",")]
|
| 69 |
+
model_config["all_language_codes"] = [i for i in model_config["all_language_codes"].split(",")]
|
| 70 |
+
tokenizer = get_tokenizer(
|
| 71 |
+
model_config["is_multilingual"],
|
| 72 |
+
num_languages=len(model_config["all_language_codes"]),
|
| 73 |
+
language=language,
|
| 74 |
+
task=task,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
return encoder, decoder_main, decoder_loop, pe, tokenizer, model_config
|
| 78 |
+
|
| 79 |
|
| 80 |
+
def load_config(self, model_config):
|
| 81 |
+
config = WhisperConfig
|
| 82 |
+
config.n_mels = model_config["n_mels"]
|
| 83 |
+
config.sample_rate = 16000
|
| 84 |
+
config.n_fft = 480
|
| 85 |
+
config.hop_length = 160
|
| 86 |
+
|
| 87 |
+
config.sot = model_config["sot"]
|
| 88 |
+
config.eot = model_config["eot"]
|
| 89 |
+
config.blank_id = model_config["blank_id"]
|
| 90 |
+
config.no_timestamps = model_config["no_timestamps"]
|
| 91 |
+
config.no_speech = model_config["no_speech"]
|
| 92 |
+
config.translate = model_config["translate"]
|
| 93 |
+
config.transcribe = model_config["transcribe"]
|
| 94 |
+
config.n_vocab = model_config["n_vocab"]
|
| 95 |
+
config.n_text_ctx = model_config["n_text_ctx"]
|
| 96 |
+
config.n_text_state = model_config["n_text_state"]
|
| 97 |
+
|
| 98 |
+
lang_token = model_config["all_language_tokens"][model_config["all_language_codes"].index(self.language)]
|
| 99 |
+
task_token = config.transcribe if self.task == "transcribe" else config.translate
|
| 100 |
+
config.sot_sequence = np.array([config.sot, lang_token, task_token, config.no_timestamps], dtype=np.int32)
|
| 101 |
+
|
| 102 |
+
return config
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def load_audio(self, audio: str):
|
| 106 |
+
data, sample_rate = librosa.load(audio, sr=self.config.sample_rate)
|
| 107 |
+
samples = np.ascontiguousarray(data)
|
| 108 |
+
return samples, sample_rate
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def compute_feature(self, audio: np.ndarray, padding = 480000):
|
| 112 |
+
if padding > 0:
|
| 113 |
+
audio = np.concatenate((audio, np.zeros((padding,), dtype=np.float32)), axis=-1)
|
| 114 |
+
|
| 115 |
+
mel = librosa.feature.melspectrogram(y=audio,
|
| 116 |
+
sr=self.config.sample_rate,
|
| 117 |
+
n_fft=self.config.n_fft,
|
| 118 |
+
hop_length=self.config.hop_length,
|
| 119 |
+
window="hann",
|
| 120 |
+
center=True,
|
| 121 |
+
pad_mode="reflect",
|
| 122 |
+
power=2.0,
|
| 123 |
+
n_mels=self.config.n_mels)
|
| 124 |
+
|
| 125 |
+
log_spec = np.log10(np.maximum(mel, 1e-10))
|
| 126 |
+
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
|
| 127 |
+
mel = (log_spec + 4.0) / 4.0
|
| 128 |
+
|
| 129 |
+
target = 3000
|
| 130 |
+
if mel.shape[1] > target:
|
| 131 |
+
# -50 so that there are some zero tail paddings.
|
| 132 |
+
mel = mel[:, : target]
|
| 133 |
+
mel[:, -50:] = 0
|
| 134 |
+
|
| 135 |
+
# We don't need to pad it to 30 seconds now!
|
| 136 |
+
if mel.shape[1] < target:
|
| 137 |
+
mel = np.concatenate((mel, np.zeros((self.config.n_mels, target - mel.shape[1]), dtype=np.float32)), axis=-1)
|
| 138 |
+
|
| 139 |
+
return mel
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def supress_tokens(self, logits, is_initial):
|
| 143 |
+
if is_initial:
|
| 144 |
+
logits[self.config.eot] = NEG_INF
|
| 145 |
+
logits[self.config.blank_id] = NEG_INF
|
| 146 |
+
|
| 147 |
+
logits[self.config.no_timestamps] = NEG_INF
|
| 148 |
+
logits[self.config.sot] = NEG_INF
|
| 149 |
+
logits[self.config.no_speech] = NEG_INF
|
| 150 |
+
|
| 151 |
+
if self.task == "transcribe":
|
| 152 |
+
logits[self.config.translate] = NEG_INF
|
| 153 |
+
else:
|
| 154 |
+
logits[self.config.transcribe] = NEG_INF
|
| 155 |
+
return logits
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def run(self, audio: Union[str, np.ndarray]) -> str:
|
| 159 |
+
if isinstance(audio, str):
|
| 160 |
+
audio, sample_rate = self.load_audio(audio)
|
| 161 |
+
|
| 162 |
+
mel = self.compute_feature(audio)
|
| 163 |
+
|
| 164 |
+
# Run encoder
|
| 165 |
+
x = self.encoder.run(None, input_feed={"mel": mel[None, ...]})
|
| 166 |
+
n_layer_cross_k, n_layer_cross_v = x
|
| 167 |
+
|
| 168 |
+
# Run decoder_main
|
| 169 |
+
x = self.decoder_main.run(None, input_feed={
|
| 170 |
+
"tokens": self.config.sot_sequence[None, ...],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
"n_layer_cross_k": n_layer_cross_k,
|
| 172 |
+
"n_layer_cross_v": n_layer_cross_v
|
|
|
|
|
|
|
| 173 |
})
|
| 174 |
logits, n_layer_self_k_cache, n_layer_self_v_cache = x
|
|
|
|
| 175 |
|
| 176 |
# Decode token
|
| 177 |
+
logits = logits[0, -1, :]
|
| 178 |
+
logits = self.supress_tokens(logits, is_initial=True)
|
| 179 |
+
# logits.tofile("logits.bin")
|
| 180 |
max_token_id = np.argmax(logits)
|
| 181 |
+
output_tokens = []
|
| 182 |
+
|
| 183 |
+
# Position embedding offset
|
| 184 |
+
offset = self.config.sot_sequence.shape[0]
|
| 185 |
+
|
| 186 |
+
# Autoregressively run decoder until token meets EOT
|
| 187 |
+
for i in range(self.config.n_text_ctx - self.config.sot_sequence.shape[0]):
|
| 188 |
+
if max_token_id >= self.config.eot:
|
| 189 |
+
break
|
| 190 |
+
|
| 191 |
+
output_tokens.append(max_token_id)
|
| 192 |
+
|
| 193 |
+
mask = np.zeros((self.config.n_text_ctx,), dtype=np.float32)
|
| 194 |
+
mask[: self.config.n_text_ctx - offset - 1] = NEG_INF
|
| 195 |
+
|
| 196 |
+
# Run decoder_loop
|
| 197 |
+
x = self.decoder_loop.run(None, input_feed={
|
| 198 |
+
"tokens": np.array([[output_tokens[-1]]], dtype=np.int32),
|
| 199 |
+
"in_n_layer_self_k_cache": n_layer_self_k_cache,
|
| 200 |
+
"in_n_layer_self_v_cache": n_layer_self_v_cache,
|
| 201 |
+
"n_layer_cross_k": n_layer_cross_k,
|
| 202 |
+
"n_layer_cross_v": n_layer_cross_v,
|
| 203 |
+
"positional_embedding": self.pe[offset * self.config.n_text_state : (offset + 1) * self.config.n_text_state][None, ...],
|
| 204 |
+
"mask": mask
|
| 205 |
+
})
|
| 206 |
+
logits, n_layer_self_k_cache, n_layer_self_v_cache = x
|
| 207 |
+
|
| 208 |
+
# Decode token
|
| 209 |
+
offset += 1
|
| 210 |
+
logits = self.supress_tokens(logits.flatten(), is_initial=False)
|
| 211 |
+
max_token_id = np.argmax(logits)
|
| 212 |
+
|
| 213 |
+
text = self.tokenizer.decode(output_tokens)
|
| 214 |
+
|
| 215 |
+
if self.language == "zh":
|
| 216 |
+
try:
|
| 217 |
+
sim_zh = zhconv.convert(text, 'zh-hans')
|
| 218 |
+
return sim_zh
|
| 219 |
+
except:
|
| 220 |
+
return text
|
| 221 |
+
|
| 222 |
+
return text
|
| 223 |
+
|
| 224 |
+
|
python/whisper_cli.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
|
| 3 |
+
def transcribe_audio(
|
| 4 |
+
server_url: str,
|
| 5 |
+
wav_path: str,
|
| 6 |
+
model_type: str = "tiny",
|
| 7 |
+
model_path: str = "../models/models-ax650",
|
| 8 |
+
language: str = "zh",
|
| 9 |
+
task: str = "transcribe"
|
| 10 |
+
):
|
| 11 |
+
url = f"{server_url.rstrip('/')}/asr"
|
| 12 |
+
|
| 13 |
+
files = {
|
| 14 |
+
"wav": open(wav_path, "rb"),
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
data = {
|
| 18 |
+
"model_type": model_type,
|
| 19 |
+
"model_path": model_path,
|
| 20 |
+
"language": language,
|
| 21 |
+
"task": task,
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
print(f"Sending request to: {url}")
|
| 25 |
+
|
| 26 |
+
response = requests.post(url, files=files, data=data)
|
| 27 |
+
if response.status_code != 200:
|
| 28 |
+
print("❌ Error:", response.text)
|
| 29 |
+
return None
|
| 30 |
+
|
| 31 |
+
result = response.json()
|
| 32 |
+
print("服务器返回结果:")
|
| 33 |
+
print(result)
|
| 34 |
+
|
| 35 |
+
return result
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
if __name__ == "__main__":
|
| 39 |
+
# 你的服务器地址
|
| 40 |
+
SERVER = "http://127.0.0.1:8000"
|
| 41 |
+
|
| 42 |
+
# 本地 wav 文件路径
|
| 43 |
+
WAV = "../demo.wav"
|
| 44 |
+
|
| 45 |
+
transcribe_audio(SERVER, WAV)
|
python/whisper_onnx.py
DELETED
|
@@ -1,239 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import onnxruntime as ort
|
| 3 |
-
import numpy as np
|
| 4 |
-
import librosa
|
| 5 |
-
import os
|
| 6 |
-
from typing import Tuple
|
| 7 |
-
import soundfile as sf
|
| 8 |
-
import base64
|
| 9 |
-
import zhconv
|
| 10 |
-
import time
|
| 11 |
-
import torch
|
| 12 |
-
from torch.nn import functional as F
|
| 13 |
-
from languages import WHISPER_LANGUAGES
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
WHISPER_N_MELS = 80
|
| 17 |
-
WHISPER_SAMPLE_RATE = 16000
|
| 18 |
-
WHISPER_N_FFT = 480
|
| 19 |
-
WHISPER_HOP_LENGTH = 160
|
| 20 |
-
|
| 21 |
-
WHISPER_SOT = 50258
|
| 22 |
-
WHISPER_EOT = 50257
|
| 23 |
-
WHISPER_BLANK = 220
|
| 24 |
-
WHISPER_NO_TIMESTAMPS = 50363
|
| 25 |
-
WHISPER_NO_SPEECH = 50362
|
| 26 |
-
WHISPER_TRANSLATE = 50358
|
| 27 |
-
WHISPER_TRANSCRIBE = 50359
|
| 28 |
-
WHISPER_VOCAB_SIZE = 51865
|
| 29 |
-
WHISPER_N_TEXT_CTX = 448
|
| 30 |
-
|
| 31 |
-
NEG_INF = float("-inf")
|
| 32 |
-
SOT_SEQUENCE = np.array([WHISPER_SOT,WHISPER_SOT + 1 + tuple(WHISPER_LANGUAGES).index("zh"),WHISPER_TRANSCRIBE,WHISPER_NO_TIMESTAMPS], dtype=np.int64)
|
| 33 |
-
WHISPER_N_TEXT_STATE_MAP = {
|
| 34 |
-
"tiny": 384,
|
| 35 |
-
"base": 512,
|
| 36 |
-
"small": 768
|
| 37 |
-
}
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
def get_args():
|
| 41 |
-
parser = argparse.ArgumentParser(
|
| 42 |
-
prog="whisper",
|
| 43 |
-
description="Run Whisper on input audio file"
|
| 44 |
-
)
|
| 45 |
-
parser.add_argument("--wav", "-w", type=str, required=True, help="Input audio file")
|
| 46 |
-
parser.add_argument("--model_type", "-t", type=str, choices=["tiny", "base", "small"], required=True, help="model type, only support tiny/base/small currently")
|
| 47 |
-
parser.add_argument("--model_path", "-p", type=str, required=False, default="../models", help="model path for *.axmodel, tokens.txt, positional_embedding.bin")
|
| 48 |
-
parser.add_argument("--language", "-l", type=str, required=False, default="zh", help="Target language, support en, zh, ja, and others. See languages.py for more options.")
|
| 49 |
-
return parser.parse_args()
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
def print_args(args):
|
| 53 |
-
print(f"wav: {args.wav}")
|
| 54 |
-
print(f"model_type: {args.model_type}")
|
| 55 |
-
print(f"model_path: {args.model_path}")
|
| 56 |
-
print(f"language: {args.language}")
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def load_audio(filename: str) -> Tuple[np.ndarray, int]:
|
| 60 |
-
data, sample_rate = sf.read(
|
| 61 |
-
filename,
|
| 62 |
-
always_2d=True,
|
| 63 |
-
dtype="float32",
|
| 64 |
-
)
|
| 65 |
-
data = data[:, 0] # use only the first channel
|
| 66 |
-
data = librosa.resample(data, orig_sr=sample_rate, target_sr=WHISPER_SAMPLE_RATE)
|
| 67 |
-
samples = np.ascontiguousarray(data)
|
| 68 |
-
return samples, sample_rate
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
def load_models(model_path, model_type):
|
| 72 |
-
encoder_path = f"{model_type}-encoder.onnx"
|
| 73 |
-
decoder_main_path = f"{model_type}-decoder-main.onnx"
|
| 74 |
-
decoder_loop_path = f"{model_type}-decoder-loop.onnx"
|
| 75 |
-
pe_path = f"{model_type}-positional_embedding.bin"
|
| 76 |
-
token_path = f"{model_type}-tokens.txt"
|
| 77 |
-
|
| 78 |
-
required_files = [os.path.join(model_path, i) for i in (encoder_path, decoder_main_path, decoder_loop_path, pe_path, token_path)]
|
| 79 |
-
# Check file existence
|
| 80 |
-
for i, file_path in enumerate(required_files):
|
| 81 |
-
assert os.path.exists(file_path), f"{file_path} NOT exist"
|
| 82 |
-
|
| 83 |
-
# Load encoder
|
| 84 |
-
encoder = ort.InferenceSession(required_files[0], providers=['CPUExecutionProvider'])
|
| 85 |
-
# Load decoder main
|
| 86 |
-
decoder_main = ort.InferenceSession(required_files[1], providers=['CPUExecutionProvider'])
|
| 87 |
-
# Load decoder loop
|
| 88 |
-
decoder_loop = ort.InferenceSession(required_files[2], providers=['CPUExecutionProvider'])
|
| 89 |
-
# Load position embedding
|
| 90 |
-
pe = np.fromfile(required_files[3], dtype=np.float32)
|
| 91 |
-
# Load tokens
|
| 92 |
-
tokens = []
|
| 93 |
-
with open(required_files[4], "r") as f:
|
| 94 |
-
for line in f:
|
| 95 |
-
line = line.strip()
|
| 96 |
-
tokens.append(line.split(" ")[0])
|
| 97 |
-
|
| 98 |
-
return encoder, decoder_main, decoder_loop, pe, tokens
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
def compute_feature(wav_path, n_mels = WHISPER_N_MELS, padding = 480000):
|
| 102 |
-
audio, sr = load_audio(wav_path)
|
| 103 |
-
|
| 104 |
-
audio = np.concatenate((audio, np.zeros((padding,), dtype=np.float32)), axis=-1)
|
| 105 |
-
|
| 106 |
-
mel = librosa.feature.melspectrogram(y=audio, sr=sr, n_fft=WHISPER_N_FFT, hop_length=WHISPER_HOP_LENGTH, window="hann", center=True, pad_mode="reflect", power=2.0, n_mels=n_mels)
|
| 107 |
-
log_spec = np.log10(np.maximum(mel, 1e-10))
|
| 108 |
-
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
|
| 109 |
-
mel = (log_spec + 4.0) / 4.0
|
| 110 |
-
|
| 111 |
-
# We pad 1500 frames at the end so that it is able to detect eot
|
| 112 |
-
# You can use another value instead of 1500.
|
| 113 |
-
# mel = np.concatenate((mel, np.zeros((n_mels, 1500), dtype=np.float32)), axis=-1)
|
| 114 |
-
|
| 115 |
-
target = 3000
|
| 116 |
-
if mel.shape[1] > target:
|
| 117 |
-
# -50 so that there are some zero tail paddings.
|
| 118 |
-
mel = mel[:, : target]
|
| 119 |
-
mel[:, -50:] = 0
|
| 120 |
-
|
| 121 |
-
# We don't need to pad it to 30 seconds now!
|
| 122 |
-
if mel.shape[1] < target:
|
| 123 |
-
mel = np.concatenate((mel, np.zeros((n_mels, target - mel.shape[1]), dtype=np.float32)), axis=-1)
|
| 124 |
-
|
| 125 |
-
return mel
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
def supress_tokens(logits, is_initial):
|
| 129 |
-
if is_initial:
|
| 130 |
-
logits[WHISPER_EOT] = NEG_INF
|
| 131 |
-
logits[WHISPER_BLANK] = NEG_INF
|
| 132 |
-
|
| 133 |
-
logits[WHISPER_NO_TIMESTAMPS] = NEG_INF
|
| 134 |
-
logits[WHISPER_SOT] = NEG_INF
|
| 135 |
-
logits[WHISPER_NO_SPEECH] = NEG_INF
|
| 136 |
-
logits[WHISPER_TRANSLATE] = NEG_INF
|
| 137 |
-
return logits
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
def choose_language(lang):
|
| 141 |
-
if lang not in WHISPER_LANGUAGES.keys():
|
| 142 |
-
raise Exception(f"Unknown language: {lang}. Check languages.py for correct options.")
|
| 143 |
-
SOT_SEQUENCE[1] = WHISPER_SOT + 1 + tuple(WHISPER_LANGUAGES.keys()).index(lang)
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
def main():
|
| 147 |
-
args = get_args()
|
| 148 |
-
print_args(args)
|
| 149 |
-
|
| 150 |
-
# Check wav existence
|
| 151 |
-
wav_path = args.wav
|
| 152 |
-
assert os.path.exists(wav_path), f"{wav_path} NOT exist"
|
| 153 |
-
|
| 154 |
-
# Choose language
|
| 155 |
-
choose_language(args.language)
|
| 156 |
-
|
| 157 |
-
# Load models and other stuff
|
| 158 |
-
encoder, decoder_main, decoder_loop, pe, token_table = load_models(args.model_path, args.model_type)
|
| 159 |
-
WHISPER_N_TEXT_STATE = WHISPER_N_TEXT_STATE_MAP[args.model_type]
|
| 160 |
-
|
| 161 |
-
# Preprocess
|
| 162 |
-
mel = compute_feature(wav_path, n_mels=WHISPER_N_MELS)
|
| 163 |
-
# mel.tofile("mel.bin")
|
| 164 |
-
# mel = np.load("../mel.npy")[..., :3000]
|
| 165 |
-
|
| 166 |
-
# Run encoder
|
| 167 |
-
start = time.time()
|
| 168 |
-
x = encoder.run(None, input_feed={"mel": mel[None, ...]})
|
| 169 |
-
n_layer_cross_k, n_layer_cross_v = x
|
| 170 |
-
print(f"Run encoder take {(time.time() - start) * 1000}ms")
|
| 171 |
-
|
| 172 |
-
# n_layer_cross_k.tofile("n_layer_cross_k.bin")
|
| 173 |
-
# n_layer_cross_v.tofile("n_layer_cross_v.bin")
|
| 174 |
-
|
| 175 |
-
# Run decoder_main
|
| 176 |
-
start = time.time()
|
| 177 |
-
x = decoder_main.run(None, input_feed={
|
| 178 |
-
"tokens": SOT_SEQUENCE[None, ...],
|
| 179 |
-
"n_layer_cross_k": n_layer_cross_k,
|
| 180 |
-
"n_layer_cross_v": n_layer_cross_v
|
| 181 |
-
})
|
| 182 |
-
logits, n_layer_self_k_cache, n_layer_self_v_cache = x
|
| 183 |
-
print(f"Run decoder_main take {(time.time() - start) * 1000}ms")
|
| 184 |
-
|
| 185 |
-
# Decode token
|
| 186 |
-
logits = logits[0, -1, :]
|
| 187 |
-
logits = supress_tokens(logits, is_initial=True)
|
| 188 |
-
# logits.tofile("logits.bin")
|
| 189 |
-
max_token_id = np.argmax(logits)
|
| 190 |
-
output_tokens = []
|
| 191 |
-
print(f"First token: {max_token_id}")
|
| 192 |
-
|
| 193 |
-
# Position embedding offset
|
| 194 |
-
offset = SOT_SEQUENCE.shape[0]
|
| 195 |
-
|
| 196 |
-
# Autoregressively run decoder until token meets EOT
|
| 197 |
-
for i in range(WHISPER_N_TEXT_CTX - SOT_SEQUENCE.shape[0]):
|
| 198 |
-
if max_token_id == WHISPER_EOT:
|
| 199 |
-
break
|
| 200 |
-
|
| 201 |
-
output_tokens.append(max_token_id)
|
| 202 |
-
|
| 203 |
-
mask = np.zeros((WHISPER_N_TEXT_CTX,), dtype=np.float32)
|
| 204 |
-
mask[: WHISPER_N_TEXT_CTX - offset - 1] = NEG_INF
|
| 205 |
-
|
| 206 |
-
# Run decoder_loop
|
| 207 |
-
start = time.time()
|
| 208 |
-
x = decoder_loop.run(None, input_feed={
|
| 209 |
-
"tokens": np.array([[output_tokens[-1]]], dtype=np.int64),
|
| 210 |
-
"in_n_layer_self_k_cache": n_layer_self_k_cache,
|
| 211 |
-
"in_n_layer_self_v_cache": n_layer_self_v_cache,
|
| 212 |
-
"n_layer_cross_k": n_layer_cross_k,
|
| 213 |
-
"n_layer_cross_v": n_layer_cross_v,
|
| 214 |
-
"positional_embedding": pe[offset * WHISPER_N_TEXT_STATE : (offset + 1) * WHISPER_N_TEXT_STATE][None, ...],
|
| 215 |
-
"mask": mask
|
| 216 |
-
})
|
| 217 |
-
logits, n_layer_self_k_cache, n_layer_self_v_cache = x
|
| 218 |
-
print(f"Run decoder_loop take {(time.time() - start) * 1000}ms")
|
| 219 |
-
|
| 220 |
-
# Decode token
|
| 221 |
-
offset += 1
|
| 222 |
-
logits = supress_tokens(logits.flatten(), is_initial=False)
|
| 223 |
-
max_token_id = np.argmax(logits)
|
| 224 |
-
|
| 225 |
-
print(f"Iter {i} \t Token: {max_token_id}")
|
| 226 |
-
|
| 227 |
-
s = b""
|
| 228 |
-
for i in output_tokens:
|
| 229 |
-
s += base64.b64decode(token_table[i])
|
| 230 |
-
# print(s.decode().strip())
|
| 231 |
-
pd = s.decode().strip()
|
| 232 |
-
if args.language == "zh":
|
| 233 |
-
pd = zhconv.convert(pd, 'zh-hans')
|
| 234 |
-
|
| 235 |
-
print(f"Result: {pd}")
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
if __name__ == "__main__":
|
| 239 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
python/whisper_svr.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import tempfile
|
| 5 |
+
from http.server import BaseHTTPRequestHandler, HTTPServer
|
| 6 |
+
from urllib.parse import parse_qs
|
| 7 |
+
|
| 8 |
+
from whisper import Whisper
|
| 9 |
+
import cgi
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# 模型缓存:避免每次请求都重新加载
|
| 13 |
+
_model_cache = {}
|
| 14 |
+
|
| 15 |
+
def get_model(model_type, model_path, language, task):
|
| 16 |
+
key = (model_type, model_path, language, task)
|
| 17 |
+
if key not in _model_cache:
|
| 18 |
+
print(f"Loading model: type={model_type}, path={model_path}, lang={language}, task={task}")
|
| 19 |
+
_model_cache[key] = Whisper(model_type, model_path, language, task)
|
| 20 |
+
return _model_cache[key]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class WhisperHandler(BaseHTTPRequestHandler):
|
| 24 |
+
|
| 25 |
+
def _send_json(self, obj, status=200):
|
| 26 |
+
data = json.dumps(obj, ensure_ascii=False).encode("utf-8")
|
| 27 |
+
self.send_response(status)
|
| 28 |
+
self.send_header("Content-Type", "application/json; charset=utf-8")
|
| 29 |
+
self.send_header("Content-Length", str(len(data)))
|
| 30 |
+
self.end_headers()
|
| 31 |
+
self.wfile.write(data)
|
| 32 |
+
|
| 33 |
+
def do_GET(self):
|
| 34 |
+
if self.path == "/health":
|
| 35 |
+
self._send_json({"status": "ok"})
|
| 36 |
+
else:
|
| 37 |
+
self._send_json({"error": "not found"}, 404)
|
| 38 |
+
|
| 39 |
+
def do_POST(self):
|
| 40 |
+
if self.path != "/asr":
|
| 41 |
+
self._send_json({"error": "not found"}, 404)
|
| 42 |
+
return
|
| 43 |
+
|
| 44 |
+
# 解析 multipart/form-data
|
| 45 |
+
content_type = self.headers.get('Content-Type')
|
| 46 |
+
if not content_type:
|
| 47 |
+
self._send_json({"error": "Missing Content-Type"}, 400)
|
| 48 |
+
return
|
| 49 |
+
|
| 50 |
+
ctype, pdict = cgi.parse_header(content_type)
|
| 51 |
+
|
| 52 |
+
if ctype != 'multipart/form-data':
|
| 53 |
+
self._send_json({"error": "Only multipart/form-data is supported"}, 400)
|
| 54 |
+
return
|
| 55 |
+
|
| 56 |
+
pdict['boundary'] = bytes(pdict['boundary'], "utf-8")
|
| 57 |
+
pdict['CONTENT-LENGTH'] = int(self.headers['Content-Length'])
|
| 58 |
+
|
| 59 |
+
form = cgi.parse_multipart(self.rfile, pdict)
|
| 60 |
+
|
| 61 |
+
# 必须包含 wav 文件
|
| 62 |
+
if "wav" not in form:
|
| 63 |
+
self._send_json({"error": "Field 'wav' is required"}, 400)
|
| 64 |
+
return
|
| 65 |
+
|
| 66 |
+
# 获取参数(如果缺省则使用默认值)
|
| 67 |
+
model_type = form.get("model_type", ["tiny"])[0]
|
| 68 |
+
model_path = form.get("model_path", ["../models/models-ax650"])[0]
|
| 69 |
+
language = form.get("language", ["zh"])[0]
|
| 70 |
+
task = form.get("task", ["transcribe"])[0]
|
| 71 |
+
|
| 72 |
+
if task not in ("transcribe", "translate"):
|
| 73 |
+
self._send_json({"error": "task must be 'transcribe' or 'translate'"}, 400)
|
| 74 |
+
return
|
| 75 |
+
|
| 76 |
+
wav_bytes = form["wav"][0]
|
| 77 |
+
|
| 78 |
+
# 写入临时文件
|
| 79 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
| 80 |
+
tmp.write(wav_bytes)
|
| 81 |
+
wav_path = tmp.name
|
| 82 |
+
|
| 83 |
+
# 加载模型并运行
|
| 84 |
+
try:
|
| 85 |
+
model = get_model(model_type, model_path, language, task)
|
| 86 |
+
result_text = model.run(wav_path)
|
| 87 |
+
except Exception as e:
|
| 88 |
+
self._send_json({"error": str(e)}, 500)
|
| 89 |
+
return
|
| 90 |
+
finally:
|
| 91 |
+
if os.path.exists(wav_path):
|
| 92 |
+
os.remove(wav_path)
|
| 93 |
+
|
| 94 |
+
self._send_json({"text": result_text})
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
if __name__ == "__main__":
|
| 98 |
+
parser = argparse.ArgumentParser(description="Whisper Server")
|
| 99 |
+
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
|
| 100 |
+
args = parser.parse_args()
|
| 101 |
+
port = args.port
|
| 102 |
+
server = HTTPServer(("0.0.0.0", port), WhisperHandler)
|
| 103 |
+
print(f"Server started at http://0.0.0.0:{port}")
|
| 104 |
+
server.serve_forever()
|
python/whisper_tokenizer.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import os
|
| 3 |
+
import string
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from functools import cached_property, lru_cache
|
| 6 |
+
from typing import Dict, List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import tiktoken
|
| 9 |
+
|
| 10 |
+
LANGUAGES = {
|
| 11 |
+
"en": "english",
|
| 12 |
+
"zh": "chinese",
|
| 13 |
+
"de": "german",
|
| 14 |
+
"es": "spanish",
|
| 15 |
+
"ru": "russian",
|
| 16 |
+
"ko": "korean",
|
| 17 |
+
"fr": "french",
|
| 18 |
+
"ja": "japanese",
|
| 19 |
+
"pt": "portuguese",
|
| 20 |
+
"tr": "turkish",
|
| 21 |
+
"pl": "polish",
|
| 22 |
+
"ca": "catalan",
|
| 23 |
+
"nl": "dutch",
|
| 24 |
+
"ar": "arabic",
|
| 25 |
+
"sv": "swedish",
|
| 26 |
+
"it": "italian",
|
| 27 |
+
"id": "indonesian",
|
| 28 |
+
"hi": "hindi",
|
| 29 |
+
"fi": "finnish",
|
| 30 |
+
"vi": "vietnamese",
|
| 31 |
+
"he": "hebrew",
|
| 32 |
+
"uk": "ukrainian",
|
| 33 |
+
"el": "greek",
|
| 34 |
+
"ms": "malay",
|
| 35 |
+
"cs": "czech",
|
| 36 |
+
"ro": "romanian",
|
| 37 |
+
"da": "danish",
|
| 38 |
+
"hu": "hungarian",
|
| 39 |
+
"ta": "tamil",
|
| 40 |
+
"no": "norwegian",
|
| 41 |
+
"th": "thai",
|
| 42 |
+
"ur": "urdu",
|
| 43 |
+
"hr": "croatian",
|
| 44 |
+
"bg": "bulgarian",
|
| 45 |
+
"lt": "lithuanian",
|
| 46 |
+
"la": "latin",
|
| 47 |
+
"mi": "maori",
|
| 48 |
+
"ml": "malayalam",
|
| 49 |
+
"cy": "welsh",
|
| 50 |
+
"sk": "slovak",
|
| 51 |
+
"te": "telugu",
|
| 52 |
+
"fa": "persian",
|
| 53 |
+
"lv": "latvian",
|
| 54 |
+
"bn": "bengali",
|
| 55 |
+
"sr": "serbian",
|
| 56 |
+
"az": "azerbaijani",
|
| 57 |
+
"sl": "slovenian",
|
| 58 |
+
"kn": "kannada",
|
| 59 |
+
"et": "estonian",
|
| 60 |
+
"mk": "macedonian",
|
| 61 |
+
"br": "breton",
|
| 62 |
+
"eu": "basque",
|
| 63 |
+
"is": "icelandic",
|
| 64 |
+
"hy": "armenian",
|
| 65 |
+
"ne": "nepali",
|
| 66 |
+
"mn": "mongolian",
|
| 67 |
+
"bs": "bosnian",
|
| 68 |
+
"kk": "kazakh",
|
| 69 |
+
"sq": "albanian",
|
| 70 |
+
"sw": "swahili",
|
| 71 |
+
"gl": "galician",
|
| 72 |
+
"mr": "marathi",
|
| 73 |
+
"pa": "punjabi",
|
| 74 |
+
"si": "sinhala",
|
| 75 |
+
"km": "khmer",
|
| 76 |
+
"sn": "shona",
|
| 77 |
+
"yo": "yoruba",
|
| 78 |
+
"so": "somali",
|
| 79 |
+
"af": "afrikaans",
|
| 80 |
+
"oc": "occitan",
|
| 81 |
+
"ka": "georgian",
|
| 82 |
+
"be": "belarusian",
|
| 83 |
+
"tg": "tajik",
|
| 84 |
+
"sd": "sindhi",
|
| 85 |
+
"gu": "gujarati",
|
| 86 |
+
"am": "amharic",
|
| 87 |
+
"yi": "yiddish",
|
| 88 |
+
"lo": "lao",
|
| 89 |
+
"uz": "uzbek",
|
| 90 |
+
"fo": "faroese",
|
| 91 |
+
"ht": "haitian creole",
|
| 92 |
+
"ps": "pashto",
|
| 93 |
+
"tk": "turkmen",
|
| 94 |
+
"nn": "nynorsk",
|
| 95 |
+
"mt": "maltese",
|
| 96 |
+
"sa": "sanskrit",
|
| 97 |
+
"lb": "luxembourgish",
|
| 98 |
+
"my": "myanmar",
|
| 99 |
+
"bo": "tibetan",
|
| 100 |
+
"tl": "tagalog",
|
| 101 |
+
"mg": "malagasy",
|
| 102 |
+
"as": "assamese",
|
| 103 |
+
"tt": "tatar",
|
| 104 |
+
"haw": "hawaiian",
|
| 105 |
+
"ln": "lingala",
|
| 106 |
+
"ha": "hausa",
|
| 107 |
+
"ba": "bashkir",
|
| 108 |
+
"jw": "javanese",
|
| 109 |
+
"su": "sundanese",
|
| 110 |
+
"yue": "cantonese",
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
# language code lookup by name, with a few language aliases
|
| 114 |
+
TO_LANGUAGE_CODE = {
|
| 115 |
+
**{language: code for code, language in LANGUAGES.items()},
|
| 116 |
+
"burmese": "my",
|
| 117 |
+
"valencian": "ca",
|
| 118 |
+
"flemish": "nl",
|
| 119 |
+
"haitian": "ht",
|
| 120 |
+
"letzeburgesch": "lb",
|
| 121 |
+
"pushto": "ps",
|
| 122 |
+
"panjabi": "pa",
|
| 123 |
+
"moldavian": "ro",
|
| 124 |
+
"moldovan": "ro",
|
| 125 |
+
"sinhalese": "si",
|
| 126 |
+
"castilian": "es",
|
| 127 |
+
"mandarin": "zh",
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@dataclass
|
| 132 |
+
class Tokenizer:
|
| 133 |
+
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
|
| 134 |
+
|
| 135 |
+
encoding: tiktoken.Encoding
|
| 136 |
+
num_languages: int
|
| 137 |
+
language: Optional[str] = None
|
| 138 |
+
task: Optional[str] = None
|
| 139 |
+
sot_sequence: Tuple[int] = ()
|
| 140 |
+
special_tokens: Dict[str, int] = field(default_factory=dict)
|
| 141 |
+
|
| 142 |
+
def __post_init__(self):
|
| 143 |
+
for special in self.encoding.special_tokens_set:
|
| 144 |
+
special_token = self.encoding.encode_single_token(special)
|
| 145 |
+
self.special_tokens[special] = special_token
|
| 146 |
+
|
| 147 |
+
sot: int = self.special_tokens["<|startoftranscript|>"]
|
| 148 |
+
translate: int = self.special_tokens["<|translate|>"]
|
| 149 |
+
transcribe: int = self.special_tokens["<|transcribe|>"]
|
| 150 |
+
|
| 151 |
+
langs = tuple(LANGUAGES.keys())[: self.num_languages]
|
| 152 |
+
sot_sequence = [sot]
|
| 153 |
+
if self.language is not None:
|
| 154 |
+
sot_sequence.append(sot + 1 + langs.index(self.language))
|
| 155 |
+
if self.task is not None:
|
| 156 |
+
task_token: int = transcribe if self.task == "transcribe" else translate
|
| 157 |
+
sot_sequence.append(task_token)
|
| 158 |
+
|
| 159 |
+
self.sot_sequence = tuple(sot_sequence)
|
| 160 |
+
|
| 161 |
+
def encode(self, text, **kwargs):
|
| 162 |
+
return self.encoding.encode(text, **kwargs)
|
| 163 |
+
|
| 164 |
+
def decode(self, token_ids: List[int], **kwargs) -> str:
|
| 165 |
+
token_ids = [t for t in token_ids if t < self.timestamp_begin]
|
| 166 |
+
return self.encoding.decode(token_ids, **kwargs)
|
| 167 |
+
|
| 168 |
+
def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
|
| 169 |
+
"""
|
| 170 |
+
Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
|
| 171 |
+
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
| 172 |
+
"""
|
| 173 |
+
return self.encoding.decode(token_ids, **kwargs)
|
| 174 |
+
|
| 175 |
+
@cached_property
|
| 176 |
+
def eot(self) -> int:
|
| 177 |
+
return self.encoding.eot_token
|
| 178 |
+
|
| 179 |
+
@cached_property
|
| 180 |
+
def transcribe(self) -> int:
|
| 181 |
+
return self.special_tokens["<|transcribe|>"]
|
| 182 |
+
|
| 183 |
+
@cached_property
|
| 184 |
+
def translate(self) -> int:
|
| 185 |
+
return self.special_tokens["<|translate|>"]
|
| 186 |
+
|
| 187 |
+
@cached_property
|
| 188 |
+
def sot(self) -> int:
|
| 189 |
+
return self.special_tokens["<|startoftranscript|>"]
|
| 190 |
+
|
| 191 |
+
@cached_property
|
| 192 |
+
def sot_lm(self) -> int:
|
| 193 |
+
return self.special_tokens["<|startoflm|>"]
|
| 194 |
+
|
| 195 |
+
@cached_property
|
| 196 |
+
def sot_prev(self) -> int:
|
| 197 |
+
return self.special_tokens["<|startofprev|>"]
|
| 198 |
+
|
| 199 |
+
@cached_property
|
| 200 |
+
def no_speech(self) -> int:
|
| 201 |
+
return self.special_tokens["<|nospeech|>"]
|
| 202 |
+
|
| 203 |
+
@cached_property
|
| 204 |
+
def no_timestamps(self) -> int:
|
| 205 |
+
return self.special_tokens["<|notimestamps|>"]
|
| 206 |
+
|
| 207 |
+
@cached_property
|
| 208 |
+
def timestamp_begin(self) -> int:
|
| 209 |
+
return self.special_tokens["<|0.00|>"]
|
| 210 |
+
|
| 211 |
+
@cached_property
|
| 212 |
+
def language_token(self) -> int:
|
| 213 |
+
"""Returns the token id corresponding to the value of the `language` field"""
|
| 214 |
+
if self.language is None:
|
| 215 |
+
raise ValueError("This tokenizer does not have language token configured")
|
| 216 |
+
|
| 217 |
+
return self.to_language_token(self.language)
|
| 218 |
+
|
| 219 |
+
def to_language_token(self, language):
|
| 220 |
+
if token := self.special_tokens.get(f"<|{language}|>", None):
|
| 221 |
+
return token
|
| 222 |
+
|
| 223 |
+
raise KeyError(f"Language {language} not found in tokenizer.")
|
| 224 |
+
|
| 225 |
+
@cached_property
|
| 226 |
+
def all_language_tokens(self) -> Tuple[int]:
|
| 227 |
+
result = []
|
| 228 |
+
for token, token_id in self.special_tokens.items():
|
| 229 |
+
if token.strip("<|>") in LANGUAGES:
|
| 230 |
+
result.append(token_id)
|
| 231 |
+
return tuple(result)[: self.num_languages]
|
| 232 |
+
|
| 233 |
+
@cached_property
|
| 234 |
+
def all_language_codes(self) -> Tuple[str]:
|
| 235 |
+
return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
|
| 236 |
+
|
| 237 |
+
@cached_property
|
| 238 |
+
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
|
| 239 |
+
return tuple(list(self.sot_sequence) + [self.no_timestamps])
|
| 240 |
+
|
| 241 |
+
@cached_property
|
| 242 |
+
def non_speech_tokens(self) -> Tuple[int]:
|
| 243 |
+
"""
|
| 244 |
+
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
| 245 |
+
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
|
| 246 |
+
|
| 247 |
+
- ♪♪♪
|
| 248 |
+
- ( SPEAKING FOREIGN LANGUAGE )
|
| 249 |
+
- [DAVID] Hey there,
|
| 250 |
+
|
| 251 |
+
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
| 252 |
+
"""
|
| 253 |
+
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
|
| 254 |
+
symbols += (
|
| 255 |
+
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# symbols that may be a single token or multiple tokens depending on the tokenizer.
|
| 259 |
+
# In case they're multiple tokens, suppress the first token, which is safe because:
|
| 260 |
+
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
|
| 261 |
+
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
|
| 262 |
+
miscellaneous = set("♩♪♫♬♭♮♯")
|
| 263 |
+
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
| 264 |
+
|
| 265 |
+
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
| 266 |
+
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
|
| 267 |
+
for symbol in symbols + list(miscellaneous):
|
| 268 |
+
for tokens in [
|
| 269 |
+
self.encoding.encode(symbol),
|
| 270 |
+
self.encoding.encode(" " + symbol),
|
| 271 |
+
]:
|
| 272 |
+
if len(tokens) == 1 or symbol in miscellaneous:
|
| 273 |
+
result.add(tokens[0])
|
| 274 |
+
|
| 275 |
+
return tuple(sorted(result))
|
| 276 |
+
|
| 277 |
+
def split_to_word_tokens(self, tokens: List[int]):
|
| 278 |
+
if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
|
| 279 |
+
# These languages don't typically use spaces, so it is difficult to split words
|
| 280 |
+
# without morpheme analysis. Here, we instead split words at any
|
| 281 |
+
# position where the tokens are decoded as valid unicode points
|
| 282 |
+
return self.split_tokens_on_unicode(tokens)
|
| 283 |
+
|
| 284 |
+
return self.split_tokens_on_spaces(tokens)
|
| 285 |
+
|
| 286 |
+
def split_tokens_on_unicode(self, tokens: List[int]):
|
| 287 |
+
decoded_full = self.decode_with_timestamps(tokens)
|
| 288 |
+
replacement_char = "\ufffd"
|
| 289 |
+
|
| 290 |
+
words = []
|
| 291 |
+
word_tokens = []
|
| 292 |
+
current_tokens = []
|
| 293 |
+
unicode_offset = 0
|
| 294 |
+
|
| 295 |
+
for token in tokens:
|
| 296 |
+
current_tokens.append(token)
|
| 297 |
+
decoded = self.decode_with_timestamps(current_tokens)
|
| 298 |
+
|
| 299 |
+
if (
|
| 300 |
+
replacement_char not in decoded
|
| 301 |
+
or decoded_full[unicode_offset + decoded.index(replacement_char)]
|
| 302 |
+
== replacement_char
|
| 303 |
+
):
|
| 304 |
+
words.append(decoded)
|
| 305 |
+
word_tokens.append(current_tokens)
|
| 306 |
+
current_tokens = []
|
| 307 |
+
unicode_offset += len(decoded)
|
| 308 |
+
|
| 309 |
+
return words, word_tokens
|
| 310 |
+
|
| 311 |
+
def split_tokens_on_spaces(self, tokens: List[int]):
|
| 312 |
+
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
|
| 313 |
+
words = []
|
| 314 |
+
word_tokens = []
|
| 315 |
+
|
| 316 |
+
for subword, subword_tokens in zip(subwords, subword_tokens_list):
|
| 317 |
+
special = subword_tokens[0] >= self.eot
|
| 318 |
+
with_space = subword.startswith(" ")
|
| 319 |
+
punctuation = subword.strip() in string.punctuation
|
| 320 |
+
if special or with_space or punctuation or len(words) == 0:
|
| 321 |
+
words.append(subword)
|
| 322 |
+
word_tokens.append(subword_tokens)
|
| 323 |
+
else:
|
| 324 |
+
words[-1] = words[-1] + subword
|
| 325 |
+
word_tokens[-1].extend(subword_tokens)
|
| 326 |
+
|
| 327 |
+
return words, word_tokens
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
@lru_cache(maxsize=None)
|
| 331 |
+
def get_encoding(name: str = "gpt2", num_languages: int = 99):
|
| 332 |
+
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
| 333 |
+
ranks = {
|
| 334 |
+
base64.b64decode(token): int(rank)
|
| 335 |
+
for token, rank in (line.split() for line in open(vocab_path) if line)
|
| 336 |
+
}
|
| 337 |
+
n_vocab = len(ranks)
|
| 338 |
+
special_tokens = {}
|
| 339 |
+
|
| 340 |
+
specials = [
|
| 341 |
+
"<|endoftext|>",
|
| 342 |
+
"<|startoftranscript|>",
|
| 343 |
+
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
|
| 344 |
+
"<|translate|>",
|
| 345 |
+
"<|transcribe|>",
|
| 346 |
+
"<|startoflm|>",
|
| 347 |
+
"<|startofprev|>",
|
| 348 |
+
"<|nospeech|>",
|
| 349 |
+
"<|notimestamps|>",
|
| 350 |
+
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
| 351 |
+
]
|
| 352 |
+
|
| 353 |
+
for token in specials:
|
| 354 |
+
special_tokens[token] = n_vocab
|
| 355 |
+
n_vocab += 1
|
| 356 |
+
|
| 357 |
+
return tiktoken.Encoding(
|
| 358 |
+
name=os.path.basename(vocab_path),
|
| 359 |
+
explicit_n_vocab=n_vocab,
|
| 360 |
+
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
| 361 |
+
mergeable_ranks=ranks,
|
| 362 |
+
special_tokens=special_tokens,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
@lru_cache(maxsize=None)
|
| 367 |
+
def get_tokenizer(
|
| 368 |
+
multilingual: bool,
|
| 369 |
+
*,
|
| 370 |
+
num_languages: int = 99,
|
| 371 |
+
language: Optional[str] = None,
|
| 372 |
+
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
| 373 |
+
) -> Tokenizer:
|
| 374 |
+
if language is not None:
|
| 375 |
+
language = language.lower()
|
| 376 |
+
if language not in LANGUAGES:
|
| 377 |
+
if language in TO_LANGUAGE_CODE:
|
| 378 |
+
language = TO_LANGUAGE_CODE[language]
|
| 379 |
+
else:
|
| 380 |
+
raise ValueError(f"Unsupported language: {language}")
|
| 381 |
+
|
| 382 |
+
if multilingual:
|
| 383 |
+
encoding_name = "multilingual"
|
| 384 |
+
language = language or "en"
|
| 385 |
+
task = task or "transcribe"
|
| 386 |
+
else:
|
| 387 |
+
encoding_name = "gpt2"
|
| 388 |
+
language = None
|
| 389 |
+
task = None
|
| 390 |
+
|
| 391 |
+
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
|
| 392 |
+
|
| 393 |
+
return Tokenizer(
|
| 394 |
+
encoding=encoding, num_languages=num_languages, language=language, task=task
|
| 395 |
+
)
|