commands: wire context with signal handling from the root (#1538)

Michael Muré created

Change summary

commands/execenv/env.go         |  5 +++
commands/execenv/env_testing.go |  1 
commands/root.go                |  5 ++-
commands/webui.go               | 43 +++++++++++++---------------------
doc/generate.go                 |  3 +
main.go                         |  8 +++++
misc/completion/generate.go     |  3 +
7 files changed, 35 insertions(+), 33 deletions(-)

Detailed changes

commands/execenv/env.go 🔗

@@ -1,6 +1,7 @@
 package execenv
 
 import (
+	"context"
 	"encoding/json"
 	"fmt"
 	"io"
@@ -19,6 +20,7 @@ const gitBugNamespace = "git-bug"
 
 // Env is the environment of a command
 type Env struct {
+	Ctx     context.Context
 	Repo    repository.ClockedRepo
 	Backend *cache.RepoCache
 	In      In
@@ -26,8 +28,9 @@ type Env struct {
 	Err     Out
 }
 
-func NewEnv() *Env {
+func NewEnv(ctx context.Context) *Env {
 	return &Env{
+		Ctx:  ctx,
 		Repo: nil,
 		In:   in{Reader: os.Stdin},
 		Out:  out{Writer: os.Stdout},

commands/execenv/env_testing.go 🔗

@@ -93,6 +93,7 @@ func newTestEnv(t *testing.T, isTerminal bool) *Env {
 	})
 
 	return &Env{
+		Ctx:     t.Context(),
 		Repo:    repo,
 		Backend: backend,
 		In:      &TestIn{Buffer: &bytes.Buffer{}, forceIsTerminal: isTerminal},

commands/root.go 🔗

@@ -1,6 +1,7 @@
 package commands
 
 import (
+	"context"
 	"os"
 
 	"github.com/spf13/cobra"
@@ -11,7 +12,7 @@ import (
 	"github.com/git-bug/git-bug/commands/user"
 )
 
-func NewRootCommand(version string) *cobra.Command {
+func NewRootCommand(ctx context.Context, version string) *cobra.Command {
 	cmd := &cobra.Command{
 		Use:   execenv.RootCommandName,
 		Short: "A bug tracker embedded in Git",
@@ -54,7 +55,7 @@ the same git remote you are already using to collaborate with other people.
 		child.GroupID = groupID
 	}
 
-	env := execenv.NewEnv()
+	env := execenv.NewEnv(ctx)
 
 	addCmdWithGroup(bugcmd.NewBugCommand(env), entityGroup)
 	addCmdWithGroup(usercmd.NewUserCommand(env), entityGroup)

commands/webui.go 🔗

@@ -5,14 +5,10 @@ import (
 	"errors"
 	"fmt"
 	"io"
-	"log"
 	"net"
 	"net/http"
 	"net/url"
-	"os"
-	"os/signal"
 	"strconv"
-	"syscall"
 	"time"
 
 	"github.com/99designs/gqlgen/graphql/playground"
@@ -134,27 +130,23 @@ func runWebUI(env *execenv.Env, opts webUIOptions) error {
 	}
 
 	done := make(chan bool)
-	quit := make(chan os.Signal, 1)
-
-	// register as handler of the interrupt signal to trigger the teardown
-	signal.Notify(quit, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGTERM, os.Interrupt)
 
 	go func() {
-		<-quit
+		<-env.Ctx.Done()
 		env.Out.Println("shutting down...")
 
-		ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
-		defer cancel()
+		ctxTeardown, cancelTeardown := context.WithTimeout(context.Background(), 30*time.Second)
+		defer cancelTeardown()
 
 		srv.SetKeepAlivesEnabled(false)
-		if err := srv.Shutdown(ctx); err != nil {
-			log.Fatalf("Could not gracefully shutdown the WebUI: %v\n", err)
+		if err := srv.Shutdown(ctxTeardown); err != nil {
+			env.Err.Printf("Could not gracefully shutdown the WebUI: %v\n", err)
 		}
 
 		// Teardown
 		err := graphqlHandler.Close()
 		if err != nil {
-			env.Out.Println(err)
+			env.Err.Println(err)
 		}
 
 		close(done)
@@ -163,7 +155,7 @@ func runWebUI(env *execenv.Env, opts webUIOptions) error {
 	env.Out.Printf("Web UI: %s\n", webUiAddr)
 	env.Out.Printf("Graphql API: http://%s/graphql\n", addr)
 	env.Out.Printf("Graphql Playground: http://%s/playground\n", addr)
-	env.Out.Printf("[ Press Ctrl+c to quit ]\n\n")
+	env.Out.Printf("\n[ Press Ctrl+c to quit ]\n\n")
 
 	configOpen, err := env.Repo.AnyConfig().ReadBool(webUIOpenConfigKey)
 	if errors.Is(err, repository.ErrNoConfigEntry) {
@@ -177,26 +169,23 @@ func runWebUI(env *execenv.Env, opts webUIOptions) error {
 
 	if shouldOpen {
 		go func() {
-			maxAttempts := 3
+			const maxAttempts = 3
 			if isUp(toOpen, maxAttempts, 3*time.Second) {
 				err = open.Run(toOpen)
 				if err != nil {
-					env.Out.Println(err)
+					env.Err.Println(err)
 					return
 				}
 
 				env.Out.Printf("opened your default browser to url: %s\n", toOpen)
 				return
-			} else {
-				env.Out.Printf(
-					"uh oh! it appears that the http server hasn't started.\n"+
-						"we failed to reach %s after %d attempts, exiting now.\n",
-					toOpen,
-					maxAttempts,
-				)
-				quit <- syscall.SIGQUIT
-				return
 			}
+
+			env.Err.Printf(
+				"uh oh! it appears that the http server hasn't started.\n"+
+					"we failed to reach %s after %d attempts, exiting now.\n",
+				toOpen, maxAttempts,
+			)
 		}()
 	}
 
@@ -219,7 +208,7 @@ func isUp(url string, maxRetries int, initialDelay time.Duration) bool {
 	for attempt := 1; attempt <= maxRetries; attempt++ {
 		resp, err := client.Head(url)
 		if err == nil {
-			resp.Body.Close()
+			_ = resp.Body.Close()
 			if resp.StatusCode >= 200 && resp.StatusCode < 400 {
 				return true
 			}

doc/generate.go 🔗

@@ -1,6 +1,7 @@
 package main
 
 import (
+	"context"
 	"fmt"
 	"os"
 	"path/filepath"
@@ -34,7 +35,7 @@ func main() {
 		wg.Add(1)
 		go func(name string, f func(*cobra.Command) error) {
 			defer wg.Done()
-			root := commands.NewRootCommand("")
+			root := commands.NewRootCommand(context.Background(), "")
 			err := f(root)
 			if err != nil {
 				fmt.Printf("  - %s: FATAL\n", name)

main.go 🔗

@@ -4,14 +4,20 @@
 package main
 
 import (
+	"context"
 	"os"
+	"os/signal"
+	"syscall"
 
 	"github.com/git-bug/git-bug/commands"
 )
 
 func main() {
+	ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
+	defer cancel()
+
 	v, _ := getVersion()
-	root := commands.NewRootCommand(v)
+	root := commands.NewRootCommand(ctx, v)
 	if err := root.Execute(); err != nil {
 		os.Exit(1)
 	}

misc/completion/generate.go 🔗

@@ -1,6 +1,7 @@
 package main
 
 import (
+	"context"
 	"fmt"
 	"os"
 	"path/filepath"
@@ -26,7 +27,7 @@ func main() {
 		wg.Add(1)
 		go func(name string, f func(*cobra.Command) error) {
 			defer wg.Done()
-			root := commands.NewRootCommand("")
+			root := commands.NewRootCommand(context.Background(), "")
 			err := f(root)
 			if err != nil {
 				fmt.Printf("  - %s: %v\n", name, err)