Skip to content

Commit 5f64acf

Browse files
aleien95lilin3
andauthored
support sp in vlm (#71)
* feat: add multimodal sequence parallelism support * refactor: separate SFT/DPO training scripts and optimize data handling - Remove data printing functions from SFT and DPO trainers for better performance - Replace 360-example-vl.sh with separate SFT and DPO training scripts - Add SFT visual-language demo dataset (data/sft-vl-demo/) - Update dataset configuration to support new data structure * refactor: restructure multimodal model forward functions and optimize code style - Add multimodal_forwards module to centrally manage multimodal model forward logic - Extract and optimize forward function implementations for Qwen2 VL and Qwen2.5 VL - Improve sequence_parallel related code structure * feat(vl): update VL training scripts and clean up demo data * refactor: improve readability of sequence parallel attention check --------- Co-authored-by: lilin3 <lilin3@360.cn>
1 parent 5ce4778 commit 5f64acf

File tree

16 files changed

+608
-20
lines changed

16 files changed

+608
-20
lines changed

360-example-vl-dpo.sh

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#!/bin/bash
2+
3+
set -x
4+
5+
# Environment setup
6+
export DS_SKIP_CUDA_CHECK=1
7+
export DISABLE_VERSION_CHECK=1
8+
export FORCE_TORCHRUN=1
9+
export CUDA_LAUNCH_BLOCKING=1
10+
11+
# Parameters
12+
MODEL_PATH=""
13+
MODEL_SIZE="7B"
14+
DATA_NAME="dpo-vl-demo"
15+
NUM_NODES=1
16+
NUM_GPUS=8
17+
CUTOFF_LEN=20000
18+
LEARNING_RATE=1e-6
19+
PER_DEVICE_BATCH_SIZE=1
20+
GRADIENT_ACCUMULATION_STEPS=16
21+
22+
# Output directory
23+
model_saved_name="demo-qwen25vl-${MODEL_SIZE}-len_${CUTOFF_LEN}-lr_${LEARNING_RATE}-data_${DATA_NAME}"
24+
OUTPUT_DIR="./output/${MODEL_SIZE}/${model_saved_name}"
25+
tensorboard_dir="${OUTPUT_DIR}/runs_${MODEL_SIZE}/${model_saved_name}"
26+
27+
# Create directories
28+
mkdir -p ${OUTPUT_DIR}
29+
mkdir -p ${tensorboard_dir}
30+
31+
# DPO Training
32+
deepspeed --hostfile=/etc/mpi.host src/train.py \
33+
--stage dpo \
34+
--do_train \
35+
--model_name_or_path ${MODEL_PATH} \
36+
--dataset ${DATA_NAME} \
37+
--dataset_dir ./data \
38+
--template qwen2_vl \
39+
--finetuning_type full \
40+
--freeze_vision_tower True \
41+
--train_mm_proj_only False \
42+
--image_resolution 1048576 \
43+
--video_resolution 16384 \
44+
--pref_beta 0.1 \
45+
--pref_ftx 0.0 \
46+
--output_dir ${OUTPUT_DIR} \
47+
--overwrite_cache \
48+
--overwrite_output_dir True \
49+
--cutoff_len ${CUTOFF_LEN} \
50+
--preprocessing_num_workers 128 \
51+
--per_device_train_batch_size ${PER_DEVICE_BATCH_SIZE} \
52+
--gradient_accumulation_steps ${GRADIENT_ACCUMULATION_STEPS} \
53+
--learning_rate ${LEARNING_RATE} \
54+
--lr_scheduler_type cosine_with_min_lr \
55+
--lr_scheduler_kwargs "{\"min_lr_rate\": 0.1}" \
56+
--num_train_epochs 1 \
57+
--warmup_ratio 0.05 \
58+
--logging_steps 1 \
59+
--logging_dir "./output/runs_${MODEL_SIZE}/${model_saved_name}" \
60+
--save_strategy epoch \
61+
--plot_loss True \
62+
--deepspeed examples/deepspeed/ds_z2_config.json \
63+
--use_unsloth_gc True \
64+
--bf16 \
65+
--flash_attn fa2 \
66+
--sequence_parallel_size 8 \
67+
--ddp_timeout 180000000 \
68+
--report_to tensorboard

360-example-vl-sft.sh

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#!/bin/bash
2+
3+
set -x
4+
5+
# Environment setup
6+
export DS_SKIP_CUDA_CHECK=1
7+
export DISABLE_VERSION_CHECK=1
8+
export FORCE_TORCHRUN=1
9+
export CUDA_LAUNCH_BLOCKING=1
10+
11+
# Parameters
12+
MODEL_PATH=""
13+
MODEL_SIZE="7B"
14+
DATA_NAME="sft-vl-demo"
15+
NUM_NODES=1
16+
NUM_GPUS=8
17+
CUTOFF_LEN=20000
18+
LEARNING_RATE=6e-5
19+
PER_DEVICE_BATCH_SIZE=1
20+
GRADIENT_ACCUMULATION_STEPS=16
21+
22+
# Output directory
23+
model_saved_name="demo-qwen25vl-${MODEL_SIZE}-len_${CUTOFF_LEN}-lr_${LEARNING_RATE}-data_${DATA_NAME}"
24+
OUTPUT_DIR="./output/${MODEL_SIZE}/${model_saved_name}"
25+
tensorboard_dir="${OUTPUT_DIR}/runs_${MODEL_SIZE}/${model_saved_name}"
26+
27+
# Create directories
28+
mkdir -p ${OUTPUT_DIR}
29+
mkdir -p ${tensorboard_dir}
30+
31+
# SFT Training
32+
deepspeed --hostfile=/etc/mpi.host src/train.py \
33+
--stage sft \
34+
--do_train \
35+
--model_name_or_path ${MODEL_PATH} \
36+
--dataset ${DATA_NAME} \
37+
--dataset_dir ./data \
38+
--template qwen2_vl \
39+
--finetuning_type full \
40+
--freeze_vision_tower True \
41+
--train_mm_proj_only False \
42+
--image_resolution 1048576 \
43+
--video_resolution 16384 \
44+
--output_dir ${OUTPUT_DIR} \
45+
--overwrite_cache \
46+
--overwrite_output_dir True \
47+
--cutoff_len ${CUTOFF_LEN} \
48+
--preprocessing_num_workers 128 \
49+
--per_device_train_batch_size ${PER_DEVICE_BATCH_SIZE} \
50+
--gradient_accumulation_steps ${GRADIENT_ACCUMULATION_STEPS} \
51+
--learning_rate ${LEARNING_RATE} \
52+
--lr_scheduler_type cosine_with_min_lr \
53+
--lr_scheduler_kwargs "{\"min_lr_rate\": 0.1}" \
54+
--num_train_epochs 1 \
55+
--warmup_ratio 0.05 \
56+
--logging_steps 1 \
57+
--logging_dir "./output/runs_${MODEL_SIZE}/${model_saved_name}" \
58+
--save_strategy epoch \
59+
--plot_loss True \
60+
--deepspeed examples/deepspeed/ds_z2_config.json \
61+
--use_unsloth_gc True \
62+
--bf16 \
63+
--flash_attn fa2 \
64+
--sequence_parallel_size 8 \
65+
--ddp_timeout 180000000 \
66+
--report_to tensorboard

data/dataset_info.json

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,5 +624,32 @@
624624
"prompt": "content"
625625
},
626626
"folder": "python"
627+
},
628+
"sft-vl-demo": {
629+
"file_name": "sft-vl-demo/train.jsonl",
630+
"formatting": "sharegpt",
631+
"columns": {
632+
"messages": "conversations",
633+
"images": "images",
634+
"videos": "videos"
635+
},
636+
"tags": {
637+
"role_tag": "from",
638+
"content_tag": "value",
639+
"user_tag": "human",
640+
"assistant_tag": "assistant"
641+
}
642+
},
643+
"dpo-vl-demo": {
644+
"file_name": "dpo-vl-demo/train.jsonl",
645+
"ranking": true,
646+
"formatting": "sharegpt",
647+
"columns": {
648+
"messages": "conversations",
649+
"chosen": "chosen",
650+
"rejected": "rejected",
651+
"images": "images",
652+
"videos": "videos"
653+
}
627654
}
628655
}

