voice / app.py
TKG51's picture
Upload 2 files
3daef1b verified
# app.py — デフォルトで空のPCAマップを表示し、入力のたびに点がどんどん貯まるミニアプリ
import os, re, io, numpy as np, pandas as pd, torch, gradio as gr
import plotly.graph_objects as go
from sklearn.decomposition import PCA
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from faster_whisper import WhisperModel
import soundfile as sf
import tempfile, os
# ===== モデル設定 =====
MODEL_NAME = os.getenv("MODEL_NAME", "cardiffnlp/twitter-xlm-roberta-base-sentiment")
id2label = {0:"NEG", 1:"NEU", 2:"POS"}
label2id = {v:k for k,v in id2label.items()}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tok = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False) # Spaces安定のためfast無効
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME, num_labels=3, id2label={0:"NEGATIVE",1:"NEUTRAL",2:"POSITIVE"},
label2id={"NEGATIVE":0,"NEUTRAL":1,"POSITIVE":2}
).to(device).eval()
# === 🟩 Whisper初期化(B) ===
ASR_MODEL_NAME = os.getenv("ASR_MODEL_NAME", "small") # tiny/smallなど
asr = WhisperModel(
ASR_MODEL_NAME,
device="cuda" if torch.cuda.is_available() else "cpu",
compute_type="float16" if torch.cuda.is_available() else "int8"
)
# ===== ユーティリティ =====
def clean_text(x: str) -> str:
if not isinstance(x, str): return ""
return re.sub(r"\s+", " ", x).strip()
# === 🟨 ここにC(音声→文字起こし関数)を追加 ===
def transcribe_numpy_audio(audio_tuple, lang_hint="ja"):
if audio_tuple is None:
return ""
sr, data = audio_tuple
if len(data.shape) > 1:
data = data.mean(axis=1)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
sf.write(tmp.name, data, sr)
tmp_path = tmp.name
try:
segments, info = asr.transcribe(tmp_path, language=lang_hint, vad_filter=True)
text = "".join([seg.text for seg in segments]).strip()
finally:
try: os.remove(tmp_path)
except: pass
return text
def mean_pool(last_hidden_state, attention_mask):
# last_hidden_state: [B, T, H], attention_mask: [B, T]
mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
summed = (last_hidden_state * mask).sum(dim=1)
counts = mask.sum(dim=1).clamp(min=1e-9)
return (summed / counts).cpu().numpy() # [B, H]
SYM = {"NEG":"triangle-left", "NEU":"circle-open", "POS":"triangle-up"}
CLR = {"NEG":"#d62728", "NEU":"#ff7f0e", "POS":"#2ca02c"}
def classify_once(text: str):
if not text or not text.strip():
return "—", 0.0, "テキストを入力してください。"
with torch.no_grad():
inputs = tok(text, return_tensors="pt", truncation=True, max_length=256)
inputs = {k: v.to(device) for k, v in inputs.items()}
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
idx = int(np.argmax(probs))
label = id2label[idx]
score = float(probs[idx])
reason = {
"POS":"肯定的な語彙・前向きなトーンが優勢。",
"NEU":"強い肯定/否定が少なく中立的。",
"NEG":"否定・不満・懸念の表現が目立つ。"
}[label]
return label, round(score, 3), reason
def embed_and_predict(text: str):
""" 1文を埋め込み & 予測 """
text = clean_text(text)
if not text:
raise ValueError("テキストが空です。")
with torch.no_grad():
inputs = tok(text, return_tensors="pt", truncation=True, max_length=128)
inputs = {k: v.to(device) for k, v in inputs.items()}
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
pred_id = int(np.argmax(probs)); conf = float(probs[pred_id])
pred = id2label[pred_id]
out = model(**inputs, output_hidden_states=True)
last = out.hidden_states[-1]
emb = mean_pool(last, inputs["attention_mask"])[0] # [H]
return emb, pred, conf
def build_fig(points_xy, points_pred, points_conf, points_text, color_mode="PRED"):
""" 現在の履歴点のみで描画(背景データなし) """
fig = go.Figure()
# 入力履歴をクラスごとにまとめて描く(凡例が見やすい)
if len(points_xy) == 0:
fig.update_layout(
title="Interactive Sentiment Map (空のマップ — 入力すると点が追加されます)",
xaxis_title="PCA 1", yaxis_title="PCA 2"
)
return fig
arr = np.asarray(points_xy) # [N,2]
preds = np.asarray(points_pred)
confs = np.asarray(points_conf)
texts = np.asarray(points_text)
for lab in ["NEG","NEU","POS"]:
mask = (preds == lab)
if not mask.any(): continue
fig.add_trace(go.Scatter(
x=arr[mask,0], y=arr[mask,1],
mode="markers+text",
name=f"INPUT ({lab})",
marker=dict(size=11, symbol=SYM.get(lab,"x"), color=CLR.get(lab,"#000")),
text=[f"{lab} ({c:.2f})" for c in confs[mask]],
textposition="top center",
customdata=[t[:180]+("…" if len(t)>180 else "") for t in texts[mask]],
hovertemplate="<b>Input</b><br>%{text}<br>Text: %{customdata}<extra></extra>"
))
fig.update_layout(
title="Interactive Sentiment Map (入力履歴のみで更新)",
xaxis_title="PCA 1", yaxis_title="PCA 2",
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0)
)
return fig
# ===== Gradio UI =====
with gr.Blocks() as demo:
gr.Markdown("## ビジネス日本語・感情分析(ライブPCA)\n- 起動直後から空のマップが表示され、**入力するたびに点がどんどん追加**されます。\n- 背景データは不要(オプションで後日追加可能)。")
with gr.Tab("判定(単発)"):
inp = gr.Textbox(lines=3, label="テキストを入力", placeholder="例)迅速なご対応に感謝します。")
btn = gr.Button("判定する")
out1 = gr.Label(label="判定")
out2 = gr.Number(label="確信度(0-1)")
out3 = gr.Textbox(label="理由")
btn.click(fn=classify_once, inputs=inp, outputs=[out1,out2,out3])
with gr.Tab("マップ(履歴が貯まる)"):
gr.Markdown("### 入力のたびに埋め込み→PCA再計算→マップ更新\n※ 点が増えるほど座標が安定します(**3件以上**で2次元化)。")
text_live = gr.Textbox(lines=2, label="テキスト")
add_btn = gr.Button("マップに追加")
clr_btn = gr.Button("履歴クリア")
dl_btn = gr.Button("履歴CSVをダウンロード")
plot = gr.Plot(label="Sentiment Map")
# 状態(埋め込みベクトルの配列 / 予測 / 確信度 / テキスト)
embs_state = gr.State([]) # List[np.ndarray(H,)]
preds_state = gr.State([]) # List[str]
confs_state = gr.State([]) # List[float]
texts_state = gr.State([]) # List[str]
def add_and_plot(text, embs, preds, confs, texts):
text = clean_text(text)
if not text:
fig = build_fig([], [], [], [])
return fig, embs, preds, confs, texts, "テキストを入力してください。"
try:
emb, pred, conf = embed_and_predict(text)
except Exception as e:
fig = build_fig([], [], [], [])
return fig, embs, preds, confs, texts, f"ERROR: {type(e).__name__}: {e}"
embs = list(embs or []); preds = list(preds or []); confs = list(confs or []); texts = list(texts or [])
embs.append(emb); preds.append(pred); confs.append(conf); texts.append(text)
# PCA再計算(履歴のみでfit)— 3件以上で実施
if len(embs) >= 3:
X = np.stack(embs, axis=0) # [N,H]
pca = PCA(n_components=2, random_state=42)
coords = pca.fit_transform(X) # [N,2]
points_xy = coords.tolist()
else:
# 件数が少ない間は簡易的な配置(Xに沿って並べる)
points_xy = [[i, 0.0] for i in range(len(embs))]
fig = build_fig(points_xy, preds, confs, texts)
return fig, embs, preds, confs, texts, f"追加: {pred} ({conf:.2f})"
def clear_all():
fig = build_fig([], [], [], [])
return fig, [], [], [], [], "履歴をクリアしました。"
def download_csv(embs, preds, confs, texts):
if not texts:
return None
df = pd.DataFrame({
"text": texts,
"pred": preds,
"confidence": confs
})
return gr.File.update(value=io.BytesIO(df.to_csv(index=False).encode("utf-8")), visible=True, filename="history.csv")
add_btn.click(add_and_plot, inputs=[text_live, embs_state, preds_state, confs_state, texts_state],
outputs=[plot, embs_state, preds_state, confs_state, texts_state, gr.Markdown()])
clr_btn.click(clear_all, inputs=None,
outputs=[plot, embs_state, preds_state, confs_state, texts_state, gr.Markdown()])
dl_btn.click(download_csv, inputs=[embs_state, preds_state, confs_state, texts_state],
outputs=[gr.File(label="history.csv", visible=True)])
with gr.Tab("音声入力(マイク→文字起こし→判定&マップ追加)"):
gr.Markdown("マイク録音 → 自動で文字起こし → 判定 → マップにも追加します。")
audio = gr.Audio(sources=["microphone"], type="numpy", label="マイク録音(日本語)")
asr_btn = gr.Button("文字起こし→判定&マップに追加")
# 文字起こしテキストと判定結果の表示
asr_text = gr.Textbox(label="文字起こし結果", interactive=False)
asr_label = gr.Label(label="判定")
asr_conf = gr.Number(label="確信度(0-1)")
asr_reason= gr.Textbox(label="理由")
# ここでマップ側の状態も使いたいので、既存のマップ・タブで定義した
# embs_state / preds_state / confs_state / texts_state / plot
# を再利用します(同じBlocks内なので参照できます)
def asr_then_classify_and_plot(audio_tuple, embs, preds, confs, texts):
# 1) ASR
text = transcribe_numpy_audio(audio_tuple, lang_hint="ja")
if not text:
fig = build_fig([], [], [], [])
return "", "—", 0.0, "音声が空でした。", fig, embs, preds, confs, texts
# 2) 判定
label, score, reason = classify_once(text)
# 3) マップ側にも追加(既存のロジックを流用)
try:
emb, pred, conf = embed_and_predict(text)
embs = list(embs or []); preds = list(preds or []); confs = list(confs or []); texts = list(texts or [])
embs.append(emb); preds.append(pred); confs.append(conf); texts.append(text)
if len(embs) >= 3:
X = np.stack(embs, axis=0)
pca = PCA(n_components=2, random_state=42)
coords = pca.fit_transform(X)
points_xy = coords.tolist()
else:
points_xy = [[i, 0.0] for i in range(len(embs))]
fig = build_fig(points_xy, preds, confs, texts)
except Exception as e:
fig = build_fig([], [], [], [])
reason = f"{reason}\n(マップ追加エラー: {type(e).__name__}: {e})"
return text, label, score, reason, fig, embs, preds, confs, texts
# ボタンクリック時の入出力(マップの状態変数を使う)
asr_btn.click(
asr_then_classify_and_plot,
inputs=[audio, embs_state, preds_state, confs_state, texts_state],
outputs=[asr_text, asr_label, asr_conf, asr_reason, plot, embs_state, preds_state, confs_state, texts_state]
)
demo.launch()