diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml new file mode 100644 index 000000000000..4dfdfa6f1e3a --- /dev/null +++ b/.github/workflows/integration.yml @@ -0,0 +1,55 @@ +name: Integration + +on: + workflow_dispatch: + inputs: + branch: + description: 'Branch to run tests' + required: true + type: string + +jobs: + test: + runs-on: ubuntu-latest + environment: integration + strategy: + matrix: + package: + [ + "./packages/autogen-core", + "./packages/autogen-ext", + "./packages/autogen-agentchat", + ] + steps: + - uses: actions/checkout@v4 + with: + ref: ${{ github.event.inputs.branch }} + - uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + version: "0.5.18" + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Run uv sync + run: | + uv sync --locked --all-extras + echo "PKG_NAME=$(basename '${{ matrix.package }}')" >> $GITHUB_ENV + + working-directory: ./python + - name: Run task + run: | + source ${{ github.workspace }}/python/.venv/bin/activate + poe --directory ${{ matrix.package }} test + working-directory: ./python + + - name: Move coverage file + run: | + mv ${{ matrix.package }}/coverage.xml coverage_${{ env.PKG_NAME }}.xml + working-directory: ./python + + - name: Upload coverage artifact + uses: actions/upload-artifact@v4 + with: + name: coverage-${{ env.PKG_NAME }} + path: ./python/coverage_${{ env.PKG_NAME }}.xml diff --git a/dotnet/src/Microsoft.AutoGen/Contracts/AgentProxy.cs b/dotnet/src/Microsoft.AutoGen/Contracts/AgentProxy.cs index 44ad9b0e10b2..d37d6284b7d0 100644 --- a/dotnet/src/Microsoft.AutoGen/Contracts/AgentProxy.cs +++ b/dotnet/src/Microsoft.AutoGen/Contracts/AgentProxy.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // AgentProxy.cs +using System.Text.Json; + namespace Microsoft.AutoGen.Contracts; /// @@ -55,7 +57,7 @@ private T ExecuteAndUnwrap(Func> delegate_) /// /// A dictionary representing the state of the agent. Must be JSON serializable. /// A task representing the asynchronous operation. - public ValueTask LoadStateAsync(IDictionary state) + public ValueTask LoadStateAsync(IDictionary state) { return this.runtime.LoadAgentStateAsync(this.Id, state); } @@ -64,7 +66,7 @@ public ValueTask LoadStateAsync(IDictionary state) /// Saves the state of the agent. The result must be JSON serializable. /// /// A task representing the asynchronous operation, returning a dictionary containing the saved state. - public ValueTask> SaveStateAsync() + public ValueTask> SaveStateAsync() { return this.runtime.SaveAgentStateAsync(this.Id); } diff --git a/dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs b/dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs index 0d84fbe72d37..c4b2e998f1b0 100644 --- a/dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs +++ b/dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // IAgentRuntime.cs -using StateDict = System.Collections.Generic.IDictionary; +using StateDict = System.Collections.Generic.IDictionary; namespace Microsoft.AutoGen.Contracts; diff --git a/dotnet/src/Microsoft.AutoGen/Contracts/ISaveState.cs b/dotnet/src/Microsoft.AutoGen/Contracts/ISaveState.cs index ed6d15d1d8d6..4f98f1fc4842 100644 --- a/dotnet/src/Microsoft.AutoGen/Contracts/ISaveState.cs +++ b/dotnet/src/Microsoft.AutoGen/Contracts/ISaveState.cs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // ISaveState.cs -using StateDict = System.Collections.Generic.IDictionary; +using StateDict = System.Collections.Generic.IDictionary; namespace Microsoft.AutoGen.Contracts; diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs index 1ff1036016d1..46114884326b 100644 --- a/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs @@ -2,6 +2,7 @@ // GrpcAgentRuntime.cs using System.Collections.Concurrent; +using System.Text.Json; using Grpc.Core; using Microsoft.AutoGen.Contracts; using Microsoft.AutoGen.Protobuf; @@ -319,13 +320,13 @@ public async ValueTask PublishMessageAsync(object message, TopicId topic, Contra public ValueTask GetAgentAsync(string agent, string key = "default", bool lazy = true) => this.GetAgentAsync(new Contracts.AgentId(agent, key), lazy); - public async ValueTask> SaveAgentStateAsync(Contracts.AgentId agentId) + public async ValueTask> SaveAgentStateAsync(Contracts.AgentId agentId) { IHostableAgent agent = await this._agentsContainer.EnsureAgentAsync(agentId); return await agent.SaveStateAsync(); } - public async ValueTask LoadAgentStateAsync(Contracts.AgentId agentId, IDictionary state) + public async ValueTask LoadAgentStateAsync(Contracts.AgentId agentId, IDictionary state) { IHostableAgent agent = await this._agentsContainer.EnsureAgentAsync(agentId); await agent.LoadStateAsync(state); @@ -375,37 +376,41 @@ public ValueTask TryGetAgentProxyAsync(Contracts.AgentId agentId) return ValueTask.FromResult(new AgentProxy(agentId, this)); } - public async ValueTask> SaveStateAsync() - { - Dictionary state = new(); - foreach (var agent in this._agentsContainer.LiveAgents) - { - state[agent.Id.ToString()] = await agent.SaveStateAsync(); - } - - return state; - } - - public async ValueTask LoadStateAsync(IDictionary state) + public async ValueTask LoadStateAsync(IDictionary state) { HashSet registeredTypes = this._agentsContainer.RegisteredAgentTypes; foreach (var agentIdStr in state.Keys) { Contracts.AgentId agentId = Contracts.AgentId.FromStr(agentIdStr); - if (state[agentIdStr] is not IDictionary agentStateDict) + + if (state[agentIdStr].ValueKind != JsonValueKind.Object) { - throw new Exception($"Agent state for {agentId} is not a {typeof(IDictionary)}: {state[agentIdStr].GetType()}"); + throw new Exception($"Agent state for {agentId} is not a valid JSON object."); } + var agentState = JsonSerializer.Deserialize>(state[agentIdStr].GetRawText()) + ?? throw new Exception($"Failed to deserialize state for {agentId}."); + if (registeredTypes.Contains(agentId.Type)) { IHostableAgent agent = await this._agentsContainer.EnsureAgentAsync(agentId); - await agent.LoadStateAsync(agentStateDict); + await agent.LoadStateAsync(agentState); } } } + public async ValueTask> SaveStateAsync() + { + Dictionary state = new(); + foreach (var agent in this._agentsContainer.LiveAgents) + { + var agentState = await agent.SaveStateAsync(); + state[agent.Id.ToString()] = JsonSerializer.SerializeToElement(agentState); + } + return state; + } + public async ValueTask OnMessageAsync(Message message, CancellationToken cancellation = default) { switch (message.MessageCase) diff --git a/dotnet/src/Microsoft.AutoGen/Core/BaseAgent.cs b/dotnet/src/Microsoft.AutoGen/Core/BaseAgent.cs index 99ff001ba98a..a3899280fef4 100644 --- a/dotnet/src/Microsoft.AutoGen/Core/BaseAgent.cs +++ b/dotnet/src/Microsoft.AutoGen/Core/BaseAgent.cs @@ -4,6 +4,7 @@ using System.Diagnostics; using System.Reflection; +using System.Text.Json; using Microsoft.AutoGen.Contracts; using Microsoft.Extensions.Logging; @@ -92,11 +93,11 @@ private Dictionary ReflectInvokers() return null; } - public virtual ValueTask> SaveStateAsync() + public virtual ValueTask> SaveStateAsync() { - return ValueTask.FromResult>(new Dictionary()); + return ValueTask.FromResult>(new Dictionary()); } - public virtual ValueTask LoadStateAsync(IDictionary state) + public virtual ValueTask LoadStateAsync(IDictionary state) { return ValueTask.CompletedTask; } diff --git a/dotnet/src/Microsoft.AutoGen/Core/InProcessRuntime.cs b/dotnet/src/Microsoft.AutoGen/Core/InProcessRuntime.cs index 69b2d314e550..9acf96e648fc 100644 --- a/dotnet/src/Microsoft.AutoGen/Core/InProcessRuntime.cs +++ b/dotnet/src/Microsoft.AutoGen/Core/InProcessRuntime.cs @@ -3,6 +3,7 @@ using System.Collections.Concurrent; using System.Diagnostics; +using System.Text.Json; using Microsoft.AutoGen.Contracts; using Microsoft.Extensions.Hosting; @@ -12,7 +13,7 @@ public sealed class InProcessRuntime : IAgentRuntime, IHostedService { public bool DeliverToSelf { get; set; } //= false; - Dictionary agentInstances = new(); + internal Dictionary agentInstances = new(); Dictionary subscriptions = new(); Dictionary>> agentFactories = new(); @@ -152,13 +153,13 @@ public async ValueTask GetAgentMetadataAsync(AgentId agentId) return agent.Metadata; } - public async ValueTask LoadAgentStateAsync(AgentId agentId, IDictionary state) + public async ValueTask LoadAgentStateAsync(AgentId agentId, IDictionary state) { IHostableAgent agent = await this.EnsureAgentAsync(agentId); await agent.LoadStateAsync(state); } - public async ValueTask> SaveAgentStateAsync(AgentId agentId) + public async ValueTask> SaveAgentStateAsync(AgentId agentId) { IHostableAgent agent = await this.EnsureAgentAsync(agentId); return await agent.SaveStateAsync(); @@ -187,16 +188,21 @@ public ValueTask RemoveSubscriptionAsync(string subscriptionId) return ValueTask.CompletedTask; } - public async ValueTask LoadStateAsync(IDictionary state) + public async ValueTask LoadStateAsync(IDictionary state) { foreach (var agentIdStr in state.Keys) { AgentId agentId = AgentId.FromStr(agentIdStr); - if (state[agentIdStr] is not IDictionary agentState) + + if (state[agentIdStr].ValueKind != JsonValueKind.Object) { - throw new Exception($"Agent state for {agentId} is not a {typeof(IDictionary)}: {state[agentIdStr].GetType()}"); + throw new Exception($"Agent state for {agentId} is not a valid JSON object."); } + // Deserialize before using + var agentState = JsonSerializer.Deserialize>(state[agentIdStr].GetRawText()) + ?? throw new Exception($"Failed to deserialize state for {agentId}."); + if (this.agentFactories.ContainsKey(agentId.Type)) { IHostableAgent agent = await this.EnsureAgentAsync(agentId); @@ -205,14 +211,14 @@ public async ValueTask LoadStateAsync(IDictionary state) } } - public async ValueTask> SaveStateAsync() + public async ValueTask> SaveStateAsync() { - Dictionary state = new(); + Dictionary state = new(); foreach (var agentId in this.agentInstances.Keys) { - state[agentId.ToString()] = await this.agentInstances[agentId].SaveStateAsync(); + var agentState = await this.agentInstances[agentId].SaveStateAsync(); + state[agentId.ToString()] = JsonSerializer.SerializeToElement(agentState); } - return state; } diff --git a/dotnet/src/Microsoft.AutoGen/Core/Properties/AssemblyInfo.cs b/dotnet/src/Microsoft.AutoGen/Core/Properties/AssemblyInfo.cs new file mode 100644 index 000000000000..8ff44481719e --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core/Properties/AssemblyInfo.cs @@ -0,0 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// AssemblyInfo.cs + +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("Microsoft.AutoGen.Core.Tests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f1d038d0b85ae392ad72011df91e9343b0b5df1bb8080aa21b9424362d696919e0e9ac3a8bca24e283e10f7a569c6f443e1d4e3ebc84377c87ca5caa562e80f9932bf5ea91b7862b538e13b8ba91c7565cf0e8dfeccfea9c805ae3bda044170ecc7fc6f147aeeac422dd96aeb9eb1f5a5882aa650efe2958f2f8107d2038f2ab")] diff --git a/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/GrpcAgentServiceFixture.cs b/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/GrpcAgentServiceFixture.cs index 98c47764269d..1ca37809a57e 100644 --- a/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/GrpcAgentServiceFixture.cs +++ b/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/GrpcAgentServiceFixture.cs @@ -25,8 +25,6 @@ public override async Task OpenChannel(IAsyncStreamReader requestStream throw; } } - public override async Task GetState(AgentId request, ServerCallContext context) => new GetStateResponse { AgentState = new AgentState { AgentId = request } }; - public override async Task SaveState(AgentState request, ServerCallContext context) => new SaveStateResponse { }; public override async Task AddSubscription(AddSubscriptionRequest request, ServerCallContext context) => new AddSubscriptionResponse { }; public override async Task RemoveSubscription(RemoveSubscriptionRequest request, ServerCallContext context) => new RemoveSubscriptionResponse { }; public override async Task GetSubscriptions(GetSubscriptionsRequest request, ServerCallContext context) => new GetSubscriptionsResponse { }; diff --git a/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentRuntimeTests.cs b/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentRuntimeTests.cs deleted file mode 100644 index 812d47c2d207..000000000000 --- a/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentRuntimeTests.cs +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// AgentRuntimeTests.cs -using FluentAssertions; -using Microsoft.AutoGen.Contracts; -using Microsoft.Extensions.Logging; -using Xunit; - -namespace Microsoft.AutoGen.Core.Tests; - -[Trait("Category", "UnitV2")] -public class AgentRuntimeTests() -{ - // Agent will not deliver to self will success when runtime.DeliverToSelf is false (default) - [Fact] - public async Task RuntimeAgentPublishToSelfDefaultNoSendTest() - { - var runtime = new InProcessRuntime(); - await runtime.StartAsync(); - - Logger logger = new(new LoggerFactory()); - SubscribedSelfPublishAgent agent = null!; - - await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) => - { - agent = new SubscribedSelfPublishAgent(id, runtime, logger); - return ValueTask.FromResult(agent); - }); - - // Ensure the agent is actually created - AgentId agentId = await runtime.GetAgentAsync("MyAgent", lazy: false); - - // Validate agent ID - agentId.Should().Be(agent.Id, "Agent ID should match the registered agent"); - - await runtime.RegisterImplicitAgentSubscriptionsAsync("MyAgent"); - - var topicType = "TestTopic"; - - await runtime.PublishMessageAsync("SelfMessage", new TopicId(topicType)).ConfigureAwait(true); - - await runtime.RunUntilIdleAsync(); - - // Agent has default messages and could not publish to self - agent.Text.Source.Should().Be("DefaultTopic"); - agent.Text.Content.Should().Be("DefaultContent"); - } - - // Agent delivery to self will success when runtime.DeliverToSelf is true - [Fact] - public async Task RuntimeAgentPublishToSelfDeliverToSelfTrueTest() - { - var runtime = new InProcessRuntime(); - runtime.DeliverToSelf = true; - await runtime.StartAsync(); - - Logger logger = new(new LoggerFactory()); - SubscribedSelfPublishAgent agent = null!; - - await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) => - { - agent = new SubscribedSelfPublishAgent(id, runtime, logger); - return ValueTask.FromResult(agent); - }); - - // Ensure the agent is actually created - AgentId agentId = await runtime.GetAgentAsync("MyAgent", lazy: false); - - // Validate agent ID - agentId.Should().Be(agent.Id, "Agent ID should match the registered agent"); - - await runtime.RegisterImplicitAgentSubscriptionsAsync("MyAgent"); - - var topicType = "TestTopic"; - - await runtime.PublishMessageAsync("SelfMessage", new TopicId(topicType)).ConfigureAwait(true); - - await runtime.RunUntilIdleAsync(); - - // Agent sucessfully published to self - agent.Text.Source.Should().Be("TestTopic"); - agent.Text.Content.Should().Be("SelfMessage"); - } -} diff --git a/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentTests.cs b/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentTests.cs index c091f9eb7478..805fbc87102b 100644 --- a/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentTests.cs +++ b/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentTests.cs @@ -54,7 +54,7 @@ await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) => return ValueTask.FromResult(agent); }); - // Ensure the agent is actually created + // Ensure the agent id is registered AgentId agentId = await runtime.GetAgentAsync("MyAgent", lazy: false); // Validate agent ID @@ -146,25 +146,4 @@ await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) => Assert.True(agent.ReceivedItems.Count == 1); } - - [Fact] - public async Task AgentShouldSaveStateCorrectlyTest() - { - var runtime = new InProcessRuntime(); - await runtime.StartAsync(); - - Logger logger = new(new LoggerFactory()); - TestAgent agent = new TestAgent(new AgentId("TestType", "TestKey"), runtime, logger); - - var state = await agent.SaveStateAsync(); - - // Ensure state is a dictionary - state.Should().NotBeNull(); - state.Should().BeOfType>(); - state.Should().BeEmpty("Default SaveStateAsync should return an empty dictionary."); - - // Add a sample value and verify it updates correctly - state["testKey"] = "testValue"; - state.Should().ContainKey("testKey").WhoseValue.Should().Be("testValue"); - } } diff --git a/dotnet/test/Microsoft.AutoGen.Core.Tests/InProcessRuntimeTests.cs b/dotnet/test/Microsoft.AutoGen.Core.Tests/InProcessRuntimeTests.cs new file mode 100644 index 000000000000..174f8b7817c2 --- /dev/null +++ b/dotnet/test/Microsoft.AutoGen.Core.Tests/InProcessRuntimeTests.cs @@ -0,0 +1,141 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// InProcessRuntimeTests.cs +using System.Text.Json; +using FluentAssertions; +using Microsoft.AutoGen.Contracts; +using Microsoft.Extensions.Logging; +using Xunit; + +namespace Microsoft.AutoGen.Core.Tests; + +[Trait("Category", "UnitV2")] +public class InProcessRuntimeTests() +{ + // Agent will not deliver to self will success when runtime.DeliverToSelf is false (default) + [Fact] + public async Task RuntimeAgentPublishToSelfDefaultNoSendTest() + { + var runtime = new InProcessRuntime(); + await runtime.StartAsync(); + + Logger logger = new(new LoggerFactory()); + SubscribedSelfPublishAgent agent = null!; + + await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) => + { + agent = new SubscribedSelfPublishAgent(id, runtime, logger); + return ValueTask.FromResult(agent); + }); + + // Ensure the agent is actually created + AgentId agentId = await runtime.GetAgentAsync("MyAgent", lazy: false); + + // Validate agent ID + agentId.Should().Be(agent.Id, "Agent ID should match the registered agent"); + + await runtime.RegisterImplicitAgentSubscriptionsAsync("MyAgent"); + + var topicType = "TestTopic"; + + await runtime.PublishMessageAsync("SelfMessage", new TopicId(topicType)).ConfigureAwait(true); + + await runtime.RunUntilIdleAsync(); + + // Agent has default messages and could not publish to self + agent.Text.Source.Should().Be("DefaultTopic"); + agent.Text.Content.Should().Be("DefaultContent"); + } + + // Agent delivery to self will success when runtime.DeliverToSelf is true + [Fact] + public async Task RuntimeAgentPublishToSelfDeliverToSelfTrueTest() + { + var runtime = new InProcessRuntime(); + runtime.DeliverToSelf = true; + await runtime.StartAsync(); + + Logger logger = new(new LoggerFactory()); + SubscribedSelfPublishAgent agent = null!; + + await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) => + { + agent = new SubscribedSelfPublishAgent(id, runtime, logger); + return ValueTask.FromResult(agent); + }); + + // Ensure the agent is actually created + AgentId agentId = await runtime.GetAgentAsync("MyAgent", lazy: false); + + // Validate agent ID + agentId.Should().Be(agent.Id, "Agent ID should match the registered agent"); + + await runtime.RegisterImplicitAgentSubscriptionsAsync("MyAgent"); + + var topicType = "TestTopic"; + + await runtime.PublishMessageAsync("SelfMessage", new TopicId(topicType)).ConfigureAwait(true); + + await runtime.RunUntilIdleAsync(); + + // Agent sucessfully published to self + agent.Text.Source.Should().Be("TestTopic"); + agent.Text.Content.Should().Be("SelfMessage"); + } + + [Fact] + public async Task RuntimeShouldSaveLoadStateCorrectlyTest() + { + // Create a runtime and register an agent + var runtime = new InProcessRuntime(); + await runtime.StartAsync(); + Logger logger = new(new LoggerFactory()); + SubscribedSaveLoadAgent agent = null!; + await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) => + { + agent = new SubscribedSaveLoadAgent(id, runtime, logger); + return ValueTask.FromResult(agent); + }); + + // Get agent ID and instantiate agent by publishing + AgentId agentId = await runtime.GetAgentAsync("MyAgent", lazy: true); + await runtime.RegisterImplicitAgentSubscriptionsAsync("MyAgent"); + var topicType = "TestTopic"; + await runtime.PublishMessageAsync(new TextMessage { Source = topicType, Content = "test" }, new TopicId(topicType)).ConfigureAwait(true); + await runtime.RunUntilIdleAsync(); + agent.ReceivedMessages.Any().Should().BeTrue("Agent should receive messages when subscribed."); + + // Save the state + var savedState = await runtime.SaveStateAsync(); + + // Ensure saved state contains the agent's state + savedState.Should().ContainKey(agentId.ToString()); + + // Ensure the agent's state is stored as a valid JSON object + savedState[agentId.ToString()].ValueKind.Should().Be(JsonValueKind.Object, "Agent state should be stored as a JSON object"); + + // Serialize and Deserialize the state to simulate persistence + string json = JsonSerializer.Serialize(savedState); + json.Should().NotBeNullOrEmpty("Serialized state should not be empty"); + var deserializedState = JsonSerializer.Deserialize>(json) + ?? throw new Exception("Deserialized state is unexpectedly null"); + deserializedState.Should().ContainKey(agentId.ToString()); + + // Start new runtime and restore the state + var newRuntime = new InProcessRuntime(); + await newRuntime.StartAsync(); + await newRuntime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) => + { + agent = new SubscribedSaveLoadAgent(id, runtime, logger); + return ValueTask.FromResult(agent); + }); + await newRuntime.RegisterImplicitAgentSubscriptionsAsync("MyAgent"); + + // Show that no agent instances exist in the new runtime + newRuntime.agentInstances.Count.Should().Be(0, "Agent should be registered in the new runtime"); + + // Load the state into the new runtime and show that agent is now instantiated + await newRuntime.LoadStateAsync(deserializedState); + newRuntime.agentInstances.Count.Should().Be(1, "Agent should be registered in the new runtime"); + newRuntime.agentInstances.Should().ContainKey(agentId, "Agent should be loaded into the new runtime"); + } +} diff --git a/dotnet/test/Microsoft.AutoGen.Core.Tests/TestAgent.cs b/dotnet/test/Microsoft.AutoGen.Core.Tests/TestAgent.cs index b6dadc833be2..ed87a71053af 100644 --- a/dotnet/test/Microsoft.AutoGen.Core.Tests/TestAgent.cs +++ b/dotnet/test/Microsoft.AutoGen.Core.Tests/TestAgent.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // TestAgent.cs +using System.Text.Json; using Microsoft.AutoGen.Contracts; using Microsoft.Extensions.Logging; @@ -59,7 +60,7 @@ public ValueTask HandleAsync(RpcTextMessage item, MessageContext message /// Key: source /// Value: message /// - private readonly Dictionary _receivedMessages = new(); + protected Dictionary _receivedMessages = new(); public Dictionary ReceivedMessages => _receivedMessages; } @@ -73,6 +74,38 @@ public SubscribedAgent(AgentId id, } } +[TypeSubscription("TestTopic")] +public class SubscribedSaveLoadAgent : TestAgent +{ + public SubscribedSaveLoadAgent(AgentId id, + IAgentRuntime runtime, + Logger? logger = null) : base(id, runtime, logger) + { + } + + public override ValueTask> SaveStateAsync() + { + var jsonSafeDictionary = _receivedMessages.ToDictionary( + kvp => kvp.Key, + kvp => JsonSerializer.SerializeToElement(kvp.Value) // Convert each object to JsonElement + ); + + return ValueTask.FromResult>(jsonSafeDictionary); + } + + public override ValueTask LoadStateAsync(IDictionary state) + { + _receivedMessages.Clear(); + + foreach (var kvp in state) + { + _receivedMessages[kvp.Key] = kvp.Value.Deserialize() ?? throw new Exception($"Failed to deserialize key: {kvp.Key}"); + } + + return ValueTask.CompletedTask; + } +} + /// /// The test agent showing an agent that subscribes to itself. /// diff --git a/protos/agent_worker.proto b/protos/agent_worker.proto index 07375fe97250..52fe809a20c9 100644 --- a/protos/agent_worker.proto +++ b/protos/agent_worker.proto @@ -79,23 +79,6 @@ message GetSubscriptionsResponse { repeated Subscription subscriptions = 1; } -message AgentState { - AgentId agent_id = 1; - string eTag = 2; - oneof data { - bytes binary_data = 3; - string text_data = 4; - google.protobuf.Any proto_data = 5; - } -} - -message GetStateResponse { - AgentState agent_state = 1; -} - -message SaveStateResponse { -} - message Message { oneof message { RpcRequest request = 1; @@ -104,10 +87,46 @@ message Message { } } +message SaveStateRequest { + AgentId agentId = 1; +} + +message SaveStateResponse { + string state = 1; + optional string error = 2; +} + +message LoadStateRequest { + AgentId agentId = 1; + string state = 2; +} +message LoadStateResponse { + optional string error = 1; +} + +message ControlMessage { + // A response message should have the same id as the request message + string rpc_id = 1; + // This is either: + // agentid=AGENT_ID + // clientid=CLIENT_ID + string destination = 2; + // This is either: + // agentid=AGENT_ID + // clientid=CLIENT_ID + // Empty string means the message is a response + optional string respond_to = 3; + // One of: + // SaveStateRequest saveStateRequest = 2; + // SaveStateResponse saveStateResponse = 3; + // LoadStateRequest loadStateRequest = 4; + // LoadStateResponse loadStateResponse = 5; + google.protobuf.Any rpcMessage = 4; +} + service AgentRpc { rpc OpenChannel (stream Message) returns (stream Message); - rpc GetState(AgentId) returns (GetStateResponse); - rpc SaveState(AgentState) returns (SaveStateResponse); + rpc OpenControlChannel (stream ControlMessage) returns (stream ControlMessage); rpc RegisterAgent(RegisterAgentTypeRequest) returns (RegisterAgentTypeResponse); rpc AddSubscription(AddSubscriptionRequest) returns (AddSubscriptionResponse); rpc RemoveSubscription(RemoveSubscriptionRequest) returns (RemoveSubscriptionResponse); diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index bf4bc95946ad..de0ef3247c69 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -3,7 +3,7 @@ from typing import Any, Callable, Dict, List, Mapping, Sequence from autogen_core import Component, ComponentModel -from autogen_core.models import ChatCompletionClient, SystemMessage, UserMessage +from autogen_core.models import ChatCompletionClient, ModelFamily, SystemMessage, UserMessage from pydantic import BaseModel from typing_extensions import Self @@ -110,18 +110,17 @@ async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str: message += " [Image]" else: raise ValueError(f"Unexpected message type in selector: {type(msg)}") - history_messages.append(message) + history_messages.append( + message.rstrip() + "\n\n" + ) # Create some consistency for how messages are separated in the transcript history = "\n".join(history_messages) # Construct agent roles, we are using the participant topic type as the agent name. - roles = "\n".join( - [ - f"{topic_type}: {description}".strip() - for topic_type, description in zip( - self._participant_topic_types, self._participant_descriptions, strict=True - ) - ] - ) + # Each agent sould appear on a single line. + roles = "" + for topic_type, description in zip(self._participant_topic_types, self._participant_descriptions, strict=True): + roles += re.sub(r"\s+", " ", f"{topic_type}: {description}").strip() + "\n" + roles = roles.strip() # Construct agent list to be selected, skip the previous speaker if not allowed. if self._previous_speaker is not None and not self._allow_repeated_speaker: @@ -136,11 +135,20 @@ async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str: roles=roles, participants=str(participants), history=history ) select_speaker_messages: List[SystemMessage | UserMessage] - if self._model_client.model_info["family"].startswith("gemini"): - select_speaker_messages = [UserMessage(content=select_speaker_prompt, source="selector")] - else: + if self._model_client.model_info["family"] in [ + ModelFamily.GPT_4, + ModelFamily.GPT_4O, + ModelFamily.GPT_35, + ModelFamily.O1, + ModelFamily.O3, + ]: select_speaker_messages = [SystemMessage(content=select_speaker_prompt)] + else: + # Many other models need a UserMessage to respond to + select_speaker_messages = [UserMessage(content=select_speaker_prompt, source="selector")] + response = await self._model_client.create(messages=select_speaker_messages) + assert isinstance(response.content, str) mentions = self._mentioned_agents(response.content, self._participant_topic_types) if len(mentions) != 1: diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py b/python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py index 1f80166d32a1..0a95c842ea08 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py @@ -75,8 +75,8 @@ def notify_event_received(self, request_id: str) -> None: self.input_events[request_id] = event -def aprint(output: str, end: str = "\n") -> Awaitable[None]: - return asyncio.to_thread(print, output, end=end) +def aprint(output: str, end: str = "\n", flush: bool = False) -> Awaitable[None]: + return asyncio.to_thread(print, output, end=end, flush=flush) async def Console( @@ -126,7 +126,7 @@ async def Console( f"Total completion tokens: {total_usage.completion_tokens}\n" f"Duration: {duration:.2f} seconds\n" ) - await aprint(output, end="") + await aprint(output, end="", flush=True) # mypy ignore last_processed = message # type: ignore @@ -141,7 +141,7 @@ async def Console( output += f"[Prompt tokens: {message.chat_message.models_usage.prompt_tokens}, Completion tokens: {message.chat_message.models_usage.completion_tokens}]\n" total_usage.completion_tokens += message.chat_message.models_usage.completion_tokens total_usage.prompt_tokens += message.chat_message.models_usage.prompt_tokens - await aprint(output, end="") + await aprint(output, end="", flush=True) # Print summary. if output_stats: @@ -156,7 +156,7 @@ async def Console( f"Total completion tokens: {total_usage.completion_tokens}\n" f"Duration: {duration:.2f} seconds\n" ) - await aprint(output, end="") + await aprint(output, end="", flush=True) # mypy ignore last_processed = message # type: ignore @@ -169,7 +169,7 @@ async def Console( message = cast(AgentEvent | ChatMessage, message) # type: ignore if not streaming_chunks: # Print message sender. - await aprint(f"{'-' * 10} {message.source} {'-' * 10}", end="\n") + await aprint(f"{'-' * 10} {message.source} {'-' * 10}", end="\n", flush=True) if isinstance(message, ModelClientStreamingChunkEvent): await aprint(message.content, end="") streaming_chunks.append(message.content) @@ -177,15 +177,16 @@ async def Console( if streaming_chunks: streaming_chunks.clear() # Chunked messages are already printed, so we just print a newline. - await aprint("", end="\n") + await aprint("", end="\n", flush=True) else: # Print message content. - await aprint(_message_to_str(message, render_image_iterm=render_image_iterm), end="\n") + await aprint(_message_to_str(message, render_image_iterm=render_image_iterm), end="\n", flush=True) if message.models_usage: if output_stats: await aprint( f"[Prompt tokens: {message.models_usage.prompt_tokens}, Completion tokens: {message.models_usage.completion_tokens}]", end="\n", + flush=True, ) total_usage.completion_tokens += message.models_usage.completion_tokens total_usage.prompt_tokens += message.models_usage.prompt_tokens diff --git a/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_multimodal_web_surfer.py b/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_multimodal_web_surfer.py index 73c0799fff38..2c62caf2cad0 100644 --- a/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_multimodal_web_surfer.py +++ b/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_multimodal_web_surfer.py @@ -10,12 +10,10 @@ from typing import ( Any, AsyncGenerator, - BinaryIO, Dict, List, Optional, Sequence, - cast, ) from urllib.parse import quote_plus @@ -31,6 +29,7 @@ AssistantMessage, ChatCompletionClient, LLMMessage, + ModelFamily, RequestUsage, SystemMessage, UserMessage, @@ -42,7 +41,6 @@ from ._events import WebSurferEvent from ._prompts import ( - WEB_SURFER_OCR_PROMPT, WEB_SURFER_QA_PROMPT, WEB_SURFER_QA_SYSTEM_MESSAGE, WEB_SURFER_TOOL_PROMPT_MM, @@ -444,6 +442,22 @@ async def _generate_reply(self, cancellation_token: CancellationToken) -> UserCo # Clone the messages, removing old screenshots history: List[LLMMessage] = remove_images(self._chat_history) + # Split the history, removing the last message + if len(history): + user_request = history.pop() + else: + user_request = UserMessage(content="Empty request.", source="user") + + # Truncate the history for smaller models + if self._model_client.model_info["family"] not in [ + ModelFamily.GPT_4O, + ModelFamily.O1, + ModelFamily.O3, + ModelFamily.GPT_4, + ModelFamily.GPT_35, + ]: + history = [] + # Ask the page for interactive elements, then prepare the state-of-mark screenshot rects = await self._playwright_controller.get_interactive_rects(self._page) viewport = await self._playwright_controller.get_visual_viewport(self._page) @@ -499,21 +513,31 @@ async def _generate_reply(self, cancellation_token: CancellationToken) -> UserCo other_targets.extend(self._format_target_list(rects_below, rects)) if len(other_targets) > 0: + if len(other_targets) > 30: + other_targets = other_targets[0:30] + other_targets.append("...") other_targets_str = ( - "Additional valid interaction targets (not shown) include:\n" + "\n".join(other_targets) + "\n\n" + "Additional valid interaction targets include (but are not limited to):\n" + + "\n".join(other_targets) + + "\n\n" ) else: other_targets_str = "" + state_description = "Your " + await self._get_state_description() tool_names = "\n".join([t["name"] for t in tools]) + page_title = await self._page.title() + prompt_message = None if self._model_client.model_info["vision"]: text_prompt = WEB_SURFER_TOOL_PROMPT_MM.format( - url=self._page.url, + state_description=state_description, visible_targets=visible_targets, other_targets_str=other_targets_str, focused_hint=focused_hint, tool_names=tool_names, + title=page_title, + url=self._page.url, ).strip() # Scale the screenshot for the MLM, and close the original @@ -522,26 +546,42 @@ async def _generate_reply(self, cancellation_token: CancellationToken) -> UserCo if self.to_save_screenshots: scaled_screenshot.save(os.path.join(self.debug_dir, "screenshot_scaled.png")) # type: ignore - # Add the message - history.append(UserMessage(content=[text_prompt, AGImage.from_pil(scaled_screenshot)], source=self.name)) + # Create the message + prompt_message = UserMessage( + content=[re.sub(r"(\n\s*){3,}", "\n\n", text_prompt), AGImage.from_pil(scaled_screenshot)], + source=self.name, + ) else: - visible_text = await self._playwright_controller.get_visible_text(self._page) - text_prompt = WEB_SURFER_TOOL_PROMPT_TEXT.format( - url=self._page.url, + state_description=state_description, visible_targets=visible_targets, other_targets_str=other_targets_str, focused_hint=focused_hint, tool_names=tool_names, - visible_text=visible_text.strip(), + title=page_title, + url=self._page.url, ).strip() - # Add the message - history.append(UserMessage(content=text_prompt, source=self.name)) + # Create the message + prompt_message = UserMessage(content=re.sub(r"(\n\s*){3,}", "\n\n", text_prompt), source=self.name) + + history.append(prompt_message) + history.append(user_request) + + # {history[-2].content if isinstance(history[-2].content, str) else history[-2].content[0]} + # print(f""" + # ================={len(history)}================= + # {history[-2].content} + # ===== + # {history[-1].content} + # =================================================== + # """) + # Make the request response = await self._model_client.create( history, tools=tools, extra_create_args={"tool_choice": "auto"}, cancellation_token=cancellation_token ) # , "parallel_tool_calls": False}) + self.model_usage.append(response.usage) message = response.content self._last_download = None @@ -716,23 +756,12 @@ async def _execute_tool( metadata_hash = hashlib.md5(page_metadata.encode("utf-8")).hexdigest() if metadata_hash != self._prior_metadata_hash: page_metadata = ( - "\nThe following metadata was extracted from the webpage:\n\n" + page_metadata.strip() + "\n" + "\n\nThe following metadata was extracted from the webpage:\n\n" + page_metadata.strip() + "\n" ) else: page_metadata = "" self._prior_metadata_hash = metadata_hash - # Describe the viewport of the new page in words - viewport = await self._playwright_controller.get_visual_viewport(self._page) - percent_visible = int(viewport["height"] * 100 / viewport["scrollHeight"]) - percent_scrolled = int(viewport["pageTop"] * 100 / viewport["scrollHeight"]) - if percent_scrolled < 1: # Allow some rounding error - position_text = "at the top of the page" - elif percent_scrolled + percent_visible >= 99: # Allow some rounding error - position_text = "at the bottom of the page" - else: - position_text = str(percent_scrolled) + "% down from the top of the page" - new_screenshot = await self._page.screenshot() if self.to_save_screenshots: current_timestamp = "_" + int(time.time()).__str__() @@ -748,25 +777,40 @@ async def _execute_tool( ) ) - ocr_text = ( - await self._get_ocr_text(new_screenshot, cancellation_token=cancellation_token) - if self.use_ocr is True - else await self._playwright_controller.get_visible_text(self._page) - ) - # Return the complete observation - page_title = await self._page.title() - message_content = f"{action_description}\n\n Here is a screenshot of the webpage: [{page_title}]({self._page.url}).\n The viewport shows {percent_visible}% of the webpage, and is positioned {position_text} {page_metadata}\n" - if self.use_ocr: - message_content += f"Automatic OCR of the page screenshot has detected the following text:\n\n{ocr_text}" - else: - message_content += f"The following text is visible in the viewport:\n\n{ocr_text}" + state_description = "The " + await self._get_state_description() + message_content = ( + f"{action_description}\n\n" + state_description + page_metadata + "\nHere is a screenshot of the page." + ) return [ - message_content, + re.sub(r"(\n\s*){3,}", "\n\n", message_content), # Removing blank lines AGImage.from_pil(PIL.Image.open(io.BytesIO(new_screenshot))), ] + async def _get_state_description(self) -> str: + assert self._playwright_controller is not None + assert self._page is not None + + # Describe the viewport of the new page in words + viewport = await self._playwright_controller.get_visual_viewport(self._page) + percent_visible = int(viewport["height"] * 100 / viewport["scrollHeight"]) + percent_scrolled = int(viewport["pageTop"] * 100 / viewport["scrollHeight"]) + if percent_scrolled < 1: # Allow some rounding error + position_text = "at the top of the page" + elif percent_scrolled + percent_visible >= 99: # Allow some rounding error + position_text = "at the bottom of the page" + else: + position_text = str(percent_scrolled) + "% down from the top of the page" + + visible_text = await self._playwright_controller.get_visible_text(self._page) + + # Return the complete observation + page_title = await self._page.title() + message_content = f"web browser is open to the page [{page_title}]({self._page.url}).\nThe viewport shows {percent_visible}% of the webpage, and is positioned {position_text}\n" + message_content += f"The following text is visible in the viewport:\n\n{visible_text}" + return message_content + def _target_name(self, target: str, rects: Dict[str, InteractiveRegion]) -> str | None: try: return rects[target]["aria_name"].strip() @@ -798,38 +842,6 @@ def _format_target_list(self, ids: List[str], rects: Dict[str, InteractiveRegion return targets - async def _get_ocr_text( - self, image: bytes | io.BufferedIOBase | PIL.Image.Image, cancellation_token: Optional[CancellationToken] = None - ) -> str: - scaled_screenshot = None - if isinstance(image, PIL.Image.Image): - scaled_screenshot = image.resize((self.MLM_WIDTH, self.MLM_HEIGHT)) - else: - pil_image = None - if not isinstance(image, io.BufferedIOBase): - pil_image = PIL.Image.open(io.BytesIO(image)) - else: - pil_image = PIL.Image.open(cast(BinaryIO, image)) - scaled_screenshot = pil_image.resize((self.MLM_WIDTH, self.MLM_HEIGHT)) - pil_image.close() - - # Add the multimodal message and make the request - messages: List[LLMMessage] = [] - messages.append( - UserMessage( - content=[ - WEB_SURFER_OCR_PROMPT, - AGImage.from_pil(scaled_screenshot), - ], - source=self.name, - ) - ) - response = await self._model_client.create(messages, cancellation_token=cancellation_token) - self.model_usage.append(response.usage) - scaled_screenshot.close() - assert isinstance(response.content, str) - return response.content - async def _summarize_page( self, question: str | None = None, diff --git a/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_prompts.py b/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_prompts.py index 59a0a7c95d5e..d1f1885240e2 100644 --- a/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_prompts.py +++ b/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_prompts.py @@ -1,43 +1,42 @@ WEB_SURFER_TOOL_PROMPT_MM = """ -Consider the following screenshot of a web browser, which is open to the page '{url}'. In this screenshot, interactive elements are outlined in bounding boxes of different colors. Each bounding box has a numeric ID label in the same color. Additional information about each visible label is listed below: +{state_description} + +Consider the following screenshot of the page. In this screenshot, interactive elements are outlined in bounding boxes of different colors. Each bounding box has a numeric ID label in the same color. Additional information about each visible label is listed below: {visible_targets}{other_targets_str}{focused_hint} -You are to respond to the most recent request by selecting an appropriate tool from the following set, or by answering the question directly if possible without tools: +You are to respond to my next request by selecting an appropriate tool from the following set, or by answering the question directly if possible: {tool_names} When deciding between tools, consider if the request can be best addressed by: - - the contents of the current viewport (in which case actions like clicking links, clicking buttons, inputting text, or hovering over an element might be most appropriate) - - contents found elsewhere on the full webpage (in which case actions like scrolling, summarization, or full-page Q&A might be most appropriate) - - on some other website entirely (in which case actions like performing a new web search might be the best option) + - the contents of the CURRENT VIEWPORT (in which case actions like clicking links, clicking buttons, inputting text, or hovering over an element, might be more appropriate) + - contents found elsewhere on the CURRENT WEBPAGE [{title}]({url}), in which case actions like scrolling, summarization, or full-page Q&A might be most appropriate + - on ANOTHER WEBSITE entirely (in which case actions like performing a new web search might be the best option) + +My request follows: """ WEB_SURFER_TOOL_PROMPT_TEXT = """ -Your web browser is open to the page '{url}'. The following text is visible in the viewport: - -``` -{visible_text} -``` +{state_description} You have also identified the following interactive components: {visible_targets}{other_targets_str}{focused_hint} -You are to respond to the most recent request by selecting an appropriate tool from the following set, or by answering the question directly if possible without tools: +You are to respond to my next request by selecting an appropriate tool from the following set, or by answering the question directly if possible: {tool_names} When deciding between tools, consider if the request can be best addressed by: - - the contents of the current viewport (in which case actions like clicking links, clicking buttons, inputting text, or hovering over an element might be most appropriate) - - contents found elsewhere on the full webpage (in which case actions like scrolling, summarization, or full-page Q&A might be most appropriate) - - on some other website entirely (in which case actions like performing a new web search might be the best option) -""" + - the contents of the CURRENT VIEWPORT (in which case actions like clicking links, clicking buttons, inputting text, or hovering over an element, might be more appropriate) + - contents found elsewhere on the CURRENT WEBPAGE [{title}]({url}), in which case actions like scrolling, summarization, or full-page Q&A might be most appropriate + - on ANOTHER WEBSITE entirely (in which case actions like performing a new web search might be the best option) -WEB_SURFER_OCR_PROMPT = """ -Please transcribe all visible text on this page, including both main content and the labels of UI elements. +My request follows: """ + WEB_SURFER_QA_SYSTEM_MESSAGE = """ You are a helpful assistant that can summarize long documents to answer question. """ diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py index 84df493bd1e8..daa4ad65101d 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py @@ -135,6 +135,13 @@ async def handle_callback(message: agent_worker_pb2.Message) -> None: # Remove the client id from the agent type to client id mapping. await self._on_client_disconnect(client_id) + async def OpenControlChannel( # type: ignore + self, + request_iterator: AsyncIterator[agent_worker_pb2.ControlMessage], + context: grpc.aio.ServicerContext[agent_worker_pb2.ControlMessage, agent_worker_pb2.ControlMessage], + ) -> AsyncIterator[agent_worker_pb2.ControlMessage]: + raise NotImplementedError("Method not implemented.") + async def _on_client_disconnect(self, client_id: ClientConnectionId) -> None: async with self._agent_type_to_client_id_lock: agent_types = [agent_type for agent_type, id_ in self._agent_type_to_client_id.items() if id_ == client_id] @@ -288,17 +295,3 @@ async def GetSubscriptions( # type: ignore ) -> agent_worker_pb2.GetSubscriptionsResponse: _client_id = await get_client_id_or_abort(context) raise NotImplementedError("Method not implemented.") - - async def GetState( # type: ignore - self, - request: agent_worker_pb2.AgentId, - context: grpc.aio.ServicerContext[agent_worker_pb2.AgentId, agent_worker_pb2.GetStateResponse], - ) -> agent_worker_pb2.GetStateResponse: - raise NotImplementedError("Method not implemented!") - - async def SaveState( # type: ignore - self, - request: agent_worker_pb2.AgentState, - context: grpc.aio.ServicerContext[agent_worker_pb2.AgentId, agent_worker_pb2.SaveStateResponse], - ) -> agent_worker_pb2.SaveStateResponse: - raise NotImplementedError("Method not implemented!") diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2.py index b3e0af61f049..54209d2fb284 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2.py @@ -26,7 +26,7 @@ from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\x1a\x10\x63loudevent.proto\x1a\x19google/protobuf/any.proto\"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"E\n\x07Payload\x12\x11\n\tdata_type\x18\x01 \x01(\t\x12\x19\n\x11\x64\x61ta_content_type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\x89\x02\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12$\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12 \n\x07payload\x18\x05 \x01(\x0b\x32\x0f.agents.Payload\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"\xb8\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12 \n\x07payload\x18\x02 \x01(\x0b\x32\x0f.agents.Payload\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"(\n\x18RegisterAgentTypeRequest\x12\x0c\n\x04type\x18\x01 \x01(\t\"\x1b\n\x19RegisterAgentTypeResponse\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"G\n\x16TypePrefixSubscription\x12\x19\n\x11topic_type_prefix\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"\xa2\x01\n\x0cSubscription\x12\n\n\x02id\x18\x01 \x01(\t\x12\x34\n\x10typeSubscription\x18\x02 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x12@\n\x16typePrefixSubscription\x18\x03 \x01(\x0b\x32\x1e.agents.TypePrefixSubscriptionH\x00\x42\x0e\n\x0csubscription\"D\n\x16\x41\x64\x64SubscriptionRequest\x12*\n\x0csubscription\x18\x01 \x01(\x0b\x32\x14.agents.Subscription\"\x19\n\x17\x41\x64\x64SubscriptionResponse\"\'\n\x19RemoveSubscriptionRequest\x12\n\n\x02id\x18\x01 \x01(\t\"\x1c\n\x1aRemoveSubscriptionResponse\"\x19\n\x17GetSubscriptionsRequest\"G\n\x18GetSubscriptionsResponse\x12+\n\rsubscriptions\x18\x01 \x03(\x0b\x32\x14.agents.Subscription\"\x9d\x01\n\nAgentState\x12!\n\x08\x61gent_id\x18\x01 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0c\n\x04\x65Tag\x18\x02 \x01(\t\x12\x15\n\x0b\x62inary_data\x18\x03 \x01(\x0cH\x00\x12\x13\n\ttext_data\x18\x04 \x01(\tH\x00\x12*\n\nproto_data\x18\x05 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00\x42\x06\n\x04\x64\x61ta\";\n\x10GetStateResponse\x12\'\n\x0b\x61gent_state\x18\x01 \x01(\x0b\x32\x12.agents.AgentState\"\x13\n\x11SaveStateResponse\"\x99\x01\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x33\n\ncloudEvent\x18\x03 \x01(\x0b\x32\x1d.io.cloudevents.v1.CloudEventH\x00\x42\t\n\x07message2\x90\x04\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x12\x35\n\x08GetState\x12\x0f.agents.AgentId\x1a\x18.agents.GetStateResponse\x12:\n\tSaveState\x12\x12.agents.AgentState\x1a\x19.agents.SaveStateResponse\x12T\n\rRegisterAgent\x12 .agents.RegisterAgentTypeRequest\x1a!.agents.RegisterAgentTypeResponse\x12R\n\x0f\x41\x64\x64Subscription\x12\x1e.agents.AddSubscriptionRequest\x1a\x1f.agents.AddSubscriptionResponse\x12[\n\x12RemoveSubscription\x12!.agents.RemoveSubscriptionRequest\x1a\".agents.RemoveSubscriptionResponse\x12U\n\x10GetSubscriptions\x12\x1f.agents.GetSubscriptionsRequest\x1a .agents.GetSubscriptionsResponseB\x1d\xaa\x02\x1aMicrosoft.AutoGen.Protobufb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\x1a\x10\x63loudevent.proto\x1a\x19google/protobuf/any.proto\"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"E\n\x07Payload\x12\x11\n\tdata_type\x18\x01 \x01(\t\x12\x19\n\x11\x64\x61ta_content_type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\x89\x02\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12$\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12 \n\x07payload\x18\x05 \x01(\x0b\x32\x0f.agents.Payload\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"\xb8\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12 \n\x07payload\x18\x02 \x01(\x0b\x32\x0f.agents.Payload\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"(\n\x18RegisterAgentTypeRequest\x12\x0c\n\x04type\x18\x01 \x01(\t\"\x1b\n\x19RegisterAgentTypeResponse\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"G\n\x16TypePrefixSubscription\x12\x19\n\x11topic_type_prefix\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"\xa2\x01\n\x0cSubscription\x12\n\n\x02id\x18\x01 \x01(\t\x12\x34\n\x10typeSubscription\x18\x02 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x12@\n\x16typePrefixSubscription\x18\x03 \x01(\x0b\x32\x1e.agents.TypePrefixSubscriptionH\x00\x42\x0e\n\x0csubscription\"D\n\x16\x41\x64\x64SubscriptionRequest\x12*\n\x0csubscription\x18\x01 \x01(\x0b\x32\x14.agents.Subscription\"\x19\n\x17\x41\x64\x64SubscriptionResponse\"\'\n\x19RemoveSubscriptionRequest\x12\n\n\x02id\x18\x01 \x01(\t\"\x1c\n\x1aRemoveSubscriptionResponse\"\x19\n\x17GetSubscriptionsRequest\"G\n\x18GetSubscriptionsResponse\x12+\n\rsubscriptions\x18\x01 \x03(\x0b\x32\x14.agents.Subscription\"\x99\x01\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x33\n\ncloudEvent\x18\x03 \x01(\x0b\x32\x1d.io.cloudevents.v1.CloudEventH\x00\x42\t\n\x07message\"4\n\x10SaveStateRequest\x12 \n\x07\x61gentId\x18\x01 \x01(\x0b\x32\x0f.agents.AgentId\"@\n\x11SaveStateResponse\x12\r\n\x05state\x18\x01 \x01(\t\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"C\n\x10LoadStateRequest\x12 \n\x07\x61gentId\x18\x01 \x01(\x0b\x32\x0f.agents.AgentId\x12\r\n\x05state\x18\x02 \x01(\t\"1\n\x11LoadStateResponse\x12\x12\n\x05\x65rror\x18\x01 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"\x87\x01\n\x0e\x43ontrolMessage\x12\x0e\n\x06rpc_id\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65stination\x18\x02 \x01(\t\x12\x17\n\nrespond_to\x18\x03 \x01(\tH\x00\x88\x01\x01\x12(\n\nrpcMessage\x18\x04 \x01(\x0b\x32\x14.google.protobuf.AnyB\r\n\x0b_respond_to2\xe7\x03\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x12H\n\x12OpenControlChannel\x12\x16.agents.ControlMessage\x1a\x16.agents.ControlMessage(\x01\x30\x01\x12T\n\rRegisterAgent\x12 .agents.RegisterAgentTypeRequest\x1a!.agents.RegisterAgentTypeResponse\x12R\n\x0f\x41\x64\x64Subscription\x12\x1e.agents.AddSubscriptionRequest\x1a\x1f.agents.AddSubscriptionResponse\x12[\n\x12RemoveSubscription\x12!.agents.RemoveSubscriptionRequest\x1a\".agents.RemoveSubscriptionResponse\x12U\n\x10GetSubscriptions\x12\x1f.agents.GetSubscriptionsRequest\x1a .agents.GetSubscriptionsResponseB\x1d\xaa\x02\x1aMicrosoft.AutoGen.Protobufb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -72,14 +72,18 @@ _globals['_GETSUBSCRIPTIONSREQUEST']._serialized_end=1201 _globals['_GETSUBSCRIPTIONSRESPONSE']._serialized_start=1203 _globals['_GETSUBSCRIPTIONSRESPONSE']._serialized_end=1274 - _globals['_AGENTSTATE']._serialized_start=1277 - _globals['_AGENTSTATE']._serialized_end=1434 - _globals['_GETSTATERESPONSE']._serialized_start=1436 - _globals['_GETSTATERESPONSE']._serialized_end=1495 - _globals['_SAVESTATERESPONSE']._serialized_start=1497 - _globals['_SAVESTATERESPONSE']._serialized_end=1516 - _globals['_MESSAGE']._serialized_start=1519 - _globals['_MESSAGE']._serialized_end=1672 - _globals['_AGENTRPC']._serialized_start=1675 - _globals['_AGENTRPC']._serialized_end=2203 + _globals['_MESSAGE']._serialized_start=1277 + _globals['_MESSAGE']._serialized_end=1430 + _globals['_SAVESTATEREQUEST']._serialized_start=1432 + _globals['_SAVESTATEREQUEST']._serialized_end=1484 + _globals['_SAVESTATERESPONSE']._serialized_start=1486 + _globals['_SAVESTATERESPONSE']._serialized_end=1550 + _globals['_LOADSTATEREQUEST']._serialized_start=1552 + _globals['_LOADSTATEREQUEST']._serialized_end=1619 + _globals['_LOADSTATERESPONSE']._serialized_start=1621 + _globals['_LOADSTATERESPONSE']._serialized_end=1670 + _globals['_CONTROLMESSAGE']._serialized_start=1673 + _globals['_CONTROLMESSAGE']._serialized_end=1808 + _globals['_AGENTRPC']._serialized_start=1811 + _globals['_AGENTRPC']._serialized_end=2298 # @@protoc_insertion_point(module_scope) diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2.pyi b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2.pyi index b37fb5ac2979..a12c53e73a7c 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2.pyi +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2.pyi @@ -313,85 +313,145 @@ class GetSubscriptionsResponse(google.protobuf.message.Message): global___GetSubscriptionsResponse = GetSubscriptionsResponse @typing.final -class AgentState(google.protobuf.message.Message): +class Message(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - AGENT_ID_FIELD_NUMBER: builtins.int - ETAG_FIELD_NUMBER: builtins.int - BINARY_DATA_FIELD_NUMBER: builtins.int - TEXT_DATA_FIELD_NUMBER: builtins.int - PROTO_DATA_FIELD_NUMBER: builtins.int - eTag: builtins.str - binary_data: builtins.bytes - text_data: builtins.str + REQUEST_FIELD_NUMBER: builtins.int + RESPONSE_FIELD_NUMBER: builtins.int + CLOUDEVENT_FIELD_NUMBER: builtins.int @property - def agent_id(self) -> global___AgentId: ... + def request(self) -> global___RpcRequest: ... + @property + def response(self) -> global___RpcResponse: ... @property - def proto_data(self) -> google.protobuf.any_pb2.Any: ... + def cloudEvent(self) -> cloudevent_pb2.CloudEvent: ... def __init__( self, *, - agent_id: global___AgentId | None = ..., - eTag: builtins.str = ..., - binary_data: builtins.bytes = ..., - text_data: builtins.str = ..., - proto_data: google.protobuf.any_pb2.Any | None = ..., + request: global___RpcRequest | None = ..., + response: global___RpcResponse | None = ..., + cloudEvent: cloudevent_pb2.CloudEvent | None = ..., ) -> None: ... - def HasField(self, field_name: typing.Literal["agent_id", b"agent_id", "binary_data", b"binary_data", "data", b"data", "proto_data", b"proto_data", "text_data", b"text_data"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["agent_id", b"agent_id", "binary_data", b"binary_data", "data", b"data", "eTag", b"eTag", "proto_data", b"proto_data", "text_data", b"text_data"]) -> None: ... - def WhichOneof(self, oneof_group: typing.Literal["data", b"data"]) -> typing.Literal["binary_data", "text_data", "proto_data"] | None: ... + def HasField(self, field_name: typing.Literal["cloudEvent", b"cloudEvent", "message", b"message", "request", b"request", "response", b"response"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["cloudEvent", b"cloudEvent", "message", b"message", "request", b"request", "response", b"response"]) -> None: ... + def WhichOneof(self, oneof_group: typing.Literal["message", b"message"]) -> typing.Literal["request", "response", "cloudEvent"] | None: ... -global___AgentState = AgentState +global___Message = Message @typing.final -class GetStateResponse(google.protobuf.message.Message): +class SaveStateRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - AGENT_STATE_FIELD_NUMBER: builtins.int + AGENTID_FIELD_NUMBER: builtins.int @property - def agent_state(self) -> global___AgentState: ... + def agentId(self) -> global___AgentId: ... def __init__( self, *, - agent_state: global___AgentState | None = ..., + agentId: global___AgentId | None = ..., ) -> None: ... - def HasField(self, field_name: typing.Literal["agent_state", b"agent_state"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["agent_state", b"agent_state"]) -> None: ... + def HasField(self, field_name: typing.Literal["agentId", b"agentId"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["agentId", b"agentId"]) -> None: ... -global___GetStateResponse = GetStateResponse +global___SaveStateRequest = SaveStateRequest @typing.final class SaveStateResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + STATE_FIELD_NUMBER: builtins.int + ERROR_FIELD_NUMBER: builtins.int + state: builtins.str + error: builtins.str def __init__( self, + *, + state: builtins.str = ..., + error: builtins.str | None = ..., ) -> None: ... + def HasField(self, field_name: typing.Literal["_error", b"_error", "error", b"error"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["_error", b"_error", "error", b"error", "state", b"state"]) -> None: ... + def WhichOneof(self, oneof_group: typing.Literal["_error", b"_error"]) -> typing.Literal["error"] | None: ... global___SaveStateResponse = SaveStateResponse @typing.final -class Message(google.protobuf.message.Message): +class LoadStateRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - REQUEST_FIELD_NUMBER: builtins.int - RESPONSE_FIELD_NUMBER: builtins.int - CLOUDEVENT_FIELD_NUMBER: builtins.int - @property - def request(self) -> global___RpcRequest: ... + AGENTID_FIELD_NUMBER: builtins.int + STATE_FIELD_NUMBER: builtins.int + state: builtins.str @property - def response(self) -> global___RpcResponse: ... + def agentId(self) -> global___AgentId: ... + def __init__( + self, + *, + agentId: global___AgentId | None = ..., + state: builtins.str = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["agentId", b"agentId"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["agentId", b"agentId", "state", b"state"]) -> None: ... + +global___LoadStateRequest = LoadStateRequest + +@typing.final +class LoadStateResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ERROR_FIELD_NUMBER: builtins.int + error: builtins.str + def __init__( + self, + *, + error: builtins.str | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["_error", b"_error", "error", b"error"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["_error", b"_error", "error", b"error"]) -> None: ... + def WhichOneof(self, oneof_group: typing.Literal["_error", b"_error"]) -> typing.Literal["error"] | None: ... + +global___LoadStateResponse = LoadStateResponse + +@typing.final +class ControlMessage(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + RPC_ID_FIELD_NUMBER: builtins.int + DESTINATION_FIELD_NUMBER: builtins.int + RESPOND_TO_FIELD_NUMBER: builtins.int + RPCMESSAGE_FIELD_NUMBER: builtins.int + rpc_id: builtins.str + """A response message should have the same id as the request message""" + destination: builtins.str + """This is either: + agentid=AGENT_ID + clientid=CLIENT_ID + """ + respond_to: builtins.str + """This is either: + agentid=AGENT_ID + clientid=CLIENT_ID + Empty string means the message is a response + """ @property - def cloudEvent(self) -> cloudevent_pb2.CloudEvent: ... + def rpcMessage(self) -> google.protobuf.any_pb2.Any: + """One of: + SaveStateRequest saveStateRequest = 2; + SaveStateResponse saveStateResponse = 3; + LoadStateRequest loadStateRequest = 4; + LoadStateResponse loadStateResponse = 5; + """ + def __init__( self, *, - request: global___RpcRequest | None = ..., - response: global___RpcResponse | None = ..., - cloudEvent: cloudevent_pb2.CloudEvent | None = ..., + rpc_id: builtins.str = ..., + destination: builtins.str = ..., + respond_to: builtins.str | None = ..., + rpcMessage: google.protobuf.any_pb2.Any | None = ..., ) -> None: ... - def HasField(self, field_name: typing.Literal["cloudEvent", b"cloudEvent", "message", b"message", "request", b"request", "response", b"response"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["cloudEvent", b"cloudEvent", "message", b"message", "request", b"request", "response", b"response"]) -> None: ... - def WhichOneof(self, oneof_group: typing.Literal["message", b"message"]) -> typing.Literal["request", "response", "cloudEvent"] | None: ... + def HasField(self, field_name: typing.Literal["_respond_to", b"_respond_to", "respond_to", b"respond_to", "rpcMessage", b"rpcMessage"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["_respond_to", b"_respond_to", "destination", b"destination", "respond_to", b"respond_to", "rpcMessage", b"rpcMessage", "rpc_id", b"rpc_id"]) -> None: ... + def WhichOneof(self, oneof_group: typing.Literal["_respond_to", b"_respond_to"]) -> typing.Literal["respond_to"] | None: ... -global___Message = Message +global___ControlMessage = ControlMessage diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2_grpc.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2_grpc.py index 85fd64f42ccb..4a86f17f04ae 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2_grpc.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2_grpc.py @@ -39,15 +39,10 @@ def __init__(self, channel): request_serializer=agent__worker__pb2.Message.SerializeToString, response_deserializer=agent__worker__pb2.Message.FromString, _registered_method=True) - self.GetState = channel.unary_unary( - '/agents.AgentRpc/GetState', - request_serializer=agent__worker__pb2.AgentId.SerializeToString, - response_deserializer=agent__worker__pb2.GetStateResponse.FromString, - _registered_method=True) - self.SaveState = channel.unary_unary( - '/agents.AgentRpc/SaveState', - request_serializer=agent__worker__pb2.AgentState.SerializeToString, - response_deserializer=agent__worker__pb2.SaveStateResponse.FromString, + self.OpenControlChannel = channel.stream_stream( + '/agents.AgentRpc/OpenControlChannel', + request_serializer=agent__worker__pb2.ControlMessage.SerializeToString, + response_deserializer=agent__worker__pb2.ControlMessage.FromString, _registered_method=True) self.RegisterAgent = channel.unary_unary( '/agents.AgentRpc/RegisterAgent', @@ -80,13 +75,7 @@ def OpenChannel(self, request_iterator, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') - def GetState(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def SaveState(self, request, context): + def OpenControlChannel(self, request_iterator, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') @@ -124,15 +113,10 @@ def add_AgentRpcServicer_to_server(servicer, server): request_deserializer=agent__worker__pb2.Message.FromString, response_serializer=agent__worker__pb2.Message.SerializeToString, ), - 'GetState': grpc.unary_unary_rpc_method_handler( - servicer.GetState, - request_deserializer=agent__worker__pb2.AgentId.FromString, - response_serializer=agent__worker__pb2.GetStateResponse.SerializeToString, - ), - 'SaveState': grpc.unary_unary_rpc_method_handler( - servicer.SaveState, - request_deserializer=agent__worker__pb2.AgentState.FromString, - response_serializer=agent__worker__pb2.SaveStateResponse.SerializeToString, + 'OpenControlChannel': grpc.stream_stream_rpc_method_handler( + servicer.OpenControlChannel, + request_deserializer=agent__worker__pb2.ControlMessage.FromString, + response_serializer=agent__worker__pb2.ControlMessage.SerializeToString, ), 'RegisterAgent': grpc.unary_unary_rpc_method_handler( servicer.RegisterAgent, @@ -193,34 +177,7 @@ def OpenChannel(request_iterator, _registered_method=True) @staticmethod - def GetState(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/agents.AgentRpc/GetState', - agent__worker__pb2.AgentId.SerializeToString, - agent__worker__pb2.GetStateResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def SaveState(request, + def OpenControlChannel(request_iterator, target, options=(), channel_credentials=None, @@ -230,12 +187,12 @@ def SaveState(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary( - request, + return grpc.experimental.stream_stream( + request_iterator, target, - '/agents.AgentRpc/SaveState', - agent__worker__pb2.AgentState.SerializeToString, - agent__worker__pb2.SaveStateResponse.FromString, + '/agents.AgentRpc/OpenControlChannel', + agent__worker__pb2.ControlMessage.SerializeToString, + agent__worker__pb2.ControlMessage.FromString, options, channel_credentials, insecure, diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2_grpc.pyi b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2_grpc.pyi index ce8a7c12ec69..cc4311825112 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2_grpc.pyi +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2_grpc.pyi @@ -24,14 +24,9 @@ class AgentRpcStub: agent_worker_pb2.Message, ] - GetState: grpc.UnaryUnaryMultiCallable[ - agent_worker_pb2.AgentId, - agent_worker_pb2.GetStateResponse, - ] - - SaveState: grpc.UnaryUnaryMultiCallable[ - agent_worker_pb2.AgentState, - agent_worker_pb2.SaveStateResponse, + OpenControlChannel: grpc.StreamStreamMultiCallable[ + agent_worker_pb2.ControlMessage, + agent_worker_pb2.ControlMessage, ] RegisterAgent: grpc.UnaryUnaryMultiCallable[ @@ -60,14 +55,9 @@ class AgentRpcAsyncStub: agent_worker_pb2.Message, ] - GetState: grpc.aio.UnaryUnaryMultiCallable[ - agent_worker_pb2.AgentId, - agent_worker_pb2.GetStateResponse, - ] - - SaveState: grpc.aio.UnaryUnaryMultiCallable[ - agent_worker_pb2.AgentState, - agent_worker_pb2.SaveStateResponse, + OpenControlChannel: grpc.aio.StreamStreamMultiCallable[ + agent_worker_pb2.ControlMessage, + agent_worker_pb2.ControlMessage, ] RegisterAgent: grpc.aio.UnaryUnaryMultiCallable[ @@ -99,18 +89,11 @@ class AgentRpcServicer(metaclass=abc.ABCMeta): ) -> typing.Union[collections.abc.Iterator[agent_worker_pb2.Message], collections.abc.AsyncIterator[agent_worker_pb2.Message]]: ... @abc.abstractmethod - def GetState( - self, - request: agent_worker_pb2.AgentId, - context: _ServicerContext, - ) -> typing.Union[agent_worker_pb2.GetStateResponse, collections.abc.Awaitable[agent_worker_pb2.GetStateResponse]]: ... - - @abc.abstractmethod - def SaveState( + def OpenControlChannel( self, - request: agent_worker_pb2.AgentState, + request_iterator: _MaybeAsyncIterator[agent_worker_pb2.ControlMessage], context: _ServicerContext, - ) -> typing.Union[agent_worker_pb2.SaveStateResponse, collections.abc.Awaitable[agent_worker_pb2.SaveStateResponse]]: ... + ) -> typing.Union[collections.abc.Iterator[agent_worker_pb2.ControlMessage], collections.abc.AsyncIterator[agent_worker_pb2.ControlMessage]]: ... @abc.abstractmethod def RegisterAgent( diff --git a/python/packages/autogen-ext/tests/test_websurfer_agent.py b/python/packages/autogen-ext/tests/test_websurfer_agent.py index a2aa33a10931..37423bfe6a50 100644 --- a/python/packages/autogen-ext/tests/test_websurfer_agent.py +++ b/python/packages/autogen-ext/tests/test_websurfer_agent.py @@ -140,7 +140,7 @@ async def test_run_websurfer(monkeypatch: pytest.MonkeyPatch) -> None: result.messages[2] # type: ignore .content[0] # type: ignore .startswith( # type: ignore - "I am waiting a short period of time before taking further action.\n\n Here is a screenshot of the webpage:" + "I am waiting a short period of time before taking further action." ) ) # type: ignore url_after_sleep = agent._page.url # type: ignore