Skip to content

Commit cf8c7d0

Browse files
authored
Merge branch 'main' into dotnet_unit
2 parents 28f953f + fca1de9 commit cf8c7d0

File tree

15 files changed

+679
-9
lines changed

15 files changed

+679
-9
lines changed

dotnet/src/Microsoft.AutoGen/Contracts/AgentId.cs

+14
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
using System.Diagnostics;
55
using System.Diagnostics.CodeAnalysis;
6+
using System.Text.RegularExpressions;
67

78
namespace Microsoft.AutoGen.Contracts;
89

@@ -16,6 +17,9 @@ namespace Microsoft.AutoGen.Contracts;
1617
[DebuggerDisplay($"AgentId(type=\"{nameof(Type)}\", key=\"{nameof(Key)}\")")]
1718
public struct AgentId
1819
{
20+
private static readonly Regex TypeRegex = new(@"^[a-zA-Z_][a-zA-Z0-9_]*$", RegexOptions.Compiled);
21+
private static readonly Regex KeyRegex = new(@"^[\x20-\x7E]+$", RegexOptions.Compiled); // ASCII 32-126
22+
1923
/// <summary>
2024
/// An identifier that associates an agent with a specific factory function.
2125
/// Strings may only be composed of alphanumeric letters (a-z) and (0-9), or underscores (_).
@@ -35,6 +39,16 @@ public struct AgentId
3539
/// <param name="key">Agent instance identifier.</param>
3640
public AgentId(string type, string key)
3741
{
42+
if (string.IsNullOrWhiteSpace(type) || !TypeRegex.IsMatch(type))
43+
{
44+
throw new ArgumentException($"Invalid AgentId type: '{type}'. Must be alphanumeric (a-z, 0-9, _) and cannot start with a number or contain spaces.");
45+
}
46+
47+
if (string.IsNullOrWhiteSpace(key) || !KeyRegex.IsMatch(key))
48+
{
49+
throw new ArgumentException($"Invalid AgentId key: '{key}'. Must only contain ASCII characters 32-126.");
50+
}
51+
3852
Type = type;
3953
Key = key;
4054
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// AgentIdTests.cs
3+
using FluentAssertions;
4+
using Microsoft.AutoGen.Contracts;
5+
using Xunit;
6+
7+
namespace Microsoft.AutoGen.Core.Tests;
8+
9+
public class AgentIdTests()
10+
{
11+
[Fact]
12+
public void AgentIdShouldInitializeCorrectlyTest()
13+
{
14+
var agentId = new AgentId("TestType", "TestKey");
15+
16+
agentId.Type.Should().Be("TestType");
17+
agentId.Key.Should().Be("TestKey");
18+
}
19+
20+
[Fact]
21+
public void AgentIdShouldConvertFromTupleTest()
22+
{
23+
var agentTuple = ("TupleType", "TupleKey");
24+
var agentId = new AgentId(agentTuple);
25+
26+
agentId.Type.Should().Be("TupleType");
27+
agentId.Key.Should().Be("TupleKey");
28+
}
29+
30+
[Fact]
31+
public void AgentIdShouldParseFromStringTest()
32+
{
33+
var agentId = AgentId.FromStr("ParsedType/ParsedKey");
34+
35+
agentId.Type.Should().Be("ParsedType");
36+
agentId.Key.Should().Be("ParsedKey");
37+
}
38+
39+
[Fact]
40+
public void AgentIdShouldCompareEqualityCorrectlyTest()
41+
{
42+
var agentId1 = new AgentId("SameType", "SameKey");
43+
var agentId2 = new AgentId("SameType", "SameKey");
44+
var agentId3 = new AgentId("DifferentType", "DifferentKey");
45+
46+
agentId1.Should().Be(agentId2);
47+
agentId1.Should().NotBe(agentId3);
48+
(agentId1 == agentId2).Should().BeTrue();
49+
(agentId1 != agentId3).Should().BeTrue();
50+
}
51+
52+
[Fact]
53+
public void AgentIdShouldGenerateCorrectHashCodeTest()
54+
{
55+
var agentId1 = new AgentId("HashType", "HashKey");
56+
var agentId2 = new AgentId("HashType", "HashKey");
57+
var agentId3 = new AgentId("DifferentType", "DifferentKey");
58+
59+
agentId1.GetHashCode().Should().Be(agentId2.GetHashCode());
60+
agentId1.GetHashCode().Should().NotBe(agentId3.GetHashCode());
61+
}
62+
63+
[Fact]
64+
public void AgentIdShouldConvertExplicitlyFromStringTest()
65+
{
66+
var agentId = (AgentId)"ConvertedType/ConvertedKey";
67+
68+
agentId.Type.Should().Be("ConvertedType");
69+
agentId.Key.Should().Be("ConvertedKey");
70+
}
71+
72+
[Fact]
73+
public void AgentIdShouldReturnCorrectToStringTest()
74+
{
75+
var agentId = new AgentId("ToStringType", "ToStringKey");
76+
77+
agentId.ToString().Should().Be("ToStringType/ToStringKey");
78+
}
79+
80+
[Fact]
81+
public void AgentIdShouldCompareInequalityCorrectlyTest()
82+
{
83+
var agentId1 = new AgentId("Type1", "Key1");
84+
var agentId2 = new AgentId("Type2", "Key2");
85+
86+
(agentId1 != agentId2).Should().BeTrue();
87+
}
88+
89+
[Fact]
90+
public void AgentIdShouldRejectInvalidNamesTest()
91+
{
92+
// Invalid: 'Type' cannot start with a number and must only contain a-z, 0-9, or underscores.
93+
Action invalidType = () => new AgentId("123InvalidType", "ValidKey");
94+
invalidType.Should().Throw<ArgumentException>("Agent type cannot start with a number and must only contain alphanumeric letters or underscores.");
95+
96+
Action invalidTypeWithSpaces = () => new AgentId("Invalid Type", "ValidKey");
97+
invalidTypeWithSpaces.Should().Throw<ArgumentException>("Agent type cannot contain spaces.");
98+
99+
Action invalidTypeWithSpecialChars = () => new AgentId("Invalid@Type", "ValidKey");
100+
invalidTypeWithSpecialChars.Should().Throw<ArgumentException>("Agent type cannot contain special characters.");
101+
102+
// Invalid: 'Key' must contain only ASCII characters 32 (space) to 126 (~).
103+
Action invalidKey = () => new AgentId("ValidType", "InvalidKey💀");
104+
invalidKey.Should().Throw<ArgumentException>("Agent key must only contain ASCII characters between 32 (space) and 126 (~).");
105+
106+
Action validCase = () => new AgentId("Valid_Type", "Valid_Key_123");
107+
validCase.Should().NotThrow("This is a correctly formatted AgentId.");
108+
}
109+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// AgentMetaDataTests.cs
3+
using FluentAssertions;
4+
using Microsoft.AutoGen.Contracts;
5+
using Xunit;
6+
7+
namespace Microsoft.AutoGen.Core.Tests;
8+
9+
public class AgentMetadataTests()
10+
{
11+
[Fact]
12+
public void AgentMetadataShouldInitializeCorrectlyTest()
13+
{
14+
var metadata = new AgentMetadata("TestType", "TestKey", "TestDescription");
15+
16+
metadata.Type.Should().Be("TestType");
17+
metadata.Key.Should().Be("TestKey");
18+
metadata.Description.Should().Be("TestDescription");
19+
}
20+
}

python/packages/autogen-core/src/autogen_core/models/_model_client.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ class ModelFamily:
2222
O1 = "o1"
2323
GPT_4 = "gpt-4"
2424
GPT_35 = "gpt-35"
25+
R1 = "r1"
2526
UNKNOWN = "unknown"
2627

27-
ANY: TypeAlias = Literal["gpt-4o", "o1", "gpt-4", "gpt-35", "unknown"]
28+
ANY: TypeAlias = Literal["gpt-4o", "o1", "gpt-4", "gpt-35", "r1", "unknown"]
2829

2930
def __new__(cls, *args: Any, **kwargs: Any) -> ModelFamily:
3031
raise TypeError(f"{cls.__name__} is a namespace class and cannot be instantiated.")

python/packages/autogen-core/src/autogen_core/models/_types.py

+35
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,25 @@
88

99

1010
class SystemMessage(BaseModel):
11+
"""System message contains instructions for the model coming from the developer.
12+
13+
.. note::
14+
15+
Open AI is moving away from using 'system' role in favor of 'developer' role.
16+
See `Model Spec <https://cdn.openai.com/spec/model-spec-2024-05-08.html#definitions>`_ for more details.
17+
However, the 'system' role is still allowed in their API and will be automatically converted to 'developer' role
18+
on the server side.
19+
So, you can use `SystemMessage` for developer messages.
20+
21+
"""
22+
1123
content: str
1224
type: Literal["SystemMessage"] = "SystemMessage"
1325

1426

1527
class UserMessage(BaseModel):
28+
"""User message contains input from end users, or a catch-all for data provided to the model."""
29+
1630
content: Union[str, List[Union[str, Image]]]
1731

1832
# Name of the agent that sent this message
@@ -22,6 +36,8 @@ class UserMessage(BaseModel):
2236

2337

2438
class AssistantMessage(BaseModel):
39+
"""Assistant message are sampled from the language model."""
40+
2541
content: Union[str, List[FunctionCall]]
2642

2743
# Name of the agent that sent this message
@@ -31,11 +47,15 @@ class AssistantMessage(BaseModel):
3147

3248

3349
class FunctionExecutionResult(BaseModel):
50+
"""Function execution result contains the output of a function call."""
51+
3452
content: str
3553
call_id: str
3654

3755

3856
class FunctionExecutionResultMessage(BaseModel):
57+
"""Function execution result message contains the output of multiple function calls."""
58+
3959
content: List[FunctionExecutionResult]
4060

4161
type: Literal["FunctionExecutionResultMessage"] = "FunctionExecutionResultMessage"
@@ -69,8 +89,23 @@ class ChatCompletionTokenLogprob(BaseModel):
6989

7090

7191
class CreateResult(BaseModel):
92+
"""Create result contains the output of a model completion."""
93+
7294
finish_reason: FinishReasons
95+
"""The reason the model finished generating the completion."""
96+
7397
content: Union[str, List[FunctionCall]]
98+
"""The output of the model completion."""
99+
74100
usage: RequestUsage
101+
"""The usage of tokens in the prompt and completion."""
102+
75103
cached: bool
104+
"""Whether the completion was generated from a cached response."""
105+
76106
logprobs: Optional[List[ChatCompletionTokenLogprob] | None] = None
107+
"""The logprobs of the tokens in the completion."""
108+
109+
thought: Optional[str] = None
110+
"""The reasoning text for the completion if available. Used for reasoning models
111+
and additional text content besides function calls."""

python/packages/autogen-ext/pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ dev = [
120120
"autogen_test_utils",
121121
"langchain-experimental",
122122
"pandas-stubs>=2.2.3.241126",
123+
"httpx>=0.28.1",
123124
]
124125

125126
[tool.ruff]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import warnings
2+
from typing import Tuple
3+
4+
5+
def parse_r1_content(content: str) -> Tuple[str | None, str]:
6+
"""Parse the content of an R1-style message that contains a `<think>...</think>` field."""
7+
# Find the start and end of the think field
8+
think_start = content.find("<think>")
9+
think_end = content.find("</think>")
10+
11+
if think_start == -1 or think_end == -1:
12+
warnings.warn(
13+
"Could not find <think>..</think> field in model response content. " "No thought was extracted.",
14+
UserWarning,
15+
stacklevel=2,
16+
)
17+
return None, content
18+
19+
if think_end < think_start:
20+
warnings.warn(
21+
"Found </think> before <think> in model response content. " "No thought was extracted.",
22+
UserWarning,
23+
stacklevel=2,
24+
)
25+
return None, content
26+
27+
# Extract the think field
28+
thought = content[think_start + len("<think>") : think_end].strip()
29+
30+
# Extract the rest of the content, skipping the think field.
31+
content = content[think_end + len("</think>") :].strip()
32+
33+
return thought, content

python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py

+15
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
FinishReasons,
1313
FunctionExecutionResultMessage,
1414
LLMMessage,
15+
ModelFamily,
1516
ModelInfo,
1617
RequestUsage,
1718
SystemMessage,
@@ -55,6 +56,8 @@
5556
AzureAIChatCompletionClientConfig,
5657
)
5758

59+
from .._utils.parse_r1_content import parse_r1_content
60+
5861
create_kwargs = set(getfullargspec(ChatCompletionsClient.complete).kwonlyargs)
5962
AzureMessage = Union[AzureSystemMessage, AzureUserMessage, AzureAssistantMessage, AzureToolMessage]
6063

@@ -354,11 +357,17 @@ async def create(
354357
finish_reason = choice.finish_reason # type: ignore
355358
content = choice.message.content or ""
356359

360+
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
361+
thought, content = parse_r1_content(content)
362+
else:
363+
thought = None
364+
357365
response = CreateResult(
358366
finish_reason=finish_reason, # type: ignore
359367
content=content,
360368
usage=usage,
361369
cached=False,
370+
thought=thought,
362371
)
363372

364373
self.add_usage(usage)
@@ -464,11 +473,17 @@ async def create_stream(
464473
prompt_tokens=prompt_tokens,
465474
)
466475

476+
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
477+
thought, content = parse_r1_content(content)
478+
else:
479+
thought = None
480+
467481
result = CreateResult(
468482
finish_reason=finish_reason,
469483
content=content,
470484
usage=usage,
471485
cached=False,
486+
thought=thought,
472487
)
473488

474489
self.add_usage(usage)

0 commit comments

Comments
 (0)