From 824dcade69c993a14770924da023036d55bdd3f0 Mon Sep 17 00:00:00 2001 From: Loc Nguyen Date: Mon, 25 Aug 2025 20:32:09 +0700 Subject: [PATCH 1/2] fix(postgresql): loosen requirements for getPrimaryKeyWhereClause --- db_changes/db/dialect_postgres.go | 43 ++++++++++++++++++++++++++++--- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/db_changes/db/dialect_postgres.go b/db_changes/db/dialect_postgres.go index 2143745..53dc723 100644 --- a/db_changes/db/dialect_postgres.go +++ b/db_changes/db/dialect_postgres.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "reflect" + "regexp" "slices" "sort" "strconv" @@ -513,21 +514,23 @@ func getPrimaryKeyWhereClause(primaryKey map[string]string, escapedTableName str // Avoid any allocation if there is a single primary key if len(primaryKey) == 1 { for key, value := range primaryKey { + formatted := formatWhereValue(value) if escapedTableName == "" { - return EscapeIdentifier(key) + " = " + escapeStringValue(value) + return EscapeIdentifier(key) + " = " + formatted } - return escapedTableName + "." + EscapeIdentifier(key) + " = " + escapeStringValue(value) + return escapedTableName + "." + EscapeIdentifier(key) + " = " + formatted } } reg := make([]string, 0, len(primaryKey)) for key, value := range primaryKey { + formatted := formatWhereValue(value) if escapedTableName == "" { - reg = append(reg, EscapeIdentifier(key)+" = "+escapeStringValue(value)) + reg = append(reg, EscapeIdentifier(key)+" = "+formatted) } else { - reg = append(reg, escapedTableName+"."+EscapeIdentifier(key)+" = "+escapeStringValue(value)) + reg = append(reg, escapedTableName+"."+EscapeIdentifier(key)+" = "+formatted) } } sort.Strings(reg) @@ -535,6 +538,38 @@ func getPrimaryKeyWhereClause(primaryKey map[string]string, escapedTableName str return strings.Join(reg[:], " AND ") } +var arrayConstructorRegex = regexp.MustCompile(`^ARRAY\[(.*)\](::[A-Za-z0-9_]+(\[\])*)?$`) + +// formatWhereValue prepares a primary key value for use on the right-hand side +// of a WHERE equality comparison. It keeps valid SQL expressions unquoted and +// converts ARRAY[...]::type[] constructs to proper array literals ('{...}') so +// that Postgres can type-cast them to the column's array type during comparison. +func formatWhereValue(value string) string { + // Handle ARRAY[...]::type[] (including empty ARRAY[]) + if matches := arrayConstructorRegex.FindStringSubmatch(value); matches != nil { + inner := matches[1] // may be empty + cast := "" + if len(matches) >= 3 { + cast = matches[2] // includes leading :: if present + } + // Convert to a curly-brace array literal, quote it, and preserve explicit cast when provided + return escapeStringValue("{"+inner+"}") + cast + } + + // If already looks like a raw array literal (e.g., {1,2}) without quotes, quote it + if strings.HasPrefix(value, "{") && strings.HasSuffix(value, "}") { + return escapeStringValue(value) + } + + // If value appears already quoted and casted (e.g., '{}''::bigint[]'), keep as-is + if strings.HasPrefix(value, "'{") && strings.Contains(value, "}'::") { + return value + } + + // Default: treat as a plain string literal + return escapeStringValue(value) +} + // Format based on type, value returned unescaped func (d *PostgresDialect) normalizeValueType(value string, valueType reflect.Type) (string, error) { switch valueType.Kind() { From 47b8b843e4b5c94c37888ec51dab5486f0d95642 Mon Sep 17 00:00:00 2001 From: Loc Nguyen Date: Fri, 29 Aug 2025 13:43:11 +0700 Subject: [PATCH 2/2] fix(postgres): generalize PgSQL typing for INSERT/UPDATE clauses --- db_changes/db/dialect_postgres.go | 186 +++++++++++++++++++++--------- 1 file changed, 133 insertions(+), 53 deletions(-) diff --git a/db_changes/db/dialect_postgres.go b/db_changes/db/dialect_postgres.go index 53dc723..84867e9 100644 --- a/db_changes/db/dialect_postgres.go +++ b/db_changes/db/dialect_postgres.go @@ -7,7 +7,6 @@ import ( "encoding/json" "fmt" "reflect" - "regexp" "slices" "sort" "strconv" @@ -308,11 +307,11 @@ func (d PostgresDialect) historyTable(schema string) string { return fmt.Sprintf("%s.%s", EscapeIdentifier(schema), EscapeIdentifier(d.historyTableName)) } -func (d PostgresDialect) saveInsert(schema string, table string, primaryKey map[string]string, blockNum uint64) string { +func (d PostgresDialect) saveInsert(schema string, table *TableInfo, primaryKey map[string]string, blockNum uint64) string { return fmt.Sprintf(`INSERT INTO %s (op,table_name,pk,block_num) values (%s,%s,%s,%d);`, d.historyTable(schema), escapeStringValue("I"), - escapeStringValue(table), + escapeStringValue(table.identifier), escapeStringValue(primaryKeyToJSON(primaryKey)), blockNum, ) @@ -322,8 +321,12 @@ func (d PostgresDialect) saveInsert(schema string, table string, primaryKey map[ with t as (select 'default' id) select CASE WHEN block_meta.id is null THEN 'I' ELSE 'U' END AS op, '"public"."block_meta"', 'allo', row_to_json(block_meta),10 from t left join block_meta on block_meta.id='default'; */ -func (d PostgresDialect) saveUpsert(schema string, escapedTableName string, primaryKey map[string]string, blockNum uint64) string { - schemaAndTable := fmt.Sprintf("%s.%s", EscapeIdentifier(schema), escapedTableName) +func (d PostgresDialect) saveUpsert(schema string, table *TableInfo, primaryKey map[string]string, blockNum uint64) string { + schemaAndTable := table.schemaEscaped + "." + table.nameEscaped + onClause, err := d.getPrimaryKeyWhereClauseTyped(table, primaryKey, table.nameEscaped) + if err != nil { + onClause = getPrimaryKeyWhereClause(primaryKey, table.nameEscaped) + } return fmt.Sprintf(` WITH t as (select %s) @@ -333,30 +336,34 @@ func (d PostgresDialect) saveUpsert(schema string, escapedTableName string, prim getPrimaryKeyFakeEmptyValues(primaryKey), d.historyTable(schema), - getPrimaryKeyFakeEmptyValuesAssertion(primaryKey, escapedTableName), + getPrimaryKeyFakeEmptyValuesAssertion(primaryKey, table.nameEscaped), - escapeStringValue(schemaAndTable), escapeStringValue(primaryKeyToJSON(primaryKey)), escapedTableName, blockNum, - EscapeIdentifier(schema), escapedTableName, - getPrimaryKeyWhereClause(primaryKey, escapedTableName), + escapeStringValue(schemaAndTable), escapeStringValue(primaryKeyToJSON(primaryKey)), table.nameEscaped, blockNum, + table.schemaEscaped, table.nameEscaped, + onClause, ) } -func (d PostgresDialect) saveUpdate(schema string, escapedTableName string, primaryKey map[string]string, blockNum uint64) string { - return d.saveRow("U", schema, escapedTableName, primaryKey, blockNum) +func (d PostgresDialect) saveUpdate(schema string, table *TableInfo, primaryKey map[string]string, blockNum uint64) string { + return d.saveRow("U", schema, table, primaryKey, blockNum) } -func (d PostgresDialect) saveDelete(schema string, escapedTableName string, primaryKey map[string]string, blockNum uint64) string { - return d.saveRow("D", schema, escapedTableName, primaryKey, blockNum) +func (d PostgresDialect) saveDelete(schema string, table *TableInfo, primaryKey map[string]string, blockNum uint64) string { + return d.saveRow("D", schema, table, primaryKey, blockNum) } -func (d PostgresDialect) saveRow(op, schema, escapedTableName string, primaryKey map[string]string, blockNum uint64) string { - schemaAndTable := fmt.Sprintf("%s.%s", EscapeIdentifier(schema), escapedTableName) +func (d PostgresDialect) saveRow(op, schema string, table *TableInfo, primaryKey map[string]string, blockNum uint64) string { + schemaAndTable := table.schemaEscaped + "." + table.nameEscaped + whereClause, err := d.getPrimaryKeyWhereClauseTyped(table, primaryKey, "") + if err != nil { + whereClause = getPrimaryKeyWhereClause(primaryKey, "") + } return fmt.Sprintf(`INSERT INTO %s (op,table_name,pk,prev_value,block_num) SELECT %s,%s,%s,row_to_json(%s),%d FROM %s.%s WHERE %s;`, d.historyTable(schema), - escapeStringValue(op), escapeStringValue(schemaAndTable), escapeStringValue(primaryKeyToJSON(primaryKey)), escapedTableName, blockNum, - EscapeIdentifier(schema), escapedTableName, - getPrimaryKeyWhereClause(primaryKey, ""), + escapeStringValue(op), escapeStringValue(schemaAndTable), escapeStringValue(primaryKeyToJSON(primaryKey)), table.nameEscaped, blockNum, + table.schemaEscaped, table.nameEscaped, + whereClause, ) } @@ -387,7 +394,7 @@ func (d *PostgresDialect) prepareStatement(schema string, o *Operation) (string, ) if o.reversibleBlockNum != nil { - return d.saveInsert(schema, o.table.identifier, o.primaryKey, *o.reversibleBlockNum) + insertQuery, nil + return d.saveInsert(schema, o.table, o.primaryKey, *o.reversibleBlockNum) + insertQuery, nil } return insertQuery, nil @@ -406,7 +413,7 @@ func (d *PostgresDialect) prepareStatement(schema string, o *Operation) (string, ) if o.reversibleBlockNum != nil { - return d.saveUpsert(schema, o.table.nameEscaped, o.primaryKey, *o.reversibleBlockNum) + insertQuery, nil + return d.saveUpsert(schema, o.table, o.primaryKey, *o.reversibleBlockNum) + insertQuery, nil } return insertQuery, nil @@ -416,7 +423,10 @@ func (d *PostgresDialect) prepareStatement(schema string, o *Operation) (string, updates[i] = fmt.Sprintf("%s=%s", columns[i], values[i]) } - primaryKeySelector := getPrimaryKeyWhereClause(o.primaryKey, "") + primaryKeySelector, err := d.getPrimaryKeyWhereClauseTyped(o.table, o.primaryKey, "") + if err != nil { + return "", err + } updateQuery := fmt.Sprintf("UPDATE %s SET %s WHERE %s", o.table.identifier, @@ -425,18 +435,21 @@ func (d *PostgresDialect) prepareStatement(schema string, o *Operation) (string, ) if o.reversibleBlockNum != nil { - return d.saveUpdate(schema, o.table.nameEscaped, o.primaryKey, *o.reversibleBlockNum) + updateQuery, nil + return d.saveUpdate(schema, o.table, o.primaryKey, *o.reversibleBlockNum) + updateQuery, nil } return updateQuery, nil case OperationTypeDelete: - primaryKeyWhereClause := getPrimaryKeyWhereClause(o.primaryKey, "") + primaryKeyWhereClause, err := d.getPrimaryKeyWhereClauseTyped(o.table, o.primaryKey, "") + if err != nil { + return "", err + } deleteQuery := fmt.Sprintf("DELETE FROM %s WHERE %s", o.table.identifier, primaryKeyWhereClause, ) if o.reversibleBlockNum != nil { - return d.saveDelete(schema, o.table.nameEscaped, o.primaryKey, *o.reversibleBlockNum) + deleteQuery, nil + return d.saveDelete(schema, o.table, o.primaryKey, *o.reversibleBlockNum) + deleteQuery, nil } return deleteQuery, nil @@ -467,7 +480,7 @@ func (d *PostgresDialect) prepareColValues(table *TableInfo, colValues map[strin return nil, nil, fmt.Errorf("cannot find column %q for table %q (valid columns are %q)", columnName, table.identifier, strings.Join(maps.Keys(table.columnsByName), ", ")) } - normalizedValue, err := d.normalizeValueType(value, columnInfo.scanType) + normalizedValue, err := d.normalizeLiteral(table, columnName, value, "insert") if err != nil { return nil, nil, fmt.Errorf("getting sql value from table %s for column %q raw value %q: %w", table.identifier, columnName, value, err) } @@ -514,23 +527,21 @@ func getPrimaryKeyWhereClause(primaryKey map[string]string, escapedTableName str // Avoid any allocation if there is a single primary key if len(primaryKey) == 1 { for key, value := range primaryKey { - formatted := formatWhereValue(value) if escapedTableName == "" { - return EscapeIdentifier(key) + " = " + formatted + return EscapeIdentifier(key) + " = " + escapeStringValue(value) } - return escapedTableName + "." + EscapeIdentifier(key) + " = " + formatted + return escapedTableName + "." + EscapeIdentifier(key) + " = " + escapeStringValue(value) } } reg := make([]string, 0, len(primaryKey)) for key, value := range primaryKey { - formatted := formatWhereValue(value) if escapedTableName == "" { - reg = append(reg, EscapeIdentifier(key)+" = "+formatted) + reg = append(reg, EscapeIdentifier(key)+" = "+escapeStringValue(value)) } else { - reg = append(reg, escapedTableName+"."+EscapeIdentifier(key)+" = "+formatted) + reg = append(reg, escapedTableName+"."+EscapeIdentifier(key)+" = "+escapeStringValue(value)) } } sort.Strings(reg) @@ -538,36 +549,105 @@ func getPrimaryKeyWhereClause(primaryKey map[string]string, escapedTableName str return strings.Join(reg[:], " AND ") } -var arrayConstructorRegex = regexp.MustCompile(`^ARRAY\[(.*)\](::[A-Za-z0-9_]+(\[\])*)?$`) +// Build a typed WHERE clause using column types for proper literal normalization +func (d *PostgresDialect) getPrimaryKeyWhereClauseTyped(table *TableInfo, primaryKey map[string]string, escapedTableName string) (string, error) { + // Avoid any allocation if there is a single primary key + if len(primaryKey) == 1 { + for key, value := range primaryKey { + rhs, err := d.normalizeLiteral(table, key, value, "where") + if err != nil { + return "", err + } + if escapedTableName == "" { + return EscapeIdentifier(key) + " = " + rhs, nil + } + return escapedTableName + "." + EscapeIdentifier(key) + " = " + rhs, nil + } + } + + reg := make([]string, 0, len(primaryKey)) + for key, value := range primaryKey { + rhs, err := d.normalizeLiteral(table, key, value, "where") + if err != nil { + return "", err + } + if escapedTableName == "" { + reg = append(reg, EscapeIdentifier(key)+" = "+rhs) + } else { + reg = append(reg, escapedTableName+"."+EscapeIdentifier(key)+" = "+rhs) + } + } + sort.Strings(reg) + return strings.Join(reg, " AND "), nil +} + +// normalizeLiteral centralizes literal formatting based on column database type name and context (insert/where) +// Returns a SQL RHS expression ready to be embedded in a statement +func (d *PostgresDialect) normalizeLiteral(table *TableInfo, columnName string, rawValue string, context string) (string, error) { + colInfo, found := table.columnsByName[columnName] + if !found { + return "", fmt.Errorf("cannot find column %q for table %q (valid columns are %q)", columnName, table.identifier, strings.Join(maps.Keys(table.columnsByName), ", ")) + } -// formatWhereValue prepares a primary key value for use on the right-hand side -// of a WHERE equality comparison. It keeps valid SQL expressions unquoted and -// converts ARRAY[...]::type[] constructs to proper array literals ('{...}') so -// that Postgres can type-cast them to the column's array type during comparison. -func formatWhereValue(value string) string { - // Handle ARRAY[...]::type[] (including empty ARRAY[]) - if matches := arrayConstructorRegex.FindStringSubmatch(value); matches != nil { - inner := matches[1] // may be empty - cast := "" - if len(matches) >= 3 { - cast = matches[2] // includes leading :: if present + dt := strings.ToLower(colInfo.databaseTypeName) + trimmed := strings.TrimSpace(rawValue) + + // Array handling using databaseTypeName + if strings.HasSuffix(dt, "[]") { + baseType := strings.TrimSuffix(dt, "[]") + return canonicalizeArrayLiteral(baseType, trimmed), nil + } + + // Scalar handling as today using scanType + return d.normalizeValueType(trimmed, colInfo.scanType) +} + +// canonicalizeArrayLiteral emits a curly literal with explicit cast to the column's base type array +// Examples: +// - empty -> '{ }'::base[] +// - 'ARRAY[1,2]' -> '{1,2}'::base[] +// - '{1,2}' or "'{1,2}'::text[]" -> '{1,2}'::base[] +func canonicalizeArrayLiteral(baseType string, raw string) string { + // Empty array detection (accept {}, { }, [], ARRAY[]) + switch raw { + case "{}", "{ }", "[]", "ARRAY[]", "array[]", "ARRAY []", "array []", "": + return "'{ }'::" + baseType + "[]" + } + + upper := strings.ToUpper(raw) + if strings.HasPrefix(upper, "ARRAY[") { + // Extract elements inside ARRAY[...] + end := strings.LastIndex(raw, "]") + inner := "" + if end > 6 { // len("ARRAY[") == 6 + inner = strings.TrimSpace(raw[6:end]) } - // Convert to a curly-brace array literal, quote it, and preserve explicit cast when provided - return escapeStringValue("{"+inner+"}") + cast + return "'{" + inner + "}'::" + baseType + "[]" } - // If already looks like a raw array literal (e.g., {1,2}) without quotes, quote it - if strings.HasPrefix(value, "{") && strings.HasSuffix(value, "}") { - return escapeStringValue(value) + // Strip surrounding single quotes if present + if strings.HasPrefix(raw, "'") && strings.HasSuffix(raw, "'") && len(raw) >= 2 { + raw = strings.TrimSuffix(strings.TrimPrefix(raw, "'"), "'") + } + // Remove any existing cast suffix like }::type[] + if idx := strings.LastIndex(raw, "}::"); idx != -1 { + raw = raw[:idx+1] } - // If value appears already quoted and casted (e.g., '{}''::bigint[]'), keep as-is - if strings.HasPrefix(value, "'{") && strings.Contains(value, "}'::") { - return value + // If contains curly braces, extract content and rebuild + if l := strings.Index(raw, "{"); l != -1 { + if r := strings.LastIndex(raw, "}"); r != -1 && r > l { + inner := strings.TrimSpace(raw[l+1 : r]) + return "'{" + inner + "}'::" + baseType + "[]" + } } - // Default: treat as a plain string literal - return escapeStringValue(value) + // Fallback: treat raw as a comma-separated list of elements + inner := strings.TrimSpace(raw) + if strings.HasPrefix(inner, "{") && strings.HasSuffix(inner, "}") && len(inner) >= 2 { + inner = strings.TrimSuffix(strings.TrimPrefix(inner, "{"), "}") + } + return "'{" + inner + "}'::" + baseType + "[]" } // Format based on type, value returned unescaped