Spaces:
Runtime error
Runtime error
| from torch.utils.data import DataLoader | |
| import torch | |
| import lightning as L | |
| import yaml | |
| import os | |
| import time | |
| import re | |
| from datasets import load_dataset | |
| from .data import ImageConditionDataset, Subject200KDataset, CartoonDataset, SceneDataset | |
| from .model import OminiModel | |
| from .callbacks import TrainingCallback | |
| import safetensors.torch | |
| from peft import PeftModel | |
| import os | |
| from PIL import Image | |
| import pandas as pd | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| from torch.utils.data import DataLoader | |
| class LocalSubjectsDataset(Dataset): | |
| def __init__(self, csv_file, image_dir, transform=None): | |
| self.data = pd.read_csv(csv_file) | |
| self.image_dir = image_dir | |
| self.transform = transform | |
| self.features = { | |
| 'imageA': 'PIL.Image', | |
| 'prompt': 'str', | |
| 'imageB': 'PIL.Image' | |
| } | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| # 获取图片A、描述和图片B的文件名 | |
| imgA_value = self.data.iloc[idx]['imageA'] | |
| if isinstance(imgA_value, pd.Series): | |
| imgA_value = imgA_value.values[0] | |
| imgA_name = os.path.join(self.image_dir, str(imgA_value)) | |
| prompt = self.data.iloc[idx]['prompt'] | |
| imgB_value = self.data.iloc[idx]['imageB'] | |
| if isinstance(imgB_value, pd.Series): | |
| imgB_value = imgB_value.values[0] | |
| imgB_name = os.path.join(self.image_dir, str(imgB_value)) | |
| imageA = Image.open(imgA_name).convert("RGB") | |
| imageB = Image.open(imgB_name).convert("RGB") | |
| if self.transform: | |
| imageA = self.transform(imageA) | |
| imageB = self.transform(imageB) | |
| sample = {'imageA': imageA, 'prompt': prompt, 'imageB': imageB} | |
| return sample | |
| transform = transforms.Compose([ | |
| transforms.Resize((600, 600)), | |
| # transforms.ToTensor(), | |
| ]) | |
| def get_rank(): | |
| try: | |
| rank = int(os.environ.get("LOCAL_RANK")) | |
| except: | |
| rank = 0 | |
| return rank | |
| def get_config(): | |
| config_path = os.environ.get("XFL_CONFIG") | |
| assert config_path is not None, "Please set the XFL_CONFIG environment variable" | |
| with open(config_path, "r") as f: | |
| config = yaml.safe_load(f) | |
| return config | |
| def init_wandb(wandb_config, run_name): | |
| import wandb | |
| wandb.init( | |
| project=wandb_config["project"], | |
| name=run_name, | |
| config={}, | |
| ) | |
| def main(): | |
| # Initialize | |
| is_main_process, rank = get_rank() == 0, get_rank() | |
| torch.cuda.set_device(rank) | |
| config = get_config() | |
| training_config = config["train"] | |
| run_name = time.strftime("%Y%m%d-%H%M%S") | |
| # Initialize WanDB | |
| wandb_config = training_config.get("wandb", None) | |
| if wandb_config is not None and is_main_process: | |
| init_wandb(wandb_config, run_name) | |
| print("Rank:", rank) | |
| if is_main_process: | |
| print("Config:", config) | |
| # Initialize dataset and dataloader | |
| if training_config["dataset"]["type"] == "scene": | |
| dataset = LocalSubjectsDataset(csv_file='csv_path', image_dir='images_path', transform=transform) | |
| data_valid = dataset | |
| print(data_valid.features) | |
| print(len(data_valid)) | |
| print(training_config["dataset"]) | |
| dataset = SceneDataset( | |
| data_valid, | |
| condition_size=training_config["dataset"]["condition_size"], | |
| target_size=training_config["dataset"]["target_size"], | |
| image_size=training_config["dataset"]["image_size"], | |
| padding=training_config["dataset"]["padding"], | |
| condition_type=training_config["condition_type"], | |
| drop_text_prob=training_config["dataset"]["drop_text_prob"], | |
| drop_image_prob=training_config["dataset"]["drop_image_prob"], | |
| ) | |
| elif training_config["dataset"]["type"] == "img": | |
| # Load dataset text-to-image-2M | |
| dataset = load_dataset( | |
| "webdataset", | |
| data_files={"train": training_config["dataset"]["urls"]}, | |
| split="train", | |
| cache_dir="cache/t2i2m", | |
| num_proc=32, | |
| ) | |
| dataset = ImageConditionDataset( | |
| dataset, | |
| condition_size=training_config["dataset"]["condition_size"], | |
| target_size=training_config["dataset"]["target_size"], | |
| condition_type=training_config["condition_type"], | |
| drop_text_prob=training_config["dataset"]["drop_text_prob"], | |
| drop_image_prob=training_config["dataset"]["drop_image_prob"], | |
| position_scale=training_config["dataset"].get("position_scale", 1.0), | |
| ) | |
| elif training_config["dataset"]["type"] == "cartoon": | |
| dataset = load_dataset("saquiboye/oye-cartoon", split="train") | |
| dataset = CartoonDataset( | |
| dataset, | |
| condition_size=training_config["dataset"]["condition_size"], | |
| target_size=training_config["dataset"]["target_size"], | |
| image_size=training_config["dataset"]["image_size"], | |
| padding=training_config["dataset"]["padding"], | |
| condition_type=training_config["condition_type"], | |
| drop_text_prob=training_config["dataset"]["drop_text_prob"], | |
| drop_image_prob=training_config["dataset"]["drop_image_prob"], | |
| ) | |
| elif training_config["dataset"]["type"] == "scene": | |
| dataset = dataset | |
| else: | |
| raise NotImplementedError | |
| print("Dataset length:", len(dataset)) | |
| train_loader = DataLoader( | |
| dataset, | |
| batch_size=training_config["batch_size"], | |
| shuffle=True, | |
| num_workers=training_config["dataloader_workers"], | |
| ) | |
| print("Trainloader generated.") | |
| # Initialize model | |
| trainable_model = OminiModel( | |
| flux_pipe_id=config["flux_path"], | |
| lora_config=training_config["lora_config"], | |
| device=f"cuda", | |
| dtype=getattr(torch, config["dtype"]), | |
| optimizer_config=training_config["optimizer"], | |
| model_config=config.get("model", {}), | |
| gradient_checkpointing=training_config.get("gradient_checkpointing", False), | |
| ) | |
| training_callbacks = ( | |
| [TrainingCallback(run_name, training_config=training_config)] | |
| if is_main_process | |
| else [] | |
| ) | |
| # Initialize trainer | |
| trainer = L.Trainer( | |
| accumulate_grad_batches=training_config["accumulate_grad_batches"], | |
| callbacks=training_callbacks, | |
| enable_checkpointing=False, | |
| enable_progress_bar=False, | |
| logger=False, | |
| max_steps=training_config.get("max_steps", -1), | |
| max_epochs=training_config.get("max_epochs", -1), | |
| gradient_clip_val=training_config.get("gradient_clip_val", 0.5), | |
| ) | |
| setattr(trainer, "training_config", training_config) | |
| # Save config | |
| save_path = training_config.get("save_path", "./output") | |
| if is_main_process: | |
| os.makedirs(f"{save_path}/{run_name}") | |
| with open(f"{save_path}/{run_name}/config.yaml", "w") as f: | |
| yaml.dump(config, f) | |
| # Start training | |
| trainer.fit(trainable_model, train_loader) | |
| if __name__ == "__main__": | |
| main() | |