luminar-nano / handler.py
Lorenzob's picture
Upload folder using huggingface_hub
354aec4 verified
raw
history blame
13.7 kB
import torch
import os
from transformers import AutoTokenizer
# Import your custom model class
# Assume the inference environment handles adding the repository root to the Python path
from modelling_trm import TRM, TRMConfig
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, model_path="."): # Accept model_path during initialization
self.model = None
self.tokenizer = None
self.device = None
self.model_path = model_path # Store model_path
logger.info(f"Initializing model from directory: {self.model_path}")
try:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Loading model on device: {self.device}")
# Load the config
logger.info("Attempting to load model config...")
config = TRMConfig.from_pretrained(self.model_path)
logger.info("Model config loaded successfully.")
logger.debug(f"Config type: {type(config)}")
# Load the model
logger.info("Attempting to load model...")
self.model = TRM.from_pretrained(self.model_path, config=config)
logger.info("Model loaded successfully.")
logger.debug(f"Model type: {type(self.model)}")
logger.debug(f"Is model callable: {callable(self.model)}")
self.model.to(self.device)
self.model.eval() # Set model to evaluation mode
logger.info("Model set to evaluation mode.")
# Load the tokenizer (using a placeholder as the original had issues)
# You might need to adapt this based on your actual tokenizer
logger.info("Attempting to load tokenizer...")
try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
logger.info("Tokenizer loaded successfully from model path.")
logger.debug(f"Tokenizer type: {type(self.tokenizer)}")
logger.debug(f"Is tokenizer callable: {callable(self.tokenizer)}")
except Exception as e:
logger.warning(f"Failed to load tokenizer from model path: {e}. Falling back to basic tokenizer.")
# Fallback to a basic tokenizer if loading from path fails
try:
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
logger.info("Loaded a placeholder tokenizer (bert-base-uncased) for inference.")
logger.debug(f"Tokenizer type (fallback): {type(self.tokenizer)}")
logger.debug(f"Is tokenizer callable (fallback): {callable(self.tokenizer)}")
except Exception as e_fallback:
logger.error(f"Failed to load placeholder tokenizer: {e_fallback}")
self.tokenizer = None # Ensure tokenizer is None if fallback fails
# Check if tokenizer is callable after loading attempts
if self.tokenizer is None or not callable(self.tokenizer):
logger.error("Loaded tokenizer is not callable!")
# Optionally raise an error here if tokenizer is essential
# raise TypeError("Loaded tokenizer is not callable.")
except Exception as e:
logger.error(f"Error during model initialization: {e}", exc_info=True)
raise # Re-raise the exception to indicate initialization failure
# Final check after initialization
if self.model is None or not callable(self.model):
logger.error("Model is None or not callable after initialization!")
raise RuntimeError("Model failed to initialize correctly.")
if self.tokenizer is None or not callable(self.tokenizer):
logger.error("Tokenizer is None or not callable after initialization!")
# Depending on if tokenizer is strictly required for handle,
# you might raise an error here or handle it in preprocess/inference
# raise RuntimeError("Tokenizer failed to initialize correctly.") # Uncomment if tokenizer is essential
def preprocess(self, inputs):
logger.info("Starting preprocess.")
logger.debug(f"Preprocess input type: {type(inputs)}")
logger.debug(f"Preprocess input: {inputs}")
# Check if tokenizer is callable before using it
if self.tokenizer is None or not callable(self.tokenizer):
logger.error("Tokenizer is not available or not callable during preprocess.")
raise RuntimeError("Tokenizer is not available for preprocessing.")
try:
# Preprocess inputs for the model
# 'inputs' will be the data received by the inference endpoint
# This needs to be adapted based on the expected input format (e.g., text string)
# For text generation, 'inputs' could be a string or a list of strings.
if isinstance(inputs, str):
inputs = [inputs]
logger.debug("Input is a string, converted to list.")
elif not isinstance(inputs, list):
logger.error(f"Input must be a string or a list of strings, but got {type(inputs)}.")
raise ValueError("Input must be a string or a list of strings.")
else:
logger.debug("Input is already a list.")
# Tokenize the input
# Ensure padding and truncation are handled
max_len = getattr(self.tokenizer, 'model_max_length', 4096) # Use a default max_len if not available
logger.debug(f"Using max_length for tokenization: {max_len}")
tokenized_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True, truncation=True, max_length=max_len)
logger.info("Inputs tokenized.")
logger.debug(f"Tokenized inputs keys: {tokenized_inputs.keys()}")
logger.debug(f"Tokenized inputs['input_ids'] shape: {tokenized_inputs['input_ids'].shape}")
# Move tokenized inputs to the model's device
tokenized_inputs = {k: v.to(self.device) for k, v in tokenized_inputs.items()}
logger.info("Tokenized inputs moved to device.")
logger.debug(f"Input tensor device: {tokenized_inputs['input_ids'].device}")
# Return only the inputs expected by the TRM model
# Based on training, TRM seems to only take 'input_ids'
model_input = {'input_ids': tokenized_inputs['input_ids']}
logger.info("Preprocess complete.")
logger.debug(f"Model input keys: {model_input.keys()}")
logger.debug(f"Model input['input_ids'] shape: {model_input['input_ids'].shape}")
return model_input
except Exception as e:
logger.error(f"Error during preprocess: {e}", exc_info=True)
raise # Re-raise the exception
def inference(self, inputs):
logger.info("Starting inference.")
logger.debug(f"Inference input keys: {inputs.keys()}")
logger.debug(f"Inference input['input_ids'] shape: {inputs['input_ids'].shape}")
# Check if model is callable before using it
if self.model is None or not callable(self.model):
logger.error("Model is not available or not callable during inference.")
raise RuntimeError("Model is not available for inference.")
try:
# Perform inference with the model
# 'inputs' here is the output of the preprocess method
with torch.no_grad():
# Perform the forward pass
# Assuming the model only takes input_ids
outputs = self.model(**inputs)
logger.info("Model forward pass complete.")
logger.debug(f"Model output type: {type(outputs)}")
if isinstance(outputs, dict):
logger.debug(f"Model output keys: {outputs.keys()}")
# logger.debug(f"Model output: {outputs}") # Uncomment for detailed output inspection
# The model's output structure might differ, assuming it returns logits
# You might need to adapt this based on the actual TRM output for inference
# For text generation, you might use model.generate() instead of a simple forward pass
# This example performs a simple forward pass and returns logits
# Based on previous debugging, output is a dict with 'logits' key
if isinstance(outputs, dict) and 'logits' in outputs:
logits = outputs['logits']
logger.debug("Accessed logits from dictionary output.")
elif hasattr(outputs, 'logits'):
logits = outputs.logits
logger.debug("Accessed logits from attribute.")
else:
logger.error("Model output does not contain 'logits'.")
raise AttributeError("Model output does not contain 'logits'.")
logger.info("Inference complete.")
logger.debug(f"Logits shape: {logits.shape}")
return logits # Or process logits further for text generation
except Exception as e:
logger.error(f"Error during inference: {e}", exc_info=True)
raise # Re-raise the exception
def postprocess(self, outputs):
logger.info("Starting postprocess for text generation (Greedy Decoding).")
# 'outputs' here is the output of the inference method (logits)
# This implements a basic greedy decoding strategy
if self.tokenizer is None:
logger.error("Tokenizer is not available for postprocessing.")
raise RuntimeError("Tokenizer is not available for postprocessing.")
try:
# Assuming outputs are logits of shape (batch_size, sequence_length, vocab_size)
# For greedy decoding, we take the argmax of the logits for the last token
# and append it to the input sequence.
# This basic handler will generate one token at a time in a loop.
# A real text generation handler would likely take the initial input_ids
# and loop until a stop condition is met.
# For a single forward pass output (like from the inference method),
# we can't generate a sequence directly here.
# The handle method would need to manage the generation loop.
# Let's adapt this postprocess to just decode the most probable token from the last position
# as a basic example, or return the input + the most probable next token.
# Assuming inputs were processed one by one (batch size 1 for simplicity in this example)
# And outputs are logits for the input sequence
if outputs.ndim == 3 and outputs.shape[0] == 1: # Shape (1, seq_len, vocab_size)
last_token_logits = outputs[0, -1, :] # Logits for the last token in the sequence
predicted_token_id = torch.argmax(last_token_logits).item()
# Decode the predicted token
predicted_text = self.tokenizer.decode([predicted_token_id])
logger.info(f"Predicted next token: {predicted_text} (ID: {predicted_token_id})")
# In a real generation loop, you would append this to the input and repeat.
# For this handler, let's just return the predicted token text.
return predicted_text
else:
logger.warning(f"Unexpected output shape for greedy decoding: {outputs.shape}. Returning raw logits list.")
return outputs.cpu().tolist()
except Exception as e:
logger.error(f"Error during postprocess: {e}", exc_info=True)
raise # Re-raise the exception
def handle(self, data):
logger.info("Starting handle method.")
logger.debug(f"Received data in handle: {data}")
try:
# 1. Preprocess
logger.info("Calling preprocess...")
model_input = self.preprocess(data)
logger.info("Preprocessing successful.")
# 2. Inference (single forward pass to get logits)
logger.info("Calling inference...")
model_output_logits = self.inference(model_input)
logger.info("Inference successful.")
# 3. Postprocess (basic greedy decoding of the next token)
# Note: This postprocess only generates the *next* token.
# For full text generation, you would need a loop here or modify inference.
logger.info("Calling postprocess...")
response = self.postprocess(model_output_logits)
logger.info("Postprocessing successful.")
logger.info("Handle method complete.")
return response
except Exception as e:
logger.error(f"Error in handle method: {e}", exc_info=True)
# Depending on requirements, you might want to return an error response
# return {"error": str(e)}
raise # Re-raise the exception to be caught by the inference toolkit
# Example usage (for testing locally)
# if __name__ == "__main__":
# # This part assumes you have the model files in the current directory or specify model_path
# # For local testing, you might need to adjust model_path
# # Example: handler = EndpointHandler(model_path="./custom_training_output")
# handler = EndpointHandler() # Assuming model files are in "."
# test_input = "This is a test input"
# output = handler.handle(test_input)
# print("Inference output:", output)