File size: 13,749 Bytes
2119d2a
 
 
 
 
670cd6b
2119d2a
899ed3a
 
 
 
 
2119d2a
670cd6b
 
bf25e1f
2119d2a
 
 
bf25e1f
2119d2a
899ed3a
2119d2a
899ed3a
 
 
 
bf25e1f
899ed3a
bf25e1f
 
 
899ed3a
 
bf25e1f
899ed3a
bf25e1f
 
 
 
899ed3a
 
bf25e1f
 
899ed3a
 
bf25e1f
 
899ed3a
 
bf25e1f
 
 
899ed3a
 
 
 
 
 
bf25e1f
 
899ed3a
 
 
 
bf25e1f
9571d80
bf25e1f
 
 
 
899ed3a
 
 
 
2119d2a
bf25e1f
 
 
 
 
 
 
 
 
 
8059d57
2119d2a
899ed3a
 
 
bf25e1f
 
 
 
 
899ed3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2119d2a
 
 
899ed3a
 
 
bf25e1f
 
 
 
 
899ed3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2119d2a
 
 
354aec4
 
 
 
 
 
 
 
899ed3a
354aec4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2119d2a
899ed3a
 
 
2119d2a
 
 
899ed3a
 
 
 
9571d80
899ed3a
 
 
354aec4
9571d80
354aec4
899ed3a
2119d2a
354aec4
 
 
9571d80
354aec4
899ed3a
2119d2a
899ed3a
 
2119d2a
899ed3a
 
 
 
 
2119d2a
 
 
 
899ed3a
 
 
 
2119d2a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278

