wimbert-synth-v0 / README.md
yhavinga's picture
Update train script to filter calamity messages
13d4fa0
|
raw
history blame
6.09 kB
---
language:
- nl
license: apache-2.0
library_name: transformers
pipeline_tag: text-classification
base_model: bert-base-multilingual-cased
tags:
- multi-label
- dutch
- municipal-complaints
- mbert
- bert
- pytorch
- safetensors
datasets:
- UWV/wim-synthetic-data-rd
metrics:
- f1
- precision
- recall
- accuracy
model-index:
- name: WimBERT v0
results:
- task:
name: Multi-label Text Classification
type: text-classification
dataset:
name: UWV/WIM Synthetic RD
type: UWV/wim-synthetic-data-rd
split: validation
subset: onderwerp
metrics:
- name: F1
type: f1
value: 0.932
- name: Precision
type: precision
value: 0.960
- name: Recall
type: recall
value: 0.905
- name: Accuracy
type: accuracy
value: 0.998
- task:
name: Multi-label Text Classification
type: text-classification
dataset:
name: UWV/WIM Synthetic RD
type: UWV/wim-synthetic-data-rd
split: validation
subset: beleving
metrics:
- name: F1
type: f1
value: 0.789
- name: Precision
type: precision
value: 0.859
- name: Recall
type: recall
value: 0.730
- name: Accuracy
type: accuracy
value: 0.971
widget:
- text: "Goedemiddag, ik heb al drie keer gebeld over mijn uitkering en krijg geen duidelijkheid."
---
# WimBERT v0
WimBERT is a dual‑head, multi‑label classifier for Dutch municipal complaint messages.
The model uses a shared mmBERT‑base encoder with two MLP heads:
- Onderwerp (topics): 96 labels
- Beleving (experience): 26 labels
Trained with a combined objective: alpha · (1 − Soft‑F1) + (1 − alpha) · BCE.
## Overview
- Encoder: mmBERT‑base (multilingual)
- Heads: 2× MLP (Linear → Dropout → ReLU → Linear)
- Labels: 96 onderwerp, 26 beleving
- Task: Multi‑label classification (sigmoid per class)
- Thresholds: Disabled (fixed 0.5 used for evaluation/inference)
## Intended Use
- Classify incoming Dutch complaint messages into topical (onderwerp) and experiential (beleving) labels.
- Useful for analytics, routing, and trend insights. Not intended for legal or benefit decisions without human review.
## Training Data
- Source: `UWV/wim-synthetic-data-rd` (train split)
- Samples: 9,351
- Labels: 96 onderwerp, 26 beleving
- Avg labels per sample: onderwerp 1.75, beleving 1.89
- Shapes: onderwerp (9351, 96), beleving (9351, 26)
- Train/Val split: 7,480 / 1,871 (80/20)
## Training Setup
- Date: 2025‑10‑20
- Hardware: NVIDIA A100 GPU
- Epochs: 15
- Batch size: 16
- Sequence length: 1,408 tokens
- Optimizer: AdamW
- Scheduler: Linear warmup (10%) → cosine annealing, `min_lr=1e‑6`
- Gradient clipping: max_norm 1.0
- Random seed: 42
### Hyperparameters
- alpha (F1 weight): 0.15
- dropout: 0.20
- encoder peak LR: 8e‑5
- temperature (Soft‑F1): 2.0
- learnable thresholds: false
- initial_threshold: 0.565 (not used, thresholds disabled)
- threshold LR mult: 5.0 (not used because thresholds disabled)
## Metrics
Final validation (500 samples):
- Onderwerp:
- Accuracy: 99.8%
- Precision: 0.960
- Recall: 0.905
- F1: 0.932
- Beleving:
- Accuracy: 97.1%
- Precision: 0.859
- Recall: 0.730
- F1: 0.789
- Combined:
- Average Accuracy: 98.4%
- Average F1: 0.861
## Saved Artifacts
- HF‑compatible files:
- `model.safetensors` — encoder weights
- `config.json` — encoder config
- `tokenizer.json`, `tokenizer_config.json`, `special_tokens_map.json` — tokenizer
- `dual_head_state.pt` — classification heads + metadata (no thresholds included when disabled)
- `label_names.json` — label names for both heads
- `inference_mmbert_hf_example.py` — example inference script (CLI)
## Inference
Quick start (script):
- `python inference_mmbert_hf_example.py [model_dir=. ] "Uw voorbeeldzin hier"`
Minimal code (probabilities + top‑k):
```python
import os, json, torch, torch.nn as nn
from transformers import AutoModel, AutoTokenizer
model_dir = "."
tok = AutoTokenizer.from_pretrained(model_dir)
enc = AutoModel.from_pretrained(model_dir).eval()
state = torch.load(os.path.join(model_dir, "dual_head_state.pt"), map_location="cpu")
with open(os.path.join(model_dir, "label_names.json")) as f:
labels = json.load(f)
hidden = enc.config.hidden_size
drop = float(state.get("dropout", 0.1))
n_on, n_be = int(state["num_onderwerp"]), int(state["num_beleving"])
on_head = nn.Sequential(nn.Linear(hidden, hidden), nn.Dropout(drop), nn.ReLU(), nn.Linear(hidden, n_on)).eval()
be_head = nn.Sequential(nn.Linear(hidden, hidden), nn.Dropout(drop), nn.ReLU(), nn.Linear(hidden, n_be)).eval()
on_head.load_state_dict(state["onderwerp_head_state"])
be_head.load_state_dict(state["beleving_head_state"])
text = "Goedemiddag, ik heb al drie keer gebeld over mijn uitkering ..."
enc_inputs = tok(text, truncation=True, padding="max_length", max_length=int(state.get("max_length", 512)), return_tensors="pt")
pooled = enc(**enc_inputs).last_hidden_state[:, 0, :]
on_probs = torch.sigmoid(on_head(pooled))[0]
be_probs = torch.sigmoid(be_head(pooled))[0]
topk = lambda p, names, k=5: [(names[i], float(p[i])) for i in torch.topk(p, k=min(k, len(p))).indices]
print("Onderwerp:", topk(on_probs, labels["onderwerp"]))
print("Beleving:", topk(be_probs, labels["beleving"]))
```
## Limitations & Risks
- Domain: Dutch complaint messages; performance may degrade out‑of‑domain or in other languages.
- Thresholding: No learned thresholds; 0.5 cutoff is a simple heuristic.
- Label imbalance and multi‑label ambiguity can affect precision/recall trade‑offs.
## Reproduction
- Script: `train_mmbert_dual_soft_f1_simplified.py`
- Env: see `requirements.txt` (PyTorch, Transformers, Datasets, wandb)
- Key config: seed 42, batch size 16, epochs 13, max_length 1408, α=0.15, encoder_peak_lr=8e‑5, warmup_ratio=0.1, min_lr=1e‑6.
## Acknowledgements
- UWV WIM synthetic RD dataset
- Hugging Face Transformers/Datasets
## License
This model is licensed under the Apache License 2.0. See `LICENSE` for details.