|
| 1 | +package store |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "database/sql" |
| 6 | + "fmt" |
| 7 | + "strings" |
| 8 | + "time" |
| 9 | + |
| 10 | + "github.com/MixinNetwork/safe/common" |
| 11 | + "github.com/MixinNetwork/safe/mtg" |
| 12 | +) |
| 13 | + |
| 14 | +type Notification struct { |
| 15 | + TraceId string |
| 16 | + OpponentId string |
| 17 | + State byte |
| 18 | + CreatedAt time.Time |
| 19 | + UpdatedAt time.Time |
| 20 | +} |
| 21 | + |
| 22 | +var notificationCols = []string{"trace_id", "opponent_id", "state", "created_at", "updated_at"} |
| 23 | + |
| 24 | +func notificationFromRow(row Row) (*Notification, error) { |
| 25 | + var n Notification |
| 26 | + err := row.Scan(&n.TraceId, &n.OpponentId, &n.State, &n.CreatedAt, &n.UpdatedAt) |
| 27 | + if err == sql.ErrNoRows { |
| 28 | + return nil, nil |
| 29 | + } |
| 30 | + return &n, err |
| 31 | +} |
| 32 | + |
| 33 | +func (s *SQLite3Store) writeNotifications(ctx context.Context, tx *sql.Tx, txs []*mtg.Transaction) error { |
| 34 | + now := time.Now().UTC() |
| 35 | + |
| 36 | + for _, t := range txs { |
| 37 | + if len(t.Receivers) != 1 { |
| 38 | + continue |
| 39 | + } |
| 40 | + vals := []any{t.TraceId, t.Receivers[0], common.RequestStateInitial, now, now} |
| 41 | + err := s.execOne(ctx, tx, buildInsertionSQL("tx_notifications", notificationCols), vals...) |
| 42 | + if err != nil { |
| 43 | + return fmt.Errorf("INSERT tx_notifications %v", err) |
| 44 | + } |
| 45 | + } |
| 46 | + |
| 47 | + return nil |
| 48 | +} |
| 49 | + |
| 50 | +func (s *SQLite3Store) ListInitialNotifications(ctx context.Context) ([]*Notification, error) { |
| 51 | + s.mutex.RLock() |
| 52 | + defer s.mutex.RUnlock() |
| 53 | + |
| 54 | + query := fmt.Sprintf("SELECT %s FROM tx_notifications WHERE state=? ORDER BY created_at ASC LIMIT 20", strings.Join(notificationCols, ",")) |
| 55 | + rows, err := s.db.QueryContext(ctx, query, common.RequestStateInitial) |
| 56 | + if err != nil { |
| 57 | + return nil, err |
| 58 | + } |
| 59 | + defer rows.Close() |
| 60 | + |
| 61 | + var ns []*Notification |
| 62 | + for rows.Next() { |
| 63 | + n, err := notificationFromRow(rows) |
| 64 | + if err != nil { |
| 65 | + return nil, err |
| 66 | + } |
| 67 | + ns = append(ns, n) |
| 68 | + } |
| 69 | + return ns, nil |
| 70 | +} |
| 71 | + |
| 72 | +func (s *SQLite3Store) MarkNotificationDone(ctx context.Context, traceId string) error { |
| 73 | + s.mutex.Lock() |
| 74 | + defer s.mutex.Unlock() |
| 75 | + |
| 76 | + tx, err := s.db.BeginTx(ctx, nil) |
| 77 | + if err != nil { |
| 78 | + return err |
| 79 | + } |
| 80 | + defer common.Rollback(tx) |
| 81 | + |
| 82 | + query := "UPDATE tx_notifications SET state=?,updated_at=? WHERE trace_id=? AND state=?" |
| 83 | + err = s.execOne(ctx, tx, query, common.RequestStateDone, time.Now().UTC(), traceId, common.RequestStateInitial) |
| 84 | + if err != nil { |
| 85 | + return fmt.Errorf("UPDATE tx_notifications %v", err) |
| 86 | + } |
| 87 | + |
| 88 | + return tx.Commit() |
| 89 | +} |
0 commit comments