Spaces:
Runtime error
Runtime error
| from text_extractor import TextExtractor | |
| from tqdm import tqdm | |
| from transformers import PegasusForConditionalGeneration, PegasusTokenizer | |
| from transformers import pipeline | |
| from mdutils.mdutils import MdUtils | |
| from pathlib import Path | |
| import gradio as gr | |
| import fitz | |
| import torch | |
| import copy | |
| import os | |
| FILENAME = "" | |
| preprocess = TextExtractor() | |
| model_name = "sshleifer/distill-pegasus-cnn-16-4" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| tokenizer = PegasusTokenizer.from_pretrained(model_name, max_length=500) | |
| model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device) | |
| def summarize(slides): | |
| generated_slides = copy.deepcopy(slides) | |
| for page, contents in tqdm(generated_slides.items()): | |
| for idx, (tag, content) in enumerate(contents): | |
| if tag.startswith('p'): | |
| try: | |
| input = tokenizer(content, truncation=True, padding="longest", return_tensors="pt").to(device) | |
| tensor = model.generate(**input) | |
| summary = tokenizer.batch_decode(tensor, skip_special_tokens=True)[0] | |
| contents[idx] = (tag, summary) | |
| except Exception as e: | |
| print(e) | |
| print("Summarization Fails") | |
| return generated_slides | |
| def convert2markdown(generate_slides): | |
| mdFile = MdUtils(file_name=FILENAME, title=f'{FILENAME} Presentation') | |
| for k, v in generate_slides.items(): | |
| mdFile.new_paragraph('---') | |
| for section in v: | |
| tag = section[0] | |
| content = section[1] | |
| if tag.startswith('h'): | |
| mdFile.new_header(level=int(tag[1]), title=content) | |
| if tag == 'p': | |
| contents = content.split('<n>') | |
| for content in contents: | |
| mdFile.new_paragraph(content) | |
| mdFile.create_md_file() | |
| return f"{FILENAME}.md" | |
| def inference(document): | |
| global FILENAME | |
| print(document) | |
| doc = fitz.open(document) | |
| FILENAME = Path(doc.name).stem | |
| font_counts, styles = preprocess.get_font_info(doc, granularity=False) | |
| size_tag = preprocess.get_font_tags(font_counts, styles) | |
| texts = preprocess.assign_tags(doc, size_tag) | |
| slides = preprocess.get_slides(texts) | |
| generated_slides = summarize(slides) | |
| markdown_path = convert2markdown(generated_slides) | |
| # with open(markdown_path, 'rt') as f: | |
| # markdown_str = f.read() | |
| return markdown_path | |
| with gr.Blocks() as demo: | |
| inp = gr.File(file_types=['pdf']) | |
| out = gr.File(label="Markdown File") | |
| # out = gr.Textbox(label="Markdown Content") | |
| inference_btn = gr.Button("Summarized PDF") | |
| inference_btn.click(fn=inference, inputs=inp, outputs=out, show_progress=True, api_name="summarize") | |
| demo.launch() |