# WimBERT v0 WimBERT is a dual‑head, multi‑label classifier for Dutch municipal complaint messages. The model uses a shared mmBERT‑base encoder with two MLP heads: - Onderwerp (topics): 96 labels - Beleving (experience): 26 labels Trained with a combined objective: alpha · (1 − Soft‑F1) + (1 − alpha) · BCE. ## Overview - Encoder: mmBERT‑base (multilingual) - Heads: 2× MLP (Linear → Dropout → ReLU → Linear) - Labels: 96 onderwerp, 26 beleving - Task: Multi‑label classification (sigmoid per class) - Thresholds: Disabled (fixed 0.5 used for evaluation/inference) ## Intended Use - Classify incoming Dutch complaint messages into topical (onderwerp) and experiential (beleving) labels. - Useful for analytics, routing, and trend insights. Not intended for legal or benefit decisions without human review. ## Training Data - Source: `UWV/wim-synthetic-data-rd` (train split) - Samples: 9,351 - Labels: 96 onderwerp, 26 beleving - Avg labels per sample: onderwerp 1.75, beleving 1.89 - Shapes: onderwerp (9351, 96), beleving (9351, 26) - Train/Val split: 7,480 / 1,871 (80/20) ## Training Setup - Date: 2025‑10‑20 - Hardware: NVIDIA A100 GPU - Epochs: 15 - Batch size: 16 - Sequence length: 1,408 tokens - Optimizer: AdamW - Scheduler: Linear warmup (10%) → cosine annealing, `min_lr=1e‑6` - Gradient clipping: max_norm 1.0 - Random seed: 42 ### Hyperparameters - alpha (F1 weight): 0.15 - dropout: 0.20 - encoder peak LR: 8e‑5 - temperature (Soft‑F1): 2.0 - learnable thresholds: false - initial_threshold: 0.565 (not used, thresholds disabled) - threshold LR mult: 5.0 (not used because thresholds disabled) ## Metrics Final validation (500 samples): - Onderwerp: - Accuracy: 99.8% - Precision: 0.960 - Recall: 0.905 - F1: 0.932 - Beleving: - Accuracy: 97.1% - Precision: 0.859 - Recall: 0.730 - F1: 0.789 - Combined: - Average Accuracy: 98.4% - Average F1: 0.861 ## Saved Artifacts - HF‑compatible files: - `model.safetensors` — encoder weights - `config.json` — encoder config - `tokenizer.json`, `tokenizer_config.json`, `special_tokens_map.json` — tokenizer - `dual_head_state.pt` — classification heads + metadata (no thresholds included when disabled) - `label_names.json` — label names for both heads - `inference_mmbert_hf_example.py` — example inference script (CLI) ## Inference Quick start (script): - `python inference_mmbert_hf_example.py [model_dir=. ] "Uw voorbeeldzin hier"` Minimal code (probabilities + top‑k): ```python import os, json, torch, torch.nn as nn from transformers import AutoModel, AutoTokenizer model_dir = "." tok = AutoTokenizer.from_pretrained(model_dir) enc = AutoModel.from_pretrained(model_dir).eval() state = torch.load(os.path.join(model_dir, "dual_head_state.pt"), map_location="cpu") with open(os.path.join(model_dir, "label_names.json")) as f: labels = json.load(f) hidden = enc.config.hidden_size drop = float(state.get("dropout", 0.1)) n_on, n_be = int(state["num_onderwerp"]), int(state["num_beleving"]) on_head = nn.Sequential(nn.Linear(hidden, hidden), nn.Dropout(drop), nn.ReLU(), nn.Linear(hidden, n_on)).eval() be_head = nn.Sequential(nn.Linear(hidden, hidden), nn.Dropout(drop), nn.ReLU(), nn.Linear(hidden, n_be)).eval() on_head.load_state_dict(state["onderwerp_head_state"]) be_head.load_state_dict(state["beleving_head_state"]) text = "Goedemiddag, ik heb al drie keer gebeld over mijn uitkering ..." enc_inputs = tok(text, truncation=True, padding="max_length", max_length=int(state.get("max_length", 512)), return_tensors="pt") pooled = enc(**enc_inputs).last_hidden_state[:, 0, :] on_probs = torch.sigmoid(on_head(pooled))[0] be_probs = torch.sigmoid(be_head(pooled))[0] topk = lambda p, names, k=5: [(names[i], float(p[i])) for i in torch.topk(p, k=min(k, len(p))).indices] print("Onderwerp:", topk(on_probs, labels["onderwerp"])) print("Beleving:", topk(be_probs, labels["beleving"])) ``` ## Limitations & Risks - Domain: Dutch complaint messages; performance may degrade out‑of‑domain or in other languages. - Thresholding: No learned thresholds; 0.5 cutoff is a simple heuristic. - Label imbalance and multi‑label ambiguity can affect precision/recall trade‑offs. ## Reproduction - Script: `train_mmbert_dual_soft_f1_simplified.py` - Env: see `requirements.txt` (PyTorch, Transformers, Datasets, wandb) - Key config: seed 42, batch size 16, epochs 13, max_length 1408, α=0.15, encoder_peak_lr=8e‑5, warmup_ratio=0.1, min_lr=1e‑6. ## Acknowledgements - UWV WIM synthetic RD dataset - Hugging Face Transformers/Datasets