GenerTeam commited on
Commit
bf487d7
·
verified ·
1 Parent(s): 96ff9ad

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +33 -24
README.md CHANGED
@@ -98,27 +98,31 @@ print(decoded_sequences)
98
 
99
  ```python
100
 
 
101
  import torch
102
  from transformers import AutoTokenizer, AutoModelForCausalLM
103
 
104
- # Load the tokenizer and model.
105
  tokenizer = AutoTokenizer.from_pretrained("GENERator-eukaryote-3b-base", trust_remote_code=True)
106
- model = AutoModelForCausalLM.from_pretrained("GenerTeam/GENERator-eukaryote-3b-base")
107
 
 
108
  config = model.config
109
  max_length = config.max_position_embeddings
110
 
111
- # Define input sequences.
112
  sequences = [
113
  "ATGAGGTGGCAAGAAATGGGCTAC",
114
  "GAATTCCATGAGGCTATAGAATAATCTAAGAGAAAT"
115
  ]
116
 
117
- # Tokenize the sequences with add_special_tokens=True to automatically add special tokens,
118
- # such as the BOS EOS token, at the appropriate positions.
 
 
119
  tokenizer.padding_side = "right"
120
  inputs = tokenizer(
121
- sequences,
122
  add_special_tokens=True,
123
  return_tensors="pt",
124
  padding=True,
@@ -126,29 +130,34 @@ inputs = tokenizer(
126
  max_length=max_length
127
  )
128
 
129
- # Perform a forward pass through the model to obtain the outputs, including hidden states.
130
  with torch.inference_mode():
131
  outputs = model(**inputs, output_hidden_states=True)
132
 
133
- # Retrieve the hidden states from the last layer.
134
- hidden_states = outputs.hidden_states[-1] # Shape: (batch_size, sequence_length, hidden_size)
135
-
136
- # Use the attention_mask to determine the index of the last token in each sequence.
137
- # Since add_special_tokens=True is used, the last token is typically the EOS token.
138
  attention_mask = inputs["attention_mask"]
139
- last_token_indices = attention_mask.sum(dim=1) - 1 # Index of the last token for each sequence
140
-
141
- # Extract the embedding corresponding to the EOS token for each sequence.
142
- seq_embeddings = []
143
- for i, token_index in enumerate(last_token_indices):
144
- # Fetch the embedding for the last token (EOS token).
145
- seq_embedding = hidden_states[i, token_index, :]
146
- seq_embeddings.append(seq_embedding)
147
-
148
- # Stack the embeddings into a tensor with shape (batch_size, hidden_size)
149
- seq_embeddings = torch.stack(seq_embeddings)
150
 
151
- print("Sequence Embeddings:", seq_embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  ```
154
 
 
98
 
99
  ```python
100
 
101
+
102
  import torch
103
  from transformers import AutoTokenizer, AutoModelForCausalLM
104
 
105
+ # Load the tokenizer and model
106
  tokenizer = AutoTokenizer.from_pretrained("GENERator-eukaryote-3b-base", trust_remote_code=True)
107
+ model = AutoModelForCausalLM.from_pretrained("GENERator-eukaryote-3b-base")
108
 
109
+ # Get model configuration
110
  config = model.config
111
  max_length = config.max_position_embeddings
112
 
113
+ # Define input sequences
114
  sequences = [
115
  "ATGAGGTGGCAAGAAATGGGCTAC",
116
  "GAATTCCATGAGGCTATAGAATAATCTAAGAGAAAT"
117
  ]
118
 
119
+ # Truncate each sequence to the nearest multiple of 6
120
+ processed_sequences = [tokenizer.bos_token + seq[:len(seq)//6*6] for seq in sequences]
121
+
122
+ # Tokenization
123
  tokenizer.padding_side = "right"
124
  inputs = tokenizer(
125
+ processed_sequences,
126
  add_special_tokens=True,
127
  return_tensors="pt",
128
  padding=True,
 
130
  max_length=max_length
131
  )
132
 
133
+ # Model Inference
134
  with torch.inference_mode():
135
  outputs = model(**inputs, output_hidden_states=True)
136
 
137
+ hidden_states = outputs.hidden_states[-1]
 
 
 
 
138
  attention_mask = inputs["attention_mask"]
 
 
 
 
 
 
 
 
 
 
 
139
 
140
+ # Option 1: Last token (EOS) embedding
141
+ last_token_indices = attention_mask.sum(dim=1) - 1
142
+ eos_embeddings = hidden_states[torch.arange(hidden_states.size(0)), last_token_indices, :]
143
+
144
+ # Option 2: Mean pooling over all tokens
145
+ expanded_mask = attention_mask.unsqueeze(-1).expand(hidden_states.size()).to(torch.float32)
146
+ sum_embeddings = torch.sum(hidden_states * expanded_mask, dim=1)
147
+ mean_embeddings = sum_embeddings / expanded_mask.sum(dim=1)
148
+
149
+ # Output
150
+ print("EOS (Last Token) Embeddings:", eos_embeddings)
151
+ print("Mean Pooling Embeddings:", mean_embeddings)
152
+
153
+ # ============================================================================
154
+ # Additional notes:
155
+ # - The preprocessing step ensures sequences are multiples of 6 for 6-mer tokenizer
156
+ # - For causal LM, the last token embedding (EOS) is commonly used
157
+ # - Mean pooling considers all tokens including BOS and content tokens
158
+ # - The choice depends on your downstream task requirements
159
+ # - Both methods handle variable sequence lengths via attention mask
160
+ # ============================================================================
161
 
162
  ```
163