inoryQwQ commited on
Commit
ca02ffa
·
1 Parent(s): 58d5563

Update README, Update python API

Browse files
README.md CHANGED
@@ -5,7 +5,199 @@ pipeline_tag: automatic-speech-recognition
5
 
6
  # Whisper
7
 
8
- ## CPP
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  ### 服务端
11
 
@@ -16,10 +208,86 @@ cd cpp
16
 
17
  ### 客户端
18
 
19
- curl命令行测试:
20
  ```
21
  ffmpeg -i demo.wav -f f32le -c:a pcm_f32le - 2>/dev/null | \
22
  curl -X POST 10.126.33.192:8080/asr \
23
  -H "Content-Type: application/octet-stream" \
24
  --data-binary @-
25
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  # Whisper
7
 
8
+ OpenAI Whisper on Axera
9
+
10
+ - 目前支持 C++ 和 Python 两种语言
11
+ - 预编译模型下载
12
+ - [Huggingface](https://huggingface.co/AXERA-TECH/Whisper)
13
+
14
+ - 如需自行转换请参考[模型转换](https://github.com/ml-inory/whisper.axera/blob/main/model_convert/README.md)
15
+
16
+ ## 支持平台
17
+
18
+ - [x] AX650N
19
+ - [x] AX630C
20
+
21
+ ## 模型转换
22
+
23
+ [模型转换](https://github.com/ml-inory/whisper.axera/blob/main/model_convert/README.md)
24
+
25
+ ## 上板部署
26
+
27
+ - 基于 AX650N、AX630C 的设备已预装 Ubuntu22.04
28
+ - 链接互联网,确保设备能正常执行 `apt install`, `pip install` 等指令
29
+ - 已验证设备:
30
+ - [爱芯派Pro(AX650N)](https://wiki.sipeed.com/hardware/zh/maixIV/m4ndock/m4ndock.html)
31
+ - [M.2 Accelerator card(AX650N)](https://axcl-docs.readthedocs.io/zh-cn/latest/doc_guide_hardware.html)
32
+ - [爱芯派2(AX630C)](https://axera-pi-2-docs-cn.readthedocs.io/zh-cn/latest/index.html)
33
+ - [Module-LLM(AX630C)](https://docs.m5stack.com/zh_CN/module/Module-LLM)
34
+ - [LLM630 Compute Kit(AX630C)](https://docs.m5stack.com/zh_CN/core/LLM630%20Compute%20Kit)
35
+ - 支持编程语言:
36
+ - [Python](#Python)
37
+ - [C++](#CPP)
38
+
39
+ <h3 id="Python">Python</h3>
40
+
41
+ #### Requirements
42
+
43
+ 推荐在板上安装Miniconda管理虚拟环境,安装方法如下:
44
+ ```
45
+ mkdir -p ~/miniconda3
46
+ wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh -O ~/miniconda3/miniconda.sh
47
+ bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
48
+ rm ~/miniconda3/miniconda.sh
49
+
50
+ source ~/miniconda3/bin/activate
51
+
52
+ conda init --all
53
+ ```
54
+
55
+ 安装Whisper依赖
56
+ ```
57
+ cd python
58
+
59
+ conda create -n whisper python=3.12
60
+ conda activate whisper
61
+ pip3 install -r requirements.txt
62
+ ```
63
+
64
+ #### 安装pyaxenigne
65
+
66
+ 参考 https://github.com/AXERA-TECH/pyaxengine 安装 NPU Python API
67
+
68
+ 在0.1.3rc2上测试通过,可通过
69
+ ```
70
+ pip install https://github.com/AXERA-TECH/pyaxengine/releases/download/0.1.3.rc2/axengine-0.1.3-py3-none-any.whl
71
+ ```
72
+ 安装,或把版本号更改为你想使用的版本
73
+
74
+
75
+ #### 运行
76
+
77
+ 登陆开发板后
78
+
79
+ 输入命令
80
+
81
+ ```
82
+ cd python
83
+ conda activate whisper
84
+ python3 main.py --model_type small --model_path ../models-ax650 --wav ../demo.wav --language zh
85
+ ```
86
+
87
+ 输出结果
88
+
89
+ ```
90
+ root@ax650:/mnt/qtang/whisper.axera/python# python3 main.py --wav ../demo.wav --model_type small --model_path ../models/ --language zh
91
+ [INFO] Available providers: ['AxEngineExecutionProvider']
92
+ wav: ../demo.wav
93
+ model_type: small
94
+ model_path: ../models/
95
+ language: zh
96
+ [INFO] Using provider: AxEngineExecutionProvider
97
+ [INFO] Chip type: ChipType.MC50
98
+ [INFO] VNPU type: VNPUType.DISABLED
99
+ [INFO] Engine version: 2.10.1s
100
+ [INFO] Model type: 2 (triple core)
101
+ [INFO] Compiler version: 3.2-patch1 117f5fd4
102
+ [INFO] Using provider: AxEngineExecutionProvider
103
+ [INFO] Model type: 2 (triple core)
104
+ [INFO] Compiler version: 3.2-patch1 117f5fd4
105
+ [INFO] Using provider: AxEngineExecutionProvider
106
+ [INFO] Model type: 2 (triple core)
107
+ [INFO] Compiler version: 3.2-patch1 117f5fd4
108
+ Load models take 2322.563409805298ms
109
+ Preprocess wav take 6971.68493270874ms
110
+ Run encoder take 211.52877807617188ms
111
+ Run decoder_main take 79.00094985961914ms
112
+ First token: 17556
113
+ Run decoder_loop take 101.91774368286133ms
114
+ Iter 0 Token: 20844
115
+ Run decoder_loop take 60.30416488647461ms
116
+ Iter 1 Token: 7781
117
+ Run decoder_loop take 60.22000312805176ms
118
+ Iter 2 Token: 20204
119
+ Run decoder_loop take 60.23716926574707ms
120
+ Iter 3 Token: 28455
121
+ Run decoder_loop take 60.214996337890625ms
122
+ Iter 4 Token: 31962
123
+ Run decoder_loop take 60.17565727233887ms
124
+ Iter 5 Token: 6336
125
+ Run decoder_loop take 60.94002723693848ms
126
+ Iter 6 Token: 254
127
+ Run decoder_loop take 60.71639060974121ms
128
+ Iter 7 Token: 2930
129
+ Run decoder_loop take 60.225725173950195ms
130
+ Iter 8 Token: 236
131
+ Run decoder_loop take 60.167789459228516ms
132
+ Iter 9 Token: 36135
133
+ Run decoder_loop take 60.29987335205078ms
134
+ Iter 10 Token: 15868
135
+ Run decoder_loop take 61.163902282714844ms
136
+ Iter 11 Token: 252
137
+ Run decoder_loop take 60.273170471191406ms
138
+ Iter 12 Token: 1546
139
+ Run decoder_loop take 60.23144721984863ms
140
+ Iter 13 Token: 46514
141
+ Run decoder_loop take 60.31966209411621ms
142
+ Iter 14 Token: 50257
143
+ Result: 甚至出现交易几乎停滞的情况
144
+ ```
145
+
146
+ 运行参数说明:
147
+ | 参数名称 | 说明 | 默认值 |
148
+ | --- | --- | --- |
149
+ | --wav | 输入音频文件 | |
150
+ | --model_type/-t | 模型类型, tiny/base/small | |
151
+ | --model_path/-p | 模型所在目录 | ../models |
152
+ | --language/-l | 识别语言 | zh |
153
+
154
+
155
+ <h3 id="CPP">CPP</h3>
156
+
157
+ #### 运行
158
+
159
+ 在 AX650N 设备上执行
160
+
161
+ ```
162
+ cd cpp
163
+ ./whisper -w ../demo.wav
164
+ ```
165
+
166
+
167
+
168
+ ```
169
+ cd cpp
170
+ ./whisper --model_type small --model_path ../models -w ../demo.wav
171
+ ```
172
+
173
+ 输出结果
174
+
175
+ ```
176
+ root@ax650:/mnt/qtang/whisper.axera/cpp# ./install/whisper --wav ../demo.wav --model_type small --model_path ../models/ --language zh
177
+ wav_file: ../demo.wav
178
+ model_path: ../models/
179
+ model_type: small
180
+ language: zh
181
+ Encoder run take 188.30 ms
182
+ First token: 17556 take 81.88ms
183
+ Next Token: 20844 take 29.64ms
184
+ Next Token: 7781 take 29.70ms
185
+ Next Token: 20204 take 29.64ms
186
+ Next Token: 28455 take 29.65ms
187
+ Next Token: 31962 take 29.61ms
188
+ Next Token: 6336 take 29.67ms
189
+ Next Token: 254 take 29.63ms
190
+ Next Token: 2930 take 29.61ms
191
+ Next Token: 236 take 29.56ms
192
+ Next Token: 36135 take 29.64ms
193
+ Next Token: 15868 take 29.71ms
194
+ Next Token: 252 take 29.51ms
195
+ Next Token: 1546 take 29.63ms
196
+ Next Token: 46514 take 29.51ms
197
+ Next Token: 50257 take 29.69ms
198
+ All take 801.13 ms
199
+ Result: 甚至出现交易几乎停滞的情况
200
+ ```
201
 
202
  ### 服务端
203
 
 
208
 
209
  ### 客户端
210
 
211
+ curl命令行测试(请自行替换IP和端口):
212
  ```
213
  ffmpeg -i demo.wav -f f32le -c:a pcm_f32le - 2>/dev/null | \
