Merge pull request #210 from MichaelMure/cleaner-with-cancel

Michael Muré created

interrupt: allow to cancel a cleaner

Change summary

util/interrupt/cleaner.go      | 95 ++++++++++++++++++++++++-----------
util/interrupt/cleaner_test.go | 60 +++++++++++++++------
2 files changed, 107 insertions(+), 48 deletions(-)

Detailed changes

util/interrupt/cleaner.go 🔗

@@ -4,45 +4,80 @@ import (
 	"fmt"
 	"os"
 	"os/signal"
+	"sync"
 	"syscall"
 )
 
-// Cleaner type refers to a function with no inputs that returns an error
-type Cleaner func() error
-
-var cleaners []Cleaner
-var active = false
-
-// RegisterCleaner is responsible for registering a cleaner function. When a function is registered, the Signal watcher is started in a goroutine.
-func RegisterCleaner(f ...Cleaner) {
-	for _, fn := range f {
-		cleaners = append([]Cleaner{fn}, cleaners...)
-		if !active {
-			active = true
-			go func() {
-				ch := make(chan os.Signal, 1)
-				signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM, os.Interrupt)
-				<-ch
-				// Prevent un-terminated ^C character in terminal
-				fmt.Println()
-				errl := clean()
-				for _, err := range errl {
-					fmt.Println(err)
-				}
-				os.Exit(1)
-			}()
-		}
+// CleanerFunc is a function to be executed when an interrupt trigger
+type CleanerFunc func() error
+
+// CancelFunc, if called, will disable the associated cleaner.
+// This allow to create temporary cleaner. Be mindful though to not
+// create too much of them as they are just disabled, not removed from
+// memory.
+type CancelFunc func()
+
+type wrapper struct {
+	f        CleanerFunc
+	disabled bool
+}
+
+var mu sync.Mutex
+var cleaners []*wrapper
+var handlerCreated = false
+
+// RegisterCleaner is responsible for registering a cleaner function.
+// When a function is registered, the Signal watcher is started in a goroutine.
+func RegisterCleaner(cleaner CleanerFunc) CancelFunc {
+	mu.Lock()
+	defer mu.Unlock()
+
+	w := &wrapper{f: cleaner}
+
+	cancel := func() {
+		mu.Lock()
+		defer mu.Unlock()
+		w.disabled = true
 	}
+
+	// prepend to later execute then in reverse order
+	cleaners = append([]*wrapper{w}, cleaners...)
+
+	if handlerCreated {
+		return cancel
+	}
+
+	handlerCreated = true
+	go func() {
+		ch := make(chan os.Signal, 1)
+		signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM, os.Interrupt)
+		<-ch
+		// Prevent un-terminated ^C character in terminal
+		fmt.Println()
+		errl := clean()
+		for _, err := range errl {
+			_, _ = fmt.Fprintln(os.Stderr, err)
+		}
+		os.Exit(1)
+	}()
+
+	return cancel
 }
 
 // clean invokes all registered cleanup functions, and returns a list of errors, if they exist.
-func clean() (errorlist []error) {
-	for _, f := range cleaners {
-		err := f()
+func clean() (errorList []error) {
+	mu.Lock()
+	defer mu.Unlock()
+
+	for _, cleaner := range cleaners {
+		if cleaner.disabled {
+			continue
+		}
+		err := cleaner.f()
 		if err != nil {
-			errorlist = append(errorlist, err)
+			errorList = append(errorList, err)
 		}
 	}
-	cleaners = []Cleaner{}
+	cleaners = []*wrapper{}
 	return
 }

util/interrupt/cleaner_test.go 🔗

@@ -3,48 +3,72 @@ package interrupt
 import (
 	"errors"
 	"testing"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 // TestRegisterAndErrorAtCleaning tests if the registered order was kept by checking the returned errors
 func TestRegisterAndErrorAtCleaning(t *testing.T) {
-	active = true // this prevents goroutine from being started during the tests
+	handlerCreated = true // this prevents goroutine from being started during the tests
 
-	f := func() error {
-		return errors.New("X")
+	f1 := func() error {
+		return errors.New("1")
 	}
 	f2 := func() error {
-		return errors.New("Y")
+		return errors.New("2")
 	}
 	f3 := func() error {
 		return nil
 	}
-	RegisterCleaner(f)
-	RegisterCleaner(f2, f3)
-	// count := 0
+
+	RegisterCleaner(f1)
+	RegisterCleaner(f2)
+	RegisterCleaner(f3)
 
 	errl := clean()
-	if len(errl) != 2 {
-		t.Fatalf("unexpected error count")
-	}
-	if errl[0].Error() != "Y" && errl[1].Error() != "X" {
-		t.Fatalf("unexpected error order")
 
-	}
+	require.Len(t, errl, 2)
+
+	// cleaners should execute in the reverse order they have been defined
+	assert.Equal(t, "2", errl[0].Error())
+	assert.Equal(t, "1", errl[1].Error())
 }
 
 func TestRegisterAndClean(t *testing.T) {
-	active = true // this prevents goroutine from being started during the tests
+	handlerCreated = true // this prevents goroutine from being started during the tests
 
-	f := func() error {
+	f1 := func() error {
 		return nil
 	}
 	f2 := func() error {
 		return nil
 	}
-	RegisterCleaner(f, f2)
+
+	RegisterCleaner(f1)
+	RegisterCleaner(f2)
 
 	errl := clean()
-	if len(errl) != 0 {
-		t.Fatalf("unexpected error count")
+	assert.Len(t, errl, 0)
+}
+
+func TestCancel(t *testing.T) {
+	handlerCreated = true // this prevents goroutine from being started during the tests
+
+	f1 := func() error {
+		return errors.New("1")
+	}
+	f2 := func() error {
+		return errors.New("2")
 	}
+
+	cancel1 := RegisterCleaner(f1)
+	RegisterCleaner(f2)
+
+	cancel1()
+
+	errl := clean()
+	require.Len(t, errl, 1)
+
+	assert.Equal(t, "2", errl[0].Error())
 }