Skip to content

Commit 05a81fc

Browse files
authored
Add ability to filter an account's state changes using different parameters (#344)
1 parent 62d27e7 commit 05a81fc

File tree

11 files changed

+801
-47
lines changed

11 files changed

+801
-47
lines changed

internal/data/statechanges.go

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,47 +18,89 @@ type StateChangeModel struct {
1818
MetricsService metrics.MetricsService
1919
}
2020

21-
// BatchGetByAccountAddress gets the state changes that are associated with the given account addresses.
22-
func (m *StateChangeModel) BatchGetByAccountAddress(ctx context.Context, accountAddress string, columns string, limit *int32, cursor *types.StateChangeCursor, sortOrder SortOrder) ([]*types.StateChangeWithCursor, error) {
21+
// BatchGetByAccountAddress gets the state changes that are associated with the given account address.
22+
// Optional filters: txHash, operationID, category, and reason can be used to further filter results.
23+
func (m *StateChangeModel) BatchGetByAccountAddress(ctx context.Context, accountAddress string, txHash *string, operationID *int64, category *string, reason *string, columns string, limit *int32, cursor *types.StateChangeCursor, sortOrder SortOrder) ([]*types.StateChangeWithCursor, error) {
2324
columns = prepareColumnsWithID(columns, types.StateChange{}, "", "to_id", "state_change_order")
2425
var queryBuilder strings.Builder
26+
args := []interface{}{accountAddress}
27+
argIndex := 2
28+
2529
queryBuilder.WriteString(fmt.Sprintf(`
2630
SELECT %s, to_id as "cursor.cursor_to_id", state_change_order as "cursor.cursor_state_change_order"
27-
FROM state_changes
31+
FROM state_changes
2832
WHERE account_id = $1
2933
`, columns))
3034

35+
// Add transaction hash filter if provided
36+
if txHash != nil {
37+
queryBuilder.WriteString(fmt.Sprintf(" AND tx_hash = $%d", argIndex))
38+
args = append(args, *txHash)
39+
argIndex++
40+
}
41+
42+
// Add operation ID filter if provided
43+
if operationID != nil {
44+
queryBuilder.WriteString(fmt.Sprintf(" AND operation_id = $%d", argIndex))
45+
args = append(args, *operationID)
46+
argIndex++
47+
}
48+
49+
// Add category filter if provided
50+
if category != nil {
51+
queryBuilder.WriteString(fmt.Sprintf(" AND state_change_category = $%d", argIndex))
52+
args = append(args, *category)
53+
argIndex++
54+
}
55+
56+
// Add reason filter if provided
57+
if reason != nil {
58+
queryBuilder.WriteString(fmt.Sprintf(" AND state_change_reason = $%d", argIndex))
59+
args = append(args, *reason)
60+
argIndex++
61+
}
62+
63+
// Add cursor-based pagination using parameterized queries
3164
if cursor != nil {
3265
if sortOrder == DESC {
3366
queryBuilder.WriteString(fmt.Sprintf(`
34-
AND (to_id < %d OR (to_id = %d AND state_change_order < %d))
35-
`, cursor.ToID, cursor.ToID, cursor.StateChangeOrder))
67+
AND (to_id < $%d OR (to_id = $%d AND state_change_order < $%d))
68+
`, argIndex, argIndex, argIndex+1))
69+
args = append(args, cursor.ToID, cursor.StateChangeOrder)
70+
argIndex += 2
3671
} else {
3772
queryBuilder.WriteString(fmt.Sprintf(`
38-
AND (to_id > %d OR (to_id = %d AND state_change_order > %d))
39-
`, cursor.ToID, cursor.ToID, cursor.StateChangeOrder))
73+
AND (to_id > $%d OR (to_id = $%d AND state_change_order > $%d))
74+
`, argIndex, argIndex, argIndex+1))
75+
args = append(args, cursor.ToID, cursor.StateChangeOrder)
76+
argIndex += 2
4077
}
4178
}
4279

80+
// TODO: Extract the ordering code to separate function in utils and use everywhere
81+
// Add ordering
4382
if sortOrder == DESC {
4483
queryBuilder.WriteString(" ORDER BY to_id DESC, state_change_order DESC")
4584
} else {
4685
queryBuilder.WriteString(" ORDER BY to_id ASC, state_change_order ASC")
4786
}
4887

88+
// Add limit using parameterized query
4989
if limit != nil && *limit > 0 {
50-
queryBuilder.WriteString(fmt.Sprintf(" LIMIT %d", *limit))
90+
queryBuilder.WriteString(fmt.Sprintf(" LIMIT $%d", argIndex))
91+
args = append(args, *limit)
5192
}
5293

5394
query := queryBuilder.String()
5495

