feat: use context

Ayman Bagabas created

Change summary

cmd/soft/migrate_config.go      |  5 ++-
cmd/soft/root.go                | 14 +++++++++
cmd/soft/serve.go               |  7 ++--
examples/setuid/main.go         |  5 ++-
server/backend/sqlite/sqlite.go | 48 ++++++++++++++++++----------------
server/config/config.go         |  4 --
server/cron/cron.go             |  4 +-
server/daemon_test.go           |  4 ++
server/server.go                | 24 +++++++++--------
server/server_test.go           |  5 ++-
server/session_test.go          |  4 ++
11 files changed, 72 insertions(+), 52 deletions(-)

Detailed changes

cmd/soft/migrate_config.go 🔗

@@ -23,12 +23,13 @@ var (
 	migrateConfig = &cobra.Command{
 		Use:   "migrate-config",
 		Short: "Migrate config to new format",
-		RunE: func(_ *cobra.Command, _ []string) error {
+		RunE: func(cmd *cobra.Command, _ []string) error {
 			keyPath := os.Getenv("SOFT_SERVE_KEY_PATH")
 			reposPath := os.Getenv("SOFT_SERVE_REPO_PATH")
 			bindAddr := os.Getenv("SOFT_SERVE_BIND_ADDRESS")
+			ctx := cmd.Context()
 			cfg := config.DefaultConfig()
-			sb, err := sqlite.NewSqliteBackend(cfg)
+			sb, err := sqlite.NewSqliteBackend(ctx, cfg)
 			if err != nil {
 				return fmt.Errorf("failed to create sqlite backend: %w", err)
 			}

cmd/soft/root.go 🔗

@@ -1,9 +1,11 @@
 package main
 
 import (
+	"context"
 	"os"
 	"runtime/debug"
 
+	"github.com/charmbracelet/log"
 	_ "github.com/charmbracelet/soft-serve/log"
 	"github.com/spf13/cobra"
 )
@@ -49,7 +51,17 @@ func init() {
 }
 
 func main() {
-	if err := rootCmd.Execute(); err != nil {
+	logger := log.NewWithOptions(os.Stderr, log.Options{
+		ReportTimestamp: true,
+		TimeFormat:      "2006-01-02",
+	})
+	if os.Getenv("SOFT_SERVE_DEBUG") == "true" {
+		logger.SetLevel(log.DebugLevel)
+	}
+
+	ctx := context.Background()
+	ctx = log.WithContext(ctx, logger)
+	if err := rootCmd.ExecuteContext(ctx); err != nil {
 		os.Exit(1)
 	}
 }

cmd/soft/serve.go 🔗

@@ -7,6 +7,7 @@ import (
 	"syscall"
 	"time"
 
+	_ "github.com/charmbracelet/soft-serve/log"
 	"github.com/charmbracelet/soft-serve/server"
 	"github.com/charmbracelet/soft-serve/server/config"
 	"github.com/spf13/cobra"
@@ -19,19 +20,19 @@ var (
 		Long:  "Start the server",
 		Args:  cobra.NoArgs,
 		RunE: func(cmd *cobra.Command, args []string) error {
+			ctx := cmd.Context()
 			cfg := config.DefaultConfig()
-			s, err := server.NewServer(cfg)
+			s, err := server.NewServer(ctx, cfg)
 			if err != nil {
 				return err
 			}
 
-			ctx := cmd.Context()
 			done := make(chan os.Signal, 1)
 			lch := make(chan error, 1)
 			go func() {
 				defer close(lch)
 				defer close(done)
-				lch <- s.Start(ctx)
+				lch <- s.Start()
 			}()
 
 			signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)

examples/setuid/main.go 🔗

@@ -44,9 +44,10 @@ func main() {
 	if err := syscall.Setuid(*uid); err != nil {
 		log.Fatal("Setuid error", "err", err)
 	}
+	ctx := context.Background()
 	cfg := config.DefaultConfig()
 	cfg.SSH.ListenAddr = fmt.Sprintf(":%d", *port)
-	s, err := server.NewServer(cfg)
+	s, err := server.NewServer(ctx, cfg)
 	if err != nil {
 		log.Fatal(err)
 	}
@@ -64,7 +65,7 @@ func main() {
 	<-done
 
 	log.Print("Stopping SSH server", "addr", cfg.SSH.ListenAddr)
-	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+	ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
 	defer func() { cancel() }()
 	if err := s.Shutdown(ctx); err != nil {
 		log.Fatal(err)

server/backend/sqlite/sqlite.go 🔗

@@ -26,6 +26,7 @@ var (
 // backend.
 type SqliteBackend struct {
 	cfg *config.Config
+	ctx context.Context
 	dp  string
 	db  *sqlx.DB
 }
@@ -37,7 +38,7 @@ func (d *SqliteBackend) reposPath() string {
 }
 
 // NewSqliteBackend creates a new SqliteBackend.
-func NewSqliteBackend(cfg *config.Config) (*SqliteBackend, error) {
+func NewSqliteBackend(ctx context.Context, cfg *config.Config) (*SqliteBackend, error) {
 	dataPath := cfg.DataPath
 	if err := os.MkdirAll(dataPath, 0755); err != nil {
 		return nil, err
@@ -51,6 +52,7 @@ func NewSqliteBackend(cfg *config.Config) (*SqliteBackend, error) {
 
 	d := &SqliteBackend{
 		cfg: cfg,
+		ctx: ctx,
 		dp:  dataPath,
 		db:  db,
 	}
@@ -71,7 +73,7 @@ func NewSqliteBackend(cfg *config.Config) (*SqliteBackend, error) {
 // It implements backend.Backend.
 func (d *SqliteBackend) AllowKeyless() bool {
 	var allow bool
-	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
+	if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
 		return tx.Get(&allow, "SELECT value FROM settings WHERE key = ?;", "allow_keyless")
 	}); err != nil {
 		return false
@@ -85,7 +87,7 @@ func (d *SqliteBackend) AllowKeyless() bool {
 // It implements backend.Backend.
 func (d *SqliteBackend) AnonAccess() backend.AccessLevel {
 	var level string
-	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
+	if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
 		return tx.Get(&level, "SELECT value FROM settings WHERE key = ?;", "anon_access")
 	}); err != nil {
 		return backend.NoAccess
@@ -99,7 +101,7 @@ func (d *SqliteBackend) AnonAccess() backend.AccessLevel {
 // It implements backend.Backend.
 func (d *SqliteBackend) SetAllowKeyless(allow bool) error {
 	return wrapDbErr(
-		wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
+		wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
 			_, err := tx.Exec("UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = ?;", allow, "allow_keyless")
 			return err
 		}),
@@ -111,7 +113,7 @@ func (d *SqliteBackend) SetAllowKeyless(allow bool) error {
 // It implements backend.Backend.
 func (d *SqliteBackend) SetAnonAccess(level backend.AccessLevel) error {
 	return wrapDbErr(
-		wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
+		wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
 			_, err := tx.Exec("UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = ?;", level.String(), "anon_access")
 			return err
 		}),
@@ -147,7 +149,7 @@ func (d *SqliteBackend) CreateRepository(name string, opts backend.RepositoryOpt
 		return nil, err
 	}
 
-	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
+	if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
 		_, err := tx.Exec(`INSERT INTO repo (name, project_name, description, private, mirror, hidden, updated_at)
 			VALUES (?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP);`,
 			name, opts.ProjectName, opts.Description, opts.Private, opts.Mirror, opts.Hidden)
@@ -210,7 +212,7 @@ func (d *SqliteBackend) DeleteRepository(name string) error {
 		return os.ErrNotExist
 	}
 
-	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
+	if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
 		_, err := tx.Exec("DELETE FROM repo WHERE name = ?;", name)
 		return err
 	}); err != nil {
@@ -245,7 +247,7 @@ func (d *SqliteBackend) RenameRepository(oldName string, newName string) error {
 		return fmt.Errorf("repository %s already exists", newName)
 	}
 
-	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
+	if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
 		_, err := tx.Exec("UPDATE repo SET name = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?;", newName, oldName)
 		return err
 	}); err != nil {
@@ -260,7 +262,7 @@ func (d *SqliteBackend) RenameRepository(oldName string, newName string) error {
 // It implements backend.Backend.
 func (d *SqliteBackend) Repositories() ([]backend.Repository, error) {
 	repos := make([]backend.Repository, 0)
-	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
+	if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
 		rows, err := tx.Query("SELECT name FROM repo")
 		if err != nil {
 			return err
@@ -299,7 +301,7 @@ func (d *SqliteBackend) Repository(repo string) (backend.Repository, error) {
 	}
 
 	var count int
-	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
+	if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
 		return tx.Get(&count, "SELECT COUNT(*) FROM repo WHERE name = ?", repo)
 	}); err != nil {
 		return nil, wrapDbErr(err)
@@ -323,7 +325,7 @@ func (d *SqliteBackend) Repository(repo string) (backend.Repository, error) {
 func (d *SqliteBackend) Description(repo string) (string, error) {
 	repo = utils.SanitizeRepo(repo)
 	var desc string
-	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
+	if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
 		return tx.Get(&desc, "SELECT description FROM repo WHERE name = ?", repo)
 	}); err != nil {
 		return "", wrapDbErr(err)
@@ -338,7 +340,7 @@ func (d *SqliteBackend) Description(repo string) (string, error) {
 func (d *SqliteBackend) IsMirror(repo string) (bool, error) {
 	repo = utils.SanitizeRepo(repo)
 	var mirror bool
-	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
+	if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
 		return tx.Get(&mirror, "SELECT mirror FROM repo WHERE name = ?", repo)
 	}); err != nil {
 		return false, wrapDbErr(err)
@@ -353,7 +355,7 @@ func (d *SqliteBackend) IsMirror(repo string) (bool, error) {
 func (d *SqliteBackend) IsPrivate(repo string) (bool, error) {
 	repo = utils.SanitizeRepo(repo)
 	var private bool
-	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
+	if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
 		return tx.Get(&private, "SELECT private FROM repo WHERE name = ?", repo)
 	}); err != nil {
 		return false, wrapDbErr(err)
@@ -368,7 +370,7 @@ func (d *SqliteBackend) IsPrivate(repo string) (bool, error) {
 func (d *SqliteBackend) IsHidden(repo string) (bool, error) {
 	repo = utils.SanitizeRepo(repo)
 	var hidden bool
-	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
+	if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
 		return tx.Get(&hidden, "SELECT hidden FROM repo WHERE name = ?", repo)
 	}); err != nil {
 		return false, wrapDbErr(err)
@@ -382,7 +384,7 @@ func (d *SqliteBackend) IsHidden(repo string) (bool, error) {
 // It implements backend.Backend.
 func (d *SqliteBackend) SetHidden(repo string, hidden bool) error {
 	repo = utils.SanitizeRepo(repo)
-	return wrapDbErr(wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
+	return wrapDbErr(wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
 		_, err := tx.Exec("UPDATE repo SET hidden = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?;", hidden, repo)
 		return err
 	}))
@@ -394,7 +396,7 @@ func (d *SqliteBackend) SetHidden(repo string, hidden bool) error {
 func (d *SqliteBackend) ProjectName(repo string) (string, error) {
 	repo = utils.SanitizeRepo(repo)
 	var name string
-	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
+	if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
 		return tx.Get(&name, "SELECT project_name FROM repo WHERE name = ?", repo)
 	}); err != nil {
 		return "", wrapDbErr(err)
@@ -408,7 +410,7 @@ func (d *SqliteBackend) ProjectName(repo string) (string, error) {
 // It implements backend.Backend.
 func (d *SqliteBackend) SetDescription(repo string, desc string) error {
 	repo = utils.SanitizeRepo(repo)
-	return wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
+	return wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
 		_, err := tx.Exec("UPDATE repo SET description = ? WHERE name = ?", desc, repo)
 		return err
 	})
@@ -420,7 +422,7 @@ func (d *SqliteBackend) SetDescription(repo string, desc string) error {
 func (d *SqliteBackend) SetPrivate(repo string, private bool) error {
 	repo = utils.SanitizeRepo(repo)
 	return wrapDbErr(
-		wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
+		wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
 			_, err := tx.Exec("UPDATE repo SET private = ? WHERE name = ?", private, repo)
 			return err
 		}),
@@ -433,7 +435,7 @@ func (d *SqliteBackend) SetPrivate(repo string, private bool) error {
 func (d *SqliteBackend) SetProjectName(repo string, name string) error {
 	repo = utils.SanitizeRepo(repo)
 	return wrapDbErr(
-		wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
+		wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
 			_, err := tx.Exec("UPDATE repo SET project_name = ? WHERE name = ?", name, repo)
 			return err
 		}),
@@ -450,7 +452,7 @@ func (d *SqliteBackend) AddCollaborator(repo string, username string) error {
 	}
 
 	repo = utils.SanitizeRepo(repo)
-	return wrapDbErr(wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
+	return wrapDbErr(wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
 		_, err := tx.Exec(`INSERT INTO collab (user_id, repo_id, updated_at)
 			VALUES (
 			(SELECT id FROM user WHERE username = ?),
@@ -468,7 +470,7 @@ func (d *SqliteBackend) AddCollaborator(repo string, username string) error {
 func (d *SqliteBackend) Collaborators(repo string) ([]string, error) {
 	repo = utils.SanitizeRepo(repo)
 	var users []string
-	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
+	if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
 		return tx.Select(&users, `SELECT name FROM user
 			INNER JOIN collab ON user.id = collab.user_id
 			INNER JOIN repo ON repo.id = collab.repo_id
@@ -486,7 +488,7 @@ func (d *SqliteBackend) Collaborators(repo string) ([]string, error) {
 func (d *SqliteBackend) IsCollaborator(repo string, username string) (bool, error) {
 	repo = utils.SanitizeRepo(repo)
 	var count int
-	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
+	if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
 		return tx.Get(&count, `SELECT COUNT(*) FROM user
 			INNER JOIN collab ON user.id = collab.user_id
 			INNER JOIN repo ON repo.id = collab.repo_id
@@ -504,7 +506,7 @@ func (d *SqliteBackend) IsCollaborator(repo string, username string) (bool, erro
 func (d *SqliteBackend) RemoveCollaborator(repo string, username string) error {
 	repo = utils.SanitizeRepo(repo)
 	return wrapDbErr(
-		wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
+		wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
 			_, err := tx.Exec(`DELETE FROM collab
 			WHERE user_id = (SELECT id FROM user WHERE username = ?)
 			AND repo_id = (SELECT id FROM repo WHERE name = ?)`, username, repo)

server/config/config.go 🔗

@@ -10,10 +10,6 @@ import (
 	"gopkg.in/yaml.v3"
 )
 
-var (
-	logger = log.WithPrefix("server.config")
-)
-
 // SSHConfig is the configuration for the SSH server.
 type SSHConfig struct {
 	// ListenAddr is the address on which the SSH server will listen.

server/cron/cron.go 🔗

@@ -38,8 +38,8 @@ func (l cronLogger) Error(err error, msg string, keysAndValues ...interface{}) {
 }
 
 // NewCronScheduler returns a new Cron.
-func NewCronScheduler() *CronScheduler {
-	logger := cronLogger{log.WithPrefix("server.cron")}
+func NewCronScheduler(ctx context.Context) *CronScheduler {
+	logger := cronLogger{log.FromContext(ctx).WithPrefix("server.cron")}
 	return &CronScheduler{
 		Cron: cron.New(cron.WithLogger(logger)),
 	}

server/daemon_test.go 🔗

@@ -2,6 +2,7 @@ package server
 
 import (
 	"bytes"
+	"context"
 	"errors"
 	"fmt"
 	"io"
@@ -34,7 +35,8 @@ func TestMain(m *testing.M) {
 	if err != nil {
 		log.Fatal(err)
 	}
-	fb, err := sqlite.NewSqliteBackend(cfg)
+	ctx := context.TODO()
+	fb, err := sqlite.NewSqliteBackend(ctx, cfg)
 	if err != nil {
 		log.Fatal(err)
 	}

server/server.go 🔗

@@ -30,6 +30,7 @@ type Server struct {
 	Cron        *cron.CronScheduler
 	Config      *config.Config
 	Backend     backend.Backend
+	ctx         context.Context
 }
 
 // NewServer returns a new *ssh.Server configured to serve Soft Serve. The SSH
@@ -37,10 +38,10 @@ type Server struct {
 // key can be provided with authKey. If authKey is provided, access will be
 // restricted to that key. If authKey is not provided, the server will be
 // publicly writable until configured otherwise by cloning the `config` repo.
-func NewServer(cfg *config.Config) (*Server, error) {
+func NewServer(ctx context.Context, cfg *config.Config) (*Server, error) {
 	var err error
 	if cfg.Backend == nil {
-		sb, err := sqlite.NewSqliteBackend(cfg)
+		sb, err := sqlite.NewSqliteBackend(ctx, cfg)
 		if err != nil {
 			logger.Fatal(err)
 		}
@@ -71,9 +72,10 @@ func NewServer(cfg *config.Config) (*Server, error) {
 	}
 
 	srv := &Server{
-		Cron:    cron.NewCronScheduler(),
+		Cron:    cron.NewCronScheduler(ctx),
 		Config:  cfg,
 		Backend: cfg.Backend,
+		ctx:     ctx,
 	}
 
 	// Add cron jobs.
@@ -117,39 +119,39 @@ func start(ctx context.Context, fn func() error) error {
 }
 
 // Start starts the SSH server.
-func (s *Server) Start(ctx context.Context) error {
-	var errg *errgroup.Group
-	errg, ctx = errgroup.WithContext(ctx)
+func (s *Server) Start() error {
+	logger := log.FromContext(s.ctx).WithPrefix("server")
+	errg, ctx := errgroup.WithContext(s.ctx)
 	errg.Go(func() error {
-		log.Print("Starting Git daemon", "addr", s.Config.Git.ListenAddr)
+		logger.Print("Starting Git daemon", "addr", s.Config.Git.ListenAddr)
 		if err := start(ctx, s.GitDaemon.Start); !errors.Is(err, ErrServerClosed) {
 			return err
 		}
 		return nil
 	})
 	errg.Go(func() error {
-		log.Print("Starting HTTP server", "addr", s.Config.HTTP.ListenAddr)
+		logger.Print("Starting HTTP server", "addr", s.Config.HTTP.ListenAddr)
 		if err := start(ctx, s.HTTPServer.ListenAndServe); !errors.Is(err, http.ErrServerClosed) {
 			return err
 		}
 		return nil
 	})
 	errg.Go(func() error {
-		log.Print("Starting SSH server", "addr", s.Config.SSH.ListenAddr)
+		logger.Print("Starting SSH server", "addr", s.Config.SSH.ListenAddr)
 		if err := start(ctx, s.SSHServer.ListenAndServe); !errors.Is(err, ssh.ErrServerClosed) {
 			return err
 		}
 		return nil
 	})
 	errg.Go(func() error {
-		log.Print("Starting Stats server", "addr", s.Config.Stats.ListenAddr)
+		logger.Print("Starting Stats server", "addr", s.Config.Stats.ListenAddr)
 		if err := start(ctx, s.StatsServer.ListenAndServe); !errors.Is(err, http.ErrServerClosed) {
 			return err
 		}
 		return nil
 	})
 	errg.Go(func() error {
-		log.Print("Starting cron scheduler")
+		logger.Print("Starting cron scheduler")
 		s.Cron.Start()
 		return nil
 	})

server/server_test.go 🔗

@@ -33,13 +33,14 @@ func setupServer(tb testing.TB) (*Server, *config.Config, string) {
 	tb.Setenv("SOFT_SERVE_GIT_LISTEN_ADDR", fmt.Sprintf(":%d", randomPort()))
 	cfg := config.DefaultConfig()
 	tb.Log("configuring server")
-	s, err := NewServer(cfg)
+	ctx := context.TODO()
+	s, err := NewServer(ctx, cfg)
 	if err != nil {
 		tb.Fatal(err)
 	}
 	go func() {
 		tb.Log("starting server")
-		s.Start(context.TODO())
+		s.Start()
 	}()
 	tb.Cleanup(func() {
 		s.Close()

server/session_test.go 🔗

@@ -1,6 +1,7 @@
 package server
 
 import (
+	"context"
 	"errors"
 	"fmt"
 	"log"
@@ -55,8 +56,9 @@ func setup(tb testing.TB) (*gossh.Session, func() error) {
 		is.NoErr(os.Unsetenv("SOFT_SERVE_SSH_LISTEN_ADDR"))
 		is.NoErr(os.RemoveAll(dp))
 	})
+	ctx := context.TODO()
 	cfg := config.DefaultConfig()
-	fb, err := sqlite.NewSqliteBackend(cfg)
+	fb, err := sqlite.NewSqliteBackend(ctx, cfg)
 	if err != nil {
 		log.Fatal(err)
 	}