input: better reusable prompt functions

Michael Muré created

Change summary

bridge/github/config.go | 139 +++---------------------------------------
commands/user_create.go |   6 
input/prompt.go         | 106 +++++++++++++++++++++++++++++---
3 files changed, 110 insertions(+), 141 deletions(-)

Detailed changes

bridge/github/config.go 🔗

@@ -14,21 +14,19 @@ import (
 	"sort"
 	"strconv"
 	"strings"
-	"syscall"
 	"time"
 
 	text "github.com/MichaelMure/go-term-text"
 	"github.com/pkg/errors"
-	"golang.org/x/crypto/ssh/terminal"
 
 	"github.com/MichaelMure/git-bug/bridge/core"
 	"github.com/MichaelMure/git-bug/bridge/core/auth"
 	"github.com/MichaelMure/git-bug/cache"
 	"github.com/MichaelMure/git-bug/entity"
 	"github.com/MichaelMure/git-bug/identity"
+	"github.com/MichaelMure/git-bug/input"
 	"github.com/MichaelMure/git-bug/repository"
 	"github.com/MichaelMure/git-bug/util/colors"
-	"github.com/MichaelMure/git-bug/util/interrupt"
 )
 
 const (
@@ -320,21 +318,14 @@ func promptToken() (string, error) {
 		panic("regexp compile:" + err.Error())
 	}
 
-	for {
-		fmt.Print("Enter token: ")
-
-		line, err := bufio.NewReader(os.Stdin).ReadString('\n')
-		if err != nil {
-			return "", err
+	validator := func(name string, value string) (complaint string, err error) {
+		if re.MatchString(value) {
+			return "", nil
 		}
-
-		token := strings.TrimSpace(line)
-		if re.MatchString(token) {
-			return token, nil
-		}
-
-		fmt.Println("token has incorrect format")
+		return "token has incorrect format", nil
 	}
+
+	return input.Prompt("Enter token", "token", "", input.Required, validator)
 }
 
 func loginAndRequestToken(owner, project string) (string, error) {
@@ -348,17 +339,18 @@ func loginAndRequestToken(owner, project string) (string, error) {
 	fmt.Println()
 
 	// prompt project visibility to know the token scope needed for the repository
-	isPublic, err := promptProjectVisibility()
+	i, err := input.PromptChoice("repository visibility", []string{"public", "private"})
 	if err != nil {
 		return "", err
 	}
+	isPublic := i == 0
 
 	username, err := promptUsername()
 	if err != nil {
 		return "", err
 	}
 
-	password, err := promptPassword()
+	password, err := input.PromptPassword("Password", "password", input.Required)
 	if err != nil {
 		return "", err
 	}
@@ -387,12 +379,12 @@ func loginAndRequestToken(owner, project string) (string, error) {
 	// Handle 2FA is needed
 	OTPHeader := resp.Header.Get("X-GitHub-OTP")
 	if resp.StatusCode == http.StatusUnauthorized && OTPHeader != "" {
-		otpCode, err := prompt2FA()
+		otpCode, err := input.PromptPassword("Two-factor authentication code", "code", input.Required)
 		if err != nil {
 			return "", err
 		}
 
-		resp, err = requestTokenWith2FA(note, username, password, otpCode, scope)
+		resp, err = requestTokenWith2FA(note, login, password, otpCode, scope)
 		if err != nil {
 			return "", err
 		}
@@ -408,29 +400,6 @@ func loginAndRequestToken(owner, project string) (string, error) {
 	return "", fmt.Errorf("error creating token %v: %v", resp.StatusCode, string(b))
 }
 
-func promptUsername() (string, error) {
-	for {
-		fmt.Print("username: ")
-
-		line, err := bufio.NewReader(os.Stdin).ReadString('\n')
-		if err != nil {
-			return "", err
-		}
-
-		line = strings.TrimSpace(line)
-
-		ok, err := validateUsername(line)
-		if err != nil {
-			return "", err
-		}
-		if ok {
-			return line, nil
-		}
-
-		fmt.Println("invalid username")
-	}
-}
-
 func promptURL(repo repository.RepoCommon) (string, string, error) {
 	// remote suggestions
 	remotes, err := repo.GetRemotes()
@@ -585,87 +554,3 @@ func validateProject(owner, project string, token *auth.Token) (bool, error) {
 
 	return resp.StatusCode == http.StatusOK, nil
 }
-
-func promptPassword() (string, error) {
-	termState, err := terminal.GetState(int(syscall.Stdin))
-	if err != nil {
-		return "", err
-	}
-
-	cancel := interrupt.RegisterCleaner(func() error {
-		return terminal.Restore(int(syscall.Stdin), termState)
-	})
-	defer cancel()
-
-	for {
-		fmt.Print("password: ")
-
-		bytePassword, err := terminal.ReadPassword(int(syscall.Stdin))
-		// new line for coherent formatting, ReadPassword clip the normal new line
-		// entered by the user
-		fmt.Println()
-
-		if err != nil {
-			return "", err
-		}
-
-		if len(bytePassword) > 0 {
-			return string(bytePassword), nil
-		}
-
-		fmt.Println("password is empty")
-	}
-}
-
-func prompt2FA() (string, error) {
-	termState, err := terminal.GetState(int(syscall.Stdin))
-	if err != nil {
-		return "", err
-	}
-
-	cancel := interrupt.RegisterCleaner(func() error {
-		return terminal.Restore(int(syscall.Stdin), termState)
-	})
-	defer cancel()
-
-	for {
-		fmt.Print("two-factor authentication code: ")
-
-		byte2fa, err := terminal.ReadPassword(int(syscall.Stdin))
-		fmt.Println()
-		if err != nil {
-			return "", err
-		}
-
-		if len(byte2fa) > 0 {
-			return string(byte2fa), nil
-		}
-
-		fmt.Println("code is empty")
-	}
-}
-
-func promptProjectVisibility() (bool, error) {
-	for {
-		fmt.Println("[1]: public")
-		fmt.Println("[2]: private")
-		fmt.Print("repository visibility: ")
-
-		line, err := bufio.NewReader(os.Stdin).ReadString('\n')
-		fmt.Println()
-		if err != nil {
-			return false, err
-		}
-
-		line = strings.TrimSpace(line)
-
-		index, err := strconv.Atoi(line)
-		if err != nil || (index != 1 && index != 2) {
-			fmt.Println("invalid input")
-			continue
-		}
-
-		// return true for public repositories, false for private
-		return index == 1, nil
-	}
-}

commands/user_create.go 🔗

@@ -23,7 +23,7 @@ func runUserCreate(cmd *cobra.Command, args []string) error {
 		return err
 	}
 
-	name, err := input.PromptValueRequired("Name", preName)
+	name, err := input.Prompt("Name", "name", preName, input.Required)
 	if err != nil {
 		return err
 	}
@@ -33,12 +33,12 @@ func runUserCreate(cmd *cobra.Command, args []string) error {
 		return err
 	}
 
-	email, err := input.PromptValueRequired("Email", preEmail)
+	email, err := input.Prompt("Email", "email", preEmail, input.Required)
 	if err != nil {
 		return err
 	}
 
-	login, err := input.PromptValue("Avatar URL", "")
+	login, err := input.Prompt("Avatar URL", "avatar", "")
 	if err != nil {
 		return err
 	}

input/prompt.go 🔗

@@ -4,23 +4,36 @@ import (
 	"bufio"
 	"fmt"
 	"os"
+	"strconv"
 	"strings"
+	"syscall"
+
+	"golang.org/x/crypto/ssh/terminal"
+
+	"github.com/MichaelMure/git-bug/util/interrupt"
 )
 
-func PromptValue(name string, preValue string) (string, error) {
-	return promptValue(name, preValue, false)
+// PromptValidator is a validator for a user entry
+type PromptValidator func(name string, value string) (complaint string, err error)
+
+// Required is a validator preventing a "" value
+func Required(name string, value string) (string, error) {
+	if value == "" {
+		return fmt.Sprintf("%s is empty", name), nil
+	}
+	return "", nil
 }
 
-func PromptValueRequired(name string, preValue string) (string, error) {
-	return promptValue(name, preValue, true)
+func Prompt(prompt, name string, validators ...PromptValidator) (string, error) {
+	return PromptDefault(prompt, name, "", validators...)
 }
 
-func promptValue(name string, preValue string, required bool) (string, error) {
+func PromptDefault(prompt, name, preValue string, validators ...PromptValidator) (string, error) {
 	for {
 		if preValue != "" {
-			_, _ = fmt.Fprintf(os.Stderr, "%s [%s]: ", name, preValue)
+			_, _ = fmt.Fprintf(os.Stderr, "%s [%s]: ", prompt, preValue)
 		} else {
-			_, _ = fmt.Fprintf(os.Stderr, "%s: ", name)
+			_, _ = fmt.Fprintf(os.Stderr, "%s: ", prompt)
 		}
 
 		line, err := bufio.NewReader(os.Stdin).ReadString('\n')
@@ -31,14 +44,85 @@ func promptValue(name string, preValue string, required bool) (string, error) {
 		line = strings.TrimSpace(line)
 
 		if preValue != "" && line == "" {
-			return preValue, nil
+			line = preValue
 		}
 
-		if required && line == "" {
-			_, _ = fmt.Fprintf(os.Stderr, "%s is empty\n", name)
-			continue
+		for _, validator := range validators {
+			complaint, err := validator(name, line)
+			if err != nil {
+				return "", err
+			}
+			if complaint != "" {
+				_, _ = fmt.Fprintln(os.Stderr, complaint)
+				continue
+			}
 		}
 
 		return line, nil
 	}
 }
+
+func PromptPassword(prompt, name string, validators ...PromptValidator) (string, error) {
+	termState, err := terminal.GetState(syscall.Stdin)
+	if err != nil {
+		return "", err
+	}
+
+	cancel := interrupt.RegisterCleaner(func() error {
+		return terminal.Restore(syscall.Stdin, termState)
+	})
+	defer cancel()
+
+	for {
+		_, _ = fmt.Fprintf(os.Stderr, "%s: ", prompt)
+
+		bytePassword, err := terminal.ReadPassword(syscall.Stdin)
+		// new line for coherent formatting, ReadPassword clip the normal new line
+		// entered by the user
+		fmt.Println()
+
+		if err != nil {
+			return "", err
+		}
+
+		pass := string(bytePassword)
+
+		for _, validator := range validators {
+			complaint, err := validator(name, pass)
+			if err != nil {
+				return "", err
+			}
+			if complaint != "" {
+				_, _ = fmt.Fprintln(os.Stderr, complaint)
+				continue
+			}
+		}
+
+		return pass, nil
+	}
+}
+
+func PromptChoice(prompt string, choices []string) (int, error) {
+	for {
+		for i, choice := range choices {
+			_, _ = fmt.Fprintf(os.Stderr, "[%d]: %s\n", i+1, choice)
+		}
+		_, _ = fmt.Fprintf(os.Stderr, "%s: ", prompt)
+
+		line, err := bufio.NewReader(os.Stdin).ReadString('\n')
+		fmt.Println()
+		if err != nil {
+			return 0, err
+		}
+
+		line = strings.TrimSpace(line)
+
+		index, err := strconv.Atoi(line)
+		if err != nil || index < 1 || index > len(choices) {
+			fmt.Println("invalid input")
+			continue
+		}
+
+		return index, nil
+	}
+}