AxionLab-official commited on
Commit
3f7e5c2
·
verified ·
1 Parent(s): 321d5f9

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +34 -9
README.md CHANGED
@@ -93,7 +93,7 @@ active_params/tok : ~160,000
93
  ## Usage
94
 
95
  ```python
96
- from transformers import AutoModelForCausalLM
97
  from tokenizer import BPETokenizer
98
  import torch
99
 
@@ -103,9 +103,25 @@ model = AutoModelForCausalLM.from_pretrained(
103
  )
104
  model.eval()
105
 
106
- # model.vocab and model.model must be in the same folder
107
  tok = BPETokenizer.load("model.vocab", "model.model")
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  prompt = "# Pergunta:\nQuanto é 5 + 3?\n--\n# Resposta:\n"
110
  ids = tok.encode(prompt, add_bos=True, add_eos=False)
111
  input_ids = torch.tensor([ids])
@@ -113,16 +129,25 @@ input_ids = torch.tensor([ids])
113
  with torch.no_grad():
114
  output = model.generate(
115
  input_ids,
116
- max_new_tokens=60,
117
- temperature=0.8,
118
  do_sample=True,
119
- top_k=40,
120
- top_p=0.9,
121
- eos_token_id=tok.token2id["<eos>"],
122
- pad_token_id=tok.token2id["<pad>"],
 
 
 
 
123
  )
124
 
125
- print(tok.decode(output[0].tolist()))
 
 
 
 
 
126
  ```
127
 
128
  ---
 
93
  ## Usage
94
 
95
  ```python
96
+ from transformers import AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
97
  from tokenizer import BPETokenizer
98
  import torch
99
 
 
103
  )
104
  model.eval()
105
 
 
106
  tok = BPETokenizer.load("model.vocab", "model.model")
107
 
108
+ # Bloqueia EOS e PAD nos primeiros min_tokens gerados
109
+ class MinNewTokens(LogitsProcessor):
110
+ def __init__(self, min_tokens: int, eos_id: int, pad_id: int):
111
+ self.min_tokens = min_tokens
112
+ self.bad = [eos_id, pad_id]
113
+ self.generated = 0
114
+
115
+ def __call__(self, input_ids, scores):
116
+ if self.generated < self.min_tokens:
117
+ for bid in self.bad:
118
+ scores[:, bid] = float("-inf")
119
+ self.generated += 1
120
+ return scores
121
+
122
+ eos_id = tok.token2id["<eos>"]
123
+ pad_id = tok.token2id["<pad>"]
124
+
125
  prompt = "# Pergunta:\nQuanto é 5 + 3?\n--\n# Resposta:\n"
126
  ids = tok.encode(prompt, add_bos=True, add_eos=False)
127
  input_ids = torch.tensor([ids])
 
129
  with torch.no_grad():
130
  output = model.generate(
131
  input_ids,
132
+ max_new_tokens=80,
133
+ temperature=0.9,
134
  do_sample=True,
135
+ top_k=50,
136
+ top_p=0.95,
137
+ eos_token_id=eos_id,
138
+ pad_token_id=pad_id,
139
+ use_cache=False,
140
+ logits_processor=LogitsProcessorList([
141
+ MinNewTokens(min_tokens=5, eos_id=eos_id, pad_id=pad_id)
142
+ ]),
143
  )
144
 
145
+ new_tokens = output[0][len(ids):].tolist()
146
+ # Remove EOS do final se presente
147
+ if new_tokens and new_tokens[-1] == eos_id:
148
+ new_tokens = new_tokens[:-1]
149
+
150
+ print("Resposta:", tok.decode(new_tokens))
151
  ```
152
 
153
  ---