Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] grpo latest support npu #6242

Open
wants to merge 33 commits into
base: grpo-latest-npu
Choose a base branch
from
Open
Changes from 2 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ef1c1cb
fix inference rebatching bug
YeAnbang Feb 20, 2025
01f84de
fix num_train_step update
YeAnbang Feb 20, 2025
bc66524
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2025
ccc512a
[misc] update torch version (#6206)
ver217 Feb 24, 2025
bcb5b60
[hotfix] fix lora load (#6231)
ver217 Mar 1, 2025
a7e3bec
[release] update version (#6236)
ver217 Mar 3, 2025
265d430
[feat] add ops test to adapt npu
duanjunwen Mar 11, 2025
e61bb0a
[feat] test loss func & assert close
duanjunwen Mar 11, 2025
704866a
detach
Mar 11, 2025
bc6e14a
[feat] support compare tools on npu
duanjunwen Mar 11, 2025
6930c7c
[fix] fix qwen policy, now use gather output as logits
duanjunwen Mar 12, 2025
4b24a03
[fix] fix qwen lmhead, now gather output for logints
duanjunwen Mar 12, 2025
05ca507
[feat] fix qwen Linear_Col --> VocalHead
duanjunwen Mar 13, 2025
2305f93
[fix] fix
duanjunwen Mar 13, 2025
03ce3c5
[fix] fix qwen VocabParallelLMHead1D and gather output
duanjunwen Mar 13, 2025
131eece
fix tp bug
Mar 13, 2025
afddfde
fix consumer
Mar 13, 2025
b835d1b
fix tp bug
Mar 13, 2025
137ec17
fix consumer
Mar 13, 2025
a9cf3aa
Merge branch 'grpo-latest' into grpo-latest-npu
duanjunwen Mar 13, 2025
4702d57
convert to 8 generation
Mar 13, 2025
45ac6c6
print results
Mar 13, 2025
57b49da
setup update
Mar 13, 2025
bc0171d
fix transformers backend
YeAnbang Mar 14, 2025
7b3c310
Merge branch 'hpcaitech:grpo-latest' into grpo-latest
duanjunwen Mar 17, 2025
dcf3f9b
[fix] fix qwen VocabParallelLMHead1D and gather output
duanjunwen Mar 13, 2025
d90bf57
Merge branch 'grpo-latest' of github.com:duanjunwen/ColossalAI into g…
duanjunwen Mar 18, 2025
a53d4cd
Merge branch 'grpo-latest' into grpo-latest-npu
duanjunwen Mar 18, 2025
7795d4c
[Feature] Support Distributed LogProb for GRPO Training (#6247)
duanjunwen Mar 18, 2025
283a479
Merge branch 'hpcaitech:grpo-latest' into grpo-latest
duanjunwen Mar 19, 2025
4712ecc
Merge branch 'grpo-latest' of github.com:duanjunwen/ColossalAI into g…
duanjunwen Mar 19, 2025
d3fd485
Merge branch 'grpo-latest' into grpo-latest-npu
duanjunwen Mar 19, 2025
3f4818c
[feat] support hybrid test
duanjunwen Mar 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions applications/ColossalChat/coati/distributed/consumer.py
Original file line number Diff line number Diff line change
@@ -73,8 +73,6 @@ def setup(self) -> None:
)
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
plugin_config["microbatch_size"] = self.microbatch_size
if self.plugin_config.get("tp_size", 1) > 1:
plugin_config["parallel_output"] = False
plugin_config.update(self.plugin_config)
self.plugin = HybridParallelPlugin(**plugin_config)
self.booster = Booster(plugin=self.plugin)
Original file line number Diff line number Diff line change
@@ -120,14 +120,18 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)["logits"]
action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action)
action_log_probs = calc_action_log_probs(
policy_model_logits, data["input_ids"], num_action, self.plugin.shard_config
)

with torch.no_grad():
reference_model_logits = self.reference_model(
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)["logits"]
reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action)
reference_action_log_probs = calc_action_log_probs(
reference_model_logits, data["input_ids"], num_action, self.plugin.shard_config
)

per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs)
20 changes: 17 additions & 3 deletions applications/ColossalChat/coati/distributed/utils.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,8 @@

import torch

from colossalai.shardformer.layer.loss import dist_log_prob


def unbind_batch(batch: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]:
batches = []
@@ -66,18 +68,30 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T
return per_label_logps.squeeze(-1)


