Skip to content

Commit ffe6c6b

Browse files
authored
Merge pull request #1 from tomtwinkle/feat/extra-suffix-insert-query
Supports Insert Query with Extra Suffix
2 parents 6d0425b + 0055d49 commit ffe6c6b

File tree

11 files changed

+563
-85
lines changed

11 files changed

+563
-85
lines changed

.github/CODEOWNERS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
* @tomtwinkle

.github/dependabot.yaml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
---
2+
version: 2
3+
updates:
4+
- package-ecosystem: github-actions
5+
directory: "/"
6+
registries:
7+
- github-andpad
8+
schedule:
9+
interval: "daily"
10+
time: "08:00"
11+
timezone: "Asia/Tokyo"
12+
groups:
13+
dependencies:
14+
patterns:
15+
- "*"
16+
open-pull-requests-limit: 10
17+
- package-ecosystem: "gomod"
18+
directory: "/"
19+
schedule:
20+
interval: "daily"
21+
time: "08:00"
22+
timezone: "Asia/Tokyo"
23+
groups:
24+
dependencies:
25+
patterns:
26+
- "*"
27+
open-pull-requests-limit: 10

.github/workflows/test.yaml

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
name: Test
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
pull_request:
8+
types:
9+
- opened
10+
- synchronize
11+
- reopened
12+
13+
jobs:
14+
test:
15+
strategy:
16+
matrix:
17+
go-version: [ 1.24.x ]
18+
os: [ ubuntu-22.04 ]
19+
runs-on: ${{ matrix.os }}
20+
timeout-minutes: 5
21+
steps:
22+
- name: Install Go
23+
uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5.5.0
24+
with:
25+
go-version: ${{ matrix.go-version }}
26+
27+
- name: Checkout code
28+
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
29+
30+
- name: Go Module Download
31+
working-directory: ${{ env.testdir }}
32+
run: |
33+
go install gotest.tools/gotestsum@latest
34+
35+
- name: Test
36+
working-directory: ${{ env.testdir }}
37+
timeout-minutes: 3
38+
run: |
39+
gotestsum -- -race ./...

go.mod

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,14 @@ module github.com/tomtwinkle/sqlc-plugin-bulk-go
22

33
go 1.24
44

5-
require github.com/sqlc-dev/plugin-sdk-go v1.23.0
5+
require (
6+
github.com/sqlc-dev/plugin-sdk-go v1.23.0
7+
gotest.tools/v3 v3.5.2
8+
)
69

