| import os |
| import numpy as np |
| import pickle |
| from torch.utils import data |
| import torchaudio.transforms as T |
| import torchaudio |
| import torch |
| import csv |
| import pytorch_lightning as pl |
| from music2latent import EncoderDecoder |
| import json |
| import math |
| from sklearn.preprocessing import StandardScaler |
|
|
| from dataset_loaders.jamendo import JamendoDataset |
| from dataset_loaders.pmemo import PMEmoDataset |
| from dataset_loaders.deam import DEAMDataset |
| from dataset_loaders.emomusic import EmoMusicDataset |
|
|
| from omegaconf import DictConfig |
|
|
| DATASET_REGISTRY = { |
| "jamendo": JamendoDataset, |
| "pmemo": PMEmoDataset, |
| "deam": DEAMDataset, |
| "emomusic": EmoMusicDataset |
| } |
|
|
| class DataModule(pl.LightningDataModule): |
| def __init__(self, cfg: DictConfig): |
| super().__init__() |
| self.cfg = cfg |
|
|
| self.train_datasets = [] |
| self.val_datasets = [] |
| self.test_datasets = [] |
| |
| def setup(self, stage=None): |
| |
| self.train_datasets = [] |
| self.val_datasets = [] |
| self.test_datasets = [] |
|
|
| |
| for dataset_name in self.cfg.datasets: |
| dataset_cfg = self.cfg.dataset[dataset_name] |
|
|
| if dataset_name in DATASET_REGISTRY: |
| train_dataset = DATASET_REGISTRY[dataset_name](**dataset_cfg, cfg=self.cfg, tr_val='train') |
| val_dataset = DATASET_REGISTRY[dataset_name](**dataset_cfg, cfg=self.cfg, tr_val='validation') |
| test_dataset = DATASET_REGISTRY[dataset_name](**dataset_cfg, cfg=self.cfg, tr_val='test') |
|
|
| self.train_datasets.append(train_dataset) |
| self.val_datasets.append(val_dataset) |
| self.test_datasets.append(test_dataset) |
| else: |
| raise ValueError(f"Dataset {dataset_name} not found in registry") |
|
|
| def train_dataloader(self): |
| return [data.DataLoader(ds, batch_size=self.cfg.dataset[ds_name].batch_size, |
| shuffle=True, num_workers=self.cfg.dataset[ds_name].num_workers, |
| persistent_workers=True) |
| for ds, ds_name in zip(self.train_datasets, self.cfg.datasets)] |
|
|
| def val_dataloader(self): |
| return [data.DataLoader(ds, batch_size=self.cfg.dataset[ds_name].batch_size, |
| shuffle=False, num_workers=self.cfg.dataset[ds_name].num_workers, |
| persistent_workers=True) |
| for ds, ds_name in zip(self.val_datasets, self.cfg.datasets)] |
|
|
| def test_dataloader(self): |
| return [data.DataLoader(ds, batch_size=self.cfg.dataset[ds_name].batch_size, |
| shuffle=False, num_workers=self.cfg.dataset[ds_name].num_workers, |
| persistent_workers=True) |
| for ds, ds_name in zip(self.test_datasets, self.cfg.datasets)] |
|
|
|
|
|
|
|
|