PlantGenoANN

PlantGenoANN is a plant genomic segmentation model that enables the prediction of various plant genomic elements at single-nucleotide resolution. The model is built upon the PlantBiMoE architecture with a 1D U-Net segmentation head, specifically designed for automated plant genome annotation. It predicts gene structures—including genes, CDSs, and exons—on both the forward and reverse strands. In addition, PlantGenoANN can serve as a long-context plant genomic foundation model (up to 49,152 bp), adaptable through fine-tuning to predict plant omic signal tracks such as RNA-seq or ATAC-seq.

Developed by: hu-lab

Model Sources

How to use

The model requires the mamba-ssm and causal-conv1d libraries for the core backbone. You can retrieve both genomic feature probabilities and sequence embeddings using the following snippet:

import torch
from transformers import AutoTokenizer, AutoModel

# Load model and tokenizer
repo_id = "qzzhang/PlantGenoANN"
tokenizer = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True)
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True)

# The number of DNA tokens (excluding the [CLS] and [SEP] token) needs to be divisible by 8 
# as required by the U-Net downsampling blocks. 
sequences = ["ACTAGAGCGAGAGAAA","TTTGAGAGCGCGCGGA"] 

# Tokenize
tokenized_sequences = tokenizer(
    sequences, 
    return_tensors="pt", 
    padding="longest"
)["input_ids"]

# Infer
model.to("cuda")
model.eval()
with torch.no_grad():
    outs = model(input_ids=tokenized_sequences.to("cuda"))

# Obtain the logits over the genomic features
# Shape: [batch, sequence_length, num_features]
logits = outs.logits

# Get probabilities associated with CDS on the forward strand (+)
pos_strand_cds_probs = model.get_feature_logits(feature="CDS", strand="+", logtis=logits).detach()
print(f"CDS probabilities on the forward strand: {pos_strand_cds_probs}")

# Get the sequence embeddings
# Shape: [batch, sequence_length, 1024]
hidden_states = outs.hidden_states.detach()
print(f"Sequence embeddings shape is: {hidden_states.shape}")

Architecture

PlantGenoANN is composed of the PlantBiMoE encoder (a 116M parameter foundation model) coupled with a custom U-Net segmentation head.

🛠️ Training Procedure

PlantGenoANN was trained for 30 hours on 4x NVIDIA A800-80G GPUs, processing a total of 18B tokens. The training utilized a high-quality dataset of 9 model plant genomes with their annotations. The model was optimized using AdamW (learning rate: 1e-4 and weight decay: 0.01) with a cosine learning rate scheduler, ensuring robust convergence across diverse plant genomic contexts.

BibTeX entry and citation info


Downloads last month
328
Safetensors
Model size
0.2B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support