Spaces:
Sleeping
Sleeping
| from fastapi import APIRouter, Depends, HTTPException, status, Query, WebSocket, WebSocketDisconnect, UploadFile, File | |
| from typing import List, Optional | |
| from datetime import datetime, timedelta | |
| from bson import ObjectId | |
| from motor.motor_asyncio import AsyncIOMotorClient | |
| import json | |
| import asyncio | |
| from collections import defaultdict | |
| import os | |
| import uuid | |
| from pathlib import Path | |
| from core.security import get_current_user | |
| from db.mongo import db | |
| from models.schemas import ( | |
| MessageCreate, MessageUpdate, MessageResponse, MessageListResponse, | |
| ConversationResponse, ConversationListResponse, MessageType, MessageStatus, | |
| NotificationCreate, NotificationResponse, NotificationListResponse, | |
| NotificationType, NotificationPriority | |
| ) | |
| router = APIRouter(prefix="/messaging", tags=["messaging"]) | |
| # WebSocket connection manager | |
| class ConnectionManager: | |
| def __init__(self): | |
| self.active_connections: dict = defaultdict(list) # user_id -> list of connections | |
| async def connect(self, websocket: WebSocket, user_id: str): | |
| await websocket.accept() | |
| self.active_connections[user_id].append(websocket) | |
| def disconnect(self, websocket: WebSocket, user_id: str): | |
| if user_id in self.active_connections: | |
| self.active_connections[user_id] = [ | |
| conn for conn in self.active_connections[user_id] if conn != websocket | |
| ] | |
| async def send_personal_message(self, message: dict, user_id: str): | |
| if user_id in self.active_connections: | |
| for connection in self.active_connections[user_id]: | |
| try: | |
| await connection.send_text(json.dumps(message)) | |
| except: | |
| # Remove dead connections | |
| self.active_connections[user_id].remove(connection) | |
| manager = ConnectionManager() | |
| # --- HELPER FUNCTIONS --- | |
| def is_valid_object_id(id_str: str) -> bool: | |
| try: | |
| ObjectId(id_str) | |
| return True | |
| except: | |
| return False | |
| def get_conversation_id(user1_id, user2_id) -> str: | |
| """Generate a consistent conversation ID for two users""" | |
| # Convert both IDs to strings for consistent comparison | |
| user1_str = str(user1_id) | |
| user2_str = str(user2_id) | |
| # Sort IDs to ensure consistent conversation ID regardless of sender/recipient | |
| sorted_ids = sorted([user1_str, user2_str]) | |
| return f"{sorted_ids[0]}_{sorted_ids[1]}" | |
| async def create_notification( | |
| db_client: AsyncIOMotorClient, | |
| recipient_id: str, | |
| title: str, | |
| message: str, | |
| notification_type: NotificationType, | |
| priority: NotificationPriority = NotificationPriority.MEDIUM, | |
| data: Optional[dict] = None | |
| ): | |
| """Create a notification for a user""" | |
| notification_doc = { | |
| "recipient_id": ObjectId(recipient_id), | |
| "title": title, | |
| "message": message, | |
| "notification_type": notification_type, | |
| "priority": priority, | |
| "data": data or {}, | |
| "is_read": False, | |
| "created_at": datetime.now() | |
| } | |
| result = await db_client.notifications.insert_one(notification_doc) | |
| notification_doc["_id"] = result.inserted_id | |
| # Convert ObjectId to string for WebSocket transmission | |
| notification_for_ws = { | |
| "id": str(notification_doc["_id"]), | |
| "recipient_id": str(notification_doc["recipient_id"]), | |
| "title": notification_doc["title"], | |
| "message": notification_doc["message"], | |
| "notification_type": notification_doc["notification_type"], | |
| "priority": notification_doc["priority"], | |
| "data": notification_doc["data"], | |
| "is_read": notification_doc["is_read"], | |
| "created_at": notification_doc["created_at"] | |
| } | |
| # Send real-time notification via WebSocket | |
| await manager.send_personal_message({ | |
| "type": "new_notification", | |
| "data": notification_for_ws | |
| }, recipient_id) | |
| return notification_doc | |
| # --- WEBSOCKET ENDPOINT --- | |
| async def websocket_endpoint(websocket: WebSocket, user_id: str): | |
| await manager.connect(websocket, user_id) | |
| print(f"🔌 WebSocket connected for user: {user_id}") | |
| try: | |
| while True: | |
| # Wait for messages from client (keep connection alive) | |
| data = await websocket.receive_text() | |
| try: | |
| message_data = json.loads(data) | |
| if message_data.get("type") == "ping": | |
| # Send pong to keep connection alive | |
| await websocket.send_text(json.dumps({"type": "pong"})) | |
| except json.JSONDecodeError: | |
| pass # Ignore invalid JSON | |
| except WebSocketDisconnect: | |
| print(f"🔌 WebSocket disconnected for user: {user_id}") | |
| manager.disconnect(websocket, user_id) | |
| except Exception as e: | |
| print(f"❌ WebSocket error for user {user_id}: {e}") | |
| manager.disconnect(websocket, user_id) | |
| # --- CONVERSATION ENDPOINTS --- | |
| async def get_conversations( | |
| page: int = Query(1, ge=1), | |
| limit: int = Query(20, ge=1, le=100), | |
| current_user: dict = Depends(get_current_user), | |
| db_client: AsyncIOMotorClient = Depends(lambda: db) | |
| ): | |
| """Get user's conversations""" | |
| skip = (page - 1) * limit | |
| user_id = current_user["_id"] | |
| # Get all messages where user is sender or recipient | |
| pipeline = [ | |
| { | |
| "$match": { | |
| "$or": [ | |
| {"sender_id": ObjectId(user_id)}, | |
| {"recipient_id": ObjectId(user_id)} | |
| ] | |
| } | |
| }, | |
| { | |
| "$sort": {"created_at": -1} | |
| }, | |
| { | |
| "$group": { | |
| "_id": { | |
| "$cond": [ | |
| {"$eq": ["$sender_id", ObjectId(user_id)]}, | |
| "$recipient_id", | |
| "$sender_id" | |
| ] | |
| }, | |
| "last_message": {"$first": "$$ROOT"}, | |
| "unread_count": { | |
| "$sum": { | |
| "$cond": [ | |
| { | |
| "$and": [ | |
| {"$eq": ["$recipient_id", ObjectId(user_id)]}, | |
| {"$ne": ["$status", "read"]} | |
| ] | |
| }, | |
| 1, | |
| 0 | |
| ] | |
| } | |
| } | |
| } | |
| }, | |
| { | |
| "$sort": {"last_message.created_at": -1} | |
| }, | |
| { | |
| "$skip": skip | |
| }, | |
| { | |
| "$limit": limit | |
| } | |
| ] | |
| conversations_data = await db_client.messages.aggregate(pipeline).to_list(length=limit) | |
| # Get user details for each conversation | |
| conversations = [] | |
| for conv_data in conversations_data: | |
| other_user_id = str(conv_data["_id"]) | |
| other_user = await db_client.users.find_one({"_id": conv_data["_id"]}) | |
| if other_user: | |
| # Convert user_id to string for consistent comparison | |
| user_id_str = str(user_id) | |
| conversation_id = get_conversation_id(user_id_str, other_user_id) | |
| # Build last message response | |
| last_message = None | |
| if conv_data["last_message"]: | |
| last_message = MessageResponse( | |
| id=str(conv_data["last_message"]["_id"]), | |
| sender_id=str(conv_data["last_message"]["sender_id"]), | |
| recipient_id=str(conv_data["last_message"]["recipient_id"]), | |
| sender_name=current_user["full_name"] if conv_data["last_message"]["sender_id"] == ObjectId(user_id) else other_user["full_name"], | |
| recipient_name=other_user["full_name"] if conv_data["last_message"]["sender_id"] == ObjectId(user_id) else current_user["full_name"], | |
| content=conv_data["last_message"]["content"], | |
| message_type=conv_data["last_message"]["message_type"], | |
| attachment_url=conv_data["last_message"].get("attachment_url"), | |
| reply_to_message_id=str(conv_data["last_message"]["reply_to_message_id"]) if conv_data["last_message"].get("reply_to_message_id") else None, | |
| status=conv_data["last_message"]["status"], | |
| is_archived=conv_data["last_message"].get("is_archived", False), | |
| created_at=conv_data["last_message"]["created_at"], | |
| updated_at=conv_data["last_message"]["updated_at"], | |
| read_at=conv_data["last_message"].get("read_at") | |
| ) | |
| conversations.append(ConversationResponse( | |
| id=conversation_id, | |
| participant_ids=[user_id_str, other_user_id], | |
| participant_names=[current_user["full_name"], other_user["full_name"]], | |
| last_message=last_message, | |
| unread_count=conv_data["unread_count"], | |
| created_at=conv_data["last_message"]["created_at"] if conv_data["last_message"] else datetime.now(), | |
| updated_at=conv_data["last_message"]["updated_at"] if conv_data["last_message"] else datetime.now() | |
| )) | |
| # Get total count | |
| total_pipeline = [ | |
| { | |
| "$match": { | |
| "$or": [ | |
| {"sender_id": ObjectId(user_id)}, | |
| {"recipient_id": ObjectId(user_id)} | |
| ] | |
| } | |
| }, | |
| { | |
| "$group": { | |
| "_id": { | |
| "$cond": [ | |
| {"$eq": ["$sender_id", ObjectId(user_id)]}, | |
| "$recipient_id", | |
| "$sender_id" | |
| ] | |
| } | |
| } | |
| }, | |
| { | |
| "$count": "total" | |
| } | |
| ] | |
| total_result = await db_client.messages.aggregate(total_pipeline).to_list(length=1) | |
| total = total_result[0]["total"] if total_result else 0 | |
| return ConversationListResponse( | |
| conversations=conversations, | |
| total=total, | |
| page=page, | |
| limit=limit | |
| ) | |
| # --- MESSAGE ENDPOINTS --- | |
| async def send_message( | |
| message_data: MessageCreate, | |
| current_user: dict = Depends(get_current_user), | |
| db_client: AsyncIOMotorClient = Depends(lambda: db) | |
| ): | |
| """Send a message to another user""" | |
| if not is_valid_object_id(message_data.recipient_id): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Invalid recipient ID" | |
| ) | |
| # Check if recipient exists | |
| recipient = await db_client.users.find_one({"_id": ObjectId(message_data.recipient_id)}) | |
| if not recipient: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="Recipient not found" | |
| ) | |
| # Check if user can message this recipient | |
| # Patients can only message their doctors, doctors can message their patients | |
| current_user_roles = current_user.get('roles', []) | |
| if isinstance(current_user.get('role'), str): | |
| current_user_roles.append(current_user.get('role')) | |
| recipient_roles = recipient.get('roles', []) | |
| if isinstance(recipient.get('role'), str): | |
| recipient_roles.append(recipient.get('role')) | |
| if 'patient' in current_user_roles: | |
| # Patients can only message doctors | |
| if 'doctor' not in recipient_roles: | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="Patients can only message doctors" | |
| ) | |
| elif 'doctor' in current_user_roles: | |
| # Doctors can only message their patients | |
| if 'patient' not in recipient_roles: | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="Doctors can only message patients" | |
| ) | |
| # Check reply message if provided | |
| if message_data.reply_to_message_id: | |
| if not is_valid_object_id(message_data.reply_to_message_id): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Invalid reply message ID" | |
| ) | |
| reply_message = await db_client.messages.find_one({"_id": ObjectId(message_data.reply_to_message_id)}) | |
| if not reply_message: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="Reply message not found" | |
| ) | |
| # Create message | |
| message_doc = { | |
| "sender_id": ObjectId(current_user["_id"]), | |
| "recipient_id": ObjectId(message_data.recipient_id), | |
| "content": message_data.content, | |
| "message_type": message_data.message_type, | |
| "attachment_url": message_data.attachment_url, | |
| "reply_to_message_id": ObjectId(message_data.reply_to_message_id) if message_data.reply_to_message_id else None, | |
| "status": MessageStatus.SENT, | |
| "is_archived": False, | |
| "created_at": datetime.now(), | |
| "updated_at": datetime.now() | |
| } | |
| result = await db_client.messages.insert_one(message_doc) | |
| message_doc["_id"] = result.inserted_id | |
| # Send real-time message via WebSocket | |
| await manager.send_personal_message({ | |
| "type": "new_message", | |
| "data": { | |
| "id": str(message_doc["_id"]), | |
| "sender_id": str(message_doc["sender_id"]), | |
| "recipient_id": str(message_doc["recipient_id"]), | |
| "sender_name": current_user["full_name"], | |
| "recipient_name": recipient["full_name"], | |
| "content": message_doc["content"], | |
| "message_type": message_doc["message_type"], | |
| "attachment_url": message_doc["attachment_url"], | |
| "reply_to_message_id": str(message_doc["reply_to_message_id"]) if message_doc["reply_to_message_id"] else None, | |
| "status": message_doc["status"], | |
| "is_archived": message_doc["is_archived"], | |
| "created_at": message_doc["created_at"], | |
| "updated_at": message_doc["updated_at"] | |
| } | |
| }, message_data.recipient_id) | |
| # Create notification for recipient | |
| await create_notification( | |
| db_client, | |
| message_data.recipient_id, | |
| f"New message from {current_user['full_name']}", | |
| message_data.content[:100] + "..." if len(message_data.content) > 100 else message_data.content, | |
| NotificationType.MESSAGE, | |
| NotificationPriority.MEDIUM, | |
| {"message_id": str(message_doc["_id"]), "sender_id": str(current_user["_id"])} | |
| ) | |
| return MessageResponse( | |
| id=str(message_doc["_id"]), | |
| sender_id=str(message_doc["sender_id"]), | |
| recipient_id=str(message_doc["recipient_id"]), | |
| sender_name=current_user["full_name"], | |
| recipient_name=recipient["full_name"], | |
| content=message_doc["content"], | |
| message_type=message_doc["message_type"], | |
| attachment_url=message_doc["attachment_url"], | |
| reply_to_message_id=str(message_doc["reply_to_message_id"]) if message_doc["reply_to_message_id"] else None, | |
| status=message_doc["status"], | |
| is_archived=message_doc["is_archived"], | |
| created_at=message_doc["created_at"], | |
| updated_at=message_doc["updated_at"] | |
| ) | |
| async def get_messages( | |
| conversation_id: str, | |
| page: int = Query(1, ge=1), | |
| limit: int = Query(50, ge=1, le=100), | |
| current_user: dict = Depends(get_current_user), | |
| db_client: AsyncIOMotorClient = Depends(lambda: db) | |
| ): | |
| """Get messages for a specific conversation""" | |
| skip = (page - 1) * limit | |
| user_id = current_user["_id"] | |
| # Parse conversation ID to get the other participant | |
| try: | |
| participant_ids = conversation_id.split("_") | |
| if len(participant_ids) != 2: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Invalid conversation ID format" | |
| ) | |
| # Find the other participant | |
| other_user_id = None | |
| user_id_str = str(user_id) | |
| for pid in participant_ids: | |
| if pid != user_id_str: | |
| other_user_id = pid | |
| break | |
| if not other_user_id or not is_valid_object_id(other_user_id): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Invalid conversation ID" | |
| ) | |
| # Verify the other user exists | |
| other_user = await db_client.users.find_one({"_id": ObjectId(other_user_id)}) | |
| if not other_user: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="Conversation participant not found" | |
| ) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Invalid conversation ID" | |
| ) | |
| # Get messages between the two users | |
| filter_query = { | |
| "$or": [ | |
| { | |
| "sender_id": ObjectId(user_id), | |
| "recipient_id": ObjectId(other_user_id) | |
| }, | |
| { | |
| "sender_id": ObjectId(other_user_id), | |
| "recipient_id": ObjectId(user_id) | |
| } | |
| ] | |
| } | |
| # Get messages | |
| cursor = db_client.messages.find(filter_query).sort("created_at", -1).skip(skip).limit(limit) | |
| messages = await cursor.to_list(length=limit) | |
| # Mark messages as read | |
| unread_messages = [ | |
| msg["_id"] for msg in messages | |
| if msg["recipient_id"] == ObjectId(user_id) and msg["status"] != "read" | |
| ] | |
| if unread_messages: | |
| await db_client.messages.update_many( | |
| {"_id": {"$in": unread_messages}}, | |
| {"$set": {"status": "read", "read_at": datetime.now()}} | |
| ) | |
| # Get total count | |
| total = await db_client.messages.count_documents(filter_query) | |
| # Build message responses | |
| message_responses = [] | |
| for msg in messages: | |
| sender = await db_client.users.find_one({"_id": msg["sender_id"]}) | |
| recipient = await db_client.users.find_one({"_id": msg["recipient_id"]}) | |
| message_responses.append(MessageResponse( | |
| id=str(msg["_id"]), | |
| sender_id=str(msg["sender_id"]), | |
| recipient_id=str(msg["recipient_id"]), | |
| sender_name=sender["full_name"] if sender else "Unknown User", | |
| recipient_name=recipient["full_name"] if recipient else "Unknown User", | |
| content=msg["content"], | |
| message_type=msg["message_type"], | |
| attachment_url=msg.get("attachment_url"), | |
| reply_to_message_id=str(msg["reply_to_message_id"]) if msg.get("reply_to_message_id") else None, | |
| status=msg["status"], | |
| is_archived=msg.get("is_archived", False), | |
| created_at=msg["created_at"], | |
| updated_at=msg["updated_at"], | |
| read_at=msg.get("read_at") | |
| )) | |
| return MessageListResponse( | |
| messages=message_responses, | |
| total=total, | |
| page=page, | |
| limit=limit, | |
| conversation_id=conversation_id | |
| ) | |
| async def update_message( | |
| message_id: str, | |
| message_data: MessageUpdate, | |
| current_user: dict = Depends(get_current_user), | |
| db_client: AsyncIOMotorClient = Depends(lambda: db) | |
| ): | |
| """Update a message (only sender can update)""" | |
| if not is_valid_object_id(message_id): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Invalid message ID" | |
| ) | |
| message = await db_client.messages.find_one({"_id": ObjectId(message_id)}) | |
| if not message: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="Message not found" | |
| ) | |
| # Only sender can update message | |
| if message["sender_id"] != ObjectId(current_user["_id"]): | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="You can only update your own messages" | |
| ) | |
| # Build update data | |
| update_data = {"updated_at": datetime.now()} | |
| if message_data.content is not None: | |
| update_data["content"] = message_data.content | |
| if message_data.is_archived is not None: | |
| update_data["is_archived"] = message_data.is_archived | |
| # Update message | |
| await db_client.messages.update_one( | |
| {"_id": ObjectId(message_id)}, | |
| {"$set": update_data} | |
| ) | |
| # Get updated message | |
| updated_message = await db_client.messages.find_one({"_id": ObjectId(message_id)}) | |
| # Get sender and recipient names | |
| sender = await db_client.users.find_one({"_id": updated_message["sender_id"]}) | |
| recipient = await db_client.users.find_one({"_id": updated_message["recipient_id"]}) | |
| return MessageResponse( | |
| id=str(updated_message["_id"]), | |
| sender_id=str(updated_message["sender_id"]), | |
| recipient_id=str(updated_message["recipient_id"]), | |
| sender_name=sender["full_name"] if sender else "Unknown User", | |
| recipient_name=recipient["full_name"] if recipient else "Unknown User", | |
| content=updated_message["content"], | |
| message_type=updated_message["message_type"], | |
| attachment_url=updated_message.get("attachment_url"), | |
| reply_to_message_id=str(updated_message["reply_to_message_id"]) if updated_message.get("reply_to_message_id") else None, | |
| status=updated_message["status"], | |
| is_archived=updated_message.get("is_archived", False), | |
| created_at=updated_message["created_at"], | |
| updated_at=updated_message["updated_at"], | |
| read_at=updated_message.get("read_at") | |
| ) | |
| async def delete_message( | |
| message_id: str, | |
| current_user: dict = Depends(get_current_user), | |
| db_client: AsyncIOMotorClient = Depends(lambda: db) | |
| ): | |
| """Delete a message (only sender can delete)""" | |
| if not is_valid_object_id(message_id): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Invalid message ID" | |
| ) | |
| message = await db_client.messages.find_one({"_id": ObjectId(message_id)}) | |
| if not message: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="Message not found" | |
| ) | |
| # Only sender can delete message | |
| if message["sender_id"] != ObjectId(current_user["_id"]): | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="You can only delete your own messages" | |
| ) | |
| await db_client.messages.delete_one({"_id": ObjectId(message_id)}) | |
| # --- NOTIFICATION ENDPOINTS --- | |
| async def get_notifications( | |
| page: int = Query(1, ge=1), | |
| limit: int = Query(20, ge=1, le=100), | |
| unread_only: bool = Query(False), | |
| current_user: dict = Depends(get_current_user), | |
| db_client: AsyncIOMotorClient = Depends(lambda: db) | |
| ): | |
| """Get user's notifications""" | |
| skip = (page - 1) * limit | |
| user_id = current_user["_id"] | |
| # Build filter | |
| filter_query = {"recipient_id": ObjectId(user_id)} | |
| if unread_only: | |
| filter_query["is_read"] = False | |
| # Get notifications | |
| cursor = db_client.notifications.find(filter_query).sort("created_at", -1).skip(skip).limit(limit) | |
| notifications = await cursor.to_list(length=limit) | |
| # Get total count and unread count | |
| total = await db_client.notifications.count_documents(filter_query) | |
| unread_count = await db_client.notifications.count_documents({ | |
| "recipient_id": ObjectId(user_id), | |
| "is_read": False | |
| }) | |
| # Build notification responses | |
| notification_responses = [] | |
| for notif in notifications: | |
| # Convert any ObjectId fields in data to strings | |
| data = notif.get("data", {}) | |
| if data: | |
| # Convert ObjectId fields to strings | |
| for key, value in data.items(): | |
| if isinstance(value, ObjectId): | |
| data[key] = str(value) | |
| notification_responses.append(NotificationResponse( | |
| id=str(notif["_id"]), | |
| recipient_id=str(notif["recipient_id"]), | |
| recipient_name=current_user["full_name"], | |
| title=notif["title"], | |
| message=notif["message"], | |
| notification_type=notif["notification_type"], | |
| priority=notif["priority"], | |
| data=data, | |
| is_read=notif.get("is_read", False), | |
| created_at=notif["created_at"], | |
| read_at=notif.get("read_at") | |
| )) | |
| return NotificationListResponse( | |
| notifications=notification_responses, | |
| total=total, | |
| unread_count=unread_count, | |
| page=page, | |
| limit=limit | |
| ) | |
| async def mark_notification_read( | |
| notification_id: str, | |
| current_user: dict = Depends(get_current_user), | |
| db_client: AsyncIOMotorClient = Depends(lambda: db) | |
| ): | |
| """Mark a notification as read""" | |
| if not is_valid_object_id(notification_id): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Invalid notification ID" | |
| ) | |
| notification = await db_client.notifications.find_one({"_id": ObjectId(notification_id)}) | |
| if not notification: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="Notification not found" | |
| ) | |
| # Only recipient can mark as read | |
| if notification["recipient_id"] != ObjectId(current_user["_id"]): | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="You can only mark your own notifications as read" | |
| ) | |
| # Update notification | |
| await db_client.notifications.update_one( | |
| {"_id": ObjectId(notification_id)}, | |
| {"$set": {"is_read": True, "read_at": datetime.now()}} | |
| ) | |
| # Get updated notification | |
| updated_notification = await db_client.notifications.find_one({"_id": ObjectId(notification_id)}) | |
| # Convert any ObjectId fields in data to strings | |
| data = updated_notification.get("data", {}) | |
| if data: | |
| # Convert ObjectId fields to strings | |
| for key, value in data.items(): | |
| if isinstance(value, ObjectId): | |
| data[key] = str(value) | |
| return NotificationResponse( | |
| id=str(updated_notification["_id"]), | |
| recipient_id=str(updated_notification["recipient_id"]), | |
| recipient_name=current_user["full_name"], | |
| title=updated_notification["title"], | |
| message=updated_notification["message"], | |
| notification_type=updated_notification["notification_type"], | |
| priority=updated_notification["priority"], | |
| data=data, | |
| is_read=updated_notification.get("is_read", False), | |
| created_at=updated_notification["created_at"], | |
| read_at=updated_notification.get("read_at") | |
| ) | |
| async def mark_all_notifications_read( | |
| current_user: dict = Depends(get_current_user), | |
| db_client: AsyncIOMotorClient = Depends(lambda: db) | |
| ): | |
| """Mark all user's notifications as read""" | |
| user_id = current_user["_id"] | |
| await db_client.notifications.update_many( | |
| { | |
| "recipient_id": ObjectId(user_id), | |
| "is_read": False | |
| }, | |
| { | |
| "$set": { | |
| "is_read": True, | |
| "read_at": datetime.now() | |
| } | |
| } | |
| ) | |
| return {"message": "All notifications marked as read"} | |
| # --- FILE UPLOAD ENDPOINT --- | |
| async def upload_file( | |
| file: UploadFile = File(...), | |
| current_user: dict = Depends(get_current_user) | |
| ): | |
| """Upload a file for messaging""" | |
| # Validate file type | |
| allowed_types = { | |
| 'image': ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp'], | |
| 'document': ['.pdf', '.doc', '.docx', '.txt', '.rtf'], | |
| 'spreadsheet': ['.xls', '.xlsx', '.csv'], | |
| 'presentation': ['.ppt', '.pptx'], | |
| 'archive': ['.zip', '.rar', '.7z'] | |
| } | |
| # Get file extension | |
| file_ext = Path(file.filename).suffix.lower() | |
| # Check if file type is allowed | |
| is_allowed = False | |
| file_category = None | |
| for category, extensions in allowed_types.items(): | |
| if file_ext in extensions: | |
| is_allowed = True | |
| file_category = category | |
| break | |
| if not is_allowed: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"File type {file_ext} is not allowed. Allowed types: {', '.join([ext for exts in allowed_types.values() for ext in exts])}" | |
| ) | |
| # Check file size (max 10MB) | |
| max_size = 10 * 1024 * 1024 # 10MB | |
| if file.size and file.size > max_size: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"File size exceeds maximum limit of 10MB" | |
| ) | |
| # Create uploads directory if it doesn't exist | |
| def get_upload_dir(): | |
| default_dir = Path("uploads") | |
| try: | |
| default_dir.mkdir(exist_ok=True) | |
| return default_dir | |
| except (PermissionError, OSError): | |
| # In containerized environments, use temp directory | |
| import tempfile | |
| temp_dir = Path(tempfile.gettempdir()) / "uploads" | |
| try: | |
| temp_dir.mkdir(exist_ok=True) | |
| return temp_dir | |
| except (PermissionError, OSError): | |
| # Last resort: use current directory | |
| current_dir = Path.cwd() / "uploads" | |
| current_dir.mkdir(exist_ok=True) | |
| return current_dir | |
| upload_dir = get_upload_dir() | |
| # Create category subdirectory | |
| category_dir = upload_dir / file_category | |
| category_dir.mkdir(exist_ok=True) | |
| # Generate unique filename | |
| unique_filename = f"{uuid.uuid4()}{file_ext}" | |
| file_path = category_dir / unique_filename | |
| try: | |
| # Save file | |
| with open(file_path, "wb") as buffer: | |
| content = await file.read() | |
| buffer.write(content) | |
| # Return file info | |
| return { | |
| "filename": file.filename, | |
| "file_url": f"/uploads/{file_category}/{unique_filename}", | |
| "file_size": len(content), | |
| "file_type": file_category, | |
| "message_type": MessageType.IMAGE if file_category == 'image' else MessageType.FILE | |
| } | |
| except Exception as e: | |
| # Clean up file if save fails | |
| if file_path.exists(): | |
| file_path.unlink() | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Failed to upload file: {str(e)}" | |
| ) | |
| # --- STATIC FILE SERVING --- | |
| async def serve_file(category: str, filename: str): | |
| """Serve uploaded files""" | |
| import os | |
| # Get the current working directory and construct absolute path | |
| current_dir = os.getcwd() | |
| file_path = Path(current_dir) / "uploads" / category / filename | |
| print(f"🔍 Looking for file: {file_path}") | |
| print(f"📁 File exists: {file_path.exists()}") | |
| if not file_path.exists(): | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail=f"File not found: {file_path}" | |
| ) | |
| # Determine content type based on file extension | |
| ext = file_path.suffix.lower() | |
| content_types = { | |
| '.jpg': 'image/jpeg', | |
| '.jpeg': 'image/jpeg', | |
| '.png': 'image/png', | |
| '.gif': 'image/gif', | |
| '.bmp': 'image/bmp', | |
| '.webp': 'image/webp', | |
| '.pdf': 'application/pdf', | |
| '.doc': 'application/msword', | |
| '.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', | |
| '.txt': 'text/plain', | |
| '.rtf': 'application/rtf', | |
| '.xls': 'application/vnd.ms-excel', | |
| '.xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', | |
| '.csv': 'text/csv', | |
| '.ppt': 'application/vnd.ms-powerpoint', | |
| '.pptx': 'application/vnd.openxmlformats-officedocument.presentationml.presentation', | |
| '.zip': 'application/zip', | |
| '.rar': 'application/x-rar-compressed', | |
| '.7z': 'application/x-7z-compressed' | |
| } | |
| content_type = content_types.get(ext, 'application/octet-stream') | |
| try: | |
| # Read and return file content | |
| with open(file_path, "rb") as f: | |
| content = f.read() | |
| from fastapi.responses import Response | |
| return Response( | |
| content=content, | |
| media_type=content_type, | |
| headers={ | |
| "Content-Disposition": f"inline; filename={filename}", | |
| "Cache-Control": "public, max-age=31536000" | |
| } | |
| ) | |
| except Exception as e: | |
| print(f"❌ Error reading file: {e}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Error reading file: {str(e)}" | |
| ) |