kairusama commited on
Commit
24ddd2a
·
verified ·
1 Parent(s): 75ea43c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -15
app.py CHANGED
@@ -1,13 +1,31 @@
1
  # app.py
2
  import gradio as gr
3
- from transformers import pipeline
 
4
 
5
  # ---- Load model via pipeline ----
6
  MODEL_NAME = "vicgalle/gpt2-open-instruct-v1"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  pipe = pipeline("text-generation", model=MODEL_NAME, device_map="auto")
8
 
9
  # ---- Inference function ----
10
- def generate_response(instruction, max_new_tokens=150, temperature=0.7, top_k=50, top_p=0.9):
 
 
 
11
  system_prompt = f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
12
 
13
  ### Instruction:
@@ -15,34 +33,48 @@ def generate_response(instruction, max_new_tokens=150, temperature=0.7, top_k=50
15
 
16
  ### Response:
17
  """
18
- output = pipe(
 
 
 
 
 
 
19
  system_prompt,
20
- max_new_tokens=max_new_tokens,
21
- temperature=temperature,
22
- top_k=top_k,
23
- top_p=top_p,
24
  do_sample=True,
 
 
 
 
 
 
 
25
  pad_token_id=pipe.tokenizer.eos_token_id,
 
 
26
  )
27
- # Clean up output text
28
- text = output[0]["generated_text"]
29
- return text.split("### Response:")[-1].strip()
 
 
30
 
31
  # ---- Gradio UI ----
32
  with gr.Blocks() as demo:
33
- gr.Markdown("# 🛸 GPT-2 Open Instruct Playground\nType an instruction and let the alien respond!")
34
  with gr.Row():
35
- with gr.Column(scale=3):
36
- instruction = gr.Textbox(label="Instruction", placeholder="Pretend you are an alien visiting Earth...", lines=6)
 
 
37
  max_new_tokens = gr.Slider(50, 500, value=150, step=10, label="Max new tokens")
38
  temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="Temperature")
39
  top_k = gr.Slider(10, 100, value=50, step=5, label="Top-K sampling")
40
  top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P (nucleus) sampling")
41
  generate_btn = gr.Button("Generate ✨")
42
- with gr.Column(scale=2):
43
- output_box = gr.Textbox(label="Model Output", lines=10)
44
  generate_btn.click(generate_response, [instruction, max_new_tokens, temperature, top_k, top_p], output_box)
45
 
 
46
  # ---- Launch ----
47
  if __name__ == "__main__":
48
  demo.launch()
 
1
  # app.py
2
  import gradio as gr
3
+ import torch
4
+ from transformers import pipeline, StoppingCriteria, StoppingCriteriaList
5
 
6
  # ---- Load model via pipeline ----
7
  MODEL_NAME = "vicgalle/gpt2-open-instruct-v1"
8
+
9
+ class StopOnStrings(StoppingCriteria):
10
+ def __init__(self, stop_ids, window=10):
11
+ super().__init__()
12
+ self.stop_ids = stop_ids
13
+ self.window = window
14
+ def __call__(self, input_ids, scores, **kwargs):
15
+ # Stop if the recent tokens match any stop sequence
16
+ for stop in self.stop_ids:
17
+ if len(input_ids[0]) >= len(stop):
18
+ if torch.equal(input_ids[0][-len(stop):], stop):
19
+ return True
20
+ return False
21
+
22
  pipe = pipeline("text-generation", model=MODEL_NAME, device_map="auto")
23
 
24
  # ---- Inference function ----
25
+ def generate_response(instruction,
26
+ max_new_tokens=150,
27
+ temperature=0.7,
28
+ top_p=0.9):
29
  system_prompt = f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
30
 
31
  ### Instruction:
 
33
 
34
  ### Response:
35
  """
36
+
37
+ # Build stop ids for "### End"
38
+ stop_text = "### End"
39
+ stop_ids = pipe.tokenizer(stop_text, add_special_tokens=False, return_tensors="pt")["input_ids"][0]
40
+ stopping = StoppingCriteriaList([StopOnStrings([stop_ids])])
41
+
42
+ out = pipe(
43
  system_prompt,
 
 
 
 
44
  do_sample=True,
45
+ temperature=temperature,
46
+ top_p=top_p, # prefer one: top_p OR top_k
47
+ # top_k=50, # leave this off when using top_p
48
+ max_new_tokens=max_new_tokens,
49
+ no_repeat_ngram_size=3,
50
+ repetition_penalty=1.15,
51
+ eos_token_id=pipe.tokenizer.eos_token_id,
52
  pad_token_id=pipe.tokenizer.eos_token_id,
53
+ return_full_text=False, # don't echo the prompt
54
+ stopping_criteria=stopping,
55
  )
56
+
57
+ text = out[0]["generated_text"]
58
+ # Hard stop as a second line of defense
59
+ text = text.split(stop_text)[0].strip()
60
+ return text
61
 
62
  # ---- Gradio UI ----
63
  with gr.Blocks() as demo:
64
+ gr.Markdown("# 🛸 GPT-2 Open Instruct Playground\nThe original GPT-2 fine-tuned with Open Instruct v1.")
65
  with gr.Row():
66
+ with gr.Column(scale=4):
67
+ instruction = gr.Textbox(label="Instruction", value="What is the capital city of France?", lines=6)
68
+ output_box = gr.Textbox(label="Model Output", lines=25)
69
+ with gr.Column(scale=1):
70
  max_new_tokens = gr.Slider(50, 500, value=150, step=10, label="Max new tokens")
71
  temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="Temperature")
72
  top_k = gr.Slider(10, 100, value=50, step=5, label="Top-K sampling")
73
  top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P (nucleus) sampling")
74
  generate_btn = gr.Button("Generate ✨")
 
 
75
  generate_btn.click(generate_response, [instruction, max_new_tokens, temperature, top_k, top_p], output_box)
76
 
77
+
78
  # ---- Launch ----
79
  if __name__ == "__main__":
80
  demo.launch()