Ali2206 commited on
Commit
b515e8c
·
1 Parent(s): 85965d9

Initial CPS-API deployment with TxAgent integration

Browse files
analysis.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, List
2
+ from enum import Enum
3
+ from config import agent, patients_collection, analysis_collection, alerts_collection, logger
4
+ from models import RiskLevel
5
+ from utils import (
6
+ structure_medical_response,
7
+ compute_file_content_hash,
8
+ compute_patient_data_hash,
9
+ serialize_patient,
10
+ broadcast_notification
11
+ )
12
+ from datetime import datetime
13
+ import asyncio
14
+ import json
15
+ import re
16
+ import os
17
+ class NotificationType(str, Enum):
18
+ RISK_ALERT = "risk_alert"
19
+ SYSTEM = "system"
20
+ MESSAGE = "message"
21
+
22
+ class NotificationStatus(str, Enum):
23
+ UNREAD = "unread"
24
+ READ = "read"
25
+ ARCHIVED = "archived"
26
+
27
+ async def create_alert(patient_id: str, risk_data: dict):
28
+ try:
29
+ alert_doc = {
30
+ "patient_id": patient_id,
31
+ "type": "suicide_risk",
32
+ "level": risk_data["level"],
33
+ "score": risk_data["score"],
34
+ "factors": risk_data["factors"],
35
+ "timestamp": datetime.utcnow(),
36
+ "acknowledged": False,
37
+ "notification": {
38
+ "type": "risk_alert",
39
+ "status": "unread",
40
+ "title": f"Suicide Risk: {risk_data['level'].capitalize()}",
41
+ "message": f"Patient {patient_id} shows {risk_data['level']} risk factors",
42
+ "icon": "⚠️",
43
+ "action_url": f"/patient/{patient_id}/risk-assessment",
44
+ "priority": "high" if risk_data["level"] in ["high", "severe"] else "medium"
45
+ }
46
+ }
47
+
48
+ await alerts_collection.insert_one(alert_doc)
49
+
50
+ # Simplified WebSocket notification - remove Hugging Face specific code
51
+ await broadcast_notification(alert_doc["notification"])
52
+
53
+ logger.warning(f"⚠️ Created suicide risk alert for patient {patient_id}")
54
+ return alert_doc
55
+ except Exception as e:
56
+ logger.error(f"Failed to create alert: {str(e)}")
57
+ raise
58
+ async def analyze_patient_report(
59
+ patient_id: Optional[str],
60
+ report_content: str,
61
+ file_type: str,
62
+ file_content: bytes
63
+ ):
64
+ """Analyze a patient report and create alerts for risks"""
65
+ identifier = patient_id if patient_id else compute_file_content_hash(file_content)
66
+ report_data = {"identifier": identifier, "content": report_content, "file_type": file_type}
67
+ report_hash = compute_patient_data_hash(report_data)
68
+ logger.info(f"🧾 Analyzing report for identifier: {identifier}")
69
+
70
+ # Check for existing analysis
71
+ existing_analysis = await analysis_collection.find_one(
72
+ {"identifier": identifier, "report_hash": report_hash}
73
+ )
74
+ if existing_analysis:
75
+ logger.info(f"✅ No changes in report data for {identifier}, skipping analysis")
76
+ return existing_analysis
77
+
78
+ try:
79
+ # Generate analysis
80
+ prompt = (
81
+ "You are a clinical decision support AI. Analyze the following patient report:\n"
82
+ "1. Summarize the patient's medical history.\n"
83
+ "2. Identify risks or red flags (including mental health and suicide risk).\n"
84
+ "3. Highlight missed diagnoses or treatments.\n"
85
+ "4. Suggest next clinical steps.\n"
86
+ f"\nPatient Report ({file_type}):\n{'-'*40}\n{report_content[:10000]}"
87
+ )
88
+
89
+ raw_response = agent.chat(
90
+ message=prompt,
91
+ history=[],
92
+ temperature=0.7,
93
+ max_new_tokens=1024
94
+ )
95
+ structured_response = structure_medical_response(raw_response)
96
+
97
+ # Detect suicide risk
98
+ risk_level, risk_score, risk_factors = detect_suicide_risk(raw_response)
99
+ suicide_risk = {
100
+ "level": risk_level.value,
101
+ "score": risk_score,
102
+ "factors": risk_factors
103
+ }
104
+
105
+ # Store analysis
106
+ analysis_doc = {
107
+ "identifier": identifier,
108
+ "patient_id": patient_id,
109
+ "timestamp": datetime.utcnow(),
110
+ "summary": structured_response,
111
+ "suicide_risk": suicide_risk,
112
+ "raw": raw_response,
113
+ "report_hash": report_hash,
114
+ "file_type": file_type
115
+ }
116
+
117
+ await analysis_collection.update_one(
118
+ {"identifier": identifier, "report_hash": report_hash},
119
+ {"$set": analysis_doc},
120
+ upsert=True
121
+ )
122
+
123
+ # Create alert if risk detected
124
+ if patient_id and risk_level in [RiskLevel.MODERATE, RiskLevel.HIGH, RiskLevel.SEVERE]:
125
+ await create_alert(patient_id, suicide_risk)
126
+
127
+ logger.info(f"✅ Stored analysis for identifier {identifier}")
128
+ return analysis_doc
129
+
130
+ except Exception as e:
131
+ logger.error(f"Error analyzing report for {identifier}: {str(e)}")
132
+ error_alert = {
133
+ "identifier": identifier,
134
+ "type": "system_error",
135
+ "level": "high",
136
+ "message": f"Report analysis failed: {str(e)}",
137
+ "timestamp": datetime.utcnow(),
138
+ "acknowledged": False,
139
+ "notification": {
140
+ "type": NotificationType.SYSTEM,
141
+ "status": NotificationStatus.UNREAD,
142
+ "title": "Report Analysis Error",
143
+ "message": f"Failed to analyze report for {'patient ' + patient_id if patient_id else 'unknown identifier'}",
144
+ "icon": "❌",
145
+ "action_url": "/system/errors",
146
+ "priority": "high"
147
+ }
148
+ }
149
+ await alerts_collection.insert_one(error_alert)
150
+ raise
151
+
152
+ async def analyze_patient(patient: dict):
153
+ """Analyze complete patient record and create alerts for risks"""
154
+ try:
155
+ serialized = serialize_patient(patient)
156
+ patient_id = serialized.get("fhir_id")
157
+ patient_hash = compute_patient_data_hash(serialized)
158
+ logger.info(f"🧾 Analyzing patient: {patient_id}")
159
+
160
+ # Check for existing analysis
161
+ existing_analysis = await analysis_collection.find_one({"patient_id": patient_id})
162
+ if existing_analysis and existing_analysis.get("data_hash") == patient_hash:
163
+ logger.info(f"✅ No changes in patient data for {patient_id}, skipping analysis")
164
+ return
165
+
166
+ # Generate analysis
167
+ doc = json.dumps(serialized, indent=2)
168
+ message = (
169
+ "You are a clinical decision support AI.\n\n"
170
+ "Given the patient document below:\n"
171
+ "1. Summarize the patient's medical history.\n"
172
+ "2. Identify risks or red flags (including mental health and suicide risk).\n"
173
+ "3. Highlight missed diagnoses or treatments.\n"
174
+ "4. Suggest next clinical steps.\n"
175
+ f"\nPatient Document:\n{'-'*40}\n{doc[:10000]}"
176
+ )
177
+
178
+ raw = agent.chat(message=message, history=[], temperature=0.7, max_new_tokens=1024)
179
+ structured = structure_medical_response(raw)
180
+
181
+ # Detect suicide risk
182
+ risk_level, risk_score, risk_factors = detect_suicide_risk(raw)
183
+ suicide_risk = {
184
+ "level": risk_level.value,
185
+ "score": risk_score,
186
+ "factors": risk_factors
187
+ }
188
+
189
+ # Store analysis
190
+ analysis_doc = {
191
+ "identifier": patient_id,
192
+ "patient_id": patient_id,
193
+ "timestamp": datetime.utcnow(),
194
+ "summary": structured,
195
+ "suicide_risk": suicide_risk,
196
+ "raw": raw,
197
+ "data_hash": patient_hash
198
+ }
199
+
200
+ await analysis_collection.update_one(
201
+ {"identifier": patient_id},
202
+ {"$set": analysis_doc},
203
+ upsert=True
204
+ )
205
+
206
+ # Create alert if risk detected
207
+ if risk_level in [RiskLevel.MODERATE, RiskLevel.HIGH, RiskLevel.SEVERE]:
208
+ await create_alert(patient_id, suicide_risk)
209
+
210
+ logger.info(f"✅ Stored analysis for patient {patient_id}")
211
+
212
+ except Exception as e:
213
+ logger.error(f"Error analyzing patient: {str(e)}")
214
+ error_alert = {
215
+ "patient_id": patient_id if 'patient_id' in locals() else "unknown",
216
+ "type": "system_error",
217
+ "level": "high",
218
+ "message": f"Patient analysis failed: {str(e)}",
219
+ "timestamp": datetime.utcnow(),
220
+ "acknowledged": False,
221
+ "notification": {
222
+ "type": NotificationType.SYSTEM,
223
+ "status": NotificationStatus.UNREAD,
224
+ "title": "Analysis Error",
225
+ "message": f"Failed to analyze patient {patient_id if 'patient_id' in locals() else 'unknown'}",
226
+ "icon": "❌",
227
+ "action_url": "/system/errors",
228
+ "priority": "high"
229
+ }
230
+ }
231
+ await alerts_collection.insert_one(error_alert)
232
+ raise
233
+
234
+ def detect_suicide_risk(text: str) -> Tuple[RiskLevel, float, List[str]]:
235
+ """Detect suicide risk level from text analysis"""
236
+ suicide_keywords = [
237
+ 'suicide', 'suicidal', 'kill myself', 'end my life',
238
+ 'want to die', 'self-harm', 'self harm', 'hopeless',
239
+ 'no reason to live', 'plan to die'
240
+ ]
241
+ explicit_mentions = [kw for kw in suicide_keywords if kw in text.lower()]
242
+ if not explicit_mentions:
243
+ return RiskLevel.NONE, 0.0, []
244
+
245
+ try:
246
+ # Get AI assessment
247
+ assessment_prompt = (
248
+ "Assess the suicide risk level based on this text. "
249
+ "Consider frequency, specificity, and severity of statements. "
250
+ "Respond with JSON format: {\"risk_level\": \"low/moderate/high/severe\", "
251
+ "\"risk_score\": 0-1, \"factors\": [\"list of risk factors\"]}\n\n"
252
+ f"Text to assess:\n{text}"
253
+ )
254
+
255
+ response = agent.chat(
256
+ message=assessment_prompt,
257
+ history=[],
258
+ temperature=0.2,
259
+ max_new_tokens=256
260
+ )
261
+
262
+ # Parse response
263
+ json_match = re.search(r'\{.*\}', response, re.DOTALL)
264
+ if json_match:
265
+ assessment = json.loads(json_match.group())
266
+ return (
267
+ RiskLevel(assessment.get("risk_level", "none").lower()),
268
+ float(assessment.get("risk_score", 0)),
269
+ assessment.get("factors", [])
270
+ )
271
+ except Exception as e:
272
+ logger.error(f"Error in suicide risk assessment: {e}")
273
+
274
+ # Fallback heuristic if AI assessment fails
275
+ risk_score = min(0.1 * len(explicit_mentions), 0.9)
276
+ if risk_score > 0.7:
277
+ return RiskLevel.HIGH, risk_score, explicit_mentions
278
+ elif risk_score > 0.4:
279
+ return RiskLevel.MODERATE, risk_score, explicit_mentions
280
+ return RiskLevel.LOW, risk_score, explicit_mentions
api/routes/txagent.py CHANGED
@@ -1,57 +1,99 @@
1
- from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, Query
2
- from fastapi.responses import StreamingResponse
 
