Skip to content
Open
38 changes: 38 additions & 0 deletions configs/inference/lora_c_1b_bfloat16.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# GLOBAL STUFF
experiment_id: stage_c_1b_lora
checkpoint_path: ~/cascade/chk
output_path: ~/cascade/lora_sample
model_version: 1B

# TRAINING PARAMS
lr: 1.0e-4
batch_size: 40
image_size: 768
multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
grad_accum_steps: 4
updates: 10000
backup_every: 1000
save_every: 100
warmup_updates: 1
# use_fsdp: True -> FSDP doesn't work at the moment for LoRA
use_fsdp: False

# GDF
# adaptive_loss_weight: True

# LoRA specific. 'No Defect Train Railcar Wheel'
module_filters: ['.attn']
rank: 4
train_tokens:
# - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized
- ['[fernando]', '^dog</w>'] # custom token [snail], initialize as avg of snail & snails


# ema_start_iters: 5000
# ema_iters: 100
# ema_beta: 0.9

webdataset_path: file:/home/asutermo/cascade/data/dataset.tar
effnet_checkpoint_path: models/effnet_encoder.safetensors
previewer_checkpoint_path: models/previewer.safetensors
generator_checkpoint_path: models/stage_c_lite_bf16.safetensors
5 changes: 4 additions & 1 deletion core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class Config(Base):
wandb_project: str = None
wandb_entity: str = None

single_gpu: bool = False

@dataclass() # not frozen, means that fields are mutable
class Info(): # not inheriting from Base, because we don't want to enforce the default fields
wandb_run_id: str = None
Expand Down Expand Up @@ -141,6 +143,7 @@ def setup_config(self, config_file_path=None, config_dict=None, training=True) -
return self.Config(training=training)

def setup_ddp(self, experiment_id, single_gpu=False):
self.single_gpu = single_gpu
if not single_gpu:
local_rank = int(os.environ.get("SLURM_LOCALID"))
process_id = int(os.environ.get("SLURM_PROCID"))
Expand Down Expand Up @@ -297,7 +300,7 @@ def __call__(self, single_gpu=False):

if self.is_main_node:
print()
print("**STARTIG JOB WITH CONFIG:**")
print("**STARTING JOB WITH CONFIG:**")
print(yaml.dump(self.config.to_dict(), default_flow_style=False))
print("------------------------------------")
print()
Expand Down
3 changes: 2 additions & 1 deletion core/templates/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ def models_to_save(self):
return ['generator', 'generator_ema']

def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None):
barrier()
if not self.single_gpu:
barrier()
suffix = '' if suffix is None else suffix
self.save_info(self.info, suffix=suffix)
models_dict = models.to_dict()
Expand Down
3 changes: 2 additions & 1 deletion train/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .train_b import WurstCore as WurstCoreB
from .train_c import WurstCore as WurstCoreC
from .train_c_controlnet import WurstCore as ControlNetCore
from .train_c_lora import WurstCore as LoraCore
from .train_c_lora import WurstCore as LoraCore

3 changes: 2 additions & 1 deletion train/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,8 @@ def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, op
self.sample(models, data, extras)

def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None):
barrier()
if not self.single_gpu:
barrier()
suffix = '' if suffix is None else suffix
self.save_info(self.info, suffix=suffix)
models_dict = models.to_dict()
Expand Down
6 changes: 4 additions & 2 deletions train/train_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,13 @@ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, ext

if __name__ == '__main__':
print("Launching Script")
device = torch.device(int(os.environ.get('SLURM_LOCALID')) if 'SLURM_LOCALID' in os.environ else "cuda" if torch.cuda.is_available() else "cpu")
warpcore = WurstCore(
config_file_path=sys.argv[1] if len(sys.argv) > 1 else None,
device=torch.device(int(os.environ.get("SLURM_LOCALID")))
device=device
)
# core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD

# RUN TRAINING
warpcore()
use_single_gpu = torch.cuda.device_count() == 1
warpcore(single_gpu=use_single_gpu)
8 changes: 6 additions & 2 deletions train/train_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,15 @@ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, ext

if __name__ == '__main__':
print("Launching Script")

device = torch.device(int(os.environ.get('SLURM_LOCALID')) if 'SLURM_LOCALID' in os.environ else "cuda" if torch.cuda.is_available() else "cpu")
warpcore = WurstCore(
config_file_path=sys.argv[1] if len(sys.argv) > 1 else None,
device=torch.device(int(os.environ.get("SLURM_LOCALID")))
device=device
)
# core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD

# RUN TRAINING
warpcore()
use_single_gpu = torch.cuda.device_count() == 1
warpcore(single_gpu=use_single_gpu)

7 changes: 5 additions & 2 deletions train/train_c_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,11 +372,14 @@ def sample(self, models: Models, data: WarpCore.Data, extras: Extras):

if __name__ == '__main__':
print("Launching Script")
device = torch.device(int(os.environ.get('SLURM_LOCALID')) if 'SLURM_LOCALID' in os.environ else "cuda" if torch.cuda.is_available() else "cpu")
warpcore = WurstCore(
config_file_path=sys.argv[1] if len(sys.argv) > 1 else None,
device=torch.device(int(os.environ.get("SLURM_LOCALID")))
device=device
)
warpcore.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD

# RUN TRAINING
warpcore()
use_single_gpu = torch.cuda.device_count() == 1
warpcore(single_gpu=use_single_gpu)

9 changes: 5 additions & 4 deletions train/train_c_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ def dummy_context():
yield None

loading_context = dummy_context if self.config.training else init_empty_weights

with loading_context():
# Diffusion models
if self.config.model_version == '3.6B':
Expand Down Expand Up @@ -252,7 +251,7 @@ def dummy_context():
)

def setup_optimizers(self, extras: Extras, models: Models) -> Optimizers:
optimizer = optim.AdamW(models.lora.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95))
optimizer = optim.AdamW(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95))
optimizer = self.load_optimizer(optimizer, 'lora_optim',
fsdp_model=models.lora if self.config.use_fsdp else None)
return self.Optimizers(generator=None, lora=optimizer)
Expand Down Expand Up @@ -320,11 +319,13 @@ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, ext

if __name__ == '__main__':
print("Launching Script")
device = torch.device(int(os.environ.get('SLURM_LOCALID')) if 'SLURM_LOCALID' in os.environ else "cuda" if torch.cuda.is_available() else "cpu")
warpcore = WurstCore(
config_file_path=sys.argv[1] if len(sys.argv) > 1 else None,
device=torch.device(int(os.environ.get("SLURM_LOCALID")))
device=device
)
warpcore.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD

# RUN TRAINING
warpcore()
use_single_gpu = torch.cuda.device_count() == 1
warpcore(single_gpu=use_single_gpu)