Renangi commited on
Commit
c2f1723
·
1 Parent(s): 714eb4c

Implement LLMClient with proper client initialization for Groq/HF

Browse files
Files changed (1) hide show
  1. ragbench_eval/llm.py +97 -25
ragbench_eval/llm.py CHANGED
@@ -1,52 +1,124 @@
1
  import os
2
  import logging
 
 
3
  from .config import LLM_PROVIDER
4
 
5
  logger = logging.getLogger(__name__)
6
 
7
- from .config import LLM_PROVIDER, GEN_MODEL, JUDGE_MODEL
 
 
 
 
8
 
9
- from typing import List, Dict
10
- from .config import LLM_PROVIDER, HF_TOKEN, GROQ_API_KEY
 
 
11
 
12
- from huggingface_hub import InferenceClient
13
- from groq import Groq
14
 
15
  class LLMClient:
 
 
 
 
 
 
 
 
16
  def __init__(self, model_name: str, for_judge: bool = False):
17
  self.model_name = model_name
18
  self.for_judge = for_judge
19
 
20
- # Provider from config, with fallback to "groq"
21
  self.provider = (LLM_PROVIDER or "groq").lower()
22
 
23
  logger.info(
24
- f"LLMClient initialized with provider={self.provider!r}, model={model_name!r}"
 
25
  )
26
 
27
  if self.provider not in ("groq", "hf"):
28
  raise ValueError(f"Unsupported provider {self.provider}")
29
 
30
- def chat(self, messages: List[Dict[str, str]], max_tokens: int = 1024) -> str:
31
- if self.provider == "hf":
32
- prompt = ""
33
- for m in messages:
34
- role = m.get("role", "user")
35
- content = m.get("content", "")
36
- prompt += f"[{role.upper()}]\n{content}\n"
37
- out = self.client.text_generation(
38
- prompt,
39
- model=self.model,
40
- max_new_tokens=max_tokens,
41
- temperature=0.2,
42
- do_sample=False,
43
- )
44
- return out
45
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  resp = self.client.chat.completions.create(
47
- model=self.model,
48
  messages=messages,
 
49
  max_tokens=max_tokens,
50
- temperature=0.2,
51
  )
 
52
  return resp.choices[0].message.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import logging
3
+ from typing import List, Dict, Any
4
+
5
  from .config import LLM_PROVIDER
6
 
7
  logger = logging.getLogger(__name__)
8
 
9
+ # Optional imports we check availability at runtime
10
+ try:
11
+ from groq import Groq # type: ignore
12
+ except Exception: # pragma: no cover
13
+ Groq = None # type: ignore
14
 
15
+ try:
16
+ from huggingface_hub import InferenceClient # type: ignore
17
+ except Exception: # pragma: no cover
18
+ InferenceClient = None # type: ignore
19
 
 
 
20
 
21
  class LLMClient:
22
+ """
23
+ Thin wrapper over either:
24
+ - Groq chat.completions API, or
25
+ - HuggingFace text-generation InferenceClient.
26
+
27
+ Used for both generation and judge models.
28
+ """
29
+
30
  def __init__(self, model_name: str, for_judge: bool = False):
31
  self.model_name = model_name
32
  self.for_judge = for_judge
33
 
34
+ # Provider from config, with safe fallback
35
  self.provider = (LLM_PROVIDER or "groq").lower()
36
 
37
  logger.info(
38
+ f"LLMClient initialized with provider={self.provider!r}, "
39
+ f"model={self.model_name!r}, for_judge={self.for_judge}"
40
  )
41
 
42
  if self.provider not in ("groq", "hf"):
43
  raise ValueError(f"Unsupported provider {self.provider}")
44
 
45
+ # Initialize underlying client
46
+ if self.provider == "groq":
47
+ if Groq is None:
48
+ raise RuntimeError(
49
+ "groq python package is not installed, but provider=groq was selected."
50
+ )
51
+ api_key = os.getenv("GROQ_API_KEY")
52
+ if not api_key:
53
+ raise RuntimeError(
54
+ "GROQ_API_KEY environment variable is required for Groq provider."
55
+ )
56
+ self.client = Groq(api_key=api_key)
57
+
58
+ else: # self.provider == "hf"
59
+ if InferenceClient is None:
60
+ raise RuntimeError(
61
+ "huggingface_hub python package is not installed, "
62
+ "but provider=hf was selected."
63
+ )
64
+ # For private models you can set HF_TOKEN as a secret in the Space
65
+ token = os.getenv("HF_TOKEN") or None
66
+ self.client = InferenceClient(model=self.model_name, token=token)
67
+
68
+ # --------------------------------------------------------------------- #
69
+ # Public API
70
+ # --------------------------------------------------------------------- #
71
+ def chat(
72
+ self,
73
+ messages: List[Dict[str, Any]],
74
+ temperature: float = 0.0,
75
+ max_tokens: int = 512,
76
+ ) -> str:
77
+ """
78
+ Simple chat wrapper.
79
+
80
+ Parameters
81
+ ----------
82
+ messages : list of {"role": "system"|"user"|"assistant", "content": str}
83
+ temperature : float
84
+ max_tokens : int
85
+
86
+ Returns
87
+ -------
88
+ str
89
+ The assistant's reply content.
90
+ """
91
+ if self.provider == "groq":
92
+ # Use Groq chat.completions
93
  resp = self.client.chat.completions.create(
94
+ model=self.model_name,
95
  messages=messages,
96
+ temperature=temperature,
97
  max_tokens=max_tokens,
 
98
  )
99
+ # Assume at least one choice
100
  return resp.choices[0].message.content
101
+
102
+ # provider == "hf" path: flatten chat into a single prompt
103
+ prompt_parts: List[str] = []
104
+ for msg in messages:
105
+ role = msg.get("role", "user")
106
+ content = msg.get("content", "")
107
+ if role == "system":
108
+ prompt_parts.append(f"[SYSTEM] {content}")
109
+ elif role == "assistant":
110
+ prompt_parts.append(f"[ASSISTANT] {content}")
111
+ else:
112
+ prompt_parts.append(f"[USER] {content}")
113
+
114
+ prompt = "\n".join(prompt_parts) + "\n[ASSISTANT]"
115
+
116
+ # HuggingFace text generation
117
+ out = self.client.text_generation(
118
+ prompt,
119
+ max_new_tokens=max_tokens,
120
+ temperature=temperature,
121
+ do_sample=temperature > 0.0,
122
+ )
123
+ # InferenceClient.text_generation returns a plain string
124
+ return out