Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| import openvino as ov | |
| from typing import List, Dict | |
| from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions | |
| def init_past_inputs(model_inputs: List): | |
| """ | |
| Helper function for initialization of past inputs on first inference step | |
| Parameters: | |
| model_inputs (List): list of model inputs | |
| Returns: | |
| pkv (List[ov.Tensor]): list of filled past key values | |
| """ | |
| pkv = [] | |
| for input_tensor in model_inputs[4:]: | |
| partial_shape = input_tensor.partial_shape | |
| partial_shape[0] = 1 | |
| partial_shape[2] = 0 | |
| pkv.append(ov.Tensor(ov.Type.f32, partial_shape.get_shape())) | |
| return pkv | |
| def postprocess_text_decoder_outputs(output: Dict): | |
| """ | |
| Helper function for rearranging model outputs and wrapping to CausalLMOutputWithCrossAttentions | |
| Parameters: | |
| output (Dict): dictionary with model output | |
| Returns | |
| wrapped_outputs (CausalLMOutputWithCrossAttentions): outputs wrapped to CausalLMOutputWithCrossAttentions format | |
| """ | |
| logits = torch.from_numpy(output[0]) | |
| past_kv = list(output.values())[1:] | |
| return CausalLMOutputWithCrossAttentions( | |
| loss=None, | |
| logits=logits, | |
| past_key_values=past_kv, | |
| hidden_states=None, | |
| attentions=None, | |
| cross_attentions=None, | |
| ) | |
| def text_decoder_forward( | |
| ov_text_decoder_with_past: ov.CompiledModel, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| past_key_values: List[ov.Tensor], | |
| encoder_hidden_states: torch.Tensor, | |
| encoder_attention_mask: torch.Tensor, | |
| **kwargs | |
| ): | |
| """ | |
| Inference function for text_decoder in one generation step | |
| Parameters: | |
| input_ids (torch.Tensor): input token ids | |
| attention_mask (torch.Tensor): attention mask for input token ids | |
| past_key_values (List[ov.Tensor] list of cached decoder hidden states from previous step | |
| encoder_hidden_states (torch.Tensor): encoder (vision or text) hidden states | |
| encoder_attention_mask (torch.Tensor): attnetion mask for encoder hidden states | |
| Returns | |
| model outputs (CausalLMOutputWithCrossAttentions): model prediction wrapped to CausalLMOutputWithCrossAttentions class including predicted logits and hidden states for caching | |
| """ | |
| inputs = [input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask] | |
| if past_key_values is None: | |
| inputs.extend(init_past_inputs(ov_text_decoder_with_past.inputs)) | |
| else: | |
| inputs.extend(past_key_values) | |
| outputs = ov_text_decoder_with_past(inputs) | |
| return postprocess_text_decoder_outputs(outputs) | |
| class OVBlipModel: | |
| """ | |
| Model class for inference BLIP model with OpenVINO | |
| """ | |
| def __init__( | |
| self, | |
| config, | |
| decoder_start_token_id: int, | |
| vision_model, | |
| text_encoder, | |
| text_decoder, | |
| ): | |
| """ | |
| Initialization class parameters | |
| """ | |
| self.vision_model = vision_model | |
| self.vision_model_out = vision_model.output(0) | |
| self.text_encoder = text_encoder | |
| self.text_encoder_out = text_encoder.output(0) | |
| self.text_decoder = text_decoder | |
| self.config = config | |
| self.decoder_start_token_id = decoder_start_token_id | |
| self.decoder_input_ids = config.text_config.bos_token_id | |
| def generate_answer(self, pixel_values: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs): | |
| """ | |
| Visual Question Answering prediction | |
| Parameters: | |
| pixel_values (torch.Tensor): preprocessed image pixel values | |
| input_ids (torch.Tensor): question token ids after tokenization | |
| attention_mask (torch.Tensor): attention mask for question tokens | |
| Retruns: | |
| generation output (torch.Tensor): tensor which represents sequence of generated answer token ids | |
| """ | |
| image_embed = self.vision_model(pixel_values.detach().numpy())[self.vision_model_out] | |
| image_attention_mask = np.ones(image_embed.shape[:-1], dtype=int) | |
| if isinstance(input_ids, list): | |
| input_ids = torch.LongTensor(input_ids) | |
| question_embeds = self.text_encoder( | |
| [ | |
| input_ids.detach().numpy(), | |
| attention_mask.detach().numpy(), | |
| image_embed, | |
| image_attention_mask, | |
| ] | |
| )[self.text_encoder_out] | |
| question_attention_mask = np.ones(question_embeds.shape[:-1], dtype=int) | |
| bos_ids = np.full((question_embeds.shape[0], 1), fill_value=self.decoder_start_token_id) | |
| outputs = self.text_decoder.generate( | |
| input_ids=torch.from_numpy(bos_ids), | |
| eos_token_id=self.config.text_config.sep_token_id, | |
| pad_token_id=self.config.text_config.pad_token_id, | |
| encoder_hidden_states=torch.from_numpy(question_embeds), | |
| encoder_attention_mask=torch.from_numpy(question_attention_mask), | |
| **generate_kwargs, | |
| ) | |
| return outputs | |
| def generate_caption(self, pixel_values: torch.Tensor, input_ids: torch.Tensor = None, attention_mask: torch.Tensor = None, **generate_kwargs): | |
| """ | |
| Image Captioning prediction | |
| Parameters: | |
| pixel_values (torch.Tensor): preprocessed image pixel values | |
| input_ids (torch.Tensor, *optional*, None): pregenerated caption token ids after tokenization, if provided caption generation continue provided text | |
| attention_mask (torch.Tensor): attention mask for caption tokens, used only if input_ids provided | |
| Retruns: | |
| generation output (torch.Tensor): tensor which represents sequence of generated caption token ids | |
| """ | |
| batch_size = pixel_values.shape[0] | |
| image_embeds = self.vision_model(pixel_values.detach().numpy())[self.vision_model_out] | |
| image_attention_mask = torch.ones(image_embeds.shape[:-1], dtype=torch.long) | |
| if isinstance(input_ids, list): | |
| input_ids = torch.LongTensor(input_ids) | |
| elif input_ids is None: | |
| input_ids = torch.LongTensor( | |
| [ | |
| [ | |
| self.config.text_config.bos_token_id, | |
| self.config.text_config.eos_token_id, | |
| ] | |
| ] | |
| ).repeat(batch_size, 1) | |
| input_ids[:, 0] = self.config.text_config.bos_token_id | |
| attention_mask = attention_mask[:, :-1] if attention_mask is not None else None | |
| outputs = self.text_decoder.generate( | |
| input_ids=input_ids[:, :-1], | |
| eos_token_id=self.config.text_config.sep_token_id, | |
| pad_token_id=self.config.text_config.pad_token_id, | |
| attention_mask=attention_mask, | |
| encoder_hidden_states=torch.from_numpy(image_embeds), | |
| encoder_attention_mask=image_attention_mask, | |
| **generate_kwargs, | |
| ) | |
| return outputs | |