1package migrate
 2
 3import (
 4	"context"
 5	"embed"
 6	"fmt"
 7	"regexp"
 8	"strings"
 9
10	"github.com/charmbracelet/soft-serve/server/db"
11)
12
13//go:embed *.sql
14var sqls embed.FS
15
16// Keep this in order of execution, oldest to newest.
17var migrations = []Migration{
18	createTables,
19	createLFSTables,
20	passwordTokens,
21	repoOwner,
22}
23
24func execMigration(ctx context.Context, tx *db.Tx, version int, name string, down bool) error {
25	direction := "up"
26	if down {
27		direction = "down"
28	}
29
30	driverName := tx.DriverName()
31	if driverName == "sqlite3" {
32		driverName = "sqlite"
33	}
34
35	fn := fmt.Sprintf("%04d_%s_%s.%s.sql", version, toSnakeCase(name), driverName, direction)
36	sqlstr, err := sqls.ReadFile(fn)
37	if err != nil {
38		return err
39	}
40
41	if _, err := tx.ExecContext(ctx, string(sqlstr)); err != nil {
42		return err
43	}
44
45	return nil
46}
47
48func migrateUp(ctx context.Context, tx *db.Tx, version int, name string) error {
49	return execMigration(ctx, tx, version, name, false)
50}
51
52func migrateDown(ctx context.Context, tx *db.Tx, version int, name string) error {
53	return execMigration(ctx, tx, version, name, true)
54}
55
56var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)")
57var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])")
58
59func toSnakeCase(str string) string {
60	str = strings.ReplaceAll(str, "-", "_")
61	str = strings.ReplaceAll(str, " ", "_")
62	snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}")
63	snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}")
64	return strings.ToLower(snake)
65}