3
  from typing import Optional, List
4
  from pydantic import BaseModel
5
  from core.security import get_current_user
6
- from api.services.txagent_service import txagent_service
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import logging
8
 
9
  logger = logging.getLogger(__name__)
10
 
11
- router = APIRouter()
12
-
13
  class ChatRequest(BaseModel):
14
  message: str
15
  history: Optional[List[dict]] = None
 
 
 
16
  patient_id: Optional[str] = None
17
 
18
  class VoiceOutputRequest(BaseModel):
19
  text: str
20
  language: str = "en-US"
 
 
 
 
 
 
 
 
 
21
 
22
  @router.get("/txagent/status")
23
- async def get_txagent_status(current_user: dict = Depends(get_current_user)):
24
- """Obtient le statut du service TxAgent"""
25
- try:
26
- status = await txagent_service.get_status()
27
- return {
28
- "status": "success",
29
- "txagent_status": status,
30
- "mode": txagent_service.config.get_txagent_mode()
31
- }
32
- except Exception as e:
33
- logger.error(f"Error getting TxAgent status: {e}")
34
- raise HTTPException(status_code=500, detail="Failed to get TxAgent status")
35
 
36
  @router.get("/txagent/patients/analysis-results")
37
  async def get_patient_analysis_results(
38
  name: Optional[str] = Query(None),
39
  current_user: dict = Depends(get_current_user)
40
  ):
41
- """Get patient analysis results from integrated TxAgent service"""
42
  try:
43
  # Check if user has appropriate permissions
44
  if not any(role in current_user.get('roles', []) for role in ['doctor', 'admin']):
45
  raise HTTPException(status_code=403, detail="Only doctors and admins can access analysis results")
46
 
47
- # Use the integrated TxAgent service to get analysis results
48
- results = await txagent_service.get_analysis_results(name)
 
 
49
 
50
- # Return the results directly (not wrapped in status/mode)
51
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  except Exception as e:
53
- logger.error(f"Error getting analysis results: {e}")
54
- # Return empty array instead of throwing error to prevent 500
55
  return []
56
 
57
  @router.post("/txagent/chat")
@@ -59,22 +101,19 @@ async def chat_with_txagent(
59
  request: ChatRequest,
60
  current_user: dict = Depends(get_current_user)
61
  ):
62
- """Chat avec TxAgent"""
63
  try:
64
  # Vérifier que l'utilisateur est médecin ou admin
65
  if not any(role in current_user.get('roles', []) for role in ['doctor', 'admin']):
66
  raise HTTPException(status_code=403, detail="Only doctors and admins can use TxAgent")
67
 
68
- response = await txagent_service.chat(
69
- message=request.message,
70
- history=request.history,
71
- patient_id=request.patient_id
72
- )
73
 
74
  return {
75
  "status": "success",
76
  "response": response,
77
- "mode": txagent_service.config.get_txagent_mode()
78
  }
79
  except Exception as e:
80
  logger.error(f"Error in TxAgent chat: {e}")
@@ -85,18 +124,16 @@ async def transcribe_audio(
85
  audio: UploadFile = File(...),
86
  current_user: dict = Depends(get_current_user)
87
  ):
