File size: 6,085 Bytes
27e4fba 13d4fa0 27e4fba bdd5464 13d4fa0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
---
language:
- nl
license: apache-2.0
library_name: transformers
pipeline_tag: text-classification
base_model: bert-base-multilingual-cased
tags:
- multi-label
- dutch
- municipal-complaints
- mbert
- bert
- pytorch
- safetensors
datasets:
- UWV/wim-synthetic-data-rd
metrics:
- f1
- precision
- recall
- accuracy
model-index:
- name: WimBERT v0
results:
- task:
name: Multi-label Text Classification
type: text-classification
dataset:
name: UWV/WIM Synthetic RD
type: UWV/wim-synthetic-data-rd
split: validation
subset: onderwerp
metrics:
- name: F1
type: f1
value: 0.932
- name: Precision
type: precision
value: 0.960
- name: Recall
type: recall
value: 0.905
- name: Accuracy
type: accuracy
value: 0.998
- task:
name: Multi-label Text Classification
type: text-classification
dataset:
name: UWV/WIM Synthetic RD
type: UWV/wim-synthetic-data-rd
split: validation
subset: beleving
metrics:
- name: F1
type: f1
value: 0.789
- name: Precision
type: precision
value: 0.859
- name: Recall
type: recall
value: 0.730
- name: Accuracy
type: accuracy
value: 0.971
widget:
- text: "Goedemiddag, ik heb al drie keer gebeld over mijn uitkering en krijg geen duidelijkheid."
---
# WimBERT v0
WimBERT is a dual‑head, multi‑label classifier for Dutch municipal complaint messages.
The model uses a shared mmBERT‑base encoder with two MLP heads:
- Onderwerp (topics): 96 labels
- Beleving (experience): 26 labels
Trained with a combined objective: alpha · (1 − Soft‑F1) + (1 − alpha) · BCE.
## Overview
- Encoder: mmBERT‑base (multilingual)
- Heads: 2× MLP (Linear → Dropout → ReLU → Linear)
- Labels: 96 onderwerp, 26 beleving
- Task: Multi‑label classification (sigmoid per class)
- Thresholds: Disabled (fixed 0.5 used for evaluation/inference)
## Intended Use
- Classify incoming Dutch complaint messages into topical (onderwerp) and experiential (beleving) labels.
- Useful for analytics, routing, and trend insights. Not intended for legal or benefit decisions without human review.
## Training Data
- Source: `UWV/wim-synthetic-data-rd` (train split)
- Samples: 9,351
- Labels: 96 onderwerp, 26 beleving
- Avg labels per sample: onderwerp 1.75, beleving 1.89
- Shapes: onderwerp (9351, 96), beleving (9351, 26)
- Train/Val split: 7,480 / 1,871 (80/20)
## Training Setup
- Date: 2025‑10‑20
- Hardware: NVIDIA A100 GPU
- Epochs: 15
- Batch size: 16
- Sequence length: 1,408 tokens
- Optimizer: AdamW
- Scheduler: Linear warmup (10%) → cosine annealing, `min_lr=1e‑6`
- Gradient clipping: max_norm 1.0
- Random seed: 42
### Hyperparameters
- alpha (F1 weight): 0.15
- dropout: 0.20
- encoder peak LR: 8e‑5
- temperature (Soft‑F1): 2.0
- learnable thresholds: false
- initial_threshold: 0.565 (not used, thresholds disabled)
- threshold LR mult: 5.0 (not used because thresholds disabled)
## Metrics
Final validation (500 samples):
- Onderwerp:
- Accuracy: 99.8%
- Precision: 0.960
- Recall: 0.905
- F1: 0.932
- Beleving:
- Accuracy: 97.1%
- Precision: 0.859
- Recall: 0.730
- F1: 0.789
- Combined:
- Average Accuracy: 98.4%
- Average F1: 0.861
## Saved Artifacts
- HF‑compatible files:
- `model.safetensors` — encoder weights
- `config.json` — encoder config
- `tokenizer.json`, `tokenizer_config.json`, `special_tokens_map.json` — tokenizer
- `dual_head_state.pt` — classification heads + metadata (no thresholds included when disabled)
- `label_names.json` — label names for both heads
- `inference_mmbert_hf_example.py` — example inference script (CLI)
## Inference
Quick start (script):
- `python inference_mmbert_hf_example.py [model_dir=. ] "Uw voorbeeldzin hier"`
Minimal code (probabilities + top‑k):
```python
import os, json, torch, torch.nn as nn
from transformers import AutoModel, AutoTokenizer
model_dir = "."
tok = AutoTokenizer.from_pretrained(model_dir)
enc = AutoModel.from_pretrained(model_dir).eval()
state = torch.load(os.path.join(model_dir, "dual_head_state.pt"), map_location="cpu")
with open(os.path.join(model_dir, "label_names.json")) as f:
labels = json.load(f)
hidden = enc.config.hidden_size
drop = float(state.get("dropout", 0.1))
n_on, n_be = int(state["num_onderwerp"]), int(state["num_beleving"])
on_head = nn.Sequential(nn.Linear(hidden, hidden), nn.Dropout(drop), nn.ReLU(), nn.Linear(hidden, n_on)).eval()
be_head = nn.Sequential(nn.Linear(hidden, hidden), nn.Dropout(drop), nn.ReLU(), nn.Linear(hidden, n_be)).eval()
on_head.load_state_dict(state["onderwerp_head_state"])
be_head.load_state_dict(state["beleving_head_state"])
text = "Goedemiddag, ik heb al drie keer gebeld over mijn uitkering ..."
enc_inputs = tok(text, truncation=True, padding="max_length", max_length=int(state.get("max_length", 512)), return_tensors="pt")
pooled = enc(**enc_inputs).last_hidden_state[:, 0, :]
on_probs = torch.sigmoid(on_head(pooled))[0]
be_probs = torch.sigmoid(be_head(pooled))[0]
topk = lambda p, names, k=5: [(names[i], float(p[i])) for i in torch.topk(p, k=min(k, len(p))).indices]
print("Onderwerp:", topk(on_probs, labels["onderwerp"]))
print("Beleving:", topk(be_probs, labels["beleving"]))
```
## Limitations & Risks
- Domain: Dutch complaint messages; performance may degrade out‑of‑domain or in other languages.
- Thresholding: No learned thresholds; 0.5 cutoff is a simple heuristic.
- Label imbalance and multi‑label ambiguity can affect precision/recall trade‑offs.
## Reproduction
- Script: `train_mmbert_dual_soft_f1_simplified.py`
- Env: see `requirements.txt` (PyTorch, Transformers, Datasets, wandb)
- Key config: seed 42, batch size 16, epochs 13, max_length 1408, α=0.15, encoder_peak_lr=8e‑5, warmup_ratio=0.1, min_lr=1e‑6.
## Acknowledgements
- UWV WIM synthetic RD dataset
- Hugging Face Transformers/Datasets
## License
This model is licensed under the Apache License 2.0. See `LICENSE` for details.
|