import torch import os from transformers import AutoTokenizer # Import your custom model class # Assume the inference environment handles adding the repository root to the Python path from modelling_trm import TRM, TRMConfig import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EndpointHandler: def __init__(self, model_path="."): # Accept model_path during initialization self.model = None self.tokenizer = None self.device = None self.model_path = model_path # Store model_path logger.info(f"Initializing model from directory: {self.model_path}") try: self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Loading model on device: {self.device}") # Load the config logger.info("Attempting to load model config...") config = TRMConfig.from_pretrained(self.model_path) logger.info("Model config loaded successfully.") logger.debug(f"Config type: {type(config)}") # Load the model logger.info("Attempting to load model...") self.model = TRM.from_pretrained(self.model_path, config=config) logger.info("Model loaded successfully.") logger.debug(f"Model type: {type(self.model)}") logger.debug(f"Is model callable: {callable(self.model)}") self.model.to(self.device) self.model.eval() # Set model to evaluation mode logger.info("Model set to evaluation mode.") # Load the tokenizer (using a placeholder as the original had issues) # You might need to adapt this based on your actual tokenizer logger.info("Attempting to load tokenizer...") try: self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) logger.info("Tokenizer loaded successfully from model path.") logger.debug(f"Tokenizer type: {type(self.tokenizer)}") logger.debug(f"Is tokenizer callable: {callable(self.tokenizer)}") except Exception as e: logger.warning(f"Failed to load tokenizer from model path: {e}. Falling back to basic tokenizer.") # Fallback to a basic tokenizer if loading from path fails try: self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') logger.info("Loaded a placeholder tokenizer (bert-base-uncased) for inference.") logger.debug(f"Tokenizer type (fallback): {type(self.tokenizer)}") logger.debug(f"Is tokenizer callable (fallback): {callable(self.tokenizer)}") except Exception as e_fallback: logger.error(f"Failed to load placeholder tokenizer: {e_fallback}") self.tokenizer = None # Ensure tokenizer is None if fallback fails # Check if tokenizer is callable after loading attempts if self.tokenizer is None or not callable(self.tokenizer): logger.error("Loaded tokenizer is not callable!") # Optionally raise an error here if tokenizer is essential # raise TypeError("Loaded tokenizer is not callable.") except Exception as e: logger.error(f"Error during model initialization: {e}", exc_info=True) raise # Re-raise the exception to indicate initialization failure # Final check after initialization if self.model is None or not callable(self.model): logger.error("Model is None or not callable after initialization!") raise RuntimeError("Model failed to initialize correctly.") if self.tokenizer is None or not callable(self.tokenizer): logger.error("Tokenizer is None or not callable after initialization!") # Depending on if tokenizer is strictly required for handle, # you might raise an error here or handle it in preprocess/inference # raise RuntimeError("Tokenizer failed to initialize correctly.") # Uncomment if tokenizer is essential def preprocess(self, inputs): logger.info("Starting preprocess.") logger.debug(f"Preprocess input type: {type(inputs)}") logger.debug(f"Preprocess input: {inputs}") # Check if tokenizer is callable before using it if self.tokenizer is None or not callable(self.tokenizer): logger.error("Tokenizer is not available or not callable during preprocess.") raise RuntimeError("Tokenizer is not available for preprocessing.") try: # Preprocess inputs for the model # 'inputs' will be the data received by the inference endpoint # This needs to be adapted based on the expected input format (e.g., text string) # For text generation, 'inputs' could be a string or a list of strings. if isinstance(inputs, str): inputs = [inputs] logger.debug("Input is a string, converted to list.") elif not isinstance(inputs, list): logger.error(f"Input must be a string or a list of strings, but got {type(inputs)}.") raise ValueError("Input must be a string or a list of strings.") else: logger.debug("Input is already a list.") # Tokenize the input # Ensure padding and truncation are handled max_len = getattr(self.tokenizer, 'model_max_length', 4096) # Use a default max_len if not available logger.debug(f"Using max_length for tokenization: {max_len}") tokenized_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True, truncation=True, max_length=max_len) logger.info("Inputs tokenized.") logger.debug(f"Tokenized inputs keys: {tokenized_inputs.keys()}") logger.debug(f"Tokenized inputs['input_ids'] shape: {tokenized_inputs['input_ids'].shape}") # Move tokenized inputs to the model's device tokenized_inputs = {k: v.to(self.device) for k, v in tokenized_inputs.items()} logger.info("Tokenized inputs moved to device.") logger.debug(f"Input tensor device: {tokenized_inputs['input_ids'].device}") # Return only the inputs expected by the TRM model # Based on training, TRM seems to only take 'input_ids' model_input = {'input_ids': tokenized_inputs['input_ids']} logger.info("Preprocess complete.") logger.debug(f"Model input keys: {model_input.keys()}") logger.debug(f"Model input['input_ids'] shape: {model_input['input_ids'].shape}") return model_input except Exception as e: logger.error(f"Error during preprocess: {e}", exc_info=True) raise # Re-raise the exception def inference(self, inputs): logger.info("Starting inference.") logger.debug(f"Inference input keys: {inputs.keys()}") logger.debug(f"Inference input['input_ids'] shape: {inputs['input_ids'].shape}") # Check if model is callable before using it if self.model is None or not callable(self.model): logger.error("Model is not available or not callable during inference.") raise RuntimeError("Model is not available for inference.") try: # Perform inference with the model # 'inputs' here is the output of the preprocess method with torch.no_grad(): # Perform the forward pass # Assuming the model only takes input_ids outputs = self.model(**inputs) logger.info("Model forward pass complete.") logger.debug(f"Model output type: {type(outputs)}") if isinstance(outputs, dict): logger.debug(f"Model output keys: {outputs.keys()}") # logger.debug(f"Model output: {outputs}") # Uncomment for detailed output inspection # The model's output structure might differ, assuming it returns logits # You might need to adapt this based on the actual TRM output for inference # For text generation, you might use model.generate() instead of a simple forward pass # This example performs a simple forward pass and returns logits # Based on previous debugging, output is a dict with 'logits' key if isinstance(outputs, dict) and 'logits' in outputs: logits = outputs['logits'] logger.debug("Accessed logits from dictionary output.") elif hasattr(outputs, 'logits'): logits = outputs.logits logger.debug("Accessed logits from attribute.") else: logger.error("Model output does not contain 'logits'.") raise AttributeError("Model output does not contain 'logits'.") logger.info("Inference complete.") logger.debug(f"Logits shape: {logits.shape}") return logits # Or process logits further for text generation except Exception as e: logger.error(f"Error during inference: {e}", exc_info=True) raise # Re-raise the exception def postprocess(self, outputs): logger.info("Starting postprocess for text generation (Greedy Decoding).") # 'outputs' here is the output of the inference method (logits) # This implements a basic greedy decoding strategy if self.tokenizer is None: logger.error("Tokenizer is not available for postprocessing.") raise RuntimeError("Tokenizer is not available for postprocessing.") try: # Assuming outputs are logits of shape (batch_size, sequence_length, vocab_size) # For greedy decoding, we take the argmax of the logits for the last token # and append it to the input sequence. # This basic handler will generate one token at a time in a loop. # A real text generation handler would likely take the initial input_ids # and loop until a stop condition is met. # For a single forward pass output (like from the inference method), # we can't generate a sequence directly here. # The handle method would need to manage the generation loop. # Let's adapt this postprocess to just decode the most probable token from the last position # as a basic example, or return the input + the most probable next token. # Assuming inputs were processed one by one (batch size 1 for simplicity in this example) # And outputs are logits for the input sequence if outputs.ndim == 3 and outputs.shape[0] == 1: # Shape (1, seq_len, vocab_size) last_token_logits = outputs[0, -1, :] # Logits for the last token in the sequence predicted_token_id = torch.argmax(last_token_logits).item() # Decode the predicted token predicted_text = self.tokenizer.decode([predicted_token_id]) logger.info(f"Predicted next token: {predicted_text} (ID: {predicted_token_id})") # In a real generation loop, you would append this to the input and repeat. # For this handler, let's just return the predicted token text. return predicted_text else: logger.warning(f"Unexpected output shape for greedy decoding: {outputs.shape}. Returning raw logits list.") return outputs.cpu().tolist() except Exception as e: logger.error(f"Error during postprocess: {e}", exc_info=True) raise # Re-raise the exception def handle(self, data): logger.info("Starting handle method.") logger.debug(f"Received data in handle: {data}") try: # 1. Preprocess logger.info("Calling preprocess...") model_input = self.preprocess(data) logger.info("Preprocessing successful.") # 2. Inference (single forward pass to get logits) logger.info("Calling inference...") model_output_logits = self.inference(model_input) logger.info("Inference successful.") # 3. Postprocess (basic greedy decoding of the next token) # Note: This postprocess only generates the *next* token. # For full text generation, you would need a loop here or modify inference. logger.info("Calling postprocess...") response = self.postprocess(model_output_logits) logger.info("Postprocessing successful.") logger.info("Handle method complete.") return response except Exception as e: logger.error(f"Error in handle method: {e}", exc_info=True) # Depending on requirements, you might want to return an error response # return {"error": str(e)} raise # Re-raise the exception to be caught by the inference toolkit # Example usage (for testing locally) # if __name__ == "__main__": # # This part assumes you have the model files in the current directory or specify model_path # # For local testing, you might need to adjust model_path # # Example: handler = EndpointHandler(model_path="./custom_training_output") # handler = EndpointHandler() # Assuming model files are in "." # test_input = "This is a test input" # output = handler.handle(test_input) # print("Inference output:", output)