Upload folder using huggingface_hub
Browse files- handler.py +45 -21
handler.py
CHANGED
|
@@ -188,24 +188,46 @@ class EndpointHandler:
|
|
| 188 |
|
| 189 |
|
| 190 |
def postprocess(self, outputs):
|
| 191 |
-
logger.info("Starting postprocess.")
|
| 192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
try:
|
| 194 |
-
#
|
| 195 |
-
#
|
| 196 |
-
#
|
| 197 |
-
# This
|
| 198 |
-
|
| 199 |
-
#
|
| 200 |
-
|
| 201 |
-
#
|
| 202 |
-
#
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
except Exception as e:
|
| 211 |
logger.error(f"Error during postprocess: {e}", exc_info=True)
|
|
@@ -221,14 +243,16 @@ class EndpointHandler:
|
|
| 221 |
model_input = self.preprocess(data)
|
| 222 |
logger.info("Preprocessing successful.")
|
| 223 |
|
| 224 |
-
# 2. Inference
|
| 225 |
logger.info("Calling inference...")
|
| 226 |
-
|
| 227 |
logger.info("Inference successful.")
|
| 228 |
|
| 229 |
-
# 3. Postprocess
|
|
|
|
|
|
|
| 230 |
logger.info("Calling postprocess...")
|
| 231 |
-
response = self.postprocess(
|
| 232 |
logger.info("Postprocessing successful.")
|
| 233 |
|
| 234 |
logger.info("Handle method complete.")
|
|
|
|
| 188 |
|
| 189 |
|
| 190 |
def postprocess(self, outputs):
|
| 191 |
+
logger.info("Starting postprocess for text generation (Greedy Decoding).")
|
| 192 |
+
# 'outputs' here is the output of the inference method (logits)
|
| 193 |
+
# This implements a basic greedy decoding strategy
|
| 194 |
+
|
| 195 |
+
if self.tokenizer is None:
|
| 196 |
+
logger.error("Tokenizer is not available for postprocessing.")
|
| 197 |
+
raise RuntimeError("Tokenizer is not available for postprocessing.")
|
| 198 |
+
|
| 199 |
try:
|
| 200 |
+
# Assuming outputs are logits of shape (batch_size, sequence_length, vocab_size)
|
| 201 |
+
# For greedy decoding, we take the argmax of the logits for the last token
|
| 202 |
+
# and append it to the input sequence.
|
| 203 |
+
# This basic handler will generate one token at a time in a loop.
|
| 204 |
+
# A real text generation handler would likely take the initial input_ids
|
| 205 |
+
# and loop until a stop condition is met.
|
| 206 |
+
|
| 207 |
+
# For a single forward pass output (like from the inference method),
|
| 208 |
+
# we can't generate a sequence directly here.
|
| 209 |
+
# The handle method would need to manage the generation loop.
|
| 210 |
+
|
| 211 |
+
# Let's adapt this postprocess to just decode the most probable token from the last position
|
| 212 |
+
# as a basic example, or return the input + the most probable next token.
|
| 213 |
+
|
| 214 |
+
# Assuming inputs were processed one by one (batch size 1 for simplicity in this example)
|
| 215 |
+
# And outputs are logits for the input sequence
|
| 216 |
+
if outputs.ndim == 3 and outputs.shape[0] == 1: # Shape (1, seq_len, vocab_size)
|
| 217 |
+
last_token_logits = outputs[0, -1, :] # Logits for the last token in the sequence
|
| 218 |
+
predicted_token_id = torch.argmax(last_token_logits).item()
|
| 219 |
+
|
| 220 |
+
# Decode the predicted token
|
| 221 |
+
predicted_text = self.tokenizer.decode([predicted_token_id])
|
| 222 |
+
logger.info(f"Predicted next token: {predicted_text} (ID: {predicted_token_id})")
|
| 223 |
+
|
| 224 |
+
# In a real generation loop, you would append this to the input and repeat.
|
| 225 |
+
# For this handler, let's just return the predicted token text.
|
| 226 |
+
return predicted_text
|
| 227 |
+
else:
|
| 228 |
+
logger.warning(f"Unexpected output shape for greedy decoding: {outputs.shape}. Returning raw logits list.")
|
| 229 |
+
return outputs.cpu().tolist()
|
| 230 |
+
|
| 231 |
|
| 232 |
except Exception as e:
|
| 233 |
logger.error(f"Error during postprocess: {e}", exc_info=True)
|
|
|
|
| 243 |
model_input = self.preprocess(data)
|
| 244 |
logger.info("Preprocessing successful.")
|
| 245 |
|
| 246 |
+
# 2. Inference (single forward pass to get logits)
|
| 247 |
logger.info("Calling inference...")
|
| 248 |
+
model_output_logits = self.inference(model_input)
|
| 249 |
logger.info("Inference successful.")
|
| 250 |
|
| 251 |
+
# 3. Postprocess (basic greedy decoding of the next token)
|
| 252 |
+
# Note: This postprocess only generates the *next* token.
|
| 253 |
+
# For full text generation, you would need a loop here or modify inference.
|
| 254 |
logger.info("Calling postprocess...")
|
| 255 |
+
response = self.postprocess(model_output_logits)
|
| 256 |
logger.info("Postprocessing successful.")
|
| 257 |
|
| 258 |
logger.info("Handle method complete.")
|