add dataset scripts
Browse files- .gitignore +2 -0
- convert_files.py +17 -0
- get_data.sh +23 -0
- merge_datasets.py +12 -0
- prepare_data.sh +0 -0
- train.py +41 -215
- train.sh +22 -0
- wiki_sentences.py +46 -0
.gitignore
CHANGED
|
@@ -1,3 +1,5 @@
|
|
| 1 |
.vscode
|
| 2 |
venv
|
| 3 |
*.pyc
|
|
|
|
|
|
|
|
|
| 1 |
.vscode
|
| 2 |
venv
|
| 3 |
*.pyc
|
| 4 |
+
segment_*
|
| 5 |
+
dataset.csv
|
convert_files.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
from transformers import AutoTokenizer
|
| 4 |
+
|
| 5 |
+
tokenizer = AutoTokenizer.from_pretrained('gpt2')
|
| 6 |
+
|
| 7 |
+
for i in tqdm(range(298)):
|
| 8 |
+
|
| 9 |
+
with open(f'wikipedia_json_64_filtered/wikipedia.segmented.nltk.split.seq64.{i}.json', 'r') as f:
|
| 10 |
+
rows = json.load(f)
|
| 11 |
+
|
| 12 |
+
tokens = [row['gpt2_token'] for row in rows]
|
| 13 |
+
texts = tokenizer.batch_decode(tokens)
|
| 14 |
+
|
| 15 |
+
with open(f'wikipedia/{i}.txt', 'w') as f:
|
| 16 |
+
for txt in texts:
|
| 17 |
+
f.write(txt.strip() + '\n')
|
get_data.sh
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=13NnkYAhwszQxc1C5HHfThnF7c1cjzjAD' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=13NnkYAhwszQxc1C5HHfThnF7c1cjzjAD" -O segment_1.zip && rm -rf /tmp/cookies.txt
|
| 4 |
+
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=14p6FHip_hGTXC-_7SYaK32BpEhZRDJI4' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=14p6FHip_hGTXC-_7SYaK32BpEhZRDJI4" -O segment_2.zip && rm -rf /tmp/cookies.txt
|
| 5 |
+
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1IaRfTFh51Wf_zPtK6tjE6xw-up_Z6EyN' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1IaRfTFh51Wf_zPtK6tjE6xw-up_Z6EyN" -O segment_3.zip && rm -rf /tmp/cookies.txt
|
| 6 |
+
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1KGhV397Xfej56uJ9H10xD7tfLdhWlg4q' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1KGhV397Xfej56uJ9H10xD7tfLdhWlg4q" -O segment_4.zip && rm -rf /tmp/cookies.txt
|
| 7 |
+
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1LfsQ1s9wr1mBG3I1bbvnbyrYmnsrXxZt' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1LfsQ1s9wr1mBG3I1bbvnbyrYmnsrXxZt" -O segment_5.zip && rm -rf /tmp/cookies.txt
|
| 8 |
+
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1OctFe_JPR0Ajh77FzWdfeYnWZinKl2sW' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1OctFe_JPR0Ajh77FzWdfeYnWZinKl2sW" -O segment_6.zip && rm -rf /tmp/cookies.txt
|
| 9 |
+
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1W-Yi8gHCcT8O5F4TcDHScH7pOb0GQZdu' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1W-Yi8gHCcT8O5F4TcDHScH7pOb0GQZdu" -O segment_7.zip && rm -rf /tmp/cookies.txt
|
| 10 |
+
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1jgHjnpe7Vk1pvRgfnH4S4KiRrpUQyqyp' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1jgHjnpe7Vk1pvRgfnH4S4KiRrpUQyqyp" -O segment_8.zip && rm -rf /tmp/cookies.txt
|
| 11 |
+
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1oVst8RG8G2d21DL6q4DwO7aJxE1vA2fc' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1oVst8RG8G2d21DL6q4DwO7aJxE1vA2fc" -O segment_9.zip && rm -rf /tmp/cookies.txt
|
| 12 |
+
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1qwckIM8YBbU9bnArB6bAoStY3e9I1kqU' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1qwckIM8YBbU9bnArB6bAoStY3e9I1kqU" -O segment_0.zip && rm -rf /tmp/cookies.txt
|
| 13 |
+
|
| 14 |
+
unzip segment_1.zip
|
| 15 |
+
unzip segment_2.zip
|
| 16 |
+
unzip segment_3.zip
|
| 17 |
+
unzip segment_4.zip
|
| 18 |
+
unzip segment_5.zip
|
| 19 |
+
unzip segment_6.zip
|
| 20 |
+
unzip segment_7.zip
|
| 21 |
+
unzip segment_8.zip
|
| 22 |
+
unzip segment_9.zip
|
| 23 |
+
|
merge_datasets.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datasets
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
dfs = []
|
| 5 |
+
|
| 6 |
+
for i in range(10):
|
| 7 |
+
dfs.append(
|
| 8 |
+
datasets.ArrowReader.read_table(f'segment_{i}/dataset.arrow').to_pandas()
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
full_df = pd.concat(dfs, ignore_index=True)
|
| 12 |
+
full_df.to_csv('dataset.csv')
|
prepare_data.sh
ADDED
|
File without changes
|
train.py
CHANGED
|
@@ -17,7 +17,6 @@
|
|
| 17 |
- [ ] Send the schedule time to the compute_loss method and calculate a coefficient based on that.
|
| 18 |
'''
|
| 19 |
import logging
|
| 20 |
-
import math
|
| 21 |
import os
|
| 22 |
import sys
|
| 23 |
import time
|
|
@@ -31,6 +30,7 @@ from tqdm import tqdm
|
|
| 31 |
|
| 32 |
import jax
|
| 33 |
import jax.numpy as jnp
|
|
|
|
| 34 |
import optax
|
| 35 |
import transformers
|
| 36 |
from flax import jax_utils, traverse_util
|
|
@@ -44,7 +44,6 @@ from transformers import (
|
|
| 44 |
is_tensorboard_available,
|
| 45 |
)
|
| 46 |
from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
|
| 47 |
-
from transformers.testing_utils import CaptureLogger
|
| 48 |
|
| 49 |
from t5_vae_flax.src.t5_vae import FlaxT5VaeForAutoencoding
|
| 50 |
from t5_vae_flax.src.config import T5VaeConfig
|
|
@@ -113,7 +112,7 @@ class ModelArguments:
|
|
| 113 |
@dataclass
|
| 114 |
class DataTrainingArguments:
|
| 115 |
"""
|
| 116 |
-
Arguments pertaining to what data we are going to input our model for training
|
| 117 |
"""
|
| 118 |
|
| 119 |
dataset_name: Optional[str] = field(
|
|
@@ -123,10 +122,6 @@ class DataTrainingArguments:
|
|
| 123 |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
| 124 |
)
|
| 125 |
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
| 126 |
-
validation_file: Optional[str] = field(
|
| 127 |
-
default=None,
|
| 128 |
-
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
| 129 |
-
)
|
| 130 |
max_train_samples: Optional[int] = field(
|
| 131 |
default=None,
|
| 132 |
metadata={
|
|
@@ -134,21 +129,8 @@ class DataTrainingArguments:
|
|
| 134 |
"value if set."
|
| 135 |
},
|
| 136 |
)
|
| 137 |
-
max_eval_samples: Optional[int] = field(
|
| 138 |
-
default=None,
|
| 139 |
-
metadata={
|
| 140 |
-
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
| 141 |
-
"value if set."
|
| 142 |
-
},
|
| 143 |
-
)
|
| 144 |
overwrite_cache: bool = field(
|
| 145 |
-
default=False, metadata={"help": "Overwrite the cached training
|
| 146 |
-
)
|
| 147 |
-
validation_split_percentage: Optional[int] = field(
|
| 148 |
-
default=5,
|
| 149 |
-
metadata={
|
| 150 |
-
"help": "The percentage of the train set used as validation set in case there's no validation split"
|
| 151 |
-
},
|
| 152 |
)
|
| 153 |
block_size: Optional[int] = field(
|
| 154 |
default=None,
|
|
@@ -162,7 +144,7 @@ class DataTrainingArguments:
|
|
| 162 |
default=False, metadata={"help": "Stream the dataset."}
|
| 163 |
)
|
| 164 |
overwrite_cache: bool = field(
|
| 165 |
-
default=False, metadata={"help": "Overwrite the cached training
|
| 166 |
)
|
| 167 |
preprocessing_num_workers: Optional[int] = field(
|
| 168 |
default=None,
|
|
@@ -170,15 +152,12 @@ class DataTrainingArguments:
|
|
| 170 |
)
|
| 171 |
|
| 172 |
def __post_init__(self):
|
| 173 |
-
if self.dataset_name is None and self.train_file is None
|
| 174 |
-
raise ValueError("Need either a dataset name or a training
|
| 175 |
else:
|
| 176 |
if self.train_file is not None:
|
| 177 |
extension = self.train_file.split(".")[-1]
|
| 178 |
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
|
| 179 |
-
if self.validation_file is not None:
|
| 180 |
-
extension = self.validation_file.split(".")[-1]
|
| 181 |
-
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
|
| 182 |
|
| 183 |
|
| 184 |
class TrainState(train_state.TrainState):
|
|
@@ -188,28 +167,19 @@ class TrainState(train_state.TrainState):
|
|
| 188 |
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
| 189 |
|
| 190 |
|
| 191 |
-
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int
|
| 192 |
"""
|
| 193 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
| 194 |
Shuffle batches if `shuffle` is `True`.
|
| 195 |
"""
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
| 205 |
-
|
| 206 |
-
for idx in batch_idx:
|
| 207 |
-
batch = dataset[idx]
|
| 208 |
-
batch = {k: jnp.array(v) for k, v in batch.items()}
|
| 209 |
-
|
| 210 |
-
batch = shard(batch)
|
| 211 |
-
|
| 212 |
-
yield batch
|
| 213 |
|
| 214 |
|
| 215 |
def write_train_metric(summary_writer, train_metrics, train_time, step):
|
|
@@ -222,11 +192,6 @@ def write_train_metric(summary_writer, train_metrics, train_time, step):
|
|
| 222 |
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
| 223 |
|
| 224 |
|
| 225 |
-
def write_eval_metric(summary_writer, eval_metrics, step):
|
| 226 |
-
for metric_name, value in eval_metrics.items():
|
| 227 |
-
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
| 228 |
-
|
| 229 |
-
|
| 230 |
def create_learning_rate_fn(
|
| 231 |
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
| 232 |
) -> Callable[[int], jnp.array]:
|
|
@@ -284,9 +249,9 @@ def main():
|
|
| 284 |
transformers.utils.logging.set_verbosity_error()
|
| 285 |
|
| 286 |
# Set the verbosity to info of the Transformers logger (on main process only):
|
| 287 |
-
logger.info(f"Training
|
| 288 |
|
| 289 |
-
#
|
| 290 |
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
| 291 |
# (the dataset will be downloaded automatically from the datasets Hub).
|
| 292 |
#
|
|
@@ -295,35 +260,7 @@ def main():
|
|
| 295 |
#
|
| 296 |
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
| 297 |
# download the dataset.
|
| 298 |
-
|
| 299 |
-
# Downloading and loading a dataset from the hub.
|
| 300 |
-
dataset = load_dataset(
|
| 301 |
-
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, streaming=data_args.streaming, keep_in_memory=False
|
| 302 |
-
)
|
| 303 |
-
|
| 304 |
-
if "validation" not in dataset.keys():
|
| 305 |
-
dataset["validation"] = load_dataset(
|
| 306 |
-
data_args.dataset_name,
|
| 307 |
-
data_args.dataset_config_name,
|
| 308 |
-
split=f"train[:{data_args.validation_split_percentage}%]",
|
| 309 |
-
cache_dir=model_args.cache_dir,
|
| 310 |
-
)
|
| 311 |
-
dataset["train"] = load_dataset(
|
| 312 |
-
data_args.dataset_name,
|
| 313 |
-
data_args.dataset_config_name,
|
| 314 |
-
split=f"train[{data_args.validation_split_percentage}%:]",
|
| 315 |
-
cache_dir=model_args.cache_dir,
|
| 316 |
-
)
|
| 317 |
-
else:
|
| 318 |
-
data_files = {}
|
| 319 |
-
if data_args.train_file is not None:
|
| 320 |
-
data_files["train"] = data_args.train_file
|
| 321 |
-
if data_args.validation_file is not None:
|
| 322 |
-
data_files["validation"] = data_args.validation_file
|
| 323 |
-
extension = data_args.train_file.split(".")[-1]
|
| 324 |
-
if extension == "txt":
|
| 325 |
-
extension = "text"
|
| 326 |
-
dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
|
| 327 |
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
| 328 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
| 329 |
|
|
@@ -381,37 +318,6 @@ def main():
|
|
| 381 |
assert tokenizer.pad_token == '<PAD>'
|
| 382 |
|
| 383 |
# Preprocessing the datasets.
|
| 384 |
-
# First we tokenize all the texts.
|
| 385 |
-
if training_args.do_train:
|
| 386 |
-
column_names = dataset["train"].column_names
|
| 387 |
-
else:
|
| 388 |
-
column_names = dataset["validation"].column_names
|
| 389 |
-
text_column_name = "text" if "text" in column_names else column_names[0]
|
| 390 |
-
|
| 391 |
-
# since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
|
| 392 |
-
tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
|
| 393 |
-
|
| 394 |
-
def tokenize_function(examples):
|
| 395 |
-
with CaptureLogger(tok_logger) as cl:
|
| 396 |
-
output = tokenizer(examples[text_column_name])
|
| 397 |
-
# clm input could be much much longer than block_size
|
| 398 |
-
if "Token indices sequence length is longer than the" in cl.out:
|
| 399 |
-
tok_logger.warning(
|
| 400 |
-
"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
|
| 401 |
-
)
|
| 402 |
-
return output
|
| 403 |
-
|
| 404 |
-
# remove dataset tasks
|
| 405 |
-
for k in dataset.keys():
|
| 406 |
-
dataset[k].info.task_templates = []
|
| 407 |
-
|
| 408 |
-
tokenized_datasets = dataset.map(
|
| 409 |
-
tokenize_function,
|
| 410 |
-
batched=True,
|
| 411 |
-
num_proc=data_args.preprocessing_num_workers,
|
| 412 |
-
remove_columns=column_names,
|
| 413 |
-
load_from_cache_file=not data_args.overwrite_cache,
|
| 414 |
-
)
|
| 415 |
|
| 416 |
if data_args.block_size > tokenizer.model_max_length:
|
| 417 |
logger.warning(
|
|
@@ -422,65 +328,27 @@ def main():
|
|
| 422 |
|
| 423 |
pad_token_id, start_token_id = tokenizer.pad_token_id, config.decoder_start_token_id
|
| 424 |
|
| 425 |
-
def
|
| 426 |
-
examples["
|
| 427 |
-
|
| 428 |
-
for i, input_ids in enumerate(examples["input_ids"]):
|
| 429 |
-
if len(input_ids) > block_size:
|
| 430 |
-
for k in examples.keys():
|
| 431 |
-
examples[k][i] = examples[k][i][:block_size]
|
| 432 |
-
elif len(input_ids) < block_size:
|
| 433 |
-
delta = block_size - len(input_ids)
|
| 434 |
-
examples['input_ids'][i] = examples['input_ids'][i] + [pad_token_id] * delta
|
| 435 |
-
examples['attention_mask'][i] = examples['attention_mask'][i] + [0] * delta
|
| 436 |
-
examples['labels'][i] = examples['labels'][i] + [-100] * delta
|
| 437 |
-
|
| 438 |
-
return examples
|
| 439 |
-
|
| 440 |
-
logger.info('clip_texts...')
|
| 441 |
-
clipped_lm_datasets = tokenized_datasets.map(
|
| 442 |
-
clip_texts,
|
| 443 |
-
batched=True,
|
| 444 |
-
num_proc=data_args.preprocessing_num_workers,
|
| 445 |
-
load_from_cache_file=not data_args.overwrite_cache,
|
| 446 |
-
)
|
| 447 |
-
|
| 448 |
-
def add_decoder_input_ids(examples):
|
| 449 |
-
arr_input_ids = jnp.array(examples["input_ids"])
|
| 450 |
-
pad = pad_token_id * jnp.ones((arr_input_ids.shape[0], 1), dtype=jnp.int32)
|
| 451 |
-
arr_pad_input_ids = jnp.concatenate((arr_input_ids, pad), axis=1)
|
| 452 |
-
examples['decoder_input_ids'] = shift_tokens_right(arr_pad_input_ids, pad_token_id, start_token_id)
|
| 453 |
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
|
| 458 |
-
|
| 459 |
-
|
|
|
|
| 460 |
|
| 461 |
-
|
|
|
|
| 462 |
|
| 463 |
-
|
| 464 |
-
lm_datasets = clipped_lm_datasets.map(
|
| 465 |
-
add_decoder_input_ids,
|
| 466 |
-
batched=True,
|
| 467 |
-
num_proc=data_args.preprocessing_num_workers,
|
| 468 |
-
load_from_cache_file=not data_args.overwrite_cache,
|
| 469 |
-
)
|
| 470 |
|
| 471 |
-
|
| 472 |
-
if "train" not in tokenized_datasets:
|
| 473 |
-
raise ValueError("--do_train requires a train dataset")
|
| 474 |
-
train_dataset = lm_datasets["train"]
|
| 475 |
-
if data_args.max_train_samples is not None:
|
| 476 |
-
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
| 477 |
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
eval_dataset = lm_datasets["validation"]
|
| 482 |
-
if data_args.max_eval_samples is not None:
|
| 483 |
-
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
| 484 |
|
| 485 |
# Enable tensorboard only on the master node
|
| 486 |
has_tensorboard = is_tensorboard_available()
|
|
@@ -507,13 +375,13 @@ def main():
|
|
| 507 |
# Store some constant
|
| 508 |
num_epochs = int(training_args.num_train_epochs)
|
| 509 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
| 510 |
-
|
| 511 |
-
steps_per_epoch =
|
| 512 |
total_train_steps = steps_per_epoch * num_epochs
|
| 513 |
|
| 514 |
# Create learning rate schedule
|
| 515 |
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
| 516 |
-
|
| 517 |
train_batch_size,
|
| 518 |
training_args.num_train_epochs,
|
| 519 |
training_args.warmup_steps,
|
|
@@ -602,26 +470,14 @@ def main():
|
|
| 602 |
|
| 603 |
return new_state, metrics
|
| 604 |
|
| 605 |
-
#
|
| 606 |
-
def eval_step(params, rng, batch):
|
| 607 |
-
labels = batch.pop("labels")
|
| 608 |
-
logits, latent_codes = model(**batch, params=params, train=False)[:2]
|
| 609 |
-
loss = loss_fn(logits, labels, latent_codes, rng)
|
| 610 |
-
|
| 611 |
-
# summarize metrics
|
| 612 |
-
metrics = {"loss": loss}
|
| 613 |
-
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
| 614 |
-
return metrics
|
| 615 |
-
|
| 616 |
-
# Create parallel version of the train and eval step
|
| 617 |
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
| 618 |
-
p_eval_step = jax.pmap(eval_step, "batch")
|
| 619 |
|
| 620 |
# Replicate the train state on each device
|
| 621 |
state = state.replicate()
|
| 622 |
|
| 623 |
logger.info("***** Running training *****")
|
| 624 |
-
logger.info(f" Num examples = {
|
| 625 |
logger.info(f" Num Epochs = {num_epochs}")
|
| 626 |
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
|
| 627 |
logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
|
|
@@ -638,15 +494,15 @@ def main():
|
|
| 638 |
rng, input_rng = jax.random.split(rng)
|
| 639 |
|
| 640 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
| 641 |
-
train_loader = data_loader(input_rng, train_dataset, train_batch_size
|
| 642 |
-
steps_per_epoch =
|
| 643 |
# train
|
| 644 |
for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
| 645 |
batch = next(train_loader)
|
| 646 |
state, train_metric = p_train_step(state, batch)
|
| 647 |
train_metrics.append(train_metric)
|
| 648 |
|
| 649 |
-
cur_step = epoch * (
|
| 650 |
|
| 651 |
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
| 652 |
# Save metrics
|
|
@@ -661,36 +517,6 @@ def main():
|
|
| 661 |
|
| 662 |
train_metrics = []
|
| 663 |
|
| 664 |
-
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
|
| 665 |
-
# ======================== Evaluating ==============================
|
| 666 |
-
eval_metrics = []
|
| 667 |
-
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
|
| 668 |
-
eval_steps = len(eval_dataset) // eval_batch_size
|
| 669 |
-
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
|
| 670 |
-
# Model forward
|
| 671 |
-
batch = next(eval_loader)
|
| 672 |
-
metrics = p_eval_step(state.params, state.dropout_rng, batch)
|
| 673 |
-
eval_metrics.append(metrics)
|
| 674 |
-
|
| 675 |
-
# normalize eval metrics
|
| 676 |
-
eval_metrics = get_metrics(eval_metrics)
|
| 677 |
-
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
| 678 |
-
|
| 679 |
-
try:
|
| 680 |
-
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
|
| 681 |
-
except OverflowError:
|
| 682 |
-
eval_metrics["perplexity"] = float("inf")
|
| 683 |
-
|
| 684 |
-
# Print metrics and update progress bar
|
| 685 |
-
desc = f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
|
| 686 |
-
epochs.write(desc)
|
| 687 |
-
epochs.desc = desc
|
| 688 |
-
|
| 689 |
-
# Save metrics
|
| 690 |
-
if has_tensorboard and jax.process_index() == 0:
|
| 691 |
-
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
| 692 |
-
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
| 693 |
-
|
| 694 |
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
| 695 |
# save checkpoint after each epoch and push checkpoint to the hub
|
| 696 |
if jax.process_index() == 0:
|
|
|
|
| 17 |
- [ ] Send the schedule time to the compute_loss method and calculate a coefficient based on that.
|
| 18 |
'''
|
| 19 |
import logging
|
|
|
|
| 20 |
import os
|
| 21 |
import sys
|
| 22 |
import time
|
|
|
|
| 30 |
|
| 31 |
import jax
|
| 32 |
import jax.numpy as jnp
|
| 33 |
+
import numpy as onp
|
| 34 |
import optax
|
| 35 |
import transformers
|
| 36 |
from flax import jax_utils, traverse_util
|
|
|
|
| 44 |
is_tensorboard_available,
|
| 45 |
)
|
| 46 |
from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
|
|
|
|
| 47 |
|
| 48 |
from t5_vae_flax.src.t5_vae import FlaxT5VaeForAutoencoding
|
| 49 |
from t5_vae_flax.src.config import T5VaeConfig
|
|
|
|
| 112 |
@dataclass
|
| 113 |
class DataTrainingArguments:
|
| 114 |
"""
|
| 115 |
+
Arguments pertaining to what data we are going to input our model for training.
|
| 116 |
"""
|
| 117 |
|
| 118 |
dataset_name: Optional[str] = field(
|
|
|
|
| 122 |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
| 123 |
)
|
| 124 |
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
max_train_samples: Optional[int] = field(
|
| 126 |
default=None,
|
| 127 |
metadata={
|
|
|
|
| 129 |
"value if set."
|
| 130 |
},
|
| 131 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
overwrite_cache: bool = field(
|
| 133 |
+
default=False, metadata={"help": "Overwrite the cached training sets"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
)
|
| 135 |
block_size: Optional[int] = field(
|
| 136 |
default=None,
|
|
|
|
| 144 |
default=False, metadata={"help": "Stream the dataset."}
|
| 145 |
)
|
| 146 |
overwrite_cache: bool = field(
|
| 147 |
+
default=False, metadata={"help": "Overwrite the cached training sets"}
|
| 148 |
)
|
| 149 |
preprocessing_num_workers: Optional[int] = field(
|
| 150 |
default=None,
|
|
|
|
| 152 |
)
|
| 153 |
|
| 154 |
def __post_init__(self):
|
| 155 |
+
if self.dataset_name is None and self.train_file is None:
|
| 156 |
+
raise ValueError("Need either a dataset name or a training file.")
|
| 157 |
else:
|
| 158 |
if self.train_file is not None:
|
| 159 |
extension = self.train_file.split(".")[-1]
|
| 160 |
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
|
| 163 |
class TrainState(train_state.TrainState):
|
|
|
|
| 167 |
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
| 168 |
|
| 169 |
|
| 170 |
+
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int):
|
| 171 |
"""
|
| 172 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
| 173 |
Shuffle batches if `shuffle` is `True`.
|
| 174 |
"""
|
| 175 |
+
batch = []
|
| 176 |
+
for row in dataset:
|
| 177 |
+
batch.append(row)
|
| 178 |
+
if len(batch) >= batch_size:
|
| 179 |
+
batch = {k: jnp.stack([row[k] for row in batch]) for k in batch[0].keys()}
|
| 180 |
+
batch = shard(batch)
|
| 181 |
+
yield batch
|
| 182 |
+
batch = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
|
| 185 |
def write_train_metric(summary_writer, train_metrics, train_time, step):
|
|
|
|
| 192 |
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
| 193 |
|
| 194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
def create_learning_rate_fn(
|
| 196 |
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
| 197 |
) -> Callable[[int], jnp.array]:
|
|
|
|
| 249 |
transformers.utils.logging.set_verbosity_error()
|
| 250 |
|
| 251 |
# Set the verbosity to info of the Transformers logger (on main process only):
|
| 252 |
+
logger.info(f"Training parameters {training_args}")
|
| 253 |
|
| 254 |
+
# Get the datasets: you can either provide your own CSV/JSON/TXT training files (see below)
|
| 255 |
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
| 256 |
# (the dataset will be downloaded automatically from the datasets Hub).
|
| 257 |
#
|
|
|
|
| 260 |
#
|
| 261 |
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
| 262 |
# download the dataset.
|
| 263 |
+
dataset = load_dataset('text', data_files=[f'wikipedia/{i}.txt' for i in range(298)], cache_dir=model_args.cache_dir, streaming=True)['train']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
| 265 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
| 266 |
|
|
|
|
| 318 |
assert tokenizer.pad_token == '<PAD>'
|
| 319 |
|
| 320 |
# Preprocessing the datasets.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
|
| 322 |
if data_args.block_size > tokenizer.model_max_length:
|
| 323 |
logger.warning(
|
|
|
|
| 328 |
|
| 329 |
pad_token_id, start_token_id = tokenizer.pad_token_id, config.decoder_start_token_id
|
| 330 |
|
| 331 |
+
def tokenize_function(examples):
|
| 332 |
+
output = tokenizer(examples["text"], return_tensors='jax', padding='max_length', max_length=block_size, truncation=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
|
| 334 |
+
output['labels'] = onp.array(output['input_ids'].copy())
|
| 335 |
+
output['labels'][output['labels'] == pad_token_id] = -100
|
| 336 |
+
output['labels'] = jnp.array(output['labels'])
|
| 337 |
|
| 338 |
+
pad = pad_token_id * jnp.ones((output['input_ids'].shape[0], 1), dtype=jnp.int32)
|
| 339 |
+
arr_pad_input_ids = jnp.concatenate((output['input_ids'], pad), axis=1)
|
| 340 |
+
output['decoder_input_ids'] = shift_tokens_right(arr_pad_input_ids, pad_token_id, start_token_id)
|
| 341 |
|
| 342 |
+
ones = jnp.ones((output['attention_mask'].shape[0], 1), dtype=jnp.int32)
|
| 343 |
+
output['decoder_attention_mask'] = jnp.concatenate((ones, output['attention_mask']), axis=1)
|
| 344 |
|
| 345 |
+
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
|
| 347 |
+
tokenized_datasets = dataset.map(tokenize_function, batched=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
|
| 349 |
+
train_dataset = tokenized_datasets
|
| 350 |
+
if data_args.max_train_samples is not None:
|
| 351 |
+
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
|
|
|
|
|
|
|
|
|
| 352 |
|
| 353 |
# Enable tensorboard only on the master node
|
| 354 |
has_tensorboard = is_tensorboard_available()
|
|
|
|
| 375 |
# Store some constant
|
| 376 |
num_epochs = int(training_args.num_train_epochs)
|
| 377 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
| 378 |
+
train_dataset_len = 97876602
|
| 379 |
+
steps_per_epoch = train_dataset_len // train_batch_size
|
| 380 |
total_train_steps = steps_per_epoch * num_epochs
|
| 381 |
|
| 382 |
# Create learning rate schedule
|
| 383 |
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
| 384 |
+
train_dataset_len,
|
| 385 |
train_batch_size,
|
| 386 |
training_args.num_train_epochs,
|
| 387 |
training_args.warmup_steps,
|
|
|
|
| 470 |
|
| 471 |
return new_state, metrics
|
| 472 |
|
| 473 |
+
# Create parallel version of the train step
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
|
|
|
| 475 |
|
| 476 |
# Replicate the train state on each device
|
| 477 |
state = state.replicate()
|
| 478 |
|
| 479 |
logger.info("***** Running training *****")
|
| 480 |
+
logger.info(f" Num examples = {train_dataset_len}")
|
| 481 |
logger.info(f" Num Epochs = {num_epochs}")
|
| 482 |
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
|
| 483 |
logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
|
|
|
|
| 494 |
rng, input_rng = jax.random.split(rng)
|
| 495 |
|
| 496 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
| 497 |
+
train_loader = data_loader(input_rng, train_dataset, train_batch_size)
|
| 498 |
+
steps_per_epoch = train_dataset_len // train_batch_size
|
| 499 |
# train
|
| 500 |
for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
| 501 |
batch = next(train_loader)
|
| 502 |
state, train_metric = p_train_step(state, batch)
|
| 503 |
train_metrics.append(train_metric)
|
| 504 |
|
| 505 |
+
cur_step = epoch * (train_dataset_len // train_batch_size) + step
|
| 506 |
|
| 507 |
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
| 508 |
# Save metrics
|
|
|
|
| 517 |
|
| 518 |
train_metrics = []
|
| 519 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 520 |
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
| 521 |
# save checkpoint after each epoch and push checkpoint to the hub
|
| 522 |
if jax.process_index() == 0:
|
train.sh
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export RUN_NAME=single_latent
|
| 2 |
+
|
| 3 |
+
# TODO update to not use tokenizer, instead use gpt2 one
|
| 4 |
+
./venv/bin/python train.py \
|
| 5 |
+
--t5_model_name_or_path="t5-base" \
|
| 6 |
+
--output_dir="output/${RUN_NAME}" \
|
| 7 |
+
--overwrite_output_dir \
|
| 8 |
+
--do_train \
|
| 9 |
+
--n_latent_tokens 1 \
|
| 10 |
+
--latent_token_size 32 \
|
| 11 |
+
--save_steps="2000" \
|
| 12 |
+
--block_size="128" \
|
| 13 |
+
--per_device_train_batch_size="100" \
|
| 14 |
+
--train_file="INVALID.txt" \
|
| 15 |
+
--overwrite_output_dir \
|
| 16 |
+
--num_train_epochs="1" \
|
| 17 |
+
|
| 18 |
+
# 200 batch size, 128 sequence len: ? (breaks)
|
| 19 |
+
# 100 batch size, 128 sequence len: 252:38:58
|
| 20 |
+
# 10 batch size, 128 sequence len: 281:32:53
|
| 21 |
+
|
| 22 |
+
# Got ~12 hours to train, want 3 saves so one save every 4 hours
|
wiki_sentences.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# unused
|
| 2 |
+
"""Wikipedia Sentences"""
|
| 3 |
+
|
| 4 |
+
from __future__ import absolute_import, division, print_function
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
import datasets
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
_DESCRIPTION = """\
|
| 13 |
+
Dataset of sentences from Wikipedia (from the [Optimus paper](https://arxiv.org/abs/2004.04092)).
|
| 14 |
+
Each is of mex 64 words & <=256 GPT2 tokens.
|
| 15 |
+
Each row is a tokenised sentence.
|
| 16 |
+
{'token_ids': '{gpt2 token ids}'}
|
| 17 |
+
This is to test the semantics of a Transformer-VAEs latent space by interpolating on sentences.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
NUM_SEGMENTS = 5
|
| 21 |
+
DOWNLOAD_URLS = 'https://drive.google.com/file/d/13NnkYAhwszQxc1C5HHfThnF7c1cjzjAD/view?usp=sharing, https://drive.google.com/file/d/14p6FHip_hGTXC-_7SYaK32BpEhZRDJI4/view?usp=sharing, https://drive.google.com/file/d/1IaRfTFh51Wf_zPtK6tjE6xw-up_Z6EyN/view?usp=sharing, https://drive.google.com/file/d/1KGhV397Xfej56uJ9H10xD7tfLdhWlg4q/view?usp=sharing, https://drive.google.com/file/d/1LfsQ1s9wr1mBG3I1bbvnbyrYmnsrXxZt/view?usp=sharing, https://drive.google.com/file/d/1OctFe_JPR0Ajh77FzWdfeYnWZinKl2sW/view?usp=sharing, https://drive.google.com/file/d/1W-Yi8gHCcT8O5F4TcDHScH7pOb0GQZdu/view?usp=sharing, https://drive.google.com/file/d/1jgHjnpe7Vk1pvRgfnH4S4KiRrpUQyqyp/view?usp=sharing, https://drive.google.com/file/d/1oVst8RG8G2d21DL6q4DwO7aJxE1vA2fc/view?usp=sharing, https://drive.google.com/file/d/1qwckIM8YBbU9bnArB6bAoStY3e9I1kqU/view?usp=sharing'.split(', ')
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class WikiSentences(datasets.GeneratorBasedBuilder):
|
| 25 |
+
"""Sentences from Wikipedia."""
|
| 26 |
+
|
| 27 |
+
BUILDER_CONFIGS = [datasets.BuilderConfig(name="main", description="Run through json files one by one.",)]
|
| 28 |
+
|
| 29 |
+
def _info(self):
|
| 30 |
+
return datasets.DatasetInfo(
|
| 31 |
+
description=_DESCRIPTION,
|
| 32 |
+
features=datasets.Features(
|
| 33 |
+
{
|
| 34 |
+
'token_ids': [datasets.Value("int32")],
|
| 35 |
+
}
|
| 36 |
+
),
|
| 37 |
+
homepage="https://github.com/Fraser-Greenlee/transformer-vae",
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def _generate_examples(self, filepath):
|
| 41 |
+
"""Generate examples."""
|
| 42 |
+
with open(filepath, encoding="utf-8") as json_lines_file:
|
| 43 |
+
for id_, line in enumerate(json_lines_file):
|
| 44 |
+
yield id_, json.loads(line)
|
| 45 |
+
if id_ >= self.config.max_num_samples:
|
| 46 |
+
break
|