96+
// For backward pagination, wrap query to reverse the final order
5597
if sortOrder == DESC {
5698
query = fmt.Sprintf(`SELECT * FROM (%s) AS statechanges ORDER BY to_id ASC, state_change_order ASC`, query)
5799
}
58100

59101
var stateChanges []*types.StateChangeWithCursor
60102
start := time.Now()
61-
err := m.DB.SelectContext(ctx, &stateChanges, query, accountAddress)
103+
err := m.DB.SelectContext(ctx, &stateChanges, query, args...)
62104
duration := time.Since(start).Seconds()
63105
m.MetricsService.ObserveDBQueryDuration("SELECT", "state_changes", duration)
64106
if err != nil {

internal/data/statechanges_test.go

Lines changed: 234 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,22 +215,254 @@ func TestStateChangeModel_BatchGetByAccountAddress(t *testing.T) {
215215
}
216216

217217
// Test BatchGetByAccount for address1
218-
stateChanges, err := m.BatchGetByAccountAddress(ctx, address1, "", nil, nil, ASC)
218+
stateChanges, err := m.BatchGetByAccountAddress(ctx, address1, nil, nil, nil, nil, "", nil, nil, ASC)
219219
require.NoError(t, err)
220220
assert.Len(t, stateChanges, 2)
221221
for _, sc := range stateChanges {
222222
assert.Equal(t, address1, sc.AccountID)
223223
}
224224

225225
// Test BatchGetByAccount for address2
226-
stateChanges, err = m.BatchGetByAccountAddress(ctx, address2, "", nil, nil, ASC)
226+
stateChanges, err = m.BatchGetByAccountAddress(ctx, address2, nil, nil, nil, nil, "", nil, nil, ASC)
227227
require.NoError(t, err)
228228
assert.Len(t, stateChanges, 1)
229229
for _, sc := range stateChanges {
230230
assert.Equal(t, address2, sc.AccountID)
231231
}
232232
}
233233

