Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 139 additions & 24 deletions db_changes/db/dialect_postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,11 +307,11 @@ func (d PostgresDialect) historyTable(schema string) string {
return fmt.Sprintf("%s.%s", EscapeIdentifier(schema), EscapeIdentifier(d.historyTableName))
}

func (d PostgresDialect) saveInsert(schema string, table string, primaryKey map[string]string, blockNum uint64) string {
func (d PostgresDialect) saveInsert(schema string, table *TableInfo, primaryKey map[string]string, blockNum uint64) string {
return fmt.Sprintf(`INSERT INTO %s (op,table_name,pk,block_num) values (%s,%s,%s,%d);`,
d.historyTable(schema),
escapeStringValue("I"),
escapeStringValue(table),
escapeStringValue(table.identifier),
escapeStringValue(primaryKeyToJSON(primaryKey)),
blockNum,
)
Expand All @@ -321,8 +321,12 @@ func (d PostgresDialect) saveInsert(schema string, table string, primaryKey map[
with t as (select 'default' id)
select CASE WHEN block_meta.id is null THEN 'I' ELSE 'U' END AS op, '"public"."block_meta"', 'allo', row_to_json(block_meta),10 from t left join block_meta on block_meta.id='default';
*/
func (d PostgresDialect) saveUpsert(schema string, escapedTableName string, primaryKey map[string]string, blockNum uint64) string {
schemaAndTable := fmt.Sprintf("%s.%s", EscapeIdentifier(schema), escapedTableName)
func (d PostgresDialect) saveUpsert(schema string, table *TableInfo, primaryKey map[string]string, blockNum uint64) string {
schemaAndTable := table.schemaEscaped + "." + table.nameEscaped
onClause, err := d.getPrimaryKeyWhereClauseTyped(table, primaryKey, table.nameEscaped)
if err != nil {
onClause = getPrimaryKeyWhereClause(primaryKey, table.nameEscaped)
}
Comment on lines +327 to +329
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, add comment as the why it's correct and log error.


return fmt.Sprintf(`
WITH t as (select %s)
Expand All @@ -332,30 +336,34 @@ func (d PostgresDialect) saveUpsert(schema string, escapedTableName string, prim
getPrimaryKeyFakeEmptyValues(primaryKey),
d.historyTable(schema),

getPrimaryKeyFakeEmptyValuesAssertion(primaryKey, escapedTableName),
getPrimaryKeyFakeEmptyValuesAssertion(primaryKey, table.nameEscaped),

escapeStringValue(schemaAndTable), escapeStringValue(primaryKeyToJSON(primaryKey)), escapedTableName, blockNum,
EscapeIdentifier(schema), escapedTableName,
getPrimaryKeyWhereClause(primaryKey, escapedTableName),
escapeStringValue(schemaAndTable), escapeStringValue(primaryKeyToJSON(primaryKey)), table.nameEscaped, blockNum,
table.schemaEscaped, table.nameEscaped,
onClause,
)

}

func (d PostgresDialect) saveUpdate(schema string, escapedTableName string, primaryKey map[string]string, blockNum uint64) string {
return d.saveRow("U", schema, escapedTableName, primaryKey, blockNum)
func (d PostgresDialect) saveUpdate(schema string, table *TableInfo, primaryKey map[string]string, blockNum uint64) string {
return d.saveRow("U", schema, table, primaryKey, blockNum)
}

func (d PostgresDialect) saveDelete(schema string, escapedTableName string, primaryKey map[string]string, blockNum uint64) string {
return d.saveRow("D", schema, escapedTableName, primaryKey, blockNum)
func (d PostgresDialect) saveDelete(schema string, table *TableInfo, primaryKey map[string]string, blockNum uint64) string {
return d.saveRow("D", schema, table, primaryKey, blockNum)
}

func (d PostgresDialect) saveRow(op, schema, escapedTableName string, primaryKey map[string]string, blockNum uint64) string {
schemaAndTable := fmt.Sprintf("%s.%s", EscapeIdentifier(schema), escapedTableName)
func (d PostgresDialect) saveRow(op, schema string, table *TableInfo, primaryKey map[string]string, blockNum uint64) string {
schemaAndTable := table.schemaEscaped + "." + table.nameEscaped
whereClause, err := d.getPrimaryKeyWhereClauseTyped(table, primaryKey, "")
if err != nil {
whereClause = getPrimaryKeyWhereClause(primaryKey, "")
}
Comment on lines +358 to +361
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this a good way of doing it, can you add a comment explaining the reasoning here, we should also log the error minimally in DEBUG mode if it's expected to log a lot of time.

return fmt.Sprintf(`INSERT INTO %s (op,table_name,pk,prev_value,block_num) SELECT %s,%s,%s,row_to_json(%s),%d FROM %s.%s WHERE %s;`,
d.historyTable(schema),
escapeStringValue(op), escapeStringValue(schemaAndTable), escapeStringValue(primaryKeyToJSON(primaryKey)), escapedTableName, blockNum,
EscapeIdentifier(schema), escapedTableName,
getPrimaryKeyWhereClause(primaryKey, ""),
escapeStringValue(op), escapeStringValue(schemaAndTable), escapeStringValue(primaryKeyToJSON(primaryKey)), table.nameEscaped, blockNum,
table.schemaEscaped, table.nameEscaped,
whereClause,
)

}
Expand Down Expand Up @@ -386,7 +394,7 @@ func (d *PostgresDialect) prepareStatement(schema string, o *Operation) (string,
)

if o.reversibleBlockNum != nil {
return d.saveInsert(schema, o.table.identifier, o.primaryKey, *o.reversibleBlockNum) + insertQuery, nil
return d.saveInsert(schema, o.table, o.primaryKey, *o.reversibleBlockNum) + insertQuery, nil
}
return insertQuery, nil

Expand All @@ -405,7 +413,7 @@ func (d *PostgresDialect) prepareStatement(schema string, o *Operation) (string,
)

if o.reversibleBlockNum != nil {
return d.saveUpsert(schema, o.table.nameEscaped, o.primaryKey, *o.reversibleBlockNum) + insertQuery, nil
return d.saveUpsert(schema, o.table, o.primaryKey, *o.reversibleBlockNum) + insertQuery, nil
}
return insertQuery, nil

Expand All @@ -415,7 +423,10 @@ func (d *PostgresDialect) prepareStatement(schema string, o *Operation) (string,
updates[i] = fmt.Sprintf("%s=%s", columns[i], values[i])
}

primaryKeySelector := getPrimaryKeyWhereClause(o.primaryKey, "")
primaryKeySelector, err := d.getPrimaryKeyWhereClauseTyped(o.table, o.primaryKey, "")
if err != nil {
return "", err
}
Comment on lines +426 to +429
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this backward compatible? The question is essentially, if there is already substreams_history entries`, are they going to be read correctly?


updateQuery := fmt.Sprintf("UPDATE %s SET %s WHERE %s",
o.table.identifier,
Expand All @@ -424,18 +435,21 @@ func (d *PostgresDialect) prepareStatement(schema string, o *Operation) (string,
)

if o.reversibleBlockNum != nil {
return d.saveUpdate(schema, o.table.nameEscaped, o.primaryKey, *o.reversibleBlockNum) + updateQuery, nil
return d.saveUpdate(schema, o.table, o.primaryKey, *o.reversibleBlockNum) + updateQuery, nil
}
return updateQuery, nil

case OperationTypeDelete:
primaryKeyWhereClause := getPrimaryKeyWhereClause(o.primaryKey, "")
primaryKeyWhereClause, err := d.getPrimaryKeyWhereClauseTyped(o.table, o.primaryKey, "")
if err != nil {
return "", err
}
deleteQuery := fmt.Sprintf("DELETE FROM %s WHERE %s",
o.table.identifier,
primaryKeyWhereClause,
)
if o.reversibleBlockNum != nil {
return d.saveDelete(schema, o.table.nameEscaped, o.primaryKey, *o.reversibleBlockNum) + deleteQuery, nil
return d.saveDelete(schema, o.table, o.primaryKey, *o.reversibleBlockNum) + deleteQuery, nil
}
return deleteQuery, nil

Expand Down Expand Up @@ -466,7 +480,7 @@ func (d *PostgresDialect) prepareColValues(table *TableInfo, colValues map[strin
return nil, nil, fmt.Errorf("cannot find column %q for table %q (valid columns are %q)", columnName, table.identifier, strings.Join(maps.Keys(table.columnsByName), ", "))
}

normalizedValue, err := d.normalizeValueType(value, columnInfo.scanType)
normalizedValue, err := d.normalizeLiteral(table, columnName, value, "insert")
if err != nil {
return nil, nil, fmt.Errorf("getting sql value from table %s for column %q raw value %q: %w", table.identifier, columnName, value, err)
}
Expand Down Expand Up @@ -535,6 +549,107 @@ func getPrimaryKeyWhereClause(primaryKey map[string]string, escapedTableName str
return strings.Join(reg[:], " AND ")
}

// Build a typed WHERE clause using column types for proper literal normalization
func (d *PostgresDialect) getPrimaryKeyWhereClauseTyped(table *TableInfo, primaryKey map[string]string, escapedTableName string) (string, error) {
// Avoid any allocation if there is a single primary key
if len(primaryKey) == 1 {
for key, value := range primaryKey {
rhs, err := d.normalizeLiteral(table, key, value, "where")
if err != nil {
return "", err
}
if escapedTableName == "" {
return EscapeIdentifier(key) + " = " + rhs, nil
}
return escapedTableName + "." + EscapeIdentifier(key) + " = " + rhs, nil
}
}

reg := make([]string, 0, len(primaryKey))
for key, value := range primaryKey {
rhs, err := d.normalizeLiteral(table, key, value, "where")
if err != nil {
return "", err
}
if escapedTableName == "" {
reg = append(reg, EscapeIdentifier(key)+" = "+rhs)
} else {
reg = append(reg, escapedTableName+"."+EscapeIdentifier(key)+" = "+rhs)
}
}
sort.Strings(reg)
return strings.Join(reg, " AND "), nil
}

// normalizeLiteral centralizes literal formatting based on column database type name and context (insert/where)
// Returns a SQL RHS expression ready to be embedded in a statement
func (d *PostgresDialect) normalizeLiteral(table *TableInfo, columnName string, rawValue string, context string) (string, error) {
colInfo, found := table.columnsByName[columnName]
if !found {
return "", fmt.Errorf("cannot find column %q for table %q (valid columns are %q)", columnName, table.identifier, strings.Join(maps.Keys(table.columnsByName), ", "))
}

dt := strings.ToLower(colInfo.databaseTypeName)
trimmed := strings.TrimSpace(rawValue)

// Array handling using databaseTypeName
if strings.HasSuffix(dt, "[]") {
baseType := strings.TrimSuffix(dt, "[]")
return canonicalizeArrayLiteral(baseType, trimmed), nil
}

// Scalar handling as today using scanType
return d.normalizeValueType(trimmed, colInfo.scanType)
}

// canonicalizeArrayLiteral emits a curly literal with explicit cast to the column's base type array
// Examples:
// - empty -> '{ }'::base[]
// - 'ARRAY[1,2]' -> '{1,2}'::base[]
// - '{1,2}' or "'{1,2}'::text[]" -> '{1,2}'::base[]
func canonicalizeArrayLiteral(baseType string, raw string) string {
// Empty array detection (accept {}, { }, [], ARRAY[])
switch raw {
case "{}", "{ }", "[]", "ARRAY[]", "array[]", "ARRAY []", "array []", "":
return "'{ }'::" + baseType + "[]"
}

upper := strings.ToUpper(raw)
if strings.HasPrefix(upper, "ARRAY[") {
// Extract elements inside ARRAY[...]
end := strings.LastIndex(raw, "]")
inner := ""
if end > 6 { // len("ARRAY[") == 6
inner = strings.TrimSpace(raw[6:end])
}
return "'{" + inner + "}'::" + baseType + "[]"
}

// Strip surrounding single quotes if present
if strings.HasPrefix(raw, "'") && strings.HasSuffix(raw, "'") && len(raw) >= 2 {
raw = strings.TrimSuffix(strings.TrimPrefix(raw, "'"), "'")
}
// Remove any existing cast suffix like }::type[]
if idx := strings.LastIndex(raw, "}::"); idx != -1 {
raw = raw[:idx+1]
}

// If contains curly braces, extract content and rebuild
if l := strings.Index(raw, "{"); l != -1 {
if r := strings.LastIndex(raw, "}"); r != -1 && r > l {
inner := strings.TrimSpace(raw[l+1 : r])
return "'{" + inner + "}'::" + baseType + "[]"
}
}

// Fallback: treat raw as a comma-separated list of elements
inner := strings.TrimSpace(raw)
if strings.HasPrefix(inner, "{") && strings.HasSuffix(inner, "}") && len(inner) >= 2 {
inner = strings.TrimSuffix(strings.TrimPrefix(inner, "{"), "}")
}
return "'{" + inner + "}'::" + baseType + "[]"
}

// Format based on type, value returned unescaped
func (d *PostgresDialect) normalizeValueType(value string, valueType reflect.Type) (string, error) {
switch valueType.Kind() {
Expand Down
Loading