db.go

 1package db
 2
 3import (
 4	"context"
 5	"database/sql"
 6	"errors"
 7	"fmt"
 8
 9	"github.com/charmbracelet/log/v2"
10	"github.com/charmbracelet/soft-serve/pkg/config"
11	"github.com/jmoiron/sqlx"
12	_ "github.com/lib/pq"  // postgres driver
13	_ "modernc.org/sqlite" // sqlite driver
14)
15
16// DB is the interface for a Soft Serve database.
17type DB struct {
18	*sqlx.DB
19	logger *log.Logger
20}
21
22// Open opens a database connection.
23func Open(ctx context.Context, driverName string, dsn string) (*DB, error) {
24	db, err := sqlx.ConnectContext(ctx, driverName, dsn)
25	if err != nil {
26		return nil, err //nolint:wrapcheck
27	}
28
29	d := &DB{
30		DB: db,
31	}
32
33	if config.IsVerbose() {
34		logger := log.FromContext(ctx).WithPrefix("db")
35		d.logger = logger
36	}
37
38	return d, nil
39}
40
41// Close implements db.DB.
42func (d *DB) Close() error {
43	return d.DB.Close() //nolint:wrapcheck
44}
45
46// Tx is a database transaction.
47type Tx struct {
48	*sqlx.Tx
49	logger *log.Logger
50}
51
52// Transaction implements db.DB.
53func (d *DB) Transaction(fn func(tx *Tx) error) error {
54	return d.TransactionContext(context.Background(), fn)
55}
56
57// TransactionContext implements db.DB.
58func (d *DB) TransactionContext(ctx context.Context, fn func(tx *Tx) error) error {
59	txx, err := d.BeginTxx(ctx, nil)
60	if err != nil {
61		return fmt.Errorf("failed to begin transaction: %w", err)
62	}
63
64	tx := &Tx{txx, d.logger}
65	if err := fn(tx); err != nil {
66		return rollback(tx, err)
67	}
68
69	if err := tx.Commit(); err != nil {
70		if errors.Is(err, sql.ErrTxDone) {
71			// this is ok because whoever did finish the tx should have also written the error already.
72			return nil
73		}
74		return fmt.Errorf("failed to commit transaction: %w", err)
75	}
76
77	return nil
78}
79
80func rollback(tx *Tx, err error) error {
81	if rerr := tx.Rollback(); rerr != nil {
82		if errors.Is(rerr, sql.ErrTxDone) {
83			return err
84		}
85		return fmt.Errorf("failed to rollback: %s: %w", err.Error(), rerr)
86	}
87
88	return err
89}