wimbert-synth-v0 / README.md
yhavinga's picture
Update train script to filter calamity messages
13d4fa0
|
raw
history blame
6.09 kB
metadata
language:
  - nl
license: apache-2.0
library_name: transformers
pipeline_tag: text-classification
base_model: bert-base-multilingual-cased
tags:
  - multi-label
  - dutch
  - municipal-complaints
  - mbert
  - bert
  - pytorch
  - safetensors
datasets:
  - UWV/wim-synthetic-data-rd
metrics:
  - f1
  - precision
  - recall
  - accuracy
model-index:
  - name: WimBERT v0
    results:
      - task:
          name: Multi-label Text Classification
          type: text-classification
        dataset:
          name: UWV/WIM Synthetic RD
          type: UWV/wim-synthetic-data-rd
          split: validation
          subset: onderwerp
        metrics:
          - name: F1
            type: f1
            value: 0.932
          - name: Precision
            type: precision
            value: 0.96
          - name: Recall
            type: recall
            value: 0.905
          - name: Accuracy
            type: accuracy
            value: 0.998
      - task:
          name: Multi-label Text Classification
          type: text-classification
        dataset:
          name: UWV/WIM Synthetic RD
          type: UWV/wim-synthetic-data-rd
          split: validation
          subset: beleving
        metrics:
          - name: F1
            type: f1
            value: 0.789
          - name: Precision
            type: precision
            value: 0.859
          - name: Recall
            type: recall
            value: 0.73
          - name: Accuracy
            type: accuracy
            value: 0.971
widget:
  - text: >-
      Goedemiddag, ik heb al drie keer gebeld over mijn uitkering en krijg geen
      duidelijkheid.

WimBERT v0

WimBERT is a dual‑head, multi‑label classifier for Dutch municipal complaint messages. The model uses a shared mmBERT‑base encoder with two MLP heads:

  • Onderwerp (topics): 96 labels
  • Beleving (experience): 26 labels

Trained with a combined objective: alpha · (1 − Soft‑F1) + (1 − alpha) · BCE.

Overview

  • Encoder: mmBERT‑base (multilingual)
  • Heads: 2× MLP (Linear → Dropout → ReLU → Linear)
  • Labels: 96 onderwerp, 26 beleving
  • Task: Multi‑label classification (sigmoid per class)
  • Thresholds: Disabled (fixed 0.5 used for evaluation/inference)

Intended Use

  • Classify incoming Dutch complaint messages into topical (onderwerp) and experiential (beleving) labels.
  • Useful for analytics, routing, and trend insights. Not intended for legal or benefit decisions without human review.

Training Data

  • Source: UWV/wim-synthetic-data-rd (train split)
  • Samples: 9,351
  • Labels: 96 onderwerp, 26 beleving
  • Avg labels per sample: onderwerp 1.75, beleving 1.89
  • Shapes: onderwerp (9351, 96), beleving (9351, 26)
  • Train/Val split: 7,480 / 1,871 (80/20)

Training Setup

  • Date: 2025‑10‑20
  • Hardware: NVIDIA A100 GPU
  • Epochs: 15
  • Batch size: 16
  • Sequence length: 1,408 tokens
  • Optimizer: AdamW
  • Scheduler: Linear warmup (10%) → cosine annealing, min_lr=1e‑6
  • Gradient clipping: max_norm 1.0
  • Random seed: 42

Hyperparameters

  • alpha (F1 weight): 0.15
  • dropout: 0.20
  • encoder peak LR: 8e‑5
  • temperature (Soft‑F1): 2.0
  • learnable thresholds: false
  • initial_threshold: 0.565 (not used, thresholds disabled)
  • threshold LR mult: 5.0 (not used because thresholds disabled)

Metrics

Final validation (500 samples):

  • Onderwerp:
    • Accuracy: 99.8%
    • Precision: 0.960
    • Recall: 0.905
    • F1: 0.932
  • Beleving:
    • Accuracy: 97.1%
    • Precision: 0.859
    • Recall: 0.730
    • F1: 0.789
  • Combined:
    • Average Accuracy: 98.4%
    • Average F1: 0.861

Saved Artifacts

  • HF‑compatible files:
    • model.safetensors — encoder weights
    • config.json — encoder config
    • tokenizer.json, tokenizer_config.json, special_tokens_map.json — tokenizer
    • dual_head_state.pt — classification heads + metadata (no thresholds included when disabled)
    • label_names.json — label names for both heads
  • inference_mmbert_hf_example.py — example inference script (CLI)

Inference

Quick start (script):

  • python inference_mmbert_hf_example.py [model_dir=. ] "Uw voorbeeldzin hier"

Minimal code (probabilities + top‑k):

import os, json, torch, torch.nn as nn
from transformers import AutoModel, AutoTokenizer

model_dir = "."
tok = AutoTokenizer.from_pretrained(model_dir)
enc = AutoModel.from_pretrained(model_dir).eval()
state = torch.load(os.path.join(model_dir, "dual_head_state.pt"), map_location="cpu")
with open(os.path.join(model_dir, "label_names.json")) as f:
    labels = json.load(f)

hidden = enc.config.hidden_size
drop = float(state.get("dropout", 0.1))
n_on, n_be = int(state["num_onderwerp"]), int(state["num_beleving"])
on_head = nn.Sequential(nn.Linear(hidden, hidden), nn.Dropout(drop), nn.ReLU(), nn.Linear(hidden, n_on)).eval()
be_head = nn.Sequential(nn.Linear(hidden, hidden), nn.Dropout(drop), nn.ReLU(), nn.Linear(hidden, n_be)).eval()
on_head.load_state_dict(state["onderwerp_head_state"])
be_head.load_state_dict(state["beleving_head_state"])

text = "Goedemiddag, ik heb al drie keer gebeld over mijn uitkering ..."
enc_inputs = tok(text, truncation=True, padding="max_length", max_length=int(state.get("max_length", 512)), return_tensors="pt")
pooled = enc(**enc_inputs).last_hidden_state[:, 0, :]
on_probs = torch.sigmoid(on_head(pooled))[0]
be_probs = torch.sigmoid(be_head(pooled))[0]

topk = lambda p, names, k=5: [(names[i], float(p[i])) for i in torch.topk(p, k=min(k, len(p))).indices]
print("Onderwerp:", topk(on_probs, labels["onderwerp"]))
print("Beleving:",  topk(be_probs, labels["beleving"]))

Limitations & Risks

  • Domain: Dutch complaint messages; performance may degrade out‑of‑domain or in other languages.
  • Thresholding: No learned thresholds; 0.5 cutoff is a simple heuristic.
  • Label imbalance and multi‑label ambiguity can affect precision/recall trade‑offs.

Reproduction

  • Script: train_mmbert_dual_soft_f1_simplified.py
  • Env: see requirements.txt (PyTorch, Transformers, Datasets, wandb)
  • Key config: seed 42, batch size 16, epochs 13, max_length 1408, α=0.15, encoder_peak_lr=8e‑5, warmup_ratio=0.1, min_lr=1e‑6.

Acknowledgements

  • UWV WIM synthetic RD dataset
  • Hugging Face Transformers/Datasets

License

This model is licensed under the Apache License 2.0. See LICENSE for details.