LLM4HEP / supervisor_coder.py
ho22joshua's picture
initial commit
cfcbbc8
import os
import re
import sys
import openai
import subprocess
import argparse
import numpy as np
import yaml
from datetime import datetime
base_dir = os.getcwd()
# Find the project root directory (where prompts/ directory is located)
# When running from results, we need to go up one level
project_root = os.path.dirname(base_dir) if 'results' in base_dir else base_dir
parser = argparse.ArgumentParser(description='supervisor/coder')
add_arg = parser.add_argument
add_arg('--prompt', help='prompt name')
add_arg('--out_dir', help='output directory')
add_arg('--input-files', nargs='*', help='List of input files from snakemake')
add_arg('--output-files', nargs='*', help='List of output files from snakemake')
add_arg('--config', help='Path to config file', default=os.path.join(project_root, 'config.yml'))
add_arg('--max_iterations', type=int, default=3, help='Maximum coder iterations for faster runs')
args = parser.parse_args()
# Read supervisor and coder from config.yml
config_path = args.config
if os.path.exists(config_path):
try:
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
supervisor = config.get('supervisor')
if not supervisor:
print(f"ERROR: 'supervisor' not found in {config_path}")
sys.exit(1)
coder = config.get('coder')
if not coder:
print(f"ERROR: 'coder' not found in {config_path}")
sys.exit(1)
except Exception as e:
print(f"ERROR: Could not read {config_path}: {e}")
sys.exit(1)
else:
print(f"ERROR: config file not found at {config_path}")
sys.exit(1)
temperature = config.get('temperature', None)
def get_code(output): # convert Markdown to Python
match = re.search(r'```python\s*(.*?)\s*```', output, re.DOTALL)
if match:
return match.group(1).strip()
return output
name = args.prompt
out_dir = args.out_dir
os.makedirs(out_dir, exist_ok=True)
log_dir = os.path.join(out_dir, 'logs')
os.makedirs(log_dir, exist_ok=True)
code_dir = os.path.join(out_dir, 'generated_code')
os.makedirs(code_dir, exist_ok=True)
prompt_pair_dir = os.path.join(out_dir, 'prompt_pairs') # for saving supervisor/user prompt pairs
os.makedirs(prompt_pair_dir, exist_ok=True)
prompt_filepath = os.path.join(out_dir, f'prompts_temp/{name}.txt')
# If the prompt file doesn't exist in prompts_temp, try logs (for modified prompts)
if not os.path.exists(prompt_filepath):
prompt_filepath = os.path.join(out_dir, f'logs/{name}.txt')
with open(prompt_filepath, 'r') as file:
user_prompt = file.read()
first_supervisor_instructions_filepath = os.path.join(out_dir, 'prompts_temp/supervisor_first_call.txt')
with open(first_supervisor_instructions_filepath, 'r') as file:
first_supervisor_instructions = file.read()
supervisor_instructions_filepath = os.path.join(out_dir, 'prompts_temp/supervisor_call.txt')
with open(supervisor_instructions_filepath, 'r') as file:
supervisor_instructions = file.read()
# Build initial supervisor prompt with file context
file_context = ''
if args.input_files:
file_context += '\nInput files:\n' + '\n'.join(args.input_files)
if args.output_files:
file_context += '\nOutput files:\n' + '\n'.join(args.output_files)
supervisor_prompt = first_supervisor_instructions + file_context + '\n\n' + user_prompt
client = openai.OpenAI(
api_key = os.environ.get('CBORG_API_KEY'),
base_url = 'https://api.cborg.lbl.gov'
)
done_outer = False
count_outer = 0
total_calls = 0
input_tokens = 0
output_tokens = 0
tokens = {"User Prompt": 0, "Supervisor to Coder": 0, "Coder Outputs": 0, "Feedback to Supervisor": 0}
# Create comprehensive log (only log file we'll create)
comprehensive_log_path = os.path.join(log_dir, f'{name}_comprehensive_log.txt')
start_time = datetime.now()
# Initialize comprehensive log with header
with open(comprehensive_log_path, 'w') as comp_log:
comp_log.write("=" * 100 + "\n")
comp_log.write("🎯 COMPREHENSIVE SUPERVISOR-CODER LOG\n")
comp_log.write("=" * 100 + "\n\n")
comp_log.write(f"πŸ“‹ Task: {name}\n")
comp_log.write(f"πŸ€– Supervisor: {supervisor}\n")
comp_log.write(f"πŸ€– Coder: {coder}\n")
comp_log.write(f"πŸ• Start Time: {start_time}\n")
comp_log.write(f"πŸ“ Working Directory: {os.getcwd()}\n")
comp_log.write(f"πŸ“‚ Output Directory: {out_dir}\n\n")
# Log the original user prompt
comp_log.write("πŸ“ ORIGINAL USER PROMPT\n")
comp_log.write("-" * 50 + "\n")
comp_log.write(user_prompt)
comp_log.write("\n\n")
# Log supervisor instructions
# comp_log.write("πŸŽ“ SUPERVISOR INSTRUCTIONS\n")
# comp_log.write("-" * 50 + "\n")
# comp_log.write(first_supervisor_instructions)
# comp_log.write("\n\n")
# Log supervisor call template
# comp_log.write("πŸ“‹ SUPERVISOR CALL TEMPLATE\n")
# comp_log.write("-" * 50 + "\n")
# comp_log.write(supervisor_instructions)
# comp_log.write("\n\n")
comp_log.write("πŸš€ PROCESS START\n")
comp_log.write("-" * 50 + "\n\n")
# Function to log to comprehensive log
def log_comprehensive(message, section="", level="INFO", plain=False):
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(comprehensive_log_path, 'a') as comp_log:
if section:
# Use more prominent section headers
comp_log.write(f"\n\n{'='*25} {section.upper()} {'='*25}\n\n")
# Truncate very long messages to prevent memory issues
if len(message) > 50000: # 50KB limit
truncated_message = message[:50000] + f"\n\n[TRUNCATED: message was {len(message)} characters long]"
message = truncated_message
if plain:
# For logging raw content like prompts or code
comp_log.write(f"{message}\n")
else:
comp_log.write(f"[{timestamp}] {level}: {message}\n")
log_comprehensive("Starting supervisor-coder interaction", "πŸš€ INITIALIZATION")
print("=== SUPERVISOR-CODER MAIN LOOP ===")
log_comprehensive("=== MAIN LOOP STARTED ===", "πŸ”„ LOOP START")
print(f"πŸ”„ Calling supervisor {supervisor} (initial)")
log_comprehensive(f"Calling supervisor {supervisor} with initial prompt", "πŸ€– SUPERVISOR CALL")
try:
response = client.chat.completions.create(
model = supervisor,
messages = [
{
'role': 'user',
'content': supervisor_prompt
}
],
temperature = temperature
)
tokens['User Prompt'] += response.usage.prompt_tokens
input_tokens += response.usage.prompt_tokens
output_tokens += response.usage.completion_tokens
supervisor_response = response.choices[-1].message.content
# Log initial prompt tokens
print(f"Initial supervisor prompt tokens: {tokens['User Prompt']}")
log_comprehensive(f"Initial supervisor prompt tokens: {tokens['User Prompt']}", "TOKEN STATS")
# Log token stats and content under SUPERVISOR RESPONSE section
log_comprehensive(
f"Supervisor response received ({response.usage.prompt_tokens} input tokens, {response.usage.completion_tokens} output tokens)",
"SUPERVISOR RESPONSE"
)
log_comprehensive(supervisor_response, plain=True)
# Include input and output files under the same section without extra headers
log_comprehensive(file_context, section="", plain=True)
except Exception as e:
log_comprehensive(f"Supervisor API call failed: {e}", "❌ API ERROR", "ERROR")
print(f"OpenAI API error: {e}")
# Check if it's a rate limit or quota issue (temporary)
if "429" in str(e) or "rate" in str(e).lower() or "quota" in str(e).lower():
print("Rate limit or quota exceeded. This is a temporary issue.")
print("Stopping job to avoid wasting resources on temporary failures.")
log_comprehensive("Terminating due to rate limit/quota issue", "πŸ›‘ TERMINATION", "ERROR")
sys.exit(1) # Fail for rate limits
else:
print("Permanent API error (auth, network, etc.). Stopping job.")
log_comprehensive("Terminating due to permanent API error", "πŸ›‘ TERMINATION", "ERROR")
sys.exit(1) # Fail for permanent errors
if 'Call record:' in supervisor_response:
coder_prompt, record = supervisor_response.split('Call record:', 1)
else:
coder_prompt = supervisor_response
record = "Supervisor did not provide a record."
# Commented out parsing warning to silence logs
# log_comprehensive("Supervisor response missing 'Call record:'. Treating entire response as coder prompt.", "PARSING WARNING", "WARNING")
supervisor_record = os.path.join(log_dir, 'supervisor_record.txt')
with open(supervisor_record, 'w') as f:
f.write('Call 1 record: ' + record)
f.write(f'WARNING: running supervisor/coder with supervisor {supervisor} and coder {coder}')
# log_comprehensive(f"Initial supervisor record: {record}", "πŸ“ SUPERVISOR RECORD")
done_inner = False
count_inner = 0
old_coder_prompt = "Default coder prompt" # Initialize to avoid unbound variable
while not done_inner and count_inner < args.max_iterations:
count_inner += 1
print(f"πŸ€– Calling coder {coder} (attempt {count_inner}/{args.max_iterations})")
log_comprehensive(f"Calling coder {coder} (attempt {count_inner}/{args.max_iterations})", f"πŸ€– CODER CALL #{count_inner}")
# log_comprehensive(f"Coder Prompt:\n{coder_prompt}")
try:
print("DEBUG: Making coder API call...")
response = client.chat.completions.create(
model = coder,
messages = [
{
'role': 'user',
'content': coder_prompt
}
],
temperature = temperature
)
print("DEBUG: Coder API call completed successfully")
input_tokens += response.usage.prompt_tokens
output_tokens += response.usage.completion_tokens
tokens['Supervisor to Coder'] += response.usage.prompt_tokens
tokens['Coder Outputs'] += response.usage.completion_tokens
supervisor_response = response.choices[-1].message.content
output = response.choices[-1].message.content
print(f"DEBUG: Extracted output, type: {type(output)}, length: {len(output) if output else 0}")
log_comprehensive(f"Coder response received ({response.usage.prompt_tokens} input tokens, {response.usage.completion_tokens} output tokens)", f"CODER RESPONSE #{count_inner}")
log_comprehensive("Generated Code:", plain=True)
log_comprehensive(output, plain=True)
print("DEBUG: Logged coder response successfully")
except Exception as e:
log_comprehensive(f"Coder API call failed: {e}", f"❌ CODER ERROR #{count_inner}", "ERROR")
print(f"OpenAI API error: {e}")
if "429" in str(e) or "rate" in str(e).lower() or "quota" in str(e).lower():
print("Rate limit or quota exceeded. This is a temporary issue.")
print("Stopping job to avoid wasting resources on temporary failures.")
log_comprehensive("Terminating due to rate limit/quota issue", "πŸ›‘ TERMINATION", "ERROR")
sys.exit(1) # Fail for rate limits
else:
print("Permanent API error (auth, network, etc.). Stopping job.")
log_comprehensive("Terminating due to permanent API error", "πŸ›‘ TERMINATION", "ERROR")
sys.exit(1) # Fail for permanent errors
code = get_code(output)
# log_comprehensive(f"Extracted code from response:\n{code}", f"πŸ”§ CODE EXTRACTION #{count_inner}")
code_filepath = os.path.join(code_dir, f'{name}.py')
with open(code_filepath, 'w') as f:
f.write(code)
print("⚑ Executing code from coder")
# log_comprehensive("Executing generated code", f"⚑ CODE EXECUTION #{count_inner}")
# Execute code and capture output directly to comprehensive log
cmd = ['python', code_filepath]
# General-purpose argument passing from Snakemake to the generated script.
# Pass all input and output files to the generated script.
if args.input_files:
cmd.extend(['--input-files', *args.input_files])
if args.output_files:
cmd.extend(['--output-files', *args.output_files])
# Log the execution command
log_comprehensive("Executing generated code.", f"CODE EXECUTION #{count_inner}")
log_comprehensive(f"Command: {' '.join(cmd)}", level="DEBUG")
# Ensure generated code writes job-scoped artifacts
env = os.environ.copy()
if out_dir:
env['OUTPUT_DIR'] = out_dir
process = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
env=env,
)
# Log execution results to comprehensive log
log_comprehensive(f"Execution completed with exit code {process.returncode}", f"EXECUTION RESULT #{count_inner}")
if process.returncode != 0:
log_comprehensive("Execution failed. Output:", level="ERROR")
log_comprehensive(process.stdout, plain=True)
else:
log_comprehensive("Execution successful. Output:", level="INFO")
log_comprehensive(process.stdout, plain=True)
# Use the execution output directly for supervisor feedback
command_line_output = process.stdout or ""
# Verify that the output files were actually created
if args.output_files:
for output_file in args.output_files:
if not os.path.exists(output_file):
error_message = f"\nERROR: Output file was not created: {output_file}"
print(error_message)
command_line_output += error_message
log_comprehensive(f"Output file not found after execution: {output_file}", f"EXECUTION RESULT #{count_inner}", "ERROR")
elif os.path.getsize(output_file) == 0:
error_message = f"\nERROR: Output file exists but is empty: {output_file}"
print(error_message)
command_line_output += error_message
log_comprehensive(f"Output file exists but is empty: {output_file}", f"EXECUTION RESULT #{count_inner}", "ERROR")
else:
# File exists and has content - verify it was recently created/modified
import time
mtime = os.path.getmtime(output_file)
current_time = time.time()
if current_time - mtime > 60: # Modified more than 1 minute ago
error_message = f"\nWARNING: Output file exists but was not recently modified: {output_file} (modified {current_time - mtime:.1f} seconds ago)"
print(error_message)
command_line_output += error_message
log_comprehensive(f"Output file exists but not recently modified: {output_file}", f"EXECUTION RESULT #{count_inner}", "WARNING")
# construct new supervisor prompt
with open(supervisor_record, 'r') as f:
record = f.read()
before_user, after_user = supervisor_instructions.split('User prompt:')
# Rebuild file context for feedback prompts
file_context = ''
if args.input_files:
file_context += '\nInput files:\n' + '\n'.join(args.input_files)
if args.output_files:
file_context += '\nOutput files:\n' + '\n'.join(args.output_files)
new_supervisor_prompt = (
before_user + file_context + '\nUser prompt:\n' + user_prompt +
'\nGenerated code:\n' + code +
'\nCommand line output:\n' + command_line_output +
'\nRecord:\n' + record
)
# SUPERVISOR_PROMPT_FEEDBACK logging removed to avoid duplicating the built prompt
print(f"Calling supervisor {supervisor} (iteration {count_inner})")
log_comprehensive(f"Calling supervisor {supervisor} for feedback (iteration {count_inner})", f"πŸ€– SUPERVISOR FEEDBACK CALL #{count_inner}")
supervisor_response_inner = "" # Initialize to avoid unbound variable
try:
response = client.chat.completions.create(
model = supervisor,
messages = [
{
'role': 'user',
'content': new_supervisor_prompt
}
],
temperature = temperature
)
input_tokens += response.usage.prompt_tokens
output_tokens += response.usage.completion_tokens
tokens['Feedback to Supervisor'] += response.usage.prompt_tokens
supervisor_response_inner = response.choices[-1].message.content
log_comprehensive(f"Supervisor feedback received ({response.usage.prompt_tokens} input tokens, {response.usage.completion_tokens} output tokens)", f"SUPERVISOR FEEDBACK RESPONSE #{count_inner}")
log_comprehensive("Supervisor Feedback:", plain=True)
log_comprehensive(supervisor_response_inner, plain=True)
except Exception as e:
log_comprehensive(f"Supervisor feedback API call failed: {e}", f"❌ SUPERVISOR ERROR #{count_inner}", "ERROR")
print(f"OpenAI API error: {e}")
if "429" in str(e) or "rate" in str(e).lower() or "quota" in str(e).lower():
print("Rate limit or quota exceeded. This is a temporary issue.")
print("Stopping job to avoid wasting resources on temporary failures.")
log_comprehensive("Terminating due to rate limit/quota issue", "πŸ›‘ TERMINATION", "ERROR")
sys.exit(1) # Fail for rate limits
else:
print("Permanent API error (auth, network, etc.). Stopping job.")
log_comprehensive("Terminating due to permanent API error", "πŸ›‘ TERMINATION", "ERROR")
sys.exit(1) # Fail for permanent errors
# Set the response for processing
supervisor_response = supervisor_response_inner
if 'Call record:' in supervisor_response:
try:
parts = supervisor_response.split('Call record:', 1)
coder_prompt = parts[0]
record = parts[-1]
with open(supervisor_record, 'a') as f:
f.write(f'\nCall {count_inner} record: ' + record)
except Exception as e:
log_comprehensive(f"Error parsing supervisor response: {e}", f"⚠️ PARSING WARNING #{count_inner}", "WARNING")
print(f"Error parsing supervisor response: {e}")
coder_prompt = supervisor_response
else:
coder_prompt = supervisor_response
# Suppress parsing warning
# log_comprehensive("Supervisor response missing 'Call record:'. Treating entire response as coder prompt.", f"⚠️ PARSING WARNING #{count_inner}", "WARNING")
old_coder_prompt = coder_prompt
if 'Supervisor is satisfied with current results' in coder_prompt:
done_inner = True
print(f"βœ… Supervisor satisfied after {count_inner} iterations!")
log_comprehensive(f"SUCCESS: Supervisor satisfied after {count_inner} iterations!", "πŸŽ‰ SUCCESS")
else:
log_comprehensive(f"Supervisor not satisfied, continuing to iteration {count_inner + 1}", f"πŸ”„ CONTINUING TO ITERATION #{count_inner + 1}")
# Ensure coder_prompt is set for the next iteration
if 'Call record:' not in supervisor_response:
coder_prompt = supervisor_response
total_calls = 1 + 2*count_inner # Initial supervisor + coder calls + feedback supervisor calls
print(f"πŸ“Š Total API calls made: {total_calls}")
log_comprehensive(f"Total API calls made: {total_calls}", "πŸ“Š FINAL STATISTICS")
log_comprehensive(f"Final token counts: {input_tokens} input, {output_tokens} output")
# Final summary
end_time = datetime.now()
duration = end_time - start_time
log_comprehensive(f"Interaction completed in {duration}", "🏁 PROCESS COMPLETED")
log_comprehensive("=== END OF COMPREHENSIVE LOG ===", "🏁 END OF LOG")
# Add final summary to comprehensive log
with open(comprehensive_log_path, 'a') as comp_log:
comp_log.write(f"\n{'='*100}\n")
comp_log.write("πŸ“Š FINAL SUMMARY\n")
comp_log.write(f"{'='*100}\n")
comp_log.write(f"Task: {name}\n")
comp_log.write(f"Supervisor: {supervisor}\n")
comp_log.write(f"Coder: {coder}\n")
comp_log.write(f"Total Iterations: {count_inner if 'count_inner' in locals() else 0}\n")
comp_log.write(f"Start Time: {start_time}\n")
comp_log.write(f"End Time: {end_time}\n")
comp_log.write(f"Duration: {duration}\n")
comp_log.write(f"Total API Calls: {total_calls}\n")
comp_log.write(f"Total Input Tokens: {input_tokens}\n")
comp_log.write(f"Total Output Tokens: {output_tokens}\n")
comp_log.write(f"User Prompt Tokens: {tokens['User Prompt']}\n")
comp_log.write(f"Supervisor to Coder Tokens: {tokens['Supervisor to Coder']}\n")
comp_log.write(f"Coder Output Tokens: {tokens['Coder Outputs']}\n")
comp_log.write(f"Feedback to Supervisor Tokens: {tokens['Feedback to Supervisor']}\n")
final_status = 'SUCCESS' if ('done_inner' in locals() and done_inner) else 'INCOMPLETE'
comp_log.write(f"Final Status: {final_status}\n")
comp_log.write(f"{'='*100}\n")
prompt_pair = f'User Prompt: \n {user_prompt} \n Supervisor Prompt: \n {old_coder_prompt}'
prompt_pair_path = os.path.join(prompt_pair_dir, f'{name}.txt')
with open(prompt_pair_path, 'w') as f:
f.write(prompt_pair)
# Save metrics arrays under the run's logs directory
calls_filepath = os.path.join(out_dir, 'logs', 'calls.npy')
if os.path.exists(calls_filepath):
old_calls = np.load(calls_filepath)
new_calls = np.append(old_calls, total_calls)
else:
new_calls = np.array([total_calls])
np.save(calls_filepath, new_calls)
input_tokens_filepath = os.path.join(out_dir, 'logs', 'input_tokens.npy')
if os.path.exists(input_tokens_filepath):
old_tokens = np.load(input_tokens_filepath)
new_tokens = np.append(old_tokens, input_tokens)
else:
new_tokens = np.array([input_tokens])
np.save(input_tokens_filepath, new_tokens)
output_tokens_filepath = os.path.join(out_dir, 'logs', 'output_tokens.npy')
if os.path.exists(output_tokens_filepath):
old_tokens = np.load(output_tokens_filepath)
new_tokens = np.append(old_tokens, output_tokens)
else:
new_tokens = np.array([output_tokens])
np.save(output_tokens_filepath, new_tokens)
# Before final success messages, verify output files exist and are non-empty
if args.output_files:
missing = [f for f in args.output_files if not os.path.exists(f) or os.path.getsize(f) == 0]
if missing:
print(f"βœ— Missing or empty output files: {missing}")
sys.exit(1)
else:
print("βœ… All expected output files created.")
print("πŸŽ‰ Supervisor-coder interaction completed!")
print(f"πŸ“ˆ Final stats: {total_calls} total API calls, {input_tokens} input tokens, {output_tokens} output tokens")
print(f"πŸ“‹ Comprehensive log saved to: {comprehensive_log_path}")