Skip to content

Commit 1af32ad

Browse files
committed
chore: Fix complex test mocking in UsageTrackingMiddlewareTests
Fixes #678
1 parent 2bb5f62 commit 1af32ad

11 files changed

Lines changed: 2231 additions & 806 deletions
Lines changed: 346 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,346 @@
1+
using Microsoft.Extensions.Logging;
2+
using Moq;
3+
using Xunit;
4+
using ConduitLLM.Configuration.DTOs;
5+
using ConduitLLM.Configuration.Entities;
6+
using ConduitLLM.Configuration.Interfaces;
7+
using ConduitLLM.Core.Interfaces;
8+
using ConduitLLM.Core.Models;
9+
using ConduitLLM.Gateway.Middleware;
10+
using IVirtualKeyService = ConduitLLM.Core.Interfaces.IVirtualKeyService;
11+
12+
namespace ConduitLLM.Tests.Http.Middleware.Assertions
13+
{
14+
/// <summary>
15+
/// Assertion helpers for UsageTrackingMiddleware tests.
16+
/// Provides fluent, readable assertions for common verification patterns.
17+
/// </summary>
18+
public static class UsageTrackingAssertions
19+
{
20+
/// <summary>
21+
/// Verifies that cost was calculated for the specified model with expected usage.
22+
/// </summary>
23+
/// <param name="costService">The mocked cost service.</param>
24+
/// <param name="expectedModel">The expected model name.</param>
25+
/// <param name="expectedPromptTokens">Optional expected prompt token count.</param>
26+
/// <param name="expectedCompletionTokens">Optional expected completion token count.</param>
27+
public static void VerifyCostCalculated(
28+
Mock<ICostCalculationService> costService,
29+
string expectedModel,
30+
int? expectedPromptTokens = null,
31+
int? expectedCompletionTokens = null)
32+
{
33+
costService.Verify(x => x.CalculateCostAsync(
34+
expectedModel,
35+
It.Is<Usage>(u =>
36+
(!expectedPromptTokens.HasValue || u.PromptTokens == expectedPromptTokens) &&
37+
(!expectedCompletionTokens.HasValue || u.CompletionTokens == expectedCompletionTokens)),
38+
It.IsAny<CancellationToken>()), Times.Once);
39+
}
40+
41+
/// <summary>
42+
/// Verifies that cost was calculated for any model.
43+
/// </summary>
44+
/// <param name="costService">The mocked cost service.</param>
45+
public static void VerifyCostCalculatedOnce(Mock<ICostCalculationService> costService)
46+
{
47+
costService.Verify(
48+
x => x.CalculateCostAsync(It.IsAny<string>(), It.IsAny<Usage>(), It.IsAny<CancellationToken>()),
49+
Times.Once);
50+
}
51+
52+
/// <summary>
53+
/// Verifies that no cost calculation was performed.
54+
/// </summary>
55+
/// <param name="costService">The mocked cost service.</param>
56+
public static void VerifyNoCostCalculation(Mock<ICostCalculationService> costService)
57+
{
58+
costService.Verify(
59+
x => x.CalculateCostAsync(It.IsAny<string>(), It.IsAny<Usage>(), It.IsAny<CancellationToken>()),
60+
Times.Never);
61+
}
62+
63+
/// <summary>
64+
/// Verifies that spend was queued via the batch service.
65+
/// </summary>
66+
/// <param name="batchService">The mocked batch spend service.</param>
67+
/// <param name="expectedVirtualKeyId">The expected virtual key ID.</param>
68+
/// <param name="expectedCost">The expected cost amount.</param>
69+
public static void VerifySpendQueued(
70+
Mock<IBatchSpendUpdateService> batchService,
71+
int expectedVirtualKeyId,
72+
decimal expectedCost)
73+
{
74+
batchService.Verify(
75+
x => x.QueueSpendUpdate(expectedVirtualKeyId, expectedCost),
76+
Times.Once);
77+
}
78+
79+
/// <summary>
80+
/// Verifies that spend was queued for any amount.
81+
/// </summary>
82+
/// <param name="batchService">The mocked batch spend service.</param>
83+
/// <param name="expectedVirtualKeyId">The expected virtual key ID.</param>
84+
public static void VerifySpendQueuedAny(
85+
Mock<IBatchSpendUpdateService> batchService,
86+
int expectedVirtualKeyId)
87+
{
88+
batchService.Verify(
89+
x => x.QueueSpendUpdate(expectedVirtualKeyId, It.IsAny<decimal>()),
90+
Times.Once);
91+
}
92+
93+
/// <summary>
94+
/// Verifies that no spend updates occurred.
95+
/// </summary>
96+
/// <param name="batchService">The mocked batch spend service.</param>
97+
/// <param name="virtualKeyService">The mocked virtual key service.</param>
98+
public static void VerifyNoSpendUpdate(
99+
Mock<IBatchSpendUpdateService> batchService,
100+
Mock<IVirtualKeyService> virtualKeyService)
101+
{
102+
batchService.Verify(
103+
x => x.QueueSpendUpdate(It.IsAny<int>(), It.IsAny<decimal>()),
104+
Times.Never);
105+
virtualKeyService.Verify(
106+
x => x.UpdateSpendAsync(It.IsAny<int>(), It.IsAny<decimal>()),
107+
Times.Never);
108+
}
109+
110+
/// <summary>
111+
/// Verifies that direct spend update was called (fallback when batch service is unhealthy).
112+
/// </summary>
113+
/// <param name="virtualKeyService">The mocked virtual key service.</param>
114+
/// <param name="expectedVirtualKeyId">The expected virtual key ID.</param>
115+
/// <param name="expectedCost">The expected cost amount.</param>
116+
public static void VerifyDirectSpendUpdate(
117+
Mock<IVirtualKeyService> virtualKeyService,
118+
int expectedVirtualKeyId,
119+
decimal expectedCost)
120+
{
121+
virtualKeyService.Verify(
122+
x => x.UpdateSpendAsync(expectedVirtualKeyId, expectedCost),
123+
Times.Once);
124+
}
125+
126+
/// <summary>
127+
/// Verifies request was logged with expected properties.
128+
/// </summary>
129+
/// <param name="requestLogService">The mocked request log service.</param>
130+
/// <param name="assertions">Action to perform assertions on the captured DTO.</param>
131+
public static void VerifyRequestLogged(
132+
Mock<IRequestLogService> requestLogService,
133+
Action<LogRequestDto> assertions)
134+
{
135+
LogRequestDto? capturedDto = null;
136+
requestLogService.Verify(x => x.LogRequestAsync(It.IsAny<LogRequestDto>()), Times.Once);
137+
138+
// Extract the captured DTO from the invocations
139+
var invocation = requestLogService.Invocations
140+
.FirstOrDefault(i => i.Method.Name == "LogRequestAsync");
141+
if (invocation != null && invocation.Arguments.Count > 0)
142+
{
143+
capturedDto = invocation.Arguments[0] as LogRequestDto;
144+
}
145+
146+
Assert.NotNull(capturedDto);
147+
assertions(capturedDto!);
148+
}
149+
150+
/// <summary>
151+
/// Verifies request was logged for a specific virtual key and model.
152+
/// </summary>
153+
/// <param name="requestLogService">The mocked request log service.</param>
154+
/// <param name="expectedVirtualKeyId">The expected virtual key ID.</param>
155+
/// <param name="expectedModel">The expected model name.</param>
156+
/// <param name="expectedRequestType">The expected request type.</param>
157+
public static void VerifyRequestLogged(
158+
Mock<IRequestLogService> requestLogService,
159+
int expectedVirtualKeyId,
160+
string expectedModel,
161+
string expectedRequestType = "chat")
162+
{
163+
requestLogService.Verify(x => x.LogRequestAsync(It.Is<LogRequestDto>(dto =>
164+
dto.VirtualKeyId == expectedVirtualKeyId &&
165+
dto.ModelName == expectedModel &&
166+
dto.RequestType == expectedRequestType)), Times.Once);
167+
}
168+
169+
/// <summary>
170+
/// Verifies that no request was logged.
171+
/// </summary>
172+
/// <param name="requestLogService">The mocked request log service.</param>
173+
public static void VerifyNoRequestLogged(Mock<IRequestLogService> requestLogService)
174+
{
175+
requestLogService.Verify(
176+
x => x.LogRequestAsync(It.IsAny<LogRequestDto>()),
177+
Times.Never);
178+
}
179+
180+
/// <summary>
181+
/// Verifies that cached tokens were tracked correctly for Anthropic responses.
182+
/// </summary>
183+
/// <param name="costService">The mocked cost service.</param>
184+
/// <param name="model">The expected model name.</param>
185+
/// <param name="expectedCacheCreation">Expected cache creation tokens.</param>
186+
/// <param name="expectedCacheRead">Expected cache read tokens.</param>
187+
public static void VerifyAnthropicCaching(
188+
Mock<ICostCalculationService> costService,
189+
string model,
190+
int expectedCacheCreation,
191+
int expectedCacheRead)
192+
{
193+
costService.Verify(x => x.CalculateCostAsync(
194+
model,
195+
It.Is<Usage>(u =>
196+
u.CachedWriteTokens == expectedCacheCreation &&
197+
u.CachedInputTokens == expectedCacheRead),
198+
It.IsAny<CancellationToken>()), Times.Once);
199+
}
200+
201+
/// <summary>
202+
/// Verifies image usage was tracked correctly.
203+
/// </summary>
204+
/// <param name="costService">The mocked cost service.</param>
205+
/// <param name="model">The expected model name.</param>
206+
/// <param name="expectedImageCount">The expected number of images.</param>
207+
/// <param name="expectedQuality">Optional expected image quality.</param>
208+
/// <param name="expectedSize">Optional expected image size.</param>
209+
public static void VerifyImageUsage(
210+
Mock<ICostCalculationService> costService,
211+
string model,
212+
int expectedImageCount,
213+
string? expectedQuality = null,
214+
string? expectedSize = null)
215+
{
216+
costService.Verify(x => x.CalculateCostAsync(
217+
model,
218+
It.Is<Usage>(u =>
219+
u.ImageCount == expectedImageCount &&
220+
(expectedQuality == null || u.ImageQuality == expectedQuality) &&
221+
(expectedSize == null || u.ImageResolution == expectedSize)),
222+
It.IsAny<CancellationToken>()), Times.Once);
223+
}
224+
225+
/// <summary>
226+
/// Verifies billing audit event was captured with expected type.
227+
/// </summary>
228+
/// <param name="events">The list of captured billing events.</param>
229+
/// <param name="expectedType">The expected event type.</param>
230+
/// <param name="additionalAssertions">Optional additional assertions.</param>
231+
public static void VerifyBillingEvent(
232+
List<BillingAuditEvent> events,
233+
BillingAuditEventType expectedType,
234+
Action<BillingAuditEvent>? additionalAssertions = null)
235+
{
236+
var matchingEvent = events.FirstOrDefault(e => e.EventType == expectedType);
237+
Assert.NotNull(matchingEvent);
238+
additionalAssertions?.Invoke(matchingEvent);
239+
}
240+
241+
/// <summary>
242+
/// Verifies that a billing event with tool usage was captured.
243+
/// </summary>
244+
/// <param name="events">The list of captured billing events.</param>
245+
/// <param name="expectedToolName">The expected tool name in the JSON.</param>
246+
/// <param name="expectedToolCost">The expected tool usage cost.</param>
247+
public static void VerifyToolUsageBillingEvent(
248+
List<BillingAuditEvent> events,
249+
string expectedToolName,
250+
decimal expectedToolCost)
251+
{
252+
var evt = events.FirstOrDefault(e => e.EventType == BillingAuditEventType.ToolUsageTracked);
253+
Assert.NotNull(evt);
254+
Assert.NotNull(evt.ToolUsageJson);
255+
Assert.Contains(expectedToolName, evt.ToolUsageJson);
256+
Assert.Equal(expectedToolCost, evt.ToolUsageCost);
257+
}
258+
259+
/// <summary>
260+
/// Verifies that no billing events were captured.
261+
/// </summary>
262+
/// <param name="events">The list of captured billing events.</param>
263+
public static void VerifyNoBillingEvents(List<BillingAuditEvent> events)
264+
{
265+
Assert.Empty(events);
266+
}
267+
268+
/// <summary>
269+
/// Verifies that exactly one billing event was captured.
270+
/// </summary>
271+
/// <param name="events">The list of captured billing events.</param>
272+
/// <returns>The single captured event for further assertions.</returns>
273+
public static BillingAuditEvent VerifySingleBillingEvent(List<BillingAuditEvent> events)
274+
{
275+
Assert.Single(events);
276+
return events[0];
277+
}
278+
279+
/// <summary>
280+
/// Verifies that a specific log message was emitted.
281+
/// </summary>
282+
/// <param name="logger">The mocked logger.</param>
283+
/// <param name="level">The expected log level.</param>
284+
/// <param name="containsMessage">Text that should be in the log message.</param>
285+
public static void VerifyLogMessage(
286+
Mock<ILogger<UsageTrackingMiddleware>> logger,
287+
LogLevel level,
288+
string containsMessage)
289+
{
290+
logger.Verify(
291+
x => x.Log(
292+
level,
293+
It.IsAny<EventId>(),
294+
It.Is<It.IsAnyType>((v, t) => v != null && v.ToString()!.Contains(containsMessage)),
295+
It.IsAny<Exception?>(),
296+
It.IsAny<Func<It.IsAnyType, Exception?, string>>()),
297+
Times.AtLeastOnce);
298+
}
299+
300+
/// <summary>
301+
/// Verifies that a debug log message was emitted.
302+
/// </summary>
303+
/// <param name="logger">The mocked logger.</param>
304+
/// <param name="containsMessage">Text that should be in the log message.</param>
305+
public static void VerifyDebugLog(
306+
Mock<ILogger<UsageTrackingMiddleware>> logger,
307+
string containsMessage)
308+
{
309+
VerifyLogMessage(logger, LogLevel.Debug, containsMessage);
310+
}
311+
312+
/// <summary>
313+
/// Verifies that a warning log message was emitted.
314+
/// </summary>
315+
/// <param name="logger">The mocked logger.</param>
316+
/// <param name="containsMessage">Text that should be in the log message.</param>
317+
public static void VerifyWarningLog(
318+
Mock<ILogger<UsageTrackingMiddleware>> logger,
319+
string containsMessage)
320+
{
321+
VerifyLogMessage(logger, LogLevel.Warning, containsMessage);
322+
}
323+
324+
/// <summary>
325+
/// Verifies the full usage tracking flow for a successful request.
326+
/// </summary>
327+
/// <param name="costService">The mocked cost service.</param>
328+
/// <param name="batchService">The mocked batch spend service.</param>
329+
/// <param name="requestLogService">The mocked request log service.</param>
330+
/// <param name="model">The expected model name.</param>
331+
/// <param name="virtualKeyId">The expected virtual key ID.</param>
332+
/// <param name="cost">The expected cost.</param>
333+
public static void VerifySuccessfulUsageTracking(
334+
Mock<ICostCalculationService> costService,
335+
Mock<IBatchSpendUpdateService> batchService,
336+
Mock<IRequestLogService> requestLogService,
337+
string model,
338+
int virtualKeyId,
339+
decimal cost)
340+
{
341+
VerifyCostCalculated(costService, model);
342+
VerifySpendQueued(batchService, virtualKeyId, cost);
343+
VerifyRequestLogged(requestLogService, virtualKeyId, model);
344+
}
345+
}
346+
}

0 commit comments

Comments
 (0)