Skip to content

Commit

Permalink
fix: AfterQuery using safer right trim while clearing from clause's j…
Browse files Browse the repository at this point in the history
…oin added as part of #7027 (#7153)

Co-authored-by: Abhijeet Bhowmik <[email protected]>
  • Loading branch information
bhowmik-abhijeet and Abhijeet Bhowmik authored Aug 22, 2024
1 parent 0dbfda5 commit 0daaf17
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 1 deletion.
2 changes: 1 addition & 1 deletion callbacks/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ func AfterQuery(db *gorm.DB) {
// clear the joins after query because preload need it
if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
fromClause := db.Statement.Clauses["FROM"]
fromClause.Expression = clause.From{Tables: v.Tables, Joins: v.Joins[:len(v.Joins)-len(db.Statement.Joins)]} // keep the original From Joins
fromClause.Expression = clause.From{Tables: v.Tables, Joins: utils.RTrimSlice(v.Joins, len(db.Statement.Joins))} // keep the original From Joins
db.Statement.Clauses["FROM"] = fromClause
}
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
Expand Down
11 changes: 11 additions & 0 deletions utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,14 @@ func SplitNestedRelationName(name string) []string {
func JoinNestedRelationNames(relationNames []string) string {
return strings.Join(relationNames, nestedRelationSplit)
}

// RTrimSlice Right trims the given slice by given length
func RTrimSlice[T any](v []T, trimLen int) []T {
if trimLen >= len(v) { // trimLen greater than slice len means fully sliced
return v[:0]
}
if trimLen < 0 { // negative trimLen is ignored
return v[:]
}
return v[:len(v)-trimLen]
}
61 changes: 61 additions & 0 deletions utils/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,64 @@ func TestToString(t *testing.T) {
})
}
}

func TestRTrimSlice(t *testing.T) {
tests := []struct {
name string
input []int
trimLen int
expected []int
}{
{
name: "Trim two elements from end",
input: []int{1, 2, 3, 4, 5},
trimLen: 2,
expected: []int{1, 2, 3},
},
{
name: "Trim entire slice",
input: []int{1, 2, 3},
trimLen: 3,
expected: []int{},
},
{
name: "Trim length greater than slice length",
input: []int{1, 2, 3},
trimLen: 5,
expected: []int{},
},
{
name: "Zero trim length",
input: []int{1, 2, 3},
trimLen: 0,
expected: []int{1, 2, 3},
},
{
name: "Trim one element from end",
input: []int{1, 2, 3},
trimLen: 1,
expected: []int{1, 2},
},
{
name: "Empty slice",
input: []int{},
trimLen: 2,
expected: []int{},
},
{
name: "Negative trim length (should be treated as zero)",
input: []int{1, 2, 3},
trimLen: -1,
expected: []int{1, 2, 3},
},
}

for _, testcase := range tests {
t.Run(testcase.name, func(t *testing.T) {
result := RTrimSlice(testcase.input, testcase.trimLen)
if !AssertEqual(result, testcase.expected) {
t.Errorf("RTrimSlice(%v, %d) = %v; want %v", testcase.input, testcase.trimLen, result, testcase.expected)
}
})
}
}

0 comments on commit 0daaf17

Please sign in to comment.