Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/dialect/mssql/mssql-query-compiler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@ export class MssqlQueryCompiler extends DefaultQueryCompiler {
return `@${this.numParameters}`
}

protected override getExistingParameterPlaceholder(
parameter: unknown,
): string | undefined {
const parameterIndex = this.getExistingParameterIndex(parameter)
if (parameterIndex === undefined) {
return undefined
}
return `@${parameterIndex + 1}`
}

protected override visitOffset(node: OffsetNode): void {
super.visitOffset(node)
this.append(' rows')
Expand Down
4 changes: 4 additions & 0 deletions src/dialect/mysql/mysql-query-compiler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ export class MysqlQueryCompiler extends DefaultQueryCompiler {
return '?'
}

protected override getExistingParameterPlaceholder(_parameter: unknown): undefined {
return undefined
}

protected override getLeftExplainOptionsWrapper(): string {
return ''
}
Expand Down
4 changes: 4 additions & 0 deletions src/dialect/sqlite/sqlite-query-compiler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ export class SqliteQueryCompiler extends DefaultQueryCompiler {
return '?'
}

protected override getExistingParameterPlaceholder(_parameter: unknown): undefined {
return undefined
}

protected override getLeftExplainOptionsWrapper(): string {
return ''
}
Expand Down
26 changes: 24 additions & 2 deletions src/query-compiler/default-query-compiler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1762,8 +1762,12 @@ export class DefaultQueryCompiler
}

protected appendValue(parameter: unknown): void {
this.addParameter(parameter)
this.append(this.getCurrentParameterPlaceholder())
let parameterPlaceholder = this.getExistingParameterPlaceholder(parameter)
if (parameterPlaceholder === undefined) {
this.addParameter(parameter)
parameterPlaceholder = this.getCurrentParameterPlaceholder()
}
this.append(parameterPlaceholder)
}

protected getLeftIdentifierWrapper(): string {
Expand All @@ -1774,6 +1778,24 @@ export class DefaultQueryCompiler
return '"'
}

protected getExistingParameterIndex(parameter: unknown): number | undefined {
const parameterIndex = this.#parameters.indexOf(parameter)
if (parameterIndex < 0) {
return undefined
}
return parameterIndex
}

protected getExistingParameterPlaceholder(
parameter: unknown,
): string | undefined {
const parameterIndex = this.getExistingParameterIndex(parameter)
if (parameterIndex === undefined) {
return undefined
}
return '$' + (parameterIndex + 1)
}

protected getCurrentParameterPlaceholder(): string {
return '$' + this.numParameters
}
Expand Down
30 changes: 30 additions & 0 deletions test/node/src/raw-sql.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,36 @@ for (const dialect of DIALECTS) {
await query.execute()
})

it('substitutions should reuse identical parameters in supported dialects', async () => {
const query = ctx.db
.selectFrom('person')
.selectAll()
.where(
sql<boolean>`first_name between ${'A'} and ${'B'} and first_name <> ${'A'}`,
)

testSql(query, dialect, {
postgres: {
sql: 'select * from "person" where first_name between $1 and $2 and first_name <> $1',
parameters: ['A', 'B'],
},
mysql: {
sql: 'select * from `person` where first_name between ? and ? and first_name <> ?',
parameters: ['A', 'B', 'A'],
},
mssql: {
sql: 'select * from "person" where first_name between @1 and @2 and first_name <> @1',
parameters: ['A', 'B'],
},
sqlite: {
sql: 'select * from "person" where first_name between ? and ? and first_name <> ?',
parameters: ['A', 'B', 'A'],
},
})

await query.execute()
})

it('substitutions should accept queries', async () => {
const compiler = new DefaultQueryCompiler()

Expand Down