1package migrate
2
3import (
4 "context"
5 "embed"
6 "fmt"
7 "regexp"
8 "strings"
9
10 "github.com/charmbracelet/soft-serve/pkg/db"
11)
12
13//go:embed *.sql
14var sqls embed.FS
15
16// Keep this in order of execution, oldest to newest.
17var migrations = []Migration{
18 createTables,
19 webhooks,
20 migrateLfsObjects,
21 createOrgsTeams,
22}
23
24func execMigration(ctx context.Context, h db.Handler, version int, name string, down bool) error {
25 direction := "up"
26 if down {
27 direction = "down"
28 }
29
30 driverName := h.DriverName()
31 if driverName == "sqlite3" {
32 driverName = "sqlite"
33 }
34
35 fn := fmt.Sprintf("%04d_%s_%s.%s.sql", version, toSnakeCase(name), driverName, direction)
36 sqlstr, err := sqls.ReadFile(fn)
37 if err != nil {
38 return err
39 }
40
41 if _, err := h.ExecContext(ctx, string(sqlstr)); err != nil {
42 return err
43 }
44
45 return nil
46}
47
48func migrateUp(ctx context.Context, h db.Handler, version int, name string) error {
49 return execMigration(ctx, h, version, name, false)
50}
51
52func migrateDown(ctx context.Context, h db.Handler, version int, name string) error {
53 return execMigration(ctx, h, version, name, true)
54}
55
56var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)")
57var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])")
58
59func toSnakeCase(str string) string {
60 str = strings.ReplaceAll(str, "-", "_")
61 str = strings.ReplaceAll(str, " ", "_")
62 snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}")
63 snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}")
64 return strings.ToLower(snake)
65}