File size: 3,058 Bytes
bdd5464
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#!/usr/bin/env python3
"""
Concise inference: load HF bundle and predict on one text.
Usage:
  python inference_mmbert_hf_example.py [model_dir] [text]
Defaults:
  model_dir = .
  text      = simple Dutch example
"""

import os, sys, json, torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer


def main():
    model_dir = sys.argv[1] if len(sys.argv) > 1 else "."
    text = sys.argv[2] if len(sys.argv) > 2 else (
        "Het is echt NIET te doen hier!!! Door dat hele filmfestival zijn er elke avond mensen aan het schreeuwen en harde muziek tot laat Ik kan gewoon niet meer slapen  Hoe is dit ooit goedgekeurd zo vlak na de feestdagen?????? Heb al beelden gemaakt als bewijs, kan ik die ergens heen sturen?? Het moet toch snellre opgelost kunnen worden dan dit, het duurt allemaal veel te lang Kunnen jullie dr ff naar kijken????"
    )

    device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu"))

    # Load encoder + tokenizer + heads metadata
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    encoder = AutoModel.from_pretrained(model_dir).to(device).eval()
    state = torch.load(os.path.join(model_dir, "dual_head_state.pt"), map_location="cpu")
    with open(os.path.join(model_dir, "label_names.json")) as f:
        labels = json.load(f)

    hidden = encoder.config.hidden_size
    n_on, n_be = int(state["num_onderwerp"]), int(state["num_beleving"]) 
    drop = float(state.get("dropout", 0.1))
    max_len = int(state.get("max_length", 512))

    # Rebuild heads and load weights
    onderwerp_head = nn.Sequential(nn.Linear(hidden, hidden), nn.Dropout(drop), nn.ReLU(), nn.Linear(hidden, n_on)).to(device).eval()
    beleving_head  = nn.Sequential(nn.Linear(hidden, hidden), nn.Dropout(drop), nn.ReLU(), nn.Linear(hidden, n_be)).to(device).eval()
    onderwerp_head.load_state_dict(state["onderwerp_head_state"], strict=True)
    beleving_head.load_state_dict(state["beleving_head_state"], strict=True)

    # Encode and predict
    with torch.inference_mode():
        enc = tokenizer(text, truncation=True, padding="max_length", max_length=max_len, return_tensors="pt")
        input_ids, attn = enc["input_ids"].to(device), enc["attention_mask"].to(device)
        pooled = encoder(input_ids=input_ids, attention_mask=attn).last_hidden_state[:, 0, :]
        on_probs = torch.sigmoid(onderwerp_head(pooled))[0].cpu()
        be_probs = torch.sigmoid(beleving_head(pooled))[0].cpu()

    # Top-5 per head (probability)
    def topk(probs, names, k=5):
        idx = torch.topk(probs, k=min(k, len(probs))).indices.tolist()
        return [(names[i], float(probs[i])) for i in idx]

    print(f"Onderwerp top-5: {[f'{n}: {p:.3f}' for n,p in topk(on_probs, labels['onderwerp'])]}")
    print(f"Beleving top-5:  {[f'{n}: {p:.3f}' for n,p in topk(be_probs, labels['beleving'])]}")
    print(f"Device: {device} | max_length: {max_len} | model_dir: {model_dir}")


if __name__ == "__main__":
    main()