src/llamafactory/data/collator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tenso
154154
if self.require_position_ids:
155155
# if requires, would be padded to cutoff_len in preprocessing
156156
target_feature["position_ids"] = feature[f"{key}_position_ids"]
157+
if "image_position_maps" in feature:
158+
target_feature["image_position_maps"] = feature[f"image_position_maps"]
157159
concatenated_features.append(target_feature)
158160

159161
return super().__call__(concatenated_features)

src/llamafactory/data/preprocess.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,9 @@ def get_sequence_parallel_preprocess(
119119
tokenizer: "PreTrainedTokenizer",
120120
) -> Tuple[Callable, Callable]:
121121
if stage == "pad":
122-
preprocess_func = partial(pad_sequence, data_args=data_args, tokenizer=tokenizer)
122+
preprocess_func = partial(pad_sequence, data_args=data_args, tokenizer=tokenizer, model_args=model_args)
123123
elif stage == "split":
124-
preprocess_func = partial(sp_split, model_args=model_args)
124+
preprocess_func = partial(sp_split, model_args=model_args, tokenizer=tokenizer)
125125
else:
126126
raise NotImplementedError(f"Unexpected stage in sequence_parallel_preprocess: {stage}")
127127

src/llamafactory/data/processors/sequence_parallel.py

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,28 @@
22
from ..data_utils import preprocess_sp_dataset
33

