Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
ci:
autofix_prs: false
autoupdate_schedule: monthly

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.7.0
rev: v0.14.11
hooks:
# Run the linter.
- id: ruff
Expand Down
4 changes: 1 addition & 3 deletions applications/calc_global_solar.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ def main():
heights = static_ds[args.geo].values / 9.81
grid_points = np.vstack([lon_grid.ravel(), lat_grid.ravel(), heights.ravel()]).T
split_indices = np.round(np.linspace(0, grid_points.shape[0], size + 1)).astype(int)
grid_points_sub = [
grid_points[split_indices[s] : split_indices[s + 1]] for s in range(split_indices.size - 1)
]
grid_points_sub = [grid_points[split_indices[s] : split_indices[s + 1]] for s in range(split_indices.size - 1)]
toa_radiation = get_toa_radiation(args.start, args.end, step_freq=args.step, sub_freq=args.sub)
toa_radiation = comm.bcast(toa_radiation, root=0)
rank_points = comm.scatter(grid_points_sub, root=0)
Expand Down
202 changes: 42 additions & 160 deletions applications/deprecated/train.py

Large diffs are not rendered by default.

128 changes: 27 additions & 101 deletions applications/deprecated/train_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,26 +67,10 @@ def load_model_states_and_optimizer(conf, model, device):
amp = conf["trainer"]["amp"]

# load weights / states flags
load_weights = (
False
if "load_weights" not in conf["trainer"]
else conf["trainer"]["load_weights"]
)
load_optimizer_conf = (
False
if "load_optimizer" not in conf["trainer"]
else conf["trainer"]["load_optimizer"]
)
load_scaler_conf = (
False
if "load_scaler" not in conf["trainer"]
else conf["trainer"]["load_scaler"]
)
load_scheduler_conf = (
False
if "load_scheduler" not in conf["trainer"]
else conf["trainer"]["load_scheduler"]
)
load_weights = False if "load_weights" not in conf["trainer"] else conf["trainer"]["load_weights"]
load_optimizer_conf = False if "load_optimizer" not in conf["trainer"] else conf["trainer"]["load_optimizer"]
load_scaler_conf = False if "load_scaler" not in conf["trainer"] else conf["trainer"]["load_scaler"]
load_scheduler_conf = False if "load_scheduler" not in conf["trainer"] else conf["trainer"]["load_scheduler"]

# Load an optimizer, gradient scaler, and learning rate scheduler, the optimizer must come after wrapping model using FSDP
if not load_weights: # Loaded after loading model weights when reloading
Expand All @@ -99,16 +83,10 @@ def load_model_states_and_optimizer(conf, model, device):
if conf["trainer"]["mode"] == "fsdp":
optimizer = FSDPOptimizerWrapper(optimizer, model)
scheduler = load_scheduler(optimizer, conf)
scaler = (
ShardedGradScaler(enabled=amp)
if conf["trainer"]["mode"] == "fsdp"
else GradScaler(enabled=amp)
)
scaler = ShardedGradScaler(enabled=amp) if conf["trainer"]["mode"] == "fsdp" else GradScaler(enabled=amp)

# Multi-step training case -- when starting, only load the model weights (then after load all states)
elif load_weights and not (
load_optimizer_conf or load_scaler_conf or load_scheduler_conf
):
elif load_weights and not (load_optimizer_conf or load_scaler_conf or load_scheduler_conf):
optimizer = torch.optim.AdamW(
model.parameters(),
lr=learning_rate,
Expand All @@ -117,9 +95,7 @@ def load_model_states_and_optimizer(conf, model, device):
)
# FSDP checkpoint settings
if conf["trainer"]["mode"] == "fsdp":
logging.info(
f"Loading FSDP model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}"
)
logging.info(f"Loading FSDP model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}")
optimizer = torch.optim.AdamW(
model.parameters(),
lr=learning_rate,
Expand All @@ -128,30 +104,20 @@ def load_model_states_and_optimizer(conf, model, device):
)
optimizer = FSDPOptimizerWrapper(optimizer, model)
checkpoint_io = TorchFSDPCheckpointIO()
checkpoint_io.load_unsharded_model(
model, os.path.join(save_loc, "model_checkpoint.pt")
)
checkpoint_io.load_unsharded_model(model, os.path.join(save_loc, "model_checkpoint.pt"))
else:
# DDP settings
ckpt = os.path.join(save_loc, "checkpoint.pt")
checkpoint = torch.load(ckpt, map_location=device)
if conf["trainer"]["mode"] == "ddp":
logging.info(
f"Loading DDP model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}"
)
logging.info(f"Loading DDP model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}")
model.module.load_state_dict(checkpoint["model_state_dict"])
else:
logging.info(
f"Loading model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}"
)
logging.info(f"Loading model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}")
model.load_state_dict(checkpoint["model_state_dict"])
# Load the learning rate scheduler and mixed precision grad scaler
scheduler = load_scheduler(optimizer, conf)
scaler = (
ShardedGradScaler(enabled=amp)
if conf["trainer"]["mode"] == "fsdp"
else GradScaler(enabled=amp)
)
scaler = ShardedGradScaler(enabled=amp) if conf["trainer"]["mode"] == "fsdp" else GradScaler(enabled=amp)

# load optimizer and grad scaler states
else:
Expand All @@ -160,9 +126,7 @@ def load_model_states_and_optimizer(conf, model, device):