88
- """Transcription vocale avec TxAgent"""
89
  try:
90
  if not any(role in current_user.get('roles', []) for role in ['doctor', 'admin']):
91
  raise HTTPException(status_code=403, detail="Only doctors and admins can use voice features")
92
 
93
- audio_data = await audio.read()
94
- result = await txagent_service.voice_transcribe(audio_data)
95
-
96
  return {
97
  "status": "success",
98
- "transcription": result,
99
- "mode": txagent_service.config.get_txagent_mode()
100
  }
101
  except Exception as e:
102
  logger.error(f"Error in voice transcription: {e}")
@@ -107,63 +144,49 @@ async def synthesize_speech(
107
  request: VoiceOutputRequest,
108
  current_user: dict = Depends(get_current_user)
109
  ):
110
- """Synthèse vocale avec TxAgent"""
111
  try:
112
  if not any(role in current_user.get('roles', []) for role in ['doctor', 'admin']):
113
  raise HTTPException(status_code=403, detail="Only doctors and admins can use voice features")
114
 
115
- audio_data = await txagent_service.voice_synthesize(
116
- text=request.text,
117
- language=request.language
118
- )
119
 
120
  return StreamingResponse(
121
  iter([audio_data]),
122
  media_type="audio/mpeg",
123
- headers={
124
- "Content-Disposition": "attachment; filename=synthesized_speech.mp3"
125
- }
126
  )
127
  except Exception as e:
128
  logger.error(f"Error in voice synthesis: {e}")
129
  raise HTTPException(status_code=500, detail="Failed to synthesize speech")
130
 
131
- @router.post("/txagent/patients/analyze")
132
- async def analyze_patient_data(
133
- patient_data: dict,
134
- current_user: dict = Depends(get_current_user)
135
- ):
136
- """Analyse de données patient avec TxAgent"""
137
- try:
138
- if not any(role in current_user.get('roles', []) for role in ['doctor', 'admin']):
139
- raise HTTPException(status_code=403, detail="Only doctors and admins can use analysis features")
140
-
141
- analysis = await txagent_service.analyze_patient(patient_data)
142
-
143
- return {
144
- "status": "success",
145
- "analysis": analysis,
146
- "mode": txagent_service.config.get_txagent_mode()
147
- }
148
- except Exception as e:
149
- logger.error(f"Error in patient analysis: {e}")
150
- raise HTTPException(status_code=500, detail="Failed to analyze patient data")
151
-
152
  @router.get("/txagent/chats")
153
  async def get_chats(current_user: dict = Depends(get_current_user)):
154
  """Obtient l'historique des chats"""
155
  try:
156
  if not any(role in current_user.get('roles', []) for role in ['doctor', 'admin']):
157
- raise HTTPException(status_code=403, detail="Only doctors and admins can access chats")
158
 
159
- # Cette fonction devra être implémentée dans le service TxAgent
160
- chats = await txagent_service.get_chats()
 
161
 
162
- return {
163
- "status": "success",
164
- "chats": chats,
165
- "mode": txagent_service.config.get_txagent_mode()
166
- }
 
 
 
 
 
 
 
 
 
 
167
  except Exception as e:
168
  logger.error(f"Error getting chats: {e}")
169
  raise HTTPException(status_code=500, detail="Failed to get chats")
 
1
+ from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, Query, Path
2
+ from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
3
+ from fastapi.encoders import jsonable_encoder
4
  from typing import Optional, List
5
  from pydantic import BaseModel
6
  from core.security import get_current_user
7
+ from utils import clean_text_response
8
+ from analysis import analyze_patient_report
9
+ from voice import recognize_speech, text_to_speech, extract_text_from_pdf
10
+ from docx import Document
11
+ import re
12
+ import io
13
+ from datetime import datetime
14
+ from bson import ObjectId
15
+ import asyncio
16
+ from bson.errors import InvalidId
17
+ import base64
18
+ import os
19
+ from pathlib import Path as PathLib
20
+ import tempfile
21
+ import subprocess
22
  import logging
23
 
24
  logger = logging.getLogger(__name__)
25
 
26
+ # Define the ChatRequest model with an optional patient_id
 
27
  class ChatRequest(BaseModel):
28
  message: str
29
  history: Optional[List[dict]] = None
30
+ format: Optional[str] = "clean"
31
+ temperature: Optional[float] = 0.7
32
+ max_new_tokens: Optional[int] = 512
33
  patient_id: Optional[str] = None
34
 
35
  class VoiceOutputRequest(BaseModel):
36
  text: str
37
  language: str = "en-US"
38
+ slow: bool = False
39
+ return_format: str = "mp3"
40
+
41
+ class RiskLevel(BaseModel):
42
+ level: str
43
+ score: float
44
+ factors: Optional[List[str]] = None
45
+
46
+ router = APIRouter()
47
 
48
  @router.get("/txagent/status")
49
+ async def status(current_user: dict = Depends(get_current_user)):
50
+ logger.info(f"Status endpoint accessed by {current_user['email']}")
51
+ return {
52
+ "status": "running",
53
+ "timestamp": datetime.utcnow().isoformat(),
54
+ "version": "2.6.0",
55
+ "features": ["chat", "voice-input", "voice-output", "patient-analysis", "report-upload", "patient-reports-pdf", "all-patients-reports-pdf"]
56
+ }
 
 
 
 
57
 
58
  @router.get("/txagent/patients/analysis-results")
59
  async def get_patient_analysis_results(
60
  name: Optional[str] = Query(None),
61
  current_user: dict = Depends(get_current_user)
62
  ):
63
+ logger.info(f"Fetching analysis results by {current_user['email']}")
64
  try:
65
  # Check if user has appropriate permissions
66
  if not any(role in current_user.get('roles', []) for role in ['doctor', 'admin']):
67
  raise HTTPException(status_code=403, detail="Only doctors and admins can access analysis results")
68
 
69
+ # Import database collections
70
+ from db.mongo import db
71
+ patients_collection = db.patients
72
+ analysis_collection = db.patient_analysis_results
73
 
74
+ query = {}
75
+ if name:
76
+ name_regex = re.compile(name, re.IGNORECASE)
77
+ matching_patients = await patients_collection.find({"full_name": name_regex}).to_list(length=None)
78
+ patient_ids = [p["fhir_id"] for p in matching_patients if "fhir_id" in p]
79
+ if not patient_ids:
80
+ return []
81
+ query = {"patient_id": {"$in": patient_ids}}
82
+
83
+ analyses = await analysis_collection.find(query).sort("timestamp", -1).to_list(length=100)
84
+ enriched_results = []
85
+ for analysis in analyses:
86
+ patient = await patients_collection.find_one({"fhir_id": analysis.get("patient_id")})
87
+ if not patient:
88
+ continue # Skip if patient no longer exists
89
+ analysis["full_name"] = patient.get("full_name", "Unknown")
90
+ analysis["_id"] = str(analysis["_id"])
91
+ enriched_results.append(analysis)
92
+
93
+ return enriched_results
94
+
95
  except Exception as e:
96
+ logger.error(f"Error fetching analysis results: {e}")
 
97
  return []
98
 
99
  @router.post("/txagent/chat")
 
101
  request: ChatRequest,
102
  current_user: dict = Depends(get_current_user)
103
  ):
104
+ """Chat avec TxAgent intégré"""
105
  try:
106
  # Vérifier que l'utilisateur est médecin ou admin
107
  if not any(role in current_user.get('roles', []) for role in ['doctor', 'admin']):
108
  raise HTTPException(status_code=403, detail="Only doctors and admins can use TxAgent")