44

5-
def pad_sequence(examples, data_args, tokenizer):
5+
def get_max_lengths(examples):
6+
valid_lists = []
7+
for key, value in examples.items():
8+
if key.endswith('input_ids') and value is not None:
9+
valid_lists.append(value)
10+
11+
if not valid_lists:
12+
return []
13+
14+
max_lengths = [max(len(lst) if lst is not None else 0 for lst in group)
15+
for group in zip(*valid_lists)]
16+
17+
return max_lengths
18+
19+
20+
def pad_sequence(examples, data_args, tokenizer, model_args):
621
max_length = data_args.cutoff_len
722
input_pad_token_id = tokenizer.pad_token_id
823
assert data_args.ignore_pad_token_for_loss
924
label_pad_token_id = IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
1025

26+
max_input_ids_length_list = get_max_lengths(examples)
1127
for k, v in examples.items():
1228
if k.endswith("input_ids"):
1329
pad_token_id = input_pad_token_id
@@ -25,25 +41,59 @@ def pad_sequence(examples, data_args, tokenizer):
2541
continue # TODO: haven't tested multi-modal yet
2642
else:
2743
raise NotImplementedError(f"Unexpected dataset key: {k}")
44+
2845
for i in range(len(v)):
29-
v[i].extend([pad_token_id] * (max_length - len(v[i])))
46+
tmp_sp_len = max_input_ids_length_list[i] // model_args.sequence_parallel_size
47+
closest_cutoff_len = int(tmp_sp_len + (8 - tmp_sp_len % 8)) * model_args.sequence_parallel_size
48+
max_length = min(closest_cutoff_len, data_args.cutoff_len)
49+
50+
v[i].extend([pad_token_id] * (max_length - len(v[i])))
3051
examples[k] = v
3152

3253
return examples
3354

3455

