|
|
|
|
|
import torch |
|
|
import os |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
|
|
|
from modelling_trm import TRM, TRMConfig |
|
|
import logging |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, model_path="."): |
|
|
self.model = None |
|
|
self.tokenizer = None |
|
|
self.device = None |
|
|
self.model_path = model_path |
|
|
|
|
|
logger.info(f"Initializing model from directory: {self.model_path}") |
|
|
try: |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
logger.info(f"Loading model on device: {self.device}") |
|
|
|
|
|
|
|
|
logger.info("Attempting to load model config...") |
|
|
config = TRMConfig.from_pretrained(self.model_path) |
|
|
logger.info("Model config loaded successfully.") |
|
|
logger.debug(f"Config type: {type(config)}") |
|
|
|
|
|
|
|
|
|
|
|
logger.info("Attempting to load model...") |
|
|
self.model = TRM.from_pretrained(self.model_path, config=config) |
|
|
logger.info("Model loaded successfully.") |
|
|
logger.debug(f"Model type: {type(self.model)}") |
|
|
logger.debug(f"Is model callable: {callable(self.model)}") |
|
|
|
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
logger.info("Model set to evaluation mode.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("Attempting to load tokenizer...") |
|
|
try: |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) |
|
|
logger.info("Tokenizer loaded successfully from model path.") |
|
|
logger.debug(f"Tokenizer type: {type(self.tokenizer)}") |
|
|
logger.debug(f"Is tokenizer callable: {callable(self.tokenizer)}") |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to load tokenizer from model path: {e}. Falling back to basic tokenizer.") |
|
|
|
|
|
try: |
|
|
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') |
|
|
logger.info("Loaded a placeholder tokenizer (bert-base-uncased) for inference.") |
|
|
logger.debug(f"Tokenizer type (fallback): {type(self.tokenizer)}") |
|
|
logger.debug(f"Is tokenizer callable (fallback): {callable(self.tokenizer)}") |
|
|
except Exception as e_fallback: |
|
|
logger.error(f"Failed to load placeholder tokenizer: {e_fallback}") |
|
|
self.tokenizer = None |
|
|
|
|
|
|
|
|
if self.tokenizer is None or not callable(self.tokenizer): |
|
|
logger.error("Loaded tokenizer is not callable!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error during model initialization: {e}", exc_info=True) |
|
|
raise |
|
|
|
|
|
|
|
|
if self.model is None or not callable(self.model): |
|
|
logger.error("Model is None or not callable after initialization!") |
|
|
raise RuntimeError("Model failed to initialize correctly.") |
|
|
if self.tokenizer is None or not callable(self.tokenizer): |
|
|
logger.error("Tokenizer is None or not callable after initialization!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess(self, inputs): |
|
|
logger.info("Starting preprocess.") |
|
|
logger.debug(f"Preprocess input type: {type(inputs)}") |
|
|
logger.debug(f"Preprocess input: {inputs}") |
|
|
|
|
|
if self.tokenizer is None or not callable(self.tokenizer): |
|
|
logger.error("Tokenizer is not available or not callable during preprocess.") |
|
|
raise RuntimeError("Tokenizer is not available for preprocessing.") |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(inputs, str): |
|
|
inputs = [inputs] |
|
|
logger.debug("Input is a string, converted to list.") |
|
|
elif not isinstance(inputs, list): |
|
|
logger.error(f"Input must be a string or a list of strings, but got {type(inputs)}.") |
|
|
raise ValueError("Input must be a string or a list of strings.") |
|
|
else: |
|
|
logger.debug("Input is already a list.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
max_len = getattr(self.tokenizer, 'model_max_length', 4096) |
|
|
logger.debug(f"Using max_length for tokenization: {max_len}") |
|
|
tokenized_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True, truncation=True, max_length=max_len) |
|
|
logger.info("Inputs tokenized.") |
|
|
logger.debug(f"Tokenized inputs keys: {tokenized_inputs.keys()}") |
|
|
logger.debug(f"Tokenized inputs['input_ids'] shape: {tokenized_inputs['input_ids'].shape}") |
|
|
|
|
|
|
|
|
|
|
|
tokenized_inputs = {k: v.to(self.device) for k, v in tokenized_inputs.items()} |
|
|
logger.info("Tokenized inputs moved to device.") |
|
|
logger.debug(f"Input tensor device: {tokenized_inputs['input_ids'].device}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_input = {'input_ids': tokenized_inputs['input_ids']} |
|
|
logger.info("Preprocess complete.") |
|
|
logger.debug(f"Model input keys: {model_input.keys()}") |
|
|
logger.debug(f"Model input['input_ids'] shape: {model_input['input_ids'].shape}") |
|
|
return model_input |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error during preprocess: {e}", exc_info=True) |
|
|
raise |
|
|
|
|
|
|
|
|
def inference(self, inputs): |
|
|
logger.info("Starting inference.") |
|
|
logger.debug(f"Inference input keys: {inputs.keys()}") |
|
|
logger.debug(f"Inference input['input_ids'] shape: {inputs['input_ids'].shape}") |
|
|
|
|
|
if self.model is None or not callable(self.model): |
|
|
logger.error("Model is not available or not callable during inference.") |
|
|
raise RuntimeError("Model is not available for inference.") |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
|
outputs = self.model(**inputs) |
|
|
logger.info("Model forward pass complete.") |
|
|
logger.debug(f"Model output type: {type(outputs)}") |
|
|
if isinstance(outputs, dict): |
|
|
logger.debug(f"Model output keys: {outputs.keys()}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(outputs, dict) and 'logits' in outputs: |
|
|
logits = outputs['logits'] |
|
|
logger.debug("Accessed logits from dictionary output.") |
|
|
elif hasattr(outputs, 'logits'): |
|
|
logits = outputs.logits |
|
|
logger.debug("Accessed logits from attribute.") |
|
|
else: |
|
|
logger.error("Model output does not contain 'logits'.") |
|
|
raise AttributeError("Model output does not contain 'logits'.") |
|
|
|
|
|
|
|
|
logger.info("Inference complete.") |
|
|
logger.debug(f"Logits shape: {logits.shape}") |
|
|
return logits |
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error during inference: {e}", exc_info=True) |
|
|
raise |
|
|
|
|
|
|
|
|
def postprocess(self, outputs): |
|
|
logger.info("Starting postprocess for text generation (Greedy Decoding).") |
|
|
|
|
|
|
|
|
|
|
|
if self.tokenizer is None: |
|
|
logger.error("Tokenizer is not available for postprocessing.") |
|
|
raise RuntimeError("Tokenizer is not available for postprocessing.") |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if outputs.ndim == 3 and outputs.shape[0] == 1: |
|
|
last_token_logits = outputs[0, -1, :] |
|
|
predicted_token_id = torch.argmax(last_token_logits).item() |
|
|
|
|
|
|
|
|
predicted_text = self.tokenizer.decode([predicted_token_id]) |
|
|
logger.info(f"Predicted next token: {predicted_text} (ID: {predicted_token_id})") |
|
|
|
|
|
|
|
|
|
|
|
return predicted_text |
|
|
else: |
|
|
logger.warning(f"Unexpected output shape for greedy decoding: {outputs.shape}. Returning raw logits list.") |
|
|
return outputs.cpu().tolist() |
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error during postprocess: {e}", exc_info=True) |
|
|
raise |
|
|
|
|
|
|
|
|
def handle(self, data): |
|
|
logger.info("Starting handle method.") |
|
|
logger.debug(f"Received data in handle: {data}") |
|
|
try: |
|
|
|
|
|
logger.info("Calling preprocess...") |
|
|
model_input = self.preprocess(data) |
|
|
logger.info("Preprocessing successful.") |
|
|
|
|
|
|
|
|
logger.info("Calling inference...") |
|
|
model_output_logits = self.inference(model_input) |
|
|
logger.info("Inference successful.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("Calling postprocess...") |
|
|
response = self.postprocess(model_output_logits) |
|
|
logger.info("Postprocessing successful.") |
|
|
|
|
|
logger.info("Handle method complete.") |
|
|
return response |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in handle method: {e}", exc_info=True) |
|
|
|
|
|
|
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|