|
|
""" |
|
|
Automated script to apply authentication pattern to all remaining handlers and tools |
|
|
Run this to complete the authentication implementation |
|
|
""" |
|
|
|
|
|
import re |
|
|
|
|
|
|
|
|
HANDLER_FUNCTIONS_TO_UPDATE = [ |
|
|
|
|
|
'handle_get_order_details', |
|
|
'handle_search_orders', |
|
|
'handle_get_incomplete_orders', |
|
|
'handle_update_order', |
|
|
'handle_delete_order', |
|
|
'handle_delete_all_orders', |
|
|
|
|
|
|
|
|
'handle_count_drivers', |
|
|
'handle_get_driver_details', |
|
|
'handle_search_drivers', |
|
|
'handle_get_available_drivers', |
|
|
'handle_update_driver', |
|
|
'handle_delete_driver', |
|
|
'handle_delete_all_drivers', |
|
|
|
|
|
|
|
|
'handle_create_assignment', |
|
|
'handle_auto_assign_order', |
|
|
'handle_intelligent_assign_order', |
|
|
'handle_get_assignment_details', |
|
|
'handle_update_assignment', |
|
|
'handle_unassign_order', |
|
|
'handle_complete_delivery', |
|
|
'handle_fail_delivery', |
|
|
] |
|
|
|
|
|
AUTH_CHECK_CODE = ''' # Authentication check |
|
|
if not user_id: |
|
|
return { |
|
|
"success": False, |
|
|
"error": "Authentication required. Please login first.", |
|
|
"auth_required": True |
|
|
} |
|
|
''' |
|
|
|
|
|
def update_handler_function(content: str, func_name: str) -> str: |
|
|
"""Add user_id parameter and auth check to a handler function""" |
|
|
|
|
|
|
|
|
pattern1 = rf'(def {func_name}\(tool_input: dict)\) -> dict:' |
|
|
replacement1 = r'\1, user_id: str = None) -> dict:' |
|
|
content = re.sub(pattern1, replacement1, content) |
|
|
|
|
|
|
|
|
pattern2 = rf'(def {func_name}\(.*?\).*?""".*?""")\n(\s+)(#|[a-zA-Z])' |
|
|
|
|
|
def add_auth_check(match): |
|
|
return match.group(1) + '\n' + AUTH_CHECK_CODE + '\n' + match.group(2) + match.group(3) |
|
|
|
|
|
content = re.sub(pattern2, add_auth_check, content, flags=re.DOTALL) |
|
|
|
|
|
return content |
|
|
|
|
|
def update_handler_queries(content: str, func_name: str) -> str: |
|
|
"""Add user_id filtering to WHERE clauses in handler functions""" |
|
|
|
|
|
|
|
|
func_pattern = rf'def {func_name}\(.*?\).*?(?=\ndef\s|\Z)' |
|
|
func_match = re.search(func_pattern, content, re.DOTALL) |
|
|
|
|
|
if not func_match: |
|
|
return content |
|
|
|
|
|
func_content = func_match.group(0) |
|
|
original_func = func_content |
|
|
|
|
|
|
|
|
if 'SELECT' in func_content and 'WHERE' in func_content: |
|
|
|
|
|
func_content = re.sub( |
|
|
r'(\s+where_clauses = \[\])', |
|
|
r'\1\n # IMPORTANT: Always filter by user_id FIRST\n where_clauses = ["user_id = %s"]', |
|
|
func_content |
|
|
) |
|
|
|
|
|
|
|
|
func_content = re.sub( |
|
|
r'(\s+params = \[\])', |
|
|
r'\1\n params = [user_id]', |
|
|
func_content |
|
|
) |
|
|
|
|
|
|
|
|
if ('UPDATE' in func_content or 'DELETE' in func_content) and 'WHERE' not in func_content: |
|
|
|
|
|
func_content = re.sub( |
|
|
r'(DELETE FROM \w+)', |
|
|
r'\1 WHERE user_id = %s', |
|
|
func_content |
|
|
) |
|
|
func_content = re.sub( |
|
|
r'(UPDATE \w+ SET.*?)(\s+""")', |
|
|
r'\1 WHERE user_id = %s\2', |
|
|
func_content, |
|
|
flags=re.DOTALL |
|
|
) |
|
|
|
|
|
|
|
|
content = content.replace(original_func, func_content) |
|
|
|
|
|
return content |
|
|
|
|
|
def main(): |
|
|
print("Applying authentication pattern to all handler functions...") |
|
|
|
|
|
|
|
|
with open('chat/tools.py', 'r', encoding='utf-8') as f: |
|
|
content = f.read() |
|
|
|
|
|
updated_count = 0 |
|
|
|
|
|
for func_name in HANDLER_FUNCTIONS_TO_UPDATE: |
|
|
if f'def {func_name}(tool_input: dict) -> dict:' in content: |
|
|
print(f" Updating {func_name}...") |
|
|
content = update_handler_function(content, func_name) |
|
|
content = update_handler_queries(content, func_name) |
|
|
updated_count += 1 |
|
|
else: |
|
|
print(f" Skipping {func_name} (already updated or not found)") |
|
|
|
|
|
|
|
|
with open('chat/tools.py', 'w', encoding='utf-8') as f: |
|
|
f.write(content) |
|
|
|
|
|
print(f"\nCompleted! Updated {updated_count} handler functions.") |
|
|
print("\nNext: Run 'python update_server_tools.py' to update server.py tools") |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|