-
Notifications
You must be signed in to change notification settings - Fork 50
[QEff Finetune]: Enable PP+DDP #394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
c213173
PP+DDP for 70B
quic-mamta 406a869
Merge branch 'quic:main' into pp_ddp
quic-mamta 9c3a460
Merge branch 'main' into pp_ddp
quic-mamta fff53ba
Update finetune.py
quic-mamta ba3e45a
Merge branch 'quic:main' into pp_ddp
quic-mamta 5ca0910
Merge branch 'main' into pp_ddp
quic-mamta b38a8f4
ignore params and use deterministic behaviour
mamtsing 005d07c
Merge branch 'quic:main' into pp_ddp
quic-mamta 0f2c5a9
Merge branch 'main' into pp_ddp
quic-mamta 76b953a
Merge branch 'main' into pp_ddp
mamtsing 95efa6e
minor refactoring
mamtsing db910fb
Merge branch 'main' into pp_ddp
quic-mamta d3e3029
Merge branch 'main' into pp_ddp
quic-mamta 27382ea
address comments
mamtsing 4d8d470
add enable_pp flag
mamtsing 3b9d71c
Merge branch 'main' into pp_ddp
quic-mamta 9d4e7c6
fix num_pp_stages
mamtsing File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# ----------------------------------------------------------------------------- | ||
# | ||
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# ----------------------------------------------------------------------------- | ||
|
||
|
||
import numpy as np | ||
import torch | ||
from transformers import AutoConfig | ||
|
||
from QEfficient.finetune.utils.helper import get_rank | ||
from QEfficient.utils._utils import get_num_layers_from_config | ||
|
||
|
||
def get_device_map(train_config): | ||
"""Returns device map for the given model. | ||
|
||
Args: | ||
train_config (TrainConfig): Training configuration object contaning model name and number of pipeline stages etc. | ||
|
||
Returns: | ||
Dict: A dictionary of layers and corresponding device id. | ||
""" | ||
torch_device = torch.device(train_config.device) | ||
num_available_devices = getattr(torch, torch_device.type).device_count() | ||
if train_config.enable_pp: | ||
if train_config.enable_ddp: | ||
device_map = custom_device_map(train_config) | ||
elif train_config.num_pp_stages < num_available_devices: | ||
device_map = custom_device_map(train_config) | ||
elif train_config.num_pp_stages == num_available_devices: | ||
device_map = "auto" | ||
else: | ||
device_map = None | ||
|
||
return device_map | ||
|
||
|
||
def custom_device_map(train_config): | ||
"""Returns custom device map for model layers based number of pipeline stages and given process rank. | ||
|
||
Args: | ||
train_config (TrainConfig): Training configuration object contaning model name and number of pipeline stages etc. | ||
|
||
Returns: | ||
Dict: A dictionary of layers and corresponding device id. | ||
|
||
Notes: | ||
- This device map structure is verified for llama models only. | ||
|
||
Example: | ||
Configuration for meta-llama/Llama-3.2-1B | ||
Total devices: 4 (2x PP x 2x DDP) | ||
|
||
PP (Pipeline Parallelism): Each copy of the model is split into 2 stages | ||
DDP (Distributed Data Parallel): 2 model copies run in parallel | ||
|
||
|--------------------------------------------------------------------------| | ||
| Process Rank | Assigned Device IDs | Model Component | | ||
|--------------------------------------------------------------------------| | ||
| Rank 0 | 0 | model.embed_tokens | | ||
| | | model.lm_head | | ||
| | | model.layers.0 - model.layers.7 | | ||
|--------------------------------------------------------------------------| | ||
| Rank 0 | 1 | model.norm | | ||
| | | model.rotary_emb | | ||
| | | model.layers.8 - model.layers.15 | | ||
|--------------------------------------------------------------------------| | ||
| Rank 1 | 2 | model.embed_tokens | | ||
| | | model.lm_head | | ||
| | | model.layers.0 - model.layers.7 | | ||
|--------------------------------------------------------------------------| | ||
| Rank 1 | 3 | model.norm | | ||
| | | model.rotary_emb | | ||
| | | model.layers.8 - model.layers.15 | | ||
|--------------------------------------------------------------------------| | ||
""" | ||
|
||
model_config = AutoConfig.from_pretrained(train_config.model_name) | ||
num_layers = get_num_layers_from_config(model_config) | ||
num_pp_stages = train_config.num_pp_stages | ||
rank = get_rank() | ||
first_device = rank * num_pp_stages | ||
last_device = rank * num_pp_stages + (num_pp_stages - 1) | ||
|
||
if model_config.tie_word_embeddings: | ||
lm_head_device = first_device | ||
else: | ||
lm_head_device = last_device | ||
|
||
device_map = { | ||
"model.embed_tokens": first_device, | ||
"lm_head": lm_head_device, | ||
"model.norm": last_device, | ||
"model.rotary_emb": last_device, | ||
} | ||
n_layer_per_stage = np.ceil(num_layers / num_pp_stages) | ||
|
||
pp_stage_ids = np.arange(num_pp_stages) | ||
pp_device_map = np.repeat(pp_stage_ids, n_layer_per_stage) | ||
|
||
for i in range(num_layers): | ||
device_map[f"model.layers.{i}"] = pp_device_map[i] + rank * num_pp_stages | ||
|
||
return device_map |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.