-
Notifications
You must be signed in to change notification settings - Fork 28
fix(postgresql): loosen requirements for getPrimaryKeyWhereClause #103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -307,11 +307,11 @@ func (d PostgresDialect) historyTable(schema string) string { | |
| return fmt.Sprintf("%s.%s", EscapeIdentifier(schema), EscapeIdentifier(d.historyTableName)) | ||
| } | ||
|
|
||
| func (d PostgresDialect) saveInsert(schema string, table string, primaryKey map[string]string, blockNum uint64) string { | ||
| func (d PostgresDialect) saveInsert(schema string, table *TableInfo, primaryKey map[string]string, blockNum uint64) string { | ||
| return fmt.Sprintf(`INSERT INTO %s (op,table_name,pk,block_num) values (%s,%s,%s,%d);`, | ||
| d.historyTable(schema), | ||
| escapeStringValue("I"), | ||
| escapeStringValue(table), | ||
| escapeStringValue(table.identifier), | ||
| escapeStringValue(primaryKeyToJSON(primaryKey)), | ||
| blockNum, | ||
| ) | ||
|
|
@@ -321,8 +321,12 @@ func (d PostgresDialect) saveInsert(schema string, table string, primaryKey map[ | |
| with t as (select 'default' id) | ||
| select CASE WHEN block_meta.id is null THEN 'I' ELSE 'U' END AS op, '"public"."block_meta"', 'allo', row_to_json(block_meta),10 from t left join block_meta on block_meta.id='default'; | ||
| */ | ||
| func (d PostgresDialect) saveUpsert(schema string, escapedTableName string, primaryKey map[string]string, blockNum uint64) string { | ||
| schemaAndTable := fmt.Sprintf("%s.%s", EscapeIdentifier(schema), escapedTableName) | ||
| func (d PostgresDialect) saveUpsert(schema string, table *TableInfo, primaryKey map[string]string, blockNum uint64) string { | ||
| schemaAndTable := table.schemaEscaped + "." + table.nameEscaped | ||
| onClause, err := d.getPrimaryKeyWhereClauseTyped(table, primaryKey, table.nameEscaped) | ||
| if err != nil { | ||
| onClause = getPrimaryKeyWhereClause(primaryKey, table.nameEscaped) | ||
| } | ||
|
|
||
| return fmt.Sprintf(` | ||
| WITH t as (select %s) | ||
|
|
@@ -332,30 +336,34 @@ func (d PostgresDialect) saveUpsert(schema string, escapedTableName string, prim | |
| getPrimaryKeyFakeEmptyValues(primaryKey), | ||
| d.historyTable(schema), | ||
|
|
||
| getPrimaryKeyFakeEmptyValuesAssertion(primaryKey, escapedTableName), | ||
| getPrimaryKeyFakeEmptyValuesAssertion(primaryKey, table.nameEscaped), | ||
|
|
||
| escapeStringValue(schemaAndTable), escapeStringValue(primaryKeyToJSON(primaryKey)), escapedTableName, blockNum, | ||
| EscapeIdentifier(schema), escapedTableName, | ||
| getPrimaryKeyWhereClause(primaryKey, escapedTableName), | ||
| escapeStringValue(schemaAndTable), escapeStringValue(primaryKeyToJSON(primaryKey)), table.nameEscaped, blockNum, | ||
| table.schemaEscaped, table.nameEscaped, | ||
| onClause, | ||
| ) | ||
|
|
||
| } | ||
|
|
||
| func (d PostgresDialect) saveUpdate(schema string, escapedTableName string, primaryKey map[string]string, blockNum uint64) string { | ||
| return d.saveRow("U", schema, escapedTableName, primaryKey, blockNum) | ||
| func (d PostgresDialect) saveUpdate(schema string, table *TableInfo, primaryKey map[string]string, blockNum uint64) string { | ||
| return d.saveRow("U", schema, table, primaryKey, blockNum) | ||
| } | ||
|
|
||
| func (d PostgresDialect) saveDelete(schema string, escapedTableName string, primaryKey map[string]string, blockNum uint64) string { | ||
| return d.saveRow("D", schema, escapedTableName, primaryKey, blockNum) | ||
| func (d PostgresDialect) saveDelete(schema string, table *TableInfo, primaryKey map[string]string, blockNum uint64) string { | ||
| return d.saveRow("D", schema, table, primaryKey, blockNum) | ||
| } | ||
|
|
||
| func (d PostgresDialect) saveRow(op, schema, escapedTableName string, primaryKey map[string]string, blockNum uint64) string { | ||
| schemaAndTable := fmt.Sprintf("%s.%s", EscapeIdentifier(schema), escapedTableName) | ||
| func (d PostgresDialect) saveRow(op, schema string, table *TableInfo, primaryKey map[string]string, blockNum uint64) string { | ||
| schemaAndTable := table.schemaEscaped + "." + table.nameEscaped | ||
| whereClause, err := d.getPrimaryKeyWhereClauseTyped(table, primaryKey, "") | ||
| if err != nil { | ||
| whereClause = getPrimaryKeyWhereClause(primaryKey, "") | ||
| } | ||
|
Comment on lines
+358
to
+361
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this a good way of doing it, can you add a comment explaining the reasoning here, we should also log the error minimally in DEBUG mode if it's expected to log a lot of time. |
||
| return fmt.Sprintf(`INSERT INTO %s (op,table_name,pk,prev_value,block_num) SELECT %s,%s,%s,row_to_json(%s),%d FROM %s.%s WHERE %s;`, | ||
| d.historyTable(schema), | ||
| escapeStringValue(op), escapeStringValue(schemaAndTable), escapeStringValue(primaryKeyToJSON(primaryKey)), escapedTableName, blockNum, | ||
| EscapeIdentifier(schema), escapedTableName, | ||
| getPrimaryKeyWhereClause(primaryKey, ""), | ||
| escapeStringValue(op), escapeStringValue(schemaAndTable), escapeStringValue(primaryKeyToJSON(primaryKey)), table.nameEscaped, blockNum, | ||
| table.schemaEscaped, table.nameEscaped, | ||
| whereClause, | ||
| ) | ||
|
|
||
| } | ||
|
|
@@ -386,7 +394,7 @@ func (d *PostgresDialect) prepareStatement(schema string, o *Operation) (string, | |
| ) | ||
|
|
||
| if o.reversibleBlockNum != nil { | ||
| return d.saveInsert(schema, o.table.identifier, o.primaryKey, *o.reversibleBlockNum) + insertQuery, nil | ||
| return d.saveInsert(schema, o.table, o.primaryKey, *o.reversibleBlockNum) + insertQuery, nil | ||
| } | ||
| return insertQuery, nil | ||
|
|
||
|
|
@@ -405,7 +413,7 @@ func (d *PostgresDialect) prepareStatement(schema string, o *Operation) (string, | |
| ) | ||
|
|
||
| if o.reversibleBlockNum != nil { | ||
| return d.saveUpsert(schema, o.table.nameEscaped, o.primaryKey, *o.reversibleBlockNum) + insertQuery, nil | ||
| return d.saveUpsert(schema, o.table, o.primaryKey, *o.reversibleBlockNum) + insertQuery, nil | ||
| } | ||
| return insertQuery, nil | ||
|
|
||
|
|
@@ -415,7 +423,10 @@ func (d *PostgresDialect) prepareStatement(schema string, o *Operation) (string, | |
| updates[i] = fmt.Sprintf("%s=%s", columns[i], values[i]) | ||
| } | ||
|
|
||
| primaryKeySelector := getPrimaryKeyWhereClause(o.primaryKey, "") | ||
| primaryKeySelector, err := d.getPrimaryKeyWhereClauseTyped(o.table, o.primaryKey, "") | ||
| if err != nil { | ||
| return "", err | ||
| } | ||
|
Comment on lines
+426
to
+429
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this backward compatible? The question is essentially, if there is already |
||
|
|
||
| updateQuery := fmt.Sprintf("UPDATE %s SET %s WHERE %s", | ||
| o.table.identifier, | ||
|
|
@@ -424,18 +435,21 @@ func (d *PostgresDialect) prepareStatement(schema string, o *Operation) (string, | |
| ) | ||
|
|
||
| if o.reversibleBlockNum != nil { | ||
| return d.saveUpdate(schema, o.table.nameEscaped, o.primaryKey, *o.reversibleBlockNum) + updateQuery, nil | ||
| return d.saveUpdate(schema, o.table, o.primaryKey, *o.reversibleBlockNum) + updateQuery, nil | ||
| } | ||
| return updateQuery, nil | ||
|
|
||
| case OperationTypeDelete: | ||
| primaryKeyWhereClause := getPrimaryKeyWhereClause(o.primaryKey, "") | ||
| primaryKeyWhereClause, err := d.getPrimaryKeyWhereClauseTyped(o.table, o.primaryKey, "") | ||
| if err != nil { | ||
| return "", err | ||
| } | ||
| deleteQuery := fmt.Sprintf("DELETE FROM %s WHERE %s", | ||
| o.table.identifier, | ||
| primaryKeyWhereClause, | ||
| ) | ||
| if o.reversibleBlockNum != nil { | ||
| return d.saveDelete(schema, o.table.nameEscaped, o.primaryKey, *o.reversibleBlockNum) + deleteQuery, nil | ||
| return d.saveDelete(schema, o.table, o.primaryKey, *o.reversibleBlockNum) + deleteQuery, nil | ||
| } | ||
| return deleteQuery, nil | ||
|
|
||
|
|
@@ -466,7 +480,7 @@ func (d *PostgresDialect) prepareColValues(table *TableInfo, colValues map[strin | |
| return nil, nil, fmt.Errorf("cannot find column %q for table %q (valid columns are %q)", columnName, table.identifier, strings.Join(maps.Keys(table.columnsByName), ", ")) | ||
| } | ||
|
|
||
| normalizedValue, err := d.normalizeValueType(value, columnInfo.scanType) | ||
| normalizedValue, err := d.normalizeLiteral(table, columnName, value, "insert") | ||
| if err != nil { | ||
| return nil, nil, fmt.Errorf("getting sql value from table %s for column %q raw value %q: %w", table.identifier, columnName, value, err) | ||
| } | ||
|
|
@@ -535,6 +549,107 @@ func getPrimaryKeyWhereClause(primaryKey map[string]string, escapedTableName str | |
| return strings.Join(reg[:], " AND ") | ||
| } | ||
|
|
||
| // Build a typed WHERE clause using column types for proper literal normalization | ||
| func (d *PostgresDialect) getPrimaryKeyWhereClauseTyped(table *TableInfo, primaryKey map[string]string, escapedTableName string) (string, error) { | ||
| // Avoid any allocation if there is a single primary key | ||
| if len(primaryKey) == 1 { | ||
| for key, value := range primaryKey { | ||
| rhs, err := d.normalizeLiteral(table, key, value, "where") | ||
| if err != nil { | ||
| return "", err | ||
| } | ||
| if escapedTableName == "" { | ||
| return EscapeIdentifier(key) + " = " + rhs, nil | ||
| } | ||
| return escapedTableName + "." + EscapeIdentifier(key) + " = " + rhs, nil | ||
| } | ||
| } | ||
|
|
||
| reg := make([]string, 0, len(primaryKey)) | ||
| for key, value := range primaryKey { | ||
| rhs, err := d.normalizeLiteral(table, key, value, "where") | ||
| if err != nil { | ||
| return "", err | ||
| } | ||
| if escapedTableName == "" { | ||
| reg = append(reg, EscapeIdentifier(key)+" = "+rhs) | ||
| } else { | ||
| reg = append(reg, escapedTableName+"."+EscapeIdentifier(key)+" = "+rhs) | ||
| } | ||
| } | ||
| sort.Strings(reg) | ||
| return strings.Join(reg, " AND "), nil | ||
| } | ||
|
|
||
| // normalizeLiteral centralizes literal formatting based on column database type name and context (insert/where) | ||
| // Returns a SQL RHS expression ready to be embedded in a statement | ||
| func (d *PostgresDialect) normalizeLiteral(table *TableInfo, columnName string, rawValue string, context string) (string, error) { | ||
| colInfo, found := table.columnsByName[columnName] | ||
| if !found { | ||
| return "", fmt.Errorf("cannot find column %q for table %q (valid columns are %q)", columnName, table.identifier, strings.Join(maps.Keys(table.columnsByName), ", ")) | ||
| } | ||
|
|
||
| dt := strings.ToLower(colInfo.databaseTypeName) | ||
| trimmed := strings.TrimSpace(rawValue) | ||
|
|
||
| // Array handling using databaseTypeName | ||
| if strings.HasSuffix(dt, "[]") { | ||
| baseType := strings.TrimSuffix(dt, "[]") | ||
| return canonicalizeArrayLiteral(baseType, trimmed), nil | ||
| } | ||
|
|
||
| // Scalar handling as today using scanType | ||
| return d.normalizeValueType(trimmed, colInfo.scanType) | ||
| } | ||
|
|
||
| // canonicalizeArrayLiteral emits a curly literal with explicit cast to the column's base type array | ||
| // Examples: | ||
| // - empty -> '{ }'::base[] | ||
| // - 'ARRAY[1,2]' -> '{1,2}'::base[] | ||
| // - '{1,2}' or "'{1,2}'::text[]" -> '{1,2}'::base[] | ||
| func canonicalizeArrayLiteral(baseType string, raw string) string { | ||
| // Empty array detection (accept {}, { }, [], ARRAY[]) | ||
| switch raw { | ||
| case "{}", "{ }", "[]", "ARRAY[]", "array[]", "ARRAY []", "array []", "": | ||
| return "'{ }'::" + baseType + "[]" | ||
| } | ||
|
|
||
| upper := strings.ToUpper(raw) | ||
| if strings.HasPrefix(upper, "ARRAY[") { | ||
| // Extract elements inside ARRAY[...] | ||
| end := strings.LastIndex(raw, "]") | ||
| inner := "" | ||
| if end > 6 { // len("ARRAY[") == 6 | ||
| inner = strings.TrimSpace(raw[6:end]) | ||
| } | ||
| return "'{" + inner + "}'::" + baseType + "[]" | ||
| } | ||
|
|
||
| // Strip surrounding single quotes if present | ||
| if strings.HasPrefix(raw, "'") && strings.HasSuffix(raw, "'") && len(raw) >= 2 { | ||
| raw = strings.TrimSuffix(strings.TrimPrefix(raw, "'"), "'") | ||
| } | ||
| // Remove any existing cast suffix like }::type[] | ||
| if idx := strings.LastIndex(raw, "}::"); idx != -1 { | ||
| raw = raw[:idx+1] | ||
| } | ||
|
|
||
| // If contains curly braces, extract content and rebuild | ||
| if l := strings.Index(raw, "{"); l != -1 { | ||
| if r := strings.LastIndex(raw, "}"); r != -1 && r > l { | ||
| inner := strings.TrimSpace(raw[l+1 : r]) | ||
| return "'{" + inner + "}'::" + baseType + "[]" | ||
| } | ||
| } | ||
|
|
||
| // Fallback: treat raw as a comma-separated list of elements | ||
| inner := strings.TrimSpace(raw) | ||
| if strings.HasPrefix(inner, "{") && strings.HasSuffix(inner, "}") && len(inner) >= 2 { | ||
| inner = strings.TrimSuffix(strings.TrimPrefix(inner, "{"), "}") | ||
| } | ||
| return "'{" + inner + "}'::" + baseType + "[]" | ||
| } | ||
|
|
||
| // Format based on type, value returned unescaped | ||
| func (d *PostgresDialect) normalizeValueType(value string, valueType reflect.Type) (string, error) { | ||
| switch valueType.Kind() { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, add comment as the why it's correct and log error.