SpliceBERT (manuscript, preprint) is a primary RNA sequence language model pre-trained on over 2 million vertebrate RNA sequences. It can be used to study RNA splicing and other biological problems related to RNA sequence.
For additional benchmarks and applications of SpliceBERT (e.g., on SpliceAI's and DeepSTARR's datasets), see SpliceBERT-analysis.
Data availability
How to use SpliceBERT?
Reproduce the analysis
Contact
Citation
The model weights and data for analysis are available at zenodo:7995778.
SpliceBERT is implemented with Huggingface transformers
and FlashAttention in PyTorch. Users should install pytorch, transformers and FlashAttention (optional) to load the SpliceBERT model.
Install PyTorch: https://pytorch.org/get-started/locally/
Install Huggingface transformers: https://huggingface.co/docs/transformers/installation
Install FlashAttention (optional): https://github.com/Dao-AILab/flash-attention
SpliceBERT can be easily used for a series of downstream tasks through the official API. See official guide for more details.
Download SpliceBERT
The weights of SpliceBERT can be downloaded from zenodo: https://zenodo.org/record/7995778/files/models.tar.gz?download=1
System requirements
We recommend running SpliceBERT on a Linux system with a NVIDIA GPU of at least 4GB memory. (Running our model with only CPU is possible, but it will be very slow.)
Examples
We provide a demo script to show how to use SpliceBERT though the official API of Huggingface transformers in the first part of the following code block.
Users can also use SpliceBERT with FlashAttention by replacing the official API with the custom API, as shown in the second part of the following code block.Note that flash-attention requires automatic mixed precision (amp) mode to be enabled and currently it does not support attention_mask
Use SpliceBERT though the official API of Huggingface transformers:
SPLICEBERT_PATH = "/path/to/SpliceBERT/models/model_folder" # set the path to the folder of pre-trained SpliceBERTimport torchfrom transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, AutoModelForTokenClassification# load tokenizertokenizer = AutoTokenizer.from_pretrained(SPLICEBERT_PATH)# prepare input sequenceseq = "ACGUACGuacguaCGu" ## WARNING: this is just a demo. SpliceBERT may not work on sequences shorter than 64nt as it was trained on sequences of 64-1024nt in lengthseq = ' '.join(list(seq.upper().replace("U", "T"))) # U -> T and add whitespaceinput_ids = tokenizer.encode(seq) # N -> 5, A -> 6, C -> 7, G -> 8, T(U) -> 9. NOTE: a [CLS] and a [SEP] token will be added to the start and the end of seqinput_ids = torch.as_tensor(input_ids) # convert python list to Tensorinput_ids = input_ids.unsqueeze(0) # add batch dimension, shape: (batch_size, sequence_length)# use huggerface's official API to use SpliceBERT# get nucleotide embeddings (hidden states)model = AutoModel.from_pretrained(SPLICEBERT_PATH) # load modellast_hidden_state = model(input_ids).last_hidden_state # get hidden states from last layerhiddens_states = model(input_ids, output_hidden_states=True).hidden_states # hidden states from the embedding layer (nn.Embedding) and the 6 transformer encoder layers# get nucleotide type logits in masked language modelingmodel = AutoModelForMaskedLM.from_pretrained(SPLICEBERT_PATH) # load modellogits = model(input_ids).logits # shape: (batch_size, sequence_length, vocab_size)# finetuning SpliceBERT for token classification tasksmodel = AutoModelForTokenClassification.from_pretrained(SPLICEBERT_PATH, num_labels=3) # assume the class number is 3, shape: (batch_size, sequence_length, num_labels)# finetuning SpliceBERT for sequence classification tasksmodel = AutoModelForSequenceClassification.from_pretrained(SPLICEBERT_PATH, num_labels=3) # assume the class number is 3, shape: (batch_size, sequence_length, num_labels)
Or use SpliceBERT with FlashAttention by replacing the official API with the custom API (Currently flash-attention does not support attention_mask. As a result, the length of sequences in each batch should be the same)
SPLICEBERT_PATH = "/path/to/SpliceBERT/models/model_folder" # set the path to the folder of pre-trained SpliceBERTimport torchimport syssys.path.append(os.path.dirname(os.path.abspath(SPICEBERT_PATH)))from transformers import AutoTokenizerfrom splicebert_model import BertModel, BertForMaskedLM, BertForTokenClassification# load tokenizertokenizer = AutoTokenizer.from_pretrained(SPLICEBERT_PATH)# prepare input sequenceseq = "ACGUACGuacguaCGu" ## WARNING: this is just a demo. SpliceBERT may not work on sequences shorter than 64nt as it was trained on sequences of 64-1024nt in lengthseq = ' '.join(list(seq.upper().replace("U", "T"))) # U -> T and add whitespaceinput_ids = tokenizer.encode(seq) # N -> 5, A -> 6, C -> 7, G -> 8, T(U) -> 9. NOTE: a [CLS] and a [SEP] token will be added to the start and the end of seqinput_ids = torch.as_tensor(input_ids) # convert python list to Tensorinput_ids = input_ids.unsqueeze(0) # add batch dimension, shape: (batch_size, sequence_length)# Or use custom BertModel with FlashAttention# get nucleotide embeddings (hidden states)model = BertModel.from_pretrained(SPLICEBERT_PATH) # load modelwith autocast(): last_hidden_state = model(input_ids).last_hidden_state # get hidden states from last layer hiddens_states = model(input_ids, output_hidden_states=True).hidden_states # hidden states from the embedding layer (nn.Embedding) and the 6 transformer encoder layers# get nucleotide type logits in masked language modelingmodel = BertForMaskedLM.from_pretrained(SPLICEBERT_PATH) # load modelwith autocast(): logits = model(input_ids).logits # shape: (batch_size, sequence_length, vocab_size)# finetuning SpliceBERT for token classification taskswith autocast(): model = BertForTokenClassification.from_pretrained(SPLICEBERT_PATH, num_labels=3) # assume the class number is 3, shape: (batch_size, sequence_length, num_labels)# finetuning SpliceBERT for sequence classification taskswith autocast(): model = BertForSequenceClassification.from_pretrained(SPLICEBERT_PATH, num_labels=3) # assume the class number is 3, shape: (batch_size, sequence_length, num_labels)
Configure the environment.
We run the scripts in a conda environment with python 3.9.7 on a Linux system (Ubuntu 20.04.3 LTS). The required packages are:
Note: the version number is only used to illustrate the version of softwares used in our study. In most cases, users do not need to ensure that the versions are strictly the same to ours to run the codes
bedtools (2.30.0)
MaxEntScan (2004)
gtfToGenePred (v377)
Python (3.9.7)
transformers (4.24.0)
pytorch (1.12.1)
h5py (3.2.1)
numpy (1.23.3)
scipy (1.8.0)
scikit-learn (1.1.1)
scanpy (1.8.2)
matplotlib (3.5.1)
seaborn (0.11.2)
tqdm (4.64.0)
pyBigWig (0.3.18)
cython (0.29.28)
Python packages:
Command line tools (optional):
Clone this repository, download data and setup scripts.
git clone [email protected]:biomed-AI/SpliceBERT.gitcd SpliceBERT bash download.sh # download model weights and data, or manually download them from [zenodo](https://doi.org/10.5281/zenodo.7995778)cd examples bash setup.sh # compile selene utils, cython is required
(Optional) Download pre-computed results for section 1-4 from Google Drive and decompress them in the examples
folder.
# users should manually download `pre-computed_results.tar.gz` and put it in the `./examples` folder and run the following command to decompress ittar -zxvf pre-computed_results.tar.gz
If pre-computed results have been downloaded and decompressed correctly,
users can skip running pipeline.sh
in the jupyter notebooks of section 1-4.
Run jupyter notebooks (section 1-4) or bash scripts pipeline.sh
(section 5-6):
evolutionary conservation analysis (related to Figure 1)
nucleotide embedding analysis (related to Figure 2)
attention weight analysis (related to Figure 3)
variant effect analysis (related to Figure 4)
branchpoint prediction (related to Figure 5)
splice site prediction (related to Figure 6)
For issues related to the scripts, create an issue at https://github.com/biomed-AI/SpliceBERT/issues.
For any other questions, feel free to contact chenkenbio {at} gmail.com.
@article{Chen2023.01.31.526427, author = {Chen, Ken and Zhou, Yue and Ding, Maolin and Wang, Yu and Ren, Zhixiang and Yang, Yuedong}, title = {Self-supervised learning on millions of primary RNA sequences from 72 vertebrates improves sequence-based RNA splicing prediction}, year = {2024}, doi = {10.1093/bib/bbae163}, publisher = {Oxford University Press}, URL = {https://doi.org/10.1093/bib/bbae163}, journal = {Briefings in bioinformatics} }