| | from .isoformer_config import IsoformerConfig |
| | from transformers import PreTrainedModel |
| | from .modeling_esm import NTForMaskedLM, MultiHeadAttention |
| | from .esm_config import NTConfig |
| | from .modeling_esm_original import EsmForMaskedLM |
| | from transformers.models.esm.configuration_esm import EsmConfig |
| | from enformer_pytorch import Enformer, str_to_one_hot, EnformerConfig |
| | import torch |
| | from torch import nn |
| |
|
| | class Isoformer(PreTrainedModel): |
| | config_class = IsoformerConfig |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| |
|
| |
|
| | self.esm_config = EsmConfig( |
| | vocab_size=config.esm_vocab_size, |
| | mask_token_id=config.esm_mask_token_id, |
| | pad_token_id=config.esm_pad_token_id, |
| | hidden_size=config.esm_hidden_size, |
| | num_hidden_layers=config.esm_num_hidden_layers, |
| | num_attention_heads=config.esm_num_attention_heads, |
| | intermediate_size=config.esm_intermediate_size, |
| | max_position_embeddings=config.esm_max_position_embeddings, |
| | token_dropout=config.esm_token_dropout, |
| | emb_layer_norm_before=config.esm_emb_layer_norm_before, |
| | attention_probs_dropout_prob=0.0, |
| | hidden_dropout_prob=0.0, |
| | use_cache=False, |
| | add_bias_fnn=config.esm_add_bias_fnn, |
| | position_embedding_type="rotary", |
| | tie_word_embeddings=False, |
| | ) |
| |
|
| | self.nt_config = NTConfig( |
| | vocab_size=config.nt_vocab_size, |
| | mask_token_id=config.nt_mask_token_id, |
| | pad_token_id=config.nt_pad_token_id, |
| | hidden_size=config.nt_hidden_size, |
| | num_hidden_layers=config.nt_num_hidden_layers, |
| | num_attention_heads=config.nt_num_attention_heads, |
| | intermediate_size=config.nt_intermediate_size, |
| | max_position_embeddings=config.nt_max_position_embeddings, |
| | token_dropout=config.nt_token_dropout, |
| | emb_layer_norm_before=config.nt_emb_layer_norm_before, |
| | attention_probs_dropout_prob=0.0, |
| | hidden_dropout_prob=0.0, |
| | use_cache=False, |
| | add_bias_fnn=config.nt_add_bias_fnn, |
| | position_embedding_type="rotary", |
| | tie_word_embeddings=False, |
| | ) |
| | self.config = config |
| |
|
| | self.esm_model = EsmForMaskedLM(self.esm_config) |
| | self.nt_model = NTForMaskedLM(self.nt_config) |
| | self.enformer_model = Enformer.from_pretrained("EleutherAI/enformer-official-rough") |
| |
|
| | self.cross_attention_layer_rna = MultiHeadAttention( |
| | config=EsmConfig( |
| | num_attention_heads=config.num_heads_omics_cross_attention, |
| | attention_head_size=3072 // config.num_heads_omics_cross_attention, |
| | hidden_size=3072, |
| | attention_probs_dropout_prob=0, |
| | max_position_embeddings=0 |
| | ), |
| | omics_of_interest_size=3072, |
| | other_omic_size=768 |
| | ) |
| | self.cross_attention_layer_protein = MultiHeadAttention( |
| | config=EsmConfig( |
| | num_attention_heads=config.num_heads_omics_cross_attention, |
| | attention_head_size=3072 // config.num_heads_omics_cross_attention, |
| | hidden_size=3072, |
| | attention_probs_dropout_prob=0, |
| | max_position_embeddings=0 |
| | ), |
| | omics_of_interest_size=3072, |
| | other_omic_size=640 |
| | ) |
| |
|
| | self.head_layer_1 = nn.Linear(3072, 2 * 3072) |
| | self.head_layer_2 = nn.Linear(2 * 3072, 30) |
| |
|
| | def forward( |
| | self, |
| | tensor_dna, |
| | tensor_rna, |
| | tensor_protein, |
| | attention_mask_rna, |
| | attention_mask_protein |
| | ): |
| | tensor_dna = tensor_dna[:, 1:] |
| | dna_embedding = self.enformer_model( |
| | tensor_dna, |
| | return_only_embeddings=True |
| | |
| | |
| | |
| | ) |
| | protein_embedding = self.esm_model( |
| | tensor_protein, |
| | attention_mask=attention_mask_protein, |
| | encoder_attention_mask=attention_mask_protein, |
| | output_hidden_states=True |
| | ) |
| | rna_embedding = self.nt_model( |
| | tensor_rna, |
| | attention_mask=attention_mask_rna, |
| | encoder_attention_mask=attention_mask_rna, |
| | output_hidden_states=True |
| | ) |
| |
|
| | encoder_attention_mask = torch.unsqueeze(torch.unsqueeze(tensor_rna != 1, 0),0).repeat(1,1,dna_embedding.shape[1],1) |
| | rna_to_dna = self.cross_attention_layer_rna.forward( |
| | hidden_states=dna_embedding, |
| | encoder_hidden_states=rna_embedding["hidden_states"][-1], |
| | encoder_attention_mask=encoder_attention_mask |
| | ) |
| |
|
| | final_dna_embeddings = self.cross_attention_layer_protein.forward( |
| | hidden_states=rna_to_dna["embeddings"], |
| | encoder_hidden_states=protein_embedding["hidden_states"][-1], |
| | )["embeddings"] |
| |
|
| | sequence_mask = torch.zeros(final_dna_embeddings.shape[1]) |
| | sequence_mask[self.config.pool_window_start:self.config.pool_window_end] = 1 |
| | x = torch.sum(torch.einsum('ijk,j->ijk', final_dna_embeddings, sequence_mask),axis=1)/torch.sum(sequence_mask) |
| | x = self.head_layer_1(x) |
| | x = torch.nn.functional.softplus(x) |
| | x = self.head_layer_2(x) |
| |
|
| |
|
| | return { |
| | "gene_expression_predictions": x, |
| | "final_dna_embeddings": final_dna_embeddings, |
| | } |