|
|
--- |
|
|
language: |
|
|
- nl |
|
|
license: apache-2.0 |
|
|
library_name: transformers |
|
|
pipeline_tag: text-classification |
|
|
base_model: bert-base-multilingual-cased |
|
|
tags: |
|
|
- multi-label |
|
|
- dutch |
|
|
- municipal-complaints |
|
|
- mbert |
|
|
- bert |
|
|
- pytorch |
|
|
- safetensors |
|
|
datasets: |
|
|
- UWV/wim-synthetic-data-rd |
|
|
metrics: |
|
|
- f1 |
|
|
- precision |
|
|
- recall |
|
|
- accuracy |
|
|
model-index: |
|
|
- name: WimBERT v0 |
|
|
results: |
|
|
- task: |
|
|
name: Multi-label Text Classification |
|
|
type: text-classification |
|
|
dataset: |
|
|
name: UWV/WIM Synthetic RD |
|
|
type: UWV/wim-synthetic-data-rd |
|
|
split: validation |
|
|
subset: onderwerp |
|
|
metrics: |
|
|
- name: F1 |
|
|
type: f1 |
|
|
value: 0.932 |
|
|
- name: Precision |
|
|
type: precision |
|
|
value: 0.960 |
|
|
- name: Recall |
|
|
type: recall |
|
|
value: 0.905 |
|
|
- name: Accuracy |
|
|
type: accuracy |
|
|
value: 0.998 |
|
|
- task: |
|
|
name: Multi-label Text Classification |
|
|
type: text-classification |
|
|
dataset: |
|
|
name: UWV/WIM Synthetic RD |
|
|
type: UWV/wim-synthetic-data-rd |
|
|
split: validation |
|
|
subset: beleving |
|
|
metrics: |
|
|
- name: F1 |
|
|
type: f1 |
|
|
value: 0.789 |
|
|
- name: Precision |
|
|
type: precision |
|
|
value: 0.859 |
|
|
- name: Recall |
|
|
type: recall |
|
|
value: 0.730 |
|
|
- name: Accuracy |
|
|
type: accuracy |
|
|
value: 0.971 |
|
|
widget: |
|
|
- text: "Goedemiddag, ik heb al drie keer gebeld over mijn uitkering en krijg geen duidelijkheid." |
|
|
--- |
|
|
|
|
|
# WimBERT v0 |
|
|
|
|
|
WimBERT is a dual‑head, multi‑label classifier for Dutch municipal complaint messages. |
|
|
The model uses a shared mmBERT‑base encoder with two MLP heads: |
|
|
- Onderwerp (topics): 96 labels |
|
|
- Beleving (experience): 26 labels |
|
|
|
|
|
Trained with a combined objective: alpha · (1 − Soft‑F1) + (1 − alpha) · BCE. |
|
|
|
|
|
## Overview |
|
|
- Encoder: mmBERT‑base (multilingual) |
|
|
- Heads: 2× MLP (Linear → Dropout → ReLU → Linear) |
|
|
- Labels: 96 onderwerp, 26 beleving |
|
|
- Task: Multi‑label classification (sigmoid per class) |
|
|
- Thresholds: Disabled (fixed 0.5 used for evaluation/inference) |
|
|
|
|
|
## Intended Use |
|
|
- Classify incoming Dutch complaint messages into topical (onderwerp) and experiential (beleving) labels. |
|
|
- Useful for analytics, routing, and trend insights. Not intended for legal or benefit decisions without human review. |
|
|
|
|
|
## Training Data |
|
|
- Source: `UWV/wim-synthetic-data-rd` (train split) |
|
|
- Samples: 9,351 |
|
|
- Labels: 96 onderwerp, 26 beleving |
|
|
- Avg labels per sample: onderwerp 1.75, beleving 1.89 |
|
|
- Shapes: onderwerp (9351, 96), beleving (9351, 26) |
|
|
- Train/Val split: 7,480 / 1,871 (80/20) |
|
|
|
|
|
## Training Setup |
|
|
- Date: 2025‑10‑20 |
|
|
- Hardware: NVIDIA A100 GPU |
|
|
- Epochs: 15 |
|
|
- Batch size: 16 |
|
|
- Sequence length: 1,408 tokens |
|
|
- Optimizer: AdamW |
|
|
- Scheduler: Linear warmup (10%) → cosine annealing, `min_lr=1e‑6` |
|
|
- Gradient clipping: max_norm 1.0 |
|
|
- Random seed: 42 |
|
|
|
|
|
### Hyperparameters |
|
|
- alpha (F1 weight): 0.15 |
|
|
- dropout: 0.20 |
|
|
- encoder peak LR: 8e‑5 |
|
|
- temperature (Soft‑F1): 2.0 |
|
|
- learnable thresholds: false |
|
|
- initial_threshold: 0.565 (not used, thresholds disabled) |
|
|
- threshold LR mult: 5.0 (not used because thresholds disabled) |
|
|
|
|
|
## Metrics |
|
|
Final validation (500 samples): |
|
|
- Onderwerp: |
|
|
- Accuracy: 99.8% |
|
|
- Precision: 0.960 |
|
|
- Recall: 0.905 |
|
|
- F1: 0.932 |
|
|
- Beleving: |
|
|
- Accuracy: 97.1% |
|
|
- Precision: 0.859 |
|
|
- Recall: 0.730 |
|
|
- F1: 0.789 |
|
|
- Combined: |
|
|
- Average Accuracy: 98.4% |
|
|
- Average F1: 0.861 |
|
|
|
|
|
## Saved Artifacts |
|
|
- HF‑compatible files: |
|
|
- `model.safetensors` — encoder weights |
|
|
- `config.json` — encoder config |
|
|
- `tokenizer.json`, `tokenizer_config.json`, `special_tokens_map.json` — tokenizer |
|
|
- `dual_head_state.pt` — classification heads + metadata (no thresholds included when disabled) |
|
|
- `label_names.json` — label names for both heads |
|
|
- `inference_mmbert_hf_example.py` — example inference script (CLI) |
|
|
|
|
|
## Inference |
|
|
Quick start (script): |
|
|
- `python inference_mmbert_hf_example.py [model_dir=. ] "Uw voorbeeldzin hier"` |
|
|
|
|
|
Minimal code (probabilities + top‑k): |
|
|
```python |
|
|
import os, json, torch, torch.nn as nn |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
|
|
model_dir = "." |
|
|
tok = AutoTokenizer.from_pretrained(model_dir) |
|
|
enc = AutoModel.from_pretrained(model_dir).eval() |
|
|
state = torch.load(os.path.join(model_dir, "dual_head_state.pt"), map_location="cpu") |
|
|
with open(os.path.join(model_dir, "label_names.json")) as f: |
|
|
labels = json.load(f) |
|
|
|
|
|
hidden = enc.config.hidden_size |
|
|
drop = float(state.get("dropout", 0.1)) |
|
|
n_on, n_be = int(state["num_onderwerp"]), int(state["num_beleving"]) |
|
|
on_head = nn.Sequential(nn.Linear(hidden, hidden), nn.Dropout(drop), nn.ReLU(), nn.Linear(hidden, n_on)).eval() |
|
|
be_head = nn.Sequential(nn.Linear(hidden, hidden), nn.Dropout(drop), nn.ReLU(), nn.Linear(hidden, n_be)).eval() |
|
|
on_head.load_state_dict(state["onderwerp_head_state"]) |
|
|
be_head.load_state_dict(state["beleving_head_state"]) |
|
|
|
|
|
text = "Goedemiddag, ik heb al drie keer gebeld over mijn uitkering ..." |
|
|
enc_inputs = tok(text, truncation=True, padding="max_length", max_length=int(state.get("max_length", 512)), return_tensors="pt") |
|
|
pooled = enc(**enc_inputs).last_hidden_state[:, 0, :] |
|
|
on_probs = torch.sigmoid(on_head(pooled))[0] |
|
|
be_probs = torch.sigmoid(be_head(pooled))[0] |
|
|
|
|
|
topk = lambda p, names, k=5: [(names[i], float(p[i])) for i in torch.topk(p, k=min(k, len(p))).indices] |
|
|
print("Onderwerp:", topk(on_probs, labels["onderwerp"])) |
|
|
print("Beleving:", topk(be_probs, labels["beleving"])) |
|
|
``` |
|
|
|
|
|
## Limitations & Risks |
|
|
- Domain: Dutch complaint messages; performance may degrade out‑of‑domain or in other languages. |
|
|
- Thresholding: No learned thresholds; 0.5 cutoff is a simple heuristic. |
|
|
- Label imbalance and multi‑label ambiguity can affect precision/recall trade‑offs. |
|
|
|
|
|
## Reproduction |
|
|
- Script: `train_mmbert_dual_soft_f1_simplified.py` |
|
|
- Env: see `requirements.txt` (PyTorch, Transformers, Datasets, wandb) |
|
|
- Key config: seed 42, batch size 16, epochs 13, max_length 1408, α=0.15, encoder_peak_lr=8e‑5, warmup_ratio=0.1, min_lr=1e‑6. |
|
|
|
|
|
## Acknowledgements |
|
|
- UWV WIM synthetic RD dataset |
|
|
- Hugging Face Transformers/Datasets |
|
|
|
|
|
## License |
|
|
This model is licensed under the Apache License 2.0. See `LICENSE` for details. |
|
|
|