#!/usr/bin/env python3 """ bin_to_safetensors.py Convert a PyTorch checkpoint (e.g., pytorch_model.bin / .pt / .ckpt) to a .safetensors file. - Safe tensors only: tensors are saved; non-tensor Python objects (optimizer, schedulers, etc.) are ignored. - Heuristics try to locate a model state_dict within common training checkpoints. USAGE: python bin_to_safetensors.py --in pytorch_model.bin --out model.safetensors python bin_to_safetensors.py --in trainer.ckpt --out model.safetensors NOTE: Loading with torch.load uses pickle and can execute code from untrusted sources. Only run this on checkpoints from sources you trust. """ import argparse import sys from typing import Dict, Any import torch from safetensors.torch import save_file, is_safe_tensor def _is_tensor_dict(d: Any) -> bool: if not isinstance(d, dict) or not d: return False # Determine if all values are (Tensor | ShardedTensor-like) for v in d.values(): if not (torch.is_tensor(v) or (hasattr(v, "tensor") and torch.is_tensor(getattr(v, "tensor")))): return False return True def _extract_state_dict(obj: Any) -> Dict[str, torch.Tensor]: """ Try to extract a {name: tensor} dict from various checkpoint formats. """ # Case 1: Already a tensor dict (typical HF "pytorch_model.bin") if _is_tensor_dict(obj): # Ensure tensors are on CPU and contiguous return {k: (v.detach().cpu().contiguous() if torch.is_tensor(v) else v.tensor.detach().cpu().contiguous()) for k, v in obj.items()} if isinstance(obj, dict): # Common keys from training frameworks candidate_keys = [ "state_dict", "model_state_dict", "model", "module", # sometimes raw module.state_dict() is stored here "network", "net", "weights", ] for ck in candidate_keys: if ck in obj and _is_tensor_dict(obj[ck]): d = obj[ck] return {k: (v.detach().cpu().contiguous() if torch.is_tensor(v) else v.tensor.detach().cpu().contiguous()) for k, v in d.items()} # Lightning-style: sometimes stored under "state_dict" but with prefixes, or nested for k, v in obj.items(): if _is_tensor_dict(v): d = v return {kk: (vv.detach().cpu().contiguous() if torch.is_tensor(vv) else vv.tensor.detach().cpu().contiguous()) for kk, vv in d.items()} raise ValueError( "Could not find a model state_dict (a dict of tensors). " "If this is a full training checkpoint, load it in Python, extract model.state_dict(), " "and save that mapping instead." ) def convert_bin_to_safetensors(in_path: str, out_path: str, metadata: Dict[str, str] = None) -> None: # TRUST WARNING: torch.load uses pickle. Only load from trusted files. obj = torch.load(in_path, map_location="cpu") # If the file is already safetensors, bail out politely. if isinstance(obj, (bytes, bytearray)) and is_safe_tensor(obj): print(f"Input appears to already be a safetensors file: {in_path}") return state = _extract_state_dict(obj) # Optional basic metadata meta = {"format": "converted-from-pytorch-bin"} if metadata: meta.update({str(k): str(v) for k, v in metadata.items()}) # Save save_file(state, out_path, metadata=meta) print(f"✅ Wrote {out_path} with {len(state)} tensors.") def main(argv=None): parser = argparse.ArgumentParser(description="Convert PyTorch .bin/.pt/.ckpt to .safetensors") parser.add_argument("--in", dest="in_path", required=True, help="Input .bin/.pt/.ckpt file path") parser.add_argument("--out", dest="out_path", required=True, help="Output .safetensors file path") parser.add_argument("--meta", nargs="*", default=[], help='Optional metadata entries like key=value (repeatable)') args = parser.parse_args(argv) metadata = {} for item in args.meta: if "=" in item: k, v = item.split("=", 1) metadata[k] = v else: print(f"Warning: ignoring malformed --meta entry (expected key=value): {item}", file=sys.stderr) convert_bin_to_safetensors(args.in_path, args.out_path, metadata) if __name__ == "__main__": main()