Spaces:
Runtime error
Runtime error
| # beams = 5, return_seq = 1, max_length = 300 | |
| def get_question(sentence,answer,mdl,tknizer, num_seq, num_beams, max_length): | |
| if num_seq > num_beams: | |
| num_seq = num_beams | |
| prompt = "context: {} answer: {}".format(sentence,answer) | |
| print (prompt) | |
| max_len = 256 | |
| encoding = tknizer.encode_plus(prompt,max_length=max_len, pad_to_max_length=False,truncation=True, return_tensors="pt") | |
| input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"] | |
| outs = mdl.generate(input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| early_stopping=True, | |
| num_beams=num_beams, | |
| num_return_sequences=num_seq, | |
| no_repeat_ngram_size=2, | |
| max_length=max_length) | |
| dec = [tknizer.decode(ids,skip_special_tokens=True) for ids in outs] | |
| Question = [x.replace("question:", "") for x in dec] | |
| return Question | |