diff --git a/megatron/rl/agent/api.py b/megatron/rl/agent/api.py index d00c9ee568..3e16f74599 100644 --- a/megatron/rl/agent/api.py +++ b/megatron/rl/agent/api.py @@ -3,7 +3,7 @@ import asyncio from abc import ABC, abstractmethod from collections.abc import AsyncIterable -from typing import TypeVar +from typing import Generic, TypeVar import numpy as np from pydantic import BaseModel @@ -99,7 +99,7 @@ class RewardEvaluationResult(EvaluationResult): T = TypeVar('T', bound=EvaluationResult) -class EvaluationResponse[T](AgentBaseModel, TypeLookupable): +class EvaluationResponse(AgentBaseModel, TypeLookupable, Generic[T]): env_id: str | None = None results: list[T] diff --git a/megatron/rl/agent/remote_agent.py b/megatron/rl/agent/remote_agent.py new file mode 100644 index 0000000000..6de2cca7ba --- /dev/null +++ b/megatron/rl/agent/remote_agent.py @@ -0,0 +1,9 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +from ..server.agent.fastapi_env_server import FastAPIEnvServer +from .api import EvaluationAgent, GroupedRolloutGenerator, RolloutGenerator + + +class RemoteAgent(FastAPIEnvServer, RolloutGenerator, GroupedRolloutGenerator, EvaluationAgent): + env_id: str = "remote" + env_server_host_port: str diff --git a/megatron/rl/inference/direct_inference.py b/megatron/rl/inference/direct_inference.py deleted file mode 100644 index b3aee6df09..0000000000 --- a/megatron/rl/inference/direct_inference.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - -from abc import ABC, abstractmethod -from itertools import zip_longest -from typing import Annotated, Any, ClassVar - -from pydantic import BeforeValidator, ValidationError - - -def grouper(iterable, n, fillvalue=None): - args = [iter(iterable)] * n - return zip_longest(*args, fillvalue=fillvalue) - - -from .api import ( - ChatInferenceInterface, - ChatInferenceRequest, - ChatInferenceResponse, - GroupedInferenceResponse, - InferenceInterface, - InferenceRequest, - InferenceResponse, -) -from .chat_templates import ConversationTemplate - - -class DirectInferenceInterface(InferenceInterface, ABC): - """Basic inference engines that operate directly on strings can extend this class and implement base_generate. - - This abstract base class then implements necessary LLM interfaces. - """ - - supports_n: ClassVar[bool] = False - - @abstractmethod - async def base_generate(self, request: InferenceRequest) -> list[InferenceResponse]: - raise NotImplementedError("Inference Classes must implement the base_generate method.") - - def duplicate_requests(self, request: InferenceRequest, n: int) -> list[InferenceRequest]: - return request.model_copy(update={'prompt': request.prompt * n}) - - def fold_responses( - self, responses: list[InferenceResponse], n: int - ) -> list[GroupedInferenceResponse]: - return [GroupedInferenceResponse(responses=x) for x in list(grouper(responses, n))] - - async def agenerate(self, request: InferenceRequest) -> list[InferenceResponse]: - return await self.base_generate( - InferenceRequest.model_validate(request, from_attributes=True) - ) - - async def agroup_generate( - self, request: InferenceRequest, group_size: int - ) -> list[GroupedInferenceResponse]: - if not self.supports_n: - request = self.duplicate_requests(request, group_size) - - generations = await self.agenerate(request) - - generations = self.fold_responses(generations, group_size) - - return generations - - -def ensure_template(value: Any) -> ConversationTemplate: - if isinstance(value, ConversationTemplate): - return value - elif isinstance(value, str): - return ConversationTemplate.from_string(value) - else: - raise ValidationError(f"Invalid conversation template: {value}") - - -class DirectChatInferenceInterface(DirectInferenceInterface, ChatInferenceInterface): - """Basic inference engines that operate directly on strings can extend this class and implement base_generate. This class implements necessary chat interfaces.""" - - conversation_template: Annotated[ConversationTemplate, BeforeValidator(ensure_template)] - - async def base_generate(self, request: ChatInferenceRequest) -> list[ChatInferenceResponse]: - base_generate_results = await super().base_generate( - InferenceRequest( - prompt=[self.conversation_template.format(messages) for messages in request.prompt], - generation_args=request.generation_args, - ) - ) - chat_message_results = self.conversation_template.parse_response(base_generate_results) - return [ - ChatInferenceResponse(response=chat_message, **response.model_dump()) - for chat_message, response in zip(chat_message_results, base_generate_results) - ] diff --git a/megatron/rl/inference/megatron.py b/megatron/rl/inference/megatron.py index 4c739b709c..387c80d611 100644 --- a/megatron/rl/inference/megatron.py +++ b/megatron/rl/inference/megatron.py @@ -31,6 +31,7 @@ ChatInferenceInterface, InferenceRequest, InferenceResponse, + LLMChatMessage, ReturnsRaw, ReturnsTokens, ) @@ -43,7 +44,9 @@ def get_static_inference_engine(args: Namespace, model: MegatronModule) -> AbstractEngine: """Get the relevant backend for running inference. - This function will automatically choose the TRTLLMBackend when possible, and default to Mcore backend if the user does not specify any backends. TRTLLMBackend is not implmented yet. + This function will automatically choose the TRTLLMBackend when possible, + and default to Mcore backend if the user does not specify any backends. + TRTLLMBackend is not implmented yet. Args: args (Namespace): The user arguments parsed from command line @@ -83,7 +86,9 @@ def get_static_inference_engine(args: Namespace, model: MegatronModule) -> Abstr def get_dynamic_inference_engine(args: Namespace, model: MegatronModule) -> AbstractEngine: """Get the relevant backend for running inference. - This function will automatically choose the TRTLLMBackend when possible, and default to Mcore backend if the user does not specify any backends. TRTLLMBackend is not implmented yet. + This function will automatically choose the TRTLLMBackend when possible, + and default to Mcore backend if the user does not specify any backends. + TRTLLMBackend is not implmented yet. Args: args (Namespace): The user arguments parsed from command line @@ -154,6 +159,16 @@ class MegatronLocal(InferenceServer, ReturnsTokens, ReturnsRaw): _kill_engine: bool = PrivateAttr(False) async def base_generate(self, request: InferenceRequest): + + if any(isinstance(p, LLMChatMessage) for p in request.prompt): + raise ValueError( + "MegatronLocal does not support chat requests." + "Use MegatronChatLocal to apply chat templating." + ) + assert all( + isinstance(p, str) for p in request.prompt + ), "MegatronLocal only supports string prompts." + tokenizer = get_tokenizer() sampling_params = SamplingParams( diff --git a/megatron/rl/server/agent/__init__.py b/megatron/rl/server/agent/__init__.py new file mode 100644 index 0000000000..b9a9591fa6 --- /dev/null +++ b/megatron/rl/server/agent/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. diff --git a/megatron/rl/server/agent/fastapi_env_server.py b/megatron/rl/server/agent/fastapi_env_server.py new file mode 100644 index 0000000000..c378c5625e --- /dev/null +++ b/megatron/rl/server/agent/fastapi_env_server.py @@ -0,0 +1,197 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import asyncio +import socket +from typing import AsyncGenerator + +import httpx +import yaml +from fastapi import FastAPI +from pydantic import Field, PrivateAttr +from typing_extensions import Self +from uvicorn import Config, Server +from uvicorn.config import LOGGING_CONFIG + +LOGGING_CONFIG['root'] = {"handlers": ["default"], "level": "INFO"} + +from ... import import_class, inference +from ...agent.api import ( + Agent, + ContrastiveRollout, + ContrastiveRolloutGenerator, + EvaluationAgent, + EvaluationRequest, + EvaluationResponse, + GroupedRolloutGenerator, + GroupedRolloutRequest, + RolloutGenerator, + RolloutRequest, + TokenRollout, +) +from ...server.api import ( + EnvironmentServer, + InferenceServer, + RemoteEvaluationRequest, + RemoteGroupedRolloutRequest, + RemoteRolloutRequest, +) +from .. import agent +from ..api import EnvironmentServer, InferenceServer, RemoteEvaluationRequest, RemoteRolloutRequest + + +@EnvironmentServer.register_subclass +class FastAPIEnvServer(EnvironmentServer): + server_type: str = Field('FastAPIEnvServer', frozen=True, Literal=True) + env_server_host_port: str + _server_task: asyncio.Task = PrivateAttr(None) + + @classmethod + async def launch(cls, env_cls: type[Agent], cls_args: dict, port: int, **kwargs) -> Self: + + app = FastAPI() + + if issubclass(env_cls, GroupedRolloutGenerator): + + @app.post("/grouped_rollouts/") + async def grouped_rollouts( + request: RemoteGroupedRolloutRequest, + ) -> list[list[TokenRollout]]: + env = env_cls(**cls_args) + request.inference_interface = request.inference_interface.unwrap() + return await env.get_grouped_rollouts(request) + + if issubclass(env_cls, ContrastiveRolloutGenerator): + + @app.post("/contrastive_rollouts/") + async def contrastive_rollouts( + request: RemoteRolloutRequest, + ) -> list[ContrastiveRollout]: + env = env_cls(**cls_args) + request.inference_interface = request.inference_interface.unwrap() + return await env.get_contrastive_rollouts(request) + + if issubclass(env_cls, RolloutGenerator): + + @app.post("/rollouts/") + async def rollouts(request: RemoteRolloutRequest) -> list[TokenRollout]: + env = env_cls(**cls_args) + request.inference_interface = request.inference_interface.unwrap() + return await env.get_reward_rollouts(request) + + if issubclass(env_cls, EvaluationAgent): + + @app.post("/evaluation/") + async def run_evaluation(request: RemoteEvaluationRequest): + env = env_cls(**cls_args) + request.inference_interface = request.inference_interface.unwrap() + return await env.run_evaluation(request) + + loop = asyncio.get_event_loop() + config = Config(app=app, loop=loop, host='0.0.0.0', port=port) + server = Server(config) + server_task = loop.create_task(server.serve()) + + ip = socket.gethostbyname(socket.gethostname()) + + launched_server = cls(env_server_host_port=f"{ip}:{config.port}", **kwargs) + launched_server._server_task = server_task + + return launched_server + + def kill(self): + return self._server_task.cancel() + + async def get_contrastive_rollouts(self, request: RolloutRequest) -> list[ContrastiveRollout]: + assert isinstance( + request.inference_interface, InferenceServer + ), "Rollout requests to remote server must contain an InferenceServer object" + payload = request.model_dump() + payload["inference_interface"] = request.inference_interface.model_dump() + async with httpx.AsyncClient() as client: + response = await client.post( + f"http://{self.env_server_host_port}/contrastive_rollouts/", + json=payload, + timeout=None, + ) + rollouts = [ContrastiveRollout.model_validate(r) for r in response.json()] + return rollouts + + async def group_rollout(self, request: GroupedRolloutRequest): + assert ( + False + ), "Calling group_rollout on FastAPIEnvServer is not supported, use get_grouped_rollouts" + + async def get_grouped_rollouts( + self, request: GroupedRolloutRequest + ) -> AsyncGenerator[list[TokenRollout], None]: + assert isinstance( + request.inference_interface, InferenceServer + ), "Rollout requests to remote server must contain an InferenceServer object" + assert request.num_groups != -1, "FastAPIEnvServer does not support group rollout streaming" + payload = request.model_dump() + payload["inference_interface"] = request.inference_interface.model_dump() + async with httpx.AsyncClient() as client: + response = await client.post( + f"http://{self.env_server_host_port}/grouped_rollouts/", json=payload, timeout=None + ) + rollouts = [[TokenRollout.model_validate(r) for r in group] for group in response.json()] + for rollout in rollouts: + yield rollout + + async def rollout(self, request: RolloutRequest) -> TokenRollout: + assert ( + False + ), "Calling rollout on FastAPIEnvServer is not supported, use get_reward_rollouts" + + async def get_reward_rollouts(self, request: RolloutRequest) -> list[TokenRollout]: + assert isinstance( + request.inference_interface, InferenceServer + ), "Rollout requests to remote server must contain an InferenceServer object" + payload = request.model_dump() + payload["inference_interface"] = request.inference_interface.model_dump() + async with httpx.AsyncClient() as client: + response = await client.post( + f"http://{self.env_server_host_port}/rollouts/", json=payload, timeout=None + ) + rollouts = [TokenRollout.model_validate(r) for r in response.json()] + return rollouts + + async def run_evaluation(self, request: EvaluationRequest) -> EvaluationResponse: + assert isinstance( + request.inference_interface, InferenceServer + ), "Evaluation requests to remote server must contain an InferenceServer object" + payload = request.model_dump() + payload["inference_interface"] = request.inference_interface.model_dump() + async with httpx.AsyncClient(timeout=None) as client: + response = await client.post( + f"http://{self.env_server_host_port}/evaluation/", json=payload, timeout=None + ) + response = EvaluationResponse.model_validate(response.json()).unwrap() + return response + + +def run(agent_cls: type[Agent], cls_args: dict, port: int): + loop = asyncio.new_event_loop() + + async def run_server(): + server: FastAPIEnvServer = await FastAPIEnvServer.launch( + env_cls=agent_cls, cls_args=cls_args, port=port + ) + print(server.model_dump(exclude={'_server_task'})) + await server._server_task + + loop.run_until_complete(run_server()) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--env-config", type=str, required=True) + parser.add_argument("--port", type=int, default=8000) + args = parser.parse_args() + with open(args.env_config, 'r') as f: + config = yaml.safe_load(f)[0] + agent_cls = import_class(config['agent_type']) + cls_args = config['agent_args'] + run(agent_cls, cls_args, port=args.port)