Spaces:
Running
Running
| import transformers | |
| import re | |
| from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
| import torch | |
| import gradio as gr | |
| import difflib | |
| from concurrent.futures import ThreadPoolExecutor | |
| import os | |
| # OCR Correction Model | |
| model_name = "PleIAs/OCRonos-Vintage" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load pre-trained model and tokenizer | |
| model = GPT2LMHeadModel.from_pretrained(model_name).to(device) | |
| tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
| # CSS for formatting | |
| css = """ | |
| <style> | |
| .generation { | |
| margin-left: 2em; | |
| margin-right: 2em; | |
| font-size: 1.2em; | |
| } | |
| .inserted { | |
| background-color: #90EE90; | |
| } | |
| </style> | |
| """ | |
| def generate_html_diff(old_text, new_text): | |
| d = difflib.Differ() | |
| diff = list(d.compare(old_text.split(), new_text.split())) | |
| html_diff = [] | |
| for word in diff: | |
| if word.startswith(' '): | |
| html_diff.append(word[2:]) | |
| elif word.startswith('+ '): | |
| html_diff.append(f'<span class="inserted">{word[2:]}</span>') | |
| return ' '.join(html_diff) | |
| def split_text(text, max_tokens=400): | |
| tokens = tokenizer.tokenize(text) | |
| chunks = [] | |
| current_chunk = [] | |
| for token in tokens: | |
| current_chunk.append(token) | |
| if len(current_chunk) >= max_tokens: | |
| chunks.append(tokenizer.convert_tokens_to_string(current_chunk)) | |
| current_chunk = [] | |
| if current_chunk: | |
| chunks.append(tokenizer.convert_tokens_to_string(current_chunk)) | |
| return chunks | |
| def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()): | |
| prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n""" | |
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
| torch.set_num_threads(num_threads) | |
| with ThreadPoolExecutor(max_workers=num_threads) as executor: | |
| future = executor.submit( | |
| model.generate, | |
| input_ids, | |
| max_new_tokens=max_new_tokens, | |
| pad_token_id=tokenizer.eos_token_id, | |
| top_k=50, | |
| num_return_sequences=1, | |
| do_sample=False | |
| ) | |
| output = future.result() | |
| result = tokenizer.decode(output[0], skip_special_tokens=True) | |
| return result.split("### Correction ###")[1].strip() | |
| def process_text(user_message): | |
| chunks = split_text(user_message) | |
| corrected_chunks = [] | |
| for chunk in chunks: | |
| corrected_chunk = ocr_correction(chunk) | |
| corrected_chunks.append(corrected_chunk) | |
| corrected_text = ' '.join(corrected_chunks) | |
| html_diff = generate_html_diff(user_message, corrected_text) | |
| ocr_result = f'<h2 style="text-align:center">OCR Correction</h2>\n<div class="generation">{html_diff}</div>' | |
| final_output = f"{css}{ocr_result}" | |
| return final_output | |
| # Define the Gradio interface | |
| with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo: | |
| gr.HTML("""<h1 style="text-align:center">Vintage OCR corrector (CPU)</h1>""") | |
| text_input = gr.Textbox(label="Your (bad?) text", type="text", lines=5) | |
| process_button = gr.Button("Process Text") | |
| text_output = gr.HTML(label="Processed text") | |
| process_button.click(process_text, inputs=text_input, outputs=[text_output]) | |
| if __name__ == "__main__": | |
| demo.queue().launch() |