# FSDP checkpoint settings
if conf["trainer"]["mode"] == "fsdp":
logging.info(
f"Loading FSDP model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}"
)
logging.info(f"Loading FSDP model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}")
optimizer = torch.optim.AdamW(
model.parameters(),
lr=learning_rate,
Expand All @@ -171,47 +135,29 @@ def load_model_states_and_optimizer(conf, model, device):
)
optimizer = FSDPOptimizerWrapper(optimizer, model)
checkpoint_io = TorchFSDPCheckpointIO()
checkpoint_io.load_unsharded_model(
model, os.path.join(save_loc, "model_checkpoint.pt")
)
if (
"load_optimizer" in conf["trainer"]
and conf["trainer"]["load_optimizer"]
):
checkpoint_io.load_unsharded_optimizer(
optimizer, os.path.join(save_loc, "optimizer_checkpoint.pt")
)
checkpoint_io.load_unsharded_model(model, os.path.join(save_loc, "model_checkpoint.pt"))
if "load_optimizer" in conf["trainer"] and conf["trainer"]["load_optimizer"]:
checkpoint_io.load_unsharded_optimizer(optimizer, os.path.join(save_loc, "optimizer_checkpoint.pt"))

else:
# DDP settings
if conf["trainer"]["mode"] == "ddp":
logging.info(
f"Loading DDP model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}"
)
logging.info(f"Loading DDP model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}")
model.module.load_state_dict(checkpoint["model_state_dict"])
else:
logging.info(
f"Loading model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}"
)
logging.info(f"Loading model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}")
model.load_state_dict(checkpoint["model_state_dict"])
optimizer = torch.optim.AdamW(
model.parameters(),
lr=learning_rate,
weight_decay=weight_decay,
betas=(0.9, 0.95),
)
if (
"load_optimizer" in conf["trainer"]
and conf["trainer"]["load_optimizer"]
):
if "load_optimizer" in conf["trainer"] and conf["trainer"]["load_optimizer"]:
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

scheduler = load_scheduler(optimizer, conf)
scaler = (
ShardedGradScaler(enabled=amp)
if conf["trainer"]["mode"] == "fsdp"
else GradScaler(enabled=amp)
)
scaler = ShardedGradScaler(enabled=amp) if conf["trainer"]["mode"] == "fsdp" else GradScaler(enabled=amp)

# Update the config file to the current epoch
if "reload_epoch" in conf["trainer"] and conf["trainer"]["reload_epoch"]:
Expand All @@ -226,11 +172,7 @@ def load_model_states_and_optimizer(conf, model, device):
scaler.load_state_dict(checkpoint["scaler_state_dict"])

# Enable updating the lr if not using a policy
if (
conf["trainer"]["update_learning_rate"]
if "update_learning_rate" in conf["trainer"]
else False
):
if conf["trainer"]["update_learning_rate"] if "update_learning_rate" in conf["trainer"] else False:
for param_group in optimizer.param_groups:
param_group["lr"] = learning_rate

Expand Down Expand Up @@ -259,11 +201,7 @@ def main(rank, world_size, conf, backend, trial=False):
setup(rank, world_size, conf["trainer"]["mode"], backend)

# infer device id from rank
device = (
torch.device(f"cuda:{rank % torch.cuda.device_count()}")
if torch.cuda.is_available()
else torch.device("cpu")
)
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") if torch.cuda.is_available() else torch.device("cpu")
torch.cuda.set_device(rank % torch.cuda.device_count())

# Config settings
Expand All @@ -275,12 +213,8 @@ def main(rank, world_size, conf, backend, trial=False):
valid_dataset = load_dataset(conf, rank=rank, world_size=world_size, is_train=False)

# Load the dataloader
train_loader = load_dataloader(
conf, train_dataset, rank=rank, world_size=world_size, is_train=True
)
valid_loader = load_dataloader(
conf, valid_dataset, rank=rank, world_size=world_size, is_train=False
)
train_loader = load_dataloader(conf, train_dataset, rank=rank, world_size=world_size, is_train=True)
valid_loader = load_dataloader(conf, valid_dataset, rank=rank, world_size=world_size, is_train=False)

# model
m = load_model(conf)
Expand All @@ -296,9 +230,7 @@ def main(rank, world_size, conf, backend, trial=False):
model = distributed_model_wrapper(conf, m, device)

# Load model weights (if any), an optimizer, scheduler, and gradient scaler
conf, model, optimizer, scheduler, scaler = load_model_states_and_optimizer(
conf, model, device
)
conf, model, optimizer, scheduler, scaler = load_model_states_and_optimizer(conf, model, device)

# Train and validation losses
train_criterion = VariableTotalLoss2D(conf)
Expand Down Expand Up @@ -371,14 +303,10 @@ def train(self, trial, conf):

except Exception as E:
if "CUDA" in str(E) or "non-singleton" in str(E):
logging.warning(
f"Pruning trial {trial.number} due to CUDA memory overflow: {str(E)}."
)
logging.warning(f"Pruning trial {trial.number} due to CUDA memory overflow: {str(E)}.")
raise optuna.TrialPruned()
elif "non-singleton" in str(E):
logging.warning(
f"Pruning trial {trial.number} due to shape mismatch: {str(E)}."
)
logging.warning(f"Pruning trial {trial.number} due to shape mismatch: {str(E)}.")
raise optuna.TrialPruned()
else:
logging.warning(f"Trial {trial.number} failed due to error: {str(E)}.")
Expand Down Expand Up @@ -442,9 +370,7 @@ def train(self, trial, conf):

# ======================================================== #
# handling config args
conf = credit_main_parser(
conf, parse_training=True, parse_predict=False, print_summary=False
)
conf = credit_main_parser(conf, parse_training=True, parse_predict=False, print_summary=False)
training_data_check(conf, print_summary=False)
# ======================================================== #

Expand Down
Loading