Spaces:
Sleeping
Sleeping
| from transformers import AutoTokenizer, AutoModel | |
| from accelerate import Accelerator | |
| from accelerate.utils import gather_object | |
| from tqdm import tqdm | |
| import torch, gc | |
| import torch.nn as nn | |
| class EmbeddingModelWrapper(): | |
| DEFAULT_MODEL="sentence-transformers/all-mpnet-base-v2" | |
| def __init__(self, model_path=None, bs=8): | |
| if model_path is None: model_path = self.DEFAULT_MODEL | |
| self.model, self.tokenizer = self.load_model(model_path) | |
| self.bs = bs | |
| self.cos = nn.CosineSimilarity(dim=1, eps=1e-6) | |
| def load_model(self, model_path): | |
| model = AutoModel.from_pretrained( | |
| model_path, | |
| ).to("mps") | |
| model.eval() | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, | |
| ) | |
| return model, tokenizer | |
| def emb_mean_pooling(self, model_output, attention_mask): | |
| token_embeddings = model_output[0] | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
| def get_embeddings(self, sentences): | |
| embeddings=torch.tensor([],device="mps") | |
| if self.bs is None: | |
| batches=[sentences] | |
| else: | |
| batches = [sentences[i:i + self.bs] for i in range(0, len(sentences), self.bs)] | |
| for sentences in batches: | |
| encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to("mps") | |
| with torch.no_grad(): | |
| model_output = self.model(**encoded_input) | |
| batch_embeddings=self.emb_mean_pooling(model_output, encoded_input['attention_mask']) | |
| embeddings=torch.cat( (embeddings, batch_embeddings), dim=0 ) | |
| return embeddings | |
| def get_similarities(self, x, y=None): | |
| if y is None: | |
| num_samples=x.shape[0] | |
| similarities = [[0 for i in range(num_samples)] for f in range(num_samples)] | |
| for row in tqdm(range(num_samples)): | |
| similarities[row][0:row+1]=self.cos(x[row].repeat(row+1,1), x[0:row+1]).tolist() | |
| return similarities | |
| else: | |
| return self.cos(x,y).tolist() | |
| class ModelPredictionGenerator: | |
| def __init__(self, model, tokenizer, eval_dataset, use_accelerate=False, bs=8, generation_config=None): | |
| self.model=model | |
| self.tokenizer=tokenizer | |
| self.bs=bs | |
| self.eval_prompts=self.messages_to_prompts( eval_dataset ) | |
| self.use_accelerate=use_accelerate | |
| self.accelerator = Accelerator() | |
| assert tokenizer.eos_token_id is not None | |
| assert tokenizer.chat_template is not None | |
| if tokenizer.pad_token_id is None: | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| # llama-precise | |
| if generation_config is None: | |
| self.generation_config = { | |
| "temperature": 0.7, | |
| "top_p": 0.1, | |
| "repetition_penalty": 1.18, | |
| "top_k": 40, | |
| "do_sample": True, | |
| "max_new_tokens": 100, | |
| "pad_token_id": tokenizer.pad_token_id | |
| } | |
| else: | |
| self.generation_config = generation_config | |
| def clear_cache(self): | |
| torch.mps.empty_cache() | |
| gc.collect() | |
| def messages_to_prompts(self, ds): | |
| prompts=[] | |
| for conversation in ds["messages"]: | |
| for i,msg in enumerate(conversation): | |
| if msg["role"]=="user": | |
| prompts.append( | |
| dict ( | |
| # prompt: format current messages up to the current user message and add a generation prompt | |
| prompt=self.tokenizer.apply_chat_template(conversation[:i+1], add_generation_prompt=True, tokenize=False), | |
| answer_ref=conversation[i+1]["content"] | |
| ) | |
| ) | |
| return prompts | |
| def get_batches(self, dataset, batch_size): | |
| return [dataset[i:i + batch_size] for i in range(0, len(dataset), batch_size)] | |
| def tokenize_batch(self, batch): | |
| pad_side=self.tokenizer.padding_side | |
| self.tokenizer.padding_side="left" # left pad for inference | |
| prompts=[ item["prompt"] for item in batch ] | |
| prompts_tok=self.tokenizer( | |
| prompts, | |
| return_tensors="pt", | |
| padding='longest', | |
| truncation=True, | |
| max_length=self.tokenizer.model_max_length, | |
| return_length=True, | |
| pad_to_multiple_of=8, | |
| add_special_tokens=False | |
| ).to(self.model.device) | |
| self.tokenizer.padding_side=pad_side # restore orig. padding side | |
| return prompts_tok | |
| def generate_batch(self, batch_tok): | |
| with torch.no_grad(): | |
| outputs_tok=self.model.generate( | |
| input_ids=batch_tok["input_ids"], | |
| attention_mask=batch_tok["attention_mask"], | |
| **self.generation_config | |
| ).to("cpu") | |
| outputs=[ | |
| # cut prompt from output | |
| self.tokenizer.decode( | |
| outputs_tok[i][outputs_tok[i] != self.tokenizer.pad_token_id][batch_tok["length"][i]:], | |
| spaces_between_special_tokens=False, | |
| skip_special_tokens=True | |
| ).strip() | |
| for i,t in enumerate(outputs_tok) ] | |
| return outputs | |
| def run(self): | |
| self.model.eval() | |
| self.clear_cache() | |
| if self.use_accelerate: | |
| with self.accelerator.split_between_processes(list(range(len(self.eval_prompts)))) as eval_prompts_local_idcs: | |
| eval_prompts_local = [self.eval_prompts[i] for i in eval_prompts_local_idcs] | |
| else: | |
| eval_prompts_local = self.eval_prompts | |
| for batch in tqdm( self.get_batches(eval_prompts_local, self.bs) ): | |
| batch_tok = self.tokenize_batch( batch ) | |
| answers = self.generate_batch( batch_tok ) | |
| for i in range(len(batch)): | |
| batch[i]["answer_pred"]=answers[i] | |
| batch[i]["GPU"]=self.accelerator.process_index | |
| if self.use_accelerate: | |
| return gather_object(eval_prompts_local) | |
| else: | |
| return eval_prompts_local |