1package db
 2
 3import (
 4	"context"
 5	"database/sql"
 6	"errors"
 7	"fmt"
 8
 9	"github.com/charmbracelet/log"
10	"github.com/charmbracelet/soft-serve/server/config"
11	"github.com/jmoiron/sqlx"
12	_ "modernc.org/sqlite" // sqlite driver
13)
14
15// DB is the interface for a Soft Serve database.
16type DB struct {
17	*sqlx.DB
18	logger *log.Logger
19}
20
21// Open opens a database connection.
22func Open(ctx context.Context, driverName string, dsn string) (*DB, error) {
23	db, err := sqlx.ConnectContext(ctx, driverName, dsn)
24	if err != nil {
25		return nil, err
26	}
27
28	d := &DB{
29		DB: db,
30	}
31
32	if config.IsVerbose() {
33		logger := log.FromContext(ctx).WithPrefix("db")
34		d.logger = logger
35	}
36
37	return d, nil
38}
39
40// Close implements db.DB.
41func (d *DB) Close() error {
42	return d.DB.Close()
43}
44
45// Tx is a database transaction.
46type Tx struct {
47	*sqlx.Tx
48	logger *log.Logger
49}
50
51// Transaction implements db.DB.
52func (d *DB) Transaction(fn func(tx *Tx) error) error {
53	return d.TransactionContext(context.Background(), fn)
54}
55
56// TransactionContext implements db.DB.
57func (d *DB) TransactionContext(ctx context.Context, fn func(tx *Tx) error) error {
58	txx, err := d.DB.BeginTxx(ctx, nil)
59	if err != nil {
60		return fmt.Errorf("failed to begin transaction: %w", err)
61	}
62
63	tx := &Tx{txx, d.logger}
64	if err := fn(tx); err != nil {
65		return rollback(tx, err)
66	}
67
68	if err := tx.Commit(); err != nil {
69		if errors.Is(err, sql.ErrTxDone) {
70			// this is ok because whoever did finish the tx should have also written the error already.
71			return nil
72		}
73		return fmt.Errorf("failed to commit transaction: %w", err)
74	}
75
76	return nil
77}
78
79func rollback(tx *Tx, err error) error {
80	if rerr := tx.Rollback(); rerr != nil {
81		if errors.Is(rerr, sql.ErrTxDone) {
82			return err
83		}
84		return fmt.Errorf("failed to rollback: %s: %w", err.Error(), rerr)
85	}
86
87	return err
88}