Spaces:
Running
Running
| # """ | |
| # Retriever module for TrialGPT | |
| # """ | |
| # import json | |
| # import os | |
| # import numpy as np | |
| # import torch | |
| # from nltk import word_tokenize | |
| # from rank_bm25 import BM25Okapi | |
| # import faiss | |
| # from transformers import AutoTokenizer, AutoModel | |
| # def get_bm25_corpus_index(corpus="sigir"): | |
| # """Get BM25 corpus index for the specified corpus.""" | |
| # corpus_path = os.path.join(f"trialgpt_retrieval/bm25_corpus_{corpus}.json") | |
| # if os.path.exists(corpus_path): | |
| # corpus_data = json.load(open(corpus_path)) | |
| # tokenized_corpus = corpus_data["tokenized_corpus"] | |
| # corpus_nctids = corpus_data["corpus_nctids"] | |
| # else: | |
| # # If the pre-built index doesn't exist, we'll need to build it | |
| # # For now, return None to indicate the index needs to be built | |
| # return None, None | |
| # bm25 = BM25Okapi(tokenized_corpus) | |
| # return bm25, corpus_nctids | |
| # def get_medcpt_corpus_index(corpus="sigir"): | |
| # """Get MedCPT corpus index for the specified corpus.""" | |
| # corpus_path = f"trialgpt_retrieval/{corpus}_embeds.npy" | |
| # nctids_path = f"trialgpt_retrieval/{corpus}_nctids.json" | |
| # if os.path.exists(corpus_path): | |
| # embeds = np.load(corpus_path) | |
| # corpus_nctids = json.load(open(nctids_path)) | |
| # else: | |
| # # If the pre-built index doesn't exist, return None | |
| # return None, None | |
| # index = faiss.IndexFlatIP(768) | |
| # index.add(embeds) | |
| # return index, corpus_nctids | |
| # def retrieve_trials(conditions, corpus="sigir", top_k=5, bm25_weight=1, medcpt_weight=1): | |
| # """ | |
| # Retrieve clinical trials based on conditions using hybrid BM25 + MedCPT retrieval. | |
| # Args: | |
| # conditions (list): List of condition strings to search for | |
| # corpus (str): Corpus to search in ("trec_2021", "trec_2022", "sigir") | |
| # top_k (int): Number of top trials to return | |
| # bm25_weight (int): Weight for BM25 scores | |
| # medcpt_weight (int): Weight for MedCPT scores | |
| # Returns: | |
| # list: List of NCT IDs for the top matching trials | |
| # """ | |
| # # Get the retrieval indices | |
| # bm25, bm25_nctids = get_bm25_corpus_index(corpus) | |
| # medcpt, medcpt_nctids = get_medcpt_corpus_index(corpus) | |
| # if bm25 is None or medcpt is None: | |
| # print(f"Warning: Pre-built indices for corpus '{corpus}' not found.") | |
| # print("Please run the hybrid_fusion_retrieval.py script first to build the indices.") | |
| # return [] | |
| # if len(conditions) == 0: | |
| # return [] | |
| # # BM25 retrieval | |
| # bm25_condition_top_nctids = [] | |
| # if bm25_weight > 0: | |
| # for condition in conditions: | |
| # tokens = word_tokenize(condition.lower()) | |
| # top_nctids = bm25.get_top_n(tokens, bm25_nctids, n=top_k*2) # Get more for fusion | |
| # bm25_condition_top_nctids.append(top_nctids) | |
| # # MedCPT retrieval | |
| # medcpt_condition_top_nctids = [] | |
| # if medcpt_weight > 0: | |
| # try: | |
| # model = AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder") | |
| # tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder") | |
| # with torch.no_grad(): | |
| # encoded = tokenizer( | |
| # conditions, | |
| # truncation=True, | |
| # padding=True, | |
| # return_tensors='pt', | |
| # max_length=256, | |
| # ) | |
| # # encode the queries | |
| # embeds = model(**encoded).last_hidden_state[:, 0, :].numpy() | |
| # # search the Faiss index | |
| # scores, inds = medcpt.search(embeds, k=top_k*2) # Get more for fusion | |
| # for ind_list in inds: | |
| # top_nctids = [medcpt_nctids[ind] for ind in ind_list] | |
| # medcpt_condition_top_nctids.append(top_nctids) | |
| # except Exception as e: | |
| # print(f"Warning: MedCPT retrieval failed: {e}") | |
| # medcpt_weight = 0 | |
| # # Fusion of results | |
| # nctid2score = {} | |
| # for condition_idx, condition in enumerate(conditions): | |
| # # BM25 scoring | |
| # if bm25_weight > 0 and condition_idx < len(bm25_condition_top_nctids): | |
| # for rank, nctid in enumerate(bm25_condition_top_nctids[condition_idx]): | |
| # if nctid not in nctid2score: | |
| # nctid2score[nctid] = 0 | |
| # nctid2score[nctid] += bm25_weight * (1 / (rank + 1)) * (1 / (condition_idx + 1)) | |
| # # MedCPT scoring | |
| # if medcpt_weight > 0 and condition_idx < len(medcpt_condition_top_nctids): | |
| # for rank, nctid in enumerate(medcpt_condition_top_nctids[condition_idx]): | |
| # if nctid not in nctid2score: | |
| # nctid2score[nctid] = 0 | |
| # nctid2score[nctid] += medcpt_weight * (1 / (rank + 1)) * (1 / (condition_idx + 1)) | |
| # # Sort by score and return top_k | |
| # nctid2score = sorted(nctid2score.items(), key=lambda x: -x[1]) | |
| # top_nctids = [nctid for nctid, _ in nctid2score[:top_k]] | |
| # return top_nctids | |
| """ | |
| GPU-accelerated Retriever module for TrialGPT | |
| """ | |
| import json | |
| import os | |
| import numpy as np | |
| import torch | |
| from nltk import word_tokenize | |
| from rank_bm25 import BM25Okapi | |
| import faiss | |
| from transformers import AutoTokenizer, AutoModel | |
| class GPUTrialRetriever: | |
| """GPU-accelerated trial retriever with model caching.""" | |
| def __init__(self, device=None): | |
| self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.model = None | |
| self.tokenizer = None | |
| self._model_loaded = False | |
| def _load_medcpt_model(self): | |
| """Load MedCPT model once and cache it.""" | |
| if not self._model_loaded: | |
| self.tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder") | |
| self.model = AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder") | |
| # Move model to GPU if available | |
| if self.device != 'cpu': | |
| self.model = self.model.to(self.device) | |
| # Set to evaluation mode for inference | |
| self.model.eval() | |
| self._model_loaded = True | |
| print(f"MedCPT model loaded on {self.device}") | |
| def get_bm25_corpus_index(corpus="sigir"): | |
| """Get BM25 corpus index for the specified corpus.""" | |
| corpus_path = os.path.join(f"trialgpt_retrieval/bm25_corpus_{corpus}.json") | |
| if os.path.exists(corpus_path): | |
| corpus_data = json.load(open(corpus_path)) | |
| tokenized_corpus = corpus_data["tokenized_corpus"] | |
| corpus_nctids = corpus_data["corpus_nctids"] | |
| else: | |
| return None, None | |
| bm25 = BM25Okapi(tokenized_corpus) | |
| return bm25, corpus_nctids | |
| def get_medcpt_corpus_index_gpu(corpus="sigir", use_gpu=True): | |
| """Get GPU-accelerated MedCPT corpus index.""" | |
| corpus_path = f"trialgpt_retrieval/{corpus}_embeds.npy" | |
| nctids_path = f"trialgpt_retrieval/{corpus}_nctids.json" | |
| if os.path.exists(corpus_path): | |
| embeds = np.load(corpus_path).astype(np.float32) # Ensure float32 for GPU | |
| corpus_nctids = json.load(open(nctids_path)) | |
| else: | |
| return None, None | |
| # Use GPU FAISS index if available and requested | |
| if use_gpu and torch.cuda.is_available(): | |
| try: | |
| # Create GPU resources | |
| res = faiss.StandardGpuResources() | |
| # Create CPU index first | |
| cpu_index = faiss.IndexFlatIP(768) | |
| cpu_index.add(embeds) | |
| # Move to GPU | |
| gpu_index = faiss.index_cpu_to_gpu(res, 0, cpu_index) | |
| print(f"FAISS index moved to GPU with {len(corpus_nctids)} embeddings") | |
| return gpu_index, corpus_nctids | |
| except Exception as e: | |
| print(f"GPU FAISS failed, falling back to CPU: {e}") | |
| # Fall back to CPU index | |
| index = faiss.IndexFlatIP(768) | |
| index.add(embeds) | |
| return index, corpus_nctids | |
| else: | |
| # CPU index | |
| index = faiss.IndexFlatIP(768) | |
| index.add(embeds) | |
| return index, corpus_nctids | |
| def retrieve_trials_gpu(conditions, corpus="sigir", top_k=5, bm25_weight=1, medcpt_weight=1, | |
| use_gpu=True, batch_size=32, retriever=None): | |
| """ | |
| GPU-accelerated clinical trial retrieval with optimized batching. | |
| Args: | |
| conditions (list): List of condition strings to search for | |
| corpus (str): Corpus to search in ("trec_2021", "trec_2022", "sigir") | |
| top_k (int): Number of top trials to return | |
| bm25_weight (int): Weight for BM25 scores | |
| medcpt_weight (int): Weight for MedCPT scores | |
| use_gpu (bool): Whether to use GPU acceleration | |
| batch_size (int): Batch size for MedCPT encoding | |
| retriever (GPUTrialRetriever): Cached retriever instance | |
| Returns: | |
| list: List of NCT IDs for the top matching trials | |
| """ | |
| if len(conditions) == 0: | |
| return [] | |
| # Get the retrieval indices | |
| bm25, bm25_nctids = get_bm25_corpus_index(corpus) | |
| medcpt, medcpt_nctids = get_medcpt_corpus_index_gpu(corpus, use_gpu) | |
| if bm25 is None or medcpt is None: | |
| print(f"Warning: Pre-built indices for corpus '{corpus}' not found.") | |
| print("Please run the hybrid_fusion_retrieval.py script first to build the indices.") | |
| return [] | |
| # Initialize retriever if not provided | |
| if retriever is None: | |
| device = 'cuda' if use_gpu and torch.cuda.is_available() else 'cpu' | |
| retriever = GPUTrialRetriever(device) | |
| # BM25 retrieval (CPU-bound, hard to optimize further) | |
| bm25_condition_top_nctids = [] | |
| if bm25_weight > 0: | |
| for condition in conditions: | |
| tokens = word_tokenize(condition.lower()) | |
| top_nctids = bm25.get_top_n(tokens, bm25_nctids, n=top_k*2) | |
| bm25_condition_top_nctids.append(top_nctids) | |
| # MedCPT retrieval with GPU acceleration | |
| medcpt_condition_top_nctids = [] | |
| if medcpt_weight > 0: | |
| try: | |
| retriever._load_medcpt_model() | |
| # Process conditions in batches for better GPU utilization | |
| all_embeds = [] | |
| for i in range(0, len(conditions), batch_size): | |
| batch_conditions = conditions[i:i + batch_size] | |
| with torch.no_grad(): | |
| # Tokenize batch | |
| encoded = retriever.tokenizer( | |
| batch_conditions, | |
| truncation=True, | |
| padding=True, | |
| return_tensors='pt', | |
| max_length=256, | |
| ) | |
| # Move tensors to GPU | |
| if retriever.device != 'cpu': | |
| encoded = {k: v.to(retriever.device) for k, v in encoded.items()} | |
| # Get embeddings | |
| batch_embeds = retriever.model(**encoded).last_hidden_state[:, 0, :] | |
| # Always move to CPU before appending (fixes the CUDA tensor issue) | |
| batch_embeds = batch_embeds.cpu() | |
| all_embeds.append(batch_embeds) | |
| # Concatenate all embeddings and convert to numpy | |
| embeds = torch.cat(all_embeds, dim=0).numpy() | |
| # Search the FAISS index | |
| scores, inds = medcpt.search(embeds, k=top_k*2) | |
| # Convert indices to NCT IDs | |
| for ind_list in inds: | |
| top_nctids = [medcpt_nctids[ind] for ind in ind_list] | |
| medcpt_condition_top_nctids.append(top_nctids) | |
| except Exception as e: | |
| print(f"Warning: MedCPT retrieval failed: {e}") | |
| medcpt_weight = 0 | |
| # Fusion of results (same as original) | |
| nctid2score = {} | |
| for condition_idx, condition in enumerate(conditions): | |
| # BM25 scoring | |
| if bm25_weight > 0 and condition_idx < len(bm25_condition_top_nctids): | |
| for rank, nctid in enumerate(bm25_condition_top_nctids[condition_idx]): | |
| if nctid not in nctid2score: | |
| nctid2score[nctid] = 0 | |
| nctid2score[nctid] += bm25_weight * (1 / (rank + 1)) * (1 / (condition_idx + 1)) | |
| # MedCPT scoring | |
| if medcpt_weight > 0 and condition_idx < len(medcpt_condition_top_nctids): | |
| for rank, nctid in enumerate(medcpt_condition_top_nctids[condition_idx]): | |
| if nctid not in nctid2score: | |
| nctid2score[nctid] = 0 | |
| nctid2score[nctid] += medcpt_weight * (1 / (rank + 1)) * (1 / (condition_idx + 1)) | |
| # Sort by score and return top_k | |
| nctid2score = sorted(nctid2score.items(), key=lambda x: -x[1]) | |
| top_nctids = [nctid for nctid, _ in nctid2score[:top_k]] | |
| return top_nctids | |
| # Convenience function for backward compatibility | |
| def retrieve_trials(conditions, corpus="sigir", top_k=5, bm25_weight=1, medcpt_weight=1): | |
| """Original interface with GPU acceleration.""" | |
| return retrieve_trials_gpu(conditions, corpus, top_k, bm25_weight, medcpt_weight) | |
| # Example usage with persistent retriever for multiple calls | |
| def create_retriever_session(use_gpu=True): | |
| """Create a retriever session for multiple queries to avoid model reloading.""" | |
| device = 'cuda' if use_gpu and torch.cuda.is_available() else 'cpu' | |
| return GPUTrialRetriever(device) |