Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import spaces | |
| import polars as pl | |
| from datetime import datetime | |
| from functools import lru_cache | |
| from transformers import pipeline | |
| from typing import Dict | |
| import requests | |
| import xml.etree.ElementTree as ET | |
| import time | |
| from typing import List, Tuple, Dict | |
| label_lookup = { | |
| "LABEL_0": "NOT_CURATEABLE", | |
| "LABEL_1": "CURATEABLE" | |
| } | |
| def get_pipeline(): | |
| print("fetching model and building pipeline") | |
| model_name = "afg1/pombe_curation_fold_0" | |
| pipe = pipeline(model=model_name, task="text-classification") | |
| return pipe | |
| def classify_abstracts(abstracts:Dict[str, str],batch_size=64, progress=gr.Progress()) -> None: | |
| pipe = get_pipeline() | |
| # return classification | |
| results = [] | |
| total = len(abstracts) | |
| # Convert dictionary to lists of PMIDs and abstracts, preserving order | |
| pmids = list(abstracts.keys()) | |
| abstract_texts = list(abstracts.values()) | |
| # Initialize progress bar | |
| progress(0, desc="Starting classification...") | |
| # Process in batches | |
| for i in range(0, total, batch_size): | |
| # Get current batch | |
| batch_abstracts = abstract_texts[i:i + batch_size] | |
| batch_pmids = pmids[i:i + batch_size] | |
| try: | |
| # Classify the batch | |
| classifications = pipe(batch_abstracts) | |
| # Process each result in the batch | |
| for pmid, classification in zip(batch_pmids, classifications): | |
| results.append({ | |
| 'pmid': pmid, | |
| 'classification': label_lookup[classification['label']], | |
| 'score': classification['score'] | |
| }) | |
| # Update progress | |
| progress(min((i + batch_size) / total, 1.0), | |
| desc=f"Classified {min(i + batch_size, total)}/{total} abstracts...") | |
| except Exception as e: | |
| print(f"Error classifying batch starting at index {i}: {str(e)}") | |
| continue | |
| progress(1.0, desc="Classification complete!") | |
| return results | |
| def fetch_latest_canto_dump() -> pl.DataFrame: | |
| """ | |
| Read the latest pombase canto dump direct from the URL | |
| """ | |
| url = "https://curation.pombase.org/kmr44/canto_pombe_pubs.tsv" | |
| return pl.read_csv(url, separator='\t') | |
| def filter_new_hits(canto_pmcids: pl.DataFrame, new_pmcids: List[str]) -> List[str]: | |
| """ | |
| Convert the list of PMCIDs from the search to a dataframe and do an anti-join to | |
| find new stuff | |
| """ | |
| new_pmids = pl.DataFrame({"pmid": new_pmcids}) | |
| uncurated = new_pmids.join(canto_pmcids, on="pmid", how="anti") | |
| return uncurated.get_column("pmid").to_list() | |
| def fetch_abstracts_batch(pmids: List[str], batch_size: int = 200) -> Dict[str, str]: | |
| """ | |
| Fetch abstracts for a list of PMIDs in batches | |
| Args: | |
| pmids (List[str]): List of PMIDs to fetch abstracts for | |
| batch_size (int): Number of PMIDs to process per batch | |
| Returns: | |
| Dict[str, str]: Dictionary mapping PMIDs to their abstracts | |
| """ | |
| base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi" | |
| all_abstracts = {} | |
| # Process PMIDs in batches | |
| for i in range(0, len(pmids), batch_size): | |
| batch_pmids = pmids[i:i + batch_size] | |
| pmids_string = ",".join(batch_pmids) | |
| print(f"Processing batch {i//batch_size + 1} of {(len(pmids) + batch_size - 1)//batch_size}") | |
| params = { | |
| "db": "pubmed", | |
| "id": pmids_string, | |
| "retmode": "xml", | |
| "rettype": "abstract" | |
| } | |
| try: | |
| response = requests.get(base_url, params=params) | |
| response.raise_for_status() | |
| # Parse XML response | |
| root = ET.fromstring(response.content) | |
| # Iterate through each article in the batch | |
| for article in root.findall(".//PubmedArticle"): | |
| # Get PMID | |
| pmid = article.find(".//PMID").text | |
| # Find abstract text | |
| abstract_element = article.find(".//Abstract/AbstractText") | |
| if abstract_element is not None: | |
| # Handle structured abstracts | |
| if 'Label' in abstract_element.attrib: | |
| abstract_sections = article.findall(".//Abstract/AbstractText") | |
| abstract_text = "\n".join( | |
| f"{section.attrib.get('Label', 'Abstract')}: {section.text}" | |
| for section in abstract_sections | |
| if section.text is not None | |
| ) | |
| else: | |
| # Simple abstract | |
| abstract_text = abstract_element.text | |
| else: | |
| abstract_text = "" | |
| if len(abstract_text) > 0: | |
| all_abstracts[pmid] = abstract_text | |
| # Respect NCBI's rate limits | |
| time.sleep(0.34) | |
| except requests.exceptions.RequestException as e: | |
| print(f"Error accessing PubMed API for batch {i//batch_size + 1}: {str(e)}") | |
| continue | |
| except ET.ParseError as e: | |
| print(f"Error parsing PubMed response for batch {i//batch_size + 1}: {str(e)}") | |
| continue | |
| except Exception as e: | |
| print(f"Unexpected error in batch {i//batch_size + 1}: {str(e)}") | |
| continue | |
| print("All abstracts retrieved") | |
| return all_abstracts | |
| def chunk_search(query: str, year_start: int, year_end: int) -> List[str]: | |
| """ | |
| Perform a PubMed search for a specific year range | |
| """ | |
| base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi" | |
| retmax = 9999 # Maximum allowed per query | |
| date_query = f"{query} AND {year_start}:{year_end}[dp]" | |
| params = { | |
| "db": "pubmed", | |
| "term": date_query, | |
| "retmax": retmax, | |
| "retmode": "xml" | |
| } | |
| response = requests.get(base_url, params=params) | |
| response.raise_for_status() | |
| root = ET.fromstring(response.content) | |
| id_list = root.findall(".//Id") | |
| return [id_elem.text for id_elem in id_list] | |
| def search_pubmed(query: str, start_year:int, end_year: int) -> Tuple[str, List[str]]: | |
| """ | |
| Search PubMed and return all matching PMIDs by breaking the search into year chunks | |
| """ | |
| base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi" | |
| all_pmids = [] | |
| yield "Loading current canto dump...", gr.DownloadButton(visible=True, interactive=False) | |
| canto_pmids = fetch_latest_canto_dump().select("pmid").with_columns(pl.col("pmid").str.split(":").list.last()) | |
| try: | |
| # First, get the total count | |
| params = { | |
| "db": "pubmed", | |
| "term": query, | |
| "retmax": 0, | |
| "retmode": "xml" | |
| } | |
| response = requests.get(base_url, params=params) | |
| response.raise_for_status() | |
| root = ET.fromstring(response.content) | |
| total_count = int(root.find(".//Count").text) | |
| if total_count == 0: | |
| return "No results found.", gr.DownloadButton(visible=True, interactive=False) | |
| print(total_count) | |
| # Break the search into year chunks | |
| year_chunks = [] | |
| chunk_size = 5 # Number of years per chunk | |
| for year in range(start_year, end_year + 1, chunk_size): | |
| chunk_end = min(year + chunk_size - 1, end_year) | |
| year_chunks.append((year, chunk_end)) | |
| # Search each year chunk | |
| for start_year, end_year in year_chunks: | |
| current_status = f"Searching years {start_year}-{end_year}..." | |
| yield current_status, gr.DownloadButton(visible=True, interactive=False) | |
| try: | |
| chunk_pmids = chunk_search(query, start_year, end_year) | |
| all_pmids.extend(chunk_pmids) | |
| # Status update | |
| yield f"Retrieved {len(all_pmids)} total results so far...", gr.DownloadButton(visible=True, interactive=False) | |
| # Respect NCBI's rate limits | |
| time.sleep(0.34) | |
| except Exception as e: | |
| print(f"Error processing years {start_year}-{end_year}: {str(e)}") | |
| continue | |
| uncurated_pmid = filter_new_hits(canto_pmids, all_pmids) | |
| final_message = f"Retrieved {len(uncurated_pmid)} uncurated pmids!" | |
| yield final_message, gr.DownloadButton(visible=True, interactive=False) | |
| abstracts = fetch_abstracts_batch(uncurated_pmid) | |
| yield f"Fetched {len(abstracts)} abstracts", gr.DownloadButton(visible=True, interactive=False) | |
| classifications = pl.DataFrame(classify_abstracts(abstracts)) | |
| print(classifications) | |
| yield f"Classified {len(abstracts)} abstracts", gr.DownloadButton(visible=True, interactive=False) | |
| classification_date = datetime.today().strftime('%Y%m%d') | |
| csv_filename = f"classified_pmids_{classification_date}.csv" | |
| yield "Write csv file...", gr.DownloadButton(visible=True, value=csv_filename, interactive=True) | |
| classifications.write_csv(csv_filename) | |
| yield final_message, gr.DownloadButton(visible=True, value=csv_filename, interactive=True) | |
| except requests.exceptions.RequestException as e: | |
| return f"Error accessing PubMed API: {str(e)}", all_pmids | |
| except ET.ParseError as e: | |
| return f"Error parsing PubMed response: {str(e)}", all_pmids | |
| except Exception as e: | |
| return f"Unexpected error: {str(e)}", all_pmids | |
| def download_file(): | |
| return gr.DownloadButton("Download results", visible=True, interactive=True) | |
| # Create Gradio interface | |
| def create_interface(): | |
| with gr.Blocks() as app: | |
| gr.Markdown("## PomBase PubMed PMID Search") | |
| gr.Markdown("Enter a search term to find ALL relevant PubMed articles. Large searches may take several minutes.") | |
| gr.Markdown("We then filter for new pmids, then classify them with a transformer model.") | |
| with gr.Row(): | |
| search_input = gr.Textbox( | |
| label="Search Term", | |
| placeholder="Enter search terms...", | |
| lines=1, | |
| value='pombe OR "fission yeast"' | |
| ) | |
| search_button = gr.Button("Search") | |
| with gr.Row(): | |
| current_year = datetime.now().year + 1 | |
| start_year = gr.Slider(label="Start year", minimum=1900, maximum=current_year, value=2020) | |
| end_year = gr.Slider(label="End year", minimum=1900, maximum=current_year, value=current_year) | |
| with gr.Row(): | |
| status_output = gr.Textbox( | |
| label="Status", | |
| value="Ready to search..." | |
| ) | |
| with gr.Row(): | |
| d = gr.DownloadButton("Download results", visible=True, interactive=False) | |
| with gr.Row(): | |
| progress=gr.Progress() | |
| d.click(download_file, None, d) | |
| search_button.click( | |
| fn=search_pubmed, | |
| inputs=[search_input, start_year, end_year], | |
| outputs=[status_output, d] | |
| ) | |
| return app | |
| # fetch_latest_canto_dump() | |
| app = create_interface() | |
| app.launch() | |