def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
def calc_action_log_probs(
logits: torch.Tensor,
sequences: torch.LongTensor,
num_actions: int,
shard_config,
vocab_size: int = None,
) -> torch.Tensor:
"""Calculate action log probs.

Args:
output (torch.Tensor): Output tensor of Actor.forward.logits.
logits (torch.Tensor): Output tensor of Actor.forward.logits.
sequences (torch.LongTensor): Input sequences.
num_actions (int): Number of actions.
shard_config
vocab_size


Returns:
torch.Tensor: Action log probs.
"""
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
# labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
# logits: torch.Tensor, # [B, S, Vocab_size]
log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype)
log_probs = log_probs.squeeze(-1)
return log_probs[:, -num_actions:]


4 changes: 3 additions & 1 deletion colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D
from .loss import cross_entropy_1d, dist_cross_entropy
from .loss import cross_entropy_1d, dist_cross_entropy, dist_log_prob, dist_log_prob_1d
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
from .qkv_fused_linear import (
@@ -28,6 +28,8 @@
"DropoutForReplicatedInput",
"cross_entropy_1d",
"dist_cross_entropy",
"dist_log_prob_1d",
"dist_log_prob",
"BaseLayerNorm",
"LayerNorm",
"RMSNorm",
150 changes: 149 additions & 1 deletion colossalai/shardformer/layer/loss.py
Original file line number Diff line number Diff line change
@@ -3,13 +3,21 @@
from torch.autograd import Function
from torch.distributed import ProcessGroup
from torch.nn import CrossEntropyLoss
from torch.nn.functional import log_softmax

from colossalai.shardformer.layer._operation import reduce_forward
from colossalai.shardformer.shard import ShardConfig

from .utils import is_share_sp_tp

__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"]
__all__ = [
"DistCrossEntropy",
"cross_entropy_1d",
"dist_cross_entropy",
"DistLogProb",
"dist_log_prob_1d",
"dist_log_prob",
]

_IGNORE_IDX = -100

@@ -137,6 +145,98 @@ def backward(ctx, grad_output):
return grad_logits, None, None, None, None, None, None


class DistLogProb(Function):
r"""
Overwrite the forward and backward function to calculate the log prob before gather

Args:
Function (:class:`torch.autograd.Function`): default
"""

@staticmethod
def forward(
ctx,
vocab_logits: torch.Tensor,
target: torch.Tensor,
process_group: ProcessGroup,
vocab_size: int,
dtype=torch.float32,
):

##################
# Step1:Find the global maximum value of logits
##################
logits_max = torch.max(vocab_logits, dim=-1)[0]
handle = dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group, async_op=True)

##################
# Step2:Find the local mask. local mask will be use to select log_probs value in Step 4.
# For accleration, we overlap Step 2 and Step 3
##################
rank = dist.get_rank(group=process_group)
world_size = dist.get_world_size(group=process_group)
if vocab_size is None:
partition_vocab_size = vocab_logits.size()[-1]
global_vocab_size = partition_vocab_size * world_size
else:
global_vocab_size = vocab_size
partition_vocab_size = global_vocab_size // world_size
# down and up threshold for local logits
delta = (global_vocab_size + world_size - 1) // world_size
down_threshold = rank * delta
up_threshold = down_threshold + delta
if up_threshold > global_vocab_size:
up_threshold = global_vocab_size
# mask
mask = (target < down_threshold) | (target >= up_threshold)
masked_target = target.clone() - down_threshold
masked_target[mask] = 0
masked_target_1d = masked_target.view(-1).contiguous()
handle.wait()

##################
# Step3:Calculate global summation exp logits
##################
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
exp_logits = torch.exp(vocab_logits)
sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32) # local summation exp logits
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group)

##################
# Step4:Calculate local prob. We first cal log_softmax, then select log probs via local mask
##################
log_probs = vocab_logits - torch.log(sum_exp_logits.unsqueeze(dim=-1)) # cal log_softmax
log_probs = log_probs.gather(dim=-1, index=masked_target.unsqueeze(-1))
log_probs[mask.unsqueeze(-1)] = 0 # set masked val to zero
dist.all_reduce(log_probs, op=dist.ReduceOp.SUM, group=process_group)

ctx.save_for_backward(exp_logits, mask, masked_target_1d, sum_exp_logits)
ctx.dtype = dtype
return log_probs

