Skip to content

Commit 302c272

Browse files
Enabled 12GB VRAM training via optional activation checkpointing
1 parent 5204e0a commit 302c272

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

Diff for: README.md

+6
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ python learn_delta.py device=cuda:0 model=sdxl prompts=people/age
4747
```
4848
This will save the delta at `outputs/learn_delta/people/age/runs/<date>/<time>/checkpoints/delta.pt`, which you can then directly use as shown in the example notebooks.
4949

50+
This will typically require slightly more than 24GB of VRAM for training (26GB when training on an A100 as of June 13th 2024, although this will likely change with newer versions of diffusers and PyTorch). If you want to train on smaller hardware, you can enable gradient checkpointing (typically called activation checkpointing, but we'll stick to diffusers terminology here) by launching the training as
51+
```shell
52+
python learn_delta.py device=cuda:0 model=sdxl prompts=people/age model.compile=False +model.gradient_checkpointing=True
53+
```
54+
In our experiments, this enabled training deltas with a 11.5GB VRAM budget, at the cost of slower training.
55+
5056
#### Naive CLIP Difference Method
5157
The simplest method to obtain deltas is the naive CLIP difference-based method. With it, you can obtain a delta in a few seconds on a decent GPU. It is substantially worse than the proper learned method though.
5258

Diff for: attribute_control/model/model.py

+6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from abc import ABC, abstractmethod, abstractproperty
44
from typing import Union, Tuple, Dict, Optional, List, Any
55
from pydoc import locate
6+
import warnings
67

78
import torch
89
from torch import nn
@@ -183,6 +184,7 @@ def __init__(
183184
pipe_kwargs: dict = { },
184185
device: Union[str, torch.device] = 'cuda:0',
185186
compile: bool = False,
187+
gradient_checkpointing: bool = False,
186188
) -> None:
187189
super().__init__(pipeline_type=pipeline_type, model_name=model_name, num_inference_steps=num_inference_steps, pipe_kwargs=pipe_kwargs, device=device, compile=compile)
188190

@@ -191,6 +193,10 @@ def __init__(
191193
d_v_major, d_v_minor, *_ = diffusers.__version__.split('.')
192194
if int(d_v_major) > 0 or int(d_v_minor) >= 25:
193195
self.pipe.fuse_qkv_projections()
196+
if gradient_checkpointing:
197+
if compile:
198+
warnings.warn('Gradient checkpointing is typically not compatible with compiling the U-Net. This will likely lead to a crash.')
199+
self.pipe.unet.enable_gradient_checkpointing()
194200
if compile:
195201
assert int(d_v_major) > 0 or int(d_v_minor) >= 25, 'Use at least diffusers==0.25 to enable proper functionality of torch.compile().'
196202
self.pipe.unet.to(memory_format=torch.channels_last)

0 commit comments

Comments
 (0)