1// Package migrate provides database migration functionality.
2package migrate
3
4import (
5 "context"
6 "errors"
7 "fmt"
8 "strconv"
9
10 "github.com/charmbracelet/soft-serve/pkg/access"
11 "github.com/charmbracelet/soft-serve/pkg/config"
12 "github.com/charmbracelet/soft-serve/pkg/db"
13 "github.com/charmbracelet/soft-serve/pkg/sshutils"
14)
15
16const (
17 createTablesName = "create tables"
18 createTablesVersion = 1
19)
20
21var createTables = Migration{
22 Version: createTablesVersion,
23 Name: createTablesName,
24 Migrate: func(ctx context.Context, tx *db.Tx) error {
25 cfg := config.FromContext(ctx)
26
27 insert := "INSERT "
28
29 // Alter old tables (if exist)
30 // This is to support prior versions of Soft Serve v0.6
31 switch tx.DriverName() {
32 case "sqlite3", "sqlite":
33 insert += "OR IGNORE "
34
35 hasUserTable := hasTable(tx, "user")
36 if hasUserTable {
37 if _, err := tx.ExecContext(ctx, "ALTER TABLE user RENAME TO user_old"); err != nil {
38 return err //nolint:wrapcheck
39 }
40 }
41
42 if hasTable(tx, "public_key") {
43 if _, err := tx.ExecContext(ctx, "ALTER TABLE public_key RENAME TO public_key_old"); err != nil {
44 return err //nolint:wrapcheck
45 }
46 }
47
48 if hasTable(tx, "collab") {
49 if _, err := tx.ExecContext(ctx, "ALTER TABLE collab RENAME TO collab_old"); err != nil {
50 return err //nolint:wrapcheck
51 }
52 }
53
54 if hasTable(tx, "repo") {
55 if _, err := tx.ExecContext(ctx, "ALTER TABLE repo RENAME TO repo_old"); err != nil {
56 return err //nolint:wrapcheck
57 }
58 }
59 }
60
61 if err := migrateUp(ctx, tx, createTablesVersion, createTablesName); err != nil {
62 return err
63 }
64
65 switch tx.DriverName() {
66 case "sqlite3", "sqlite":
67
68 if _, err := tx.ExecContext(ctx, "PRAGMA foreign_keys = OFF"); err != nil {
69 return err //nolint:wrapcheck
70 }
71
72 if hasTable(tx, "user_old") {
73 sqlm := `
74 INSERT INTO users (id, username, admin, updated_at)
75 SELECT id, username, admin, updated_at FROM user_old;
76 `
77 if _, err := tx.ExecContext(ctx, sqlm); err != nil {
78 return err //nolint:wrapcheck
79 }
80 }
81
82 if hasTable(tx, "public_key_old") {
83 // Check duplicate keys
84 pks := []struct {
85 ID string `db:"id"`
86 PublicKey string `db:"public_key"`
87 }{}
88 if err := tx.SelectContext(ctx, &pks, "SELECT id, public_key FROM public_key_old"); err != nil {
89 return err //nolint:wrapcheck
90 }
91
92 pkss := map[string]struct{}{}
93 for _, pk := range pks {
94 if _, ok := pkss[pk.PublicKey]; ok {
95 return fmt.Errorf("duplicate public key: %q, please remove the duplicate key and try again", pk.PublicKey)
96 }
97 pkss[pk.PublicKey] = struct{}{}
98 }
99
100 sqlm := `
101 INSERT INTO public_keys (id, user_id, public_key, created_at, updated_at)
102 SELECT id, user_id, public_key, created_at, updated_at FROM public_key_old;
103 `
104 if _, err := tx.ExecContext(ctx, sqlm); err != nil {
105 return err //nolint:wrapcheck
106 }
107 }
108
109 if hasTable(tx, "repo_old") {
110 sqlm := `
111 INSERT INTO repos (id, name, project_name, description, private,mirror, hidden, created_at, updated_at, user_id)
112 SELECT id, name, project_name, description, private, mirror, hidden, created_at, updated_at, (
113 SELECT id FROM users WHERE admin = true ORDER BY id LIMIT 1
114 ) FROM repo_old;
115 `
116 if _, err := tx.ExecContext(ctx, sqlm); err != nil {
117 return err //nolint:wrapcheck
118 }
119 }
120
121 if hasTable(tx, "collab_old") {
122 sqlm := `
123 INSERT INTO collabs (id, user_id, repo_id, access_level, created_at, updated_at)
124 SELECT id, user_id, repo_id, ` + strconv.Itoa(int(access.ReadWriteAccess)) + `, created_at, updated_at FROM collab_old;
125 `
126 if _, err := tx.ExecContext(ctx, sqlm); err != nil {
127 return err //nolint:wrapcheck
128 }
129 }
130
131 if _, err := tx.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil {
132 return err //nolint:wrapcheck
133 }
134 }
135
136 // Insert default user
137 insertUser := tx.Rebind(insert + "INTO users (username, admin, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)")
138 if _, err := tx.ExecContext(ctx, insertUser, "admin", true); err != nil {
139 return err //nolint:wrapcheck
140 }
141
142 for _, k := range cfg.AdminKeys() {
143 query := insert + "INTO public_keys (user_id, public_key, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)"
144 if tx.DriverName() == "postgres" {
145 query += " ON CONFLICT DO NOTHING"
146 }
147
148 query = tx.Rebind(query)
149 ak := sshutils.MarshalAuthorizedKey(k)
150 if _, err := tx.ExecContext(ctx, query, 1, ak); err != nil {
151 if errors.Is(db.WrapError(err), db.ErrDuplicateKey) {
152 continue
153 }
154 return err //nolint:wrapcheck
155 }
156 }
157
158 // Insert default settings
159 insertSettings := insert + "INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)"
160 insertSettings = tx.Rebind(insertSettings)
161 settings := []struct {
162 Key string
163 Value string
164 }{
165 {"allow_keyless", "true"},
166 {"anon_access", access.ReadOnlyAccess.String()},
167 {"init", "true"},
168 }
169
170 for _, s := range settings {
171 if _, err := tx.ExecContext(ctx, insertSettings, s.Key, s.Value); err != nil {
172 return fmt.Errorf("inserting default settings %q: %w", s.Key, err)
173 }
174 }
175
176 return nil
177 },
178 Rollback: func(ctx context.Context, tx *db.Tx) error {
179 return migrateDown(ctx, tx, createTablesVersion, createTablesName)
180 },
181}