From 8dbe830713096bb4d25c2bb86fd0128d49780453 Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Mon, 10 Apr 2023 12:40:39 -0400 Subject: [PATCH] feat: use context --- cmd/soft/migrate_config.go | 5 ++-- cmd/soft/root.go | 14 +++++++++- cmd/soft/serve.go | 7 ++--- examples/setuid/main.go | 5 ++-- server/backend/sqlite/sqlite.go | 48 +++++++++++++++++---------------- server/config/config.go | 4 --- server/cron/cron.go | 4 +-- server/daemon_test.go | 4 ++- server/server.go | 24 +++++++++-------- server/server_test.go | 5 ++-- server/session_test.go | 4 ++- 11 files changed, 72 insertions(+), 52 deletions(-) diff --git a/cmd/soft/migrate_config.go b/cmd/soft/migrate_config.go index e4731ab608bb1252b9786d0a995175157bc58387..ea03f751367c3430b18467f88a6f3ab6175f140a 100644 --- a/cmd/soft/migrate_config.go +++ b/cmd/soft/migrate_config.go @@ -23,12 +23,13 @@ var ( migrateConfig = &cobra.Command{ Use: "migrate-config", Short: "Migrate config to new format", - RunE: func(_ *cobra.Command, _ []string) error { + RunE: func(cmd *cobra.Command, _ []string) error { keyPath := os.Getenv("SOFT_SERVE_KEY_PATH") reposPath := os.Getenv("SOFT_SERVE_REPO_PATH") bindAddr := os.Getenv("SOFT_SERVE_BIND_ADDRESS") + ctx := cmd.Context() cfg := config.DefaultConfig() - sb, err := sqlite.NewSqliteBackend(cfg) + sb, err := sqlite.NewSqliteBackend(ctx, cfg) if err != nil { return fmt.Errorf("failed to create sqlite backend: %w", err) } diff --git a/cmd/soft/root.go b/cmd/soft/root.go index af4c56c3f97d518e569e9b63aede54a2106808a6..44b6e8ab84f9631fa0fcfd069a90e17b60183f86 100644 --- a/cmd/soft/root.go +++ b/cmd/soft/root.go @@ -1,9 +1,11 @@ package main import ( + "context" "os" "runtime/debug" + "github.com/charmbracelet/log" _ "github.com/charmbracelet/soft-serve/log" "github.com/spf13/cobra" ) @@ -49,7 +51,17 @@ func init() { } func main() { - if err := rootCmd.Execute(); err != nil { + logger := log.NewWithOptions(os.Stderr, log.Options{ + ReportTimestamp: true, + TimeFormat: "2006-01-02", + }) + if os.Getenv("SOFT_SERVE_DEBUG") == "true" { + logger.SetLevel(log.DebugLevel) + } + + ctx := context.Background() + ctx = log.WithContext(ctx, logger) + if err := rootCmd.ExecuteContext(ctx); err != nil { os.Exit(1) } } diff --git a/cmd/soft/serve.go b/cmd/soft/serve.go index 5841f2cdd36113eab15b5b07fd0b5084e82f4012..f0d3a42151137795ef3d650b75fc50f60644f876 100644 --- a/cmd/soft/serve.go +++ b/cmd/soft/serve.go @@ -7,6 +7,7 @@ import ( "syscall" "time" + _ "github.com/charmbracelet/soft-serve/log" "github.com/charmbracelet/soft-serve/server" "github.com/charmbracelet/soft-serve/server/config" "github.com/spf13/cobra" @@ -19,19 +20,19 @@ var ( Long: "Start the server", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() cfg := config.DefaultConfig() - s, err := server.NewServer(cfg) + s, err := server.NewServer(ctx, cfg) if err != nil { return err } - ctx := cmd.Context() done := make(chan os.Signal, 1) lch := make(chan error, 1) go func() { defer close(lch) defer close(done) - lch <- s.Start(ctx) + lch <- s.Start() }() signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) diff --git a/examples/setuid/main.go b/examples/setuid/main.go index 8999dbcc25fe9f2c379c5d46ddd2f998da524067..e9cc1f751a431c961e36d050fa62f809479bd8c1 100644 --- a/examples/setuid/main.go +++ b/examples/setuid/main.go @@ -44,9 +44,10 @@ func main() { if err := syscall.Setuid(*uid); err != nil { log.Fatal("Setuid error", "err", err) } + ctx := context.Background() cfg := config.DefaultConfig() cfg.SSH.ListenAddr = fmt.Sprintf(":%d", *port) - s, err := server.NewServer(cfg) + s, err := server.NewServer(ctx, cfg) if err != nil { log.Fatal(err) } @@ -64,7 +65,7 @@ func main() { <-done log.Print("Stopping SSH server", "addr", cfg.SSH.ListenAddr) - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer func() { cancel() }() if err := s.Shutdown(ctx); err != nil { log.Fatal(err) diff --git a/server/backend/sqlite/sqlite.go b/server/backend/sqlite/sqlite.go index d3dcb81ade248ca0ecb7fa688978d37a0413e8fc..fe73c57238cba7c6c0abdbdb90e6598daaa9e982 100644 --- a/server/backend/sqlite/sqlite.go +++ b/server/backend/sqlite/sqlite.go @@ -26,6 +26,7 @@ var ( // backend. type SqliteBackend struct { cfg *config.Config + ctx context.Context dp string db *sqlx.DB } @@ -37,7 +38,7 @@ func (d *SqliteBackend) reposPath() string { } // NewSqliteBackend creates a new SqliteBackend. -func NewSqliteBackend(cfg *config.Config) (*SqliteBackend, error) { +func NewSqliteBackend(ctx context.Context, cfg *config.Config) (*SqliteBackend, error) { dataPath := cfg.DataPath if err := os.MkdirAll(dataPath, 0755); err != nil { return nil, err @@ -51,6 +52,7 @@ func NewSqliteBackend(cfg *config.Config) (*SqliteBackend, error) { d := &SqliteBackend{ cfg: cfg, + ctx: ctx, dp: dataPath, db: db, } @@ -71,7 +73,7 @@ func NewSqliteBackend(cfg *config.Config) (*SqliteBackend, error) { // It implements backend.Backend. func (d *SqliteBackend) AllowKeyless() bool { var allow bool - if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { + if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { return tx.Get(&allow, "SELECT value FROM settings WHERE key = ?;", "allow_keyless") }); err != nil { return false @@ -85,7 +87,7 @@ func (d *SqliteBackend) AllowKeyless() bool { // It implements backend.Backend. func (d *SqliteBackend) AnonAccess() backend.AccessLevel { var level string - if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { + if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { return tx.Get(&level, "SELECT value FROM settings WHERE key = ?;", "anon_access") }); err != nil { return backend.NoAccess @@ -99,7 +101,7 @@ func (d *SqliteBackend) AnonAccess() backend.AccessLevel { // It implements backend.Backend. func (d *SqliteBackend) SetAllowKeyless(allow bool) error { return wrapDbErr( - wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { + wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { _, err := tx.Exec("UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = ?;", allow, "allow_keyless") return err }), @@ -111,7 +113,7 @@ func (d *SqliteBackend) SetAllowKeyless(allow bool) error { // It implements backend.Backend. func (d *SqliteBackend) SetAnonAccess(level backend.AccessLevel) error { return wrapDbErr( - wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { + wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { _, err := tx.Exec("UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = ?;", level.String(), "anon_access") return err }), @@ -147,7 +149,7 @@ func (d *SqliteBackend) CreateRepository(name string, opts backend.RepositoryOpt return nil, err } - if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { + if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { _, err := tx.Exec(`INSERT INTO repo (name, project_name, description, private, mirror, hidden, updated_at) VALUES (?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP);`, name, opts.ProjectName, opts.Description, opts.Private, opts.Mirror, opts.Hidden) @@ -210,7 +212,7 @@ func (d *SqliteBackend) DeleteRepository(name string) error { return os.ErrNotExist } - if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { + if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { _, err := tx.Exec("DELETE FROM repo WHERE name = ?;", name) return err }); err != nil { @@ -245,7 +247,7 @@ func (d *SqliteBackend) RenameRepository(oldName string, newName string) error { return fmt.Errorf("repository %s already exists", newName) } - if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { + if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { _, err := tx.Exec("UPDATE repo SET name = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?;", newName, oldName) return err }); err != nil { @@ -260,7 +262,7 @@ func (d *SqliteBackend) RenameRepository(oldName string, newName string) error { // It implements backend.Backend. func (d *SqliteBackend) Repositories() ([]backend.Repository, error) { repos := make([]backend.Repository, 0) - if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { + if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { rows, err := tx.Query("SELECT name FROM repo") if err != nil { return err @@ -299,7 +301,7 @@ func (d *SqliteBackend) Repository(repo string) (backend.Repository, error) { } var count int - if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { + if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { return tx.Get(&count, "SELECT COUNT(*) FROM repo WHERE name = ?", repo) }); err != nil { return nil, wrapDbErr(err) @@ -323,7 +325,7 @@ func (d *SqliteBackend) Repository(repo string) (backend.Repository, error) { func (d *SqliteBackend) Description(repo string) (string, error) { repo = utils.SanitizeRepo(repo) var desc string - if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { + if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { return tx.Get(&desc, "SELECT description FROM repo WHERE name = ?", repo) }); err != nil { return "", wrapDbErr(err) @@ -338,7 +340,7 @@ func (d *SqliteBackend) Description(repo string) (string, error) { func (d *SqliteBackend) IsMirror(repo string) (bool, error) { repo = utils.SanitizeRepo(repo) var mirror bool - if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { + if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { return tx.Get(&mirror, "SELECT mirror FROM repo WHERE name = ?", repo) }); err != nil { return false, wrapDbErr(err) @@ -353,7 +355,7 @@ func (d *SqliteBackend) IsMirror(repo string) (bool, error) { func (d *SqliteBackend) IsPrivate(repo string) (bool, error) { repo = utils.SanitizeRepo(repo) var private bool - if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { + if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { return tx.Get(&private, "SELECT private FROM repo WHERE name = ?", repo) }); err != nil { return false, wrapDbErr(err) @@ -368,7 +370,7 @@ func (d *SqliteBackend) IsPrivate(repo string) (bool, error) { func (d *SqliteBackend) IsHidden(repo string) (bool, error) { repo = utils.SanitizeRepo(repo) var hidden bool - if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { + if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { return tx.Get(&hidden, "SELECT hidden FROM repo WHERE name = ?", repo) }); err != nil { return false, wrapDbErr(err) @@ -382,7 +384,7 @@ func (d *SqliteBackend) IsHidden(repo string) (bool, error) { // It implements backend.Backend. func (d *SqliteBackend) SetHidden(repo string, hidden bool) error { repo = utils.SanitizeRepo(repo) - return wrapDbErr(wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { + return wrapDbErr(wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { _, err := tx.Exec("UPDATE repo SET hidden = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?;", hidden, repo) return err })) @@ -394,7 +396,7 @@ func (d *SqliteBackend) SetHidden(repo string, hidden bool) error { func (d *SqliteBackend) ProjectName(repo string) (string, error) { repo = utils.SanitizeRepo(repo) var name string - if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { + if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { return tx.Get(&name, "SELECT project_name FROM repo WHERE name = ?", repo) }); err != nil { return "", wrapDbErr(err) @@ -408,7 +410,7 @@ func (d *SqliteBackend) ProjectName(repo string) (string, error) { // It implements backend.Backend. func (d *SqliteBackend) SetDescription(repo string, desc string) error { repo = utils.SanitizeRepo(repo) - return wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { + return wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { _, err := tx.Exec("UPDATE repo SET description = ? WHERE name = ?", desc, repo) return err }) @@ -420,7 +422,7 @@ func (d *SqliteBackend) SetDescription(repo string, desc string) error { func (d *SqliteBackend) SetPrivate(repo string, private bool) error { repo = utils.SanitizeRepo(repo) return wrapDbErr( - wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { + wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { _, err := tx.Exec("UPDATE repo SET private = ? WHERE name = ?", private, repo) return err }), @@ -433,7 +435,7 @@ func (d *SqliteBackend) SetPrivate(repo string, private bool) error { func (d *SqliteBackend) SetProjectName(repo string, name string) error { repo = utils.SanitizeRepo(repo) return wrapDbErr( - wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { + wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { _, err := tx.Exec("UPDATE repo SET project_name = ? WHERE name = ?", name, repo) return err }), @@ -450,7 +452,7 @@ func (d *SqliteBackend) AddCollaborator(repo string, username string) error { } repo = utils.SanitizeRepo(repo) - return wrapDbErr(wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { + return wrapDbErr(wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { _, err := tx.Exec(`INSERT INTO collab (user_id, repo_id, updated_at) VALUES ( (SELECT id FROM user WHERE username = ?), @@ -468,7 +470,7 @@ func (d *SqliteBackend) AddCollaborator(repo string, username string) error { func (d *SqliteBackend) Collaborators(repo string) ([]string, error) { repo = utils.SanitizeRepo(repo) var users []string - if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { + if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { return tx.Select(&users, `SELECT name FROM user INNER JOIN collab ON user.id = collab.user_id INNER JOIN repo ON repo.id = collab.repo_id @@ -486,7 +488,7 @@ func (d *SqliteBackend) Collaborators(repo string) ([]string, error) { func (d *SqliteBackend) IsCollaborator(repo string, username string) (bool, error) { repo = utils.SanitizeRepo(repo) var count int - if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { + if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { return tx.Get(&count, `SELECT COUNT(*) FROM user INNER JOIN collab ON user.id = collab.user_id INNER JOIN repo ON repo.id = collab.repo_id @@ -504,7 +506,7 @@ func (d *SqliteBackend) IsCollaborator(repo string, username string) (bool, erro func (d *SqliteBackend) RemoveCollaborator(repo string, username string) error { repo = utils.SanitizeRepo(repo) return wrapDbErr( - wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { + wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { _, err := tx.Exec(`DELETE FROM collab WHERE user_id = (SELECT id FROM user WHERE username = ?) AND repo_id = (SELECT id FROM repo WHERE name = ?)`, username, repo) diff --git a/server/config/config.go b/server/config/config.go index ea4e42408ab4b8c26c7f9fc62b65d00290dd1005..032f566ba50155d1d283423616e3a18dd1900a41 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -10,10 +10,6 @@ import ( "gopkg.in/yaml.v3" ) -var ( - logger = log.WithPrefix("server.config") -) - // SSHConfig is the configuration for the SSH server. type SSHConfig struct { // ListenAddr is the address on which the SSH server will listen. diff --git a/server/cron/cron.go b/server/cron/cron.go index b86b0eacf6b96c6d6e52c86a7de984efac504dbb..27e052dee20e8aeaa4dc77d2530a4f4ab87ef02d 100644 --- a/server/cron/cron.go +++ b/server/cron/cron.go @@ -38,8 +38,8 @@ func (l cronLogger) Error(err error, msg string, keysAndValues ...interface{}) { } // NewCronScheduler returns a new Cron. -func NewCronScheduler() *CronScheduler { - logger := cronLogger{log.WithPrefix("server.cron")} +func NewCronScheduler(ctx context.Context) *CronScheduler { + logger := cronLogger{log.FromContext(ctx).WithPrefix("server.cron")} return &CronScheduler{ Cron: cron.New(cron.WithLogger(logger)), } diff --git a/server/daemon_test.go b/server/daemon_test.go index 4b5899ddd33341be670bb1426071e270713ca4c9..765e62c9a84cc7240529768e04744a982f46e573 100644 --- a/server/daemon_test.go +++ b/server/daemon_test.go @@ -2,6 +2,7 @@ package server import ( "bytes" + "context" "errors" "fmt" "io" @@ -34,7 +35,8 @@ func TestMain(m *testing.M) { if err != nil { log.Fatal(err) } - fb, err := sqlite.NewSqliteBackend(cfg) + ctx := context.TODO() + fb, err := sqlite.NewSqliteBackend(ctx, cfg) if err != nil { log.Fatal(err) } diff --git a/server/server.go b/server/server.go index beee1fa744ecc660e0274d5858750ef1e26aa945..b4645f3cab1262667c37a07bd6e2c15763085f8e 100644 --- a/server/server.go +++ b/server/server.go @@ -30,6 +30,7 @@ type Server struct { Cron *cron.CronScheduler Config *config.Config Backend backend.Backend + ctx context.Context } // NewServer returns a new *ssh.Server configured to serve Soft Serve. The SSH @@ -37,10 +38,10 @@ type Server struct { // key can be provided with authKey. If authKey is provided, access will be // restricted to that key. If authKey is not provided, the server will be // publicly writable until configured otherwise by cloning the `config` repo. -func NewServer(cfg *config.Config) (*Server, error) { +func NewServer(ctx context.Context, cfg *config.Config) (*Server, error) { var err error if cfg.Backend == nil { - sb, err := sqlite.NewSqliteBackend(cfg) + sb, err := sqlite.NewSqliteBackend(ctx, cfg) if err != nil { logger.Fatal(err) } @@ -71,9 +72,10 @@ func NewServer(cfg *config.Config) (*Server, error) { } srv := &Server{ - Cron: cron.NewCronScheduler(), + Cron: cron.NewCronScheduler(ctx), Config: cfg, Backend: cfg.Backend, + ctx: ctx, } // Add cron jobs. @@ -117,39 +119,39 @@ func start(ctx context.Context, fn func() error) error { } // Start starts the SSH server. -func (s *Server) Start(ctx context.Context) error { - var errg *errgroup.Group - errg, ctx = errgroup.WithContext(ctx) +func (s *Server) Start() error { + logger := log.FromContext(s.ctx).WithPrefix("server") + errg, ctx := errgroup.WithContext(s.ctx) errg.Go(func() error { - log.Print("Starting Git daemon", "addr", s.Config.Git.ListenAddr) + logger.Print("Starting Git daemon", "addr", s.Config.Git.ListenAddr) if err := start(ctx, s.GitDaemon.Start); !errors.Is(err, ErrServerClosed) { return err } return nil }) errg.Go(func() error { - log.Print("Starting HTTP server", "addr", s.Config.HTTP.ListenAddr) + logger.Print("Starting HTTP server", "addr", s.Config.HTTP.ListenAddr) if err := start(ctx, s.HTTPServer.ListenAndServe); !errors.Is(err, http.ErrServerClosed) { return err } return nil }) errg.Go(func() error { - log.Print("Starting SSH server", "addr", s.Config.SSH.ListenAddr) + logger.Print("Starting SSH server", "addr", s.Config.SSH.ListenAddr) if err := start(ctx, s.SSHServer.ListenAndServe); !errors.Is(err, ssh.ErrServerClosed) { return err } return nil }) errg.Go(func() error { - log.Print("Starting Stats server", "addr", s.Config.Stats.ListenAddr) + logger.Print("Starting Stats server", "addr", s.Config.Stats.ListenAddr) if err := start(ctx, s.StatsServer.ListenAndServe); !errors.Is(err, http.ErrServerClosed) { return err } return nil }) errg.Go(func() error { - log.Print("Starting cron scheduler") + logger.Print("Starting cron scheduler") s.Cron.Start() return nil }) diff --git a/server/server_test.go b/server/server_test.go index 07066d976412aacb4f1fb174609587a4731edac2..3992794e936426dcce2a477921552811666512bc 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -33,13 +33,14 @@ func setupServer(tb testing.TB) (*Server, *config.Config, string) { tb.Setenv("SOFT_SERVE_GIT_LISTEN_ADDR", fmt.Sprintf(":%d", randomPort())) cfg := config.DefaultConfig() tb.Log("configuring server") - s, err := NewServer(cfg) + ctx := context.TODO() + s, err := NewServer(ctx, cfg) if err != nil { tb.Fatal(err) } go func() { tb.Log("starting server") - s.Start(context.TODO()) + s.Start() }() tb.Cleanup(func() { s.Close() diff --git a/server/session_test.go b/server/session_test.go index 56c0b59efc3872bf3e2b2c22f2cb335af61a3720..673e85a6099668009d6499ee4bfe70c812507214 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "errors" "fmt" "log" @@ -55,8 +56,9 @@ func setup(tb testing.TB) (*gossh.Session, func() error) { is.NoErr(os.Unsetenv("SOFT_SERVE_SSH_LISTEN_ADDR")) is.NoErr(os.RemoveAll(dp)) }) + ctx := context.TODO() cfg := config.DefaultConfig() - fb, err := sqlite.NewSqliteBackend(cfg) + fb, err := sqlite.NewSqliteBackend(ctx, cfg) if err != nil { log.Fatal(err) }