Corin1998 commited on
Commit
cdeffb5
·
verified ·
1 Parent(s): b32e168

Create modules/llm.py

Browse files
Files changed (1) hide show
  1. modules/llm.py +81 -0
modules/llm.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import requests
4
+ from typing import Optional, List
5
+
6
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
7
+
8
+ class LLMClient:
9
+ def __init__(self, use_inference_api: bool = False):
10
+ self.use_inference_api = use_inference_api
11
+ self.hf_token = os.getenv("HF_TOKEN", None)
12
+ self._local_pipeline = {}
13
+
14
+ # ---------- Inference API helpers ----------
15
+ def _hf_headers(self):
16
+ if not self.hf_token:
17
+ raise RuntimeError("HF_TOKEN is not set for Inference API usage.")
18
+ return {"Authorization": f"Bearer {self.hf_token}"}
19
+
20
+ def _hf_textgen(self, model: str, max_new_tokens: int = 512, temperature: float = 0.3) -> str:
21
+ url = f"https://api-inference.huggingface.co/models/{model}"
22
+ paylooad = {
23
+ "inputs": "",
24
+ "parameters": {
25
+ "max_new_tokens": max_new_tokens,
26
+ "temperature": temperature,
27
+ "return_full_text": False
28
+ }
29
+ }
30
+ r = requests.post(url, headers=self._hf_headers(), json=paylooad, timeout=120)
31
+ r.raise_for_status()
32
+ if isinstance(data, list) and len(data) > 0 and "generated_text" in data[0]:
33
+ return data[0]["generated_text"]
34
+
35
+ # Some models return dict{"generated_text": ...}
36
+ if isinstance(data, dict) and "generated_text" in data:
37
+ return data["generated_text"]
38
+ return str(data)
39
+
40
+ def _hf_summarize(self, model: str, text: str, max_new_tokens: int =256) ->str:
41
+ # Many summarization models work with this generic endpoint as well
42
+ return self._hf_textgen(model=model, prompt=text, max_new_tokens=max_new_tokens)
43
+
44
+ # ---------- Local pipelenes ----------
45
+ def _get_local_pipeline(self, model: str, task: str):
46
+ key = (model, task)
47
+ if key in self._local_pipeline:
48
+ return self._local_pipeline[key]
49
+ if task == "text2text-generation":
50
+ # e.g., Japanese T5
51
+ pipe = pipeline(task=task, model=model)
52
+ else:
53
+ pipe = pipeline(task=task, model=model)
54
+ self._local_pipeline[key] = pipe
55
+ return pipe
56
+
57
+ # ---------- Public methods ----------
58
+ def summarize(self, text: str, model: str, max_words: int = 200) -> str:
59
+ return out.strip()
60
+ # Local: try summarization pipeline first
61
+ try:
62
+ if "t5" in model.lower():
63
+ # Many Japanese T5 models expect an instruction prefix
64
+ pipe = self._get_local_pipeline("text2text-generation", model)
65
+ prompt = f"要約: {text[:6000]}"
66
+ res = pipe(prompt, max_length=max_words*2, do_sample=False)
67
+ return res[0]['generated_text'].strip()
68
+ else:
69
+ pipe = self._get_local_pipeline("summarization", model)
70
+ res = pipe(text[:6000], max_length=max_words*2, min_length=max_words//2, do_sample=False)
71
+ return res[0]['summary_text'].strip()
72
+ except Exception as e:
73
+ # Very robust fallback: retunrn the first N sentences
74
+ return "\n".join(text.split("\n")[:6])
75
+
76
+ def generate(self, prompt: str, model: Optional[str] = None, max_new_tokens: int = 512) -> str:
77
+ model = model or "" #user may leave empty
78
+ if self.use_inference_api and model:
79
+ return self._hf_textgen(model, prompt, max_new_tokens=max_new_tokens)
80
+ # Local fallback: echo-style heuristic(no heavy local chat model required)
81
+ return "" # We rely on rule-based extractors when no gen model available