Spaces:
Runtime error
Runtime error
NER model added
Browse files- app.py +29 -10
- news_pipeline.py +17 -0
- requirements.txt +1 -0
app.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
-
|
| 3 |
from news_pipeline import NewsPipeline
|
| 4 |
|
| 5 |
CATEGORY_EMOJIS = {
|
|
@@ -34,15 +34,34 @@ def app():
|
|
| 34 |
|
| 35 |
with st.spinner("Analyzing article..."):
|
| 36 |
prediction = news_pipe(headline, content)
|
| 37 |
-
st.
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
|
| 48 |
if __name__ == "__main__":
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
from annotated_text import annotated_text
|
| 3 |
from news_pipeline import NewsPipeline
|
| 4 |
|
| 5 |
CATEGORY_EMOJIS = {
|
|
|
|
| 34 |
|
| 35 |
with st.spinner("Analyzing article..."):
|
| 36 |
prediction = news_pipe(headline, content)
|
| 37 |
+
col1, _, col2 = st.columns([2, 1, 6])
|
| 38 |
+
with col1:
|
| 39 |
+
st.subheader("Analysis:")
|
| 40 |
+
st.markdown(
|
| 41 |
+
f"{CATEGORY_EMOJIS[prediction['category']]} **Category**: {prediction['category']}"
|
| 42 |
+
)
|
| 43 |
+
st.markdown(
|
| 44 |
+
f"{FAKE_EMOJIS[prediction['fake']]} **Fake**: {'Yes' if prediction['fake'] == 'Fake' else 'No'}"
|
| 45 |
+
)
|
| 46 |
+
st.markdown(
|
| 47 |
+
f"{CLICKBAIT_EMOJIS[prediction['clickbait']]} **Clickbait**: {'Yes' if prediction['clickbait'] == 'Clickbait' else 'No'}"
|
| 48 |
+
)
|
| 49 |
+
with col2:
|
| 50 |
+
st.subheader("Headline")
|
| 51 |
+
annotated_text(*parse_text(headline, prediction["ner"]["headline"]))
|
| 52 |
+
st.subheader("Content")
|
| 53 |
+
annotated_text(*parse_text(content, prediction["ner"]["content"]))
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def parse_text(text, prediction):
|
| 57 |
+
start = 0
|
| 58 |
+
parsed_text = []
|
| 59 |
+
for p in prediction:
|
| 60 |
+
parsed_text.append(text[start : p["start"]])
|
| 61 |
+
parsed_text.append((p["word"], p["entity_group"]))
|
| 62 |
+
start = p["end"]
|
| 63 |
+
parsed_text.append(text[start:])
|
| 64 |
+
return parsed_text
|
| 65 |
|
| 66 |
|
| 67 |
if __name__ == "__main__":
|
news_pipeline.py
CHANGED
|
@@ -2,8 +2,10 @@ from typing import Dict
|
|
| 2 |
|
| 3 |
from transformers import (
|
| 4 |
AutoModelForSequenceClassification,
|
|
|
|
| 5 |
AutoTokenizer,
|
| 6 |
TextClassificationPipeline,
|
|
|
|
| 7 |
)
|
| 8 |
|
| 9 |
|
|
@@ -29,6 +31,13 @@ class NewsPipeline:
|
|
| 29 |
),
|
| 30 |
tokenizer=AutoTokenizer.from_pretrained("elozano/news-clickbait"),
|
| 31 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
def __call__(self, headline: str, content: str) -> Dict[str, str]:
|
| 34 |
category_article_text = f" {self.category_tokenizer.sep_token} ".join(
|
|
@@ -41,4 +50,12 @@ class NewsPipeline:
|
|
| 41 |
"category": self.category_pipeline(category_article_text)[0]["label"],
|
| 42 |
"fake": self.fake_pipeline(fake_article_text)[0]["label"],
|
| 43 |
"clickbait": self.clickbait_pipeline(headline)[0]["label"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
}
|
|
|
|
| 2 |
|
| 3 |
from transformers import (
|
| 4 |
AutoModelForSequenceClassification,
|
| 5 |
+
AutoModelForTokenClassification,
|
| 6 |
AutoTokenizer,
|
| 7 |
TextClassificationPipeline,
|
| 8 |
+
TokenClassificationPipeline,
|
| 9 |
)
|
| 10 |
|
| 11 |
|
|
|
|
| 31 |
),
|
| 32 |
tokenizer=AutoTokenizer.from_pretrained("elozano/news-clickbait"),
|
| 33 |
)
|
| 34 |
+
self.ner_pipeline = TokenClassificationPipeline(
|
| 35 |
+
tokenizer=AutoTokenizer.from_pretrained("dslim/bert-base-NER"),
|
| 36 |
+
model=AutoModelForTokenClassification.from_pretrained(
|
| 37 |
+
"dslim/bert-base-NER"
|
| 38 |
+
),
|
| 39 |
+
aggregation_strategy="simple",
|
| 40 |
+
)
|
| 41 |
|
| 42 |
def __call__(self, headline: str, content: str) -> Dict[str, str]:
|
| 43 |
category_article_text = f" {self.category_tokenizer.sep_token} ".join(
|
|
|
|
| 50 |
"category": self.category_pipeline(category_article_text)[0]["label"],
|
| 51 |
"fake": self.fake_pipeline(fake_article_text)[0]["label"],
|
| 52 |
"clickbait": self.clickbait_pipeline(headline)[0]["label"],
|
| 53 |
+
"ner": {
|
| 54 |
+
"headline": list(
|
| 55 |
+
filter(lambda x: x["score"] > 0.8, self.ner_pipeline(headline))
|
| 56 |
+
),
|
| 57 |
+
"content": list(
|
| 58 |
+
filter(lambda x: x["score"] > 0.8, self.ner_pipeline(content))
|
| 59 |
+
),
|
| 60 |
+
},
|
| 61 |
}
|
requirements.txt
CHANGED
|
@@ -1,2 +1,3 @@
|
|
| 1 |
transformers
|
| 2 |
torch
|
|
|
|
|
|
| 1 |
transformers
|
| 2 |
torch
|
| 3 |
+
st-annotated-text
|