Pavanb's picture
Upload 2 files
17b89fa verified
raw
history blame
3.03 kB
import gradio as gr
from address_extractor import AddressExtractor
import tempfile
import os
# Instantiate your AddressExtractor class
address_extractor = AddressExtractor()
def extract_from_text(input_text):
if not input_text.strip():
return "Error: No text provided."
messages = [
{"role": "system", "content": address_extractor.system_prompt_text},
{"role": "user", "content": input_text},
]
prompt = address_extractor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
chat_input = address_extractor.tokenizer(prompt, return_tensors="pt").to(address_extractor.bitnet_model.device)
chat_outputs = address_extractor.bitnet_model.generate(**chat_input, max_new_tokens=256)
generated_text = address_extractor.tokenizer.decode(
chat_outputs[0][chat_input['input_ids'].shape[-1]:], skip_special_tokens=True
)
return generated_text.strip() or "No address detected."
def extract_from_audio(audio_file):
if audio_file is None:
return "Error: No audio provided."
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
tmp_file.write(audio_file.read())
tmp_file_path = tmp_file.name
try:
segments = address_extractor.whisper_model.transcribe(tmp_file_path)
input_text = " ".join([seg.text.strip() for seg in segments])
input_text = address_extractor.preprocess_text(input_text)
messages = [
{"role": "system", "content": address_extractor.system_prompt_speech},
{"role": "user", "content": input_text},
]
prompt = address_extractor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
chat_input = address_extractor.tokenizer(prompt, return_tensors="pt").to(address_extractor.bitnet_model.device)
chat_outputs = address_extractor.bitnet_model.generate(**chat_input, max_new_tokens=256)
generated_text = address_extractor.tokenizer.decode(
chat_outputs[0][chat_input['input_ids'].shape[-1]:], skip_special_tokens=True
)
result = generated_text.strip() or "No address detected."
finally:
os.remove(tmp_file_path)
return result
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## 📦 US Address Extractor")
with gr.Tab("Text Input"):
text_input = gr.Textbox(lines=3, label="Enter Text")
text_output = gr.Textbox(label="Extracted Address")
text_button = gr.Button("Extract Address")
text_button.click(fn=extract_from_text, inputs=text_input, outputs=text_output)
with gr.Tab("Audio Input (.wav)"):
audio_input = gr.Audio(source="upload", type="file", label="Upload a .wav Audio File")
audio_output = gr.Textbox(label="Extracted Address")
audio_button = gr.Button("Extract Address")
audio_button.click(fn=extract_from_audio, inputs=audio_input, outputs=audio_output)
demo.launch(server_name="0.0.0.0", server_port=7860)