Spaces:
Sleeping
Sleeping
| from fastapi import APIRouter, HTTPException, Depends, status, Request, Query | |
| from typing import Optional | |
| from fastapi.security import OAuth2PasswordRequestForm | |
| from db.mongo import users_collection, password_reset_codes_collection | |
| from core.security import hash_password, verify_password, create_access_token, get_current_user | |
| from core.email_service import email_service | |
| from models.schemas import ( | |
| SignupForm, | |
| TokenResponse, | |
| DoctorCreate, | |
| AdminCreate, | |
| ProfileUpdate, | |
| PasswordChange, | |
| AdminUserUpdate, | |
| AdminPasswordReset, | |
| PasswordResetRequest, | |
| PasswordResetVerify, | |
| PasswordResetConfirm, | |
| ) | |
| from datetime import datetime | |
| import logging | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(name)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| router = APIRouter() | |
| async def signup(data: SignupForm): | |
| """ | |
| Patient registration endpoint - only patients can register through signup | |
| Doctors and admins must be created by existing admins | |
| """ | |
| logger.info(f"Patient signup attempt for email: {data.email}") | |
| logger.info(f"Received signup data: {data.dict()}") | |
| email = data.email.lower().strip() | |
| existing = await users_collection.find_one({"email": email}) | |
| if existing: | |
| logger.warning(f"Signup failed: Email already exists: {email}") | |
| raise HTTPException( | |
| status_code=status.HTTP_409_CONFLICT, | |
| detail="Email already exists" | |
| ) | |
| hashed_pw = hash_password(data.password) | |
| user_doc = { | |
| "email": email, | |
| "full_name": data.full_name.strip(), | |
| "password": hashed_pw, | |
| "roles": ["patient"], # Only patients can register through signup | |
| "created_at": datetime.utcnow().isoformat(), | |
| "updated_at": datetime.utcnow().isoformat(), | |
| "device_token": "" # Default empty device token for patients | |
| } | |
| try: | |
| result = await users_collection.insert_one(user_doc) | |
| logger.info(f"User created successfully: {email}") | |
| return { | |
| "status": "success", | |
| "id": str(result.inserted_id), | |
| "email": email | |
| } | |
| except Exception as e: | |
| logger.error(f"Failed to create user {email}: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Failed to create user: {str(e)}" | |
| ) | |
| async def create_doctor( | |
| data: DoctorCreate, | |
| current_user: dict = Depends(get_current_user) | |
| ): | |
| """ | |
| Create doctor account - only admins can create doctor accounts | |
| """ | |
| logger.info(f"Doctor creation attempt by {current_user.get('email')}") | |
| if 'admin' not in current_user.get('roles', []): | |
| logger.warning(f"Unauthorized doctor creation attempt by {current_user.get('email')}") | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="Only admins can create doctor accounts" | |
| ) | |
| email = data.email.lower().strip() | |
| existing = await users_collection.find_one({"email": email}) | |
| if existing: | |
| logger.warning(f"Doctor creation failed: Email already exists: {email}") | |
| raise HTTPException( | |
| status_code=status.HTTP_409_CONFLICT, | |
| detail="Email already exists" | |
| ) | |
| hashed_pw = hash_password(data.password) | |
| doctor_doc = { | |
| "email": email, | |
| "full_name": data.full_name.strip(), | |
| "password": hashed_pw, | |
| "roles": data.roles, # Support multiple roles | |
| "specialty": data.specialty, | |
| "license_number": data.license_number, | |
| "created_at": datetime.utcnow().isoformat(), | |
| "updated_at": datetime.utcnow().isoformat(), | |
| "device_token": "" # Default empty device token for doctors | |
| } | |
| try: | |
| result = await users_collection.insert_one(doctor_doc) | |
| logger.info(f"Doctor created successfully: {email}") | |
| return { | |
| "status": "success", | |
| "id": str(result.inserted_id), | |
| "email": email | |
| } | |
| except Exception as e: | |
| logger.error(f"Failed to create doctor {email}: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Failed to create doctor: {str(e)}" | |
| ) | |
| async def create_admin( | |
| data: AdminCreate, | |
| current_user: dict = Depends(get_current_user) | |
| ): | |
| """ | |
| Create admin account - only existing admins can create new admin accounts | |
| """ | |
| logger.info(f"Admin creation attempt by {current_user.get('email')}") | |
| if 'admin' not in current_user.get('roles', []): | |
| logger.warning(f"Unauthorized admin creation attempt by {current_user.get('email')}") | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="Only admins can create admin accounts" | |
| ) | |
| email = data.email.lower().strip() | |
| existing = await users_collection.find_one({"email": email}) | |
| if existing: | |
| logger.warning(f"Admin creation failed: Email already exists: {email}") | |
| raise HTTPException( | |
| status_code=status.HTTP_409_CONFLICT, | |
| detail="Email already exists" | |
| ) | |
| hashed_pw = hash_password(data.password) | |
| admin_doc = { | |
| "email": email, | |
| "full_name": data.full_name.strip(), | |
| "password": hashed_pw, | |
| "roles": data.roles, # Support multiple roles | |
| "created_at": datetime.utcnow().isoformat(), | |
| "updated_at": datetime.utcnow().isoformat(), | |
| "device_token": "" | |
| } | |
| try: | |
| result = await users_collection.insert_one(admin_doc) | |
| logger.info(f"Admin created successfully: {email}") | |
| return { | |
| "status": "success", | |
| "id": str(result.inserted_id), | |
| "email": email | |
| } | |
| except Exception as e: | |
| logger.error(f"Failed to create admin {email}: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Failed to create admin: {str(e)}" | |
| ) | |
| async def login(form_data: OAuth2PasswordRequestForm = Depends()): | |
| logger.info(f"Login attempt for email: {form_data.username}") | |
| user = await users_collection.find_one({"email": form_data.username.lower()}) | |
| if not user or not verify_password(form_data.password, user["password"]): | |
| logger.warning(f"Login failed for {form_data.username}: Invalid credentials") | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid credentials", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| # Update device token if provided in form_data (e.g., from frontend) | |
| if hasattr(form_data, 'device_token') and form_data.device_token: | |
| await users_collection.update_one( | |
| {"email": user["email"]}, | |
| {"$set": {"device_token": form_data.device_token}} | |
| ) | |
| logger.info(f"Device token updated for {form_data.username}") | |
| access_token = create_access_token(data={"sub": user["email"]}) | |
| logger.info(f"Successful login for {form_data.username}") | |
| return { | |
| "access_token": access_token, | |
| "token_type": "bearer", | |
| "roles": user.get("roles", ["patient"]) # Return all roles | |
| } | |
| async def get_me(request: Request, current_user: dict = Depends(get_current_user)): | |
| logger.info(f"Fetching user profile for {current_user['email']} at {datetime.utcnow().isoformat()}") | |
| print(f"Headers: {request.headers}") | |
| try: | |
| user = await users_collection.find_one({"email": current_user["email"]}) | |
| if not user: | |
| logger.warning(f"User not found: {current_user['email']}") | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="User not found" | |
| ) | |
| # Handle both "role" (singular) and "roles" (array) formats | |
| user_roles = user.get("roles", []) | |
| if not user_roles and user.get("role"): | |
| # If roles array is empty but role field exists, convert to array | |
| user_roles = [user.get("role")] | |
| print(f"🔍 User from DB: {user}") | |
| print(f"🔍 User roles: {user_roles}") | |
| response = { | |
| "id": str(user["_id"]), | |
| "email": user["email"], | |
| "full_name": user.get("full_name", ""), | |
| "roles": user_roles, # Return all roles | |
| "specialty": user.get("specialty"), | |
| "created_at": user.get("created_at"), | |
| "updated_at": user.get("updated_at"), | |
| "device_token": user.get("device_token", "") # Include device token in response | |
| } | |
| logger.info(f"User profile retrieved for {current_user['email']} at {datetime.utcnow().isoformat()}") | |
| return response | |
| except Exception as e: | |
| logger.error(f"Database error for user {current_user['email']}: {str(e)} at {datetime.utcnow().isoformat()}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Database error: {str(e)}" | |
| ) | |
| async def list_users( | |
| role: Optional[str] = None, | |
| search: Optional[str] = Query(None, description="Search by name or email"), | |
| current_user: dict = Depends(get_current_user) | |
| ): | |
| """ | |
| List users - admins can see all users, doctors can see patients, patients can only see themselves | |
| """ | |
| logger.info(f"User list request by {current_user.get('email')} with role filter: {role}") | |
| # Build query based on user role and requested filter | |
| query = {} | |
| if role: | |
| # support both role singlular and roles array in historical docs | |
| query["roles"] = {"$in": [role]} | |
| if search: | |
| query["$or"] = [ | |
| {"full_name": {"$regex": search, "$options": "i"}}, | |
| {"email": {"$regex": search, "$options": "i"}}, | |
| ] | |
| # Role-based access control | |
| if 'admin' in current_user.get('roles', []): | |
| # Admins can see all users | |
| pass | |
| elif 'doctor' in current_user.get('roles', []): | |
| # Doctors can only see patients | |
| query["roles"] = {"$in": ["patient"]} | |
| elif 'patient' in current_user.get('roles', []): | |
| # Patients can only see themselves | |
| query["email"] = current_user.get('email') | |
| try: | |
| users = await users_collection.find(query).limit(500).to_list(length=500) | |
| # Remove sensitive information | |
| for user in users: | |
| user["id"] = str(user["_id"]) | |
| del user["_id"] | |
| del user["password"] | |
| user.pop("device_token", None) # Safely remove device_token if it exists | |
| logger.info(f"Retrieved {len(users)} users for {current_user.get('email')}") | |
| return users | |
| except Exception as e: | |
| logger.error(f"Error retrieving users: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Failed to retrieve users: {str(e)}" | |
| ) | |
| async def admin_update_user( | |
| user_id: str, | |
| data: AdminUserUpdate, | |
| current_user: dict = Depends(get_current_user) | |
| ): | |
| if 'admin' not in current_user.get('roles', []): | |
| raise HTTPException(status_code=403, detail="Admins only") | |
| try: | |
| update_data = {k: v for k, v in data.dict().items() if v is not None} | |
| update_data["updated_at"] = datetime.utcnow().isoformat() | |
| result = await users_collection.update_one({"_id": __import__('bson').ObjectId(user_id)}, {"$set": update_data}) | |
| if result.matched_count == 0: | |
| raise HTTPException(status_code=404, detail="User not found") | |
| return {"status": "success"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Failed to update user: {str(e)}") | |
| async def admin_delete_user( | |
| user_id: str, | |
| current_user: dict = Depends(get_current_user) | |
| ): | |
| if 'admin' not in current_user.get('roles', []): | |
| raise HTTPException(status_code=403, detail="Admins only") | |
| try: | |
| result = await users_collection.delete_one({"_id": __import__('bson').ObjectId(user_id)}) | |
| if result.deleted_count == 0: | |
| raise HTTPException(status_code=404, detail="User not found") | |
| return {"status": "success"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Failed to delete user: {str(e)}") | |
| async def admin_reset_password( | |
| user_id: str, | |
| data: AdminPasswordReset, | |
| current_user: dict = Depends(get_current_user) | |
| ): | |
| if 'admin' not in current_user.get('roles', []): | |
| raise HTTPException(status_code=403, detail="Admins only") | |
| if len(data.new_password) < 6: | |
| raise HTTPException(status_code=400, detail="Password must be at least 6 characters") | |
| try: | |
| hashed = hash_password(data.new_password) | |
| result = await users_collection.update_one( | |
| {"_id": __import__('bson').ObjectId(user_id)}, | |
| {"$set": {"password": hashed, "updated_at": datetime.utcnow().isoformat()}} | |
| ) | |
| if result.matched_count == 0: | |
| raise HTTPException(status_code=404, detail="User not found") | |
| return {"status": "success"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Failed to reset password: {str(e)}") | |
| async def update_profile( | |
| data: ProfileUpdate, | |
| current_user: dict = Depends(get_current_user) | |
| ): | |
| """ | |
| Update user profile - users can update their own profile | |
| """ | |
| logger.info(f"Profile update attempt by {current_user.get('email')}") | |
| # Build update data (only include fields that are provided) | |
| update_data = {} | |
| if data.full_name is not None: | |
| update_data["full_name"] = data.full_name.strip() | |
| if data.phone is not None: | |
| update_data["phone"] = data.phone.strip() | |
| if data.address is not None: | |
| update_data["address"] = data.address.strip() | |
| if data.date_of_birth is not None: | |
| update_data["date_of_birth"] = data.date_of_birth | |
| if data.gender is not None: | |
| update_data["gender"] = data.gender.strip() | |
| if data.specialty is not None: | |
| update_data["specialty"] = data.specialty.strip() | |
| if data.license_number is not None: | |
| update_data["license_number"] = data.license_number.strip() | |
| # Add updated timestamp | |
| update_data["updated_at"] = datetime.utcnow().isoformat() | |
| try: | |
| result = await users_collection.update_one( | |
| {"email": current_user.get('email')}, | |
| {"$set": update_data} | |
| ) | |
| if result.modified_count == 0: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="User not found" | |
| ) | |
| logger.info(f"Profile updated successfully for {current_user.get('email')}") | |
| return { | |
| "status": "success", | |
| "message": "Profile updated successfully" | |
| } | |
| except Exception as e: | |
| logger.error(f"Failed to update profile for {current_user.get('email')}: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Failed to update profile: {str(e)}" | |
| ) | |
| async def change_password( | |
| data: PasswordChange, | |
| current_user: dict = Depends(get_current_user) | |
| ): | |
| """ | |
| Change user password - users can change their own password | |
| """ | |
| logger.info(f"Password change attempt by {current_user.get('email')}") | |
| # Verify current password | |
| if not verify_password(data.current_password, current_user.get('password')): | |
| logger.warning(f"Password change failed: incorrect current password for {current_user.get('email')}") | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Current password is incorrect" | |
| ) | |
| # Validate new password | |
| if len(data.new_password) < 6: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="New password must be at least 6 characters long" | |
| ) | |
| # Hash new password | |
| hashed_new_password = hash_password(data.new_password) | |
| try: | |
| result = await users_collection.update_one( | |
| {"email": current_user.get('email')}, | |
| { | |
| "$set": { | |
| "password": hashed_new_password, | |
| "updated_at": datetime.utcnow().isoformat() | |
| } | |
| } | |
| ) | |
| if result.modified_count == 0: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="User not found" | |
| ) | |
| logger.info(f"Password changed successfully for {current_user.get('email')}") | |
| return { | |
| "status": "success", | |
| "message": "Password changed successfully" | |
| } | |
| except Exception as e: | |
| logger.error(f"Failed to change password for {current_user.get('email')}: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Failed to change password: {str(e)}" | |
| ) | |
| async def forgot_password(data: PasswordResetRequest): | |
| """ | |
| Request password reset - sends verification code via email | |
| """ | |
| logger.info(f"Password reset request for email: {data.email}") | |
| email = data.email.lower().strip() | |
| # Check if user exists | |
| user = await users_collection.find_one({"email": email}) | |
| if not user: | |
| # Don't reveal if email exists or not for security | |
| logger.info(f"Password reset requested for non-existent email: {email}") | |
| return { | |
| "status": "success", | |
| "message": "If the email exists, a verification code has been sent" | |
| } | |
| # Generate verification code | |
| verification_code = email_service.generate_verification_code() | |
| # Store verification code in database | |
| code_document = email_service.create_verification_code_document(email, verification_code) | |
| try: | |
| # Remove any existing codes for this email | |
| await password_reset_codes_collection.delete_many({"email": email}) | |
| # Insert new code | |
| await password_reset_codes_collection.insert_one(code_document) | |
| # Send email | |
| user_name = user.get('full_name', 'User') | |
| email_sent = await email_service.send_password_reset_email(email, verification_code, user_name) | |
| if email_sent: | |
| logger.info(f"Password reset email sent successfully to {email}") | |
| return { | |
| "status": "success", | |
| "message": "If the email exists, a verification code has been sent" | |
| } | |
| else: | |
| logger.error(f"Failed to send password reset email to {email}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Failed to send verification email" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to process password reset for {email}: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Failed to process password reset request" | |
| ) | |
| async def verify_reset_code(data: PasswordResetVerify): | |
| """ | |
| Verify password reset code | |
| """ | |
| logger.info(f"Verification code check for email: {data.email}") | |
| email = data.email.lower().strip() | |
| code = data.verification_code.strip() | |
| # Find the verification code | |
| code_document = await password_reset_codes_collection.find_one({ | |
| "email": email, | |
| "verification_code": code, | |
| "used": False | |
| }) | |
| if not code_document: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Invalid or expired verification code" | |
| ) | |
| # Check if code has expired | |
| if email_service.is_code_expired(code_document["expires_at"]): | |
| # Remove expired code | |
| await password_reset_codes_collection.delete_one({"_id": code_document["_id"]}) | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Verification code has expired" | |
| ) | |
| logger.info(f"Verification code validated for {email}") | |
| return { | |
| "status": "success", | |
| "message": "Verification code is valid" | |
| } | |
| async def reset_password(data: PasswordResetConfirm): | |
| """ | |
| Reset password with verification code | |
| """ | |
| logger.info(f"Password reset attempt for email: {data.email}") | |
| email = data.email.lower().strip() | |
| code = data.verification_code.strip() | |
| # Validate new password | |
| if len(data.new_password) < 6: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="New password must be at least 6 characters long" | |
| ) | |
| # Find and validate the verification code | |
| code_document = await password_reset_codes_collection.find_one({ | |
| "email": email, | |
| "verification_code": code, | |
| "used": False | |
| }) | |
| if not code_document: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Invalid or expired verification code" | |
| ) | |
| # Check if code has expired | |
| if email_service.is_code_expired(code_document["expires_at"]): | |
| # Remove expired code | |
| await password_reset_codes_collection.delete_one({"_id": code_document["_id"]}) | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Verification code has expired" | |
| ) | |
| # Hash new password | |
| hashed_new_password = hash_password(data.new_password) | |
| try: | |
| # Update user password | |
| result = await users_collection.update_one( | |
| {"email": email}, | |
| { | |
| "$set": { | |
| "password": hashed_new_password, | |
| "updated_at": datetime.utcnow().isoformat() | |
| } | |
| } | |
| ) | |
| if result.modified_count == 0: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="User not found" | |
| ) | |
| # Mark verification code as used | |
| await password_reset_codes_collection.update_one( | |
| {"_id": code_document["_id"]}, | |
| {"$set": {"used": True}} | |
| ) | |
| logger.info(f"Password reset successfully for {email}") | |
| return { | |
| "status": "success", | |
| "message": "Password reset successfully" | |
| } | |
| except Exception as e: | |
| logger.error(f"Failed to reset password for {email}: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Failed to reset password" | |
| ) | |
| # Export the router as 'auth' for api.__init__.py | |
| auth = router |