Spaces:
Running
Running
Request too large ... TPM 6000, Requested 6841 Reduce max_tokens for the judge 2nd time bug fix
Browse files- ragbench_eval/judge.py +109 -9
ragbench_eval/judge.py
CHANGED
|
@@ -1,22 +1,63 @@
|
|
| 1 |
import json
|
|
|
|
| 2 |
from pathlib import Path
|
| 3 |
from typing import Any, Dict, List, Tuple
|
| 4 |
|
| 5 |
from .llm import LLMClient
|
| 6 |
from .config import JUDGE_MODEL
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
def format_docs_with_keys(
|
| 10 |
documents_sentences: List[List[Tuple[str, str]]]
|
| 11 |
) -> str:
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
for doc in documents_sentences:
|
| 14 |
for key, sent in doc:
|
| 15 |
blocks.append(f"{key}: {sent}")
|
| 16 |
-
|
|
|
|
| 17 |
return "\n".join(blocks).strip()
|
| 18 |
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
class RAGJudge:
|
| 21 |
def __init__(self, prompt_path: str = "prompts/ragbench_judge_prompt.txt"):
|
| 22 |
self.client = LLMClient(JUDGE_MODEL)
|
|
@@ -28,12 +69,18 @@ class RAGJudge:
|
|
| 28 |
answer: str,
|
| 29 |
docs_sentences: List[List[Tuple[str, str]]],
|
| 30 |
) -> Dict[str, Any]:
|
|
|
|
| 31 |
docs_block = format_docs_with_keys(docs_sentences)
|
|
|
|
|
|
|
|
|
|
| 32 |
prompt = self.prompt_template.format(
|
| 33 |
documents=docs_block,
|
| 34 |
question=question,
|
| 35 |
answer=answer,
|
| 36 |
)
|
|
|
|
|
|
|
| 37 |
messages = [
|
| 38 |
{
|
| 39 |
"role": "system",
|
|
@@ -41,21 +88,74 @@ class RAGJudge:
|
|
| 41 |
},
|
| 42 |
{"role": "user", "content": prompt},
|
| 43 |
]
|
| 44 |
-
#raw = self.client.chat(messages, max_tokens=2048)
|
| 45 |
-
raw = self.client.chat(messages, max_tokens=512)
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
try:
|
| 48 |
data = json.loads(raw)
|
| 49 |
-
except json.JSONDecodeError
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
"relevance_explanation",
|
| 53 |
"all_relevant_sentence_keys",
|
| 54 |
"overall_supported_explanation",
|
| 55 |
"overall_supported",
|
| 56 |
"sentence_support_information",
|
| 57 |
"all_utilized_sentence_keys",
|
| 58 |
-
]
|
|
|
|
|
|
|
| 59 |
if key not in data:
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
return data
|
|
|
|
| 1 |
import json
|
| 2 |
+
import logging
|
| 3 |
from pathlib import Path
|
| 4 |
from typing import Any, Dict, List, Tuple
|
| 5 |
|
| 6 |
from .llm import LLMClient
|
| 7 |
from .config import JUDGE_MODEL
|
| 8 |
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
# Hard limits to stay under Groq's token constraints
|
| 12 |
+
# Rough rule-of-thumb: 4 characters ≈ 1 token in English.
|
| 13 |
+
MAX_DOC_CHARS = 8000 # limit for the "documents" block
|
| 14 |
+
MAX_PROMPT_CHARS = 12000 # limit for the full judge prompt
|
| 15 |
+
|
| 16 |
|
| 17 |
def format_docs_with_keys(
|
| 18 |
documents_sentences: List[List[Tuple[str, str]]]
|
| 19 |
) -> str:
|
| 20 |
+
"""
|
| 21 |
+
Turn a nested list of (sentence_key, sentence_text) into the flat
|
| 22 |
+
`<key>: <text>` format expected by the judge prompt.
|
| 23 |
+
"""
|
| 24 |
+
blocks: List[str] = []
|
| 25 |
for doc in documents_sentences:
|
| 26 |
for key, sent in doc:
|
| 27 |
blocks.append(f"{key}: {sent}")
|
| 28 |
+
# blank line between documents
|
| 29 |
+
blocks.append("")
|
| 30 |
return "\n".join(blocks).strip()
|
| 31 |
|
| 32 |
|
| 33 |
+
def _truncate(text: str, limit: int) -> str:
|
| 34 |
+
"""
|
| 35 |
+
Truncate a long string to at most `limit` characters, appending a marker
|
| 36 |
+
so the judge knows context was cut.
|
| 37 |
+
"""
|
| 38 |
+
if len(text) <= limit:
|
| 39 |
+
return text
|
| 40 |
+
return text[:limit] + "\n[TRUNCATED]\n"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _default_annotation(reason: str) -> Dict[str, Any]:
|
| 44 |
+
"""
|
| 45 |
+
Safe fallback annotation used when the judge LLM fails
|
| 46 |
+
(size/rate limit, invalid JSON, etc.).
|
| 47 |
+
"""
|
| 48 |
+
return {
|
| 49 |
+
"relevance_explanation": f"Automatic fallback: {reason}",
|
| 50 |
+
"all_relevant_sentence_keys": [],
|
| 51 |
+
"overall_supported_explanation": (
|
| 52 |
+
"No reliable judgement could be produced because the judge LLM "
|
| 53 |
+
"call failed or the output was not valid JSON."
|
| 54 |
+
),
|
| 55 |
+
"overall_supported": False,
|
| 56 |
+
"sentence_support_information": [],
|
| 57 |
+
"all_utilized_sentence_keys": [],
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
class RAGJudge:
|
| 62 |
def __init__(self, prompt_path: str = "prompts/ragbench_judge_prompt.txt"):
|
| 63 |
self.client = LLMClient(JUDGE_MODEL)
|
|
|
|
| 69 |
answer: str,
|
| 70 |
docs_sentences: List[List[Tuple[str, str]]],
|
| 71 |
) -> Dict[str, Any]:
|
| 72 |
+
# 1) Format docs and truncate to stay under token limits
|
| 73 |
docs_block = format_docs_with_keys(docs_sentences)
|
| 74 |
+
docs_block = _truncate(docs_block, MAX_DOC_CHARS)
|
| 75 |
+
|
| 76 |
+
# 2) Build prompt and also apply a global char-limit
|
| 77 |
prompt = self.prompt_template.format(
|
| 78 |
documents=docs_block,
|
| 79 |
question=question,
|
| 80 |
answer=answer,
|
| 81 |
)
|
| 82 |
+
prompt = _truncate(prompt, MAX_PROMPT_CHARS)
|
| 83 |
+
|
| 84 |
messages = [
|
| 85 |
{
|
| 86 |
"role": "system",
|
|
|
|
| 88 |
},
|
| 89 |
{"role": "user", "content": prompt},
|
| 90 |
]
|
|
|
|
|
|
|
| 91 |
|
| 92 |
+
# 3) Call LLM with smaller max_tokens and catch Groq 413 / rate limit errors
|
| 93 |
+
try:
|
| 94 |
+
raw = self.client.chat(messages, max_tokens=512)
|
| 95 |
+
except Exception as e:
|
| 96 |
+
msg = str(e)
|
| 97 |
+
if (
|
| 98 |
+
"rate_limit_exceeded" in msg
|
| 99 |
+
or "Request too large" in msg
|
| 100 |
+
or "413" in msg
|
| 101 |
+
):
|
| 102 |
+
logger.warning("Judge LLM call failed due to size/limit: %s", msg)
|
| 103 |
+
return _default_annotation(
|
| 104 |
+
"judge LLM request was too large or hit a rate limit."
|
| 105 |
+
)
|
| 106 |
+
# Other errors should still surface
|
| 107 |
+
raise
|
| 108 |
+
|
| 109 |
+
if not isinstance(raw, str):
|
| 110 |
+
raw = str(raw)
|
| 111 |
+
|
| 112 |
+
# 4) Parse JSON robustly
|
| 113 |
try:
|
| 114 |
data = json.loads(raw)
|
| 115 |
+
except json.JSONDecodeError:
|
| 116 |
+
# Try to salvage JSON between first '{' and last '}'
|
| 117 |
+
start = raw.find("{")
|
| 118 |
+
end = raw.rfind("}")
|
| 119 |
+
if start != -1 and end != -1 and end > start:
|
| 120 |
+
candidate = raw[start : end + 1]
|
| 121 |
+
try:
|
| 122 |
+
data = json.loads(candidate)
|
| 123 |
+
except json.JSONDecodeError as e2:
|
| 124 |
+
logger.error("Judge JSON parse error after salvage: %s", e2)
|
| 125 |
+
logger.debug(
|
| 126 |
+
"Raw judge output (first 500 chars): %s", raw[:500]
|
| 127 |
+
)
|
| 128 |
+
return _default_annotation("could not parse judge JSON output.")
|
| 129 |
+
else:
|
| 130 |
+
logger.error(
|
| 131 |
+
"Judge JSON parse error: could not find JSON object in output."
|
| 132 |
+
)
|
| 133 |
+
logger.debug(
|
| 134 |
+
"Raw judge output (first 500 chars): %s", raw[:500]
|
| 135 |
+
)
|
| 136 |
+
return _default_annotation("could not parse judge JSON output.")
|
| 137 |
+
|
| 138 |
+
# 5) Ensure required keys exist; fill missing with safe defaults
|
| 139 |
+
required_keys = [
|
| 140 |
"relevance_explanation",
|
| 141 |
"all_relevant_sentence_keys",
|
| 142 |
"overall_supported_explanation",
|
| 143 |
"overall_supported",
|
| 144 |
"sentence_support_information",
|
| 145 |
"all_utilized_sentence_keys",
|
| 146 |
+
]
|
| 147 |
+
|
| 148 |
+
for key in required_keys:
|
| 149 |
if key not in data:
|
| 150 |
+
if key in ("relevance_explanation", "overall_supported_explanation"):
|
| 151 |
+
data[key] = ""
|
| 152 |
+
elif key in (
|
| 153 |
+
"all_relevant_sentence_keys",
|
| 154 |
+
"sentence_support_information",
|
| 155 |
+
"all_utilized_sentence_keys",
|
| 156 |
+
):
|
| 157 |
+
data[key] = []
|
| 158 |
+
elif key == "overall_supported":
|
| 159 |
+
data[key] = False
|
| 160 |
+
|
| 161 |
return data
|