diff --git a/cmd/root.go b/cmd/root.go index 7832da982616727b5e13bc272ebbfccc9ec704cc..2b5f79cf0337c386196d783ad9d18e2e1380aa5b 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -77,16 +77,16 @@ to assist developers in writing, debugging, and understanding code directly from return err } + // Create main context for the application + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // Connect DB, this will also run migrations - conn, err := db.Connect() + conn, err := db.Connect(ctx) if err != nil { return err } - // Create main context for the application - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - app, err := app.New(ctx, conn) if err != nil { logging.Error("Failed to create app: %v", err) diff --git a/internal/db/connect.go b/internal/db/connect.go index ed48ddcba8fea094c815b009dcaa5ce1cc354d0c..9212ce1f097e6877a9ce9b368e77d76e739b673f 100644 --- a/internal/db/connect.go +++ b/internal/db/connect.go @@ -1,6 +1,7 @@ package db import ( + "context" "database/sql" "fmt" "os" @@ -15,7 +16,7 @@ import ( "github.com/pressly/goose/v3" ) -func Connect() (*sql.DB, error) { +func Connect(ctx context.Context) (*sql.DB, error) { dataDir := config.Get().Data.Directory if dataDir == "" { return nil, fmt.Errorf("data.dir is not set") @@ -31,7 +32,7 @@ func Connect() (*sql.DB, error) { } // Verify connection - if err = db.Ping(); err != nil { + if err = db.PingContext(ctx); err != nil { db.Close() return nil, fmt.Errorf("failed to connect to database: %w", err) } @@ -46,7 +47,7 @@ func Connect() (*sql.DB, error) { } for _, pragma := range pragmas { - if _, err = db.Exec(pragma); err != nil { + if _, err = db.ExecContext(ctx, pragma); err != nil { logging.Error("Failed to set pragma", pragma, err) } else { logging.Debug("Set pragma", "pragma", pragma) diff --git a/internal/history/file.go b/internal/history/file.go index d8fe6088626be28262f06485c07c95693ddfd219..7317f012fd83b31990bbc701261eed91794a52a5 100644 --- a/internal/history/file.go +++ b/internal/history/file.go @@ -83,7 +83,7 @@ func (s *service) createWithVersion(ctx context.Context, sessionID, path, conten // Retry loop for transaction conflicts for attempt := range maxRetries { // Start a transaction - tx, txErr := s.db.Begin() + tx, txErr := s.db.BeginTx(ctx, nil) if txErr != nil { return File{}, fmt.Errorf("failed to begin transaction: %w", txErr) }