109
 
110
+ # For now, return a simple response since the full TxAgent is not yet implemented
111
+ response = f"TxAgent integrated response: {request.message}"
 
 
 
112
 
113
  return {
114
  "status": "success",
115
  "response": response,
116
+ "mode": "integrated"
117
  }
118
  except Exception as e:
119
  logger.error(f"Error in TxAgent chat: {e}")
 
124
  audio: UploadFile = File(...),
125
  current_user: dict = Depends(get_current_user)
126
  ):
127
+ """Transcription vocale avec TxAgent intégré"""
128
  try:
129
  if not any(role in current_user.get('roles', []) for role in ['doctor', 'admin']):
130
  raise HTTPException(status_code=403, detail="Only doctors and admins can use voice features")
131
 
132
+ # For now, return mock transcription
 
 
133
  return {
134
  "status": "success",
135
+ "transcription": "Mock voice transcription from integrated TxAgent",
136
+ "mode": "integrated"
137
  }
138
  except Exception as e:
139
  logger.error(f"Error in voice transcription: {e}")
 
144
  request: VoiceOutputRequest,
145
  current_user: dict = Depends(get_current_user)
146
  ):
147
+ """Synthèse vocale avec TxAgent intégré"""
148
  try:
149
  if not any(role in current_user.get('roles', []) for role in ['doctor', 'admin']):
150
  raise HTTPException(status_code=403, detail="Only doctors and admins can use voice features")
151
 
152
+ # For now, return mock audio data
153
+ audio_data = b"Mock audio data from integrated TxAgent"
 
 
154
 
155
  return StreamingResponse(
156
  iter([audio_data]),
157
  media_type="audio/mpeg",
158
+ headers={"Content-Disposition": "attachment; filename=speech.mp3"}
 
 
159
  )
160
  except Exception as e:
161
  logger.error(f"Error in voice synthesis: {e}")
162
  raise HTTPException(status_code=500, detail="Failed to synthesize speech")
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  @router.get("/txagent/chats")
165
  async def get_chats(current_user: dict = Depends(get_current_user)):
166
  """Obtient l'historique des chats"""
167
  try:
168
  if not any(role in current_user.get('roles', []) for role in ['doctor', 'admin']):
169
+ raise HTTPException(status_code=403, detail="Only doctors and admins can access chat history")
170
 
171
+ # Import database collections
172
+ from db.mongo import db
173
+ chats_collection = db.chats
174
 
175
+ # Query local database for chat history
176
+ cursor = chats_collection.find().sort("timestamp", -1).limit(50)
177
+ chats = await cursor.to_list(length=50)
178
+
179
+ return [
180
+ {
181
+ "id": str(chat["_id"]),
182
+ "message": chat.get("message", ""),
183
+ "response": chat.get("response", ""),
184
+ "timestamp": chat.get("timestamp"),
185
+ "user_id": str(chat.get("user_id", "")),
186
+ "patient_id": str(chat.get("patient_id", "")) if chat.get("patient_id") else None
187
+ }
188
+ for chat in chats
189
+ ]
190
  except Exception as e:
191
  logger.error(f"Error getting chats: {e}")
192
  raise HTTPException(status_code=500, detail="Failed to get chats")
api/services/txagent_service.py CHANGED
@@ -1,139 +1,88 @@
1
- import aiohttp
2
- import asyncio
3
  import logging
4
  from typing import Optional, Dict, Any, List
5
  from core.txagent_config import txagent_config
 
6
 
7
  logger = logging.getLogger(__name__)
8
 
9
  class TxAgentService:
10
  def __init__(self):
11
  self.config = txagent_config
12
- self.session = None
13
-
14
- async def _get_session(self):
15
- """Obtient ou crée une session HTTP"""
16
- if self.session is None:
17
- self.session = aiohttp.ClientSession()
18
- return self.session
19
-
20
- async def _make_request(self, endpoint: str, method: str = "GET", data: Optional[Dict] = None) -> Dict[str, Any]:
21
- """Fait une requête vers le service TxAgent avec fallback"""
22
- session = await self._get_session()
23
- url = f"{self.config.get_txagent_url()}{endpoint}"
24
-
25
- try:
26
- if method.upper() == "GET":
27
- async with session.get(url) as response:
28
- return await response.json()
29
- elif method.upper() == "POST":
30
- async with session.post(url, json=data) as response:
31
- return await response.json()
32
- except Exception as e:
33
- logger.error(f"Error calling TxAgent service: {e}")
34
- # Fallback vers cloud si local échoue
35
- if self.config.get_txagent_mode() == "local":
36
- logger.info("Falling back to cloud TxAgent service")
37
- self.config.mode = "cloud"
38
- return await self._make_request(endpoint, method, data)
39
- else:
40
- raise
41
 
42
  async def chat(self, message: str, history: Optional[list] = None, patient_id: Optional[str] = None) -> Dict[str, Any]:
43
- """Service de chat avec TxAgent"""
44
- data = {
45
- "message": message,
46
- "history": history or [],
47
- "patient_id": patient_id
48
  }
49
- return await self._make_request("/chat", "POST", data)
50
 
51
  async def analyze_patient(self, patient_data: Dict[str, Any]) -> Dict[str, Any]:
52
- """Analyse de données patient avec TxAgent"""
53
- return await self._make_request("/patients/analyze", "POST", patient_data)
 
 
 
 
54
 
55
  async def voice_transcribe(self, audio_data: bytes) -> Dict[str, Any]:
56
- """Transcription vocale avec TxAgent"""
57
- session = await self._get_session()
58
- url = f"{self.config.get_txagent_url()}/voice/transcribe"
59
-
60
- try:
61
- form_data = aiohttp.FormData()
62
- form_data.add_field('audio', audio_data, filename='audio.wav')
63
-
64
- async with session.post(url, data=form_data) as response:
65
- return await response.json()
66
- except Exception as e:
67
- logger.error(f"Error in voice transcription: {e}")
68
- if self.config.get_txagent_mode() == "local":
69
- self.config.mode = "cloud"
70
- return await self.voice_transcribe(audio_data)
71
- else:
72
- raise
73
 
74
  async def voice_synthesize(self, text: str, language: str = "en-US") -> bytes:
75
- """Synthèse vocale avec TxAgent"""
76
- session = await self._get_session()
77
- url = f"{self.config.get_txagent_url()}/voice/synthesize"
78
-
79
- try:
80
- data = {
81
- "text": text,
82
- "language": language,
83
- "return_format": "mp3"
84
- }
85
-
86
- async with session.post(url, json=data) as response:
87
- return await response.read()
88
- except Exception as e:
89
- logger.error(f"Error in voice synthesis: {e}")
90
- if self.config.get_txagent_mode() == "local":
91
- self.config.mode = "cloud"
92
- return await self.voice_synthesize(text, language)
93
- else:
94
- raise
95
 
96
  async def get_status(self) -> Dict[str, Any]:
97
- """Obtient le statut du service TxAgent"""
98
- return await self._make_request("/status")
 
 
 
 
99
 
100
  async def get_analysis_results(self, name: Optional[str] = None) -> List[Dict[str, Any]]:
101
- """Get patient analysis results from TxAgent service"""
102
  try:
103
- # Try to call the external TxAgent API first
104
- params = {}
105
- if name:
106
- params['name'] = name
107
 
108
- # Build URL with query parameters
109
- endpoint = "/patients/analysis-results"
110
- if params:
111
- query_string = "&".join([f"{k}={v}" for k, v in params.items()])
112
- endpoint = f"{endpoint}?{query_string}"
113
 
