parser.go

  1package sqlparser
  2
  3import (
  4	"bufio"
  5	"bytes"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log"
 10	"os"
 11	"strings"
 12	"sync"
 13
 14	"github.com/mfridman/interpolate"
 15)
 16
 17type Direction string
 18
 19const (
 20	DirectionUp   Direction = "up"
 21	DirectionDown Direction = "down"
 22)
 23
 24func FromBool(b bool) Direction {
 25	if b {
 26		return DirectionUp
 27	}
 28	return DirectionDown
 29}
 30
 31func (d Direction) String() string {
 32	return string(d)
 33}
 34
 35func (d Direction) ToBool() bool {
 36	return d == DirectionUp
 37}
 38
 39type parserState int
 40
 41const (
 42	start                   parserState = iota // 0
 43	gooseUp                                    // 1
 44	gooseStatementBeginUp                      // 2
 45	gooseStatementEndUp                        // 3
 46	gooseDown                                  // 4
 47	gooseStatementBeginDown                    // 5
 48	gooseStatementEndDown                      // 6
 49)
 50
 51type stateMachine struct {
 52	state   parserState
 53	verbose bool
 54}
 55
 56func newStateMachine(begin parserState, verbose bool) *stateMachine {
 57	return &stateMachine{
 58		state:   begin,
 59		verbose: verbose,
 60	}
 61}
 62
 63func (s *stateMachine) get() parserState {
 64	return s.state
 65}
 66
 67func (s *stateMachine) set(new parserState) {
 68	s.print("set %d => %d", s.state, new)
 69	s.state = new
 70}
 71
 72const (
 73	grayColor  = "\033[90m"
 74	resetColor = "\033[00m"
 75)
 76
 77func (s *stateMachine) print(msg string, args ...interface{}) {
 78	msg = "StateMachine: " + msg
 79	if s.verbose {
 80		log.Printf(grayColor+msg+resetColor, args...)
 81	}
 82}
 83
 84const scanBufSize = 4 * 1024 * 1024
 85
 86var bufferPool = sync.Pool{
 87	New: func() interface{} {
 88		buf := make([]byte, scanBufSize)
 89		return &buf
 90	},
 91}
 92
 93// Split given SQL script into individual statements and return
 94// SQL statements for given direction (up=true, down=false).
 95//
 96// The base case is to simply split on semicolons, as these
 97// naturally terminate a statement.
 98//
 99// However, more complex cases like pl/pgsql can have semicolons
