fix(server): start with context

Ayman Bagabas created

Change summary

cmd/soft/root.go      | 14 ++++++--------
cmd/soft/serve.go     |  5 +++--
server/config/file.go |  6 +++---
server/server.go      | 28 ++++++++++++++++++++++------
server/server_test.go |  3 ++-
5 files changed, 36 insertions(+), 20 deletions(-)

Detailed changes

cmd/soft/root.go 🔗

@@ -1,9 +1,9 @@
 package main
 
 import (
+	"os"
 	"runtime/debug"
 
-	"github.com/charmbracelet/log"
 	_ "github.com/charmbracelet/soft-serve/log"
 	"github.com/spf13/cobra"
 )
@@ -18,12 +18,10 @@ var (
 	CommitSHA = ""
 
 	rootCmd = &cobra.Command{
-		Use:   "soft",
-		Short: "A self-hostable Git server for the command line",
-		Long:  "Soft Serve is a self-hostable Git server for the command line.",
-		RunE: func(cmd *cobra.Command, args []string) error {
-			return cmd.Help()
-		},
+		Use:          "soft",
+		Short:        "A self-hostable Git server for the command line",
+		Long:         "Soft Serve is a self-hostable Git server for the command line.",
+		SilenceUsage: true,
 	}
 )
 
@@ -52,6 +50,6 @@ func init() {
 
 func main() {
 	if err := rootCmd.Execute(); err != nil {
-		log.Fatal(err)
+		os.Exit(1)
 	}
 }

cmd/soft/serve.go 🔗

@@ -25,18 +25,19 @@ var (
 				return err
 			}
 
+			ctx := cmd.Context()
 			done := make(chan os.Signal, 1)
 			lch := make(chan error, 1)
 			go func() {
 				defer close(lch)
 				defer close(done)
-				lch <- s.Start()
+				lch <- s.Start(ctx)
 			}()
 
 			signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
 			<-done
 
-			ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+			ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
 			defer cancel()
 			if err := s.Shutdown(ctx); err != nil {
 				return err

server/config/file.go 🔗

@@ -59,10 +59,10 @@ http:
   # The address on which the HTTP server will listen.
   listen_addr: "{{ .HTTP.ListenAddr }}"
 
-  # The relative path to the TLS private key.
+  # The path to the TLS private key.
   tls_key_path: "{{ .HTTP.TLSKeyPath }}"
 
-  # The relative path to the TLS certificate.
+  # The path to the TLS certificate.
   tls_cert_path: "{{ .HTTP.TLSCertPath }}"
 
   # The public URL of the HTTP server.
@@ -79,6 +79,6 @@ stats:
 
 func newConfigFile(cfg *Config) string {
 	var b bytes.Buffer
-	configFileTmpl.Execute(&b, cfg)
+	configFileTmpl.Execute(&b, cfg) // nolint: errcheck
 	return b.String()
 }

server/server.go 🔗

@@ -2,6 +2,7 @@ package server
 
 import (
 	"context"
+	"errors"
 	"net/http"
 	"path/filepath"
 
@@ -101,33 +102,48 @@ func NewServer(cfg *config.Config) (*Server, error) {
 	return srv, nil
 }
 
+func start(ctx context.Context, fn func() error) error {
+	errc := make(chan error, 1)
+	go func() {
+		errc <- fn()
+	}()
+
+	select {
+	case err := <-errc:
+		return err
+	case <-ctx.Done():
+		return ctx.Err()
+	}
+}
+
 // Start starts the SSH server.
-func (s *Server) Start() error {
-	var errg errgroup.Group
+func (s *Server) Start(ctx context.Context) error {
+	var errg *errgroup.Group
+	errg, ctx = errgroup.WithContext(ctx)
 	errg.Go(func() error {
 		log.Print("Starting Git daemon", "addr", s.Config.Git.ListenAddr)
-		if err := s.GitDaemon.Start(); err != ErrServerClosed {
+		if err := start(ctx, s.GitDaemon.Start); !errors.Is(err, ErrServerClosed) {
 			return err
 		}
 		return nil
 	})
 	errg.Go(func() error {
 		log.Print("Starting HTTP server", "addr", s.Config.HTTP.ListenAddr)
-		if err := s.HTTPServer.ListenAndServe(); err != http.ErrServerClosed {
+		if err := start(ctx, s.HTTPServer.ListenAndServe); !errors.Is(err, http.ErrServerClosed) {
 			return err
 		}
 		return nil
 	})
 	errg.Go(func() error {
 		log.Print("Starting SSH server", "addr", s.Config.SSH.ListenAddr)
-		if err := s.SSHServer.ListenAndServe(); err != ssh.ErrServerClosed {
+		if err := start(ctx, s.SSHServer.ListenAndServe); !errors.Is(err, ssh.ErrServerClosed) {
 			return err
 		}
 		return nil
 	})
 	errg.Go(func() error {
 		log.Print("Starting Stats server", "addr", s.Config.Stats.ListenAddr)
-		if err := s.StatsServer.ListenAndServe(); err != http.ErrServerClosed {
+		if err := start(ctx, s.StatsServer.ListenAndServe); !errors.Is(err, http.ErrServerClosed) {
 			return err
 		}
 		return nil

server/server_test.go 🔗

@@ -1,6 +1,7 @@
 package server
 
 import (
+	"context"
 	"fmt"
 	"net"
 	"path/filepath"
@@ -38,7 +39,7 @@ func setupServer(tb testing.TB) (*Server, *config.Config, string) {
 	}
 	go func() {
 		tb.Log("starting server")
-		s.Start()
+		s.Start(context.TODO())
 	}()
 	tb.Cleanup(func() {
 		s.Close()