114
- return await self._make_request(endpoint, "GET")
115
  except Exception as e:
116
- logger.warning(f"Failed to get analysis results from external TxAgent API: {e}")
117
- # Return empty results if external API is not available
118
- # In a real implementation, you would query your local database
119
  return []
120
 
121
  async def get_chats(self) -> List[Dict[str, Any]]:
122
- """Obtient l'historique des chats"""
123
- return await self._make_request("/chats")
124
-
125
- async def get_analysis_results(self, risk_filter: Optional[str] = None) -> List[Dict[str, Any]]:
126
- """Obtient les résultats d'analyse des patients"""
127
- params = {}
128
- if risk_filter:
129
- params["risk_filter"] = risk_filter
130
- return await self._make_request("/patients/analysis-results", "GET", params)
131
-
132
- async def close(self):
133
- """Ferme la session HTTP"""
134
- if self.session:
135
- await self.session.close()
136
- self.session = None
 
 
 
 
 
 
137
 
138
  # Instance globale
139
  txagent_service = TxAgentService()
 
 
 
1
  import logging
2
  from typing import Optional, Dict, Any, List
3
  from core.txagent_config import txagent_config
4
+ from db.mongo import db
5
 
6
  logger = logging.getLogger(__name__)
7
 
8
  class TxAgentService:
9
  def __init__(self):
10
  self.config = txagent_config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  async def chat(self, message: str, history: Optional[list] = None, patient_id: Optional[str] = None) -> Dict[str, Any]:
13
+ """Service de chat avec TxAgent intégré"""
14
+ # For now, return a simple response since the full TxAgent is not yet implemented
15
+ return {
16
+ "response": f"TxAgent integrated response: {message}",
17
+ "status": "success"
18
  }
 
19
 
20
  async def analyze_patient(self, patient_data: Dict[str, Any]) -> Dict[str, Any]:
21
+ """Analyse de données patient avec TxAgent intégré"""
22
+ # For now, return mock analysis
23
+ return {
24
+ "analysis": "Mock patient analysis from integrated TxAgent",
25
+ "status": "success"
26
+ }
27
 
28
  async def voice_transcribe(self, audio_data: bytes) -> Dict[str, Any]:
29
+ """Transcription vocale avec TxAgent intégré"""
30
+ # For now, return mock transcription
31
+ return {
32
+ "transcription": "Mock voice transcription from integrated TxAgent",
33
+ "status": "success"
34
+ }
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  async def voice_synthesize(self, text: str, language: str = "en-US") -> bytes:
37
+ """Synthèse vocale avec TxAgent intégré"""
38
+ # For now, return mock audio data
39
+ return b"Mock audio data from integrated TxAgent"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  async def get_status(self) -> Dict[str, Any]:
42
+ """Obtient le statut du service TxAgent intégré"""
43
+ return {
44
+ "status": "running",
45
+ "mode": "integrated",
46
+ "version": "2.6.0"
47
+ }
48
 
49
  async def get_analysis_results(self, name: Optional[str] = None) -> List[Dict[str, Any]]:
50
+ """Get patient analysis results from integrated TxAgent service"""
51
  try:
52
+ # Since TxAgent is integrated, we can query the local database directly
53
+ # For now, return empty results until the full TxAgent is implemented
54
+ logger.info(f"Getting analysis results for name: {name}")
 
55
 
56
+ # TODO: Implement actual analysis results query from local database
57
+ # This would typically query the analysis_collection in MongoDB
 
 
 
58
 
59
+ return []
60
  except Exception as e:
61
+ logger.error(f"Error getting analysis results from integrated TxAgent: {e}")
 
 
62
  return []
63
 
64
  async def get_chats(self) -> List[Dict[str, Any]]:
65
+ """Obtient l'historique des chats depuis le service intégré"""
66
+ try:
67
+ # Query local database for chat history
68
+ chats_collection = db.chats
69
+ cursor = chats_collection.find().sort("timestamp", -1).limit(50)
70
+ chats = await cursor.to_list(length=50)
71
+
72
+ return [
73
+ {
74
+ "id": str(chat["_id"]),
75
+ "message": chat.get("message", ""),
76
+ "response": chat.get("response", ""),
77
+ "timestamp": chat.get("timestamp"),
78
+ "user_id": str(chat.get("user_id", "")),
79
+ "patient_id": str(chat.get("patient_id", "")) if chat.get("patient_id") else None
80
+ }
81
+ for chat in chats
82
+ ]
83
+ except Exception as e:
84
+ logger.error(f"Error getting chats from integrated service: {e}")
85
+ return []
86
 
87
  # Instance globale
88
  txagent_service = TxAgentService()
core/txagent_config.py CHANGED
@@ -6,9 +6,10 @@ logger = logging.getLogger(__name__)
6
 
7
  class TxAgentConfig:
8
  def __init__(self):
 
9
  self.mode = os.getenv("TXAGENT_MODE", "local") # local, cloud, hybrid
10
  self.cloud_url = os.getenv("TXAGENT_CLOUD_URL", "https://rocketfarmstudios-txagent-api.hf.space")
11
- self.local_enabled = os.getenv("TXAGENT_LOCAL_ENABLED", "false").lower() == "true"
12
  self.gpu_available = self._check_gpu_availability()
13
 
14
  def _check_gpu_availability(self) -> bool:
@@ -21,23 +22,17 @@ class TxAgentConfig:
21
 
22
  def get_txagent_mode(self) -> str:
23
  """Détermine le mode optimal pour TxAgent"""
24
- if self.mode == "cloud":
25
- return "cloud"
26
- elif self.mode == "local" and self.local_enabled and self.gpu_available:
27
- return "local"
28
- else:
29
- return "cloud" # Fallback vers cloud
30
 
31
  def get_txagent_url(self) -> str:
32
  """Retourne l'URL du service TxAgent"""
33
- if self.get_txagent_mode() == "local":
34
- return "http://localhost:8001" # Port local pour TxAgent
35
- else:
36
- return self.cloud_url
37
 
38
  def is_local_available(self) -> bool:
39
  """Vérifie si le mode local est disponible"""
40
- return self.local_enabled and self.gpu_available
41
 
42
  # Instance globale
43
  txagent_config = TxAgentConfig()
 
6
 
7
  class TxAgentConfig:
8
  def __init__(self):
9
+ # Since TxAgent is now integrated, default to local mode
10
  self.mode = os.getenv("TXAGENT_MODE", "local") # local, cloud, hybrid
11
  self.cloud_url = os.getenv("TXAGENT_CLOUD_URL", "https://rocketfarmstudios-txagent-api.hf.space")
12
+ self.local_enabled = os.getenv("TXAGENT_LOCAL_ENABLED", "true").lower() == "true"
13
  self.gpu_available = self._check_gpu_availability()
14
 
15
  def _check_gpu_availability(self) -> bool:
 
22
 
23
  def get_txagent_mode(self) -> str:
24
  """Détermine le mode optimal pour TxAgent"""
25
+ # Since TxAgent is integrated, always use local mode
26
+ return "local"
 
 
 
 
27
 
28
  def get_txagent_url(self) -> str:
29
  """Retourne l'URL du service TxAgent"""
30
+ # Since TxAgent is integrated, return localhost
31
+ return "http://localhost:7860" # Same port as the main API
 
 
32
 
33
  def is_local_available(self) -> bool:
34
  """Vérifie si le mode local est disponible"""
35
+ return True # Always available since it's integrated
36
 
37
  # Instance globale
38
  txagent_config = TxAgentConfig()
