import re
import string
from collections import defaultdict
import seaborn as sns
import streamlit as st
from charset_normalizer import detect
from transformers import (
AutoModelForTokenClassification,
AutoTokenizer,
logging,
pipeline,
)
# Setup
def setup_page():
st.set_page_config(
page_title="Juristische Anonymisierung",
page_icon="⚖️",
layout="wide",
)
logging.set_verbosity(logging.ERROR)
st.markdown(
"""
""",
unsafe_allow_html=True,
)
def get_constants():
entity_importance = {
"High": ["PER", "UN", "INN", "MRK", "RED"],
"Mid": ["RR", "AN", "GRT", "GS", "VO", "RS", "EUN", "LIT", "VS", "VT"],
"Low": ["LD", "ST", "STR", "LDS", "ORG"],
}
entity_labels = {
"AN": "Rechtsbeistand",
"EUN": "EUNorm",
"GRT": "Gericht",
"GS": "Norm",
"INN": "Institution",
"LD": "Land",
"LDS": "Bezirk",
"LIT": "Schrifttum",
"MRK": "Marke",
"ORG": "Organisation",
"PER": "Person",
"RR": "RichterIn",
"RS": "Entscheidung",
"ST": "Stadt",
"STR": "Strasse",
"UN": "Unternehmen",
"VO": "Verordnung",
"VS": "Richtlinie",
"VT": "Vertrag",
"RED": "Schwärzung",
}
return entity_importance, entity_labels
def generate_fixed_colors(keys, alpha=0.25):
base_colors = sns.color_palette("tab20", len(keys))
return {
key: f"rgba({int(r*255)}, {int(g*255)}, {int(b*255)}, {alpha})"
for key, (r, g, b) in zip(keys, base_colors)
}
@st.cache_resource
def load_ner_model():
tokenizer = AutoTokenizer.from_pretrained("harshildarji/JuraNER")
model = AutoModelForTokenClassification.from_pretrained("harshildarji/JuraNER")
return pipeline("ner", model=model, tokenizer=tokenizer)
@st.cache_data(show_spinner=False)
def ner_merge_lines(text: str):
ner = load_ner_model()
merged_lines = []
for line in text.splitlines():
if not line.strip():
merged_lines.append((line, []))
continue
tokens = ner(line)
merged = merge_entities(tokens)
merged_lines.append((line, merged))
return merged_lines
def merge_entities(entities):
if not entities:
return []
ents = sorted(entities, key=lambda e: e["index"])
merged = [ents[0].copy()]
merged[0]["score_sum"] = ents[0]["score"]
merged[0]["count"] = 1
for ent in ents[1:]:
prev = merged[-1]
if ent["index"] == prev["index"] + 1:
tok = ent["word"]
prev["word"] += tok[2:] if tok.startswith("##") else " " + tok
prev["end"] = ent["end"]
prev["index"] = ent["index"]
prev["score_sum"] += ent["score"]
prev["count"] += 1
else:
prev["score"] = prev["score_sum"] / prev["count"]
del prev["score_sum"], prev["count"]
new_ent = ent.copy()
new_ent["score_sum"] = ent["score"]
new_ent["count"] = 1
merged.append(new_ent)
if "score_sum" in merged[-1]:
merged[-1]["score"] = merged[-1]["score_sum"] / merged[-1]["count"]
del merged[-1]["score_sum"], merged[-1]["count"]
final = []
for ent in merged:
w = ent["word"].strip()
w = re.sub(r"\s*\.\s*", ".", w)
w = re.sub(r"\s*,\s*", ", ", w)
w = re.sub(r"\s*/\s*", "/", w)
w = w.strip(string.whitespace + string.punctuation)
if len(w) > 1 and re.search(r"\w", w):
cleaned = ent.copy()
cleaned["word"] = w
final.append(cleaned)
return final
def truncate(number, decimals=2):
factor = 10**decimals
return int(number * factor) / factor
# Canonical grouping
def canonical_key(text: str, label: str):
s = text.casefold().strip()
if label == "RS":
m = re.search(r"(ecli:[a-z]{2}:[a-z0-9]+:\d{4}:[a-z0-9.\-]+)", s)
if m:
original = text[m.start() : m.end()]
canon = m.group(1).replace(" ", "")
return (canon, label, original)
m = re.search(
r"((?:[ivxlcdm]+|\d{1,3})\s*[a-zäöüß]{1,3}\s*\d{1,6}\s*/\s*\d{2,4})", s
)
if m:
original = text[m.start() : m.end()].strip()
canon = re.sub(r"\s+", "", m.group(1))
return (canon, label, original)
cleaned = re.sub(r"[^\w]+", "", s)
return (cleaned, label, text.strip())
cleaned_generic = re.sub(r"[^\w]+", " ", s)
cleaned_generic = re.sub(r"\s+", " ", cleaned_generic).strip()
return (cleaned_generic, label, text.strip())
# Rendering
def highlight_entities(
line,
merged_entities,
importance_levels,
threshold,
label_counters,
anonymized_map,
allowed_keys,
entity_labels,
entity_importance,
ENTITY_COLORS,
):
html = ""
last_end = 0
for ent in merged_entities:
if ent["score"] < threshold:
continue
start, end = ent["start"], ent["end"]
label = ent["entity"].split("-")[-1]
label_desc = entity_labels.get(label, label)
truncated_score = truncate(ent["score"], 2)
tooltip = f"{label_desc} ({truncated_score:.2f})"
color = ENTITY_COLORS.get(label, "#cccccc")
html += line[last_end:start]
should_anonymize = any(
label in entity_importance[level] for level in importance_levels
)
if should_anonymize:
key = (ent["word"].lower(), label)
if key not in anonymized_map:
count = label_counters.get(label, 0)
suffix = chr(ord("A") + count)
label_counters[label] = count + 1
anonymized_map[key] = suffix
suffix = anonymized_map[key]
display = f"{label_desc} {suffix}"
normalized_word = ent["word"].strip().lower()
display_key = f"{label_desc} {suffix} : {normalized_word}"
if allowed_keys and display_key not in allowed_keys:
display = ent["word"]
style = ""
css_class = "entity"
else:
style = f"background-color:{color}; font-weight:600;"
css_class = "entity marked"
else:
display = ent["word"]
style = ""
css_class = "entity"
html += f'{display}'
last_end = end
html += line[last_end:]
return html
# Main App
def main():
setup_page()
entity_importance, entity_labels = get_constants()
ENTITY_COLORS = generate_fixed_colors(list(entity_labels.keys()))
if "manual_phrases" not in st.session_state:
st.session_state.manual_phrases = []
st.markdown("#### Juristische Anonymisierung")
uploaded_file = st.file_uploader(
"Bitte laden Sie eine .txt-Datei hoch:", type="txt"
)
importance_display_to_key = {"Hoch": "High", "Mittel": "Mid", "Niedrig": "Low"}
selected_importance_display = st.multiselect(
"Wähle Wichtigkeitsstufen zur Anonymisierung:",
options=list(importance_display_to_key.keys()),
default=["Hoch"],
)
importance_levels = [
importance_display_to_key[i] for i in selected_importance_display
]
with st.expander("Übersicht: Entitätstypen nach Wichtigkeit", expanded=False):
for level in ["High", "Mid", "Low"]:
label = {"High": "Hoch", "Mid": "Mittel", "Low": "Niedrig"}[level]
ent_list = [entity_labels[k] for k in entity_importance[level]]
st.markdown(f"**{label}**: {', '.join(ent_list)}")
threshold = st.slider("Schwellenwert für das Modellvertrauen:", 0.0, 1.0, 0.5, 0.01)
st.markdown("---")
if uploaded_file:
raw_bytes = uploaded_file.read()
encoding = detect(raw_bytes)["encoding"]
if encoding is None:
st.error("Zeichenkodierung konnte nicht erkannt werden.")
return
text = raw_bytes.decode(encoding)
with st.spinner("Modell wird einmalig auf die Datei angewendet..."):
merged_all_lines = ner_merge_lines(text)
# Manual phrases to RED
manual_phrases = st.session_state.manual_phrases
overlap_warnings = set()
for idx, (line, merged) in enumerate(merged_all_lines):
for phrase in manual_phrases:
for match in re.finditer(re.escape(phrase), line.lower()):
start, end = match.start(), match.end()
if any(start < e["end"] and end > e["start"] for e in merged):
overlap_warnings.add(phrase)
continue
merged.append(
{
"start": start,
"end": end,
"word": line[start:end],
"entity": "B-RED",
"score": 1.0,
"index": 9999,
}
)
merged_all_lines[idx] = (line, sorted(merged, key=lambda x: x["start"]))
# Grouping layer for the sidebar
groups = defaultdict(
lambda: {"variants": set(), "displays": set(), "rep": None}
)
for _, merged in merged_all_lines:
for ent in merged:
label = ent["entity"].split("-")[-1]
if any(label in entity_importance[lvl] for lvl in importance_levels):
variant_norm = ent["word"].strip().lower()
canon_key, canon_label, display_key = canonical_key(
ent["word"], label
)
g = groups[(canon_key, canon_label)]
g["variants"].add(variant_norm)
g["displays"].add(display_key)
# Suffix per canonical group
label_counters_for_groups = {}
for (canon_text, label), data in groups.items():
count = label_counters_for_groups.get(label, 0)
suffix = chr(ord("A") + count)
label_counters_for_groups[label] = count + 1
data["suffix"] = suffix
for key, data in groups.items():
if data["displays"]:
data["rep"] = max(data["displays"], key=len)
else:
data["rep"] = ""
anonymized_map = {}
for (canon_text, label), data in groups.items():
suffix = data["suffix"]
for v in data["variants"]:
anonymized_map[(v, label)] = suffix
entity_labels_map = entity_labels
display_to_variants = {}
groups_by_label_desc = defaultdict(list)
all_display_keys = set()
for (canon_text, label), data in groups.items():
label_desc = entity_labels_map.get(label, label)
suffix = data["suffix"]
shown = f"{label_desc} {suffix} : {data['rep']}"
groups_by_label_desc[label_desc].append(shown)
display_keys = [f"{label_desc} {suffix} : {v}" for v in data["variants"]]
display_to_variants[shown] = display_keys
all_display_keys.update(display_keys)
label_order = [
"RS",
"GS",
"PER",
"AN",
"GRT",
"VO",
"VS",
"VT",
"EUN",
"LIT",
"UN",
"INN",
"ORG",
"MRK",
"RR",
"LD",
"LDS",
"ST",
"STR",
"RED",
]
label_order_desc = [entity_labels_map.get(x, x) for x in label_order]
with st.sidebar:
st.markdown("### Neue Phrase schwärzen:")
if "manual_phrases" not in st.session_state:
st.session_state.manual_phrases = []
with st.form("manual_add_form"):
new_phrase = st.text_input("Neue Phrase:")
submitted = st.form_submit_button("Hinzufügen")
with st.sidebar.expander("Hinweise zu manuellen Phrasen", expanded=False):
st.markdown("**Noch in Entwicklung**")
st.markdown(
"_Manuelle Schwärzungen können fehlschlagen, wenn sich die Phrase "
"mit bereits erkannten Entitäten überschneidet oder über mehrere "
"Zeilen erstreckt._"
)
if submitted and new_phrase.strip():
cleaned = new_phrase.strip().lower()
if cleaned not in st.session_state.manual_phrases:
st.session_state.manual_phrases.append(cleaned)
st.rerun()
st.markdown("---")
st.markdown("### Anonymisierte Entitäten verwalten:")
selected_canon = []
for lab_desc in label_order_desc:
items = groups_by_label_desc.get(lab_desc, [])
if not items:
continue
st.markdown(f"**{lab_desc}**")
for shown in sorted(items, key=str.lower):
checked = st.checkbox(shown, value=True, key=f"chk::{shown}")
if checked:
selected_canon.append(shown)
if not selected_canon and groups_by_label_desc:
selected_canon = [
x for items in groups_by_label_desc.values() for x in items
]
allowed_keys = set()
for shown in selected_canon:
allowed_keys.update(display_to_variants.get(shown, []))
if not allowed_keys and all_display_keys:
allowed_keys = set(all_display_keys)
label_counters_runtime = {}
anonymized_lines = []
for line, merged in merged_all_lines:
if not line.strip():
st.markdown("
", unsafe_allow_html=True)
anonymized_lines.append("")
continue
html_line = highlight_entities(
line,
merged,
importance_levels,
threshold,
label_counters_runtime,
anonymized_map,
allowed_keys,
entity_labels,
entity_importance,
ENTITY_COLORS,
)
st.markdown(
f'