Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import requests | |
| from typing import Optional, List | |
| from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM | |
| class LLMClient: | |
| def __init__(self, use_inference_api: bool = False): | |
| self.use_inference_api = use_inference_api | |
| self.hf_token = os.getenv("HF_TOKEN", None) | |
| self._local_pipeline = {} | |
| # ---------- Inference API helpers ---------- | |
| def _hf_headers(self): | |
| if not self.hf_token: | |
| raise RuntimeError("HF_TOKEN is not set for Inference API usage.") | |
| return {"Authorization": f"Bearer {self.hf_token}"} | |
| def _hf_textgen(self, model: str, max_new_tokens: int = 512, temperature: float = 0.3) -> str: | |
| url = f"https://api-inference.huggingface.co/models/{model}" | |
| paylooad = { | |
| "inputs": "", | |
| "parameters": { | |
| "max_new_tokens": max_new_tokens, | |
| "temperature": temperature, | |
| "return_full_text": False | |
| } | |
| } | |
| r = requests.post(url, headers=self._hf_headers(), json=paylooad, timeout=120) | |
| r.raise_for_status() | |
| if isinstance(data, list) and len(data) > 0 and "generated_text" in data[0]: | |
| return data[0]["generated_text"] | |
| # Some models return dict{"generated_text": ...} | |
| if isinstance(data, dict) and "generated_text" in data: | |
| return data["generated_text"] | |
| return str(data) | |
| def _hf_summarize(self, model: str, text: str, max_new_tokens: int =256) ->str: | |
| # Many summarization models work with this generic endpoint as well | |
| return self._hf_textgen(model=model, prompt=text, max_new_tokens=max_new_tokens) | |
| # ---------- Local pipelenes ---------- | |
| def _get_local_pipeline(self, model: str, task: str): | |
| key = (model, task) | |
| if key in self._local_pipeline: | |
| return self._local_pipeline[key] | |
| if task == "text2text-generation": | |
| # e.g., Japanese T5 | |
| pipe = pipeline(task=task, model=model) | |
| else: | |
| pipe = pipeline(task=task, model=model) | |
| self._local_pipeline[key] = pipe | |
| return pipe | |
| # ---------- Public methods ---------- | |
| def summarize(self, text: str, model: str, max_words: int = 200) -> str: | |
| return out.strip() | |
| # Local: try summarization pipeline first | |
| try: | |
| if "t5" in model.lower(): | |
| # Many Japanese T5 models expect an instruction prefix | |
| pipe = self._get_local_pipeline("text2text-generation", model) | |
| prompt = f"要約: {text[:6000]}" | |
| res = pipe(prompt, max_length=max_words*2, do_sample=False) | |
| return res[0]['generated_text'].strip() | |
| else: | |
| pipe = self._get_local_pipeline("summarization", model) | |
| res = pipe(text[:6000], max_length=max_words*2, min_length=max_words//2, do_sample=False) | |
| return res[0]['summary_text'].strip() | |
| except Exception as e: | |
| # Very robust fallback: retunrn the first N sentences | |
| return "\n".join(text.split("\n")[:6]) | |
| def generate(self, prompt: str, model: Optional[str] = None, max_new_tokens: int = 512) -> str: | |
| model = model or "" #user may leave empty | |
| if self.use_inference_api and model: | |
| return self._hf_textgen(model, prompt, max_new_tokens=max_new_tokens) | |
| # Local fallback: echo-style heuristic(no heavy local chat model required) | |
| return "" # We rely on rule-based extractors when no gen model available | |