Corin1998's picture
Create modules/llm.py
cdeffb5 verified
raw
history blame
3.61 kB
import os
import json
import requests
from typing import Optional, List
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
class LLMClient:
def __init__(self, use_inference_api: bool = False):
self.use_inference_api = use_inference_api
self.hf_token = os.getenv("HF_TOKEN", None)
self._local_pipeline = {}
# ---------- Inference API helpers ----------
def _hf_headers(self):
if not self.hf_token:
raise RuntimeError("HF_TOKEN is not set for Inference API usage.")
return {"Authorization": f"Bearer {self.hf_token}"}
def _hf_textgen(self, model: str, max_new_tokens: int = 512, temperature: float = 0.3) -> str:
url = f"https://api-inference.huggingface.co/models/{model}"
paylooad = {
"inputs": "",
"parameters": {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"return_full_text": False
}
}
r = requests.post(url, headers=self._hf_headers(), json=paylooad, timeout=120)
r.raise_for_status()
if isinstance(data, list) and len(data) > 0 and "generated_text" in data[0]:
return data[0]["generated_text"]
# Some models return dict{"generated_text": ...}
if isinstance(data, dict) and "generated_text" in data:
return data["generated_text"]
return str(data)
def _hf_summarize(self, model: str, text: str, max_new_tokens: int =256) ->str:
# Many summarization models work with this generic endpoint as well
return self._hf_textgen(model=model, prompt=text, max_new_tokens=max_new_tokens)
# ---------- Local pipelenes ----------
def _get_local_pipeline(self, model: str, task: str):
key = (model, task)
if key in self._local_pipeline:
return self._local_pipeline[key]
if task == "text2text-generation":
# e.g., Japanese T5
pipe = pipeline(task=task, model=model)
else:
pipe = pipeline(task=task, model=model)
self._local_pipeline[key] = pipe
return pipe
# ---------- Public methods ----------
def summarize(self, text: str, model: str, max_words: int = 200) -> str:
return out.strip()
# Local: try summarization pipeline first
try:
if "t5" in model.lower():
# Many Japanese T5 models expect an instruction prefix
pipe = self._get_local_pipeline("text2text-generation", model)
prompt = f"要約: {text[:6000]}"
res = pipe(prompt, max_length=max_words*2, do_sample=False)
return res[0]['generated_text'].strip()
else:
pipe = self._get_local_pipeline("summarization", model)
res = pipe(text[:6000], max_length=max_words*2, min_length=max_words//2, do_sample=False)
return res[0]['summary_text'].strip()
except Exception as e:
# Very robust fallback: retunrn the first N sentences
return "\n".join(text.split("\n")[:6])
def generate(self, prompt: str, model: Optional[str] = None, max_new_tokens: int = 512) -> str:
model = model or "" #user may leave empty
if self.use_inference_api and model:
return self._hf_textgen(model, prompt, max_new_tokens=max_new_tokens)
# Local fallback: echo-style heuristic(no heavy local chat model required)
return "" # We rely on rule-based extractors when no gen model available