234+
func TestStateChangeModel_BatchGetByAccountAddress_WithFilters(t *testing.T) {
235+
dbt := dbtest.Open(t)
236+
defer dbt.Close()
237+
dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN)
238+
require.NoError(t, err)
239+
defer dbConnectionPool.Close()
240+
241+
ctx := context.Background()
242+
now := time.Now()
243+
244+
// Create test account
245+
address := keypair.MustRandom().Address()
246+
_, err = dbConnectionPool.ExecContext(ctx, "INSERT INTO accounts (stellar_address) VALUES ($1)", address)
247+
require.NoError(t, err)
248+
249+
// Create test transactions
250+
_, err = dbConnectionPool.ExecContext(ctx, `
251+
INSERT INTO transactions (hash, to_id, envelope_xdr, result_xdr, meta_xdr, ledger_number, ledger_created_at)
252+
VALUES
253+
('tx1', 1, 'env1', 'res1', 'meta1', 1, $1),
254+
('tx2', 2, 'env2', 'res2', 'meta2', 2, $1),
255+
('tx3', 3, 'env3', 'res3', 'meta3', 3, $1)
256+
`, now)
257+
require.NoError(t, err)
258+
259+
// Create test state changes with different operation IDs, transaction hashes, categories, and reasons
260+
_, err = dbConnectionPool.ExecContext(ctx, `
261+
INSERT INTO state_changes (to_id, state_change_order, state_change_category, state_change_reason, ledger_created_at, ledger_number, account_id, operation_id, tx_hash)
262+
VALUES
263+
(1, 1, 'BALANCE', 'CREDIT', $1, 1, $2, 123, 'tx1'),
264+
(2, 1, 'BALANCE', 'DEBIT', $1, 2, $2, 456, 'tx2'),
265+
(3, 1, 'SIGNER', 'ADD', $1, 3, $2, 789, 'tx3'),
266+
(4, 1, 'BALANCE', 'DEBIT', $1, 4, $2, 123, 'tx1'),
267+
(5, 1, 'SIGNER', 'ADD', $1, 5, $2, 999, 'tx2')
268+
`, now, address)
269+
require.NoError(t, err)
270+
271+
t.Run("filter by transaction hash only", func(t *testing.T) {
272+
mockMetricsService := metrics.NewMockMetricsService()
273+
mockMetricsService.On("ObserveDBQueryDuration", "SELECT", "state_changes", mock.Anything).Return().Once()
274+
mockMetricsService.On("IncDBQuery", "SELECT", "state_changes").Return().Once()
275+
defer mockMetricsService.AssertExpectations(t)
276+
277+
m := &StateChangeModel{
278+
DB: dbConnectionPool,
279+
MetricsService: mockMetricsService,
280+
}
281+
282+
txHash := "tx1"
283+
stateChanges, err := m.BatchGetByAccountAddress(ctx, address, &txHash, nil, nil, nil, "", nil, nil, ASC)
284+
require.NoError(t, err)
285+
assert.Len(t, stateChanges, 2)
286+
for _, sc := range stateChanges {
287+
assert.Equal(t, "tx1", sc.TxHash)
288+
assert.Equal(t, address, sc.AccountID)
289+
}
290+
})
291+
292+
t.Run("filter by operation ID only", func(t *testing.T) {
293+
mockMetricsService := metrics.NewMockMetricsService()
294+
mockMetricsService.On("ObserveDBQueryDuration", "SELECT", "state_changes", mock.Anything).Return().Once()
295+
mockMetricsService.On("IncDBQuery", "SELECT", "state_changes").Return().Once()
296+
defer mockMetricsService.AssertExpectations(t)
297+
298+
m := &StateChangeModel{
299+
DB: dbConnectionPool,
300+
MetricsService: mockMetricsService,
301+
}
302+
303+
operationID := int64(123)
304+
stateChanges, err := m.BatchGetByAccountAddress(ctx, address, nil, &operationID, nil, nil, "", nil, nil, ASC)
305+
require.NoError(t, err)
306+
assert.Len(t, stateChanges, 2)
307+
for _, sc := range stateChanges {
308+
assert.Equal(t, int64(123), sc.OperationID)
309+
assert.Equal(t, address, sc.AccountID)
310+
}
311+
})
312+
313+
t.Run("filter by both transaction hash and operation ID", func(t *testing.T) {
314+
mockMetricsService := metrics.NewMockMetricsService()
315+
mockMetricsService.On("ObserveDBQueryDuration", "SELECT", "state_changes", mock.Anything).Return().Once()
316+
mockMetricsService.On("IncDBQuery", "SELECT", "state_changes").Return().Once()
317+
defer mockMetricsService.AssertExpectations(t)
318+
319+
m := &StateChangeModel{
320+
DB: dbConnectionPool,
321+
MetricsService: mockMetricsService,
322+
}
323+
324+
txHash := "tx1"
325+
operationID := int64(123)
326+
stateChanges, err := m.BatchGetByAccountAddress(ctx, address, &txHash, &operationID, nil, nil, "", nil, nil, ASC)
327+
require.NoError(t, err)
328+
// Should get only state changes that match BOTH filters
329+
assert.Len(t, stateChanges, 2)
330+
for _, sc := range stateChanges {
331+
assert.Equal(t, "tx1", sc.TxHash)
332+
assert.Equal(t, int64(123), sc.OperationID)
333+
assert.Equal(t, address, sc.AccountID)
334+
}
335+
})
336+
337+
t.Run("filter by category only", func(t *testing.T) {
338+
mockMetricsService := metrics.NewMockMetricsService()
339+
mockMetricsService.On("ObserveDBQueryDuration", "SELECT", "state_changes", mock.Anything).Return().Once()
340+
mockMetricsService.On("IncDBQuery", "SELECT", "state_changes").Return().Once()
341+
defer mockMetricsService.AssertExpectations(t)
342+
343+
m := &StateChangeModel{
344+
DB: dbConnectionPool,
345+
MetricsService: mockMetricsService,
346+
}
347+
348+
category := "BALANCE"
349+
stateChanges, err := m.BatchGetByAccountAddress(ctx, address, nil, nil, &category, nil, "", nil, nil, ASC)
350+
require.NoError(t, err)
351+
assert.Len(t, stateChanges, 3)
352+
for _, sc := range stateChanges {
353+
assert.Equal(t, types.StateChangeCategoryBalance, sc.StateChangeCategory)
354+
assert.Equal(t, address, sc.AccountID)
355+
}
356+
})
357+
358+
t.Run("filter by reason only", func(t *testing.T) {
359+
mockMetricsService := metrics.NewMockMetricsService()
360+
mockMetricsService.On("ObserveDBQueryDuration", "SELECT", "state_changes", mock.Anything).Return().Once()
361+
mockMetricsService.On("IncDBQuery", "SELECT", "state_changes").Return().Once()
362+
defer mockMetricsService.AssertExpectations(t)
363+
364+
m := &StateChangeModel{
365+
DB: dbConnectionPool,
366+
MetricsService: mockMetricsService,
367+
}
368+
369+
reason := "ADD"
370+
stateChanges, err := m.BatchGetByAccountAddress(ctx, address, nil, nil, nil, &reason, "", nil, nil, ASC)
371+
require.NoError(t, err)
372+
assert.Len(t, stateChanges, 2)
373+
for _, sc := range stateChanges {
374+
assert.Equal(t, types.StateChangeReasonAdd, *sc.StateChangeReason)
375+
assert.Equal(t, address, sc.AccountID)
376+
}
377+
})
378+
379+
t.Run("filter by both category and reason", func(t *testing.T) {
380+
mockMetricsService := metrics.NewMockMetricsService()
381+
mockMetricsService.On("ObserveDBQueryDuration", "SELECT", "state_changes", mock.Anything).Return().Once()
382+
mockMetricsService.On("IncDBQuery", "SELECT", "state_changes").Return().Once()
383+
defer mockMetricsService.AssertExpectations(t)
384+
385+
m := &StateChangeModel{
386+
DB: dbConnectionPool,
387+
MetricsService: mockMetricsService,
388+
}
389+
390+
category := "SIGNER"
391+
reason := "ADD"
392+
stateChanges, err := m.BatchGetByAccountAddress(ctx, address, nil, nil, &category, &reason, "", nil, nil, ASC)
393+
require.NoError(t, err)
394+
assert.Len(t, stateChanges, 2)
395+
for _, sc := range stateChanges {
396+
assert.Equal(t, types.StateChangeCategorySigner, sc.StateChangeCategory)
397+
assert.Equal(t, types.StateChangeReasonAdd, *sc.StateChangeReason)
398+
assert.Equal(t, address, sc.AccountID)
399+
}
400+
})
401+
402+
t.Run("filter with all filters - txHash, operationID, category, reason", func(t *testing.T) {
403+
mockMetricsService := metrics.NewMockMetricsService()
404+
mockMetricsService.On("ObserveDBQueryDuration", "SELECT", "state_changes", mock.Anything).Return().Once()
405+
mockMetricsService.On("IncDBQuery", "SELECT", "state_changes").Return().Once()
406+
defer mockMetricsService.AssertExpectations(t)
407+
408+
m := &StateChangeModel{
409+
DB: dbConnectionPool,
410+
MetricsService: mockMetricsService,
411+
}
412+
413+
txHash := "tx1"
414+
operationID := int64(123)
415+
category := "BALANCE"
416+
reason := "CREDIT"
417+
stateChanges, err := m.BatchGetByAccountAddress(ctx, address, &txHash, &operationID, &category, &reason, "", nil, nil, ASC)
418+
require.NoError(t, err)
419+
assert.Len(t, stateChanges, 1)
420+
for _, sc := range stateChanges {
421+
assert.Equal(t, "tx1", sc.TxHash)
422+
assert.Equal(t, int64(123), sc.OperationID)
423+
assert.Equal(t, types.StateChangeCategoryBalance, sc.StateChangeCategory)
424+
assert.Equal(t, types.StateChangeReasonCredit, *sc.StateChangeReason)
425+
assert.Equal(t, address, sc.AccountID)
426+
}
427+
})
428+
429+
t.Run("filter with no matching results", func(t *testing.T) {
430+
mockMetricsService := metrics.NewMockMetricsService()
431+
mockMetricsService.On("ObserveDBQueryDuration", "SELECT", "state_changes", mock.Anything).Return().Once()
432+
mockMetricsService.On("IncDBQuery", "SELECT", "state_changes").Return().Once()
433+
defer mockMetricsService.AssertExpectations(t)
434+
435+
m := &StateChangeModel{
436+
DB: dbConnectionPool,
437+
MetricsService: mockMetricsService,
438+
}
439+
440+
txHash := "nonexistent"
441+
stateChanges, err := m.BatchGetByAccountAddress(ctx, address, &txHash, nil, nil, nil, "", nil, nil, ASC)
442+
require.NoError(t, err)
443+
assert.Empty(t, stateChanges)
444+
})
445+
446+
t.Run("filter with pagination", func(t *testing.T) {
447+
mockMetricsService := metrics.NewMockMetricsService()
448+
mockMetricsService.On("ObserveDBQueryDuration", "SELECT", "state_changes", mock.Anything).Return().Once()
449+
mockMetricsService.On("IncDBQuery", "SELECT", "state_changes").Return().Once()
450+
defer mockMetricsService.AssertExpectations(t)
451+
452+
m := &StateChangeModel{
453+
DB: dbConnectionPool,
454+
MetricsService: mockMetricsService,
455+
}
456+
457+
txHash := "tx1"
458+
limit := int32(1)
459+
stateChanges, err := m.BatchGetByAccountAddress(ctx, address, &txHash, nil, nil, nil, "", &limit, nil, ASC)
460+
require.NoError(t, err)
461+
assert.Len(t, stateChanges, 1)
462+
assert.Equal(t, "tx1", stateChanges[0].TxHash)
463+
})
464+
}
465+
234466
func TestStateChangeModel_GetAll(t *testing.T) {
235467
dbt := dbtest.Open(t)
236468
defer dbt.Close()

0 commit comments

Comments
 (0)