1package migrate
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "strconv"
8
9 "github.com/charmbracelet/soft-serve/pkg/access"
10 "github.com/charmbracelet/soft-serve/pkg/config"
11 "github.com/charmbracelet/soft-serve/pkg/db"
12 "github.com/charmbracelet/soft-serve/pkg/sshutils"
13)
14
15const (
16 createTablesName = "create tables"
17 createTablesVersion = 1
18
19 // Database driver names
20 driverSQLite3 = "sqlite3"
21 driverSQLite = "sqlite"
22 driverPostgres = "postgres"
23)
24
25var createTables = Migration{
26 Version: createTablesVersion,
27 Name: createTablesName,
28 Migrate: func(ctx context.Context, tx *db.Tx) error {
29 cfg := config.FromContext(ctx)
30
31 insert := "INSERT "
32
33 // Alter old tables (if exist)
34 // This is to support prior versions of Soft Serve v0.6
35 switch tx.DriverName() {
36 case driverSQLite3, driverSQLite:
37 insert += "OR IGNORE "
38
39 hasUserTable := hasTable(tx, "user")
40 if hasUserTable {
41 if _, err := tx.ExecContext(ctx, "ALTER TABLE user RENAME TO user_old"); err != nil {
42 return err
43 }
44 }
45
46 if hasTable(tx, "public_key") {
47 if _, err := tx.ExecContext(ctx, "ALTER TABLE public_key RENAME TO public_key_old"); err != nil {
48 return err
49 }
50 }
51
52 if hasTable(tx, "collab") {
53 if _, err := tx.ExecContext(ctx, "ALTER TABLE collab RENAME TO collab_old"); err != nil {
54 return err
55 }
56 }
57
58 if hasTable(tx, "repo") {
59 if _, err := tx.ExecContext(ctx, "ALTER TABLE repo RENAME TO repo_old"); err != nil {
60 return err
61 }
62 }
63 }
64
65 if err := migrateUp(ctx, tx, createTablesVersion, createTablesName); err != nil {
66 return err
67 }
68
69 switch tx.DriverName() {
70 case driverSQLite3, driverSQLite:
71
72 if _, err := tx.ExecContext(ctx, "PRAGMA foreign_keys = OFF"); err != nil {
73 return err
74 }
75
76 if hasTable(tx, "user_old") {
77 sqlm := `
78 INSERT INTO users (id, username, admin, updated_at)
79 SELECT id, username, admin, updated_at FROM user_old;
80 `
81 if _, err := tx.ExecContext(ctx, sqlm); err != nil {
82 return err
83 }
84 }
85
86 if hasTable(tx, "public_key_old") {
87 // Check duplicate keys
88 pks := []struct {
89 ID string `db:"id"`
90 PublicKey string `db:"public_key"`
91 }{}
92 if err := tx.SelectContext(ctx, &pks, "SELECT id, public_key FROM public_key_old"); err != nil {
93 return err
94 }
95
96 pkss := map[string]struct{}{}
97 for _, pk := range pks {
98 if _, ok := pkss[pk.PublicKey]; ok {
99 return fmt.Errorf("duplicate public key: %q, please remove the duplicate key and try again", pk.PublicKey)
100 }
101 pkss[pk.PublicKey] = struct{}{}
102 }
103
104 sqlm := `
105 INSERT INTO public_keys (id, user_id, public_key, created_at, updated_at)
106 SELECT id, user_id, public_key, created_at, updated_at FROM public_key_old;
107 `
108 if _, err := tx.ExecContext(ctx, sqlm); err != nil {
109 return err
110 }
111 }
112
113 if hasTable(tx, "repo_old") {
114 sqlm := `
115 INSERT INTO repos (id, name, project_name, description, private,mirror, hidden, created_at, updated_at, user_id)
116 SELECT id, name, project_name, description, private, mirror, hidden, created_at, updated_at, (
117 SELECT id FROM users WHERE admin = true ORDER BY id LIMIT 1
118 ) FROM repo_old;
119 `
120 if _, err := tx.ExecContext(ctx, sqlm); err != nil {
121 return err
122 }
123 }
124
125 if hasTable(tx, "collab_old") {
126 sqlm := `
127 INSERT INTO collabs (id, user_id, repo_id, access_level, created_at, updated_at)
128 SELECT id, user_id, repo_id, ` + strconv.Itoa(int(access.ReadWriteAccess)) + `, created_at, updated_at FROM collab_old;
129 `
130 if _, err := tx.ExecContext(ctx, sqlm); err != nil {
131 return err
132 }
133 }
134
135 if _, err := tx.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil {
136 return err
137 }
138 }
139
140 // Insert default user
141 insertUser := tx.Rebind(insert + "INTO users (username, admin, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)")
142 if _, err := tx.ExecContext(ctx, insertUser, "admin", true); err != nil {
143 return err
144 }
145
146 for _, k := range cfg.AdminKeys() {
147 query := insert + "INTO public_keys (user_id, public_key, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)"
148 if tx.DriverName() == driverPostgres {
149 query += " ON CONFLICT DO NOTHING"
150 }
151
152 query = tx.Rebind(query)
153 ak := sshutils.MarshalAuthorizedKey(k)
154 if _, err := tx.ExecContext(ctx, query, 1, ak); err != nil {
155 if errors.Is(db.WrapError(err), db.ErrDuplicateKey) {
156 continue
157 }
158 return err
159 }
160 }
161
162 // Insert default settings
163 insertSettings := insert + "INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)"
164 insertSettings = tx.Rebind(insertSettings)
165 settings := []struct {
166 Key string
167 Value string
168 }{
169 {"allow_keyless", "true"},
170 {"anon_access", access.ReadOnlyAccess.String()},
171 {"init", "true"},
172 }
173
174 for _, s := range settings {
175 if _, err := tx.ExecContext(ctx, insertSettings, s.Key, s.Value); err != nil {
176 return fmt.Errorf("inserting default settings %q: %w", s.Key, err)
177 }
178 }
179
180 return nil
181 },
182 Rollback: func(ctx context.Context, tx *db.Tx) error {
183 return migrateDown(ctx, tx, createTablesVersion, createTablesName)
184 },
185}