Yes, I will share a minimal reproducible script.
from functools import partial
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import colossalai
from colossalai.booster import Booster
from colossalai.nn.optimizer import HybridAdam
import datasets
from typing import Callable, Optional
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from tqdm import tqdm
from transformers import AutoConfig, GPT2ForSequenceClassification, get_linear_schedule_with_warmup
from transformers import AutoTokenizer
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
TEST_GEMINI = False
NUM_EPOCHS = 1
BATCH_SIZE = 32
LEARNING_RATE = 2.4e-5
WEIGHT_DECAY = 0.01
WARMUP_FRACTION = 0.1
def tokenize_batch(batch, tokenizer: Optional[AutoTokenizer] = None, max_length: int = 512):
texts = [sample["sentence1"] + " " + sample["sentence2"] for sample in batch]
labels = torch.tensor([sample["label"] for sample in batch], dtype=torch.long)
encoded = tokenizer(
texts,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_length,
)
data = {k: v for k, v in encoded.items()}
data["labels"] = labels
return data
def move_to_cuda(data):
if isinstance(data, torch.Tensor):
return data.cuda()
elif isinstance(data, dict):
return {k: move_to_cuda(v) for k, v in data.items()}
elif isinstance(data, list):
return [move_to_cuda(v) for v in data]
else:
return data
def train_epoch(
epoch: int,
model: nn.Module,
optimizer: Optimizer,
_criterion: Callable,
lr_scheduler: LRScheduler,
train_dataloader: DataLoader,
booster: Booster,
coordinator: DistCoordinator,
):
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage)
total_step = len(train_dataloader)
model.train()
optimizer.zero_grad()
train_dataloader_iter = iter(train_dataloader)
with tqdm(
range(total_step),
desc=f"Epoch [{epoch + 1}/{NUM_EPOCHS}]",
disable=not print_flag,
) as pbar:
for step in pbar:
if use_pipeline:
outputs = booster.execute_pipeline(
train_dataloader_iter, model, _criterion, optimizer, return_loss=True
)
if is_pp_last_stage:
loss = outputs["loss"]
pbar.set_postfix({"loss": loss.item()})
else:
data = next(train_dataloader_iter)
data = move_to_cuda(data)
outputs = model(**data)
loss = _criterion(outputs, None)
booster.backward(loss, optimizer)
pbar.set_postfix({"loss": loss.item()})
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
if step >= 20:
print(f"[Rank {coordinator.rank}] Early stop at step {step + 1}")
break
def main():
colossalai.launch_from_torch(seed=42)
coordinator = DistCoordinator()
plugin = HybridParallelPlugin(
tp_size=1,
pp_size=1,
sp_size=1,
enable_all_optimization=True,
)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
dataset = datasets.load_dataset("glue", "mrpc")
train_dataloader = plugin.prepare_dataloader(
dataset["train"],
batch_size=BATCH_SIZE,
shuffle=True,
drop_last=True,
collate_fn=partial(tokenize_batch, tokenizer=tokenizer, max_length=512),
)
config = AutoConfig.from_pretrained("gpt2", num_labels=2)
config.pad_token_id = tokenizer.pad_token_id
model = GPT2ForSequenceClassification.from_pretrained("gpt2", config=config).cuda()
lr = LEARNING_RATE * coordinator.world_size
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": WEIGHT_DECAY,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8)
total_steps = len(train_dataloader) * NUM_EPOCHS
num_warmup_steps = int(WARMUP_FRACTION * total_steps)
lr_scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=total_steps,
)
def _criterion(outputs, inputs):
return outputs.loss
optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8)
booster = Booster(plugin=plugin)
model, optimizer, _criterion, _, lr_scheduler = booster.boost(
model, optimizer, criterion=_criterion, lr_scheduler=lr_scheduler
)
for epoch in range(NUM_EPOCHS):
train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
if __name__ == "__main__":
main()
W1107 11:51:39.894366 2961155 site-packages/torch/distributed/run.py:793]
W1107 11:51:39.894366 2961155 site-packages/torch/distributed/run.py:793] *****************************************
W1107 11:51:39.894366 2961155 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W1107 11:51:39.894366 2961155 site-packages/torch/distributed/run.py:793] *****************************************
[11/07/25 11:51:44] INFO colossalai - colossalai - INFO:
/home/yanzhen/miniconda3/envs/colossal/lib/python3.
9/site-packages/colossalai/initialize.py:75 launch
INFO colossalai - colossalai - INFO: Distributed
environment is initialized, world size: 4
Using the latest cached version of the dataset since glue couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'mrpc' at /home/yanzhen/.cache/huggingface/datasets/glue/mrpc/0.0.0/bcdcba79d07bc864c1c254ccfcedcce55bcc9a8c (last modified on Wed Oct 29 20:32:03 2025).
Using the latest cached version of the dataset since glue couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'mrpc' at /home/yanzhen/.cache/huggingface/datasets/glue/mrpc/0.0.0/bcdcba79d07bc864c1c254ccfcedcce55bcc9a8c (last modified on Wed Oct 29 20:32:03 2025).
Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Using the latest cached version of the dataset since glue couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'mrpc' at /home/yanzhen/.cache/huggingface/datasets/glue/mrpc/0.0.0/bcdcba79d07bc864c1c254ccfcedcce55bcc9a8c (last modified on Wed Oct 29 20:32:03 2025).
Using the latest cached version of the dataset since glue couldn't be found on the Hugging Face Hub
Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Found the latest cached dataset configuration 'mrpc' at /home/yanzhen/.cache/huggingface/datasets/glue/mrpc/0.0.0/bcdcba79d07bc864c1c254ccfcedcce55bcc9a8c (last modified on Wed Oct 29 20:32:03 2025).
Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch [1/1]: 0%| | 0/28 [00:00<?, ?it/s]
[rank0]: Traceback (most recent call last):
[rank0]: File "/home/yanzhen/distributed_test/colossalAI/test/bug4.py", line 161, in <module>
[rank0]: main()
[rank0]: File "/home/yanzhen/distributed_test/colossalAI/test/bug4.py", line 157, in main
[rank0]: train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
[rank0]: File "/home/yanzhen/distributed_test/colossalAI/test/bug4.py", line 91, in train_epoch
[rank0]: outputs = model(**data)
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/colossalai/booster/plugin/hybrid_parallel_plugin.py", line 222, in forward
[rank0]: return super().forward(*args, **kwargs)
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/colossalai/interface/model.py", line 127, in forward
[rank0]: return self.module(*args, **kwargs)
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1643, in forward
[rank0]: else self._run_ddp_forward(*inputs, **kwargs)
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1459, in _run_ddp_forward
[rank0]: return self.module(*inputs, **kwargs) # type: ignore[index]
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1375, in forward
[rank0]: transformer_outputs = self.transformer(
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 922, in forward
[rank0]: outputs = block(
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 404, in forward
[rank0]: attn_outputs = self.attn(
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/colossalai/shardformer/modeling/gpt2.py", line 876, in forward
[rank0]: attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
[rank0]: TypeError: colossalai.shardformer.layer.attn.ColoAttention.attention() argument after ** must be a mapping, not Tensor
[rank1]: Traceback (most recent call last):
[rank1]: File "/home/yanzhen/distributed_test/colossalAI/test/bug4.py", line 161, in <module>
[rank1]: main()
[rank1]: File "/home/yanzhen/distributed_test/colossalAI/test/bug4.py", line 157, in main
[rank1]: train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
[rank1]: File "/home/yanzhen/distributed_test/colossalAI/test/bug4.py", line 91, in train_epoch
[rank1]: outputs = model(**data)
[rank1]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/colossalai/booster/plugin/hybrid_parallel_plugin.py", line 222, in forward
[rank1]: return super().forward(*args, **kwargs)
[rank1]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/colossalai/interface/model.py", line 127, in forward
[rank1]: return self.module(*args, **kwargs)
[rank1]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1643, in forward
[rank1]: else self._run_ddp_forward(*inputs, **kwargs)
[rank1]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1459, in _run_ddp_forward
[rank1]: return self.module(*inputs, **kwargs) # type: ignore[index]
[rank1]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1375, in forward
[rank1]: transformer_outputs = self.transformer(
[rank1]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 922, in forward
[rank1]: outputs = block(
[rank1]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 404, in forward
[rank1]: attn_outputs = self.attn(
[rank1]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/colossalai/shardformer/modeling/gpt2.py", line 876, in forward
[rank1]: attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
[rank1]: TypeError: colossalai.shardformer.layer.attn.ColoAttention.attention() argument after ** must be a mapping, not Tensor
[rank2]: Traceback (most recent call last):
[rank2]: File "/home/yanzhen/distributed_test/colossalAI/test/bug4.py", line 161, in <module>
[rank2]: main()
[rank2]: File "/home/yanzhen/distributed_test/colossalAI/test/bug4.py", line 157, in main
[rank2]: train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
[rank2]: File "/home/yanzhen/distributed_test/colossalAI/test/bug4.py", line 91, in train_epoch
[rank2]: outputs = model(**data)
[rank2]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/colossalai/booster/plugin/hybrid_parallel_plugin.py", line 222, in forward
[rank2]: return super().forward(*args, **kwargs)
[rank2]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/colossalai/interface/model.py", line 127, in forward
[rank2]: return self.module(*args, **kwargs)
[rank2]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1643, in forward
[rank2]: else self._run_ddp_forward(*inputs, **kwargs)
[rank2]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1459, in _run_ddp_forward
[rank2]: return self.module(*inputs, **kwargs) # type: ignore[index]
[rank2]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1375, in forward
[rank2]: transformer_outputs = self.transformer(
[rank2]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 922, in forward
[rank2]: outputs = block(
[rank2]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 404, in forward
[rank2]: attn_outputs = self.attn(
[rank2]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/colossalai/shardformer/modeling/gpt2.py", line 876, in forward
[rank2]: attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
[rank2]: TypeError: colossalai.shardformer.layer.attn.ColoAttention.attention() argument after ** must be a mapping, not Tensor
[rank3]: Traceback (most recent call last):
[rank3]: File "/home/yanzhen/distributed_test/colossalAI/test/bug4.py", line 161, in <module>
[rank3]: main()
[rank3]: File "/home/yanzhen/distributed_test/colossalAI/test/bug4.py", line 157, in main
[rank3]: train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
[rank3]: File "/home/yanzhen/distributed_test/colossalAI/test/bug4.py", line 91, in train_epoch
[rank3]: outputs = model(**data)
[rank3]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank3]: return self._call_impl(*args, **kwargs)
[rank3]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank3]: return forward_call(*args, **kwargs)
[rank3]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/colossalai/booster/plugin/hybrid_parallel_plugin.py", line 222, in forward
[rank3]: return super().forward(*args, **kwargs)
[rank3]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/colossalai/interface/model.py", line 127, in forward
[rank3]: return self.module(*args, **kwargs)
[rank3]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank3]: return self._call_impl(*args, **kwargs)
[rank3]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank3]: return forward_call(*args, **kwargs)
[rank3]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1643, in forward
[rank3]: else self._run_ddp_forward(*inputs, **kwargs)
[rank3]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1459, in _run_ddp_forward
[rank3]: return self.module(*inputs, **kwargs) # type: ignore[index]
[rank3]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank3]: return self._call_impl(*args, **kwargs)
[rank3]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank3]: return forward_call(*args, **kwargs)
[rank3]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1375, in forward
[rank3]: transformer_outputs = self.transformer(
[rank3]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank3]: return self._call_impl(*args, **kwargs)
[rank3]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank3]: return forward_call(*args, **kwargs)
[rank3]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 922, in forward
[rank3]: outputs = block(
[rank3]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank3]: return self._call_impl(*args, **kwargs)
[rank3]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank3]: return forward_call(*args, **kwargs)
[rank3]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 404, in forward
[rank3]: attn_outputs = self.attn(
[rank3]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank3]: return self._call_impl(*args, **kwargs)
[rank3]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank3]: return forward_call(*args, **kwargs)
[rank3]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/colossalai/shardformer/modeling/gpt2.py", line 876, in forward
[rank3]: attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
[rank3]: TypeError: colossalai.shardformer.layer.attn.ColoAttention.attention() argument after ** must be a mapping, not Tensor
[rank0]:[W1107 11:51:47.092202066 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 (function operator())
W1107 11:51:48.085337 2961155 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2961242 closing signal SIGTERM
W1107 11:51:48.087910 2961155 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2961243 closing signal SIGTERM
E1107 11:51:48.304292 2961155 site-packages/torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 0 (pid: 2961240) of binary: /home/yanzhen/miniconda3/envs/colossal/bin/python3.9
Traceback (most recent call last):
File "/home/yanzhen/miniconda3/envs/colossal/bin/torchrun", line 7, in <module>
sys.exit(main())
File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
return f(*args, **kwargs)
File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/distributed/run.py", line 919, in main
run(args)
File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/distributed/run.py", line 910, in run
elastic_launch(
File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
bug4.py FAILED
------------------------------------------------------------
Failures:
[1]:
time : 2025-11-07_11:51:48
host : ubuntu
rank : 1 (local_rank: 1)
exitcode : 1 (pid: 2961241)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
time : 2025-11-07_11:51:48
host : ubuntu
rank : 0 (local_rank: 0)
exitcode : 1 (pid: 2961240)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
Error: failed to run torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 --master_addr=127.0.0.1 --master_port=29505 bug4.py on 127.0.0.1, is localhost: True, exception: Encountered a bad command exit code!
Command: 'cd /home/yanzhen/distributed_test/colossalAI/test && export SHELL="/bin/bash" COLORTERM="truecolor" VSCODE_DEBUGPY_ADAPTER_ENDPOINTS="/home/yanzhen/.vscode-server/extensions/ms-python.debugpy-2025.14.1/.noConfigDebugAdapterEndpoints/endpoint-8ca95acfe78cb59c.txt" TERM_PROGRAM_VERSION="1.105.1" CONDA_EXE="/home/yanzhen/miniconda3/bin/conda" NCCL_P2P_DISABLE="1" LC_ADDRESS="zh_CN.UTF-8" LC_NAME="zh_CN.UTF-8" PYDEVD_DISABLE_FILE_VALIDATION="1" LC_MONETARY="zh_CN.UTF-8" PWD="/home/yanzhen/distributed_test/colossalAI/test" LOGNAME="yanzhen" XDG_SESSION_TYPE="tty" CONDA_PREFIX="/home/yanzhen/miniconda3/envs/colossal" BUNDLED_DEBUGPY_PATH="/home/yanzhen/.vscode-server/extensions/ms-python.debugpy-2025.14.1/bundled/libs/debugpy" VSCODE_GIT_ASKPASS_NODE="/home/yanzhen/.vscode-server/cli/servers/Stable-7d842fb85a0275a4a8e4d7e040d2625abbf7f084/server/node" MOTD_SHOWN="pam" HOME="/home/yanzhen" LC_PAPER="zh_CN.UTF-8" LANG="en_US.UTF-8" LS_COLORS="rs=0:di=01;34:ln=01;36:mh=00:pi=40;33:so=01;35:do=01;35:bd=40;33;01:cd=40;33;01:or=40;31;01:mi=00:su=37;41:sg=30;43:ca=30;41:tw=30;42:ow=34;42:st=37;44:ex=01;32:*.tar=01;31:*.tgz=01;31:*.arc=01;31:*.arj=01;31:*.taz=01;31:*.lha=01;31:*.lz4=01;31:*.lzh=01;31:*.lzma=01;31:*.tlz=01;31:*.txz=01;31:*.tzo=01;31:*.t7z=01;31:*.zip=01;31:*.z=01;31:*.dz=01;31:*.gz=01;31:*.lrz=01;31:*.lz=01;31:*.lzo=01;31:*.xz=01;31:*.zst=01;31:*.tzst=01;31:*.bz2=01;31:*.bz=01;31:*.tbz=01;31:*.tbz2=01;31:*.tz=01;31:*.deb=01;31:*.rpm=01;31:*.jar=01;31:*.war=01;31:*.ear=01;31:*.sar=01;31:*.rar=01;31:*.alz=01;31:*.ace=01;31:*.zoo=01;31:*.cpio=01;31:*.7z=01;31:*.rz=01;31:*.cab=01;31:*.wim=01;31:*.swm=01;31:*.dwm=01;31:*.esd=01;31:*.jpg=01;35:*.jpeg=01;35:*.mjpg=01;35:*.mjpeg=01;35:*.gif=01;35:*.bmp=01;35:*.pbm=01;35:*.pgm=01;35:*.ppm=01;35:*.tga=01;35:*.xbm=01;35:*.xpm=01;35:*.tif=01;35:*.tiff=01;35:*.png=01;35:*.svg=01;35:*.svgz=01;35:*.mng=01;35:*.pcx=01;35:*.mov=01;35:*.mpg=01;35:*.mpeg=01;35:*.m2v=01;35:*.mkv=01;35:*.webm=01;35:*.webp=01;35:*.ogm=01;35:*.mp4=01;35:*.m4v=01;35:*.mp4v=01;35:*.vob=01;35:*.qt=01;35:*.nuv=01;35:*.wmv=01;35:*.asf=01;35:*.rm=01;35:*.rmvb=01;35:*.flc=01;35:*.avi=01;35:*.fli=01;35:*.flv=01;35:*.gl=01;35:*.dl=01;35:*.xcf=01;35:*.xwd=01;35:*.yuv=01;35:*.cgm=01;35:*.emf=01;35:*.ogv=01;35:*.ogx=01;35:*.aac=00;36:*.au=00;36:*.flac=00;36:*.m4a=00;36:*.mid=00;36:*.midi=00;36:*.mka=00;36:*.mp3=00;36:*.mpc=00;36:*.ogg=00;36:*.ra=00;36:*.wav=00;36:*.oga=00;36:*.opus=00;36:*.spx=00;36:*.xspf=00;36:" PYTHONSTARTUP="/home/yanzhen/.vscode-server/data/User/workspaceStorage/0d3e22743b5008777912953212595ae2/ms-python.python/pythonrc.py" SSL_CERT_DIR="/usr/lib/ssl/certs" CONDA_PROMPT_MODIFIER="(colossal) " GIT_ASKPASS="/home/yanzhen/.vscode-server/cli/servers/Stable-7d842fb85a0275a4a8e4d7e040d2625abbf7f084/server/extensions/git/dist/askpass.sh" SSH_CONNECTION="192.168.1.29 39642 192.168.102.133 18022" USE_MODELSCOPE_HUB="1" VSCODE_PYTHON_AUTOACTIVATE_GUARD="1" _CONDA_EXE="/home/yanzhen/miniconda3/bin/conda" LESSCLOSE="/usr/bin/lesspipe %s %s" _CONDA_ROOT="/home/yanzhen/miniconda3" XDG_SESSION_CLASS="user" TERM="xterm-256color" LC_IDENTIFICATION="zh_CN.UTF-8" PYTHON_BASIC_REPL="1" LESSOPEN="| /usr/bin/lesspipe %s" USER="yanzhen" VSCODE_GIT_IPC_HANDLE="/run/user/1006/vscode-git-760712a092.sock" CONDA_SHLVL="2" SHLVL="1" LC_TELEPHONE="zh_CN.UTF-8" LC_MEASUREMENT="zh_CN.UTF-8" XDG_SESSION_ID="6320" CONDA_PYTHON_EXE="/home/yanzhen/miniconda3/bin/python" LD_LIBRARY_PATH="/home/yanzhen/.tensornvme/lib:/usr/local/cuda-12.4/lib64:/home/yanzhen/.tensornvme/lib:/usr/local/cuda-12.4/lib64:" XDG_RUNTIME_DIR="/run/user/1006" SSL_CERT_FILE="/usr/lib/ssl/cert.pem" SSH_CLIENT="192.168.1.29 39642 18022" CONDA_DEFAULT_ENV="colossal" DEBUGINFOD_URLS="https://debuginfod.ubuntu.com " LC_TIME="zh_CN.UTF-8" VSCODE_GIT_ASKPASS_MAIN="/home/yanzhen/.vscode-server/cli/servers/Stable-7d842fb85a0275a4a8e4d7e040d2625abbf7f084/server/extensions/git/dist/askpass-main.js" CUDA_HOME="/usr/local/cuda-12.4" XDG_DATA_DIRS="/usr/share/gnome:/usr/local/share:/usr/share:/var/lib/snapd/desktop" BROWSER="/home/yanzhen/.vscode-server/cli/servers/Stable-7d842fb85a0275a4a8e4d7e040d2625abbf7f084/server/bin/helpers/browser.sh" PATH="/usr/local/cuda-12.4/bin:/home/yanzhen/.vscode-server/cli/servers/Stable-7d842fb85a0275a4a8e4d7e040d2625abbf7f084/server/bin/remote-cli:/home/yanzhen/.local/bin:/home/yanzhen/miniconda3/envs/colossal/bin:/home/yanzhen/miniconda3/condabin:/usr/local/cuda-12.4/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin:/home/yanzhen/miniconda3/bin:/home/yanzhen/.vscode-server/extensions/ms-python.debugpy-2025.14.1/bundled/scripts/noConfigScripts:/home/yanzhen/.vscode-server/data/User/globalStorage/github.copilot-chat/debugCommand:/home/yanzhen/miniconda3/bin" DBUS_SESSION_BUS_ADDRESS="unix:path=/run/user/1006/bus" CONDA_PREFIX_1="/home/yanzhen/miniconda3" LC_NUMERIC="zh_CN.UTF-8" TERM_PROGRAM="vscode" VSCODE_IPC_HOOK_CLI="/run/user/1006/vscode-ipc-d6a0f812-564d-488a-8d89-54d8df9c7838.sock" OLDPWD="/home/yanzhen/distributed_test" _="/home/yanzhen/miniconda3/envs/colossal/bin/colossalai" CUDA_DEVICE_MAX_CONNECTIONS="1" && torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 --master_addr=127.0.0.1 --master_port=29505 bug4.py'
Exit code: 1
Stdout: already printed
Stderr: already printed
====== Training on All Nodes =====
127.0.0.1: failure
====== Stopping All Nodes =====
127.0.0.1: finish
Is there an existing issue for this bug?
The bug has not been fixed in the latest main branch
Do you feel comfortable sharing a concise (minimal) script that reproduces the error? :)
Yes, I will share a minimal reproducible script.
🐛 Describe the bug
When I run the training example from https://colossalai.org/zh-Hans/docs/advanced_tutorials/train_gpt_using_hybrid_parallelism with the following configuration, I encounter the following error:
TypeError: colossalai.shardformer.layer.attn.ColoAttention.attention() argument after ** must be a mapping, not TensorHowever, when I set
enable_all_optimizationtoFalsein the configuration, the error disappears.Therefore, it seems that the issue may be caused by a bug related to the
enable_all_optimizationoption.Below is the Python training script
main.pythat can reproduce the error:Running the following command:
Produces the following error log:
Environment