1package db
2
3import (
4 "database/sql"
5 "embed"
6 "errors"
7 "fmt"
8 "log/slog"
9 "regexp"
10 "sort"
11 "strconv"
12
13 _ "modernc.org/sqlite"
14)
15
16//go:generate go tool github.com/sqlc-dev/sqlc/cmd/sqlc generate
17
18//go:embed migrations/*.sql
19var migrationFS embed.FS
20
21// Open opens an sqlite database and prepares pragmas suitable for a small web app.
22func Open(path string) (*sql.DB, error) {
23 db, err := sql.Open("sqlite", path)
24 if err != nil {
25 return nil, err
26 }
27 // Light pragmas similar
28 if _, err := db.Exec("PRAGMA foreign_keys=ON;"); err != nil {
29 _ = db.Close()
30 return nil, fmt.Errorf("enable foreign keys: %w", err)
31 }
32 if _, err := db.Exec("PRAGMA journal_mode=wal;"); err != nil {
33 _ = db.Close()
34 return nil, fmt.Errorf("set WAL: %w", err)
35 }
36 if _, err := db.Exec("PRAGMA busy_timeout=1000;"); err != nil {
37 _ = db.Close()
38 return nil, fmt.Errorf("set busy_timeout: %w", err)
39 }
40 return db, nil
41}
42
43// RunMigrations executes database migrations in numeric order (NNN-*.sql),
44// similar in spirit to exed's exedb.RunMigrations.
45func RunMigrations(db *sql.DB) error {
46 entries, err := migrationFS.ReadDir("migrations")
47 if err != nil {
48 return fmt.Errorf("read migrations dir: %w", err)
49 }
50 var migrations []string
51 pat := regexp.MustCompile(`^(\d{3})-.*\.sql$`)
52 for _, e := range entries {
53 if e.IsDir() {
54 continue
55 }
56 name := e.Name()
57 if pat.MatchString(name) {
58 migrations = append(migrations, name)
59 }
60 }
61 sort.Strings(migrations)
62
63 executed := make(map[int]bool)
64 var tableName string
65 err = db.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations'").Scan(&tableName)
66 switch {
67 case err == nil:
68 rows, err := db.Query("SELECT migration_number FROM migrations")
69 if err != nil {
70 return fmt.Errorf("query executed migrations: %w", err)
71 }
72 defer rows.Close()
73 for rows.Next() {
74 var n int
75 if err := rows.Scan(&n); err != nil {
76 return fmt.Errorf("scan migration number: %w", err)
77 }
78 executed[n] = true
79 }
80 case errors.Is(err, sql.ErrNoRows):
81 slog.Info("db: migrations table not found; running all migrations")
82 default:
83 return fmt.Errorf("check migrations table: %w", err)
84 }
85
86 for _, m := range migrations {
87 match := pat.FindStringSubmatch(m)
88 if len(match) != 2 {
89 return fmt.Errorf("invalid migration filename: %s", m)
90 }
91 n, err := strconv.Atoi(match[1])
92 if err != nil {
93 return fmt.Errorf("parse migration number %s: %w", m, err)
94 }
95 if executed[n] {
96 continue
97 }
98 if err := executeMigration(db, m); err != nil {
99 return fmt.Errorf("execute %s: %w", m, err)
100 }
101 slog.Info("db: applied migration", "file", m, "number", n)
102 }
103 return nil
104}
105
106func executeMigration(db *sql.DB, filename string) error {
107 content, err := migrationFS.ReadFile("migrations/" + filename)
108 if err != nil {
109 return fmt.Errorf("read %s: %w", filename, err)
110 }
111 if _, err := db.Exec(string(content)); err != nil {
112 return fmt.Errorf("exec %s: %w", filename, err)
113 }
114 return nil
115}