| |
|
| | import functools |
| | import seqio |
| | import tensorflow as tf |
| | import t5.data |
| | from datasets import load_dataset, load_from_disk |
| | from t5.data import postprocessors |
| | from t5.data import preprocessors |
| | from t5.evaluation import metrics |
| | from seqio import FunctionDataSource, utils |
| |
|
| | from ul2_objective import ul2_objective |
| |
|
| | |
| | R_DENOISER_SPAN_LENGTHS = [3.0, 8.0] |
| | X_DENOISER_SPAN_LENGTHS = [3.0, 8.0, 64.0, 64.0] |
| | R_DENOISER_CORRUPT_RATES = [0.15, 0.15] |
| | X_DENOISER_CORRUPT_RATES = [0.5, 0.5, 0.15, 0.5] |
| |
|
| | R_DENOISER_TOKEN_PREFIX = '[NLU]' |
| | X_DENOISER_TOKEN_PREFIX = '[NLG]' |
| | S_DENOISER_TOKEN_PREFIX = '[S2S]' |
| |
|
| | TaskRegistry = seqio.TaskRegistry |
| |
|
| | vocabulary = seqio.SentencePieceVocabulary('spiece.model', extra_ids=0) |
| |
|
| | DEFAULT_OUTPUT_FEATURES = { |
| | "inputs": seqio.Feature( |
| | vocabulary=vocabulary, add_eos=True, |
| | required=False), |
| | "targets": seqio.Feature( |
| | vocabulary=vocabulary, add_eos=True) |
| | } |
| |
|
| |
|
| | def gen_dataset(split, shuffle=False, seed=None, column="text", dataset=None): |
| | if shuffle: |
| | if seed: |
| | dataset = dataset.shuffle(seed=seed) |
| | else: |
| | dataset = dataset.shuffle() |
| | while True: |
| | for item in dataset[str(split)]: |
| | if item[column] is not None: |
| | yield item[column] |
| |
|
| |
|
| | def dataset_fn(split, shuffle_files, seed=None, dataset=None): |
| | return tf.data.Dataset.from_generator( |
| | functools.partial(gen_dataset, split, shuffle_files, |
| | seed, dataset=dataset), |
| | output_signature=tf.TensorSpec( |
| | shape=(), dtype=tf.string, name=dataset_name) |
| | ) |
| |
|
| |
|
| | @utils.map_over_dataset |
| | def target_to_key(x, key_map, target_key): |
| | """Assign the value from the dataset to target_key in key_map""" |
| | return {**key_map, target_key: x} |
| |
|
| |
|
| | dataset_name = "/home/sdeshpande/data/medical_dataset" |
| | dataset_params = {"from_disk_path": dataset_name} |
| |
|
| | if "from_disk_path" in dataset_params: |
| | dataset = load_from_disk(dataset_params.get("from_disk_path")) |
| | else: |
| | dataset = load_dataset(**dataset_params) |
| |
|
| | dataset_shapes = {"train": dataset["train"].num_rows, |
| | "validation": dataset["validation"].num_rows} |
| |
|
| | TaskRegistry.add( |
| | "pretrain_medical_ul2", |
| | source=seqio.FunctionDataSource( |
| | dataset_fn=functools.partial(dataset_fn, dataset=dataset), |
| | splits=("train", "validation"), |
| | caching_permitted=False, |
| | num_input_examples=dataset_shapes, |
| | ), |
| | preprocessors=[ |
| | functools.partial( |
| | target_to_key, key_map={ |
| | "inputs": None, |
| | "targets": None, |
| | }, target_key="targets"), |
| | seqio.preprocessors.tokenize, |
| | functools.partial( |
| | ul2_objective, |
| | shard_ds=False, |
| | use_prefix_lm_task=True, |
| | rates=[0.4 / len(R_DENOISER_SPAN_LENGTHS)]*len(R_DENOISER_SPAN_LENGTHS) + [ |
| | 0.4 / len(X_DENOISER_SPAN_LENGTHS)]*len(X_DENOISER_SPAN_LENGTHS) + [0.2], |
| | mean_noise_span_lengths=R_DENOISER_SPAN_LENGTHS + X_DENOISER_SPAN_LENGTHS, |
| | noise_densities=R_DENOISER_CORRUPT_RATES + X_DENOISER_CORRUPT_RATES, |
| | optional_task_prefixes=[R_DENOISER_TOKEN_PREFIX]*len(R_DENOISER_SPAN_LENGTHS) + [ |
| | X_DENOISER_TOKEN_PREFIX]*len(X_DENOISER_SPAN_LENGTHS) + [S_DENOISER_TOKEN_PREFIX], |
| | reserved_for_packing=1, |
| | ), |
| | seqio.preprocessors.append_eos_after_trim, |
| | ], |
| | output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]}, |
| | metric_fns=[metrics.accuracy] |
| | ) |