1package db
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8
9 "github.com/charmbracelet/log"
10 "github.com/charmbracelet/soft-serve/server/config"
11 "github.com/jmoiron/sqlx"
12 _ "modernc.org/sqlite" // sqlite driver
13)
14
15// DB is the interface for a Soft Serve database.
16type DB struct {
17 *sqlx.DB
18 logger *log.Logger
19}
20
21// Open opens a database connection.
22func Open(ctx context.Context, driverName string, dsn string) (*DB, error) {
23 db, err := sqlx.ConnectContext(ctx, driverName, dsn)
24 if err != nil {
25 return nil, err
26 }
27
28 d := &DB{
29 DB: db,
30 }
31
32 if config.IsVerbose() {
33 logger := log.FromContext(ctx).WithPrefix("db")
34 d.logger = logger
35 }
36
37 return d, nil
38}
39
40// Close implements db.DB.
41func (d *DB) Close() error {
42 return d.DB.Close()
43}
44
45// Tx is a database transaction.
46type Tx struct {
47 *sqlx.Tx
48 logger *log.Logger
49}
50
51// Transaction implements db.DB.
52func (d *DB) Transaction(fn func(tx *Tx) error) error {
53 return d.TransactionContext(context.Background(), fn)
54}
55
56// TransactionContext implements db.DB.
57func (d *DB) TransactionContext(ctx context.Context, fn func(tx *Tx) error) error {
58 txx, err := d.DB.BeginTxx(ctx, nil)
59 if err != nil {
60 return fmt.Errorf("failed to begin transaction: %w", err)
61 }
62
63 tx := &Tx{txx, d.logger}
64 if err := fn(tx); err != nil {
65 return rollback(tx, err)
66 }
67
68 if err := tx.Commit(); err != nil {
69 if errors.Is(err, sql.ErrTxDone) {
70 // this is ok because whoever did finish the tx should have also written the error already.
71 return nil
72 }
73 return fmt.Errorf("failed to commit transaction: %w", err)
74 }
75
76 return nil
77}
78
79func rollback(tx *Tx, err error) error {
80 if rerr := tx.Rollback(); rerr != nil {
81 if errors.Is(rerr, sql.ErrTxDone) {
82 return err
83 }
84 return fmt.Errorf("failed to rollback: %s: %w", err.Error(), rerr)
85 }
86
87 return err
88}