happyme531 commited on
Commit
eb10636
·
verified ·
1 Parent(s): 6b5b4b0

Upload 6 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ librkllmrt.so filter=lfs diff=lfs merge=lfs -text
37
+ Qwen3-Reranker-0.6B_f16.rkllm filter=lfs diff=lfs merge=lfs -text
38
+ Qwen3-Reranker-0.6B_w8a8.rkllm filter=lfs diff=lfs merge=lfs -text
Qwen3-Reranker-0.6B_f16.rkllm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5577675b87b5d19b1fbc20360b9caa174fa9d64d08b6cbe564d74456216f117
3
+ size 1524801182
Qwen3-Reranker-0.6B_w8a8.rkllm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:820c0b6b2419d3799a612751d9d8b86ed0e0e852364a6de78972a3976c34eaa8
3
+ size 931372078
librkllmrt.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6a9c2de93cf94bb524eb071c27190ad4c83401e01b562534f265dff4cb40da2
3
+ size 6710712
rkllm-convert.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from rkllm.api import RKLLM
3
+
4
+ def convert_model(model_path, output_name, do_quantization=False):
5
+ """转换单个模型"""
6
+ llm = RKLLM()
7
+
8
+ print(f"正在加载模型: {model_path}")
9
+ ret = llm.load_huggingface(model=model_path, model_lora=None, device='cpu')
10
+ if ret != 0:
11
+ print(f'加载模型失败: {model_path}')
12
+ return ret
13
+
14
+ print(f"正在构建模型: {output_name} (量化: {do_quantization})")
15
+ qparams = None
16
+ ret = llm.build(do_quantization=do_quantization, optimization_level=1, quantized_dtype='w8a8',
17
+ quantized_algorithm='normal', target_platform='rk3588', num_npu_core=3, extra_qparams=qparams)
18
+
19
+ if ret != 0:
20
+ print(f'构建模型失败: {output_name}')
21
+ return ret
22
+
23
+ # 导出rkllm模型
24
+ print(f"正在导出模型: {output_name}")
25
+ ret = llm.export_rkllm(output_name)
26
+ if ret != 0:
27
+ print(f'导出模型失败: {output_name}')
28
+ return ret
29
+
30
+ print(f"成功转换: {output_name}")
31
+ return 0
32
+
33
+ def main():
34
+ """主函数:遍历所有子文件夹并转换模型"""
35
+ current_dir = '.'
36
+
37
+ # 获取所有子文件夹
38
+ subdirs = [d for d in os.listdir(current_dir)
39
+ if os.path.isdir(os.path.join(current_dir, d)) and not d.startswith('.')]
40
+
41
+ print(f"找到 {len(subdirs)} 个模型文件夹: {subdirs}")
42
+
43
+ for subdir in subdirs:
44
+ model_path = os.path.join(current_dir, subdir)
45
+
46
+ # 生成输出文件名
47
+ base_name = subdir.replace('/', '_').replace('\\', '_')
48
+ quantized_output = f"{base_name}_w8a8.rkllm"
49
+ unquantized_output = f"{base_name}_f16.rkllm"
50
+
51
+ print(f"\n{'='*50}")
52
+ print(f"处理模型文件夹: {subdir}")
53
+ print(f"{'='*50}")
54
+
55
+ # 转换非量化版本
56
+ print(f"\n--- 转换非量化版本 ---")
57
+ ret = convert_model(model_path, unquantized_output, do_quantization=False)
58
+ if ret != 0:
59
+ print(f"非量化版本转换失败: {subdir}")
60
+ continue
61
+
62
+ # 转换量化版本
63
+ print(f"\n--- 转换量化版本 ---")
64
+ ret = convert_model(model_path, quantized_output, do_quantization=True)
65
+ if ret != 0:
66
+ print(f"量化版本转换失败: {subdir}")
67
+ continue
68
+
69
+ print(f"\n✓ {subdir} 模型转换完成!")
70
+ print(f" - 非量化版本: {unquantized_output}")
71
+ print(f" - 量化版本: {quantized_output}")
72
+
73
+ if __name__ == "__main__":
74
+ main()
rkllm_binding.py ADDED
@@ -0,0 +1,658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import enum
3
+ import os
4
+
5
+ # Define constants from the header
6
+ CPU0 = (1 << 0) # 0x01
7
+ CPU1 = (1 << 1) # 0x02
8
+ CPU2 = (1 << 2) # 0x04
9
+ CPU3 = (1 << 3) # 0x08
10
+ CPU4 = (1 << 4) # 0x10
11
+ CPU5 = (1 << 5) # 0x20
12
+ CPU6 = (1 << 6) # 0x40
13
+ CPU7 = (1 << 7) # 0x80
14
+
15
+ # --- Enums ---
16
+ class LLMCallState(enum.IntEnum):
17
+ RKLLM_RUN_NORMAL = 0
18
+ RKLLM_RUN_WAITING = 1
19
+ RKLLM_RUN_FINISH = 2
20
+ RKLLM_RUN_ERROR = 3
21
+
22
+ class RKLLMInputType(enum.IntEnum):
23
+ RKLLM_INPUT_PROMPT = 0
24
+ RKLLM_INPUT_TOKEN = 1
25
+ RKLLM_INPUT_EMBED = 2
26
+ RKLLM_INPUT_MULTIMODAL = 3
27
+
28
+ class RKLLMInferMode(enum.IntEnum):
29
+ RKLLM_INFER_GENERATE = 0
30
+ RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1
31
+ RKLLM_INFER_GET_LOGITS = 2
32
+
33
+ # --- Structures ---
34
+ class RKLLMExtendParam(ctypes.Structure):
35
+ # 基础iommu domain ID, 对>1b的模型建议设置为1
36
+ base_domain_id: ctypes.c_int32
37
+ # 是否使用flash存储Embedding
38
+ embed_flash: ctypes.c_int8
39
+ # 启用的cpu核心数
40
+ enabled_cpus_num: ctypes.c_int8
41
+ # 启用的cpu核心掩码
42
+ enabled_cpus_mask: ctypes.c_uint32
43
+ reserved: ctypes.c_uint8 * 106
44
+
45
+ _fields_ = [
46
+ ("base_domain_id", ctypes.c_int32),
47
+ ("embed_flash", ctypes.c_int8),
48
+ ("enabled_cpus_num", ctypes.c_int8),
49
+ ("enabled_cpus_mask", ctypes.c_uint32),
50
+ ("reserved", ctypes.c_uint8 * 106)
51
+ ]
52
+
53
+ class RKLLMParam(ctypes.Structure):
54
+ # 模型文件路径
55
+ model_path: ctypes.c_char_p
56
+ # 上下文窗口最大token数
57
+ max_context_len: ctypes.c_int32
58
+ # 最大生成新token数
59
+ max_new_tokens: ctypes.c_int32
60
+ # Top-K采样参数
61
+ top_k: ctypes.c_int32
62
+ # 上下文窗口移动时保留的kv缓存数量
63
+ n_keep: ctypes.c_int32
64
+ # Top-P采样参数
65
+ top_p: ctypes.c_float
66
+ # 采样温度,影响token选择的随机性
67
+ temperature: ctypes.c_float
68
+ # 重复token惩罚
69
+ repeat_penalty: ctypes.c_float
70
+ # 频繁token惩罚
71
+ frequency_penalty: ctypes.c_float
72
+ # 输入中已存在token的惩罚
73
+ presence_penalty: ctypes.c_float
74
+ # Mirostat采样策略标志(0表示禁用)
75
+ mirostat: ctypes.c_int32
76
+ # Mirostat采样Tau参数
77
+ mirostat_tau: ctypes.c_float
78
+ # Mirostat采样Eta参数
79
+ mirostat_eta: ctypes.c_float
80
+ # 是否跳过特殊token
81
+ skip_special_token: ctypes.c_bool
82
+ # 是否异步推理
83
+ is_async: ctypes.c_bool
84
+ # 多模态输入中图像的起始Token
85
+ img_start: ctypes.c_char_p
86
+ # 多模态输入中图像的结束Token
87
+ img_end: ctypes.c_char_p
88
+ # 图像内容指针
89
+ img_content: ctypes.c_char_p
90
+ # 扩展参数
91
+ extend_param: RKLLMExtendParam
92
+
93
+ _fields_ = [
94
+ ("model_path", ctypes.c_char_p), # 模型文件路径
95
+ ("max_context_len", ctypes.c_int32), # 上下文窗口最大token数
96
+ ("max_new_tokens", ctypes.c_int32), # 最大生成新token数
97
+ ("top_k", ctypes.c_int32), # Top-K采样参数
98
+ ("n_keep", ctypes.c_int32), # 上下文窗口移动时保留的kv缓存数量
99
+ ("top_p", ctypes.c_float), # Top-P(nucleus)采样参数
100
+ ("temperature", ctypes.c_float), # 采样温度,影响token选择的随机性
101
+ ("repeat_penalty", ctypes.c_float), # 重复token惩罚
102
+ ("frequency_penalty", ctypes.c_float), # 频繁token惩罚
103
+ ("presence_penalty", ctypes.c_float), # 输入中已存在token的惩罚
104
+ ("mirostat", ctypes.c_int32), # Mirostat采样策略标志(0表示禁用)
105
+ ("mirostat_tau", ctypes.c_float), # Mirostat采样Tau参数
106
+ ("mirostat_eta", ctypes.c_float), # Mirostat采样Eta参数
107
+ ("skip_special_token", ctypes.c_bool), # 是否跳过特殊token
108
+ ("is_async", ctypes.c_bool), # 是否异步推理
109
+ ("img_start", ctypes.c_char_p), # 多模态输入中图像的起始Token
110
+ ("img_end", ctypes.c_char_p), # 多模态输入中图像的结束Token
111
+ ("img_content", ctypes.c_char_p), # 图像内容指针
112
+ ("extend_param", RKLLMExtendParam) # 扩展参数
113
+ ]
114
+
115
+ class RKLLMLoraAdapter(ctypes.Structure):
116
+ lora_adapter_path: ctypes.c_char_p
117
+ lora_adapter_name: ctypes.c_char_p
118
+ scale: ctypes.c_float
119
+
120
+ _fields_ = [
121
+ ("lora_adapter_path", ctypes.c_char_p),
122
+ ("lora_adapter_name", ctypes.c_char_p),
123
+ ("scale", ctypes.c_float)
124
+ ]
125
+
126
+ class RKLLMEmbedInput(ctypes.Structure):
127
+ # Shape: [n_tokens, embed_size]
128
+ embed: ctypes.POINTER(ctypes.c_float)
129
+ n_tokens: ctypes.c_size_t
130
+
131
+ _fields_ = [
132
+ ("embed", ctypes.POINTER(ctypes.c_float)),
133
+ ("n_tokens", ctypes.c_size_t)
134
+ ]
135
+
136
+ class RKLLMTokenInput(ctypes.Structure):
137
+ # Shape: [n_tokens]
138
+ input_ids: ctypes.POINTER(ctypes.c_int32)
139
+ n_tokens: ctypes.c_size_t
140
+
141
+ _fields_ = [
142
+ ("input_ids", ctypes.POINTER(ctypes.c_int32)),
143
+ ("n_tokens", ctypes.c_size_t)
144
+ ]
145
+
146
+ class RKLLMMultiModelInput(ctypes.Structure):
147
+ prompt: ctypes.c_char_p
148
+ image_embed: ctypes.POINTER(ctypes.c_float)
149
+ n_image_tokens: ctypes.c_size_t
150
+ n_image: ctypes.c_size_t
151
+ image_width: ctypes.c_size_t
152
+ image_height: ctypes.c_size_t
153
+
154
+ _fields_ = [
155
+ ("prompt", ctypes.c_char_p),
156
+ ("image_embed", ctypes.POINTER(ctypes.c_float)),
157
+ ("n_image_tokens", ctypes.c_size_t),
158
+ ("n_image", ctypes.c_size_t),
159
+ ("image_width", ctypes.c_size_t),
160
+ ("image_height", ctypes.c_size_t)
161
+ ]
162
+
163
+ class _RKLLMInputUnion(ctypes.Union):
164
+ prompt_input: ctypes.c_char_p
165
+ embed_input: RKLLMEmbedInput
166
+ token_input: RKLLMTokenInput
167
+ multimodal_input: RKLLMMultiModelInput
168
+
169
+ _fields_ = [
170
+ ("prompt_input", ctypes.c_char_p),
171
+ ("embed_input", RKLLMEmbedInput),
172
+ ("token_input", RKLLMTokenInput),
173
+ ("multimodal_input", RKLLMMultiModelInput)
174
+ ]
175
+
176
+ class RKLLMInput(ctypes.Structure):
177
+ input_type: ctypes.c_int
178
+ _union_data: _RKLLMInputUnion
179
+
180
+ _fields_ = [
181
+ ("input_type", ctypes.c_int), # Enum will be passed as int, changed RKLLMInputType to ctypes.c_int
182
+ ("_union_data", _RKLLMInputUnion)
183
+ ]
184
+ # Properties to make accessing union members easier
185
+ @property
186
+ def prompt_input(self) -> bytes: # Assuming c_char_p maps to bytes
187
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
188
+ return self._union_data.prompt_input
189
+ raise AttributeError("Not a prompt input")
190
+ @prompt_input.setter
191
+ def prompt_input(self, value: bytes): # Assuming c_char_p maps to bytes
192
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
193
+ self._union_data.prompt_input = value
194
+ else:
195
+ raise AttributeError("Not a prompt input")
196
+ @property
197
+ def embed_input(self) -> RKLLMEmbedInput:
198
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
199
+ return self._union_data.embed_input
200
+ raise AttributeError("Not an embed input")
201
+ @embed_input.setter
202
+ def embed_input(self, value: RKLLMEmbedInput):
203
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
204
+ self._union_data.embed_input = value
205
+ else:
206
+ raise AttributeError("Not an embed input")
207
+
208
+ @property
209
+ def token_input(self) -> RKLLMTokenInput:
210
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
211
+ return self._union_data.token_input
212
+ raise AttributeError("Not a token input")
213
+ @token_input.setter
214
+ def token_input(self, value: RKLLMTokenInput):
215
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
216
+ self._union_data.token_input = value
217
+ else:
218
+ raise AttributeError("Not a token input")
219
+
220
+ @property
221
+ def multimodal_input(self) -> RKLLMMultiModelInput:
222
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL:
223
+ return self._union_data.multimodal_input
224
+ raise AttributeError("Not a multimodal input")
225
+ @multimodal_input.setter
226
+ def multimodal_input(self, value: RKLLMMultiModelInput):
227
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL:
228
+ self._union_data.multimodal_input = value
229
+ else:
230
+ raise AttributeError("Not a multimodal input")
231
+
232
+ class RKLLMLoraParam(ctypes.Structure): # For inference
233
+ lora_adapter_name: ctypes.c_char_p
234
+
235
+ _fields_ = [
236
+ ("lora_adapter_name", ctypes.c_char_p)
237
+ ]
238
+
239
+ class RKLLMPromptCacheParam(ctypes.Structure): # For inference
240
+ save_prompt_cache: ctypes.c_int # bool-like
241
+ prompt_cache_path: ctypes.c_char_p
242
+
243
+ _fields_ = [
244
+ ("save_prompt_cache", ctypes.c_int), # bool-like
245
+ ("prompt_cache_path", ctypes.c_char_p)
246
+ ]
247
+
248
+ class RKLLMInferParam(ctypes.Structure):
249
+ mode: ctypes.c_int
250
+ lora_params: ctypes.POINTER(RKLLMLoraParam)
251
+ prompt_cache_params: ctypes.POINTER(RKLLMPromptCacheParam)
252
+ keep_history: ctypes.c_int # bool-like
253
+
254
+ _fields_ = [
255
+ ("mode", ctypes.c_int), # Enum will be passed as int, changed RKLLMInferMode to ctypes.c_int
256
+ ("lora_params", ctypes.POINTER(RKLLMLoraParam)),
257
+ ("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)),
258
+ ("keep_history", ctypes.c_int) # bool-like
259
+ ]
260
+
261
+ class RKLLMResultLastHiddenLayer(ctypes.Structure):
262
+ # Shape: [num_tokens, embd_size]
263
+ hidden_states: ctypes.POINTER(ctypes.c_float)
264
+ # 隐藏层大小
265
+ embd_size: ctypes.c_int
266
+ # 输出token数
267
+ num_tokens: ctypes.c_int
268
+
269
+ _fields_ = [
270
+ ("hidden_states", ctypes.POINTER(ctypes.c_float)),
271
+ ("embd_size", ctypes.c_int),
272
+ ("num_tokens", ctypes.c_int)
273
+ ]
274
+
275
+ class RKLLMResultLogits(ctypes.Structure):
276
+ # Shape: [num_tokens, vocab_size]
277
+ logits: ctypes.POINTER(ctypes.c_float)
278
+ # 词汇表大小
279
+ vocab_size: ctypes.c_int
280
+ # 输出token数
281
+ num_tokens: ctypes.c_int
282
+
283
+ _fields_ = [
284
+ ("logits", ctypes.POINTER(ctypes.c_float)),
285
+ ("vocab_size", ctypes.c_int),
286
+ ("num_tokens", ctypes.c_int)
287
+ ]
288
+
289
+ class RKLLMResult(ctypes.Structure):
290
+ text: ctypes.c_char_p
291
+ token_id: ctypes.c_int32
292
+ last_hidden_layer: RKLLMResultLastHiddenLayer
293
+ logits: RKLLMResultLogits
294
+
295
+ _fields_ = [
296
+ ("text", ctypes.c_char_p),
297
+ ("token_id", ctypes.c_int32),
298
+ ("last_hidden_layer", RKLLMResultLastHiddenLayer),
299
+ ("logits", RKLLMResultLogits)
300
+ ]
301
+
302
+ # --- Typedefs ---
303
+ LLMHandle = ctypes.c_void_p
304
+
305
+ # --- Callback Function Type ---
306
+ LLMResultCallback = ctypes.CFUNCTYPE(
307
+ None, # return type: void
308
+ ctypes.POINTER(RKLLMResult),
309
+ ctypes.c_void_p, # userdata
310
+ ctypes.c_int # enum, will be passed as int. Changed LLMCallState to ctypes.c_int
311
+ )
312
+
313
+
314
+ class RKLLMRuntime:
315
+ def __init__(self, library_path="./librkllmrt.so"):
316
+ try:
317
+ self.lib = ctypes.CDLL(library_path)
318
+ except OSError as e:
319
+ raise OSError(f"Failed to load RKLLM library from {library_path}. "
320
+ f"Ensure it's in your LD_LIBRARY_PATH or provide the full path. Error: {e}")
321
+ self._setup_functions()
322
+ self.llm_handle = LLMHandle()
323
+ self._c_callback = None # To keep the callback object alive
324
+
325
+ def _setup_functions(self):
326
+ # RKLLMParam rkllm_createDefaultParam();
327
+ self.lib.rkllm_createDefaultParam.restype = RKLLMParam
328
+ self.lib.rkllm_createDefaultParam.argtypes = []
329
+
330
+ # int rkllm_init(LLMHandle* handle, RKLLMParam* param, LLMResultCallback callback);
331
+ self.lib.rkllm_init.restype = ctypes.c_int
332
+ self.lib.rkllm_init.argtypes = [
333
+ ctypes.POINTER(LLMHandle),
334
+ ctypes.POINTER(RKLLMParam),
335
+ LLMResultCallback
336
+ ]
337
+
338
+ # int rkllm_load_lora(LLMHandle handle, RKLLMLoraAdapter* lora_adapter);
339
+ self.lib.rkllm_load_lora.restype = ctypes.c_int
340
+ self.lib.rkllm_load_lora.argtypes = [LLMHandle, ctypes.POINTER(RKLLMLoraAdapter)]
341
+
342
+ # int rkllm_load_prompt_cache(LLMHandle handle, const char* prompt_cache_path);
343
+ self.lib.rkllm_load_prompt_cache.restype = ctypes.c_int
344
+ self.lib.rkllm_load_prompt_cache.argtypes = [LLMHandle, ctypes.c_char_p]
345
+
346
+ # int rkllm_release_prompt_cache(LLMHandle handle);
347
+ self.lib.rkllm_release_prompt_cache.restype = ctypes.c_int
348
+ self.lib.rkllm_release_prompt_cache.argtypes = [LLMHandle]
349
+
350
+ # int rkllm_destroy(LLMHandle handle);
351
+ self.lib.rkllm_destroy.restype = ctypes.c_int
352
+ self.lib.rkllm_destroy.argtypes = [LLMHandle]
353
+
354
+ # int rkllm_run(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
355
+ self.lib.rkllm_run.restype = ctypes.c_int
356
+ self.lib.rkllm_run.argtypes = [
357
+ LLMHandle,
358
+ ctypes.POINTER(RKLLMInput),
359
+ ctypes.POINTER(RKLLMInferParam),
360
+ ctypes.c_void_p # userdata
361
+ ]
362
+
363
+ # int rkllm_run_async(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
364
+ # Assuming async also takes userdata for the callback context
365
+ self.lib.rkllm_run_async.restype = ctypes.c_int
366
+ self.lib.rkllm_run_async.argtypes = [
367
+ LLMHandle,
368
+ ctypes.POINTER(RKLLMInput),
369
+ ctypes.POINTER(RKLLMInferParam),
370
+ ctypes.c_void_p # userdata
371
+ ]
372
+
373
+ # int rkllm_abort(LLMHandle handle);
374
+ self.lib.rkllm_abort.restype = ctypes.c_int
375
+ self.lib.rkllm_abort.argtypes = [LLMHandle]
376
+
377
+ # int rkllm_is_running(LLMHandle handle);
378
+ self.lib.rkllm_is_running.restype = ctypes.c_int # 0 if running, non-zero otherwise
379
+ self.lib.rkllm_is_running.argtypes = [LLMHandle]
380
+
381
+ # int rkllm_clear_kv_cache(LLMHandle handle, int keep_system_prompt);
382
+ self.lib.rkllm_clear_kv_cache.restype = ctypes.c_int
383
+ self.lib.rkllm_clear_kv_cache.argtypes = [LLMHandle, ctypes.c_int]
384
+
385
+ # int rkllm_set_chat_template(LLMHandle handle, const char* system_prompt, const char* prompt_prefix, const char* prompt_postfix);
386
+ self.lib.rkllm_set_chat_template.restype = ctypes.c_int
387
+ self.lib.rkllm_set_chat_template.argtypes = [
388
+ LLMHandle,
389
+ ctypes.c_char_p,
390
+ ctypes.c_char_p,
391
+ ctypes.c_char_p
392
+ ]
393
+
394
+ def create_default_param(self) -> RKLLMParam:
395
+ """Creates a default RKLLMParam structure."""
396
+ return self.lib.rkllm_createDefaultParam()
397
+
398
+ def init(self, param: RKLLMParam, callback_func) -> int:
399
+ """
400
+ Initializes the LLM.
401
+ :param param: RKLLMParam structure.
402
+ :param callback_func: A Python function that matches the signature:
403
+ def my_callback(result_ptr, userdata_ptr, state_enum):
404
+ result = result_ptr.contents # RKLLMResult
405
+ # Process result
406
+ # userdata can be retrieved if passed during run, or ignored
407
+ # state = LLMCallState(state_enum)
408
+ :return: 0 for success, non-zero for failure.
409
+ """
410
+ if not callable(callback_func):
411
+ raise ValueError("callback_func must be a callable Python function.")
412
+
413
+ # Keep a reference to the ctypes callback object to prevent it from being garbage collected
414
+ self._c_callback = LLMResultCallback(callback_func)
415
+
416
+ ret = self.lib.rkllm_init(ctypes.byref(self.llm_handle), ctypes.byref(param), self._c_callback)
417
+ if ret != 0:
418
+ raise RuntimeError(f"rkllm_init failed with error code {ret}")
419
+ return ret
420
+
421
+ def load_lora(self, lora_adapter: RKLLMLoraAdapter) -> int:
422
+ """Loads a Lora adapter."""
423
+ ret = self.lib.rkllm_load_lora(self.llm_handle, ctypes.byref(lora_adapter))
424
+ if ret != 0:
425
+ raise RuntimeError(f"rkllm_load_lora failed with error code {ret}")
426
+ return ret
427
+
428
+ def load_prompt_cache(self, prompt_cache_path: str) -> int:
429
+ """Loads a prompt cache from a file."""
430
+ c_path = prompt_cache_path.encode('utf-8')
431
+ ret = self.lib.rkllm_load_prompt_cache(self.llm_handle, c_path)
432
+ if ret != 0:
433
+ raise RuntimeError(f"rkllm_load_prompt_cache failed for {prompt_cache_path} with error code {ret}")
434
+ return ret
435
+
436
+ def release_prompt_cache(self) -> int:
437
+ """Releases the prompt cache from memory."""
438
+ ret = self.lib.rkllm_release_prompt_cache(self.llm_handle)
439
+ if ret != 0:
440
+ raise RuntimeError(f"rkllm_release_prompt_cache failed with error code {ret}")
441
+ return ret
442
+
443
+ def destroy(self) -> int:
444
+ """Destroys the LLM instance and releases resources."""
445
+ if self.llm_handle and self.llm_handle.value: # Check if handle is not NULL
446
+ ret = self.lib.rkllm_destroy(self.llm_handle)
447
+ self.llm_handle = LLMHandle() # Reset handle
448
+ if ret != 0:
449
+ # Don't raise here as it might be called in __del__
450
+ print(f"Warning: rkllm_destroy failed with error code {ret}")
451
+ return ret
452
+ return 0 # Already destroyed or not initialized
453
+
454
+ def run(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
455
+ """Runs an LLM inference task synchronously."""
456
+ # userdata can be a ctypes.py_object if you want to pass Python objects,
457
+ # then cast to c_void_p. Or simply None.
458
+ if userdata is not None:
459
+ # Store the userdata object to keep it alive during the call
460
+ self._userdata_ref = userdata
461
+ c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p)
462
+ else:
463
+ c_userdata = None
464
+ ret = self.lib.rkllm_run(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
465
+ if ret != 0:
466
+ raise RuntimeError(f"rkllm_run failed with error code {ret}")
467
+ return ret
468
+
469
+ def run_async(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
470
+ """Runs an LLM inference task asynchronously."""
471
+ if userdata is not None:
472
+ # Store the userdata object to keep it alive during the call
473
+ self._userdata_ref = userdata
474
+ c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p)
475
+ else:
476
+ c_userdata = None
477
+ ret = self.lib.rkllm_run_async(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
478
+ if ret != 0:
479
+ raise RuntimeError(f"rkllm_run_async failed with error code {ret}")
480
+ return ret
481
+
482
+ def abort(self) -> int:
483
+ """Aborts an ongoing LLM task."""
484
+ ret = self.lib.rkllm_abort(self.llm_handle)
485
+ if ret != 0:
486
+ raise RuntimeError(f"rkllm_abort failed with error code {ret}")
487
+ return ret
488
+
489
+ def is_running(self) -> bool:
490
+ """Checks if an LLM task is currently running. Returns True if running."""
491
+ # The C API returns 0 if running, non-zero otherwise.
492
+ # This is a bit counter-intuitive for a boolean "is_running".
493
+ return self.lib.rkllm_is_running(self.llm_handle) == 0
494
+
495
+ def clear_kv_cache(self, keep_system_prompt: bool) -> int:
496
+ """Clears the key-value cache."""
497
+ ret = self.lib.rkllm_clear_kv_cache(self.llm_handle, ctypes.c_int(1 if keep_system_prompt else 0))
498
+ if ret != 0:
499
+ raise RuntimeError(f"rkllm_clear_kv_cache failed with error code {ret}")
500
+ return ret
501
+
502
+ def set_chat_template(self, system_prompt: str, prompt_prefix: str, prompt_postfix: str) -> int:
503
+ """Sets the chat template for the LLM."""
504
+ c_system = system_prompt.encode('utf-8') if system_prompt else b""
505
+ c_prefix = prompt_prefix.encode('utf-8') if prompt_prefix else b""
506
+ c_postfix = prompt_postfix.encode('utf-8') if prompt_postfix else b""
507
+
508
+ ret = self.lib.rkllm_set_chat_template(self.llm_handle, c_system, c_prefix, c_postfix)
509
+ if ret != 0:
510
+ raise RuntimeError(f"rkllm_set_chat_template failed with error code {ret}")
511
+ return ret
512
+
513
+ def __enter__(self):
514
+ return self
515
+
516
+ def __exit__(self, exc_type, exc_val, exc_tb):
517
+ self.destroy()
518
+
519
+ def __del__(self):
520
+ self.destroy() # Ensure resources are freed if object is garbage collected
521
+
522
+ # --- Example Usage (Illustrative) ---
523
+ if __name__ == "__main__":
524
+ # This is a placeholder for how you might use it.
525
+ # You'll need a valid .rkllm model and librkllmrt.so in your path.
526
+
527
+ # Global list to store results from callback for demonstration
528
+ results_buffer = []
529
+
530
+ def my_python_callback(result_ptr, userdata_ptr, state_enum):
531
+ """
532
+ Callback function to be called by the C library.
533
+ """
534
+ global results_buffer
535
+ state = LLMCallState(state_enum)
536
+ result = result_ptr.contents
537
+
538
+ current_text = ""
539
+ if result.text: # Check if the char_p is not NULL
540
+ current_text = result.text.decode('utf-8', errors='ignore')
541
+
542
+ print(f"Callback: State={state.name}, TokenID={result.token_id}, Text='{current_text}'")
543
+ results_buffer.append(current_text)
544
+
545
+ if state == LLMCallState.RKLLM_RUN_FINISH:
546
+ print("Inference finished.")
547
+ elif state == LLMCallState.RKLLM_RUN_ERROR:
548
+ print("Inference error.")
549
+
550
+ # Example: Accessing logits if available (and if mode was set to get logits)
551
+ # if result.logits.logits and result.logits.vocab_size > 0:
552
+ # print(f" Logits (first 5 of vocab_size {result.logits.vocab_size}):")
553
+ # for i in range(min(5, result.logits.vocab_size)):
554
+ # print(f" {result.logits.logits[i]:.4f}", end=" ")
555
+ # print()
556
+
557
+
558
+ # --- Attempt to use the wrapper ---
559
+ try:
560
+ print("Initializing RKLLMRuntime...")
561
+ # Adjust library_path if librkllmrt.so is not in default search paths
562
+ # e.g., library_path="./path/to/librkllmrt.so"
563
+ rk_llm = RKLLMRuntime()
564
+
565
+ print("Creating default parameters...")
566
+ params = rk_llm.create_default_param()
567
+
568
+ # --- Configure parameters ---
569
+ # THIS IS CRITICAL: model_path must point to an actual .rkllm file
570
+ # For this example to run, you need a model file.
571
+ # Let's assume a dummy path for now, this will fail at init if not valid.
572
+ model_file = "dummy_model.rkllm"
573
+ if not os.path.exists(model_file):
574
+ print(f"Warning: Model file '{model_file}' does not exist. Init will likely fail.")
575
+ # Create a dummy file for the example to proceed further, though init will still fail
576
+ # with a real library unless it's a valid model.
577
+ with open(model_file, "w") as f:
578
+ f.write("dummy content")
579
+
580
+ params.model_path = model_file.encode('utf-8')
581
+ params.max_context_len = 512
582
+ params.max_new_tokens = 128
583
+ params.top_k = 1 # Greedy
584
+ params.temperature = 0.7
585
+ params.repeat_penalty = 1.1
586
+ # ... set other params as needed
587
+
588
+ print(f"Initializing LLM with model: {params.model_path.decode()}...")
589
+ # This will likely fail if dummy_model.rkllm is not a valid model recognized by the library
590
+ try:
591
+ rk_llm.init(params, my_python_callback)
592
+ print("LLM Initialized.")
593
+ except RuntimeError as e:
594
+ print(f"Error during LLM initialization: {e}")
595
+ print("This is expected if 'dummy_model.rkllm' is not a valid model.")
596
+ print("Replace 'dummy_model.rkllm' with a real model path to test further.")
597
+ exit()
598
+
599
+
600
+ # --- Prepare input ---
601
+ print("Preparing input...")
602
+ rk_input = RKLLMInput()
603
+ rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
604
+
605
+ prompt_text = "Translate the following English text to French: 'Hello, world!'"
606
+ c_prompt = prompt_text.encode('utf-8')
607
+ rk_input._union_data.prompt_input = c_prompt # Accessing union member directly
608
+
609
+ # --- Prepare inference parameters ---
610
+ print("Preparing inference parameters...")
611
+ infer_params = RKLLMInferParam()
612
+ infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
613
+ infer_params.keep_history = 1 # True
614
+ # infer_params.lora_params = None # or set up RKLLMLoraParam if using LoRA
615
+ # infer_params.prompt_cache_params = None # or set up RKLLMPromptCacheParam
616
+
617
+ # --- Run inference ---
618
+ print(f"Running inference with prompt: '{prompt_text}'")
619
+ results_buffer.clear()
620
+ try:
621
+ rk_llm.run(rk_input, infer_params) # Userdata is None by default
622
+ print("\n--- Full Response ---")
623
+ print("".join(results_buffer))
624
+ print("---------------------\n")
625
+ except RuntimeError as e:
626
+ print(f"Error during LLM run: {e}")
627
+
628
+
629
+ # --- Example: Set chat template (if model supports it) ---
630
+ # print("Setting chat template...")
631
+ # try:
632
+ # rk_llm.set_chat_template("You are a helpful assistant.", "<user>: ", "<assistant>: ")
633
+ # print("Chat template set.")
634
+ # except RuntimeError as e:
635
+ # print(f"Error setting chat template: {e}")
636
+
637
+ # --- Example: Clear KV Cache ---
638
+ # print("Clearing KV cache (keeping system prompt if any)...")
639
+ # try:
640
+ # rk_llm.clear_kv_cache(keep_system_prompt=True)
641
+ # print("KV cache cleared.")
642
+ # except RuntimeError as e:
643
+ # print(f"Error clearing KV cache: {e}")
644
+
645
+ except OSError as e:
646
+ print(f"OSError: {e}. Could not load the RKLLM library.")
647
+ print("Please ensure 'librkllmrt.so' is in your LD_LIBRARY_PATH or provide the full path.")
648
+ except Exception as e:
649
+ print(f"An unexpected error occurred: {e}")
650
+ finally:
651
+ if 'rk_llm' in locals() and rk_llm.llm_handle and rk_llm.llm_handle.value:
652
+ print("Destroying LLM instance...")
653
+ rk_llm.destroy()
654
+ print("LLM instance destroyed.")
655
+ if os.path.exists(model_file) and model_file == "dummy_model.rkllm":
656
+ os.remove(model_file) # Clean up dummy file
657
+
658
+ print("Example finished.")
test_reranker.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Qwen3-Reranker 推理测试代码
5
+ 使用 RKLLM API 进行文本重排序推理
6
+ """
7
+
8
+ import faulthandler
9
+ faulthandler.enable()
10
+ import os
11
+ os.environ["RKLLM_LOG_LEVEL"] = "1"
12
+ import numpy as np
13
+ import time
14
+ import re
15
+ from typing import List, Dict, Any, Tuple
16
+ from rkllm_binding import *
17
+
18
+
19
+ class Qwen3RerankerTester:
20
+ def __init__(self, model_path, library_path="./librkllmrt.so"):
21
+ """
22
+ 初始化 Qwen3 重排序模型测试器
23
+
24
+ Args:
25
+ model_path: 模型文件路径(.rkllm 格式)
26
+ library_path: RKLLM 库文件路径
27
+ """
28
+ self.model_path = model_path
29
+ self.library_path = library_path
30
+ self.runtime = None
31
+ self.current_result = None
32
+
33
+ # 根据官方 README 设置的格式
34
+ self.system_prompt = "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\"."
35
+
36
+ # "yes" 和 "no" 的可能 token IDs(需要通过实际测试确定)
37
+ # 这些是常见的token ID,实际使用中可能需要调整
38
+ self.yes_token_candidates = [9693]
39
+ self.no_token_candidates = [2152]
40
+
41
+ def callback_function(self, result_ptr, userdata_ptr, state_enum):
42
+ """
43
+ 推理回调函数
44
+
45
+ Args:
46
+ result_ptr: 结果指针
47
+ userdata_ptr: 用户数据指针
48
+ state_enum: 状态枚举
49
+ """
50
+ state = LLMCallState(state_enum)
51
+
52
+ if state == LLMCallState.RKLLM_RUN_NORMAL:
53
+ result = result_ptr.contents
54
+ print(f"result: {result}")
55
+
56
+ # 获取 logits
57
+ if result.logits.logits and result.logits.vocab_size > 0:
58
+ vocab_size = result.logits.vocab_size
59
+ num_tokens = result.logits.num_tokens
60
+
61
+ print(f"获取到 logits:vocab_size={vocab_size}, num_tokens={num_tokens}")
62
+
63
+ # 获取最后一个 token 的 logits
64
+ if num_tokens > 0:
65
+ last_token_logits = []
66
+ start_idx = (num_tokens - 1) * vocab_size
67
+ for i in range(vocab_size):
68
+ last_token_logits.append(result.logits.logits[start_idx + i])
69
+
70
+ self.current_result = {
71
+ 'logits': last_token_logits,
72
+ 'vocab_size': vocab_size,
73
+ 'num_tokens': num_tokens
74
+ }
75
+
76
+ print(f"最后一个 token 的 logits 范围: [{min(last_token_logits):.4f}, {max(last_token_logits):.4f}]")
77
+ else:
78
+ print("警告: 未能获取到 logits")
79
+
80
+ elif state == LLMCallState.RKLLM_RUN_ERROR:
81
+ print("推理过程发生错误")
82
+
83
+ def find_best_yes_no_tokens(self, logits):
84
+ """
85
+ 找到最可能的 "yes" 和 "no" token IDs
86
+
87
+ Args:
88
+ logits: 词汇表大小的 logits 数组
89
+
90
+ Returns:
91
+ (yes_token_id, no_token_id, yes_logit, no_logit)
92
+ """
93
+ vocab_size = len(logits)
94
+
95
+ # 找到 yes token 的最大 logit
96
+ best_yes_id = None
97
+ best_yes_logit = float('-inf')
98
+ for token_id in self.yes_token_candidates:
99
+ if token_id < vocab_size:
100
+ if logits[token_id] > best_yes_logit:
101
+ best_yes_logit = logits[token_id]
102
+ best_yes_id = token_id
103
+
104
+ # 找到 no token 的最大 logit
105
+ best_no_id = None
106
+ best_no_logit = float('-inf')
107
+ for token_id in self.no_token_candidates:
108
+ if token_id < vocab_size:
109
+ if logits[token_id] > best_no_logit:
110
+ best_no_logit = logits[token_id]
111
+ best_no_id = token_id
112
+
113
+ # 如果找不到预定义的 token,使用启发式方法
114
+ if best_yes_id is None or best_no_id is None:
115
+ print("警告: 使用启发式方法寻找 yes/no tokens")
116
+
117
+ # 找到 logits 最高的几个 token
118
+ sorted_indices = np.argsort(logits)[::-1]
119
+ top_tokens = sorted_indices[:20] # 取前20个最高的 logits
120
+
121
+ # 简单启发式:假设较高的 logit 对应 "yes",较低的对应 "no"
122
+ if best_yes_id is None:
123
+ best_yes_id = top_tokens[0]
124
+ best_yes_logit = logits[best_yes_id]
125
+
126
+ if best_no_id is None:
127
+ # 寻找一个相对较低但合理的 logit 作为 "no"
128
+ best_no_id = top_tokens[min(10, len(top_tokens)-1)]
129
+ best_no_logit = logits[best_no_id]
130
+
131
+ return best_yes_id, best_no_id, best_yes_logit, best_no_logit
132
+
133
+ def calculate_reranker_score(self, logits):
134
+ """
135
+ 计算重排序分数(基于 "yes" 和 "no" token 的 softmax 概率)
136
+
137
+ Args:
138
+ logits: 词汇表大小的 logits 数组
139
+
140
+ Returns:
141
+ 相关性分数 (0-1之间,越高越相关)
142
+ """
143
+ try:
144
+ # 找到 yes 和 no token 的 logits
145
+ yes_id, no_id, yes_logit, no_logit = self.find_best_yes_no_tokens(logits)
146
+
147
+ print(f"Yes token ID: {yes_id}, logit: {yes_logit:.4f}")
148
+ print(f"No token ID: {no_id}, logit: {no_logit:.4f}")
149
+
150
+ # 计算 softmax 概率
151
+ # 只考虑 yes 和 no 两个 token 的相对概率
152
+ max_logit = max(yes_logit, no_logit)
153
+ yes_exp = np.exp(yes_logit - max_logit) # 数值稳定性
154
+ no_exp = np.exp(no_logit - max_logit)
155
+
156
+ sum_exp = yes_exp + no_exp
157
+ yes_prob = yes_exp / sum_exp
158
+ no_prob = no_exp / sum_exp
159
+
160
+ print(f"Yes 概率: {yes_prob:.4f}, No 概率: {no_prob:.4f}")
161
+
162
+ # 返回 yes 的概率作为相关性分数
163
+ return float(yes_prob)
164
+
165
+ except Exception as e:
166
+ print(f"计算 reranker 分数时发生错误: {e}")
167
+ # 回退到简单的启发式方法
168
+ return self.fallback_score_calculation(logits)
169
+
170
+ def fallback_score_calculation(self, logits):
171
+ """
172
+ 备用分数计算方法(当无法找到 yes/no tokens 时)
173
+
174
+ Args:
175
+ logits: 词汇表大小的 logits 数组
176
+
177
+ Returns:
178
+ 相关性分数 (0-1之间)
179
+ """
180
+ print("使用备用分数计算方法")
181
+
182
+ # 使用 logits 的分布特征计算分数
183
+ logits_array = np.array(logits)
184
+
185
+ # 计算 softmax 分布的熵
186
+ softmax_probs = np.exp(logits_array - np.max(logits_array))
187
+ softmax_probs = softmax_probs / np.sum(softmax_probs)
188
+
189
+ # 熵越低,模型越确信(越相关)
190
+ entropy = -np.sum(softmax_probs * np.log(softmax_probs + 1e-10))
191
+ max_entropy = np.log(len(logits))
192
+ normalized_entropy = entropy / max_entropy
193
+
194
+ # 转换为相关性分数(熵低 = 相关性高)
195
+ confidence_score = 1.0 - normalized_entropy
196
+
197
+ # 结合最大 logit 的信息
198
+ max_logit_score = (np.max(logits_array) - np.mean(logits_array)) / (np.std(logits_array) + 1e-8)
199
+ max_logit_score = max(0, min(1, max_logit_score / 10)) # 归一化
200
+
201
+ # 综合分数
202
+ final_score = 0.7 * confidence_score + 0.3 * max_logit_score
203
+ final_score = max(0.0, min(1.0, final_score))
204
+
205
+ print(f"备用计算 - 熵分数: {confidence_score:.4f}, 最大logit分数: {max_logit_score:.4f}, 最终分数: {final_score:.4f}")
206
+
207
+ return final_score
208
+
209
+ def init_model(self):
210
+ """初始化模型"""
211
+ try:
212
+ print(f"初始化 RKLLM 运行时,库路径: {self.library_path}")
213
+ self.runtime = RKLLMRuntime(self.library_path)
214
+
215
+ print("创建默认参数...")
216
+ params = self.runtime.create_default_param()
217
+
218
+ # 配置参数
219
+ params.model_path = self.model_path.encode('utf-8')
220
+ params.max_context_len = 1024
221
+ params.max_new_tokens = 1 # reranker 只需要生成一个 token
222
+ params.temperature = 0.0 # 确定性输出
223
+ params.top_k = 1 # 贪心解码
224
+ params.top_p = 1.0 # 禁用nucleus采样
225
+
226
+ # 扩展参数配置
227
+ params.extend_param.base_domain_id = 1
228
+ params.extend_param.embed_flash = 0
229
+ params.extend_param.enabled_cpus_num = 4
230
+ params.extend_param.enabled_cpus_mask = 0x0F
231
+
232
+ print(f"初始化模型: {self.model_path}")
233
+ self.runtime.init(params, self.callback_function)
234
+
235
+ # 设置聊天模板
236
+ self.runtime.set_chat_template(
237
+ "",
238
+ "", # prefix
239
+ "" # suffix
240
+ )
241
+
242
+ print("模型初始化成功!")
243
+
244
+ except Exception as e:
245
+ print(f"模型初始化失败: {e}")
246
+ raise
247
+
248
+ def format_rerank_input(self, instruction, query, document):
249
+ """
250
+ 格式化重排序输入(根据官方 README 格式)
251
+
252
+ Args:
253
+ instruction: 任务指令
254
+ query: 查询文本
255
+ document: 文档文本
256
+
257
+ Returns:
258
+ 格式化的输入文本
259
+ """
260
+ if instruction is None:
261
+ instruction = 'Given a web search query, retrieve relevant passages that answer the query'
262
+
263
+ # 根据官方 README 的格式
264
+ formatted_input = f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {document}"
265
+ return formatted_input
266
+
267
+ def get_reranker_score(self, instruction, query, document):
268
+ """
269
+ 获取重排序分数(通过 logits)
270
+
271
+ Args:
272
+ instruction: 任务指令
273
+ query: 查询文本
274
+ document: 文档文本
275
+
276
+ Returns:
277
+ 相关性分数 (0-1之间)
278
+ """
279
+ try:
280
+ # 格式化输入
281
+ input_text = self.format_rerank_input(instruction, query, document)
282
+ print(f"\n重排序输入: {input_text[:200]}{'...' if len(input_text) > 200 else ''}")
283
+
284
+ # 准备输入
285
+ rk_input = RKLLMInput()
286
+ rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
287
+ c_prompt = input_text.encode('utf-8')
288
+ rk_input._union_data.prompt_input = c_prompt
289
+
290
+ # 准备推理参数 - 使用 GET_LOGITS 模式
291
+ infer_params = RKLLMInferParam()
292
+ infer_params.mode = RKLLMInferMode.RKLLM_INFER_GET_LOGITS # 获取 logits
293
+ infer_params.keep_history = 0
294
+
295
+ # 清空之前的结果
296
+ self.current_result = None
297
+ self.runtime.clear_kv_cache(False)
298
+
299
+ # 执行推理
300
+ start_time = time.time()
301
+ self.runtime.run(rk_input, infer_params)
302
+ end_time = time.time()
303
+
304
+ print(f"\n推理耗时: {end_time - start_time:.3f}秒")
305
+
306
+ if self.current_result and 'logits' in self.current_result:
307
+ # 使用正确的方法计算 reranker 分数
308
+ logits = self.current_result['logits']
309
+ score = self.calculate_reranker_score(logits)
310
+
311
+ print(f"计算得分: {score:.4f}")
312
+ return score
313
+ else:
314
+ print("警告: 未能获取到有效的 logits,返回默认分数")
315
+ return 0.0
316
+
317
+ except Exception as e:
318
+ print(f"重排序评分时发生错误: {e}")
319
+ import traceback
320
+ traceback.print_exc()
321
+ return 0.0
322
+
323
+ def rerank_documents(self, query, documents, instruction=None):
324
+ """
325
+ 对文档列表进行重排序
326
+
327
+ Args:
328
+ query: 查询文本
329
+ documents: 文档列表
330
+ instruction: 可选的任务指令
331
+
332
+ Returns:
333
+ 按相关性分数降序排列的(文档, 分数)元组列表
334
+ """
335
+ print(f"\n对 {len(documents)} 个文档进行重排序")
336
+ print(f"查询: {query}")
337
+
338
+ if instruction:
339
+ print(f"指令: {instruction}")
340
+
341
+ scored_docs = []
342
+ for i, doc in enumerate(documents):
343
+ print(f"\n--- 处理文档 {i+1}/{len(documents)} ---")
344
+ print(f"文档: {doc[:100]}{'...' if len(doc) > 100 else ''}")
345
+
346
+ score = self.get_reranker_score(instruction, query, doc)
347
+ scored_docs.append((doc, score))
348
+ print(f"得分: {score:.4f}")
349
+
350
+ # 按分数降序排序
351
+ scored_docs.sort(key=lambda x: x[1], reverse=True)
352
+ return scored_docs
353
+
354
+ def test_basic_reranking(self):
355
+ """测试基础重排序功能"""
356
+ print("\n" + "="*60)
357
+ print("测试基础重排序功能")
358
+ print("="*60)
359
+
360
+ # 测试查询
361
+ query = "What is the capital of China?"
362
+
363
+ # 候选文档(包含相关和不相关的)
364
+ documents = [
365
+ "Beijing is the capital city of China, located in northern China.",
366
+ "The Great Wall of China is an ancient fortification built to protect Chinese states.",
367
+ "Python is a high-level programming language used for software development.",
368
+ "China's capital Beijing is home to over 21 million people.",
369
+ "Machine learning is a subset of artificial intelligence that uses algorithms."
370
+ ]
371
+
372
+ # 执行重排序
373
+ instruction = "Given a web search query, retrieve relevant passages that answer the query"
374
+ ranked_docs = self.rerank_documents(query, documents, instruction)
375
+
376
+ # 显示结果
377
+ print(f"\n重排序结果(查询: {query}):")
378
+ print("-" * 80)
379
+ for i, (doc, score) in enumerate(ranked_docs):
380
+ print(f"排名 {i+1}: 分数 {score:.4f}")
381
+ print(f"文档: {doc}")
382
+ print()
383
+
384
+ return ranked_docs
385
+
386
+ def test_multilingual_reranking(self):
387
+ """测试多语言重排序"""
388
+ print("\n" + "="*60)
389
+ print("测试多语言重排序功能")
390
+ print("="*60)
391
+
392
+ # 中文查询
393
+ query = "中国的首都是什么?"
394
+
395
+ documents = [
396
+ "北京是中华人民共和国的首都,位于中国北部。",
397
+ "上海是中国的经济中心,人口超过2400万。",
398
+ "Python 是一种高级编程语言。",
399
+ "The capital of China is Beijing.",
400
+ "长城是中国古代的军事防御工程。"
401
+ ]
402
+
403
+ instruction = "Given a web search query, retrieve relevant passages that answer the query"
404
+ ranked_docs = self.rerank_documents(query, documents, instruction)
405
+
406
+ print(f"\n多语言重排序结果(查询: {query}):")
407
+ print("-" * 80)
408
+ for i, (doc, score) in enumerate(ranked_docs):
409
+ print(f"排名 {i+1}: 分数 {score:.4f}")
410
+ print(f"文档: {doc}")
411
+ print()
412
+
413
+ return ranked_docs
414
+
415
+ def test_domain_specific_reranking(self):
416
+ """测试领域特定的重排序"""
417
+ print("\n" + "="*60)
418
+ print("测试领域特定重排序(技术文档)")
419
+ print("="*60)
420
+
421
+ query = "How to implement a neural network in Python?"
422
+
423
+ documents = [
424
+ "PyTorch is a deep learning framework that provides tensor computations with GPU acceleration.",
425
+ "TensorFlow is an open-source machine learning library developed by Google.",
426
+ "Neural networks are computing systems inspired by biological neural networks.",
427
+ "Python is a programming language with simple syntax and powerful libraries.",
428
+ "To implement a neural network in Python, you can use libraries like PyTorch or TensorFlow to define layers, loss functions, and optimization algorithms.",
429
+ "Cooking recipes often require precise measurements and cooking times.",
430
+ "Backpropagation is the algorithm used to train neural networks by computing gradients."
431
+ ]
432
+
433
+ # 使用自定义指令
434
+ instruction = "Given a technical query and a document, determine if the document provides practical information for implementing the requested technical solution"
435
+
436
+ ranked_docs = self.rerank_documents(query, documents, instruction)
437
+
438
+ print(f"\n技术文档重排序结果(查询: {query}):")
439
+ print("-" * 80)
440
+ for i, (doc, score) in enumerate(ranked_docs):
441
+ print(f"排名 {i+1}: 分数 {score:.4f}")
442
+ print(f"文档: {doc}")
443
+ print()
444
+
445
+ return ranked_docs
446
+
447
+ def test_comparison_with_official_example(self):
448
+ """测试与官方示例的对比"""
449
+ print("\n" + "="*60)
450
+ print("测试与官方示例的对比")
451
+ print("="*60)
452
+
453
+ # 使用官方 README 中的示例
454
+ task = 'Given a web search query, retrieve relevant passages that answer the query'
455
+
456
+ queries = [
457
+ "What is the capital of China?",
458
+ "Explain gravity",
459
+ ]
460
+
461
+ documents = [
462
+ "The capital of China is Beijing.",
463
+ "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
464
+ ]
465
+
466
+ print("测试官方示例的查询-文档对:")
467
+ for i, (query, doc) in enumerate(zip(queries, documents)):
468
+ print(f"\n=== 查询-文档对 {i+1} ===")
469
+ print(f"查询: {query}")
470
+ print(f"文档: {doc}")
471
+
472
+ score = self.get_reranker_score(task, query, doc)
473
+ print(f"相关性分数: {score:.4f}")
474
+
475
+ def cleanup(self):
476
+ """清理资源"""
477
+ if self.runtime:
478
+ try:
479
+ self.runtime.destroy()
480
+ print("模型资源已清理")
481
+ except Exception as e:
482
+ print(f"清理资源时发生错误: {e}")
483
+
484
+
485
+ def main():
486
+ """主函数"""
487
+ import argparse
488
+
489
+ # 解析命令行参数
490
+ parser = argparse.ArgumentParser(description='Qwen3-Reranker-0.6B 推理测试')
491
+ parser.add_argument('model_path', help='模型文件路径(.rkllm格式)')
492
+ parser.add_argument('--library_path', default="./librkllmrt.so", help='RKLLM库文件路径(默认为./librkllmrt.so)')
493
+ args = parser.parse_args()
494
+
495
+ # 检查文件是否存在
496
+ if not os.path.exists(args.model_path):
497
+ print(f"错误: 模型文件不存在: {args.model_path}")
498
+ print("请确保:")
499
+ print("1. 已下载 Qwen3-Reranker-0.6B 模型")
500
+ print("2. 已使用 rkllm-convert.py 将模型转换为 .rkllm 格式")
501
+ return
502
+
503
+ if not os.path.exists(args.library_path):
504
+ print(f"错误: RKLLM 库文件不存在: {args.library_path}")
505
+ print("请确保 librkllmrt.so 在当前目录或 LD_LIBRARY_PATH 中")
506
+ return
507
+
508
+ print("Qwen3-Reranker-0.6B 推理测试")
509
+ print("=" * 60)
510
+ print("基于官方 README 的正确实现")
511
+ print("=" * 60)
512
+
513
+ # 创建测试器
514
+ tester = Qwen3RerankerTester(args.model_path, args.library_path)
515
+
516
+ try:
517
+ # 初始化模型
518
+ tester.init_model()
519
+
520
+ # 运行测试
521
+ print("\n开始运行重排序测试...")
522
+
523
+ # 测试官方示例对比
524
+ tester.test_comparison_with_official_example()
525
+
526
+ # 测试基础重排序功能
527
+ tester.test_basic_reranking()
528
+
529
+ # 测试多语言重排序
530
+ tester.test_multilingual_reranking()
531
+
532
+ # 测试领域特定重排序
533
+ tester.test_domain_specific_reranking()
534
+
535
+ print("\n" + "="*60)
536
+ print("所有重排序测试完成!")
537
+ print("="*60)
538
+
539
+ except KeyboardInterrupt:
540
+ print("\n测试被用户中断")
541
+ except Exception as e:
542
+ print(f"\n测试过程中发生错误: {e}")
543
+ import traceback
544
+ traceback.print_exc()
545
+ finally:
546
+ # 清理资源
547
+ tester.cleanup()
548
+
549
+
550
+ if __name__ == "__main__":
551
+ main()