Spaces:
Build error
Build error
| import spaces | |
| import torch | |
| import gradio as gr | |
| from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from threading import Thread | |
| from typing import Iterator | |
| import os | |
| MODEL_NAME = "openai/whisper-large-v3-turbo" | |
| BATCH_SIZE = 8 | |
| FILE_LIMIT_MB = 5000 | |
| MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
| device = 0 if torch.cuda.is_available() else "cpu" | |
| # Initialize the LLM | |
| if torch.cuda.is_available(): | |
| llm_model_id = "chuanli11/Llama-3.2-3B-Instruct-uncensored" | |
| llm = AutoModelForCausalLM.from_pretrained(llm_model_id, torch_dtype=torch.float16, device_map="auto") | |
| tokenizer = AutoTokenizer.from_pretrained(llm_model_id) | |
| tokenizer.use_default_system_prompt = False | |
| # Initialize the transcription pipeline | |
| pipe = pipeline( | |
| task="automatic-speech-recognition", | |
| model=MODEL_NAME, | |
| chunk_length_s=30, | |
| device=device, | |
| ) | |
| # Function to transcribe audio inputs | |
| def transcribe(inputs, task): | |
| if inputs is None: | |
| raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.") | |
| text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"] | |
| return text | |
| # Function to generate SOAP notes using LLM | |
| def generate_soap( | |
| transcribed_text: str, | |
| system_prompt: str = "You are a world class clinical assistant.", | |
| max_new_tokens: int = 4098, | |
| temperature: float = 0.6, | |
| top_p: float = 0.9, | |
| top_k: int = 50, | |
| repetition_penalty: float = 1.2, | |
| ) -> Iterator[str]: | |
| task_prompt = """ | |
| Convert the following transcribed conversation into a clinical SOAP note. | |
| The text includes dialogue between a physician and a patient. Please clearly distinguish between the physician's and the patient's statements. | |
| Extract and organize the information into the relevant sections of a SOAP note: | |
| - Subjective (symptoms and patient statements), | |
| - Objective (clinical findings and observations), | |
| - Assessment (diagnosis or potential diagnoses), | |
| - Plan (treatment and follow-up). | |
| Ensure the note is concise, clear, and accurately reflects the conversation. | |
| """ | |
| conversation = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": f"{task_prompt}\n\nTranscribed conversation:\n{transcribed_text}"} | |
| ] | |
| input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") | |
| if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
| input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
| gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") | |
| input_ids = input_ids.to(llm.device) | |
| streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
| generate_kwargs = dict( | |
| {"input_ids": input_ids}, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| top_p=top_p, | |
| top_k=top_k, | |
| temperature=temperature, | |
| num_beams=1, | |
| repetition_penalty=repetition_penalty, | |
| ) | |
| t = Thread(target=llm.generate, kwargs=generate_kwargs) | |
| t.start() | |
| outputs = [] | |
| for text in streamer: | |
| outputs.append(text) | |
| yield "".join(outputs) | |
| # Gradio Interface combining transcription and SOAP note generation | |
| demo = gr.Blocks(theme=gr.themes.Ocean()) | |
| with demo: | |
| with gr.Tab("Clinical SOAP Note from Audio"): | |
| # Transcription Interface | |
| audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio Input") | |
| task_input = gr.Radio(["transcribe", "translate"], label="Task", value="transcribe") | |
| transcription_output = gr.Textbox(label="Transcription Output") | |
| # Transcription button | |
| transcribe_button = gr.Button("Transcribe") | |
| transcribe_button.click(fn=transcribe, inputs=[audio_input, task_input], outputs=transcription_output) | |
| # SOAP Generation Interface | |
| transcribed_text_input = gr.Textbox(label="Edit Transcription before SOAP Generation", lines=5) | |
| system_prompt_input = gr.Textbox(label="System Prompt", lines=2, value="You are a world class clinical assistant.") | |
| max_new_tokens_input = gr.Slider(label="Max new tokens", minimum=1, maximum=2048, value=1024, step=1) | |
| temperature_input = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, value=0.6, step=0.1) | |
| top_p_input = gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, value=0.9, step=0.05) | |
| top_k_input = gr.Slider(label="Top-k", minimum=1, maximum=1000, value=50, step=1) | |
| repetition_penalty_input = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, value=1.2, step=0.05) | |
| soap_output = gr.Textbox(label="Generated SOAP Note Output") | |
| # SOAP generation button | |
| generate_soap_button = gr.Button("Generate SOAP Note") | |
| generate_soap_button.click( | |
| fn=generate_soap, | |
| inputs=[ | |
| transcribed_text_input, | |
| system_prompt_input, | |
| max_new_tokens_input, | |
| temperature_input, | |
| top_p_input, | |
| top_k_input, | |
| repetition_penalty_input | |
| ], | |
| outputs=soap_output | |
| ) | |
| # Automatically copy transcription output to the edit box | |
| def update_transcription_box(transcription_text): | |
| return transcription_text | |
| transcription_output.change(fn=update_transcription_box, inputs=transcription_output, outputs=transcribed_text_input) | |
| demo.queue().launch(ssr_mode=False) |