Skip to content
89 changes: 89 additions & 0 deletions rollout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from compose_rl.algorithms.online.generation_utils.vllm_utils import init_process_group
from composer.utils import dist
import torch

import logging

MODEL_UPDATE_PORT=29600
EXPERIENCE_BUFFER_PORT=29601
NUM_INFERENCE_ENGINES=1
MAX_ITERATIONS=2

# Set it to 0 to be fully synchronous and on-policy.
MAX_ASYNC_STEP=1

logging.basicConfig(
format=
f'[ROLLOUT] %(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s',
)
log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)

if __name__ == "__main__":
# Initialize the process groups for communication between train and rollout agents
rank = 1 # TODO: UPDATE TO SUPPORT MULTIPLE INFERENCE ENGINES
model_update_group = init_process_group(
backend="nccl",
init_method=f"tcp://localhost:{MODEL_UPDATE_PORT}",
world_size=1 + NUM_INFERENCE_ENGINES,
rank=rank,
group_name="model_update_group",
)
experience_buffer_group = init_process_group(
backend="gloo",
init_method=f"tcp://localhost:{EXPERIENCE_BUFFER_PORT}",
world_size=1 + NUM_INFERENCE_ENGINES,
rank=rank,
group_name="experience_buffer_group",
)

is_ready_to_update = torch.tensor([0]).to('cuda')
is_ready_to_update_work = None
last_update_iteration = 0


for i in range(MAX_ITERATIONS):
log.info(f"Starting iteration {i + 1}/{MAX_ITERATIONS}")

if is_ready_to_update_work is None:
# Check to see if there's an update to the model weights available.
is_ready_to_update_work = torch.distributed.broadcast(group=model_update_group, src=0,tensor=is_ready_to_update, async_op=True)

# We always need to update on the first iteration.
if i == 0 or i - last_update_iteration > MAX_ASYNC_STEP:
is_ready_to_update_work.wait()

if is_ready_to_update.item() == 1:
assert is_ready_to_update_work.is_completed()
log.info(f"Weights are ready to update")

# Update the model weights
log.info("Updating the model weights")
weights = torch.tensor([0]).to('cuda')
torch.distributed.broadcast(group=model_update_group, src=0,tensor=weights)
log.info(f"Updating the weights to {weights}")

# Reset the update check
is_ready_to_update = torch.tensor([0]).to('cuda')
is_ready_to_update_work = None
last_update_iteration = i

# TODO: start generating rollouts for the experience buffer
log.info("Generating rollouts!")

# Send the experience buffer to the train agent.
# We do not block here. We can continue generating rollouts while the experience buffer is being sent.
experience_buffer = torch.tensor([20+i])
experience_buffer_work = torch.distributed.broadcast(group=experience_buffer_group, src=1,tensor=experience_buffer, async_op=True)
log.info(f"Sent experience buffer {experience_buffer}")

log.info(f"Completed iteration {i + 1}/{MAX_ITERATIONS}")

if i == MAX_ITERATIONS - 1:
assert experience_buffer_work is not None
log.info(f"Waiting for the last experience buffer to be received")
experience_buffer_work.wait()




23 changes: 23 additions & 0 deletions test_no_ray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import subprocess
import traceback


if __name__ == "__main__":
# test on 4 gpus!
# for multinode, we should determine which command to launch on which node
try:
# Launch the train agent with multiple processes for distributed training
p1 = subprocess.Popen('CUDA_VISIBLE_DEVICES=0,1 composer -n 2 train.py', shell=True)

# Launch the rollout agent with a single process for vllm
p2 = subprocess.Popen('CUDA_VISIBLE_DEVICES=2,3 python rollout.py', shell=True)
p1.wait()
p2.wait()
except Exception as e:
print(e)
print(traceback.format_exc())
print('Killing training processes')
finally:
p1.terminate()
p2.terminate()

78 changes: 78 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from composer.utils import dist
import torch

from compose_rl.algorithms.online.generation_utils.vllm_utils import init_process_group

import logging

MODEL_UPDATE_PORT=29600
EXPERIENCE_BUFFER_PORT=29601
NUM_INFERENCE_ENGINES=1
MAX_ITERATIONS=2


logging.basicConfig(
format=
f'[TRAIN] %(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s',
)
log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)

if __name__ == "__main__":
torch.distributed.init_process_group(backend="nccl")

# Initialize the process groups for communication between train and rollout agents
model_update_group = None
experience_buffer_group = None
if dist.get_global_rank() == 0:
log.info("Initializing model update process group")
model_update_group = init_process_group(
backend="nccl",
init_method=f"tcp://localhost:{MODEL_UPDATE_PORT}",
world_size=1 + NUM_INFERENCE_ENGINES,
rank=0,
group_name="model_update_group",
)
experience_buffer_group = init_process_group(
backend="gloo",
init_method=f"tcp://localhost:{EXPERIENCE_BUFFER_PORT}",
world_size=1 + NUM_INFERENCE_ENGINES,
rank=0,
group_name="experience_buffer_group",
)

for i in range(MAX_ITERATIONS):
log.info(f"Starting iteration {i + 1}/{MAX_ITERATIONS}")

# TODO: We shouldn't broadcast if the rollouts are done!
if model_update_group is not None:
# Let the rollout agent know that we're ready to update the model weights
is_ready_to_update = torch.tensor([1]).to('cuda')
torch.distributed.broadcast(group=model_update_group, src=0,tensor=is_ready_to_update)
log.info(f"Broadcasted is_ready_to_update {is_ready_to_update}")

# Broadcast the model weights
weights = torch.tensor([10+i]).to('cuda')
torch.distributed.broadcast(group=model_update_group, src=0,tensor=weights)
log.info(f"Broadcasted model weights {weights}")

# Get the experience buffer results from the rollout process
experience_buffer = torch.tensor([0])
if experience_buffer_group is not None:
torch.distributed.broadcast(group=experience_buffer_group, src=1,tensor=experience_buffer)
log.info(f"Got experience buffer {experience_buffer}")

# all training ranks should wait until we have the experience buffer results
dist.barrier()

# TODO: distributed the experiences results to each of the training ranks
# TODO: train the model

# simulate "long" training!
log.info("Training!")
import time
time.sleep(5)


log.info(f"Completed iteration {i + 1}/{MAX_ITERATIONS}")