Lorenzob commited on
Commit
354aec4
·
verified ·
1 Parent(s): 9571d80

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +45 -21
handler.py CHANGED
@@ -188,24 +188,46 @@ class EndpointHandler:
188
 
189
 
190
  def postprocess(self, outputs):
191
- logger.info("Starting postprocess.")
192
- logger.debug(f"Postprocess input shape: {outputs.shape}")
 
 
 
 
 
 
193
  try:
194
- # Postprocess the model outputs
195
- # 'outputs' here is the output of the inference method (e.g., logits)
196
- # For text generation, you would typically decode the generated token IDs
197
- # This is a placeholder postprocessing step (e.g., returning the raw logits as a list)
198
-
199
- # Example: decode token IDs if using model.generate()
200
- # generated_ids = outputs[0] # Assuming outputs from generate() is a tensor
201
- # generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
202
- # return generated_text
203
-
204
- # For this basic handler returning logits, just convert to CPU and list
205
- response = outputs.cpu().tolist()
206
- logger.info("Postprocess complete.")
207
- # logger.debug(f"Postprocess output (partial): {response[:10]}...") # Avoid printing very large outputs
208
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  except Exception as e:
211
  logger.error(f"Error during postprocess: {e}", exc_info=True)
@@ -221,14 +243,16 @@ class EndpointHandler:
221
  model_input = self.preprocess(data)
222
  logger.info("Preprocessing successful.")
223
 
224
- # 2. Inference
225
  logger.info("Calling inference...")
226
- model_output = self.inference(model_input)
227
  logger.info("Inference successful.")
228
 
229
- # 3. Postprocess
 
 
230
  logger.info("Calling postprocess...")
231
- response = self.postprocess(model_output)
232
  logger.info("Postprocessing successful.")
233
 
234
  logger.info("Handle method complete.")
 
188
 
189
 
190
  def postprocess(self, outputs):
191
+ logger.info("Starting postprocess for text generation (Greedy Decoding).")
192
+ # 'outputs' here is the output of the inference method (logits)
193
+ # This implements a basic greedy decoding strategy
194
+
195
+ if self.tokenizer is None:
196
+ logger.error("Tokenizer is not available for postprocessing.")
197
+ raise RuntimeError("Tokenizer is not available for postprocessing.")
198
+
199
  try:
200
+ # Assuming outputs are logits of shape (batch_size, sequence_length, vocab_size)
201
+ # For greedy decoding, we take the argmax of the logits for the last token
202
+ # and append it to the input sequence.
203
+ # This basic handler will generate one token at a time in a loop.
204
+ # A real text generation handler would likely take the initial input_ids
205
+ # and loop until a stop condition is met.
206
+
207
+ # For a single forward pass output (like from the inference method),
208
+ # we can't generate a sequence directly here.
209
+ # The handle method would need to manage the generation loop.
210
+
211
+ # Let's adapt this postprocess to just decode the most probable token from the last position
212
+ # as a basic example, or return the input + the most probable next token.
213
+
214
+ # Assuming inputs were processed one by one (batch size 1 for simplicity in this example)
215
+ # And outputs are logits for the input sequence
216
+ if outputs.ndim == 3 and outputs.shape[0] == 1: # Shape (1, seq_len, vocab_size)
217
+ last_token_logits = outputs[0, -1, :] # Logits for the last token in the sequence
218
+ predicted_token_id = torch.argmax(last_token_logits).item()
219
+
220
+ # Decode the predicted token
221
+ predicted_text = self.tokenizer.decode([predicted_token_id])
222
+ logger.info(f"Predicted next token: {predicted_text} (ID: {predicted_token_id})")
223
+
224
+ # In a real generation loop, you would append this to the input and repeat.
225
+ # For this handler, let's just return the predicted token text.
226
+ return predicted_text
227
+ else:
228
+ logger.warning(f"Unexpected output shape for greedy decoding: {outputs.shape}. Returning raw logits list.")
229
+ return outputs.cpu().tolist()
230
+
231
 
232
  except Exception as e:
233
  logger.error(f"Error during postprocess: {e}", exc_info=True)
 
243
  model_input = self.preprocess(data)
244
  logger.info("Preprocessing successful.")
245
 
246
+ # 2. Inference (single forward pass to get logits)
247
  logger.info("Calling inference...")
248
+ model_output_logits = self.inference(model_input)
249
  logger.info("Inference successful.")
250
 
251
+ # 3. Postprocess (basic greedy decoding of the next token)
252
+ # Note: This postprocess only generates the *next* token.
253
+ # For full text generation, you would need a loop here or modify inference.
254
  logger.info("Calling postprocess...")
255
+ response = self.postprocess(model_output_logits)
256
  logger.info("Postprocessing successful.")
257
 
258
  logger.info("Handle method complete.")