File size: 4,066 Bytes
c502cb9
 
c2f1723
 
714eb4c
 
 
 
c2f1723
 
 
 
 
c502cb9
c2f1723
 
 
 
c8dfbc0
 
 
c2f1723
 
 
 
 
 
 
 
c502cb9
 
 
c8dfbc0
c2f1723
c502cb9
 
714eb4c
c2f1723
 
714eb4c
c502cb9
 
c8dfbc0
 
c2f1723
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8dfbc0
c2f1723
c8dfbc0
c2f1723
c8dfbc0
 
c2f1723
c8dfbc0
c2f1723
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import logging
from typing import List, Dict, Any

from .config import LLM_PROVIDER

logger = logging.getLogger(__name__)

# Optional imports – we check availability at runtime
try:
    from groq import Groq  # type: ignore
except Exception:  # pragma: no cover
    Groq = None  # type: ignore

try:
    from huggingface_hub import InferenceClient  # type: ignore
except Exception:  # pragma: no cover
    InferenceClient = None  # type: ignore


class LLMClient:
    """
    Thin wrapper over either:
    - Groq chat.completions API, or
    - HuggingFace text-generation InferenceClient.

    Used for both generation and judge models.
    """

    def __init__(self, model_name: str, for_judge: bool = False):
        self.model_name = model_name
        self.for_judge = for_judge

        # Provider from config, with safe fallback
        self.provider = (LLM_PROVIDER or "groq").lower()

        logger.info(
            f"LLMClient initialized with provider={self.provider!r}, "
            f"model={self.model_name!r}, for_judge={self.for_judge}"
        )

        if self.provider not in ("groq", "hf"):
            raise ValueError(f"Unsupported provider {self.provider}")

        # Initialize underlying client
        if self.provider == "groq":
            if Groq is None:
                raise RuntimeError(
                    "groq python package is not installed, but provider=groq was selected."
                )
            api_key = os.getenv("GROQ_API_KEY")
            if not api_key:
                raise RuntimeError(
                    "GROQ_API_KEY environment variable is required for Groq provider."
                )
            self.client = Groq(api_key=api_key)

        else:  # self.provider == "hf"
            if InferenceClient is None:
                raise RuntimeError(
                    "huggingface_hub python package is not installed, "
                    "but provider=hf was selected."
                )
            # For private models you can set HF_TOKEN as a secret in the Space
            token = os.getenv("HF_TOKEN") or None
            self.client = InferenceClient(model=self.model_name, token=token)

    # --------------------------------------------------------------------- #
    # Public API
    # --------------------------------------------------------------------- #
    def chat(
        self,
        messages: List[Dict[str, Any]],
        temperature: float = 0.0,
        max_tokens: int = 512,
    ) -> str:
        """
        Simple chat wrapper.

        Parameters
        ----------
        messages : list of {"role": "system"|"user"|"assistant", "content": str}
        temperature : float
        max_tokens : int

        Returns
        -------
        str
            The assistant's reply content.
        """
        if self.provider == "groq":
            # Use Groq chat.completions
            resp = self.client.chat.completions.create(
                model=self.model_name,
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens,
            )
            # Assume at least one choice
            return resp.choices[0].message.content

        # provider == "hf" path: flatten chat into a single prompt
        prompt_parts: List[str] = []
        for msg in messages:
            role = msg.get("role", "user")
            content = msg.get("content", "")
            if role == "system":
                prompt_parts.append(f"[SYSTEM] {content}")
            elif role == "assistant":
                prompt_parts.append(f"[ASSISTANT] {content}")
            else:
                prompt_parts.append(f"[USER] {content}")

        prompt = "\n".join(prompt_parts) + "\n[ASSISTANT]"

        # HuggingFace text generation
        out = self.client.text_generation(
            prompt,
            max_new_tokens=max_tokens,
            temperature=temperature,
            do_sample=temperature > 0.0,
        )
        # InferenceClient.text_generation returns a plain string
        return out