Spaces:
Runtime error
Runtime error
| import gc | |
| import yaml | |
| import torch | |
| from transformers import GenerationConfig | |
| from models import alpaca, stablelm, koalpaca, flan_alpaca, mpt | |
| from models import camel, t5_vicuna, vicuna, starchat, redpajama, bloom | |
| from models import baize, guanaco, falcon, kullm, replit, airoboros | |
| from models import samantha_vicuna, wizard_coder, xgen, freewilly | |
| from models import byom | |
| cuda_availability = False | |
| available_vrams_gb = 0 | |
| mps_availability = False | |
| if torch.cuda.is_available(): | |
| cuda_availability = True | |
| available_vrams_mb = sum( | |
| [ | |
| torch.cuda.get_device_properties(i).total_memory | |
| for i in range(torch.cuda.device_count()) | |
| ] | |
| ) / 1024. / 1024 | |
| if torch.backends.mps.is_available(): | |
| mps_availability = True | |
| def initialize_globals_byom( | |
| base, ckpt, model_cls, tokenizer_cls, | |
| bos_token_id, eos_token_id, pad_token_id, | |
| mode_cpu, model_mps, mode_8bit, mode_4bit, mode_full_gpu | |
| ): | |
| global model, model_type, stream_model, tokenizer | |
| global model_thumbnail_tiny, device | |
| global gen_config, gen_config_raw | |
| global gen_config_summarization | |
| model_type = "custom" | |
| model, tokenizer = byom.load_model( | |
| base=base, | |
| finetuned=ckpt, | |
| mode_cpu=mode_cpu, | |
| mode_mps=mode_mps, | |
| mode_full_gpu=mode_full_gpu, | |
| mode_8bit=mode_8bit, | |
| mode_4bit=mode_4bit, | |
| model_cls=model_cls if model_cls != "" else None, | |
| tokenizer_cls=tokenizer_cls if tokenizer_cls != "" else None | |
| ) | |
| stream_model = model | |
| gen_config, gen_config_raw = get_generation_config("configs/response_configs/default.yaml") | |
| gen_config_summarization, _ = get_generation_config("configs/summarization_configs/default.yaml") | |
| if bos_token_id != "" or bos_token_id.isdigit(): | |
| gen_config.bos_token_id = int(bos_token_id) | |
| if eos_token_id != "" or eos_token_id.isdigit(): | |
| gen_config.eos_token_id = int(eos_token_id) | |
| if pad_token_id != "" or pad_token_id.isdigit(): | |
| gen_config.pad_token_id = int(pad_token_id) | |
| def initialize_globals(args): | |
| global device, model_thumbnail_tiny | |
| global model, model_type, stream_model, tokenizer | |
| global gen_config, gen_config_raw | |
| global gen_config_summarization | |
| model_type_tmp = "alpaca" | |
| if "stabilityai/freewilly2" in args.base_url.lower(): | |
| model_type_tmp = "free-willy" | |
| elif "upstage/llama-" in args.base_url.lower(): | |
| model_type_tmp = "upstage-llama" | |
| elif "llama-2" in args.base_url.lower(): | |
| model_type_tmp = "llama2" | |
| elif "xgen" in args.base_url.lower(): | |
| model_type_tmp = "xgen" | |
| elif "orca_mini" in args.base_url.lower(): | |
| model_type_tmp = "orcamini" | |
| elif "open-llama" in args.base_url.lower(): | |
| model_type_tmp = "openllama" | |
| elif "wizardcoder" in args.base_url.lower(): | |
| model_type_tmp = "wizard-coder" | |
| elif "wizard-vicuna" in args.base_url.lower(): | |
| model_type_tmp = "wizard-vicuna" | |
| elif "llms/wizardlm" in args.base_url.lower(): | |
| model_type_tmp = "wizardlm" | |
| elif "chronos" in args.base_url.lower(): | |
| model_type_tmp = "chronos" | |
| elif "lazarus" in args.base_url.lower(): | |
| model_type_tmp = "lazarus" | |
| elif "samantha" in args.base_url.lower(): | |
| model_type_tmp = "samantha-vicuna" | |
| elif "airoboros" in args.base_url.lower(): | |
| model_type_tmp = "airoboros" | |
| elif "replit" in args.base_url.lower(): | |
| model_type_tmp = "replit-instruct" | |
| elif "kullm" in args.base_url.lower(): | |
| model_type_tmp = "kullm-polyglot" | |
| elif "nous-hermes" in args.base_url.lower(): | |
| model_type_tmp = "nous-hermes" | |
| elif "guanaco" in args.base_url.lower(): | |
| model_type_tmp = "guanaco" | |
| elif "wizardlm-uncensored-falcon" in args.base_url.lower(): | |
| model_type_tmp = "wizard-falcon" | |
| elif "falcon" in args.base_url.lower(): | |
| model_type_tmp = "falcon" | |
| elif "baize" in args.base_url.lower(): | |
| model_type_tmp = "baize" | |
| elif "stable-vicuna" in args.base_url.lower(): | |
| model_type_tmp = "stable-vicuna" | |
| elif "vicuna" in args.base_url.lower(): | |
| model_type_tmp = "vicuna" | |
| elif "mpt" in args.base_url.lower(): | |
| model_type_tmp = "mpt" | |
| elif "redpajama-incite-7b-instruct" in args.base_url.lower(): | |
| model_type_tmp = "redpajama-instruct" | |
| elif "redpajama" in args.base_url.lower(): | |
| model_type_tmp = "redpajama" | |
| elif "starchat" in args.base_url.lower(): | |
| model_type_tmp = "starchat" | |
| elif "camel" in args.base_url.lower(): | |
| model_type_tmp = "camel" | |
| elif "flan-alpaca" in args.base_url.lower(): | |
| model_type_tmp = "flan-alpaca" | |
| elif "openassistant/stablelm" in args.base_url.lower(): | |
| model_type_tmp = "os-stablelm" | |
| elif "stablelm" in args.base_url.lower(): | |
| model_type_tmp = "stablelm" | |
| elif "fastchat-t5" in args.base_url.lower(): | |
| model_type_tmp = "t5-vicuna" | |
| elif "koalpaca-polyglot" in args.base_url.lower(): | |
| model_type_tmp = "koalpaca-polyglot" | |
| elif "alpacagpt4" in args.ft_ckpt_url.lower(): | |
| model_type_tmp = "alpaca-gpt4" | |
| elif "alpaca" in args.ft_ckpt_url.lower(): | |
| model_type_tmp = "alpaca" | |
| elif "llama-deus" in args.ft_ckpt_url.lower(): | |
| model_type_tmp = "llama-deus" | |
| elif "vicuna-lora-evolinstruct" in args.ft_ckpt_url.lower(): | |
| model_type_tmp = "evolinstruct-vicuna" | |
| elif "alpacoom" in args.ft_ckpt_url.lower(): | |
| model_type_tmp = "alpacoom" | |
| elif "guanaco" in args.ft_ckpt_url.lower(): | |
| model_type_tmp = "guanaco" | |
| else: | |
| print("unsupported model type") | |
| quit() | |
| print(f"determined model type: {model_type_tmp}") | |
| device = "cpu" | |
| if args.mode_cpu: | |
| device = "cpu" | |
| elif args.mode_mps: | |
| device = "mps" | |
| else: | |
| device = "cuda" | |
| try: | |
| if model is not None: | |
| del model | |
| if stream_model is not None: | |
| del stream_model | |
| if tokenizer is not None: | |
| del tokenizer | |
| gc.collect() | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| elif device == "mps": | |
| torch.mps.empty_cache() | |
| except NameError: | |
| pass | |
| model_type = model_type_tmp | |
| load_model = get_load_model(model_type_tmp) | |
| model, tokenizer = load_model( | |
| base=args.base_url, | |
| finetuned=args.ft_ckpt_url, | |
| mode_cpu=args.mode_cpu, | |
| mode_mps=args.mode_mps, | |
| mode_full_gpu=args.mode_full_gpu, | |
| mode_8bit=args.mode_8bit, | |
| mode_4bit=args.mode_4bit, | |
| force_download_ckpt=args.force_download_ckpt, | |
| local_files_only=args.local_files_only | |
| ) | |
| model.eval() | |
| model_thumbnail_tiny = args.thumbnail_tiny | |
| gen_config, gen_config_raw = get_generation_config(args.gen_config_path) | |
| gen_config_summarization, _ = get_generation_config(args.gen_config_summarization_path) | |
| stream_model = model | |
| def get_load_model(model_type): | |
| if model_type == "alpaca" or \ | |
| model_type == "alpaca-gpt4" or \ | |
| model_type == "llama-deus" or \ | |
| model_type == "nous-hermes" or \ | |
| model_type == "lazarus" or \ | |
| model_type == "chronos" or \ | |
| model_type == "wizardlm" or \ | |
| model_type == "openllama" or \ | |
| model_type == "orcamini" or \ | |
| model_type == "llama2" or \ | |
| model_type == "upstage-llama": | |
| return alpaca.load_model | |
| elif model_type == "free-willy": | |
| return freewilly.load_model | |
| elif model_type == "stablelm" or model_type == "os-stablelm": | |
| return stablelm.load_model | |
| elif model_type == "koalpaca-polyglot": | |
| return koalpaca.load_model | |
| elif model_type == "kullm-polyglot": | |
| return kullm.load_model | |
| elif model_type == "flan-alpaca": | |
| return flan_alpaca.load_model | |
| elif model_type == "camel": | |
| return camel.load_model | |
| elif model_type == "t5-vicuna": | |
| return t5_vicuna.load_model | |
| elif model_type == "stable-vicuna": | |
| return vicuna.load_model | |
| elif model_type == "starchat": | |
| return starchat.load_model | |
| elif model_type == "wizard-coder": | |
| return wizard_coder.load_model | |
| elif model_type == "mpt": | |
| return mpt.load_model | |
| elif model_type == "redpajama" or \ | |
| model_type == "redpajama-instruct": | |
| return redpajama.load_model | |
| elif model_type == "vicuna": | |
| return vicuna.load_model | |
| elif model_type == "evolinstruct-vicuna" or \ | |
| model_type == "wizard-vicuna": | |
| return alpaca.load_model | |
| elif model_type == "alpacoom": | |
| return bloom.load_model | |
| elif model_type == "baize": | |
| return baize.load_model | |
| elif model_type == "guanaco": | |
| return guanaco.load_model | |
| elif model_type == "falcon" or model_type == "wizard-falcon": | |
| return falcon.load_model | |
| elif model_type == "replit-instruct": | |
| return replit.load_model | |
| elif model_type == "airoboros": | |
| return airoboros.load_model | |
| elif model_type == "samantha-vicuna": | |
| return samantha_vicuna.load_model | |
| elif model_type == "xgen": | |
| return xgen.load_model | |
| else: | |
| return None | |
| def get_generation_config(path): | |
| with open(path, 'rb') as f: | |
| generation_config = yaml.safe_load(f.read()) | |
| generation_config = generation_config["generation_config"] | |
| return GenerationConfig(**generation_config), generation_config | |
| def get_constraints_config(path): | |
| with open(path, 'rb') as f: | |
| constraints_config = yaml.safe_load(f.read()) | |
| return ConstraintsConfig(**constraints_config), constraints_config["constraints"] | |