|
|
import io, os, re, json |
|
|
from typing import List, Tuple, Dict |
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from PIL import Image, ImageOps, ImageFilter |
|
|
|
|
|
import streamlit as st |
|
|
import torch |
|
|
import torchvision.transforms as T |
|
|
|
|
|
|
|
|
import pytesseract |
|
|
from pytesseract import Output |
|
|
|
|
|
|
|
|
from pdf2image import convert_from_bytes |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import models as pdrt_models |
|
|
|
|
|
st.set_page_config(page_title="Invoice OCR (ViT recognizer + Tesseract detector)", layout="wide") |
|
|
|
|
|
|
|
|
st.sidebar.header("Model") |
|
|
arch = st.sidebar.selectbox("Architecture", ["Swin_CTC", "VED"], index=0) |
|
|
ckpt_path = st.sidebar.text_input("Checkpoint path (inside Space)", value="checkpoints/pdrt_weights.pth") |
|
|
alphabet = st.sidebar.text_input("Alphabet (ordered classes, exclude CTC blank)", value="0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz-_/.,:;()[]{}#+*&%$@!?\"' ") |
|
|
img_h = st.sidebar.number_input("Recognizer input height", 64, 256, 128, 8) |
|
|
img_w = st.sidebar.number_input("Recognizer input width", 128, 2048, 512, 16) |
|
|
det_lang = st.sidebar.text_input("Tesseract lang(s) for detection only", value="eng") |
|
|
show_boxes = st.sidebar.checkbox("Show word boxes", value=False) |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
def load_pages(file_bytes: bytes, name: str) -> List[Image.Image]: |
|
|
name = (name or "").lower() |
|
|
if name.endswith(".pdf"): |
|
|
return convert_from_bytes(file_bytes, dpi=300) |
|
|
return [Image.open(io.BytesIO(file_bytes)).convert("RGB")] |
|
|
|
|
|
def preprocess_for_detection(img: Image.Image) -> Image.Image: |
|
|
g = ImageOps.grayscale(img) |
|
|
g = ImageOps.autocontrast(g) |
|
|
g = g.filter(ImageFilter.UnsharpMask(radius=1, percent=150, threshold=3)) |
|
|
return g |
|
|
|
|
|
@st.cache_resource |
|
|
def load_pdrt(arch_name: str, ckpt: str, num_classes: int): |
|
|
if arch_name == "Swin_CTC": |
|
|
model = pdrt_models.Swin_CTC(num_classes=num_classes) |
|
|
elif arch_name == "VED": |
|
|
model = pdrt_models.VED(num_classes=num_classes) |
|
|
else: |
|
|
raise ValueError("Unknown model") |
|
|
state = torch.load(ckpt, map_location="cpu") |
|
|
model.load_state_dict(state, strict=False) |
|
|
model.eval().to(device) |
|
|
return model |
|
|
|
|
|
def build_transform(img_h: int, img_w: int): |
|
|
return T.Compose([ |
|
|
T.Grayscale(num_output_channels=3), |
|
|
T.Resize((img_h, img_w)), |
|
|
T.ToTensor(), |
|
|
T.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]), |
|
|
]) |
|
|
|
|
|
def greedy_ctc_decode(logits: torch.Tensor, alphabet: str) -> str: |
|
|
""" |
|
|
logits: (B, T, C) or (T, B, C). We map argmax to chars, collapse repeats, remove blank. |
|
|
We assume blank_id = len(alphabet). |
|
|
""" |
|
|
if logits.dim() == 3 and logits.shape[0] != 1 and logits.shape[1] == 1: |
|
|
|
|
|
pass |
|
|
if logits.shape[0] == 1: |
|
|
logits = logits.squeeze(0) |
|
|
elif logits.shape[1] == 1: |
|
|
logits = logits[:,0,:] |
|
|
probs = logits.softmax(-1) |
|
|
ids = probs.argmax(-1).tolist() |
|
|
blank_id = len(alphabet) |
|
|
out = [] |
|
|
prev = None |
|
|
for i in ids: |
|
|
if i != prev and i != blank_id: |
|
|
out.append(alphabet[i] if i < len(alphabet) else "") |
|
|
prev = i |
|
|
return "".join(out) |
|
|
|
|
|
def recognize_word_crops(model, crops: List[Image.Image], tfm, arch_name: str, alphabet: str) -> List[str]: |
|
|
texts = [] |
|
|
with torch.no_grad(): |
|
|
for im in crops: |
|
|
x = tfm(im).unsqueeze(0).to(device) |
|
|
y = model(x) |
|
|
if arch_name == "Swin_CTC": |
|
|
|
|
|
if y.dim() == 3 and y.shape[0] == 1: |
|
|
logits = y[0] |
|
|
elif y.dim() == 3 and y.shape[1] == 1: |
|
|
logits = y[:,0,:] |
|
|
else: |
|
|
logits = y |
|
|
txt = greedy_ctc_decode(logits, alphabet) |
|
|
else: |
|
|
|
|
|
|
|
|
if y.dim() == 3 and y.shape[0] == 1: |
|
|
y = y[0] |
|
|
ids = y.argmax(-1).tolist() |
|
|
txt = "".join(alphabet[i] if i < len(alphabet) else "" for i in ids).strip() |
|
|
texts.append(txt) |
|
|
return texts |
|
|
|
|
|
def detect_words(img: Image.Image, lang="eng") -> pd.DataFrame: |
|
|
df = pytesseract.image_to_data(img, lang=lang, output_type=Output.DATAFRAME) |
|
|
df = df.dropna(subset=["text"]).reset_index(drop=True) |
|
|
df["x2"] = df["left"] + df["width"] |
|
|
df["y2"] = df["top"] + df["height"] |
|
|
return df[df["conf"] > -1] |
|
|
|
|
|
def crop_words(img: Image.Image, df: pd.DataFrame) -> List[Tuple[Image.Image, Dict]]: |
|
|
crops, metas = [], [] |
|
|
for _, r in df.iterrows(): |
|
|
if str(r["text"]).strip() == "": |
|
|
continue |
|
|
box = (int(r["left"]), int(r["top"]), int(r["x2"]), int(r["y2"])) |
|
|
c = img.crop(box) |
|
|
crops.append(c) |
|
|
metas.append({"box": box}) |
|
|
return crops, metas |
|
|
|
|
|
|
|
|
CURRENCY = r"(?P<curr>USD|CAD|EUR|GBP|\$|C\$|β¬|Β£)?" |
|
|
MONEY = rf"{CURRENCY}\s?(?P<amt>\d{{1,3}}(?:[,]\d{{3}})*(?:[.]\d{{2}})?)" |
|
|
DATE = r"(?P<date>(?:\d{4}[-/]\d{1,2}[-/]\d{1,2})|(?:\d{1,2}[-/]\d{1,2}[-/]\d{2,4})|(?:[A-Za-z]{3,9}\s+\d{1,2},\s*\d{2,4}))" |
|
|
INV_PAT = r"(?:invoice\s*(?:no\.?|#|number)?\s*[:\-]?\s*(?P<inv>[A-Z0-9\-_/]{4,}))" |
|
|
PO_PAT = r"(?:po\s*(?:no\.?|#|number)?\s*[:\-]?\s*(?P<po>[A-Z0-9\-_/]{3,}))" |
|
|
TOTAL_PAT = rf"(?:\b(total(?:\s*amount)?|amount\s*due|grand\s*total)\b.*?{MONEY})" |
|
|
SUBTOTAL_PAT = rf"(?:\bsub\s*total\b.*?{MONEY})" |
|
|
TAX_PAT = rf"(?:\b(tax|gst|vat|hst)\b.*?{MONEY})" |
|
|
|
|
|
def parse_fields(fulltext: str): |
|
|
t = re.sub(r"[ \t]+", " ", fulltext) |
|
|
t = re.sub(r"\n{2,}", "\n", t) |
|
|
out = {"invoice_number":None,"invoice_date":None,"po_number":None,"subtotal":None,"tax":None,"total":None,"currency":None} |
|
|
m = re.search(INV_PAT, t, re.I); out["invoice_number"] = m.group("inv") if m else None |
|
|
m = re.search(PO_PAT, t, re.I); out["po_number"] = m.group("po") if m else None |
|
|
m = re.search(rf"(invoice\s*date[:\-\s]*){DATE}", t, re.I) |
|
|
out["invoice_date"] = (m.group("date") if m else (re.search(DATE, t, re.I).group("date") if re.search(DATE, t, re.I) else None)) |
|
|
m = re.search(SUBTOTAL_PAT, t, re.I|re.S); |
|
|
if m: out["subtotal"], out["currency"] = m.group("amt").replace(",",""), m.group("curr") or out["currency"] |
|
|
m = re.search(TAX_PAT, t, re.I|re.S); |
|
|
if m: out["tax"], out["currency"] = m.group("amt").replace(",",""), m.group("curr") or out["currency"] |
|
|
m = re.search(TOTAL_PAT, t, re.I|re.S); |
|
|
if m: out["total"], out["currency"] = m.group("amt").replace(",",""), m.group("curr") or out["currency"] |
|
|
if out["currency"] in ["$", "C$", "β¬", "Β£"]: |
|
|
out["currency"] = {"$":"USD", "C$":"CAD", "β¬":"EUR", "Β£":"GBP"}[out["currency"]] |
|
|
return out |
|
|
|
|
|
HEAD_CANDIDATES = ["description","item","qty","quantity","price","unit","rate","amount","total"] |
|
|
def items_from_wordgrid(df: pd.DataFrame) -> pd.DataFrame: |
|
|
|
|
|
df = df.copy() |
|
|
df["cx"] = df["left"] + 0.5*df["width"] |
|
|
df["cy"] = df["top"] + 0.5*df["height"] |
|
|
lines = [] |
|
|
for (b,p,l), g in df.groupby(["block_num","par_num","line_num"]): |
|
|
text = " ".join([t for t in g["text"].astype(str) if t.strip()]) |
|
|
if text.strip(): |
|
|
lines.append({ |
|
|
"block_num":b,"par_num":p,"line_num":l, |
|
|
"text": text.lower(), |
|
|
"top": g["top"].min(), "bottom": (g["top"]+g["height"]).max(), |
|
|
"left": g["left"].min(), "right": (g["left"]+g["width"]).max(), |
|
|
"words": g.sort_values("cx")[["cx","left","top","width","height"]].values.tolist() |
|
|
}) |
|
|
L = pd.DataFrame(lines) |
|
|
if L.empty: return pd.DataFrame() |
|
|
L["score"] = L["text"].apply(lambda s: sum(1 for h in HEAD_CANDIDATES if h in s)) |
|
|
headers = L[L["score"]>=2].sort_values(["score","top"], ascending=[False,True]) |
|
|
if headers.empty: return pd.DataFrame() |
|
|
H = headers.iloc[0] |
|
|
header_y = H["bottom"] + 4 |
|
|
|
|
|
|
|
|
|
|
|
header_band = df[(df["top"]>=H["top"]-5) & ((df["top"]+df["height"])<=H["bottom"]+5)] |
|
|
header_band = header_band.sort_values("left") |
|
|
col_x = header_band["left"].tolist() |
|
|
if len(col_x)<2: return pd.DataFrame() |
|
|
|
|
|
below = df[df["top"]>header_y].copy() |
|
|
totals_mask = below["text"].str.lower().str.contains(r"(sub\s*total|amount\s*due|total|grand\s*total|balance)", regex=True, na=False) |
|
|
if totals_mask.any(): |
|
|
stop_y = below.loc[totals_mask,"top"].min() |
|
|
below = below[below["top"]<stop_y-4] |
|
|
rows = [] |
|
|
for (b,p,l), g in below.groupby(["block_num","par_num","line_num"]): |
|
|
if g["text"].astype(str).str.strip().eq("").all(): continue |
|
|
g = g.sort_values("left") |
|
|
|
|
|
xs = np.array(col_x) |
|
|
buckets = {i:[] for i in range(len(xs))} |
|
|
for _,w in g.iterrows(): |
|
|
idx = int(np.abs(xs - w["left"]).argmin()) |
|
|
buckets[idx].append(str(w["text"])) |
|
|
vals = [" ".join(buckets.get(i,[])).strip() for i in range(len(xs))] |
|
|
rows.append(vals) |
|
|
if not rows: return pd.DataFrame() |
|
|
df_rows = pd.DataFrame(rows).fillna("") |
|
|
|
|
|
names = [] |
|
|
for i, w in enumerate(header_band["text"].tolist()[:df_rows.shape[1]]): |
|
|
wl = w.lower() |
|
|
if "desc" in wl or wl in ["item","description"]: |
|
|
names.append("description") |
|
|
elif wl in ["qty","quantity"]: |
|
|
names.append("quantity") |
|
|
elif "unit" in wl or "rate" in wl or "price" in wl: |
|
|
names.append("unit_price") |
|
|
elif "amount" in wl or "total" in wl: |
|
|
names.append("line_total") |
|
|
else: |
|
|
names.append(f"col_{i}") |
|
|
df_rows.columns = names |
|
|
|
|
|
df_rows = df_rows[~(df_rows.fillna("").apply(lambda r: "".join(r.values), axis=1).str.strip()=="")] |
|
|
return df_rows.reset_index(drop=True) |
|
|
|
|
|
|
|
|
st.title("Invoice Extraction β ViT recognizer (dparres) + Tesseract detector") |
|
|
|
|
|
up = st.file_uploader("Upload an invoice (PDF/JPG/PNG)", type=["pdf","png","jpg","jpeg"]) |
|
|
if not up: |
|
|
st.info("Upload a scanned invoice to begin.") |
|
|
st.stop() |
|
|
|
|
|
pages = load_pages(up.read(), up.name) |
|
|
|
|
|
|
|
|
num_classes = len(alphabet) + (1 if arch=="Swin_CTC" else 0) |
|
|
assert os.path.exists(ckpt_path), f"Checkpoint not found: {ckpt_path}" |
|
|
model = load_pdrt(arch, ckpt_path, num_classes) |
|
|
tfm = build_transform(img_h, img_w) |
|
|
|
|
|
page_idx = 0 |
|
|
if len(pages) > 1: |
|
|
page_idx = st.number_input("Page", 1, len(pages), 1) - 1 |
|
|
img = pages[page_idx] |
|
|
|
|
|
col1, col2 = st.columns([1.1,1.3], gap="large") |
|
|
|
|
|
with col1: |
|
|
st.subheader("Preview") |
|
|
st.image(img, use_column_width=True) |
|
|
det_img = preprocess_for_detection(img) |
|
|
with st.expander("Detection view"): |
|
|
st.image(det_img, use_column_width=True) |
|
|
|
|
|
with col2: |
|
|
st.subheader("OCR & Extraction") |
|
|
|
|
|
det_df = detect_words(det_img, lang=det_lang) |
|
|
|
|
|
|
|
|
crops, metas = crop_words(det_img, det_df) |
|
|
texts = recognize_word_crops(model, crops, tfm, arch, alphabet) |
|
|
|
|
|
|
|
|
det_df = det_df.reset_index(drop=True) |
|
|
det_df["pred"] = texts |
|
|
grouped = det_df.groupby(["block_num","par_num","line_num"]) |
|
|
lines = [] |
|
|
for _, g in grouped: |
|
|
g = g.sort_values("left") |
|
|
line = " ".join([t for t in g["pred"].tolist() if t]) |
|
|
lines.append(line) |
|
|
full_text = "\n".join([ln for ln in lines if ln.strip()]) |
|
|
|
|
|
if show_boxes: |
|
|
st.caption("First 15 predicted words") |
|
|
st.write(det_df[["left","top","width","height","text","pred"]].head(15)) |
|
|
|
|
|
|
|
|
key_fields = parse_fields(full_text) |
|
|
k1,k2,k3 = st.columns(3) |
|
|
with k1: |
|
|
st.write(f"**Invoice #:** {key_fields.get('invoice_number') or 'β'}") |
|
|
st.write(f"**Invoice Date:** {key_fields.get('invoice_date') or 'β'}") |
|
|
with k2: |
|
|
st.write(f"**PO #:** {key_fields.get('po_number') or 'β'}") |
|
|
st.write(f"**Subtotal:** {key_fields.get('subtotal') or 'β'}") |
|
|
with k3: |
|
|
st.write(f"**Tax:** {key_fields.get('tax') or 'β'}") |
|
|
tot = key_fields.get('total') or 'β' |
|
|
cur = key_fields.get('currency') or '' |
|
|
st.write(f"**Total:** {tot} {cur}".strip()) |
|
|
|
|
|
|
|
|
items = items_from_wordgrid(det_df.assign(text=det_df["pred"])) |
|
|
st.markdown("**Line Items**") |
|
|
if items.empty: |
|
|
st.caption("No line items confidently detected.") |
|
|
else: |
|
|
st.dataframe(items, use_container_width=True) |
|
|
|
|
|
|
|
|
result = { |
|
|
"file": up.name, "page": page_idx+1, |
|
|
"key_fields": key_fields, |
|
|
"items": items.to_dict(orient="records") if not items.empty else [], |
|
|
"full_text": full_text |
|
|
} |
|
|
st.download_button("Download JSON", data=json.dumps(result, indent=2), file_name="invoice_extraction.json", mime="application/json") |
|
|
if not items.empty: |
|
|
st.download_button("Download Items CSV", data=items.to_csv(index=False), file_name="invoice_items.csv", mime="text/csv") |
|
|
|