100// within a statement. For these cases, we provide the explicit annotations
101// 'StatementBegin' and 'StatementEnd' to allow the script to
102// tell us to ignore semicolons.
103func ParseSQLMigration(r io.Reader, direction Direction, debug bool) (stmts []string, useTx bool, err error) {
104	scanBufPtr := bufferPool.Get().(*[]byte)
105	scanBuf := *scanBufPtr
106	defer bufferPool.Put(scanBufPtr)
107
108	scanner := bufio.NewScanner(r)
109	scanner.Buffer(scanBuf, scanBufSize)
110
111	stateMachine := newStateMachine(start, debug)
112	useTx = true
113	useEnvsub := false
114
115	var buf bytes.Buffer
116	for scanner.Scan() {
117		line := scanner.Text()
118		if debug {
119			log.Println(line)
120		}
121		if stateMachine.get() == start && strings.TrimSpace(line) == "" {
122			continue
123		}
124
125		// Check for annotations.
126		// All annotations must be in format: "-- +goose [annotation]"
127		if strings.HasPrefix(strings.TrimSpace(line), "--") && strings.Contains(line, "+goose") {
128			var cmd annotation
129
130			cmd, err = extractAnnotation(line)
131			if err != nil {
132				return nil, false, fmt.Errorf("failed to parse annotation line %q: %w", line, err)
133			}
134
135			switch cmd {
136			case annotationUp:
137				switch stateMachine.get() {
138				case start:
139					stateMachine.set(gooseUp)
140				default:
141					return nil, false, fmt.Errorf("duplicate '-- +goose Up' annotations; stateMachine=%d, see https://github.com/pressly/goose#sql-migrations", stateMachine.state)
142				}
143				continue
144
145			case annotationDown:
146				switch stateMachine.get() {
147				case gooseUp, gooseStatementEndUp:
148					// If we hit a down annotation, but the buffer is not empty, we have an unfinished SQL query from a
149					// previous up annotation. This is an error, because we expect the SQL query to be terminated by a semicolon
150					// and the buffer to have been reset.
151					if bufferRemaining := strings.TrimSpace(buf.String()); len(bufferRemaining) > 0 {
152						return nil, false, missingSemicolonError(stateMachine.state, direction, bufferRemaining)
153					}
154					stateMachine.set(gooseDown)
155				default:
156					return nil, false, fmt.Errorf("must start with '-- +goose Up' annotation, stateMachine=%d, see https://github.com/pressly/goose#sql-migrations", stateMachine.state)
157				}
158				continue
159
160			case annotationStatementBegin:
161				switch stateMachine.get() {
162				case gooseUp, gooseStatementEndUp:
163					stateMachine.set(gooseStatementBeginUp)
164				case gooseDown, gooseStatementEndDown:
165					stateMachine.set(gooseStatementBeginDown)
166				default:
167					return nil, false, fmt.Errorf("'-- +goose StatementBegin' must be defined after '-- +goose Up' or '-- +goose Down' annotation, stateMachine=%d, see https://github.com/pressly/goose#sql-migrations", stateMachine.state)
168				}
169				continue
170
171			case annotationStatementEnd:
172				switch stateMachine.get() {
173				case gooseStatementBeginUp:
174					stateMachine.set(gooseStatementEndUp)
175				case gooseStatementBeginDown:
176					stateMachine.set(gooseStatementEndDown)
177				default:
178					return nil, false, errors.New("'-- +goose StatementEnd' must be defined after '-- +goose StatementBegin', see https://github.com/pressly/goose#sql-migrations")
179				}
180
181			case annotationNoTransaction:
182				useTx = false
183				continue
184
185			case annotationEnvsubOn:
186				useEnvsub = true
187				continue
188
189			case annotationEnvsubOff:
190				useEnvsub = false
191				continue
192
193			default:
194				return nil, false, fmt.Errorf("unknown annotation: %q", cmd)
195			}
196		}
197		// Once we've started parsing a statement the buffer is no longer empty,
198		// we keep all comments up until the end of the statement (the buffer will be reset).
199		// All other comments in the file are ignored.
200		if buf.Len() == 0 {
201			// This check ensures leading comments and empty lines prior to a statement are ignored.
202			if strings.HasPrefix(strings.TrimSpace(line), "--") || line == "" {
203				stateMachine.print("ignore comment")
204				continue
205			}
206		}
207		switch stateMachine.get() {
208		case gooseStatementEndDown, gooseStatementEndUp:
209			// Do not include the "+goose StatementEnd" annotation in the final statement.
210		default:
211			if useEnvsub {
212				expanded, err := interpolate.Interpolate(&envWrapper{}, line)
213				if err != nil {
214					return nil, false, fmt.Errorf("variable substitution failed: %w:\n%s", err, line)
215				}
216				line = expanded
217			}
218			// Write SQL line to a buffer.
219			if _, err := buf.WriteString(line + "\n"); err != nil {
220				return nil, false, fmt.Errorf("failed to write to buf: %w", err)
221			}
222		}
223		// Read SQL body one by line, if we're in the right direction.
224		//
225		// 1) basic query with semicolon; 2) psql statement
226		//
227		// Export statement once we hit end of statement.
228		switch stateMachine.get() {
229		case gooseUp, gooseStatementBeginUp, gooseStatementEndUp:
230			if direction == DirectionDown {
231				buf.Reset()
232				stateMachine.print("ignore down")
233				continue
234			}
235		case gooseDown, gooseStatementBeginDown, gooseStatementEndDown:
236			if direction == DirectionUp {
237				buf.Reset()
238				stateMachine.print("ignore up")
239				continue
240			}
241		default:
242			return nil, false, fmt.Errorf("failed to parse migration: unexpected state %d on line %q, see https://github.com/pressly/goose#sql-migrations", stateMachine.state, line)
243		}
244
245		switch stateMachine.get() {
246		case gooseUp:
247			if endsWithSemicolon(line) {
248				stmts = append(stmts, cleanupStatement(buf.String()))
249				buf.Reset()
250				stateMachine.print("store simple Up query")
251			}
252		case gooseDown:
253			if endsWithSemicolon(line) {
254				stmts = append(stmts, cleanupStatement(buf.String()))
255				buf.Reset()
256				stateMachine.print("store simple Down query")
257			}
258		case gooseStatementEndUp:
259			stmts = append(stmts, cleanupStatement(buf.String()))
260			buf.Reset()
261			stateMachine.print("store Up statement")
262			stateMachine.set(gooseUp)
263		case gooseStatementEndDown:
264			stmts = append(stmts, cleanupStatement(buf.String()))
265			buf.Reset()
266			stateMachine.print("store Down statement")
267			stateMachine.set(gooseDown)
268		}
269	}
270	if err := scanner.Err(); err != nil {
271		return nil, false, fmt.Errorf("failed to scan migration: %w", err)
272	}
273	// EOF
274
275	switch stateMachine.get() {
276	case start:
277		return nil, false, errors.New("failed to parse migration: must start with '-- +goose Up' annotation, see https://github.com/pressly/goose#sql-migrations")
278	case gooseStatementBeginUp, gooseStatementBeginDown:
279		return nil, false, errors.New("failed to parse migration: missing '-- +goose StatementEnd' annotation")
280	}
281
282	if bufferRemaining := strings.TrimSpace(buf.String()); len(bufferRemaining) > 0 {
283		return nil, false, missingSemicolonError(stateMachine.state, direction, bufferRemaining)
284	}
285
286	return stmts, useTx, nil
287}
288
289type annotation string
290
291const (
292	annotationUp             annotation = "Up"
293	annotationDown           annotation = "Down"
294	annotationStatementBegin annotation = "StatementBegin"
295	annotationStatementEnd   annotation = "StatementEnd"
296	annotationNoTransaction  annotation = "NO TRANSACTION"
297	annotationEnvsubOn       annotation = "ENVSUB ON"
298	annotationEnvsubOff      annotation = "ENVSUB OFF"
299)
300
301var supportedAnnotations = map[annotation]struct{}{
302	annotationUp:             {},
303	annotationDown:           {},
304	annotationStatementBegin: {},
305	annotationStatementEnd:   {},
306	annotationNoTransaction:  {},
307	annotationEnvsubOn:       {},
308	annotationEnvsubOff:      {},
309}
310
311var (
312	errEmptyAnnotation   = errors.New("empty annotation")
313	errInvalidAnnotation = errors.New("invalid annotation")
314)
315
316// extractAnnotation extracts the annotation from the line.
317// All annotations must be in format: "-- +goose [annotation]"
318// Allowed annotations: Up, Down, StatementBegin, StatementEnd, NO TRANSACTION, ENVSUB ON, ENVSUB OFF
319func extractAnnotation(line string) (annotation, error) {
320	// If line contains leading whitespace - return error.
321	if strings.HasPrefix(line, " ") || strings.HasPrefix(line, "\t") {
322		return "", fmt.Errorf("%q contains leading whitespace: %w", line, errInvalidAnnotation)
323	}
324
325	// Extract the annotation from the line, by removing the leading "--"
326	cmd := strings.ReplaceAll(line, "--", "")
327
328	// Extract the annotation from the line, by removing the leading "+goose"
329	cmd = strings.Replace(cmd, "+goose", "", 1)
330
331	if strings.Contains(cmd, "+goose") {
332		return "", fmt.Errorf("%q contains multiple '+goose' annotations: %w", cmd, errInvalidAnnotation)
333	}
334
335	// Remove leading and trailing whitespace from the annotation command.
336	cmd = strings.TrimSpace(cmd)
337
338	if cmd == "" {
339		return "", errEmptyAnnotation
340	}
341
342	a := annotation(cmd)
343
344	for s := range supportedAnnotations {
345		if strings.EqualFold(string(s), string(a)) {
346			return s, nil
347		}
348	}
349
350	return "", fmt.Errorf("%q not supported: %w", cmd, errInvalidAnnotation)
351}
352
353func missingSemicolonError(state parserState, direction Direction, s string) error {
354	return fmt.Errorf("failed to parse migration: state %d, direction: %v: unexpected unfinished SQL query: %q: missing semicolon?",
355		state,
356		direction,
357		s,
358	)
359}
360
361type envWrapper struct{}
362
363var _ interpolate.Env = (*envWrapper)(nil)
364
365func (e *envWrapper) Get(key string) (string, bool) {
366	return os.LookupEnv(key)
367}
368
369func cleanupStatement(input string) string {
370	return strings.TrimSpace(input)
371}
372
373// Checks the line to see if the line has a statement-ending semicolon
374// or if the line contains a double-dash comment.
375func endsWithSemicolon(line string) bool {
376	scanBufPtr := bufferPool.Get().(*[]byte)
377	scanBuf := *scanBufPtr
378	defer bufferPool.Put(scanBufPtr)
379
380	prev := ""
381	scanner := bufio.NewScanner(strings.NewReader(line))
382	scanner.Buffer(scanBuf, scanBufSize)
383	scanner.Split(bufio.ScanWords)
384
385	for scanner.Scan() {
386		word := scanner.Text()
387		if strings.HasPrefix(word, "--") {
388			break
389		}
390		prev = word
391	}
392
393	return strings.HasSuffix(prev, ";")
394}