data/new_tool.json ADDED
@@ -0,0 +1 @@
 
 
1
+ []
db/mongo.py CHANGED
@@ -15,6 +15,12 @@ appointments_collection = db.appointments
15
  messages_collection = db.messages
16
  password_reset_codes_collection = db.password_reset_codes
17
 
 
 
 
 
 
 
18
  # Create indexes for better duplicate detection
19
  async def create_indexes():
20
  """Create database indexes for better performance and duplicate detection"""
@@ -51,6 +57,27 @@ async def create_indexes():
51
  ("source", 1)
52
  ])
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  print("Database indexes created successfully")
55
 
56
  except Exception as e:
 
15
  messages_collection = db.messages
16
  password_reset_codes_collection = db.password_reset_codes
17
 
18
+ # TxAgent Collections
19
+ patient_analysis_results_collection = db.patient_analysis_results
20
+ chats_collection = db.chats
21
+ clinical_alerts_collection = db.clinical_alerts
22
+ notifications_collection = db.notifications
23
+
24
  # Create indexes for better duplicate detection
25
  async def create_indexes():
26
  """Create database indexes for better performance and duplicate detection"""
 
57
  ("source", 1)
58
  ])
59
 
60
+ # TxAgent indexes
61
+ await patient_analysis_results_collection.create_index([
62
+ ("patient_id", 1),
63
+ ("timestamp", -1)
64
+ ])
65
+
66
+ await chats_collection.create_index([
67
+ ("user_id", 1),
68
+ ("timestamp", -1)
69
+ ])
70
+
71
+ await clinical_alerts_collection.create_index([
72
+ ("patient_id", 1),
73
+ ("timestamp", -1)
74
+ ])
75
+
76
+ await notifications_collection.create_index([
77
+ ("user_id", 1),
78
+ ("timestamp", -1)
79
+ ])
80
+
81
  print("Database indexes created successfully")
82
 
83
  except Exception as e:
requirements.txt CHANGED
@@ -1,15 +1,33 @@
1
- fastapi
2
- uvicorn
3
  motor
4
  python-jose[cryptography]
5
  passlib[bcrypt]
6
  certifi
7
  bcrypt==4.0.1
8
  email-validator
9
- python-multipart
10
  requests
11
  gradio
12
  python-dotenv>=0.21.0
13
- aiohttp
14
  fastapi-mail
15
- jinja2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi>=0.68.0
2
+ uvicorn>=0.15.0
3
  motor
4
  python-jose[cryptography]
5
  passlib[bcrypt]
6
  certifi
7
  bcrypt==4.0.1
8
  email-validator
9
+ python-multipart>=0.0.5
10
  requests
11
  gradio
12
  python-dotenv>=0.21.0
 
13
  fastapi-mail
