feat(config): add git-bug.remote for defining the default remote (#1460)

Matěj Cepl and William Ahern created

Other way is to have explicit REMOTE argument.

---------

Signed-off-by: Matěj Cepl <mcepl@cepl.eu>
Co-authored-by: William Ahern <william@25thandClement.com>

Change summary

commands/pull.go          | 16 +++++++++++-----
commands/push.go          | 16 +++++++++++-----
repository/config.go      | 11 +++++++++++
repository/config_test.go | 32 ++++++++++++++++++++++++++++++++
4 files changed, 65 insertions(+), 10 deletions(-)

Detailed changes

commands/pull.go 🔗

@@ -8,6 +8,7 @@ import (
 	"github.com/git-bug/git-bug/commands/completion"
 	"github.com/git-bug/git-bug/commands/execenv"
 	"github.com/git-bug/git-bug/entity"
+	"github.com/git-bug/git-bug/repository"
 )
 
 func newPullCommand(env *execenv.Env) *cobra.Command {
@@ -25,13 +26,18 @@ func newPullCommand(env *execenv.Env) *cobra.Command {
 }
 
 func runPull(env *execenv.Env, args []string) error {
-	if len(args) > 1 {
+	var remote string
+	switch {
+	case len(args) > 1:
 		return errors.New("Only pulling from one remote at a time is supported")
-	}
-
-	remote := "origin"
-	if len(args) == 1 {
+	case len(args) == 1:
 		remote = args[0]
+	default:
+		v, err := repository.GetDefaultString("git-bug.remote", env.Repo.AnyConfig(), "origin")
+		if err != nil {
+			return err
+		}
+		remote = v
 	}
 
 	env.Out.Println("Fetching remote ...")

commands/push.go 🔗

@@ -7,6 +7,7 @@ import (
 
 	"github.com/git-bug/git-bug/commands/completion"
 	"github.com/git-bug/git-bug/commands/execenv"
+	"github.com/git-bug/git-bug/repository"
 )
 
 func newPushCommand(env *execenv.Env) *cobra.Command {
@@ -24,13 +25,18 @@ func newPushCommand(env *execenv.Env) *cobra.Command {
 }
 
 func runPush(env *execenv.Env, args []string) error {
-	if len(args) > 1 {
+	var remote string
+	switch {
+	case len(args) > 1:
 		return errors.New("Only pushing to one remote at a time is supported")
-	}
-
-	remote := "origin"
-	if len(args) == 1 {
+	case len(args) == 1:
 		remote = args[0]
+	default:
+		v, err := repository.GetDefaultString("git-bug.remote", env.Repo.AnyConfig(), "origin")
+		if err != nil {
+			return err
+		}
+		remote = v
 	}
 
 	stdout, err := env.Backend.Push(remote)

repository/config.go 🔗

@@ -60,6 +60,17 @@ type ConfigWrite interface {
 	RemoveAll(keyPrefix string) error
 }
 
+func GetDefaultString(key string, cfg ConfigRead, def string) (string, error) {
+	val, err := cfg.ReadString(key)
+	if err == nil {
+		return val, nil
+	} else if errors.Is(err, ErrNoConfigEntry) {
+		return def, nil
+	} else {
+		return "", err
+	}
+}
+
 func ParseTimestamp(s string) (time.Time, error) {
 	timestamp, err := strconv.Atoi(s)
 	if err != nil {

repository/config_test.go 🔗

@@ -4,6 +4,7 @@ import (
 	"testing"
 	"time"
 
+	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
 )
 
@@ -52,3 +53,34 @@ func TestMergedConfig(t *testing.T) {
 		"timestamp": "5678",
 	})
 }
+
+func TestGetDefaultString(t *testing.T) {
+	cfg := NewMemConfig()
+
+	// Test with missing key - should return default
+	val, err := GetDefaultString("missing.key", cfg, "default_value")
+	require.NoError(t, err)
+	assert.Equal(t, "default_value", val)
+
+	// Test with existing key - should return actual value
+	require.NoError(t, cfg.StoreString("existing.key", "actual_value"))
+	val, err = GetDefaultString("existing.key", cfg, "default_value")
+	require.NoError(t, err)
+	assert.Equal(t, "actual_value", val)
+
+	// Test with empty string value - should return empty string, not default
+	require.NoError(t, cfg.StoreString("empty.key", ""))
+	val, err = GetDefaultString("empty.key", cfg, "default_value")
+	require.NoError(t, err)
+	assert.Equal(t, "", val)
+
+	// Test the specific git-bug.remote case
+	val, err = GetDefaultString("git-bug.remote", cfg, "origin")
+	require.NoError(t, err)
+	assert.Equal(t, "origin", val)
+
+	require.NoError(t, cfg.StoreString("git-bug.remote", "upstream"))
+	val, err = GetDefaultString("git-bug.remote", cfg, "origin")
+	require.NoError(t, err)
+	assert.Equal(t, "upstream", val)
+}