|
|
|
|
|
import os, re, io, numpy as np, pandas as pd, torch, gradio as gr |
|
|
import plotly.graph_objects as go |
|
|
from sklearn.decomposition import PCA |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
from faster_whisper import WhisperModel |
|
|
import soundfile as sf |
|
|
import tempfile, os |
|
|
|
|
|
|
|
|
|
|
|
MODEL_NAME = os.getenv("MODEL_NAME", "cardiffnlp/twitter-xlm-roberta-base-sentiment") |
|
|
id2label = {0:"NEG", 1:"NEU", 2:"POS"} |
|
|
label2id = {v:k for k,v in id2label.items()} |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
tok = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False) |
|
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
|
MODEL_NAME, num_labels=3, id2label={0:"NEGATIVE",1:"NEUTRAL",2:"POSITIVE"}, |
|
|
label2id={"NEGATIVE":0,"NEUTRAL":1,"POSITIVE":2} |
|
|
).to(device).eval() |
|
|
|
|
|
|
|
|
ASR_MODEL_NAME = os.getenv("ASR_MODEL_NAME", "small") |
|
|
asr = WhisperModel( |
|
|
ASR_MODEL_NAME, |
|
|
device="cuda" if torch.cuda.is_available() else "cpu", |
|
|
compute_type="float16" if torch.cuda.is_available() else "int8" |
|
|
) |
|
|
|
|
|
|
|
|
def clean_text(x: str) -> str: |
|
|
if not isinstance(x, str): return "" |
|
|
return re.sub(r"\s+", " ", x).strip() |
|
|
|
|
|
|
|
|
def transcribe_numpy_audio(audio_tuple, lang_hint="ja"): |
|
|
if audio_tuple is None: |
|
|
return "" |
|
|
sr, data = audio_tuple |
|
|
if len(data.shape) > 1: |
|
|
data = data.mean(axis=1) |
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: |
|
|
sf.write(tmp.name, data, sr) |
|
|
tmp_path = tmp.name |
|
|
try: |
|
|
segments, info = asr.transcribe(tmp_path, language=lang_hint, vad_filter=True) |
|
|
text = "".join([seg.text for seg in segments]).strip() |
|
|
finally: |
|
|
try: os.remove(tmp_path) |
|
|
except: pass |
|
|
return text |
|
|
|
|
|
def mean_pool(last_hidden_state, attention_mask): |
|
|
|
|
|
mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() |
|
|
summed = (last_hidden_state * mask).sum(dim=1) |
|
|
counts = mask.sum(dim=1).clamp(min=1e-9) |
|
|
return (summed / counts).cpu().numpy() |
|
|
|
|
|
SYM = {"NEG":"triangle-left", "NEU":"circle-open", "POS":"triangle-up"} |
|
|
CLR = {"NEG":"#d62728", "NEU":"#ff7f0e", "POS":"#2ca02c"} |
|
|
|
|
|
def classify_once(text: str): |
|
|
if not text or not text.strip(): |
|
|
return "—", 0.0, "テキストを入力してください。" |
|
|
with torch.no_grad(): |
|
|
inputs = tok(text, return_tensors="pt", truncation=True, max_length=256) |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
logits = model(**inputs).logits |
|
|
probs = torch.softmax(logits, dim=-1).cpu().numpy()[0] |
|
|
idx = int(np.argmax(probs)) |
|
|
label = id2label[idx] |
|
|
score = float(probs[idx]) |
|
|
reason = { |
|
|
"POS":"肯定的な語彙・前向きなトーンが優勢。", |
|
|
"NEU":"強い肯定/否定が少なく中立的。", |
|
|
"NEG":"否定・不満・懸念の表現が目立つ。" |
|
|
}[label] |
|
|
return label, round(score, 3), reason |
|
|
|
|
|
def embed_and_predict(text: str): |
|
|
""" 1文を埋め込み & 予測 """ |
|
|
text = clean_text(text) |
|
|
if not text: |
|
|
raise ValueError("テキストが空です。") |
|
|
with torch.no_grad(): |
|
|
inputs = tok(text, return_tensors="pt", truncation=True, max_length=128) |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
logits = model(**inputs).logits |
|
|
probs = torch.softmax(logits, dim=-1).cpu().numpy()[0] |
|
|
pred_id = int(np.argmax(probs)); conf = float(probs[pred_id]) |
|
|
pred = id2label[pred_id] |
|
|
out = model(**inputs, output_hidden_states=True) |
|
|
last = out.hidden_states[-1] |
|
|
emb = mean_pool(last, inputs["attention_mask"])[0] |
|
|
return emb, pred, conf |
|
|
|
|
|
def build_fig(points_xy, points_pred, points_conf, points_text, color_mode="PRED"): |
|
|
""" 現在の履歴点のみで描画(背景データなし) """ |
|
|
fig = go.Figure() |
|
|
|
|
|
if len(points_xy) == 0: |
|
|
fig.update_layout( |
|
|
title="Interactive Sentiment Map (空のマップ — 入力すると点が追加されます)", |
|
|
xaxis_title="PCA 1", yaxis_title="PCA 2" |
|
|
) |
|
|
return fig |
|
|
arr = np.asarray(points_xy) |
|
|
preds = np.asarray(points_pred) |
|
|
confs = np.asarray(points_conf) |
|
|
texts = np.asarray(points_text) |
|
|
|
|
|
for lab in ["NEG","NEU","POS"]: |
|
|
mask = (preds == lab) |
|
|
if not mask.any(): continue |
|
|
fig.add_trace(go.Scatter( |
|
|
x=arr[mask,0], y=arr[mask,1], |
|
|
mode="markers+text", |
|
|
name=f"INPUT ({lab})", |
|
|
marker=dict(size=11, symbol=SYM.get(lab,"x"), color=CLR.get(lab,"#000")), |
|
|
text=[f"{lab} ({c:.2f})" for c in confs[mask]], |
|
|
textposition="top center", |
|
|
customdata=[t[:180]+("…" if len(t)>180 else "") for t in texts[mask]], |
|
|
hovertemplate="<b>Input</b><br>%{text}<br>Text: %{customdata}<extra></extra>" |
|
|
)) |
|
|
fig.update_layout( |
|
|
title="Interactive Sentiment Map (入力履歴のみで更新)", |
|
|
xaxis_title="PCA 1", yaxis_title="PCA 2", |
|
|
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0) |
|
|
) |
|
|
return fig |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("## ビジネス日本語・感情分析(ライブPCA)\n- 起動直後から空のマップが表示され、**入力するたびに点がどんどん追加**されます。\n- 背景データは不要(オプションで後日追加可能)。") |
|
|
|
|
|
with gr.Tab("判定(単発)"): |
|
|
inp = gr.Textbox(lines=3, label="テキストを入力", placeholder="例)迅速なご対応に感謝します。") |
|
|
btn = gr.Button("判定する") |
|
|
out1 = gr.Label(label="判定") |
|
|
out2 = gr.Number(label="確信度(0-1)") |
|
|
out3 = gr.Textbox(label="理由") |
|
|
btn.click(fn=classify_once, inputs=inp, outputs=[out1,out2,out3]) |
|
|
|
|
|
with gr.Tab("マップ(履歴が貯まる)"): |
|
|
gr.Markdown("### 入力のたびに埋め込み→PCA再計算→マップ更新\n※ 点が増えるほど座標が安定します(**3件以上**で2次元化)。") |
|
|
text_live = gr.Textbox(lines=2, label="テキスト") |
|
|
add_btn = gr.Button("マップに追加") |
|
|
clr_btn = gr.Button("履歴クリア") |
|
|
dl_btn = gr.Button("履歴CSVをダウンロード") |
|
|
plot = gr.Plot(label="Sentiment Map") |
|
|
|
|
|
|
|
|
embs_state = gr.State([]) |
|
|
preds_state = gr.State([]) |
|
|
confs_state = gr.State([]) |
|
|
texts_state = gr.State([]) |
|
|
|
|
|
def add_and_plot(text, embs, preds, confs, texts): |
|
|
text = clean_text(text) |
|
|
if not text: |
|
|
fig = build_fig([], [], [], []) |
|
|
return fig, embs, preds, confs, texts, "テキストを入力してください。" |
|
|
try: |
|
|
emb, pred, conf = embed_and_predict(text) |
|
|
except Exception as e: |
|
|
fig = build_fig([], [], [], []) |
|
|
return fig, embs, preds, confs, texts, f"ERROR: {type(e).__name__}: {e}" |
|
|
|
|
|
embs = list(embs or []); preds = list(preds or []); confs = list(confs or []); texts = list(texts or []) |
|
|
embs.append(emb); preds.append(pred); confs.append(conf); texts.append(text) |
|
|
|
|
|
|
|
|
if len(embs) >= 3: |
|
|
X = np.stack(embs, axis=0) |
|
|
pca = PCA(n_components=2, random_state=42) |
|
|
coords = pca.fit_transform(X) |
|
|
points_xy = coords.tolist() |
|
|
else: |
|
|
|
|
|
points_xy = [[i, 0.0] for i in range(len(embs))] |
|
|
|
|
|
fig = build_fig(points_xy, preds, confs, texts) |
|
|
return fig, embs, preds, confs, texts, f"追加: {pred} ({conf:.2f})" |
|
|
|
|
|
def clear_all(): |
|
|
fig = build_fig([], [], [], []) |
|
|
return fig, [], [], [], [], "履歴をクリアしました。" |
|
|
|
|
|
def download_csv(embs, preds, confs, texts): |
|
|
if not texts: |
|
|
return None |
|
|
df = pd.DataFrame({ |
|
|
"text": texts, |
|
|
"pred": preds, |
|
|
"confidence": confs |
|
|
}) |
|
|
return gr.File.update(value=io.BytesIO(df.to_csv(index=False).encode("utf-8")), visible=True, filename="history.csv") |
|
|
|
|
|
add_btn.click(add_and_plot, inputs=[text_live, embs_state, preds_state, confs_state, texts_state], |
|
|
outputs=[plot, embs_state, preds_state, confs_state, texts_state, gr.Markdown()]) |
|
|
clr_btn.click(clear_all, inputs=None, |
|
|
outputs=[plot, embs_state, preds_state, confs_state, texts_state, gr.Markdown()]) |
|
|
dl_btn.click(download_csv, inputs=[embs_state, preds_state, confs_state, texts_state], |
|
|
outputs=[gr.File(label="history.csv", visible=True)]) |
|
|
|
|
|
with gr.Tab("音声入力(マイク→文字起こし→判定&マップ追加)"): |
|
|
gr.Markdown("マイク録音 → 自動で文字起こし → 判定 → マップにも追加します。") |
|
|
audio = gr.Audio(sources=["microphone"], type="numpy", label="マイク録音(日本語)") |
|
|
asr_btn = gr.Button("文字起こし→判定&マップに追加") |
|
|
|
|
|
|
|
|
asr_text = gr.Textbox(label="文字起こし結果", interactive=False) |
|
|
asr_label = gr.Label(label="判定") |
|
|
asr_conf = gr.Number(label="確信度(0-1)") |
|
|
asr_reason= gr.Textbox(label="理由") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def asr_then_classify_and_plot(audio_tuple, embs, preds, confs, texts): |
|
|
|
|
|
text = transcribe_numpy_audio(audio_tuple, lang_hint="ja") |
|
|
if not text: |
|
|
fig = build_fig([], [], [], []) |
|
|
return "", "—", 0.0, "音声が空でした。", fig, embs, preds, confs, texts |
|
|
|
|
|
|
|
|
label, score, reason = classify_once(text) |
|
|
|
|
|
|
|
|
try: |
|
|
emb, pred, conf = embed_and_predict(text) |
|
|
embs = list(embs or []); preds = list(preds or []); confs = list(confs or []); texts = list(texts or []) |
|
|
embs.append(emb); preds.append(pred); confs.append(conf); texts.append(text) |
|
|
|
|
|
if len(embs) >= 3: |
|
|
X = np.stack(embs, axis=0) |
|
|
pca = PCA(n_components=2, random_state=42) |
|
|
coords = pca.fit_transform(X) |
|
|
points_xy = coords.tolist() |
|
|
else: |
|
|
points_xy = [[i, 0.0] for i in range(len(embs))] |
|
|
|
|
|
fig = build_fig(points_xy, preds, confs, texts) |
|
|
except Exception as e: |
|
|
fig = build_fig([], [], [], []) |
|
|
reason = f"{reason}\n(マップ追加エラー: {type(e).__name__}: {e})" |
|
|
|
|
|
return text, label, score, reason, fig, embs, preds, confs, texts |
|
|
|
|
|
|
|
|
asr_btn.click( |
|
|
asr_then_classify_and_plot, |
|
|
inputs=[audio, embs_state, preds_state, confs_state, texts_state], |
|
|
outputs=[asr_text, asr_label, asr_conf, asr_reason, plot, embs_state, preds_state, confs_state, texts_state] |
|
|
) |
|
|
|
|
|
|
|
|
demo.launch() |
|
|
|