Skip to content

Commit 5c6c5ed

Browse files
committed
chore: Standardize prompt handling in PydanticAI provider
1 parent 3ac936c commit 5c6c5ed

File tree

3 files changed

+21
-32
lines changed

3 files changed

+21
-32
lines changed

src/llmling_agent_providers/pydanticai/provider.py

+10-28
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@
2020
from llmling_agent.common_types import EndStrategy, ModelProtocol
2121
from llmling_agent.log import get_logger
2222
from llmling_agent.messaging.messages import ChatMessage, TokenCost
23-
from llmling_agent.models.content import BaseContent
2423
from llmling_agent.observability import track_action
25-
from llmling_agent.prompts.convert import format_prompts
2624
from llmling_agent.tasks.exceptions import (
2725
ChainAbortedError,
2826
RunAbortedError,
@@ -31,6 +29,7 @@
3129
from llmling_agent.utils.inspection import execute, has_argument_type
3230
from llmling_agent_providers.base import AgentLLMProvider, ProviderResponse, UsageLimits
3331
from llmling_agent_providers.pydanticai.utils import (
32+
convert_prompts_to_user_content,
3433
format_part,
3534
get_tool_calls,
3635
to_model_message,
@@ -252,25 +251,15 @@ async def generate_response(
252251
use_model = infer_model(use_model)
253252
self.model_changed.emit(use_model)
254253
try:
255-
text_prompts = [p for p in prompts if isinstance(p, str)]
256-
content_prompts = [p for p in prompts if isinstance(p, BaseContent)]
257-
258-
# Get normal text prompt
259-
prompt = await format_prompts(text_prompts)
260-
261-
# Convert Content objects to ModelMessages
262-
if content_prompts:
263-
prompts_msgs = [
264-
ChatMessage(role="user", content=p) for p in content_prompts
265-
]
266-
message_history = [*message_history, *prompts_msgs]
254+
# Convert prompts to pydantic-ai format
255+
converted_prompts = await convert_prompts_to_user_content(prompts)
267256

268257
# Run with complete history
269258
to_use = model or self.model
270259
to_use = infer_model(to_use) if isinstance(to_use, str) else to_use
271260
limits = asdict(usage_limits) if usage_limits else {}
272261
result: AgentRunResult = await agent.run(
273-
prompt,
262+
converted_prompts, # Pass converted prompts
274263
deps=self._context, # type: ignore
275264
message_history=[to_model_message(m) for m in message_history],
276265
model=to_use, # type: ignore
@@ -292,10 +281,11 @@ async def generate_response(
292281
use_model.model_name if isinstance(use_model, Model) else str(use_model)
293282
)
294283
usage = result.usage()
295-
cost_str = prompt + str(content_prompts) # dirty
284+
# Create input content representation for cost calculations
285+
cost_input = "\n".join(str(p) for p in prompts)
296286
cost_info = (
297287
await TokenCost.from_usage(
298-
usage, resolved_model, cost_str, str(result.data)
288+
usage, resolved_model, cost_input, str(result.data)
299289
)
300290
if resolved_model and usage
301291
else None
@@ -376,22 +366,14 @@ async def stream_response( # type: ignore[override]
376366
if model:
377367
self.model_changed.emit(use_model)
378368

379-
text_prompts = [p for p in prompts if isinstance(p, str)]
380-
content_prompts = [p for p in prompts if isinstance(p, BaseContent)]
381-
382-
# Get normal text prompt
383-
prompt = await format_prompts(text_prompts)
384-
385-
# Convert Content objects to ChatMessages
386-
if content_prompts:
387-
prompts_msgs = [ChatMessage(role="user", content=p) for p in content_prompts]
388-
message_history = [*message_history, *prompts_msgs]
369+
# Convert prompts to pydantic-ai format
370+
converted_prompts = await convert_prompts_to_user_content(prompts)
389371

390372
# Convert all messages to pydantic-ai format
391373
model_messages = [to_model_message(m) for m in message_history]
392374

393375
async with agent.run_stream(
394-
prompt,
376+
converted_prompts,
395377
deps=self._context,
396378
message_history=model_messages,
397379
model=model or self.model, # type: ignore

src/llmling_agent_providers/pydanticai/utils.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def to_model_message(message: ChatMessage[str | Content]) -> ModelMessage:
241241
return ModelRequest(parts=[part_cls(content=message.content)])
242242

243243

244-
def convert_prompts_to_user_content(
244+
async def convert_prompts_to_user_content(
245245
prompts: Sequence[str | Content],
246246
) -> list[str | PydanticUserContent]:
247247
"""Convert our prompts to pydantic-ai compatible format.
@@ -252,13 +252,20 @@ def convert_prompts_to_user_content(
252252
Returns:
253253
List of strings and pydantic-ai UserContent objects
254254
"""
255+
from llmling_agent.prompts.convert import format_prompts
255256
from llmling_agent_providers.pydanticai.convert_content import content_to_pydantic_ai
256257

258+
# Special case: if we only have string prompts, format them together
259+
# if all(isinstance(p, str) for p in prompts):
260+
# formatted = await format_prompts(prompts)
261+
# return [formatted]
262+
263+
# Otherwise, process each item individually in order
257264
result = []
258265
for p in prompts:
259266
if isinstance(p, str):
260-
# Simply pass through string prompts directly
261-
result.append(p)
267+
formatted = await format_prompts([p])
268+
result.append(formatted)
262269
elif p_content := content_to_pydantic_ai(p):
263270
result.append(p_content)
264271

tests/test_agent.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ async def test_agent_streaming_pydanticai_history(test_agent: Agent[None]):
7373
# Check prompt message
7474
assert isinstance(new_messages[0], ModelRequest)
7575
assert isinstance(new_messages[0].parts[0], UserPromptPart)
76-
assert new_messages[0].parts[0].content == SIMPLE_PROMPT
76+
assert new_messages[0].parts[0].content == [SIMPLE_PROMPT]
7777

7878
# Check response message
7979
assert isinstance(new_messages[1], ModelResponse)

0 commit comments

Comments
 (0)