Sontran0108
update the compute changes of the handler.py
9abf359
raw
history blame
5.34 kB
from transformers import AutoTokenizer, T5ForConditionalGeneration
import torch
from difflib import SequenceMatcher
class EndpointHandler:
def __init__(self, path=""):
# Load model and tokenizer
model_name = path if path else "grammarly/coedit-large"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = T5ForConditionalGeneration.from_pretrained(model_name)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
def paraphrase_batch(self, sentences, num_return_sequences=1, temperature=1.0):
# Add the text editing prefix to each sentence
prefix = "Fix the grammar: "
sentences_with_prefix = [prefix + s for s in sentences]
inputs = self.tokenizer(
sentences_with_prefix,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
).to(self.device)
outputs = self.model.generate(
**inputs,
max_length=512,
num_beams=5,
temperature=temperature,
num_return_sequences=num_return_sequences,
early_stopping=True
)
decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
if num_return_sequences > 1:
grouped = [
decoded[i * num_return_sequences:(i + 1) * num_return_sequences]
for i in range(len(sentences))
]
return grouped
else:
return decoded
def compute_changes(self, original, enhanced):
changes = []
matcher = SequenceMatcher(None, original, enhanced) # char-level, not token-level
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag in ("replace", "insert", "delete"):
original_phrase = original[i1:i2]
new_phrase = enhanced[j1:j2]
changes.append({
"original_phrase": original_phrase,
"new_phrase": new_phrase,
"char_start": i1,
"char_end": i2,
"token_start": None, # not token-based anymore
"token_end": None,
"explanation": f"{tag} change",
"error_type": "whitespace" if original_phrase.isspace() or new_phrase.isspace() else "",
"tip": "Avoid extra spaces between words." if original_phrase.isspace() or new_phrase.isspace() else ""
})
return changes
def __call__(self, inputs):
# This method is the main entry point for the Hugging Face Endpoint.
# Check for both standard and wrapped JSON inputs
if isinstance(inputs, list):
sentences = inputs
parameters = {}
elif isinstance(inputs, dict):
# Check for the common {"inputs": "...", "parameters": {}} format
sentences = inputs.get("inputs", [])
# If inputs is a single string, wrap it in a list
if isinstance(sentences, str):
sentences = [sentences]
parameters = inputs.get("parameters", {})
else:
return {
"success": False,
"error": "Invalid input format. Expected a string, list of strings, or a dictionary with 'inputs' and 'parameters' keys."
}
# Handle optional parameters
num_return_sequences = parameters.get("num_return_sequences", 1)
temperature = parameters.get("temperature", 1.0)
if not sentences:
return {
"success": False,
"error": "No sentences provided."
}
try:
paraphrased = self.paraphrase_batch(sentences, num_return_sequences, temperature)
results = []
if num_return_sequences > 1:
# Logic for multiple return sequences
for i, orig in enumerate(sentences):
for cand in paraphrased[i]:
results.append({
"original_sentence": orig,
"enhanced_sentence": cand,
"changes": self.compute_changes(orig, cand)
})
else:
# Logic for single return sequence
for orig, cand in zip(sentences, paraphrased):
results.append({
"original_sentence": orig,
"enhanced_sentence": cand,
"changes": self.compute_changes(orig, cand)
})
return {
"success": True,
"results": results,
"sentences_count": len(sentences),
"processed_count": len(results),
"skipped_count": 0,
"error_count": 0
}
except Exception as e:
return {
"success": False,
"error": str(e),
"sentences_count": len(sentences),
"processed_count": 0,
"skipped_count": 0,
"error_count": 1
}