GitHub Actions
Clean sync from GitHub - no large files in history
aca8ab4
"""
Analyzer Agent: Analyze individual papers using RAG context.
"""
import os
import json
import logging
import threading
from typing import Dict, Any, List
from concurrent.futures import ThreadPoolExecutor, as_completed
from openai import AzureOpenAI
from tenacity import retry, stop_after_attempt, wait_exponential
from utils.schemas import Analysis, Paper
from rag.retrieval import RAGRetriever
from utils.langfuse_client import observe
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class AnalyzerAgent:
"""Agent for analyzing individual papers with RAG."""
def __init__(
self,
rag_retriever: RAGRetriever,
model=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME"),
temperature: float = 0.0,
timeout: int = 60
):
"""
Initialize Analyzer Agent.
Args:
rag_retriever: RAGRetriever instance
model: Azure OpenAI model deployment name
temperature: Temperature for generation (0 for deterministic)
timeout: Request timeout in seconds (default: 60)
"""
self.rag_retriever = rag_retriever
self.model = model
self.temperature = temperature
self.timeout = timeout
# Circuit breaker for consecutive failures
self.consecutive_failures = 0
self.max_consecutive_failures = 2
# Thread-safe token tracking for parallel processing
self.token_lock = threading.Lock()
self.batch_tokens = {"input": 0, "output": 0}
# Initialize Azure OpenAI client with timeout
self.client = AzureOpenAI(
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
#api_version="2024-02-01",
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
timeout=timeout,
max_retries=2 # SDK-level retries
)
def _create_analysis_prompt(
self,
paper: Paper,
context: str
) -> str:
"""Create prompt for paper analysis."""
prompt = f"""You are a research paper analyst. Analyze the following paper using ONLY the provided context.
Paper Title: {paper.title}
Authors: {", ".join(paper.authors)}
Abstract: {paper.abstract}
Context from Paper:
{context}
Analyze this paper and extract the following information. You MUST ground every statement in the provided context.
Provide your analysis in the following JSON format:
{{
"methodology": "Description of research methodology used",
"key_findings": ["Finding 1", "Finding 2", "Finding 3"],
"conclusions": "Main conclusions of the paper",
"limitations": ["Limitation 1", "Limitation 2"],
"main_contributions": ["Contribution 1", "Contribution 2"],
"citations": ["Reference 1", "Reference 2", "Reference 3"]
}}
CRITICAL JSON FORMATTING RULES:
- Use ONLY information from the provided context
- Be specific and cite which parts of the context support your statements
- For string fields (methodology, conclusions): use "Not available in provided context" if information is missing
- For array fields (key_findings, limitations, main_contributions, citations):
* MUST be flat arrays of strings ONLY: ["item1", "item2"]
* If no information available, use empty array: []
* NEVER nest arrays: [[], "text"] or [["nested"]] are INVALID
* NEVER include null, empty strings, or non-string values
* Each array element must be a non-empty string
- ALWAYS maintain correct JSON types: strings for text fields, flat arrays of strings for list fields
"""
return prompt
def _normalize_analysis_response(self, data: dict) -> dict:
"""
Normalize LLM response to ensure list fields contain only strings.
Handles multiple edge cases:
- Strings converted to single-element lists
- Nested lists flattened recursively
- None values filtered out
- Empty strings removed
- Mixed types converted to strings
This prevents Pydantic validation errors from malformed LLM responses.
Args:
data: Raw analysis data dictionary from LLM
Returns:
Normalized dictionary with correct types for all fields
"""
list_fields = ['key_findings', 'limitations', 'main_contributions', 'citations']
def flatten_and_clean(value):
"""Recursively flatten nested lists and clean values."""
if isinstance(value, str):
# Single string - return as list if non-empty
return [value.strip()] if value.strip() else []
elif isinstance(value, list):
# List - recursively flatten and filter
cleaned = []
for item in value:
if isinstance(item, str):
# Add non-empty strings
if item.strip():
cleaned.append(item.strip())
elif isinstance(item, list):
# Recursively flatten nested lists
cleaned.extend(flatten_and_clean(item))
elif item is not None and str(item).strip():
# Convert non-None, non-string values to strings
cleaned.append(str(item).strip())
return cleaned
elif value is not None:
# Non-list, non-string, non-None - stringify
str_value = str(value).strip()
return [str_value] if str_value else []
else:
# None value
return []
for field in list_fields:
if field not in data:
# Missing field - set to empty list
data[field] = []
logger.debug(f"Field '{field}' missing in LLM response, set to []")
else:
original_value = data[field]
normalized_value = flatten_and_clean(original_value)
# Log if normalization changed the structure
if original_value != normalized_value:
logger.warning(
f"Normalized '{field}': {type(original_value).__name__} "
f"with {len(original_value) if isinstance(original_value, list) else 1} items "
f"-> list with {len(normalized_value)} items"
)
data[field] = normalized_value
return data
def analyze_paper(
self,
paper: Paper,
top_k_chunks: int = 10
) -> Analysis:
"""
Analyze a single paper with retry logic and circuit breaker.
Args:
paper: Paper object
top_k_chunks: Number of chunks to retrieve for context
Returns:
Analysis object
"""
# Circuit breaker: Skip if too many consecutive failures
if self.consecutive_failures >= self.max_consecutive_failures:
logger.warning(
f"Circuit breaker active: Skipping {paper.arxiv_id} after "
f"{self.consecutive_failures} consecutive failures"
)
raise Exception("Circuit breaker active - too many consecutive failures")
try:
logger.info(f"Analyzing paper: {paper.arxiv_id}")
# Retrieve relevant chunks for this paper
# Use broad queries to get comprehensive coverage
queries = [
"methodology approach methods",
"results findings experiments",
"conclusions contributions implications",
"limitations future work challenges"
]
all_chunks = []
chunk_ids = set()
for query in queries:
result = self.rag_retriever.retrieve(
query=query,
top_k=top_k_chunks // len(queries),
paper_ids=[paper.arxiv_id]
)
for chunk in result["chunks"]:
if chunk["chunk_id"] not in chunk_ids:
all_chunks.append(chunk)
chunk_ids.add(chunk["chunk_id"])
# Format context
context = self.rag_retriever.format_context(all_chunks)
# Create prompt
prompt = self._create_analysis_prompt(paper, context)
# Call Azure OpenAI with temperature=0 and output limits
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a research paper analyst. Provide accurate, grounded analysis based only on the provided context."},
{"role": "user", "content": prompt}
],
temperature=self.temperature,
max_tokens=1500, # Limit output to prevent slow responses
response_format={"type": "json_object"}
)
# Track token usage (thread-safe)
if hasattr(response, 'usage') and response.usage:
with self.token_lock:
self.batch_tokens["input"] += response.usage.prompt_tokens
self.batch_tokens["output"] += response.usage.completion_tokens
logger.info(f"Analyzer token usage for {paper.arxiv_id}: "
f"{response.usage.prompt_tokens} input, "
f"{response.usage.completion_tokens} output")
# Parse response
analysis_data = json.loads(response.choices[0].message.content)
# Normalize response to ensure list fields are lists (not strings)
analysis_data = self._normalize_analysis_response(analysis_data)
# Calculate confidence based on context completeness
confidence = min(len(all_chunks) / top_k_chunks, 1.0)
# Create Analysis object
analysis = Analysis(
paper_id=paper.arxiv_id,
methodology=analysis_data.get("methodology", "Not available"),
key_findings=analysis_data.get("key_findings", []),
conclusions=analysis_data.get("conclusions", "Not available"),
limitations=analysis_data.get("limitations", []),
citations=analysis_data.get("citations", []),
main_contributions=analysis_data.get("main_contributions", []),
confidence_score=confidence
)
logger.info(f"Analysis completed for {paper.arxiv_id} with confidence {confidence:.2f}")
# Reset circuit breaker on success
self.consecutive_failures = 0
return analysis
except Exception as e:
# Increment circuit breaker on failure
self.consecutive_failures += 1
logger.error(
f"Error analyzing paper {paper.arxiv_id} ({str(e)}). "
f"Consecutive failures: {self.consecutive_failures}"
)
# Return minimal analysis on error
return Analysis(
paper_id=paper.arxiv_id,
methodology="Analysis failed",
key_findings=[],
conclusions="Analysis failed",
limitations=[],
citations=[],
main_contributions=[],
confidence_score=0.0
)
@observe(name="analyzer_agent_run", as_type="generation")
def run(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""
Execute analyzer agent with parallel processing.
Args:
state: Current agent state
Returns:
Updated state with analyses
"""
try:
logger.info("=== Analyzer Agent Started ===")
papers = state.get("papers", [])
if not papers:
error_msg = "No papers to analyze"
logger.error(error_msg)
state["errors"].append(error_msg)
return state
# Reset circuit breaker for new batch
self.consecutive_failures = 0
logger.info("Circuit breaker reset for new batch")
# Reset token counters for new batch
self.batch_tokens = {"input": 0, "output": 0}
# Analyze papers in parallel (max 4 concurrent for optimal throughput)
max_workers = min(4, len(papers))
logger.info(f"Analyzing {len(papers)} papers with {max_workers} parallel workers")
analyses = []
failed_papers = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all papers for analysis
future_to_paper = {
executor.submit(self.analyze_paper, paper): paper
for paper in papers
}
# Collect results as they complete
for future in as_completed(future_to_paper):
paper = future_to_paper[future]
try:
analysis = future.result()
analyses.append(analysis)
logger.info(f"Successfully analyzed paper {paper.arxiv_id}")
except Exception as e:
error_msg = f"Failed to analyze paper {paper.arxiv_id}: {str(e)}"
logger.error(error_msg)
state["errors"].append(error_msg)
failed_papers.append(paper.arxiv_id)
# Accumulate batch tokens to state
state["token_usage"]["input_tokens"] += self.batch_tokens["input"]
state["token_usage"]["output_tokens"] += self.batch_tokens["output"]
logger.info(f"Total analyzer batch tokens: {self.batch_tokens['input']} input, "
f"{self.batch_tokens['output']} output")
if not analyses:
error_msg = "Failed to analyze any papers"
logger.error(error_msg)
state["errors"].append(error_msg)
return state
if failed_papers:
logger.warning(f"Failed to analyze {len(failed_papers)} papers: {failed_papers}")
state["analyses"] = analyses
logger.info(f"=== Analyzer Agent Completed: {len(analyses)}/{len(papers)} papers analyzed ===")
return state
except Exception as e:
error_msg = f"Analyzer Agent error: {str(e)}"
logger.error(error_msg)
state["errors"].append(error_msg)
return state