# Copyright (c) Alibaba, Inc. and its affiliates. import asyncio import inspect import multiprocessing import time from contextlib import contextmanager from dataclasses import asdict from http import HTTPStatus from threading import Thread from typing import List, Optional, Union import json import uvicorn from aiohttp import ClientConnectorError from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, StreamingResponse from swift.llm import AdapterRequest, DeployArguments from swift.llm.infer.protocol import MultiModalRequestMixin from swift.plugin import InferStats from swift.utils import JsonlWriter, get_logger from .infer import SwiftInfer from .infer_engine import InferClient from .protocol import ChatCompletionRequest, CompletionRequest, Model, ModelList logger = get_logger() class SwiftDeploy(SwiftInfer): args_class = DeployArguments args: args_class def _register_app(self): self.app.get('/v1/models')(self.get_available_models) self.app.post('/v1/chat/completions')(self.create_chat_completion) self.app.post('/v1/completions')(self.create_completion) def __init__(self, args: Union[List[str], DeployArguments, None] = None) -> None: super().__init__(args) self.infer_engine.strict = True self.infer_stats = InferStats() self.app = FastAPI(lifespan=self.lifespan) self._register_app() async def _log_stats_hook(self): while True: await asyncio.sleep(self.args.log_interval) self._compute_infer_stats() self.infer_stats.reset() def _compute_infer_stats(self): global_stats = self.infer_stats.compute() for k, v in global_stats.items(): global_stats[k] = round(v, 8) logger.info(global_stats) def lifespan(self, app: FastAPI): args = self.args if args.log_interval > 0: thread = Thread(target=lambda: asyncio.run(self._log_stats_hook()), daemon=True) thread.start() try: yield finally: if args.log_interval > 0: self._compute_infer_stats() def _get_model_list(self): args = self.args model_list = [args.served_model_name or args.model_suffix] if args.adapter_mapping: model_list += [name for name in args.adapter_mapping.keys()] return model_list async def get_available_models(self): model_list = self._get_model_list() data = [Model(id=model_id, owned_by=self.args.owned_by) for model_id in model_list] return ModelList(data=data) async def _check_model(self, request: ChatCompletionRequest) -> Optional[str]: available_models = await self.get_available_models() model_list = [model.id for model in available_models.data] if request.model not in model_list: return f'`{request.model}` is not in the model_list: `{model_list}`.' def _check_api_key(self, raw_request: Request) -> Optional[str]: api_key = self.args.api_key if api_key is None: return authorization = dict(raw_request.headers).get('authorization') error_msg = 'API key error' if authorization is None or not authorization.startswith('Bearer '): return error_msg request_api_key = authorization[7:] if request_api_key != api_key: return error_msg def _check_max_logprobs(self, request): args = self.args if isinstance(request.top_logprobs, int) and request.top_logprobs > args.max_logprobs: return (f'The value of top_logprobs({request.top_logprobs}) is greater than ' f'the server\'s max_logprobs({args.max_logprobs}).') @staticmethod def create_error_response(status_code: Union[int, str, HTTPStatus], message: str) -> JSONResponse: status_code = int(status_code) return JSONResponse({'message': message, 'object': 'error'}, status_code) def _post_process(self, request_info, response, return_cmpl_response: bool = False): args = self.args for i in range(len(response.choices)): if not hasattr(response.choices[i], 'message') or not isinstance(response.choices[i].message.content, (tuple, list)): continue for j, content in enumerate(response.choices[i].message.content): if content['type'] == 'image': b64_image = MultiModalRequestMixin.to_base64(content['image']) response.choices[i].message.content[j]['image'] = f'data:image/jpg;base64,{b64_image}' is_finished = all(response.choices[i].finish_reason for i in range(len(response.choices))) if 'stream' in response.__class__.__name__.lower(): request_info['response'] += response.choices[0].delta.content else: request_info['response'] = response.choices[0].message.content if return_cmpl_response: response = response.to_cmpl_response() if is_finished: if args.log_interval > 0: self.infer_stats.update(response) if self.jsonl_writer: self.jsonl_writer.append(request_info) if self.args.verbose: logger.info(request_info) return response def _set_request_config(self, request_config) -> None: default_request_config = self.args.get_request_config() if default_request_config is None: return for key, val in asdict(request_config).items(): default_val = getattr(default_request_config, key) if default_val is not None and (val is None or isinstance(val, (list, tuple)) and len(val) == 0): setattr(request_config, key, default_val) async def create_chat_completion(self, request: ChatCompletionRequest, raw_request: Request, *, return_cmpl_response: bool = False): args = self.args error_msg = (await self._check_model(request) or self._check_api_key(raw_request) or self._check_max_logprobs(request)) if error_msg: return self.create_error_response(HTTPStatus.BAD_REQUEST, error_msg) infer_kwargs = self.infer_kwargs.copy() adapter_path = args.adapter_mapping.get(request.model) if adapter_path: infer_kwargs['adapter_request'] = AdapterRequest(request.model, adapter_path) infer_request, request_config = request.parse() self._set_request_config(request_config) request_info = {'response': '', 'infer_request': infer_request.to_printable()} def pre_infer_hook(kwargs): request_info['generation_config'] = kwargs['generation_config'] return kwargs infer_kwargs['pre_infer_hook'] = pre_infer_hook try: res_or_gen = await self.infer_async(infer_request, request_config, template=self.template, **infer_kwargs) except Exception as e: import traceback logger.info(traceback.format_exc()) return self.create_error_response(HTTPStatus.BAD_REQUEST, str(e)) if request_config.stream: async def _gen_wrapper(): async for res in res_or_gen: res = self._post_process(request_info, res, return_cmpl_response) yield f'data: {json.dumps(asdict(res), ensure_ascii=False)}\n\n' yield 'data: [DONE]\n\n' return StreamingResponse(_gen_wrapper(), media_type='text/event-stream') else: return self._post_process(request_info, res_or_gen, return_cmpl_response) async def create_completion(self, request: CompletionRequest, raw_request: Request): chat_request = ChatCompletionRequest.from_cmpl_request(request) return await self.create_chat_completion(chat_request, raw_request, return_cmpl_response=True) def run(self): args = self.args self.jsonl_writer = JsonlWriter(args.result_path) if args.result_path else None logger.info(f'model_list: {self._get_model_list()}') uvicorn.run( self.app, host=args.host, port=args.port, ssl_keyfile=args.ssl_keyfile, ssl_certfile=args.ssl_certfile, log_level=args.log_level) def deploy_main(args: Union[List[str], DeployArguments, None] = None) -> None: SwiftDeploy(args).main() def is_accessible(port: int): infer_client = InferClient(port=port) try: infer_client.get_model_list() except ClientConnectorError: return False return True @contextmanager def run_deploy(args: DeployArguments, return_url: bool = False): if isinstance(args, DeployArguments) and args.__class__.__name__ == 'DeployArguments': deploy_args = args else: args_dict = asdict(args) parameters = inspect.signature(DeployArguments).parameters for k in list(args_dict.keys()): if k not in parameters or args_dict[k] is None: args_dict.pop(k) deploy_args = DeployArguments(**args_dict) mp = multiprocessing.get_context('spawn') process = mp.Process(target=deploy_main, args=(deploy_args, )) process.start() try: while not is_accessible(deploy_args.port): time.sleep(1) yield f'http://127.0.0.1:{deploy_args.port}/v1' if return_url else deploy_args.port finally: process.terminate() logger.info('The deployment process has been terminated.')