Skip to content

Commit 5942e1e

Browse files
committed
Add support Upsert to model query
1 parent 4d1628e commit 5942e1e

File tree

3 files changed

+137
-2
lines changed

3 files changed

+137
-2
lines changed

builder_mysql.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,17 @@ func (b *MysqlBuilder) Upsert(table string, cols Params, constraints ...string)
7373

7474
names := []string{}
7575
for name := range cols {
76-
names = append(names, name)
76+
found := false
77+
for _, pkName := range constraints {
78+
if pkName == name {
79+
found = true
80+
}
81+
}
82+
if !found {
83+
names = append(names, name)
84+
}
7785
}
86+
7887
sort.Strings(names)
7988

8089
lines := []string{}
@@ -91,7 +100,7 @@ func (b *MysqlBuilder) Upsert(table string, cols Params, constraints ...string)
91100

92101
q.sql += " ON DUPLICATE KEY UPDATE " + strings.Join(lines, ", ")
93102

94-
return q
103+
return b.NewQuery(q.sql).Bind(q.params)
95104
}
96105

97106
var mysqlColumnRegexp = regexp.MustCompile("(?m)^\\s*[`\"](.*?)[`\"]\\s+(.*?),?$")

model_query.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,27 @@ func (q *ModelQuery) Delete() error {
172172
_, err := q.builder.Delete(q.model.tableName, HashExp(pk)).WithContext(q.ctx).Execute()
173173
return err
174174
}
175+
176+
// Upsert creates a Query that represents an UPSERT SQL statement.
177+
// Upsert inserts a row into the table if the primary key or unique index is not found.
178+
// Otherwise it will update the row with the new values.
179+
// The keys of cols are the column names, while the values of cols are the corresponding column
180+
// values to be inserted.
181+
func (q *ModelQuery) Upsert(attrs ...string) error {
182+
if q.lastError != nil {
183+
return q.lastError
184+
}
185+
pk := q.model.pk()
186+
if len(pk) == 0 {
187+
return MissingPKError
188+
}
189+
var pks []string
190+
191+
cols := q.model.columns(attrs, q.exclude)
192+
for name := range pk {
193+
cols[name] = pk[name]
194+
pks = append(pks, name)
195+
}
196+
_, err := q.builder.Upsert(q.model.tableName, Params(cols), pks...).WithContext(q.ctx).Execute()
197+
return err
198+
}

model_query_test.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,3 +236,105 @@ func TestModelQuery_Delete(t *testing.T) {
236236
var a int
237237
assert.NotNil(t, db.Model(&a).Delete())
238238
}
239+
240+
func TestModelQuery_Upsert(t *testing.T) {
241+
db := getPreparedDB()
242+
defer db.Close()
243+
244+
id := 2
245+
name := "test"
246+
email := "[email protected]"
247+
{
248+
// updating normally
249+
customer := Customer{
250+
ID: id,
251+
Name: name,
252+
Email: email,
253+
}
254+
err := db.Model(&customer).Upsert()
255+
if assert.Nil(t, err) {
256+
var c Customer
257+
db.Select().From("customer").Where(HashExp{"ID": id}).One(&c)
258+
assert.Equal(t, name, c.Name)
259+
assert.Equal(t, email, c.Email)
260+
assert.Equal(t, 0, c.Status)
261+
}
262+
}
263+
264+
{
265+
// updating without primary keys
266+
item2 := Item{
267+
Name: name,
268+
}
269+
err := db.Model(&item2).Upsert()
270+
assert.Equal(t, MissingPKError, err)
271+
}
272+
273+
{
274+
// updating all fields
275+
customer := CustomerPtr{
276+
ID: &id,
277+
Name: name,
278+
Email: &email,
279+
}
280+
err := db.Model(&customer).Upsert()
281+
if assert.Nil(t, err) {
282+
assert.Equal(t, id, *customer.ID)
283+
var c CustomerPtr
284+
db.Select().From("customer").Where(HashExp{"ID": id}).One(&c)
285+
assert.Equal(t, name, c.Name)
286+
if assert.NotNil(t, c.Email) {
287+
assert.Equal(t, email, *c.Email)
288+
}
289+
assert.Nil(t, c.Status)
290+
}
291+
}
292+
293+
{
294+
// updating selected fields only
295+
id = 3
296+
customer := CustomerPtr{
297+
ID: &id,
298+
Name: name,
299+
Email: &email,
300+
}
301+
err := db.Model(&customer).Upsert("Name", "Email")
302+
if assert.Nil(t, err) {
303+
assert.Equal(t, id, *customer.ID)
304+
var c CustomerPtr
305+
db.Select().From("customer").Where(HashExp{"ID": id}).One(&c)
306+
assert.Equal(t, name, c.Name)
307+
if assert.NotNil(t, c.Email) {
308+
assert.Equal(t, email, *c.Email)
309+
}
310+
if assert.NotNil(t, c.Status) {
311+
assert.Equal(t, 2, *c.Status)
312+
}
313+
}
314+
}
315+
316+
{
317+
// inserting normally
318+
customer := Customer{
319+
ID: 5,
320+
Name: name,
321+
Email: email,
322+
}
323+
err := db.Model(&customer).Upsert()
324+
if assert.Nil(t, err) {
325+
assert.Equal(t, 5, customer.ID)
326+
var c Customer
327+
db.Select().From("customer").Where(HashExp{"ID": 5}).One(&c)
328+
assert.Equal(t, name, c.Name)
329+
assert.Equal(t, email, c.Email)
330+
assert.Equal(t, 0, c.Status)
331+
assert.False(t, c.Address.Valid)
332+
}
333+
}
334+
335+
{
336+
// updating non-struct
337+
var a int
338+
assert.NotNil(t, db.Model(&a).Upsert())
339+
}
340+
}

0 commit comments

Comments
 (0)