brity commited on
Commit
58e01c5
·
verified ·
1 Parent(s): 4426ed0

End of training

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ images_0.png filter=lfs diff=lfs merge=lfs -text
37
+ images_1.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: stable-diffusion-v1-5/stable-diffusion-v1-5
3
+ library_name: diffusers
4
+ license: creativeml-openrail-m
5
+ inference: true
6
+ tags:
7
+ - stable-diffusion
8
+ - stable-diffusion-diffusers
9
+ - text-to-image
10
+ - diffusers
11
+ - controlnet
12
+ - diffusers-training
13
+ ---
14
+
15
+ <!-- This model card has been generated automatically according to the information the training script had access to. You
16
+ should probably proofread and complete it, then remove this comment. -->
17
+
18
+
19
+ # controlnet-brity/controlnet-2210
20
+
21
+ These are controlnet weights trained on stable-diffusion-v1-5/stable-diffusion-v1-5 with new type of conditioning.
22
+ You can find some example images below.
23
+
24
+ prompt: red circle with blue background
25
+ ![images_0)](./images_0.png)
26
+ prompt: cyan circle with brown floral background
27
+ ![images_1)](./images_1.png)
28
+
29
+
30
+
31
+ ## Intended uses & limitations
32
+
33
+ #### How to use
34
+
35
+ ```python
36
+ # TODO: add an example code snippet for running this diffusion pipeline
37
+ ```
38
+
39
+ #### Limitations and bias
40
+
41
+ [TODO: provide examples of latent issues and potential remediations]
42
+
43
+ ## Training details
44
+
45
+ [TODO: describe the data used to train the model]
README_flux.md ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ControlNet training example for FLUX
2
+
3
+ The `train_controlnet_flux.py` script shows how to implement the ControlNet training procedure and adapt it for [FLUX](https://github.com/black-forest-labs/flux).
4
+
5
+ Training script provided by LibAI, which is an institution dedicated to the progress and achievement of artificial general intelligence. LibAI is the developer of [cutout.pro](https://www.cutout.pro/) and [promeai.pro](https://www.promeai.pro/).
6
+ > [!NOTE]
7
+ > **Memory consumption**
8
+ >
9
+ > Flux can be quite expensive to run on consumer hardware devices and as a result, ControlNet training of it comes with higher memory requirements than usual.
10
+
11
+ > **Gated access**
12
+ >
13
+ > As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in: `huggingface-cli login`
14
+
15
+
16
+ ## Running locally with PyTorch
17
+
18
+ ### Installing the dependencies
19
+
20
+ Before running the scripts, make sure to install the library's training dependencies:
21
+
22
+ **Important**
23
+
24
+ To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
25
+
26
+ ```bash
27
+ git clone https://github.com/huggingface/diffusers
28
+ cd diffusers
29
+ pip install -e .
30
+ ```
31
+
32
+ Then cd in the `examples/controlnet` folder and run
33
+ ```bash
34
+ pip install -r requirements_flux.txt
35
+ ```
36
+
37
+ And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
38
+
39
+ ```bash
40
+ accelerate config
41
+ ```
42
+
43
+ Or for a default accelerate configuration without answering questions about your environment
44
+
45
+ ```bash
46
+ accelerate config default
47
+ ```
48
+
49
+ Or if your environment doesn't support an interactive shell (e.g., a notebook)
50
+
51
+ ```python
52
+ from accelerate.utils import write_basic_config
53
+ write_basic_config()
54
+ ```
55
+
56
+ When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
57
+
58
+ ## Custom Datasets
59
+
60
+ We support dataset formats:
61
+ The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script. To use our example, add `--dataset_name=fusing/fill50k \` to the script and remove line `--jsonl_for_train` mentioned below.
62
+
63
+
64
+ We also support importing data from jsonl(xxx.jsonl),using `--jsonl_for_train` to enable it, here is a brief example of jsonl files:
65
+ ```sh
66
+ {"image": "xxx", "text": "xxx", "conditioning_image": "xxx"}
67
+ {"image": "xxx", "text": "xxx", "conditioning_image": "xxx"}
68
+ ```
69
+
70
+ ## Training
71
+
72
+ Our training examples use two test conditioning images. They can be downloaded by running
73
+
74
+ ```sh
75
+ wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
76
+ wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
77
+ ```
78
+
79
+ Then run `huggingface-cli login` to log into your Hugging Face account. This is needed to be able to push the trained ControlNet parameters to Hugging Face Hub.
80
+
81
+ we can define the num_layers, num_single_layers, which determines the size of the control(default values are num_layers=4, num_single_layers=10)
82
+
83
+
84
+ ```bash
85
+ accelerate launch train_controlnet_flux.py \
86
+ --pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
87
+ --dataset_name=fusing/fill50k \
88
+ --conditioning_image_column=conditioning_image \
89
+ --image_column=image \
90
+ --caption_column=text \
91
+ --output_dir="path to save model" \
92
+ --mixed_precision="bf16" \
93
+ --resolution=512 \
94
+ --learning_rate=1e-5 \
95
+ --max_train_steps=15000 \
96
+ --validation_steps=100 \
97
+ --checkpointing_steps=200 \
98
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
99
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
100
+ --train_batch_size=1 \
101
+ --gradient_accumulation_steps=4 \
102
+ --report_to="wandb" \
103
+ --num_double_layers=4 \
104
+ --num_single_layers=0 \
105
+ --seed=42 \
106
+ --push_to_hub \
107
+ ```
108
+
109
+ To better track our training experiments, we're using the following flags in the command above:
110
+
111
+ * `report_to="wandb` will ensure the training runs are tracked on Weights and Biases.
112
+ * `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
113
+
114
+ Our experiments were conducted on a single 80GB A100 GPU.
115
+
116
+ ### Inference
117
+
118
+ Once training is done, we can perform inference like so:
119
+
120
+ ```python
121
+ import torch
122
+ from diffusers.utils import load_image
123
+ from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
124
+ from diffusers.models.controlnet_flux import FluxControlNetModel
125
+
126
+ base_model = 'black-forest-labs/FLUX.1-dev'
127
+ controlnet_model = 'promeai/FLUX.1-controlnet-lineart-promeai'
128
+ controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
129
+ pipe = FluxControlNetPipeline.from_pretrained(
130
+ base_model,
131
+ controlnet=controlnet,
132
+ torch_dtype=torch.bfloat16
133
+ )
134
+ # enable memory optimizations
135
+ pipe.enable_model_cpu_offload()
136
+
137
+ control_image = load_image("https://huggingface.co/promeai/FLUX.1-controlnet-lineart-promeai/resolve/main/images/example-control.jpg")resize((1024, 1024))
138
+ prompt = "cute anime girl with massive fluffy fennec ears and a big fluffy tail blonde messy long hair blue eyes wearing a maid outfit with a long black gold leaf pattern dress and a white apron mouth open holding a fancy black forest cake with candles on top in the kitchen of an old dark Victorian mansion lit by candlelight with a bright window to the foggy forest and very expensive stuff everywhere"
139
+
140
+ image = pipe(
141
+ prompt,
142
+ control_image=control_image,
143
+ controlnet_conditioning_scale=0.6,
144
+ num_inference_steps=28,
145
+ guidance_scale=3.5,
146
+ ).images[0]
147
+ image.save("./output.png")
148
+ ```
149
+
150
+ ## Apply Deepspeed Zero3
151
+
152
+ This is an experimental process, I am not sure if it is suitable for everyone, we used this process to successfully train 512 resolution on A100(40g) * 8.
153
+ Please modify some of the code in the script.
154
+ ### 1.Customize zero3 settings
155
+
156
+ Copy the **accelerate_config_zero3.yaml**,modify `num_processes` according to the number of gpus you want to use:
157
+
158
+ ```bash
159
+ compute_environment: LOCAL_MACHINE
160
+ debug: false
161
+ deepspeed_config:
162
+ gradient_accumulation_steps: 8
163
+ offload_optimizer_device: cpu
164
+ offload_param_device: cpu
165
+ zero3_init_flag: true
166
+ zero3_save_16bit_model: true
167
+ zero_stage: 3
168
+ distributed_type: DEEPSPEED
169
+ downcast_bf16: 'no'
170
+ enable_cpu_affinity: false
171
+ machine_rank: 0
172
+ main_training_function: main
173
+ mixed_precision: bf16
174
+ num_machines: 1
175
+ num_processes: 8
176
+ rdzv_backend: static
177
+ same_network: true
178
+ tpu_env: []
179
+ tpu_use_cluster: false
180
+ tpu_use_sudo: false
181
+ use_cpu: false
182
+ ```
183
+
184
+ ### 2.Precompute all inputs (latent, embeddings)
185
+
186
+ In the train_controlnet_flux.py, We need to pre-calculate all parameters and put them into batches.So we first need to rewrite the `compute_embeddings` function.
187
+
188
+ ```python
189
+ def compute_embeddings(batch, proportion_empty_prompts, vae, flux_controlnet_pipeline, weight_dtype, is_train=True):
190
+
191
+ ### compute text embeddings
192
+ prompt_batch = batch[args.caption_column]
193
+ captions = []
194
+ for caption in prompt_batch:
195
+ if random.random() < proportion_empty_prompts:
196
+ captions.append("")
197
+ elif isinstance(caption, str):
198
+ captions.append(caption)
199
+ elif isinstance(caption, (list, np.ndarray)):
200
+ # take a random caption if there are multiple
201
+ captions.append(random.choice(caption) if is_train else caption[0])
202
+ prompt_batch = captions
203
+ prompt_embeds, pooled_prompt_embeds, text_ids = flux_controlnet_pipeline.encode_prompt(
204
+ prompt_batch, prompt_2=prompt_batch
205
+ )
206
+ prompt_embeds = prompt_embeds.to(dtype=weight_dtype)
207
+ pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=weight_dtype)
208
+ text_ids = text_ids.to(dtype=weight_dtype)
209
+
210
+ # text_ids [512,3] to [bs,512,3]
211
+ text_ids = text_ids.unsqueeze(0).expand(prompt_embeds.shape[0], -1, -1)
212
+
213
+ ### compute latents
214
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
215
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
216
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
217
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
218
+ return latents
219
+
220
+ # vae encode
221
+ pixel_values = batch["pixel_values"]
222
+ pixel_values = torch.stack([image for image in pixel_values]).to(dtype=weight_dtype).to(vae.device)
223
+ pixel_latents_tmp = vae.encode(pixel_values).latent_dist.sample()
224
+ pixel_latents_tmp = (pixel_latents_tmp - vae.config.shift_factor) * vae.config.scaling_factor
225
+ pixel_latents = _pack_latents(
226
+ pixel_latents_tmp,
227
+ pixel_values.shape[0],
228
+ pixel_latents_tmp.shape[1],
229
+ pixel_latents_tmp.shape[2],
230
+ pixel_latents_tmp.shape[3],
231
+ )
232
+
233
+ control_values = batch["conditioning_pixel_values"]
234
+ control_values = torch.stack([image for image in control_values]).to(dtype=weight_dtype).to(vae.device)
235
+ control_latents = vae.encode(control_values).latent_dist.sample()
236
+ control_latents = (control_latents - vae.config.shift_factor) * vae.config.scaling_factor
237
+ control_latents = _pack_latents(
238
+ control_latents,
239
+ control_values.shape[0],
240
+ control_latents.shape[1],
241
+ control_latents.shape[2],
242
+ control_latents.shape[3],
243
+ )
244
+
245
+ # copied from pipeline_flux_controlnet
246
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
247
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
248
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
249
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
250
+
251
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
252
+
253
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
254
+ latent_image_ids = latent_image_ids.reshape(
255
+ batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
256
+ )
257
+
258
+ return latent_image_ids.to(device=device, dtype=dtype)
259
+ latent_image_ids = _prepare_latent_image_ids(
260
+ batch_size=pixel_latents_tmp.shape[0],
261
+ height=pixel_latents_tmp.shape[2],
262
+ width=pixel_latents_tmp.shape[3],
263
+ device=pixel_values.device,
264
+ dtype=pixel_values.dtype,
265
+ )
266
+
267
+ # unet_added_cond_kwargs = {"pooled_prompt_embeds": pooled_prompt_embeds, "text_ids": text_ids}
268
+ return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds, "text_ids": text_ids, "pixel_latents": pixel_latents, "control_latents": control_latents, "latent_image_ids": latent_image_ids}
269
+ ```
270
+
271
+ Because we need images to pass through vae, we need to preprocess the images in the dataset first. At the same time, vae requires more gpu memory, so you may need to modify the `batch_size` below
272
+ ```diff
273
+ +train_dataset = prepare_train_dataset(train_dataset, accelerator)
274
+ with accelerator.main_process_first():
275
+ from datasets.fingerprint import Hasher
276
+
277
+ # fingerprint used by the cache for the other processes to load the result
278
+ # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
279
+ new_fingerprint = Hasher.hash(args)
280
+ train_dataset = train_dataset.map(
281
+ - compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=100
282
+ + compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=10
283
+ )
284
+
285
+ del text_encoders, tokenizers
286
+ gc.collect()
287
+ torch.cuda.empty_cache()
288
+
289
+ # Then get the training dataset ready to be passed to the dataloader.
290
+ -train_dataset = prepare_train_dataset(train_dataset, accelerator)
291
+ ```
292
+ ### 3.Redefine the behavior of getting batchsize
293
+
294
+ Now that we have all the preprocessing done, we need to modify the `collate_fn` function.
295
+
296
+ ```python
297
+ def collate_fn(examples):
298
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
299
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
300
+
301
+ conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
302
+ conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
303
+
304
+ pixel_latents = torch.stack([torch.tensor(example["pixel_latents"]) for example in examples])
305
+ pixel_latents = pixel_latents.to(memory_format=torch.contiguous_format).float()
306
+
307
+ control_latents = torch.stack([torch.tensor(example["control_latents"]) for example in examples])
308
+ control_latents = control_latents.to(memory_format=torch.contiguous_format).float()
309
+
310
+ latent_image_ids= torch.stack([torch.tensor(example["latent_image_ids"]) for example in examples])
311
+
312
+ prompt_ids = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples])
313
+
314
+ pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples])
315
+ text_ids = torch.stack([torch.tensor(example["text_ids"]) for example in examples])
316
+
317
+ return {
318
+ "pixel_values": pixel_values,
319
+ "conditioning_pixel_values": conditioning_pixel_values,
320
+ "pixel_latents": pixel_latents,
321
+ "control_latents": control_latents,
322
+ "latent_image_ids": latent_image_ids,
323
+ "prompt_ids": prompt_ids,
324
+ "unet_added_conditions": {"pooled_prompt_embeds": pooled_prompt_embeds, "time_ids": text_ids},
325
+ }
326
+ ```
327
+ Finally, we just need to modify the way of obtaining various parameters during training.
328
+ ```python
329
+ for epoch in range(first_epoch, args.num_train_epochs):
330
+ for step, batch in enumerate(train_dataloader):
331
+ with accelerator.accumulate(flux_controlnet):
332
+ # Convert images to latent space
333
+ pixel_latents = batch["pixel_latents"].to(dtype=weight_dtype)
334
+ control_image = batch["control_latents"].to(dtype=weight_dtype)
335
+ latent_image_ids = batch["latent_image_ids"].to(dtype=weight_dtype)
336
+
337
+ # Sample noise that we'll add to the latents
338
+ noise = torch.randn_like(pixel_latents).to(accelerator.device).to(dtype=weight_dtype)
339
+ bsz = pixel_latents.shape[0]
340
+
341
+ # Sample a random timestep for each image
342
+ t = torch.sigmoid(torch.randn((bsz,), device=accelerator.device, dtype=weight_dtype))
343
+
344
+ # apply flow matching
345
+ noisy_latents = (
346
+ 1 - t.unsqueeze(1).unsqueeze(2).repeat(1, pixel_latents.shape[1], pixel_latents.shape[2])
347
+ ) * pixel_latents + t.unsqueeze(1).unsqueeze(2).repeat(
348
+ 1, pixel_latents.shape[1], pixel_latents.shape[2]
349
+ ) * noise
350
+
351
+ guidance_vec = torch.full(
352
+ (noisy_latents.shape[0],), 3.5, device=noisy_latents.device, dtype=weight_dtype
353
+ )
354
+
355
+ controlnet_block_samples, controlnet_single_block_samples = flux_controlnet(
356
+ hidden_states=noisy_latents,
357
+ controlnet_cond=control_image,
358
+ timestep=t,
359
+ guidance=guidance_vec,
360
+ pooled_projections=batch["unet_added_conditions"]["pooled_prompt_embeds"].to(dtype=weight_dtype),
361
+ encoder_hidden_states=batch["prompt_ids"].to(dtype=weight_dtype),
362
+ txt_ids=batch["unet_added_conditions"]["time_ids"][0].to(dtype=weight_dtype),
363
+ img_ids=latent_image_ids[0],
364
+ return_dict=False,
365
+ )
366
+
367
+ noise_pred = flux_transformer(
368
+ hidden_states=noisy_latents,
369
+ timestep=t,
370
+ guidance=guidance_vec,
371
+ pooled_projections=batch["unet_added_conditions"]["pooled_prompt_embeds"].to(dtype=weight_dtype),
372
+ encoder_hidden_states=batch["prompt_ids"].to(dtype=weight_dtype),
373
+ controlnet_block_samples=[sample.to(dtype=weight_dtype) for sample in controlnet_block_samples]
374
+ if controlnet_block_samples is not None
375
+ else None,
376
+ controlnet_single_block_samples=[
377
+ sample.to(dtype=weight_dtype) for sample in controlnet_single_block_samples
378
+ ]
379
+ if controlnet_single_block_samples is not None
380
+ else None,
381
+ txt_ids=batch["unet_added_conditions"]["time_ids"][0].to(dtype=weight_dtype),
382
+ img_ids=latent_image_ids[0],
383
+ return_dict=False,
384
+ )[0]
385
+ ```
386
+ Congratulations! You have completed all the required code modifications required for deepspeedzero3.
387
+
388
+ ### 4.Training with deepspeedzero3
389
+
390
+ Start!!!
391
+
392
+ ```bash
393
+ export pretrained_model_name_or_path='flux-dev-model-path'
394
+ export MODEL_TYPE='train_model_type'
395
+ export TRAIN_JSON_FILE="your_json_file"
396
+ export CONTROL_TYPE='control_preprocessor_type'
397
+ export CAPTION_COLUMN='caption_column'
398
+
399
+ export CACHE_DIR="/data/train_csr/.cache/huggingface/"
400
+ export OUTPUT_DIR='/data/train_csr/FLUX/MODEL_OUT/'$MODEL_TYPE
401
+ # The first step is to use Python to precompute all caches.Replace the first line below with this line. (I am not sure why using acclerate would cause problems.)
402
+
403
+ CUDA_VISIBLE_DEVICES=0 python3 train_controlnet_flux.py \
404
+
405
+ # The second step is to use the above accelerate config to train
406
+ accelerate launch --config_file "./accelerate_config_zero3.yaml" train_controlnet_flux.py \
407
+ --pretrained_model_name_or_path=$pretrained_model_name_or_path \
408
+ --jsonl_for_train=$TRAIN_JSON_FILE \
409
+ --conditioning_image_column=$CONTROL_TYPE \
410
+ --image_column=image \
411
+ --caption_column=$CAPTION_COLUMN\
412
+ --cache_dir=$CACHE_DIR \
413
+ --tracker_project_name=$MODEL_TYPE \
414
+ --output_dir=$OUTPUT_DIR \
415
+ --max_train_steps=500000 \
416
+ --mixed_precision bf16 \
417
+ --checkpointing_steps=1000 \
418
+ --gradient_accumulation_steps=8 \
419
+ --resolution=512 \
420
+ --train_batch_size=1 \
421
+ --learning_rate=1e-5 \
422
+ --num_double_layers=4 \
423
+ --num_single_layers=0 \
424
+ --gradient_checkpointing \
425
+ --resume_from_checkpoint="latest" \
426
+ # --use_adafactor \ dont use
427
+ # --validation_steps=3 \ not support
428
+ # --validation_image $VALIDATION_IMAGE \ not support
429
+ # --validation_prompt "xxx" \ not support
430
+ ```
README_sd3.md ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ControlNet training example for Stable Diffusion 3/3.5 (SD3/3.5)
2
+
3
+ The `train_controlnet_sd3.py` script shows how to implement the ControlNet training procedure and adapt it for [Stable Diffusion 3](https://arxiv.org/abs/2403.03206) and [Stable Diffusion 3.5](https://stability.ai/news/introducing-stable-diffusion-3-5).
4
+
5
+ ## Running locally with PyTorch
6
+
7
+ ### Installing the dependencies
8
+
9
+ Before running the scripts, make sure to install the library's training dependencies:
10
+
11
+ **Important**
12
+
13
+ To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
14
+
15
+ ```bash
16
+ git clone https://github.com/huggingface/diffusers
17
+ cd diffusers
18
+ pip install -e .
19
+ ```
20
+
21
+ Then cd in the `examples/controlnet` folder and run
22
+ ```bash
23
+ pip install -r requirements_sd3.txt
24
+ ```
25
+
26
+ And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
27
+
28
+ ```bash
29
+ accelerate config
30
+ ```
31
+
32
+ Or for a default accelerate configuration without answering questions about your environment
33
+
34
+ ```bash
35
+ accelerate config default
36
+ ```
37
+
38
+ Or if your environment doesn't support an interactive shell (e.g., a notebook)
39
+
40
+ ```python
41
+ from accelerate.utils import write_basic_config
42
+ write_basic_config()
43
+ ```
44
+
45
+ When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
46
+
47
+ ## Circle filling dataset
48
+
49
+ The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script.
50
+ Please download the dataset and unzip it in the directory `fill50k` in the `examples/controlnet` folder.
51
+
52
+ ## Training
53
+
54
+ First download the SD3 model from [Hugging Face Hub](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) or the SD3.5 model from [Hugging Face Hub](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium). We will use it as a base model for the ControlNet training.
55
+ > [!NOTE]
56
+ > As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) or [Stable Diffusion 3.5 Large Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
57
+
58
+ ```bash
59
+ huggingface-cli login
60
+ ```
61
+
62
+ This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
63
+
64
+
65
+ Our training examples use two test conditioning images. They can be downloaded by running
66
+
67
+ ```sh
68
+ wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
69
+
70
+ wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
71
+ ```
72
+
73
+ Then run the following commands to train a ControlNet model.
74
+
75
+ ```bash
76
+ export MODEL_DIR="stabilityai/stable-diffusion-3-medium-diffusers"
77
+ export OUTPUT_DIR="sd3-controlnet-out"
78
+
79
+ accelerate launch train_controlnet_sd3.py \
80
+ --pretrained_model_name_or_path=$MODEL_DIR \
81
+ --output_dir=$OUTPUT_DIR \
82
+ --train_data_dir="fill50k" \
83
+ --resolution=1024 \
84
+ --learning_rate=1e-5 \
85
+ --max_train_steps=15000 \
86
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
87
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
88
+ --validation_steps=100 \
89
+ --train_batch_size=1 \
90
+ --gradient_accumulation_steps=4
91
+ ```
92
+
93
+ To train a ControlNet model for Stable Diffusion 3.5, replace the `MODEL_DIR` with `stabilityai/stable-diffusion-3.5-medium`.
94
+
95
+ To better track our training experiments, we're using flags `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
96
+
97
+ Our experiments were conducted on a single 40GB A100 GPU.
98
+
99
+ ### Inference
100
+
101
+ Once training is done, we can perform inference like so:
102
+
103
+ ```python
104
+ from diffusers import StableDiffusion3ControlNetPipeline, SD3ControlNetModel
105
+ from diffusers.utils import load_image
106
+ import torch
107
+
108
+ base_model_path = "stabilityai/stable-diffusion-3-medium-diffusers"
109
+ controlnet_path = "DavyMorgan/sd3-controlnet-out"
110
+
111
+ controlnet = SD3ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
112
+ pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
113
+ base_model_path, controlnet=controlnet
114
+ )
115
+ pipe.to("cuda", torch.float16)
116
+
117
+
118
+ control_image = load_image("./conditioning_image_1.png").resize((1024, 1024))
119
+ prompt = "pale golden rod circle with old lace background"
120
+
121
+ # generate image
122
+ generator = torch.manual_seed(0)
123
+ image = pipe(
124
+ prompt, num_inference_steps=20, generator=generator, control_image=control_image
125
+ ).images[0]
126
+ image.save("./output.png")
127
+ ```
128
+
129
+ Similarly, for SD3.5, replace the `base_model_path` with `stabilityai/stable-diffusion-3.5-medium` and controlnet_path `DavyMorgan/sd35-controlnet-out'.
130
+
131
+ ## Notes
132
+
133
+ ### GPU usage
134
+
135
+ SD3 is a large model and requires a lot of GPU memory.
136
+ We recommend using one GPU with at least 80GB of memory.
137
+ Make sure to use the right GPU when configuring the [accelerator](https://huggingface.co/docs/transformers/en/accelerate).
138
+
139
+
140
+ ## Example results
141
+
142
+ ### SD3
143
+
144
+ #### After 500 steps with batch size 8
145
+
146
+ | | |
147
+ |-------------------|:-------------------------:|
148
+ || pale golden rod circle with old lace background |
149
+ ![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png) | ![pale golden rod circle with old lace background](https://huggingface.co/datasets/DavyMorgan/sd3-controlnet-results/resolve/main/step-500.png) |
150
+
151
+
152
+ #### After 6500 steps with batch size 8:
153
+
154
+ | | |
155
+ |-------------------|:-------------------------:|
156
+ || pale golden rod circle with old lace background |
157
+ ![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png) | ![pale golden rod circle with old lace background](https://huggingface.co/datasets/DavyMorgan/sd3-controlnet-results/resolve/main/step-6500.png) |
158
+
159
+ ### SD3.5
160
+
161
+ #### After 500 steps with batch size 8
162
+
163
+ | | |
164
+ |-------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------:|
165
+ || pale golden rod circle with old lace background |
166
+ ![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png) | ![pale golden rod circle with old lace background](https://huggingface.co/datasets/DavyMorgan/sd3-controlnet-results/resolve/main/step-500-3.5.png) |
167
+
168
+
169
+ #### After 3000 steps with batch size 8:
170
+
171
+ | | |
172
+ |-------------------|:----------------------------------------------------------------------------------------------------------------------------------------------------:|
173
+ || pale golden rod circle with old lace background |
174
+ ![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png) | ![pale golden rod circle with old lace background](https://huggingface.co/datasets/DavyMorgan/sd3-controlnet-results/resolve/main/step-3000-3.5.png) |
175
+
README_sdxl.md ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ControlNet training example for Stable Diffusion XL (SDXL)
2
+
3
+ The `train_controlnet_sdxl.py` script shows how to implement the ControlNet training procedure and adapt it for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952).
4
+
5
+ ## Running locally with PyTorch
6
+
7
+ ### Installing the dependencies
8
+
9
+ Before running the scripts, make sure to install the library's training dependencies:
10
+
11
+ **Important**
12
+
13
+ To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
14
+
15
+ ```bash
16
+ git clone https://github.com/huggingface/diffusers
17
+ cd diffusers
18
+ pip install -e .
19
+ ```
20
+
21
+ Then cd in the `examples/controlnet` folder and run
22
+ ```bash
23
+ pip install -r requirements_sdxl.txt
24
+ ```
25
+
26
+ And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
27
+
28
+ ```bash
29
+ accelerate config
30
+ ```
31
+
32
+ Or for a default accelerate configuration without answering questions about your environment
33
+
34
+ ```bash
35
+ accelerate config default
36
+ ```
37
+
38
+ Or if your environment doesn't support an interactive shell (e.g., a notebook)
39
+
40
+ ```python
41
+ from accelerate.utils import write_basic_config
42
+ write_basic_config()
43
+ ```
44
+
45
+ When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
46
+
47
+ ## Circle filling dataset
48
+
49
+ The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script.
50
+
51
+ ## Training
52
+
53
+ Our training examples use two test conditioning images. They can be downloaded by running
54
+
55
+ ```sh
56
+ wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
57
+
58
+ wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
59
+ ```
60
+
61
+ Then run `huggingface-cli login` to log into your Hugging Face account. This is needed to be able to push the trained ControlNet parameters to Hugging Face Hub.
62
+
63
+ ```bash
64
+ export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
65
+ export OUTPUT_DIR="path to save model"
66
+
67
+ accelerate launch train_controlnet_sdxl.py \
68
+ --pretrained_model_name_or_path=$MODEL_DIR \
69
+ --output_dir=$OUTPUT_DIR \
70
+ --dataset_name=fusing/fill50k \
71
+ --mixed_precision="fp16" \
72
+ --resolution=1024 \
73
+ --learning_rate=1e-5 \
74
+ --max_train_steps=15000 \
75
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
76
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
77
+ --validation_steps=100 \
78
+ --train_batch_size=1 \
79
+ --gradient_accumulation_steps=4 \
80
+ --report_to="wandb" \
81
+ --seed=42 \
82
+ --push_to_hub
83
+ ```
84
+
85
+ To better track our training experiments, we're using the following flags in the command above:
86
+
87
+ * `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
88
+ * `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
89
+
90
+ Our experiments were conducted on a single 40GB A100 GPU.
91
+
92
+ ### Inference
93
+
94
+ Once training is done, we can perform inference like so:
95
+
96
+ ```python
97
+ from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
98
+ from diffusers.utils import load_image
99
+ import torch
100
+
101
+ base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
102
+ controlnet_path = "path to controlnet"
103
+
104
+ controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
105
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
106
+ base_model_path, controlnet=controlnet, torch_dtype=torch.float16
107
+ )
108
+
109
+ # speed up diffusion process with faster scheduler and memory optimization
110
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
111
+ # remove following line if xformers is not installed or when using Torch 2.0.
112
+ pipe.enable_xformers_memory_efficient_attention()
113
+ # memory optimization.
114
+ pipe.enable_model_cpu_offload()
115
+
116
+ control_image = load_image("./conditioning_image_1.png").resize((1024, 1024))
117
+ prompt = "pale golden rod circle with old lace background"
118
+
119
+ # generate image
120
+ generator = torch.manual_seed(0)
121
+ image = pipe(
122
+ prompt, num_inference_steps=20, generator=generator, image=control_image
123
+ ).images[0]
124
+ image.save("./output.png")
125
+ ```
126
+
127
+ ## Notes
128
+
129
+ ### Specifying a better VAE
130
+
131
+ SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of an alternative VAE (such as [`madebyollin/sdxl-vae-fp16-fix`](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
132
+
133
+ If you're using this VAE during training, you need to ensure you're using it during inference too. You do so by:
134
+
135
+ ```diff
136
+ + vae = AutoencoderKL.from_pretrained(vae_path_or_repo_id, torch_dtype=torch.float16)
137
+ controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
138
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
139
+ base_model_path, controlnet=controlnet, torch_dtype=torch.float16,
140
+ + vae=vae,
141
+ )
conditioning_image_1.png ADDED
conditioning_image_2.png ADDED
config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ControlNetModel",
3
+ "_diffusers_version": "0.33.0.dev0",
4
+ "_name_or_path": "lllyasviel/sd-controlnet-canny",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": null,
9
+ "attention_head_dim": 8,
10
+ "block_out_channels": [
11
+ 320,
12
+ 640,
13
+ 1280,
14
+ 1280
15
+ ],
16
+ "class_embed_type": null,
17
+ "conditioning_channels": 3,
18
+ "conditioning_embedding_out_channels": [
19
+ 16,
20
+ 32,
21
+ 96,
22
+ 256
23
+ ],
24
+ "controlnet_conditioning_channel_order": "rgb",
25
+ "cross_attention_dim": 768,
26
+ "down_block_types": [
27
+ "CrossAttnDownBlock2D",
28
+ "CrossAttnDownBlock2D",
29
+ "CrossAttnDownBlock2D",
30
+ "DownBlock2D"
31
+ ],
32
+ "downsample_padding": 1,
33
+ "encoder_hid_dim": null,
34
+ "encoder_hid_dim_type": null,
35
+ "flip_sin_to_cos": true,
36
+ "freq_shift": 0,
37
+ "global_pool_conditions": false,
38
+ "in_channels": 4,
39
+ "layers_per_block": 2,
40
+ "mid_block_scale_factor": 1,
41
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
42
+ "norm_eps": 1e-05,
43
+ "norm_num_groups": 32,
44
+ "num_attention_heads": null,
45
+ "num_class_embeds": null,
46
+ "only_cross_attention": false,
47
+ "projection_class_embeddings_input_dim": null,
48
+ "resnet_time_scale_shift": "default",
49
+ "transformer_layers_per_block": 1,
50
+ "upcast_attention": false,
51
+ "use_linear_projection": false
52
+ }
diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b554b86ce5bfdda631f6aa02875fe5cfeabbb019b1de2eddaaf5f939cc9191b
3
+ size 1445157120
image_control.png ADDED
images_0.png ADDED