214
  curl -X POST 10.126.33.192:8080/asr \
215
  -H "Content-Type: application/octet-stream" \
216
  --data-binary @-
217
+ ```
218
+
219
+ ## 模型性能
220
+
221
+ ### Latency
222
+
223
+ RTF: Real-Time Factor
224
+
225
+ CPP:
226
+
227
+ | Models | AX650N | AX630C |
228
+ | ------------- | ------ | ------ |
229
+ | Whisper-Tiny | 0.08 | |
230
+ | Whisper-Base | 0.11 | 0.35 |
231
+ | Whisper-Small | 0.24 | |
232
+ | Whisper-Turbo | 0.48 | |
233
+
234
+ Python:
235
+
236
+ | Models | AX650N | AX630C |
237
+ | ------------- | ------ | ------ |
238
+ | Whisper-Tiny | 0.12 | |
239
+ | Whisper-Base | 0.16 | 0.35 |
240
+ | Whisper-Small | 0.50 | |
241
+ | Whisper-Turbo | 0.60 | |
242
+
243
+ ### Word Error Rate(Test on AIShell dataset)
244
+
245
+ | Models | AX650N | AX630C |
246
+ | ------------- | ------ | ------ |
247
+ | Whisper-Tiny | 0.24 | |
248
+ | Whisper-Base | 0.18 | |
249
+ | Whisper-Small | 0.11 | |
250
+ | Whisper-Turbo | 0.06 | |
251
+
252
+ 若要复现测试结果,请按照以下步骤:
253
+
254
+ 解压数据集:
255
+ ```
256
+ unzip datasets.zip
257
+ ```
258
+
259
+ 运行测试脚本:
260
+ ```
261
+ cd python
262
+ conda activate whisper
263
+ python test_wer.py -d aishell --gt_path ../datasets/ground_truth.txt --model_type tiny
264
+
265
+ ```
266
+
267
+ ### MEM Usage
268
+
269
+ * CMM Stands for Physical memory used by Axera modules like VDEC(Video decoder), VENC(Video encoder), NPU, etc.
270
+
271
+ Python:
272
+
273
+ | Models | CMM(MB)| OS(MB) |
274
+ | ------------- | ------ | ------ |
275
+ | Whisper-Tiny | 332 | 512 |
276
+ | Whisper-Base | 533 | 644 |
277
+ | Whisper-Small | 1106 | 906 |
278
+ | Whisper-Turbo | 2065 | 2084 |
279
+
280
+ C++:
281
+
282
+ | Models | CMM(MB)| OS(MB) |
283
+ | ------------- | ------ | ------ |
284
+ | Whisper-Tiny | 332 | 31 |
285
+ | Whisper-Base | 533 | 54 |
286
+ | Whisper-Small | 1106 | 146 |
287
+ | Whisper-Turbo | 2065 | 86 |
288
+
289
+
290
+ ## 技术讨论
291
+
292
+ - Github issues
293
+ - QQ 群: 139953715
python/main.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from whisper import Whisper
4
+ import time
5
+
6
+
7
+ def get_args():
8
+ parser = argparse.ArgumentParser(
9
+ prog="whisper",
10
+ description="Run Whisper on input audio file"
11
+ )
12
+ parser.add_argument("--wav", "-w", type=str, required=True, help="Input audio file")
13
+ parser.add_argument("--model_type", "-t", type=str, choices=["tiny", "base", "small", "large", "large-v3", "turbo"], required=True, help="model type, only support tiny, base and small currently")
14
+ parser.add_argument("--model_path", "-p", type=str, required=False, default="../models/models-ax650", help="model path for *.axmodel, tokens.txt, positional_embedding.bin")
15
+ parser.add_argument("--language", "-l", type=str, required=False, default="zh", help="Target language, support en, zh, ja, and others. See languages.py for more options.")
16
+ parser.add_argument("--task", type=str, required=False, choices=["translate", "transcribe"], default="transcribe")
17
+ parser.add_argument("--print_rtf", action="store_true", help="Print Real-Time Factor")
18
+ return parser.parse_args()
19
+
20
+
21
+ def print_args(args):
22
+ print(f"wav: {args.wav}")
23
+ print(f"model_type: {args.model_type}")
24
+ print(f"model_path: {args.model_path}")
25
+ print(f"language: {args.language}")
26
+ print(f"task: {args.task}")
27
+
28
+
29
+ def main():
30
+ args = get_args()
31
+ print_args(args)
32
+
33
+ # Check wav existence
34
+ wav_path = args.wav
35
+ assert os.path.exists(wav_path), f"{wav_path} NOT exist"
36
+
37
+ model = Whisper(args.model_type, args.model_path, args.language, args.task)
38
+
39
+
40
+
41
+ print("\n预测结果:")
42
+ start = time.time()
43
+ print(model.run(wav_path))
44
+ end = time.time()
45
+
46
+ if args.print_rtf:
47
+ import librosa
48
+ samples, sr = librosa.load(wav_path, sr=16000)
49
+ duration = len(samples) / sr
50
+ process_time = end - start
51
+ print(f"RTF: {process_time / duration}")
52
+
53
+ if __name__ == "__main__":
54
+ main()
python/requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
  numpy==1.26.4
2
  soundfile
3
- librosa
4
- zhconv
 
 
 
1
  numpy==1.26.4
2
  soundfile
3
+ librosa==0.9.1
4
+ zhconv
5
+ jiwer
6
+ tiktoken
python/test_wer.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import logging
4
+ import re
5
+ from whisper import Whisper
6
+
7
+
8
+ def setup_logging():
9
+ """配置日志系统,同时输出到控制台和文件"""
10
+ # 获取脚本所在目录
11
+ script_dir = os.path.dirname(os.path.abspath(__file__))
12
+ log_file = os.path.join(script_dir, "test_wer.log")
13
+
14
+ # 配置日志格式
15
+ log_format = '%(asctime)s - %(levelname)s - %(message)s'
16
+ date_format = '%Y-%m-%d %H:%M:%S'
17
+
18
+ # 创建logger
19
+ logger = logging.getLogger()
20
+ logger.setLevel(logging.INFO)
21
+
22
+ # 清除现有的handler
23
+ for handler in logger.handlers[:]:
24
+ logger.removeHandler(handler)
25
+
26
+ # 创建文件handler
27
+ file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8')
28
+ file_handler.setLevel(logging.INFO)
29
+ file_formatter = logging.Formatter(log_format, date_format)
30
+ file_handler.setFormatter(file_formatter)
31
+
32
+ # 创建控制台handler
33
+ console_handler = logging.StreamHandler()
34
+ console_handler.setLevel(logging.INFO)
35
+ console_formatter = logging.Formatter(log_format, date_format)
36
+ console_handler.setFormatter(console_formatter)
37
+
38
+ # 添加handler到logger
39
+ logger.addHandler(file_handler)
40
+ logger.addHandler(console_handler)
41
+
42
+ return logger
43
+
44
+
45
+ class AIShellDataset:
46
+ def __init__(self, gt_path: str):
47
+ """
48
+ 初始化数据集
49
+
50
+ Args:
51
+ json_path: voice.json文件的路径
52
+ """
53
+ self.gt_path = gt_path
54
+ self.dataset_dir = os.path.dirname(gt_path)
55
+ self.voice_dir = os.path.join(self.dataset_dir, "aishell_S0764")
56
+
57
+ # 检查必要文件和文件夹是否存在
58
+ assert os.path.exists(gt_path), f"gt文件不存在: {gt_path}"
59
+ assert os.path.exists(self.voice_dir), f"aishell_S0764文件夹不存在: {self.voice_dir}"
60
+
61
+ # 加载数据
62
+ self.data = []
63
+ with open(gt_path, 'r', encoding='utf-8') as f:
64
+ for line in f:
65
+ line = line.strip()
66
+ audio_path, gt = line.split(" ")
67
+ audio_path = os.path.join(self.voice_dir, audio_path + ".wav")
68
+ self.data.append({"audio_path": audio_path, "gt": gt})
69
+
70
+ # 使用logging而不是print
71
+ logger = logging.getLogger()
72
+ logger.info(f"加载了 {len(self.data)} 条数据")
73
+
74
+ def __iter__(self):
75
+ """返回迭代器"""
76
+ self.index = 0
77
+ return self
78
+
79
+ def __next__(self):
80
+ """返回下一个数据项"""
81
+ if self.index >= len(self.data):
82
+ raise StopIteration
83
+
84
+ item = self.data[self.index]
85
+ audio_path = item["audio_path"]
86
+ ground_truth = item["gt"]
87
+
88
+ self.index += 1
89
+ return audio_path, ground_truth
90
+
91
+ def __len__(self):
92
+ """返回数据集大小"""
93
+ return len(self.data)
94
+
95
+
96
+ class CommonVoiceDataset:
97
+ """Common Voice数据集解析器"""
98
+
99
+ def __init__(self, tsv_path: str):
100
+ """
101
+ 初始化数据集
102
+
103
+ Args:
104
+ json_path: voice.json文件的路径
105
+ """
106
+ self.tsv_path = tsv_path
107
+ self.dataset_dir = os.path.dirname(tsv_path)
108
+ self.voice_dir = os.path.join(self.dataset_dir, "clips")
109
+
110
+ # 检查必要文件和文件夹是否存在
111
+ assert os.path.exists(tsv_path), f"{tsv_path}文件不存在: {tsv_path}"
112
+ assert os.path.exists(self.voice_dir), f"voice文件夹不存在: {self.voice_dir}"
113
+
114
+ # 加载JSON数据
115
+ self.data = []
116
+ with open(tsv_path, 'r', encoding='utf-8') as f:
117
+ f.readline()
118
+ for line in f:
119
+ line = line.strip()
120
+ splits = line.split("\t")
121
+ audio_path = splits[1]
122
+ gt = splits[2]
123
+ audio_path = os.path.join(self.voice_dir, audio_path)
124
+ self.data.append({"audio_path": audio_path, "gt": gt})
125
+
126
+ # 使用logging而不是print
127
+ logger = logging.getLogger()
128
+ logger.info(f"加载了 {len(self.data)} 条数据")
129
+
130
+ def __iter__(self):
131
+ """返回迭代器"""
132
+ self.index = 0
133
+ return self
134
+
135
+ def __next__(self):
136
+ """返回下一个数据项"""
137
+ if self.index >= len(self.data):
138
+ raise StopIteration
139
+
140
+ item = self.data[self.index]
141
+ audio_path = item["audio_path"]
142
+ ground_truth = item["gt"]
143
+
144
+ self.index += 1
145
+ return audio_path, ground_truth
146
+
147
+ def __len__(self):
148
+ """返回数据集大小"""
149
+ return len(self.data)
150
+
151
+ def get_args():
152
+ parser = argparse.ArgumentParser(
153
+ prog="whisper",
154
+ description="Test WER on dataset"
155
+ )
156
+ parser.add_argument("--dataset", "-d", type=str, required=True, choices=["aishell", "common_voice"], help="Test dataset")
157
+ parser.add_argument("--gt_path", "-g", type=str, required=True, help="Test dataset ground truth file")
158
+ parser.add_argument("--max_num", type=int, default=-1, required=False, help="Maximum test data num")
159
+ parser.add_argument("--model_type", "-t", type=str, choices=["tiny", "base", "small", "large", "large-v3", "turbo"], required=True, help="model type, only support tiny, base and small currently")
160
+ parser.add_argument("--model_path", "-p", type=str, required=False, default="../models/models-ax650", help="model path for *.axmodel, tokens.txt, positional_embedding.bin")
161
+ parser.add_argument("--language", "-l", type=str, required=False, default="zh", help="Target language, support en, zh, ja, and others. See languages.py for more options.")
162
+ return parser.parse_args()
163
+
164
+
165
+ def print_args(args):
166
+ logger = logging.getLogger()
167
+ logger.info(f"dataset: {args.dataset}")
168
+ logger.info(f"gt_path: {args.gt_path}")
169
+ logger.info(f"max_num: {args.max_num}")
170
+ logger.info(f"model_type: {args.model_type}")
171
+ logger.info(f"model_path: {args.model_path}")
172
+ logger.info(f"language: {args.language}")
173
+
174
+
175
+ def min_distance(word1: str, word2: str) -> int:
176
+
177
+ row = len(word1) + 1
178
+ column = len(word2) + 1
179
+
180
+ cache = [ [0]*column for i in range(row) ]
181
+
182
+ for i in range(row):
183
+ for j in range(column):
184
+
185
+ if i ==0 and j ==0:
186
+ cache[i][j] = 0
187
+ elif i == 0 and j!=0:
188
+ cache[i][j] = j
189
+ elif j == 0 and i!=0:
190
+ cache[i][j] = i
191
+ else:
192
+ if word1[i-1] == word2[j-1]:
193
+ cache[i][j] = cache[i-1][j-1]
194
+ else:
195
+ replace = cache[i-1][j-1] + 1
196
+ insert = cache[i][j-1] + 1
197
+ remove = cache[i-1][j] + 1
198
+
199
+ cache[i][j] = min(replace, insert, remove)
200
+
201
+ return cache[row-1][column-1]
202
+
203
+
204
+ def remove_punctuation(text):
205
+ # 定义正则表达式模式,匹配所有标点符号
206
+ # 这个模式包括常见的标点符号和中文标点
207
+ pattern = r'[^\w\s]|_'
208
+
209
+ # 使用sub方法将所有匹配的标点符号替换为空字符串
210
+ cleaned_text = re.sub(pattern, '', text)
211
+
212
+ return cleaned_text
213
+
214
+
215
+ def main():
216
+ # 设置日志系统
217
+ logger = setup_logging()
218
+
219
+ args = get_args()
220
+ print_args(args)
221
+
222
+ dataset_type = args.dataset.lower()
223
+ if dataset_type == "aishell":
224
+ dataset = AIShellDataset(args.gt_path)
225
+ elif dataset_type == "common_voice":
226
+ dataset = CommonVoiceDataset(args.gt_path)
227
+ else:
228
+ raise ValueError(f"Unknown dataset type {dataset_type}")
229
+
230
+ max_num = args.max_num
231
+
232
+ # Load model
233
+ model = Whisper(args.model_type, args.model_path, args.language, "transcribe")
234
+
235
+ # Iterate over dataset
236
+ references = []
237
+ hyp = []
238
+ all_character_error_num = 0
239
+ all_character_num = 0
240
+ wer_file = open("wer.txt", "w")
241
+ max_data_num = max_num if max_num > 0 else len(dataset)
242
+ for n, (audio_path, reference) in enumerate(dataset):
243
+ hypothesis = model.run(audio_path)
244
+
245
+ hypothesis = remove_punctuation(hypothesis)
246
+ reference = remove_punctuation(reference)
247
+
248
+ character_error_num = min_distance(reference, hypothesis)
249
+ character_num = len(reference)
250
+ character_error_rate = character_error_num / character_num * 100
251
+
252
+ all_character_error_num += character_error_num
253
+ all_character_num += character_num
254
+
255
+ hyp.append(hypothesis)
256
+ references.append(reference)
257
+
258
+ line_content = f"({n+1}/{max_data_num}) {os.path.basename(audio_path)} gt: {reference} predict: {hypothesis} WER: {character_error_rate}%"
259
+ wer_file.write(line_content + "\n")
260
+ logger.info(line_content)
261
+
262
+ if n + 1 >= max_data_num:
263
+ break
264
+
265
+ total_character_error_rate = all_character_error_num / all_character_num * 100
266
+
267
+ logger.info(f"Total WER: {total_character_error_rate}%")
268
+ wer_file.write(f"Total WER: {total_character_error_rate}%")
269
+ wer_file.close()
270
+
271
+ if __name__ == "__main__":
272
+ main()
python/whisper.py CHANGED
@@ -1,240 +1,224 @@
1
- import argparse
2
  import axengine as axe
3
  import numpy as np
4
  import librosa
5
  import os
6
- from typing import Tuple
7
- import soundfile as sf
8
- import base64
 
9
  import zhconv
10
- import time
11
- from languages import WHISPER_LANGUAGES
12
 
13
 
14
- WHISPER_N_MELS = 80
15
- WHISPER_SAMPLE_RATE = 16000
16
- WHISPER_N_FFT = 480
17
- WHISPER_HOP_LENGTH = 160
18
 
19
- WHISPER_SOT = 50258
20
- WHISPER_EOT = 50257
21
- WHISPER_BLANK = 220
22
- WHISPER_NO_TIMESTAMPS = 50363
23
- WHISPER_NO_SPEECH = 50362
24
- WHISPER_TRANSLATE = 50358
25
- WHISPER_TRANSCRIBE = 50359
26
- WHISPER_VOCAB_SIZE = 51865
27
- WHISPER_N_TEXT_CTX = 448
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- NEG_INF = float("-inf")
30
- SOT_SEQUENCE = np.array([WHISPER_SOT,WHISPER_SOT + 1 + tuple(WHISPER_LANGUAGES).index("zh"),WHISPER_TRANSCRIBE,WHISPER_NO_TIMESTAMPS], dtype=np.int32)
31
- WHISPER_N_TEXT_STATE_MAP = {
32
- "tiny": 384,
33
- "base": 512,
34
- "small": 768
35
- }
36
-
37
-
38
- def get_args():
39
- parser = argparse.ArgumentParser(
40
- prog="whisper",
41
- description="Run Whisper on input audio file"
42
- )
43
- parser.add_argument("--wav", "-w", type=str, required=True, help="Input audio file")
44
- parser.add_argument("--model_type", "-t", type=str, choices=["tiny", "base", "small"], required=True, help="model type, only support tiny, base and small currently")
45
- parser.add_argument("--model_path", "-p", type=str, required=False, default="../models", help="model path for *.axmodel, tokens.txt, positional_embedding.bin")
46
- parser.add_argument("--language", "-l", type=str, required=False, default="zh", help="Target language, support en, zh, ja, and others. See languages.py for more options.")
47
- return parser.parse_args()
48
-
49
-
50
- def print_args(args):
51
- print(f"wav: {args.wav}")
52
- print(f"model_type: {args.model_type}")
53
- print(f"model_path: {args.model_path}")
54
- print(f"language: {args.language}")
55
-
56
-
57
- def load_audio(filename: str) -> Tuple[np.ndarray, int]:
58
- data, sample_rate = sf.read(
59
- filename,
60
- always_2d=True,
61
- dtype="float32",
62
- )
63
- data = data[:, 0] # use only the first channel
64
- data = librosa.resample(data, orig_sr=sample_rate, target_sr=WHISPER_SAMPLE_RATE)
65
- samples = np.ascontiguousarray(data)
66
- return samples, sample_rate
67
-
68
-
69
- def load_models(model_path, model_type):
70
- encoder_path = f"{model_type}-encoder.axmodel"
71
- decoder_main_path = f"{model_type}-decoder-main.axmodel"
72
- decoder_loop_path = f"{model_type}-decoder-loop.axmodel"
73
- pe_path = f"{model_type}-positional_embedding.bin"
74
- token_path = f"{model_type}-tokens.txt"
75
-
76
- required_files = [os.path.join(model_path, i) for i in (encoder_path, decoder_main_path, decoder_loop_path, pe_path, token_path)]
77
- # Check file existence
78
- for i, file_path in enumerate(required_files):
79
- assert os.path.exists(file_path), f"{file_path} NOT exist"
80
-
81
- # Load encoder
82
- encoder = axe.InferenceSession(required_files[0])
83
- # Load decoder main
84
- decoder_main = axe.InferenceSession(required_files[1])
85
- # Load decoder loop
86
- decoder_loop = axe.InferenceSession(required_files[2])
87
- # Load position embedding
88
- pe = np.fromfile(required_files[3], dtype=np.float32)
89
- # Load tokens
90
- tokens = []
91
- with open(required_files[4], "r") as f:
92
- for line in f:
93
- line = line.strip()
94
- tokens.append(line.split(" ")[0])
95
-
96
- return encoder, decoder_main, decoder_loop, pe, tokens
97
-
98
-
99
- def compute_feature(wav_path, n_mels = WHISPER_N_MELS, padding = 480000):
100
- audio, sr = load_audio(wav_path)
101
-
102
- audio = np.concatenate((audio, np.zeros((padding,), dtype=np.float32)), axis=-1)
103
-
104
- mel = librosa.feature.melspectrogram(y=audio, sr=sr, n_fft=WHISPER_N_FFT, hop_length=WHISPER_HOP_LENGTH, window="hann", center=True, pad_mode="reflect", power=2.0, n_mels=n_mels)
105
- log_spec = np.log10(np.maximum(mel, 1e-10))
106
- log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
107
- mel = (log_spec + 4.0) / 4.0
108
-
109
- # We pad 1500 frames at the end so that it is able to detect eot
110
- # You can use another value instead of 1500.
111
- # mel = np.concatenate((mel, np.zeros((n_mels, 1500), dtype=np.float32)), axis=-1)
112
-
113
- target = 3000
114
- if mel.shape[1] > target:
115
- # -50 so that there are some zero tail paddings.
116
- mel = mel[:, : target]
117
- mel[:, -50:] = 0
118
-
119
- # We don't need to pad it to 30 seconds now!
120
- if mel.shape[1] < target:
121
- mel = np.concatenate((mel, np.zeros((n_mels, target - mel.shape[1]), dtype=np.float32)), axis=-1)
122
-
123
- return mel
124
-
125
-
126
- def supress_tokens(logits, is_initial):
127
- if is_initial:
128
- logits[WHISPER_EOT] = NEG_INF
129
- logits[WHISPER_BLANK] = NEG_INF
130
-
131
- logits[WHISPER_NO_TIMESTAMPS] = NEG_INF
132
- logits[WHISPER_SOT] = NEG_INF
133
- logits[WHISPER_NO_SPEECH] = NEG_INF
134
- logits[WHISPER_TRANSLATE] = NEG_INF
135
- return logits
136
-
137
-
138
- def choose_language(lang):
139
- if lang not in WHISPER_LANGUAGES.keys():
140
- raise Exception(f"Unknown language: {lang}. Check languages.py for correct options.")
141
- SOT_SEQUENCE[1] = WHISPER_SOT + 1 + tuple(WHISPER_LANGUAGES.keys()).index(lang)
142
-
143
-
144
- def main():
145
- args = get_args()
146
- print_args(args)
147
-
148
- # Check wav existence
149
- wav_path = args.wav
150
- assert os.path.exists(wav_path), f"{wav_path} NOT exist"
151
-
152
- # Choose language
153
- choose_language(args.language)
154
-
155
- # Load models and other stuff
156
- start = time.time()
157
- encoder, decoder_main, decoder_loop, pe, token_table = load_models(args.model_path, args.model_type)
158
- print(f"Load models take {(time.time() - start) * 1000}ms")
159
- WHISPER_N_TEXT_STATE = WHISPER_N_TEXT_STATE_MAP[args.model_type]
160
-
161
- # Preprocess
162
- start = time.time()
163
- mel = compute_feature(wav_path, n_mels=WHISPER_N_MELS)
164
- print(f"Preprocess wav take {(time.time() - start) * 1000}ms")
165
- # mel.tofile("mel.bin")
166
-
167
- # Run encoder
168
- start = time.time()
169
- x = encoder.run(None, input_feed={"mel": mel[None, ...]})
170
- n_layer_cross_k, n_layer_cross_v = x
171
- print(f"Run encoder take {(time.time() - start) * 1000}ms")
172
-
173
- # n_layer_cross_k.tofile("n_layer_cross_k.bin")
174
- # n_layer_cross_v.tofile("n_layer_cross_v.bin")
175
-
176
- # Run decoder_main
177
- start = time.time()
178
- x = decoder_main.run(None, input_feed={
179
- "tokens": SOT_SEQUENCE[None, ...],
180
- "n_layer_cross_k": n_layer_cross_k,
181
- "n_layer_cross_v": n_layer_cross_v
182
- })
183
- logits, n_layer_self_k_cache, n_layer_self_v_cache = x
184
- print(f"Run decoder_main take {(time.time() - start) * 1000}ms")
185
-
186
- # Decode token
187
- logits = logits[0, -1, :]
188
- logits = supress_tokens(logits, is_initial=True)
189
- # logits.tofile("logits.bin")
190
- max_token_id = np.argmax(logits)
191
- output_tokens = []
192
- print(f"First token: {max_token_id}")
193
-
194
- # Position embedding offset
195
- offset = SOT_SEQUENCE.shape[0]
196
-
197
- # Autoregressively run decoder until token meets EOT
198
- for i in range(WHISPER_N_TEXT_CTX - SOT_SEQUENCE.shape[0]):
199
- if max_token_id == WHISPER_EOT:
200
- break
201
-
202
- output_tokens.append(max_token_id)
203
-
204
- mask = np.zeros((WHISPER_N_TEXT_CTX,), dtype=np.float32)
205
- mask[: WHISPER_N_TEXT_CTX - offset - 1] = NEG_INF
206
-
207
- # Run decoder_loop
208
- start = time.time()
209
- x = decoder_loop.run(None, input_feed={
210
- "tokens": np.array([[output_tokens[-1]]], dtype=np.int32),
211
- "in_n_layer_self_k_cache": n_layer_self_k_cache,
212
- "in_n_layer_self_v_cache": n_layer_self_v_cache,
213
  "n_layer_cross_k": n_layer_cross_k,
214
- "n_layer_cross_v": n_layer_cross_v,
215
- "positional_embedding": pe[offset * WHISPER_N_TEXT_STATE : (offset + 1) * WHISPER_N_TEXT_STATE][None, ...],
216
- "mask": mask
217
  })
218
  logits, n_layer_self_k_cache, n_layer_self_v_cache = x
219
- print(f"Run decoder_loop take {(time.time() - start) * 1000}ms")
220
 
221
  # Decode token
222
- offset += 1
223
- logits = supress_tokens(logits.flatten(), is_initial=False)
 
224
  max_token_id = np.argmax(logits)
225
-
226
- print(f"Iter {i} \t Token: {max_token_id}")
227
-
228
- s = b""
229
- for i in output_tokens:
230
- s += base64.b64decode(token_table[i])
231
- # print(s.decode().strip())
232
- pd = s.decode().strip()
233
- if args.language == "zh":
234
- pd = zhconv.convert(pd, 'zh-hans')
235
-
236
- print(f"Result: {pd}")
237
-
238
-
239
- if __name__ == "__main__":
240
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import axengine as axe
2
  import numpy as np
3
  import librosa
4
  import os
5
+ from typing import Union
6
+ from whisper_tokenizer import *
7
+ import json
8
+ from dataclasses import dataclass
9
  import zhconv
 
 
10
 
11
 
12
+ NEG_INF = float("-inf")
 
 
 
13
 
14
+ @dataclass
15
+ class WhisperConfig:
16
+ n_mels : int = 0
17
+ sample_rate : int = 0
18
+ n_fft : int = 0
19
+ hop_length : int = 0
20
+
21
+ sot : int = 0
22
+ eot : int = 0
23
+ blank_id : int = 0
24
+ no_timestamps : int = 0
25
+ no_speech : int = 0
26
+ translate : int = 0
27
+ transcribe : int = 0
28
+ n_vocab : int = 0
29
+ n_text_ctx : int = 0
30
+ n_text_state : int = 0
31
+
32
+ sot_sequence : np.ndarray = field(default_factory=lambda: np.array([0,0,0,0], dtype=np.int32))
33
+
34
+
35
+ class Whisper:
36
+ def __init__(self, model_type: str, model_path: str, language: str, task: str):
37
+ assert task in ["translate", "transcribe"]
38
+
39
+ self.language = language
40
+ self.task = task
41
+ self.encoder, self.decoder_main, self.decoder_loop, self.pe, self.tokenizer, model_config = \
42
+ self.load_model(model_type, model_path, language, task)
43
+ self.config = self.load_config(model_config)
44
+
45
+
46
+ def load_model(self, model_type, model_path, language, task):
47
+ encoder_path = f"{model_type}/{model_type}-encoder.axmodel"
48
+ decoder_main_path = f"{model_type}/{model_type}-decoder-main.axmodel"
49
+ decoder_loop_path = f"{model_type}/{model_type}-decoder-loop.axmodel"
50
+ pe_path = f"{model_type}/{model_type}-positional_embedding.bin"
51
+ model_config_file = f"{model_type}/{model_type}_config.json"
52
+
53
+ required_files = [os.path.join(model_path, i) for i in (encoder_path, decoder_main_path, decoder_loop_path, pe_path, model_config_file)]
54
+ # Check file existence
55
+ for i, file_path in enumerate(required_files):
56
+ assert os.path.exists(file_path), f"{file_path} NOT exist"
57
+
58
+ # Load encoder
59
+ encoder = axe.InferenceSession(required_files[0], providers=['AxEngineExecutionProvider'])
60
+ # Load decoder main
61
+ decoder_main = axe.InferenceSession(required_files[1], providers=['AxEngineExecutionProvider'])
62
+ # Load decoder loop
63
+ decoder_loop = axe.InferenceSession(required_files[2], providers=['AxEngineExecutionProvider'])
64
+ # Load position embedding
65
+ pe = np.fromfile(required_files[3], dtype=np.float32)
66
+ # Load tokens
67
+ model_config = json.load(open(required_files[4], "r"))
68
+ model_config["all_language_tokens"] = [int(i) for i in model_config["all_language_tokens"].split(",")]
69
+ model_config["all_language_codes"] = [i for i in model_config["all_language_codes"].split(",")]
70
+ tokenizer = get_tokenizer(
71
+ model_config["is_multilingual"],
72
+ num_languages=len(model_config["all_language_codes"]),
73
+ language=language,
74
+ task=task,
75
+ )
76
+
77
+ return encoder, decoder_main, decoder_loop, pe, tokenizer, model_config
78
+
79
 
80
+ def load_config(self, model_config):
81
+ config = WhisperConfig
82
+ config.n_mels = model_config["n_mels"]
83
+ config.sample_rate = 16000
84
+ config.n_fft = 480
85
+ config.hop_length = 160
86
+
87
+ config.sot = model_config["sot"]
88
+ config.eot = model_config["eot"]
89
+ config.blank_id = model_config["blank_id"]
90
+ config.no_timestamps = model_config["no_timestamps"]
91
+ config.no_speech = model_config["no_speech"]
92
+ config.translate = model_config["translate"]
93
+ config.transcribe = model_config["transcribe"]
94
+ config.n_vocab = model_config["n_vocab"]
95
+ config.n_text_ctx = model_config["n_text_ctx"]
96
+ config.n_text_state = model_config["n_text_state"]
97
+
98
+ lang_token = model_config["all_language_tokens"][model_config["all_language_codes"].index(self.language)]
99
+ task_token = config.transcribe if self.task == "transcribe" else config.translate
100
+ config.sot_sequence = np.array([config.sot, lang_token, task_token, config.no_timestamps], dtype=np.int32)
101
+
102
+ return config
103
+
104
+
105
+ def load_audio(self, audio: str):
106
+ data, sample_rate = librosa.load(audio, sr=self.config.sample_rate)
107
+ samples = np.ascontiguousarray(data)
108
+ return samples, sample_rate
109
+
110
+
111
+ def compute_feature(self, audio: np.ndarray, padding = 480000):
112
+ if padding > 0:
113
+ audio = np.concatenate((audio, np.zeros((padding,), dtype=np.float32)), axis=-1)
114
+
115
+ mel = librosa.feature.melspectrogram(y=audio,
116
+ sr=self.config.sample_rate,
117
+ n_fft=self.config.n_fft,
118
+ hop_length=self.config.hop_length,
119
+ window="hann",
120
+ center=True,
121
+ pad_mode="reflect",
122
+ power=2.0,
123
+ n_mels=self.config.n_mels)
124
+
125
+ log_spec = np.log10(np.maximum(mel, 1e-10))
126
+ log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
127
+ mel = (log_spec + 4.0) / 4.0
128
+
129
+ target = 3000
130
+ if mel.shape[1] > target:
131
+ # -50 so that there are some zero tail paddings.
132
+ mel = mel[:, : target]
133
+ mel[:, -50:] = 0
134
+
135
+ # We don't need to pad it to 30 seconds now!
136
+ if mel.shape[1] < target:
137
+ mel = np.concatenate((mel, np.zeros((self.config.n_mels, target - mel.shape[1]), dtype=np.float32)), axis=-1)
138
+
139
+ return mel
140
+
141
+
142
+ def supress_tokens(self, logits, is_initial):
143
+ if is_initial:
144
+ logits[self.config.eot] = NEG_INF
145
+ logits[self.config.blank_id] = NEG_INF
146
+
147
+ logits[self.config.no_timestamps] = NEG_INF
148
+ logits[self.config.sot] = NEG_INF
149
+ logits[self.config.no_speech] = NEG_INF
150
+
151
+ if self.task == "transcribe":
152
+ logits[self.config.translate] = NEG_INF
153
+ else:
154
+ logits[self.config.transcribe] = NEG_INF
155
+ return logits
156
+
157
+
158
+ def run(self, audio: Union[str, np.ndarray]) -> str:
159
+ if isinstance(audio, str):
160
+ audio, sample_rate = self.load_audio(audio)
161
+
162
+ mel = self.compute_feature(audio)
163
+
164
+ # Run encoder
165
+ x = self.encoder.run(None, input_feed={"mel": mel[None, ...]})
166
+ n_layer_cross_k, n_layer_cross_v = x
167
+
168
+ # Run decoder_main
169
+ x = self.decoder_main.run(None, input_feed={
170
+ "tokens": self.config.sot_sequence[None, ...],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  "n_layer_cross_k": n_layer_cross_k,
172
+ "n_layer_cross_v": n_layer_cross_v
 
 
173
  })
174
  logits, n_layer_self_k_cache, n_layer_self_v_cache = x
 
175
 
176
  # Decode token
177
+ logits = logits[0, -1, :]
178
+ logits = self.supress_tokens(logits, is_initial=True)
179
+ # logits.tofile("logits.bin")
180
  max_token_id = np.argmax(logits)
181
+ output_tokens = []
182
+
183
+ # Position embedding offset
184
+ offset = self.config.sot_sequence.shape[0]
185
+
186
+ # Autoregressively run decoder until token meets EOT
187
+ for i in range(self.config.n_text_ctx - self.config.sot_sequence.shape[0]):
188
+ if max_token_id >= self.config.eot:
189
+ break
190
+
191
+ output_tokens.append(max_token_id)
192
+
193
+ mask = np.zeros((self.config.n_text_ctx,), dtype=np.float32)
194
+ mask[: self.config.n_text_ctx - offset - 1] = NEG_INF
195
+
196
+ # Run decoder_loop
197
+ x = self.decoder_loop.run(None, input_feed={
198
+ "tokens": np.array([[output_tokens[-1]]], dtype=np.int32),
199
+ "in_n_layer_self_k_cache": n_layer_self_k_cache,
200
+ "in_n_layer_self_v_cache": n_layer_self_v_cache,
201
+ "n_layer_cross_k": n_layer_cross_k,
202
+ "n_layer_cross_v": n_layer_cross_v,
203
+ "positional_embedding": self.pe[offset * self.config.n_text_state : (offset + 1) * self.config.n_text_state][None, ...],
204
+ "mask": mask
205
+ })
206
+ logits, n_layer_self_k_cache, n_layer_self_v_cache = x
207
+
208
+ # Decode token
209
+ offset += 1
210
+ logits = self.supress_tokens(logits.flatten(), is_initial=False)
211
+ max_token_id = np.argmax(logits)
212
+
213
+ text = self.tokenizer.decode(output_tokens)
214
+
215
+ if self.language == "zh":
216
+ try:
217
+ sim_zh = zhconv.convert(text, 'zh-hans')
218
+ return sim_zh
219
+ except:
220
+ return text
221
+
222
+ return text
223
+
224
+
python/whisper_cli.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+ def transcribe_audio(
4
+ server_url: str,
5
+ wav_path: str,
6
+ model_type: str = "tiny",
7
+ model_path: str = "../models/models-ax650",
8
+ language: str = "zh",
9
+ task: str = "transcribe"
10
+ ):
11
+ url = f"{server_url.rstrip('/')}/asr"
12
+
13
+ files = {
14
+ "wav": open(wav_path, "rb"),
15
+ }
16
+
17
+ data = {
18
+ "model_type": model_type,
19
+ "model_path": model_path,
20
+ "language": language,
21
+ "task": task,
22
+ }
23
+
24
+ print(f"Sending request to: {url}")
25
+
26
+ response = requests.post(url, files=files, data=data)
27
+ if response.status_code != 200:
28
+ print("❌ Error:", response.text)
29
+ return None
30
+
31
+ result = response.json()
32
+ print("服务器返回结果:")
33
+ print(result)
34
+
35
+ return result
36
+
37
+
38
+ if __name__ == "__main__":
39
+ # 你的服务器地址
40
+ SERVER = "http://127.0.0.1:8000"
41
+
42
+ # 本地 wav 文件路径
43
+ WAV = "../demo.wav"
44
+
45
+ transcribe_audio(SERVER, WAV)
python/whisper_onnx.py DELETED
@@ -1,239 +0,0 @@
1
- import argparse
2
- import onnxruntime as ort
3
- import numpy as np
4
- import librosa
5
- import os
6
- from typing import Tuple
7
- import soundfile as sf
8
- import base64
9
- import zhconv
10
- import time
11
- import torch
12
- from torch.nn import functional as F
13
- from languages import WHISPER_LANGUAGES
14
-
15
-
16
- WHISPER_N_MELS = 80
17
- WHISPER_SAMPLE_RATE = 16000
18
- WHISPER_N_FFT = 480
19
- WHISPER_HOP_LENGTH = 160
20
-
21
- WHISPER_SOT = 50258
22
- WHISPER_EOT = 50257
23
- WHISPER_BLANK = 220
24
- WHISPER_NO_TIMESTAMPS = 50363
25
- WHISPER_NO_SPEECH = 50362
26
- WHISPER_TRANSLATE = 50358
27
- WHISPER_TRANSCRIBE = 50359
28
- WHISPER_VOCAB_SIZE = 51865
29
- WHISPER_N_TEXT_CTX = 448
30
-
31
- NEG_INF = float("-inf")
32
- SOT_SEQUENCE = np.array([WHISPER_SOT,WHISPER_SOT + 1 + tuple(WHISPER_LANGUAGES).index("zh"),WHISPER_TRANSCRIBE,WHISPER_NO_TIMESTAMPS], dtype=np.int64)
33
- WHISPER_N_TEXT_STATE_MAP = {
34
- "tiny": 384,
35
- "base": 512,
36
- "small": 768
37
- }
38
-
39
-
40
- def get_args():
41
- parser = argparse.ArgumentParser(
42
- prog="whisper",
43
- description="Run Whisper on input audio file"
44
- )
45
- parser.add_argument("--wav", "-w", type=str, required=True, help="Input audio file")
46
- parser.add_argument("--model_type", "-t", type=str, choices=["tiny", "base", "small"], required=True, help="model type, only support tiny/base/small currently")
47
- parser.add_argument("--model_path", "-p", type=str, required=False, default="../models", help="model path for *.axmodel, tokens.txt, positional_embedding.bin")
48
- parser.add_argument("--language", "-l", type=str, required=False, default="zh", help="Target language, support en, zh, ja, and others. See languages.py for more options.")
49
- return parser.parse_args()
50
-
51
-
52
- def print_args(args):
53
- print(f"wav: {args.wav}")
54
- print(f"model_type: {args.model_type}")
55
- print(f"model_path: {args.model_path}")
56
- print(f"language: {args.language}")
57
-
58
-
59
- def load_audio(filename: str) -> Tuple[np.ndarray, int]:
60
- data, sample_rate = sf.read(
61
- filename,
62
- always_2d=True,
63
- dtype="float32",
64
- )
65
- data = data[:, 0] # use only the first channel
66
- data = librosa.resample(data, orig_sr=sample_rate, target_sr=WHISPER_SAMPLE_RATE)
67
- samples = np.ascontiguousarray(data)
68
- return samples, sample_rate
69
-
70
-
71
- def load_models(model_path, model_type):
72
- encoder_path = f"{model_type}-encoder.onnx"
73
- decoder_main_path = f"{model_type}-decoder-main.onnx"
74
- decoder_loop_path = f"{model_type}-decoder-loop.onnx"
75
- pe_path = f"{model_type}-positional_embedding.bin"
76
- token_path = f"{model_type}-tokens.txt"
77
-
78
- required_files = [os.path.join(model_path, i) for i in (encoder_path, decoder_main_path, decoder_loop_path, pe_path, token_path)]
79
- # Check file existence
80
- for i, file_path in enumerate(required_files):
81
- assert os.path.exists(file_path), f"{file_path} NOT exist"
82
-
83
- # Load encoder
84
- encoder = ort.InferenceSession(required_files[0], providers=['CPUExecutionProvider'])
85
- # Load decoder main
86
- decoder_main = ort.InferenceSession(required_files[1], providers=['CPUExecutionProvider'])
87
- # Load decoder loop
88
- decoder_loop = ort.InferenceSession(required_files[2], providers=['CPUExecutionProvider'])
89
- # Load position embedding
90
- pe = np.fromfile(required_files[3], dtype=np.float32)
91
- # Load tokens
92
- tokens = []
93
- with open(required_files[4], "r") as f:
94
- for line in f:
95
- line = line.strip()
96
- tokens.append(line.split(" ")[0])
97
-
98
- return encoder, decoder_main, decoder_loop, pe, tokens
99
-
100
-
101
- def compute_feature(wav_path, n_mels = WHISPER_N_MELS, padding = 480000):
102
- audio, sr = load_audio(wav_path)
103
-
104
- audio = np.concatenate((audio, np.zeros((padding,), dtype=np.float32)), axis=-1)
105
-
106
- mel = librosa.feature.melspectrogram(y=audio, sr=sr, n_fft=WHISPER_N_FFT, hop_length=WHISPER_HOP_LENGTH, window="hann", center=True, pad_mode="reflect", power=2.0, n_mels=n_mels)
107
- log_spec = np.log10(np.maximum(mel, 1e-10))
108
- log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
109
- mel = (log_spec + 4.0) / 4.0
110
-
111
- # We pad 1500 frames at the end so that it is able to detect eot
112
- # You can use another value instead of 1500.
113
- # mel = np.concatenate((mel, np.zeros((n_mels, 1500), dtype=np.float32)), axis=-1)
114
-
115
- target = 3000
116
- if mel.shape[1] > target:
117
- # -50 so that there are some zero tail paddings.
118
- mel = mel[:, : target]
119
- mel[:, -50:] = 0
120
-
121
- # We don't need to pad it to 30 seconds now!
122
- if mel.shape[1] < target:
123
- mel = np.concatenate((mel, np.zeros((n_mels, target - mel.shape[1]), dtype=np.float32)), axis=-1)
124
-
125
- return mel
126
-
127
-
128
- def supress_tokens(logits, is_initial):
129
- if is_initial:
130
- logits[WHISPER_EOT] = NEG_INF
131
- logits[WHISPER_BLANK] = NEG_INF
132
-
133
- logits[WHISPER_NO_TIMESTAMPS] = NEG_INF
134
- logits[WHISPER_SOT] = NEG_INF
135
- logits[WHISPER_NO_SPEECH] = NEG_INF
136
- logits[WHISPER_TRANSLATE] = NEG_INF
137
- return logits
138
-
139
-
140
- def choose_language(lang):
141
- if lang not in WHISPER_LANGUAGES.keys():
142
- raise Exception(f"Unknown language: {lang}. Check languages.py for correct options.")
143
- SOT_SEQUENCE[1] = WHISPER_SOT + 1 + tuple(WHISPER_LANGUAGES.keys()).index(lang)
144
-
145
-
146
- def main():
147
- args = get_args()
148
- print_args(args)
149
-
150
- # Check wav existence
151
- wav_path = args.wav
152
- assert os.path.exists(wav_path), f"{wav_path} NOT exist"
153
-
154
- # Choose language
155
- choose_language(args.language)
156
-
157
- # Load models and other stuff
158
- encoder, decoder_main, decoder_loop, pe, token_table = load_models(args.model_path, args.model_type)
159
- WHISPER_N_TEXT_STATE = WHISPER_N_TEXT_STATE_MAP[args.model_type]
160
-
161
- # Preprocess
162
- mel = compute_feature(wav_path, n_mels=WHISPER_N_MELS)
163
- # mel.tofile("mel.bin")
164
- # mel = np.load("../mel.npy")[..., :3000]
165
-
166
- # Run encoder
167
- start = time.time()
168
- x = encoder.run(None, input_feed={"mel": mel[None, ...]})
169
- n_layer_cross_k, n_layer_cross_v = x
170
- print(f"Run encoder take {(time.time() - start) * 1000}ms")
171
-
172
- # n_layer_cross_k.tofile("n_layer_cross_k.bin")
173
- # n_layer_cross_v.tofile("n_layer_cross_v.bin")
174
-
175
- # Run decoder_main
176
- start = time.time()
177
- x = decoder_main.run(None, input_feed={
178
- "tokens": SOT_SEQUENCE[None, ...],
179
- "n_layer_cross_k": n_layer_cross_k,
180
- "n_layer_cross_v": n_layer_cross_v
181
- })
182
- logits, n_layer_self_k_cache, n_layer_self_v_cache = x
183
- print(f"Run decoder_main take {(time.time() - start) * 1000}ms")
184
-
185
- # Decode token
186
- logits = logits[0, -1, :]
187
- logits = supress_tokens(logits, is_initial=True)
188
- # logits.tofile("logits.bin")
189
- max_token_id = np.argmax(logits)
190
- output_tokens = []
191
- print(f"First token: {max_token_id}")
192
-
193
- # Position embedding offset
194
- offset = SOT_SEQUENCE.shape[0]
195
-
196
- # Autoregressively run decoder until token meets EOT
197
- for i in range(WHISPER_N_TEXT_CTX - SOT_SEQUENCE.shape[0]):
198
- if max_token_id == WHISPER_EOT:
199
- break
200
-
201
- output_tokens.append(max_token_id)
202
-
203
- mask = np.zeros((WHISPER_N_TEXT_CTX,), dtype=np.float32)
204
- mask[: WHISPER_N_TEXT_CTX - offset - 1] = NEG_INF
205
-
206
- # Run decoder_loop
207
- start = time.time()
208
- x = decoder_loop.run(None, input_feed={
209
- "tokens": np.array([[output_tokens[-1]]], dtype=np.int64),
210
- "in_n_layer_self_k_cache": n_layer_self_k_cache,
211
- "in_n_layer_self_v_cache": n_layer_self_v_cache,
212
- "n_layer_cross_k": n_layer_cross_k,
213
- "n_layer_cross_v": n_layer_cross_v,
214
- "positional_embedding": pe[offset * WHISPER_N_TEXT_STATE : (offset + 1) * WHISPER_N_TEXT_STATE][None, ...],
215
- "mask": mask
216
- })
217
- logits, n_layer_self_k_cache, n_layer_self_v_cache = x
218
- print(f"Run decoder_loop take {(time.time() - start) * 1000}ms")
219
-
220
- # Decode token
221
- offset += 1
222
- logits = supress_tokens(logits.flatten(), is_initial=False)
223
- max_token_id = np.argmax(logits)
224
-
225
- print(f"Iter {i} \t Token: {max_token_id}")
226
-
227
- s = b""
228
- for i in output_tokens:
229
- s += base64.b64decode(token_table[i])
230
- # print(s.decode().strip())
231
- pd = s.decode().strip()
232
- if args.language == "zh":
233
- pd = zhconv.convert(pd, 'zh-hans')
234
-
235
- print(f"Result: {pd}")
236
-
237
-
238
- if __name__ == "__main__":
239
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
python/whisper_svr.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import tempfile
5
+ from http.server import BaseHTTPRequestHandler, HTTPServer
6
+ from urllib.parse import parse_qs
7
+
8
+ from whisper import Whisper
9
+ import cgi
10
+
11
+
12
+ # 模型缓存:避免每次请求都重新加载
13
+ _model_cache = {}
14
+
15
+ def get_model(model_type, model_path, language, task):
16
+ key = (model_type, model_path, language, task)
17
+ if key not in _model_cache:
18
+ print(f"Loading model: type={model_type}, path={model_path}, lang={language}, task={task}")
19
+ _model_cache[key] = Whisper(model_type, model_path, language, task)
20
+ return _model_cache[key]
21
+
22
+
23
+ class WhisperHandler(BaseHTTPRequestHandler):
24
+
25
+ def _send_json(self, obj, status=200):
26
+ data = json.dumps(obj, ensure_ascii=False).encode("utf-8")
27
+ self.send_response(status)
28
+ self.send_header("Content-Type", "application/json; charset=utf-8")
29
+ self.send_header("Content-Length", str(len(data)))
30
+ self.end_headers()
31
+ self.wfile.write(data)
32
+
33
+ def do_GET(self):
34
+ if self.path == "/health":
35
+ self._send_json({"status": "ok"})
36
+ else:
37
+ self._send_json({"error": "not found"}, 404)
38
+
39
+ def do_POST(self):
40
+ if self.path != "/asr":
41
+ self._send_json({"error": "not found"}, 404)
42
+ return
43
+
44
+ # 解析 multipart/form-data
45
+ content_type = self.headers.get('Content-Type')
46
+ if not content_type:
47
+ self._send_json({"error": "Missing Content-Type"}, 400)
48
+ return
49
+
50
+ ctype, pdict = cgi.parse_header(content_type)
51
+
52
+ if ctype != 'multipart/form-data':
53
+ self._send_json({"error": "Only multipart/form-data is supported"}, 400)
54
+ return
55
+
56
+ pdict['boundary'] = bytes(pdict['boundary'], "utf-8")
57
+ pdict['CONTENT-LENGTH'] = int(self.headers['Content-Length'])
58
+
59
+ form = cgi.parse_multipart(self.rfile, pdict)
60
+
61
+ # 必须包含 wav 文件
62
+ if "wav" not in form:
63
+ self._send_json({"error": "Field 'wav' is required"}, 400)
64
+ return
65
+
66
+ # 获取参数(如果缺省则使用默认值)
67
+ model_type = form.get("model_type", ["tiny"])[0]
68
+ model_path = form.get("model_path", ["../models/models-ax650"])[0]
69
+ language = form.get("language", ["zh"])[0]
70
+ task = form.get("task", ["transcribe"])[0]
71
+
72
+ if task not in ("transcribe", "translate"):
73
+ self._send_json({"error": "task must be 'transcribe' or 'translate'"}, 400)
74
+ return
75
+
76
+ wav_bytes = form["wav"][0]
77
+
78
+ # 写入临时文件
79
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
80
+ tmp.write(wav_bytes)
81
+ wav_path = tmp.name
82
+
83
+ # 加载模型并运行
84
+ try:
85
+ model = get_model(model_type, model_path, language, task)
86
+ result_text = model.run(wav_path)
87
+ except Exception as e:
88
+ self._send_json({"error": str(e)}, 500)
89
+ return
90
+ finally:
91
+ if os.path.exists(wav_path):
92
+ os.remove(wav_path)
93
+
94
+ self._send_json({"text": result_text})
95
+
96
+
97
+ if __name__ == "__main__":
98
+ parser = argparse.ArgumentParser(description="Whisper Server")
99
+ parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
100
+ args = parser.parse_args()
101
+ port = args.port
102
+ server = HTTPServer(("0.0.0.0", port), WhisperHandler)
103
+ print(f"Server started at http://0.0.0.0:{port}")
104
+ server.serve_forever()
python/whisper_tokenizer.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+ import string
4
+ from dataclasses import dataclass, field
5
+ from functools import cached_property, lru_cache
6
+ from typing import Dict, List, Optional, Tuple
7
+
8
+ import tiktoken
9
+
10
+ LANGUAGES = {
11
+ "en": "english",
12
+ "zh": "chinese",
13
+ "de": "german",
14
+ "es": "spanish",
15
+ "ru": "russian",
16
+ "ko": "korean",
17
+ "fr": "french",
18
+ "ja": "japanese",
19
+ "pt": "portuguese",
20
+ "tr": "turkish",
21
+ "pl": "polish",
22
+ "ca": "catalan",
23
+ "nl": "dutch",
24
+ "ar": "arabic",
25
+ "sv": "swedish",
26
+ "it": "italian",
27
+ "id": "indonesian",
28
+ "hi": "hindi",
29
+ "fi": "finnish",
30
+ "vi": "vietnamese",
31
+ "he": "hebrew",
32
+ "uk": "ukrainian",
33
+ "el": "greek",
34
+ "ms": "malay",
35
+ "cs": "czech",
36
+ "ro": "romanian",
37
+ "da": "danish",
38
+ "hu": "hungarian",
39
+ "ta": "tamil",
40
+ "no": "norwegian",
41
+ "th": "thai",
42
+ "ur": "urdu",
43
+ "hr": "croatian",
44
+ "bg": "bulgarian",
45
+ "lt": "lithuanian",
46
+ "la": "latin",
47
+ "mi": "maori",
48
+ "ml": "malayalam",
49
+ "cy": "welsh",
50
+ "sk": "slovak",
51
+ "te": "telugu",
52
+ "fa": "persian",
53
+ "lv": "latvian",
54
+ "bn": "bengali",
55
+ "sr": "serbian",
56
+ "az": "azerbaijani",
57
+ "sl": "slovenian",
58
+ "kn": "kannada",
59
+ "et": "estonian",
60
+ "mk": "macedonian",
61
+ "br": "breton",
62
+ "eu": "basque",
63
+ "is": "icelandic",
64
+ "hy": "armenian",
65
+ "ne": "nepali",
66
+ "mn": "mongolian",
67
+ "bs": "bosnian",
68
+ "kk": "kazakh",
69
+ "sq": "albanian",
70
+ "sw": "swahili",
71
+ "gl": "galician",
72
+ "mr": "marathi",
73
+ "pa": "punjabi",
74
+ "si": "sinhala",
75
+ "km": "khmer",
76
+ "sn": "shona",
77
+ "yo": "yoruba",
78
+ "so": "somali",
79
+ "af": "afrikaans",
80
+ "oc": "occitan",
81
+ "ka": "georgian",
82
+ "be": "belarusian",
83
+ "tg": "tajik",
84
+ "sd": "sindhi",
85
+ "gu": "gujarati",
86
+ "am": "amharic",
87
+ "yi": "yiddish",
88
+ "lo": "lao",
89
+ "uz": "uzbek",
90
+ "fo": "faroese",
91
+ "ht": "haitian creole",
92
+ "ps": "pashto",
93
+ "tk": "turkmen",
94
+ "nn": "nynorsk",
95
+ "mt": "maltese",
96
+ "sa": "sanskrit",
97
+ "lb": "luxembourgish",
98
+ "my": "myanmar",
99
+ "bo": "tibetan",
100
+ "tl": "tagalog",
101
+ "mg": "malagasy",
102
+ "as": "assamese",
103
+ "tt": "tatar",
104
+ "haw": "hawaiian",
105
+ "ln": "lingala",
106
+ "ha": "hausa",
107
+ "ba": "bashkir",
108
+ "jw": "javanese",
109
+ "su": "sundanese",
110
+ "yue": "cantonese",
111
+ }
112
+
113
+ # language code lookup by name, with a few language aliases
114
+ TO_LANGUAGE_CODE = {
115
+ **{language: code for code, language in LANGUAGES.items()},
116
+ "burmese": "my",
117
+ "valencian": "ca",
118
+ "flemish": "nl",
119
+ "haitian": "ht",
120
+ "letzeburgesch": "lb",
121
+ "pushto": "ps",
122
+ "panjabi": "pa",
123
+ "moldavian": "ro",
124
+ "moldovan": "ro",
125
+ "sinhalese": "si",
126
+ "castilian": "es",
127
+ "mandarin": "zh",
128
+ }
129
+
130
+
131
+ @dataclass
132
+ class Tokenizer:
133
+ """A thin wrapper around `tiktoken` providing quick access to special tokens"""
134
+
135
+ encoding: tiktoken.Encoding
136
+ num_languages: int
137
+ language: Optional[str] = None
138
+ task: Optional[str] = None
139
+ sot_sequence: Tuple[int] = ()
140
+ special_tokens: Dict[str, int] = field(default_factory=dict)
141
+
142
+ def __post_init__(self):
143
+ for special in self.encoding.special_tokens_set:
144
+ special_token = self.encoding.encode_single_token(special)
145
+ self.special_tokens[special] = special_token
146
+
147
+ sot: int = self.special_tokens["<|startoftranscript|>"]
148
+ translate: int = self.special_tokens["<|translate|>"]
149
+ transcribe: int = self.special_tokens["<|transcribe|>"]
150
+
151
+ langs = tuple(LANGUAGES.keys())[: self.num_languages]
152
+ sot_sequence = [sot]
153
+ if self.language is not None:
154
+ sot_sequence.append(sot + 1 + langs.index(self.language))
155
+ if self.task is not None:
156
+ task_token: int = transcribe if self.task == "transcribe" else translate
157
+ sot_sequence.append(task_token)
158
+
159
+ self.sot_sequence = tuple(sot_sequence)
160
+
161
+ def encode(self, text, **kwargs):
162
+ return self.encoding.encode(text, **kwargs)
163
+
164
+ def decode(self, token_ids: List[int], **kwargs) -> str:
165
+ token_ids = [t for t in token_ids if t < self.timestamp_begin]
166
+ return self.encoding.decode(token_ids, **kwargs)
167
+
168
+ def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
169
+ """
170
+ Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
171
+ This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
172
+ """
173
+ return self.encoding.decode(token_ids, **kwargs)
174
+
175
+ @cached_property
176
+ def eot(self) -> int:
177
+ return self.encoding.eot_token
178
+
179
+ @cached_property
180
+ def transcribe(self) -> int:
181
+ return self.special_tokens["<|transcribe|>"]
182
+
183
+ @cached_property
184
+ def translate(self) -> int:
185
+ return self.special_tokens["<|translate|>"]
186
+
187
+ @cached_property
188
+ def sot(self) -> int:
189
+ return self.special_tokens["<|startoftranscript|>"]
190
+
191
+ @cached_property
192
+ def sot_lm(self) -> int:
193
+ return self.special_tokens["<|startoflm|>"]
194
+
195
+ @cached_property
196
+ def sot_prev(self) -> int:
197
+ return self.special_tokens["<|startofprev|>"]
198
+
199
+ @cached_property
200
+ def no_speech(self) -> int:
201
+ return self.special_tokens["<|nospeech|>"]
202
+
203
+ @cached_property
204
+ def no_timestamps(self) -> int:
205
+ return self.special_tokens["<|notimestamps|>"]
206
+
207
+ @cached_property
208
+ def timestamp_begin(self) -> int:
209
+ return self.special_tokens["<|0.00|>"]
210
+
211
+ @cached_property
212
+ def language_token(self) -> int:
213
+ """Returns the token id corresponding to the value of the `language` field"""
214
+ if self.language is None:
215
+ raise ValueError("This tokenizer does not have language token configured")
216
+
217
+ return self.to_language_token(self.language)
218
+
219
+ def to_language_token(self, language):
220
+ if token := self.special_tokens.get(f"<|{language}|>", None):
221
+ return token
222
+
223
+ raise KeyError(f"Language {language} not found in tokenizer.")
224
+
225
+ @cached_property
226
+ def all_language_tokens(self) -> Tuple[int]:
227
+ result = []
228
+ for token, token_id in self.special_tokens.items():
229
+ if token.strip("<|>") in LANGUAGES:
230
+ result.append(token_id)
231
+ return tuple(result)[: self.num_languages]
232
+
233
+ @cached_property
234
+ def all_language_codes(self) -> Tuple[str]:
235
+ return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
236
+
237
+ @cached_property
238
+ def sot_sequence_including_notimestamps(self) -> Tuple[int]:
239
+ return tuple(list(self.sot_sequence) + [self.no_timestamps])
240
+
241
+ @cached_property
242
+ def non_speech_tokens(self) -> Tuple[int]:
243
+ """
244
+ Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
245
+ annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
246
+
247
+ - ♪♪♪
248
+ - ( SPEAKING FOREIGN LANGUAGE )
249
+ - [DAVID] Hey there,
250
+
251
+ keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
252
+ """
253
+ symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
254
+ symbols += (
255
+ "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
256
+ )
257
+
258
+ # symbols that may be a single token or multiple tokens depending on the tokenizer.
259
+ # In case they're multiple tokens, suppress the first token, which is safe because:
260
+ # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
261
+ # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
262
+ miscellaneous = set("♩♪♫♬♭♮♯")
263
+ assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
264
+
265
+ # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
266
+ result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
267
+ for symbol in symbols + list(miscellaneous):
268
+ for tokens in [
269
+ self.encoding.encode(symbol),
270
+ self.encoding.encode(" " + symbol),
271
+ ]:
272
+ if len(tokens) == 1 or symbol in miscellaneous:
273
+ result.add(tokens[0])
274
+
275
+ return tuple(sorted(result))
276
+
277
+ def split_to_word_tokens(self, tokens: List[int]):
278
+ if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
279
+ # These languages don't typically use spaces, so it is difficult to split words
280
+ # without morpheme analysis. Here, we instead split words at any
281
+ # position where the tokens are decoded as valid unicode points
282
+ return self.split_tokens_on_unicode(tokens)
283
+
284
+ return self.split_tokens_on_spaces(tokens)
285
+
286
+ def split_tokens_on_unicode(self, tokens: List[int]):
287
+ decoded_full = self.decode_with_timestamps(tokens)
288
+ replacement_char = "\ufffd"
289
+
290
+ words = []
291
+ word_tokens = []
292
+ current_tokens = []
293
+ unicode_offset = 0
294
+
295
+ for token in tokens:
296
+ current_tokens.append(token)
297
+ decoded = self.decode_with_timestamps(current_tokens)
298
+
299
+ if (
300
+ replacement_char not in decoded
301
+ or decoded_full[unicode_offset + decoded.index(replacement_char)]
302
+ == replacement_char
303
+ ):
304
+ words.append(decoded)
305
+ word_tokens.append(current_tokens)
306
+ current_tokens = []
307
+ unicode_offset += len(decoded)
308
+
309
+ return words, word_tokens
310
+
311
+ def split_tokens_on_spaces(self, tokens: List[int]):
312
+ subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
313
+ words = []
314
+ word_tokens = []
315
+
316
+ for subword, subword_tokens in zip(subwords, subword_tokens_list):
317
+ special = subword_tokens[0] >= self.eot
318
+ with_space = subword.startswith(" ")
319
+ punctuation = subword.strip() in string.punctuation
320
+ if special or with_space or punctuation or len(words) == 0:
321
+ words.append(subword)
322
+ word_tokens.append(subword_tokens)
323
+ else:
324
+ words[-1] = words[-1] + subword
325
+ word_tokens[-1].extend(subword_tokens)
326
+
327
+ return words, word_tokens
328
+
329
+
330
+ @lru_cache(maxsize=None)
331
+ def get_encoding(name: str = "gpt2", num_languages: int = 99):
332
+ vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
333
+ ranks = {
334
+ base64.b64decode(token): int(rank)
335
+ for token, rank in (line.split() for line in open(vocab_path) if line)
336
+ }
337
+ n_vocab = len(ranks)
338
+ special_tokens = {}
339
+
340
+ specials = [
341
+ "<|endoftext|>",
342
+ "<|startoftranscript|>",
343
+ *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
344
+ "<|translate|>",
345
+ "<|transcribe|>",
346
+ "<|startoflm|>",
347
+ "<|startofprev|>",
348
+ "<|nospeech|>",
349
+ "<|notimestamps|>",
350
+ *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
351
+ ]
352
+
353
+ for token in specials:
354
+ special_tokens[token] = n_vocab
355
+ n_vocab += 1
356
+
357
+ return tiktoken.Encoding(
358
+ name=os.path.basename(vocab_path),
359
+ explicit_n_vocab=n_vocab,
360
+ pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
361
+ mergeable_ranks=ranks,
362
+ special_tokens=special_tokens,
363
+ )
364
+
365
+
366
+ @lru_cache(maxsize=None)
367
+ def get_tokenizer(
368
+ multilingual: bool,
369
+ *,
370
+ num_languages: int = 99,
371
+ language: Optional[str] = None,
372
+ task: Optional[str] = None, # Literal["transcribe", "translate", None]
373
+ ) -> Tokenizer:
374
+ if language is not None:
375
+ language = language.lower()
376
+ if language not in LANGUAGES:
377
+ if language in TO_LANGUAGE_CODE:
378
+ language = TO_LANGUAGE_CODE[language]
379
+ else:
380
+ raise ValueError(f"Unsupported language: {language}")
381
+
382
+ if multilingual:
383
+ encoding_name = "multilingual"
384
+ language = language or "en"
385
+ task = task or "transcribe"
386
+ else:
387
+ encoding_name = "gpt2"
388
+ language = None
389
+ task = None
390
+
391
+ encoding = get_encoding(name=encoding_name, num_languages=num_languages)
392
+
393
+ return Tokenizer(
394
+ encoding=encoding, num_languages=num_languages, language=language, task=task
395
+ )