import torch
import os
from transformers import AutoTokenizer
# Import your custom model class
# Assume the inference environment handles adding the repository root to the Python path
from modelling_trm import TRM, TRMConfig
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class EndpointHandler:
    def __init__(self, model_path="."): # Accept model_path during initialization
        self.model = None
        self.tokenizer = None
        self.device = None
        self.model_path = model_path # Store model_path

        logger.info(f"Initializing model from directory: {self.model_path}")
        try:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            logger.info(f"Loading model on device: {self.device}")

            # Load the config
            logger.info("Attempting to load model config...")
            config = TRMConfig.from_pretrained(self.model_path)
            logger.info("Model config loaded successfully.")
            logger.debug(f"Config type: {type(config)}")


            # Load the model
            logger.info("Attempting to load model...")
            self.model = TRM.from_pretrained(self.model_path, config=config)
            logger.info("Model loaded successfully.")
            logger.debug(f"Model type: {type(self.model)}")
            logger.debug(f"Is model callable: {callable(self.model)}")

            self.model.to(self.device)
            self.model.eval() # Set model to evaluation mode
            logger.info("Model set to evaluation mode.")


            # Load the tokenizer (using a placeholder as the original had issues)
            # You might need to adapt this based on your actual tokenizer
            logger.info("Attempting to load tokenizer...")
            try:
                self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
                logger.info("Tokenizer loaded successfully from model path.")
                logger.debug(f"Tokenizer type: {type(self.tokenizer)}")
                logger.debug(f"Is tokenizer callable: {callable(self.tokenizer)}")
            except Exception as e:
                logger.warning(f"Failed to load tokenizer from model path: {e}. Falling back to basic tokenizer.")
                # Fallback to a basic tokenizer if loading from path fails
                try:
                    self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
                    logger.info("Loaded a placeholder tokenizer (bert-base-uncased) for inference.")
                    logger.debug(f"Tokenizer type (fallback): {type(self.tokenizer)}")
                    logger.debug(f"Is tokenizer callable (fallback): {callable(self.tokenizer)}")
                except Exception as e_fallback:
                     logger.error(f"Failed to load placeholder tokenizer: {e_fallback}")
                     self.tokenizer = None # Ensure tokenizer is None if fallback fails

            # Check if tokenizer is callable after loading attempts
            if self.tokenizer is None or not callable(self.tokenizer):
                 logger.error("Loaded tokenizer is not callable!")
                 # Optionally raise an error here if tokenizer is essential
                 # raise TypeError("Loaded tokenizer is not callable.")


        except Exception as e:
            logger.error(f"Error during model initialization: {e}", exc_info=True)
            raise # Re-raise the exception to indicate initialization failure

        # Final check after initialization
        if self.model is None or not callable(self.model):
             logger.error("Model is None or not callable after initialization!")
             raise RuntimeError("Model failed to initialize correctly.")
        if self.tokenizer is None or not callable(self.tokenizer):
             logger.error("Tokenizer is None or not callable after initialization!")
             # Depending on if tokenizer is strictly required for handle,
             # you might raise an error here or handle it in preprocess/inference
             # raise RuntimeError("Tokenizer failed to initialize correctly.") # Uncomment if tokenizer is essential


    def preprocess(self, inputs):
        logger.info("Starting preprocess.")
        logger.debug(f"Preprocess input type: {type(inputs)}")
        logger.debug(f"Preprocess input: {inputs}")
        # Check if tokenizer is callable before using it
        if self.tokenizer is None or not callable(self.tokenizer):
            logger.error("Tokenizer is not available or not callable during preprocess.")
            raise RuntimeError("Tokenizer is not available for preprocessing.")

        try:
            # Preprocess inputs for the model
            # 'inputs' will be the data received by the inference endpoint
            # This needs to be adapted based on the expected input format (e.g., text string)
            # For text generation, 'inputs' could be a string or a list of strings.
            if isinstance(inputs, str):
                inputs = [inputs]
                logger.debug("Input is a string, converted to list.")
            elif not isinstance(inputs, list):
                 logger.error(f"Input must be a string or a list of strings, but got {type(inputs)}.")
                 raise ValueError("Input must be a string or a list of strings.")
            else:
                 logger.debug("Input is already a list.")


            # Tokenize the input
            # Ensure padding and truncation are handled
            max_len = getattr(self.tokenizer, 'model_max_length', 4096) # Use a default max_len if not available
            logger.debug(f"Using max_length for tokenization: {max_len}")
            tokenized_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True, truncation=True, max_length=max_len)
            logger.info("Inputs tokenized.")
            logger.debug(f"Tokenized inputs keys: {tokenized_inputs.keys()}")
            logger.debug(f"Tokenized inputs['input_ids'] shape: {tokenized_inputs['input_ids'].shape}")


            # Move tokenized inputs to the model's device
            tokenized_inputs = {k: v.to(self.device) for k, v in tokenized_inputs.items()}
            logger.info("Tokenized inputs moved to device.")
            logger.debug(f"Input tensor device: {tokenized_inputs['input_ids'].device}")


            # Return only the inputs expected by the TRM model
            # Based on training, TRM seems to only take 'input_ids'
            model_input = {'input_ids': tokenized_inputs['input_ids']}
            logger.info("Preprocess complete.")
            logger.debug(f"Model input keys: {model_input.keys()}")
            logger.debug(f"Model input['input_ids'] shape: {model_input['input_ids'].shape}")
            return model_input

        except Exception as e:
            logger.error(f"Error during preprocess: {e}", exc_info=True)
            raise # Re-raise the exception


    def inference(self, inputs):
        logger.info("Starting inference.")
        logger.debug(f"Inference input keys: {inputs.keys()}")
        logger.debug(f"Inference input['input_ids'] shape: {inputs['input_ids'].shape}")
        # Check if model is callable before using it
        if self.model is None or not callable(self.model):
            logger.error("Model is not available or not callable during inference.")
            raise RuntimeError("Model is not available for inference.")

        try:
            # Perform inference with the model
            # 'inputs' here is the output of the preprocess method
            with torch.no_grad():
                # Perform the forward pass
                # Assuming the model only takes input_ids
                outputs = self.model(**inputs)
            logger.info("Model forward pass complete.")
            logger.debug(f"Model output type: {type(outputs)}")
            if isinstance(outputs, dict):
                logger.debug(f"Model output keys: {outputs.keys()}")
            # logger.debug(f"Model output: {outputs}") # Uncomment for detailed output inspection


            # The model's output structure might differ, assuming it returns logits
            # You might need to adapt this based on the actual TRM output for inference
            # For text generation, you might use model.generate() instead of a simple forward pass
            # This example performs a simple forward pass and returns logits
            # Based on previous debugging, output is a dict with 'logits' key
            if isinstance(outputs, dict) and 'logits' in outputs:
                 logits = outputs['logits']
                 logger.debug("Accessed logits from dictionary output.")
            elif hasattr(outputs, 'logits'):
                 logits = outputs.logits
                 logger.debug("Accessed logits from attribute.")
            else:
                 logger.error("Model output does not contain 'logits'.")
                 raise AttributeError("Model output does not contain 'logits'.")


            logger.info("Inference complete.")
            logger.debug(f"Logits shape: {logits.shape}")
            return logits # Or process logits further for text generation


        except Exception as e:
            logger.error(f"Error during inference: {e}", exc_info=True)
            raise # Re-raise the exception


    def postprocess(self, outputs):
        logger.info("Starting postprocess for text generation (Greedy Decoding).")
        # 'outputs' here is the output of the inference method (logits)
        # This implements a basic greedy decoding strategy

        if self.tokenizer is None:
            logger.error("Tokenizer is not available for postprocessing.")
            raise RuntimeError("Tokenizer is not available for postprocessing.")

        try:
            # Assuming outputs are logits of shape (batch_size, sequence_length, vocab_size)
            # For greedy decoding, we take the argmax of the logits for the last token
            # and append it to the input sequence.
            # This basic handler will generate one token at a time in a loop.
            # A real text generation handler would likely take the initial input_ids
            # and loop until a stop condition is met.

            # For a single forward pass output (like from the inference method),
            # we can't generate a sequence directly here.
            # The handle method would need to manage the generation loop.

            # Let's adapt this postprocess to just decode the most probable token from the last position
            # as a basic example, or return the input + the most probable next token.

            # Assuming inputs were processed one by one (batch size 1 for simplicity in this example)
            # And outputs are logits for the input sequence
            if outputs.ndim == 3 and outputs.shape[0] == 1: # Shape (1, seq_len, vocab_size)
                last_token_logits = outputs[0, -1, :] # Logits for the last token in the sequence
                predicted_token_id = torch.argmax(last_token_logits).item()

                # Decode the predicted token
                predicted_text = self.tokenizer.decode([predicted_token_id])
                logger.info(f"Predicted next token: {predicted_text} (ID: {predicted_token_id})")

                # In a real generation loop, you would append this to the input and repeat.
                # For this handler, let's just return the predicted token text.
                return predicted_text
            else:
                logger.warning(f"Unexpected output shape for greedy decoding: {outputs.shape}. Returning raw logits list.")
                return outputs.cpu().tolist()


        except Exception as e:
            logger.error(f"Error during postprocess: {e}", exc_info=True)
            raise # Re-raise the exception


    def handle(self, data):
        logger.info("Starting handle method.")
        logger.debug(f"Received data in handle: {data}")
        try:
            # 1. Preprocess
            logger.info("Calling preprocess...")
            model_input = self.preprocess(data)
            logger.info("Preprocessing successful.")

            # 2. Inference (single forward pass to get logits)
            logger.info("Calling inference...")
            model_output_logits = self.inference(model_input)
            logger.info("Inference successful.")

            # 3. Postprocess (basic greedy decoding of the next token)
            # Note: This postprocess only generates the *next* token.
            # For full text generation, you would need a loop here or modify inference.
            logger.info("Calling postprocess...")
            response = self.postprocess(model_output_logits)
            logger.info("Postprocessing successful.")

            logger.info("Handle method complete.")
            return response

        except Exception as e:
            logger.error(f"Error in handle method: {e}", exc_info=True)
            # Depending on requirements, you might want to return an error response
            # return {"error": str(e)}
            raise # Re-raise the exception to be caught by the inference toolkit


# Example usage (for testing locally)
# if __name__ == "__main__":
#     # This part assumes you have the model files in the current directory or specify model_path
#     # For local testing, you might need to adjust model_path
#     # Example: handler = EndpointHandler(model_path="./custom_training_output")
#     handler = EndpointHandler() # Assuming model files are in "."
#     test_input = "This is a test input"
#     output = handler.handle(test_input)
#     print("Inference output:", output)