import copy import re import regex import torch import torch.nn.functional as F from langchain_text_splitters import RecursiveCharacterTextSplitter from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline MAX_TOKENS_IOB_SENT = 256 OVERLAPPING_LEN = 0 def split_sentence_with_indices(text): pattern = r''' (?: \p{N}+[.,]?\p{N}*\s*[%$€]? ) | \p{L}+(?:-\p{L}+)* | [()\[\]{}] | [^\p{L}\p{N}\s] ''' return list(regex.finditer(pattern, text, flags=regex.VERBOSE)) class PredictionNER: def __init__(self, model_checkpoint, revision) -> None: self.model_checkpoint = model_checkpoint self.revision = revision self.model = None self.tokenizer = None self.pipe = None self.text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( encoding_name='o200k_base', separators=["\n\n\n", "\n\n", "\n", " .", " !", " ?", " ،", " ,", " ", ""], keep_separator=True, chunk_size=MAX_TOKENS_IOB_SENT, chunk_overlap=OVERLAPPING_LEN, ) def _load(self): """Load model and tokenizer only when first used.""" if self.model is None: self.tokenizer = AutoTokenizer.from_pretrained( self.model_checkpoint, revision=self.revision, is_split_into_words=True, truncation=False ) self.model = AutoModelForTokenClassification.from_pretrained( self.model_checkpoint, revision=self.revision ) self.pipe = pipeline( "token-classification", model=self.model, tokenizer=self.tokenizer, aggregation_strategy="average" ) def split_text_with_indices(self, text): raw_chunks = self.text_splitter.split_text(text) used_indices = set() for chunk_text in raw_chunks: start_index = text.find(chunk_text) while start_index in used_indices: start_index = text.find(chunk_text, start_index + 1) used_indices.add(start_index) end_index = start_index + len(chunk_text) yield chunk_text, start_index, end_index def predict_text(self, text: str, o_confidence_threshold: float = 0.70): self._load() # Lazy load here text_matches = split_sentence_with_indices(text) text_words = [m.group().strip() for m in text_matches if m.group().strip()] if not text_words: return [] inputs = self.tokenizer( text_words, return_tensors="pt", is_split_into_words=True, truncation=False ) word_ids = inputs.word_ids() with torch.no_grad(): logits = self.model(**inputs).logits probs = F.softmax(logits, dim=-1) predictions = torch.argmax(logits, dim=2)[0] results = [] seen = set() non_empty_matches = [m for m in text_matches if m.group().strip()] id2label = self.model.config.id2label for i, word_idx in enumerate(word_ids): if word_idx is None or word_idx in seen: continue seen.add(word_idx) word = text_words[word_idx] start = non_empty_matches[word_idx].start() end = non_empty_matches[word_idx].end() tag_id = predictions[i].item() tag = id2label[tag_id] score = probs[0, i, tag_id].item() if tag == "O" and score < o_confidence_threshold: sorted_probs = torch.argsort(probs[0, i], descending=True) for alt_id in sorted_probs: alt_tag = id2label[alt_id.item()] if alt_tag != "O": tag_id = alt_id.item() tag = alt_tag score = probs[0, i, tag_id].item() break results.append({ 'word': word, 'tag': tag, 'start': start, 'end': end, 'score': score }) return results def aggregate_entities(self, tagged_tokens, original_text, confidence_threshold=0.3): def is_special_char(text): return bool(re.fullmatch(r"\W+", text.strip())) def finalize_entity(entity): if all(s >= confidence_threshold for s in entity["scores"]): entity_text = original_text[entity["start"]:entity["end"]] if not is_special_char(entity_text): entity["text"] = entity_text entity["score"] = sum(entity["scores"]) / len(entity["scores"]) del entity["scores"] return entity return None corrected_tokens = copy.deepcopy(tagged_tokens) for i in range(1, len(corrected_tokens) - 1): prev_tag = corrected_tokens[i - 1]["tag"] curr_tag = corrected_tokens[i]["tag"] next_tag = corrected_tokens[i + 1]["tag"] if curr_tag == "O" and prev_tag.startswith("B-") and next_tag.startswith("I-"): prev_type = prev_tag[2:] next_type = next_tag[2:] if prev_type == next_type: corrected_tokens[i]["tag"] = "I-" + prev_type last_tag_type = None for i in range(len(corrected_tokens)): tag = corrected_tokens[i]["tag"] if tag.startswith("I-"): tag_type = tag[2:] if last_tag_type != tag_type: corrected_tokens[i]["tag"] = "B-" + tag_type last_tag_type = tag_type elif tag.startswith("B-"): last_tag_type = tag[2:] else: last_tag_type = None entities = [] current_entity = None for item in corrected_tokens: tag = item["tag"] start = item["start"] end = item["end"] score = item["score"] if tag.startswith("B-"): if current_entity: finalized = finalize_entity(current_entity) if finalized: entities.append(finalized) current_entity = { "start": start, "end": end, "tag": tag[2:], "scores": [score] } elif tag.startswith("I-") and current_entity and current_entity["tag"] == tag[2:]: current_entity["end"] = end current_entity["scores"].append(score) else: if current_entity: finalized = finalize_entity(current_entity) if finalized: entities.append(finalized) current_entity = None if current_entity: finalized = finalize_entity(current_entity) if finalized: entities.append(finalized) return entities def do_prediction(self, text, confidence_threshold=0.6) -> list: final_prediction = [] for sub_text, sub_text_start, sub_text_end in self.split_text_with_indices(text): tokens = self.predict_text(text=sub_text) predictions = self.aggregate_entities(tokens, sub_text, confidence_threshold=confidence_threshold) for pred in predictions: pred["start"] += sub_text_start pred["end"] += sub_text_start pred['entity'] = pred.pop('tag') final_prediction.append(pred) return final_prediction