Add extra dataset, also filter no subtopic from beleving
Browse files- train/rd_dataset_loader.py +47 -18
train/rd_dataset_loader.py
CHANGED
|
@@ -5,21 +5,25 @@ Carmack-style: minimal abstraction, direct data flow, fast operations.
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
import numpy as np
|
| 8 |
-
from datasets import load_dataset
|
| 9 |
|
| 10 |
|
| 11 |
def load_rd_wim_dataset(max_samples=None, split='train', filter_calamity=True):
|
| 12 |
"""
|
| 13 |
-
Load UWV
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
Dataset contains Dutch municipal complaint conversations with two types of labels:
|
| 16 |
-
- onderwerp: What the message is about
|
| 17 |
-
- beleving: How the citizen experienced the interaction
|
| 18 |
|
| 19 |
Args:
|
| 20 |
max_samples: Limit number of samples (None = all samples)
|
| 21 |
split: Dataset split to load (default: 'train')
|
| 22 |
-
filter_calamity: If True, exclude samples with is_calamity=True (default: True)
|
| 23 |
|
| 24 |
Returns:
|
| 25 |
texts: List of conversation strings
|
|
@@ -29,25 +33,50 @@ def load_rd_wim_dataset(max_samples=None, split='train', filter_calamity=True):
|
|
| 29 |
beleving_labels: List of beleving label names (sorted alphabetically)
|
| 30 |
"""
|
| 31 |
|
| 32 |
-
# Load dataset
|
| 33 |
print(f"Loading UWV/wim-synthetic-data-rd dataset (split={split})...")
|
| 34 |
-
|
| 35 |
-
|
| 36 |
# Filter out calamity samples if requested
|
| 37 |
if filter_calamity:
|
| 38 |
-
original_len = len(
|
| 39 |
-
|
| 40 |
-
filtered_len = len(
|
| 41 |
print(f"Filtered out {original_len - filtered_len} calamity samples ({filtered_len} remaining)")
|
| 42 |
-
|
| 43 |
-
#
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
ds = ds.map(lambda x: {
|
| 46 |
**x,
|
| 47 |
-
'onderwerp_labels': [] if x['onderwerp_labels'] == ['No subtopic found'] else x['onderwerp_labels']
|
|
|
|
| 48 |
})
|
| 49 |
-
|
| 50 |
-
|
|
|
|
| 51 |
|
| 52 |
# Limit samples if requested
|
| 53 |
if max_samples is not None:
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
import numpy as np
|
| 8 |
+
from datasets import load_dataset, concatenate_datasets
|
| 9 |
|
| 10 |
|
| 11 |
def load_rd_wim_dataset(max_samples=None, split='train', filter_calamity=True):
|
| 12 |
"""
|
| 13 |
+
Load combined UWV datasets and encode multi-labels.
|
| 14 |
+
|
| 15 |
+
Combines two datasets:
|
| 16 |
+
- UWV/wim-synthetic-data-rd: Original RD dataset
|
| 17 |
+
- UWV/wim_synthetic_data_for_testing_split_labels: Validated testing dataset
|
| 18 |
+
|
| 19 |
Dataset contains Dutch municipal complaint conversations with two types of labels:
|
| 20 |
+
- onderwerp: What the message is about
|
| 21 |
+
- beleving: How the citizen experienced the interaction
|
| 22 |
|
| 23 |
Args:
|
| 24 |
max_samples: Limit number of samples (None = all samples)
|
| 25 |
split: Dataset split to load (default: 'train')
|
| 26 |
+
filter_calamity: If True, exclude samples with is_calamity=True from RD dataset (default: True)
|
| 27 |
|
| 28 |
Returns:
|
| 29 |
texts: List of conversation strings
|
|
|
|
| 33 |
beleving_labels: List of beleving label names (sorted alphabetically)
|
| 34 |
"""
|
| 35 |
|
| 36 |
+
# Load RD dataset
|
| 37 |
print(f"Loading UWV/wim-synthetic-data-rd dataset (split={split})...")
|
| 38 |
+
ds_rd = load_dataset('UWV/wim-synthetic-data-rd', split=split)
|
| 39 |
+
|
| 40 |
# Filter out calamity samples if requested
|
| 41 |
if filter_calamity:
|
| 42 |
+
original_len = len(ds_rd)
|
| 43 |
+
ds_rd = ds_rd.filter(lambda x: not x['is_calamity'])
|
| 44 |
+
filtered_len = len(ds_rd)
|
| 45 |
print(f"Filtered out {original_len - filtered_len} calamity samples ({filtered_len} remaining)")
|
| 46 |
+
|
| 47 |
+
# Keep only essential columns from RD dataset
|
| 48 |
+
ds_rd = ds_rd.select_columns(['text', 'onderwerp_labels', 'beleving_labels'])
|
| 49 |
+
print(f"RD dataset: {len(ds_rd)} samples")
|
| 50 |
+
|
| 51 |
+
# Load testing dataset
|
| 52 |
+
print(f"Loading UWV/wim_synthetic_data_for_testing_split_labels dataset (split={split})...")
|
| 53 |
+
ds_test = load_dataset('UWV/wim_synthetic_data_for_testing_split_labels', split=split)
|
| 54 |
+
|
| 55 |
+
# Rename columns to match RD dataset structure
|
| 56 |
+
ds_test = ds_test.map(lambda x: {
|
| 57 |
+
'text': x['Synthetic Text'],
|
| 58 |
+
'onderwerp_labels': x['validated_onderwerp_labels'],
|
| 59 |
+
'beleving_labels': x['validated_beleving_labels']
|
| 60 |
+
}, remove_columns=ds_test.column_names)
|
| 61 |
+
print(f"Testing dataset: {len(ds_test)} samples")
|
| 62 |
+
|
| 63 |
+
# Concatenate datasets
|
| 64 |
+
ds = concatenate_datasets([ds_rd, ds_test])
|
| 65 |
+
print(f"Combined dataset: {len(ds)} samples")
|
| 66 |
+
|
| 67 |
+
# Shuffle with fixed seed for reproducibility
|
| 68 |
+
ds = ds.shuffle(seed=42)
|
| 69 |
+
print(f"Shuffled combined dataset")
|
| 70 |
+
|
| 71 |
+
# Replace "No subtopic found" with empty list (for both onderwerp and beleving)
|
| 72 |
ds = ds.map(lambda x: {
|
| 73 |
**x,
|
| 74 |
+
'onderwerp_labels': [] if x['onderwerp_labels'] == ['No subtopic found'] else x['onderwerp_labels'],
|
| 75 |
+
'beleving_labels': [] if x['beleving_labels'] == ['No subtopic found'] else x['beleving_labels']
|
| 76 |
})
|
| 77 |
+
no_onderwerp_count = sum(1 for sample in ds if len(sample['onderwerp_labels']) == 0)
|
| 78 |
+
no_beleving_count = sum(1 for sample in ds if len(sample['beleving_labels']) == 0)
|
| 79 |
+
print(f"Replaced 'No subtopic found' with empty list: {no_onderwerp_count} onderwerp, {no_beleving_count} beleving")
|
| 80 |
|
| 81 |
# Limit samples if requested
|
| 82 |
if max_samples is not None:
|