1package migrate
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"strconv"
  8
  9	"github.com/charmbracelet/soft-serve/server/access"
 10	"github.com/charmbracelet/soft-serve/server/config"
 11	"github.com/charmbracelet/soft-serve/server/db"
 12	"github.com/charmbracelet/soft-serve/server/sshutils"
 13)
 14
 15const (
 16	createTablesName    = "create tables"
 17	createTablesVersion = 1
 18)
 19
 20var createTables = Migration{
 21	Version: createTablesVersion,
 22	Name:    createTablesName,
 23	Migrate: func(ctx context.Context, tx *db.Tx) error {
 24		cfg := config.FromContext(ctx)
 25
 26		insert := "INSERT "
 27
 28		// Alter old tables (if exist)
 29		// This is to support prior versions of Soft Serve v0.6
 30		switch tx.DriverName() {
 31		case "sqlite3", "sqlite":
 32			insert += "OR IGNORE "
 33
 34			hasUserTable := hasTable(tx, "user")
 35			if hasUserTable {
 36				if _, err := tx.ExecContext(ctx, "ALTER TABLE user RENAME TO user_old"); err != nil {
 37					return err
 38				}
 39			}
 40
 41			if hasTable(tx, "public_key") {
 42				if _, err := tx.ExecContext(ctx, "ALTER TABLE public_key RENAME TO public_key_old"); err != nil {
 43					return err
 44				}
 45			}
 46
 47			if hasTable(tx, "collab") {
 48				if _, err := tx.ExecContext(ctx, "ALTER TABLE collab RENAME TO collab_old"); err != nil {
 49					return err
 50				}
 51			}
 52
 53			if hasTable(tx, "repo") {
 54				if _, err := tx.ExecContext(ctx, "ALTER TABLE repo RENAME TO repo_old"); err != nil {
 55					return err
 56				}
 57			}
 58		}
 59
 60		if err := migrateUp(ctx, tx, createTablesVersion, createTablesName); err != nil {
 61			return err
 62		}
 63
 64		switch tx.DriverName() {
 65		case "sqlite3", "sqlite":
 66
 67			if _, err := tx.ExecContext(ctx, "PRAGMA foreign_keys = OFF"); err != nil {
 68				return err
 69			}
 70
 71			if hasTable(tx, "user_old") {
 72				sqlm := `
 73				INSERT INTO users (id, username, admin, updated_at)
 74					SELECT id, username, admin, updated_at FROM user_old;
 75				`
 76				if _, err := tx.ExecContext(ctx, sqlm); err != nil {
 77					return err
 78				}
 79			}
 80
 81			if hasTable(tx, "public_key_old") {
 82				// Check duplicate keys
 83				pks := []struct {
 84					ID        string `db:"id"`
 85					PublicKey string `db:"public_key"`
 86				}{}
 87				if err := tx.SelectContext(ctx, &pks, "SELECT id, public_key FROM public_key_old"); err != nil {
 88					return err
 89				}
 90
 91				pkss := map[string]struct{}{}
 92				for _, pk := range pks {
 93					if _, ok := pkss[pk.PublicKey]; ok {
 94						return fmt.Errorf("duplicate public key: %q, please remove the duplicate key and try again", pk.PublicKey)
 95					}
 96					pkss[pk.PublicKey] = struct{}{}
 97				}
 98
 99				sqlm := `
100				INSERT INTO public_keys (id, user_id, public_key, created_at, updated_at)
101					SELECT id, user_id, public_key, created_at, updated_at FROM public_key_old;
102				`
103				if _, err := tx.ExecContext(ctx, sqlm); err != nil {
104					return err
105				}
106			}
107
108			if hasTable(tx, "repo_old") {
109				sqlm := `
110				INSERT INTO repos (id, name, project_name, description, private,mirror, hidden, created_at, updated_at, user_id)
111					SELECT id, name, project_name, description, private, mirror, hidden, created_at, updated_at, (
112						SELECT id FROM users WHERE admin = true ORDER BY id LIMIT 1
113				) FROM repo_old;
114				`
115				if _, err := tx.ExecContext(ctx, sqlm); err != nil {
116					return err
117				}
118			}
119
120			if hasTable(tx, "collab_old") {
121				sqlm := `
122				INSERT INTO collabs (id, user_id, repo_id, access_level, created_at, updated_at)
123					SELECT id, user_id, repo_id, ` + strconv.Itoa(int(access.ReadWriteAccess)) + `, created_at, updated_at FROM collab_old;
124				`
125				if _, err := tx.ExecContext(ctx, sqlm); err != nil {
126					return err
127				}
128			}
129
130			if _, err := tx.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil {
131				return err
132			}
133		}
134
135		// Insert default user
136		insertUser := tx.Rebind(insert + "INTO users (username, admin, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)")
137		if _, err := tx.ExecContext(ctx, insertUser, "admin", true); err != nil {
138			return err
139		}
140
141		for _, k := range cfg.AdminKeys() {
142			query := insert + "INTO public_keys (user_id, public_key, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)"
143			if tx.DriverName() == "postgres" {
144				query += " ON CONFLICT DO NOTHING"
145			}
146
147			query = tx.Rebind(query)
148			ak := sshutils.MarshalAuthorizedKey(k)
149			if _, err := tx.ExecContext(ctx, query, 1, ak); err != nil {
150				if errors.Is(db.WrapError(err), db.ErrDuplicateKey) {
151					continue
152				}
153				return err
154			}
155		}
156
157		// Insert default settings
158		insertSettings := insert + "INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)"
159		insertSettings = tx.Rebind(insertSettings)
160		settings := []struct {
161			Key   string
162			Value string
163		}{
164			{"allow_keyless", "true"},
165			{"anon_access", access.ReadOnlyAccess.String()},
166			{"init", "true"},
167		}
168
169		for _, s := range settings {
170			if _, err := tx.ExecContext(ctx, insertSettings, s.Key, s.Value); err != nil {
171				return fmt.Errorf("inserting default settings %q: %w", s.Key, err)
172			}
173		}
174
175		return nil
176	},
177	Rollback: func(ctx context.Context, tx *db.Tx) error {
178		return migrateDown(ctx, tx, createTablesVersion, createTablesName)
179	},
180}