diff --git a/rollout.py b/rollout.py new file mode 100644 index 00000000..471fb6ab --- /dev/null +++ b/rollout.py @@ -0,0 +1,89 @@ +from compose_rl.algorithms.online.generation_utils.vllm_utils import init_process_group +from composer.utils import dist +import torch + +import logging + +MODEL_UPDATE_PORT=29600 +EXPERIENCE_BUFFER_PORT=29601 +NUM_INFERENCE_ENGINES=1 +MAX_ITERATIONS=2 + +# Set it to 0 to be fully synchronous and on-policy. +MAX_ASYNC_STEP=1 + +logging.basicConfig( + format= + f'[ROLLOUT] %(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s', +) +log = logging.getLogger(__name__) +log.setLevel(logging.DEBUG) + +if __name__ == "__main__": + # Initialize the process groups for communication between train and rollout agents + rank = 1 # TODO: UPDATE TO SUPPORT MULTIPLE INFERENCE ENGINES + model_update_group = init_process_group( + backend="nccl", + init_method=f"tcp://localhost:{MODEL_UPDATE_PORT}", + world_size=1 + NUM_INFERENCE_ENGINES, + rank=rank, + group_name="model_update_group", + ) + experience_buffer_group = init_process_group( + backend="gloo", + init_method=f"tcp://localhost:{EXPERIENCE_BUFFER_PORT}", + world_size=1 + NUM_INFERENCE_ENGINES, + rank=rank, + group_name="experience_buffer_group", + ) + + is_ready_to_update = torch.tensor([0]).to('cuda') + is_ready_to_update_work = None + last_update_iteration = 0 + + + for i in range(MAX_ITERATIONS): + log.info(f"Starting iteration {i + 1}/{MAX_ITERATIONS}") + + if is_ready_to_update_work is None: + # Check to see if there's an update to the model weights available. + is_ready_to_update_work = torch.distributed.broadcast(group=model_update_group, src=0,tensor=is_ready_to_update, async_op=True) + + # We always need to update on the first iteration. + if i == 0 or i - last_update_iteration > MAX_ASYNC_STEP: + is_ready_to_update_work.wait() + + if is_ready_to_update.item() == 1: + assert is_ready_to_update_work.is_completed() + log.info(f"Weights are ready to update") + + # Update the model weights + log.info("Updating the model weights") + weights = torch.tensor([0]).to('cuda') + torch.distributed.broadcast(group=model_update_group, src=0,tensor=weights) + log.info(f"Updating the weights to {weights}") + + # Reset the update check + is_ready_to_update = torch.tensor([0]).to('cuda') + is_ready_to_update_work = None + last_update_iteration = i + + # TODO: start generating rollouts for the experience buffer + log.info("Generating rollouts!") + + # Send the experience buffer to the train agent. + # We do not block here. We can continue generating rollouts while the experience buffer is being sent. + experience_buffer = torch.tensor([20+i]) + experience_buffer_work = torch.distributed.broadcast(group=experience_buffer_group, src=1,tensor=experience_buffer, async_op=True) + log.info(f"Sent experience buffer {experience_buffer}") + + log.info(f"Completed iteration {i + 1}/{MAX_ITERATIONS}") + + if i == MAX_ITERATIONS - 1: + assert experience_buffer_work is not None + log.info(f"Waiting for the last experience buffer to be received") + experience_buffer_work.wait() + + + + diff --git a/test_no_ray.py b/test_no_ray.py new file mode 100644 index 00000000..09236e97 --- /dev/null +++ b/test_no_ray.py @@ -0,0 +1,23 @@ +import subprocess +import traceback + + +if __name__ == "__main__": + # test on 4 gpus! + # for multinode, we should determine which command to launch on which node + try: + # Launch the train agent with multiple processes for distributed training + p1 = subprocess.Popen('CUDA_VISIBLE_DEVICES=0,1 composer -n 2 train.py', shell=True) + + # Launch the rollout agent with a single process for vllm + p2 = subprocess.Popen('CUDA_VISIBLE_DEVICES=2,3 python rollout.py', shell=True) + p1.wait() + p2.wait() + except Exception as e: + print(e) + print(traceback.format_exc()) + print('Killing training processes') + finally: + p1.terminate() + p2.terminate() + diff --git a/train.py b/train.py new file mode 100644 index 00000000..835ce2e1 --- /dev/null +++ b/train.py @@ -0,0 +1,78 @@ +from composer.utils import dist +import torch + +from compose_rl.algorithms.online.generation_utils.vllm_utils import init_process_group + +import logging + +MODEL_UPDATE_PORT=29600 +EXPERIENCE_BUFFER_PORT=29601 +NUM_INFERENCE_ENGINES=1 +MAX_ITERATIONS=2 + + +logging.basicConfig( + format= + f'[TRAIN] %(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s', +) +log = logging.getLogger(__name__) +log.setLevel(logging.DEBUG) + +if __name__ == "__main__": + torch.distributed.init_process_group(backend="nccl") + + # Initialize the process groups for communication between train and rollout agents + model_update_group = None + experience_buffer_group = None + if dist.get_global_rank() == 0: + log.info("Initializing model update process group") + model_update_group = init_process_group( + backend="nccl", + init_method=f"tcp://localhost:{MODEL_UPDATE_PORT}", + world_size=1 + NUM_INFERENCE_ENGINES, + rank=0, + group_name="model_update_group", + ) + experience_buffer_group = init_process_group( + backend="gloo", + init_method=f"tcp://localhost:{EXPERIENCE_BUFFER_PORT}", + world_size=1 + NUM_INFERENCE_ENGINES, + rank=0, + group_name="experience_buffer_group", + ) + + for i in range(MAX_ITERATIONS): + log.info(f"Starting iteration {i + 1}/{MAX_ITERATIONS}") + + # TODO: We shouldn't broadcast if the rollouts are done! + if model_update_group is not None: + # Let the rollout agent know that we're ready to update the model weights + is_ready_to_update = torch.tensor([1]).to('cuda') + torch.distributed.broadcast(group=model_update_group, src=0,tensor=is_ready_to_update) + log.info(f"Broadcasted is_ready_to_update {is_ready_to_update}") + + # Broadcast the model weights + weights = torch.tensor([10+i]).to('cuda') + torch.distributed.broadcast(group=model_update_group, src=0,tensor=weights) + log.info(f"Broadcasted model weights {weights}") + + # Get the experience buffer results from the rollout process + experience_buffer = torch.tensor([0]) + if experience_buffer_group is not None: + torch.distributed.broadcast(group=experience_buffer_group, src=1,tensor=experience_buffer) + log.info(f"Got experience buffer {experience_buffer}") + + # all training ranks should wait until we have the experience buffer results + dist.barrier() + + # TODO: distributed the experiences results to each of the training ranks + # TODO: train the model + + # simulate "long" training! + log.info("Training!") + import time + time.sleep(5) + + + log.info(f"Completed iteration {i + 1}/{MAX_ITERATIONS}") +