Spaces:
Running
Running
| # Hint: this cheatsheet is magic! https://cheat-sheet.streamlit.app/ | |
| import constants | |
| import pandas as pd | |
| import streamlit as st | |
| import matplotlib.pyplot as plt | |
| from transformers import BertForSequenceClassification, AutoTokenizer | |
| import altair as alt | |
| from altair import X, Y, Scale | |
| import base64 | |
| import re | |
| def preprocess_text(arabic_text, remove_urls, remove_latin): | |
| """Apply preprocessing to the given Arabic text. | |
| Args: | |
| arabic_text: The Arabic text to be preprocessed. | |
| remove_urls: Boolean indicating whether to remove URLs. | |
| remove_latin: Boolean indicating whether to remove Latin characters. | |
| Returns: | |
| The preprocessed Arabic text. | |
| """ | |
| if remove_urls: | |
| arabic_text = re.sub( | |
| r"(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b", | |
| "", | |
| arabic_text, | |
| flags=re.MULTILINE, | |
| ) | |
| if remove_latin: | |
| arabic_text = re.sub(r"[a-zA-Z]", "", arabic_text) | |
| arabic_text = arabic_text.strip() | |
| return arabic_text | |
| def render_svg(svg): | |
| """Renders the given svg string.""" | |
| b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8") | |
| html = rf'<p align="center"> <img src="data:image/svg+xml;base64,{b64}"/> </p>' | |
| c = st.container() | |
| c.write(html, unsafe_allow_html=True) | |
| def convert_df(df): | |
| # IMPORTANT: Cache the conversion to prevent computation on every rerun | |
| return df.to_csv(index=None).encode("utf-8") | |
| def load_model(model_name): | |
| model = BertForSequenceClassification.from_pretrained(model_name) | |
| return model | |
| tokenizer = AutoTokenizer.from_pretrained(constants.MODEL_NAME) | |
| model = load_model(constants.MODEL_NAME) | |
| def compute_ALDi(sentences, remove_urls=True, remove_latin=True): | |
| """Computes the ALDi score for the given sentences. | |
| Args: | |
| sentences: A list of Arabic sentences. | |
| Returns: | |
| A list of ALDi scores for the given sentences. | |
| """ | |
| progress_text = "Computing ALDi..." | |
| my_bar = st.progress(0, text=progress_text) | |
| BATCH_SIZE = 4 | |
| output_logits = [] | |
| preprocessed_sentences = [ | |
| preprocess_text(s, remove_urls, remove_latin) for s in sentences | |
| ] | |
| for first_index in range(0, len(preprocessed_sentences), BATCH_SIZE): | |
| inputs = tokenizer( | |
| preprocessed_sentences[first_index : first_index + BATCH_SIZE], | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| outputs = model(**inputs).logits.reshape(-1).tolist() | |
| output_logits = output_logits + [max(min(o, 1), 0) for o in outputs] | |
| my_bar.progress( | |
| min((first_index + BATCH_SIZE) / len(preprocessed_sentences), 1), | |
| text=progress_text, | |
| ) | |
| my_bar.empty() | |
| return output_logits | |
| def render_metadata(): | |
| """Renders the metadata.""" | |
| html = r"""<p align="center"> | |
| <a href="https://huggingface.co/AMR-KELEG/Sentence-ALDi"><img alt="HuggingFace Model" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-8A2BE2"></a> | |
| <a href="https://github.com/AMR-KELEG/ALDi"><img alt="GitHub" src="https://img.shields.io/badge/%F0%9F%93%A6%20GitHub-orange"></a> | |
| <a href="https://arxiv.org/abs/2310.13747"><img alt="arXiv" src="https://img.shields.io/badge/arXiv-2310.13747-b31b1b.svg"></a> | |
| </p>""" | |
| c = st.container() | |
| c.write(html, unsafe_allow_html=True) | |
| render_svg(open("assets/ALDi_logo.svg").read()) | |
| render_metadata() | |
| tab1, tab2 = st.tabs(["Input a Sentence", "Upload a File"]) | |
| with tab1: | |
| sent = st.text_input( | |
| "Arabic Sentence:", placeholder="Enter an Arabic sentence.", on_change=None | |
| ) | |
| # TODO: Check if this is needed! | |
| clicked = st.button("Submit") | |
| remove_urls = st.toggle("Remove urls") | |
| remove_latin = st.toggle("Remove Latin characters") | |
| if sent: | |
| ALDi_score = compute_ALDi( | |
| [sent], remove_urls=remove_urls, remove_latin=remove_latin | |
| )[0] | |
| ORANGE_COLOR = "#FF8000" | |
| fig, ax = plt.subplots(figsize=(8, 1)) | |
| fig.patch.set_facecolor("none") | |
| ax.set_facecolor("none") | |
| ax.spines["left"].set_color(ORANGE_COLOR) | |
| ax.spines["bottom"].set_color(ORANGE_COLOR) | |
| ax.tick_params(axis="x", colors=ORANGE_COLOR) | |
| ax.spines[["right", "top"]].set_visible(False) | |
| ax.barh(y=[0], width=[ALDi_score], color=ORANGE_COLOR) | |
| ax.set_xlim(0, 1) | |
| ax.set_ylim(-1, 1) | |
| ax.set_title(f"ALDi score is: {round(ALDi_score, 3)}", color=ORANGE_COLOR) | |
| ax.get_yaxis().set_visible(False) | |
| ax.set_xlabel("ALDi score", color=ORANGE_COLOR) | |
| st.pyplot(fig) | |
| print(sent) | |
| with open("logs.txt", "a") as f: | |
| f.write(sent + "\n") | |
| with tab2: | |
| file = st.file_uploader("Upload a file", type=["txt"]) | |
| if file is not None: | |
| df = pd.read_csv(file, sep="\t", header=None) | |
| df.columns = ["Sentence"] | |
| df.reset_index(drop=True, inplace=True) | |
| # TODO: Run the model | |
| df["ALDi"] = compute_ALDi(df["Sentence"].tolist()) | |
| # A horizontal rule | |
| st.markdown("""---""") | |
| chart = ( | |
| alt.Chart(df.reset_index()) | |
| .mark_area(color="darkorange", opacity=0.5) | |
| .encode( | |
| x=X(field="index", title="Sentence Index"), | |
| y=Y("ALDi", scale=Scale(domain=[0, 1])), | |
| ) | |
| ) | |
| st.altair_chart(chart.interactive(), use_container_width=True) | |
| col1, col2 = st.columns([4, 1]) | |
| with col1: | |
| # Display the output | |
| st.table( | |
| df, | |
| ) | |
| with col2: | |
| # Add a download button | |
| csv = convert_df(df) | |
| st.download_button( | |
| label=":file_folder: Download predictions as CSV", | |
| data=csv, | |
| file_name="ALDi_scores.csv", | |
| mime="text/csv", | |
| ) | |