1package db
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8
9 "github.com/charmbracelet/log/v2"
10 "github.com/charmbracelet/soft-serve/pkg/config"
11 "github.com/jmoiron/sqlx"
12 _ "github.com/lib/pq" // postgres driver
13 _ "modernc.org/sqlite" // sqlite driver
14)
15
16// DB is the interface for a Soft Serve database.
17type DB struct {
18 *sqlx.DB
19 logger *log.Logger
20}
21
22// Open opens a database connection.
23func Open(ctx context.Context, driverName string, dsn string) (*DB, error) {
24 db, err := sqlx.ConnectContext(ctx, driverName, dsn)
25 if err != nil {
26 return nil, err //nolint:wrapcheck
27 }
28
29 d := &DB{
30 DB: db,
31 }
32
33 if config.IsVerbose() {
34 logger := log.FromContext(ctx).WithPrefix("db")
35 d.logger = logger
36 }
37
38 return d, nil
39}
40
41// Close implements db.DB.
42func (d *DB) Close() error {
43 return d.DB.Close() //nolint:wrapcheck
44}
45
46// Tx is a database transaction.
47type Tx struct {
48 *sqlx.Tx
49 logger *log.Logger
50}
51
52// Transaction implements db.DB.
53func (d *DB) Transaction(fn func(tx *Tx) error) error {
54 return d.TransactionContext(context.Background(), fn)
55}
56
57// TransactionContext implements db.DB.
58func (d *DB) TransactionContext(ctx context.Context, fn func(tx *Tx) error) error {
59 txx, err := d.BeginTxx(ctx, nil)
60 if err != nil {
61 return fmt.Errorf("failed to begin transaction: %w", err)
62 }
63
64 tx := &Tx{txx, d.logger}
65 if err := fn(tx); err != nil {
66 return rollback(tx, err)
67 }
68
69 if err := tx.Commit(); err != nil {
70 if errors.Is(err, sql.ErrTxDone) {
71 // this is ok because whoever did finish the tx should have also written the error already.
72 return nil
73 }
74 return fmt.Errorf("failed to commit transaction: %w", err)
75 }
76
77 return nil
78}
79
80func rollback(tx *Tx, err error) error {
81 if rerr := tx.Rollback(); rerr != nil {
82 if errors.Is(rerr, sql.ErrTxDone) {
83 return err
84 }
85 return fmt.Errorf("failed to rollback: %s: %w", err.Error(), rerr)
86 }
87
88 return err
89}