Try avoid hf hub git rate limits
Browse files- config.gin +2 -3
- config.json +1 -1
- small_nl24_pretrain.gin +2 -3
- start_train.sh +2 -1
- tasks.py +70 -37
- train.py +0 -689
- train/events.out.tfevents.1649073594.t1v-n-304587cf-w-0.1316481.0.v2 +0 -3
- train/events.out.tfevents.1649092520.t1v-n-304587cf-w-0.1399566.0.v2 +0 -3
- train/{events.out.tfevents.1649056216.t1v-n-304587cf-w-0.1239745.0.v2 → events.out.tfevents.1649705066.t1v-n-304587cf-w-0.2549834.0.v2} +2 -2
- training_eval/pretrain_finnish/events.out.tfevents.1649073594.t1v-n-304587cf-w-0.1316481.1.v2 +0 -3
- training_eval/pretrain_finnish/events.out.tfevents.1649092520.t1v-n-304587cf-w-0.1399566.1.v2 +0 -3
- training_eval/pretrain_finnish/{events.out.tfevents.1649056216.t1v-n-304587cf-w-0.1239745.1.v2 → events.out.tfevents.1649705066.t1v-n-304587cf-w-0.2549834.1.v2} +2 -2
config.gin
CHANGED
|
@@ -12,7 +12,7 @@ import tasks
|
|
| 12 |
|
| 13 |
# Macros:
|
| 14 |
# ==============================================================================
|
| 15 |
-
BATCH_SIZE =
|
| 16 |
DROPOUT_RATE = 0.0
|
| 17 |
LABEL_SMOOTHING = 0.0
|
| 18 |
LOSS_NORMALIZING_FACTOR = None
|
|
@@ -23,7 +23,7 @@ MODEL_DIR = '/researchdisk/t5x-small-nl24-finnish'
|
|
| 23 |
OPTIMIZER = @adafactor.Adafactor()
|
| 24 |
RANDOM_SEED = None
|
| 25 |
SHUFFLE_TRAIN_EXAMPLES = True
|
| 26 |
-
TASK_FEATURE_LENGTHS = {'inputs': 512, 'targets':
|
| 27 |
TRAIN_STEPS = 500000
|
| 28 |
USE_CACHED_TASKS = False
|
| 29 |
USE_HARDWARE_RNG = False
|
|
@@ -123,7 +123,6 @@ network.T5Config.vocab_size = 32128
|
|
| 123 |
train_script.train.checkpoint_cfg = @utils.CheckpointConfig()
|
| 124 |
train_script.train.eval_period = 10000
|
| 125 |
train_script.train.eval_steps = 20
|
| 126 |
-
train_script.train.hub_model_id = 'Finnish-NLP/t5x-small-nl24-finnish'
|
| 127 |
train_script.train.infer_eval_dataset_cfg = None
|
| 128 |
train_script.train.model = %MODEL
|
| 129 |
train_script.train.model_dir = %MODEL_DIR
|
|
|
|
| 12 |
|
| 13 |
# Macros:
|
| 14 |
# ==============================================================================
|
| 15 |
+
BATCH_SIZE = 256
|
| 16 |
DROPOUT_RATE = 0.0
|
| 17 |
LABEL_SMOOTHING = 0.0
|
| 18 |
LOSS_NORMALIZING_FACTOR = None
|
|
|
|
| 23 |
OPTIMIZER = @adafactor.Adafactor()
|
| 24 |
RANDOM_SEED = None
|
| 25 |
SHUFFLE_TRAIN_EXAMPLES = True
|
| 26 |
+
TASK_FEATURE_LENGTHS = {'inputs': 512, 'targets': 512}
|
| 27 |
TRAIN_STEPS = 500000
|
| 28 |
USE_CACHED_TASKS = False
|
| 29 |
USE_HARDWARE_RNG = False
|
|
|
|
| 123 |
train_script.train.checkpoint_cfg = @utils.CheckpointConfig()
|
| 124 |
train_script.train.eval_period = 10000
|
| 125 |
train_script.train.eval_steps = 20
|
|
|
|
| 126 |
train_script.train.infer_eval_dataset_cfg = None
|
| 127 |
train_script.train.model = %MODEL
|
| 128 |
train_script.train.model_dir = %MODEL_DIR
|
config.json
CHANGED
|
@@ -7,7 +7,7 @@
|
|
| 7 |
"d_kv": 64,
|
| 8 |
"d_model": 512,
|
| 9 |
"decoder_start_token_id": 0,
|
| 10 |
-
"dropout_rate": 0.
|
| 11 |
"eos_token_id": 1,
|
| 12 |
"feed_forward_proj": "gated-gelu",
|
| 13 |
"initializer_factor": 1.0,
|
|
|
|
| 7 |
"d_kv": 64,
|
| 8 |
"d_model": 512,
|
| 9 |
"decoder_start_token_id": 0,
|
| 10 |
+
"dropout_rate": 0.1,
|
| 11 |
"eos_token_id": 1,
|
| 12 |
"feed_forward_proj": "gated-gelu",
|
| 13 |
"initializer_factor": 1.0,
|
small_nl24_pretrain.gin
CHANGED
|
@@ -11,7 +11,6 @@ include 't5x/configs/runs/pretrain.gin'
|
|
| 11 |
# ------------------- Training specification overrides --------------------------
|
| 12 |
train_script.train:
|
| 13 |
eval_period = 10000
|
| 14 |
-
hub_model_id = "Finnish-NLP/t5x-small-nl24-finnish"
|
| 15 |
|
| 16 |
utils.SaveCheckpointConfig:
|
| 17 |
period = 10000
|
|
@@ -19,7 +18,7 @@ utils.SaveCheckpointConfig:
|
|
| 19 |
|
| 20 |
MIXTURE_OR_TASK_NAME = "pretrain_finnish"
|
| 21 |
USE_CACHED_TASKS = False
|
| 22 |
-
TASK_FEATURE_LENGTHS = {"inputs": 512, "targets":
|
| 23 |
TRAIN_STEPS = 500000
|
| 24 |
DROPOUT_RATE = 0.0
|
| 25 |
-
BATCH_SIZE =
|
|
|
|
| 11 |
# ------------------- Training specification overrides --------------------------
|
| 12 |
train_script.train:
|
| 13 |
eval_period = 10000
|
|
|
|
| 14 |
|
| 15 |
utils.SaveCheckpointConfig:
|
| 16 |
period = 10000
|
|
|
|
| 18 |
|
| 19 |
MIXTURE_OR_TASK_NAME = "pretrain_finnish"
|
| 20 |
USE_CACHED_TASKS = False
|
| 21 |
+
TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512}
|
| 22 |
TRAIN_STEPS = 500000
|
| 23 |
DROPOUT_RATE = 0.0
|
| 24 |
+
BATCH_SIZE = 256
|
start_train.sh
CHANGED
|
@@ -2,10 +2,11 @@
|
|
| 2 |
unset LD_PRELOAD
|
| 3 |
|
| 4 |
PROJECT_DIR="/researchdisk/t5x-small-nl24-finnish"
|
|
|
|
| 5 |
MODEL_DIR="/researchdisk/t5x-small-nl24-finnish"
|
| 6 |
export PYTHONPATH=${PROJECT_DIR}
|
| 7 |
|
| 8 |
-
python3 train.py \
|
| 9 |
--gin_search_paths=${PROJECT_DIR} \
|
| 10 |
--gin_file="small_nl24_pretrain.gin" \
|
| 11 |
--gin.MODEL_DIR=\"${MODEL_DIR}\"
|
|
|
|
| 2 |
unset LD_PRELOAD
|
| 3 |
|
| 4 |
PROJECT_DIR="/researchdisk/t5x-small-nl24-finnish"
|
| 5 |
+
T5X_DIR=${HOME}"/t5x" # directory where the t5x is cloned.
|
| 6 |
MODEL_DIR="/researchdisk/t5x-small-nl24-finnish"
|
| 7 |
export PYTHONPATH=${PROJECT_DIR}
|
| 8 |
|
| 9 |
+
python3 ${T5X_DIR}/t5x/train.py \
|
| 10 |
--gin_search_paths=${PROJECT_DIR} \
|
| 11 |
--gin_file="small_nl24_pretrain.gin" \
|
| 12 |
--gin.MODEL_DIR=\"${MODEL_DIR}\"
|
tasks.py
CHANGED
|
@@ -1,49 +1,82 @@
|
|
|
|
|
|
|
|
| 1 |
import functools
|
|
|
|
| 2 |
import seqio
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
| 4 |
from t5.data import preprocessors
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
}
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
preprocessors=[
|
| 19 |
functools.partial(
|
| 20 |
-
|
| 21 |
-
field_names=["text"],
|
| 22 |
-
field_delim="\n"),
|
| 23 |
-
functools.partial(
|
| 24 |
-
preprocessors.rekey, key_map={
|
| 25 |
"inputs": None,
|
| 26 |
-
"targets":
|
| 27 |
-
}),
|
| 28 |
seqio.preprocessors.tokenize,
|
| 29 |
-
seqio.CacheDatasetPlaceholder(),
|
| 30 |
-
preprocessors.span_corruption,
|
| 31 |
seqio.preprocessors.append_eos_after_trim,
|
| 32 |
],
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
# dataset = seqio.get_mixture_or_task("pretrain_finnish").get_dataset(
|
| 37 |
-
# sequence_length={"inputs": 512, "targets": 114},
|
| 38 |
-
# split="train",
|
| 39 |
-
# shuffle=True,
|
| 40 |
-
# num_epochs=1,
|
| 41 |
-
# #shard_info=seqio.ShardInfo(index=0, num_shards=10),
|
| 42 |
-
# use_cached=False,
|
| 43 |
-
# seed=42
|
| 44 |
-
# )
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
# # Print the first 5 examples.
|
| 48 |
-
# for _, ex in zip(range(5), dataset.as_numpy_iterator()):
|
| 49 |
-
# print(ex)
|
|
|
|
| 1 |
+
# adapted from https://huggingface.co/pere/pk-nb-t5x/blob/main/tasks.py
|
| 2 |
+
|
| 3 |
import functools
|
| 4 |
+
|
| 5 |
import seqio
|
| 6 |
+
import tensorflow as tf
|
| 7 |
+
import t5.data
|
| 8 |
+
from datasets import load_dataset, load_from_disk
|
| 9 |
+
from t5.data import postprocessors
|
| 10 |
from t5.data import preprocessors
|
| 11 |
+
from t5.evaluation import metrics
|
| 12 |
+
from seqio import FunctionDataSource, utils
|
| 13 |
|
| 14 |
+
TaskRegistry = seqio.TaskRegistry
|
| 15 |
+
|
| 16 |
+
vocabulary = seqio.SentencePieceVocabulary('spiece.model', extra_ids=0)
|
| 17 |
+
|
| 18 |
+
DEFAULT_OUTPUT_FEATURES = {
|
| 19 |
+
"inputs": seqio.Feature(
|
| 20 |
+
vocabulary=vocabulary, add_eos=True,
|
| 21 |
+
required=False),
|
| 22 |
+
"targets": seqio.Feature(
|
| 23 |
+
vocabulary=vocabulary, add_eos=True)
|
| 24 |
}
|
| 25 |
|
| 26 |
+
|
| 27 |
+
def gen_dataset(split, shuffle=False, seed=None, column="text", dataset=None):
|
| 28 |
+
if shuffle:
|
| 29 |
+
if seed:
|
| 30 |
+
dataset = dataset.shuffle(seed=seed)
|
| 31 |
+
else:
|
| 32 |
+
dataset = dataset.shuffle()
|
| 33 |
+
while True:
|
| 34 |
+
for item in dataset[str(split)]:
|
| 35 |
+
yield item[column]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def dataset_fn(split, shuffle_files, seed=None, dataset=None):
|
| 39 |
+
return tf.data.Dataset.from_generator(
|
| 40 |
+
functools.partial(gen_dataset, split, shuffle_files, seed, dataset=dataset),
|
| 41 |
+
output_signature=tf.TensorSpec(shape=(), dtype=tf.string, name=dataset_name)
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@utils.map_over_dataset
|
| 46 |
+
def target_to_key(x, key_map, target_key):
|
| 47 |
+
"""Assign the value from the dataset to target_key in key_map"""
|
| 48 |
+
return {**key_map, target_key: x}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# Final pretraining task used in Raffel et al., 2019 adaptated to NCC
|
| 52 |
+
dataset_name = "/researchdisk/lm_training_dataset_full"
|
| 53 |
+
dataset_params = {"from_disk_path": dataset_name}
|
| 54 |
+
|
| 55 |
+
if "from_disk_path" in dataset_params:
|
| 56 |
+
dataset = load_from_disk(dataset_params.get("from_disk_path"))
|
| 57 |
+
else:
|
| 58 |
+
dataset = load_dataset(**dataset_params)
|
| 59 |
+
|
| 60 |
+
dataset_shapes = {"train": dataset["train"].num_rows, "validation": dataset["validation"].num_rows}
|
| 61 |
+
TaskRegistry.add(
|
| 62 |
+
"pretrain_finnish",
|
| 63 |
+
source=seqio.FunctionDataSource(
|
| 64 |
+
dataset_fn=functools.partial(dataset_fn, dataset=dataset),
|
| 65 |
+
splits=("train", "validation"),
|
| 66 |
+
caching_permitted=False,
|
| 67 |
+
num_input_examples=dataset_shapes,
|
| 68 |
+
),
|
| 69 |
preprocessors=[
|
| 70 |
functools.partial(
|
| 71 |
+
target_to_key, key_map={
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
"inputs": None,
|
| 73 |
+
"targets": None,
|
| 74 |
+
}, target_key="targets"),
|
| 75 |
seqio.preprocessors.tokenize,
|
| 76 |
+
# seqio.CacheDatasetPlaceholder(),
|
| 77 |
+
preprocessors.span_corruption,
|
| 78 |
seqio.preprocessors.append_eos_after_trim,
|
| 79 |
],
|
| 80 |
+
output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
|
| 81 |
+
metric_fns=[metrics.accuracy]
|
| 82 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train.py
DELETED
|
@@ -1,689 +0,0 @@
|
|
| 1 |
-
# Copyright 2022 The T5X Authors.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
r"""Script to pretrain or finetune in JAX using a SeqIO pipeline.
|
| 16 |
-
|
| 17 |
-
"""
|
| 18 |
-
import functools
|
| 19 |
-
import itertools
|
| 20 |
-
import math
|
| 21 |
-
import os
|
| 22 |
-
import time
|
| 23 |
-
from typing import Callable, Iterator, Sequence, Mapping, Tuple, Type, Optional
|
| 24 |
-
import subprocess
|
| 25 |
-
|
| 26 |
-
# Set Linen to add profiling information when constructing Modules.
|
| 27 |
-
# Must be set before flax imports.
|
| 28 |
-
# pylint:disable=g-import-not-at-top
|
| 29 |
-
os.environ['FLAX_PROFILE'] = 'true'
|
| 30 |
-
# TODO(adarob): Re-enable once users are notified and tests are updated.
|
| 31 |
-
os.environ['FLAX_LAZY_RNG'] = 'no'
|
| 32 |
-
from absl import logging
|
| 33 |
-
from clu import metric_writers
|
| 34 |
-
import jax
|
| 35 |
-
from jax import random
|
| 36 |
-
from jax.experimental import multihost_utils
|
| 37 |
-
import jax.numpy as jnp
|
| 38 |
-
import numpy as np
|
| 39 |
-
import seqio
|
| 40 |
-
from t5x import models
|
| 41 |
-
from t5x import partitioning
|
| 42 |
-
from t5x import train_state as train_state_lib
|
| 43 |
-
from t5x import trainer as trainer_lib
|
| 44 |
-
from t5x import utils
|
| 45 |
-
from t5x import checkpoint_importer
|
| 46 |
-
LazyArray = checkpoint_importer.LazyArray
|
| 47 |
-
import tensorflow as tf
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
# Automatically search for gin files relative to the T5X package.
|
| 51 |
-
_DEFAULT_GIN_SEARCH_PATHS = [
|
| 52 |
-
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 53 |
-
]
|
| 54 |
-
PyTreeDef = type(jax.tree_structure(None))
|
| 55 |
-
P = partitioning.PartitionSpec
|
| 56 |
-
# Special key that used to distinguish train metrics.
|
| 57 |
-
TRAIN_METRIC_KEY = 'train'
|
| 58 |
-
# String keys that is acceptable from config.
|
| 59 |
-
_ACTION_KEYS = frozenset(trainer_lib.ActionMode.__members__.keys())
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
def run_actions(
|
| 63 |
-
mode: trainer_lib.ActionMode, actions: trainer_lib.ActionMapType,
|
| 64 |
-
train_state: train_state_lib.TrainState,
|
| 65 |
-
metrics_by_task: Mapping[str, trainer_lib.MetricValueMapType]) -> bool:
|
| 66 |
-
"""Invokes all actions on the given mode on host 0, then broadcasts to all.
|
| 67 |
-
|
| 68 |
-
Args:
|
| 69 |
-
mode: The mode to run the actions. e.g., if mode is `train`, only actions
|
| 70 |
-
configured to run with `train` mode will be invoked.
|
| 71 |
-
actions: A mapping of actions that runs after train, eval or infer_eval, to
|
| 72 |
-
inspect the model and perform useful operations, e.g., early stopping.
|
| 73 |
-
train_state: The current train_state of the trainer.
|
| 74 |
-
metrics_by_task: A map of metrics keyed by task name.
|
| 75 |
-
|
| 76 |
-
Returns:
|
| 77 |
-
A bool indicating whether training should be halted.
|
| 78 |
-
|
| 79 |
-
Raises:
|
| 80 |
-
RuntimeError: When the metrics processed on host 0 is None.
|
| 81 |
-
"""
|
| 82 |
-
stop_training = False
|
| 83 |
-
if jax.process_index() == 0:
|
| 84 |
-
if not metrics_by_task:
|
| 85 |
-
raise RuntimeError('Metric is unexpectedly empty on process 0')
|
| 86 |
-
for action in actions.get(mode, []):
|
| 87 |
-
stop_training |= action.run(train_state, metrics_by_task=metrics_by_task)
|
| 88 |
-
# Broadcast result from host 0 to others.
|
| 89 |
-
return bool(multihost_utils.broadcast_one_to_all(jnp.array(stop_training)))
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
def train(
|
| 93 |
-
*,
|
| 94 |
-
model: models.BaseTransformerModel,
|
| 95 |
-
train_dataset_cfg: utils.DatasetConfig,
|
| 96 |
-
train_eval_dataset_cfg: Optional[utils.DatasetConfig],
|
| 97 |
-
infer_eval_dataset_cfg: Optional[utils.DatasetConfig],
|
| 98 |
-
checkpoint_cfg: utils.CheckpointConfig,
|
| 99 |
-
partitioner: partitioning.BasePartitioner,
|
| 100 |
-
trainer_cls: Type[trainer_lib.BaseTrainer],
|
| 101 |
-
model_dir: str,
|
| 102 |
-
total_steps: int,
|
| 103 |
-
eval_steps: int,
|
| 104 |
-
eval_period: int,
|
| 105 |
-
stats_period: Optional[int] = None,
|
| 106 |
-
random_seed: Optional[int],
|
| 107 |
-
use_hardware_rng: bool = False,
|
| 108 |
-
summarize_config_fn: Callable[[str, metric_writers.MetricWriter, int],
|
| 109 |
-
None],
|
| 110 |
-
inference_evaluator_cls: Type[seqio.Evaluator] = seqio.Evaluator,
|
| 111 |
-
get_dataset_fn: utils.GetDatasetCallable = utils.get_dataset,
|
| 112 |
-
concurrent_metrics: bool = True,
|
| 113 |
-
actions: Optional[Mapping[str, Sequence[trainer_lib.BaseAction]]] = None,
|
| 114 |
-
train_eval_get_dataset_fn: Optional[utils.GetDatasetCallable] = None,
|
| 115 |
-
run_eval_before_training: bool = False,
|
| 116 |
-
hub_model_id: str = None,
|
| 117 |
-
) -> Tuple[int, train_state_lib.TrainState]:
|
| 118 |
-
"""Train function.
|
| 119 |
-
|
| 120 |
-
Args:
|
| 121 |
-
model: The model object to use for training.
|
| 122 |
-
train_dataset_cfg: Specification for the dataset to train with.
|
| 123 |
-
train_eval_dataset_cfg: Specification for the dataset to evaluate with using
|
| 124 |
-
the train metrics and no inference (e.g., uses teacher forcing). If None,
|
| 125 |
-
train eval is disabled.
|
| 126 |
-
infer_eval_dataset_cfg: Specification for the dataset to evaluate with using
|
| 127 |
-
the inference metrics (e.g., uses sampled decoding). If None, inference
|
| 128 |
-
eval is disabled.
|
| 129 |
-
checkpoint_cfg: Specification for saving and restoring model parameters and
|
| 130 |
-
dataset state to/from checkpoints.
|
| 131 |
-
partitioner: Partitioner for model parameters and data across devices.
|
| 132 |
-
trainer_cls: An implementation of BaseTrainer.
|
| 133 |
-
model_dir: Path of directory to store checkpoints and metric summaries.
|
| 134 |
-
total_steps: The step number to stop training after. The number of actual
|
| 135 |
-
steps trained in this run will be this number minus the starting step from
|
| 136 |
-
the checkpoint.
|
| 137 |
-
eval_steps: The number of batches to process for each train-eval loop.
|
| 138 |
-
eval_period: The number of train steps between each evaluation (both
|
| 139 |
-
train-eval and infer-eval).
|
| 140 |
-
stats_period: The number of train steps between writing scalar stats. If
|
| 141 |
-
None, defaults to eval_period.
|
| 142 |
-
random_seed: A random seed to use for dropout and initialization. If None, a
|
| 143 |
-
fast, non-deterministic hardware-based RNG is used.
|
| 144 |
-
use_hardware_rng: Whether to force using the RngBitGenerator based hardware
|
| 145 |
-
rng, which takes seeds and acts similarly to software PRNG in that it
|
| 146 |
-
should be seed-deterministic. The new RngBitGenerator custom PRNG system
|
| 147 |
-
should be reproducible for a given sharding, but the numbers will change
|
| 148 |
-
for different shardings of the same model.
|
| 149 |
-
summarize_config_fn: A function that takes in the model directory, a
|
| 150 |
-
SummaryWriter, and the step number, and writes a summary of the
|
| 151 |
-
inference_evaluator_cls: seqio.Evaluator class to use for inference
|
| 152 |
-
evaluation, potentially with bound configuration args.
|
| 153 |
-
get_dataset_fn: The callable use to get the train and train-eval datasets
|
| 154 |
-
based on the DatasetConfig and shard information.
|
| 155 |
-
concurrent_metrics: If True, allow metrics computation and logging to
|
| 156 |
-
overlap with training. Will likely result in additional TPU memory usage.
|
| 157 |
-
actions: A mapping of actions that runs after train, eval or infer_eval, to
|
| 158 |
-
inspect the model and perform useful operations, e.g., early stopping. The
|
| 159 |
-
key must have a 1:1 mapping to ActionMode enum. For EVAL actions to
|
| 160 |
-
actually work, this requires `concurrent_metrics` to be turned off,
|
| 161 |
-
since chaining futures and mutating states concurrently might be
|
| 162 |
-
error-prone.
|
| 163 |
-
train_eval_get_dataset_fn: Optional callable use to get the train-eval
|
| 164 |
-
datasets based on the DatasetConfig and shard information. If missing, it
|
| 165 |
-
defaults to `get_dataset_fn`.
|
| 166 |
-
run_eval_before_training: If True, calculate training eval and inference
|
| 167 |
-
eval metrics before training begins.
|
| 168 |
-
|
| 169 |
-
Returns:
|
| 170 |
-
The tuple of (last_step, last_train_state).
|
| 171 |
-
"""
|
| 172 |
-
logging.info('Process ID: %d', jax.process_index())
|
| 173 |
-
tf.io.gfile.makedirs(model_dir)
|
| 174 |
-
|
| 175 |
-
# Each "epoch" of the training loop should be the min of the eval period,
|
| 176 |
-
# checkpoint period or the full training.
|
| 177 |
-
# We compute here to ensure that the eval period and checkpoint period are
|
| 178 |
-
# divisible by this number, otherwise we fail.
|
| 179 |
-
eval_enabled = (train_eval_dataset_cfg or infer_eval_dataset_cfg)
|
| 180 |
-
eval_period = eval_period if eval_enabled else 0
|
| 181 |
-
checkpoint_period = checkpoint_cfg.save.period if checkpoint_cfg.save else 0
|
| 182 |
-
if eval_period or checkpoint_period:
|
| 183 |
-
steps_per_epoch = min(eval_period or np.inf, checkpoint_period or np.inf)
|
| 184 |
-
else:
|
| 185 |
-
steps_per_epoch = total_steps
|
| 186 |
-
stats_period = stats_period or steps_per_epoch
|
| 187 |
-
if (eval_period and eval_period % steps_per_epoch or
|
| 188 |
-
checkpoint_period and checkpoint_period % steps_per_epoch):
|
| 189 |
-
raise ValueError(
|
| 190 |
-
f'Checkpoint period ({checkpoint_period}) must evenly divide eval '
|
| 191 |
-
f'period ({eval_period}), or vice-versa.')
|
| 192 |
-
|
| 193 |
-
if use_hardware_rng or random_seed is None:
|
| 194 |
-
logging.info(
|
| 195 |
-
'Using fast RngBitGenerator PRNG for initialization and dropout.')
|
| 196 |
-
|
| 197 |
-
if random_seed is None:
|
| 198 |
-
random_seed = multihost_utils.broadcast_one_to_all(np.int32(time.time()))
|
| 199 |
-
logging.info('Random seed not provided, using RNG seed %s', random_seed)
|
| 200 |
-
else:
|
| 201 |
-
logging.warning(
|
| 202 |
-
'When using hardware RNG with a fixed seed, repeatability is only '
|
| 203 |
-
'guaranteed for fixed hardware and partitioning schemes and for a '
|
| 204 |
-
'fixed version of this code and its dependencies.')
|
| 205 |
-
utils.set_hardware_rng_ops()
|
| 206 |
-
rng = random.PRNGKey(random_seed)
|
| 207 |
-
else:
|
| 208 |
-
logging.info('Using seed for initialization and dropout RNG: %d',
|
| 209 |
-
random_seed)
|
| 210 |
-
rng = random.PRNGKey(random_seed)
|
| 211 |
-
|
| 212 |
-
init_rng, trainer_rng = random.split(rng, 2)
|
| 213 |
-
|
| 214 |
-
# ---------------------------------------------------------------------------
|
| 215 |
-
# Initialize datasets
|
| 216 |
-
# ---------------------------------------------------------------------------
|
| 217 |
-
|
| 218 |
-
if (train_dataset_cfg.seed and
|
| 219 |
-
not (checkpoint_cfg.save or checkpoint_cfg.save.save_dataset)):
|
| 220 |
-
logging.warning(
|
| 221 |
-
'Providing a random seed for the train dataset with '
|
| 222 |
-
'`checkpoint_train_ds=False` is dangerous since each '
|
| 223 |
-
'preemption/restart will cause the dataset to deterministically replay '
|
| 224 |
-
'from the beginning.')
|
| 225 |
-
|
| 226 |
-
data_layout = partitioner.get_data_layout(train_dataset_cfg.batch_size)
|
| 227 |
-
ds_shard_id = data_layout.shard_id
|
| 228 |
-
num_ds_shards = data_layout.num_shards
|
| 229 |
-
|
| 230 |
-
def _verify_matching_vocabs(cfg: utils.DatasetConfig):
|
| 231 |
-
ds_vocabs = utils.get_vocabulary(cfg)
|
| 232 |
-
if (ds_vocabs[0] != model.input_vocabulary or
|
| 233 |
-
ds_vocabs[1] != model.output_vocabulary):
|
| 234 |
-
raise ValueError(f'Model and Task vocabularies do not match:\n'
|
| 235 |
-
f' task={cfg.mixture_or_task_name}\n'
|
| 236 |
-
f' ds_vocabs=({ds_vocabs[0]}, {ds_vocabs[1]})\n'
|
| 237 |
-
f' model.input_vocabulary={model.input_vocabulary}\n'
|
| 238 |
-
f' model.output_vocabulary={model.output_vocabulary}\n')
|
| 239 |
-
|
| 240 |
-
_verify_matching_vocabs(train_dataset_cfg)
|
| 241 |
-
|
| 242 |
-
train_ds = get_dataset_fn(train_dataset_cfg, ds_shard_id, num_ds_shards,
|
| 243 |
-
model.FEATURE_CONVERTER_CLS)
|
| 244 |
-
|
| 245 |
-
if train_eval_dataset_cfg:
|
| 246 |
-
_verify_matching_vocabs(train_eval_dataset_cfg)
|
| 247 |
-
train_eval_datasets = utils.get_training_eval_datasets(
|
| 248 |
-
train_eval_dataset_cfg,
|
| 249 |
-
ds_shard_id,
|
| 250 |
-
num_ds_shards,
|
| 251 |
-
eval_steps,
|
| 252 |
-
model.FEATURE_CONVERTER_CLS,
|
| 253 |
-
get_dataset_fn=train_eval_get_dataset_fn if train_eval_get_dataset_fn
|
| 254 |
-
is not None else get_dataset_fn) # type: Mapping[str, tf.data.Dataset]
|
| 255 |
-
if not train_eval_datasets:
|
| 256 |
-
logging.warning(
|
| 257 |
-
'No train_eval datasets loaded from config `train_eval_dataset_cfg`: '
|
| 258 |
-
'%s', train_eval_dataset_cfg)
|
| 259 |
-
else:
|
| 260 |
-
train_eval_datasets = {}
|
| 261 |
-
|
| 262 |
-
# Initialize optimizer, maybe from an existing checkpoint.
|
| 263 |
-
checkpointable_train_iter: tf.data.Iterator = iter(train_ds) # pytype:disable=annotation-type-mismatch
|
| 264 |
-
train_iter: Iterator[trainer_lib.BatchType] = map(
|
| 265 |
-
lambda x: jax.tree_map(np.array, x), checkpointable_train_iter)
|
| 266 |
-
|
| 267 |
-
# The manner in which parameters are initialized follows this order of
|
| 268 |
-
# preference:
|
| 269 |
-
# 1. From a T5X checkpoint in `model_dir`, if one exists.
|
| 270 |
-
# 2. From a T5X or TF checkpoint specified by `cfg.path`, if set.
|
| 271 |
-
# 3. From scratch using `init_fn`.
|
| 272 |
-
|
| 273 |
-
# 1. From a T5X checkpoint in `model_dir`, if one exists.
|
| 274 |
-
if checkpoint_cfg.restore is not None:
|
| 275 |
-
state_transforms_for_restore = [
|
| 276 |
-
functools.partial(fn, is_resuming=True)
|
| 277 |
-
for fn in checkpoint_cfg.restore.state_transformation_fns
|
| 278 |
-
]
|
| 279 |
-
else:
|
| 280 |
-
state_transforms_for_restore = []
|
| 281 |
-
restore_cfgs = [
|
| 282 |
-
utils.RestoreCheckpointConfig(
|
| 283 |
-
path=model_dir,
|
| 284 |
-
mode='latest',
|
| 285 |
-
dtype=checkpoint_cfg.save.dtype,
|
| 286 |
-
checkpointer_cls=checkpoint_cfg.save.checkpointer_cls,
|
| 287 |
-
# Restore dataset state if it is being saved.
|
| 288 |
-
restore_dataset=(checkpoint_cfg.save and
|
| 289 |
-
checkpoint_cfg.save.save_dataset),
|
| 290 |
-
state_transformation_fns=state_transforms_for_restore)
|
| 291 |
-
]
|
| 292 |
-
# 2. From a checkpoint specified by `checkpoint_cfg.restore.path`, if set.
|
| 293 |
-
if checkpoint_cfg.restore:
|
| 294 |
-
if checkpoint_cfg.restore.mode == 'all':
|
| 295 |
-
raise ValueError(
|
| 296 |
-
"Restore checkpoint mode 'all' is not supported in training.")
|
| 297 |
-
|
| 298 |
-
# TODO(dhgarrette): Split "restore" behavior into separate configurations
|
| 299 |
-
# for the initial restoration for a new run, vs resuming a stopped run.
|
| 300 |
-
if isinstance(checkpoint_cfg.restore.path, str):
|
| 301 |
-
restore_cfgs.append(checkpoint_cfg.restore)
|
| 302 |
-
elif not checkpoint_cfg.restore.path:
|
| 303 |
-
# `path` is an empty (non-`str`) sequence, so there is nothing to restore.
|
| 304 |
-
pass
|
| 305 |
-
else:
|
| 306 |
-
raise ValueError(
|
| 307 |
-
'Restore checkpoint config may only have a single path in training.')
|
| 308 |
-
|
| 309 |
-
# Need to use full batch size.
|
| 310 |
-
input_shapes = {
|
| 311 |
-
k: (data_layout.batch_size, *v.shape[1:])
|
| 312 |
-
for k, v in train_ds.element_spec.items()
|
| 313 |
-
}
|
| 314 |
-
input_types = {
|
| 315 |
-
k: v.dtype.as_numpy_dtype() for k, v in train_ds.element_spec.items()
|
| 316 |
-
}
|
| 317 |
-
init_or_restore_tick = time.time()
|
| 318 |
-
train_state_initializer = utils.TrainStateInitializer(
|
| 319 |
-
optimizer_def=model.optimizer_def,
|
| 320 |
-
init_fn=model.get_initial_variables,
|
| 321 |
-
input_shapes=input_shapes,
|
| 322 |
-
input_types=input_types,
|
| 323 |
-
partitioner=partitioner)
|
| 324 |
-
# 3. From scratch using `init_fn`.
|
| 325 |
-
train_state = train_state_initializer.from_checkpoint_or_scratch(
|
| 326 |
-
restore_cfgs, init_rng=init_rng, ds_iter=checkpointable_train_iter)
|
| 327 |
-
train_state_axes = train_state_initializer.train_state_axes
|
| 328 |
-
init_or_restore_secs = time.time() - init_or_restore_tick
|
| 329 |
-
logging.info('Initialize/restore complete (%.2f seconds).',
|
| 330 |
-
init_or_restore_secs)
|
| 331 |
-
|
| 332 |
-
# Log the variable shapes information and write to a file.
|
| 333 |
-
log_file = os.path.join(model_dir, 'model-info.txt')
|
| 334 |
-
utils.log_model_info(log_file,
|
| 335 |
-
train_state_initializer.global_train_state_shape,
|
| 336 |
-
partitioner)
|
| 337 |
-
|
| 338 |
-
if checkpoint_period:
|
| 339 |
-
checkpointer = checkpoint_cfg.save.checkpointer_cls(
|
| 340 |
-
train_state=train_state_initializer.global_train_state_shape,
|
| 341 |
-
partitioner=partitioner,
|
| 342 |
-
checkpoints_dir=model_dir,
|
| 343 |
-
dataset_iterator=(checkpointable_train_iter
|
| 344 |
-
if checkpoint_cfg.save.save_dataset else None),
|
| 345 |
-
save_dtype=checkpoint_cfg.save.dtype,
|
| 346 |
-
keep=checkpoint_cfg.save.keep)
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
# Restore step from last checkpoint or set to 0 if training from scratch.
|
| 350 |
-
host_step = int(train_state.step)
|
| 351 |
-
|
| 352 |
-
# ---------------------------------------------------------------------------
|
| 353 |
-
# Trainer
|
| 354 |
-
# ---------------------------------------------------------------------------
|
| 355 |
-
|
| 356 |
-
trainer: trainer_lib.BaseTrainer = trainer_cls(
|
| 357 |
-
model=model,
|
| 358 |
-
train_state=train_state,
|
| 359 |
-
partitioner=partitioner,
|
| 360 |
-
train_state_axes=train_state_axes,
|
| 361 |
-
eval_names=train_eval_datasets.keys(),
|
| 362 |
-
summary_dir=model_dir,
|
| 363 |
-
rng=trainer_rng)
|
| 364 |
-
del train_state
|
| 365 |
-
|
| 366 |
-
train_metrics = trainer.train_metrics_manager
|
| 367 |
-
summarize_config_fn(model_dir, train_metrics.summary_writer, host_step)
|
| 368 |
-
|
| 369 |
-
train_metrics.write_scalar('timing/init_or_restore_seconds',
|
| 370 |
-
init_or_restore_secs, host_step)
|
| 371 |
-
|
| 372 |
-
# ----------------------------------------------------------------------------
|
| 373 |
-
# SeqIO (inference-based) evaluation setup
|
| 374 |
-
# ----------------------------------------------------------------------------
|
| 375 |
-
# Init evaluator to set up cached datasets
|
| 376 |
-
evaluator = None
|
| 377 |
-
if infer_eval_dataset_cfg is not None:
|
| 378 |
-
_verify_matching_vocabs(infer_eval_dataset_cfg)
|
| 379 |
-
evaluator = inference_evaluator_cls(
|
| 380 |
-
log_dir=os.path.join(model_dir, 'inference_eval'),
|
| 381 |
-
mixture_or_task_name=infer_eval_dataset_cfg.mixture_or_task_name,
|
| 382 |
-
feature_converter=model.FEATURE_CONVERTER_CLS(pack=False),
|
| 383 |
-
eval_split=infer_eval_dataset_cfg.split,
|
| 384 |
-
use_cached=infer_eval_dataset_cfg.use_cached,
|
| 385 |
-
seed=infer_eval_dataset_cfg.seed,
|
| 386 |
-
sequence_length=infer_eval_dataset_cfg.task_feature_lengths,
|
| 387 |
-
use_memory_cache=infer_eval_dataset_cfg.use_memory_cache)
|
| 388 |
-
if not evaluator.eval_tasks:
|
| 389 |
-
# Skip evaluaton.
|
| 390 |
-
evaluator = None
|
| 391 |
-
|
| 392 |
-
if evaluator is not None:
|
| 393 |
-
predict_fn = utils.get_infer_fn(
|
| 394 |
-
infer_step=model.predict_batch,
|
| 395 |
-
batch_size=infer_eval_dataset_cfg.batch_size,
|
| 396 |
-
train_state_axes=train_state_axes,
|
| 397 |
-
partitioner=partitioner)
|
| 398 |
-
|
| 399 |
-
score_fn = utils.get_infer_fn(
|
| 400 |
-
infer_step=model.score_batch,
|
| 401 |
-
batch_size=infer_eval_dataset_cfg.batch_size,
|
| 402 |
-
train_state_axes=train_state_axes,
|
| 403 |
-
partitioner=partitioner)
|
| 404 |
-
|
| 405 |
-
if actions is None:
|
| 406 |
-
actions = {}
|
| 407 |
-
|
| 408 |
-
if set(actions.keys()).difference(_ACTION_KEYS):
|
| 409 |
-
raise ValueError(f'actions keys must be one of {_ACTION_KEYS}, but got : '
|
| 410 |
-
f'{actions.keys()}')
|
| 411 |
-
|
| 412 |
-
# Transform the string key into proper ActionMode enum.
|
| 413 |
-
actions = {trainer_lib.ActionMode[k]: v for k, v in actions.items()}
|
| 414 |
-
|
| 415 |
-
if concurrent_metrics and actions.get(trainer_lib.ActionMode.INFER_EVAL,
|
| 416 |
-
None) is not None:
|
| 417 |
-
logging.warning('Actions for INFER_EVAL will not be triggered when async '
|
| 418 |
-
'metrics computation is enabled')
|
| 419 |
-
if concurrent_metrics and actions.get(trainer_lib.ActionMode.TRAIN,
|
| 420 |
-
None) is not None:
|
| 421 |
-
logging.warning('Actions for TRAIN will not be triggered when async '
|
| 422 |
-
'metrics computation is enabled')
|
| 423 |
-
|
| 424 |
-
# ----------------------------------------------------------------------------
|
| 425 |
-
# Setup Eval Utility Functions
|
| 426 |
-
# ----------------------------------------------------------------------------
|
| 427 |
-
def _run_training_eval(first_run: bool = False):
|
| 428 |
-
if first_run:
|
| 429 |
-
logging.info('Compiling training eval loop.')
|
| 430 |
-
trainer.compile_eval({
|
| 431 |
-
task: utils.get_zeros_batch_like_dataset(ds)
|
| 432 |
-
for task, ds in train_eval_datasets.items()
|
| 433 |
-
})
|
| 434 |
-
logging.info('Computing training evaluation metrics.')
|
| 435 |
-
eval_batch_iters = {
|
| 436 |
-
task: ds.as_numpy_iterator()
|
| 437 |
-
for task, ds in train_eval_datasets.items()
|
| 438 |
-
}
|
| 439 |
-
eval_summaries = trainer.eval(eval_batch_iters)
|
| 440 |
-
trainer.stop_training = run_actions(trainer_lib.ActionMode.TRAIN_EVAL,
|
| 441 |
-
actions, trainer.train_state,
|
| 442 |
-
eval_summaries)
|
| 443 |
-
|
| 444 |
-
def _run_inference_eval():
|
| 445 |
-
"""Run prediction based inference eval."""
|
| 446 |
-
if evaluator is None:
|
| 447 |
-
return
|
| 448 |
-
logging.info('Running inference evaluation.')
|
| 449 |
-
evaluate_tick = time.time()
|
| 450 |
-
all_metrics, _, _ = evaluator.evaluate(
|
| 451 |
-
compute_metrics=jax.process_index() == 0,
|
| 452 |
-
step=host_step,
|
| 453 |
-
predict_fn=functools.partial(
|
| 454 |
-
predict_fn,
|
| 455 |
-
train_state=trainer.train_state,
|
| 456 |
-
rng=jax.random.PRNGKey(0)),
|
| 457 |
-
score_fn=functools.partial(score_fn, train_state=trainer.train_state))
|
| 458 |
-
if not concurrent_metrics:
|
| 459 |
-
# Ensure metrics are finished being computed.
|
| 460 |
-
all_metrics_done = all_metrics.result() or {}
|
| 461 |
-
trainer.stop_training = run_actions(trainer_lib.ActionMode.INFER_EVAL,
|
| 462 |
-
actions, trainer.train_state,
|
| 463 |
-
all_metrics_done)
|
| 464 |
-
train_metrics.write_scalar('timing/evaluate_seconds',
|
| 465 |
-
time.time() - evaluate_tick, host_step)
|
| 466 |
-
|
| 467 |
-
# Optionally run teacher-forcing training eval and SeqIO inference-base eval
|
| 468 |
-
# before training. Useful for testing how much a model knows before any
|
| 469 |
-
# finetuning.
|
| 470 |
-
if run_eval_before_training:
|
| 471 |
-
if train_eval_datasets:
|
| 472 |
-
logging.info('Running training eval before training.')
|
| 473 |
-
_run_training_eval(first_run=True)
|
| 474 |
-
if evaluator is not None:
|
| 475 |
-
logging.info('Running inference eval before training.')
|
| 476 |
-
_run_inference_eval()
|
| 477 |
-
|
| 478 |
-
# ----------------------------------------------------------------------------
|
| 479 |
-
# Main training loop
|
| 480 |
-
# ----------------------------------------------------------------------------
|
| 481 |
-
logging.info('Starting training loop.')
|
| 482 |
-
|
| 483 |
-
first_step = host_step
|
| 484 |
-
|
| 485 |
-
if total_steps < first_step:
|
| 486 |
-
raise ValueError(
|
| 487 |
-
f'Unexpected total_steps ({total_steps}) < checkpoint step '
|
| 488 |
-
f' ({first_step}).')
|
| 489 |
-
|
| 490 |
-
logging.info('Starting main loop over steps %d-%d', first_step, total_steps)
|
| 491 |
-
|
| 492 |
-
steps_per_epoch = min(steps_per_epoch, total_steps)
|
| 493 |
-
first_epoch = first_step // steps_per_epoch
|
| 494 |
-
num_epochs = first_epoch + math.ceil(
|
| 495 |
-
(total_steps - first_step) / steps_per_epoch)
|
| 496 |
-
logging.info('Training with artificial "epochs" of %d steps.',
|
| 497 |
-
steps_per_epoch)
|
| 498 |
-
|
| 499 |
-
# Kickstart training dataset and compile train loop.
|
| 500 |
-
logging.info('Kickstarting train dataset prefetch.')
|
| 501 |
-
logging.flush()
|
| 502 |
-
|
| 503 |
-
ds_tick = time.time()
|
| 504 |
-
# Get first batch to warm up the dataset pipeline.
|
| 505 |
-
first_batch = next(train_iter)
|
| 506 |
-
# Prepend first batch back to iterator to be used by trainer.
|
| 507 |
-
train_iter = itertools.chain([first_batch], train_iter)
|
| 508 |
-
train_metrics.write_scalar('timing/dataset_warmup_seconds',
|
| 509 |
-
time.time() - ds_tick, host_step)
|
| 510 |
-
logging.info('Compiling train loop.')
|
| 511 |
-
logging.flush()
|
| 512 |
-
trainer.compile_train(first_batch)
|
| 513 |
-
|
| 514 |
-
# Main Loop over "epochs".
|
| 515 |
-
for epoch in range(first_epoch, num_epochs):
|
| 516 |
-
final_epoch = epoch == num_epochs - 1
|
| 517 |
-
logging.info('Epoch %d of %d', epoch, num_epochs)
|
| 518 |
-
|
| 519 |
-
# `stop_training` is requested, break out the main loop immediately.
|
| 520 |
-
if trainer.stop_training:
|
| 521 |
-
break
|
| 522 |
-
|
| 523 |
-
logging.info('BEGIN Train loop.')
|
| 524 |
-
try:
|
| 525 |
-
# Until the last epoch, `num_steps = steps_per_epoch`
|
| 526 |
-
num_steps = min(total_steps - host_step, steps_per_epoch)
|
| 527 |
-
epoch_end_step = host_step + num_steps
|
| 528 |
-
logging.info('Training for %d steps.', num_steps)
|
| 529 |
-
while host_step < epoch_end_step:
|
| 530 |
-
if trainer.stop_training:
|
| 531 |
-
logging.info('Saving a checkpoint before early stopping...')
|
| 532 |
-
checkpointer.save(trainer.train_state,
|
| 533 |
-
checkpoint_cfg.save.state_transformation_fns)
|
| 534 |
-
|
| 535 |
-
if hub_model_id:
|
| 536 |
-
# convert checkpoint to HF Flax model and push to hub
|
| 537 |
-
checkpoint_step = trainer.train_state.step
|
| 538 |
-
checkpoint_step = checkpoint_step.get() if isinstance(checkpoint_step, LazyArray) else checkpoint_step
|
| 539 |
-
checkpoint_step = int(checkpoint_step) # Integer, to avoid side effects in the checkpoint path.
|
| 540 |
-
config_path = os.path.join(model_dir, 'config.json')
|
| 541 |
-
subprocess.run(["python3", "convert_t5x_checkpoint_to_flax.py", f"--t5x_checkpoint_path='checkpoint_{checkpoint_step}'/'", f'--config_name="{config_path}"', "--flax_dump_folder_path='./'"])
|
| 542 |
-
subprocess.run("git lfs prune --verify-remote", shell=True)
|
| 543 |
-
subprocess.run("git add .", shell=True)
|
| 544 |
-
subprocess.run(f'git commit -m "Saving weights and logs of step {checkpoint_step}"', shell=True)
|
| 545 |
-
subprocess.Popen("git push", shell=True)
|
| 546 |
-
|
| 547 |
-
logging.info('Stopping training loop early since `stop_training` is '
|
| 548 |
-
'requested.')
|
| 549 |
-
break
|
| 550 |
-
|
| 551 |
-
inner_num_steps = min(epoch_end_step - host_step, stats_period)
|
| 552 |
-
train_summary = trainer.train(
|
| 553 |
-
train_iter, inner_num_steps, start_step=host_step)
|
| 554 |
-
if not concurrent_metrics:
|
| 555 |
-
# Note that we always pass the dictionary of `tasks` -> summary so
|
| 556 |
-
# that the actions can be performed without special casing. The only
|
| 557 |
-
# caveat is that train would need its own special `key` given no
|
| 558 |
-
# `task` will be applied.
|
| 559 |
-
trainer.stop_training = run_actions(
|
| 560 |
-
trainer_lib.ActionMode.TRAIN, actions, trainer.train_state,
|
| 561 |
-
{TRAIN_METRIC_KEY: train_summary.result()})
|
| 562 |
-
|
| 563 |
-
host_step += inner_num_steps
|
| 564 |
-
logging.info('END Train loop.')
|
| 565 |
-
except trainer_lib.PreemptionError as e:
|
| 566 |
-
logging.info('Saving emergency checkpoint.')
|
| 567 |
-
checkpointer.save(trainer.train_state,
|
| 568 |
-
checkpoint_cfg.save.state_transformation_fns)
|
| 569 |
-
logging.info('Saving emergency checkpoint done.')
|
| 570 |
-
raise e
|
| 571 |
-
|
| 572 |
-
step_offset = host_step - first_step
|
| 573 |
-
|
| 574 |
-
is_eval_epoch = eval_period and (final_epoch or
|
| 575 |
-
step_offset % eval_period == 0)
|
| 576 |
-
|
| 577 |
-
# Training Evaluation (i.e., with teacher forcing).
|
| 578 |
-
if is_eval_epoch and train_eval_datasets:
|
| 579 |
-
# Maybe less if final step < period.
|
| 580 |
-
first_run = step_offset // eval_period <= 1
|
| 581 |
-
_run_training_eval(first_run and not run_eval_before_training)
|
| 582 |
-
|
| 583 |
-
# Maybe save a checkpoint.
|
| 584 |
-
if checkpoint_period and (final_epoch or
|
| 585 |
-
step_offset % checkpoint_period == 0):
|
| 586 |
-
# Make sure last train step has completed before starting the clock.
|
| 587 |
-
train_summary.result()
|
| 588 |
-
logging.info('Saving checkpoint.')
|
| 589 |
-
checkpoint_tick = time.time()
|
| 590 |
-
checkpointer.save(trainer.train_state,
|
| 591 |
-
checkpoint_cfg.save.state_transformation_fns)
|
| 592 |
-
checkpoint_tock = time.time()
|
| 593 |
-
train_metrics.write_scalar('timing/checkpoint_seconds',
|
| 594 |
-
checkpoint_tock - checkpoint_tick, host_step)
|
| 595 |
-
|
| 596 |
-
if hub_model_id:
|
| 597 |
-
# convert checkpoint to HF Flax model and push to hub
|
| 598 |
-
checkpoint_step = trainer.train_state.step
|
| 599 |
-
checkpoint_step = checkpoint_step.get() if isinstance(checkpoint_step, LazyArray) else checkpoint_step
|
| 600 |
-
checkpoint_step = int(checkpoint_step) # Integer, to avoid side effects in the checkpoint path.
|
| 601 |
-
config_path = os.path.join(model_dir, 'config.json')
|
| 602 |
-
subprocess.run(["python3", "convert_t5x_checkpoint_to_flax.py", f"--t5x_checkpoint_path='checkpoint_{checkpoint_step}'/'", f'--config_name="{config_path}"', "--flax_dump_folder_path='./'"])
|
| 603 |
-
subprocess.run("git lfs prune --verify-remote", shell=True)
|
| 604 |
-
subprocess.run("git add .", shell=True)
|
| 605 |
-
subprocess.run(f'git commit -m "Saving weights and logs of step {checkpoint_step}"', shell=True)
|
| 606 |
-
subprocess.Popen("git push", shell=True)
|
| 607 |
-
|
| 608 |
-
# Inference Evaluation (i.e., with decoding or scoring).
|
| 609 |
-
if evaluator is not None:
|
| 610 |
-
_run_inference_eval()
|
| 611 |
-
|
| 612 |
-
# Wait until computations are done before exiting
|
| 613 |
-
logging.info('Finished.')
|
| 614 |
-
trainer.close()
|
| 615 |
-
if evaluator:
|
| 616 |
-
evaluator.close()
|
| 617 |
-
multihost_utils.sync_global_devices('complete')
|
| 618 |
-
|
| 619 |
-
return host_step, trainer.train_state
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
if __name__ == '__main__':
|
| 623 |
-
# pylint: disable=g-import-not-at-top
|
| 624 |
-
from absl import app
|
| 625 |
-
from absl import flags
|
| 626 |
-
import gin
|
| 627 |
-
from t5x import gin_utils
|
| 628 |
-
# pylint: enable=g-import-not-at-top
|
| 629 |
-
|
| 630 |
-
FLAGS = flags.FLAGS
|
| 631 |
-
|
| 632 |
-
jax.config.parse_flags_with_absl()
|
| 633 |
-
|
| 634 |
-
flags.DEFINE_multi_string(
|
| 635 |
-
'gin_file',
|
| 636 |
-
default=None,
|
| 637 |
-
help='Path to gin configuration file. Multiple paths may be passed and '
|
| 638 |
-
'will be imported in the given order, with later configurations '
|
| 639 |
-
'overriding earlier ones.')
|
| 640 |
-
|
| 641 |
-
flags.DEFINE_multi_string(
|
| 642 |
-
'gin_bindings', default=[], help='Individual gin bindings.')
|
| 643 |
-
|
| 644 |
-
flags.DEFINE_list(
|
| 645 |
-
'gin_search_paths',
|
| 646 |
-
default=['.'],
|
| 647 |
-
help='Comma-separated list of gin config path prefixes to be prepended '
|
| 648 |
-
'to suffixes given via `--gin_file`. If a file appears in. Only the '
|
| 649 |
-
'first prefix that produces a valid path for each suffix will be '
|
| 650 |
-
'used.')
|
| 651 |
-
|
| 652 |
-
flags.DEFINE_string(
|
| 653 |
-
'tfds_data_dir', None,
|
| 654 |
-
'If set, this directory will be used to store datasets prepared by '
|
| 655 |
-
'TensorFlow Datasets that are not available in the public TFDS GCS '
|
| 656 |
-
'bucket. Note that this flag overrides the `tfds_data_dir` attribute of '
|
| 657 |
-
'all `Task`s.')
|
| 658 |
-
|
| 659 |
-
flags.DEFINE_list(
|
| 660 |
-
'seqio_additional_cache_dirs', [],
|
| 661 |
-
'Directories to search for cached Tasks in addition to defaults.')
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
def main(argv: Sequence[str]):
|
| 666 |
-
"""Wrapper for pdb post mortems."""
|
| 667 |
-
_main(argv)
|
| 668 |
-
|
| 669 |
-
def _main(argv: Sequence[str]):
|
| 670 |
-
"""True main function."""
|
| 671 |
-
if len(argv) > 1:
|
| 672 |
-
raise app.UsageError('Too many command-line arguments.')
|
| 673 |
-
|
| 674 |
-
if FLAGS.tfds_data_dir:
|
| 675 |
-
seqio.set_tfds_data_dir_override(FLAGS.tfds_data_dir)
|
| 676 |
-
|
| 677 |
-
seqio.add_global_cache_dirs(FLAGS.seqio_additional_cache_dirs)
|
| 678 |
-
|
| 679 |
-
# Create gin-configurable version of `train`.
|
| 680 |
-
train_using_gin = gin.configurable(train)
|
| 681 |
-
|
| 682 |
-
gin_utils.parse_gin_flags(
|
| 683 |
-
# User-provided gin paths take precedence if relative paths conflict.
|
| 684 |
-
FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS,
|
| 685 |
-
FLAGS.gin_file,
|
| 686 |
-
FLAGS.gin_bindings)
|
| 687 |
-
train_using_gin()
|
| 688 |
-
|
| 689 |
-
gin_utils.run(main)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train/events.out.tfevents.1649073594.t1v-n-304587cf-w-0.1316481.0.v2
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:224a0411c5fc4e0e882c7a647ff554b58fec3f79dc12f9809b26b3a319225c1d
|
| 3 |
-
size 7585
|
|
|
|
|
|
|
|
|
|
|
|
train/events.out.tfevents.1649092520.t1v-n-304587cf-w-0.1399566.0.v2
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:d5a8424af960443ad6fbb097216b8add4fb9af5298424f440afb56a27ee260b9
|
| 3 |
-
size 16363
|
|
|
|
|
|
|
|
|
|
|
|
train/{events.out.tfevents.1649056216.t1v-n-304587cf-w-0.1239745.0.v2 → events.out.tfevents.1649705066.t1v-n-304587cf-w-0.2549834.0.v2}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:59fc20fda0f88a5e31f18b0ebc9497d4131b6458ac28d6d2a52875d1ba5c5b13
|
| 3 |
+
size 10402
|
training_eval/pretrain_finnish/events.out.tfevents.1649073594.t1v-n-304587cf-w-0.1316481.1.v2
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:7a4fbca2952ba8dbad98a5e752b6bcf0e80ea9d1376380bbbbd669b7fb0897e7
|
| 3 |
-
size 1431
|
|
|
|
|
|
|
|
|
|
|
|
training_eval/pretrain_finnish/events.out.tfevents.1649092520.t1v-n-304587cf-w-0.1399566.1.v2
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:b3efa0a2b6af0ef032441aaf0e97ac69bab14c84b90bbbfa22001dc094926fb2
|
| 3 |
-
size 9261
|
|
|
|
|
|
|
|
|
|
|
|
training_eval/pretrain_finnish/{events.out.tfevents.1649056216.t1v-n-304587cf-w-0.1239745.1.v2 → events.out.tfevents.1649705066.t1v-n-304587cf-w-0.2549834.1.v2}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6ec5c452edec8f5036cab8e5f3d67f492e55a15ff004d4ee5847b2d2cd56f2df
|
| 3 |
+
size 4024
|