zswzswzsw's picture
Upload folder using huggingface_hub
66407c5 verified
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
An naive implementation of split placment example
"""
import uuid
from copy import deepcopy
from pprint import pprint
import numpy as np
import torch
from verl import DataProto
from verl.trainer.ppo.ray_trainer import (
AdvantageEstimator,
apply_kl_penalty,
compute_advantage,
compute_data_metrics,
compute_timing_metrics,
marked_timer,
)
from verl.utils.metric import reduce_metrics
def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC
to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
from omegaconf import OmegaConf
from verl.utils.tracking import Tracking
logger = Tracking(
project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True),
)
self.global_steps = 0
# load checkpoint before doing anything
self._load_checkpoint()
# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
val_metrics = self._validate()
pprint(f"Initial validation metrics: {val_metrics}")
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get("val_only", False):
return
# we start from step 1
self.global_steps += 1
last_val_metrics = None
for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
metrics = {}
timing_raw = {}
batch: DataProto = DataProto.from_single_dict(batch_dict)
# pop those keys for generation
gen_batch = batch.pop(batch_keys=["input_ids", "attention_mask", "position_ids"])
is_last_step = self.global_steps >= self.total_training_steps
with marked_timer("step", timing_raw):
# generate a batch
with marked_timer("gen", timing_raw):
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
timing_raw.update(gen_batch_output.meta_info["timing"])
gen_batch_output.meta_info.pop("timing", None)
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
with marked_timer("gen_max", timing_raw):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["do_sample"] = False
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
batch = batch.union(gen_baseline_output)
reward_baseline_tensor = self.reward_fn(batch)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
batch.batch["reward_baselines"] = reward_baseline_tensor
del gen_baseline_batch, gen_baseline_output
batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object)
# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)
# Balance the number of valid tokens across DP ranks.
# NOTE: This usually changes the order of data in the `batch`,
# which won't affect the advantage calculation (since it's based on uid),
# but might affect the loss calculation (due to the change of mini-batching).
# TODO: Decouple the DP balancing and mini-batching.
self._balance_batch(batch, metrics=metrics)
# compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
# recompute old_log_probs
with marked_timer("old_log_prob", timing_raw):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
batch = batch.union(old_log_prob)
if self.use_reference_policy:
# compute reference log_prob
with marked_timer("ref", timing_raw):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
# compute values
if self.use_critic:
with marked_timer("values", timing_raw):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
with marked_timer("adv", timing_raw):
# compute scores. Support both model and function-based.
# We first compute the scores using reward model. Then, we call reward_fn to combine
# the results from reward model and rule-based results.
if self.use_rm:
# we first compute reward model score
reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)
# we combine with rule-based rm
reward_tensor = self.reward_fn(batch)
batch.batch["token_level_scores"] = reward_tensor
# compute rewards. apply_kl_penalty if available
if self.config.algorithm.use_kl_in_reward:
batch, kl_metrics = apply_kl_penalty(batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty)
metrics.update(kl_metrics)
else:
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
# compute advantages, executed on the driver process
norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True)
batch = compute_advantage(
batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n,
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
)
# update critic
if self.use_critic:
with marked_timer("update_critic_call", timing_raw):
critic_output = self.critic_wg.update_critic(batch)
# implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
with marked_timer("update_actor_call", timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch)
# NOTE: make sure you set blocking=False in update_actor and update_crtic in the worker class
with marked_timer("update_actor_critic", timing_raw):
critic_output = critic_output.get()
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
metrics.update(critic_output_metrics)
actor_output = actor_output.get()
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)
# validate
if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0):
with marked_timer("testing", timing_raw):
val_metrics: dict = self._validate()
if is_last_step:
last_val_metrics = val_metrics
metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.save_freq == 0):
with marked_timer("save_checkpoint", timing_raw):
self._save_checkpoint()
# collect metrics
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)
if self.global_steps >= self.total_training_steps:
pprint(f"Final validation metrics: {last_val_metrics}")
return
self.global_steps += 1