Spaces:
Runtime error
Runtime error
| import os | |
| import cv2 | |
| import torch | |
| import albumentations as A | |
| import config as CFG | |
| class PoemTextDataset(torch.utils.data.Dataset): | |
| """ | |
| torch Dataset for PoemTextModel. | |
| ... | |
| Attributes: | |
| ----------- | |
| dataset_dict : list of dict | |
| dataset containing poem-text pair with ids | |
| encoded_poems : dict | |
| output of tokenizer for beyts found in dataset_dict. max_length spedified in configs. | |
| padding and truncation set to True to be truncated or padded to max length. | |
| encoded_texts : dict | |
| output of tokenizer for texts found in dataset_dict. max_length spedified in configs. | |
| padding and truncation set to True to be truncated or padded to max length. | |
| Methods: | |
| -------- | |
| __get_item__(idx) | |
| returns item with index idx. | |
| __len__() | |
| represents length of dataset | |
| """ | |
| def __init__(self, dataset_dict): | |
| """ | |
| Init class, save dataset_dict and calculate output of tokenizers for each text and poem using their corresponding tokenizers. | |
| The tokenizers are chosen based on configs. | |
| Parameters: | |
| ----------- | |
| dataset_dict: list of dict | |
| a list containing dictionaries which have "beyt", "text" and "id" keys. | |
| """ | |
| self.dataset_dict = dataset_dict | |
| poem_tokenizer = CFG.tokenizers[CFG.poem_encoder_model].from_pretrained(CFG.poem_tokenizer) | |
| text_tokenizer = CFG.tokenizers[CFG.text_encoder_model].from_pretrained(CFG.text_tokenizer) | |
| self.encoded_poems = poem_tokenizer( | |
| [item['beyt'] for item in dataset_dict], padding=True, truncation=True, max_length=CFG.poems_max_length | |
| ) | |
| self.encoded_texts = text_tokenizer( | |
| [item['text'] for item in dataset_dict], padding=True, truncation=True, max_length=CFG.text_max_length | |
| ) | |
| def __getitem__(self, idx): | |
| """ | |
| returns a dict having data with index idx. the dict is used as an input to the PoemTextModel. | |
| Parameters: | |
| ----------- | |
| idx: int | |
| index of the data to get | |
| Returns: | |
| -------- | |
| item: dict | |
| a dict having tokenizers' output for poem and text, and id of the data with index idx | |
| """ | |
| item = {} | |
| item["beyt"] = { | |
| key: torch.tensor(values[idx]) | |
| for key, values in self.encoded_poems.items() | |
| } | |
| item["text"] = { | |
| key: torch.tensor(values[idx]) | |
| for key, values in self.encoded_texts.items() | |
| } | |
| item['id'] = self.dataset_dict[idx]['id'] | |
| return item | |
| def __len__(self): | |
| """ | |
| returns the length of the dataset | |
| Returns: | |
| -------- | |
| length: int | |
| length using the length of dataset_dict we saved in class | |
| """ | |
| return len(self.dataset_dict) | |
| class CLIPDataset(torch.utils.data.Dataset): | |
| """ | |
| torch Dataset for CLIPModel. | |
| ... | |
| Attributes: | |
| ----------- | |
| dataset_dict : list of dict | |
| dataset containing poem-image or text-image pair with ids | |
| encoded : dict | |
| output of tokenizer for beyts/texts found in dataset_dict. max_length spedified in configs. | |
| padding and truncation set to True to be truncated or padded to max length. | |
| transforms: albumentations.BasicTransform | |
| transforms to apply to the images | |
| Methods: | |
| -------- | |
| __get_item__(idx) | |
| returns item with index idx. | |
| __len__() | |
| represents length of dataset | |
| """ | |
| def __init__(self, dataset_dict, transforms, is_image_poem_pair=True): | |
| """ | |
| Init class, save dataset_dict and transforms and calculate output of tokenizers for each text and poem using their corresponding tokenizers. | |
| The tokenizers are chosen based on configs. | |
| Parameters: | |
| ----------- | |
| dataset_dict: list of dict | |
| a list containing dictionaries which have "beyt", "text" and "id" keys. | |
| transforms: albumentations.BasicTransform | |
| transforms to apply to the images | |
| is_image_poem_pair: Bool, optional | |
| if set to False, dataset has text-image pairs and must use the corresponding text tokenizer. | |
| else has poem-images pairs and uses the poem tokenizer. | |
| """ | |
| self.dataset_dict = dataset_dict | |
| # using the poem tokenizer to encode poems or text tokenizer to encode text (based on configs). | |
| if is_image_poem_pair: | |
| poem_tokenizer = CFG.tokenizers[CFG.poem_encoder_model].from_pretrained(CFG.poem_tokenizer) | |
| self.encoded = poem_tokenizer( | |
| [item['beyt'] for item in dataset_dict], padding=True, truncation=True, max_length=CFG.poems_max_length | |
| ) | |
| else: | |
| text_tokenizer = CFG.tokenizers[CFG.text_encoder_model].from_pretrained(CFG.text_tokenizer) | |
| self.encoded = text_tokenizer( | |
| [item['text'] for item in dataset_dict], padding=True, truncation=True, max_length=CFG.text_max_length | |
| ) | |
| self.transforms = transforms | |
| def __getitem__(self, idx): | |
| """ | |
| returns a dict having data with index idx. the dict is used as an input to the CLIPModel. | |
| Parameters: | |
| ----------- | |
| idx: int | |
| index of the data to get | |
| Returns: | |
| -------- | |
| item: dict | |
| a dict having tokenizers' output for poem and text, and id of the data with index idx | |
| """ | |
| item = {} | |
| # getting text from encoded texts | |
| item["text"] = { | |
| key: torch.tensor(values[idx]) | |
| for key, values in self.encoded.items() | |
| } | |
| # opening the image | |
| image = cv2.imread(f"{CFG.image_path}{self.dataset_dict[idx]['image']}") | |
| # converting BGR to RGB for transforms | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| # apply transforms | |
| image = self.transforms(image=image)['image'] | |
| # permute dims of image | |
| item['image'] = torch.tensor(image).permute(2, 0, 1).float() | |
| return item | |
| def __len__(self): | |
| """ | |
| returns the length of the dataset | |
| Returns: | |
| -------- | |
| length: int | |
| length using the length of dataset_dict we saved in class | |
| """ | |
| return len(self.dataset_dict) | |
| def get_transforms(mode="train"): | |
| """ | |
| returns transforms to use on image based on mode | |
| Parameters: | |
| ----------- | |
| mode: str, optional | |
| to distinguish between train and val/test transforms (here they are the same!) | |
| Returns: | |
| -------- | |
| item: dict | |
| a dict having tokenizers' output for poem and text, and id of the data with index idx | |
| """ | |
| if mode == "train": | |
| return A.Compose( | |
| [ | |
| A.Resize(CFG.size, CFG.size, always_apply=True), # resizing image to CFG.size | |
| A.Normalize(max_pixel_value=255.0, always_apply=True), # normalizing image values | |
| ] | |
| ) | |
| else: | |
| return A.Compose( | |
| [ | |
| A.Resize(CFG.size, CFG.size, always_apply=True), # resizing image to CFG.size | |
| A.Normalize(max_pixel_value=255.0, always_apply=True), # normalizing image values | |
| ] | |
| ) |