@@ -17,7 +17,7 @@ limitations under the License.
17
17
package data
18
18
19
19
import (
20
- "bytes "
20
+ "bufio "
21
21
"fmt"
22
22
"io"
23
23
"net/http"
@@ -27,118 +27,150 @@ import (
27
27
28
28
type SQLScriptLoader struct {}
29
29
30
- func readData (url string ) ([]byte , error ) {
31
- if strings .HasPrefix (url , "http" ) {
32
- client := http.Client {}
33
- res , err := client .Get (url )
30
+ func (SQLScriptLoader ) Load (fileName string ) ([]Query , error ) {
31
+ loader := SQLScriptLoader {}.Loadit (fileName )
32
+ return makeSlice (loader )
33
+ }
34
+
35
+ func (SQLScriptLoader ) Loadit (fileName string ) IteratorLoader {
36
+ var fd * os.File
37
+ var err error
38
+
39
+ if strings .HasPrefix (fileName , "http" ) {
40
+ // Read from URL
41
+ data , err := readData (fileName )
34
42
if err != nil {
35
- return nil , err
43
+ return & errLoader { err }
36
44
}
37
- if res .StatusCode != http .StatusOK {
38
- return nil , fmt .Errorf ("failed to get data from %s, status code %d" , url , res .StatusCode )
45
+ scanner := bufio .NewScanner (strings .NewReader (string (data )))
46
+ return & sqlScriptReaderState {
47
+ logReaderState : logReaderState {
48
+ scanner : scanner ,
49
+ },
50
+ newStmt : true ,
39
51
}
40
- defer res .Body .Close ()
41
- return io .ReadAll (res .Body )
42
52
}
43
- return os .ReadFile (url )
44
- }
45
53
46
- func ( SQLScriptLoader ) Load ( url string ) ([] Query , error ) {
47
- data , err := readData ( url )
54
+ // Read from file
55
+ fd , err = os . OpenFile ( fileName , os . O_RDONLY , 0 )
48
56
if err != nil {
49
- return nil , err
57
+ return & errLoader { err }
50
58
}
51
- seps := bytes .Split (data , []byte ("\n " ))
52
- queries := make ([]Query , 0 , len (seps ))
53
- newStmt := true
54
- for i , v := range seps {
55
- v := bytes .TrimSpace (v )
56
- s := string (v )
57
- // Skip comments and empty lines
58
- switch {
59
- case strings .HasPrefix (s , "#" ):
60
- newStmt = true
61
- continue
62
- case strings .HasPrefix (s , "--" ):
63
- queries = append (queries , Query {Query : s , Line : i + 1 })
64
- newStmt = true
65
- continue
66
- case len (s ) == 0 :
67
- continue
68
- }
69
59
70
- if newStmt {
71
- queries = append (queries , Query {Query : s , Line : i + 1 })
72
- } else {
73
- lastQuery := queries [len (queries )- 1 ]
74
- lastQuery .Query = fmt .Sprintf ("%s\n %s" , lastQuery .Query , s )
75
- queries [len (queries )- 1 ] = lastQuery
76
- }
77
-
78
- // Treat new line as a new statement if line ends with ';'
79
- newStmt = strings .HasSuffix (s , ";" )
60
+ scanner := bufio .NewScanner (fd )
61
+ return & sqlScriptReaderState {
62
+ logReaderState : logReaderState {
63
+ fd : fd ,
64
+ scanner : scanner ,
65
+ lineNumber : 0 ,
66
+ },
67
+ newStmt : true ,
80
68
}
69
+ }
81
70
82
- // Process queries directly without calling ParseQueries
83
- finalQueries := make ([]Query , 0 , len (queries ))
84
- for _ , rs := range queries {
85
- q , err := parseQuery (rs )
86
- if err != nil {
87
- return nil , err
88
- }
89
- if q != nil {
90
- finalQueries = append (finalQueries , * q )
91
- }
92
- }
93
- return finalQueries , nil
71
+ type sqlScriptReaderState struct {
72
+ logReaderState
73
+ prevQuery string
74
+ queryStart int
75
+ newStmt bool
94
76
}
95
77
96
- // Helper function to parse individual queries
97
- func parseQuery (rs Query ) (* Query , error ) {
98
- realS := rs .Query
99
- s := rs .Query
100
- q := Query {Line : rs .Line , Type : Unknown }
78
+ func (s * sqlScriptReaderState ) Next () (Query , bool ) {
79
+ s .mu .Lock ()
80
+ defer s .mu .Unlock ()
101
81
102
- if len ( s ) < 3 {
103
- return nil , nil
82
+ if s . closed {
83
+ return Query {}, false
104
84
}
105
85
106
- switch {
107
- case strings .HasPrefix (s , "#" ):
108
- q .Type = Comment
109
- return & q , nil
110
- case strings .HasPrefix (s , "--" ):
111
- q .Type = CommentWithCommand
112
- if len (s ) > 2 && s [2 ] == ' ' {
113
- s = s [3 :]
86
+ for s .scanner .Scan () {
87
+ line := s .scanner .Text ()
88
+ line = strings .TrimSpace (line )
89
+ s .lineNumber ++
90
+
91
+ // Skip empty lines and comments
92
+ if len (line ) == 0 {
93
+ continue
94
+ }
95
+ switch {
96
+ case strings .HasPrefix (line , "#" ):
97
+ s .newStmt = true
98
+ continue
99
+ case strings .HasPrefix (line , "--" ):
100
+ // Return previous query before processing the comment
101
+ if s .prevQuery != "" {
102
+ query := Query {
103
+ Query : s .prevQuery ,
104
+ Line : s .queryStart ,
105
+ Type : QueryT ,
106
+ }
107
+ s .prevQuery = ""
108
+ s .queryStart = 0
109
+ s .newStmt = true
110
+ // Store current comment line as new query
111
+ s .prevQuery = line
112
+ s .queryStart = s .lineNumber
113
+ return query , true
114
+ } else {
115
+ s .prevQuery = line
116
+ s .queryStart = s .lineNumber
117
+ s .newStmt = true
118
+ continue
119
+ }
120
+ }
121
+
122
+ if s .newStmt {
123
+ s .prevQuery = line
124
+ s .queryStart = s .lineNumber
114
125
} else {
115
- s = s [ 2 :]
126
+ s . prevQuery += " \n " + line
116
127
}
117
- case s [0 ] == '\n' :
118
- q .Type = EmptyLine
119
- return & q , nil
120
- }
121
128
122
- i := findFirstWord (s )
123
- if i > 0 {
124
- q .FirstWord = s [:i ]
129
+ // Check if the line ends with a semicolon
130
+ if strings .HasSuffix (line , ";" ) {
131
+ query := Query {
132
+ Query : s .prevQuery ,
133
+ Line : s .queryStart ,
134
+ Type : QueryT ,
135
+ }
136
+ s .prevQuery = ""
137
+ s .queryStart = 0
138
+ s .newStmt = true
139
+ return query , true
140
+ } else {
141
+ s .newStmt = false
142
+ }
125
143
}
126
- q .Query = s [i :]
127
144
128
- if q .Type == Unknown || q .Type == CommentWithCommand {
129
- if err := q .getQueryType (realS ); err != nil {
130
- return nil , err
145
+ s .closed = true
146
+
147
+ // Return the last query if we have one
148
+ if s .prevQuery != "" {
149
+ query := Query {
150
+ Query : s .prevQuery ,
151
+ Line : s .queryStart ,
152
+ Type : QueryT ,
131
153
}
154
+ s .prevQuery = ""
155
+ return query , true
132
156
}
133
157
134
- return & q , nil
158
+ s .err = s .scanner .Err ()
159
+ return Query {}, false
135
160
}
136
161
137
- // findFirstWord calculates the length of the first word in the string
138
- func findFirstWord (s string ) int {
139
- i := 0
140
- for i < len (s ) && s [i ] != '(' && s [i ] != ' ' && s [i ] != ';' && s [i ] != '\n' {
141
- i ++
162
+ func readData (url string ) ([]byte , error ) {
163
+ if strings .HasPrefix (url , "http" ) {
164
+ client := http.Client {}
165
+ res , err := client .Get (url )
166
+ if err != nil {
167
+ return nil , err
168
+ }
169
+ if res .StatusCode != http .StatusOK {
170
+ return nil , fmt .Errorf ("failed to get data from %s, status code %d" , url , res .StatusCode )
171
+ }
172
+ defer res .Body .Close ()
173
+ return io .ReadAll (res .Body )
142
174
}
143
- return i
175
+ return os . ReadFile ( url )
144
176
}
0 commit comments