Skip to content

[QEff Finetune]: Enable PP+DDP #394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 69 additions & 16 deletions QEfficient/cloud/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#
# -----------------------------------------------------------------------------

import math
import random
import warnings
from typing import Any, Dict, Optional, Union
Expand All @@ -18,7 +19,7 @@
import torch.utils.data
from peft import PeftModel, get_peft_model
from torch.optim.lr_scheduler import StepLR
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer

from QEfficient.finetune.configs.training import TrainConfig
from QEfficient.finetune.utils.config_utils import (
Expand All @@ -32,7 +33,7 @@
get_preprocessed_dataset,
)
from QEfficient.finetune.utils.train_utils import get_longest_seq_length, print_model_size, train
from QEfficient.utils._utils import login_and_download_hf_lm
from QEfficient.utils._utils import get_num_layers_from_config, login_and_download_hf_lm

# Try importing QAIC-specific module, proceed without it if unavailable
try:
Expand All @@ -41,12 +42,37 @@
print(f"Warning: {e}. Proceeding without QAIC modules.")


from transformers import AutoModelForSequenceClassification

# Suppress all warnings
warnings.filterwarnings("ignore")


def get_device_map(rank, num_pp_stages, num_layers):
"""Returns device map for model layers and given process rank based on number of pipeline stages.

Args:
rank (int): process rank
num_pp_stages (int): number of stages in pipeline
num_layers (int): total number of layers in the models

Returns:
Dict: A dictionary of layers and corresponding device id.

Notes:
- This device map structure is verified for llama models only.
"""
device_map = {
"model.embed_tokens": rank * num_pp_stages,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add some explanation why these particular layers are mapped to a particular device.
L64 to L67

"lm_head": rank * num_pp_stages,
"model.norm": rank * num_pp_stages + (num_pp_stages - 1),
"model.rotary_emb": rank * num_pp_stages + (num_pp_stages - 1),
}
n_layer_per_stage = math.ceil(num_layers / num_pp_stages)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Use np.ceil so that no new module will be imported.

for j in range(num_pp_stages):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add some strong documentation for this double for loop. It is difficult to understand without taking a case. Better add some example and explain with it.

for i in range(n_layer_per_stage * j, min(n_layer_per_stage * (j + 1), num_layers)):
device_map[f"model.layers.{i}"] = rank * num_pp_stages + j
return device_map


def setup_distributed_training(train_config: TrainConfig) -> None:
"""Initialize distributed training environment if enabled.

Expand All @@ -69,8 +95,13 @@ def setup_distributed_training(train_config: TrainConfig) -> None:
assert torch_device.index is None, f"DDP requires only device type, got: {torch_device}"

dist.init_process_group(backend=train_config.dist_backend)
# from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
getattr(torch, torch_device.type).set_device(dist.get_rank())
if train_config.enable_pp:
assert dist.get_world_size() % train_config.num_pp_stages == 0, (
"total available devices should be multiple of number of pipeline stages"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Total instead of total
full stop at the end.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, can we intimate the user that
if dist.get_world_size() // train_config.num_pp_stage == 1, this will be only pure PP.
if dist.get_world_size() // train_config.num_pp_stage > 1, this will be actually PP+DDP.

This might be helpful to make our system idiot proof.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, we need another assert condition.
assert dist.get_world_size() * train_config.num_pp_stage == total_available_devices

)
else:
# from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
getattr(torch, torch_device.type).set_device(dist.get_rank())


def setup_seeds(seed: int) -> None:
Expand Down Expand Up @@ -128,12 +159,29 @@ def load_model_and_tokenizer(
if param.requires_grad:
param.data = param.data.to(torch.float32)
else:
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_path,
use_cache=False,
attn_implementation="sdpa",
torch_dtype=torch.float16,
)
if train_config.enable_pp:
if train_config.enable_ddp:
rank = dist.get_rank()
model_config = AutoConfig.from_pretrained(train_config.model_name)
num_layers = get_num_layers_from_config(model_config)
device_map = get_device_map(rank, train_config.num_pp_stages, num_layers)
else:
device_map = "auto"
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_path,
use_cache=False,
attn_implementation="sdpa",
torch_dtype=torch.float16,
device_map=device_map,
)
print(model.hf_device_map)
else:
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_path,
use_cache=False,
attn_implementation="sdpa",
torch_dtype=torch.float16,
)

tokenizer = AutoTokenizer.from_pretrained(
train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name
Expand Down Expand Up @@ -332,12 +380,17 @@ def main(peft_config_file: str = None, **kwargs) -> None:
f"passed context length is {train_config.context_length} and overall model's context length is "
f"{model.config.max_position_embeddings}"
)

model.to(train_config.device)
optimizer = optim.AdamW(model.parameters(), lr=train_config.lr, weight_decay=train_config.weight_decay)
if not train_config.enable_pp:
model.to(train_config.device)
optimizer = optim.AdamW(
model.parameters(),
lr=train_config.lr,
weight_decay=train_config.weight_decay,
)
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
if train_config.enable_ddp:
model = nn.parallel.DistributedDataParallel(model, device_ids=[dist.get_rank()])
model = nn.parallel.DistributedDataParallel(model) # , device_ids=[dist.get_rank()])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we removed device_ids in case of ddp? Because we are using device_map now?


results = train(
model,
tokenizer,
Expand Down
2 changes: 2 additions & 0 deletions QEfficient/finetune/configs/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ class TrainConfig:
# profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler

# dist-related
enable_pp: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this support is only added for decoder kind of model. So this needs to be properly documented. May be we can share some numerical data as well. E.g. If user's model is more than lets say 8B then user may need 4 pp stages. If it is more than 30B, user may need 16 pp stage. Like that.

num_pp_stages: int = 1
enable_ddp: bool = False
dist_backend: str = "cpu:gloo,qaic:qccl,cuda:gloo"

Expand Down
Loading