# llm_router.py import logging from models_config import LLM_CONFIG logger = logging.getLogger(__name__) class LLMRouter: def __init__(self, hf_token): self.hf_token = hf_token self.health_status = {} logger.info("LLMRouter initialized") if hf_token: logger.info("HF token available") else: logger.warning("No HF token provided") async def route_inference(self, task_type: str, prompt: str, **kwargs): """ Smart routing based on task specialization """ logger.info(f"Routing inference for task: {task_type}") model_config = self._select_model(task_type) logger.info(f"Selected model: {model_config['model_id']}") # Health check and fallback logic if not await self._is_model_healthy(model_config["model_id"]): logger.warning(f"Model unhealthy, using fallback") model_config = self._get_fallback_model(task_type) logger.info(f"Fallback model: {model_config['model_id']}") result = await self._call_hf_endpoint(model_config, prompt, **kwargs) logger.info(f"Inference complete for {task_type}") return result def _select_model(self, task_type: str) -> dict: model_map = { "intent_classification": LLM_CONFIG["models"]["classification_specialist"], "embedding_generation": LLM_CONFIG["models"]["embedding_specialist"], "safety_check": LLM_CONFIG["models"]["safety_checker"], "general_reasoning": LLM_CONFIG["models"]["reasoning_primary"], "response_synthesis": LLM_CONFIG["models"]["reasoning_primary"] } return model_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"]) async def _is_model_healthy(self, model_id: str) -> bool: """ Check if the model is healthy and available Mark models as healthy by default - actual availability checked at API call time """ # Check cached health status if model_id in self.health_status: return self.health_status[model_id] # All models marked healthy initially - real check happens during API call self.health_status[model_id] = True return True def _get_fallback_model(self, task_type: str) -> dict: """ Get fallback model configuration for the task type """ # Fallback mapping fallback_map = { "intent_classification": LLM_CONFIG["models"]["reasoning_primary"], "embedding_generation": LLM_CONFIG["models"]["embedding_specialist"], "safety_check": LLM_CONFIG["models"]["reasoning_primary"], "general_reasoning": LLM_CONFIG["models"]["reasoning_primary"], "response_synthesis": LLM_CONFIG["models"]["reasoning_primary"] } return fallback_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"]) async def _call_hf_endpoint(self, model_config: dict, prompt: str, **kwargs): """ Make actual call to Hugging Face Chat Completions API Uses the correct chat completions protocol """ try: import requests model_id = model_config["model_id"] # Use the chat completions endpoint api_url = "https://router.huggingface.co/v1/chat/completions" logger.info(f"Calling HF Chat Completions API for model: {model_id}") logger.debug(f"Prompt length: {len(prompt)}") headers = { "Authorization": f"Bearer {self.hf_token}", "Content-Type": "application/json" } # Prepare payload in chat completions format # Extract the actual question from the prompt if it's in a structured format user_message = prompt if "User Question:" not in prompt else prompt.split("User Question:")[1].split("\n")[0].strip() payload = { "model": f"{model_id}:together", # Use the Together endpoint as specified "messages": [ { "role": "user", "content": user_message } ], "max_tokens": kwargs.get("max_tokens", 2000), "temperature": kwargs.get("temperature", 0.7), "top_p": kwargs.get("top_p", 0.95) } # Make the API call response = requests.post(api_url, json=payload, headers=headers, timeout=60) if response.status_code == 200: result = response.json() # Handle chat completions response format if "choices" in result and len(result["choices"]) > 0: message = result["choices"][0].get("message", {}) generated_text = message.get("content", "") # Ensure we always return a string, never None if not generated_text or not isinstance(generated_text, str): logger.warning(f"Empty or invalid response, using fallback") return None logger.info(f"HF API returned response (length: {len(generated_text)})") return generated_text else: logger.error(f"Unexpected response format: {result}") return None elif response.status_code == 503: # Model is loading, retry with simpler model logger.warning(f"Model loading (503), trying fallback") fallback_config = self._get_fallback_model("response_synthesis") return await self._call_hf_endpoint(fallback_config, prompt, **kwargs) else: logger.error(f"HF API error: {response.status_code} - {response.text}") return None except ImportError: logger.warning("requests library not available, using mock response") return f"[Mock] Response to: {prompt[:100]}..." except Exception as e: logger.error(f"Error calling HF endpoint: {e}", exc_info=True) return None