Renangi's picture
Implement LLMClient with proper client initialization for Groq/HF
c2f1723
import os
import logging
from typing import List, Dict, Any
from .config import LLM_PROVIDER
logger = logging.getLogger(__name__)
# Optional imports – we check availability at runtime
try:
from groq import Groq # type: ignore
except Exception: # pragma: no cover
Groq = None # type: ignore
try:
from huggingface_hub import InferenceClient # type: ignore
except Exception: # pragma: no cover
InferenceClient = None # type: ignore
class LLMClient:
"""
Thin wrapper over either:
- Groq chat.completions API, or
- HuggingFace text-generation InferenceClient.
Used for both generation and judge models.
"""
def __init__(self, model_name: str, for_judge: bool = False):
self.model_name = model_name
self.for_judge = for_judge
# Provider from config, with safe fallback
self.provider = (LLM_PROVIDER or "groq").lower()
logger.info(
f"LLMClient initialized with provider={self.provider!r}, "
f"model={self.model_name!r}, for_judge={self.for_judge}"
)
if self.provider not in ("groq", "hf"):
raise ValueError(f"Unsupported provider {self.provider}")
# Initialize underlying client
if self.provider == "groq":
if Groq is None:
raise RuntimeError(
"groq python package is not installed, but provider=groq was selected."
)
api_key = os.getenv("GROQ_API_KEY")
if not api_key:
raise RuntimeError(
"GROQ_API_KEY environment variable is required for Groq provider."
)
self.client = Groq(api_key=api_key)
else: # self.provider == "hf"
if InferenceClient is None:
raise RuntimeError(
"huggingface_hub python package is not installed, "
"but provider=hf was selected."
)
# For private models you can set HF_TOKEN as a secret in the Space
token = os.getenv("HF_TOKEN") or None
self.client = InferenceClient(model=self.model_name, token=token)
# --------------------------------------------------------------------- #
# Public API
# --------------------------------------------------------------------- #
def chat(
self,
messages: List[Dict[str, Any]],
temperature: float = 0.0,
max_tokens: int = 512,
) -> str:
"""
Simple chat wrapper.
Parameters
----------
messages : list of {"role": "system"|"user"|"assistant", "content": str}
temperature : float
max_tokens : int
Returns
-------
str
The assistant's reply content.
"""
if self.provider == "groq":
# Use Groq chat.completions
resp = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
# Assume at least one choice
return resp.choices[0].message.content
# provider == "hf" path: flatten chat into a single prompt
prompt_parts: List[str] = []
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "system":
prompt_parts.append(f"[SYSTEM] {content}")
elif role == "assistant":
prompt_parts.append(f"[ASSISTANT] {content}")
else:
prompt_parts.append(f"[USER] {content}")
prompt = "\n".join(prompt_parts) + "\n[ASSISTANT]"
# HuggingFace text generation
out = self.client.text_generation(
prompt,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=temperature > 0.0,
)
# InferenceClient.text_generation returns a plain string
return out