LegalProdigy / app.py
Polarisailabs's picture
Upload app.py
c2958f9 verified
raw
history blame
9.46 kB
_L='faiss.index'
_K='meta.json'
_J='source'
_I='name'
_H='\x00'
_G=False
_F='\n'
_E='text'
_D='mmr'
_C='utf-8'
_B=None
_A=True
import os,io,json,pathlib,shutil
from typing import List,Tuple,Dict
import gradio as gr,numpy as np,faiss
from sentence_transformers import SentenceTransformer
from pypdf import PdfReader
import fitz
from collections import defaultdict
from openai import OpenAI
API_KEY=os.environ.get('API_KEY')
if not API_KEY:raise RuntimeError('Missing API_KEY (set it in Hugging Face: Settings → Variables and secrets).')
client=OpenAI(base_url='https://openrouter.ai/api/v1',api_key=API_KEY)
SINGLE_MODEL_NAME='deepseek/deepseek-r1:free'
GEN_TEMPERATURE=.2
GEN_TOP_P=.95
GEN_MAX_TOKENS=1024
EMB_MODEL_NAME='intfloat/multilingual-e5-base'
def choose_store_dir():
B='/data'
if os.path.isdir(B)and os.access(B,os.W_OK):
A=os.path.join(B,'rag_store')
try:
os.makedirs(A,exist_ok=_A);C=os.path.join(A,'.write_test')
with open(C,'w',encoding=_C)as D:D.write('ok')
os.remove(C);return A,_A
except Exception:pass
A=os.path.join(os.getcwd(),'store');os.makedirs(A,exist_ok=_A);return A,_G
STORE_DIR,IS_PERSISTENT=choose_store_dir()
META_PATH=os.path.join(STORE_DIR,_K)
INDEX_PATH=os.path.join(STORE_DIR,_L)
LEGACY_STORE_DIR=os.path.join(os.getcwd(),'store')
def migrate_legacy_if_any():
try:
if IS_PERSISTENT:
A=os.path.join(LEGACY_STORE_DIR,_K);B=os.path.join(LEGACY_STORE_DIR,_L)
if(not os.path.exists(META_PATH)or not os.path.exists(INDEX_PATH))and os.path.isdir(LEGACY_STORE_DIR)and os.path.exists(A)and os.path.exists(B):shutil.copyfile(A,META_PATH);shutil.copyfile(B,INDEX_PATH)
except Exception:pass
migrate_legacy_if_any()
_emb_model=_B
_index=_B
_meta={}
DEFAULT_TOP_K=6
DEFAULT_POOL_K=40
DEFAULT_PER_SOURCE_CAP=2
DEFAULT_STRATEGY=_D
DEFAULT_MMR_LAMBDA=.5
def get_emb_model():
global _emb_model
if _emb_model is _B:_emb_model=SentenceTransformer(EMB_MODEL_NAME)
return _emb_model
def _ensure_index(dim):
global _index
if _index is _B:_index=faiss.IndexFlatIP(dim)
def _persist():
faiss.write_index(_index,INDEX_PATH)
with open(META_PATH,'w',encoding=_C)as A:json.dump(_meta,A,ensure_ascii=_G)
def _load_if_any():
global _index,_meta
if os.path.exists(INDEX_PATH)and os.path.exists(META_PATH):
_index=faiss.read_index(INDEX_PATH)
with open(META_PATH,'r',encoding=_C)as A:_meta=json.load(A)
def _chunk_text(text,chunk_size=800,overlap=120):
A=text;A=A.replace(_H,'');E,B,C=[],0,len(A)
while B<C:
D=min(B+chunk_size,C);F=A[B:D].strip()
if F:E.append(F)
B=max(0,D-overlap)
if D>=C:break
return E
def _read_bytes(file):
D='data';A=file
if isinstance(A,dict):
B=A.get('path')or A.get(_I)
if B and os.path.exists(B):
with open(B,'rb')as C:return C.read()
if D in A and isinstance(A[D],(bytes,bytearray)):return bytes(A[D])
if isinstance(A,(str,pathlib.Path)):
with open(A,'rb')as C:return C.read()
if hasattr(A,'read'):
try:
if hasattr(A,'seek'):
try:A.seek(0)
except Exception:pass
return A.read()
finally:
try:A.close()
except Exception:pass
raise ValueError('Unsupported file type from gr.File')
def _decode_best_effort(raw):
for A in[_C,'cp932','shift_jis','cp950','big5','gb18030','latin-1']:
try:return raw.decode(A)
except Exception:continue
return raw.decode(_C,errors='ignore')
def _read_pdf(file_bytes):
C=file_bytes
try:
with fitz.open(stream=C,filetype='pdf')as A:
if A.is_encrypted:
try:A.authenticate('')
except Exception:pass
E=[A.get_text(_E)or''for A in A];D=_F.join(E)
if D.strip():return D
except Exception:pass
try:
F=PdfReader(io.BytesIO(C));B=[]
for G in F.pages:
try:B.append(G.extract_text()or'')
except Exception:B.append('')
return _F.join(B)
except Exception:return''
def _read_any(file):
D='upload';A=file
if isinstance(A,dict):B=(A.get('orig_name')or A.get(_I)or A.get('path')or D).lower()
else:B=getattr(A,_I,_B)or(str(A)if isinstance(A,(str,pathlib.Path))else D)
B=B.lower();C=_read_bytes(A)
if B.endswith('.pdf'):return _read_pdf(C).replace(_H,'')
return _decode_best_effort(C).replace(_H,'')
DOCS_DIR=os.path.join(os.getcwd(),'docs')
def get_docs_files():
if not os.path.isdir(DOCS_DIR):return[]
A=[]
for B in os.listdir(DOCS_DIR):
if B.lower().endswith(('.pdf','.txt')):A.append(os.path.join(DOCS_DIR,B))
return A
def build_corpus_from_docs():
global _index,_meta;C=get_docs_files()
if not C:return'No files found in docs folder.'
J=get_emb_model();A,F,B=[],[],[];_index=_B;_meta={}
for G in C:
D=os.path.basename(G)
try:
K=_read_any(G)or'';E=_chunk_text(K)
if not E:B.append(D);continue
A.extend(E);F.extend([D]*len(E))
except Exception:B.append(D)
if not A:return'No text extracted from docs.'
L=[f"passage: {A}"for A in A];H=J.encode(L,batch_size=64,convert_to_numpy=_A,normalize_embeddings=_A);_ensure_index(H.shape[1]);_index.add(H)
for(M,(N,O))in enumerate(zip(F,A)):_meta[str(M)]={_J:N,_E:O}
_persist();I=f"Indexed {len(A)} chunks from {len(C)} files."
if B:I+=f" Failed files: {', '.join(B)}"
return I
def _encode_query_vec(query):return get_emb_model().encode([f"query: {query}"],convert_to_numpy=_A,normalize_embeddings=_A)
def retrieve_candidates(qvec,pool_k=40):
A=pool_k
if _index is _B or _index.ntotal==0:return[]
A=min(A,_index.ntotal);B,C=_index.search(qvec,A);return[(str(A),float(B))for(A,B)in zip(C[0],B[0])if A!=-1]
def select_diverse_by_source(cands,top_k=6,per_source_cap=2):
F=cands;D=top_k
if not F:return[]
B=defaultdict(list)
for(C,G)in F:
I=_meta.get(C)
if not I:continue
B[I[_J]].append((C,G))
for E in B:B[E]=B[E][:per_source_cap]
A,N,J=[],[(A,B)for(A,B)in B.items()],{A:0 for A in B}
while len(A)<D:
K=_G
for(E,L)in N:
H=J[E]
if H<len(L):A.append(L[H]);J[E]=H+1;K=_A
if len(A)>=D:break
if not K:break
if len(A)<D:
M={A for(A,B)in A}
for(C,G)in F:
if C not in M:
A.append((C,G));M.add(C)
if len(A)>=D:break
return A[:D]
def _encode_chunks_text(cids):A=[f"passage: {(_meta.get(A)or{}).get(_E,'')}"for A in cids];return get_emb_model().encode(A,convert_to_numpy=_A,normalize_embeddings=_A)
def select_diverse_mmr(cands,qvec,top_k=6,mmr_lambda=.5):
H=mmr_lambda;G=top_k;F=cands
if not F:return[]
C=[A for(A,B)in F];D=_encode_chunks_text(C);E=([email protected]).reshape(-1);A,B=[],set(range(len(C)))
while len(A)<min(G,len(C)):
if not A:I=int(np.argmax(E));A.append(I);B.remove(I);continue
L=D[A];J=D[list(B)]@L.T;M=J.max(axis=1)if J.size>0 else np.zeros((len(B),),dtype=np.float32);N=E[list(B)];O=H*N-(1.-H)*M;P=int(np.argmax(O));K=list(B)[P];A.append(K);B.remove(K)
return[(C[A],float(E[A]))for A in A][:G]
def retrieve_diverse(query,top_k=6,pool_k=40,per_source_cap=2,strategy=_D,mmr_lambda=.5):
A=top_k;B=_encode_query_vec(query);C=retrieve_candidates(B,pool_k=pool_k)
if strategy==_D:return select_diverse_mmr(C,B,top_k=A,mmr_lambda=mmr_lambda)
return select_diverse_by_source(C,top_k=A,per_source_cap=per_source_cap)
def _format_ctx(hits):
if not hits:return''
B=[]
for(C,F)in hits:
A=_meta.get(C)
if not A:continue
D=A.get(_J,'');E=(A.get(_E,'')or'').replace(_F,' ');B.append(f"[{C}] ({D}) "+E)
return _F.join(B[:10])
def chat_fn(message,history):
H='user';F=message;E='content';D='role';I=SINGLE_MODEL_NAME
if _index is _B or _index.ntotal==0:
J=build_corpus_from_docs()
if not(_index and _index.ntotal>0):yield f"**Index Status:** {J}\n\nPlease ensure you have a 'docs' folder with PDF/TXT files and try again.";return
G=retrieve_diverse(F,top_k=6,pool_k=40,per_source_cap=2,strategy=_D,mmr_lambda=.5);P=_format_ctx(G)if G else'(Current index is empty or no matching chunks found)';K=['You are a research assistant who has an excellent factual understanding of the legal policies, regulations, and compliance of enterprises, governments, and global organizations. You are a research assistant who reads Legal papers and provides factual answers to queries. If you do not know the answer, you should convey that to the user instead of hallucinating. Answers must be based on retrieved content with evidence and source numbers cited. If retrieval is insufficient, please clearly explain the shortcomings. When answering, please cite the numbers, e.g., [3]'];B=[{D:'system',E:'\n\n'.join(K)}]
for(L,M)in history:B.append({D:H,E:L});B.append({D:'assistant',E:M})
B.append({D:H,E:F})
try:
N=client.chat.completions.create(model=I,messages=B,temperature=GEN_TEMPERATURE,top_p=GEN_TOP_P,max_tokens=GEN_MAX_TOKENS,stream=_A);C=''
for A in N:
if hasattr(A.choices[0],'delta')and A.choices[0].delta.content is not _B:C+=A.choices[0].delta.content;yield C
elif hasattr(A.choices[0],'message')and A.choices[0].message.content is not _B:C+=A.choices[0].message.content;yield C
except Exception as O:yield f"[Exception] {repr(O)}"
try:_load_if_any()
except Exception as e:print(f"Notice: Could not load existing index. A new one will be created. Error: {e}")
if __name__=='__main__':
def chatbot_interface(user_message):
'\n Adapter function to connect the stateless gr.Interface to the \n streaming backend chat function.\n ';B=[];C=chat_fn(user_message,B);A=''
for D in C:A=D
return A
with gr.Blocks(theme=gr.themes.Default(primary_hue='sky'))as legalprodigy:inputs=gr.Textbox(lines=7,label='LegalProdigy Query:',placeholder='Try: Explain Arbitration Process');outputs=gr.Textbox(lines=10,label='LegalProdigy Response:');gr.Interface(fn=chatbot_interface,inputs=inputs,outputs=outputs)
legalprodigy.launch()