1package sqlparser
2
3import (
4 "bufio"
5 "bytes"
6 "errors"
7 "fmt"
8 "io"
9 "log"
10 "os"
11 "strings"
12 "sync"
13
14 "github.com/mfridman/interpolate"
15)
16
17type Direction string
18
19const (
20 DirectionUp Direction = "up"
21 DirectionDown Direction = "down"
22)
23
24func FromBool(b bool) Direction {
25 if b {
26 return DirectionUp
27 }
28 return DirectionDown
29}
30
31func (d Direction) String() string {
32 return string(d)
33}
34
35func (d Direction) ToBool() bool {
36 return d == DirectionUp
37}
38
39type parserState int
40
41const (
42 start parserState = iota // 0
43 gooseUp // 1
44 gooseStatementBeginUp // 2
45 gooseStatementEndUp // 3
46 gooseDown // 4
47 gooseStatementBeginDown // 5
48 gooseStatementEndDown // 6
49)
50
51type stateMachine struct {
52 state parserState
53 verbose bool
54}
55
56func newStateMachine(begin parserState, verbose bool) *stateMachine {
57 return &stateMachine{
58 state: begin,
59 verbose: verbose,
60 }
61}
62
63func (s *stateMachine) get() parserState {
64 return s.state
65}
66
67func (s *stateMachine) set(new parserState) {
68 s.print("set %d => %d", s.state, new)
69 s.state = new
70}
71
72const (
73 grayColor = "\033[90m"
74 resetColor = "\033[00m"
75)
76
77func (s *stateMachine) print(msg string, args ...interface{}) {
78 msg = "StateMachine: " + msg
79 if s.verbose {
80 log.Printf(grayColor+msg+resetColor, args...)
81 }
82}
83
84const scanBufSize = 4 * 1024 * 1024
85
86var bufferPool = sync.Pool{
87 New: func() interface{} {
88 buf := make([]byte, scanBufSize)
89 return &buf
90 },
91}
92
93// Split given SQL script into individual statements and return
94// SQL statements for given direction (up=true, down=false).
95//
96// The base case is to simply split on semicolons, as these
97// naturally terminate a statement.
98//
99// However, more complex cases like pl/pgsql can have semicolons
100// within a statement. For these cases, we provide the explicit annotations
101// 'StatementBegin' and 'StatementEnd' to allow the script to
102// tell us to ignore semicolons.
103func ParseSQLMigration(r io.Reader, direction Direction, debug bool) (stmts []string, useTx bool, err error) {
104 scanBufPtr := bufferPool.Get().(*[]byte)
105 scanBuf := *scanBufPtr
106 defer bufferPool.Put(scanBufPtr)
107
108 scanner := bufio.NewScanner(r)
109 scanner.Buffer(scanBuf, scanBufSize)
110
111 stateMachine := newStateMachine(start, debug)
112 useTx = true
113 useEnvsub := false
114
115 var buf bytes.Buffer
116 for scanner.Scan() {
117 line := scanner.Text()
118 if debug {
119 log.Println(line)
120 }
121 if stateMachine.get() == start && strings.TrimSpace(line) == "" {
122 continue
123 }
124
125 // Check for annotations.
126 // All annotations must be in format: "-- +goose [annotation]"
127 if strings.HasPrefix(strings.TrimSpace(line), "--") && strings.Contains(line, "+goose") {
128 var cmd annotation
129
130 cmd, err = extractAnnotation(line)
131 if err != nil {
132 return nil, false, fmt.Errorf("failed to parse annotation line %q: %w", line, err)
133 }
134
135 switch cmd {
136 case annotationUp:
137 switch stateMachine.get() {
138 case start:
139 stateMachine.set(gooseUp)
140 default:
141 return nil, false, fmt.Errorf("duplicate '-- +goose Up' annotations; stateMachine=%d, see https://github.com/pressly/goose#sql-migrations", stateMachine.state)
142 }
143 continue
144
145 case annotationDown:
146 switch stateMachine.get() {
147 case gooseUp, gooseStatementEndUp:
148 // If we hit a down annotation, but the buffer is not empty, we have an unfinished SQL query from a
149 // previous up annotation. This is an error, because we expect the SQL query to be terminated by a semicolon
150 // and the buffer to have been reset.
151 if bufferRemaining := strings.TrimSpace(buf.String()); len(bufferRemaining) > 0 {
152 return nil, false, missingSemicolonError(stateMachine.state, direction, bufferRemaining)
153 }
154 stateMachine.set(gooseDown)
155 default:
156 return nil, false, fmt.Errorf("must start with '-- +goose Up' annotation, stateMachine=%d, see https://github.com/pressly/goose#sql-migrations", stateMachine.state)
157 }
158 continue
159
160 case annotationStatementBegin:
161 switch stateMachine.get() {
162 case gooseUp, gooseStatementEndUp:
163 stateMachine.set(gooseStatementBeginUp)
164 case gooseDown, gooseStatementEndDown:
165 stateMachine.set(gooseStatementBeginDown)
166 default:
167 return nil, false, fmt.Errorf("'-- +goose StatementBegin' must be defined after '-- +goose Up' or '-- +goose Down' annotation, stateMachine=%d, see https://github.com/pressly/goose#sql-migrations", stateMachine.state)
168 }
169 continue
170
171 case annotationStatementEnd:
172 switch stateMachine.get() {
173 case gooseStatementBeginUp:
174 stateMachine.set(gooseStatementEndUp)
175 case gooseStatementBeginDown:
176 stateMachine.set(gooseStatementEndDown)
177 default:
178 return nil, false, errors.New("'-- +goose StatementEnd' must be defined after '-- +goose StatementBegin', see https://github.com/pressly/goose#sql-migrations")
179 }
180
181 case annotationNoTransaction:
182 useTx = false
183 continue
184
185 case annotationEnvsubOn:
186 useEnvsub = true
187 continue
188
189 case annotationEnvsubOff:
190 useEnvsub = false
191 continue
192
193 default:
194 return nil, false, fmt.Errorf("unknown annotation: %q", cmd)
195 }
196 }
197 // Once we've started parsing a statement the buffer is no longer empty,
198 // we keep all comments up until the end of the statement (the buffer will be reset).
199 // All other comments in the file are ignored.
200 if buf.Len() == 0 {
201 // This check ensures leading comments and empty lines prior to a statement are ignored.
202 if strings.HasPrefix(strings.TrimSpace(line), "--") || line == "" {
203 stateMachine.print("ignore comment")
204 continue
205 }
206 }
207 switch stateMachine.get() {
208 case gooseStatementEndDown, gooseStatementEndUp:
209 // Do not include the "+goose StatementEnd" annotation in the final statement.
210 default:
211 if useEnvsub {
212 expanded, err := interpolate.Interpolate(&envWrapper{}, line)
213 if err != nil {
214 return nil, false, fmt.Errorf("variable substitution failed: %w:\n%s", err, line)
215 }
216 line = expanded
217 }
218 // Write SQL line to a buffer.
219 if _, err := buf.WriteString(line + "\n"); err != nil {
220 return nil, false, fmt.Errorf("failed to write to buf: %w", err)
221 }
222 }
223 // Read SQL body one by line, if we're in the right direction.
224 //
225 // 1) basic query with semicolon; 2) psql statement
226 //
227 // Export statement once we hit end of statement.
228 switch stateMachine.get() {
229 case gooseUp, gooseStatementBeginUp, gooseStatementEndUp:
230 if direction == DirectionDown {
231 buf.Reset()
232 stateMachine.print("ignore down")
233 continue
234 }
235 case gooseDown, gooseStatementBeginDown, gooseStatementEndDown:
236 if direction == DirectionUp {
237 buf.Reset()
238 stateMachine.print("ignore up")
239 continue
240 }
241 default:
242 return nil, false, fmt.Errorf("failed to parse migration: unexpected state %d on line %q, see https://github.com/pressly/goose#sql-migrations", stateMachine.state, line)
243 }
244
245 switch stateMachine.get() {
246 case gooseUp:
247 if endsWithSemicolon(line) {
248 stmts = append(stmts, cleanupStatement(buf.String()))
249 buf.Reset()
250 stateMachine.print("store simple Up query")
251 }
252 case gooseDown:
253 if endsWithSemicolon(line) {
254 stmts = append(stmts, cleanupStatement(buf.String()))
255 buf.Reset()
256 stateMachine.print("store simple Down query")
257 }
258 case gooseStatementEndUp:
259 stmts = append(stmts, cleanupStatement(buf.String()))
260 buf.Reset()
261 stateMachine.print("store Up statement")
262 stateMachine.set(gooseUp)
263 case gooseStatementEndDown:
264 stmts = append(stmts, cleanupStatement(buf.String()))
265 buf.Reset()
266 stateMachine.print("store Down statement")
267 stateMachine.set(gooseDown)
268 }
269 }
270 if err := scanner.Err(); err != nil {
271 return nil, false, fmt.Errorf("failed to scan migration: %w", err)
272 }
273 // EOF
274
275 switch stateMachine.get() {
276 case start:
277 return nil, false, errors.New("failed to parse migration: must start with '-- +goose Up' annotation, see https://github.com/pressly/goose#sql-migrations")
278 case gooseStatementBeginUp, gooseStatementBeginDown:
279 return nil, false, errors.New("failed to parse migration: missing '-- +goose StatementEnd' annotation")
280 }
281
282 if bufferRemaining := strings.TrimSpace(buf.String()); len(bufferRemaining) > 0 {
283 return nil, false, missingSemicolonError(stateMachine.state, direction, bufferRemaining)
284 }
285
286 return stmts, useTx, nil
287}
288
289type annotation string
290
291const (
292 annotationUp annotation = "Up"
293 annotationDown annotation = "Down"
294 annotationStatementBegin annotation = "StatementBegin"
295 annotationStatementEnd annotation = "StatementEnd"
296 annotationNoTransaction annotation = "NO TRANSACTION"
297 annotationEnvsubOn annotation = "ENVSUB ON"
298 annotationEnvsubOff annotation = "ENVSUB OFF"
299)
300
301var supportedAnnotations = map[annotation]struct{}{
302 annotationUp: {},
303 annotationDown: {},
304 annotationStatementBegin: {},
305 annotationStatementEnd: {},
306 annotationNoTransaction: {},
307 annotationEnvsubOn: {},
308 annotationEnvsubOff: {},
309}
310
311var (
312 errEmptyAnnotation = errors.New("empty annotation")
313 errInvalidAnnotation = errors.New("invalid annotation")
314)
315
316// extractAnnotation extracts the annotation from the line.
317// All annotations must be in format: "-- +goose [annotation]"
318// Allowed annotations: Up, Down, StatementBegin, StatementEnd, NO TRANSACTION, ENVSUB ON, ENVSUB OFF
319func extractAnnotation(line string) (annotation, error) {
320 // If line contains leading whitespace - return error.
321 if strings.HasPrefix(line, " ") || strings.HasPrefix(line, "\t") {
322 return "", fmt.Errorf("%q contains leading whitespace: %w", line, errInvalidAnnotation)
323 }
324
325 // Extract the annotation from the line, by removing the leading "--"
326 cmd := strings.ReplaceAll(line, "--", "")
327
328 // Extract the annotation from the line, by removing the leading "+goose"
329 cmd = strings.Replace(cmd, "+goose", "", 1)
330
331 if strings.Contains(cmd, "+goose") {
332 return "", fmt.Errorf("%q contains multiple '+goose' annotations: %w", cmd, errInvalidAnnotation)
333 }
334
335 // Remove leading and trailing whitespace from the annotation command.
336 cmd = strings.TrimSpace(cmd)
337
338 if cmd == "" {
339 return "", errEmptyAnnotation
340 }
341
342 a := annotation(cmd)
343
344 for s := range supportedAnnotations {
345 if strings.EqualFold(string(s), string(a)) {
346 return s, nil
347 }
348 }
349
350 return "", fmt.Errorf("%q not supported: %w", cmd, errInvalidAnnotation)
351}
352
353func missingSemicolonError(state parserState, direction Direction, s string) error {
354 return fmt.Errorf("failed to parse migration: state %d, direction: %v: unexpected unfinished SQL query: %q: missing semicolon?",
355 state,
356 direction,
357 s,
358 )
359}
360
361type envWrapper struct{}
362
363var _ interpolate.Env = (*envWrapper)(nil)
364
365func (e *envWrapper) Get(key string) (string, bool) {
366 return os.LookupEnv(key)
367}
368
369func cleanupStatement(input string) string {
370 return strings.TrimSpace(input)
371}
372
373// Checks the line to see if the line has a statement-ending semicolon
374// or if the line contains a double-dash comment.
375func endsWithSemicolon(line string) bool {
376 scanBufPtr := bufferPool.Get().(*[]byte)
377 scanBuf := *scanBufPtr
378 defer bufferPool.Put(scanBufPtr)
379
380 prev := ""
381 scanner := bufio.NewScanner(strings.NewReader(line))
382 scanner.Buffer(scanBuf, scanBufSize)
383 scanner.Split(bufio.ScanWords)
384
385 for scanner.Scan() {
386 word := scanner.Text()
387 if strings.HasPrefix(word, "--") {
388 break
389 }
390 prev = word
391 }
392
393 return strings.HasSuffix(prev, ";")
394}