Skip to content

Commit 266b692

Browse files
authored
Fix query using table alias (#73)
1 parent 177dd60 commit 266b692

File tree

3 files changed

+62
-21
lines changed

3 files changed

+62
-21
lines changed

builder/buffer.go

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -131,35 +131,40 @@ func (b Buffer) escape(table, value string) string {
131131
return escapedValue.(string)
132132
}
133133

134-
var escaped_table string
134+
table, alias := extractAlias(table)
135+
var escapedTable string
135136
if table != "" {
136-
if i := strings.Index(strings.ToLower(table), " as "); i > -1 {
137-
return b.escape(table[:i], "") + " AS " + b.Quoter.ID(table[i+4:])
137+
if table != alias {
138+
if value == "" {
139+
return b.escape(table, "") + " AS " + b.Quoter.ID(alias)
140+
} else {
141+
escapedTable = b.Quoter.ID(alias)
142+
}
138143
}
139144
if b.AllowTableSchema && strings.IndexByte(table, '.') >= 0 {
140145
parts := strings.Split(table, ".")
141146
for i, part := range parts {
142147
part = strings.TrimSpace(part)
143148
parts[i] = b.Quoter.ID(part)
144149
}
145-
escaped_table = strings.Join(parts, ".")
150+
escapedTable = strings.Join(parts, ".")
146151
} else {
147-
escaped_table = b.Quoter.ID(strings.ReplaceAll(table, ".", "_"))
152+
escapedTable = b.Quoter.ID(strings.ReplaceAll(table, ".", "_"))
148153
}
149154
}
150155

151156
if value == "" {
152-
escapedValue = escaped_table
157+
escapedValue = escapedTable
153158
} else if value == "*" {
154-
escapedValue = escaped_table + ".*"
159+
escapedValue = escapedTable + ".*"
155160
} else if len(value) > 0 && value[0] == UnescapeCharacter {
156161
escapedValue = value[1:]
157162
} else if _, err := strconv.Atoi(value); err == nil {
158163
escapedValue = value
159164
} else if i := strings.Index(strings.ToLower(value), " as "); i > -1 {
160-
escapedValue = b.escape(table, value[:i]) + " AS " + b.Quoter.ID(value[i+4:])
165+
escapedValue = b.escape(alias, value[:i]) + " AS " + b.Quoter.ID(value[i+4:])
161166
} else if start, end := strings.IndexRune(value, '('), strings.IndexRune(value, ')'); start >= 0 && end >= 0 && end > start {
162-
escapedValue = value[:start+1] + b.escape(table, value[start+1:end]) + value[end:]
167+
escapedValue = value[:start+1] + b.escape(alias, value[start+1:end]) + value[end:]
163168
} else {
164169
parts := strings.Split(value, ".")
165170
for i, part := range parts {
@@ -171,7 +176,7 @@ func (b Buffer) escape(table, value string) string {
171176
}
172177
result := strings.Join(parts, ".")
173178
if len(parts) == 1 && table != "" {
174-
result = escaped_table + "." + result
179+
result = escapedTable + "." + result
175180
}
176181
escapedValue = result
177182
}
@@ -228,3 +233,13 @@ func (bf BufferFactory) Create() Buffer {
228233
BoolFalseValue: bf.BoolFalseValue,
229234
}
230235
}
236+
237+
// extract alias in the form of table as alias
238+
// if no alias, table will be returned as alias
239+
func extractAlias(input string) (string, string) {
240+
if i := strings.Index(strings.ToLower(input), " as "); i > -1 {
241+
return input[:i], input[i+4:]
242+
}
243+
244+
return input, input
245+
}

builder/query.go

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -105,20 +105,16 @@ func (q Query) WriteJoin(buffer *Buffer, table string, joins []rel.JoinQuery) {
105105

106106
for _, join := range joins {
107107
var (
108-
from = join.From
109-
to = join.To
108+
_, sAlias = extractAlias(table)
109+
jTable, jAlias = extractAlias(join.Table)
110+
from = join.From
111+
to = join.To
110112
)
111113

112-
jtable := join.Table
113-
// If join table has alias use that for filter conditions
114-
if i := strings.Index(strings.ToLower(jtable), " as "); i > -1 {
115-
jtable = jtable[i+4:]
116-
}
117-
118114
// TODO: move this to core functionality, and infer join condition using assoc data.
119115
if join.Arguments == nil && (join.From == "" || join.To == "") {
120-
from = table + "." + strings.TrimSuffix(join.Table, "s") + "_id"
121-
to = jtable + ".id"
116+
from = sAlias + "." + strings.TrimSuffix(jTable, "s") + "_id"
117+
to = jAlias + ".id"
122118
}
123119

124120
buffer.WriteByte(' ')
@@ -133,7 +129,7 @@ func (q Query) WriteJoin(buffer *Buffer, table string, joins []rel.JoinQuery) {
133129
buffer.WriteEscape(to)
134130
if !join.Filter.None() {
135131
buffer.WriteString(" AND ")
136-
q.Filter.Write(buffer, jtable, join.Filter, q)
132+
q.Filter.Write(buffer, join.Table, join.Filter, q)
137133
}
138134
}
139135

builder/query_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,36 @@ func TestQuery_Build(t *testing.T) {
105105
result: "SELECT `users`.* FROM `users` FOR UPDATE;",
106106
query: rel.From("users").Lock("FOR UPDATE"),
107107
},
108+
{
109+
result: "SELECT `c`.`id`,`c`.`name` FROM `contacts` AS `c`;",
110+
query: rel.Select("c.id", "c.name").From("contacts as c"),
111+
},
112+
{
113+
result: "SELECT MAX(`c`.`id`) FROM `contacts` AS `c`;",
114+
query: rel.Select("MAX(id)").From("contacts as c"),
115+
},
116+
{
117+
result: "SELECT MAX(`c`.`id`) FROM `contacts` AS `c`;",
118+
query: rel.Select("MAX(c.id)").From("contacts as c"),
119+
},
120+
{
121+
result: "SELECT MAX(`c`.`id`) AS `max_id` FROM `contacts` AS `c`;",
122+
query: rel.Select("MAX(id) as max_id").From("contacts as c"),
123+
},
124+
{
125+
result: "SELECT MAX(`c`.`id`) AS `max_id` FROM `contacts` AS `c`;",
126+
query: rel.Select("MAX(c.id) as max_id").From("contacts as c"),
127+
},
128+
{
129+
result: "SELECT `c`.`id`,`c`.`name` FROM `contacts` AS `c` JOIN `users` AS `u` ON `c`.`user_id`=`u`.`id` WHERE `u`.`active`=?;",
130+
args: []any{true},
131+
query: rel.Select("c.id", "c.name").From("contacts as c").Join("users as u").Where(rel.Eq("u.active", true)),
132+
},
133+
{
134+
result: "SELECT `c`.`id`,`c`.`name` FROM `contacts` AS `c` JOIN `users` AS `u` ON `u`.`id`=`c`.`user_id` WHERE `u`.`active`=?;",
135+
args: []any{true},
136+
query: rel.Select("c.id", "c.name").From("contacts as c").JoinOn("users as u", "u.id", "c.user_id").Where(rel.Eq("u.active", true)),
137+
},
108138
}
109139

110140
for _, test := range tests {

0 commit comments

Comments
 (0)