import re import string from collections import defaultdict import seaborn as sns import streamlit as st from charset_normalizer import detect from transformers import ( AutoModelForTokenClassification, AutoTokenizer, logging, pipeline, ) # Setup def setup_page(): st.set_page_config( page_title="Juristische Anonymisierung", page_icon="⚖️", layout="wide", ) logging.set_verbosity(logging.ERROR) st.markdown( """ """, unsafe_allow_html=True, ) def get_constants(): entity_importance = { "High": ["PER", "UN", "INN", "MRK", "RED"], "Mid": ["RR", "AN", "GRT", "GS", "VO", "RS", "EUN", "LIT", "VS", "VT"], "Low": ["LD", "ST", "STR", "LDS", "ORG"], } entity_labels = { "AN": "Rechtsbeistand", "EUN": "EUNorm", "GRT": "Gericht", "GS": "Norm", "INN": "Institution", "LD": "Land", "LDS": "Bezirk", "LIT": "Schrifttum", "MRK": "Marke", "ORG": "Organisation", "PER": "Person", "RR": "RichterIn", "RS": "Entscheidung", "ST": "Stadt", "STR": "Strasse", "UN": "Unternehmen", "VO": "Verordnung", "VS": "Richtlinie", "VT": "Vertrag", "RED": "Schwärzung", } return entity_importance, entity_labels def generate_fixed_colors(keys, alpha=0.25): base_colors = sns.color_palette("tab20", len(keys)) return { key: f"rgba({int(r*255)}, {int(g*255)}, {int(b*255)}, {alpha})" for key, (r, g, b) in zip(keys, base_colors) } @st.cache_resource def load_ner_model(): tokenizer = AutoTokenizer.from_pretrained("harshildarji/JuraNER") model = AutoModelForTokenClassification.from_pretrained("harshildarji/JuraNER") return pipeline("ner", model=model, tokenizer=tokenizer) @st.cache_data(show_spinner=False) def ner_merge_lines(text: str): ner = load_ner_model() merged_lines = [] for line in text.splitlines(): if not line.strip(): merged_lines.append((line, [])) continue tokens = ner(line) merged = merge_entities(tokens) merged_lines.append((line, merged)) return merged_lines def merge_entities(entities): if not entities: return [] ents = sorted(entities, key=lambda e: e["index"]) merged = [ents[0].copy()] merged[0]["score_sum"] = ents[0]["score"] merged[0]["count"] = 1 for ent in ents[1:]: prev = merged[-1] if ent["index"] == prev["index"] + 1: tok = ent["word"] prev["word"] += tok[2:] if tok.startswith("##") else " " + tok prev["end"] = ent["end"] prev["index"] = ent["index"] prev["score_sum"] += ent["score"] prev["count"] += 1 else: prev["score"] = prev["score_sum"] / prev["count"] del prev["score_sum"], prev["count"] new_ent = ent.copy() new_ent["score_sum"] = ent["score"] new_ent["count"] = 1 merged.append(new_ent) if "score_sum" in merged[-1]: merged[-1]["score"] = merged[-1]["score_sum"] / merged[-1]["count"] del merged[-1]["score_sum"], merged[-1]["count"] final = [] for ent in merged: w = ent["word"].strip() w = re.sub(r"\s*\.\s*", ".", w) w = re.sub(r"\s*,\s*", ", ", w) w = re.sub(r"\s*/\s*", "/", w) w = w.strip(string.whitespace + string.punctuation) if len(w) > 1 and re.search(r"\w", w): cleaned = ent.copy() cleaned["word"] = w final.append(cleaned) return final def truncate(number, decimals=2): factor = 10**decimals return int(number * factor) / factor # Canonical grouping def canonical_key(text: str, label: str): s = text.casefold().strip() if label == "RS": m = re.search(r"(ecli:[a-z]{2}:[a-z0-9]+:\d{4}:[a-z0-9.\-]+)", s) if m: original = text[m.start() : m.end()] canon = m.group(1).replace(" ", "") return (canon, label, original) m = re.search( r"((?:[ivxlcdm]+|\d{1,3})\s*[a-zäöüß]{1,3}\s*\d{1,6}\s*/\s*\d{2,4})", s ) if m: original = text[m.start() : m.end()].strip() canon = re.sub(r"\s+", "", m.group(1)) return (canon, label, original) cleaned = re.sub(r"[^\w]+", "", s) return (cleaned, label, text.strip()) cleaned_generic = re.sub(r"[^\w]+", " ", s) cleaned_generic = re.sub(r"\s+", " ", cleaned_generic).strip() return (cleaned_generic, label, text.strip()) # Rendering def highlight_entities( line, merged_entities, importance_levels, threshold, label_counters, anonymized_map, allowed_keys, entity_labels, entity_importance, ENTITY_COLORS, ): html = "" last_end = 0 for ent in merged_entities: if ent["score"] < threshold: continue start, end = ent["start"], ent["end"] label = ent["entity"].split("-")[-1] label_desc = entity_labels.get(label, label) truncated_score = truncate(ent["score"], 2) tooltip = f"{label_desc} ({truncated_score:.2f})" color = ENTITY_COLORS.get(label, "#cccccc") html += line[last_end:start] should_anonymize = any( label in entity_importance[level] for level in importance_levels ) if should_anonymize: key = (ent["word"].lower(), label) if key not in anonymized_map: count = label_counters.get(label, 0) suffix = chr(ord("A") + count) label_counters[label] = count + 1 anonymized_map[key] = suffix suffix = anonymized_map[key] display = f"{label_desc} {suffix}" normalized_word = ent["word"].strip().lower() display_key = f"{label_desc} {suffix} : {normalized_word}" if allowed_keys and display_key not in allowed_keys: display = ent["word"] style = "" css_class = "entity" else: style = f"background-color:{color}; font-weight:600;" css_class = "entity marked" else: display = ent["word"] style = "" css_class = "entity" html += f'{display}' last_end = end html += line[last_end:] return html # Main App def main(): setup_page() entity_importance, entity_labels = get_constants() ENTITY_COLORS = generate_fixed_colors(list(entity_labels.keys())) if "manual_phrases" not in st.session_state: st.session_state.manual_phrases = [] st.markdown("#### Juristische Anonymisierung") uploaded_file = st.file_uploader( "Bitte laden Sie eine .txt-Datei hoch:", type="txt" ) importance_display_to_key = {"Hoch": "High", "Mittel": "Mid", "Niedrig": "Low"} selected_importance_display = st.multiselect( "Wähle Wichtigkeitsstufen zur Anonymisierung:", options=list(importance_display_to_key.keys()), default=["Hoch"], ) importance_levels = [ importance_display_to_key[i] for i in selected_importance_display ] with st.expander("Übersicht: Entitätstypen nach Wichtigkeit", expanded=False): for level in ["High", "Mid", "Low"]: label = {"High": "Hoch", "Mid": "Mittel", "Low": "Niedrig"}[level] ent_list = [entity_labels[k] for k in entity_importance[level]] st.markdown(f"**{label}**: {', '.join(ent_list)}") threshold = st.slider("Schwellenwert für das Modellvertrauen:", 0.0, 1.0, 0.5, 0.01) st.markdown("---") if uploaded_file: raw_bytes = uploaded_file.read() encoding = detect(raw_bytes)["encoding"] if encoding is None: st.error("Zeichenkodierung konnte nicht erkannt werden.") return text = raw_bytes.decode(encoding) with st.spinner("Modell wird einmalig auf die Datei angewendet..."): merged_all_lines = ner_merge_lines(text) # Manual phrases to RED manual_phrases = st.session_state.manual_phrases overlap_warnings = set() for idx, (line, merged) in enumerate(merged_all_lines): for phrase in manual_phrases: for match in re.finditer(re.escape(phrase), line.lower()): start, end = match.start(), match.end() if any(start < e["end"] and end > e["start"] for e in merged): overlap_warnings.add(phrase) continue merged.append( { "start": start, "end": end, "word": line[start:end], "entity": "B-RED", "score": 1.0, "index": 9999, } ) merged_all_lines[idx] = (line, sorted(merged, key=lambda x: x["start"])) # Grouping layer for the sidebar groups = defaultdict( lambda: {"variants": set(), "displays": set(), "rep": None} ) for _, merged in merged_all_lines: for ent in merged: label = ent["entity"].split("-")[-1] if any(label in entity_importance[lvl] for lvl in importance_levels): variant_norm = ent["word"].strip().lower() canon_key, canon_label, display_key = canonical_key( ent["word"], label ) g = groups[(canon_key, canon_label)] g["variants"].add(variant_norm) g["displays"].add(display_key) # Suffix per canonical group label_counters_for_groups = {} for (canon_text, label), data in groups.items(): count = label_counters_for_groups.get(label, 0) suffix = chr(ord("A") + count) label_counters_for_groups[label] = count + 1 data["suffix"] = suffix for key, data in groups.items(): if data["displays"]: data["rep"] = max(data["displays"], key=len) else: data["rep"] = "" anonymized_map = {} for (canon_text, label), data in groups.items(): suffix = data["suffix"] for v in data["variants"]: anonymized_map[(v, label)] = suffix entity_labels_map = entity_labels display_to_variants = {} groups_by_label_desc = defaultdict(list) all_display_keys = set() for (canon_text, label), data in groups.items(): label_desc = entity_labels_map.get(label, label) suffix = data["suffix"] shown = f"{label_desc} {suffix} : {data['rep']}" groups_by_label_desc[label_desc].append(shown) display_keys = [f"{label_desc} {suffix} : {v}" for v in data["variants"]] display_to_variants[shown] = display_keys all_display_keys.update(display_keys) label_order = [ "RS", "GS", "PER", "AN", "GRT", "VO", "VS", "VT", "EUN", "LIT", "UN", "INN", "ORG", "MRK", "RR", "LD", "LDS", "ST", "STR", "RED", ] label_order_desc = [entity_labels_map.get(x, x) for x in label_order] with st.sidebar: st.markdown("### Neue Phrase schwärzen:") if "manual_phrases" not in st.session_state: st.session_state.manual_phrases = [] with st.form("manual_add_form"): new_phrase = st.text_input("Neue Phrase:") submitted = st.form_submit_button("Hinzufügen") with st.sidebar.expander("Hinweise zu manuellen Phrasen", expanded=False): st.markdown("**Noch in Entwicklung**") st.markdown( "_Manuelle Schwärzungen können fehlschlagen, wenn sich die Phrase " "mit bereits erkannten Entitäten überschneidet oder über mehrere " "Zeilen erstreckt._" ) if submitted and new_phrase.strip(): cleaned = new_phrase.strip().lower() if cleaned not in st.session_state.manual_phrases: st.session_state.manual_phrases.append(cleaned) st.rerun() st.markdown("---") st.markdown("### Anonymisierte Entitäten verwalten:") selected_canon = [] for lab_desc in label_order_desc: items = groups_by_label_desc.get(lab_desc, []) if not items: continue st.markdown(f"**{lab_desc}**") for shown in sorted(items, key=str.lower): checked = st.checkbox(shown, value=True, key=f"chk::{shown}") if checked: selected_canon.append(shown) if not selected_canon and groups_by_label_desc: selected_canon = [ x for items in groups_by_label_desc.values() for x in items ] allowed_keys = set() for shown in selected_canon: allowed_keys.update(display_to_variants.get(shown, [])) if not allowed_keys and all_display_keys: allowed_keys = set(all_display_keys) label_counters_runtime = {} anonymized_lines = [] for line, merged in merged_all_lines: if not line.strip(): st.markdown("
", unsafe_allow_html=True) anonymized_lines.append("") continue html_line = highlight_entities( line, merged, importance_levels, threshold, label_counters_runtime, anonymized_map, allowed_keys, entity_labels, entity_importance, ENTITY_COLORS, ) st.markdown( f'
{html_line}
', unsafe_allow_html=True, ) cleaned = re.sub(r"", "", html_line, flags=re.DOTALL) text_only = re.sub(r"<[^>]+>", "", cleaned) anonymized_lines.append(text_only.strip()) st.markdown("---") st.download_button( label="Anonymisierten Text herunterladen", data="\n".join(anonymized_lines), file_name=f"anonymisiert_{uploaded_file.name}", mime="text/plain", ) if __name__ == "__main__": main()