File size: 9,457 Bytes
c2958f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01ae44e
 
c2958f9
01ae44e
 
c2958f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01ae44e
c2958f9
 
 
 
 
01ae44e
c2958f9
 
 
 
 
 
 
 
01ae44e
c2958f9
 
 
 
 
 
01ae44e
c2958f9
 
01ae44e
c2958f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01ae44e
c2958f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
_L='faiss.index'
_K='meta.json'
_J='source'
_I='name'
_H='\x00'
_G=False
_F='\n'
_E='text'
_D='mmr'
_C='utf-8'
_B=None
_A=True
import os,io,json,pathlib,shutil
from typing import List,Tuple,Dict
import gradio as gr,numpy as np,faiss
from sentence_transformers import SentenceTransformer
from pypdf import PdfReader
import fitz
from collections import defaultdict
from openai import OpenAI
API_KEY=os.environ.get('API_KEY')
if not API_KEY:raise RuntimeError('Missing API_KEY (set it in Hugging Face: Settings → Variables and secrets).')
client=OpenAI(base_url='https://openrouter.ai/api/v1',api_key=API_KEY)
SINGLE_MODEL_NAME='deepseek/deepseek-r1:free'
GEN_TEMPERATURE=.2
GEN_TOP_P=.95
GEN_MAX_TOKENS=1024
EMB_MODEL_NAME='intfloat/multilingual-e5-base'
def choose_store_dir():
	B='/data'
	if os.path.isdir(B)and os.access(B,os.W_OK):
		A=os.path.join(B,'rag_store')
		try:
			os.makedirs(A,exist_ok=_A);C=os.path.join(A,'.write_test')
			with open(C,'w',encoding=_C)as D:D.write('ok')
			os.remove(C);return A,_A
		except Exception:pass
	A=os.path.join(os.getcwd(),'store');os.makedirs(A,exist_ok=_A);return A,_G
STORE_DIR,IS_PERSISTENT=choose_store_dir()
META_PATH=os.path.join(STORE_DIR,_K)
INDEX_PATH=os.path.join(STORE_DIR,_L)
LEGACY_STORE_DIR=os.path.join(os.getcwd(),'store')
def migrate_legacy_if_any():
	try:
		if IS_PERSISTENT:
			A=os.path.join(LEGACY_STORE_DIR,_K);B=os.path.join(LEGACY_STORE_DIR,_L)
			if(not os.path.exists(META_PATH)or not os.path.exists(INDEX_PATH))and os.path.isdir(LEGACY_STORE_DIR)and os.path.exists(A)and os.path.exists(B):shutil.copyfile(A,META_PATH);shutil.copyfile(B,INDEX_PATH)
	except Exception:pass
migrate_legacy_if_any()
_emb_model=_B
_index=_B
_meta={}
DEFAULT_TOP_K=6
DEFAULT_POOL_K=40
DEFAULT_PER_SOURCE_CAP=2
DEFAULT_STRATEGY=_D
DEFAULT_MMR_LAMBDA=.5
def get_emb_model():
	global _emb_model
	if _emb_model is _B:_emb_model=SentenceTransformer(EMB_MODEL_NAME)
	return _emb_model
def _ensure_index(dim):
	global _index
	if _index is _B:_index=faiss.IndexFlatIP(dim)
def _persist():
	faiss.write_index(_index,INDEX_PATH)
	with open(META_PATH,'w',encoding=_C)as A:json.dump(_meta,A,ensure_ascii=_G)
def _load_if_any():
	global _index,_meta
	if os.path.exists(INDEX_PATH)and os.path.exists(META_PATH):
		_index=faiss.read_index(INDEX_PATH)
		with open(META_PATH,'r',encoding=_C)as A:_meta=json.load(A)
def _chunk_text(text,chunk_size=800,overlap=120):
	A=text;A=A.replace(_H,'');E,B,C=[],0,len(A)
	while B<C:
		D=min(B+chunk_size,C);F=A[B:D].strip()
		if F:E.append(F)
		B=max(0,D-overlap)
		if D>=C:break
	return E