14
+ jinja2
15
+ pandas>=1.3.0
16
+ pdfplumber>=0.6.0
17
+ fpdf2>=2.5.5
18
+ matplotlib>=3.4.0
19
+ transformers>=4.36.0
20
+ sentence-transformers>=2.2.2
21
+ accelerate>=0.24.1
22
+ tooluniverse
23
+ markdown
24
+ PyPDF2
25
+ pymongo
26
+ SpeechRecognition
27
+ gTTS
28
+ pydub
29
+ fitz
30
+ python-docx
31
+ pyfcm
32
+ httpx
33
+ jwt
src/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .txagent import TxAgent
2
+ from .toolrag import ToolRAGModel
3
+ __all__ = [
4
+ "TxAgent",
5
+ "ToolRAGModel",
6
+ ]
src/toolrag.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ from sentence_transformers import SentenceTransformer
5
+ from .utils import get_md5
6
+
7
+
8
+ class ToolRAGModel:
9
+ def __init__(self, rag_model_name):
10
+ self.rag_model_name = rag_model_name
11
+ self.rag_model = None
12
+ self.tool_desc_embedding = None
13
+ self.tool_name = None
14
+ self.tool_embedding_path = None
15
+ self.load_rag_model()
16
+
17
+ def load_rag_model(self):
18
+ self.rag_model = SentenceTransformer(self.rag_model_name)
19
+ self.rag_model.max_seq_length = 4096
20
+ self.rag_model.tokenizer.padding_side = "right"
21
+
22
+ def load_tool_desc_embedding(self, toolbox):
23
+ self.tool_name, _ = toolbox.refresh_tool_name_desc(enable_full_desc=True)
24
+ all_tools_str = [json.dumps(each) for each in toolbox.prepare_tool_prompts(toolbox.all_tools)]
25
+ md5_value = get_md5(str(all_tools_str))
26
+ print("Computed MD5 for tool embedding:", md5_value)
27
+
28
+ self.tool_embedding_path = os.path.join(
29
+ os.path.dirname(__file__),
30
+ self.rag_model_name.split("/")[-1] + f"_tool_embedding_{md5_value}.pt"
31
+ )
32
+
33
+ if os.path.exists(self.tool_embedding_path):
34
+ try:
35
+ self.tool_desc_embedding = torch.load(self.tool_embedding_path, map_location="cpu")
36
+ assert len(self.tool_desc_embedding) == len(toolbox.all_tools), \
37
+ "Tool count mismatch with loaded embeddings."
38
+ print("\033[92mLoaded cached tool_desc_embedding.\033[0m")
39
+ return
40
+ except Exception as e:
41
+ print(f"⚠️ Failed loading cached embeddings: {e}")
42
+ self.tool_desc_embedding = None
43
+
44
+ print("\033[93mGenerating new tool_desc_embedding...\033[0m")
45
+ self.tool_desc_embedding = self.rag_model.encode(
46
+ all_tools_str, prompt="", normalize_embeddings=True
47
+ )
48
+
49
+ torch.save(self.tool_desc_embedding, self.tool_embedding_path)
50
+ print(f"\033[92mSaved new tool_desc_embedding to {self.tool_embedding_path}\033[0m")
51
+
52
+ def rag_infer(self, query, top_k=5):
53
+ torch.cuda.empty_cache()
54
+ queries = [query]
55
+ query_embeddings = self.rag_model.encode(
56
+ queries, prompt="", normalize_embeddings=True
57
+ )
58
+ if self.tool_desc_embedding is None:
59
+ raise RuntimeError("❌ tool_desc_embedding is not initialized. Did you forget to call load_tool_desc_embedding()?")
60
+
61
+ scores = self.rag_model.similarity(
62
+ query_embeddings, self.tool_desc_embedding
63
+ )
64
+ top_k = min(top_k, len(self.tool_name))
65
+ top_k_indices = torch.topk(scores, top_k).indices.tolist()[0]
66
+ top_k_tool_names = [self.tool_name[i] for i in top_k_indices]
67
+ return top_k_tool_names
src/txagent.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import torch
4
+ from typing import Dict, Optional, List, Union
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
6
+ from sentence_transformers import SentenceTransformer
7
+
8
+ # Configure logging for Hugging Face Spaces
9
+ logging.basicConfig(
10
+ level=logging.INFO,
11
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
12
+ )
13
+ logger = logging.getLogger("TxAgent")
14
+
15
+ class TxAgent:
16
+ def __init__(self,
17
+ model_name: str,
18
+ rag_model_name: str,
19
+ tool_files_dict: Optional[Dict] = None,
20
+ enable_finish: bool = True,
21
+ enable_rag: bool = False,
22
+ force_finish: bool = True,
23
+ enable_checker: bool = True,
24
+ step_rag_num: int = 4,
25
+ seed: Optional[int] = None):
26
+
27
+ # Initialization parameters
28
+ self.model_name = model_name
29
+ self.rag_model_name = rag_model_name
30
+ self.tool_files_dict = tool_files_dict or {}
31
+ self.enable_finish = enable_finish
32
+ self.enable_rag = enable_rag
33
+ self.force_finish = force_finish
34
+ self.enable_checker = enable_checker
35
+ self.step_rag_num = step_rag_num
36
+ self.seed = seed
37
+
38
+ # Device setup
39
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+
41
+ # Models
42
+ self.model = None
43
+ self.tokenizer = None
44
+ self.rag_model = None
45
+
46
+ # Prompts
47
+ self.chat_prompt = "You are a helpful assistant for user chat."
48
+
49
+ logger.info(f"Initialized TxAgent with model: {model_name}")
50
+
51
+ def init_model(self):
52
+ """Initialize all models and components"""
53
+ try:
54
+ self.load_llm_model()
55
+ if self.enable_rag:
56
+ self.load_rag_model()
57
+ logger.info("Models initialized successfully")
58
+ except Exception as e:
59
+ logger.error(f"Model initialization failed: {str(e)}")
60
+ raise
61
+
62
+ def load_llm_model(self):
63
+ """Load the main LLM model"""
64
+ try:
65
+ logger.info(f"Loading LLM model: {self.model_name}")
66
+ self.tokenizer = AutoTokenizer.from_pretrained(
67
+ self.model_name,
68
+ trust_remote_code=True
69
+ )
70
+
71
+ self.model = AutoModelForCausalLM.from_pretrained(
72
+ self.model_name,
73
+ torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
74
+ device_map="auto",
75
+ trust_remote_code=True
76
+ )
77
+ logger.info(f"LLM model loaded on {self.device}")
78
+ except Exception as e:
79
+ logger.error(f"Failed to load LLM model: {str(e)}")
80
+ raise
81
+
82
+ def load_rag_model(self):
83
+ """Load the RAG model"""
84
+ try:
85
+ logger.info(f"Loading RAG model: {self.rag_model_name}")
86
+ self.rag_model = SentenceTransformer(
87
+ self.rag_model_name,
88
+ device=str(self.device)
89
+ )
90
+ logger.info("RAG model loaded successfully")
91
+ except Exception as e:
92
+ logger.error(f"Failed to load RAG model: {str(e)}")
93
+ raise
94
+
95
+ def chat(self, message: str, history: Optional[List[Dict]] = None,
96
+ temperature: float = 0.7, max_new_tokens: int = 512) -> str:
97
+ """Handle chat conversations"""
98
+ try:
99
+ conversation = []
100
+
101
+ # Initialize with system prompt
102
+ conversation.append({"role": "system", "content": self.chat_prompt})
103
+
104
+ # Add history if provided
105
+ if history:
106
+ for msg in history:
107
+ conversation.append({"role": msg["role"], "content": msg["content"]})
108
+
109
+ # Add current message
110
+ conversation.append({"role": "user", "content": message})
111
+
112
+ # Generate response
113
+ inputs = self.tokenizer.apply_chat_template(
114
+ conversation,
115
+ add_generation_prompt=True,
116
+ return_tensors="pt"
117
+ ).to(self.device)
118
+
119
+ generation_config = GenerationConfig(
120
+ max_new_tokens=max_new_tokens,
121
+ temperature=temperature,
122
+ do_sample=True,
123
+ pad_token_id=self.tokenizer.eos_token_id
124
+ )
125
+
126
+ outputs = self.model.generate(
127
+ inputs,
128
+ generation_config=generation_config
129
+ )
130
+
131
+ # Decode and clean up response
132
+ response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
133
+ return response.strip()
134
+
135
+ except Exception as e:
136
+ logger.error(f"Chat failed: {str(e)}")
137
+ raise RuntimeError(f"Chat failed: {str(e)}")
138
+
139
+ def cleanup(self):
140
+ """Clean up resources"""
141
+ try:
142
+ if hasattr(self, 'model'):
143
+ del self.model
144
+ if hasattr(self, 'rag_model'):
145
+ del self.rag_model
146
+ torch.cuda.empty_cache()
147
+ logger.info("Resources cleaned up")
148
+ except Exception as e:
149
+ logger.error(f"Cleanup failed: {str(e)}")
150
+ raise
151
+
152
+ def __del__(self):
153
+ """Destructor to ensure proper cleanup"""
154
+ self.cleanup()
src/utils.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import json
3
+ import hashlib
4
+ import torch
5
+ from typing import List
6
+
7
+
8
+ def get_md5(input_str):
9
+ # Create an MD5 hash object
10
+ md5_hash = hashlib.md5()
11
+ md5_hash.update(input_str.encode('utf-8'))
12
+ return md5_hash.hexdigest()
13
+
14
+
15
+ def tool_result_format(function_call_messages):
16
+ current_output = "\n\n<details>\n<summary> <strong>Verified Feedback from Tools</strong>, click to see details:</summary>\n\n"
17
+ for each_message in function_call_messages:
18
+ if each_message['role'] == 'tool':
19
+ try:
20
+ parsed = json.loads(each_message['content'])
21
+ tool_name = parsed.get("tool_name", "Unknown Tool")
22
+ tool_output = parsed.get("content", each_message['content'])
23
+ current_output += f"**🔧 Tool: {tool_name}**\n\n{tool_output}\n\n"
24
+ except Exception:
25
+ current_output += f"{each_message['content']}\n\n"
26
+ current_output += "</details>\n\n\n"
27
+ return current_output
28
+
29
+
30
+ class NoRepeatSentenceProcessor:
31
+ def __init__(self, forbidden_sequences: List[List[int]], allowed_prefix_length: int):
32
+ self.allowed_prefix_length = allowed_prefix_length
33
+ self.forbidden_prefix_dict = {}
34
+ for seq in forbidden_sequences:
35
+ if len(seq) > allowed_prefix_length:
36
+ prefix = tuple(seq[:allowed_prefix_length])
37
+ next_token = seq[allowed_prefix_length]
38
+ self.forbidden_prefix_dict.setdefault(prefix, set()).add(next_token)
39
+
40
+ def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor:
41
+ if len(token_ids) >= self.allowed_prefix_length:
42
+ prefix = tuple(token_ids[:self.allowed_prefix_length])
43
+ if prefix in self.forbidden_prefix_dict:
44
+ for token_id in self.forbidden_prefix_dict[prefix]:
45
+ logits[token_id] = -float("inf")
46
+ return logits
47
+
48
+
49
+ class ReasoningTraceChecker:
50
+ def __init__(self, question, conversation, init_index=None):
51
+ self.question = question.lower()
52
+ self.conversation = conversation
53
+ self.existing_thoughts = []
54
+ self.existing_actions = []
55
+ self.new_thoughts = []
56
+ self.new_actions = []
57
+ self.index = init_index if init_index is not None else 1
58
+
59
+ def check_conversation(self):
60
+ info = ''
61
+ current_index = self.index
62
+ for i in range(current_index, len(self.conversation)):
63
+ each = self.conversation[i]
64
+ self.index = i
65
+ if each['role'] == 'assistant':
66
+ thought = each['content']
67
+ actions = each['tool_calls']
68
+ good_status, current_info = self.check_repeat_thought(thought)
69
+ info += current_info
70
+ if not good_status:
71
+ return False, info
72
+ good_status, current_info = self.check_repeat_action(actions)
73
+ info += current_info
74
+ if not good_status:
75
+ return False, info
76
+ return True, info
77
+
78
+ def check_repeat_thought(self, thought):
79
+ if thought in self.existing_thoughts:
80
+ return False, "repeat_thought"
81
+ self.existing_thoughts.append(thought)
82
+ return True, ''
83
+
84
+ def check_repeat_action(self, actions):
85
+ if type(actions) != list:
86
+ actions = json.loads(actions)
87
+ for each_action in actions:
88
+ if 'call_id' in each_action:
89
+ del each_action['call_id']
90
+ each_action = json.dumps(each_action)
91
+ if each_action in self.existing_actions:
92
+ return False, "repeat_action"
93
+ self.existing_actions.append(each_action)
94
+ return True, ''
utils.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import hashlib
3
+ import io
4
+ import json
5
+ from datetime import datetime
6
+ from typing import Dict, List, Tuple
7
+ from bson import ObjectId
8
+ import logging
9
+ from config import logger
10
+ # Add to your utils.py
11
+ from fastapi import WebSocket
12
+ import asyncio
13
+
14
+ class NotificationManager:
15
+ def __init__(self):
16
+ self.active_connections = {}
17
+ self.notification_queue = asyncio.Queue()
18
+
19
+ async def connect(self, websocket: WebSocket, user_id: str):
20
+ await websocket.accept()
21
+ self.active_connections[user_id] = websocket
22
+
23
+ def disconnect(self, user_id: str):
24
+ if user_id in self.active_connections:
25
+ del self.active_connections[user_id]
26
+
27
+ async def broadcast_notification(self, notification: dict):
28
+ """Broadcast to all connected clients"""
29
+ for connection in self.active_connections.values():
30
+ try:
31
+ await connection.send_json({
32
+ "type": "notification",
33
+ "data": notification
34
+ })
35
+ except Exception as e:
36
+ logger.error(f"Error sending notification: {e}")
37
+
38
+ notification_manager = NotificationManager()
39
+
40
+ async def broadcast_notification(notification: dict):
41
+ """Broadcast notification to relevant users"""
42
+ # Determine recipients based on notification type/priority
43
+ recipients = []
44
+ if notification["priority"] == "high":
45
+ recipients = ["psychiatrist", "emergency_team", "primary_care"]
46
+ else:
47
+ recipients = ["primary_care", "case_manager"]
48
+
49
+ # Add to each recipient's notification queue
50
+ await notification_manager.notification_queue.put({
51
+ "recipients": recipients,
52
+ "notification": notification
53
+ })
54
+
55
+
56
+
57
+ def clean_text_response(text: str) -> str:
58
+ text = re.sub(r'\n\s*\n', '\n\n', text)
59
+ text = re.sub(r'[ ]+', ' ', text)
60
+ return text.replace("**", "").replace("__", "").strip()
61
+
62
+ def extract_section(text: str, heading: str) -> str:
63
+ try:
64
+ pattern = rf"{re.escape(heading)}:\s*\n(.*?)(?=\n[A-Z][^\n]*:|\Z)"
65
+ match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
66
+ return match.group(1).strip() if match else ""
67
+ except Exception as e:
68
+ logger.error(f"Section extraction failed for heading '{heading}': {e}")
69
+ return ""
70
+
71
+ def structure_medical_response(text: str) -> Dict:
72
+ def extract_improved(text: str, heading: str) -> str:
73
+ patterns = [
74
+ rf"{re.escape(heading)}:\s*\n(.*?)(?=\n\s*\n|\Z)",
75
+ rf"\*\*{re.escape(heading)}\*\*:\s*\n(.*?)(?=\n\s*\n|\Z)",
76
+ rf"{re.escape(heading)}[\s\-]+(.*?)(?=\n\s*\n|\Z)",
77
+ rf"\n{re.escape(heading)}\s*\n(.*?)(?=\n\s*\n|\Z)"
78
+ ]
79
+ for pattern in patterns:
80
+ match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
81
+ if match:
82
+ content = match.group(1).strip()
83
+ content = re.sub(r'^\s*[\-\*]\s*', '', content, flags=re.MULTILINE)
84
+ return content
85
+ return ""
86
+
87
+ text = text.replace('**', '').replace('__', '')
88
+ return {
89
+ "summary": extract_improved(text, "Summary of Patient's Medical History") or
90
+ extract_improved(text, "Summarize the patient's medical history"),
91
+ "risks": extract_improved(text, "Identify Risks or Red Flags") or
92
+ extract_improved(text, "Risks or Red Flags"),
93
+ "missed_issues": extract_improved(text, "Missed Diagnoses or Treatments") or
94
+ extract_improved(text, "What the doctor might have missed"),
95
+ "recommendations": extract_improved(text, "Suggest Next Clinical Steps") or
96
+ extract_improved(text, "Suggested Clinical Actions")
97
+ }
98
+
99
+ def serialize_patient(patient: dict) -> dict:
100
+ patient_copy = patient.copy()
101
+ if "_id" in patient_copy:
102
+ patient_copy["_id"] = str(patient_copy["_id"])
103
+ return patient_copy
104
+
105
+ def compute_patient_data_hash(data: dict) -> str:
106
+ serialized = json.dumps(data, sort_keys=True)
107
+ return hashlib.sha256(serialized.encode()).hexdigest()
108
+
109
+ def compute_file_content_hash(file_content: bytes) -> str:
110
+ return hashlib.sha256(file_content).hexdigest()
voice.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from fastapi import HTTPException
3
+ from config import logger
4
+ import io
5
+ import speech_recognition as sr
6
+ from gtts import gTTS
7
+ from pydub import AudioSegment
8
+ import base64
9
+ from utils import clean_text_response # Added this import
10
+
11
+ def recognize_speech(audio_data: bytes, language: str = "en-US") -> str:
12
+ recognizer = sr.Recognizer()
13
+ try:
14
+ with io.BytesIO(audio_data) as audio_file:
15
+ with sr.AudioFile(audio_file) as source:
16
+ audio = recognizer.record(source)
17
+ text = recognizer.recognize_google(audio, language=language)
18
+ return text
19
+ except sr.UnknownValueError:
20
+ logger.error("Google Speech Recognition could not understand audio")
21
+ raise HTTPException(status_code=400, detail="Could not understand audio")
22
+ except sr.RequestError as e:
23
+ logger.error(f"Could not request results from Google Speech Recognition service; {e}")
24
+ raise HTTPException(status_code=503, detail="Speech recognition service unavailable")
25
+ except Exception as e:
26
+ logger.error(f"Error in speech recognition: {e}")
27
+ raise HTTPException(status_code=500, detail="Error processing speech")
28
+
29
+ def text_to_speech(text: str, language: str = "en", slow: bool = False) -> bytes:
30
+ try:
31
+ tts = gTTS(text=text, lang=language, slow=slow)
32
+ mp3_fp = io.BytesIO()
33
+ tts.write_to_fp(mp3_fp)
34
+ mp3_fp.seek(0)
35
+ return mp3_fp.read()
36
+ except Exception as e:
37
+ logger.error(f"Error in text-to-speech conversion: {e}")
38
+ raise HTTPException(status_code=500, detail="Error generating speech")
39
+
40
+ def extract_text_from_pdf(pdf_data: bytes) -> str:
41
+ try:
42
+ from PyPDF2 import PdfReader
43
+ pdf_reader = PdfReader(io.BytesIO(pdf_data))
44
+ text = ""
45
+ for page in pdf_reader.pages:
46
+ text += page.extract_text() or ""
47
+ return clean_text_response(text) # Now works with the import
48
+ except Exception as e:
49
+ logger.error(f"Error extracting text from PDF: {e}")
50
+ raise HTTPException(status_code=400, detail="Failed to extract text from PDF")