710
require (
11+
github.com/google/go-cmp v0.6.0 // indirect
12+
github.com/sqlc-dev/sqlc-gen-go v1.5.0 // indirect
813
golang.org/x/net v0.40.0 // indirect
914
golang.org/x/sys v0.33.0 // indirect
1015
golang.org/x/text v0.25.0 // indirect

go.sum

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
1010
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
1111
github.com/sqlc-dev/plugin-sdk-go v1.23.0 h1:iSeJhnXPlbDXlbzUEebw/DxsGzE9rdDJArl8Hvt0RMM=
1212
github.com/sqlc-dev/plugin-sdk-go v1.23.0/go.mod h1:I1r4THOfyETD+LI2gogN2LX8wCjwUZrgy/NU4In3llA=
13+
github.com/sqlc-dev/sqlc-gen-go v1.5.0 h1:nznEyqQ/Y0puwtahhDofxGwni2Dz+tSE6CWgiZpcPc8=
14+
github.com/sqlc-dev/sqlc-gen-go v1.5.0/go.mod h1:e2WgEv8ZGycpVM57hJNjq4/IpClsCITjv9Nye6AUyjQ=
1315
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
1416
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
1517
go.opentelemetry.io/otel v1.34.0 h1:zRLXxLCgL1WyKsPVrgbSdMN4c0FMkDAskSTQP+0hdUY=
@@ -34,3 +36,5 @@ google.golang.org/grpc v1.72.2 h1:TdbGzwb82ty4OusHWepvFWGLgIbNo1/SUynEN0ssqv8=
3436
google.golang.org/grpc v1.72.2/go.mod h1:wH5Aktxcg25y1I3w7H69nHfXdOG3UiadoBtjh3izSDM=
3537
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
3638
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
39+
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
40+
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=

main.go

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@ import (
88
"github.com/sqlc-dev/plugin-sdk-go/plugin"
99
)
1010

11-
const generateFileName = "bulk.sql.go"
11+
const (
12+
generateFileName = "bulk.sql.go"
13+
14+
sourceTemplateFuncPath = "templates/template.go"
15+
sourceTemplateFunc1 = "extractFieldValues"
16+
sourceTemplateFunc2 = "buildBulkInsertQuery"
17+
)
1218

1319
func main() {
1420
codegen.Run(Generate)
@@ -37,14 +43,31 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat
3743
func generate(
3844
ctx context.Context, req *plugin.GenerateRequest, opts *Options, structs BulkInserts,
3945
) (*plugin.GenerateResponse, error) {
46+
extractFieldValuesFn, err := parseGoCode(sourceTemplateFuncPath, sourceTemplateFunc1)
47+
if err != nil {
48+
return nil, fmt.Errorf("failed to parse function %s: %w", sourceTemplateFunc1, err)
49+
}
50+
buildBulkInsertQueryFn, err := parseGoCode(sourceTemplateFuncPath, sourceTemplateFunc2)
51+
if err != nil {
52+
return nil, fmt.Errorf("failed to parse function %s: %w", sourceTemplateFunc2, err)
53+
}
54+
4055
tmpl := struct {
41-
Package string
42-
SqlcVersion string
43-
BulkInsert []BulkInsert
56+
Package string
57+
SqlcVersion string
58+
BulkInsert []BulkInsert
59+
ExtractFnName string
60+
ExtractFn string
61+
BuildFnName string
62+
BuildFn string
4463
}{
45-
Package: opts.Package,
46-
SqlcVersion: req.GetSqlcVersion(),
47-
BulkInsert: structs,
64+
Package: opts.Package,
65+
SqlcVersion: req.GetSqlcVersion(),
66+
BulkInsert: structs,
67+
ExtractFnName: sourceTemplateFunc1,
68+
ExtractFn: string(extractFieldValuesFn),
69+
BuildFnName: sourceTemplateFunc2,
70+
BuildFn: string(buildBulkInsertQueryFn),
4871
}
4972

5073
code, err := executeTemplate(ctx, "bulkInsertFile", tmpl)

main_test.go

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
package main_test
2+
3+
import (
4+
"bytes"
5+
"go/ast"
6+
"go/importer"
7+
"go/parser"
8+
"go/token"
9+
"go/types"
10+
"strings"
11+
"testing"
12+
13+
"github.com/sqlc-dev/plugin-sdk-go/plugin"
14+
sqlcpluginbulkgo "github.com/tomtwinkle/sqlc-plugin-bulk-go"
15+
"gotest.tools/v3/assert"
16+
)
17+
18+
func TestGenerate(t *testing.T) {
19+
t.Parallel()
20+
21+
type Args struct {
22+
req *plugin.GenerateRequest
23+
}
24+
type Expected struct {
25+
err error
26+
}
27+
28+
tests := map[string]struct {
29+
arrange func(*testing.T) (Args, Expected)
30+
}{
31+
"valid:Normal INSERT Query": {
32+
arrange: func(t *testing.T) (Args, Expected) {
33+
req := &plugin.GenerateRequest{
34+
SqlcVersion: "1.0.0",
35+
PluginOptions: []byte(`{"package": "sqlc"}`),
36+
Queries: []*plugin.Query{
37+
{
38+
Name: "InsertUser",
39+
Text: "INSERT INTO users (id, name) VALUES (?, ?)",
40+
Params: []*plugin.Parameter{
41+
{Column: &plugin.Column{Name: "id"}},
42+
{Column: &plugin.Column{Name: "name"}},
43+
},
44+
},
45+
},
46+
}
47+
return Args{req: req}, Expected{err: nil}
48+
},
49+
},
50+
"valid:Extra suffix INSERT Query": {
51+
arrange: func(t *testing.T) (Args, Expected) {
52+
req := &plugin.GenerateRequest{
53+
SqlcVersion: "1.0.0",
54+
PluginOptions: []byte(`{"package": "sqlc"}`),
55+
Queries: []*plugin.Query{
56+
{
57+
Name: "InsertUser",
58+
Text: "INSERT INTO users (id, name) VALUES (?, ?) ON DUPLICATE KEY UPDATE id = id",
59+
Params: []*plugin.Parameter{
60+
{Column: &plugin.Column{Name: "id"}},
61+
{Column: &plugin.Column{Name: "name"}},
62+
},
63+
},
64+
},
65+
}
66+
return Args{req: req}, Expected{err: nil}
67+
},
68+
},
69+
}
70+
71+
for name, tc := range tests {
72+
t.Run(name, func(t *testing.T) {
73+
t.Parallel()
74+
args, want := tc.arrange(t)
75+
76+
got, err := sqlcpluginbulkgo.Generate(t.Context(), args.req)
77+
if want.err != nil {
78+
assert.ErrorContains(t, err, want.err.Error())
79+
return
80+
}
81+
assert.NilError(t, err)
82+
assert.Assert(t, len(got.Files) > 0, "Expected at least one generated file, got none")
83+
84+
// TODO: Need the code generated by sqlc-gen-go
85+
// assertGeneratedCodeIsValid(t, got.Files)
86+
t.Log(string(bytes.Replace(got.Files[0].Contents, []byte("\\n"), []byte("\n"), -1)))
87+
})
88+
}
89+
}
90+
91+
func assertGeneratedCodeIsValid(t *testing.T, files []*plugin.File) {
92+
t.Helper()
93+
94+
fset := token.NewFileSet()
95+
var parsedFiles []*ast.File
96+
97+
for _, file := range files {
98+
if !strings.HasSuffix(file.Name, ".go") {
99+
continue
100+
}
101+
102+
node, err := parser.ParseFile(fset, file.Name, file.Contents, parser.ParseComments)
103+
if err != nil {
104+
t.Fatalf("Failed to parse file %s: %v", file.Name, err)
105+
}
106+
parsedFiles = append(parsedFiles, node)
107+
}
108+
if len(parsedFiles) == 0 {
109+
return
110+
}
111+
conf := types.Config{
112+
Importer: importer.Default(),
113+
Error: func(err error) {
114+
t.Fatalf("Failed to type check generated code: %v", err)
115+
},
116+
}
117+
pkgPath := parsedFiles[0].Name.Name
118+
_, err := conf.Check(pkgPath, fset, parsedFiles, nil)
119+
if err != nil {
120+
t.Fatalf("Type checking failed for generated code: %v", err)
121+
}
122+
}

parser.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"go/ast"
6+
"go/parser"
7+
"go/token"
8+
"os"
9+
)
10+
11+
// parseGoCode parses a Go source file and extracts the code of a specific function by its name.
12+
// It returns the function code as a byte slice or an error if the function is not found.
13+
func parseGoCode(sourceFile string, targetFuncName string) ([]byte, error) {
14+
srcBytes, err := os.ReadFile(sourceFile)
15+
if err != nil {
16+
return nil, err
17+
}
18+
19+
// Parsing files and building ASTs with go/parser
20+
fset := token.NewFileSet()
21+
node, err := parser.ParseFile(fset, sourceFile, srcBytes, parser.ParseComments)
22+
if err != nil {
23+
return nil, err
24+
}
25+
26+
var funcCode []byte
27+
// Scanning AST top-level declarations
28+
ast.Inspect(node, func(n ast.Node) bool {
29+
// Find function declarations (FuncDecl)
30+
fn, ok := n.(*ast.FuncDecl)
31+
if !ok {
32+
return true // If it is not a function declaration, continue the search.
33+
}
34+
35+
// Check if it matches the desired function name
36+
if fn.Name.Name == targetFuncName {
37+
// Get the start and end positions of the function
38+
start := fset.Position(fn.Pos()).Offset
39+
end := fset.Position(fn.End()).Offset
40+
41+
// Cut out the function part as a string from the original source code
42+
funcCode = srcBytes[start:end]
43+
return false // Finish the search because the desired function has been found.
44+
}
45+
return true
46+
})
47+
48+
if len(funcCode) == 0 {
49+
return nil, fmt.Errorf("function '%s' not found", targetFuncName)
50+
}
51+
return funcCode, nil
52+
}

0 commit comments

Comments
 (0)