0001_create_tables.go

  1// Package migrate provides database migration functionality.
  2package migrate
  3
  4import (
  5	"context"
  6	"errors"
  7	"fmt"
  8	"strconv"
  9
 10	"github.com/charmbracelet/soft-serve/pkg/access"
 11	"github.com/charmbracelet/soft-serve/pkg/config"
 12	"github.com/charmbracelet/soft-serve/pkg/db"
 13	"github.com/charmbracelet/soft-serve/pkg/sshutils"
 14)
 15
 16const (
 17	createTablesName    = "create tables"
 18	createTablesVersion = 1
 19)
 20
 21var createTables = Migration{
 22	Version: createTablesVersion,
 23	Name:    createTablesName,
 24	Migrate: func(ctx context.Context, tx *db.Tx) error {
 25		cfg := config.FromContext(ctx)
 26
 27		insert := "INSERT "
 28
 29		// Alter old tables (if exist)
 30		// This is to support prior versions of Soft Serve v0.6
 31		switch tx.DriverName() {
 32		case "sqlite3", "sqlite":
 33			insert += "OR IGNORE "
 34
 35			hasUserTable := hasTable(tx, "user")
 36			if hasUserTable {
 37				if _, err := tx.ExecContext(ctx, "ALTER TABLE user RENAME TO user_old"); err != nil {
 38					return err //nolint:wrapcheck
 39				}
 40			}
 41
 42			if hasTable(tx, "public_key") {
 43				if _, err := tx.ExecContext(ctx, "ALTER TABLE public_key RENAME TO public_key_old"); err != nil {
 44					return err //nolint:wrapcheck
 45				}
 46			}
 47
 48			if hasTable(tx, "collab") {
 49				if _, err := tx.ExecContext(ctx, "ALTER TABLE collab RENAME TO collab_old"); err != nil {
 50					return err //nolint:wrapcheck
 51				}
 52			}
 53
 54			if hasTable(tx, "repo") {
 55				if _, err := tx.ExecContext(ctx, "ALTER TABLE repo RENAME TO repo_old"); err != nil {
 56					return err //nolint:wrapcheck
 57				}
 58			}
 59		}
 60
 61		if err := migrateUp(ctx, tx, createTablesVersion, createTablesName); err != nil {
 62			return err
 63		}
 64
 65		switch tx.DriverName() {
 66		case "sqlite3", "sqlite":
 67
 68			if _, err := tx.ExecContext(ctx, "PRAGMA foreign_keys = OFF"); err != nil {
 69				return err //nolint:wrapcheck
 70			}
 71
 72			if hasTable(tx, "user_old") {
 73				sqlm := `
 74				INSERT INTO users (id, username, admin, updated_at)
 75					SELECT id, username, admin, updated_at FROM user_old;
 76				`
 77				if _, err := tx.ExecContext(ctx, sqlm); err != nil {
 78					return err //nolint:wrapcheck
 79				}
 80			}
 81
 82			if hasTable(tx, "public_key_old") {
 83				// Check duplicate keys
 84				pks := []struct {
 85					ID        string `db:"id"`
 86					PublicKey string `db:"public_key"`
 87				}{}
 88				if err := tx.SelectContext(ctx, &pks, "SELECT id, public_key FROM public_key_old"); err != nil {
 89					return err //nolint:wrapcheck
 90				}
 91
 92				pkss := map[string]struct{}{}
 93				for _, pk := range pks {
 94					if _, ok := pkss[pk.PublicKey]; ok {
 95						return fmt.Errorf("duplicate public key: %q, please remove the duplicate key and try again", pk.PublicKey)
 96					}
 97					pkss[pk.PublicKey] = struct{}{}
 98				}
 99
100				sqlm := `
101				INSERT INTO public_keys (id, user_id, public_key, created_at, updated_at)
102					SELECT id, user_id, public_key, created_at, updated_at FROM public_key_old;
103				`
104				if _, err := tx.ExecContext(ctx, sqlm); err != nil {
105					return err //nolint:wrapcheck
106				}
107			}
108
109			if hasTable(tx, "repo_old") {
110				sqlm := `
111				INSERT INTO repos (id, name, project_name, description, private,mirror, hidden, created_at, updated_at, user_id)
112					SELECT id, name, project_name, description, private, mirror, hidden, created_at, updated_at, (
113						SELECT id FROM users WHERE admin = true ORDER BY id LIMIT 1
114				) FROM repo_old;
115				`
116				if _, err := tx.ExecContext(ctx, sqlm); err != nil {
117					return err //nolint:wrapcheck
118				}
119			}
120
121			if hasTable(tx, "collab_old") {
122				sqlm := `
123				INSERT INTO collabs (id, user_id, repo_id, access_level, created_at, updated_at)
124					SELECT id, user_id, repo_id, ` + strconv.Itoa(int(access.ReadWriteAccess)) + `, created_at, updated_at FROM collab_old;
125				`
126				if _, err := tx.ExecContext(ctx, sqlm); err != nil {
127					return err //nolint:wrapcheck
128				}
129			}
130
131			if _, err := tx.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil {
132				return err //nolint:wrapcheck
133			}
134		}
135
136		// Insert default user
137		insertUser := tx.Rebind(insert + "INTO users (username, admin, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)")
138		if _, err := tx.ExecContext(ctx, insertUser, "admin", true); err != nil {
139			return err //nolint:wrapcheck
140		}
141
142		for _, k := range cfg.AdminKeys() {
143			query := insert + "INTO public_keys (user_id, public_key, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)"
144			if tx.DriverName() == "postgres" {
145				query += " ON CONFLICT DO NOTHING"
146			}
147
148			query = tx.Rebind(query)
149			ak := sshutils.MarshalAuthorizedKey(k)
150			if _, err := tx.ExecContext(ctx, query, 1, ak); err != nil {
151				if errors.Is(db.WrapError(err), db.ErrDuplicateKey) {
152					continue
153				}
154				return err //nolint:wrapcheck
155			}
156		}
157
158		// Insert default settings
159		insertSettings := insert + "INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)"
160		insertSettings = tx.Rebind(insertSettings)
161		settings := []struct {
162			Key   string
163			Value string
164		}{
165			{"allow_keyless", "true"},
166			{"anon_access", access.ReadOnlyAccess.String()},
167			{"init", "true"},
168		}
169
170		for _, s := range settings {
171			if _, err := tx.ExecContext(ctx, insertSettings, s.Key, s.Value); err != nil {
172				return fmt.Errorf("inserting default settings %q: %w", s.Key, err)
173			}
174		}
175
176		return nil
177	},
178	Rollback: func(ctx context.Context, tx *db.Tx) error {
179		return migrateDown(ctx, tx, createTablesVersion, createTablesName)
180	},
181}