Spaces:
Sleeping
Sleeping
Create modules/llm.py
Browse files- modules/llm.py +81 -0
modules/llm.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import requests
|
| 4 |
+
from typing import Optional, List
|
| 5 |
+
|
| 6 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
|
| 7 |
+
|
| 8 |
+
class LLMClient:
|
| 9 |
+
def __init__(self, use_inference_api: bool = False):
|
| 10 |
+
self.use_inference_api = use_inference_api
|
| 11 |
+
self.hf_token = os.getenv("HF_TOKEN", None)
|
| 12 |
+
self._local_pipeline = {}
|
| 13 |
+
|
| 14 |
+
# ---------- Inference API helpers ----------
|
| 15 |
+
def _hf_headers(self):
|
| 16 |
+
if not self.hf_token:
|
| 17 |
+
raise RuntimeError("HF_TOKEN is not set for Inference API usage.")
|
| 18 |
+
return {"Authorization": f"Bearer {self.hf_token}"}
|
| 19 |
+
|
| 20 |
+
def _hf_textgen(self, model: str, max_new_tokens: int = 512, temperature: float = 0.3) -> str:
|
| 21 |
+
url = f"https://api-inference.huggingface.co/models/{model}"
|
| 22 |
+
paylooad = {
|
| 23 |
+
"inputs": "",
|
| 24 |
+
"parameters": {
|
| 25 |
+
"max_new_tokens": max_new_tokens,
|
| 26 |
+
"temperature": temperature,
|
| 27 |
+
"return_full_text": False
|
| 28 |
+
}
|
| 29 |
+
}
|
| 30 |
+
r = requests.post(url, headers=self._hf_headers(), json=paylooad, timeout=120)
|
| 31 |
+
r.raise_for_status()
|
| 32 |
+
if isinstance(data, list) and len(data) > 0 and "generated_text" in data[0]:
|
| 33 |
+
return data[0]["generated_text"]
|
| 34 |
+
|
| 35 |
+
# Some models return dict{"generated_text": ...}
|
| 36 |
+
if isinstance(data, dict) and "generated_text" in data:
|
| 37 |
+
return data["generated_text"]
|
| 38 |
+
return str(data)
|
| 39 |
+
|
| 40 |
+
def _hf_summarize(self, model: str, text: str, max_new_tokens: int =256) ->str:
|
| 41 |
+
# Many summarization models work with this generic endpoint as well
|
| 42 |
+
return self._hf_textgen(model=model, prompt=text, max_new_tokens=max_new_tokens)
|
| 43 |
+
|
| 44 |
+
# ---------- Local pipelenes ----------
|
| 45 |
+
def _get_local_pipeline(self, model: str, task: str):
|
| 46 |
+
key = (model, task)
|
| 47 |
+
if key in self._local_pipeline:
|
| 48 |
+
return self._local_pipeline[key]
|
| 49 |
+
if task == "text2text-generation":
|
| 50 |
+
# e.g., Japanese T5
|
| 51 |
+
pipe = pipeline(task=task, model=model)
|
| 52 |
+
else:
|
| 53 |
+
pipe = pipeline(task=task, model=model)
|
| 54 |
+
self._local_pipeline[key] = pipe
|
| 55 |
+
return pipe
|
| 56 |
+
|
| 57 |
+
# ---------- Public methods ----------
|
| 58 |
+
def summarize(self, text: str, model: str, max_words: int = 200) -> str:
|
| 59 |
+
return out.strip()
|
| 60 |
+
# Local: try summarization pipeline first
|
| 61 |
+
try:
|
| 62 |
+
if "t5" in model.lower():
|
| 63 |
+
# Many Japanese T5 models expect an instruction prefix
|
| 64 |
+
pipe = self._get_local_pipeline("text2text-generation", model)
|
| 65 |
+
prompt = f"要約: {text[:6000]}"
|
| 66 |
+
res = pipe(prompt, max_length=max_words*2, do_sample=False)
|
| 67 |
+
return res[0]['generated_text'].strip()
|
| 68 |
+
else:
|
| 69 |
+
pipe = self._get_local_pipeline("summarization", model)
|
| 70 |
+
res = pipe(text[:6000], max_length=max_words*2, min_length=max_words//2, do_sample=False)
|
| 71 |
+
return res[0]['summary_text'].strip()
|
| 72 |
+
except Exception as e:
|
| 73 |
+
# Very robust fallback: retunrn the first N sentences
|
| 74 |
+
return "\n".join(text.split("\n")[:6])
|
| 75 |
+
|
| 76 |
+
def generate(self, prompt: str, model: Optional[str] = None, max_new_tokens: int = 512) -> str:
|
| 77 |
+
model = model or "" #user may leave empty
|
| 78 |
+
if self.use_inference_api and model:
|
| 79 |
+
return self._hf_textgen(model, prompt, max_new_tokens=max_new_tokens)
|
| 80 |
+
# Local fallback: echo-style heuristic(no heavy local chat model required)
|
| 81 |
+
return "" # We rely on rule-based extractors when no gen model available
|