diff --git a/cmd/soft/root.go b/cmd/soft/root.go index 5fc9e4588e6dcc9fb57169b6ea1a857c1d899c44..af4c56c3f97d518e569e9b63aede54a2106808a6 100644 --- a/cmd/soft/root.go +++ b/cmd/soft/root.go @@ -1,9 +1,9 @@ package main import ( + "os" "runtime/debug" - "github.com/charmbracelet/log" _ "github.com/charmbracelet/soft-serve/log" "github.com/spf13/cobra" ) @@ -18,12 +18,10 @@ var ( CommitSHA = "" rootCmd = &cobra.Command{ - Use: "soft", - Short: "A self-hostable Git server for the command line", - Long: "Soft Serve is a self-hostable Git server for the command line.", - RunE: func(cmd *cobra.Command, args []string) error { - return cmd.Help() - }, + Use: "soft", + Short: "A self-hostable Git server for the command line", + Long: "Soft Serve is a self-hostable Git server for the command line.", + SilenceUsage: true, } ) @@ -52,6 +50,6 @@ func init() { func main() { if err := rootCmd.Execute(); err != nil { - log.Fatal(err) + os.Exit(1) } } diff --git a/cmd/soft/serve.go b/cmd/soft/serve.go index 332796394866e8b1fe6275985ccf771a589aa467..5841f2cdd36113eab15b5b07fd0b5084e82f4012 100644 --- a/cmd/soft/serve.go +++ b/cmd/soft/serve.go @@ -25,18 +25,19 @@ var ( return err } + ctx := cmd.Context() done := make(chan os.Signal, 1) lch := make(chan error, 1) go func() { defer close(lch) defer close(done) - lch <- s.Start() + lch <- s.Start(ctx) }() signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) <-done - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() if err := s.Shutdown(ctx); err != nil { return err diff --git a/server/config/file.go b/server/config/file.go index f85c7247b564ce3d2c1f9a8ae65f70ff0849bccb..7079ff3033b4163c7ac07257e90750c722e7a928 100644 --- a/server/config/file.go +++ b/server/config/file.go @@ -59,10 +59,10 @@ http: # The address on which the HTTP server will listen. listen_addr: "{{ .HTTP.ListenAddr }}" - # The relative path to the TLS private key. + # The path to the TLS private key. tls_key_path: "{{ .HTTP.TLSKeyPath }}" - # The relative path to the TLS certificate. + # The path to the TLS certificate. tls_cert_path: "{{ .HTTP.TLSCertPath }}" # The public URL of the HTTP server. @@ -79,6 +79,6 @@ stats: func newConfigFile(cfg *Config) string { var b bytes.Buffer - configFileTmpl.Execute(&b, cfg) + configFileTmpl.Execute(&b, cfg) // nolint: errcheck return b.String() } diff --git a/server/server.go b/server/server.go index de227c54287853bf5cbb73a7bdeb1e68b52809d9..992e0c7c13c0a656afd08a3c764481fe393d9b5a 100644 --- a/server/server.go +++ b/server/server.go @@ -2,6 +2,7 @@ package server import ( "context" + "errors" "net/http" "path/filepath" @@ -101,33 +102,48 @@ func NewServer(cfg *config.Config) (*Server, error) { return srv, nil } +func start(ctx context.Context, fn func() error) error { + errc := make(chan error, 1) + go func() { + errc <- fn() + }() + + select { + case err := <-errc: + return err + case <-ctx.Done(): + return ctx.Err() + } +} + // Start starts the SSH server. -func (s *Server) Start() error { - var errg errgroup.Group +func (s *Server) Start(ctx context.Context) error { + var errg *errgroup.Group + errg, ctx = errgroup.WithContext(ctx) errg.Go(func() error { log.Print("Starting Git daemon", "addr", s.Config.Git.ListenAddr) - if err := s.GitDaemon.Start(); err != ErrServerClosed { + if err := start(ctx, s.GitDaemon.Start); !errors.Is(err, ErrServerClosed) { return err } return nil }) errg.Go(func() error { log.Print("Starting HTTP server", "addr", s.Config.HTTP.ListenAddr) - if err := s.HTTPServer.ListenAndServe(); err != http.ErrServerClosed { + if err := start(ctx, s.HTTPServer.ListenAndServe); !errors.Is(err, http.ErrServerClosed) { return err } return nil }) errg.Go(func() error { log.Print("Starting SSH server", "addr", s.Config.SSH.ListenAddr) - if err := s.SSHServer.ListenAndServe(); err != ssh.ErrServerClosed { + if err := start(ctx, s.SSHServer.ListenAndServe); !errors.Is(err, ssh.ErrServerClosed) { return err } return nil }) errg.Go(func() error { log.Print("Starting Stats server", "addr", s.Config.Stats.ListenAddr) - if err := s.StatsServer.ListenAndServe(); err != http.ErrServerClosed { + if err := start(ctx, s.StatsServer.ListenAndServe); !errors.Is(err, http.ErrServerClosed) { return err } return nil diff --git a/server/server_test.go b/server/server_test.go index 5505bec565280f8ed92c0487eed52c31c26c9910..07066d976412aacb4f1fb174609587a4731edac2 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "net" "path/filepath" @@ -38,7 +39,7 @@ func setupServer(tb testing.TB) (*Server, *config.Config, string) { } go func() { tb.Log("starting server") - s.Start() + s.Start(context.TODO()) }() tb.Cleanup(func() { s.Close()