def _read_bytes(file):
	D='data';A=file
	if isinstance(A,dict):
		B=A.get('path')or A.get(_I)
		if B and os.path.exists(B):
			with open(B,'rb')as C:return C.read()
	if D in A and isinstance(A[D],(bytes,bytearray)):return bytes(A[D])
	if isinstance(A,(str,pathlib.Path)):
		with open(A,'rb')as C:return C.read()
	if hasattr(A,'read'):
		try:
			if hasattr(A,'seek'):
				try:A.seek(0)
				except Exception:pass
			return A.read()
		finally:
			try:A.close()
			except Exception:pass
	raise ValueError('Unsupported file type from gr.File')
def _decode_best_effort(raw):
	for A in[_C,'cp932','shift_jis','cp950','big5','gb18030','latin-1']:
		try:return raw.decode(A)
		except Exception:continue
	return raw.decode(_C,errors='ignore')
def _read_pdf(file_bytes):
	C=file_bytes
	try:
		with fitz.open(stream=C,filetype='pdf')as A:
			if A.is_encrypted:
				try:A.authenticate('')
				except Exception:pass
			E=[A.get_text(_E)or''for A in A];D=_F.join(E)
			if D.strip():return D
	except Exception:pass
	try:
		F=PdfReader(io.BytesIO(C));B=[]
		for G in F.pages:
			try:B.append(G.extract_text()or'')
			except Exception:B.append('')
		return _F.join(B)
	except Exception:return''
def _read_any(file):
	D='upload';A=file
	if isinstance(A,dict):B=(A.get('orig_name')or A.get(_I)or A.get('path')or D).lower()
	else:B=getattr(A,_I,_B)or(str(A)if isinstance(A,(str,pathlib.Path))else D)
	B=B.lower();C=_read_bytes(A)
	if B.endswith('.pdf'):return _read_pdf(C).replace(_H,'')
	return _decode_best_effort(C).replace(_H,'')
DOCS_DIR=os.path.join(os.getcwd(),'docs')
def get_docs_files():
	if not os.path.isdir(DOCS_DIR):return[]
	A=[]
	for B in os.listdir(DOCS_DIR):
		if B.lower().endswith(('.pdf','.txt')):A.append(os.path.join(DOCS_DIR,B))
	return A
def build_corpus_from_docs():
	global _index,_meta;C=get_docs_files()
	if not C:return'No files found in docs folder.'
	J=get_emb_model();A,F,B=[],[],[];_index=_B;_meta={}
	for G in C:
		D=os.path.basename(G)
		try:
			K=_read_any(G)or'';E=_chunk_text(K)
			if not E:B.append(D);continue
			A.extend(E);F.extend([D]*len(E))
		except Exception:B.append(D)
	if not A:return'No text extracted from docs.'
	L=[f"passage: {A}"for A in A];H=J.encode(L,batch_size=64,convert_to_numpy=_A,normalize_embeddings=_A);_ensure_index(H.shape[1]);_index.add(H)
	for(M,(N,O))in enumerate(zip(F,A)):_meta[str(M)]={_J:N,_E:O}
	_persist();I=f"Indexed {len(A)} chunks from {len(C)} files."
	if B:I+=f" Failed files: {', '.join(B)}"
	return I
def _encode_query_vec(query):return get_emb_model().encode([f"query: {query}"],convert_to_numpy=_A,normalize_embeddings=_A)
def retrieve_candidates(qvec,pool_k=40):
	A=pool_k
	if _index is _B or _index.ntotal==0:return[]
	A=min(A,_index.ntotal);B,C=_index.search(qvec,A);return[(str(A),float(B))for(A,B)in zip(C[0],B[0])if A!=-1]
def select_diverse_by_source(cands,top_k=6,per_source_cap=2):
	F=cands;D=top_k
	if not F:return[]
	B=defaultdict(list)
	for(C,G)in F:
		I=_meta.get(C)
		if not I:continue
		B[I[_J]].append((C,G))
	for E in B:B[E]=B[E][:per_source_cap]
	A,N,J=[],[(A,B)for(A,B)in B.items()],{A:0 for A in B}
	while len(A)<D:
		K=_G
		for(E,L)in N:
			H=J[E]
			if H<len(L):A.append(L[H]);J[E]=H+1;K=_A
			if len(A)>=D:break
		if not K:break
	if len(A)<D:
		M={A for(A,B)in A}
		for(C,G)in F:
			if C not in M:
				A.append((C,G));M.add(C)
				if len(A)>=D:break
	return A[:D]