@staticmethod
def backward(ctx, grad_output):
exp_logits, mask, masked_target_1d, sum_exp_logits = ctx.saved_tensors
##################
# Step1:Find the global sofmax value
##################
softmax_logits = exp_logits / sum_exp_logits.unsqueeze(dim=-1)

##################
# Step2:Update softmax value based on local target index
##################
partion_vocab_size = softmax_logits.shape[-1]
softmax_logits_2d = softmax_logits.view(-1, partion_vocab_size)
update = 1.0 - mask.view(-1).float().to(ctx.dtype)
softmax_logits_2d[torch.arange(0, softmax_logits_2d.shape[0]), masked_target_1d] -= update

##################
# Step3:Calculate grad_output, which is the gradient of the loss function with respect to the output of logsoftmax
##################
grad_logits = -softmax_logits.mul_(grad_output)
return grad_logits, None, None, None, None, None, None


def cross_entropy_1d(
vocab_logits: torch.Tensor,
labels: torch.Tensor,
@@ -149,6 +249,16 @@ def cross_entropy_1d(
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype, mode)


def dist_log_prob_1d(
vocab_logits: torch.Tensor,
labels: torch.Tensor,
process_group: ProcessGroup = None,
vocab_size: int = None,
dtype: torch.dtype = None,
) -> torch.Tensor:
return DistLogProb.apply(vocab_logits, labels, process_group, vocab_size, dtype)


def dist_cross_entropy(
labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
logits: torch.Tensor, # [B, S, Vocab_size]
@@ -243,3 +353,41 @@ def dist_cross_entropy(
loss, num_nonzero = loss[0], loss[1].detach()
loss = (loss / num_nonzero).squeeze()
return loss


def dist_log_prob(
labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
logits: torch.Tensor, # [B, S, Vocab_size]
shard_config: ShardConfig,
vocab_size: int,
dtype: torch.dtype,
seq_dim: int = 1,
) -> torch.Tensor:
"""
Helper to compute log prob for most shardformer models supporting PP, TP.
"""
# Split labels if not gather output
parallel_output = shard_config.parallel_output
is_tp = shard_config.enable_tensor_parallelism

# TODO:support sp
labels = labels[..., 1:]
logits = logits[..., :-1, :]
labels = labels.contiguous()
logits = logits.contiguous()
assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}"

# Flatten the tokens
if is_tp and parallel_output:
log_prob = dist_log_prob_1d(
logits,
labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=vocab_size,
dtype=dtype,
)
else:
log_prob = log_softmax(logits)
log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1))

return log_prob
1 change: 0 additions & 1 deletion colossalai/shardformer/modeling/qwen2.py
Original file line number Diff line number Diff line change
@@ -832,7 +832,6 @@ def forward(
loss = None
if labels is not None:
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype)

if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
52 changes: 52 additions & 0 deletions tests/test_shardformer/test_layer/test_dist_log_prob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest
import torch
from coati.distributed.utils import log_probs_from_logits

import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer import dist_log_prob_1d
from colossalai.testing import rerun_if_address_is_in_use, spawn

CONFIG = dict(
parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode="1d")),
)


def check_dist_log_prob(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl")

# prepare data
pred = torch.randn(2, 4, 8, requires_grad=True).cuda()
labels = torch.randint(8, (2, 4)).cuda()

logprob = log_probs_from_logits(pred, labels)

pred.retain_grad()
logprob.mean().backward()

dist_pred = pred.clone().chunk(world_size, -1)[rank].detach()
dist_pred.requires_grad = True
dist_logprob = dist_log_prob_1d(dist_pred, labels)

dist_pred.retain_grad()
dist_logprob.squeeze(-1).mean().backward()

assert torch.allclose(
logprob, dist_logprob.squeeze(-1), atol=1e-5
), f"dist cross entropy logprob is not equal to orgin logprob\n{logprob}\n{dist_logprob.squeeze(-1)}"

pred_grad_partial = pred.grad.clone().chunk(world_size, -1)[rank].detach()
assert torch.allclose(
pred_grad_partial, dist_pred.grad
), f"dist grad is not equal to orgin grad\n{pred.grad}\n{dist_pred.grad}"


@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_dist_log_prob():
spawn(check_dist_log_prob, 2)


if __name__ == "__main__":
test_dist_log_prob()