Skip to content

[QEff Finetune]: Enable PP+DDP #394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jul 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 45 additions & 19 deletions QEfficient/cloud/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,12 @@
generate_peft_config,
update_config,
)
from QEfficient.finetune.utils.dataset_utils import get_dataloader
from QEfficient.finetune.utils.helper import Task_Mode
from QEfficient.finetune.utils.dataset_utils import get_dataloader, get_longest_seq_length
from QEfficient.finetune.utils.device_map import get_device_map
from QEfficient.finetune.utils.helper import Task_Mode, get_world_size
from QEfficient.finetune.utils.logging_utils import logger
from QEfficient.finetune.utils.parser import get_finetune_parser
from QEfficient.finetune.utils.train_utils import (
get_longest_seq_length,
print_model_size,
print_trainable_parameters,
train,
)
from QEfficient.finetune.utils.train_utils import print_model_size, print_trainable_parameters, train
from QEfficient.utils._utils import hf_download

# Try importing QAIC-specific module, proceed without it if unavailable
Expand Down Expand Up @@ -63,17 +59,27 @@ def setup_distributed_training(train_config: TrainConfig) -> None:
Raises:
AssertionError: If device is CPU or includes an index with DDP enabled.
"""

torch_device = torch.device(train_config.device)
num_available_devices = getattr(torch, torch_device.type).device_count()
assert get_world_size() * train_config.num_pp_stages <= num_available_devices, (
"Number of devices required should be less than or equal to total available devices."
)
if train_config.enable_pp:
assert train_config.num_pp_stages > 1, (
f"For pipeline parallelism, num_pp_stages should be greater than 1. Got {train_config.num_pp_stages}"
)

if not train_config.enable_ddp:
return

torch_device = torch.device(train_config.device)
assert torch_device.type != "cpu", "Host doesn't support single-node DDP"
assert torch_device.index is None, f"DDP requires only device type, got: {torch_device}"

dist_backend_map = {"cpu": "gloo", "qaic": "qccl", "cuda": "gloo"}
dist.init_process_group(backend=dist_backend_map[torch_device.type])
# from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
getattr(torch, torch_device.type).set_device(dist.get_rank())
if not train_config.enable_pp:
# from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
getattr(torch, torch_device.type).set_device(dist.get_rank())


def setup_seeds(seed: int) -> None:
Expand All @@ -85,6 +91,10 @@ def setup_seeds(seed: int) -> None:
Notes:
- Sets seeds for PyTorch, Python's random module, and NumPy.
"""
torch.use_deterministic_algorithms(True)
# With this flag, PP+DDP works only for meta-llama/Llama-3.2-1B and mistralai/Mistral-7B-Instruct-v0.3
# and throws error during loading model for meta-llama/Llama-3.1-8B and bigger size models.

torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
Expand All @@ -96,7 +106,7 @@ def load_model_and_tokenizer(
"""Load the pre-trained model and tokenizer from Hugging Face.

Args:
config (TrainConfig): Training configuration object containing model and tokenizer names.
train_config (TrainConfig): Training configuration object containing model and tokenizer names.
dataset_config (Any): A dataclass object representing dataset configuration.
kwargs: Additional arguments to override PEFT config.

Expand All @@ -112,7 +122,10 @@ def load_model_and_tokenizer(
- Sets pad_token_id to eos_token_id if not defined in the tokenizer.
"""
logger.log_rank_zero(f"Loading HuggingFace model for {train_config.model_name}")
pretrained_model_path = hf_download(train_config.model_name)
pretrained_model_path = hf_download(
train_config.model_name,
ignore_patterns=["*.txt", "*.onnx", "*.ot", "*.md", "*.tflite", "*.pdf", "*.msgpack", "*.h5", "*.pth"],
)
if train_config.task_mode == Task_Mode.SEQ_CLASSIFICATION:
model = AutoModelForSequenceClassification.from_pretrained(
pretrained_model_path,
Expand All @@ -131,13 +144,14 @@ def load_model_and_tokenizer(
if param.requires_grad:
param.data = param.data.to(torch.float32)
else:
device_map = get_device_map(train_config)
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_path,
use_cache=False,
attn_implementation="sdpa",
torch_dtype=torch.float16,
device_map=device_map,
)

tokenizer = AutoTokenizer.from_pretrained(
train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name
)
Expand Down Expand Up @@ -290,11 +304,24 @@ def main(**kwargs) -> None:
f"passed context length is {train_config.context_length} and overall model's context length is "
f"{model.config.max_position_embeddings}"
)
model.to(train_config.device)
optimizer = optim.AdamW(model.parameters(), lr=train_config.lr, weight_decay=train_config.weight_decay)
if not train_config.enable_pp:
model.to(train_config.device)
optimizer = optim.AdamW(
model.parameters(),
lr=train_config.lr,
weight_decay=train_config.weight_decay,
)
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
if train_config.enable_ddp:
model = nn.parallel.DistributedDataParallel(model, device_ids=[dist.get_rank()])
ignore_names = set()
for name, param in model.named_parameters():
if not param.requires_grad:
ignore_names.add(name)
# Adding params in ignore list will enforce DDP to ignore them during synchronization,
# which will further reduce the tensor exchange across devices.
torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(model, ignore_names)
model = nn.parallel.DistributedDataParallel(model)

results = train(
model,
tokenizer,
Expand All @@ -303,7 +330,6 @@ def main(**kwargs) -> None:
optimizer,
scheduler,
train_config,
dist.get_rank() if train_config.enable_ddp else None,
)
if train_config.enable_ddp:
dist.destroy_process_group()
Expand Down
11 changes: 8 additions & 3 deletions QEfficient/finetune/configs/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,13 @@ class TrainConfig:
save_metrics (bool): Save training metrics (default: True).
intermediate_step_save (int): Steps between intermediate saves (default: 1000).
batching_strategy (str): Batching strategy (default: "packing").
enable_sorting_for_ddp (bool): Sort data for DDP (default: True).
convergence_counter (int): Steps to check convergence (default: 5).
convergence_loss (float): Loss threshold for convergence (default: 1e-4).
use_profiler (bool): Enable profiling (default: False).
enable_pp (bool): Enable training with pipeline parallelism (default: False).
num_pp_stages (int): Number of stages in which model is split layerwise when training using pipeline (default: 1).
enable_ddp (bool): Enable distributed data parallel (default: False).
enable_sorting_for_ddp (bool): Sort data for DDP (default: True).
opByOpVerifier (bool): Enable operation-by-operation verification (default: False).
dump_logs (bool): Whether to dump logs (default: True).
log_level (str): logging level (default: logging.INFO)
Expand Down Expand Up @@ -87,8 +89,6 @@ class TrainConfig:
save_metrics: bool = True # saves training metrics to a json file for later plotting
intermediate_step_save: int = 1000
batching_strategy: str = Batching_Strategy.PADDING.value
enable_ddp: bool = False
enable_sorting_for_ddp: bool = True
convergence_counter: int = 5 # its value should be >= 1, stop fine tuning when loss <= convergence_loss (defined below) for #convergence_counter steps
convergence_loss: float = (
1e-4 # if loss value is <= convergence_loss for #convergence_counter consecutive steps, fine tuning stops
Expand All @@ -100,6 +100,11 @@ class TrainConfig:
use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time.
# profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler

# dist-related
enable_pp: bool = False
num_pp_stages: int = 1
enable_ddp: bool = False
enable_sorting_for_ddp: bool = True
opByOpVerifier: bool = False

dump_logs: bool = True
Expand Down
14 changes: 12 additions & 2 deletions QEfficient/finetune/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import logging
from typing import Dict, List, Tuple

import datasets
import torch
Expand All @@ -13,7 +15,7 @@

from QEfficient.finetune.data.sampler import DistributedLengthBasedBatchSampler
from QEfficient.finetune.dataset.dataset_config import DATALOADER_COLLATE_FUNC, DATASET_PREPROC
from QEfficient.finetune.utils.helper import get_num_ddp_devices
from QEfficient.finetune.utils.helper import get_world_size
from QEfficient.finetune.utils.logging_utils import logger


Expand Down Expand Up @@ -68,7 +70,7 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, split):


def padding_dataset(train_config, dataset, batch_size):
num_replicas = get_num_ddp_devices()
num_replicas = get_world_size()
remainder = len(dataset) % (num_replicas * batch_size)
if remainder == 0:
return dataset
Expand Down Expand Up @@ -125,3 +127,11 @@ def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"
**dl_kwargs,
)
return dataloader


def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
# find out the minimum max_seq_length required during fine-tuning (saves memory!)
lengths = [len(d["input_ids"]) for d in data]
longest_seq_length = max(lengths)
longest_seq_ix = lengths.index(longest_seq_length)
return longest_seq_length, longest_seq_ix
107 changes: 107 additions & 0 deletions QEfficient/finetune/utils/device_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------


import numpy as np
import torch
from transformers import AutoConfig

from QEfficient.finetune.utils.helper import get_rank
from QEfficient.utils._utils import get_num_layers_from_config


