cps-api-tx / api /routes /messaging.py
Ali2206's picture
Fix PermissionError with robust directory creation fallback
c378250
from fastapi import APIRouter, Depends, HTTPException, status, Query, WebSocket, WebSocketDisconnect, UploadFile, File
from typing import List, Optional
from datetime import datetime, timedelta
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
import json
import asyncio
from collections import defaultdict
import os
import uuid
from pathlib import Path
from core.security import get_current_user
from db.mongo import db
from models.schemas import (
MessageCreate, MessageUpdate, MessageResponse, MessageListResponse,
ConversationResponse, ConversationListResponse, MessageType, MessageStatus,
NotificationCreate, NotificationResponse, NotificationListResponse,
NotificationType, NotificationPriority
)
router = APIRouter(prefix="/messaging", tags=["messaging"])
# WebSocket connection manager
class ConnectionManager:
def __init__(self):
self.active_connections: dict = defaultdict(list) # user_id -> list of connections
async def connect(self, websocket: WebSocket, user_id: str):
await websocket.accept()
self.active_connections[user_id].append(websocket)
def disconnect(self, websocket: WebSocket, user_id: str):
if user_id in self.active_connections:
self.active_connections[user_id] = [
conn for conn in self.active_connections[user_id] if conn != websocket
]
async def send_personal_message(self, message: dict, user_id: str):
if user_id in self.active_connections:
for connection in self.active_connections[user_id]:
try:
await connection.send_text(json.dumps(message))
except:
# Remove dead connections
self.active_connections[user_id].remove(connection)
manager = ConnectionManager()
# --- HELPER FUNCTIONS ---
def is_valid_object_id(id_str: str) -> bool:
try:
ObjectId(id_str)
return True
except:
return False
def get_conversation_id(user1_id, user2_id) -> str:
"""Generate a consistent conversation ID for two users"""
# Convert both IDs to strings for consistent comparison
user1_str = str(user1_id)
user2_str = str(user2_id)
# Sort IDs to ensure consistent conversation ID regardless of sender/recipient
sorted_ids = sorted([user1_str, user2_str])
return f"{sorted_ids[0]}_{sorted_ids[1]}"
async def create_notification(
db_client: AsyncIOMotorClient,
recipient_id: str,
title: str,
message: str,
notification_type: NotificationType,
priority: NotificationPriority = NotificationPriority.MEDIUM,
data: Optional[dict] = None
):
"""Create a notification for a user"""
notification_doc = {
"recipient_id": ObjectId(recipient_id),
"title": title,
"message": message,
"notification_type": notification_type,
"priority": priority,
"data": data or {},
"is_read": False,
"created_at": datetime.now()
}
result = await db_client.notifications.insert_one(notification_doc)
notification_doc["_id"] = result.inserted_id
# Convert ObjectId to string for WebSocket transmission
notification_for_ws = {
"id": str(notification_doc["_id"]),
"recipient_id": str(notification_doc["recipient_id"]),
"title": notification_doc["title"],
"message": notification_doc["message"],
"notification_type": notification_doc["notification_type"],
"priority": notification_doc["priority"],
"data": notification_doc["data"],
"is_read": notification_doc["is_read"],
"created_at": notification_doc["created_at"]
}
# Send real-time notification via WebSocket
await manager.send_personal_message({
"type": "new_notification",
"data": notification_for_ws
}, recipient_id)
return notification_doc
# --- WEBSOCKET ENDPOINT ---
@router.websocket("/ws/{user_id}")
async def websocket_endpoint(websocket: WebSocket, user_id: str):
await manager.connect(websocket, user_id)
print(f"🔌 WebSocket connected for user: {user_id}")
try:
while True:
# Wait for messages from client (keep connection alive)
data = await websocket.receive_text()
try:
message_data = json.loads(data)
if message_data.get("type") == "ping":
# Send pong to keep connection alive
await websocket.send_text(json.dumps({"type": "pong"}))
except json.JSONDecodeError:
pass # Ignore invalid JSON
except WebSocketDisconnect:
print(f"🔌 WebSocket disconnected for user: {user_id}")
manager.disconnect(websocket, user_id)
except Exception as e:
print(f"❌ WebSocket error for user {user_id}: {e}")
manager.disconnect(websocket, user_id)
# --- CONVERSATION ENDPOINTS ---
@router.get("/conversations", response_model=ConversationListResponse)
async def get_conversations(
page: int = Query(1, ge=1),
limit: int = Query(20, ge=1, le=100),
current_user: dict = Depends(get_current_user),
db_client: AsyncIOMotorClient = Depends(lambda: db)
):
"""Get user's conversations"""
skip = (page - 1) * limit
user_id = current_user["_id"]
# Get all messages where user is sender or recipient
pipeline = [
{
"$match": {
"$or": [
{"sender_id": ObjectId(user_id)},
{"recipient_id": ObjectId(user_id)}
]
}
},
{
"$sort": {"created_at": -1}
},
{
"$group": {
"_id": {
"$cond": [
{"$eq": ["$sender_id", ObjectId(user_id)]},
"$recipient_id",
"$sender_id"
]
},
"last_message": {"$first": "$$ROOT"},
"unread_count": {
"$sum": {
"$cond": [
{
"$and": [
{"$eq": ["$recipient_id", ObjectId(user_id)]},
{"$ne": ["$status", "read"]}
]
},
1,
0
]
}
}
}
},
{
"$sort": {"last_message.created_at": -1}
},
{
"$skip": skip
},
{
"$limit": limit
}
]
conversations_data = await db_client.messages.aggregate(pipeline).to_list(length=limit)
# Get user details for each conversation
conversations = []
for conv_data in conversations_data:
other_user_id = str(conv_data["_id"])
other_user = await db_client.users.find_one({"_id": conv_data["_id"]})
if other_user:
# Convert user_id to string for consistent comparison
user_id_str = str(user_id)
conversation_id = get_conversation_id(user_id_str, other_user_id)
# Build last message response
last_message = None
if conv_data["last_message"]:
last_message = MessageResponse(
id=str(conv_data["last_message"]["_id"]),
sender_id=str(conv_data["last_message"]["sender_id"]),
recipient_id=str(conv_data["last_message"]["recipient_id"]),
sender_name=current_user["full_name"] if conv_data["last_message"]["sender_id"] == ObjectId(user_id) else other_user["full_name"],
recipient_name=other_user["full_name"] if conv_data["last_message"]["sender_id"] == ObjectId(user_id) else current_user["full_name"],
content=conv_data["last_message"]["content"],
message_type=conv_data["last_message"]["message_type"],
attachment_url=conv_data["last_message"].get("attachment_url"),
reply_to_message_id=str(conv_data["last_message"]["reply_to_message_id"]) if conv_data["last_message"].get("reply_to_message_id") else None,
status=conv_data["last_message"]["status"],
is_archived=conv_data["last_message"].get("is_archived", False),
created_at=conv_data["last_message"]["created_at"],
updated_at=conv_data["last_message"]["updated_at"],
read_at=conv_data["last_message"].get("read_at")
)
conversations.append(ConversationResponse(
id=conversation_id,
participant_ids=[user_id_str, other_user_id],
participant_names=[current_user["full_name"], other_user["full_name"]],
last_message=last_message,
unread_count=conv_data["unread_count"],
created_at=conv_data["last_message"]["created_at"] if conv_data["last_message"] else datetime.now(),
updated_at=conv_data["last_message"]["updated_at"] if conv_data["last_message"] else datetime.now()
))
# Get total count
total_pipeline = [
{
"$match": {
"$or": [
{"sender_id": ObjectId(user_id)},
{"recipient_id": ObjectId(user_id)}
]
}
},
{
"$group": {
"_id": {
"$cond": [
{"$eq": ["$sender_id", ObjectId(user_id)]},
"$recipient_id",
"$sender_id"
]
}
}
},
{
"$count": "total"
}
]
total_result = await db_client.messages.aggregate(total_pipeline).to_list(length=1)
total = total_result[0]["total"] if total_result else 0
return ConversationListResponse(
conversations=conversations,
total=total,
page=page,
limit=limit
)
# --- MESSAGE ENDPOINTS ---
@router.post("/messages", response_model=MessageResponse, status_code=status.HTTP_201_CREATED)
async def send_message(
message_data: MessageCreate,
current_user: dict = Depends(get_current_user),
db_client: AsyncIOMotorClient = Depends(lambda: db)
):
"""Send a message to another user"""
if not is_valid_object_id(message_data.recipient_id):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid recipient ID"
)
# Check if recipient exists
recipient = await db_client.users.find_one({"_id": ObjectId(message_data.recipient_id)})
if not recipient:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Recipient not found"
)
# Check if user can message this recipient
# Patients can only message their doctors, doctors can message their patients
current_user_roles = current_user.get('roles', [])
if isinstance(current_user.get('role'), str):
current_user_roles.append(current_user.get('role'))
recipient_roles = recipient.get('roles', [])
if isinstance(recipient.get('role'), str):
recipient_roles.append(recipient.get('role'))
if 'patient' in current_user_roles:
# Patients can only message doctors
if 'doctor' not in recipient_roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Patients can only message doctors"
)
elif 'doctor' in current_user_roles:
# Doctors can only message their patients
if 'patient' not in recipient_roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Doctors can only message patients"
)
# Check reply message if provided
if message_data.reply_to_message_id:
if not is_valid_object_id(message_data.reply_to_message_id):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid reply message ID"
)
reply_message = await db_client.messages.find_one({"_id": ObjectId(message_data.reply_to_message_id)})
if not reply_message:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Reply message not found"
)
# Create message
message_doc = {
"sender_id": ObjectId(current_user["_id"]),
"recipient_id": ObjectId(message_data.recipient_id),
"content": message_data.content,
"message_type": message_data.message_type,
"attachment_url": message_data.attachment_url,
"reply_to_message_id": ObjectId(message_data.reply_to_message_id) if message_data.reply_to_message_id else None,
"status": MessageStatus.SENT,
"is_archived": False,
"created_at": datetime.now(),
"updated_at": datetime.now()
}
result = await db_client.messages.insert_one(message_doc)
message_doc["_id"] = result.inserted_id
# Send real-time message via WebSocket
await manager.send_personal_message({
"type": "new_message",
"data": {
"id": str(message_doc["_id"]),
"sender_id": str(message_doc["sender_id"]),
"recipient_id": str(message_doc["recipient_id"]),
"sender_name": current_user["full_name"],
"recipient_name": recipient["full_name"],
"content": message_doc["content"],
"message_type": message_doc["message_type"],
"attachment_url": message_doc["attachment_url"],
"reply_to_message_id": str(message_doc["reply_to_message_id"]) if message_doc["reply_to_message_id"] else None,
"status": message_doc["status"],
"is_archived": message_doc["is_archived"],
"created_at": message_doc["created_at"],
"updated_at": message_doc["updated_at"]
}
}, message_data.recipient_id)
# Create notification for recipient
await create_notification(
db_client,
message_data.recipient_id,
f"New message from {current_user['full_name']}",
message_data.content[:100] + "..." if len(message_data.content) > 100 else message_data.content,
NotificationType.MESSAGE,
NotificationPriority.MEDIUM,
{"message_id": str(message_doc["_id"]), "sender_id": str(current_user["_id"])}
)
return MessageResponse(
id=str(message_doc["_id"]),
sender_id=str(message_doc["sender_id"]),
recipient_id=str(message_doc["recipient_id"]),
sender_name=current_user["full_name"],
recipient_name=recipient["full_name"],
content=message_doc["content"],
message_type=message_doc["message_type"],
attachment_url=message_doc["attachment_url"],
reply_to_message_id=str(message_doc["reply_to_message_id"]) if message_doc["reply_to_message_id"] else None,
status=message_doc["status"],
is_archived=message_doc["is_archived"],
created_at=message_doc["created_at"],
updated_at=message_doc["updated_at"]
)
@router.get("/messages/{conversation_id}", response_model=MessageListResponse)
async def get_messages(
conversation_id: str,
page: int = Query(1, ge=1),
limit: int = Query(50, ge=1, le=100),
current_user: dict = Depends(get_current_user),
db_client: AsyncIOMotorClient = Depends(lambda: db)
):
"""Get messages for a specific conversation"""
skip = (page - 1) * limit
user_id = current_user["_id"]
# Parse conversation ID to get the other participant
try:
participant_ids = conversation_id.split("_")
if len(participant_ids) != 2:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid conversation ID format"
)
# Find the other participant
other_user_id = None
user_id_str = str(user_id)
for pid in participant_ids:
if pid != user_id_str:
other_user_id = pid
break
if not other_user_id or not is_valid_object_id(other_user_id):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid conversation ID"
)
# Verify the other user exists
other_user = await db_client.users.find_one({"_id": ObjectId(other_user_id)})
if not other_user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Conversation participant not found"
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid conversation ID"
)
# Get messages between the two users
filter_query = {
"$or": [
{
"sender_id": ObjectId(user_id),
"recipient_id": ObjectId(other_user_id)
},
{
"sender_id": ObjectId(other_user_id),
"recipient_id": ObjectId(user_id)
}
]
}
# Get messages
cursor = db_client.messages.find(filter_query).sort("created_at", -1).skip(skip).limit(limit)
messages = await cursor.to_list(length=limit)
# Mark messages as read
unread_messages = [
msg["_id"] for msg in messages
if msg["recipient_id"] == ObjectId(user_id) and msg["status"] != "read"
]
if unread_messages:
await db_client.messages.update_many(
{"_id": {"$in": unread_messages}},
{"$set": {"status": "read", "read_at": datetime.now()}}
)
# Get total count
total = await db_client.messages.count_documents(filter_query)
# Build message responses
message_responses = []
for msg in messages:
sender = await db_client.users.find_one({"_id": msg["sender_id"]})
recipient = await db_client.users.find_one({"_id": msg["recipient_id"]})
message_responses.append(MessageResponse(
id=str(msg["_id"]),
sender_id=str(msg["sender_id"]),
recipient_id=str(msg["recipient_id"]),
sender_name=sender["full_name"] if sender else "Unknown User",
recipient_name=recipient["full_name"] if recipient else "Unknown User",
content=msg["content"],
message_type=msg["message_type"],
attachment_url=msg.get("attachment_url"),
reply_to_message_id=str(msg["reply_to_message_id"]) if msg.get("reply_to_message_id") else None,
status=msg["status"],
is_archived=msg.get("is_archived", False),
created_at=msg["created_at"],
updated_at=msg["updated_at"],
read_at=msg.get("read_at")
))
return MessageListResponse(
messages=message_responses,
total=total,
page=page,
limit=limit,
conversation_id=conversation_id
)
@router.put("/messages/{message_id}", response_model=MessageResponse)
async def update_message(
message_id: str,
message_data: MessageUpdate,
current_user: dict = Depends(get_current_user),
db_client: AsyncIOMotorClient = Depends(lambda: db)
):
"""Update a message (only sender can update)"""
if not is_valid_object_id(message_id):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid message ID"
)
message = await db_client.messages.find_one({"_id": ObjectId(message_id)})
if not message:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Message not found"
)
# Only sender can update message
if message["sender_id"] != ObjectId(current_user["_id"]):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You can only update your own messages"
)
# Build update data
update_data = {"updated_at": datetime.now()}
if message_data.content is not None:
update_data["content"] = message_data.content
if message_data.is_archived is not None:
update_data["is_archived"] = message_data.is_archived
# Update message
await db_client.messages.update_one(
{"_id": ObjectId(message_id)},
{"$set": update_data}
)
# Get updated message
updated_message = await db_client.messages.find_one({"_id": ObjectId(message_id)})
# Get sender and recipient names
sender = await db_client.users.find_one({"_id": updated_message["sender_id"]})
recipient = await db_client.users.find_one({"_id": updated_message["recipient_id"]})
return MessageResponse(
id=str(updated_message["_id"]),
sender_id=str(updated_message["sender_id"]),
recipient_id=str(updated_message["recipient_id"]),
sender_name=sender["full_name"] if sender else "Unknown User",
recipient_name=recipient["full_name"] if recipient else "Unknown User",
content=updated_message["content"],
message_type=updated_message["message_type"],
attachment_url=updated_message.get("attachment_url"),
reply_to_message_id=str(updated_message["reply_to_message_id"]) if updated_message.get("reply_to_message_id") else None,
status=updated_message["status"],
is_archived=updated_message.get("is_archived", False),
created_at=updated_message["created_at"],
updated_at=updated_message["updated_at"],
read_at=updated_message.get("read_at")
)
@router.delete("/messages/{message_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_message(
message_id: str,
current_user: dict = Depends(get_current_user),
db_client: AsyncIOMotorClient = Depends(lambda: db)
):
"""Delete a message (only sender can delete)"""
if not is_valid_object_id(message_id):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid message ID"
)
message = await db_client.messages.find_one({"_id": ObjectId(message_id)})
if not message:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Message not found"
)
# Only sender can delete message
if message["sender_id"] != ObjectId(current_user["_id"]):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You can only delete your own messages"
)
await db_client.messages.delete_one({"_id": ObjectId(message_id)})
# --- NOTIFICATION ENDPOINTS ---
@router.get("/notifications", response_model=NotificationListResponse)
async def get_notifications(
page: int = Query(1, ge=1),
limit: int = Query(20, ge=1, le=100),
unread_only: bool = Query(False),
current_user: dict = Depends(get_current_user),
db_client: AsyncIOMotorClient = Depends(lambda: db)
):
"""Get user's notifications"""
skip = (page - 1) * limit
user_id = current_user["_id"]
# Build filter
filter_query = {"recipient_id": ObjectId(user_id)}
if unread_only:
filter_query["is_read"] = False
# Get notifications
cursor = db_client.notifications.find(filter_query).sort("created_at", -1).skip(skip).limit(limit)
notifications = await cursor.to_list(length=limit)
# Get total count and unread count
total = await db_client.notifications.count_documents(filter_query)
unread_count = await db_client.notifications.count_documents({
"recipient_id": ObjectId(user_id),
"is_read": False
})
# Build notification responses
notification_responses = []
for notif in notifications:
# Convert any ObjectId fields in data to strings
data = notif.get("data", {})
if data:
# Convert ObjectId fields to strings
for key, value in data.items():
if isinstance(value, ObjectId):
data[key] = str(value)
notification_responses.append(NotificationResponse(
id=str(notif["_id"]),
recipient_id=str(notif["recipient_id"]),
recipient_name=current_user["full_name"],
title=notif["title"],
message=notif["message"],
notification_type=notif["notification_type"],
priority=notif["priority"],
data=data,
is_read=notif.get("is_read", False),
created_at=notif["created_at"],
read_at=notif.get("read_at")
))
return NotificationListResponse(
notifications=notification_responses,
total=total,
unread_count=unread_count,
page=page,
limit=limit
)
@router.put("/notifications/{notification_id}/read", response_model=NotificationResponse)
async def mark_notification_read(
notification_id: str,
current_user: dict = Depends(get_current_user),
db_client: AsyncIOMotorClient = Depends(lambda: db)
):
"""Mark a notification as read"""
if not is_valid_object_id(notification_id):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid notification ID"
)
notification = await db_client.notifications.find_one({"_id": ObjectId(notification_id)})
if not notification:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Notification not found"
)
# Only recipient can mark as read
if notification["recipient_id"] != ObjectId(current_user["_id"]):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You can only mark your own notifications as read"
)
# Update notification
await db_client.notifications.update_one(
{"_id": ObjectId(notification_id)},
{"$set": {"is_read": True, "read_at": datetime.now()}}
)
# Get updated notification
updated_notification = await db_client.notifications.find_one({"_id": ObjectId(notification_id)})
# Convert any ObjectId fields in data to strings
data = updated_notification.get("data", {})
if data:
# Convert ObjectId fields to strings
for key, value in data.items():
if isinstance(value, ObjectId):
data[key] = str(value)
return NotificationResponse(
id=str(updated_notification["_id"]),
recipient_id=str(updated_notification["recipient_id"]),
recipient_name=current_user["full_name"],
title=updated_notification["title"],
message=updated_notification["message"],
notification_type=updated_notification["notification_type"],
priority=updated_notification["priority"],
data=data,
is_read=updated_notification.get("is_read", False),
created_at=updated_notification["created_at"],
read_at=updated_notification.get("read_at")
)
@router.put("/notifications/read-all", status_code=status.HTTP_200_OK)
async def mark_all_notifications_read(
current_user: dict = Depends(get_current_user),
db_client: AsyncIOMotorClient = Depends(lambda: db)
):
"""Mark all user's notifications as read"""
user_id = current_user["_id"]
await db_client.notifications.update_many(
{
"recipient_id": ObjectId(user_id),
"is_read": False
},
{
"$set": {
"is_read": True,
"read_at": datetime.now()
}
}
)
return {"message": "All notifications marked as read"}
# --- FILE UPLOAD ENDPOINT ---
@router.post("/upload", status_code=status.HTTP_201_CREATED)
async def upload_file(
file: UploadFile = File(...),
current_user: dict = Depends(get_current_user)
):
"""Upload a file for messaging"""
# Validate file type
allowed_types = {
'image': ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp'],
'document': ['.pdf', '.doc', '.docx', '.txt', '.rtf'],
'spreadsheet': ['.xls', '.xlsx', '.csv'],
'presentation': ['.ppt', '.pptx'],
'archive': ['.zip', '.rar', '.7z']
}
# Get file extension
file_ext = Path(file.filename).suffix.lower()
# Check if file type is allowed
is_allowed = False
file_category = None
for category, extensions in allowed_types.items():
if file_ext in extensions:
is_allowed = True
file_category = category
break
if not is_allowed:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"File type {file_ext} is not allowed. Allowed types: {', '.join([ext for exts in allowed_types.values() for ext in exts])}"
)
# Check file size (max 10MB)
max_size = 10 * 1024 * 1024 # 10MB
if file.size and file.size > max_size:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"File size exceeds maximum limit of 10MB"
)
# Create uploads directory if it doesn't exist
def get_upload_dir():
default_dir = Path("uploads")
try:
default_dir.mkdir(exist_ok=True)
return default_dir
except (PermissionError, OSError):
# In containerized environments, use temp directory
import tempfile
temp_dir = Path(tempfile.gettempdir()) / "uploads"
try:
temp_dir.mkdir(exist_ok=True)
return temp_dir
except (PermissionError, OSError):
# Last resort: use current directory
current_dir = Path.cwd() / "uploads"
current_dir.mkdir(exist_ok=True)
return current_dir
upload_dir = get_upload_dir()
# Create category subdirectory
category_dir = upload_dir / file_category
category_dir.mkdir(exist_ok=True)
# Generate unique filename
unique_filename = f"{uuid.uuid4()}{file_ext}"
file_path = category_dir / unique_filename
try:
# Save file
with open(file_path, "wb") as buffer:
content = await file.read()
buffer.write(content)
# Return file info
return {
"filename": file.filename,
"file_url": f"/uploads/{file_category}/{unique_filename}",
"file_size": len(content),
"file_type": file_category,
"message_type": MessageType.IMAGE if file_category == 'image' else MessageType.FILE
}
except Exception as e:
# Clean up file if save fails
if file_path.exists():
file_path.unlink()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to upload file: {str(e)}"
)
# --- STATIC FILE SERVING ---
@router.get("/uploads/{category}/{filename}")
async def serve_file(category: str, filename: str):
"""Serve uploaded files"""
import os
# Get the current working directory and construct absolute path
current_dir = os.getcwd()
file_path = Path(current_dir) / "uploads" / category / filename
print(f"🔍 Looking for file: {file_path}")
print(f"📁 File exists: {file_path.exists()}")
if not file_path.exists():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"File not found: {file_path}"
)
# Determine content type based on file extension
ext = file_path.suffix.lower()
content_types = {
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.png': 'image/png',
'.gif': 'image/gif',
'.bmp': 'image/bmp',
'.webp': 'image/webp',
'.pdf': 'application/pdf',
'.doc': 'application/msword',
'.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
'.txt': 'text/plain',
'.rtf': 'application/rtf',
'.xls': 'application/vnd.ms-excel',
'.xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
'.csv': 'text/csv',
'.ppt': 'application/vnd.ms-powerpoint',
'.pptx': 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
'.zip': 'application/zip',
'.rar': 'application/x-rar-compressed',
'.7z': 'application/x-7z-compressed'
}
content_type = content_types.get(ext, 'application/octet-stream')
try:
# Read and return file content
with open(file_path, "rb") as f:
content = f.read()
from fastapi.responses import Response
return Response(
content=content,
media_type=content_type,
headers={
"Content-Disposition": f"inline; filename={filename}",
"Cache-Control": "public, max-age=31536000"
}
)
except Exception as e:
print(f"❌ Error reading file: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error reading file: {str(e)}"
)