File size: 13,749 Bytes
2119d2a 670cd6b 2119d2a 899ed3a 2119d2a 670cd6b bf25e1f 2119d2a bf25e1f 2119d2a 899ed3a 2119d2a 899ed3a bf25e1f 899ed3a bf25e1f 899ed3a bf25e1f 899ed3a bf25e1f 899ed3a bf25e1f 899ed3a bf25e1f 899ed3a bf25e1f 899ed3a bf25e1f 899ed3a bf25e1f 9571d80 bf25e1f 899ed3a 2119d2a bf25e1f 8059d57 2119d2a 899ed3a bf25e1f 899ed3a 2119d2a 899ed3a bf25e1f 899ed3a 2119d2a 354aec4 899ed3a 354aec4 2119d2a 899ed3a 2119d2a 899ed3a 9571d80 899ed3a 354aec4 9571d80 354aec4 899ed3a 2119d2a 354aec4 9571d80 354aec4 899ed3a 2119d2a 899ed3a 2119d2a 899ed3a 2119d2a 899ed3a 2119d2a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 |
import torch
import os
from transformers import AutoTokenizer
# Import your custom model class
# Assume the inference environment handles adding the repository root to the Python path
from modelling_trm import TRM, TRMConfig
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, model_path="."): # Accept model_path during initialization
self.model = None
self.tokenizer = None
self.device = None
self.model_path = model_path # Store model_path
logger.info(f"Initializing model from directory: {self.model_path}")
try:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Loading model on device: {self.device}")
# Load the config
logger.info("Attempting to load model config...")
config = TRMConfig.from_pretrained(self.model_path)
logger.info("Model config loaded successfully.")
logger.debug(f"Config type: {type(config)}")
# Load the model
logger.info("Attempting to load model...")
self.model = TRM.from_pretrained(self.model_path, config=config)
logger.info("Model loaded successfully.")
logger.debug(f"Model type: {type(self.model)}")
logger.debug(f"Is model callable: {callable(self.model)}")
self.model.to(self.device)
self.model.eval() # Set model to evaluation mode
logger.info("Model set to evaluation mode.")
# Load the tokenizer (using a placeholder as the original had issues)
# You might need to adapt this based on your actual tokenizer
logger.info("Attempting to load tokenizer...")
try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
logger.info("Tokenizer loaded successfully from model path.")
logger.debug(f"Tokenizer type: {type(self.tokenizer)}")
logger.debug(f"Is tokenizer callable: {callable(self.tokenizer)}")
except Exception as e:
logger.warning(f"Failed to load tokenizer from model path: {e}. Falling back to basic tokenizer.")
# Fallback to a basic tokenizer if loading from path fails
try:
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
logger.info("Loaded a placeholder tokenizer (bert-base-uncased) for inference.")
logger.debug(f"Tokenizer type (fallback): {type(self.tokenizer)}")
logger.debug(f"Is tokenizer callable (fallback): {callable(self.tokenizer)}")
except Exception as e_fallback:
logger.error(f"Failed to load placeholder tokenizer: {e_fallback}")
self.tokenizer = None # Ensure tokenizer is None if fallback fails
# Check if tokenizer is callable after loading attempts
if self.tokenizer is None or not callable(self.tokenizer):
logger.error("Loaded tokenizer is not callable!")
# Optionally raise an error here if tokenizer is essential
# raise TypeError("Loaded tokenizer is not callable.")
except Exception as e:
logger.error(f"Error during model initialization: {e}", exc_info=True)
raise # Re-raise the exception to indicate initialization failure
# Final check after initialization
if self.model is None or not callable(self.model):
logger.error("Model is None or not callable after initialization!")
raise RuntimeError("Model failed to initialize correctly.")
if self.tokenizer is None or not callable(self.tokenizer):
logger.error("Tokenizer is None or not callable after initialization!")
# Depending on if tokenizer is strictly required for handle,
# you might raise an error here or handle it in preprocess/inference
# raise RuntimeError("Tokenizer failed to initialize correctly.") # Uncomment if tokenizer is essential
def preprocess(self, inputs):
logger.info("Starting preprocess.")
logger.debug(f"Preprocess input type: {type(inputs)}")
logger.debug(f"Preprocess input: {inputs}")
# Check if tokenizer is callable before using it
if self.tokenizer is None or not callable(self.tokenizer):
logger.error("Tokenizer is not available or not callable during preprocess.")
raise RuntimeError("Tokenizer is not available for preprocessing.")
try:
# Preprocess inputs for the model
# 'inputs' will be the data received by the inference endpoint
# This needs to be adapted based on the expected input format (e.g., text string)
# For text generation, 'inputs' could be a string or a list of strings.
if isinstance(inputs, str):
inputs = [inputs]
logger.debug("Input is a string, converted to list.")
elif not isinstance(inputs, list):
logger.error(f"Input must be a string or a list of strings, but got {type(inputs)}.")
raise ValueError("Input must be a string or a list of strings.")
else:
logger.debug("Input is already a list.")
# Tokenize the input
# Ensure padding and truncation are handled
max_len = getattr(self.tokenizer, 'model_max_length', 4096) # Use a default max_len if not available
logger.debug(f"Using max_length for tokenization: {max_len}")
tokenized_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True, truncation=True, max_length=max_len)
logger.info("Inputs tokenized.")
logger.debug(f"Tokenized inputs keys: {tokenized_inputs.keys()}")
logger.debug(f"Tokenized inputs['input_ids'] shape: {tokenized_inputs['input_ids'].shape}")
# Move tokenized inputs to the model's device
tokenized_inputs = {k: v.to(self.device) for k, v in tokenized_inputs.items()}
logger.info("Tokenized inputs moved to device.")
logger.debug(f"Input tensor device: {tokenized_inputs['input_ids'].device}")
# Return only the inputs expected by the TRM model
# Based on training, TRM seems to only take 'input_ids'
model_input = {'input_ids': tokenized_inputs['input_ids']}
logger.info("Preprocess complete.")
logger.debug(f"Model input keys: {model_input.keys()}")
logger.debug(f"Model input['input_ids'] shape: {model_input['input_ids'].shape}")
return model_input
except Exception as e:
logger.error(f"Error during preprocess: {e}", exc_info=True)
raise # Re-raise the exception
def inference(self, inputs):
logger.info("Starting inference.")
logger.debug(f"Inference input keys: {inputs.keys()}")
logger.debug(f"Inference input['input_ids'] shape: {inputs['input_ids'].shape}")
# Check if model is callable before using it
if self.model is None or not callable(self.model):
logger.error("Model is not available or not callable during inference.")
raise RuntimeError("Model is not available for inference.")
try:
# Perform inference with the model
# 'inputs' here is the output of the preprocess method
with torch.no_grad():
# Perform the forward pass
# Assuming the model only takes input_ids
outputs = self.model(**inputs)
logger.info("Model forward pass complete.")
logger.debug(f"Model output type: {type(outputs)}")
if isinstance(outputs, dict):
logger.debug(f"Model output keys: {outputs.keys()}")
# logger.debug(f"Model output: {outputs}") # Uncomment for detailed output inspection
# The model's output structure might differ, assuming it returns logits
# You might need to adapt this based on the actual TRM output for inference
# For text generation, you might use model.generate() instead of a simple forward pass
# This example performs a simple forward pass and returns logits
# Based on previous debugging, output is a dict with 'logits' key
if isinstance(outputs, dict) and 'logits' in outputs:
logits = outputs['logits']
logger.debug("Accessed logits from dictionary output.")
elif hasattr(outputs, 'logits'):
logits = outputs.logits
logger.debug("Accessed logits from attribute.")
else:
logger.error("Model output does not contain 'logits'.")
raise AttributeError("Model output does not contain 'logits'.")
logger.info("Inference complete.")
logger.debug(f"Logits shape: {logits.shape}")
return logits # Or process logits further for text generation
except Exception as e:
logger.error(f"Error during inference: {e}", exc_info=True)
raise # Re-raise the exception
def postprocess(self, outputs):
logger.info("Starting postprocess for text generation (Greedy Decoding).")
# 'outputs' here is the output of the inference method (logits)
# This implements a basic greedy decoding strategy
if self.tokenizer is None:
logger.error("Tokenizer is not available for postprocessing.")
raise RuntimeError("Tokenizer is not available for postprocessing.")
try:
# Assuming outputs are logits of shape (batch_size, sequence_length, vocab_size)
# For greedy decoding, we take the argmax of the logits for the last token
# and append it to the input sequence.
# This basic handler will generate one token at a time in a loop.
# A real text generation handler would likely take the initial input_ids
# and loop until a stop condition is met.
# For a single forward pass output (like from the inference method),
# we can't generate a sequence directly here.
# The handle method would need to manage the generation loop.
# Let's adapt this postprocess to just decode the most probable token from the last position
# as a basic example, or return the input + the most probable next token.
# Assuming inputs were processed one by one (batch size 1 for simplicity in this example)
# And outputs are logits for the input sequence
if outputs.ndim == 3 and outputs.shape[0] == 1: # Shape (1, seq_len, vocab_size)
last_token_logits = outputs[0, -1, :] # Logits for the last token in the sequence
predicted_token_id = torch.argmax(last_token_logits).item()
# Decode the predicted token
predicted_text = self.tokenizer.decode([predicted_token_id])
logger.info(f"Predicted next token: {predicted_text} (ID: {predicted_token_id})")
# In a real generation loop, you would append this to the input and repeat.
# For this handler, let's just return the predicted token text.
return predicted_text
else:
logger.warning(f"Unexpected output shape for greedy decoding: {outputs.shape}. Returning raw logits list.")
return outputs.cpu().tolist()
except Exception as e:
logger.error(f"Error during postprocess: {e}", exc_info=True)
raise # Re-raise the exception
def handle(self, data):
logger.info("Starting handle method.")
logger.debug(f"Received data in handle: {data}")
try:
# 1. Preprocess
logger.info("Calling preprocess...")
model_input = self.preprocess(data)
logger.info("Preprocessing successful.")
# 2. Inference (single forward pass to get logits)
logger.info("Calling inference...")
model_output_logits = self.inference(model_input)
logger.info("Inference successful.")
# 3. Postprocess (basic greedy decoding of the next token)
# Note: This postprocess only generates the *next* token.
# For full text generation, you would need a loop here or modify inference.
logger.info("Calling postprocess...")
response = self.postprocess(model_output_logits)
logger.info("Postprocessing successful.")
logger.info("Handle method complete.")
return response
except Exception as e:
logger.error(f"Error in handle method: {e}", exc_info=True)
# Depending on requirements, you might want to return an error response
# return {"error": str(e)}
raise # Re-raise the exception to be caught by the inference toolkit
# Example usage (for testing locally)
# if __name__ == "__main__":
# # This part assumes you have the model files in the current directory or specify model_path
# # For local testing, you might need to adjust model_path
# # Example: handler = EndpointHandler(model_path="./custom_training_output")
# handler = EndpointHandler() # Assuming model files are in "."
# test_input = "This is a test input"
# output = handler.handle(test_input)
# print("Inference output:", output)
|