def _encode_chunks_text(cids):A=[f"passage: {(_meta.get(A)or{}).get(_E,'')}"for A in cids];return get_emb_model().encode(A,convert_to_numpy=_A,normalize_embeddings=_A)
def select_diverse_mmr(cands,qvec,top_k=6,mmr_lambda=.5):
	H=mmr_lambda;G=top_k;F=cands
	if not F:return[]
	C=[A for(A,B)in F];D=_encode_chunks_text(C);E=([email protected]).reshape(-1);A,B=[],set(range(len(C)))
	while len(A)<min(G,len(C)):
		if not A:I=int(np.argmax(E));A.append(I);B.remove(I);continue
		L=D[A];J=D[list(B)]@L.T;M=J.max(axis=1)if J.size>0 else np.zeros((len(B),),dtype=np.float32);N=E[list(B)];O=H*N-(1.-H)*M;P=int(np.argmax(O));K=list(B)[P];A.append(K);B.remove(K)
	return[(C[A],float(E[A]))for A in A][:G]
def retrieve_diverse(query,top_k=6,pool_k=40,per_source_cap=2,strategy=_D,mmr_lambda=.5):
	A=top_k;B=_encode_query_vec(query);C=retrieve_candidates(B,pool_k=pool_k)
	if strategy==_D:return select_diverse_mmr(C,B,top_k=A,mmr_lambda=mmr_lambda)
	return select_diverse_by_source(C,top_k=A,per_source_cap=per_source_cap)
def _format_ctx(hits):
	if not hits:return''
	B=[]
	for(C,F)in hits:
		A=_meta.get(C)
		if not A:continue
		D=A.get(_J,'');E=(A.get(_E,'')or'').replace(_F,' ');B.append(f"[{C}] ({D}) "+E)
	return _F.join(B[:10])
def chat_fn(message,history):
	H='user';F=message;E='content';D='role';I=SINGLE_MODEL_NAME
	if _index is _B or _index.ntotal==0:
		J=build_corpus_from_docs()
		if not(_index and _index.ntotal>0):yield f"**Index Status:** {J}\n\nPlease ensure you have a 'docs' folder with PDF/TXT files and try again.";return
	G=retrieve_diverse(F,top_k=6,pool_k=40,per_source_cap=2,strategy=_D,mmr_lambda=.5);P=_format_ctx(G)if G else'(Current index is empty or no matching chunks found)';K=['You are a research assistant who has an excellent factual understanding of the legal policies, regulations, and compliance of enterprises, governments, and global organizations. You are a research assistant who reads Legal papers and provides factual answers to queries. If you do not know the answer, you should convey that to the user instead of hallucinating. Answers must be based on retrieved content with evidence and source numbers cited. If retrieval is insufficient, please clearly explain the shortcomings. When answering, please cite the numbers, e.g., [3]'];B=[{D:'system',E:'\n\n'.join(K)}]
	for(L,M)in history:B.append({D:H,E:L});B.append({D:'assistant',E:M})
	B.append({D:H,E:F})
	try:
		N=client.chat.completions.create(model=I,messages=B,temperature=GEN_TEMPERATURE,top_p=GEN_TOP_P,max_tokens=GEN_MAX_TOKENS,stream=_A);C=''
		for A in N:
			if hasattr(A.choices[0],'delta')and A.choices[0].delta.content is not _B:C+=A.choices[0].delta.content;yield C
			elif hasattr(A.choices[0],'message')and A.choices[0].message.content is not _B:C+=A.choices[0].message.content;yield C
	except Exception as O:yield f"[Exception] {repr(O)}"
try:_load_if_any()
except Exception as e:print(f"Notice: Could not load existing index. A new one will be created. Error: {e}")
if __name__=='__main__':
	def chatbot_interface(user_message):
		'\n        Adapter function to connect the stateless gr.Interface to the \n        streaming backend chat function.\n        ';B=[];C=chat_fn(user_message,B);A=''
		for D in C:A=D
		return A
	with gr.Blocks(theme=gr.themes.Default(primary_hue='sky'))as legalprodigy:inputs=gr.Textbox(lines=7,label='LegalProdigy Query:',placeholder='Try: Explain Arbitration Process');outputs=gr.Textbox(lines=10,label='LegalProdigy Response:');gr.Interface(fn=chatbot_interface,inputs=inputs,outputs=outputs)
	legalprodigy.launch()