def get_device_map(train_config):
"""Returns device map for the given model.

Args:
train_config (TrainConfig): Training configuration object contaning model name and number of pipeline stages etc.

Returns:
Dict: A dictionary of layers and corresponding device id.
"""
torch_device = torch.device(train_config.device)
num_available_devices = getattr(torch, torch_device.type).device_count()
if train_config.enable_pp:
if train_config.enable_ddp:
device_map = custom_device_map(train_config)
elif train_config.num_pp_stages < num_available_devices:
device_map = custom_device_map(train_config)
elif train_config.num_pp_stages == num_available_devices:
device_map = "auto"
else:
device_map = None

return device_map


def custom_device_map(train_config):
"""Returns custom device map for model layers based number of pipeline stages and given process rank.

Args:
train_config (TrainConfig): Training configuration object contaning model name and number of pipeline stages etc.

Returns:
Dict: A dictionary of layers and corresponding device id.

Notes:
- This device map structure is verified for llama models only.

Example:
Configuration for meta-llama/Llama-3.2-1B
Total devices: 4 (2x PP x 2x DDP)

PP (Pipeline Parallelism): Each copy of the model is split into 2 stages
DDP (Distributed Data Parallel): 2 model copies run in parallel

|--------------------------------------------------------------------------|
| Process Rank | Assigned Device IDs | Model Component |
|--------------------------------------------------------------------------|
| Rank 0 | 0 | model.embed_tokens |
| | | model.lm_head |
| | | model.layers.0 - model.layers.7 |
|--------------------------------------------------------------------------|
| Rank 0 | 1 | model.norm |
| | | model.rotary_emb |
| | | model.layers.8 - model.layers.15 |
|--------------------------------------------------------------------------|
| Rank 1 | 2 | model.embed_tokens |
| | | model.lm_head |
| | | model.layers.0 - model.layers.7 |
|--------------------------------------------------------------------------|
| Rank 1 | 3 | model.norm |
| | | model.rotary_emb |
| | | model.layers.8 - model.layers.15 |
|--------------------------------------------------------------------------|
"""

model_config = AutoConfig.from_pretrained(train_config.model_name)
num_layers = get_num_layers_from_config(model_config)
num_pp_stages = train_config.num_pp_stages
rank = get_rank()
first_device = rank * num_pp_stages
last_device = rank * num_pp_stages + (num_pp_stages - 1)

if model_config.tie_word_embeddings:
lm_head_device = first_device
else:
lm_head_device = last_device

device_map = {
"model.embed_tokens": first_device,
"lm_head": lm_head_device,
"model.norm": last_device,
"model.rotary_emb": last_device,
}
n_layer_per_stage = np.ceil(num_layers / num_pp_stages)

pp_stage_ids = np.arange(num_pp_stages)
pp_device_map = np.repeat(pp_stage_ids, n_layer_per_stage)

for i in range(num_layers):
device_map[f"model.layers.{i}"] = pp_device_map[i] + rank * num_pp_stages

return device_map
37 changes: 34 additions & 3 deletions QEfficient/finetune/utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import json
import os
from contextlib import nullcontext
from enum import Enum
Expand Down Expand Up @@ -37,14 +39,18 @@ class Task_Mode(str, Enum):


def enum_names(enum_cls):
return [member.value for member in enum_cls]
return [x.value for x in enum_cls]


def get_rank():
return int(os.getenv("LOCAL_RANK", 0))


def is_rank_zero():
return int(os.getenv("LOCAL_RANK", 0)) == 0
return get_rank() == 0


def get_num_ddp_devices():
def get_world_size():
return int(os.getenv("WORLD_SIZE", 1))


Expand Down Expand Up @@ -77,3 +83,28 @@ def get_op_verifier_ctx(
filter_config=filter_config,
dump_root_dir=dump_dir,
)


def save_to_json(
output_filename,
train_step_loss,
train_epoch_loss,
train_step_metric,
train_epoch_metric,
val_step_loss,
val_epoch_loss,
val_step_metric,
val_epoch_metric,
):
metrics_data = {
"train_step_loss": train_step_loss,
"train_epoch_loss": train_epoch_loss,
"train_step_metric": train_step_metric,
"train_epoch_metric": train_epoch_metric,
"val_step_loss": val_step_loss,
"val_epoch_loss": val_epoch_loss,
"val_step_metric": val_step_metric,
"val_epoch_metric": val_epoch_metric,
}
with open(output_filename, "w") as f:
json.dump(metrics_data, f)
Loading
Loading