bbb / swift /llm /export /ollama.py
novateur's picture
Add files using upload-large-folder tool
0955071 verified
raw
history blame
3.3 kB
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from typing import List
from swift.llm import ExportArguments, PtEngine, RequestConfig, Template, prepare_model_template
from swift.utils import get_logger
logger = get_logger()
def replace_and_concat(template: 'Template', template_list: List, placeholder: str, keyword: str):
final_str = ''
for t in template_list:
if isinstance(t, str):
final_str += t.replace(placeholder, keyword)
elif isinstance(t, (tuple, list)):
if isinstance(t[0], int):
final_str += template.tokenizer.decode(t)
else:
for attr in t:
if attr == 'bos_token_id':
final_str += template.tokenizer.bos_token
elif attr == 'eos_token_id':
final_str += template.tokenizer.eos_token
else:
raise ValueError(f'Unknown token: {attr}')
return final_str
def export_to_ollama(args: ExportArguments):
args.device_map = 'meta' # Accelerate load speed.
logger.info('Exporting to ollama:')
os.makedirs(args.output_dir, exist_ok=True)
model, template = prepare_model_template(args)
pt_engine = PtEngine.from_model_template(model, template)
logger.info(f'Using model_dir: {pt_engine.model_dir}')
template_meta = template.template_meta
with open(os.path.join(args.output_dir, 'Modelfile'), 'w', encoding='utf-8') as f:
f.write(f'FROM {pt_engine.model_dir}\n')
f.write(f'TEMPLATE """{{{{ if .System }}}}'
f'{replace_and_concat(template, template_meta.system_prefix, "{{SYSTEM}}", "{{ .System }}")}'
f'{{{{ else }}}}{replace_and_concat(template, template_meta.prefix, "", "")}'
f'{{{{ end }}}}')
f.write(f'{{{{ if .Prompt }}}}'
f'{replace_and_concat(template, template_meta.prompt, "{{QUERY}}", "{{ .Prompt }}")}'
f'{{{{ end }}}}')
f.write('{{ .Response }}')
f.write(replace_and_concat(template, template_meta.suffix, '', '') + '"""\n')
f.write(f'PARAMETER stop "{replace_and_concat(template, template_meta.suffix, "", "")}"\n')
request_config = RequestConfig(
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty)
generation_config = pt_engine._prepare_generation_config(request_config)
pt_engine._add_stop_words(generation_config, request_config, template.template_meta)
for stop_word in generation_config.stop_words:
f.write(f'PARAMETER stop "{stop_word}"\n')
f.write(f'PARAMETER temperature {generation_config.temperature}\n')
f.write(f'PARAMETER top_k {generation_config.top_k}\n')
f.write(f'PARAMETER top_p {generation_config.top_p}\n')
f.write(f'PARAMETER repeat_penalty {generation_config.repetition_penalty}\n')
logger.info('Save Modelfile done, you can start ollama by:')
logger.info('> ollama serve')
logger.info('In another terminal:')
logger.info('> ollama create my-custom-model ' f'-f {os.path.join(args.output_dir, "Modelfile")}')
logger.info('> ollama run my-custom-model')