diff --git a/commands/pull.go b/commands/pull.go index 90f23bfae9ada1cd1c88bb0c33e6c56be282b118..99fae8e86ffe24b820ae23148f3b1d12b205d8e4 100644 --- a/commands/pull.go +++ b/commands/pull.go @@ -8,6 +8,7 @@ import ( "github.com/git-bug/git-bug/commands/completion" "github.com/git-bug/git-bug/commands/execenv" "github.com/git-bug/git-bug/entity" + "github.com/git-bug/git-bug/repository" ) func newPullCommand(env *execenv.Env) *cobra.Command { @@ -25,13 +26,18 @@ func newPullCommand(env *execenv.Env) *cobra.Command { } func runPull(env *execenv.Env, args []string) error { - if len(args) > 1 { + var remote string + switch { + case len(args) > 1: return errors.New("Only pulling from one remote at a time is supported") - } - - remote := "origin" - if len(args) == 1 { + case len(args) == 1: remote = args[0] + default: + v, err := repository.GetDefaultString("git-bug.remote", env.Repo.AnyConfig(), "origin") + if err != nil { + return err + } + remote = v } env.Out.Println("Fetching remote ...") diff --git a/commands/push.go b/commands/push.go index 384925178a4cbc608d5820e9776081788fd35828..7024bd794edd60b1b243c8a86ec317f5593bf16c 100644 --- a/commands/push.go +++ b/commands/push.go @@ -7,6 +7,7 @@ import ( "github.com/git-bug/git-bug/commands/completion" "github.com/git-bug/git-bug/commands/execenv" + "github.com/git-bug/git-bug/repository" ) func newPushCommand(env *execenv.Env) *cobra.Command { @@ -24,13 +25,18 @@ func newPushCommand(env *execenv.Env) *cobra.Command { } func runPush(env *execenv.Env, args []string) error { - if len(args) > 1 { + var remote string + switch { + case len(args) > 1: return errors.New("Only pushing to one remote at a time is supported") - } - - remote := "origin" - if len(args) == 1 { + case len(args) == 1: remote = args[0] + default: + v, err := repository.GetDefaultString("git-bug.remote", env.Repo.AnyConfig(), "origin") + if err != nil { + return err + } + remote = v } stdout, err := env.Backend.Push(remote) diff --git a/repository/config.go b/repository/config.go index 7e1ee6e85ec6e45dfa224e7df7e751f5cfc8af23..109ee8772241f5021bc5946db8814a6bdace948b 100644 --- a/repository/config.go +++ b/repository/config.go @@ -60,6 +60,17 @@ type ConfigWrite interface { RemoveAll(keyPrefix string) error } +func GetDefaultString(key string, cfg ConfigRead, def string) (string, error) { + val, err := cfg.ReadString(key) + if err == nil { + return val, nil + } else if errors.Is(err, ErrNoConfigEntry) { + return def, nil + } else { + return "", err + } +} + func ParseTimestamp(s string) (time.Time, error) { timestamp, err := strconv.Atoi(s) if err != nil { diff --git a/repository/config_test.go b/repository/config_test.go index 2a76354008c0c37b28a8b034b75e9fcb5fea65f5..022c111af64c7f574972487858721c15f5231f8e 100644 --- a/repository/config_test.go +++ b/repository/config_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -52,3 +53,34 @@ func TestMergedConfig(t *testing.T) { "timestamp": "5678", }) } + +func TestGetDefaultString(t *testing.T) { + cfg := NewMemConfig() + + // Test with missing key - should return default + val, err := GetDefaultString("missing.key", cfg, "default_value") + require.NoError(t, err) + assert.Equal(t, "default_value", val) + + // Test with existing key - should return actual value + require.NoError(t, cfg.StoreString("existing.key", "actual_value")) + val, err = GetDefaultString("existing.key", cfg, "default_value") + require.NoError(t, err) + assert.Equal(t, "actual_value", val) + + // Test with empty string value - should return empty string, not default + require.NoError(t, cfg.StoreString("empty.key", "")) + val, err = GetDefaultString("empty.key", cfg, "default_value") + require.NoError(t, err) + assert.Equal(t, "", val) + + // Test the specific git-bug.remote case + val, err = GetDefaultString("git-bug.remote", cfg, "origin") + require.NoError(t, err) + assert.Equal(t, "origin", val) + + require.NoError(t, cfg.StoreString("git-bug.remote", "upstream")) + val, err = GetDefaultString("git-bug.remote", cfg, "origin") + require.NoError(t, err) + assert.Equal(t, "upstream", val) +}