|
|
from transformers import AutoTokenizer, T5ForConditionalGeneration |
|
|
import torch |
|
|
from difflib import SequenceMatcher |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
|
|
|
model_name = path if path else "grammarly/coedit-large" |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
self.model = T5ForConditionalGeneration.from_pretrained(model_name) |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.model.to(self.device) |
|
|
|
|
|
def paraphrase_batch(self, sentences, num_return_sequences=1, temperature=1.0): |
|
|
|
|
|
prefix = "Fix the grammar: " |
|
|
sentences_with_prefix = [prefix + s for s in sentences] |
|
|
|
|
|
inputs = self.tokenizer( |
|
|
sentences_with_prefix, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=512, |
|
|
return_tensors="pt" |
|
|
).to(self.device) |
|
|
|
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
max_length=512, |
|
|
num_beams=5, |
|
|
temperature=temperature, |
|
|
num_return_sequences=num_return_sequences, |
|
|
early_stopping=True |
|
|
) |
|
|
|
|
|
decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
|
if num_return_sequences > 1: |
|
|
grouped = [ |
|
|
decoded[i * num_return_sequences:(i + 1) * num_return_sequences] |
|
|
for i in range(len(sentences)) |
|
|
] |
|
|
return grouped |
|
|
else: |
|
|
return decoded |
|
|
|
|
|
def compute_changes(self, original, enhanced): |
|
|
changes = [] |
|
|
matcher = SequenceMatcher(None, original, enhanced) |
|
|
|
|
|
for tag, i1, i2, j1, j2 in matcher.get_opcodes(): |
|
|
if tag in ("replace", "insert", "delete"): |
|
|
original_phrase = original[i1:i2] |
|
|
new_phrase = enhanced[j1:j2] |
|
|
changes.append({ |
|
|
"original_phrase": original_phrase, |
|
|
"new_phrase": new_phrase, |
|
|
"char_start": i1, |
|
|
"char_end": i2, |
|
|
"token_start": None, |
|
|
"token_end": None, |
|
|
"explanation": f"{tag} change", |
|
|
"error_type": "whitespace" if original_phrase.isspace() or new_phrase.isspace() else "", |
|
|
"tip": "Avoid extra spaces between words." if original_phrase.isspace() or new_phrase.isspace() else "" |
|
|
}) |
|
|
return changes |
|
|
|
|
|
|
|
|
def __call__(self, inputs): |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(inputs, list): |
|
|
sentences = inputs |
|
|
parameters = {} |
|
|
elif isinstance(inputs, dict): |
|
|
|
|
|
sentences = inputs.get("inputs", []) |
|
|
|
|
|
if isinstance(sentences, str): |
|
|
sentences = [sentences] |
|
|
parameters = inputs.get("parameters", {}) |
|
|
else: |
|
|
return { |
|
|
"success": False, |
|
|
"error": "Invalid input format. Expected a string, list of strings, or a dictionary with 'inputs' and 'parameters' keys." |
|
|
} |
|
|
|
|
|
|
|
|
num_return_sequences = parameters.get("num_return_sequences", 1) |
|
|
temperature = parameters.get("temperature", 1.0) |
|
|
|
|
|
if not sentences: |
|
|
return { |
|
|
"success": False, |
|
|
"error": "No sentences provided." |
|
|
} |
|
|
|
|
|
try: |
|
|
paraphrased = self.paraphrase_batch(sentences, num_return_sequences, temperature) |
|
|
results = [] |
|
|
|
|
|
if num_return_sequences > 1: |
|
|
|
|
|
for i, orig in enumerate(sentences): |
|
|
for cand in paraphrased[i]: |
|
|
results.append({ |
|
|
"original_sentence": orig, |
|
|
"enhanced_sentence": cand, |
|
|
"changes": self.compute_changes(orig, cand) |
|
|
}) |
|
|
else: |
|
|
|
|
|
for orig, cand in zip(sentences, paraphrased): |
|
|
results.append({ |
|
|
"original_sentence": orig, |
|
|
"enhanced_sentence": cand, |
|
|
"changes": self.compute_changes(orig, cand) |
|
|
}) |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"results": results, |
|
|
"sentences_count": len(sentences), |
|
|
"processed_count": len(results), |
|
|
"skipped_count": 0, |
|
|
"error_count": 0 |
|
|
} |
|
|
except Exception as e: |
|
|
return { |
|
|
"success": False, |
|
|
"error": str(e), |
|
|
"sentences_count": len(sentences), |
|
|
"processed_count": 0, |
|
|
"skipped_count": 0, |
|
|
"error_count": 1 |
|
|
} |