db.go

  1package sqlite
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"errors"
  7	"fmt"
  8
  9	"github.com/charmbracelet/soft-serve/server/backend"
 10	"github.com/charmbracelet/soft-serve/server/sshutils"
 11	"github.com/jmoiron/sqlx"
 12	"modernc.org/sqlite"
 13	sqlite3 "modernc.org/sqlite/lib"
 14)
 15
 16// Close closes the database.
 17func (d *SqliteBackend) Close() error {
 18	return d.db.Close()
 19}
 20
 21// init creates the database.
 22func (d *SqliteBackend) init() error {
 23	return wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
 24		if _, err := tx.Exec(sqlCreateSettingsTable); err != nil {
 25			return err
 26		}
 27		if _, err := tx.Exec(sqlCreateUserTable); err != nil {
 28			return err
 29		}
 30		if _, err := tx.Exec(sqlCreatePublicKeyTable); err != nil {
 31			return err
 32		}
 33		if _, err := tx.Exec(sqlCreateRepoTable); err != nil {
 34			return err
 35		}
 36		if _, err := tx.Exec(sqlCreateCollabTable); err != nil {
 37			return err
 38		}
 39
 40		// Set default settings.
 41		if _, err := tx.Exec("INSERT OR IGNORE INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)", "allow_keyless", true); err != nil {
 42			return err
 43		}
 44		if _, err := tx.Exec("INSERT OR IGNORE INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)", "anon_access", backend.ReadOnlyAccess.String()); err != nil {
 45			return err
 46		}
 47
 48		var init bool
 49		if err := tx.Get(&init, "SELECT value FROM settings WHERE key = 'init'"); err != nil && !errors.Is(err, sql.ErrNoRows) {
 50			return err
 51		}
 52
 53		// Create default user.
 54		if !init {
 55			r, err := tx.Exec("INSERT OR IGNORE INTO user (username, admin, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP);", "admin", true)
 56			if err != nil {
 57				return err
 58			}
 59			userID, err := r.LastInsertId()
 60			if err != nil {
 61				return err
 62			}
 63
 64			// Add initial keys
 65			// Don't use cfg.AdminKeys since it also includes the internal key
 66			// used for internal api access.
 67			for _, k := range d.cfg.InitialAdminKeys {
 68				pk, _, err := sshutils.ParseAuthorizedKey(k)
 69				if err != nil {
 70					d.logger.Error("error parsing initial admin key, skipping", "key", k, "err", err)
 71					continue
 72				}
 73
 74				stmt, err := tx.Prepare(`INSERT INTO public_key (user_id, public_key, updated_at)
 75					VALUES (?, ?, CURRENT_TIMESTAMP);`)
 76				if err != nil {
 77					return err
 78				}
 79
 80				defer stmt.Close() // nolint: errcheck
 81				if _, err := stmt.Exec(userID, sshutils.MarshalAuthorizedKey(pk)); err != nil {
 82					return err
 83				}
 84			}
 85		}
 86
 87		// set init flag
 88		if _, err := tx.Exec("INSERT OR IGNORE INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)", "init", true); err != nil {
 89			return err
 90		}
 91
 92		return nil
 93	})
 94}
 95
 96func wrapDbErr(err error) error {
 97	if err != nil {
 98		if errors.Is(err, sql.ErrNoRows) {
 99			return ErrNoRecord
100		}
101		if liteErr, ok := err.(*sqlite.Error); ok {
102			code := liteErr.Code()
103			if code == sqlite3.SQLITE_CONSTRAINT_PRIMARYKEY ||
104				code == sqlite3.SQLITE_CONSTRAINT_UNIQUE {
105				return ErrDuplicateKey
106			}
107		}
108	}
109	return err
110}
111
112func wrapTx(db *sqlx.DB, ctx context.Context, fn func(tx *sqlx.Tx) error) error {
113	tx, err := db.BeginTxx(ctx, nil)
114	if err != nil {
115		return fmt.Errorf("failed to begin transaction: %w", err)
116	}
117
118	if err := fn(tx); err != nil {
119		return rollback(tx, err)
120	}
121
122	if err := tx.Commit(); err != nil {
123		if errors.Is(err, sql.ErrTxDone) {
124			// this is ok because whoever did finish the tx should have also written the error already.
125			return nil
126		}
127		return fmt.Errorf("failed to commit transaction: %w", err)
128	}
129
130	return nil
131}
132
133func rollback(tx *sqlx.Tx, err error) error {
134	if rerr := tx.Rollback(); rerr != nil {
135		if errors.Is(rerr, sql.ErrTxDone) {
136			return err
137		}
138		return fmt.Errorf("failed to rollback: %s: %w", err.Error(), rerr)
139	}
140
141	return err
142}