Spaces:
Running
Running
| import os | |
| import logging | |
| from typing import List, Dict, Any | |
| from .config import LLM_PROVIDER | |
| logger = logging.getLogger(__name__) | |
| # Optional imports β we check availability at runtime | |
| try: | |
| from groq import Groq # type: ignore | |
| except Exception: # pragma: no cover | |
| Groq = None # type: ignore | |
| try: | |
| from huggingface_hub import InferenceClient # type: ignore | |
| except Exception: # pragma: no cover | |
| InferenceClient = None # type: ignore | |
| class LLMClient: | |
| """ | |
| Thin wrapper over either: | |
| - Groq chat.completions API, or | |
| - HuggingFace text-generation InferenceClient. | |
| Used for both generation and judge models. | |
| """ | |
| def __init__(self, model_name: str, for_judge: bool = False): | |
| self.model_name = model_name | |
| self.for_judge = for_judge | |
| # Provider from config, with safe fallback | |
| self.provider = (LLM_PROVIDER or "groq").lower() | |
| logger.info( | |
| f"LLMClient initialized with provider={self.provider!r}, " | |
| f"model={self.model_name!r}, for_judge={self.for_judge}" | |
| ) | |
| if self.provider not in ("groq", "hf"): | |
| raise ValueError(f"Unsupported provider {self.provider}") | |
| # Initialize underlying client | |
| if self.provider == "groq": | |
| if Groq is None: | |
| raise RuntimeError( | |
| "groq python package is not installed, but provider=groq was selected." | |
| ) | |
| api_key = os.getenv("GROQ_API_KEY") | |
| if not api_key: | |
| raise RuntimeError( | |
| "GROQ_API_KEY environment variable is required for Groq provider." | |
| ) | |
| self.client = Groq(api_key=api_key) | |
| else: # self.provider == "hf" | |
| if InferenceClient is None: | |
| raise RuntimeError( | |
| "huggingface_hub python package is not installed, " | |
| "but provider=hf was selected." | |
| ) | |
| # For private models you can set HF_TOKEN as a secret in the Space | |
| token = os.getenv("HF_TOKEN") or None | |
| self.client = InferenceClient(model=self.model_name, token=token) | |
| # --------------------------------------------------------------------- # | |
| # Public API | |
| # --------------------------------------------------------------------- # | |
| def chat( | |
| self, | |
| messages: List[Dict[str, Any]], | |
| temperature: float = 0.0, | |
| max_tokens: int = 512, | |
| ) -> str: | |
| """ | |
| Simple chat wrapper. | |
| Parameters | |
| ---------- | |
| messages : list of {"role": "system"|"user"|"assistant", "content": str} | |
| temperature : float | |
| max_tokens : int | |
| Returns | |
| ------- | |
| str | |
| The assistant's reply content. | |
| """ | |
| if self.provider == "groq": | |
| # Use Groq chat.completions | |
| resp = self.client.chat.completions.create( | |
| model=self.model_name, | |
| messages=messages, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| # Assume at least one choice | |
| return resp.choices[0].message.content | |
| # provider == "hf" path: flatten chat into a single prompt | |
| prompt_parts: List[str] = [] | |
| for msg in messages: | |
| role = msg.get("role", "user") | |
| content = msg.get("content", "") | |
| if role == "system": | |
| prompt_parts.append(f"[SYSTEM] {content}") | |
| elif role == "assistant": | |
| prompt_parts.append(f"[ASSISTANT] {content}") | |
| else: | |
| prompt_parts.append(f"[USER] {content}") | |
| prompt = "\n".join(prompt_parts) + "\n[ASSISTANT]" | |
| # HuggingFace text generation | |
| out = self.client.text_generation( | |
| prompt, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| do_sample=temperature > 0.0, | |
| ) | |
| # InferenceClient.text_generation returns a plain string | |
| return out | |