Spaces:
Runtime error
Runtime error
| # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # // | |
| # // Licensed under the Apache License, Version 2.0 (the "License"); | |
| # // you may not use this file except in compliance with the License. | |
| # // You may obtain a copy of the License at | |
| # // | |
| # // http://www.apache.org/licenses/LICENSE-2.0 | |
| # // | |
| # // Unless required by applicable law or agreed to in writing, software | |
| # // distributed under the License is distributed on an "AS IS" BASIS, | |
| # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # // See the License for the specific language governing permissions and | |
| # // limitations under the License. | |
| import os | |
| import random | |
| import threading | |
| from abc import ABC | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from dataclasses import dataclass | |
| from functools import partial | |
| from itertools import chain | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| import pyarrow as pa | |
| import pyarrow.parquet as pq | |
| from omegaconf import DictConfig | |
| from common.distributed import get_global_rank, get_world_size | |
| from common.fs import copy, exists, listdir, mkdir, remove | |
| from common.partition import partition_by_groups | |
| from common.persistence.utils import get_local_path | |
| from data.common.parquet_sampler import ( | |
| IdentityParquetSampler, | |
| ParquetSampler, | |
| create_parquet_sampler, | |
| ) | |
| from data.common.utils import filter_parquets, get_parquet_metadata | |
| # Function to save a Parquet file and copy it to a target path | |
| def save_and_copy( | |
| pa_table, | |
| local_path: str, | |
| target_path: str, | |
| row_group_size: int, | |
| executor: ThreadPoolExecutor, | |
| do_async: bool = False, | |
| futures: List[Tuple[threading.Thread, str]] = [], | |
| ): | |
| # Function to handle completion of the future | |
| def _make_on_complete(local_path): | |
| def _on_complete(future): | |
| target_path = future.result() | |
| remove(local_path) | |
| # del future | |
| print(f"Target path saved: {target_path}") | |
| return _on_complete | |
| # Function to write Parquet table and copy it | |
| def _fn(pa_table, local_path, target_path, row_group_size): | |
| pq.write_table( | |
| pa_table, | |
| local_path, | |
| row_group_size=row_group_size, | |
| ) | |
| mkdir(os.path.dirname(target_path)) | |
| copy(local_path, target_path) | |
| return target_path | |
| # Submit the task to the executor | |
| future = executor.submit(_fn, pa_table, local_path, target_path, row_group_size) | |
| future.add_done_callback(_make_on_complete(local_path)) | |
| futures.append(future) | |
| # If not asynchronous, wait for all futures to complete | |
| if not do_async: | |
| for future in as_completed(futures): | |
| try: | |
| future.result() | |
| except Exception as exc: | |
| print(f"Generated an exception: {exc}") | |
| executor.shutdown(wait=True) | |
| class FileListOutput: | |
| existing_files: List[str] | |
| source_files: List[Any] | |
| target_files: List[str] | |
| class PersistedParquet: | |
| path: str | |
| # Method to save the Parquet file | |
| def save( | |
| self, | |
| row_group_size: int, | |
| executor: ThreadPoolExecutor, | |
| pa_table: Optional[pa.Table] = None, | |
| data_dict: Optional[Dict[str, List[Union[str, bytes]]]] = None, | |
| is_last_file=False, | |
| futures: List[threading.Thread] = [], | |
| ): | |
| assert (pa_table is None) != (data_dict is None) | |
| local_path = get_local_path(self.path) | |
| if not pa_table: | |
| schema_dict = self.generate_schema_from_dict(data_dict) | |
| pa_table = pa.Table.from_pydict(data_dict, schema=schema_dict) | |
| save_and_copy( | |
| pa_table, | |
| local_path=local_path, | |
| target_path=self.path, | |
| row_group_size=row_group_size, | |
| executor=executor, | |
| do_async=not is_last_file, | |
| futures=futures, | |
| ) | |
| # Method to generate schema from a dictionary | |
| def generate_schema_from_dict( | |
| self, | |
| data_dict: Dict[str, List[Union[str, bytes]]], | |
| ): | |
| schema_dict = {} | |
| for key, value in data_dict.items(): | |
| if isinstance(value[0], str): | |
| schema_dict[key] = pa.string() | |
| elif isinstance(value[0], bytes): | |
| schema_dict[key] = pa.binary() | |
| else: | |
| raise ValueError(f"Unsupported data type for key '{key}': {type(value)}") | |
| return pa.schema(schema_dict) | |
| # Base class for managing Parquet files | |
| class ParquetManager(ABC): | |
| """ | |
| Base class for the DumpingManager and RepackingManager. | |
| """ | |
| def __init__( | |
| self, | |
| task: Optional[DictConfig] = None, | |
| target_dir: str = ".", | |
| ): | |
| self.task = task | |
| self.target_dir = target_dir.rstrip("/") | |
| self.executor = ThreadPoolExecutor(max_workers=4) | |
| self.futures = [] | |
| # Method to get list of Parquet files from source path | |
| def get_parquet_files( | |
| self, | |
| source_path: str, | |
| parquet_sampler: ParquetSampler = IdentityParquetSampler(), | |
| path_mode: str = "dir", | |
| ): | |
| # Helper function to flatten nested lists | |
| def _flatten(paths): | |
| if isinstance(paths, list): | |
| if any(isinstance(i, list) for i in paths): | |
| return list(chain(*paths)) | |
| else: | |
| return paths | |
| else: | |
| return [paths] | |
| file_paths = _flatten(source_path) | |
| if path_mode == "dir": | |
| file_paths = map(listdir, file_paths) | |
| if isinstance(parquet_sampler.size, float): | |
| file_paths = map(filter_parquets, file_paths) | |
| file_paths = map(parquet_sampler, file_paths) | |
| file_paths = list(chain(*file_paths)) | |
| else: | |
| file_paths = chain(*file_paths) | |
| file_paths = parquet_sampler(filter_parquets(file_paths)) | |
| return file_paths | |
| # Method to save a Parquet file | |
| def save_parquet( | |
| self, | |
| *, | |
| file_name: str, | |
| row_group_size: int, | |
| pa_table: Optional[pa.Table] = None, | |
| data_dict: Optional[Dict[str, List[Union[str, bytes]]]] = None, | |
| override: bool = True, | |
| is_last_file: bool = False, | |
| ): | |
| persist = self._get_parquet(file_name) | |
| if override or not exists(persist.path): | |
| persist.save( | |
| pa_table=pa_table, | |
| data_dict=data_dict, | |
| executor=self.executor, | |
| row_group_size=row_group_size, | |
| is_last_file=is_last_file, | |
| futures=self.futures, | |
| ) | |
| # Method to get a PersistedParquet object | |
| def _get_parquet(self, file_name: str) -> PersistedParquet: | |
| return PersistedParquet(file_name) | |
| # Class to manage dumping of Parquet files | |
| class DumpingManager(ParquetManager): | |
| """ | |
| Dumping manager handles parquet saving and resuming. | |
| """ | |
| def __init__( | |
| self, | |
| task: DictConfig, | |
| target_dir: str, | |
| ): | |
| super().__init__(task=task, target_dir=target_dir) | |
| # Method to generate saving path | |
| def generate_saving_path(self, file_path: str, rsplit: int): | |
| part_list = file_path.rsplit("/", rsplit) | |
| result_folder = "/".join( | |
| [self.target_dir] + [f"epoch_{self.task.epoch}"] + part_list[-rsplit:-1] | |
| ) | |
| result_file = "/".join([result_folder, part_list[-1]]) | |
| return result_folder, result_file | |
| # Method to configure task paths | |
| def configure_task_path(self, source_path: str, rsplit: int, path_mode: str = "dir"): | |
| file_paths = self.get_parquet_files( | |
| source_path=source_path, | |
| path_mode=path_mode, | |
| ) | |
| # Shuffle file paths | |
| random.Random(0).shuffle(file_paths) | |
| # Partition the file paths based on task configuration | |
| full_source_files = partition_by_groups(file_paths, self.task.total_count)[self.task.index] | |
| full_source_files = partition_by_groups(full_source_files, get_world_size())[ | |
| get_global_rank() | |
| ] | |
| if not full_source_files: | |
| return FileListOutput([], [], []) | |
| generate_saving_path = partial(self.generate_saving_path, rsplit=rsplit) | |
| full_paths = map(generate_saving_path, full_source_files) | |
| full_target_folders, full_target_files = map(list, zip(*full_paths)) | |
| full_target_folders = set(full_target_folders) | |
| existing_file_paths = map( | |
| lambda folder: listdir(folder) if exists(folder) else [], full_target_folders | |
| ) | |
| existing_file_paths = chain(*existing_file_paths) | |
| self.existing_files = list( | |
| filter( | |
| lambda path: path.endswith(".parquet") and path in full_target_files, | |
| existing_file_paths, | |
| ) | |
| ) | |
| filtered_pairs = list( | |
| filter( | |
| lambda pair: pair[1] not in self.existing_files, | |
| zip(full_source_files, full_target_files), | |
| ) | |
| ) | |
| if filtered_pairs: | |
| filtered_source_files, filtered_target_files = map(list, zip(*filtered_pairs)) | |
| else: | |
| filtered_source_files, filtered_target_files = [], [] | |
| # Skip existing file paths if specified | |
| skip_exists = self.task.skip_exists | |
| self.source_files = filtered_source_files if skip_exists else full_source_files | |
| self.target_files = filtered_target_files if skip_exists else full_target_files | |
| return FileListOutput(self.existing_files, self.source_files, self.target_files) | |
| class RepackingManager(ParquetManager): | |
| """ | |
| Repacking manager handles parquet spliting and saving. | |
| """ | |
| def __init__( | |
| self, | |
| task: DictConfig, | |
| target_dir: str, | |
| repackaging: DictConfig, | |
| ): | |
| super().__init__(task=task, target_dir=target_dir) | |
| self.repackaging = repackaging | |
| # Configure the task paths for repacking | |
| def configure_task_path( | |
| self, | |
| source_path: str, | |
| parquet_sampler: Optional[DictConfig] = None, | |
| path_mode: str = "dir", | |
| ): | |
| parquet_sampler = create_parquet_sampler(config=parquet_sampler) | |
| file_paths = self.get_parquet_files( | |
| source_path=source_path, | |
| parquet_sampler=parquet_sampler, | |
| path_mode=path_mode, | |
| ) | |
| random.Random(0).shuffle(file_paths) | |
| target_dir = self.target_dir | |
| size = abs(parquet_sampler.size) | |
| if self.task: | |
| # Partition the file paths based on task configuration | |
| file_paths = partition_by_groups(file_paths, self.task.total_count)[self.task.index] | |
| target_dir = os.path.join(target_dir, f"{self.task.total_count}_{self.task.index}") | |
| if size > 1: | |
| size = len( | |
| partition_by_groups(range(size), self.task.total_count)[self.task.index] | |
| ) | |
| # Get metadata for each Parquet file | |
| metadatas = get_parquet_metadata(file_paths, self.repackaging.num_processes) | |
| # Create a list of (file_path, row) tuples for each row in the files | |
| target_items = [ | |
| (file_path, row) | |
| for file_path, metadata in zip(file_paths, metadatas) | |
| for row in range(metadata.num_rows) | |
| ] | |
| # Shuffle the target items | |
| random.Random(0).shuffle(target_items) | |
| if size > 1: | |
| target_items = target_items[:size] | |
| # Partition the items into groups for each target file | |
| items_per_file = partition_by_groups(target_items, self.repackaging.num_files) | |
| # Generate target file paths | |
| target_files = [ | |
| os.path.join(target_dir, f"{str(i).zfill(5)}.parquet") | |
| for i in range(self.repackaging.num_files) | |
| ] | |
| existing_file_paths = listdir(target_dir) if exists(target_dir) else [] | |
| self.existing_files = list( | |
| filter( | |
| lambda path: path.endswith(".parquet"), | |
| existing_file_paths, | |
| ) | |
| ) | |
| self.source_files = items_per_file | |
| self.target_files = target_files | |
| return FileListOutput(self.existing_files, self.source_files, self.target_files) | |