Upload 6 files
Browse files- .gitattributes +3 -0
- Qwen3-Reranker-0.6B_f16.rkllm +3 -0
- Qwen3-Reranker-0.6B_w8a8.rkllm +3 -0
- librkllmrt.so +3 -0
- rkllm-convert.py +74 -0
- rkllm_binding.py +658 -0
- test_reranker.py +551 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
librkllmrt.so filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
Qwen3-Reranker-0.6B_f16.rkllm filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
Qwen3-Reranker-0.6B_w8a8.rkllm filter=lfs diff=lfs merge=lfs -text
|
Qwen3-Reranker-0.6B_f16.rkllm
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a5577675b87b5d19b1fbc20360b9caa174fa9d64d08b6cbe564d74456216f117
|
| 3 |
+
size 1524801182
|
Qwen3-Reranker-0.6B_w8a8.rkllm
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:820c0b6b2419d3799a612751d9d8b86ed0e0e852364a6de78972a3976c34eaa8
|
| 3 |
+
size 931372078
|
librkllmrt.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f6a9c2de93cf94bb524eb071c27190ad4c83401e01b562534f265dff4cb40da2
|
| 3 |
+
size 6710712
|
rkllm-convert.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from rkllm.api import RKLLM
|
| 3 |
+
|
| 4 |
+
def convert_model(model_path, output_name, do_quantization=False):
|
| 5 |
+
"""转换单个模型"""
|
| 6 |
+
llm = RKLLM()
|
| 7 |
+
|
| 8 |
+
print(f"正在加载模型: {model_path}")
|
| 9 |
+
ret = llm.load_huggingface(model=model_path, model_lora=None, device='cpu')
|
| 10 |
+
if ret != 0:
|
| 11 |
+
print(f'加载模型失败: {model_path}')
|
| 12 |
+
return ret
|
| 13 |
+
|
| 14 |
+
print(f"正在构建模型: {output_name} (量化: {do_quantization})")
|
| 15 |
+
qparams = None
|
| 16 |
+
ret = llm.build(do_quantization=do_quantization, optimization_level=1, quantized_dtype='w8a8',
|
| 17 |
+
quantized_algorithm='normal', target_platform='rk3588', num_npu_core=3, extra_qparams=qparams)
|
| 18 |
+
|
| 19 |
+
if ret != 0:
|
| 20 |
+
print(f'构建模型失败: {output_name}')
|
| 21 |
+
return ret
|
| 22 |
+
|
| 23 |
+
# 导出rkllm模型
|
| 24 |
+
print(f"正在导出模型: {output_name}")
|
| 25 |
+
ret = llm.export_rkllm(output_name)
|
| 26 |
+
if ret != 0:
|
| 27 |
+
print(f'导出模型失败: {output_name}')
|
| 28 |
+
return ret
|
| 29 |
+
|
| 30 |
+
print(f"成功转换: {output_name}")
|
| 31 |
+
return 0
|
| 32 |
+
|
| 33 |
+
def main():
|
| 34 |
+
"""主函数:遍历所有子文件夹并转换模型"""
|
| 35 |
+
current_dir = '.'
|
| 36 |
+
|
| 37 |
+
# 获取所有子文件夹
|
| 38 |
+
subdirs = [d for d in os.listdir(current_dir)
|
| 39 |
+
if os.path.isdir(os.path.join(current_dir, d)) and not d.startswith('.')]
|
| 40 |
+
|
| 41 |
+
print(f"找到 {len(subdirs)} 个模型文件夹: {subdirs}")
|
| 42 |
+
|
| 43 |
+
for subdir in subdirs:
|
| 44 |
+
model_path = os.path.join(current_dir, subdir)
|
| 45 |
+
|
| 46 |
+
# 生成输出文件名
|
| 47 |
+
base_name = subdir.replace('/', '_').replace('\\', '_')
|
| 48 |
+
quantized_output = f"{base_name}_w8a8.rkllm"
|
| 49 |
+
unquantized_output = f"{base_name}_f16.rkllm"
|
| 50 |
+
|
| 51 |
+
print(f"\n{'='*50}")
|
| 52 |
+
print(f"处理模型文件夹: {subdir}")
|
| 53 |
+
print(f"{'='*50}")
|
| 54 |
+
|
| 55 |
+
# 转换非量化版本
|
| 56 |
+
print(f"\n--- 转换非量化版本 ---")
|
| 57 |
+
ret = convert_model(model_path, unquantized_output, do_quantization=False)
|
| 58 |
+
if ret != 0:
|
| 59 |
+
print(f"非量化版本转换失败: {subdir}")
|
| 60 |
+
continue
|
| 61 |
+
|
| 62 |
+
# 转换量化版本
|
| 63 |
+
print(f"\n--- 转换量化版本 ---")
|
| 64 |
+
ret = convert_model(model_path, quantized_output, do_quantization=True)
|
| 65 |
+
if ret != 0:
|
| 66 |
+
print(f"量化版本转换失败: {subdir}")
|
| 67 |
+
continue
|
| 68 |
+
|
| 69 |
+
print(f"\n✓ {subdir} 模型转换完成!")
|
| 70 |
+
print(f" - 非量化版本: {unquantized_output}")
|
| 71 |
+
print(f" - 量化版本: {quantized_output}")
|
| 72 |
+
|
| 73 |
+
if __name__ == "__main__":
|
| 74 |
+
main()
|
rkllm_binding.py
ADDED
|
@@ -0,0 +1,658 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ctypes
|
| 2 |
+
import enum
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# Define constants from the header
|
| 6 |
+
CPU0 = (1 << 0) # 0x01
|
| 7 |
+
CPU1 = (1 << 1) # 0x02
|
| 8 |
+
CPU2 = (1 << 2) # 0x04
|
| 9 |
+
CPU3 = (1 << 3) # 0x08
|
| 10 |
+
CPU4 = (1 << 4) # 0x10
|
| 11 |
+
CPU5 = (1 << 5) # 0x20
|
| 12 |
+
CPU6 = (1 << 6) # 0x40
|
| 13 |
+
CPU7 = (1 << 7) # 0x80
|
| 14 |
+
|
| 15 |
+
# --- Enums ---
|
| 16 |
+
class LLMCallState(enum.IntEnum):
|
| 17 |
+
RKLLM_RUN_NORMAL = 0
|
| 18 |
+
RKLLM_RUN_WAITING = 1
|
| 19 |
+
RKLLM_RUN_FINISH = 2
|
| 20 |
+
RKLLM_RUN_ERROR = 3
|
| 21 |
+
|
| 22 |
+
class RKLLMInputType(enum.IntEnum):
|
| 23 |
+
RKLLM_INPUT_PROMPT = 0
|
| 24 |
+
RKLLM_INPUT_TOKEN = 1
|
| 25 |
+
RKLLM_INPUT_EMBED = 2
|
| 26 |
+
RKLLM_INPUT_MULTIMODAL = 3
|
| 27 |
+
|
| 28 |
+
class RKLLMInferMode(enum.IntEnum):
|
| 29 |
+
RKLLM_INFER_GENERATE = 0
|
| 30 |
+
RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1
|
| 31 |
+
RKLLM_INFER_GET_LOGITS = 2
|
| 32 |
+
|
| 33 |
+
# --- Structures ---
|
| 34 |
+
class RKLLMExtendParam(ctypes.Structure):
|
| 35 |
+
# 基础iommu domain ID, 对>1b的模型建议设置为1
|
| 36 |
+
base_domain_id: ctypes.c_int32
|
| 37 |
+
# 是否使用flash存储Embedding
|
| 38 |
+
embed_flash: ctypes.c_int8
|
| 39 |
+
# 启用的cpu核心数
|
| 40 |
+
enabled_cpus_num: ctypes.c_int8
|
| 41 |
+
# 启用的cpu核心掩码
|
| 42 |
+
enabled_cpus_mask: ctypes.c_uint32
|
| 43 |
+
reserved: ctypes.c_uint8 * 106
|
| 44 |
+
|
| 45 |
+
_fields_ = [
|
| 46 |
+
("base_domain_id", ctypes.c_int32),
|
| 47 |
+
("embed_flash", ctypes.c_int8),
|
| 48 |
+
("enabled_cpus_num", ctypes.c_int8),
|
| 49 |
+
("enabled_cpus_mask", ctypes.c_uint32),
|
| 50 |
+
("reserved", ctypes.c_uint8 * 106)
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
class RKLLMParam(ctypes.Structure):
|
| 54 |
+
# 模型文件路径
|
| 55 |
+
model_path: ctypes.c_char_p
|
| 56 |
+
# 上下文窗口最大token数
|
| 57 |
+
max_context_len: ctypes.c_int32
|
| 58 |
+
# 最大生成新token数
|
| 59 |
+
max_new_tokens: ctypes.c_int32
|
| 60 |
+
# Top-K采样参数
|
| 61 |
+
top_k: ctypes.c_int32
|
| 62 |
+
# 上下文窗口移动时保留的kv缓存数量
|
| 63 |
+
n_keep: ctypes.c_int32
|
| 64 |
+
# Top-P采样参数
|
| 65 |
+
top_p: ctypes.c_float
|
| 66 |
+
# 采样温度,影响token选择的随机性
|
| 67 |
+
temperature: ctypes.c_float
|
| 68 |
+
# 重复token惩罚
|
| 69 |
+
repeat_penalty: ctypes.c_float
|
| 70 |
+
# 频繁token惩罚
|
| 71 |
+
frequency_penalty: ctypes.c_float
|
| 72 |
+
# 输入中已存在token的惩罚
|
| 73 |
+
presence_penalty: ctypes.c_float
|
| 74 |
+
# Mirostat采样策略标志(0表示禁用)
|
| 75 |
+
mirostat: ctypes.c_int32
|
| 76 |
+
# Mirostat采样Tau参数
|
| 77 |
+
mirostat_tau: ctypes.c_float
|
| 78 |
+
# Mirostat采样Eta参数
|
| 79 |
+
mirostat_eta: ctypes.c_float
|
| 80 |
+
# 是否跳过特殊token
|
| 81 |
+
skip_special_token: ctypes.c_bool
|
| 82 |
+
# 是否异步推理
|
| 83 |
+
is_async: ctypes.c_bool
|
| 84 |
+
# 多模态输入中图像的起始Token
|
| 85 |
+
img_start: ctypes.c_char_p
|
| 86 |
+
# 多模态输入中图像的结束Token
|
| 87 |
+
img_end: ctypes.c_char_p
|
| 88 |
+
# 图像内容指针
|
| 89 |
+
img_content: ctypes.c_char_p
|
| 90 |
+
# 扩展参数
|
| 91 |
+
extend_param: RKLLMExtendParam
|
| 92 |
+
|
| 93 |
+
_fields_ = [
|
| 94 |
+
("model_path", ctypes.c_char_p), # 模型文件路径
|
| 95 |
+
("max_context_len", ctypes.c_int32), # 上下文窗口最大token数
|
| 96 |
+
("max_new_tokens", ctypes.c_int32), # 最大生成新token数
|
| 97 |
+
("top_k", ctypes.c_int32), # Top-K采样参数
|
| 98 |
+
("n_keep", ctypes.c_int32), # 上下文窗口移动时保留的kv缓存数量
|
| 99 |
+
("top_p", ctypes.c_float), # Top-P(nucleus)采样参数
|
| 100 |
+
("temperature", ctypes.c_float), # 采样温度,影响token选择的随机性
|
| 101 |
+
("repeat_penalty", ctypes.c_float), # 重复token惩罚
|
| 102 |
+
("frequency_penalty", ctypes.c_float), # 频繁token惩罚
|
| 103 |
+
("presence_penalty", ctypes.c_float), # 输入中已存在token的惩罚
|
| 104 |
+
("mirostat", ctypes.c_int32), # Mirostat采样策略标志(0表示禁用)
|
| 105 |
+
("mirostat_tau", ctypes.c_float), # Mirostat采样Tau参数
|
| 106 |
+
("mirostat_eta", ctypes.c_float), # Mirostat采样Eta参数
|
| 107 |
+
("skip_special_token", ctypes.c_bool), # 是否跳过特殊token
|
| 108 |
+
("is_async", ctypes.c_bool), # 是否异步推理
|
| 109 |
+
("img_start", ctypes.c_char_p), # 多模态输入中图像的起始Token
|
| 110 |
+
("img_end", ctypes.c_char_p), # 多模态输入中图像的结束Token
|
| 111 |
+
("img_content", ctypes.c_char_p), # 图像内容指针
|
| 112 |
+
("extend_param", RKLLMExtendParam) # 扩展参数
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
class RKLLMLoraAdapter(ctypes.Structure):
|
| 116 |
+
lora_adapter_path: ctypes.c_char_p
|
| 117 |
+
lora_adapter_name: ctypes.c_char_p
|
| 118 |
+
scale: ctypes.c_float
|
| 119 |
+
|
| 120 |
+
_fields_ = [
|
| 121 |
+
("lora_adapter_path", ctypes.c_char_p),
|
| 122 |
+
("lora_adapter_name", ctypes.c_char_p),
|
| 123 |
+
("scale", ctypes.c_float)
|
| 124 |
+
]
|
| 125 |
+
|
| 126 |
+
class RKLLMEmbedInput(ctypes.Structure):
|
| 127 |
+
# Shape: [n_tokens, embed_size]
|
| 128 |
+
embed: ctypes.POINTER(ctypes.c_float)
|
| 129 |
+
n_tokens: ctypes.c_size_t
|
| 130 |
+
|
| 131 |
+
_fields_ = [
|
| 132 |
+
("embed", ctypes.POINTER(ctypes.c_float)),
|
| 133 |
+
("n_tokens", ctypes.c_size_t)
|
| 134 |
+
]
|
| 135 |
+
|
| 136 |
+
class RKLLMTokenInput(ctypes.Structure):
|
| 137 |
+
# Shape: [n_tokens]
|
| 138 |
+
input_ids: ctypes.POINTER(ctypes.c_int32)
|
| 139 |
+
n_tokens: ctypes.c_size_t
|
| 140 |
+
|
| 141 |
+
_fields_ = [
|
| 142 |
+
("input_ids", ctypes.POINTER(ctypes.c_int32)),
|
| 143 |
+
("n_tokens", ctypes.c_size_t)
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
class RKLLMMultiModelInput(ctypes.Structure):
|
| 147 |
+
prompt: ctypes.c_char_p
|
| 148 |
+
image_embed: ctypes.POINTER(ctypes.c_float)
|
| 149 |
+
n_image_tokens: ctypes.c_size_t
|
| 150 |
+
n_image: ctypes.c_size_t
|
| 151 |
+
image_width: ctypes.c_size_t
|
| 152 |
+
image_height: ctypes.c_size_t
|
| 153 |
+
|
| 154 |
+
_fields_ = [
|
| 155 |
+
("prompt", ctypes.c_char_p),
|
| 156 |
+
("image_embed", ctypes.POINTER(ctypes.c_float)),
|
| 157 |
+
("n_image_tokens", ctypes.c_size_t),
|
| 158 |
+
("n_image", ctypes.c_size_t),
|
| 159 |
+
("image_width", ctypes.c_size_t),
|
| 160 |
+
("image_height", ctypes.c_size_t)
|
| 161 |
+
]
|
| 162 |
+
|
| 163 |
+
class _RKLLMInputUnion(ctypes.Union):
|
| 164 |
+
prompt_input: ctypes.c_char_p
|
| 165 |
+
embed_input: RKLLMEmbedInput
|
| 166 |
+
token_input: RKLLMTokenInput
|
| 167 |
+
multimodal_input: RKLLMMultiModelInput
|
| 168 |
+
|
| 169 |
+
_fields_ = [
|
| 170 |
+
("prompt_input", ctypes.c_char_p),
|
| 171 |
+
("embed_input", RKLLMEmbedInput),
|
| 172 |
+
("token_input", RKLLMTokenInput),
|
| 173 |
+
("multimodal_input", RKLLMMultiModelInput)
|
| 174 |
+
]
|
| 175 |
+
|
| 176 |
+
class RKLLMInput(ctypes.Structure):
|
| 177 |
+
input_type: ctypes.c_int
|
| 178 |
+
_union_data: _RKLLMInputUnion
|
| 179 |
+
|
| 180 |
+
_fields_ = [
|
| 181 |
+
("input_type", ctypes.c_int), # Enum will be passed as int, changed RKLLMInputType to ctypes.c_int
|
| 182 |
+
("_union_data", _RKLLMInputUnion)
|
| 183 |
+
]
|
| 184 |
+
# Properties to make accessing union members easier
|
| 185 |
+
@property
|
| 186 |
+
def prompt_input(self) -> bytes: # Assuming c_char_p maps to bytes
|
| 187 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
|
| 188 |
+
return self._union_data.prompt_input
|
| 189 |
+
raise AttributeError("Not a prompt input")
|
| 190 |
+
@prompt_input.setter
|
| 191 |
+
def prompt_input(self, value: bytes): # Assuming c_char_p maps to bytes
|
| 192 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
|
| 193 |
+
self._union_data.prompt_input = value
|
| 194 |
+
else:
|
| 195 |
+
raise AttributeError("Not a prompt input")
|
| 196 |
+
@property
|
| 197 |
+
def embed_input(self) -> RKLLMEmbedInput:
|
| 198 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
|
| 199 |
+
return self._union_data.embed_input
|
| 200 |
+
raise AttributeError("Not an embed input")
|
| 201 |
+
@embed_input.setter
|
| 202 |
+
def embed_input(self, value: RKLLMEmbedInput):
|
| 203 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
|
| 204 |
+
self._union_data.embed_input = value
|
| 205 |
+
else:
|
| 206 |
+
raise AttributeError("Not an embed input")
|
| 207 |
+
|
| 208 |
+
@property
|
| 209 |
+
def token_input(self) -> RKLLMTokenInput:
|
| 210 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
|
| 211 |
+
return self._union_data.token_input
|
| 212 |
+
raise AttributeError("Not a token input")
|
| 213 |
+
@token_input.setter
|
| 214 |
+
def token_input(self, value: RKLLMTokenInput):
|
| 215 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
|
| 216 |
+
self._union_data.token_input = value
|
| 217 |
+
else:
|
| 218 |
+
raise AttributeError("Not a token input")
|
| 219 |
+
|
| 220 |
+
@property
|
| 221 |
+
def multimodal_input(self) -> RKLLMMultiModelInput:
|
| 222 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL:
|
| 223 |
+
return self._union_data.multimodal_input
|
| 224 |
+
raise AttributeError("Not a multimodal input")
|
| 225 |
+
@multimodal_input.setter
|
| 226 |
+
def multimodal_input(self, value: RKLLMMultiModelInput):
|
| 227 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL:
|
| 228 |
+
self._union_data.multimodal_input = value
|
| 229 |
+
else:
|
| 230 |
+
raise AttributeError("Not a multimodal input")
|
| 231 |
+
|
| 232 |
+
class RKLLMLoraParam(ctypes.Structure): # For inference
|
| 233 |
+
lora_adapter_name: ctypes.c_char_p
|
| 234 |
+
|
| 235 |
+
_fields_ = [
|
| 236 |
+
("lora_adapter_name", ctypes.c_char_p)
|
| 237 |
+
]
|
| 238 |
+
|
| 239 |
+
class RKLLMPromptCacheParam(ctypes.Structure): # For inference
|
| 240 |
+
save_prompt_cache: ctypes.c_int # bool-like
|
| 241 |
+
prompt_cache_path: ctypes.c_char_p
|
| 242 |
+
|
| 243 |
+
_fields_ = [
|
| 244 |
+
("save_prompt_cache", ctypes.c_int), # bool-like
|
| 245 |
+
("prompt_cache_path", ctypes.c_char_p)
|
| 246 |
+
]
|
| 247 |
+
|
| 248 |
+
class RKLLMInferParam(ctypes.Structure):
|
| 249 |
+
mode: ctypes.c_int
|
| 250 |
+
lora_params: ctypes.POINTER(RKLLMLoraParam)
|
| 251 |
+
prompt_cache_params: ctypes.POINTER(RKLLMPromptCacheParam)
|
| 252 |
+
keep_history: ctypes.c_int # bool-like
|
| 253 |
+
|
| 254 |
+
_fields_ = [
|
| 255 |
+
("mode", ctypes.c_int), # Enum will be passed as int, changed RKLLMInferMode to ctypes.c_int
|
| 256 |
+
("lora_params", ctypes.POINTER(RKLLMLoraParam)),
|
| 257 |
+
("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)),
|
| 258 |
+
("keep_history", ctypes.c_int) # bool-like
|
| 259 |
+
]
|
| 260 |
+
|
| 261 |
+
class RKLLMResultLastHiddenLayer(ctypes.Structure):
|
| 262 |
+
# Shape: [num_tokens, embd_size]
|
| 263 |
+
hidden_states: ctypes.POINTER(ctypes.c_float)
|
| 264 |
+
# 隐藏层大小
|
| 265 |
+
embd_size: ctypes.c_int
|
| 266 |
+
# 输出token数
|
| 267 |
+
num_tokens: ctypes.c_int
|
| 268 |
+
|
| 269 |
+
_fields_ = [
|
| 270 |
+
("hidden_states", ctypes.POINTER(ctypes.c_float)),
|
| 271 |
+
("embd_size", ctypes.c_int),
|
| 272 |
+
("num_tokens", ctypes.c_int)
|
| 273 |
+
]
|
| 274 |
+
|
| 275 |
+
class RKLLMResultLogits(ctypes.Structure):
|
| 276 |
+
# Shape: [num_tokens, vocab_size]
|
| 277 |
+
logits: ctypes.POINTER(ctypes.c_float)
|
| 278 |
+
# 词汇表大小
|
| 279 |
+
vocab_size: ctypes.c_int
|
| 280 |
+
# 输出token数
|
| 281 |
+
num_tokens: ctypes.c_int
|
| 282 |
+
|
| 283 |
+
_fields_ = [
|
| 284 |
+
("logits", ctypes.POINTER(ctypes.c_float)),
|
| 285 |
+
("vocab_size", ctypes.c_int),
|
| 286 |
+
("num_tokens", ctypes.c_int)
|
| 287 |
+
]
|
| 288 |
+
|
| 289 |
+
class RKLLMResult(ctypes.Structure):
|
| 290 |
+
text: ctypes.c_char_p
|
| 291 |
+
token_id: ctypes.c_int32
|
| 292 |
+
last_hidden_layer: RKLLMResultLastHiddenLayer
|
| 293 |
+
logits: RKLLMResultLogits
|
| 294 |
+
|
| 295 |
+
_fields_ = [
|
| 296 |
+
("text", ctypes.c_char_p),
|
| 297 |
+
("token_id", ctypes.c_int32),
|
| 298 |
+
("last_hidden_layer", RKLLMResultLastHiddenLayer),
|
| 299 |
+
("logits", RKLLMResultLogits)
|
| 300 |
+
]
|
| 301 |
+
|
| 302 |
+
# --- Typedefs ---
|
| 303 |
+
LLMHandle = ctypes.c_void_p
|
| 304 |
+
|
| 305 |
+
# --- Callback Function Type ---
|
| 306 |
+
LLMResultCallback = ctypes.CFUNCTYPE(
|
| 307 |
+
None, # return type: void
|
| 308 |
+
ctypes.POINTER(RKLLMResult),
|
| 309 |
+
ctypes.c_void_p, # userdata
|
| 310 |
+
ctypes.c_int # enum, will be passed as int. Changed LLMCallState to ctypes.c_int
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class RKLLMRuntime:
|
| 315 |
+
def __init__(self, library_path="./librkllmrt.so"):
|
| 316 |
+
try:
|
| 317 |
+
self.lib = ctypes.CDLL(library_path)
|
| 318 |
+
except OSError as e:
|
| 319 |
+
raise OSError(f"Failed to load RKLLM library from {library_path}. "
|
| 320 |
+
f"Ensure it's in your LD_LIBRARY_PATH or provide the full path. Error: {e}")
|
| 321 |
+
self._setup_functions()
|
| 322 |
+
self.llm_handle = LLMHandle()
|
| 323 |
+
self._c_callback = None # To keep the callback object alive
|
| 324 |
+
|
| 325 |
+
def _setup_functions(self):
|
| 326 |
+
# RKLLMParam rkllm_createDefaultParam();
|
| 327 |
+
self.lib.rkllm_createDefaultParam.restype = RKLLMParam
|
| 328 |
+
self.lib.rkllm_createDefaultParam.argtypes = []
|
| 329 |
+
|
| 330 |
+
# int rkllm_init(LLMHandle* handle, RKLLMParam* param, LLMResultCallback callback);
|
| 331 |
+
self.lib.rkllm_init.restype = ctypes.c_int
|
| 332 |
+
self.lib.rkllm_init.argtypes = [
|
| 333 |
+
ctypes.POINTER(LLMHandle),
|
| 334 |
+
ctypes.POINTER(RKLLMParam),
|
| 335 |
+
LLMResultCallback
|
| 336 |
+
]
|
| 337 |
+
|
| 338 |
+
# int rkllm_load_lora(LLMHandle handle, RKLLMLoraAdapter* lora_adapter);
|
| 339 |
+
self.lib.rkllm_load_lora.restype = ctypes.c_int
|
| 340 |
+
self.lib.rkllm_load_lora.argtypes = [LLMHandle, ctypes.POINTER(RKLLMLoraAdapter)]
|
| 341 |
+
|
| 342 |
+
# int rkllm_load_prompt_cache(LLMHandle handle, const char* prompt_cache_path);
|
| 343 |
+
self.lib.rkllm_load_prompt_cache.restype = ctypes.c_int
|
| 344 |
+
self.lib.rkllm_load_prompt_cache.argtypes = [LLMHandle, ctypes.c_char_p]
|
| 345 |
+
|
| 346 |
+
# int rkllm_release_prompt_cache(LLMHandle handle);
|
| 347 |
+
self.lib.rkllm_release_prompt_cache.restype = ctypes.c_int
|
| 348 |
+
self.lib.rkllm_release_prompt_cache.argtypes = [LLMHandle]
|
| 349 |
+
|
| 350 |
+
# int rkllm_destroy(LLMHandle handle);
|
| 351 |
+
self.lib.rkllm_destroy.restype = ctypes.c_int
|
| 352 |
+
self.lib.rkllm_destroy.argtypes = [LLMHandle]
|
| 353 |
+
|
| 354 |
+
# int rkllm_run(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
|
| 355 |
+
self.lib.rkllm_run.restype = ctypes.c_int
|
| 356 |
+
self.lib.rkllm_run.argtypes = [
|
| 357 |
+
LLMHandle,
|
| 358 |
+
ctypes.POINTER(RKLLMInput),
|
| 359 |
+
ctypes.POINTER(RKLLMInferParam),
|
| 360 |
+
ctypes.c_void_p # userdata
|
| 361 |
+
]
|
| 362 |
+
|
| 363 |
+
# int rkllm_run_async(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
|
| 364 |
+
# Assuming async also takes userdata for the callback context
|
| 365 |
+
self.lib.rkllm_run_async.restype = ctypes.c_int
|
| 366 |
+
self.lib.rkllm_run_async.argtypes = [
|
| 367 |
+
LLMHandle,
|
| 368 |
+
ctypes.POINTER(RKLLMInput),
|
| 369 |
+
ctypes.POINTER(RKLLMInferParam),
|
| 370 |
+
ctypes.c_void_p # userdata
|
| 371 |
+
]
|
| 372 |
+
|
| 373 |
+
# int rkllm_abort(LLMHandle handle);
|
| 374 |
+
self.lib.rkllm_abort.restype = ctypes.c_int
|
| 375 |
+
self.lib.rkllm_abort.argtypes = [LLMHandle]
|
| 376 |
+
|
| 377 |
+
# int rkllm_is_running(LLMHandle handle);
|
| 378 |
+
self.lib.rkllm_is_running.restype = ctypes.c_int # 0 if running, non-zero otherwise
|
| 379 |
+
self.lib.rkllm_is_running.argtypes = [LLMHandle]
|
| 380 |
+
|
| 381 |
+
# int rkllm_clear_kv_cache(LLMHandle handle, int keep_system_prompt);
|
| 382 |
+
self.lib.rkllm_clear_kv_cache.restype = ctypes.c_int
|
| 383 |
+
self.lib.rkllm_clear_kv_cache.argtypes = [LLMHandle, ctypes.c_int]
|
| 384 |
+
|
| 385 |
+
# int rkllm_set_chat_template(LLMHandle handle, const char* system_prompt, const char* prompt_prefix, const char* prompt_postfix);
|
| 386 |
+
self.lib.rkllm_set_chat_template.restype = ctypes.c_int
|
| 387 |
+
self.lib.rkllm_set_chat_template.argtypes = [
|
| 388 |
+
LLMHandle,
|
| 389 |
+
ctypes.c_char_p,
|
| 390 |
+
ctypes.c_char_p,
|
| 391 |
+
ctypes.c_char_p
|
| 392 |
+
]
|
| 393 |
+
|
| 394 |
+
def create_default_param(self) -> RKLLMParam:
|
| 395 |
+
"""Creates a default RKLLMParam structure."""
|
| 396 |
+
return self.lib.rkllm_createDefaultParam()
|
| 397 |
+
|
| 398 |
+
def init(self, param: RKLLMParam, callback_func) -> int:
|
| 399 |
+
"""
|
| 400 |
+
Initializes the LLM.
|
| 401 |
+
:param param: RKLLMParam structure.
|
| 402 |
+
:param callback_func: A Python function that matches the signature:
|
| 403 |
+
def my_callback(result_ptr, userdata_ptr, state_enum):
|
| 404 |
+
result = result_ptr.contents # RKLLMResult
|
| 405 |
+
# Process result
|
| 406 |
+
# userdata can be retrieved if passed during run, or ignored
|
| 407 |
+
# state = LLMCallState(state_enum)
|
| 408 |
+
:return: 0 for success, non-zero for failure.
|
| 409 |
+
"""
|
| 410 |
+
if not callable(callback_func):
|
| 411 |
+
raise ValueError("callback_func must be a callable Python function.")
|
| 412 |
+
|
| 413 |
+
# Keep a reference to the ctypes callback object to prevent it from being garbage collected
|
| 414 |
+
self._c_callback = LLMResultCallback(callback_func)
|
| 415 |
+
|
| 416 |
+
ret = self.lib.rkllm_init(ctypes.byref(self.llm_handle), ctypes.byref(param), self._c_callback)
|
| 417 |
+
if ret != 0:
|
| 418 |
+
raise RuntimeError(f"rkllm_init failed with error code {ret}")
|
| 419 |
+
return ret
|
| 420 |
+
|
| 421 |
+
def load_lora(self, lora_adapter: RKLLMLoraAdapter) -> int:
|
| 422 |
+
"""Loads a Lora adapter."""
|
| 423 |
+
ret = self.lib.rkllm_load_lora(self.llm_handle, ctypes.byref(lora_adapter))
|
| 424 |
+
if ret != 0:
|
| 425 |
+
raise RuntimeError(f"rkllm_load_lora failed with error code {ret}")
|
| 426 |
+
return ret
|
| 427 |
+
|
| 428 |
+
def load_prompt_cache(self, prompt_cache_path: str) -> int:
|
| 429 |
+
"""Loads a prompt cache from a file."""
|
| 430 |
+
c_path = prompt_cache_path.encode('utf-8')
|
| 431 |
+
ret = self.lib.rkllm_load_prompt_cache(self.llm_handle, c_path)
|
| 432 |
+
if ret != 0:
|
| 433 |
+
raise RuntimeError(f"rkllm_load_prompt_cache failed for {prompt_cache_path} with error code {ret}")
|
| 434 |
+
return ret
|
| 435 |
+
|
| 436 |
+
def release_prompt_cache(self) -> int:
|
| 437 |
+
"""Releases the prompt cache from memory."""
|
| 438 |
+
ret = self.lib.rkllm_release_prompt_cache(self.llm_handle)
|
| 439 |
+
if ret != 0:
|
| 440 |
+
raise RuntimeError(f"rkllm_release_prompt_cache failed with error code {ret}")
|
| 441 |
+
return ret
|
| 442 |
+
|
| 443 |
+
def destroy(self) -> int:
|
| 444 |
+
"""Destroys the LLM instance and releases resources."""
|
| 445 |
+
if self.llm_handle and self.llm_handle.value: # Check if handle is not NULL
|
| 446 |
+
ret = self.lib.rkllm_destroy(self.llm_handle)
|
| 447 |
+
self.llm_handle = LLMHandle() # Reset handle
|
| 448 |
+
if ret != 0:
|
| 449 |
+
# Don't raise here as it might be called in __del__
|
| 450 |
+
print(f"Warning: rkllm_destroy failed with error code {ret}")
|
| 451 |
+
return ret
|
| 452 |
+
return 0 # Already destroyed or not initialized
|
| 453 |
+
|
| 454 |
+
def run(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
|
| 455 |
+
"""Runs an LLM inference task synchronously."""
|
| 456 |
+
# userdata can be a ctypes.py_object if you want to pass Python objects,
|
| 457 |
+
# then cast to c_void_p. Or simply None.
|
| 458 |
+
if userdata is not None:
|
| 459 |
+
# Store the userdata object to keep it alive during the call
|
| 460 |
+
self._userdata_ref = userdata
|
| 461 |
+
c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p)
|
| 462 |
+
else:
|
| 463 |
+
c_userdata = None
|
| 464 |
+
ret = self.lib.rkllm_run(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
|
| 465 |
+
if ret != 0:
|
| 466 |
+
raise RuntimeError(f"rkllm_run failed with error code {ret}")
|
| 467 |
+
return ret
|
| 468 |
+
|
| 469 |
+
def run_async(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
|
| 470 |
+
"""Runs an LLM inference task asynchronously."""
|
| 471 |
+
if userdata is not None:
|
| 472 |
+
# Store the userdata object to keep it alive during the call
|
| 473 |
+
self._userdata_ref = userdata
|
| 474 |
+
c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p)
|
| 475 |
+
else:
|
| 476 |
+
c_userdata = None
|
| 477 |
+
ret = self.lib.rkllm_run_async(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
|
| 478 |
+
if ret != 0:
|
| 479 |
+
raise RuntimeError(f"rkllm_run_async failed with error code {ret}")
|
| 480 |
+
return ret
|
| 481 |
+
|
| 482 |
+
def abort(self) -> int:
|
| 483 |
+
"""Aborts an ongoing LLM task."""
|
| 484 |
+
ret = self.lib.rkllm_abort(self.llm_handle)
|
| 485 |
+
if ret != 0:
|
| 486 |
+
raise RuntimeError(f"rkllm_abort failed with error code {ret}")
|
| 487 |
+
return ret
|
| 488 |
+
|
| 489 |
+
def is_running(self) -> bool:
|
| 490 |
+
"""Checks if an LLM task is currently running. Returns True if running."""
|
| 491 |
+
# The C API returns 0 if running, non-zero otherwise.
|
| 492 |
+
# This is a bit counter-intuitive for a boolean "is_running".
|
| 493 |
+
return self.lib.rkllm_is_running(self.llm_handle) == 0
|
| 494 |
+
|
| 495 |
+
def clear_kv_cache(self, keep_system_prompt: bool) -> int:
|
| 496 |
+
"""Clears the key-value cache."""
|
| 497 |
+
ret = self.lib.rkllm_clear_kv_cache(self.llm_handle, ctypes.c_int(1 if keep_system_prompt else 0))
|
| 498 |
+
if ret != 0:
|
| 499 |
+
raise RuntimeError(f"rkllm_clear_kv_cache failed with error code {ret}")
|
| 500 |
+
return ret
|
| 501 |
+
|
| 502 |
+
def set_chat_template(self, system_prompt: str, prompt_prefix: str, prompt_postfix: str) -> int:
|
| 503 |
+
"""Sets the chat template for the LLM."""
|
| 504 |
+
c_system = system_prompt.encode('utf-8') if system_prompt else b""
|
| 505 |
+
c_prefix = prompt_prefix.encode('utf-8') if prompt_prefix else b""
|
| 506 |
+
c_postfix = prompt_postfix.encode('utf-8') if prompt_postfix else b""
|
| 507 |
+
|
| 508 |
+
ret = self.lib.rkllm_set_chat_template(self.llm_handle, c_system, c_prefix, c_postfix)
|
| 509 |
+
if ret != 0:
|
| 510 |
+
raise RuntimeError(f"rkllm_set_chat_template failed with error code {ret}")
|
| 511 |
+
return ret
|
| 512 |
+
|
| 513 |
+
def __enter__(self):
|
| 514 |
+
return self
|
| 515 |
+
|
| 516 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 517 |
+
self.destroy()
|
| 518 |
+
|
| 519 |
+
def __del__(self):
|
| 520 |
+
self.destroy() # Ensure resources are freed if object is garbage collected
|
| 521 |
+
|
| 522 |
+
# --- Example Usage (Illustrative) ---
|
| 523 |
+
if __name__ == "__main__":
|
| 524 |
+
# This is a placeholder for how you might use it.
|
| 525 |
+
# You'll need a valid .rkllm model and librkllmrt.so in your path.
|
| 526 |
+
|
| 527 |
+
# Global list to store results from callback for demonstration
|
| 528 |
+
results_buffer = []
|
| 529 |
+
|
| 530 |
+
def my_python_callback(result_ptr, userdata_ptr, state_enum):
|
| 531 |
+
"""
|
| 532 |
+
Callback function to be called by the C library.
|
| 533 |
+
"""
|
| 534 |
+
global results_buffer
|
| 535 |
+
state = LLMCallState(state_enum)
|
| 536 |
+
result = result_ptr.contents
|
| 537 |
+
|
| 538 |
+
current_text = ""
|
| 539 |
+
if result.text: # Check if the char_p is not NULL
|
| 540 |
+
current_text = result.text.decode('utf-8', errors='ignore')
|
| 541 |
+
|
| 542 |
+
print(f"Callback: State={state.name}, TokenID={result.token_id}, Text='{current_text}'")
|
| 543 |
+
results_buffer.append(current_text)
|
| 544 |
+
|
| 545 |
+
if state == LLMCallState.RKLLM_RUN_FINISH:
|
| 546 |
+
print("Inference finished.")
|
| 547 |
+
elif state == LLMCallState.RKLLM_RUN_ERROR:
|
| 548 |
+
print("Inference error.")
|
| 549 |
+
|
| 550 |
+
# Example: Accessing logits if available (and if mode was set to get logits)
|
| 551 |
+
# if result.logits.logits and result.logits.vocab_size > 0:
|
| 552 |
+
# print(f" Logits (first 5 of vocab_size {result.logits.vocab_size}):")
|
| 553 |
+
# for i in range(min(5, result.logits.vocab_size)):
|
| 554 |
+
# print(f" {result.logits.logits[i]:.4f}", end=" ")
|
| 555 |
+
# print()
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
# --- Attempt to use the wrapper ---
|
| 559 |
+
try:
|
| 560 |
+
print("Initializing RKLLMRuntime...")
|
| 561 |
+
# Adjust library_path if librkllmrt.so is not in default search paths
|
| 562 |
+
# e.g., library_path="./path/to/librkllmrt.so"
|
| 563 |
+
rk_llm = RKLLMRuntime()
|
| 564 |
+
|
| 565 |
+
print("Creating default parameters...")
|
| 566 |
+
params = rk_llm.create_default_param()
|
| 567 |
+
|
| 568 |
+
# --- Configure parameters ---
|
| 569 |
+
# THIS IS CRITICAL: model_path must point to an actual .rkllm file
|
| 570 |
+
# For this example to run, you need a model file.
|
| 571 |
+
# Let's assume a dummy path for now, this will fail at init if not valid.
|
| 572 |
+
model_file = "dummy_model.rkllm"
|
| 573 |
+
if not os.path.exists(model_file):
|
| 574 |
+
print(f"Warning: Model file '{model_file}' does not exist. Init will likely fail.")
|
| 575 |
+
# Create a dummy file for the example to proceed further, though init will still fail
|
| 576 |
+
# with a real library unless it's a valid model.
|
| 577 |
+
with open(model_file, "w") as f:
|
| 578 |
+
f.write("dummy content")
|
| 579 |
+
|
| 580 |
+
params.model_path = model_file.encode('utf-8')
|
| 581 |
+
params.max_context_len = 512
|
| 582 |
+
params.max_new_tokens = 128
|
| 583 |
+
params.top_k = 1 # Greedy
|
| 584 |
+
params.temperature = 0.7
|
| 585 |
+
params.repeat_penalty = 1.1
|
| 586 |
+
# ... set other params as needed
|
| 587 |
+
|
| 588 |
+
print(f"Initializing LLM with model: {params.model_path.decode()}...")
|
| 589 |
+
# This will likely fail if dummy_model.rkllm is not a valid model recognized by the library
|
| 590 |
+
try:
|
| 591 |
+
rk_llm.init(params, my_python_callback)
|
| 592 |
+
print("LLM Initialized.")
|
| 593 |
+
except RuntimeError as e:
|
| 594 |
+
print(f"Error during LLM initialization: {e}")
|
| 595 |
+
print("This is expected if 'dummy_model.rkllm' is not a valid model.")
|
| 596 |
+
print("Replace 'dummy_model.rkllm' with a real model path to test further.")
|
| 597 |
+
exit()
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
# --- Prepare input ---
|
| 601 |
+
print("Preparing input...")
|
| 602 |
+
rk_input = RKLLMInput()
|
| 603 |
+
rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
|
| 604 |
+
|
| 605 |
+
prompt_text = "Translate the following English text to French: 'Hello, world!'"
|
| 606 |
+
c_prompt = prompt_text.encode('utf-8')
|
| 607 |
+
rk_input._union_data.prompt_input = c_prompt # Accessing union member directly
|
| 608 |
+
|
| 609 |
+
# --- Prepare inference parameters ---
|
| 610 |
+
print("Preparing inference parameters...")
|
| 611 |
+
infer_params = RKLLMInferParam()
|
| 612 |
+
infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
|
| 613 |
+
infer_params.keep_history = 1 # True
|
| 614 |
+
# infer_params.lora_params = None # or set up RKLLMLoraParam if using LoRA
|
| 615 |
+
# infer_params.prompt_cache_params = None # or set up RKLLMPromptCacheParam
|
| 616 |
+
|
| 617 |
+
# --- Run inference ---
|
| 618 |
+
print(f"Running inference with prompt: '{prompt_text}'")
|
| 619 |
+
results_buffer.clear()
|
| 620 |
+
try:
|
| 621 |
+
rk_llm.run(rk_input, infer_params) # Userdata is None by default
|
| 622 |
+
print("\n--- Full Response ---")
|
| 623 |
+
print("".join(results_buffer))
|
| 624 |
+
print("---------------------\n")
|
| 625 |
+
except RuntimeError as e:
|
| 626 |
+
print(f"Error during LLM run: {e}")
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
# --- Example: Set chat template (if model supports it) ---
|
| 630 |
+
# print("Setting chat template...")
|
| 631 |
+
# try:
|
| 632 |
+
# rk_llm.set_chat_template("You are a helpful assistant.", "<user>: ", "<assistant>: ")
|
| 633 |
+
# print("Chat template set.")
|
| 634 |
+
# except RuntimeError as e:
|
| 635 |
+
# print(f"Error setting chat template: {e}")
|
| 636 |
+
|
| 637 |
+
# --- Example: Clear KV Cache ---
|
| 638 |
+
# print("Clearing KV cache (keeping system prompt if any)...")
|
| 639 |
+
# try:
|
| 640 |
+
# rk_llm.clear_kv_cache(keep_system_prompt=True)
|
| 641 |
+
# print("KV cache cleared.")
|
| 642 |
+
# except RuntimeError as e:
|
| 643 |
+
# print(f"Error clearing KV cache: {e}")
|
| 644 |
+
|
| 645 |
+
except OSError as e:
|
| 646 |
+
print(f"OSError: {e}. Could not load the RKLLM library.")
|
| 647 |
+
print("Please ensure 'librkllmrt.so' is in your LD_LIBRARY_PATH or provide the full path.")
|
| 648 |
+
except Exception as e:
|
| 649 |
+
print(f"An unexpected error occurred: {e}")
|
| 650 |
+
finally:
|
| 651 |
+
if 'rk_llm' in locals() and rk_llm.llm_handle and rk_llm.llm_handle.value:
|
| 652 |
+
print("Destroying LLM instance...")
|
| 653 |
+
rk_llm.destroy()
|
| 654 |
+
print("LLM instance destroyed.")
|
| 655 |
+
if os.path.exists(model_file) and model_file == "dummy_model.rkllm":
|
| 656 |
+
os.remove(model_file) # Clean up dummy file
|
| 657 |
+
|
| 658 |
+
print("Example finished.")
|
test_reranker.py
ADDED
|
@@ -0,0 +1,551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
Qwen3-Reranker 推理测试代码
|
| 5 |
+
使用 RKLLM API 进行文本重排序推理
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import faulthandler
|
| 9 |
+
faulthandler.enable()
|
| 10 |
+
import os
|
| 11 |
+
os.environ["RKLLM_LOG_LEVEL"] = "1"
|
| 12 |
+
import numpy as np
|
| 13 |
+
import time
|
| 14 |
+
import re
|
| 15 |
+
from typing import List, Dict, Any, Tuple
|
| 16 |
+
from rkllm_binding import *
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Qwen3RerankerTester:
|
| 20 |
+
def __init__(self, model_path, library_path="./librkllmrt.so"):
|
| 21 |
+
"""
|
| 22 |
+
初始化 Qwen3 重排序模型测试器
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
model_path: 模型文件路径(.rkllm 格式)
|
| 26 |
+
library_path: RKLLM 库文件路径
|
| 27 |
+
"""
|
| 28 |
+
self.model_path = model_path
|
| 29 |
+
self.library_path = library_path
|
| 30 |
+
self.runtime = None
|
| 31 |
+
self.current_result = None
|
| 32 |
+
|
| 33 |
+
# 根据官方 README 设置的格式
|
| 34 |
+
self.system_prompt = "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\"."
|
| 35 |
+
|
| 36 |
+
# "yes" 和 "no" 的可能 token IDs(需要通过实际测试确定)
|
| 37 |
+
# 这些是常见的token ID,实际使用中可能需要调整
|
| 38 |
+
self.yes_token_candidates = [9693]
|
| 39 |
+
self.no_token_candidates = [2152]
|
| 40 |
+
|
| 41 |
+
def callback_function(self, result_ptr, userdata_ptr, state_enum):
|
| 42 |
+
"""
|
| 43 |
+
推理回调函数
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
result_ptr: 结果指针
|
| 47 |
+
userdata_ptr: 用户数据指针
|
| 48 |
+
state_enum: 状态枚举
|
| 49 |
+
"""
|
| 50 |
+
state = LLMCallState(state_enum)
|
| 51 |
+
|
| 52 |
+
if state == LLMCallState.RKLLM_RUN_NORMAL:
|
| 53 |
+
result = result_ptr.contents
|
| 54 |
+
print(f"result: {result}")
|
| 55 |
+
|
| 56 |
+
# 获取 logits
|
| 57 |
+
if result.logits.logits and result.logits.vocab_size > 0:
|
| 58 |
+
vocab_size = result.logits.vocab_size
|
| 59 |
+
num_tokens = result.logits.num_tokens
|
| 60 |
+
|
| 61 |
+
print(f"获取到 logits:vocab_size={vocab_size}, num_tokens={num_tokens}")
|
| 62 |
+
|
| 63 |
+
# 获取最后一个 token 的 logits
|
| 64 |
+
if num_tokens > 0:
|
| 65 |
+
last_token_logits = []
|
| 66 |
+
start_idx = (num_tokens - 1) * vocab_size
|
| 67 |
+
for i in range(vocab_size):
|
| 68 |
+
last_token_logits.append(result.logits.logits[start_idx + i])
|
| 69 |
+
|
| 70 |
+
self.current_result = {
|
| 71 |
+
'logits': last_token_logits,
|
| 72 |
+
'vocab_size': vocab_size,
|
| 73 |
+
'num_tokens': num_tokens
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
print(f"最后一个 token 的 logits 范围: [{min(last_token_logits):.4f}, {max(last_token_logits):.4f}]")
|
| 77 |
+
else:
|
| 78 |
+
print("警告: 未能获取到 logits")
|
| 79 |
+
|
| 80 |
+
elif state == LLMCallState.RKLLM_RUN_ERROR:
|
| 81 |
+
print("推理过程发生错误")
|
| 82 |
+
|
| 83 |
+
def find_best_yes_no_tokens(self, logits):
|
| 84 |
+
"""
|
| 85 |
+
找到最可能的 "yes" 和 "no" token IDs
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
logits: 词汇表大小的 logits 数组
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
(yes_token_id, no_token_id, yes_logit, no_logit)
|
| 92 |
+
"""
|
| 93 |
+
vocab_size = len(logits)
|
| 94 |
+
|
| 95 |
+
# 找到 yes token 的最大 logit
|
| 96 |
+
best_yes_id = None
|
| 97 |
+
best_yes_logit = float('-inf')
|
| 98 |
+
for token_id in self.yes_token_candidates:
|
| 99 |
+
if token_id < vocab_size:
|
| 100 |
+
if logits[token_id] > best_yes_logit:
|
| 101 |
+
best_yes_logit = logits[token_id]
|
| 102 |
+
best_yes_id = token_id
|
| 103 |
+
|
| 104 |
+
# 找到 no token 的最大 logit
|
| 105 |
+
best_no_id = None
|
| 106 |
+
best_no_logit = float('-inf')
|
| 107 |
+
for token_id in self.no_token_candidates:
|
| 108 |
+
if token_id < vocab_size:
|
| 109 |
+
if logits[token_id] > best_no_logit:
|
| 110 |
+
best_no_logit = logits[token_id]
|
| 111 |
+
best_no_id = token_id
|
| 112 |
+
|
| 113 |
+
# 如果找不到预定义的 token,使用启发式方法
|
| 114 |
+
if best_yes_id is None or best_no_id is None:
|
| 115 |
+
print("警告: 使用启发式方法寻找 yes/no tokens")
|
| 116 |
+
|
| 117 |
+
# 找到 logits 最高的几个 token
|
| 118 |
+
sorted_indices = np.argsort(logits)[::-1]
|
| 119 |
+
top_tokens = sorted_indices[:20] # 取前20个最高的 logits
|
| 120 |
+
|
| 121 |
+
# 简单启发式:假设较高的 logit 对应 "yes",较低的对应 "no"
|
| 122 |
+
if best_yes_id is None:
|
| 123 |
+
best_yes_id = top_tokens[0]
|
| 124 |
+
best_yes_logit = logits[best_yes_id]
|
| 125 |
+
|
| 126 |
+
if best_no_id is None:
|
| 127 |
+
# 寻找一个相对较低但合理的 logit 作为 "no"
|
| 128 |
+
best_no_id = top_tokens[min(10, len(top_tokens)-1)]
|
| 129 |
+
best_no_logit = logits[best_no_id]
|
| 130 |
+
|
| 131 |
+
return best_yes_id, best_no_id, best_yes_logit, best_no_logit
|
| 132 |
+
|
| 133 |
+
def calculate_reranker_score(self, logits):
|
| 134 |
+
"""
|
| 135 |
+
计算重排序分数(基于 "yes" 和 "no" token 的 softmax 概率)
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
logits: 词汇表大小的 logits 数组
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
相关性分数 (0-1之间,越高越相关)
|
| 142 |
+
"""
|
| 143 |
+
try:
|
| 144 |
+
# 找到 yes 和 no token 的 logits
|
| 145 |
+
yes_id, no_id, yes_logit, no_logit = self.find_best_yes_no_tokens(logits)
|
| 146 |
+
|
| 147 |
+
print(f"Yes token ID: {yes_id}, logit: {yes_logit:.4f}")
|
| 148 |
+
print(f"No token ID: {no_id}, logit: {no_logit:.4f}")
|
| 149 |
+
|
| 150 |
+
# 计算 softmax 概率
|
| 151 |
+
# 只考虑 yes 和 no 两个 token 的相对概率
|
| 152 |
+
max_logit = max(yes_logit, no_logit)
|
| 153 |
+
yes_exp = np.exp(yes_logit - max_logit) # 数值稳定性
|
| 154 |
+
no_exp = np.exp(no_logit - max_logit)
|
| 155 |
+
|
| 156 |
+
sum_exp = yes_exp + no_exp
|
| 157 |
+
yes_prob = yes_exp / sum_exp
|
| 158 |
+
no_prob = no_exp / sum_exp
|
| 159 |
+
|
| 160 |
+
print(f"Yes 概率: {yes_prob:.4f}, No 概率: {no_prob:.4f}")
|
| 161 |
+
|
| 162 |
+
# 返回 yes 的概率作为相关性分数
|
| 163 |
+
return float(yes_prob)
|
| 164 |
+
|
| 165 |
+
except Exception as e:
|
| 166 |
+
print(f"计算 reranker 分数时发生错误: {e}")
|
| 167 |
+
# 回退到简单的启发式方法
|
| 168 |
+
return self.fallback_score_calculation(logits)
|
| 169 |
+
|
| 170 |
+
def fallback_score_calculation(self, logits):
|
| 171 |
+
"""
|
| 172 |
+
备用分数计算方法(当无法找到 yes/no tokens 时)
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
logits: 词汇表大小的 logits 数组
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
相关性分数 (0-1之间)
|
| 179 |
+
"""
|
| 180 |
+
print("使用备用分数计算方法")
|
| 181 |
+
|
| 182 |
+
# 使用 logits 的分布特征计算分数
|
| 183 |
+
logits_array = np.array(logits)
|
| 184 |
+
|
| 185 |
+
# 计算 softmax 分布的熵
|
| 186 |
+
softmax_probs = np.exp(logits_array - np.max(logits_array))
|
| 187 |
+
softmax_probs = softmax_probs / np.sum(softmax_probs)
|
| 188 |
+
|
| 189 |
+
# 熵越低,模型越确信(越相关)
|
| 190 |
+
entropy = -np.sum(softmax_probs * np.log(softmax_probs + 1e-10))
|
| 191 |
+
max_entropy = np.log(len(logits))
|
| 192 |
+
normalized_entropy = entropy / max_entropy
|
| 193 |
+
|
| 194 |
+
# 转换为相关性分数(熵低 = 相关性高)
|
| 195 |
+
confidence_score = 1.0 - normalized_entropy
|
| 196 |
+
|
| 197 |
+
# 结合最大 logit 的信息
|
| 198 |
+
max_logit_score = (np.max(logits_array) - np.mean(logits_array)) / (np.std(logits_array) + 1e-8)
|
| 199 |
+
max_logit_score = max(0, min(1, max_logit_score / 10)) # 归一化
|
| 200 |
+
|
| 201 |
+
# 综合分数
|
| 202 |
+
final_score = 0.7 * confidence_score + 0.3 * max_logit_score
|
| 203 |
+
final_score = max(0.0, min(1.0, final_score))
|
| 204 |
+
|
| 205 |
+
print(f"备用计算 - 熵分数: {confidence_score:.4f}, 最大logit分数: {max_logit_score:.4f}, 最终分数: {final_score:.4f}")
|
| 206 |
+
|
| 207 |
+
return final_score
|
| 208 |
+
|
| 209 |
+
def init_model(self):
|
| 210 |
+
"""初始化模型"""
|
| 211 |
+
try:
|
| 212 |
+
print(f"初始化 RKLLM 运行时,库路径: {self.library_path}")
|
| 213 |
+
self.runtime = RKLLMRuntime(self.library_path)
|
| 214 |
+
|
| 215 |
+
print("创建默认参数...")
|
| 216 |
+
params = self.runtime.create_default_param()
|
| 217 |
+
|
| 218 |
+
# 配置参数
|
| 219 |
+
params.model_path = self.model_path.encode('utf-8')
|
| 220 |
+
params.max_context_len = 1024
|
| 221 |
+
params.max_new_tokens = 1 # reranker 只需要生成一个 token
|
| 222 |
+
params.temperature = 0.0 # 确定性输出
|
| 223 |
+
params.top_k = 1 # 贪心解码
|
| 224 |
+
params.top_p = 1.0 # 禁用nucleus采样
|
| 225 |
+
|
| 226 |
+
# 扩展参数配置
|
| 227 |
+
params.extend_param.base_domain_id = 1
|
| 228 |
+
params.extend_param.embed_flash = 0
|
| 229 |
+
params.extend_param.enabled_cpus_num = 4
|
| 230 |
+
params.extend_param.enabled_cpus_mask = 0x0F
|
| 231 |
+
|
| 232 |
+
print(f"初始化模型: {self.model_path}")
|
| 233 |
+
self.runtime.init(params, self.callback_function)
|
| 234 |
+
|
| 235 |
+
# 设置聊天模板
|
| 236 |
+
self.runtime.set_chat_template(
|
| 237 |
+
"",
|
| 238 |
+
"", # prefix
|
| 239 |
+
"" # suffix
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
print("模型初始化成功!")
|
| 243 |
+
|
| 244 |
+
except Exception as e:
|
| 245 |
+
print(f"模型初始化失败: {e}")
|
| 246 |
+
raise
|
| 247 |
+
|
| 248 |
+
def format_rerank_input(self, instruction, query, document):
|
| 249 |
+
"""
|
| 250 |
+
格式化重排序输入(根据官方 README 格式)
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
instruction: 任务指令
|
| 254 |
+
query: 查询文本
|
| 255 |
+
document: 文档文本
|
| 256 |
+
|
| 257 |
+
Returns:
|
| 258 |
+
格式化的输入文本
|
| 259 |
+
"""
|
| 260 |
+
if instruction is None:
|
| 261 |
+
instruction = 'Given a web search query, retrieve relevant passages that answer the query'
|
| 262 |
+
|
| 263 |
+
# 根据官方 README 的格式
|
| 264 |
+
formatted_input = f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {document}"
|
| 265 |
+
return formatted_input
|
| 266 |
+
|
| 267 |
+
def get_reranker_score(self, instruction, query, document):
|
| 268 |
+
"""
|
| 269 |
+
获取重排序分数(通过 logits)
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
instruction: 任务指令
|
| 273 |
+
query: 查询文本
|
| 274 |
+
document: 文档文本
|
| 275 |
+
|
| 276 |
+
Returns:
|
| 277 |
+
相关性分数 (0-1之间)
|
| 278 |
+
"""
|
| 279 |
+
try:
|
| 280 |
+
# 格式化输入
|
| 281 |
+
input_text = self.format_rerank_input(instruction, query, document)
|
| 282 |
+
print(f"\n重排序输入: {input_text[:200]}{'...' if len(input_text) > 200 else ''}")
|
| 283 |
+
|
| 284 |
+
# 准备输入
|
| 285 |
+
rk_input = RKLLMInput()
|
| 286 |
+
rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
|
| 287 |
+
c_prompt = input_text.encode('utf-8')
|
| 288 |
+
rk_input._union_data.prompt_input = c_prompt
|
| 289 |
+
|
| 290 |
+
# 准备推理参数 - 使用 GET_LOGITS 模式
|
| 291 |
+
infer_params = RKLLMInferParam()
|
| 292 |
+
infer_params.mode = RKLLMInferMode.RKLLM_INFER_GET_LOGITS # 获取 logits
|
| 293 |
+
infer_params.keep_history = 0
|
| 294 |
+
|
| 295 |
+
# 清空之前的结果
|
| 296 |
+
self.current_result = None
|
| 297 |
+
self.runtime.clear_kv_cache(False)
|
| 298 |
+
|
| 299 |
+
# 执行推理
|
| 300 |
+
start_time = time.time()
|
| 301 |
+
self.runtime.run(rk_input, infer_params)
|
| 302 |
+
end_time = time.time()
|
| 303 |
+
|
| 304 |
+
print(f"\n推理耗时: {end_time - start_time:.3f}秒")
|
| 305 |
+
|
| 306 |
+
if self.current_result and 'logits' in self.current_result:
|
| 307 |
+
# 使用正确的方法计算 reranker 分数
|
| 308 |
+
logits = self.current_result['logits']
|
| 309 |
+
score = self.calculate_reranker_score(logits)
|
| 310 |
+
|
| 311 |
+
print(f"计算得分: {score:.4f}")
|
| 312 |
+
return score
|
| 313 |
+
else:
|
| 314 |
+
print("警告: 未能获取到有效的 logits,返回默认分数")
|
| 315 |
+
return 0.0
|
| 316 |
+
|
| 317 |
+
except Exception as e:
|
| 318 |
+
print(f"重排序评分时发生错误: {e}")
|
| 319 |
+
import traceback
|
| 320 |
+
traceback.print_exc()
|
| 321 |
+
return 0.0
|
| 322 |
+
|
| 323 |
+
def rerank_documents(self, query, documents, instruction=None):
|
| 324 |
+
"""
|
| 325 |
+
对文档列表进行重排序
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
query: 查询文本
|
| 329 |
+
documents: 文档列表
|
| 330 |
+
instruction: 可选的任务指令
|
| 331 |
+
|
| 332 |
+
Returns:
|
| 333 |
+
按相关性分数降序排列的(文档, 分数)元组列表
|
| 334 |
+
"""
|
| 335 |
+
print(f"\n对 {len(documents)} 个文档进行重排序")
|
| 336 |
+
print(f"查询: {query}")
|
| 337 |
+
|
| 338 |
+
if instruction:
|
| 339 |
+
print(f"指令: {instruction}")
|
| 340 |
+
|
| 341 |
+
scored_docs = []
|
| 342 |
+
for i, doc in enumerate(documents):
|
| 343 |
+
print(f"\n--- 处理文档 {i+1}/{len(documents)} ---")
|
| 344 |
+
print(f"文档: {doc[:100]}{'...' if len(doc) > 100 else ''}")
|
| 345 |
+
|
| 346 |
+
score = self.get_reranker_score(instruction, query, doc)
|
| 347 |
+
scored_docs.append((doc, score))
|
| 348 |
+
print(f"得分: {score:.4f}")
|
| 349 |
+
|
| 350 |
+
# 按分数降序排序
|
| 351 |
+
scored_docs.sort(key=lambda x: x[1], reverse=True)
|
| 352 |
+
return scored_docs
|
| 353 |
+
|
| 354 |
+
def test_basic_reranking(self):
|
| 355 |
+
"""测试基础重排序功能"""
|
| 356 |
+
print("\n" + "="*60)
|
| 357 |
+
print("测试基础重排序功能")
|
| 358 |
+
print("="*60)
|
| 359 |
+
|
| 360 |
+
# 测试查询
|
| 361 |
+
query = "What is the capital of China?"
|
| 362 |
+
|
| 363 |
+
# 候选文档(包含相关和不相关的)
|
| 364 |
+
documents = [
|
| 365 |
+
"Beijing is the capital city of China, located in northern China.",
|
| 366 |
+
"The Great Wall of China is an ancient fortification built to protect Chinese states.",
|
| 367 |
+
"Python is a high-level programming language used for software development.",
|
| 368 |
+
"China's capital Beijing is home to over 21 million people.",
|
| 369 |
+
"Machine learning is a subset of artificial intelligence that uses algorithms."
|
| 370 |
+
]
|
| 371 |
+
|
| 372 |
+
# 执行重排序
|
| 373 |
+
instruction = "Given a web search query, retrieve relevant passages that answer the query"
|
| 374 |
+
ranked_docs = self.rerank_documents(query, documents, instruction)
|
| 375 |
+
|
| 376 |
+
# 显示结果
|
| 377 |
+
print(f"\n重排序结果(查询: {query}):")
|
| 378 |
+
print("-" * 80)
|
| 379 |
+
for i, (doc, score) in enumerate(ranked_docs):
|
| 380 |
+
print(f"排名 {i+1}: 分数 {score:.4f}")
|
| 381 |
+
print(f"文档: {doc}")
|
| 382 |
+
print()
|
| 383 |
+
|
| 384 |
+
return ranked_docs
|
| 385 |
+
|
| 386 |
+
def test_multilingual_reranking(self):
|
| 387 |
+
"""测试多语言重排序"""
|
| 388 |
+
print("\n" + "="*60)
|
| 389 |
+
print("测试多语言重排序功能")
|
| 390 |
+
print("="*60)
|
| 391 |
+
|
| 392 |
+
# 中文查询
|
| 393 |
+
query = "中国的首都是什么?"
|
| 394 |
+
|
| 395 |
+
documents = [
|
| 396 |
+
"北京是中华人民共和国的首都,位于中国北部。",
|
| 397 |
+
"上海是中国的经济中心,人口超过2400万。",
|
| 398 |
+
"Python 是一种高级编程语言。",
|
| 399 |
+
"The capital of China is Beijing.",
|
| 400 |
+
"长城是中国古代的军事防御工程。"
|
| 401 |
+
]
|
| 402 |
+
|
| 403 |
+
instruction = "Given a web search query, retrieve relevant passages that answer the query"
|
| 404 |
+
ranked_docs = self.rerank_documents(query, documents, instruction)
|
| 405 |
+
|
| 406 |
+
print(f"\n多语言重排序结果(查询: {query}):")
|
| 407 |
+
print("-" * 80)
|
| 408 |
+
for i, (doc, score) in enumerate(ranked_docs):
|
| 409 |
+
print(f"排名 {i+1}: 分数 {score:.4f}")
|
| 410 |
+
print(f"文档: {doc}")
|
| 411 |
+
print()
|
| 412 |
+
|
| 413 |
+
return ranked_docs
|
| 414 |
+
|
| 415 |
+
def test_domain_specific_reranking(self):
|
| 416 |
+
"""测试领域特定的重排序"""
|
| 417 |
+
print("\n" + "="*60)
|
| 418 |
+
print("测试领域特定重排序(技术文档)")
|
| 419 |
+
print("="*60)
|
| 420 |
+
|
| 421 |
+
query = "How to implement a neural network in Python?"
|
| 422 |
+
|
| 423 |
+
documents = [
|
| 424 |
+
"PyTorch is a deep learning framework that provides tensor computations with GPU acceleration.",
|
| 425 |
+
"TensorFlow is an open-source machine learning library developed by Google.",
|
| 426 |
+
"Neural networks are computing systems inspired by biological neural networks.",
|
| 427 |
+
"Python is a programming language with simple syntax and powerful libraries.",
|
| 428 |
+
"To implement a neural network in Python, you can use libraries like PyTorch or TensorFlow to define layers, loss functions, and optimization algorithms.",
|
| 429 |
+
"Cooking recipes often require precise measurements and cooking times.",
|
| 430 |
+
"Backpropagation is the algorithm used to train neural networks by computing gradients."
|
| 431 |
+
]
|
| 432 |
+
|
| 433 |
+
# 使用自定义指令
|
| 434 |
+
instruction = "Given a technical query and a document, determine if the document provides practical information for implementing the requested technical solution"
|
| 435 |
+
|
| 436 |
+
ranked_docs = self.rerank_documents(query, documents, instruction)
|
| 437 |
+
|
| 438 |
+
print(f"\n技术文档重排序结果(查询: {query}):")
|
| 439 |
+
print("-" * 80)
|
| 440 |
+
for i, (doc, score) in enumerate(ranked_docs):
|
| 441 |
+
print(f"排名 {i+1}: 分数 {score:.4f}")
|
| 442 |
+
print(f"文档: {doc}")
|
| 443 |
+
print()
|
| 444 |
+
|
| 445 |
+
return ranked_docs
|
| 446 |
+
|
| 447 |
+
def test_comparison_with_official_example(self):
|
| 448 |
+
"""测试与官方示例的对比"""
|
| 449 |
+
print("\n" + "="*60)
|
| 450 |
+
print("测试与官方示例的对比")
|
| 451 |
+
print("="*60)
|
| 452 |
+
|
| 453 |
+
# 使用官方 README 中的示例
|
| 454 |
+
task = 'Given a web search query, retrieve relevant passages that answer the query'
|
| 455 |
+
|
| 456 |
+
queries = [
|
| 457 |
+
"What is the capital of China?",
|
| 458 |
+
"Explain gravity",
|
| 459 |
+
]
|
| 460 |
+
|
| 461 |
+
documents = [
|
| 462 |
+
"The capital of China is Beijing.",
|
| 463 |
+
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
|
| 464 |
+
]
|
| 465 |
+
|
| 466 |
+
print("测试官方示例的查询-文档对:")
|
| 467 |
+
for i, (query, doc) in enumerate(zip(queries, documents)):
|
| 468 |
+
print(f"\n=== 查询-文档对 {i+1} ===")
|
| 469 |
+
print(f"查询: {query}")
|
| 470 |
+
print(f"文档: {doc}")
|
| 471 |
+
|
| 472 |
+
score = self.get_reranker_score(task, query, doc)
|
| 473 |
+
print(f"相关性分数: {score:.4f}")
|
| 474 |
+
|
| 475 |
+
def cleanup(self):
|
| 476 |
+
"""清理资源"""
|
| 477 |
+
if self.runtime:
|
| 478 |
+
try:
|
| 479 |
+
self.runtime.destroy()
|
| 480 |
+
print("模型资源已清理")
|
| 481 |
+
except Exception as e:
|
| 482 |
+
print(f"清理资源时发生错误: {e}")
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def main():
|
| 486 |
+
"""主函数"""
|
| 487 |
+
import argparse
|
| 488 |
+
|
| 489 |
+
# 解析命令行参数
|
| 490 |
+
parser = argparse.ArgumentParser(description='Qwen3-Reranker-0.6B 推理测试')
|
| 491 |
+
parser.add_argument('model_path', help='模型文件路径(.rkllm格式)')
|
| 492 |
+
parser.add_argument('--library_path', default="./librkllmrt.so", help='RKLLM库文件路径(默认为./librkllmrt.so)')
|
| 493 |
+
args = parser.parse_args()
|
| 494 |
+
|
| 495 |
+
# 检查文件是否存在
|
| 496 |
+
if not os.path.exists(args.model_path):
|
| 497 |
+
print(f"错误: 模型文件不存在: {args.model_path}")
|
| 498 |
+
print("请确保:")
|
| 499 |
+
print("1. 已下载 Qwen3-Reranker-0.6B 模型")
|
| 500 |
+
print("2. 已使用 rkllm-convert.py 将模型转换为 .rkllm 格式")
|
| 501 |
+
return
|
| 502 |
+
|
| 503 |
+
if not os.path.exists(args.library_path):
|
| 504 |
+
print(f"错误: RKLLM 库文件不存在: {args.library_path}")
|
| 505 |
+
print("请确保 librkllmrt.so 在当前目录或 LD_LIBRARY_PATH 中")
|
| 506 |
+
return
|
| 507 |
+
|
| 508 |
+
print("Qwen3-Reranker-0.6B 推理测试")
|
| 509 |
+
print("=" * 60)
|
| 510 |
+
print("基于官方 README 的正确实现")
|
| 511 |
+
print("=" * 60)
|
| 512 |
+
|
| 513 |
+
# 创建测试器
|
| 514 |
+
tester = Qwen3RerankerTester(args.model_path, args.library_path)
|
| 515 |
+
|
| 516 |
+
try:
|
| 517 |
+
# 初始化模型
|
| 518 |
+
tester.init_model()
|
| 519 |
+
|
| 520 |
+
# 运行测试
|
| 521 |
+
print("\n开始运行重排序测试...")
|
| 522 |
+
|
| 523 |
+
# 测试官方示例对比
|
| 524 |
+
tester.test_comparison_with_official_example()
|
| 525 |
+
|
| 526 |
+
# 测试基础重排序功能
|
| 527 |
+
tester.test_basic_reranking()
|
| 528 |
+
|
| 529 |
+
# 测试多语言重排序
|
| 530 |
+
tester.test_multilingual_reranking()
|
| 531 |
+
|
| 532 |
+
# 测试领域特定重排序
|
| 533 |
+
tester.test_domain_specific_reranking()
|
| 534 |
+
|
| 535 |
+
print("\n" + "="*60)
|
| 536 |
+
print("所有重排序测试完成!")
|
| 537 |
+
print("="*60)
|
| 538 |
+
|
| 539 |
+
except KeyboardInterrupt:
|
| 540 |
+
print("\n测试被用户中断")
|
| 541 |
+
except Exception as e:
|
| 542 |
+
print(f"\n测试过程中发生错误: {e}")
|
| 543 |
+
import traceback
|
| 544 |
+
traceback.print_exc()
|
| 545 |
+
finally:
|
| 546 |
+
# 清理资源
|
| 547 |
+
tester.cleanup()
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
if __name__ == "__main__":
|
| 551 |
+
main()
|