Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions xtuner/v1/ray/dataflow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,10 @@ async def concurrent_task_runner(self):
waiting_tasks.add(task)

_, pending_tasks = await asyncio.wait(waiting_tasks, timeout=1, return_when=asyncio.FIRST_COMPLETED)
self.finished_samples_count = await self.replay_buffer.get_finished_samples.remote()
if self._is_async_run is False:
self.finished_samples_count = await self.replay_buffer.get_finished_samples.remote()
else:
self.finished_samples_count = await self.replay_buffer.get_total_finished_samples.remote()
waiting_tasks = pending_tasks

pbar.n = self.finished_samples_count
Expand Down Expand Up @@ -313,6 +316,16 @@ async def concurrent_task_runner(self):
self.logger.info(
f"Data generation completed. Replay Buffer Stats: {replay_buffer_stats}, Rollout Stats: {rollout_stats}"
)
await self.replay_buffer.reset_finished_samples.remote() # type: ignore[attr-defined]

# 在全异步场景下,一旦 get 完成数据 finished_samples_count 就会改变,导致一直无法退出
async def get_finished_samples_count(self):
"""Gets the number of finished samples."""
return await self.replay_buffer.get_finished_samples.remote()

async def get_async_data(self, target_batch_size):
"""Gets all data from the replay buffer."""
return await self.replay_buffer.get_samples.remote(target_batch_size)

@ray_method
async def pause(self, timeout: float = 60.0):
Expand Down Expand Up @@ -362,14 +375,19 @@ async def run(
Returns:
List[RLDataFlowItem]: A list of collected training samples.
"""
if num is None:
self._is_async_run = False
else:
self._is_async_run = True
self._reset_internal_states(
global_batch_size=num,
sample_params=sample_params,
extra_params=extra_params,
enable_partial_rollout=enable_partial_rollout,
)
await self.concurrent_task_runner()
return await self.replay_buffer.get_samples.remote(self.target_batch_size) # type: ignore[attr-defined]
if self._is_async_run is False:
return await self.replay_buffer.get_samples.remote(self.target_batch_size) # type: ignore[attr-defined]

def get_replaybuffer_status(self):
return ray.get(self.replay_buffer.status.remote())
Expand Down
11 changes: 11 additions & 0 deletions xtuner/v1/ray/dataflow/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ def __init__(
self.storage,
)
self.post_processor_func = config.postprocessor_func
self._prev_total_finished_samples = 0
self.logger = get_logger(log_dir=config.worker_log_dir, tag="ReplayBuffer")

def get_train_dataset_length(self):
Expand Down Expand Up @@ -743,6 +744,16 @@ def get_finished_samples(self):
def get_unfinished_samples(self):
"""Returns the number of unfinished sample groups in the storage."""
return self.storage.get_unfinished_samples()

# 返回当前一次 step 中完成的 sample 数量
def get_total_finished_samples(self):
"""Returns the total number of sample groups in the storage."""
current_total_finished_samples = len(self.storage._root2actions) - self._prev_total_finished_samples
return current_total_finished_samples

def reset_finished_samples(self):
"""Sets the total number of sample groups in the storage to zero."""
self._prev_total_finished_samples = len(self.storage._root2actions)

def clear(self):
"""Clears the replay buffer storage."""
Expand Down
1 change: 1 addition & 0 deletions xtuner/v1/ray/rollout/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def get_rollout_info(self):
server_url_dict=self.worker_server_urls_map,
rollout_config=self.config,
worker_server_urls_status=worker_server_urls_status,
rollout_workers=[info.actor for info in self.workers_info.values()],
)

def init_workers(self):
Expand Down
37 changes: 37 additions & 0 deletions xtuner/v1/ray/rollout/sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from xtuner.v1.ray.config import RolloutConfig

from .worker import RolloutWorker
import logging


class SGLangWorker(RolloutWorker):
Expand All @@ -23,6 +24,12 @@ def __init__(
super().__init__(config, rank, master_addr, master_port, world_size, accelerator)
from sglang.srt.entrypoints.http_server import launch_server

# 禁用客户端 HTTP 请求日志
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logging.getLogger("httpcore.connection").setLevel(logging.WARNING)
logging.getLogger("httpcore.http11").setLevel(logging.WARNING)

self.server_func = launch_server
self.endpoints["health_generate"] = "health_generate"
self.endpoints["generate"] = "generate"
Expand All @@ -31,6 +38,36 @@ def __init__(
self.api_keys = self.config.api_key
self.model_name = self.config.model_name
self.enable_return_routed_experts = self.config.enable_return_routed_experts

def init_weights_update_group(self, master_address, master_port, rank_offset, world_size, group_name, backend):
return self._make_request(
"init_weights_update_group",
{
"master_address": master_address,
"master_port": master_port,
"rank_offset": rank_offset,
"world_size": world_size,
"group_name": group_name,
"backend": backend,
},
)

def update_weights_from_distributed(
self, names, dtypes, shapes, group_name, flush_cache=False, weight_version: str | None = None
):
payload = {
"names": names,
"dtypes": [str(dtype).replace("torch.", "") for dtype in dtypes],
"shapes": shapes,
"group_name": group_name,
"flush_cache": flush_cache,
}
if weight_version is not None:
payload["weight_version"] = weight_version
return self._make_request(
"update_weights_from_distributed",
payload,
)

async def _create_request(
self,
Expand Down
Loading
Loading