Git LFS Details

  • SHA256: ae023af508eb36f6ecc7197c4372b5a97f9c604fe648e08d02f49f3ba09d1706
  • Pointer size: 132 Bytes
  • Size of remote file: 1.78 MB
images_1.png ADDED

Git LFS Details

  • SHA256: ab0a8fe9e051ec9c79277627bd01530caad0bf78f465dc1fd1b3f4dcef1acf0b
  • Pointer size: 132 Bytes
  • Size of remote file: 2.39 MB
logs/controlnet-2210/1743377945.961673/events.out.tfevents.1743377945.79c7fdbf50b7.8420.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:389766e60dd13e2b0eed03964d3ab8b8bcc8cd678fa14134ce4467d466684b9c
3
+ size 2436
logs/controlnet-2210/1743377945.9631884/hparams.yml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adam_beta1: 0.9
2
+ adam_beta2: 0.999
3
+ adam_epsilon: 1.0e-08
4
+ adam_weight_decay: 0.01
5
+ allow_tf32: false
6
+ cache_dir: null
7
+ caption_column: text
8
+ checkpointing_steps: 500
9
+ checkpoints_total_limit: null
10
+ conditioning_image_column: conditioning_image
11
+ controlnet_model_name_or_path: lllyasviel/sd-controlnet-canny
12
+ dataloader_num_workers: 0
13
+ dataset_config_name: null
14
+ dataset_name: brity01/csc2210_1
15
+ enable_xformers_memory_efficient_attention: false
16
+ gradient_accumulation_steps: 4
17
+ gradient_checkpointing: false
18
+ hub_model_id: controlnet-2210
19
+ hub_token: null
20
+ image_column: image
21
+ learning_rate: 1.0e-05
22
+ logging_dir: logs
23
+ lr_num_cycles: 1
24
+ lr_power: 1.0
25
+ lr_scheduler: constant
26
+ lr_warmup_steps: 500
27
+ max_grad_norm: 1.0
28
+ max_train_samples: null
29
+ max_train_steps: 62
30
+ mixed_precision: null
31
+ num_train_epochs: 1
32
+ num_validation_images: 4
33
+ output_dir: ./
34
+ pretrained_model_name_or_path: stable-diffusion-v1-5/stable-diffusion-v1-5
35
+ proportion_empty_prompts: 0
36
+ push_to_hub: true
37
+ report_to: tensorboard
38
+ resolution: 512
39
+ resume_from_checkpoint: null
40
+ revision: null
41
+ scale_lr: false
42
+ seed: null
43
+ set_grads_to_none: false
44
+ tokenizer_name: null
45
+ tracker_project_name: controlnet-2210
46
+ train_batch_size: 1
47
+ train_data_dir: null
48
+ use_8bit_adam: false
49
+ validation_steps: 100
50
+ variant: null
logs/controlnet-2210/events.out.tfevents.1743377945.79c7fdbf50b7.8420.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:108cf9cd114aeee66f4524efbc9c1053764fcf1dbf2dd2c7d6ae5213888357e6
3
+ size 4185798
logs/train_controlnet/1743377496.261348/events.out.tfevents.1743377496.79c7fdbf50b7.6238.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:769b6306a2c6336554efc36cc87711fdc972836b04518f5a44ddcfa2926bd122
3
+ size 2384
logs/train_controlnet/1743377496.2628558/hparams.yml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adam_beta1: 0.9
2
+ adam_beta2: 0.999
3
+ adam_epsilon: 1.0e-08
4
+ adam_weight_decay: 0.01
5
+ allow_tf32: false
6
+ cache_dir: null
7
+ caption_column: text
8
+ checkpointing_steps: 500
9
+ checkpoints_total_limit: null
10
+ conditioning_image_column: conditioning_image
11
+ controlnet_model_name_or_path: lllyasviel/sd-controlnet-canny
12
+ dataloader_num_workers: 0
13
+ dataset_config_name: null
14
+ dataset_name: brity01/csc2210_1
15
+ enable_xformers_memory_efficient_attention: false
16
+ gradient_accumulation_steps: 4
17
+ gradient_checkpointing: false
18
+ hub_model_id: null
19
+ hub_token: null
20
+ image_column: image
21
+ learning_rate: 1.0e-05
22
+ logging_dir: logs
23
+ lr_num_cycles: 1
24
+ lr_power: 1.0
25
+ lr_scheduler: constant
26
+ lr_warmup_steps: 500
27
+ max_grad_norm: 1.0
28
+ max_train_samples: null
29
+ max_train_steps: 62
30
+ mixed_precision: null
31
+ num_train_epochs: 1
32
+ num_validation_images: 4
33
+ output_dir: ./
34
+ pretrained_model_name_or_path: stable-diffusion-v1-5/stable-diffusion-v1-5
35
+ proportion_empty_prompts: 0
36
+ push_to_hub: false
37
+ report_to: tensorboard
38
+ resolution: 512
39
+ resume_from_checkpoint: null
40
+ revision: null
41
+ scale_lr: false
42
+ seed: null
43
+ set_grads_to_none: false
44
+ tokenizer_name: null
45
+ tracker_project_name: train_controlnet
46
+ train_batch_size: 1
47
+ train_data_dir: null
48
+ use_8bit_adam: false
49
+ validation_steps: 100
50
+ variant: null
logs/train_controlnet/events.out.tfevents.1743377496.79c7fdbf50b7.6238.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f84680b134f5cea5fe92681258df13a8679fc4b10cabf0dc998ff68d47c0c89
3
+ size 4251047
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ accelerate>=0.16.0
2
+ torchvision
3
+ transformers>=4.25.1
4
+ ftfy
5
+ tensorboard
6
+ datasets
requirements_flax.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ transformers>=4.25.1
2
+ datasets
3
+ flax
4
+ optax
5
+ torch
6
+ torchvision
7
+ ftfy
8
+ tensorboard
9
+ Jinja2
requirements_flux.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ accelerate>=0.16.0
2
+ torchvision
3
+ transformers>=4.25.1
4
+ ftfy
5
+ tensorboard
6
+ Jinja2
7
+ datasets
8
+ wandb
9
+ SentencePiece
requirements_sd3.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ accelerate>=0.16.0
2
+ torchvision
3
+ transformers>=4.25.1
4
+ ftfy
5
+ tensorboard
6
+ Jinja2
7
+ datasets
8
+ wandb
requirements_sdxl.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ accelerate>=0.16.0
2
+ torchvision
3
+ transformers>=4.25.1
4
+ ftfy
5
+ tensorboard
6
+ Jinja2
7
+ datasets
8
+ wandb
test_controlnet.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ import os
18
+ import sys
19
+ import tempfile
20
+
21
+
22
+ sys.path.append("..")
23
+ from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
24
+
25
+
26
+ logging.basicConfig(level=logging.DEBUG)
27
+
28
+ logger = logging.getLogger()
29
+ stream_handler = logging.StreamHandler(sys.stdout)
30
+ logger.addHandler(stream_handler)
31
+
32
+
33
+ class ControlNet(ExamplesTestsAccelerate):
34
+ def test_controlnet_checkpointing_checkpoints_total_limit(self):
35
+ with tempfile.TemporaryDirectory() as tmpdir:
36
+ test_args = f"""
37
+ examples/controlnet/train_controlnet.py
38
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
39
+ --dataset_name=hf-internal-testing/fill10
40
+ --output_dir={tmpdir}
41
+ --resolution=64
42
+ --train_batch_size=1
43
+ --gradient_accumulation_steps=1
44
+ --max_train_steps=6
45
+ --checkpoints_total_limit=2
46
+ --checkpointing_steps=2
47
+ --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
48
+ """.split()
49
+
50
+ run_command(self._launch_args + test_args)
51
+
52
+ self.assertEqual(
53
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
54
+ {"checkpoint-4", "checkpoint-6"},
55
+ )
56
+
57
+ def test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
58
+ with tempfile.TemporaryDirectory() as tmpdir:
59
+ test_args = f"""
60
+ examples/controlnet/train_controlnet.py
61
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
62
+ --dataset_name=hf-internal-testing/fill10
63
+ --output_dir={tmpdir}
64
+ --resolution=64
65
+ --train_batch_size=1
66
+ --gradient_accumulation_steps=1
67
+ --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
68
+ --max_train_steps=6
69
+ --checkpointing_steps=2
70
+ """.split()
71
+
72
+ run_command(self._launch_args + test_args)
73
+
74
+ self.assertEqual(
75
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
76
+ {"checkpoint-2", "checkpoint-4", "checkpoint-6"},
77
+ )
78
+
79
+ resume_run_args = f"""
80
+ examples/controlnet/train_controlnet.py
81
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
82
+ --dataset_name=hf-internal-testing/fill10
83
+ --output_dir={tmpdir}
84
+ --resolution=64
85
+ --train_batch_size=1
86
+ --gradient_accumulation_steps=1
87
+ --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
88
+ --max_train_steps=8
89
+ --checkpointing_steps=2
90
+ --resume_from_checkpoint=checkpoint-6
91
+ --checkpoints_total_limit=2
92
+ """.split()
93
+
94
+ run_command(self._launch_args + resume_run_args)
95
+
96
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
97
+
98
+
99
+ class ControlNetSDXL(ExamplesTestsAccelerate):
100
+ def test_controlnet_sdxl(self):
101
+ with tempfile.TemporaryDirectory() as tmpdir:
102
+ test_args = f"""
103
+ examples/controlnet/train_controlnet_sdxl.py
104
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe
105
+ --dataset_name=hf-internal-testing/fill10
106
+ --output_dir={tmpdir}
107
+ --resolution=64
108
+ --train_batch_size=1
109
+ --gradient_accumulation_steps=1
110
+ --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet-sdxl
111
+ --max_train_steps=4
112
+ --checkpointing_steps=2
113
+ """.split()
114
+
115
+ run_command(self._launch_args + test_args)
116
+
117
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
118
+
119
+
120
+ class ControlNetSD3(ExamplesTestsAccelerate):
121
+ def test_controlnet_sd3(self):
122
+ with tempfile.TemporaryDirectory() as tmpdir:
123
+ test_args = f"""
124
+ examples/controlnet/train_controlnet_sd3.py
125
+ --pretrained_model_name_or_path=DavyMorgan/tiny-sd3-pipe
126
+ --dataset_name=hf-internal-testing/fill10
127
+ --output_dir={tmpdir}
128
+ --resolution=64
129
+ --train_batch_size=1
130
+ --gradient_accumulation_steps=1
131
+ --controlnet_model_name_or_path=DavyMorgan/tiny-controlnet-sd3
132
+ --max_train_steps=4
133
+ --checkpointing_steps=2
134
+ """.split()
135
+
136
+ run_command(self._launch_args + test_args)
137
+
138
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
139
+
140
+
141
+ class ControlNetSD35(ExamplesTestsAccelerate):
142
+ def test_controlnet_sd3(self):
143
+ with tempfile.TemporaryDirectory() as tmpdir:
144
+ test_args = f"""
145
+ examples/controlnet/train_controlnet_sd3.py
146
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-sd35-pipe
147
+ --dataset_name=hf-internal-testing/fill10
148
+ --output_dir={tmpdir}
149
+ --resolution=64
150
+ --train_batch_size=1
151
+ --gradient_accumulation_steps=1
152
+ --controlnet_model_name_or_path=DavyMorgan/tiny-controlnet-sd35
153
+ --max_train_steps=4
154
+ --checkpointing_steps=2
155
+ """.split()
156
+
157
+ run_command(self._launch_args + test_args)
158
+
159
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
160
+
161
+
162
+ class ControlNetflux(ExamplesTestsAccelerate):
163
+ def test_controlnet_flux(self):
164
+ with tempfile.TemporaryDirectory() as tmpdir:
165
+ test_args = f"""
166
+ examples/controlnet/train_controlnet_flux.py
167
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-flux-pipe
168
+ --output_dir={tmpdir}
169
+ --dataset_name=hf-internal-testing/fill10
170
+ --conditioning_image_column=conditioning_image
171
+ --image_column=image
172
+ --caption_column=text
173
+ --resolution=64
174
+ --train_batch_size=1
175
+ --gradient_accumulation_steps=1
176
+ --max_train_steps=4
177
+ --checkpointing_steps=2
178
+ --num_double_layers=1
179
+ --num_single_layers=1
180
+ """.split()
181
+
182
+ run_command(self._launch_args + test_args)
183
+
184
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
train_controlnet.py ADDED
@@ -0,0 +1,1185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import contextlib
18
+ import gc
19
+ import logging
20
+ import math
21
+ import os
22
+ import random
23
+ import shutil
24
+ from pathlib import Path
25
+
26
+ import accelerate
27
+ import numpy as np
28
+ import torch
29
+ import torch.nn.functional as F
30
+ import torch.utils.checkpoint
31
+ import transformers
32
+ from accelerate import Accelerator
33
+ from accelerate.logging import get_logger
34
+ from accelerate.utils import ProjectConfiguration, set_seed
35
+ from datasets import load_dataset
36
+ from huggingface_hub import create_repo, upload_folder
37
+ from packaging import version
38
+ from PIL import Image
39
+ from torchvision import transforms
40
+ from tqdm.auto import tqdm
41
+ from transformers import AutoTokenizer, PretrainedConfig
42
+
43
+ import diffusers
44
+ from diffusers import (
45
+ AutoencoderKL,
46
+ ControlNetModel,
47
+ DDPMScheduler,
48
+ StableDiffusionControlNetPipeline,
49
+ UNet2DConditionModel,
50
+ UniPCMultistepScheduler,
51
+ )
52
+ from diffusers.optimization import get_scheduler
53
+ from diffusers.utils import check_min_version, is_wandb_available
54
+ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
55
+ from diffusers.utils.import_utils import is_xformers_available
56
+ from diffusers.utils.torch_utils import is_compiled_module
57
+
58
+
59
+ if is_wandb_available():
60
+ import wandb
61
+
62
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
63
+ check_min_version("0.33.0.dev0")
64
+
65
+ logger = get_logger(__name__)
66
+
67
+
68
+ def image_grid(imgs, rows, cols):
69
+ assert len(imgs) == rows * cols
70
+
71
+ w, h = imgs[0].size
72
+ grid = Image.new("RGB", size=(cols * w, rows * h))
73
+
74
+ for i, img in enumerate(imgs):
75
+ grid.paste(img, box=(i % cols * w, i // cols * h))
76
+ return grid
77
+
78
+
79
+ def log_validation(
80
+ vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False
81
+ ):
82
+ logger.info("Running validation... ")
83
+
84
+ if not is_final_validation:
85
+ controlnet = accelerator.unwrap_model(controlnet)
86
+ else:
87
+ controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
88
+
89
+ pipeline = StableDiffusionControlNetPipeline.from_pretrained(
90
+ args.pretrained_model_name_or_path,
91
+ vae=vae,
92
+ text_encoder=text_encoder,
93
+ tokenizer=tokenizer,
94
+ unet=unet,
95
+ controlnet=controlnet,
96
+ safety_checker=None,
97
+ revision=args.revision,
98
+ variant=args.variant,
99
+ torch_dtype=weight_dtype,
100
+ )
101
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
102
+ pipeline = pipeline.to(accelerator.device)
103
+ pipeline.set_progress_bar_config(disable=True)
104
+
105
+ if args.enable_xformers_memory_efficient_attention:
106
+ pipeline.enable_xformers_memory_efficient_attention()
107
+
108
+ if args.seed is None:
109
+ generator = None
110
+ else:
111
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
112
+
113
+ if len(args.validation_image) == len(args.validation_prompt):
114
+ validation_images = args.validation_image
115
+ validation_prompts = args.validation_prompt
116
+ elif len(args.validation_image) == 1:
117
+ validation_images = args.validation_image * len(args.validation_prompt)
118
+ validation_prompts = args.validation_prompt
119
+ elif len(args.validation_prompt) == 1:
120
+ validation_images = args.validation_image
121
+ validation_prompts = args.validation_prompt * len(args.validation_image)
122
+ else:
123
+ raise ValueError(
124
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
125
+ )
126
+
127
+ image_logs = []
128
+ inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda")
129
+
130
+ for validation_prompt, validation_image in zip(validation_prompts, validation_images):
131
+ validation_image = Image.open(validation_image).convert("RGB")
132
+
133
+ images = []
134
+
135
+ for _ in range(args.num_validation_images):
136
+ with inference_ctx:
137
+ image = pipeline(
138
+ validation_prompt, validation_image, num_inference_steps=20, generator=generator
139
+ ).images[0]
140
+
141
+ images.append(image)
142
+
143
+ image_logs.append(
144
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
145
+ )
146
+
147
+ tracker_key = "test" if is_final_validation else "validation"
148
+ for tracker in accelerator.trackers:
149
+ if tracker.name == "tensorboard":
150
+ for log in image_logs:
151
+ images = log["images"]
152
+ validation_prompt = log["validation_prompt"]
153
+ validation_image = log["validation_image"]
154
+
155
+ formatted_images = [np.asarray(validation_image)]
156
+
157
+ for image in images:
158
+ formatted_images.append(np.asarray(image))
159
+
160
+ formatted_images = np.stack(formatted_images)
161
+
162
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
163
+ elif tracker.name == "wandb":
164
+ formatted_images = []
165
+
166
+ for log in image_logs:
167
+ images = log["images"]
168
+ validation_prompt = log["validation_prompt"]
169
+ validation_image = log["validation_image"]
170
+
171
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
172
+
173
+ for image in images:
174
+ image = wandb.Image(image, caption=validation_prompt)
175
+ formatted_images.append(image)
176
+
177
+ tracker.log({tracker_key: formatted_images})
178
+ else:
179
+ logger.warning(f"image logging not implemented for {tracker.name}")
180
+
181
+ del pipeline
182
+ gc.collect()
183
+ torch.cuda.empty_cache()
184
+
185
+ return image_logs
186
+
187
+
188
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
189
+ text_encoder_config = PretrainedConfig.from_pretrained(
190
+ pretrained_model_name_or_path,
191
+ subfolder="text_encoder",
192
+ revision=revision,
193
+ )
194
+ model_class = text_encoder_config.architectures[0]
195
+
196
+ if model_class == "CLIPTextModel":
197
+ from transformers import CLIPTextModel
198
+
199
+ return CLIPTextModel
200
+ elif model_class == "RobertaSeriesModelWithTransformation":
201
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
202
+
203
+ return RobertaSeriesModelWithTransformation
204
+ else:
205
+ raise ValueError(f"{model_class} is not supported.")
206
+
207
+
208
+ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
209
+ img_str = ""
210
+ if image_logs is not None:
211
+ img_str = "You can find some example images below.\n\n"
212
+ for i, log in enumerate(image_logs):
213
+ images = log["images"]
214
+ validation_prompt = log["validation_prompt"]
215
+ validation_image = log["validation_image"]
216
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
217
+ img_str += f"prompt: {validation_prompt}\n"
218
+ images = [validation_image] + images
219
+ image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
220
+ img_str += f"![images_{i})](./images_{i}.png)\n"
221
+
222
+ model_description = f"""
223
+ # controlnet-{repo_id}
224
+
225
+ These are controlnet weights trained on {base_model} with new type of conditioning.
226
+ {img_str}
227
+ """
228
+ model_card = load_or_create_model_card(
229
+ repo_id_or_path=repo_id,
230
+ from_training=True,
231
+ license="creativeml-openrail-m",
232
+ base_model=base_model,
233
+ model_description=model_description,
234
+ inference=True,
235
+ )
236
+
237
+ tags = [
238
+ "stable-diffusion",
239
+ "stable-diffusion-diffusers",
240
+ "text-to-image",
241
+ "diffusers",
242
+ "controlnet",
243
+ "diffusers-training",
244
+ ]
245
+ model_card = populate_model_card(model_card, tags=tags)
246
+
247
+ model_card.save(os.path.join(repo_folder, "README.md"))
248
+
249
+
250
+ def parse_args(input_args=None):
251
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
252
+ parser.add_argument(
253
+ "--pretrained_model_name_or_path",
254
+ type=str,
255
+ default=None,
256
+ required=True,
257
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
258
+ )
259
+ parser.add_argument(
260
+ "--controlnet_model_name_or_path",
261
+ type=str,
262
+ default=None,
263
+ help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
264
+ " If not specified controlnet weights are initialized from unet.",
265
+ )
266
+ parser.add_argument(
267
+ "--revision",
268
+ type=str,
269
+ default=None,
270
+ required=False,
271
+ help="Revision of pretrained model identifier from huggingface.co/models.",
272
+ )
273
+ parser.add_argument(
274
+ "--variant",
275
+ type=str,
276
+ default=None,
277
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
278
+ )
279
+ parser.add_argument(
280
+ "--tokenizer_name",
281
+ type=str,
282
+ default=None,
283
+ help="Pretrained tokenizer name or path if not the same as model_name",
284
+ )
285
+ parser.add_argument(
286
+ "--output_dir",
287
+ type=str,
288
+ default="controlnet-model",
289
+ help="The output directory where the model predictions and checkpoints will be written.",
290
+ )
291
+ parser.add_argument(
292
+ "--cache_dir",
293
+ type=str,
294
+ default=None,
295
+ help="The directory where the downloaded models and datasets will be stored.",
296
+ )
297
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
298
+ parser.add_argument(
299
+ "--resolution",
300
+ type=int,
301
+ default=512,
302
+ help=(
303
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
304
+ " resolution"
305
+ ),
306
+ )
307
+ parser.add_argument(
308
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
309
+ )
310
+ parser.add_argument("--num_train_epochs", type=int, default=1)
311
+ parser.add_argument(
312
+ "--max_train_steps",
313
+ type=int,
314
+ default=None,
315
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
316
+ )
317
+ parser.add_argument(
318
+ "--checkpointing_steps",
319
+ type=int,
320
+ default=500,
321
+ help=(
322
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
323
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
324
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
325
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
326
+ "instructions."
327
+ ),
328
+ )
329
+ parser.add_argument(
330
+ "--checkpoints_total_limit",
331
+ type=int,
332
+ default=None,
333
+ help=("Max number of checkpoints to store."),
334
+ )
335
+ parser.add_argument(
336
+ "--resume_from_checkpoint",
337
+ type=str,
338
+ default=None,
339
+ help=(
340
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
341
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
342
+ ),
343
+ )
344
+ parser.add_argument(
345
+ "--gradient_accumulation_steps",
346
+ type=int,
347
+ default=1,
348
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
349
+ )
350
+ parser.add_argument(
351
+ "--gradient_checkpointing",
352
+ action="store_true",
353
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
354
+ )
355
+ parser.add_argument(
356
+ "--learning_rate",
357
+ type=float,
358
+ default=5e-6,
359
+ help="Initial learning rate (after the potential warmup period) to use.",
360
+ )
361
+ parser.add_argument(
362
+ "--scale_lr",
363
+ action="store_true",
364
+ default=False,
365
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
366
+ )
367
+ parser.add_argument(
368
+ "--lr_scheduler",
369
+ type=str,
370
+ default="constant",
371
+ help=(
372
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
373
+ ' "constant", "constant_with_warmup"]'
374
+ ),
375
+ )
376
+ parser.add_argument(
377
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
378
+ )
379
+ parser.add_argument(
380
+ "--lr_num_cycles",
381
+ type=int,
382
+ default=1,
383
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
384
+ )
385
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
386
+ parser.add_argument(
387
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
388
+ )
389
+ parser.add_argument(
390
+ "--dataloader_num_workers",
391
+ type=int,
392
+ default=0,
393
+ help=(
394
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
395
+ ),
396
+ )
397
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
398
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
399
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
400
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
401
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
402
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
403
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
404
+ parser.add_argument(
405
+ "--hub_model_id",
406
+ type=str,
407
+ default=None,
408
+ help="The name of the repository to keep in sync with the local `output_dir`.",
409
+ )
410
+ parser.add_argument(
411
+ "--logging_dir",
412
+ type=str,
413
+ default="logs",
414
+ help=(
415
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
416
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
417
+ ),
418
+ )
419
+ parser.add_argument(
420
+ "--allow_tf32",
421
+ action="store_true",
422
+ help=(
423
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
424
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
425
+ ),
426
+ )
427
+ parser.add_argument(
428
+ "--report_to",
429
+ type=str,
430
+ default="tensorboard",
431
+ help=(
432
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
433
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
434
+ ),
435
+ )
436
+ parser.add_argument(
437
+ "--mixed_precision",
438
+ type=str,
439
+ default=None,
440
+ choices=["no", "fp16", "bf16"],
441
+ help=(
442
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
443
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
444
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
445
+ ),
446
+ )
447
+ parser.add_argument(
448
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
449
+ )
450
+ parser.add_argument(
451
+ "--set_grads_to_none",
452
+ action="store_true",
453
+ help=(
454
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
455
+ " behaviors, so disable this argument if it causes any problems. More info:"
456
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
457
+ ),
458
+ )
459
+ parser.add_argument(
460
+ "--dataset_name",
461
+ type=str,
462
+ default=None,
463
+ help=(
464
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
465
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
466
+ " or to a folder containing files that 🤗 Datasets can understand."
467
+ ),
468
+ )
469
+ parser.add_argument(
470
+ "--dataset_config_name",
471
+ type=str,
472
+ default=None,
473
+ help="The config of the Dataset, leave as None if there's only one config.",
474
+ )
475
+ parser.add_argument(
476
+ "--train_data_dir",
477
+ type=str,
478
+ default=None,
479
+ help=(
480
+ "A folder containing the training data. Folder contents must follow the structure described in"
481
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
482
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
483
+ ),
484
+ )
485
+ parser.add_argument(
486
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
487
+ )
488
+ parser.add_argument(
489
+ "--conditioning_image_column",
490
+ type=str,
491
+ default="conditioning_image",
492
+ help="The column of the dataset containing the controlnet conditioning image.",
493
+ )
494
+ parser.add_argument(
495
+ "--caption_column",
496
+ type=str,
497
+ default="text",
498
+ help="The column of the dataset containing a caption or a list of captions.",
499
+ )
500
+ parser.add_argument(
501
+ "--max_train_samples",
502
+ type=int,
503
+ default=None,
504
+ help=(
505
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
506
+ "value if set."
507
+ ),
508
+ )
509
+ parser.add_argument(
510
+ "--proportion_empty_prompts",
511
+ type=float,
512
+ default=0,
513
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
514
+ )
515
+ parser.add_argument(
516
+ "--validation_prompt",
517
+ type=str,
518
+ default=None,
519
+ nargs="+",
520
+ help=(
521
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
522
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
523
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
524
+ ),
525
+ )
526
+ parser.add_argument(
527
+ "--validation_image",
528
+ type=str,
529
+ default=None,
530
+ nargs="+",
531
+ help=(
532
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
533
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
534
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
535
+ " `--validation_image` that will be used with all `--validation_prompt`s."
536
+ ),
537
+ )
538
+ parser.add_argument(
539
+ "--num_validation_images",
540
+ type=int,
541
+ default=4,
542
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
543
+ )
544
+ parser.add_argument(
545
+ "--validation_steps",
546
+ type=int,
547
+ default=100,
548
+ help=(
549
+ "Run validation every X steps. Validation consists of running the prompt"
550
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
551
+ " and logging the images."
552
+ ),
553
+ )
554
+ parser.add_argument(
555
+ "--tracker_project_name",
556
+ type=str,
557
+ default="train_controlnet",
558
+ help=(
559
+ "The `project_name` argument passed to Accelerator.init_trackers for"
560
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
561
+ ),
562
+ )
563
+
564
+ if input_args is not None:
565
+ args = parser.parse_args(input_args)
566
+ else:
567
+ args = parser.parse_args()
568
+
569
+ if args.dataset_name is None and args.train_data_dir is None:
570
+ raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
571
+
572
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
573
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
574
+
575
+ if args.validation_prompt is not None and args.validation_image is None:
576
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
577
+
578
+ if args.validation_prompt is None and args.validation_image is not None:
579
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
580
+
581
+ if (
582
+ args.validation_image is not None
583
+ and args.validation_prompt is not None
584
+ and len(args.validation_image) != 1
585
+ and len(args.validation_prompt) != 1
586
+ and len(args.validation_image) != len(args.validation_prompt)
587
+ ):
588
+ raise ValueError(
589
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
590
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
591
+ )
592
+
593
+ if args.resolution % 8 != 0:
594
+ raise ValueError(
595
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
596
+ )
597
+
598
+ return args
599
+
600
+
601
+ def make_train_dataset(args, tokenizer, accelerator):
602
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
603
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
604
+
605
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
606
+ # download the dataset.
607
+ if args.dataset_name is not None:
608
+ # Downloading and loading a dataset from the hub.
609
+ dataset = load_dataset(
610
+ args.dataset_name,
611
+ args.dataset_config_name,
612
+ cache_dir=args.cache_dir,
613
+ data_dir=args.train_data_dir,
614
+ )
615
+ else:
616
+ if args.train_data_dir is not None:
617
+ dataset = load_dataset(
618
+ args.train_data_dir,
619
+ cache_dir=args.cache_dir,
620
+ )
621
+ # See more about loading custom images at
622
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
623
+
624
+ # Preprocessing the datasets.
625
+ # We need to tokenize inputs and targets.
626
+ column_names = dataset["train"].column_names
627
+
628
+ # 6. Get the column names for input/target.
629
+ if args.image_column is None:
630
+ image_column = column_names[0]
631
+ logger.info(f"image column defaulting to {image_column}")
632
+ else:
633
+ image_column = args.image_column
634
+ if image_column not in column_names:
635
+ raise ValueError(
636
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
637
+ )
638
+
639
+ if args.caption_column is None:
640
+ caption_column = column_names[1]
641
+ logger.info(f"caption column defaulting to {caption_column}")
642
+ else:
643
+ caption_column = args.caption_column
644
+ if caption_column not in column_names:
645
+ raise ValueError(
646
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
647
+ )
648
+
649
+ if args.conditioning_image_column is None:
650
+ conditioning_image_column = column_names[2]
651
+ logger.info(f"conditioning image column defaulting to {conditioning_image_column}")
652
+ else:
653
+ conditioning_image_column = args.conditioning_image_column
654
+ if conditioning_image_column not in column_names:
655
+ raise ValueError(
656
+ f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
657
+ )
658
+
659
+ def tokenize_captions(examples, is_train=True):
660
+ captions = []
661
+ for caption in examples[caption_column]:
662
+ if random.random() < args.proportion_empty_prompts:
663
+ captions.append("")
664
+ elif isinstance(caption, str):
665
+ captions.append(caption)
666
+ elif isinstance(caption, (list, np.ndarray)):
667
+ # take a random caption if there are multiple
668
+ captions.append(random.choice(caption) if is_train else caption[0])
669
+ else:
670
+ raise ValueError(
671
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
672
+ )
673
+ inputs = tokenizer(
674
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
675
+ )
676
+ return inputs.input_ids
677
+
678
+ image_transforms = transforms.Compose(
679
+ [
680
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
681
+ transforms.CenterCrop(args.resolution),
682
+ transforms.ToTensor(),
683
+ transforms.Normalize([0.5], [0.5]),
684
+ ]
685
+ )
686
+
687
+ conditioning_image_transforms = transforms.Compose(
688
+ [
689
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
690
+ transforms.CenterCrop(args.resolution),
691
+ transforms.ToTensor(),
692
+ ]
693
+ )
694
+
695
+ def preprocess_train(examples):
696
+ images = [image.convert("RGB") for image in examples[image_column]]
697
+ images = [image_transforms(image) for image in images]
698
+
699
+ conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]]
700
+ conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
701
+
702
+ examples["pixel_values"] = images
703
+ examples["conditioning_pixel_values"] = conditioning_images
704
+ examples["input_ids"] = tokenize_captions(examples)
705
+
706
+ return examples
707
+
708
+ with accelerator.main_process_first():
709
+ if args.max_train_samples is not None:
710
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
711
+ # Set the training transforms
712
+ train_dataset = dataset["train"].with_transform(preprocess_train)
713
+
714
+ return train_dataset
715
+
716
+
717
+ def collate_fn(examples):
718
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
719
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
720
+
721
+ conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
722
+ conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
723
+
724
+ input_ids = torch.stack([example["input_ids"] for example in examples])
725
+
726
+ return {
727
+ "pixel_values": pixel_values,
728
+ "conditioning_pixel_values": conditioning_pixel_values,
729
+ "input_ids": input_ids,
730
+ }
731
+
732
+
733
+ def main(args):
734
+ if args.report_to == "wandb" and args.hub_token is not None:
735
+ raise ValueError(
736
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
737
+ " Please use `huggingface-cli login` to authenticate with the Hub."
738
+ )
739
+
740
+ logging_dir = Path(args.output_dir, args.logging_dir)
741
+
742
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
743
+
744
+ accelerator = Accelerator(
745
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
746
+ mixed_precision=args.mixed_precision,
747
+ log_with=args.report_to,
748
+ project_config=accelerator_project_config,
749
+ )
750
+
751
+ # Disable AMP for MPS.
752
+ if torch.backends.mps.is_available():
753
+ accelerator.native_amp = False
754
+
755
+ # Make one log on every process with the configuration for debugging.
756
+ logging.basicConfig(
757
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
758
+ datefmt="%m/%d/%Y %H:%M:%S",
759
+ level=logging.INFO,
760
+ )
761
+ logger.info(accelerator.state, main_process_only=False)
762
+ if accelerator.is_local_main_process:
763
+ transformers.utils.logging.set_verbosity_warning()
764
+ diffusers.utils.logging.set_verbosity_info()
765
+ else:
766
+ transformers.utils.logging.set_verbosity_error()
767
+ diffusers.utils.logging.set_verbosity_error()
768
+
769
+ # If passed along, set the training seed now.
770
+ if args.seed is not None:
771
+ set_seed(args.seed)
772
+
773
+ # Handle the repository creation
774
+ if accelerator.is_main_process:
775
+ if args.output_dir is not None:
776
+ os.makedirs(args.output_dir, exist_ok=True)
777
+
778
+ if args.push_to_hub:
779
+ repo_id = create_repo(
780
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
781
+ ).repo_id
782
+
783
+ # Load the tokenizer
784
+ if args.tokenizer_name:
785
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
786
+ elif args.pretrained_model_name_or_path:
787
+ tokenizer = AutoTokenizer.from_pretrained(
788
+ args.pretrained_model_name_or_path,
789
+ subfolder="tokenizer",
790
+ revision=args.revision,
791
+ use_fast=False,
792
+ )
793
+
794
+ # import correct text encoder class
795
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
796
+
797
+ # Load scheduler and models
798
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
799
+ text_encoder = text_encoder_cls.from_pretrained(
800
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
801
+ )
802
+ vae = AutoencoderKL.from_pretrained(
803
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
804
+ )
805
+ unet = UNet2DConditionModel.from_pretrained(
806
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
807
+ )
808
+
809
+ if args.controlnet_model_name_or_path:
810
+ logger.info("Loading existing controlnet weights")
811
+ controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
812
+ else:
813
+ logger.info("Initializing controlnet weights from unet")
814
+ controlnet = ControlNetModel.from_unet(unet)
815
+
816
+ # Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
817
+ def unwrap_model(model):
818
+ model = accelerator.unwrap_model(model)
819
+ model = model._orig_mod if is_compiled_module(model) else model
820
+ return model
821
+
822
+ # `accelerate` 0.16.0 will have better support for customized saving
823
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
824
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
825
+ def save_model_hook(models, weights, output_dir):
826
+ if accelerator.is_main_process:
827
+ i = len(weights) - 1
828
+
829
+ while len(weights) > 0:
830
+ weights.pop()
831
+ model = models[i]
832
+
833
+ sub_dir = "controlnet"
834
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
835
+
836
+ i -= 1
837
+
838
+ def load_model_hook(models, input_dir):
839
+ while len(models) > 0:
840
+ # pop models so that they are not loaded again
841
+ model = models.pop()
842
+
843
+ # load diffusers style into model
844
+ load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet")
845
+ model.register_to_config(**load_model.config)
846
+
847
+ model.load_state_dict(load_model.state_dict())
848
+ del load_model
849
+
850
+ accelerator.register_save_state_pre_hook(save_model_hook)
851
+ accelerator.register_load_state_pre_hook(load_model_hook)
852
+
853
+ vae.requires_grad_(False)
854
+ unet.requires_grad_(False)
855
+ text_encoder.requires_grad_(False)
856
+ controlnet.train()
857
+
858
+ if args.enable_xformers_memory_efficient_attention:
859
+ if is_xformers_available():
860
+ import xformers
861
+
862
+ xformers_version = version.parse(xformers.__version__)
863
+ if xformers_version == version.parse("0.0.16"):
864
+ logger.warning(
865
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
866
+ )
867
+ unet.enable_xformers_memory_efficient_attention()
868
+ controlnet.enable_xformers_memory_efficient_attention()
869
+ else:
870
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
871
+
872
+ if args.gradient_checkpointing:
873
+ controlnet.enable_gradient_checkpointing()
874
+
875
+ # Check that all trainable models are in full precision
876
+ low_precision_error_string = (
877
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
878
+ " doing mixed precision training, copy of the weights should still be float32."
879
+ )
880
+
881
+ if unwrap_model(controlnet).dtype != torch.float32:
882
+ raise ValueError(
883
+ f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}"
884
+ )
885
+
886
+ # Enable TF32 for faster training on Ampere GPUs,
887
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
888
+ if args.allow_tf32:
889
+ torch.backends.cuda.matmul.allow_tf32 = True
890
+
891
+ if args.scale_lr:
892
+ args.learning_rate = (
893
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
894
+ )
895
+
896
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
897
+ if args.use_8bit_adam:
898
+ try:
899
+ import bitsandbytes as bnb
900
+ except ImportError:
901
+ raise ImportError(
902
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
903
+ )
904
+
905
+ optimizer_class = bnb.optim.AdamW8bit
906
+ else:
907
+ optimizer_class = torch.optim.AdamW
908
+
909
+ # Optimizer creation
910
+ params_to_optimize = controlnet.parameters()
911
+ optimizer = optimizer_class(
912
+ params_to_optimize,
913
+ lr=args.learning_rate,
914
+ betas=(args.adam_beta1, args.adam_beta2),
915
+ weight_decay=args.adam_weight_decay,
916
+ eps=args.adam_epsilon,
917
+ )
918
+
919
+ train_dataset = make_train_dataset(args, tokenizer, accelerator)
920
+
921
+ train_dataloader = torch.utils.data.DataLoader(
922
+ train_dataset,
923
+ shuffle=True,
924
+ collate_fn=collate_fn,
925
+ batch_size=args.train_batch_size,
926
+ num_workers=args.dataloader_num_workers,
927
+ )
928
+
929
+ # Scheduler and math around the number of training steps.
930
+ overrode_max_train_steps = False
931
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
932
+ if args.max_train_steps is None:
933
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
934
+ overrode_max_train_steps = True
935
+
936
+ lr_scheduler = get_scheduler(
937
+ args.lr_scheduler,
938
+ optimizer=optimizer,
939
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
940
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
941
+ num_cycles=args.lr_num_cycles,
942
+ power=args.lr_power,
943
+ )
944
+
945
+ # Prepare everything with our `accelerator`.
946
+ controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
947
+ controlnet, optimizer, train_dataloader, lr_scheduler
948
+ )
949
+
950
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
951
+ # as these models are only used for inference, keeping weights in full precision is not required.
952
+ weight_dtype = torch.float32
953
+ if accelerator.mixed_precision == "fp16":
954
+ weight_dtype = torch.float16
955
+ elif accelerator.mixed_precision == "bf16":
956
+ weight_dtype = torch.bfloat16
957
+
958
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
959
+ vae.to(accelerator.device, dtype=weight_dtype)
960
+ unet.to(accelerator.device, dtype=weight_dtype)
961
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
962
+
963
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
964
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
965
+ if overrode_max_train_steps:
966
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
967
+ # Afterwards we recalculate our number of training epochs
968
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
969
+
970
+ # We need to initialize the trackers we use, and also store our configuration.
971
+ # The trackers initializes automatically on the main process.
972
+ if accelerator.is_main_process:
973
+ tracker_config = dict(vars(args))
974
+
975
+ # tensorboard cannot handle list types for config
976
+ tracker_config.pop("validation_prompt")
977
+ tracker_config.pop("validation_image")
978
+
979
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
980
+
981
+ # Train!
982
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
983
+
984
+ logger.info("***** Running training *****")
985
+ logger.info(f" Num examples = {len(train_dataset)}")
986
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
987
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
988
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
989
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
990
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
991
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
992
+ global_step = 0
993
+ first_epoch = 0
994
+
995
+ # Potentially load in the weights and states from a previous save
996
+ if args.resume_from_checkpoint:
997
+ if args.resume_from_checkpoint != "latest":
998
+ path = os.path.basename(args.resume_from_checkpoint)
999
+ else:
1000
+ # Get the most recent checkpoint
1001
+ dirs = os.listdir(args.output_dir)
1002
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1003
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1004
+ path = dirs[-1] if len(dirs) > 0 else None
1005
+
1006
+ if path is None:
1007
+ accelerator.print(
1008
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1009
+ )
1010
+ args.resume_from_checkpoint = None
1011
+ initial_global_step = 0
1012
+ else:
1013
+ accelerator.print(f"Resuming from checkpoint {path}")
1014
+ accelerator.load_state(os.path.join(args.output_dir, path))
1015
+ global_step = int(path.split("-")[1])
1016
+
1017
+ initial_global_step = global_step
1018
+ first_epoch = global_step // num_update_steps_per_epoch
1019
+ else:
1020
+ initial_global_step = 0
1021
+
1022
+ progress_bar = tqdm(
1023
+ range(0, args.max_train_steps),
1024
+ initial=initial_global_step,
1025
+ desc="Steps",
1026
+ # Only show the progress bar once on each machine.
1027
+ disable=not accelerator.is_local_main_process,
1028
+ )
1029
+
1030
+ image_logs = None
1031
+ for epoch in range(first_epoch, args.num_train_epochs):
1032
+ for step, batch in enumerate(train_dataloader):
1033
+ with accelerator.accumulate(controlnet):
1034
+ # Convert images to latent space
1035
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
1036
+ latents = latents * vae.config.scaling_factor
1037
+
1038
+ # Sample noise that we'll add to the latents
1039
+ noise = torch.randn_like(latents)
1040
+ bsz = latents.shape[0]
1041
+ # Sample a random timestep for each image
1042
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
1043
+ timesteps = timesteps.long()
1044
+
1045
+ # Add noise to the latents according to the noise magnitude at each timestep
1046
+ # (this is the forward diffusion process)
1047
+ noisy_latents = noise_scheduler.add_noise(latents.float(), noise.float(), timesteps).to(
1048
+ dtype=weight_dtype
1049
+ )
1050
+
1051
+ # Get the text embedding for conditioning
1052
+ encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
1053
+
1054
+ controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
1055
+
1056
+ down_block_res_samples, mid_block_res_sample = controlnet(
1057
+ noisy_latents,
1058
+ timesteps,
1059
+ encoder_hidden_states=encoder_hidden_states,
1060
+ controlnet_cond=controlnet_image,
1061
+ return_dict=False,
1062
+ )
1063
+
1064
+ # Predict the noise residual
1065
+ model_pred = unet(
1066
+ noisy_latents,
1067
+ timesteps,
1068
+ encoder_hidden_states=encoder_hidden_states,
1069
+ down_block_additional_residuals=[
1070
+ sample.to(dtype=weight_dtype) for sample in down_block_res_samples
1071
+ ],
1072
+ mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
1073
+ return_dict=False,
1074
+ )[0]
1075
+
1076
+ # Get the target for loss depending on the prediction type
1077
+ if noise_scheduler.config.prediction_type == "epsilon":
1078
+ target = noise
1079
+ elif noise_scheduler.config.prediction_type == "v_prediction":
1080
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
1081
+ else:
1082
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1083
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1084
+
1085
+ accelerator.backward(loss)
1086
+ if accelerator.sync_gradients:
1087
+ params_to_clip = controlnet.parameters()
1088
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1089
+ optimizer.step()
1090
+ lr_scheduler.step()
1091
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
1092
+
1093
+ # Checks if the accelerator has performed an optimization step behind the scenes
1094
+ if accelerator.sync_gradients:
1095
+ progress_bar.update(1)
1096
+ global_step += 1
1097
+
1098
+ if accelerator.is_main_process:
1099
+ if global_step % args.checkpointing_steps == 0:
1100
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1101
+ if args.checkpoints_total_limit is not None:
1102
+ checkpoints = os.listdir(args.output_dir)
1103
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1104
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1105
+
1106
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1107
+ if len(checkpoints) >= args.checkpoints_total_limit:
1108
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1109
+ removing_checkpoints = checkpoints[0:num_to_remove]
1110
+
1111
+ logger.info(
1112
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1113
+ )
1114
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1115
+
1116
+ for removing_checkpoint in removing_checkpoints:
1117
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1118
+ shutil.rmtree(removing_checkpoint)
1119
+
1120
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1121
+ accelerator.save_state(save_path)
1122
+ logger.info(f"Saved state to {save_path}")
1123
+
1124
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
1125
+ image_logs = log_validation(
1126
+ vae,
1127
+ text_encoder,
1128
+ tokenizer,
1129
+ unet,
1130
+ controlnet,
1131
+ args,
1132
+ accelerator,
1133
+ weight_dtype,
1134
+ global_step,
1135
+ )
1136
+
1137
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1138
+ progress_bar.set_postfix(**logs)
1139
+ accelerator.log(logs, step=global_step)
1140
+
1141
+ if global_step >= args.max_train_steps:
1142
+ break
1143
+
1144
+ # Create the pipeline using the trained modules and save it.
1145
+ accelerator.wait_for_everyone()
1146
+ if accelerator.is_main_process:
1147
+ controlnet = unwrap_model(controlnet)
1148
+ controlnet.save_pretrained(args.output_dir)
1149
+
1150
+ # Run a final round of validation.
1151
+ image_logs = None
1152
+ if args.validation_prompt is not None:
1153
+ image_logs = log_validation(
1154
+ vae=vae,
1155
+ text_encoder=text_encoder,
1156
+ tokenizer=tokenizer,
1157
+ unet=unet,
1158
+ controlnet=None,
1159
+ args=args,
1160
+ accelerator=accelerator,
1161
+ weight_dtype=weight_dtype,
1162
+ step=global_step,
1163
+ is_final_validation=True,
1164
+ )
1165
+
1166
+ if args.push_to_hub:
1167
+ save_model_card(
1168
+ repo_id,
1169
+ image_logs=image_logs,
1170
+ base_model=args.pretrained_model_name_or_path,
1171
+ repo_folder=args.output_dir,
1172
+ )
1173
+ upload_folder(
1174
+ repo_id=repo_id,
1175
+ folder_path=args.output_dir,
1176
+ commit_message="End of training",
1177
+ ignore_patterns=["step_*", "epoch_*"],
1178
+ )
1179
+
1180
+ accelerator.end_training()
1181
+
1182
+
1183
+ if __name__ == "__main__":
1184
+ args = parse_args()
1185
+ main(args)
train_controlnet_flax.py ADDED
@@ -0,0 +1,1152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import logging
18
+ import math
19
+ import os
20
+ import random
21
+ import time
22
+ from pathlib import Path
23
+
24
+ import jax
25
+ import jax.numpy as jnp
26
+ import numpy as np
27
+ import optax
28
+ import torch
29
+ import torch.utils.checkpoint
30
+ import transformers
31
+ from datasets import load_dataset, load_from_disk
32
+ from flax import jax_utils
33
+ from flax.core.frozen_dict import unfreeze
34
+ from flax.training import train_state
35
+ from flax.training.common_utils import shard
36
+ from huggingface_hub import create_repo, upload_folder
37
+ from PIL import Image, PngImagePlugin
38
+ from torch.utils.data import IterableDataset
39
+ from torchvision import transforms
40
+ from tqdm.auto import tqdm
41
+ from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed
42
+
43
+ from diffusers import (
44
+ FlaxAutoencoderKL,
45
+ FlaxControlNetModel,
46
+ FlaxDDPMScheduler,
47
+ FlaxStableDiffusionControlNetPipeline,
48
+ FlaxUNet2DConditionModel,
49
+ )
50
+ from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
51
+ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
52
+
53
+
54
+ # To prevent an error that occurs when there are abnormally large compressed data chunk in the png image
55
+ # see more https://github.com/python-pillow/Pillow/issues/5610
56
+ LARGE_ENOUGH_NUMBER = 100
57
+ PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
58
+
59
+ if is_wandb_available():
60
+ import wandb
61
+
62
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
63
+ check_min_version("0.33.0.dev0")
64
+
65
+ logger = logging.getLogger(__name__)
66
+
67
+
68
+ def log_validation(pipeline, pipeline_params, controlnet_params, tokenizer, args, rng, weight_dtype):
69
+ logger.info("Running validation...")
70
+
71
+ pipeline_params = pipeline_params.copy()
72
+ pipeline_params["controlnet"] = controlnet_params
73
+
74
+ num_samples = jax.device_count()
75
+ prng_seed = jax.random.split(rng, jax.device_count())
76
+
77
+ if len(args.validation_image) == len(args.validation_prompt):
78
+ validation_images = args.validation_image
79
+ validation_prompts = args.validation_prompt
80
+ elif len(args.validation_image) == 1:
81
+ validation_images = args.validation_image * len(args.validation_prompt)
82
+ validation_prompts = args.validation_prompt
83
+ elif len(args.validation_prompt) == 1:
84
+ validation_images = args.validation_image
85
+ validation_prompts = args.validation_prompt * len(args.validation_image)
86
+ else:
87
+ raise ValueError(
88
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
89
+ )
90
+
91
+ image_logs = []
92
+
93
+ for validation_prompt, validation_image in zip(validation_prompts, validation_images):
94
+ prompts = num_samples * [validation_prompt]
95
+ prompt_ids = pipeline.prepare_text_inputs(prompts)
96
+ prompt_ids = shard(prompt_ids)
97
+
98
+ validation_image = Image.open(validation_image).convert("RGB")
99
+ processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image])
100
+ processed_image = shard(processed_image)
101
+ images = pipeline(
102
+ prompt_ids=prompt_ids,
103
+ image=processed_image,
104
+ params=pipeline_params,
105
+ prng_seed=prng_seed,
106
+ num_inference_steps=50,
107
+ jit=True,
108
+ ).images
109
+
110
+ images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
111
+ images = pipeline.numpy_to_pil(images)
112
+
113
+ image_logs.append(
114
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
115
+ )
116
+
117
+ if args.report_to == "wandb":
118
+ formatted_images = []
119
+ for log in image_logs:
120
+ images = log["images"]
121
+ validation_prompt = log["validation_prompt"]
122
+ validation_image = log["validation_image"]
123
+
124
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
125
+ for image in images:
126
+ image = wandb.Image(image, caption=validation_prompt)
127
+ formatted_images.append(image)
128
+
129
+ wandb.log({"validation": formatted_images})
130
+ else:
131
+ logger.warning(f"image logging not implemented for {args.report_to}")
132
+
133
+ return image_logs
134
+
135
+
136
+ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
137
+ img_str = ""
138
+ if image_logs is not None:
139
+ for i, log in enumerate(image_logs):
140
+ images = log["images"]
141
+ validation_prompt = log["validation_prompt"]
142
+ validation_image = log["validation_image"]
143
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
144
+ img_str += f"prompt: {validation_prompt}\n"
145
+ images = [validation_image] + images
146
+ make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
147
+ img_str += f"![images_{i})](./images_{i}.png)\n"
148
+
149
+ model_description = f"""
150
+ # controlnet- {repo_id}
151
+
152
+ These are controlnet weights trained on {base_model} with new type of conditioning. You can find some example images in the following. \n
153
+ {img_str}
154
+ """
155
+
156
+ model_card = load_or_create_model_card(
157
+ repo_id_or_path=repo_id,
158
+ from_training=True,
159
+ license="creativeml-openrail-m",
160
+ base_model=base_model,
161
+ model_description=model_description,
162
+ inference=True,
163
+ )
164
+
165
+ tags = [
166
+ "stable-diffusion",
167
+ "stable-diffusion-diffusers",
168
+ "text-to-image",
169
+ "diffusers",
170
+ "controlnet",
171
+ "jax-diffusers-event",
172
+ "diffusers-training",
173
+ ]
174
+ model_card = populate_model_card(model_card, tags=tags)
175
+
176
+ model_card.save(os.path.join(repo_folder, "README.md"))
177
+
178
+
179
+ def parse_args():
180
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
181
+ parser.add_argument(
182
+ "--pretrained_model_name_or_path",
183
+ type=str,
184
+ required=True,
185
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
186
+ )
187
+ parser.add_argument(
188
+ "--controlnet_model_name_or_path",
189
+ type=str,
190
+ default=None,
191
+ help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
192
+ " If not specified controlnet weights are initialized from unet.",
193
+ )
194
+ parser.add_argument(
195
+ "--revision",
196
+ type=str,
197
+ default=None,
198
+ help="Revision of pretrained model identifier from huggingface.co/models.",
199
+ )
200
+ parser.add_argument(
201
+ "--from_pt",
202
+ action="store_true",
203
+ help="Load the pretrained model from a PyTorch checkpoint.",
204
+ )
205
+ parser.add_argument(
206
+ "--controlnet_revision",
207
+ type=str,
208
+ default=None,
209
+ help="Revision of controlnet model identifier from huggingface.co/models.",
210
+ )
211
+ parser.add_argument(
212
+ "--profile_steps",
213
+ type=int,
214
+ default=0,
215
+ help="How many training steps to profile in the beginning.",
216
+ )
217
+ parser.add_argument(
218
+ "--profile_validation",
219
+ action="store_true",
220
+ help="Whether to profile the (last) validation.",
221
+ )
222
+ parser.add_argument(
223
+ "--profile_memory",
224
+ action="store_true",
225
+ help="Whether to dump an initial (before training loop) and a final (at program end) memory profile.",
226
+ )
227
+ parser.add_argument(
228
+ "--ccache",
229
+ type=str,
230
+ default=None,
231
+ help="Enables compilation cache.",
232
+ )
233
+ parser.add_argument(
234
+ "--controlnet_from_pt",
235
+ action="store_true",
236
+ help="Load the controlnet model from a PyTorch checkpoint.",
237
+ )
238
+ parser.add_argument(
239
+ "--tokenizer_name",
240
+ type=str,
241
+ default=None,
242
+ help="Pretrained tokenizer name or path if not the same as model_name",
243
+ )
244
+ parser.add_argument(
245
+ "--output_dir",
246
+ type=str,
247
+ default="runs/{timestamp}",
248
+ help="The output directory where the model predictions and checkpoints will be written. "
249
+ "Can contain placeholders: {timestamp}.",
250
+ )
251
+ parser.add_argument(
252
+ "--cache_dir",
253
+ type=str,
254
+ default=None,
255
+ help="The directory where the downloaded models and datasets will be stored.",
256
+ )
257
+ parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.")
258
+ parser.add_argument(
259
+ "--resolution",
260
+ type=int,
261
+ default=512,
262
+ help=(
263
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
264
+ " resolution"
265
+ ),
266
+ )
267
+ parser.add_argument(
268
+ "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
269
+ )
270
+ parser.add_argument("--num_train_epochs", type=int, default=100)
271
+ parser.add_argument(
272
+ "--max_train_steps",
273
+ type=int,
274
+ default=None,
275
+ help="Total number of training steps to perform.",
276
+ )
277
+ parser.add_argument(
278
+ "--checkpointing_steps",
279
+ type=int,
280
+ default=5000,
281
+ help=("Save a checkpoint of the training state every X updates."),
282
+ )
283
+ parser.add_argument(
284
+ "--learning_rate",
285
+ type=float,
286
+ default=1e-4,
287
+ help="Initial learning rate (after the potential warmup period) to use.",
288
+ )
289
+ parser.add_argument(
290
+ "--scale_lr",
291
+ action="store_true",
292
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
293
+ )
294
+ parser.add_argument(
295
+ "--lr_scheduler",
296
+ type=str,
297
+ default="constant",
298
+ help=(
299
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
300
+ ' "constant", "constant_with_warmup"]'
301
+ ),
302
+ )
303
+ parser.add_argument(
304
+ "--snr_gamma",
305
+ type=float,
306
+ default=None,
307
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
308
+ "More details here: https://arxiv.org/abs/2303.09556.",
309
+ )
310
+ parser.add_argument(
311
+ "--dataloader_num_workers",
312
+ type=int,
313
+ default=0,
314
+ help=(
315
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
316
+ ),
317
+ )
318
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
319
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
320
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
321
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
322
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
323
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
324
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
325
+ parser.add_argument(
326
+ "--hub_model_id",
327
+ type=str,
328
+ default=None,
329
+ help="The name of the repository to keep in sync with the local `output_dir`.",
330
+ )
331
+ parser.add_argument(
332
+ "--logging_steps",
333
+ type=int,
334
+ default=100,
335
+ help=("log training metric every X steps to `--report_t`"),
336
+ )
337
+ parser.add_argument(
338
+ "--report_to",
339
+ type=str,
340
+ default="wandb",
341
+ help=('The integration to report the results and logs to. Currently only supported platforms are `"wandb"`'),
342
+ )
343
+ parser.add_argument(
344
+ "--mixed_precision",
345
+ type=str,
346
+ default="no",
347
+ choices=["no", "fp16", "bf16"],
348
+ help=(
349
+ "Whether to use mixed precision. Choose"
350
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
351
+ "and an Nvidia Ampere GPU."
352
+ ),
353
+ )
354
+ parser.add_argument(
355
+ "--dataset_name",
356
+ type=str,
357
+ default=None,
358
+ help=(
359
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
360
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
361
+ " or to a folder containing files that 🤗 Datasets can understand."
362
+ ),
363
+ )
364
+ parser.add_argument("--streaming", action="store_true", help="To stream a large dataset from Hub.")
365
+ parser.add_argument(
366
+ "--dataset_config_name",
367
+ type=str,
368
+ default=None,
369
+ help="The config of the Dataset, leave as None if there's only one config.",
370
+ )
371
+ parser.add_argument(
372
+ "--train_data_dir",
373
+ type=str,
374
+ default=None,
375
+ help=(
376
+ "A folder containing the training dataset. By default it will use `load_dataset` method to load a custom dataset from the folder."
377
+ "Folder must contain a dataset script as described here https://huggingface.co/docs/datasets/dataset_script) ."
378
+ "If `--load_from_disk` flag is passed, it will use `load_from_disk` method instead. Ignored if `dataset_name` is specified."
379
+ ),
380
+ )
381
+ parser.add_argument(
382
+ "--load_from_disk",
383
+ action="store_true",
384
+ help=(
385
+ "If True, will load a dataset that was previously saved using `save_to_disk` from `--train_data_dir`"
386
+ "See more https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.Dataset.load_from_disk"
387
+ ),
388
+ )
389
+ parser.add_argument(
390
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
391
+ )
392
+ parser.add_argument(
393
+ "--conditioning_image_column",
394
+ type=str,
395
+ default="conditioning_image",
396
+ help="The column of the dataset containing the controlnet conditioning image.",
397
+ )
398
+ parser.add_argument(
399
+ "--caption_column",
400
+ type=str,
401
+ default="text",
402
+ help="The column of the dataset containing a caption or a list of captions.",
403
+ )
404
+ parser.add_argument(
405
+ "--max_train_samples",
406
+ type=int,
407
+ default=None,
408
+ help=(
409
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
410
+ "value if set. Needed if `streaming` is set to True."
411
+ ),
412
+ )
413
+ parser.add_argument(
414
+ "--proportion_empty_prompts",
415
+ type=float,
416
+ default=0,
417
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
418
+ )
419
+ parser.add_argument(
420
+ "--validation_prompt",
421
+ type=str,
422
+ default=None,
423
+ nargs="+",
424
+ help=(
425
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
426
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
427
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
428
+ ),
429
+ )
430
+ parser.add_argument(
431
+ "--validation_image",
432
+ type=str,
433
+ default=None,
434
+ nargs="+",
435
+ help=(
436
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
437
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
438
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
439
+ " `--validation_image` that will be used with all `--validation_prompt`s."
440
+ ),
441
+ )
442
+ parser.add_argument(
443
+ "--validation_steps",
444
+ type=int,
445
+ default=100,
446
+ help=(
447
+ "Run validation every X steps. Validation consists of running the prompt"
448
+ " `args.validation_prompt` and logging the images."
449
+ ),
450
+ )
451
+ parser.add_argument("--wandb_entity", type=str, default=None, help=("The wandb entity to use (for teams)."))
452
+ parser.add_argument(
453
+ "--tracker_project_name",
454
+ type=str,
455
+ default="train_controlnet_flax",
456
+ help=("The `project` argument passed to wandb"),
457
+ )
458
+ parser.add_argument(
459
+ "--gradient_accumulation_steps", type=int, default=1, help="Number of steps to accumulate gradients over"
460
+ )
461
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
462
+
463
+ args = parser.parse_args()
464
+ args.output_dir = args.output_dir.replace("{timestamp}", time.strftime("%Y%m%d_%H%M%S"))
465
+
466
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
467
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
468
+ args.local_rank = env_local_rank
469
+
470
+ # Sanity checks
471
+ if args.dataset_name is None and args.train_data_dir is None:
472
+ raise ValueError("Need either a dataset name or a training folder.")
473
+ if args.dataset_name is not None and args.train_data_dir is not None:
474
+ raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
475
+
476
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
477
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
478
+
479
+ if args.validation_prompt is not None and args.validation_image is None:
480
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
481
+
482
+ if args.validation_prompt is None and args.validation_image is not None:
483
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
484
+
485
+ if (
486
+ args.validation_image is not None
487
+ and args.validation_prompt is not None
488
+ and len(args.validation_image) != 1
489
+ and len(args.validation_prompt) != 1
490
+ and len(args.validation_image) != len(args.validation_prompt)
491
+ ):
492
+ raise ValueError(
493
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
494
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
495
+ )
496
+
497
+ # This idea comes from
498
+ # https://github.com/borisdayma/dalle-mini/blob/d2be512d4a6a9cda2d63ba04afc33038f98f705f/src/dalle_mini/data.py#L370
499
+ if args.streaming and args.max_train_samples is None:
500
+ raise ValueError("You must specify `max_train_samples` when using dataset streaming.")
501
+
502
+ return args
503
+
504
+
505
+ def make_train_dataset(args, tokenizer, batch_size=None):
506
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
507
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
508
+
509
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
510
+ # download the dataset.
511
+ if args.dataset_name is not None:
512
+ # Downloading and loading a dataset from the hub.
513
+ dataset = load_dataset(
514
+ args.dataset_name,
515
+ args.dataset_config_name,
516
+ cache_dir=args.cache_dir,
517
+ streaming=args.streaming,
518
+ )
519
+ else:
520
+ if args.train_data_dir is not None:
521
+ if args.load_from_disk:
522
+ dataset = load_from_disk(
523
+ args.train_data_dir,
524
+ )
525
+ else:
526
+ dataset = load_dataset(
527
+ args.train_data_dir,
528
+ cache_dir=args.cache_dir,
529
+ )
530
+ # See more about loading custom images at
531
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
532
+
533
+ # Preprocessing the datasets.
534
+ # We need to tokenize inputs and targets.
535
+ if isinstance(dataset["train"], IterableDataset):
536
+ column_names = next(iter(dataset["train"])).keys()
537
+ else:
538
+ column_names = dataset["train"].column_names
539
+
540
+ # 6. Get the column names for input/target.
541
+ if args.image_column is None:
542
+ image_column = column_names[0]
543
+ logger.info(f"image column defaulting to {image_column}")
544
+ else:
545
+ image_column = args.image_column
546
+ if image_column not in column_names:
547
+ raise ValueError(
548
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
549
+ )
550
+
551
+ if args.caption_column is None:
552
+ caption_column = column_names[1]
553
+ logger.info(f"caption column defaulting to {caption_column}")
554
+ else:
555
+ caption_column = args.caption_column
556
+ if caption_column not in column_names:
557
+ raise ValueError(
558
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
559
+ )
560
+
561
+ if args.conditioning_image_column is None:
562
+ conditioning_image_column = column_names[2]
563
+ logger.info(f"conditioning image column defaulting to {caption_column}")
564
+ else:
565
+ conditioning_image_column = args.conditioning_image_column
566
+ if conditioning_image_column not in column_names:
567
+ raise ValueError(
568
+ f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
569
+ )
570
+
571
+ def tokenize_captions(examples, is_train=True):
572
+ captions = []
573
+ for caption in examples[caption_column]:
574
+ if random.random() < args.proportion_empty_prompts:
575
+ captions.append("")
576
+ elif isinstance(caption, str):
577
+ captions.append(caption)
578
+ elif isinstance(caption, (list, np.ndarray)):
579
+ # take a random caption if there are multiple
580
+ captions.append(random.choice(caption) if is_train else caption[0])
581
+ else:
582
+ raise ValueError(
583
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
584
+ )
585
+ inputs = tokenizer(
586
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
587
+ )
588
+ return inputs.input_ids
589
+
590
+ image_transforms = transforms.Compose(
591
+ [
592
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
593
+ transforms.CenterCrop(args.resolution),
594
+ transforms.ToTensor(),
595
+ transforms.Normalize([0.5], [0.5]),
596
+ ]
597
+ )
598
+
599
+ conditioning_image_transforms = transforms.Compose(
600
+ [
601
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
602
+ transforms.CenterCrop(args.resolution),
603
+ transforms.ToTensor(),
604
+ ]
605
+ )
606
+
607
+ def preprocess_train(examples):
608
+ images = [image.convert("RGB") for image in examples[image_column]]
609
+ images = [image_transforms(image) for image in images]
610
+
611
+ conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]]
612
+ conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
613
+
614
+ examples["pixel_values"] = images
615
+ examples["conditioning_pixel_values"] = conditioning_images
616
+ examples["input_ids"] = tokenize_captions(examples)
617
+
618
+ return examples
619
+
620
+ if jax.process_index() == 0:
621
+ if args.max_train_samples is not None:
622
+ if args.streaming:
623
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).take(args.max_train_samples)
624
+ else:
625
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
626
+ # Set the training transforms
627
+ if args.streaming:
628
+ train_dataset = dataset["train"].map(
629
+ preprocess_train,
630
+ batched=True,
631
+ batch_size=batch_size,
632
+ remove_columns=list(dataset["train"].features.keys()),
633
+ )
634
+ else:
635
+ train_dataset = dataset["train"].with_transform(preprocess_train)
636
+
637
+ return train_dataset
638
+
639
+
640
+ def collate_fn(examples):
641
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
642
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
643
+
644
+ conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
645
+ conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
646
+
647
+ input_ids = torch.stack([example["input_ids"] for example in examples])
648
+
649
+ batch = {
650
+ "pixel_values": pixel_values,
651
+ "conditioning_pixel_values": conditioning_pixel_values,
652
+ "input_ids": input_ids,
653
+ }
654
+ batch = {k: v.numpy() for k, v in batch.items()}
655
+ return batch
656
+
657
+
658
+ def get_params_to_save(params):
659
+ return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))
660
+
661
+
662
+ def main():
663
+ args = parse_args()
664
+
665
+ if args.report_to == "wandb" and args.hub_token is not None:
666
+ raise ValueError(
667
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
668
+ " Please use `huggingface-cli login` to authenticate with the Hub."
669
+ )
670
+
671
+ logging.basicConfig(
672
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
673
+ datefmt="%m/%d/%Y %H:%M:%S",
674
+ level=logging.INFO,
675
+ )
676
+ # Setup logging, we only want one process per machine to log things on the screen.
677
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
678
+ if jax.process_index() == 0:
679
+ transformers.utils.logging.set_verbosity_info()
680
+ else:
681
+ transformers.utils.logging.set_verbosity_error()
682
+
683
+ # wandb init
684
+ if jax.process_index() == 0 and args.report_to == "wandb":
685
+ wandb.init(
686
+ entity=args.wandb_entity,
687
+ project=args.tracker_project_name,
688
+ job_type="train",
689
+ config=args,
690
+ )
691
+
692
+ if args.seed is not None:
693
+ set_seed(args.seed)
694
+
695
+ rng = jax.random.PRNGKey(0)
696
+
697
+ # Handle the repository creation
698
+ if jax.process_index() == 0:
699
+ if args.output_dir is not None:
700
+ os.makedirs(args.output_dir, exist_ok=True)
701
+
702
+ if args.push_to_hub:
703
+ repo_id = create_repo(
704
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
705
+ ).repo_id
706
+
707
+ # Load the tokenizer and add the placeholder token as a additional special token
708
+ if args.tokenizer_name:
709
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
710
+ elif args.pretrained_model_name_or_path:
711
+ tokenizer = CLIPTokenizer.from_pretrained(
712
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
713
+ )
714
+ else:
715
+ raise NotImplementedError("No tokenizer specified!")
716
+
717
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
718
+ total_train_batch_size = args.train_batch_size * jax.local_device_count() * args.gradient_accumulation_steps
719
+ train_dataset = make_train_dataset(args, tokenizer, batch_size=total_train_batch_size)
720
+
721
+ train_dataloader = torch.utils.data.DataLoader(
722
+ train_dataset,
723
+ shuffle=not args.streaming,
724
+ collate_fn=collate_fn,
725
+ batch_size=total_train_batch_size,
726
+ num_workers=args.dataloader_num_workers,
727
+ drop_last=True,
728
+ )
729
+
730
+ weight_dtype = jnp.float32
731
+ if args.mixed_precision == "fp16":
732
+ weight_dtype = jnp.float16
733
+ elif args.mixed_precision == "bf16":
734
+ weight_dtype = jnp.bfloat16
735
+
736
+ # Load models and create wrapper for stable diffusion
737
+ text_encoder = FlaxCLIPTextModel.from_pretrained(
738
+ args.pretrained_model_name_or_path,
739
+ subfolder="text_encoder",
740
+ dtype=weight_dtype,
741
+ revision=args.revision,
742
+ from_pt=args.from_pt,
743
+ )
744
+ vae, vae_params = FlaxAutoencoderKL.from_pretrained(
745
+ args.pretrained_model_name_or_path,
746
+ revision=args.revision,
747
+ subfolder="vae",
748
+ dtype=weight_dtype,
749
+ from_pt=args.from_pt,
750
+ )
751
+ unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
752
+ args.pretrained_model_name_or_path,
753
+ subfolder="unet",
754
+ dtype=weight_dtype,
755
+ revision=args.revision,
756
+ from_pt=args.from_pt,
757
+ )
758
+
759
+ if args.controlnet_model_name_or_path:
760
+ logger.info("Loading existing controlnet weights")
761
+ controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
762
+ args.controlnet_model_name_or_path,
763
+ revision=args.controlnet_revision,
764
+ from_pt=args.controlnet_from_pt,
765
+ dtype=jnp.float32,
766
+ )
767
+ else:
768
+ logger.info("Initializing controlnet weights from unet")
769
+ rng, rng_params = jax.random.split(rng)
770
+
771
+ controlnet = FlaxControlNetModel(
772
+ in_channels=unet.config.in_channels,
773
+ down_block_types=unet.config.down_block_types,
774
+ only_cross_attention=unet.config.only_cross_attention,
775
+ block_out_channels=unet.config.block_out_channels,
776
+ layers_per_block=unet.config.layers_per_block,
777
+ attention_head_dim=unet.config.attention_head_dim,
778
+ cross_attention_dim=unet.config.cross_attention_dim,
779
+ use_linear_projection=unet.config.use_linear_projection,
780
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
781
+ freq_shift=unet.config.freq_shift,
782
+ )
783
+ controlnet_params = controlnet.init_weights(rng=rng_params)
784
+ controlnet_params = unfreeze(controlnet_params)
785
+ for key in [
786
+ "conv_in",
787
+ "time_embedding",
788
+ "down_blocks_0",
789
+ "down_blocks_1",
790
+ "down_blocks_2",
791
+ "down_blocks_3",
792
+ "mid_block",
793
+ ]:
794
+ controlnet_params[key] = unet_params[key]
795
+
796
+ pipeline, pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
797
+ args.pretrained_model_name_or_path,
798
+ tokenizer=tokenizer,
799
+ controlnet=controlnet,
800
+ safety_checker=None,
801
+ dtype=weight_dtype,
802
+ revision=args.revision,
803
+ from_pt=args.from_pt,
804
+ )
805
+ pipeline_params = jax_utils.replicate(pipeline_params)
806
+
807
+ # Optimization
808
+ if args.scale_lr:
809
+ args.learning_rate = args.learning_rate * total_train_batch_size
810
+
811
+ constant_scheduler = optax.constant_schedule(args.learning_rate)
812
+
813
+ adamw = optax.adamw(
814
+ learning_rate=constant_scheduler,
815
+ b1=args.adam_beta1,
816
+ b2=args.adam_beta2,
817
+ eps=args.adam_epsilon,
818
+ weight_decay=args.adam_weight_decay,
819
+ )
820
+
821
+ optimizer = optax.chain(
822
+ optax.clip_by_global_norm(args.max_grad_norm),
823
+ adamw,
824
+ )
825
+
826
+ state = train_state.TrainState.create(apply_fn=controlnet.__call__, params=controlnet_params, tx=optimizer)
827
+
828
+ noise_scheduler, noise_scheduler_state = FlaxDDPMScheduler.from_pretrained(
829
+ args.pretrained_model_name_or_path, subfolder="scheduler"
830
+ )
831
+
832
+ # Initialize our training
833
+ validation_rng, train_rngs = jax.random.split(rng)
834
+ train_rngs = jax.random.split(train_rngs, jax.local_device_count())
835
+
836
+ def compute_snr(timesteps):
837
+ """
838
+ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
839
+ """
840
+ alphas_cumprod = noise_scheduler_state.common.alphas_cumprod
841
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
842
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
843
+
844
+ alpha = sqrt_alphas_cumprod[timesteps]
845
+ sigma = sqrt_one_minus_alphas_cumprod[timesteps]
846
+ # Compute SNR.
847
+ snr = (alpha / sigma) ** 2
848
+ return snr
849
+
850
+ def train_step(state, unet_params, text_encoder_params, vae_params, batch, train_rng):
851
+ # reshape batch, add grad_step_dim if gradient_accumulation_steps > 1
852
+ if args.gradient_accumulation_steps > 1:
853
+ grad_steps = args.gradient_accumulation_steps
854
+ batch = jax.tree_map(lambda x: x.reshape((grad_steps, x.shape[0] // grad_steps) + x.shape[1:]), batch)
855
+
856
+ def compute_loss(params, minibatch, sample_rng):
857
+ # Convert images to latent space
858
+ vae_outputs = vae.apply(
859
+ {"params": vae_params}, minibatch["pixel_values"], deterministic=True, method=vae.encode
860
+ )
861
+ latents = vae_outputs.latent_dist.sample(sample_rng)
862
+ # (NHWC) -> (NCHW)
863
+ latents = jnp.transpose(latents, (0, 3, 1, 2))
864
+ latents = latents * vae.config.scaling_factor
865
+
866
+ # Sample noise that we'll add to the latents
867
+ noise_rng, timestep_rng = jax.random.split(sample_rng)
868
+ noise = jax.random.normal(noise_rng, latents.shape)
869
+ # Sample a random timestep for each image
870
+ bsz = latents.shape[0]
871
+ timesteps = jax.random.randint(
872
+ timestep_rng,
873
+ (bsz,),
874
+ 0,
875
+ noise_scheduler.config.num_train_timesteps,
876
+ )
877
+
878
+ # Add noise to the latents according to the noise magnitude at each timestep
879
+ # (this is the forward diffusion process)
880
+ noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
881
+
882
+ # Get the text embedding for conditioning
883
+ encoder_hidden_states = text_encoder(
884
+ minibatch["input_ids"],
885
+ params=text_encoder_params,
886
+ train=False,
887
+ )[0]
888
+
889
+ controlnet_cond = minibatch["conditioning_pixel_values"]
890
+
891
+ # Predict the noise residual and compute loss
892
+ down_block_res_samples, mid_block_res_sample = controlnet.apply(
893
+ {"params": params},
894
+ noisy_latents,
895
+ timesteps,
896
+ encoder_hidden_states,
897
+ controlnet_cond,
898
+ train=True,
899
+ return_dict=False,
900
+ )
901
+
902
+ model_pred = unet.apply(
903
+ {"params": unet_params},
904
+ noisy_latents,
905
+ timesteps,
906
+ encoder_hidden_states,
907
+ down_block_additional_residuals=down_block_res_samples,
908
+ mid_block_additional_residual=mid_block_res_sample,
909
+ ).sample
910
+
911
+ # Get the target for loss depending on the prediction type
912
+ if noise_scheduler.config.prediction_type == "epsilon":
913
+ target = noise
914
+ elif noise_scheduler.config.prediction_type == "v_prediction":
915
+ target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
916
+ else:
917
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
918
+
919
+ loss = (target - model_pred) ** 2
920
+
921
+ if args.snr_gamma is not None:
922
+ snr = jnp.array(compute_snr(timesteps))
923
+ snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma)
924
+ if noise_scheduler.config.prediction_type == "epsilon":
925
+ snr_loss_weights = snr_loss_weights / snr
926
+ elif noise_scheduler.config.prediction_type == "v_prediction":
927
+ snr_loss_weights = snr_loss_weights / (snr + 1)
928
+
929
+ loss = loss * snr_loss_weights
930
+
931
+ loss = loss.mean()
932
+
933
+ return loss
934
+
935
+ grad_fn = jax.value_and_grad(compute_loss)
936
+
937
+ # get a minibatch (one gradient accumulation slice)
938
+ def get_minibatch(batch, grad_idx):
939
+ return jax.tree_util.tree_map(
940
+ lambda x: jax.lax.dynamic_index_in_dim(x, grad_idx, keepdims=False),
941
+ batch,
942
+ )
943
+
944
+ def loss_and_grad(grad_idx, train_rng):
945
+ # create minibatch for the grad step
946
+ minibatch = get_minibatch(batch, grad_idx) if grad_idx is not None else batch
947
+ sample_rng, train_rng = jax.random.split(train_rng, 2)
948
+ loss, grad = grad_fn(state.params, minibatch, sample_rng)
949
+ return loss, grad, train_rng
950
+
951
+ if args.gradient_accumulation_steps == 1:
952
+ loss, grad, new_train_rng = loss_and_grad(None, train_rng)
953
+ else:
954
+ init_loss_grad_rng = (
955
+ 0.0, # initial value for cumul_loss
956
+ jax.tree_map(jnp.zeros_like, state.params), # initial value for cumul_grad
957
+ train_rng, # initial value for train_rng
958
+ )
959
+
960
+ def cumul_grad_step(grad_idx, loss_grad_rng):
961
+ cumul_loss, cumul_grad, train_rng = loss_grad_rng
962
+ loss, grad, new_train_rng = loss_and_grad(grad_idx, train_rng)
963
+ cumul_loss, cumul_grad = jax.tree_map(jnp.add, (cumul_loss, cumul_grad), (loss, grad))
964
+ return cumul_loss, cumul_grad, new_train_rng
965
+
966
+ loss, grad, new_train_rng = jax.lax.fori_loop(
967
+ 0,
968
+ args.gradient_accumulation_steps,
969
+ cumul_grad_step,
970
+ init_loss_grad_rng,
971
+ )
972
+ loss, grad = jax.tree_map(lambda x: x / args.gradient_accumulation_steps, (loss, grad))
973
+
974
+ grad = jax.lax.pmean(grad, "batch")
975
+
976
+ new_state = state.apply_gradients(grads=grad)
977
+
978
+ metrics = {"loss": loss}
979
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
980
+
981
+ def l2(xs):
982
+ return jnp.sqrt(sum([jnp.vdot(x, x) for x in jax.tree_util.tree_leaves(xs)]))
983
+
984
+ metrics["l2_grads"] = l2(jax.tree_util.tree_leaves(grad))
985
+
986
+ return new_state, metrics, new_train_rng
987
+
988
+ # Create parallel version of the train step
989
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
990
+
991
+ # Replicate the train state on each device
992
+ state = jax_utils.replicate(state)
993
+ unet_params = jax_utils.replicate(unet_params)
994
+ text_encoder_params = jax_utils.replicate(text_encoder.params)
995
+ vae_params = jax_utils.replicate(vae_params)
996
+
997
+ # Train!
998
+ if args.streaming:
999
+ dataset_length = args.max_train_samples
1000
+ else:
1001
+ dataset_length = len(train_dataloader)
1002
+ num_update_steps_per_epoch = math.ceil(dataset_length / args.gradient_accumulation_steps)
1003
+
1004
+ # Scheduler and math around the number of training steps.
1005
+ if args.max_train_steps is None:
1006
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1007
+
1008
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1009
+
1010
+ logger.info("***** Running training *****")
1011
+ logger.info(f" Num examples = {args.max_train_samples if args.streaming else len(train_dataset)}")
1012
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1013
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1014
+ logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
1015
+ logger.info(f" Total optimization steps = {args.num_train_epochs * num_update_steps_per_epoch}")
1016
+
1017
+ if jax.process_index() == 0 and args.report_to == "wandb":
1018
+ wandb.define_metric("*", step_metric="train/step")
1019
+ wandb.define_metric("train/step", step_metric="walltime")
1020
+ wandb.config.update(
1021
+ {
1022
+ "num_train_examples": args.max_train_samples if args.streaming else len(train_dataset),
1023
+ "total_train_batch_size": total_train_batch_size,
1024
+ "total_optimization_step": args.num_train_epochs * num_update_steps_per_epoch,
1025
+ "num_devices": jax.device_count(),
1026
+ "controlnet_params": sum(np.prod(x.shape) for x in jax.tree_util.tree_leaves(state.params)),
1027
+ }
1028
+ )
1029
+
1030
+ global_step = step0 = 0
1031
+ epochs = tqdm(
1032
+ range(args.num_train_epochs),
1033
+ desc="Epoch ... ",
1034
+ position=0,
1035
+ disable=jax.process_index() > 0,
1036
+ )
1037
+ if args.profile_memory:
1038
+ jax.profiler.save_device_memory_profile(os.path.join(args.output_dir, "memory_initial.prof"))
1039
+ t00 = t0 = time.monotonic()
1040
+ for epoch in epochs:
1041
+ # ======================== Training ================================
1042
+
1043
+ train_metrics = []
1044
+ train_metric = None
1045
+
1046
+ steps_per_epoch = (
1047
+ args.max_train_samples // total_train_batch_size
1048
+ if args.streaming or args.max_train_samples
1049
+ else len(train_dataset) // total_train_batch_size
1050
+ )
1051
+ train_step_progress_bar = tqdm(
1052
+ total=steps_per_epoch,
1053
+ desc="Training...",
1054
+ position=1,
1055
+ leave=False,
1056
+ disable=jax.process_index() > 0,
1057
+ )
1058
+ # train
1059
+ for batch in train_dataloader:
1060
+ if args.profile_steps and global_step == 1:
1061
+ train_metric["loss"].block_until_ready()
1062
+ jax.profiler.start_trace(args.output_dir)
1063
+ if args.profile_steps and global_step == 1 + args.profile_steps:
1064
+ train_metric["loss"].block_until_ready()
1065
+ jax.profiler.stop_trace()
1066
+
1067
+ batch = shard(batch)
1068
+ with jax.profiler.StepTraceAnnotation("train", step_num=global_step):
1069
+ state, train_metric, train_rngs = p_train_step(
1070
+ state, unet_params, text_encoder_params, vae_params, batch, train_rngs
1071
+ )
1072
+ train_metrics.append(train_metric)
1073
+
1074
+ train_step_progress_bar.update(1)
1075
+
1076
+ global_step += 1
1077
+ if global_step >= args.max_train_steps:
1078
+ break
1079
+
1080
+ if (
1081
+ args.validation_prompt is not None
1082
+ and global_step % args.validation_steps == 0
1083
+ and jax.process_index() == 0
1084
+ ):
1085
+ _ = log_validation(
1086
+ pipeline, pipeline_params, state.params, tokenizer, args, validation_rng, weight_dtype
1087
+ )
1088
+
1089
+ if global_step % args.logging_steps == 0 and jax.process_index() == 0:
1090
+ if args.report_to == "wandb":
1091
+ train_metrics = jax_utils.unreplicate(train_metrics)
1092
+ train_metrics = jax.tree_util.tree_map(lambda *m: jnp.array(m).mean(), *train_metrics)
1093
+ wandb.log(
1094
+ {
1095
+ "walltime": time.monotonic() - t00,
1096
+ "train/step": global_step,
1097
+ "train/epoch": global_step / dataset_length,
1098
+ "train/steps_per_sec": (global_step - step0) / (time.monotonic() - t0),
1099
+ **{f"train/{k}": v for k, v in train_metrics.items()},
1100
+ }
1101
+ )
1102
+ t0, step0 = time.monotonic(), global_step
1103
+ train_metrics = []
1104
+ if global_step % args.checkpointing_steps == 0 and jax.process_index() == 0:
1105
+ controlnet.save_pretrained(
1106
+ f"{args.output_dir}/{global_step}",
1107
+ params=get_params_to_save(state.params),
1108
+ )
1109
+
1110
+ train_metric = jax_utils.unreplicate(train_metric)
1111
+ train_step_progress_bar.close()
1112
+ epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
1113
+
1114
+ # Final validation & store model.
1115
+ if jax.process_index() == 0:
1116
+ if args.validation_prompt is not None:
1117
+ if args.profile_validation:
1118
+ jax.profiler.start_trace(args.output_dir)
1119
+ image_logs = log_validation(
1120
+ pipeline, pipeline_params, state.params, tokenizer, args, validation_rng, weight_dtype
1121
+ )
1122
+ if args.profile_validation:
1123
+ jax.profiler.stop_trace()
1124
+ else:
1125
+ image_logs = None
1126
+
1127
+ controlnet.save_pretrained(
1128
+ args.output_dir,
1129
+ params=get_params_to_save(state.params),
1130
+ )
1131
+
1132
+ if args.push_to_hub:
1133
+ save_model_card(
1134
+ repo_id,
1135
+ image_logs=image_logs,
1136
+ base_model=args.pretrained_model_name_or_path,
1137
+ repo_folder=args.output_dir,
1138
+ )
1139
+ upload_folder(
1140
+ repo_id=repo_id,
1141
+ folder_path=args.output_dir,
1142
+ commit_message="End of training",
1143
+ ignore_patterns=["step_*", "epoch_*"],
1144
+ )
1145
+
1146
+ if args.profile_memory:
1147
+ jax.profiler.save_device_memory_profile(os.path.join(args.output_dir, "memory_final.prof"))
1148
+ logger.info("Finished training.")
1149
+
1150
+
1151
+ if __name__ == "__main__":
1152
+ main()
train_controlnet_flux.py ADDED
@@ -0,0 +1,1435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import copy
18
+ import functools
19
+ import logging
20
+ import math
21
+ import os
22
+ import random
23
+ import shutil
24
+ from contextlib import nullcontext
25
+ from pathlib import Path
26
+
27
+ import accelerate
28
+ import numpy as np
29
+ import torch
30
+ import torch.nn.functional as F
31
+ import torch.utils.checkpoint
32
+ import transformers
33
+ from accelerate import Accelerator
34
+ from accelerate.logging import get_logger
35
+ from accelerate.utils import DistributedType, ProjectConfiguration, set_seed
36
+ from datasets import load_dataset
37
+ from huggingface_hub import create_repo, upload_folder
38
+ from packaging import version
39
+ from PIL import Image
40
+ from torchvision import transforms
41
+ from tqdm.auto import tqdm
42
+ from transformers import (
43
+ AutoTokenizer,
44
+ CLIPTextModel,
45
+ T5EncoderModel,
46
+ )
47
+
48
+ import diffusers
49
+ from diffusers import (
50
+ AutoencoderKL,
51
+ FlowMatchEulerDiscreteScheduler,
52
+ FluxTransformer2DModel,
53
+ )
54
+ from diffusers.models.controlnet_flux import FluxControlNetModel
55
+ from diffusers.optimization import get_scheduler
56
+ from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
57
+ from diffusers.training_utils import compute_density_for_timestep_sampling, free_memory
58
+ from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
59
+ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
60
+ from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
61
+ from diffusers.utils.torch_utils import is_compiled_module
62
+
63
+
64
+ if is_wandb_available():
65
+ import wandb
66
+
67
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
68
+ check_min_version("0.33.0.dev0")
69
+
70
+ logger = get_logger(__name__)
71
+ if is_torch_npu_available():
72
+ torch.npu.config.allow_internal_format = False
73
+
74
+
75
+ def log_validation(
76
+ vae, flux_transformer, flux_controlnet, args, accelerator, weight_dtype, step, is_final_validation=False
77
+ ):
78
+ logger.info("Running validation... ")
79
+
80
+ if not is_final_validation:
81
+ flux_controlnet = accelerator.unwrap_model(flux_controlnet)
82
+ pipeline = FluxControlNetPipeline.from_pretrained(
83
+ args.pretrained_model_name_or_path,
84
+ controlnet=flux_controlnet,
85
+ transformer=flux_transformer,
86
+ torch_dtype=torch.bfloat16,
87
+ )
88
+ else:
89
+ flux_controlnet = FluxControlNetModel.from_pretrained(
90
+ args.output_dir, torch_dtype=torch.bfloat16, variant=args.save_weight_dtype
91
+ )
92
+ pipeline = FluxControlNetPipeline.from_pretrained(
93
+ args.pretrained_model_name_or_path,
94
+ controlnet=flux_controlnet,
95
+ transformer=flux_transformer,
96
+ torch_dtype=torch.bfloat16,
97
+ )
98
+
99
+ pipeline.to(accelerator.device)
100
+ pipeline.set_progress_bar_config(disable=True)
101
+
102
+ if args.enable_xformers_memory_efficient_attention:
103
+ pipeline.enable_xformers_memory_efficient_attention()
104
+
105
+ if args.seed is None:
106
+ generator = None
107
+ else:
108
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
109
+
110
+ if len(args.validation_image) == len(args.validation_prompt):
111
+ validation_images = args.validation_image
112
+ validation_prompts = args.validation_prompt
113
+ elif len(args.validation_image) == 1:
114
+ validation_images = args.validation_image * len(args.validation_prompt)
115
+ validation_prompts = args.validation_prompt
116
+ elif len(args.validation_prompt) == 1:
117
+ validation_images = args.validation_image
118
+ validation_prompts = args.validation_prompt * len(args.validation_image)
119
+ else:
120
+ raise ValueError(
121
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
122
+ )
123
+
124
+ image_logs = []
125
+ if is_final_validation or torch.backends.mps.is_available():
126
+ autocast_ctx = nullcontext()
127
+ else:
128
+ autocast_ctx = torch.autocast(accelerator.device.type)
129
+
130
+ for validation_prompt, validation_image in zip(validation_prompts, validation_images):
131
+ from diffusers.utils import load_image
132
+
133
+ validation_image = load_image(validation_image)
134
+ # maybe need to inference on 1024 to get a good image
135
+ validation_image = validation_image.resize((args.resolution, args.resolution))
136
+
137
+ images = []
138
+
139
+ # pre calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
140
+ prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
141
+ validation_prompt, prompt_2=validation_prompt
142
+ )
143
+ for _ in range(args.num_validation_images):
144
+ with autocast_ctx:
145
+ # need to fix in pipeline_flux_controlnet
146
+ image = pipeline(
147
+ prompt_embeds=prompt_embeds,
148
+ pooled_prompt_embeds=pooled_prompt_embeds,
149
+ control_image=validation_image,
150
+ num_inference_steps=28,
151
+ controlnet_conditioning_scale=0.7,
152
+ guidance_scale=3.5,
153
+ generator=generator,
154
+ ).images[0]
155
+ image = image.resize((args.resolution, args.resolution))
156
+ images.append(image)
157
+ image_logs.append(
158
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
159
+ )
160
+
161
+ tracker_key = "test" if is_final_validation else "validation"
162
+ for tracker in accelerator.trackers:
163
+ if tracker.name == "tensorboard":
164
+ for log in image_logs:
165
+ images = log["images"]
166
+ validation_prompt = log["validation_prompt"]
167
+ validation_image = log["validation_image"]
168
+
169
+ formatted_images = [np.asarray(validation_image)]
170
+
171
+ for image in images:
172
+ formatted_images.append(np.asarray(image))
173
+
174
+ formatted_images = np.stack(formatted_images)
175
+
176
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
177
+ elif tracker.name == "wandb":
178
+ formatted_images = []
179
+
180
+ for log in image_logs:
181
+ images = log["images"]
182
+ validation_prompt = log["validation_prompt"]
183
+ validation_image = log["validation_image"]
184
+
185
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
186
+
187
+ for image in images:
188
+ image = wandb.Image(image, caption=validation_prompt)
189
+ formatted_images.append(image)
190
+
191
+ tracker.log({tracker_key: formatted_images})
192
+ else:
193
+ logger.warning(f"image logging not implemented for {tracker.name}")
194
+
195
+ del pipeline
196
+ free_memory()
197
+ return image_logs
198
+
199
+
200
+ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
201
+ img_str = ""
202
+ if image_logs is not None:
203
+ img_str = "You can find some example images below.\n\n"
204
+ for i, log in enumerate(image_logs):
205
+ images = log["images"]
206
+ validation_prompt = log["validation_prompt"]
207
+ validation_image = log["validation_image"]
208
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
209
+ img_str += f"prompt: {validation_prompt}\n"
210
+ images = [validation_image] + images
211
+ make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
212
+ img_str += f"![images_{i})](./images_{i}.png)\n"
213
+
214
+ model_description = f"""
215
+ # controlnet-{repo_id}
216
+
217
+ These are controlnet weights trained on {base_model} with new type of conditioning.
218
+ {img_str}
219
+
220
+ ## License
221
+
222
+ Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)
223
+ """
224
+
225
+ model_card = load_or_create_model_card(
226
+ repo_id_or_path=repo_id,
227
+ from_training=True,
228
+ license="other",
229
+ base_model=base_model,
230
+ model_description=model_description,
231
+ inference=True,
232
+ )
233
+
234
+ tags = [
235
+ "flux",
236
+ "flux-diffusers",
237
+ "text-to-image",
238
+ "diffusers",
239
+ "controlnet",
240
+ "diffusers-training",
241
+ ]
242
+ model_card = populate_model_card(model_card, tags=tags)
243
+
244
+ model_card.save(os.path.join(repo_folder, "README.md"))
245
+
246
+
247
+ def parse_args(input_args=None):
248
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
249
+ parser.add_argument(
250
+ "--pretrained_model_name_or_path",
251
+ type=str,
252
+ default=None,
253
+ required=True,
254
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
255
+ )
256
+ parser.add_argument(
257
+ "--pretrained_vae_model_name_or_path",
258
+ type=str,
259
+ default=None,
260
+ help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
261
+ )
262
+ parser.add_argument(
263
+ "--controlnet_model_name_or_path",
264
+ type=str,
265
+ default=None,
266
+ help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
267
+ " If not specified controlnet weights are initialized from unet.",
268
+ )
269
+ parser.add_argument(
270
+ "--variant",
271
+ type=str,
272
+ default=None,
273
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
274
+ )
275
+ parser.add_argument(
276
+ "--revision",
277
+ type=str,
278
+ default=None,
279
+ required=False,
280
+ help="Revision of pretrained model identifier from huggingface.co/models.",
281
+ )
282
+ parser.add_argument(
283
+ "--tokenizer_name",
284
+ type=str,
285
+ default=None,
286
+ help="Pretrained tokenizer name or path if not the same as model_name",
287
+ )
288
+ parser.add_argument(
289
+ "--output_dir",
290
+ type=str,
291
+ default="controlnet-model",
292
+ help="The output directory where the model predictions and checkpoints will be written.",
293
+ )
294
+ parser.add_argument(
295
+ "--cache_dir",
296
+ type=str,
297
+ default=None,
298
+ help="The directory where the downloaded models and datasets will be stored.",
299
+ )
300
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
301
+ parser.add_argument(
302
+ "--resolution",
303
+ type=int,
304
+ default=512,
305
+ help=(
306
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
307
+ " resolution"
308
+ ),
309
+ )
310
+ parser.add_argument(
311
+ "--crops_coords_top_left_h",
312
+ type=int,
313
+ default=0,
314
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
315
+ )
316
+ parser.add_argument(
317
+ "--crops_coords_top_left_w",
318
+ type=int,
319
+ default=0,
320
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
321
+ )
322
+ parser.add_argument(
323
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
324
+ )
325
+ parser.add_argument("--num_train_epochs", type=int, default=1)
326
+ parser.add_argument(
327
+ "--max_train_steps",
328
+ type=int,
329
+ default=None,
330
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
331
+ )
332
+ parser.add_argument(
333
+ "--checkpointing_steps",
334
+ type=int,
335
+ default=500,
336
+ help=(
337
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
338
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
339
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
340
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
341
+ "instructions."
342
+ ),
343
+ )
344
+ parser.add_argument(
345
+ "--checkpoints_total_limit",
346
+ type=int,
347
+ default=None,
348
+ help=("Max number of checkpoints to store."),
349
+ )
350
+ parser.add_argument(
351
+ "--resume_from_checkpoint",
352
+ type=str,
353
+ default=None,
354
+ help=(
355
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
356
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
357
+ ),
358
+ )
359
+ parser.add_argument(
360
+ "--gradient_accumulation_steps",
361
+ type=int,
362
+ default=1,
363
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
364
+ )
365
+ parser.add_argument(
366
+ "--gradient_checkpointing",
367
+ action="store_true",
368
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
369
+ )
370
+ parser.add_argument(
371
+ "--learning_rate",
372
+ type=float,
373
+ default=5e-6,
374
+ help="Initial learning rate (after the potential warmup period) to use.",
375
+ )
376
+ parser.add_argument(
377
+ "--scale_lr",
378
+ action="store_true",
379
+ default=False,
380
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
381
+ )
382
+ parser.add_argument(
383
+ "--lr_scheduler",
384
+ type=str,
385
+ default="constant",
386
+ help=(
387
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
388
+ ' "constant", "constant_with_warmup"]'
389
+ ),
390
+ )
391
+ parser.add_argument(
392
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
393
+ )
394
+ parser.add_argument(
395
+ "--lr_num_cycles",
396
+ type=int,
397
+ default=1,
398
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
399
+ )
400
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
401
+ parser.add_argument(
402
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
403
+ )
404
+ parser.add_argument(
405
+ "--use_adafactor",
406
+ action="store_true",
407
+ help=(
408
+ "Adafactor is a stochastic optimization method based on Adam that reduces memory usage while retaining"
409
+ "the empirical benefits of adaptivity. This is achieved through maintaining a factored representation "
410
+ "of the squared gradient accumulator across training steps."
411
+ ),
412
+ )
413
+ parser.add_argument(
414
+ "--dataloader_num_workers",
415
+ type=int,
416
+ default=0,
417
+ help=(
418
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
419
+ ),
420
+ )
421
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
422
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
423
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
424
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
425
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
426
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
427
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
428
+ parser.add_argument(
429
+ "--hub_model_id",
430
+ type=str,
431
+ default=None,
432
+ help="The name of the repository to keep in sync with the local `output_dir`.",
433
+ )
434
+ parser.add_argument(
435
+ "--logging_dir",
436
+ type=str,
437
+ default="logs",
438
+ help=(
439
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
440
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
441
+ ),
442
+ )
443
+ parser.add_argument(
444
+ "--allow_tf32",
445
+ action="store_true",
446
+ help=(
447
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
448
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
449
+ ),
450
+ )
451
+ parser.add_argument(
452
+ "--report_to",
453
+ type=str,
454
+ default="tensorboard",
455
+ help=(
456
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
457
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
458
+ ),
459
+ )
460
+ parser.add_argument(
461
+ "--mixed_precision",
462
+ type=str,
463
+ default=None,
464
+ choices=["no", "fp16", "bf16"],
465
+ help=(
466
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
467
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
468
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
469
+ ),
470
+ )
471
+ parser.add_argument(
472
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
473
+ )
474
+ parser.add_argument(
475
+ "--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
476
+ )
477
+ parser.add_argument(
478
+ "--set_grads_to_none",
479
+ action="store_true",
480
+ help=(
481
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
482
+ " behaviors, so disable this argument if it causes any problems. More info:"
483
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
484
+ ),
485
+ )
486
+ parser.add_argument(
487
+ "--dataset_name",
488
+ type=str,
489
+ default=None,
490
+ help=(
491
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
492
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
493
+ " or to a folder containing files that 🤗 Datasets can understand."
494
+ ),
495
+ )
496
+ parser.add_argument(
497
+ "--dataset_config_name",
498
+ type=str,
499
+ default=None,
500
+ help="The config of the Dataset, leave as None if there's only one config.",
501
+ )
502
+ parser.add_argument(
503
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
504
+ )
505
+ parser.add_argument(
506
+ "--conditioning_image_column",
507
+ type=str,
508
+ default="conditioning_image",
509
+ help="The column of the dataset containing the controlnet conditioning image.",
510
+ )
511
+ parser.add_argument(
512
+ "--caption_column",
513
+ type=str,
514
+ default="text",
515
+ help="The column of the dataset containing a caption or a list of captions.",
516
+ )
517
+ parser.add_argument(
518
+ "--max_train_samples",
519
+ type=int,
520
+ default=None,
521
+ help=(
522
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
523
+ "value if set."
524
+ ),
525
+ )
526
+ parser.add_argument(
527
+ "--proportion_empty_prompts",
528
+ type=float,
529
+ default=0,
530
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
531
+ )
532
+ parser.add_argument(
533
+ "--validation_prompt",
534
+ type=str,
535
+ default=None,
536
+ nargs="+",
537
+ help=(
538
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
539
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
540
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
541
+ ),
542
+ )
543
+ parser.add_argument(
544
+ "--validation_image",
545
+ type=str,
546
+ default=None,
547
+ nargs="+",
548
+ help=(
549
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
550
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
551
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
552
+ " `--validation_image` that will be used with all `--validation_prompt`s."
553
+ ),
554
+ )
555
+ parser.add_argument(
556
+ "--num_double_layers",
557
+ type=int,
558
+ default=4,
559
+ help="Number of double layers in the controlnet (default: 4).",
560
+ )
561
+ parser.add_argument(
562
+ "--num_single_layers",
563
+ type=int,
564
+ default=4,
565
+ help="Number of single layers in the controlnet (default: 4).",
566
+ )
567
+ parser.add_argument(
568
+ "--num_validation_images",
569
+ type=int,
570
+ default=2,
571
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
572
+ )
573
+ parser.add_argument(
574
+ "--validation_steps",
575
+ type=int,
576
+ default=100,
577
+ help=(
578
+ "Run validation every X steps. Validation consists of running the prompt"
579
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
580
+ " and logging the images."
581
+ ),
582
+ )
583
+ parser.add_argument(
584
+ "--tracker_project_name",
585
+ type=str,
586
+ default="flux_train_controlnet",
587
+ help=(
588
+ "The `project_name` argument passed to Accelerator.init_trackers for"
589
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
590
+ ),
591
+ )
592
+ parser.add_argument(
593
+ "--jsonl_for_train",
594
+ type=str,
595
+ default=None,
596
+ help="Path to the jsonl file containing the training data.",
597
+ )
598
+
599
+ parser.add_argument(
600
+ "--guidance_scale",
601
+ type=float,
602
+ default=3.5,
603
+ help="the guidance scale used for transformer.",
604
+ )
605
+
606
+ parser.add_argument(
607
+ "--save_weight_dtype",
608
+ type=str,
609
+ default="fp32",
610
+ choices=[
611
+ "fp16",
612
+ "bf16",
613
+ "fp32",
614
+ ],
615
+ help=("Preserve precision type according to selected weight"),
616
+ )
617
+
618
+ parser.add_argument(
619
+ "--weighting_scheme",
620
+ type=str,
621
+ default="logit_normal",
622
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
623
+ help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
624
+ )
625
+ parser.add_argument(
626
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
627
+ )
628
+ parser.add_argument(
629
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
630
+ )
631
+ parser.add_argument(
632
+ "--mode_scale",
633
+ type=float,
634
+ default=1.29,
635
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
636
+ )
637
+ parser.add_argument(
638
+ "--enable_model_cpu_offload",
639
+ action="store_true",
640
+ help="Enable model cpu offload and save memory.",
641
+ )
642
+
643
+ if input_args is not None:
644
+ args = parser.parse_args(input_args)
645
+ else:
646
+ args = parser.parse_args()
647
+
648
+ if args.dataset_name is None and args.jsonl_for_train is None:
649
+ raise ValueError("Specify either `--dataset_name` or `--jsonl_for_train`")
650
+
651
+ if args.dataset_name is not None and args.jsonl_for_train is not None:
652
+ raise ValueError("Specify only one of `--dataset_name` or `--jsonl_for_train`")
653
+
654
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
655
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
656
+
657
+ if args.validation_prompt is not None and args.validation_image is None:
658
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
659
+
660
+ if args.validation_prompt is None and args.validation_image is not None:
661
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
662
+
663
+ if (
664
+ args.validation_image is not None
665
+ and args.validation_prompt is not None
666
+ and len(args.validation_image) != 1
667
+ and len(args.validation_prompt) != 1
668
+ and len(args.validation_image) != len(args.validation_prompt)
669
+ ):
670
+ raise ValueError(
671
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
672
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
673
+ )
674
+
675
+ if args.resolution % 8 != 0:
676
+ raise ValueError(
677
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
678
+ )
679
+
680
+ return args
681
+
682
+
683
+ def get_train_dataset(args, accelerator):
684
+ dataset = None
685
+ if args.dataset_name is not None:
686
+ # Downloading and loading a dataset from the hub.
687
+ dataset = load_dataset(
688
+ args.dataset_name,
689
+ args.dataset_config_name,
690
+ cache_dir=args.cache_dir,
691
+ )
692
+ if args.jsonl_for_train is not None:
693
+ # load from json
694
+ dataset = load_dataset("json", data_files=args.jsonl_for_train, cache_dir=args.cache_dir)
695
+ dataset = dataset.flatten_indices()
696
+ # Preprocessing the datasets.
697
+ # We need to tokenize inputs and targets.
698
+ column_names = dataset["train"].column_names
699
+
700
+ # 6. Get the column names for input/target.
701
+ if args.image_column is None:
702
+ image_column = column_names[0]
703
+ logger.info(f"image column defaulting to {image_column}")
704
+ else:
705
+ image_column = args.image_column
706
+ if image_column not in column_names:
707
+ raise ValueError(
708
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
709
+ )
710
+
711
+ if args.caption_column is None:
712
+ caption_column = column_names[1]
713
+ logger.info(f"caption column defaulting to {caption_column}")
714
+ else:
715
+ caption_column = args.caption_column
716
+ if caption_column not in column_names:
717
+ raise ValueError(
718
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
719
+ )
720
+
721
+ if args.conditioning_image_column is None:
722
+ conditioning_image_column = column_names[2]
723
+ logger.info(f"conditioning image column defaulting to {conditioning_image_column}")
724
+ else:
725
+ conditioning_image_column = args.conditioning_image_column
726
+ if conditioning_image_column not in column_names:
727
+ raise ValueError(
728
+ f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
729
+ )
730
+
731
+ with accelerator.main_process_first():
732
+ train_dataset = dataset["train"].shuffle(seed=args.seed)
733
+ if args.max_train_samples is not None:
734
+ train_dataset = train_dataset.select(range(args.max_train_samples))
735
+ return train_dataset
736
+
737
+
738
+ def prepare_train_dataset(dataset, accelerator):
739
+ image_transforms = transforms.Compose(
740
+ [
741
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
742
+ transforms.CenterCrop(args.resolution),
743
+ transforms.ToTensor(),
744
+ transforms.Normalize([0.5], [0.5]),
745
+ ]
746
+ )
747
+
748
+ conditioning_image_transforms = transforms.Compose(
749
+ [
750
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
751
+ transforms.CenterCrop(args.resolution),
752
+ transforms.ToTensor(),
753
+ transforms.Normalize([0.5], [0.5]),
754
+ ]
755
+ )
756
+
757
+ def preprocess_train(examples):
758
+ images = [
759
+ (image.convert("RGB") if not isinstance(image, str) else Image.open(image).convert("RGB"))
760
+ for image in examples[args.image_column]
761
+ ]
762
+ images = [image_transforms(image) for image in images]
763
+
764
+ conditioning_images = [
765
+ (image.convert("RGB") if not isinstance(image, str) else Image.open(image).convert("RGB"))
766
+ for image in examples[args.conditioning_image_column]
767
+ ]
768
+ conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
769
+ examples["pixel_values"] = images
770
+ examples["conditioning_pixel_values"] = conditioning_images
771
+
772
+ return examples
773
+
774
+ with accelerator.main_process_first():
775
+ dataset = dataset.with_transform(preprocess_train)
776
+
777
+ return dataset
778
+
779
+
780
+ def collate_fn(examples):
781
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
782
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
783
+
784
+ conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
785
+ conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
786
+
787
+ prompt_ids = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples])
788
+
789
+ pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples])
790
+ text_ids = torch.stack([torch.tensor(example["text_ids"]) for example in examples])
791
+
792
+ return {
793
+ "pixel_values": pixel_values,
794
+ "conditioning_pixel_values": conditioning_pixel_values,
795
+ "prompt_ids": prompt_ids,
796
+ "unet_added_conditions": {"pooled_prompt_embeds": pooled_prompt_embeds, "time_ids": text_ids},
797
+ }
798
+
799
+
800
+ def main(args):
801
+ if args.report_to == "wandb" and args.hub_token is not None:
802
+ raise ValueError(
803
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
804
+ " Please use `huggingface-cli login` to authenticate with the Hub."
805
+ )
806
+
807
+ logging_out_dir = Path(args.output_dir, args.logging_dir)
808
+
809
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
810
+ # due to pytorch#99272, MPS does not yet support bfloat16.
811
+ raise ValueError(
812
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
813
+ )
814
+
815
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=str(logging_out_dir))
816
+
817
+ accelerator = Accelerator(
818
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
819
+ mixed_precision=args.mixed_precision,
820
+ log_with=args.report_to,
821
+ project_config=accelerator_project_config,
822
+ )
823
+
824
+ # Disable AMP for MPS. A technique for accelerating machine learning computations on iOS and macOS devices.
825
+ if torch.backends.mps.is_available():
826
+ print("MPS is enabled. Disabling AMP.")
827
+ accelerator.native_amp = False
828
+
829
+ # Make one log on every process with the configuration for debugging.
830
+ logging.basicConfig(
831
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
832
+ datefmt="%m/%d/%Y %H:%M:%S",
833
+ # DEBUG, INFO, WARNING, ERROR, CRITICAL
834
+ level=logging.INFO,
835
+ )
836
+ logger.info(accelerator.state, main_process_only=False)
837
+
838
+ if accelerator.is_local_main_process:
839
+ transformers.utils.logging.set_verbosity_warning()
840
+ diffusers.utils.logging.set_verbosity_info()
841
+ else:
842
+ transformers.utils.logging.set_verbosity_error()
843
+ diffusers.utils.logging.set_verbosity_error()
844
+
845
+ # If passed along, set the training seed now.
846
+ if args.seed is not None:
847
+ set_seed(args.seed)
848
+
849
+ # Handle the repository creation
850
+ if accelerator.is_main_process:
851
+ if args.output_dir is not None:
852
+ os.makedirs(args.output_dir, exist_ok=True)
853
+
854
+ if args.push_to_hub:
855
+ repo_id = create_repo(
856
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
857
+ ).repo_id
858
+
859
+ # Load the tokenizers
860
+ # load clip tokenizer
861
+ tokenizer_one = AutoTokenizer.from_pretrained(
862
+ args.pretrained_model_name_or_path,
863
+ subfolder="tokenizer",
864
+ revision=args.revision,
865
+ )
866
+ # load t5 tokenizer
867
+ tokenizer_two = AutoTokenizer.from_pretrained(
868
+ args.pretrained_model_name_or_path,
869
+ subfolder="tokenizer_2",
870
+ revision=args.revision,
871
+ )
872
+ # load clip text encoder
873
+ text_encoder_one = CLIPTextModel.from_pretrained(
874
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
875
+ )
876
+ # load t5 text encoder
877
+ text_encoder_two = T5EncoderModel.from_pretrained(
878
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
879
+ )
880
+
881
+ vae = AutoencoderKL.from_pretrained(
882
+ args.pretrained_model_name_or_path,
883
+ subfolder="vae",
884
+ revision=args.revision,
885
+ variant=args.variant,
886
+ )
887
+ flux_transformer = FluxTransformer2DModel.from_pretrained(
888
+ args.pretrained_model_name_or_path,
889
+ subfolder="transformer",
890
+ revision=args.revision,
891
+ variant=args.variant,
892
+ )
893
+ if args.controlnet_model_name_or_path:
894
+ logger.info("Loading existing controlnet weights")
895
+ flux_controlnet = FluxControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
896
+ else:
897
+ logger.info("Initializing controlnet weights from transformer")
898
+ # we can define the num_layers, num_single_layers,
899
+ flux_controlnet = FluxControlNetModel.from_transformer(
900
+ flux_transformer,
901
+ attention_head_dim=flux_transformer.config["attention_head_dim"],
902
+ num_attention_heads=flux_transformer.config["num_attention_heads"],
903
+ num_layers=args.num_double_layers,
904
+ num_single_layers=args.num_single_layers,
905
+ )
906
+ logger.info("all models loaded successfully")
907
+
908
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
909
+ args.pretrained_model_name_or_path,
910
+ subfolder="scheduler",
911
+ )
912
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
913
+ vae.requires_grad_(False)
914
+ flux_transformer.requires_grad_(False)
915
+ text_encoder_one.requires_grad_(False)
916
+ text_encoder_two.requires_grad_(False)
917
+ flux_controlnet.train()
918
+
919
+ # use some pipeline function
920
+ flux_controlnet_pipeline = FluxControlNetPipeline(
921
+ scheduler=noise_scheduler,
922
+ vae=vae,
923
+ text_encoder=text_encoder_one,
924
+ tokenizer=tokenizer_one,
925
+ text_encoder_2=text_encoder_two,
926
+ tokenizer_2=tokenizer_two,
927
+ transformer=flux_transformer,
928
+ controlnet=flux_controlnet,
929
+ )
930
+ if args.enable_model_cpu_offload:
931
+ flux_controlnet_pipeline.enable_model_cpu_offload()
932
+ else:
933
+ flux_controlnet_pipeline.to(accelerator.device)
934
+
935
+ def unwrap_model(model):
936
+ model = accelerator.unwrap_model(model)
937
+ model = model._orig_mod if is_compiled_module(model) else model
938
+ return model
939
+
940
+ # `accelerate` 0.16.0 will have better support for customized saving
941
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
942
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
943
+ def save_model_hook(models, weights, output_dir):
944
+ if accelerator.is_main_process:
945
+ i = len(weights) - 1
946
+
947
+ while len(weights) > 0:
948
+ weights.pop()
949
+ model = models[i]
950
+
951
+ sub_dir = "flux_controlnet"
952
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
953
+
954
+ i -= 1
955
+
956
+ def load_model_hook(models, input_dir):
957
+ while len(models) > 0:
958
+ # pop models so that they are not loaded again
959
+ model = models.pop()
960
+
961
+ # load diffusers style into model
962
+ load_model = FluxControlNetModel.from_pretrained(input_dir, subfolder="flux_controlnet")
963
+ model.register_to_config(**load_model.config)
964
+
965
+ model.load_state_dict(load_model.state_dict())
966
+ del load_model
967
+
968
+ accelerator.register_save_state_pre_hook(save_model_hook)
969
+ accelerator.register_load_state_pre_hook(load_model_hook)
970
+
971
+ if args.enable_npu_flash_attention:
972
+ if is_torch_npu_available():
973
+ logger.info("npu flash attention enabled.")
974
+ flux_transformer.enable_npu_flash_attention()
975
+ else:
976
+ raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")
977
+
978
+ if args.enable_xformers_memory_efficient_attention:
979
+ if is_xformers_available():
980
+ import xformers
981
+
982
+ xformers_version = version.parse(xformers.__version__)
983
+ if xformers_version == version.parse("0.0.16"):
984
+ logger.warning(
985
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
986
+ )
987
+ flux_transformer.enable_xformers_memory_efficient_attention()
988
+ flux_controlnet.enable_xformers_memory_efficient_attention()
989
+ else:
990
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
991
+
992
+ if args.gradient_checkpointing:
993
+ flux_transformer.enable_gradient_checkpointing()
994
+ flux_controlnet.enable_gradient_checkpointing()
995
+
996
+ # Check that all trainable models are in full precision
997
+ low_precision_error_string = (
998
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
999
+ " doing mixed precision training, copy of the weights should still be float32."
1000
+ )
1001
+
1002
+ if unwrap_model(flux_controlnet).dtype != torch.float32:
1003
+ raise ValueError(
1004
+ f"Controlnet loaded as datatype {unwrap_model(flux_controlnet).dtype}. {low_precision_error_string}"
1005
+ )
1006
+
1007
+ # Enable TF32 for faster training on Ampere GPUs,
1008
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
1009
+ if args.allow_tf32:
1010
+ torch.backends.cuda.matmul.allow_tf32 = True
1011
+
1012
+ if args.scale_lr:
1013
+ args.learning_rate = (
1014
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
1015
+ )
1016
+
1017
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
1018
+ if args.use_8bit_adam:
1019
+ try:
1020
+ import bitsandbytes as bnb
1021
+ except ImportError:
1022
+ raise ImportError(
1023
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
1024
+ )
1025
+
1026
+ optimizer_class = bnb.optim.AdamW8bit
1027
+ else:
1028
+ optimizer_class = torch.optim.AdamW
1029
+
1030
+ # Optimizer creation
1031
+ params_to_optimize = flux_controlnet.parameters()
1032
+ # use adafactor optimizer to save gpu memory
1033
+ if args.use_adafactor:
1034
+ from transformers import Adafactor
1035
+
1036
+ optimizer = Adafactor(
1037
+ params_to_optimize,
1038
+ lr=args.learning_rate,
1039
+ scale_parameter=False,
1040
+ relative_step=False,
1041
+ # warmup_init=True,
1042
+ weight_decay=args.adam_weight_decay,
1043
+ )
1044
+ else:
1045
+ optimizer = optimizer_class(
1046
+ params_to_optimize,
1047
+ lr=args.learning_rate,
1048
+ betas=(args.adam_beta1, args.adam_beta2),
1049
+ weight_decay=args.adam_weight_decay,
1050
+ eps=args.adam_epsilon,
1051
+ )
1052
+
1053
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
1054
+ # as these models are only used for inference, keeping weights in full precision is not required.
1055
+ weight_dtype = torch.float32
1056
+ if accelerator.mixed_precision == "fp16":
1057
+ weight_dtype = torch.float16
1058
+ elif accelerator.mixed_precision == "bf16":
1059
+ weight_dtype = torch.bfloat16
1060
+
1061
+ vae.to(accelerator.device, dtype=weight_dtype)
1062
+ flux_transformer.to(accelerator.device, dtype=weight_dtype)
1063
+
1064
+ def compute_embeddings(batch, proportion_empty_prompts, flux_controlnet_pipeline, weight_dtype, is_train=True):
1065
+ prompt_batch = batch[args.caption_column]
1066
+ captions = []
1067
+ for caption in prompt_batch:
1068
+ if random.random() < proportion_empty_prompts:
1069
+ captions.append("")
1070
+ elif isinstance(caption, str):
1071
+ captions.append(caption)
1072
+ elif isinstance(caption, (list, np.ndarray)):
1073
+ # take a random caption if there are multiple
1074
+ captions.append(random.choice(caption) if is_train else caption[0])
1075
+ prompt_batch = captions
1076
+ prompt_embeds, pooled_prompt_embeds, text_ids = flux_controlnet_pipeline.encode_prompt(
1077
+ prompt_batch, prompt_2=prompt_batch
1078
+ )
1079
+ prompt_embeds = prompt_embeds.to(dtype=weight_dtype)
1080
+ pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=weight_dtype)
1081
+ text_ids = text_ids.to(dtype=weight_dtype)
1082
+
1083
+ # text_ids [512,3] to [bs,512,3]
1084
+ text_ids = text_ids.unsqueeze(0).expand(prompt_embeds.shape[0], -1, -1)
1085
+ return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds, "text_ids": text_ids}
1086
+
1087
+ train_dataset = get_train_dataset(args, accelerator)
1088
+ text_encoders = [text_encoder_one, text_encoder_two]
1089
+ tokenizers = [tokenizer_one, tokenizer_two]
1090
+ compute_embeddings_fn = functools.partial(
1091
+ compute_embeddings,
1092
+ flux_controlnet_pipeline=flux_controlnet_pipeline,
1093
+ proportion_empty_prompts=args.proportion_empty_prompts,
1094
+ weight_dtype=weight_dtype,
1095
+ )
1096
+ with accelerator.main_process_first():
1097
+ from datasets.fingerprint import Hasher
1098
+
1099
+ # fingerprint used by the cache for the other processes to load the result
1100
+ # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
1101
+ new_fingerprint = Hasher.hash(args)
1102
+ train_dataset = train_dataset.map(
1103
+ compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=50
1104
+ )
1105
+
1106
+ del text_encoders, tokenizers, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
1107
+ free_memory()
1108
+
1109
+ # Then get the training dataset ready to be passed to the dataloader.
1110
+ train_dataset = prepare_train_dataset(train_dataset, accelerator)
1111
+
1112
+ train_dataloader = torch.utils.data.DataLoader(
1113
+ train_dataset,
1114
+ shuffle=True,
1115
+ collate_fn=collate_fn,
1116
+ batch_size=args.train_batch_size,
1117
+ num_workers=args.dataloader_num_workers,
1118
+ )
1119
+
1120
+ # Scheduler and math around the number of training steps.
1121
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1122
+ if args.max_train_steps is None:
1123
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
1124
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
1125
+ num_training_steps_for_scheduler = (
1126
+ args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
1127
+ )
1128
+ else:
1129
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
1130
+
1131
+ lr_scheduler = get_scheduler(
1132
+ args.lr_scheduler,
1133
+ optimizer=optimizer,
1134
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1135
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
1136
+ num_cycles=args.lr_num_cycles,
1137
+ power=args.lr_power,
1138
+ )
1139
+ # Prepare everything with our `accelerator`.
1140
+ flux_controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1141
+ flux_controlnet, optimizer, train_dataloader, lr_scheduler
1142
+ )
1143
+
1144
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1145
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1146
+ if args.max_train_steps is None:
1147
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1148
+ if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
1149
+ logger.warning(
1150
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
1151
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
1152
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
1153
+ )
1154
+ # Afterwards we recalculate our number of training epochs
1155
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1156
+
1157
+ # We need to initialize the trackers we use, and also store our configuration.
1158
+ # The trackers initializes automatically on the main process.
1159
+ if accelerator.is_main_process:
1160
+ tracker_config = dict(vars(args))
1161
+
1162
+ # tensorboard cannot handle list types for config
1163
+ tracker_config.pop("validation_prompt")
1164
+ tracker_config.pop("validation_image")
1165
+
1166
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
1167
+
1168
+ # Train!
1169
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1170
+
1171
+ logger.info("***** Running training *****")
1172
+ logger.info(f" Num examples = {len(train_dataset)}")
1173
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1174
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1175
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1176
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1177
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1178
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1179
+ global_step = 0
1180
+ first_epoch = 0
1181
+
1182
+ # Potentially load in the weights and states from a previous save
1183
+ if args.resume_from_checkpoint:
1184
+ if args.resume_from_checkpoint != "latest":
1185
+ path = os.path.basename(args.resume_from_checkpoint)
1186
+ else:
1187
+ # Get the most recent checkpoint
1188
+ dirs = os.listdir(args.output_dir)
1189
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1190
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1191
+ path = dirs[-1] if len(dirs) > 0 else None
1192
+
1193
+ if path is None:
1194
+ accelerator.print(
1195
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1196
+ )
1197
+ args.resume_from_checkpoint = None
1198
+ initial_global_step = 0
1199
+ else:
1200
+ accelerator.print(f"Resuming from checkpoint {path}")
1201
+ accelerator.load_state(os.path.join(args.output_dir, path))
1202
+ global_step = int(path.split("-")[1])
1203
+
1204
+ initial_global_step = global_step
1205
+ first_epoch = global_step // num_update_steps_per_epoch
1206
+ else:
1207
+ initial_global_step = 0
1208
+
1209
+ progress_bar = tqdm(
1210
+ range(0, args.max_train_steps),
1211
+ initial=initial_global_step,
1212
+ desc="Steps",
1213
+ # Only show the progress bar once on each machine.
1214
+ disable=not accelerator.is_local_main_process,
1215
+ )
1216
+
1217
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1218
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
1219
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
1220
+ timesteps = timesteps.to(accelerator.device)
1221
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
1222
+
1223
+ sigma = sigmas[step_indices].flatten()
1224
+ while len(sigma.shape) < n_dim:
1225
+ sigma = sigma.unsqueeze(-1)
1226
+ return sigma
1227
+
1228
+ image_logs = None
1229
+ for epoch in range(first_epoch, args.num_train_epochs):
1230
+ for step, batch in enumerate(train_dataloader):
1231
+ with accelerator.accumulate(flux_controlnet):
1232
+ # Convert images to latent space
1233
+ # vae encode
1234
+ pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
1235
+ pixel_latents_tmp = vae.encode(pixel_values).latent_dist.sample()
1236
+ pixel_latents_tmp = (pixel_latents_tmp - vae.config.shift_factor) * vae.config.scaling_factor
1237
+ pixel_latents = FluxControlNetPipeline._pack_latents(
1238
+ pixel_latents_tmp,
1239
+ pixel_values.shape[0],
1240
+ pixel_latents_tmp.shape[1],
1241
+ pixel_latents_tmp.shape[2],
1242
+ pixel_latents_tmp.shape[3],
1243
+ )
1244
+
1245
+ control_values = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
1246
+ control_latents = vae.encode(control_values).latent_dist.sample()
1247
+ control_latents = (control_latents - vae.config.shift_factor) * vae.config.scaling_factor
1248
+ control_image = FluxControlNetPipeline._pack_latents(
1249
+ control_latents,
1250
+ control_values.shape[0],
1251
+ control_latents.shape[1],
1252
+ control_latents.shape[2],
1253
+ control_latents.shape[3],
1254
+ )
1255
+
1256
+ latent_image_ids = FluxControlNetPipeline._prepare_latent_image_ids(
1257
+ batch_size=pixel_latents_tmp.shape[0],
1258
+ height=pixel_latents_tmp.shape[2] // 2,
1259
+ width=pixel_latents_tmp.shape[3] // 2,
1260
+ device=pixel_values.device,
1261
+ dtype=pixel_values.dtype,
1262
+ )
1263
+
1264
+ bsz = pixel_latents.shape[0]
1265
+ noise = torch.randn_like(pixel_latents).to(accelerator.device).to(dtype=weight_dtype)
1266
+ # Sample a random timestep for each image
1267
+ # for weighting schemes where we sample timesteps non-uniformly
1268
+ u = compute_density_for_timestep_sampling(
1269
+ weighting_scheme=args.weighting_scheme,
1270
+ batch_size=bsz,
1271
+ logit_mean=args.logit_mean,
1272
+ logit_std=args.logit_std,
1273
+ mode_scale=args.mode_scale,
1274
+ )
1275
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
1276
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=pixel_latents.device)
1277
+
1278
+ # Add noise according to flow matching.
1279
+ sigmas = get_sigmas(timesteps, n_dim=pixel_latents.ndim, dtype=pixel_latents.dtype)
1280
+ noisy_model_input = (1.0 - sigmas) * pixel_latents + sigmas * noise
1281
+
1282
+ # handle guidance
1283
+ if flux_transformer.config.guidance_embeds:
1284
+ guidance_vec = torch.full(
1285
+ (noisy_model_input.shape[0],),
1286
+ args.guidance_scale,
1287
+ device=noisy_model_input.device,
1288
+ dtype=weight_dtype,
1289
+ )
1290
+ else:
1291
+ guidance_vec = None
1292
+
1293
+ controlnet_block_samples, controlnet_single_block_samples = flux_controlnet(
1294
+ hidden_states=noisy_model_input,
1295
+ controlnet_cond=control_image,
1296
+ timestep=timesteps / 1000,
1297
+ guidance=guidance_vec,
1298
+ pooled_projections=batch["unet_added_conditions"]["pooled_prompt_embeds"].to(dtype=weight_dtype),
1299
+ encoder_hidden_states=batch["prompt_ids"].to(dtype=weight_dtype),
1300
+ txt_ids=batch["unet_added_conditions"]["time_ids"][0].to(dtype=weight_dtype),
1301
+ img_ids=latent_image_ids,
1302
+ return_dict=False,
1303
+ )
1304
+
1305
+ noise_pred = flux_transformer(
1306
+ hidden_states=noisy_model_input,
1307
+ timestep=timesteps / 1000,
1308
+ guidance=guidance_vec,
1309
+ pooled_projections=batch["unet_added_conditions"]["pooled_prompt_embeds"].to(dtype=weight_dtype),
1310
+ encoder_hidden_states=batch["prompt_ids"].to(dtype=weight_dtype),
1311
+ controlnet_block_samples=[sample.to(dtype=weight_dtype) for sample in controlnet_block_samples]
1312
+ if controlnet_block_samples is not None
1313
+ else None,
1314
+ controlnet_single_block_samples=[
1315
+ sample.to(dtype=weight_dtype) for sample in controlnet_single_block_samples
1316
+ ]
1317
+ if controlnet_single_block_samples is not None
1318
+ else None,
1319
+ txt_ids=batch["unet_added_conditions"]["time_ids"][0].to(dtype=weight_dtype),
1320
+ img_ids=latent_image_ids,
1321
+ return_dict=False,
1322
+ )[0]
1323
+
1324
+ loss = F.mse_loss(noise_pred.float(), (noise - pixel_latents).float(), reduction="mean")
1325
+ accelerator.backward(loss)
1326
+ # Check if the gradient of each model parameter contains NaN
1327
+ for name, param in flux_controlnet.named_parameters():
1328
+ if param.grad is not None and torch.isnan(param.grad).any():
1329
+ logger.error(f"Gradient for {name} contains NaN!")
1330
+
1331
+ if accelerator.sync_gradients:
1332
+ params_to_clip = flux_controlnet.parameters()
1333
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1334
+ optimizer.step()
1335
+ lr_scheduler.step()
1336
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
1337
+
1338
+ # Checks if the accelerator has performed an optimization step behind the scenes
1339
+ if accelerator.sync_gradients:
1340
+ progress_bar.update(1)
1341
+ global_step += 1
1342
+
1343
+ # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
1344
+ if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
1345
+ if global_step % args.checkpointing_steps == 0:
1346
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1347
+ if args.checkpoints_total_limit is not None:
1348
+ checkpoints = os.listdir(args.output_dir)
1349
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1350
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1351
+
1352
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1353
+ if len(checkpoints) >= args.checkpoints_total_limit:
1354
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1355
+ removing_checkpoints = checkpoints[0:num_to_remove]
1356
+
1357
+ logger.info(
1358
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1359
+ )
1360
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1361
+
1362
+ for removing_checkpoint in removing_checkpoints:
1363
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1364
+ shutil.rmtree(removing_checkpoint)
1365
+
1366
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1367
+ accelerator.save_state(save_path)
1368
+ logger.info(f"Saved state to {save_path}")
1369
+
1370
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
1371
+ image_logs = log_validation(
1372
+ vae=vae,
1373
+ flux_transformer=flux_transformer,
1374
+ flux_controlnet=flux_controlnet,
1375
+ args=args,
1376
+ accelerator=accelerator,
1377
+ weight_dtype=weight_dtype,
1378
+ step=global_step,
1379
+ )
1380
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1381
+ progress_bar.set_postfix(**logs)
1382
+ accelerator.log(logs, step=global_step)
1383
+
1384
+ if global_step >= args.max_train_steps:
1385
+ break
1386
+ # Create the pipeline using using the trained modules and save it.
1387
+ accelerator.wait_for_everyone()
1388
+ if accelerator.is_main_process:
1389
+ flux_controlnet = unwrap_model(flux_controlnet)
1390
+ save_weight_dtype = torch.float32
1391
+ if args.save_weight_dtype == "fp16":
1392
+ save_weight_dtype = torch.float16
1393
+ elif args.save_weight_dtype == "bf16":
1394
+ save_weight_dtype = torch.bfloat16
1395
+ flux_controlnet.to(save_weight_dtype)
1396
+ if args.save_weight_dtype != "fp32":
1397
+ flux_controlnet.save_pretrained(args.output_dir, variant=args.save_weight_dtype)
1398
+ else:
1399
+ flux_controlnet.save_pretrained(args.output_dir)
1400
+ # Run a final round of validation.
1401
+ # Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`.
1402
+ image_logs = None
1403
+ if args.validation_prompt is not None:
1404
+ image_logs = log_validation(
1405
+ vae=vae,
1406
+ flux_transformer=flux_transformer,
1407
+ flux_controlnet=None,
1408
+ args=args,
1409
+ accelerator=accelerator,
1410
+ weight_dtype=weight_dtype,
1411
+ step=global_step,
1412
+ is_final_validation=True,
1413
+ )
1414
+
1415
+ if args.push_to_hub:
1416
+ save_model_card(
1417
+ repo_id,
1418
+ image_logs=image_logs,
1419
+ base_model=args.pretrained_model_name_or_path,
1420
+ repo_folder=args.output_dir,
1421
+ )
1422
+
1423
+ upload_folder(
1424
+ repo_id=repo_id,
1425
+ folder_path=args.output_dir,
1426
+ commit_message="End of training",
1427
+ ignore_patterns=["step_*", "epoch_*"],
1428
+ )
1429
+
1430
+ accelerator.end_training()
1431
+
1432
+
1433
+ if __name__ == "__main__":
1434
+ args = parse_args()
1435
+ main(args)
train_controlnet_sd3.py ADDED
@@ -0,0 +1,1428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import contextlib
18
+ import copy
19
+ import functools
20
+ import logging
21
+ import math
22
+ import os
23
+ import random
24
+ import shutil
25
+ from pathlib import Path
26
+
27
+ import accelerate
28
+ import numpy as np
29
+ import torch
30
+ import torch.utils.checkpoint
31
+ import transformers
32
+ from accelerate import Accelerator
33
+ from accelerate.logging import get_logger
34
+ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
35
+ from datasets import load_dataset
36
+ from huggingface_hub import create_repo, upload_folder
37
+ from packaging import version
38
+ from PIL import Image
39
+ from torchvision import transforms
40
+ from tqdm.auto import tqdm
41
+ from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
42
+
43
+ import diffusers
44
+ from diffusers import (
45
+ AutoencoderKL,
46
+ FlowMatchEulerDiscreteScheduler,
47
+ SD3ControlNetModel,
48
+ SD3Transformer2DModel,
49
+ StableDiffusion3ControlNetPipeline,
50
+ )
51
+ from diffusers.optimization import get_scheduler
52
+ from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
53
+ from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
54
+ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
55
+ from diffusers.utils.torch_utils import is_compiled_module
56
+
57
+
58
+ if is_wandb_available():
59
+ import wandb
60
+
61
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
62
+ check_min_version("0.33.0.dev0")
63
+
64
+ logger = get_logger(__name__)
65
+
66
+
67
+ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_validation=False):
68
+ logger.info("Running validation... ")
69
+
70
+ if not is_final_validation:
71
+ controlnet = accelerator.unwrap_model(controlnet)
72
+ else:
73
+ controlnet = SD3ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
74
+
75
+ pipeline = StableDiffusion3ControlNetPipeline.from_pretrained(
76
+ args.pretrained_model_name_or_path,
77
+ controlnet=controlnet,
78
+ safety_checker=None,
79
+ revision=args.revision,
80
+ variant=args.variant,
81
+ torch_dtype=weight_dtype,
82
+ )
83
+ pipeline = pipeline.to(torch.device(accelerator.device))
84
+ pipeline.set_progress_bar_config(disable=True)
85
+
86
+ if args.seed is None:
87
+ generator = None
88
+ else:
89
+ generator = torch.manual_seed(args.seed)
90
+
91
+ if len(args.validation_image) == len(args.validation_prompt):
92
+ validation_images = args.validation_image
93
+ validation_prompts = args.validation_prompt
94
+ elif len(args.validation_image) == 1:
95
+ validation_images = args.validation_image * len(args.validation_prompt)
96
+ validation_prompts = args.validation_prompt
97
+ elif len(args.validation_prompt) == 1:
98
+ validation_images = args.validation_image
99
+ validation_prompts = args.validation_prompt * len(args.validation_image)
100
+ else:
101
+ raise ValueError(
102
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
103
+ )
104
+
105
+ image_logs = []
106
+ inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast(accelerator.device.type)
107
+
108
+ for validation_prompt, validation_image in zip(validation_prompts, validation_images):
109
+ validation_image = Image.open(validation_image).convert("RGB")
110
+
111
+ images = []
112
+
113
+ for _ in range(args.num_validation_images):
114
+ with inference_ctx:
115
+ image = pipeline(
116
+ validation_prompt, control_image=validation_image, num_inference_steps=20, generator=generator
117
+ ).images[0]
118
+
119
+ images.append(image)
120
+
121
+ image_logs.append(
122
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
123
+ )
124
+
125
+ tracker_key = "test" if is_final_validation else "validation"
126
+ for tracker in accelerator.trackers:
127
+ if tracker.name == "tensorboard":
128
+ for log in image_logs:
129
+ images = log["images"]
130
+ validation_prompt = log["validation_prompt"]
131
+ validation_image = log["validation_image"]
132
+
133
+ tracker.writer.add_image(
134
+ "Controlnet conditioning", np.asarray([validation_image]), step, dataformats="NHWC"
135
+ )
136
+
137
+ formatted_images = []
138
+ for image in images:
139
+ formatted_images.append(np.asarray(image))
140
+
141
+ formatted_images = np.stack(formatted_images)
142
+
143
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
144
+ elif tracker.name == "wandb":
145
+ formatted_images = []
146
+
147
+ for log in image_logs:
148
+ images = log["images"]
149
+ validation_prompt = log["validation_prompt"]
150
+ validation_image = log["validation_image"]
151
+
152
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
153
+
154
+ for image in images:
155
+ image = wandb.Image(image, caption=validation_prompt)
156
+ formatted_images.append(image)
157
+
158
+ tracker.log({tracker_key: formatted_images})
159
+ else:
160
+ logger.warning(f"image logging not implemented for {tracker.name}")
161
+
162
+ del pipeline
163
+ free_memory()
164
+
165
+ if not is_final_validation:
166
+ controlnet.to(accelerator.device)
167
+
168
+ return image_logs
169
+
170
+
171
+ # Copied from dreambooth sd3 example
172
+ def load_text_encoders(class_one, class_two, class_three):
173
+ text_encoder_one = class_one.from_pretrained(
174
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
175
+ )
176
+ text_encoder_two = class_two.from_pretrained(
177
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
178
+ )
179
+ text_encoder_three = class_three.from_pretrained(
180
+ args.pretrained_model_name_or_path, subfolder="text_encoder_3", revision=args.revision, variant=args.variant
181
+ )
182
+ return text_encoder_one, text_encoder_two, text_encoder_three
183
+
184
+
185
+ # Copied from dreambooth sd3 example
186
+ def import_model_class_from_model_name_or_path(
187
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
188
+ ):
189
+ text_encoder_config = PretrainedConfig.from_pretrained(
190
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
191
+ )
192
+ model_class = text_encoder_config.architectures[0]
193
+ if model_class == "CLIPTextModelWithProjection":
194
+ from transformers import CLIPTextModelWithProjection
195
+
196
+ return CLIPTextModelWithProjection
197
+ elif model_class == "T5EncoderModel":
198
+ from transformers import T5EncoderModel
199
+
200
+ return T5EncoderModel
201
+ else:
202
+ raise ValueError(f"{model_class} is not supported.")
203
+
204
+
205
+ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
206
+ img_str = ""
207
+ if image_logs is not None:
208
+ img_str = "You can find some example images below.\n\n"
209
+ for i, log in enumerate(image_logs):
210
+ images = log["images"]
211
+ validation_prompt = log["validation_prompt"]
212
+ validation_image = log["validation_image"]
213
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
214
+ img_str += f"prompt: {validation_prompt}\n"
215
+ images = [validation_image] + images
216
+ make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
217
+ img_str += f"![images_{i})](./images_{i}.png)\n"
218
+
219
+ model_description = f"""
220
+ # SD3 controlnet-{repo_id}
221
+
222
+ These are controlnet weights trained on {base_model} with new type of conditioning.
223
+ The weights were trained using [ControlNet](https://github.com/lllyasviel/ControlNet) with the [SD3 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/README_sd3.md).
224
+ {img_str}
225
+
226
+ Please adhere to the licensing terms as described `[here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE)`.
227
+ """
228
+ model_card = load_or_create_model_card(
229
+ repo_id_or_path=repo_id,
230
+ from_training=True,
231
+ license="openrail++",
232
+ base_model=base_model,
233
+ model_description=model_description,
234
+ inference=True,
235
+ )
236
+
237
+ tags = [
238
+ "text-to-image",
239
+ "diffusers-training",
240
+ "diffusers",
241
+ "sd3",
242
+ "sd3-diffusers",
243
+ "controlnet",
244
+ ]
245
+ model_card = populate_model_card(model_card, tags=tags)
246
+
247
+ model_card.save(os.path.join(repo_folder, "README.md"))
248
+
249
+
250
+ def parse_args(input_args=None):
251
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
252
+ parser.add_argument(
253
+ "--pretrained_model_name_or_path",
254
+ type=str,
255
+ default=None,
256
+ required=True,
257
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
258
+ )
259
+ parser.add_argument(
260
+ "--controlnet_model_name_or_path",
261
+ type=str,
262
+ default=None,
263
+ help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
264
+ " If not specified controlnet weights are initialized from unet.",
265
+ )
266
+ parser.add_argument(
267
+ "--num_extra_conditioning_channels",
268
+ type=int,
269
+ default=0,
270
+ help="Number of extra conditioning channels for controlnet.",
271
+ )
272
+ parser.add_argument(
273
+ "--revision",
274
+ type=str,
275
+ default=None,
276
+ required=False,
277
+ help="Revision of pretrained model identifier from huggingface.co/models.",
278
+ )
279
+ parser.add_argument(
280
+ "--variant",
281
+ type=str,
282
+ default=None,
283
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
284
+ )
285
+ parser.add_argument(
286
+ "--output_dir",
287
+ type=str,
288
+ default="controlnet-model",
289
+ help="The output directory where the model predictions and checkpoints will be written.",
290
+ )
291
+ parser.add_argument(
292
+ "--cache_dir",
293
+ type=str,
294
+ default=None,
295
+ help="The directory where the downloaded models and datasets will be stored.",
296
+ )
297
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
298
+ parser.add_argument(
299
+ "--resolution",
300
+ type=int,
301
+ default=512,
302
+ help=(
303
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
304
+ " resolution"
305
+ ),
306
+ )
307
+ parser.add_argument(
308
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
309
+ )
310
+ parser.add_argument("--num_train_epochs", type=int, default=1)
311
+ parser.add_argument(
312
+ "--max_train_steps",
313
+ type=int,
314
+ default=None,
315
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
316
+ )
317
+ parser.add_argument(
318
+ "--checkpointing_steps",
319
+ type=int,
320
+ default=500,
321
+ help=(
322
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
323
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
324
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
325
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
326
+ "instructions."
327
+ ),
328
+ )
329
+ parser.add_argument(
330
+ "--checkpoints_total_limit",
331
+ type=int,
332
+ default=None,
333
+ help=("Max number of checkpoints to store."),
334
+ )
335
+ parser.add_argument(
336
+ "--resume_from_checkpoint",
337
+ type=str,
338
+ default=None,
339
+ help=(
340
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
341
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
342
+ ),
343
+ )
344
+ parser.add_argument(
345
+ "--gradient_accumulation_steps",
346
+ type=int,
347
+ default=1,
348
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
349
+ )
350
+ parser.add_argument(
351
+ "--gradient_checkpointing",
352
+ action="store_true",
353
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
354
+ )
355
+ parser.add_argument(
356
+ "--upcast_vae",
357
+ action="store_true",
358
+ help="Whether or not to upcast vae to fp32",
359
+ )
360
+ parser.add_argument(
361
+ "--learning_rate",
362
+ type=float,
363
+ default=5e-6,
364
+ help="Initial learning rate (after the potential warmup period) to use.",
365
+ )
366
+ parser.add_argument(
367
+ "--scale_lr",
368
+ action="store_true",
369
+ default=False,
370
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
371
+ )
372
+ parser.add_argument(
373
+ "--lr_scheduler",
374
+ type=str,
375
+ default="constant",
376
+ help=(
377
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
378
+ ' "constant", "constant_with_warmup"]'
379
+ ),
380
+ )
381
+ parser.add_argument(
382
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
383
+ )
384
+ parser.add_argument(
385
+ "--lr_num_cycles",
386
+ type=int,
387
+ default=1,
388
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
389
+ )
390
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
391
+ parser.add_argument(
392
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
393
+ )
394
+ parser.add_argument(
395
+ "--dataloader_num_workers",
396
+ type=int,
397
+ default=0,
398
+ help=(
399
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
400
+ ),
401
+ )
402
+ parser.add_argument(
403
+ "--weighting_scheme",
404
+ type=str,
405
+ default="logit_normal",
406
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
407
+ )
408
+ parser.add_argument(
409
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
410
+ )
411
+ parser.add_argument(
412
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
413
+ )
414
+ parser.add_argument(
415
+ "--mode_scale",
416
+ type=float,
417
+ default=1.29,
418
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
419
+ )
420
+ parser.add_argument(
421
+ "--precondition_outputs",
422
+ type=int,
423
+ default=1,
424
+ help="Flag indicating if we are preconditioning the model outputs or not as done in EDM. This affects how "
425
+ "model `target` is calculated.",
426
+ )
427
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
428
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
429
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
430
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
431
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
432
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
433
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
434
+ parser.add_argument(
435
+ "--hub_model_id",
436
+ type=str,
437
+ default=None,
438
+ help="The name of the repository to keep in sync with the local `output_dir`.",
439
+ )
440
+ parser.add_argument(
441
+ "--logging_dir",
442
+ type=str,
443
+ default="logs",
444
+ help=(
445
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
446
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
447
+ ),
448
+ )
449
+ parser.add_argument(
450
+ "--allow_tf32",
451
+ action="store_true",
452
+ help=(
453
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
454
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
455
+ ),
456
+ )
457
+ parser.add_argument(
458
+ "--report_to",
459
+ type=str,
460
+ default="tensorboard",
461
+ help=(
462
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
463
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
464
+ ),
465
+ )
466
+ parser.add_argument(
467
+ "--mixed_precision",
468
+ type=str,
469
+ default=None,
470
+ choices=["no", "fp16", "bf16"],
471
+ help=(
472
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
473
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
474
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
475
+ ),
476
+ )
477
+ parser.add_argument(
478
+ "--set_grads_to_none",
479
+ action="store_true",
480
+ help=(
481
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
482
+ " behaviors, so disable this argument if it causes any problems. More info:"
483
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
484
+ ),
485
+ )
486
+ parser.add_argument(
487
+ "--dataset_name",
488
+ type=str,
489
+ default=None,
490
+ help=(
491
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
492
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
493
+ " or to a folder containing files that 🤗 Datasets can understand."
494
+ ),
495
+ )
496
+ parser.add_argument(
497
+ "--dataset_config_name",
498
+ type=str,
499
+ default=None,
500
+ help="The config of the Dataset, leave as None if there's only one config.",
501
+ )
502
+ parser.add_argument(
503
+ "--train_data_dir",
504
+ type=str,
505
+ default=None,
506
+ help=(
507
+ "A folder containing the training data. Folder contents must follow the structure described in"
508
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
509
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
510
+ ),
511
+ )
512
+ parser.add_argument(
513
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
514
+ )
515
+ parser.add_argument(
516
+ "--conditioning_image_column",
517
+ type=str,
518
+ default="conditioning_image",
519
+ help="The column of the dataset containing the controlnet conditioning image.",
520
+ )
521
+ parser.add_argument(
522
+ "--caption_column",
523
+ type=str,
524
+ default="text",
525
+ help="The column of the dataset containing a caption or a list of captions.",
526
+ )
527
+ parser.add_argument(
528
+ "--max_train_samples",
529
+ type=int,
530
+ default=None,
531
+ help=(
532
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
533
+ "value if set."
534
+ ),
535
+ )
536
+ parser.add_argument(
537
+ "--proportion_empty_prompts",
538
+ type=float,
539
+ default=0,
540
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
541
+ )
542
+ parser.add_argument(
543
+ "--max_sequence_length",
544
+ type=int,
545
+ default=77,
546
+ help="Maximum sequence length to use with with the T5 text encoder",
547
+ )
548
+ parser.add_argument(
549
+ "--dataset_preprocess_batch_size", type=int, default=1000, help="Batch size for preprocessing dataset."
550
+ )
551
+ parser.add_argument(
552
+ "--validation_prompt",
553
+ type=str,
554
+ default=None,
555
+ nargs="+",
556
+ help=(
557
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
558
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
559
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
560
+ ),
561
+ )
562
+ parser.add_argument(
563
+ "--validation_image",
564
+ type=str,
565
+ default=None,
566
+ nargs="+",
567
+ help=(
568
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
569
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
570
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
571
+ " `--validation_image` that will be used with all `--validation_prompt`s."
572
+ ),
573
+ )
574
+ parser.add_argument(
575
+ "--num_validation_images",
576
+ type=int,
577
+ default=4,
578
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
579
+ )
580
+ parser.add_argument(
581
+ "--validation_steps",
582
+ type=int,
583
+ default=100,
584
+ help=(
585
+ "Run validation every X steps. Validation consists of running the prompt"
586
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
587
+ " and logging the images."
588
+ ),
589
+ )
590
+ parser.add_argument(
591
+ "--tracker_project_name",
592
+ type=str,
593
+ default="train_controlnet",
594
+ help=(
595
+ "The `project_name` argument passed to Accelerator.init_trackers for"
596
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
597
+ ),
598
+ )
599
+
600
+ if input_args is not None:
601
+ args = parser.parse_args(input_args)
602
+ else:
603
+ args = parser.parse_args()
604
+
605
+ if args.dataset_name is None and args.train_data_dir is None:
606
+ raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
607
+
608
+ if args.dataset_name is not None and args.train_data_dir is not None:
609
+ raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
610
+
611
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
612
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
613
+
614
+ if args.validation_prompt is not None and args.validation_image is None:
615
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
616
+
617
+ if args.validation_prompt is None and args.validation_image is not None:
618
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
619
+
620
+ if (
621
+ args.validation_image is not None
622
+ and args.validation_prompt is not None
623
+ and len(args.validation_image) != 1
624
+ and len(args.validation_prompt) != 1
625
+ and len(args.validation_image) != len(args.validation_prompt)
626
+ ):
627
+ raise ValueError(
628
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
629
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
630
+ )
631
+
632
+ if args.resolution % 8 != 0:
633
+ raise ValueError(
634
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
635
+ )
636
+
637
+ return args
638
+
639
+
640
+ def make_train_dataset(args, tokenizer_one, tokenizer_two, tokenizer_three, accelerator):
641
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
642
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
643
+
644
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
645
+ # download the dataset.
646
+ if args.dataset_name is not None:
647
+ # Downloading and loading a dataset from the hub.
648
+ dataset = load_dataset(
649
+ args.dataset_name,
650
+ args.dataset_config_name,
651
+ cache_dir=args.cache_dir,
652
+ )
653
+ else:
654
+ if args.train_data_dir is not None:
655
+ dataset = load_dataset(
656
+ args.train_data_dir,
657
+ cache_dir=args.cache_dir,
658
+ )
659
+ # See more about loading custom images at
660
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
661
+
662
+ # Preprocessing the datasets.
663
+ # We need to tokenize inputs and targets.
664
+ column_names = dataset["train"].column_names
665
+
666
+ # 6. Get the column names for input/target.
667
+ if args.image_column is None:
668
+ image_column = column_names[0]
669
+ logger.info(f"image column defaulting to {image_column}")
670
+ else:
671
+ image_column = args.image_column
672
+ if image_column not in column_names:
673
+ raise ValueError(
674
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
675
+ )
676
+
677
+ if args.caption_column is None:
678
+ caption_column = column_names[1]
679
+ logger.info(f"caption column defaulting to {caption_column}")
680
+ else:
681
+ caption_column = args.caption_column
682
+ if caption_column not in column_names:
683
+ raise ValueError(
684
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
685
+ )
686
+
687
+ if args.conditioning_image_column is None:
688
+ conditioning_image_column = column_names[2]
689
+ logger.info(f"conditioning image column defaulting to {conditioning_image_column}")
690
+ else:
691
+ conditioning_image_column = args.conditioning_image_column
692
+ if conditioning_image_column not in column_names:
693
+ raise ValueError(
694
+ f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
695
+ )
696
+
697
+ def process_captions(examples, is_train=True):
698
+ captions = []
699
+ for caption in examples[caption_column]:
700
+ if random.random() < args.proportion_empty_prompts:
701
+ captions.append("")
702
+ elif isinstance(caption, str):
703
+ captions.append(caption)
704
+ elif isinstance(caption, (list, np.ndarray)):
705
+ # take a random caption if there are multiple
706
+ captions.append(random.choice(caption) if is_train else caption[0])
707
+ else:
708
+ raise ValueError(
709
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
710
+ )
711
+ return captions
712
+
713
+ image_transforms = transforms.Compose(
714
+ [
715
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
716
+ transforms.CenterCrop(args.resolution),
717
+ transforms.ToTensor(),
718
+ transforms.Normalize([0.5], [0.5]),
719
+ ]
720
+ )
721
+
722
+ conditioning_image_transforms = transforms.Compose(
723
+ [
724
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
725
+ transforms.CenterCrop(args.resolution),
726
+ transforms.ToTensor(),
727
+ ]
728
+ )
729
+
730
+ def preprocess_train(examples):
731
+ images = [image.convert("RGB") for image in examples[image_column]]
732
+ images = [image_transforms(image) for image in images]
733
+
734
+ conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]]
735
+ conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
736
+
737
+ examples["pixel_values"] = images
738
+ examples["conditioning_pixel_values"] = conditioning_images
739
+ examples["prompts"] = process_captions(examples)
740
+
741
+ return examples
742
+
743
+ with accelerator.main_process_first():
744
+ if args.max_train_samples is not None:
745
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
746
+ # Set the training transforms
747
+ train_dataset = dataset["train"].with_transform(preprocess_train)
748
+
749
+ return train_dataset
750
+
751
+
752
+ def collate_fn(examples):
753
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
754
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
755
+
756
+ conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
757
+ conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
758
+
759
+ prompt_embeds = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples])
760
+ pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples])
761
+
762
+ return {
763
+ "pixel_values": pixel_values,
764
+ "conditioning_pixel_values": conditioning_pixel_values,
765
+ "prompt_embeds": prompt_embeds,
766
+ "pooled_prompt_embeds": pooled_prompt_embeds,
767
+ }
768
+
769
+
770
+ # Copied from dreambooth sd3 example
771
+ def _encode_prompt_with_t5(
772
+ text_encoder,
773
+ tokenizer,
774
+ max_sequence_length,
775
+ prompt=None,
776
+ num_images_per_prompt=1,
777
+ device=None,
778
+ ):
779
+ prompt = [prompt] if isinstance(prompt, str) else prompt
780
+ batch_size = len(prompt)
781
+
782
+ text_inputs = tokenizer(
783
+ prompt,
784
+ padding="max_length",
785
+ max_length=max_sequence_length,
786
+ truncation=True,
787
+ add_special_tokens=True,
788
+ return_tensors="pt",
789
+ )
790
+ text_input_ids = text_inputs.input_ids
791
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
792
+
793
+ dtype = text_encoder.dtype
794
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
795
+
796
+ _, seq_len, _ = prompt_embeds.shape
797
+
798
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
799
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
800
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
801
+
802
+ return prompt_embeds
803
+
804
+
805
+ # Copied from dreambooth sd3 example
806
+ def _encode_prompt_with_clip(
807
+ text_encoder,
808
+ tokenizer,
809
+ prompt: str,
810
+ device=None,
811
+ num_images_per_prompt: int = 1,
812
+ ):
813
+ prompt = [prompt] if isinstance(prompt, str) else prompt
814
+ batch_size = len(prompt)
815
+
816
+ text_inputs = tokenizer(
817
+ prompt,
818
+ padding="max_length",
819
+ max_length=77,
820
+ truncation=True,
821
+ return_tensors="pt",
822
+ )
823
+
824
+ text_input_ids = text_inputs.input_ids
825
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
826
+
827
+ pooled_prompt_embeds = prompt_embeds[0]
828
+ prompt_embeds = prompt_embeds.hidden_states[-2]
829
+ prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
830
+
831
+ _, seq_len, _ = prompt_embeds.shape
832
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
833
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
834
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
835
+
836
+ return prompt_embeds, pooled_prompt_embeds
837
+
838
+
839
+ # Copied from dreambooth sd3 example
840
+ def encode_prompt(
841
+ text_encoders,
842
+ tokenizers,
843
+ prompt: str,
844
+ max_sequence_length,
845
+ device=None,
846
+ num_images_per_prompt: int = 1,
847
+ ):
848
+ prompt = [prompt] if isinstance(prompt, str) else prompt
849
+
850
+ clip_tokenizers = tokenizers[:2]
851
+ clip_text_encoders = text_encoders[:2]
852
+
853
+ clip_prompt_embeds_list = []
854
+ clip_pooled_prompt_embeds_list = []
855
+ for tokenizer, text_encoder in zip(clip_tokenizers, clip_text_encoders):
856
+ prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
857
+ text_encoder=text_encoder,
858
+ tokenizer=tokenizer,
859
+ prompt=prompt,
860
+ device=device if device is not None else text_encoder.device,
861
+ num_images_per_prompt=num_images_per_prompt,
862
+ )
863
+ clip_prompt_embeds_list.append(prompt_embeds)
864
+ clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
865
+
866
+ clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1)
867
+ pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1)
868
+
869
+ t5_prompt_embed = _encode_prompt_with_t5(
870
+ text_encoders[-1],
871
+ tokenizers[-1],
872
+ max_sequence_length,
873
+ prompt=prompt,
874
+ num_images_per_prompt=num_images_per_prompt,
875
+ device=device if device is not None else text_encoders[-1].device,
876
+ )
877
+
878
+ clip_prompt_embeds = torch.nn.functional.pad(
879
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
880
+ )
881
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
882
+
883
+ return prompt_embeds, pooled_prompt_embeds
884
+
885
+
886
+ def main(args):
887
+ if args.report_to == "wandb" and args.hub_token is not None:
888
+ raise ValueError(
889
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
890
+ " Please use `huggingface-cli login` to authenticate with the Hub."
891
+ )
892
+
893
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
894
+ # due to pytorch#99272, MPS does not yet support bfloat16.
895
+ raise ValueError(
896
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
897
+ )
898
+
899
+ logging_dir = Path(args.output_dir, args.logging_dir)
900
+
901
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
902
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
903
+ accelerator = Accelerator(
904
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
905
+ mixed_precision=args.mixed_precision,
906
+ log_with=args.report_to,
907
+ project_config=accelerator_project_config,
908
+ kwargs_handlers=[kwargs],
909
+ )
910
+
911
+ # Disable AMP for MPS.
912
+ if torch.backends.mps.is_available():
913
+ accelerator.native_amp = False
914
+
915
+ if args.report_to == "wandb":
916
+ if not is_wandb_available():
917
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
918
+
919
+ # Make one log on every process with the configuration for debugging.
920
+ logging.basicConfig(
921
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
922
+ datefmt="%m/%d/%Y %H:%M:%S",
923
+ level=logging.INFO,
924
+ )
925
+ logger.info(accelerator.state, main_process_only=False)
926
+ if accelerator.is_local_main_process:
927
+ transformers.utils.logging.set_verbosity_warning()
928
+ diffusers.utils.logging.set_verbosity_info()
929
+ else:
930
+ transformers.utils.logging.set_verbosity_error()
931
+ diffusers.utils.logging.set_verbosity_error()
932
+
933
+ # If passed along, set the training seed now.
934
+ if args.seed is not None:
935
+ set_seed(args.seed)
936
+
937
+ # Handle the repository creation
938
+ if accelerator.is_main_process:
939
+ if args.output_dir is not None:
940
+ os.makedirs(args.output_dir, exist_ok=True)
941
+
942
+ if args.push_to_hub:
943
+ repo_id = create_repo(
944
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
945
+ ).repo_id
946
+
947
+ # Load the tokenizer
948
+ tokenizer_one = CLIPTokenizer.from_pretrained(
949
+ args.pretrained_model_name_or_path,
950
+ subfolder="tokenizer",
951
+ revision=args.revision,
952
+ )
953
+ tokenizer_two = CLIPTokenizer.from_pretrained(
954
+ args.pretrained_model_name_or_path,
955
+ subfolder="tokenizer_2",
956
+ revision=args.revision,
957
+ )
958
+ tokenizer_three = T5TokenizerFast.from_pretrained(
959
+ args.pretrained_model_name_or_path,
960
+ subfolder="tokenizer_3",
961
+ revision=args.revision,
962
+ )
963
+
964
+ # import correct text encoder class
965
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
966
+ args.pretrained_model_name_or_path, args.revision
967
+ )
968
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
969
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
970
+ )
971
+ text_encoder_cls_three = import_model_class_from_model_name_or_path(
972
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3"
973
+ )
974
+
975
+ # Load scheduler and models
976
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
977
+ args.pretrained_model_name_or_path, subfolder="scheduler"
978
+ )
979
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
980
+ text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
981
+ text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
982
+ )
983
+ vae = AutoencoderKL.from_pretrained(
984
+ args.pretrained_model_name_or_path,
985
+ subfolder="vae",
986
+ revision=args.revision,
987
+ variant=args.variant,
988
+ )
989
+ transformer = SD3Transformer2DModel.from_pretrained(
990
+ args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
991
+ )
992
+
993
+ if args.controlnet_model_name_or_path:
994
+ logger.info("Loading existing controlnet weights")
995
+ controlnet = SD3ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
996
+ else:
997
+ logger.info("Initializing controlnet weights from transformer")
998
+ controlnet = SD3ControlNetModel.from_transformer(
999
+ transformer, num_extra_conditioning_channels=args.num_extra_conditioning_channels
1000
+ )
1001
+
1002
+ transformer.requires_grad_(False)
1003
+ vae.requires_grad_(False)
1004
+ text_encoder_one.requires_grad_(False)
1005
+ text_encoder_two.requires_grad_(False)
1006
+ text_encoder_three.requires_grad_(False)
1007
+ controlnet.train()
1008
+
1009
+ # Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
1010
+ def unwrap_model(model):
1011
+ model = accelerator.unwrap_model(model)
1012
+ model = model._orig_mod if is_compiled_module(model) else model
1013
+ return model
1014
+
1015
+ # `accelerate` 0.16.0 will have better support for customized saving
1016
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
1017
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
1018
+ def save_model_hook(models, weights, output_dir):
1019
+ if accelerator.is_main_process:
1020
+ i = len(weights) - 1
1021
+
1022
+ while len(weights) > 0:
1023
+ weights.pop()
1024
+ model = models[i]
1025
+
1026
+ sub_dir = "controlnet"
1027
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
1028
+
1029
+ i -= 1
1030
+
1031
+ def load_model_hook(models, input_dir):
1032
+ while len(models) > 0:
1033
+ # pop models so that they are not loaded again
1034
+ model = models.pop()
1035
+
1036
+ # load diffusers style into model
1037
+ load_model = SD3ControlNetModel.from_pretrained(input_dir, subfolder="controlnet")
1038
+ model.register_to_config(**load_model.config)
1039
+
1040
+ model.load_state_dict(load_model.state_dict())
1041
+ del load_model
1042
+
1043
+ accelerator.register_save_state_pre_hook(save_model_hook)
1044
+ accelerator.register_load_state_pre_hook(load_model_hook)
1045
+
1046
+ if args.gradient_checkpointing:
1047
+ controlnet.enable_gradient_checkpointing()
1048
+
1049
+ # Check that all trainable models are in full precision
1050
+ low_precision_error_string = (
1051
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
1052
+ " doing mixed precision training, copy of the weights should still be float32."
1053
+ )
1054
+
1055
+ if unwrap_model(controlnet).dtype != torch.float32:
1056
+ raise ValueError(
1057
+ f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}"
1058
+ )
1059
+
1060
+ # Enable TF32 for faster training on Ampere GPUs,
1061
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
1062
+ if args.allow_tf32:
1063
+ torch.backends.cuda.matmul.allow_tf32 = True
1064
+
1065
+ if args.scale_lr:
1066
+ args.learning_rate = (
1067
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
1068
+ )
1069
+
1070
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
1071
+ if args.use_8bit_adam:
1072
+ try:
1073
+ import bitsandbytes as bnb
1074
+ except ImportError:
1075
+ raise ImportError(
1076
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
1077
+ )
1078
+
1079
+ optimizer_class = bnb.optim.AdamW8bit
1080
+ else:
1081
+ optimizer_class = torch.optim.AdamW
1082
+
1083
+ # Optimizer creation
1084
+ params_to_optimize = controlnet.parameters()
1085
+ optimizer = optimizer_class(
1086
+ params_to_optimize,
1087
+ lr=args.learning_rate,
1088
+ betas=(args.adam_beta1, args.adam_beta2),
1089
+ weight_decay=args.adam_weight_decay,
1090
+ eps=args.adam_epsilon,
1091
+ )
1092
+
1093
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
1094
+ # as these models are only used for inference, keeping weights in full precision is not required.
1095
+ weight_dtype = torch.float32
1096
+ if accelerator.mixed_precision == "fp16":
1097
+ weight_dtype = torch.float16
1098
+ elif accelerator.mixed_precision == "bf16":
1099
+ weight_dtype = torch.bfloat16
1100
+
1101
+ # Move vae, transformer and text_encoder to device and cast to weight_dtype
1102
+ if args.upcast_vae:
1103
+ vae.to(accelerator.device, dtype=torch.float32)
1104
+ else:
1105
+ vae.to(accelerator.device, dtype=weight_dtype)
1106
+ transformer.to(accelerator.device, dtype=weight_dtype)
1107
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
1108
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
1109
+ text_encoder_three.to(accelerator.device, dtype=weight_dtype)
1110
+
1111
+ train_dataset = make_train_dataset(args, tokenizer_one, tokenizer_two, tokenizer_three, accelerator)
1112
+
1113
+ tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three]
1114
+ text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three]
1115
+
1116
+ def compute_text_embeddings(batch, text_encoders, tokenizers):
1117
+ with torch.no_grad():
1118
+ prompt = batch["prompts"]
1119
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
1120
+ text_encoders, tokenizers, prompt, args.max_sequence_length
1121
+ )
1122
+ prompt_embeds = prompt_embeds.to(accelerator.device)
1123
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
1124
+ return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds}
1125
+
1126
+ compute_embeddings_fn = functools.partial(
1127
+ compute_text_embeddings,
1128
+ text_encoders=text_encoders,
1129
+ tokenizers=tokenizers,
1130
+ )
1131
+ with accelerator.main_process_first():
1132
+ from datasets.fingerprint import Hasher
1133
+
1134
+ # fingerprint used by the cache for the other processes to load the result
1135
+ # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
1136
+ new_fingerprint = Hasher.hash(args)
1137
+ train_dataset = train_dataset.map(
1138
+ compute_embeddings_fn,
1139
+ batched=True,
1140
+ batch_size=args.dataset_preprocess_batch_size,
1141
+ new_fingerprint=new_fingerprint,
1142
+ )
1143
+
1144
+ del text_encoder_one, text_encoder_two, text_encoder_three
1145
+ del tokenizer_one, tokenizer_two, tokenizer_three
1146
+ free_memory()
1147
+
1148
+ train_dataloader = torch.utils.data.DataLoader(
1149
+ train_dataset,
1150
+ shuffle=True,
1151
+ collate_fn=collate_fn,
1152
+ batch_size=args.train_batch_size,
1153
+ num_workers=args.dataloader_num_workers,
1154
+ )
1155
+
1156
+ # Scheduler and math around the number of training steps.
1157
+ overrode_max_train_steps = False
1158
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1159
+ if args.max_train_steps is None:
1160
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1161
+ overrode_max_train_steps = True
1162
+
1163
+ lr_scheduler = get_scheduler(
1164
+ args.lr_scheduler,
1165
+ optimizer=optimizer,
1166
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1167
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
1168
+ num_cycles=args.lr_num_cycles,
1169
+ power=args.lr_power,
1170
+ )
1171
+
1172
+ # Prepare everything with our `accelerator`.
1173
+ controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1174
+ controlnet, optimizer, train_dataloader, lr_scheduler
1175
+ )
1176
+
1177
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1178
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1179
+ if overrode_max_train_steps:
1180
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1181
+ # Afterwards we recalculate our number of training epochs
1182
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1183
+
1184
+ # We need to initialize the trackers we use, and also store our configuration.
1185
+ # The trackers initializes automatically on the main process.
1186
+ if accelerator.is_main_process:
1187
+ tracker_config = dict(vars(args))
1188
+
1189
+ # tensorboard cannot handle list types for config
1190
+ tracker_config.pop("validation_prompt")
1191
+ tracker_config.pop("validation_image")
1192
+
1193
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
1194
+
1195
+ # Train!
1196
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1197
+
1198
+ logger.info("***** Running training *****")
1199
+ logger.info(f" Num examples = {len(train_dataset)}")
1200
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1201
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1202
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1203
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1204
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1205
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1206
+ global_step = 0
1207
+ first_epoch = 0
1208
+
1209
+ # Potentially load in the weights and states from a previous save
1210
+ if args.resume_from_checkpoint:
1211
+ if args.resume_from_checkpoint != "latest":
1212
+ path = os.path.basename(args.resume_from_checkpoint)
1213
+ else:
1214
+ # Get the most recent checkpoint
1215
+ dirs = os.listdir(args.output_dir)
1216
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1217
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1218
+ path = dirs[-1] if len(dirs) > 0 else None
1219
+
1220
+ if path is None:
1221
+ accelerator.print(
1222
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1223
+ )
1224
+ args.resume_from_checkpoint = None
1225
+ initial_global_step = 0
1226
+ else:
1227
+ accelerator.print(f"Resuming from checkpoint {path}")
1228
+ accelerator.load_state(os.path.join(args.output_dir, path))
1229
+ global_step = int(path.split("-")[1])
1230
+
1231
+ initial_global_step = global_step
1232
+ first_epoch = global_step // num_update_steps_per_epoch
1233
+ else:
1234
+ initial_global_step = 0
1235
+
1236
+ progress_bar = tqdm(
1237
+ range(0, args.max_train_steps),
1238
+ initial=initial_global_step,
1239
+ desc="Steps",
1240
+ # Only show the progress bar once on each machine.
1241
+ disable=not accelerator.is_local_main_process,
1242
+ )
1243
+
1244
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1245
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
1246
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
1247
+ timesteps = timesteps.to(accelerator.device)
1248
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
1249
+
1250
+ sigma = sigmas[step_indices].flatten()
1251
+ while len(sigma.shape) < n_dim:
1252
+ sigma = sigma.unsqueeze(-1)
1253
+ return sigma
1254
+
1255
+ image_logs = None
1256
+ for epoch in range(first_epoch, args.num_train_epochs):
1257
+ for step, batch in enumerate(train_dataloader):
1258
+ with accelerator.accumulate(controlnet):
1259
+ # Convert images to latent space
1260
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1261
+ model_input = vae.encode(pixel_values).latent_dist.sample()
1262
+ model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
1263
+ model_input = model_input.to(dtype=weight_dtype)
1264
+
1265
+ # Sample noise that we'll add to the latents
1266
+ noise = torch.randn_like(model_input)
1267
+ bsz = model_input.shape[0]
1268
+ # Sample a random timestep for each image
1269
+ # for weighting schemes where we sample timesteps non-uniformly
1270
+ u = compute_density_for_timestep_sampling(
1271
+ weighting_scheme=args.weighting_scheme,
1272
+ batch_size=bsz,
1273
+ logit_mean=args.logit_mean,
1274
+ logit_std=args.logit_std,
1275
+ mode_scale=args.mode_scale,
1276
+ )
1277
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
1278
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
1279
+
1280
+ # Add noise according to flow matching.
1281
+ # zt = (1 - texp) * x + texp * z1
1282
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
1283
+ noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
1284
+
1285
+ # Get the text embedding for conditioning
1286
+ prompt_embeds = batch["prompt_embeds"].to(dtype=weight_dtype)
1287
+ pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(dtype=weight_dtype)
1288
+
1289
+ # controlnet(s) inference
1290
+ controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
1291
+ controlnet_image = vae.encode(controlnet_image).latent_dist.sample()
1292
+ controlnet_image = controlnet_image * vae.config.scaling_factor
1293
+
1294
+ control_block_res_samples = controlnet(
1295
+ hidden_states=noisy_model_input,
1296
+ timestep=timesteps,
1297
+ encoder_hidden_states=prompt_embeds,
1298
+ pooled_projections=pooled_prompt_embeds,
1299
+ controlnet_cond=controlnet_image,
1300
+ return_dict=False,
1301
+ )[0]
1302
+ control_block_res_samples = [sample.to(dtype=weight_dtype) for sample in control_block_res_samples]
1303
+
1304
+ # Predict the noise residual
1305
+ model_pred = transformer(
1306
+ hidden_states=noisy_model_input,
1307
+ timestep=timesteps,
1308
+ encoder_hidden_states=prompt_embeds,
1309
+ pooled_projections=pooled_prompt_embeds,
1310
+ block_controlnet_hidden_states=control_block_res_samples,
1311
+ return_dict=False,
1312
+ )[0]
1313
+
1314
+ # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
1315
+ # Preconditioning of the model outputs.
1316
+ if args.precondition_outputs:
1317
+ model_pred = model_pred * (-sigmas) + noisy_model_input
1318
+
1319
+ # these weighting schemes use a uniform timestep sampling
1320
+ # and instead post-weight the loss
1321
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
1322
+
1323
+ # flow matching loss
1324
+ if args.precondition_outputs:
1325
+ target = model_input
1326
+ else:
1327
+ target = noise - model_input
1328
+
1329
+ # Compute regular loss.
1330
+ loss = torch.mean(
1331
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
1332
+ 1,
1333
+ )
1334
+ loss = loss.mean()
1335
+
1336
+ accelerator.backward(loss)
1337
+ if accelerator.sync_gradients:
1338
+ params_to_clip = controlnet.parameters()
1339
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1340
+ optimizer.step()
1341
+ lr_scheduler.step()
1342
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
1343
+
1344
+ # Checks if the accelerator has performed an optimization step behind the scenes
1345
+ if accelerator.sync_gradients:
1346
+ progress_bar.update(1)
1347
+ global_step += 1
1348
+
1349
+ if accelerator.is_main_process:
1350
+ if global_step % args.checkpointing_steps == 0:
1351
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1352
+ if args.checkpoints_total_limit is not None:
1353
+ checkpoints = os.listdir(args.output_dir)
1354
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1355
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1356
+
1357
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1358
+ if len(checkpoints) >= args.checkpoints_total_limit:
1359
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1360
+ removing_checkpoints = checkpoints[0:num_to_remove]
1361
+
1362
+ logger.info(
1363
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1364
+ )
1365
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1366
+
1367
+ for removing_checkpoint in removing_checkpoints:
1368
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1369
+ shutil.rmtree(removing_checkpoint)
1370
+
1371
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1372
+ accelerator.save_state(save_path)
1373
+ logger.info(f"Saved state to {save_path}")
1374
+
1375
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
1376
+ image_logs = log_validation(
1377
+ controlnet,
1378
+ args,
1379
+ accelerator,
1380
+ weight_dtype,
1381
+ global_step,
1382
+ )
1383
+
1384
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1385
+ progress_bar.set_postfix(**logs)
1386
+ accelerator.log(logs, step=global_step)
1387
+
1388
+ if global_step >= args.max_train_steps:
1389
+ break
1390
+
1391
+ # Create the pipeline using using the trained modules and save it.
1392
+ accelerator.wait_for_everyone()
1393
+ if accelerator.is_main_process:
1394
+ controlnet = unwrap_model(controlnet)
1395
+ controlnet.save_pretrained(args.output_dir)
1396
+
1397
+ # Run a final round of validation.
1398
+ image_logs = None
1399
+ if args.validation_prompt is not None:
1400
+ image_logs = log_validation(
1401
+ controlnet=None,
1402
+ args=args,
1403
+ accelerator=accelerator,
1404
+ weight_dtype=weight_dtype,
1405
+ step=global_step,
1406
+ is_final_validation=True,
1407
+ )
1408
+
1409
+ if args.push_to_hub:
1410
+ save_model_card(
1411
+ repo_id,
1412
+ image_logs=image_logs,
1413
+ base_model=args.pretrained_model_name_or_path,
1414
+ repo_folder=args.output_dir,
1415
+ )
1416
+ upload_folder(
1417
+ repo_id=repo_id,
1418
+ folder_path=args.output_dir,
1419
+ commit_message="End of training",
1420
+ ignore_patterns=["step_*", "epoch_*"],
1421
+ )
1422
+
1423
+ accelerator.end_training()
1424
+
1425
+
1426
+ if __name__ == "__main__":
1427
+ args = parse_args()
1428
+ main(args)
train_controlnet_sdxl.py ADDED
@@ -0,0 +1,1344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import functools
18
+ import gc
19
+ import logging
20
+ import math
21
+ import os
22
+ import random
23
+ import shutil
24
+ from contextlib import nullcontext
25
+ from pathlib import Path
26
+
27
+ import accelerate
28
+ import numpy as np
29
+ import torch
30
+ import torch.nn.functional as F
31
+ import torch.utils.checkpoint
32
+ import transformers
33
+ from accelerate import Accelerator
34
+ from accelerate.logging import get_logger
35
+ from accelerate.utils import DistributedType, ProjectConfiguration, set_seed
36
+ from datasets import load_dataset
37
+ from huggingface_hub import create_repo, upload_folder
38
+ from packaging import version
39
+ from PIL import Image
40
+ from torchvision import transforms
41
+ from tqdm.auto import tqdm
42
+ from transformers import AutoTokenizer, PretrainedConfig
43
+
44
+ import diffusers
45
+ from diffusers import (
46
+ AutoencoderKL,
47
+ ControlNetModel,
48
+ DDPMScheduler,
49
+ StableDiffusionXLControlNetPipeline,
50
+ UNet2DConditionModel,
51
+ UniPCMultistepScheduler,
52
+ )
53
+ from diffusers.optimization import get_scheduler
54
+ from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
55
+ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
56
+ from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
57
+ from diffusers.utils.torch_utils import is_compiled_module
58
+
59
+
60
+ if is_wandb_available():
61
+ import wandb
62
+
63
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
64
+ check_min_version("0.33.0.dev0")
65
+
66
+ logger = get_logger(__name__)
67
+ if is_torch_npu_available():
68
+ torch.npu.config.allow_internal_format = False
69
+
70
+
71
+ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False):
72
+ logger.info("Running validation... ")
73
+
74
+ if not is_final_validation:
75
+ controlnet = accelerator.unwrap_model(controlnet)
76
+ pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
77
+ args.pretrained_model_name_or_path,
78
+ vae=vae,
79
+ unet=unet,
80
+ controlnet=controlnet,
81
+ revision=args.revision,
82
+ variant=args.variant,
83
+ torch_dtype=weight_dtype,
84
+ )
85
+ else:
86
+ controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
87
+ if args.pretrained_vae_model_name_or_path is not None:
88
+ vae = AutoencoderKL.from_pretrained(args.pretrained_vae_model_name_or_path, torch_dtype=weight_dtype)
89
+ else:
90
+ vae = AutoencoderKL.from_pretrained(
91
+ args.pretrained_model_name_or_path, subfolder="vae", torch_dtype=weight_dtype
92
+ )
93
+
94
+ pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
95
+ args.pretrained_model_name_or_path,
96
+ vae=vae,
97
+ controlnet=controlnet,
98
+ revision=args.revision,
99
+ variant=args.variant,
100
+ torch_dtype=weight_dtype,
101
+ )
102
+
103
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
104
+ pipeline = pipeline.to(accelerator.device)
105
+ pipeline.set_progress_bar_config(disable=True)
106
+
107
+ if args.enable_xformers_memory_efficient_attention:
108
+ pipeline.enable_xformers_memory_efficient_attention()
109
+
110
+ if args.seed is None:
111
+ generator = None
112
+ else:
113
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
114
+
115
+ if len(args.validation_image) == len(args.validation_prompt):
116
+ validation_images = args.validation_image
117
+ validation_prompts = args.validation_prompt
118
+ elif len(args.validation_image) == 1:
119
+ validation_images = args.validation_image * len(args.validation_prompt)
120
+ validation_prompts = args.validation_prompt
121
+ elif len(args.validation_prompt) == 1:
122
+ validation_images = args.validation_image
123
+ validation_prompts = args.validation_prompt * len(args.validation_image)
124
+ else:
125
+ raise ValueError(
126
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
127
+ )
128
+
129
+ image_logs = []
130
+ if is_final_validation or torch.backends.mps.is_available():
131
+ autocast_ctx = nullcontext()
132
+ else:
133
+ autocast_ctx = torch.autocast(accelerator.device.type)
134
+
135
+ for validation_prompt, validation_image in zip(validation_prompts, validation_images):
136
+ validation_image = Image.open(validation_image).convert("RGB")
137
+ validation_image = validation_image.resize((args.resolution, args.resolution))
138
+
139
+ images = []
140
+
141
+ for _ in range(args.num_validation_images):
142
+ with autocast_ctx:
143
+ image = pipeline(
144
+ prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
145
+ ).images[0]
146
+ images.append(image)
147
+
148
+ image_logs.append(
149
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
150
+ )
151
+
152
+ tracker_key = "test" if is_final_validation else "validation"
153
+ for tracker in accelerator.trackers:
154
+ if tracker.name == "tensorboard":
155
+ for log in image_logs:
156
+ images = log["images"]
157
+ validation_prompt = log["validation_prompt"]
158
+ validation_image = log["validation_image"]
159
+
160
+ formatted_images = [np.asarray(validation_image)]
161
+
162
+ for image in images:
163
+ formatted_images.append(np.asarray(image))
164
+
165
+ formatted_images = np.stack(formatted_images)
166
+
167
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
168
+ elif tracker.name == "wandb":
169
+ formatted_images = []
170
+
171
+ for log in image_logs:
172
+ images = log["images"]
173
+ validation_prompt = log["validation_prompt"]
174
+ validation_image = log["validation_image"]
175
+
176
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
177
+
178
+ for image in images:
179
+ image = wandb.Image(image, caption=validation_prompt)
180
+ formatted_images.append(image)
181
+
182
+ tracker.log({tracker_key: formatted_images})
183
+ else:
184
+ logger.warning(f"image logging not implemented for {tracker.name}")
185
+
186
+ del pipeline
187
+ gc.collect()
188
+ torch.cuda.empty_cache()
189
+
190
+ return image_logs
191
+
192
+
193
+ def import_model_class_from_model_name_or_path(
194
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
195
+ ):
196
+ text_encoder_config = PretrainedConfig.from_pretrained(
197
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
198
+ )
199
+ model_class = text_encoder_config.architectures[0]
200
+
201
+ if model_class == "CLIPTextModel":
202
+ from transformers import CLIPTextModel
203
+
204
+ return CLIPTextModel
205
+ elif model_class == "CLIPTextModelWithProjection":
206
+ from transformers import CLIPTextModelWithProjection
207
+
208
+ return CLIPTextModelWithProjection
209
+ else:
210
+ raise ValueError(f"{model_class} is not supported.")
211
+
212
+
213
+ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
214
+ img_str = ""
215
+ if image_logs is not None:
216
+ img_str = "You can find some example images below.\n\n"
217
+ for i, log in enumerate(image_logs):
218
+ images = log["images"]
219
+ validation_prompt = log["validation_prompt"]
220
+ validation_image = log["validation_image"]
221
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
222
+ img_str += f"prompt: {validation_prompt}\n"
223
+ images = [validation_image] + images
224
+ make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
225
+ img_str += f"![images_{i})](./images_{i}.png)\n"
226
+
227
+ model_description = f"""
228
+ # controlnet-{repo_id}
229
+
230
+ These are controlnet weights trained on {base_model} with new type of conditioning.
231
+ {img_str}
232
+ """
233
+
234
+ model_card = load_or_create_model_card(
235
+ repo_id_or_path=repo_id,
236
+ from_training=True,
237
+ license="openrail++",
238
+ base_model=base_model,
239
+ model_description=model_description,
240
+ inference=True,
241
+ )
242
+
243
+ tags = [
244
+ "stable-diffusion-xl",
245
+ "stable-diffusion-xl-diffusers",
246
+ "text-to-image",
247
+ "diffusers",
248
+ "controlnet",
249
+ "diffusers-training",
250
+ ]
251
+ model_card = populate_model_card(model_card, tags=tags)
252
+
253
+ model_card.save(os.path.join(repo_folder, "README.md"))
254
+
255
+
256
+ def parse_args(input_args=None):
257
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
258
+ parser.add_argument(
259
+ "--pretrained_model_name_or_path",
260
+ type=str,
261
+ default=None,
262
+ required=True,
263
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
264
+ )
265
+ parser.add_argument(
266
+ "--pretrained_vae_model_name_or_path",
267
+ type=str,
268
+ default=None,
269
+ help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
270
+ )
271
+ parser.add_argument(
272
+ "--controlnet_model_name_or_path",
273
+ type=str,
274
+ default=None,
275
+ help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
276
+ " If not specified controlnet weights are initialized from unet.",
277
+ )
278
+ parser.add_argument(
279
+ "--variant",
280
+ type=str,
281
+ default=None,
282
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
283
+ )
284
+ parser.add_argument(
285
+ "--revision",
286
+ type=str,
287
+ default=None,
288
+ required=False,
289
+ help="Revision of pretrained model identifier from huggingface.co/models.",
290
+ )
291
+ parser.add_argument(
292
+ "--tokenizer_name",
293
+ type=str,
294
+ default=None,
295
+ help="Pretrained tokenizer name or path if not the same as model_name",
296
+ )
297
+ parser.add_argument(
298
+ "--output_dir",
299
+ type=str,
300
+ default="controlnet-model",
301
+ help="The output directory where the model predictions and checkpoints will be written.",
302
+ )
303
+ parser.add_argument(
304
+ "--cache_dir",
305
+ type=str,
306
+ default=None,
307
+ help="The directory where the downloaded models and datasets will be stored.",
308
+ )
309
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
310
+ parser.add_argument(
311
+ "--resolution",
312
+ type=int,
313
+ default=512,
314
+ help=(
315
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
316
+ " resolution"
317
+ ),
318
+ )
319
+ parser.add_argument(
320
+ "--crops_coords_top_left_h",
321
+ type=int,
322
+ default=0,
323
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
324
+ )
325
+ parser.add_argument(
326
+ "--crops_coords_top_left_w",
327
+ type=int,
328
+ default=0,
329
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
330
+ )
331
+ parser.add_argument(
332
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
333
+ )
334
+ parser.add_argument("--num_train_epochs", type=int, default=1)
335
+ parser.add_argument(
336
+ "--max_train_steps",
337
+ type=int,
338
+ default=None,
339
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
340
+ )
341
+ parser.add_argument(
342
+ "--checkpointing_steps",
343
+ type=int,
344
+ default=500,
345
+ help=(
346
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
347
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
348
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
349
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
350
+ "instructions."
351
+ ),
352
+ )
353
+ parser.add_argument(
354
+ "--checkpoints_total_limit",
355
+ type=int,
356
+ default=None,
357
+ help=("Max number of checkpoints to store."),
358
+ )
359
+ parser.add_argument(
360
+ "--resume_from_checkpoint",
361
+ type=str,
362
+ default=None,
363
+ help=(
364
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
365
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
366
+ ),
367
+ )
368
+ parser.add_argument(
369
+ "--gradient_accumulation_steps",
370
+ type=int,
371
+ default=1,
372
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
373
+ )
374
+ parser.add_argument(
375
+ "--gradient_checkpointing",
376
+ action="store_true",
377
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
378
+ )
379
+ parser.add_argument(
380
+ "--learning_rate",
381
+ type=float,
382
+ default=5e-6,
383
+ help="Initial learning rate (after the potential warmup period) to use.",
384
+ )
385
+ parser.add_argument(
386
+ "--scale_lr",
387
+ action="store_true",
388
+ default=False,
389
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
390
+ )
391
+ parser.add_argument(
392
+ "--lr_scheduler",
393
+ type=str,
394
+ default="constant",
395
+ help=(
396
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
397
+ ' "constant", "constant_with_warmup"]'
398
+ ),
399
+ )
400
+ parser.add_argument(
401
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
402
+ )
403
+ parser.add_argument(
404
+ "--lr_num_cycles",
405
+ type=int,
406
+ default=1,
407
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
408
+ )
409
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
410
+ parser.add_argument(
411
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
412
+ )
413
+ parser.add_argument(
414
+ "--dataloader_num_workers",
415
+ type=int,
416
+ default=0,
417
+ help=(
418
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
419
+ ),
420
+ )
421
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
422
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
423
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
424
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
425
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
426
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
427
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
428
+ parser.add_argument(
429
+ "--hub_model_id",
430
+ type=str,
431
+ default=None,
432
+ help="The name of the repository to keep in sync with the local `output_dir`.",
433
+ )
434
+ parser.add_argument(
435
+ "--logging_dir",
436
+ type=str,
437
+ default="logs",
438
+ help=(
439
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
440
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
441
+ ),
442
+ )
443
+ parser.add_argument(
444
+ "--allow_tf32",
445
+ action="store_true",
446
+ help=(
447
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
448
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
449
+ ),
450
+ )
451
+ parser.add_argument(
452
+ "--report_to",
453
+ type=str,
454
+ default="tensorboard",
455
+ help=(
456
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
457
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
458
+ ),
459
+ )
460
+ parser.add_argument(
461
+ "--mixed_precision",
462
+ type=str,
463
+ default=None,
464
+ choices=["no", "fp16", "bf16"],
465
+ help=(
466
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
467
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
468
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
469
+ ),
470
+ )
471
+ parser.add_argument(
472
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
473
+ )
474
+ parser.add_argument(
475
+ "--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
476
+ )
477
+ parser.add_argument(
478
+ "--set_grads_to_none",
479
+ action="store_true",
480
+ help=(
481
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
482
+ " behaviors, so disable this argument if it causes any problems. More info:"
483
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
484
+ ),
485
+ )
486
+ parser.add_argument(
487
+ "--dataset_name",
488
+ type=str,
489
+ default=None,
490
+ help=(
491
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
492
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
493
+ " or to a folder containing files that 🤗 Datasets can understand."
494
+ ),
495
+ )
496
+ parser.add_argument(
497
+ "--dataset_config_name",
498
+ type=str,
499
+ default=None,
500
+ help="The config of the Dataset, leave as None if there's only one config.",
501
+ )
502
+ parser.add_argument(
503
+ "--train_data_dir",
504
+ type=str,
505
+ default=None,
506
+ help=(
507
+ "A folder containing the training data. Folder contents must follow the structure described in"
508
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
509
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
510
+ ),
511
+ )
512
+ parser.add_argument(
513
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
514
+ )
515
+ parser.add_argument(
516
+ "--conditioning_image_column",
517
+ type=str,
518
+ default="conditioning_image",
519
+ help="The column of the dataset containing the controlnet conditioning image.",
520
+ )
521
+ parser.add_argument(
522
+ "--caption_column",
523
+ type=str,
524
+ default="text",
525
+ help="The column of the dataset containing a caption or a list of captions.",
526
+ )
527
+ parser.add_argument(
528
+ "--max_train_samples",
529
+ type=int,
530
+ default=None,
531
+ help=(
532
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
533
+ "value if set."
534
+ ),
535
+ )
536
+ parser.add_argument(
537
+ "--proportion_empty_prompts",
538
+ type=float,
539
+ default=0,
540
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
541
+ )
542
+ parser.add_argument(
543
+ "--validation_prompt",
544
+ type=str,
545
+ default=None,
546
+ nargs="+",
547
+ help=(
548
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
549
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
550
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
551
+ ),
552
+ )
553
+ parser.add_argument(
554
+ "--validation_image",
555
+ type=str,
556
+ default=None,
557
+ nargs="+",
558
+ help=(
559
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
560
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
561
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
562
+ " `--validation_image` that will be used with all `--validation_prompt`s."
563
+ ),
564
+ )
565
+ parser.add_argument(
566
+ "--num_validation_images",
567
+ type=int,
568
+ default=4,
569
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
570
+ )
571
+ parser.add_argument(
572
+ "--validation_steps",
573
+ type=int,
574
+ default=100,
575
+ help=(
576
+ "Run validation every X steps. Validation consists of running the prompt"
577
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
578
+ " and logging the images."
579
+ ),
580
+ )
581
+ parser.add_argument(
582
+ "--tracker_project_name",
583
+ type=str,
584
+ default="sd_xl_train_controlnet",
585
+ help=(
586
+ "The `project_name` argument passed to Accelerator.init_trackers for"
587
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
588
+ ),
589
+ )
590
+
591
+ if input_args is not None:
592
+ args = parser.parse_args(input_args)
593
+ else:
594
+ args = parser.parse_args()
595
+
596
+ if args.dataset_name is None and args.train_data_dir is None:
597
+ raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
598
+
599
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
600
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
601
+
602
+ if args.validation_prompt is not None and args.validation_image is None:
603
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
604
+
605
+ if args.validation_prompt is None and args.validation_image is not None:
606
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
607
+
608
+ if (
609
+ args.validation_image is not None
610
+ and args.validation_prompt is not None
611
+ and len(args.validation_image) != 1
612
+ and len(args.validation_prompt) != 1
613
+ and len(args.validation_image) != len(args.validation_prompt)
614
+ ):
615
+ raise ValueError(
616
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
617
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
618
+ )
619
+
620
+ if args.resolution % 8 != 0:
621
+ raise ValueError(
622
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
623
+ )
624
+
625
+ return args
626
+
627
+
628
+ def get_train_dataset(args, accelerator):
629
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
630
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
631
+
632
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
633
+ # download the dataset.
634
+ if args.dataset_name is not None:
635
+ # Downloading and loading a dataset from the hub.
636
+ dataset = load_dataset(
637
+ args.dataset_name,
638
+ args.dataset_config_name,
639
+ cache_dir=args.cache_dir,
640
+ data_dir=args.train_data_dir,
641
+ )
642
+ else:
643
+ if args.train_data_dir is not None:
644
+ dataset = load_dataset(
645
+ args.train_data_dir,
646
+ cache_dir=args.cache_dir,
647
+ )
648
+ # See more about loading custom images at
649
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
650
+
651
+ # Preprocessing the datasets.
652
+ # We need to tokenize inputs and targets.
653
+ column_names = dataset["train"].column_names
654
+
655
+ # 6. Get the column names for input/target.
656
+ if args.image_column is None:
657
+ image_column = column_names[0]
658
+ logger.info(f"image column defaulting to {image_column}")
659
+ else:
660
+ image_column = args.image_column
661
+ if image_column not in column_names:
662
+ raise ValueError(
663
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
664
+ )
665
+
666
+ if args.caption_column is None:
667
+ caption_column = column_names[1]
668
+ logger.info(f"caption column defaulting to {caption_column}")
669
+ else:
670
+ caption_column = args.caption_column
671
+ if caption_column not in column_names:
672
+ raise ValueError(
673
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
674
+ )
675
+
676
+ if args.conditioning_image_column is None:
677
+ conditioning_image_column = column_names[2]
678
+ logger.info(f"conditioning image column defaulting to {conditioning_image_column}")
679
+ else:
680
+ conditioning_image_column = args.conditioning_image_column
681
+ if conditioning_image_column not in column_names:
682
+ raise ValueError(
683
+ f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
684
+ )
685
+
686
+ with accelerator.main_process_first():
687
+ train_dataset = dataset["train"].shuffle(seed=args.seed)
688
+ if args.max_train_samples is not None:
689
+ train_dataset = train_dataset.select(range(args.max_train_samples))
690
+ return train_dataset
691
+
692
+
693
+ # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
694
+ def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True):
695
+ prompt_embeds_list = []
696
+
697
+ captions = []
698
+ for caption in prompt_batch:
699
+ if random.random() < proportion_empty_prompts:
700
+ captions.append("")
701
+ elif isinstance(caption, str):
702
+ captions.append(caption)
703
+ elif isinstance(caption, (list, np.ndarray)):
704
+ # take a random caption if there are multiple
705
+ captions.append(random.choice(caption) if is_train else caption[0])
706
+
707
+ with torch.no_grad():
708
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
709
+ text_inputs = tokenizer(
710
+ captions,
711
+ padding="max_length",
712
+ max_length=tokenizer.model_max_length,
713
+ truncation=True,
714
+ return_tensors="pt",
715
+ )
716
+ text_input_ids = text_inputs.input_ids
717
+ prompt_embeds = text_encoder(
718
+ text_input_ids.to(text_encoder.device),
719
+ output_hidden_states=True,
720
+ )
721
+
722
+ # We are only ALWAYS interested in the pooled output of the final text encoder
723
+ pooled_prompt_embeds = prompt_embeds[0]
724
+ prompt_embeds = prompt_embeds.hidden_states[-2]
725
+ bs_embed, seq_len, _ = prompt_embeds.shape
726
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
727
+ prompt_embeds_list.append(prompt_embeds)
728
+
729
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
730
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
731
+ return prompt_embeds, pooled_prompt_embeds
732
+
733
+
734
+ def prepare_train_dataset(dataset, accelerator):
735
+ image_transforms = transforms.Compose(
736
+ [
737
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
738
+ transforms.CenterCrop(args.resolution),
739
+ transforms.ToTensor(),
740
+ transforms.Normalize([0.5], [0.5]),
741
+ ]
742
+ )
743
+
744
+ conditioning_image_transforms = transforms.Compose(
745
+ [
746
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
747
+ transforms.CenterCrop(args.resolution),
748
+ transforms.ToTensor(),
749
+ ]
750
+ )
751
+
752
+ def preprocess_train(examples):
753
+ images = [image.convert("RGB") for image in examples[args.image_column]]
754
+ images = [image_transforms(image) for image in images]
755
+
756
+ conditioning_images = [image.convert("RGB") for image in examples[args.conditioning_image_column]]
757
+ conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
758
+
759
+ examples["pixel_values"] = images
760
+ examples["conditioning_pixel_values"] = conditioning_images
761
+
762
+ return examples
763
+
764
+ with accelerator.main_process_first():
765
+ dataset = dataset.with_transform(preprocess_train)
766
+
767
+ return dataset
768
+
769
+
770
+ def collate_fn(examples):
771
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
772
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
773
+
774
+ conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
775
+ conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
776
+
777
+ prompt_ids = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples])
778
+
779
+ add_text_embeds = torch.stack([torch.tensor(example["text_embeds"]) for example in examples])
780
+ add_time_ids = torch.stack([torch.tensor(example["time_ids"]) for example in examples])
781
+
782
+ return {
783
+ "pixel_values": pixel_values,
784
+ "conditioning_pixel_values": conditioning_pixel_values,
785
+ "prompt_ids": prompt_ids,
786
+ "unet_added_conditions": {"text_embeds": add_text_embeds, "time_ids": add_time_ids},
787
+ }
788
+
789
+
790
+ def main(args):
791
+ if args.report_to == "wandb" and args.hub_token is not None:
792
+ raise ValueError(
793
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
794
+ " Please use `huggingface-cli login` to authenticate with the Hub."
795
+ )
796
+
797
+ logging_dir = Path(args.output_dir, args.logging_dir)
798
+
799
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
800
+ # due to pytorch#99272, MPS does not yet support bfloat16.
801
+ raise ValueError(
802
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
803
+ )
804
+
805
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
806
+
807
+ accelerator = Accelerator(
808
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
809
+ mixed_precision=args.mixed_precision,
810
+ log_with=args.report_to,
811
+ project_config=accelerator_project_config,
812
+ )
813
+
814
+ # Disable AMP for MPS.
815
+ if torch.backends.mps.is_available():
816
+ accelerator.native_amp = False
817
+
818
+ # Make one log on every process with the configuration for debugging.
819
+ logging.basicConfig(
820
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
821
+ datefmt="%m/%d/%Y %H:%M:%S",
822
+ level=logging.INFO,
823
+ )
824
+ logger.info(accelerator.state, main_process_only=False)
825
+ if accelerator.is_local_main_process:
826
+ transformers.utils.logging.set_verbosity_warning()
827
+ diffusers.utils.logging.set_verbosity_info()
828
+ else:
829
+ transformers.utils.logging.set_verbosity_error()
830
+ diffusers.utils.logging.set_verbosity_error()
831
+
832
+ # If passed along, set the training seed now.
833
+ if args.seed is not None:
834
+ set_seed(args.seed)
835
+
836
+ # Handle the repository creation
837
+ if accelerator.is_main_process:
838
+ if args.output_dir is not None:
839
+ os.makedirs(args.output_dir, exist_ok=True)
840
+
841
+ if args.push_to_hub:
842
+ repo_id = create_repo(
843
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
844
+ ).repo_id
845
+
846
+ # Load the tokenizers
847
+ tokenizer_one = AutoTokenizer.from_pretrained(
848
+ args.pretrained_model_name_or_path,
849
+ subfolder="tokenizer",
850
+ revision=args.revision,
851
+ use_fast=False,
852
+ )
853
+ tokenizer_two = AutoTokenizer.from_pretrained(
854
+ args.pretrained_model_name_or_path,
855
+ subfolder="tokenizer_2",
856
+ revision=args.revision,
857
+ use_fast=False,
858
+ )
859
+
860
+ # import correct text encoder classes
861
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
862
+ args.pretrained_model_name_or_path, args.revision
863
+ )
864
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
865
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
866
+ )
867
+
868
+ # Load scheduler and models
869
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
870
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
871
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
872
+ )
873
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
874
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
875
+ )
876
+ vae_path = (
877
+ args.pretrained_model_name_or_path
878
+ if args.pretrained_vae_model_name_or_path is None
879
+ else args.pretrained_vae_model_name_or_path
880
+ )
881
+ vae = AutoencoderKL.from_pretrained(
882
+ vae_path,
883
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
884
+ revision=args.revision,
885
+ variant=args.variant,
886
+ )
887
+ unet = UNet2DConditionModel.from_pretrained(
888
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
889
+ )
890
+
891
+ if args.controlnet_model_name_or_path:
892
+ logger.info("Loading existing controlnet weights")
893
+ controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
894
+ else:
895
+ logger.info("Initializing controlnet weights from unet")
896
+ controlnet = ControlNetModel.from_unet(unet)
897
+
898
+ def unwrap_model(model):
899
+ model = accelerator.unwrap_model(model)
900
+ model = model._orig_mod if is_compiled_module(model) else model
901
+ return model
902
+
903
+ # `accelerate` 0.16.0 will have better support for customized saving
904
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
905
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
906
+ def save_model_hook(models, weights, output_dir):
907
+ if accelerator.is_main_process:
908
+ i = len(weights) - 1
909
+
910
+ while len(weights) > 0:
911
+ weights.pop()
912
+ model = models[i]
913
+
914
+ sub_dir = "controlnet"
915
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
916
+
917
+ i -= 1
918
+
919
+ def load_model_hook(models, input_dir):
920
+ while len(models) > 0:
921
+ # pop models so that they are not loaded again
922
+ model = models.pop()
923
+
924
+ # load diffusers style into model
925
+ load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet")
926
+ model.register_to_config(**load_model.config)
927
+
928
+ model.load_state_dict(load_model.state_dict())
929
+ del load_model
930
+
931
+ accelerator.register_save_state_pre_hook(save_model_hook)
932
+ accelerator.register_load_state_pre_hook(load_model_hook)
933
+
934
+ vae.requires_grad_(False)
935
+ unet.requires_grad_(False)
936
+ text_encoder_one.requires_grad_(False)
937
+ text_encoder_two.requires_grad_(False)
938
+ controlnet.train()
939
+
940
+ if args.enable_npu_flash_attention:
941
+ if is_torch_npu_available():
942
+ logger.info("npu flash attention enabled.")
943
+ unet.enable_npu_flash_attention()
944
+ else:
945
+ raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")
946
+
947
+ if args.enable_xformers_memory_efficient_attention:
948
+ if is_xformers_available():
949
+ import xformers
950
+
951
+ xformers_version = version.parse(xformers.__version__)
952
+ if xformers_version == version.parse("0.0.16"):
953
+ logger.warning(
954
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
955
+ )
956
+ unet.enable_xformers_memory_efficient_attention()
957
+ controlnet.enable_xformers_memory_efficient_attention()
958
+ else:
959
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
960
+
961
+ if args.gradient_checkpointing:
962
+ controlnet.enable_gradient_checkpointing()
963
+ unet.enable_gradient_checkpointing()
964
+
965
+ # Check that all trainable models are in full precision
966
+ low_precision_error_string = (
967
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
968
+ " doing mixed precision training, copy of the weights should still be float32."
969
+ )
970
+
971
+ if unwrap_model(controlnet).dtype != torch.float32:
972
+ raise ValueError(
973
+ f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}"
974
+ )
975
+
976
+ # Enable TF32 for faster training on Ampere GPUs,
977
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
978
+ if args.allow_tf32:
979
+ torch.backends.cuda.matmul.allow_tf32 = True
980
+
981
+ if args.scale_lr:
982
+ args.learning_rate = (
983
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
984
+ )
985
+
986
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
987
+ if args.use_8bit_adam:
988
+ try:
989
+ import bitsandbytes as bnb
990
+ except ImportError:
991
+ raise ImportError(
992
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
993
+ )
994
+
995
+ optimizer_class = bnb.optim.AdamW8bit
996
+ else:
997
+ optimizer_class = torch.optim.AdamW
998
+
999
+ # Optimizer creation
1000
+ params_to_optimize = controlnet.parameters()
1001
+ optimizer = optimizer_class(
1002
+ params_to_optimize,
1003
+ lr=args.learning_rate,
1004
+ betas=(args.adam_beta1, args.adam_beta2),
1005
+ weight_decay=args.adam_weight_decay,
1006
+ eps=args.adam_epsilon,
1007
+ )
1008
+
1009
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
1010
+ # as these models are only used for inference, keeping weights in full precision is not required.
1011
+ weight_dtype = torch.float32
1012
+ if accelerator.mixed_precision == "fp16":
1013
+ weight_dtype = torch.float16
1014
+ elif accelerator.mixed_precision == "bf16":
1015
+ weight_dtype = torch.bfloat16
1016
+
1017
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
1018
+ # The VAE is in float32 to avoid NaN losses.
1019
+ if args.pretrained_vae_model_name_or_path is not None:
1020
+ vae.to(accelerator.device, dtype=weight_dtype)
1021
+ else:
1022
+ vae.to(accelerator.device, dtype=torch.float32)
1023
+ unet.to(accelerator.device, dtype=weight_dtype)
1024
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
1025
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
1026
+
1027
+ # Here, we compute not just the text embeddings but also the additional embeddings
1028
+ # needed for the SD XL UNet to operate.
1029
+ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizers, is_train=True):
1030
+ original_size = (args.resolution, args.resolution)
1031
+ target_size = (args.resolution, args.resolution)
1032
+ crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
1033
+ prompt_batch = batch[args.caption_column]
1034
+
1035
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
1036
+ prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train
1037
+ )
1038
+ add_text_embeds = pooled_prompt_embeds
1039
+
1040
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
1041
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
1042
+ add_time_ids = torch.tensor([add_time_ids])
1043
+
1044
+ prompt_embeds = prompt_embeds.to(accelerator.device)
1045
+ add_text_embeds = add_text_embeds.to(accelerator.device)
1046
+ add_time_ids = add_time_ids.repeat(len(prompt_batch), 1)
1047
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
1048
+ unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1049
+
1050
+ return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
1051
+
1052
+ # Let's first compute all the embeddings so that we can free up the text encoders
1053
+ # from memory.
1054
+ text_encoders = [text_encoder_one, text_encoder_two]
1055
+ tokenizers = [tokenizer_one, tokenizer_two]
1056
+ train_dataset = get_train_dataset(args, accelerator)
1057
+ compute_embeddings_fn = functools.partial(
1058
+ compute_embeddings,
1059
+ text_encoders=text_encoders,
1060
+ tokenizers=tokenizers,
1061
+ proportion_empty_prompts=args.proportion_empty_prompts,
1062
+ )
1063
+ with accelerator.main_process_first():
1064
+ from datasets.fingerprint import Hasher
1065
+
1066
+ # fingerprint used by the cache for the other processes to load the result
1067
+ # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
1068
+ new_fingerprint = Hasher.hash(args)
1069
+ train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
1070
+
1071
+ del text_encoders, tokenizers
1072
+ gc.collect()
1073
+ torch.cuda.empty_cache()
1074
+
1075
+ # Then get the training dataset ready to be passed to the dataloader.
1076
+ train_dataset = prepare_train_dataset(train_dataset, accelerator)
1077
+
1078
+ train_dataloader = torch.utils.data.DataLoader(
1079
+ train_dataset,
1080
+ shuffle=True,
1081
+ collate_fn=collate_fn,
1082
+ batch_size=args.train_batch_size,
1083
+ num_workers=args.dataloader_num_workers,
1084
+ )
1085
+
1086
+ # Scheduler and math around the number of training steps.
1087
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1088
+ num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
1089
+ if args.max_train_steps is None:
1090
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
1091
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
1092
+ num_training_steps_for_scheduler = (
1093
+ args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
1094
+ )
1095
+ else:
1096
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
1097
+
1098
+ lr_scheduler = get_scheduler(
1099
+ args.lr_scheduler,
1100
+ optimizer=optimizer,
1101
+ num_warmup_steps=num_warmup_steps_for_scheduler,
1102
+ num_training_steps=num_training_steps_for_scheduler,
1103
+ num_cycles=args.lr_num_cycles,
1104
+ power=args.lr_power,
1105
+ )
1106
+
1107
+ # Prepare everything with our `accelerator`.
1108
+ controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1109
+ controlnet, optimizer, train_dataloader, lr_scheduler
1110
+ )
1111
+
1112
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1113
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1114
+ if args.max_train_steps is None:
1115
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1116
+ if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
1117
+ logger.warning(
1118
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
1119
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
1120
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
1121
+ )
1122
+ # Afterwards we recalculate our number of training epochs
1123
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1124
+
1125
+ # We need to initialize the trackers we use, and also store our configuration.
1126
+ # The trackers initializes automatically on the main process.
1127
+ if accelerator.is_main_process:
1128
+ tracker_config = dict(vars(args))
1129
+
1130
+ # tensorboard cannot handle list types for config
1131
+ tracker_config.pop("validation_prompt")
1132
+ tracker_config.pop("validation_image")
1133
+
1134
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
1135
+
1136
+ # Train!
1137
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1138
+
1139
+ logger.info("***** Running training *****")
1140
+ logger.info(f" Num examples = {len(train_dataset)}")
1141
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1142
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1143
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1144
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1145
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1146
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1147
+ global_step = 0
1148
+ first_epoch = 0
1149
+
1150
+ # Potentially load in the weights and states from a previous save
1151
+ if args.resume_from_checkpoint:
1152
+ if args.resume_from_checkpoint != "latest":
1153
+ path = os.path.basename(args.resume_from_checkpoint)
1154
+ else:
1155
+ # Get the most recent checkpoint
1156
+ dirs = os.listdir(args.output_dir)
1157
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1158
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1159
+ path = dirs[-1] if len(dirs) > 0 else None
1160
+
1161
+ if path is None:
1162
+ accelerator.print(
1163
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1164
+ )
1165
+ args.resume_from_checkpoint = None
1166
+ initial_global_step = 0
1167
+ else:
1168
+ accelerator.print(f"Resuming from checkpoint {path}")
1169
+ accelerator.load_state(os.path.join(args.output_dir, path))
1170
+ global_step = int(path.split("-")[1])
1171
+
1172
+ initial_global_step = global_step
1173
+ first_epoch = global_step // num_update_steps_per_epoch
1174
+ else:
1175
+ initial_global_step = 0
1176
+
1177
+ progress_bar = tqdm(
1178
+ range(0, args.max_train_steps),
1179
+ initial=initial_global_step,
1180
+ desc="Steps",
1181
+ # Only show the progress bar once on each machine.
1182
+ disable=not accelerator.is_local_main_process,
1183
+ )
1184
+
1185
+ image_logs = None
1186
+ for epoch in range(first_epoch, args.num_train_epochs):
1187
+ for step, batch in enumerate(train_dataloader):
1188
+ with accelerator.accumulate(controlnet):
1189
+ # Convert images to latent space
1190
+ if args.pretrained_vae_model_name_or_path is not None:
1191
+ pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
1192
+ else:
1193
+ pixel_values = batch["pixel_values"]
1194
+ latents = vae.encode(pixel_values).latent_dist.sample()
1195
+ latents = latents * vae.config.scaling_factor
1196
+ if args.pretrained_vae_model_name_or_path is None:
1197
+ latents = latents.to(weight_dtype)
1198
+
1199
+ # Sample noise that we'll add to the latents
1200
+ noise = torch.randn_like(latents)
1201
+ bsz = latents.shape[0]
1202
+
1203
+ # Sample a random timestep for each image
1204
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
1205
+ timesteps = timesteps.long()
1206
+
1207
+ # Add noise to the latents according to the noise magnitude at each timestep
1208
+ # (this is the forward diffusion process)
1209
+ noisy_latents = noise_scheduler.add_noise(latents.float(), noise.float(), timesteps).to(
1210
+ dtype=weight_dtype
1211
+ )
1212
+
1213
+ # ControlNet conditioning.
1214
+ controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
1215
+ down_block_res_samples, mid_block_res_sample = controlnet(
1216
+ noisy_latents,
1217
+ timesteps,
1218
+ encoder_hidden_states=batch["prompt_ids"],
1219
+ added_cond_kwargs=batch["unet_added_conditions"],
1220
+ controlnet_cond=controlnet_image,
1221
+ return_dict=False,
1222
+ )
1223
+
1224
+ # Predict the noise residual
1225
+ model_pred = unet(
1226
+ noisy_latents,
1227
+ timesteps,
1228
+ encoder_hidden_states=batch["prompt_ids"],
1229
+ added_cond_kwargs=batch["unet_added_conditions"],
1230
+ down_block_additional_residuals=[
1231
+ sample.to(dtype=weight_dtype) for sample in down_block_res_samples
1232
+ ],
1233
+ mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
1234
+ return_dict=False,
1235
+ )[0]
1236
+
1237
+ # Get the target for loss depending on the prediction type
1238
+ if noise_scheduler.config.prediction_type == "epsilon":
1239
+ target = noise
1240
+ elif noise_scheduler.config.prediction_type == "v_prediction":
1241
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
1242
+ else:
1243
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1244
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1245
+
1246
+ accelerator.backward(loss)
1247
+ if accelerator.sync_gradients:
1248
+ params_to_clip = controlnet.parameters()
1249
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1250
+ optimizer.step()
1251
+ lr_scheduler.step()
1252
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
1253
+
1254
+ # Checks if the accelerator has performed an optimization step behind the scenes
1255
+ if accelerator.sync_gradients:
1256
+ progress_bar.update(1)
1257
+ global_step += 1
1258
+
1259
+ # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
1260
+ if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
1261
+ if global_step % args.checkpointing_steps == 0:
1262
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1263
+ if args.checkpoints_total_limit is not None:
1264
+ checkpoints = os.listdir(args.output_dir)
1265
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1266
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1267
+
1268
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1269
+ if len(checkpoints) >= args.checkpoints_total_limit:
1270
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1271
+ removing_checkpoints = checkpoints[0:num_to_remove]
1272
+
1273
+ logger.info(
1274
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1275
+ )
1276
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1277
+
1278
+ for removing_checkpoint in removing_checkpoints:
1279
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1280
+ shutil.rmtree(removing_checkpoint)
1281
+
1282
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1283
+ accelerator.save_state(save_path)
1284
+ logger.info(f"Saved state to {save_path}")
1285
+
1286
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
1287
+ image_logs = log_validation(
1288
+ vae=vae,
1289
+ unet=unet,
1290
+ controlnet=controlnet,
1291
+ args=args,
1292
+ accelerator=accelerator,
1293
+ weight_dtype=weight_dtype,
1294
+ step=global_step,
1295
+ )
1296
+
1297
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1298
+ progress_bar.set_postfix(**logs)
1299
+ accelerator.log(logs, step=global_step)
1300
+
1301
+ if global_step >= args.max_train_steps:
1302
+ break
1303
+
1304
+ # Create the pipeline using using the trained modules and save it.
1305
+ accelerator.wait_for_everyone()
1306
+ if accelerator.is_main_process:
1307
+ controlnet = unwrap_model(controlnet)
1308
+ controlnet.save_pretrained(args.output_dir)
1309
+
1310
+ # Run a final round of validation.
1311
+ # Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`.
1312
+ image_logs = None
1313
+ if args.validation_prompt is not None:
1314
+ image_logs = log_validation(
1315
+ vae=None,
1316
+ unet=None,
1317
+ controlnet=None,
1318
+ args=args,
1319
+ accelerator=accelerator,
1320
+ weight_dtype=weight_dtype,
1321
+ step=global_step,
1322
+ is_final_validation=True,
1323
+ )
1324
+
1325
+ if args.push_to_hub:
1326
+ save_model_card(
1327
+ repo_id,
1328
+ image_logs=image_logs,
1329
+ base_model=args.pretrained_model_name_or_path,
1330
+ repo_folder=args.output_dir,
1331
+ )
1332
+ upload_folder(
1333
+ repo_id=repo_id,
1334
+ folder_path=args.output_dir,
1335
+ commit_message="End of training",
1336
+ ignore_patterns=["step_*", "epoch_*"],
1337
+ )
1338
+
1339
+ accelerator.end_training()
1340
+
1341
+
1342
+ if __name__ == "__main__":
1343
+ args = parse_args()
1344
+ main(args)