import os import torch.distributed as dist class COMM_INFO: def __init__(self): self.group = None self.sp_size = 1 self.global_rank = 0 self.rank_within_group = 0 self.group_id = 0 nccl_info = COMM_INFO() _SEQUENCE_PARALLEL_STATE = False def initialize_sequence_parallel_state(sequence_parallel_size): global _SEQUENCE_PARALLEL_STATE if sequence_parallel_size > 1: _SEQUENCE_PARALLEL_STATE = True initialize_sequence_parallel_group(sequence_parallel_size) else: nccl_info.sp_size = 1 nccl_info.global_rank = int(os.getenv("RANK", "0")) nccl_info.rank_within_group = 0 nccl_info.group_id = int(os.getenv("RANK", "0")) def set_sequence_parallel_state(state): global _SEQUENCE_PARALLEL_STATE _SEQUENCE_PARALLEL_STATE = state def get_sequence_parallel_state(): return _SEQUENCE_PARALLEL_STATE def initialize_sequence_parallel_group(sequence_parallel_size): """Initialize the sequence parallel group.""" rank = int(os.getenv("RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "1")) assert ( world_size % sequence_parallel_size == 0 ), "world_size must be divisible by sequence_parallel_size, but got world_size: {}, sequence_parallel_size: {}".format( world_size, sequence_parallel_size) nccl_info.sp_size = sequence_parallel_size nccl_info.global_rank = rank num_sequence_parallel_groups: int = world_size // sequence_parallel_size for i in range(num_sequence_parallel_groups): ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size) group = dist.new_group(ranks) if rank in ranks: nccl_info.group = group nccl_info.rank_within_group = rank - i * sequence_parallel_size nccl_info.group_id = i def initialize_sequence_parallel_group_custom(process_group): set_sequence_parallel_state(True) """Initialize an unsafe sequence parallel group with a pre-formed group.""" rank = dist.get_rank(group=process_group) sequence_parallel_size = dist.get_world_size(group=process_group) nccl_info.sp_size = sequence_parallel_size nccl_info.global_rank = dist.get_rank() # global rank nccl_info.group = process_group nccl_info.rank_within_group = rank nccl_info.group_id = 0 def destroy_sequence_parallel_group(): """Destroy the sequence parallel group.""" dist.destroy_process_group()