1package sqlite
 2
 3import (
 4	"context"
 5	"database/sql"
 6	"errors"
 7	"fmt"
 8
 9	"github.com/charmbracelet/soft-serve/server/backend"
10	"github.com/jmoiron/sqlx"
11	"modernc.org/sqlite"
12	sqlite3 "modernc.org/sqlite/lib"
13)
14
15// Close closes the database.
16func (d *SqliteBackend) Close() error {
17	return d.db.Close()
18}
19
20// init creates the database.
21func (d *SqliteBackend) init() error {
22	return wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
23		if _, err := tx.Exec(sqlCreateSettingsTable); err != nil {
24			return err
25		}
26		if _, err := tx.Exec(sqlCreateUserTable); err != nil {
27			return err
28		}
29		if _, err := tx.Exec(sqlCreatePublicKeyTable); err != nil {
30			return err
31		}
32		if _, err := tx.Exec(sqlCreateRepoTable); err != nil {
33			return err
34		}
35		if _, err := tx.Exec(sqlCreateCollabTable); err != nil {
36			return err
37		}
38
39		// Set default settings.
40		if _, err := tx.Exec("INSERT OR IGNORE INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)", "allow_keyless", true); err != nil {
41			return err
42		}
43		if _, err := tx.Exec("INSERT OR IGNORE INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)", "anon_access", backend.ReadOnlyAccess.String()); err != nil {
44			return err
45		}
46
47		return nil
48	})
49}
50
51func wrapDbErr(err error) error {
52	if err != nil {
53		if errors.Is(err, sql.ErrNoRows) {
54			return ErrNoRecord
55		}
56		if liteErr, ok := err.(*sqlite.Error); ok {
57			code := liteErr.Code()
58			if code == sqlite3.SQLITE_CONSTRAINT_PRIMARYKEY ||
59				code == sqlite3.SQLITE_CONSTRAINT_UNIQUE {
60				return ErrDuplicateKey
61			}
62		}
63	}
64	return err
65}
66
67func wrapTx(db *sqlx.DB, ctx context.Context, fn func(tx *sqlx.Tx) error) error {
68	tx, err := db.BeginTxx(ctx, nil)
69	if err != nil {
70		return fmt.Errorf("failed to begin transaction: %w", err)
71	}
72
73	if err := fn(tx); err != nil {
74		return rollback(tx, err)
75	}
76
77	if err := tx.Commit(); err != nil {
78		if errors.Is(err, sql.ErrTxDone) {
79			// this is ok because whoever did finish the tx should have also written the error already.
80			return nil
81		}
82		return fmt.Errorf("failed to commit transaction: %w", err)
83	}
84
85	return nil
86}
87
88func rollback(tx *sqlx.Tx, err error) error {
89	if rerr := tx.Rollback(); rerr != nil {
90		if errors.Is(rerr, sql.ErrTxDone) {
91			return err
92		}
93		return fmt.Errorf("failed to rollback: %s: %w", err.Error(), rerr)
94	}
95
96	return err
97}