yhavinga commited on
Commit
23bf736
·
1 Parent(s): 7a24f2a

Add extra dataset, also filter no subtopic from beleving

Browse files
Files changed (1) hide show
  1. train/rd_dataset_loader.py +47 -18
train/rd_dataset_loader.py CHANGED
@@ -5,21 +5,25 @@ Carmack-style: minimal abstraction, direct data flow, fast operations.
5
  """
6
 
7
  import numpy as np
8
- from datasets import load_dataset
9
 
10
 
11
  def load_rd_wim_dataset(max_samples=None, split='train', filter_calamity=True):
12
  """
13
- Load UWV/wim-synthetic-data-rd dataset and encode multi-labels.
14
-
 
 
 
 
15
  Dataset contains Dutch municipal complaint conversations with two types of labels:
16
- - onderwerp: What the message is about (96 unique labels)
17
- - beleving: How the citizen experienced the interaction (26 unique labels)
18
 
19
  Args:
20
  max_samples: Limit number of samples (None = all samples)
21
  split: Dataset split to load (default: 'train')
22
- filter_calamity: If True, exclude samples with is_calamity=True (default: True)
23
 
24
  Returns:
25
  texts: List of conversation strings
@@ -29,25 +33,50 @@ def load_rd_wim_dataset(max_samples=None, split='train', filter_calamity=True):
29
  beleving_labels: List of beleving label names (sorted alphabetically)
30
  """
31
 
32
- # Load dataset from HuggingFace
33
  print(f"Loading UWV/wim-synthetic-data-rd dataset (split={split})...")
34
- ds = load_dataset('UWV/wim-synthetic-data-rd', split=split)
35
-
36
  # Filter out calamity samples if requested
37
  if filter_calamity:
38
- original_len = len(ds)
39
- ds = ds.filter(lambda x: not x['is_calamity'])
40
- filtered_len = len(ds)
41
  print(f"Filtered out {original_len - filtered_len} calamity samples ({filtered_len} remaining)")
42
-
43
- # Replace "No subtopic found" with empty list
44
- original_len_before_replacement = len(ds)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  ds = ds.map(lambda x: {
46
  **x,
47
- 'onderwerp_labels': [] if x['onderwerp_labels'] == ['No subtopic found'] else x['onderwerp_labels']
 
48
  })
49
- no_subtopic_count = sum(1 for sample in ds if len(sample['onderwerp_labels']) == 0)
50
- print(f"Replaced 'No subtopic found' with empty list for samples with no valid subtopic ({no_subtopic_count} samples)")
 
51
 
52
  # Limit samples if requested
53
  if max_samples is not None:
 
5
  """
6
 
7
  import numpy as np
8
+ from datasets import load_dataset, concatenate_datasets
9
 
10
 
11
  def load_rd_wim_dataset(max_samples=None, split='train', filter_calamity=True):
12
  """
13
+ Load combined UWV datasets and encode multi-labels.
14
+
15
+ Combines two datasets:
16
+ - UWV/wim-synthetic-data-rd: Original RD dataset
17
+ - UWV/wim_synthetic_data_for_testing_split_labels: Validated testing dataset
18
+
19
  Dataset contains Dutch municipal complaint conversations with two types of labels:
20
+ - onderwerp: What the message is about
21
+ - beleving: How the citizen experienced the interaction
22
 
23
  Args:
24
  max_samples: Limit number of samples (None = all samples)
25
  split: Dataset split to load (default: 'train')
26
+ filter_calamity: If True, exclude samples with is_calamity=True from RD dataset (default: True)
27
 
28
  Returns:
29
  texts: List of conversation strings
 
33
  beleving_labels: List of beleving label names (sorted alphabetically)
34
  """
35
 
36
+ # Load RD dataset
37
  print(f"Loading UWV/wim-synthetic-data-rd dataset (split={split})...")
38
+ ds_rd = load_dataset('UWV/wim-synthetic-data-rd', split=split)
39
+
40
  # Filter out calamity samples if requested
41
  if filter_calamity:
42
+ original_len = len(ds_rd)
43
+ ds_rd = ds_rd.filter(lambda x: not x['is_calamity'])
44
+ filtered_len = len(ds_rd)
45
  print(f"Filtered out {original_len - filtered_len} calamity samples ({filtered_len} remaining)")
46
+
47
+ # Keep only essential columns from RD dataset
48
+ ds_rd = ds_rd.select_columns(['text', 'onderwerp_labels', 'beleving_labels'])
49
+ print(f"RD dataset: {len(ds_rd)} samples")
50
+
51
+ # Load testing dataset
52
+ print(f"Loading UWV/wim_synthetic_data_for_testing_split_labels dataset (split={split})...")
53
+ ds_test = load_dataset('UWV/wim_synthetic_data_for_testing_split_labels', split=split)
54
+
55
+ # Rename columns to match RD dataset structure
56
+ ds_test = ds_test.map(lambda x: {
57
+ 'text': x['Synthetic Text'],
58
+ 'onderwerp_labels': x['validated_onderwerp_labels'],
59
+ 'beleving_labels': x['validated_beleving_labels']
60
+ }, remove_columns=ds_test.column_names)
61
+ print(f"Testing dataset: {len(ds_test)} samples")
62
+
63
+ # Concatenate datasets
64
+ ds = concatenate_datasets([ds_rd, ds_test])
65
+ print(f"Combined dataset: {len(ds)} samples")
66
+
67
+ # Shuffle with fixed seed for reproducibility
68
+ ds = ds.shuffle(seed=42)
69
+ print(f"Shuffled combined dataset")
70
+
71
+ # Replace "No subtopic found" with empty list (for both onderwerp and beleving)
72
  ds = ds.map(lambda x: {
73
  **x,
74
+ 'onderwerp_labels': [] if x['onderwerp_labels'] == ['No subtopic found'] else x['onderwerp_labels'],
75
+ 'beleving_labels': [] if x['beleving_labels'] == ['No subtopic found'] else x['beleving_labels']
76
  })
77
+ no_onderwerp_count = sum(1 for sample in ds if len(sample['onderwerp_labels']) == 0)
78
+ no_beleving_count = sum(1 for sample in ds if len(sample['beleving_labels']) == 0)
79
+ print(f"Replaced 'No subtopic found' with empty list: {no_onderwerp_count} onderwerp, {no_beleving_count} beleving")
80
 
81
  # Limit samples if requested
82
  if max_samples is not None: