Skip to content
Open
2 changes: 1 addition & 1 deletion .github/workflows/pr-cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
matrix:
include:
- name: "cpu-2.7.0"
container: mosaicml/pytorch:2.7.0_cpu-python3.12-ubuntu22.04
container: mosaicml/dle:nightly-latest # Update after the next release of llm-foundry (mosaicml/llm-foundry:2.7.0_cpu-python3.12-ubuntu22.04)

Check failure on line 26 in .github/workflows/pr-cpu.yaml

View workflow job for this annotation

GitHub Actions / code-quality (3.11, [dev])

26:121 [line-length] line too long (152 > 120 characters)
markers: "not gpu and not only_release"
pip_deps: "[cpu]"
pytest_command: "coverage run -m pytest"
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/pr-gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
matrix:
include:
- name: "gpu-2.7.0-1"
container: mosaicml/llm-foundry:2.7.0_cu128-latest
container: mosaicml/dle:nightly-latest # Update after the next release of llm-foundry (mosaicml/llm-foundry:2.7.0_cu128-latest)

Check failure on line 26 in .github/workflows/pr-gpu.yaml

View workflow job for this annotation

GitHub Actions / code-quality (3.11, [dev])

26:121 [line-length] line too long (138 > 120 characters)
markers: "gpu"
pip_deps: "[gpu]"
pytest_command: "coverage run -m pytest"
Expand Down Expand Up @@ -52,7 +52,7 @@
matrix:
include:
- name: "gpu-2.7.0-2"
container: mosaicml/llm-foundry:2.7.0_cu128-latest
container: mosaicml/dle:nightly-latest # Update after the next release of llm-foundry (mosaicml/llm-foundry:2.7.0_cu128-latest)

Check failure on line 55 in .github/workflows/pr-gpu.yaml

View workflow job for this annotation

GitHub Actions / code-quality (3.11, [dev])

55:121 [line-length] line too long (138 > 120 characters)
markers: "gpu"
pip_deps: "[gpu]"
pytest_command: "coverage run -m pytest"
Expand Down
204 changes: 111 additions & 93 deletions compose_rl/algorithms/online/generation_utils/vllm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import torch
import torch.distributed
import torch.nn as nn
from composer.distributed.shared_utils import get_summon_params_fn
from composer.utils import dist
from ray.exceptions import GetTimeoutError
from ray.util.placement_group import placement_group
Expand All @@ -42,7 +43,9 @@
default_pg_timeout,
rendezvous,
)
from torch.distributed.fsdp import FSDPModule
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.tensor import DTensor

from compose_rl.algorithms.online.generation_utils.vllm_actor import LLMRayActor
from compose_rl.algorithms.online.model_methods import (
Expand Down Expand Up @@ -277,33 +280,6 @@ def create_vllm_engines(
return vllm_engines


def build_param_fullnames(top_module: nn.Module) -> dict:
"""Builds a mapping of parameter objects to their fully-qualified names.

Traverses the entire model from the top level and map each parameter
object to its fully-qualified name (e.g.,
"lm_backbone.layer1.mlp.down_proj.weight").

Args:
top_module (nn.Module): The top-level module to traverse.
"""
param2fullname = {}

def _dfs(current_module: nn.Module, prefix: str = ''):
# Get local parameters (without recursing into children).
for local_name, param in current_module.named_parameters(recurse=False):
full_name = f'{prefix}.{local_name}' if prefix else local_name
param2fullname[param] = full_name

# Recurse on child modules.
for child_name, child_module in current_module.named_children():
child_prefix = f'{prefix}.{child_name}' if prefix else child_name
_dfs(child_module, prefix=child_prefix)

_dfs(top_module)
return param2fullname


def simplify_param_path(path: str) -> str:
"""Simplifies the parameter path by removing unnecessary parts.

Expand Down Expand Up @@ -333,15 +309,15 @@ def simplify_param_path(path: str) -> str:


def is_fsdp_leaf(module: nn.Module) -> bool:
"""Check if the module is a leaf in the FSDP hierarchy.
"""Check if the module is a leaf in the FSDP(1/2) hierarchy.

Args:
module (nn.Module): The torch module to check
"""
if not isinstance(module, FSDP):
if not isinstance(module, (FSDP, FSDPModule)):
return False
for subm in module.modules():
if subm is not module and isinstance(subm, FSDP):
if subm is not module and isinstance(subm, (FSDP, FSDPModule)):
return False
return True

Expand Down Expand Up @@ -377,6 +353,19 @@ def should_update_torch_module(
return False


def get_name_for_param(model: nn.Module, param: torch.Tensor) -> str:
"""Get the full name of a parameter in the model.

Args:
model (nn.Module): The model that contains the parameter
param (torch.Tensor): The parameter to get the name for
"""
for name, p in model.named_parameters():
if p is param:
return name
raise ValueError(f'Parameter {param} not found in model {model}')


def broadcast_to_vllm(
model: nn.Module,
vllm_engines: list,
Expand All @@ -399,19 +388,21 @@ def broadcast_to_vllm(
torch.cuda.empty_cache()
if loss_type == OnPolicyEnum.PPO:
# Extract the lm_backbone params from the model
count, num_params = 0, len(
num_params = len(
list(model.model.lm_backbone.named_parameters()), # type: ignore
)
elif loss_type in ALGORITHM_TYPE.CRITIC_FREE:
# Directly use the model params
count, num_params = 0, len(
num_params = len(
list(model.model.named_parameters()), # type: ignore
)
else:
raise ValueError(
f'Unsupported loss type: {loss_type}. Supported types are: ppo, grpo',
)
count = 0

# Reset prefix caching if enabled
refss = []
cache_reset_refss = []
if enable_prefix_caching and dist.get_global_rank() == 0:
Expand All @@ -430,8 +421,6 @@ def broadcast_to_vllm(
]
seen_fsdp_modules = set()
seen_updated_parsed_names = set()
count = 0
param_2_full_name = build_param_fullnames(model)

with torch.no_grad():
# Adding a dummy forwards call.
Expand All @@ -454,67 +443,96 @@ def broadcast_to_vllm(
start_time = time.time()
update_time = 0

# Getting the correct summon_full_params function based on whether
# the model is FSDP1 vs FSDP2.
summon_full_params = get_summon_params_fn(model)

for module_name, module in model.named_modules():
if isinstance(module, FSDP):
# This is needed otherwise FSDP will materialize parameters of size 0.
# So just for the joint actor critic models we have to actually skip this module.
if module_name == 'model' and loss_type == OnPolicyEnum.PPO:
continue

# Only update if we haven't updated this module before
if module not in seen_fsdp_modules:
seen_fsdp_modules.add(module)

# Materializes parameters for this specific FSDP module
with FSDP.summon_full_params(
# Skip non-FSDP modules
if not isinstance(module, (FSDP, FSDPModule)):
continue

# This is needed otherwise FSDP will materialize parameters of size 0.
# So just for the joint actor critic models we have to actually skip this module.
if module_name == 'model' and loss_type == OnPolicyEnum.PPO:
continue

# Only update if we haven't updated this module before
if module in seen_fsdp_modules:
continue
seen_fsdp_modules.add(module)

# Materializes parameters for this specific FSDP module only BUT THIS
# INCLUDES any parameters from submodules that are not FSDP-wrapped themselves.
# We don't want to materialize the entire model to avoid potential OOM.
# View NestedFSDPModel in the Composer repo and the related test in
# for an example of why this logic is needed.
with summon_full_params(
module,
writeback=False,
rank0_only=False,
recurse=False,
):
# Note: For the following module.named_parameters(), we have to use recurse=True
# since the following case is possible where we still need NonFSDP_Child's params
# FSDP_Module
# |- direct_param (found with recurse=False)
# |- NonFSDP_Child
# | |- child_param (missed with recurse=False)
for _, param in module.named_parameters(recurse=True):
# Only distribute on rank 0
if not dist.get_global_rank() == 0:
continue

# Skip DTensor params at this level since they were not summoned
# and we only want to broadcast the summoned parameters.
# Encountering this conditional implies that a FSDP-wrapped submodule
# exists and will later be summoned to materialize this parameter.
if isinstance(param, DTensor):
continue

full_name = get_name_for_param(model, param)
parsed_name = simplify_param_path(full_name)

if parsed_name in seen_updated_parsed_names:
continue

if 'critic_head' in parsed_name:
log.info('Critic head found, skipping sending')
continue

update = should_update_torch_module(
parsed_name,
full_name,
module,
writeback=False,
rank0_only=True,
recurse=False,
):
for _, param in module.named_parameters(recurse=True):
if dist.get_global_rank() == 0:
full_name = param_2_full_name[param]
parsed_name = simplify_param_path(full_name)

if 'critic_head' in parsed_name:
log.info('Critic head found, skipping sending')
continue

update = should_update_torch_module(
parsed_name,
full_name,
module,
loss_type,
valid_non_leaf_module_names,
)

# We've already updated this module before,
if parsed_name in seen_updated_parsed_names:
continue

# Usually if we have to skip a module, it's because we cannot
if update:
start_update_time = time.time()
seen_updated_parsed_names.add(parsed_name)

count += 1
shape = param.shape
refs = [
engine.update_weight.remote(
parsed_name,
dtype=param.dtype,
shape=shape,
empty_cache=(count == num_params),
) for engine in vllm_engines
]
refss.extend(refs)
torch.distributed.broadcast(
param.data,
0,
group=model_update_group,
)
update_time += time.time() - start_update_time
loss_type,
valid_non_leaf_module_names,
)

if not update:
continue

start_update_time = time.time()
seen_updated_parsed_names.add(parsed_name)

count += 1
shape = param.shape
refs = [
engine.update_weight.remote(
parsed_name,
dtype=param.dtype,
shape=shape,
empty_cache=(count == num_params),
) for engine in vllm_engines
]
refss.extend(refs)

torch.distributed.broadcast(
param.data,
0,
group=model_update_group,
)
update_time += time.time() - start_update_time

# Issue (#67): Note this code will likely need to be updated for PEFT for efficiency reasons.
if dist.get_global_rank() == 0:
Expand Down
8 changes: 4 additions & 4 deletions compose_rl/algorithms/online/hf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
import torch.nn as nn
from composer.distributed.shared_utils import get_summon_params_fn
from composer.utils import is_model_fsdp
from transformers import (
AutoConfig,
Expand Down Expand Up @@ -93,14 +94,13 @@ def generate(
**kwargs: Any,
):
if is_model_fsdp(self.lm_backbone):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

# Note: We need to use the FSDP.summon_full_params context manager here because the generate function
# Note: We need to use the summon_full_params context manager here because the generate function
# does not seem to gather the weights for the LM head. This solution works because the tied weights of the LM head
# are in the root FSDP module, and are summoned by the below context manager. See https://github.com/pytorch/pytorch/issues/100069
# for more info.
# Note: We use recurse=False here so that we only summon full params for the LM head, not the entire model.
with FSDP.summon_full_params(
summon_full_params = get_summon_params_fn(self.lm_backbone)
with summon_full_params(
self.lm_backbone,
writeback=False,
recurse=False,
Expand Down
8 changes: 4 additions & 4 deletions compose_rl/algorithms/online/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, MutableMapping, Optional, Union

import torch
from composer.distributed.shared_utils import get_summon_params_fn
from composer.models import HuggingFaceModel
from composer.utils import dist, is_model_fsdp
from llmfoundry.models import ComposerHFCausalLM
Expand Down Expand Up @@ -186,14 +187,13 @@ def generate(self, input_ids: torch.Tensor, *args: Any, **kwargs: Any):
# Note: it seems as if we need to summon FSDP parameters here to ensure that we don't break
# the standard actor critic forward pass.
if is_model_fsdp(self.model):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

# Note: We need to use the FSDP.summon_full_params context manager here because the generate function
# Note: We need to use the summon_full_params context manager here because the generate function
# does not seem to gather the weights for the LM head. This solution works because the tied weights of the LM head
# are in the root FSDP module, and are summoned by the below context manager. See https://github.com/pytorch/pytorch/issues/100069
# for more info.
# Note: We use recurse=False here so that we only summon full params for the LM head, not the entire model.
with FSDP.summon_full_params(
summon_full_params = get_summon_params_fn(self.model)
with summon_full_params(
self.model,
writeback=False,
recurse=False,
Expand Down
6 changes: 3 additions & 3 deletions compose_rl/algorithms/reward_modeling/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Mapping, MutableMapping, Optional, Union

import torch
from composer.distributed.shared_utils import get_summon_params_fn
from composer.utils import is_model_fsdp
from llmfoundry.models import ComposerHFCausalLM, ComposerMPTCausalLM

Expand Down Expand Up @@ -241,7 +242,6 @@ def loss(self, outputs: SequenceClassifierOutput,


class ComposerHFCausalClassifierRewardModel(ComposerHFCausalLM, RewardModel):

default_train_metrics: tuple = ()
default_eval_metrics: tuple = ()

Expand Down Expand Up @@ -292,9 +292,9 @@ def mask_last_embed_except_eos(

context_manager = nullcontext
if is_model_fsdp(self.model):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
summon_full_params = get_summon_params_fn(self.model)
context_manager = partial(
FSDP.summon_full_params,
summon_full_params,
self.model,
writeback=True,
recurse=False,
Expand Down
Loading
Loading