File size: 3,032 Bytes
17b89fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr
from address_extractor import AddressExtractor
import tempfile
import os

# Instantiate your AddressExtractor class
address_extractor = AddressExtractor()

def extract_from_text(input_text):
    if not input_text.strip():
        return "Error: No text provided."
    messages = [
        {"role": "system", "content": address_extractor.system_prompt_text},
        {"role": "user", "content": input_text},
    ]
    prompt = address_extractor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    chat_input = address_extractor.tokenizer(prompt, return_tensors="pt").to(address_extractor.bitnet_model.device)

    chat_outputs = address_extractor.bitnet_model.generate(**chat_input, max_new_tokens=256)
    generated_text = address_extractor.tokenizer.decode(
        chat_outputs[0][chat_input['input_ids'].shape[-1]:], skip_special_tokens=True
    )

    return generated_text.strip() or "No address detected."

def extract_from_audio(audio_file):
    if audio_file is None:
        return "Error: No audio provided."

    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
        tmp_file.write(audio_file.read())
        tmp_file_path = tmp_file.name

    try:
        segments = address_extractor.whisper_model.transcribe(tmp_file_path)
        input_text = " ".join([seg.text.strip() for seg in segments])
        input_text = address_extractor.preprocess_text(input_text)

        messages = [
            {"role": "system", "content": address_extractor.system_prompt_speech},
            {"role": "user", "content": input_text},
        ]
        prompt = address_extractor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        chat_input = address_extractor.tokenizer(prompt, return_tensors="pt").to(address_extractor.bitnet_model.device)

        chat_outputs = address_extractor.bitnet_model.generate(**chat_input, max_new_tokens=256)
        generated_text = address_extractor.tokenizer.decode(
            chat_outputs[0][chat_input['input_ids'].shape[-1]:], skip_special_tokens=True
        )
        result = generated_text.strip() or "No address detected."

    finally:
        os.remove(tmp_file_path)

    return result

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## 📦 US Address Extractor")
    with gr.Tab("Text Input"):
        text_input = gr.Textbox(lines=3, label="Enter Text")
        text_output = gr.Textbox(label="Extracted Address")
        text_button = gr.Button("Extract Address")

        text_button.click(fn=extract_from_text, inputs=text_input, outputs=text_output)

    with gr.Tab("Audio Input (.wav)"):
        audio_input = gr.Audio(source="upload", type="file", label="Upload a .wav Audio File")
        audio_output = gr.Textbox(label="Extracted Address")
        audio_button = gr.Button("Extract Address")

        audio_button.click(fn=extract_from_audio, inputs=audio_input, outputs=audio_output)

demo.launch(server_name="0.0.0.0", server_port=7860)