DPT2 / app.py
Seth0330's picture
Update app.py
11d644c verified
raw
history blame
13.9 kB
import io, os, re, json
from typing import List, Tuple, Dict
import numpy as np
import pandas as pd
from PIL import Image, ImageOps, ImageFilter
import streamlit as st
import torch
import torchvision.transforms as T
# --- word detector (Tesseract) ---
import pytesseract
from pytesseract import Output
# --- PDF -> images ---
from pdf2image import convert_from_bytes
# ---- import the repo's models ----
# Install via requirements.txt (git+https URL) OR copy repo files into root.
# The repo defines model classes: Swin_CTC, VED
import models as pdrt_models # from dparres/Pretrained-Document-Recognition-Transformers
st.set_page_config(page_title="Invoice OCR (ViT recognizer + Tesseract detector)", layout="wide")
# ========================= UI SIDEBAR =========================
st.sidebar.header("Model")
arch = st.sidebar.selectbox("Architecture", ["Swin_CTC", "VED"], index=0)
ckpt_path = st.sidebar.text_input("Checkpoint path (inside Space)", value="checkpoints/pdrt_weights.pth")
alphabet = st.sidebar.text_input("Alphabet (ordered classes, exclude CTC blank)", value="0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz-_/.,:;()[]{}#+*&%$@!?\"' ")
img_h = st.sidebar.number_input("Recognizer input height", 64, 256, 128, 8)
img_w = st.sidebar.number_input("Recognizer input width", 128, 2048, 512, 16)
det_lang = st.sidebar.text_input("Tesseract lang(s) for detection only", value="eng")
show_boxes = st.sidebar.checkbox("Show word boxes", value=False)
device = "cuda" if torch.cuda.is_available() else "cpu"
# ========================= UTILITIES =========================
def load_pages(file_bytes: bytes, name: str) -> List[Image.Image]:
name = (name or "").lower()
if name.endswith(".pdf"):
return convert_from_bytes(file_bytes, dpi=300)
return [Image.open(io.BytesIO(file_bytes)).convert("RGB")]
def preprocess_for_detection(img: Image.Image) -> Image.Image:
g = ImageOps.grayscale(img)
g = ImageOps.autocontrast(g)
g = g.filter(ImageFilter.UnsharpMask(radius=1, percent=150, threshold=3))
return g
@st.cache_resource
def load_pdrt(arch_name: str, ckpt: str, num_classes: int):
if arch_name == "Swin_CTC":
model = pdrt_models.Swin_CTC(num_classes=num_classes)
elif arch_name == "VED":
model = pdrt_models.VED(num_classes=num_classes)
else:
raise ValueError("Unknown model")
state = torch.load(ckpt, map_location="cpu")
model.load_state_dict(state, strict=False)
model.eval().to(device)
return model
def build_transform(img_h: int, img_w: int):
return T.Compose([
T.Grayscale(num_output_channels=3), # keep 3ch if encoder expects RGB
T.Resize((img_h, img_w)),
T.ToTensor(),
T.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]),
])
def greedy_ctc_decode(logits: torch.Tensor, alphabet: str) -> str:
"""
logits: (B, T, C) or (T, B, C). We map argmax to chars, collapse repeats, remove blank.
We assume blank_id = len(alphabet).
"""
if logits.dim() == 3 and logits.shape[0] != 1 and logits.shape[1] == 1:
# rare shape, just permute if needed
pass
if logits.shape[0] == 1:
logits = logits.squeeze(0) # (T, C)
elif logits.shape[1] == 1:
logits = logits[:,0,:] # (T, C)
probs = logits.softmax(-1)
ids = probs.argmax(-1).tolist()
blank_id = len(alphabet)
out = []
prev = None
for i in ids:
if i != prev and i != blank_id:
out.append(alphabet[i] if i < len(alphabet) else "")
prev = i
return "".join(out)
def recognize_word_crops(model, crops: List[Image.Image], tfm, arch_name: str, alphabet: str) -> List[str]:
texts = []
with torch.no_grad():
for im in crops:
x = tfm(im).unsqueeze(0).to(device)
y = model(x)
if arch_name == "Swin_CTC":
# expect CTC logits [B, T, C] or [T, B, C]
if y.dim() == 3 and y.shape[0] == 1: # [1, T, C]
logits = y[0] # [T, C]
elif y.dim() == 3 and y.shape[1] == 1: # [T, 1, C]
logits = y[:,0,:]
else:
logits = y
txt = greedy_ctc_decode(logits, alphabet)
else:
# VED: if returns token ids/logits, plug your repo's decoding here.
# Fallback: argmax over last dim per step and map ids to alphabet (no blank).
if y.dim() == 3 and y.shape[0] == 1:
y = y[0]
ids = y.argmax(-1).tolist()
txt = "".join(alphabet[i] if i < len(alphabet) else "" for i in ids).strip()
texts.append(txt)
return texts
def detect_words(img: Image.Image, lang="eng") -> pd.DataFrame:
df = pytesseract.image_to_data(img, lang=lang, output_type=Output.DATAFRAME)
df = df.dropna(subset=["text"]).reset_index(drop=True)
df["x2"] = df["left"] + df["width"]
df["y2"] = df["top"] + df["height"]
return df[df["conf"] > -1]
def crop_words(img: Image.Image, df: pd.DataFrame) -> List[Tuple[Image.Image, Dict]]:
crops, metas = [], []
for _, r in df.iterrows():
if str(r["text"]).strip() == "":
continue
box = (int(r["left"]), int(r["top"]), int(r["x2"]), int(r["y2"]))
c = img.crop(box)
crops.append(c)
metas.append({"box": box})
return crops, metas
# ---------------- key fields & table (same logic as earlier Tesseract app) ----------------
CURRENCY = r"(?P<curr>USD|CAD|EUR|GBP|\$|C\$|€|Β£)?"
MONEY = rf"{CURRENCY}\s?(?P<amt>\d{{1,3}}(?:[,]\d{{3}})*(?:[.]\d{{2}})?)"
DATE = r"(?P<date>(?:\d{4}[-/]\d{1,2}[-/]\d{1,2})|(?:\d{1,2}[-/]\d{1,2}[-/]\d{2,4})|(?:[A-Za-z]{3,9}\s+\d{1,2},\s*\d{2,4}))"
INV_PAT = r"(?:invoice\s*(?:no\.?|#|number)?\s*[:\-]?\s*(?P<inv>[A-Z0-9\-_/]{4,}))"
PO_PAT = r"(?:po\s*(?:no\.?|#|number)?\s*[:\-]?\s*(?P<po>[A-Z0-9\-_/]{3,}))"
TOTAL_PAT = rf"(?:\b(total(?:\s*amount)?|amount\s*due|grand\s*total)\b.*?{MONEY})"
SUBTOTAL_PAT = rf"(?:\bsub\s*total\b.*?{MONEY})"
TAX_PAT = rf"(?:\b(tax|gst|vat|hst)\b.*?{MONEY})"
def parse_fields(fulltext: str):
t = re.sub(r"[ \t]+", " ", fulltext)
t = re.sub(r"\n{2,}", "\n", t)
out = {"invoice_number":None,"invoice_date":None,"po_number":None,"subtotal":None,"tax":None,"total":None,"currency":None}
m = re.search(INV_PAT, t, re.I); out["invoice_number"] = m.group("inv") if m else None
m = re.search(PO_PAT, t, re.I); out["po_number"] = m.group("po") if m else None
m = re.search(rf"(invoice\s*date[:\-\s]*){DATE}", t, re.I)
out["invoice_date"] = (m.group("date") if m else (re.search(DATE, t, re.I).group("date") if re.search(DATE, t, re.I) else None))
m = re.search(SUBTOTAL_PAT, t, re.I|re.S);
if m: out["subtotal"], out["currency"] = m.group("amt").replace(",",""), m.group("curr") or out["currency"]
m = re.search(TAX_PAT, t, re.I|re.S);
if m: out["tax"], out["currency"] = m.group("amt").replace(",",""), m.group("curr") or out["currency"]
m = re.search(TOTAL_PAT, t, re.I|re.S);
if m: out["total"], out["currency"] = m.group("amt").replace(",",""), m.group("curr") or out["currency"]
if out["currency"] in ["$", "C$", "€", "Β£"]:
out["currency"] = {"$":"USD", "C$":"CAD", "€":"EUR", "Β£":"GBP"}[out["currency"]]
return out
HEAD_CANDIDATES = ["description","item","qty","quantity","price","unit","rate","amount","total"]
def items_from_wordgrid(df: pd.DataFrame) -> pd.DataFrame:
# Group into lines
df = df.copy()
df["cx"] = df["left"] + 0.5*df["width"]
df["cy"] = df["top"] + 0.5*df["height"]
lines = []
for (b,p,l), g in df.groupby(["block_num","par_num","line_num"]):
text = " ".join([t for t in g["text"].astype(str) if t.strip()])
if text.strip():
lines.append({
"block_num":b,"par_num":p,"line_num":l,
"text": text.lower(),
"top": g["top"].min(), "bottom": (g["top"]+g["height"]).max(),
"left": g["left"].min(), "right": (g["left"]+g["width"]).max(),
"words": g.sort_values("cx")[["cx","left","top","width","height"]].values.tolist()
})
L = pd.DataFrame(lines)
if L.empty: return pd.DataFrame()
L["score"] = L["text"].apply(lambda s: sum(1 for h in HEAD_CANDIDATES if h in s))
headers = L[L["score"]>=2].sort_values(["score","top"], ascending=[False,True])
if headers.empty: return pd.DataFrame()
H = headers.iloc[0]
header_y = H["bottom"] + 4
# choose column centers from header words positions
# we reuse df within header band
header_band = df[(df["top"]>=H["top"]-5) & ((df["top"]+df["height"])<=H["bottom"]+5)]
header_band = header_band.sort_values("left")
col_x = header_band["left"].tolist()
if len(col_x)<2: return pd.DataFrame()
# region below header until totals
below = df[df["top"]>header_y].copy()
totals_mask = below["text"].str.lower().str.contains(r"(sub\s*total|amount\s*due|total|grand\s*total|balance)", regex=True, na=False)
if totals_mask.any():
stop_y = below.loc[totals_mask,"top"].min()
below = below[below["top"]<stop_y-4]
rows = []
for (b,p,l), g in below.groupby(["block_num","par_num","line_num"]):
if g["text"].astype(str).str.strip().eq("").all(): continue
g = g.sort_values("left")
# assign to nearest header word x
xs = np.array(col_x)
buckets = {i:[] for i in range(len(xs))}
for _,w in g.iterrows():
idx = int(np.abs(xs - w["left"]).argmin())
buckets[idx].append(str(w["text"]))
vals = [" ".join(buckets.get(i,[])).strip() for i in range(len(xs))]
rows.append(vals)
if not rows: return pd.DataFrame()
df_rows = pd.DataFrame(rows).fillna("")
# try to name columns
names = []
for i, w in enumerate(header_band["text"].tolist()[:df_rows.shape[1]]):
wl = w.lower()
if "desc" in wl or wl in ["item","description"]:
names.append("description")
elif wl in ["qty","quantity"]:
names.append("quantity")
elif "unit" in wl or "rate" in wl or "price" in wl:
names.append("unit_price")
elif "amount" in wl or "total" in wl:
names.append("line_total")
else:
names.append(f"col_{i}")
df_rows.columns = names
# drop empty lines
df_rows = df_rows[~(df_rows.fillna("").apply(lambda r: "".join(r.values), axis=1).str.strip()=="")]
return df_rows.reset_index(drop=True)
# ========================= APP =========================
st.title("Invoice Extraction β€” ViT recognizer (dparres) + Tesseract detector")
up = st.file_uploader("Upload an invoice (PDF/JPG/PNG)", type=["pdf","png","jpg","jpeg"])
if not up:
st.info("Upload a scanned invoice to begin.")
st.stop()
pages = load_pages(up.read(), up.name)
# load model once
num_classes = len(alphabet) + (1 if arch=="Swin_CTC" else 0) # add CTC blank for Swin_CTC
assert os.path.exists(ckpt_path), f"Checkpoint not found: {ckpt_path}"
model = load_pdrt(arch, ckpt_path, num_classes)
tfm = build_transform(img_h, img_w)
page_idx = 0
if len(pages) > 1:
page_idx = st.number_input("Page", 1, len(pages), 1) - 1
img = pages[page_idx]
col1, col2 = st.columns([1.1,1.3], gap="large")
with col1:
st.subheader("Preview")
st.image(img, use_column_width=True)
det_img = preprocess_for_detection(img)
with st.expander("Detection view"):
st.image(det_img, use_column_width=True)
with col2:
st.subheader("OCR & Extraction")
# 1) detect words (boxes only)
det_df = detect_words(det_img, lang=det_lang)
# 2) crop & recognize each word via ViT recognizer
crops, metas = crop_words(det_img, det_df)
texts = recognize_word_crops(model, crops, tfm, arch, alphabet)
# 3) stitch line-by-line using tesseract line indices
det_df = det_df.reset_index(drop=True)
det_df["pred"] = texts
grouped = det_df.groupby(["block_num","par_num","line_num"])
lines = []
for _, g in grouped:
g = g.sort_values("left")
line = " ".join([t for t in g["pred"].tolist() if t])
lines.append(line)
full_text = "\n".join([ln for ln in lines if ln.strip()])
if show_boxes:
st.caption("First 15 predicted words")
st.write(det_df[["left","top","width","height","text","pred"]].head(15))
# 4) key fields
key_fields = parse_fields(full_text)
k1,k2,k3 = st.columns(3)
with k1:
st.write(f"**Invoice #:** {key_fields.get('invoice_number') or 'β€”'}")
st.write(f"**Invoice Date:** {key_fields.get('invoice_date') or 'β€”'}")
with k2:
st.write(f"**PO #:** {key_fields.get('po_number') or 'β€”'}")
st.write(f"**Subtotal:** {key_fields.get('subtotal') or 'β€”'}")
with k3:
st.write(f"**Tax:** {key_fields.get('tax') or 'β€”'}")
tot = key_fields.get('total') or 'β€”'
cur = key_fields.get('currency') or ''
st.write(f"**Total:** {tot} {cur}".strip())
# 5) line items (geometry heuristic)
items = items_from_wordgrid(det_df.assign(text=det_df["pred"]))
st.markdown("**Line Items**")
if items.empty:
st.caption("No line items confidently detected.")
else:
st.dataframe(items, use_container_width=True)
# 6) downloads
result = {
"file": up.name, "page": page_idx+1,
"key_fields": key_fields,
"items": items.to_dict(orient="records") if not items.empty else [],
"full_text": full_text
}
st.download_button("Download JSON", data=json.dumps(result, indent=2), file_name="invoice_extraction.json", mime="application/json")
if not items.empty:
st.download_button("Download Items CSV", data=items.to_csv(index=False), file_name="invoice_items.csv", mime="text/csv")