56+
def create_image_position_info(seq_ids, image_token_id):
57+
"""为整个序列创建图像位置信息"""
58+
info = []
59+
global_image_pos = 0 # 全局连续的图像位置计数器
60+
61+
for token_id in seq_ids:
62+
if token_id == image_token_id:
63+
info.append(global_image_pos)
64+
global_image_pos += 1
65+
else:
66+
info.append(-1)
67+
return info
68+
69+
3570
# sp for Sequence Parallel
36-
def sp_split(examples, model_args):
71+
def sp_split(examples, model_args, tokenizer):
72+
all_image_position_maps = list()
73+
new_examples = dict()
74+
3775
for k, v in examples.items():
3876
chunks = list()
3977
for row in v:
40-
if k.endswith("attention_mask"):
41-
chunks.extend([row] * model_args.sequence_parallel_size)
42-
elif row is None:
78+
if row is None:
4379
chunks.extend([None] * model_args.sequence_parallel_size)
80+
elif k in ['images']:
81+
chunks.extend([row] * model_args.sequence_parallel_size)
4482
else:
4583
chunks.extend(
4684
preprocess_sp_dataset(row, model_args.sequence_parallel_size, model_args.sequence_parallel_mode)
4785
)
48-
examples[k] = chunks
49-
return examples
86+
if k.endswith("input_ids") and len(all_image_position_maps) < (len(v) * model_args.sequence_parallel_size):
87+
image_position_info = create_image_position_info(row, tokenizer.image_token_id)
88+
all_image_position_maps.extend(
89+
preprocess_sp_dataset(image_position_info, model_args.sequence_parallel_size, model_args.sequence_parallel_mode)
90+
)
91+
new_examples[k] = chunks
92+
93+
if len(all_image_position_maps)>0:
94+
new_examples['image_position_maps'] = all_image_position_maps
95+
for index in range(len(new_examples['images'])):
96+
if all(image_position==-1 for image_position in new_examples['image_position_maps'][index]):
97+
new_examples['images'][index] = None
98+
99+
return new_examples

src/llamafactory/hparams/parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
logger = logging.get_logger(__name__)
4242

4343

44-
check_dependencies()
44+
# check_dependencies()
4545

4646

4747
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]

src/llamafactory/model/loader.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
8787
except Exception as e:
8888
raise OSError("Failed to load tokenizer.") from e
8989

90+
if hasattr(config, 'image_token_id'):
91+
tokenizer.image_token_id = config.image_token_id
92+
9093
if model_args.new_special_tokens is not None:
9194
num_added_tokens = tokenizer.add_special_tokens(
9295
dict(additional_special_tokens=model_args.new_special_tokens),
@@ -144,7 +147,7 @@ def load_model(
144147
config.attention_dropout = 0.0
145148

146149
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
147-
sequence_parallel_group = apply_sequence_parallel(model_args, full_determinism) # monkey patching, similar to liger_kernel
150+
sequence_parallel_group = apply_sequence_parallel(model_args, config, full_determinism) # monkey patching, similar to liger_kernel
148151

149152
model = None
150153
lazy_load = False
@@ -157,7 +160,9 @@ def load_model(
157160
if model is None and not lazy_load:
158161
init_kwargs["config"] = config
159162
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
160-
if sequence_parallel_group is not None and is_transformers_version_greater_than("4.51.0"):
163+
if (sequence_parallel_group is not None
164+
and is_transformers_version_greater_than("4.51.0")
165+
and config.model_type not in ['qwen2_vl', 'qwen2_5_vl']):
161166
init_kwargs["attn_implementation"] = "sequence_parallel_attention"
162167

163168
if model_args.mixture_of_depths == "load":

src/llamafactory/model/model_utils/checkpointing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ def backward(ctx: "torch.autograd.Function", grad_output: "torch.Tensor") -> "to
6868
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
6969
hidden_states.requires_grad_(True)
7070
with torch.enable_grad():
71-
(output,) = ctx.forward_function(hidden_states, *ctx.args)
71+
#(output,) = ctx.forward_function(hidden_states, *ctx.args)
72+
outputs = ctx.forward_function(hidden_states, *ctx.args)
73+
output = outputs[0] if isinstance(outputs, tuple) else outputs
7274

7375
torch.autograd.backward(output, grad_output)
7476
return (None, hidden_states.grad) + (None,) * len(ctx.args)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!/usr/bin/python
2+
#!-*-coding:utf8-*-
3+
4+
5+
from .qwen2_vl_forward import patched_qwen2_vl_forward
6+
from .qwen2_5_vl_forward import patched_qwen2_5_vl_forward
7+
8+
__all__ = ['patched_qwen2_vl_forward', 'patched_qwen2_5_vl_forward']

0 commit comments

Comments
 (0)