0001_create_tables.go

  1package migrate
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"strconv"
  8
  9	"github.com/charmbracelet/soft-serve/pkg/access"
 10	"github.com/charmbracelet/soft-serve/pkg/config"
 11	"github.com/charmbracelet/soft-serve/pkg/db"
 12	"github.com/charmbracelet/soft-serve/pkg/sshutils"
 13)
 14
 15const (
 16	createTablesName    = "create tables"
 17	createTablesVersion = 1
 18	
 19	// Database driver names
 20	driverSQLite3 = "sqlite3"
 21	driverSQLite = "sqlite"
 22	driverPostgres = "postgres"
 23)
 24
 25var createTables = Migration{
 26	Version: createTablesVersion,
 27	Name:    createTablesName,
 28	Migrate: func(ctx context.Context, tx *db.Tx) error {
 29		cfg := config.FromContext(ctx)
 30
 31		insert := "INSERT "
 32
 33		// Alter old tables (if exist)
 34		// This is to support prior versions of Soft Serve v0.6
 35		switch tx.DriverName() {
 36		case driverSQLite3, driverSQLite:
 37			insert += "OR IGNORE "
 38
 39			hasUserTable := hasTable(tx, "user")
 40			if hasUserTable {
 41				if _, err := tx.ExecContext(ctx, "ALTER TABLE user RENAME TO user_old"); err != nil {
 42					return err
 43				}
 44			}
 45
 46			if hasTable(tx, "public_key") {
 47				if _, err := tx.ExecContext(ctx, "ALTER TABLE public_key RENAME TO public_key_old"); err != nil {
 48					return err
 49				}
 50			}
 51
 52			if hasTable(tx, "collab") {
 53				if _, err := tx.ExecContext(ctx, "ALTER TABLE collab RENAME TO collab_old"); err != nil {
 54					return err
 55				}
 56			}
 57
 58			if hasTable(tx, "repo") {
 59				if _, err := tx.ExecContext(ctx, "ALTER TABLE repo RENAME TO repo_old"); err != nil {
 60					return err
 61				}
 62			}
 63		}
 64
 65		if err := migrateUp(ctx, tx, createTablesVersion, createTablesName); err != nil {
 66			return err
 67		}
 68
 69		switch tx.DriverName() {
 70		case driverSQLite3, driverSQLite:
 71
 72			if _, err := tx.ExecContext(ctx, "PRAGMA foreign_keys = OFF"); err != nil {
 73				return err
 74			}
 75
 76			if hasTable(tx, "user_old") {
 77				sqlm := `
 78				INSERT INTO users (id, username, admin, updated_at)
 79					SELECT id, username, admin, updated_at FROM user_old;
 80				`
 81				if _, err := tx.ExecContext(ctx, sqlm); err != nil {
 82					return err
 83				}
 84			}
 85
 86			if hasTable(tx, "public_key_old") {
 87				// Check duplicate keys
 88				pks := []struct {
 89					ID        string `db:"id"`
 90					PublicKey string `db:"public_key"`
 91				}{}
 92				if err := tx.SelectContext(ctx, &pks, "SELECT id, public_key FROM public_key_old"); err != nil {
 93					return err
 94				}
 95
 96				pkss := map[string]struct{}{}
 97				for _, pk := range pks {
 98					if _, ok := pkss[pk.PublicKey]; ok {
 99						return fmt.Errorf("duplicate public key: %q, please remove the duplicate key and try again", pk.PublicKey)
100					}
101					pkss[pk.PublicKey] = struct{}{}
102				}
103
104				sqlm := `
105				INSERT INTO public_keys (id, user_id, public_key, created_at, updated_at)
106					SELECT id, user_id, public_key, created_at, updated_at FROM public_key_old;
107				`
108				if _, err := tx.ExecContext(ctx, sqlm); err != nil {
109					return err
110				}
111			}
112
113			if hasTable(tx, "repo_old") {
114				sqlm := `
115				INSERT INTO repos (id, name, project_name, description, private,mirror, hidden, created_at, updated_at, user_id)
116					SELECT id, name, project_name, description, private, mirror, hidden, created_at, updated_at, (
117						SELECT id FROM users WHERE admin = true ORDER BY id LIMIT 1
118				) FROM repo_old;
119				`
120				if _, err := tx.ExecContext(ctx, sqlm); err != nil {
121					return err
122				}
123			}
124
125			if hasTable(tx, "collab_old") {
126				sqlm := `
127				INSERT INTO collabs (id, user_id, repo_id, access_level, created_at, updated_at)
128					SELECT id, user_id, repo_id, ` + strconv.Itoa(int(access.ReadWriteAccess)) + `, created_at, updated_at FROM collab_old;
129				`
130				if _, err := tx.ExecContext(ctx, sqlm); err != nil {
131					return err
132				}
133			}
134
135			if _, err := tx.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil {
136				return err
137			}
138		}
139
140		// Insert default user
141		insertUser := tx.Rebind(insert + "INTO users (username, admin, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)")
142		if _, err := tx.ExecContext(ctx, insertUser, "admin", true); err != nil {
143			return err
144		}
145
146		for _, k := range cfg.AdminKeys() {
147			query := insert + "INTO public_keys (user_id, public_key, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)"
148			if tx.DriverName() == driverPostgres {
149				query += " ON CONFLICT DO NOTHING"
150			}
151
152			query = tx.Rebind(query)
153			ak := sshutils.MarshalAuthorizedKey(k)
154			if _, err := tx.ExecContext(ctx, query, 1, ak); err != nil {
155				if errors.Is(db.WrapError(err), db.ErrDuplicateKey) {
156					continue
157				}
158				return err
159			}
160		}
161
162		// Insert default settings
163		insertSettings := insert + "INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)"
164		insertSettings = tx.Rebind(insertSettings)
165		settings := []struct {
166			Key   string
167			Value string
168		}{
169			{"allow_keyless", "true"},
170			{"anon_access", access.ReadOnlyAccess.String()},
171			{"init", "true"},
172		}
173
174		for _, s := range settings {
175			if _, err := tx.ExecContext(ctx, insertSettings, s.Key, s.Value); err != nil {
176				return fmt.Errorf("inserting default settings %q: %w", s.Key, err)
177			}
178		}
179
180		return nil
181	},
182	Rollback: func(ctx context.Context, tx *db.Tx) error {
183		return migrateDown(ctx, tx, createTablesVersion, createTablesName)
184	},
185}