Spaces:
Running
Running
| import gradio as gr | |
| import subprocess | |
| import os | |
| import logging | |
| from pathlib import Path | |
| import spaces | |
| def merge_and_upload(base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token, commit_message): | |
| # Define a fixed output path | |
| outpath = Path('/tmp') | |
| # Construct the command to run hf_merge.py | |
| command = [ | |
| "python3", "hf_merge.py", | |
| base_model, | |
| model_to_merge, | |
| "-p", str(weight_drop_prob), | |
| "-lambda", str(scaling_factor), | |
| "--token", token, | |
| "--repo", repo_name, | |
| "--commit-message", commit_message, | |
| "-U" | |
| ] | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| log_output = "" | |
| # Run the command and capture the output | |
| result = subprocess.run(command, capture_output=True, text=True) | |
| # Log the output | |
| log_output += result.stdout + "\n" | |
| log_output += result.stderr + "\n" | |
| logging.info(result.stdout) | |
| logging.error(result.stderr) | |
| # Check if the merge was successful | |
| if result.returncode != 0: | |
| return None, f"Error in merging models: {result.stderr}", log_output | |
| # Assuming the script handles the upload and returns the repo URL | |
| repo_url = f"https://huggingface.co/{repo_name}" | |
| return repo_url, "Model merged and uploaded successfully!", log_output | |
| # Define the Gradio interface | |
| with gr.Blocks(theme="Ytheme/Minecraft", fill_width=True, delete_cache=(60, 3600)) as demo: | |
| gr.Markdown("# SuperMario Safetensors Merger") | |
| gr.Markdown("Combine any two models using a Super Mario merge(DARE)") | |
| gr.Markdown("Based on: https://github.com/martyn/safetensors-merge-supermario") | |
| gr.Markdown("Works with:") | |
| gr.Markdown("* Stable Diffusion (1.5, XL/XL Turbo)") | |
| gr.Markdown("* LLMs (Mistral, Llama, etc) (also works with Llava, Visison models) ") | |
| gr.Markdown("* LoRas (must be same size)") | |
| gr.Markdown("* Any two homologous models") | |
| with gr.Column(): | |
| with gr.Row(): | |
| token = gr.Textbox(label="Your HF write token", placeholder="hf_...", value="", max_lines=1) | |
| with gr.Row(): | |
| base_model = gr.Textbox(label="Base Model", placeholder="meta-llama/Llama-3.2-11B-Vision-Instruct", info="Safetensors format") | |
| with gr.Row(): | |
| model_to_merge = gr.Textbox(label="Merge Model", placeholder="Qwen/Qwen2.5-Coder-7B-Instruct", info="Safetensors or .bin") | |
| with gr.Row(): | |
| repo_name = gr.Textbox(label="New Model", placeholder="Llama-Qwen-Vision_Instruct", info="your-username/new-model-name", value="", max_lines=1) | |
| with gr.Row(): | |
| scaling_factor = gr.Slider(minimum=0, maximum=10, value=3.0, label="Scaling Factor") | |
| with gr.Row(): | |
| weight_drop_prob = gr.Slider(minimum=0, maximum=1, value=0.3, label="Weight Drop Probability") | |
| with gr.Row(): | |
| commit_message = gr.Textbox(label="Commit Message", value="Upload merged model", max_lines=1) | |
| progress = gr.Progress() | |
| repo_url = gr.Markdown(label="Repository URL") | |
| output = gr.Textbox(label="Output") | |
| gr.Button("Merge").click( | |
| merge_and_upload, | |
| inputs=[base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token, commit_message], | |
| outputs=[repo_url, output] | |
| ) | |
| demo.launch() |