1package migrate
2
3import (
4 "context"
5 "embed"
6 "fmt"
7 "regexp"
8 "strings"
9
10 "github.com/charmbracelet/soft-serve/server/db"
11)
12
13//go:embed *.sql
14var sqls embed.FS
15
16// Keep this in order of execution, oldest to newest.
17var migrations = []Migration{
18 createTables,
19 webhooks,
20}
21
22func execMigration(ctx context.Context, tx *db.Tx, version int, name string, down bool) error {
23 direction := "up"
24 if down {
25 direction = "down"
26 }
27
28 driverName := tx.DriverName()
29 if driverName == "sqlite3" {
30 driverName = "sqlite"
31 }
32
33 fn := fmt.Sprintf("%04d_%s_%s.%s.sql", version, toSnakeCase(name), driverName, direction)
34 sqlstr, err := sqls.ReadFile(fn)
35 if err != nil {
36 return err
37 }
38
39 if _, err := tx.ExecContext(ctx, string(sqlstr)); err != nil {
40 return err
41 }
42
43 return nil
44}
45
46func migrateUp(ctx context.Context, tx *db.Tx, version int, name string) error {
47 return execMigration(ctx, tx, version, name, false)
48}
49
50func migrateDown(ctx context.Context, tx *db.Tx, version int, name string) error {
51 return execMigration(ctx, tx, version, name, true)
52}
53
54var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)")
55var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])")
56
57func toSnakeCase(str string) string {
58 str = strings.ReplaceAll(str, "-", "_")
59 str = strings.ReplaceAll(str, " ", "_")
60 snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}")
61 snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}")
62 return strings.ToLower(snake)
63}