Update interactive restore flow with snapshot picker and overwrite prompt

Amolith created

Change summary

cmd/root.go                            | 107 ++++++++++++++++++++++++---
internal/form/form.go                  |  39 +++------
internal/form/snapshots.go             | 104 +++++++++++++++++++++++++++
internal/restic/list_snapshots.go      |  20 ++++-
internal/restic/list_snapshots_test.go |  59 ++++++++++++--
5 files changed, 277 insertions(+), 52 deletions(-)

Detailed changes

cmd/root.go 🔗

@@ -2,6 +2,7 @@ package cmd
 
 import (
 	"context"
+	"errors"
 	"fmt"
 	"os"
 	"slices"
@@ -9,6 +10,7 @@ import (
 
 	tea "charm.land/bubbletea/v2"
 	"charm.land/fang/v2"
+	"charm.land/huh/v2/spinner"
 	"github.com/spf13/cobra"
 
 	"git.secluded.site/keld/internal/config"
@@ -265,29 +267,108 @@ func promptForCommand(command string, cfg *config.ResolvedConfig) (map[string][]
 	}
 }
 
-// promptRestore collects the snapshot ID and target directory for restore,
-// skipping prompts for values already present in the resolved config.
+// promptRestore collects the snapshot ID, target directory, and overwrite
+// behavior for restore, skipping prompts for values already present in
+// the resolved config.
+//
+// Snapshot selection uses an interactive picker when the resolved config
+// has a repository configured; otherwise falls back to a text input.
+// Auth/network errors during snapshot listing produce a terse note and
+// fall back to text input. User cancellation aborts entirely.
 func promptRestore(cfg *config.ResolvedConfig) (map[string][]string, error) {
-	hasSnapshotID := len(cfg.Arguments) > 0
-	hasTarget := cfg.HasFlag("target")
+	overrides := make(map[string][]string)
+
+	// Step 1: Snapshot ID.
+	if len(cfg.Arguments) == 0 {
+		snapshotID, err := promptSnapshotID(cfg)
+		if err != nil {
+			return nil, err
+		}
+		if snapshotID == "" {
+			return nil, fmt.Errorf("snapshot ID is required")
+		}
+		overrides[overrideArgumentsKey] = []string{snapshotID}
+	}
+
+	// Step 2: Target directory.
+	if !cfg.HasFlag("target") {
+		target, err := form.TargetDirectory()
+		if err != nil {
+			return nil, fmt.Errorf("target directory: %w", err)
+		}
+		overrides["target"] = []string{target}
+	}
 
-	if hasSnapshotID && hasTarget {
+	// Step 3: Overwrite behavior.
+	if !cfg.HasFlag("overwrite") {
+		overwrite, err := form.SelectOverwrite()
+		if err != nil {
+			return nil, fmt.Errorf("overwrite selection: %w", err)
+		}
+		overrides["overwrite"] = []string{overwrite}
+	}
+
+	if len(overrides) == 0 {
 		return nil, nil
 	}
+	return overrides, nil
+}
 
-	snapshotID, target, err := form.RestoreInputs(hasSnapshotID, hasTarget)
+// promptSnapshotID attempts to show an interactive snapshot picker using
+// the repository from the resolved config. Falls back to a manual text
+// input when:
+//   - No repository is configured (silent fallback)
+//   - Snapshot listing fails (prints a terse note, then text input)
+//   - No snapshots exist (prints a note, then text input)
+//
+// If the user picks "Enter ID manually…" from the picker, switches to
+// the text input. User cancellation (ctrl+c) always aborts entirely.
+func promptSnapshotID(cfg *config.ResolvedConfig) (string, error) {
+	var snapshots []restic.Snapshot
+	var listErr error
+
+	err := spinner.New().
+		Title("Loading snapshots…").
+		Action(func() {
+			snapshots, listErr = restic.ListSnapshots(cfg)
+		}).
+		Run()
 	if err != nil {
-		return nil, fmt.Errorf("restore inputs: %w", err)
+		// Spinner itself failed (unlikely); fall back gracefully.
+		return form.ManualSnapshotID()
 	}
 
-	overrides := make(map[string][]string)
-	if !hasSnapshotID {
-		overrides[overrideArgumentsKey] = []string{snapshotID}
+	if listErr != nil {
+		// Silent fallback for "no repo configured"; terse note for
+		// anything else (auth failure, network error, bad JSON, etc.)
+		if !isNoRepoError(listErr) {
+			fmt.Fprintf(os.Stderr, "Could not list snapshots: %v\n", listErr)
+		}
+		return form.ManualSnapshotID()
 	}
-	if !hasTarget {
-		overrides["target"] = []string{target}
+
+	if len(snapshots) == 0 {
+		fmt.Fprintln(os.Stderr, "No snapshots found in repository.")
+		return form.ManualSnapshotID()
 	}
-	return overrides, nil
+
+	selected, err := form.SelectSnapshot(snapshots)
+	if err != nil {
+		return "", err
+	}
+
+	if form.IsManualEntry(selected) {
+		return form.ManualSnapshotID()
+	}
+
+	return selected, nil
+}
+
+// isNoRepoError checks whether the error from ListSnapshots indicates
+// that no repository was configured (as opposed to an auth/network/etc.
+// failure).
+func isNoRepoError(err error) bool {
+	return errors.Is(err, restic.ErrNoRepo)
 }
 
 // promptBackup collects backup paths when none are configured in the preset.

internal/form/form.go 🔗

@@ -49,34 +49,23 @@ func SelectPreset(presets []string) (string, error) {
 	return selected, nil
 }
 
-// RestoreInputs collects the required inputs for `restic restore`:
-// a snapshot ID and a target directory. Fields whose corresponding
-// "has" parameter is true are skipped (already provided by config).
-func RestoreInputs(hasSnapshotID, hasTarget bool) (snapshotID, target string, err error) {
-	var fields []huh.Field
-	if !hasSnapshotID {
-		fields = append(fields, huh.NewInput().
-			Title("Snapshot ID").
-			Placeholder("e.g. latest or a1b2c3d4").
-			Value(&snapshotID).
-			Validate(notEmpty("snapshot ID")))
-	}
-	if !hasTarget {
-		fields = append(fields, huh.NewInput().
-			Title("Target directory").
-			Placeholder("e.g. /tmp/restore").
-			Value(&target).
-			Validate(notEmpty("target directory")))
-	}
-	if len(fields) == 0 {
-		return "", "", nil
-	}
+// TargetDirectory prompts for a restore target directory.
+func TargetDirectory() (string, error) {
+	var target string
+	form := huh.NewForm(
+		huh.NewGroup(
+			huh.NewInput().
+				Title("Target directory").
+				Placeholder("e.g. /tmp/restore").
+				Value(&target).
+				Validate(notEmpty("target directory")),
+		),
+	)
 
-	form := huh.NewForm(huh.NewGroup(fields...))
 	if err := wrapAbort(form.Run()); err != nil {
-		return "", "", err
+		return "", err
 	}
-	return snapshotID, target, nil
+	return target, nil
 }
 
 // BackupPaths collects one or more paths to back up when none are configured.

internal/form/snapshots.go 🔗

@@ -0,0 +1,104 @@
+package form
+
+import (
+	"charm.land/huh/v2"
+
+	"git.secluded.site/keld/internal/restic"
+)
+
+// manualEntryValue is the sentinel returned by SelectSnapshot when the
+// user chooses to type a snapshot ID manually instead of picking from
+// the list.
+const manualEntryValue = "__manual__"
+
+// IsManualEntry reports whether the value returned by SelectSnapshot
+// indicates the user chose manual entry.
+func IsManualEntry(v string) bool {
+	return v == manualEntryValue
+}
+
+// SelectSnapshot presents an interactive picker for the given snapshots.
+// Each snapshot is shown as a formatted summary line; the selected
+// snapshot's short ID is returned.
+//
+// An "Enter ID manually…" option is always included at the end so users
+// can type arbitrary IDs (including the snapshotID:subfolder syntax).
+//
+// Returns the selected short ID, or the manualEntryValue sentinel
+// (check with IsManualEntry), or ErrAborted if the user cancels.
+func SelectSnapshot(snapshots []restic.Snapshot) (string, error) {
+	opts := make([]huh.Option[string], 0, len(snapshots)+1)
+	for _, s := range snapshots {
+		opts = append(opts, huh.NewOption(restic.FormatSnapshotLine(s), s.ShortID))
+	}
+	opts = append(opts, huh.NewOption("Enter ID manually…", manualEntryValue))
+
+	var selected string
+	form := huh.NewForm(
+		huh.NewGroup(
+			huh.NewSelect[string]().
+				Title("Select a snapshot").
+				Options(opts...).
+				Filtering(true).
+				Value(&selected),
+		),
+	)
+
+	if err := wrapAbort(form.Run()); err != nil {
+		return "", err
+	}
+	return selected, nil
+}
+
+// ManualSnapshotID prompts the user to type a snapshot ID. This is used
+// as the fallback when snapshot listing fails or when the user chooses
+// "Enter ID manually…" from the picker.
+func ManualSnapshotID() (string, error) {
+	var id string
+	form := huh.NewForm(
+		huh.NewGroup(
+			huh.NewInput().
+				Title("Snapshot ID").
+				Description("Supports snapshotID:subfolder syntax, or \"latest\".").
+				Placeholder("e.g. latest or a1b2c3d4 or a1b2c3d4:/home/user").
+				Value(&id).
+				Validate(notEmpty("snapshot ID")),
+		),
+	)
+
+	if err := wrapAbort(form.Run()); err != nil {
+		return "", err
+	}
+	return id, nil
+}
+
+// overwriteOptions defines the choices for --overwrite in the order
+// they're presented to the user.
+var overwriteOptions = []huh.Option[string]{
+	huh.NewOption("if-changed  (recommended — only restore what differs)", "if-changed"),
+	huh.NewOption("if-newer    (only overwrite older files)", "if-newer"),
+	huh.NewOption("never       (skip existing files entirely)", "never"),
+	huh.NewOption("always      (restic default — overwrite everything)", "always"),
+}
+
+// SelectOverwrite presents an interactive picker for the --overwrite
+// behavior. The cursor starts on "if-changed".
+//
+// Returns the selected value (one of: if-changed, if-newer, never,
+// always), or ErrAborted if the user cancels.
+func SelectOverwrite() (string, error) {
+	var selected string
+	form := huh.NewForm(
+		huh.NewGroup(
+			huh.NewSelect[string]().
+				Title("Overwrite existing files?").
+				Options(overwriteOptions...).
+				Value(&selected),
+		),
+	)
+
+	if err := wrapAbort(form.Run()); err != nil {
+		return "", err
+	}
+	return selected, nil
+}

internal/restic/list_snapshots.go 🔗

@@ -2,14 +2,20 @@ package restic
 
 import (
 	"bytes"
+	"errors"
 	"fmt"
 	"maps"
+	"os"
 	"os/exec"
 	"sort"
 
 	"git.secluded.site/keld/internal/config"
 )
 
+// ErrNoRepo is returned by ListSnapshots when no repository is configured
+// via flags, config environ, or process environment.
+var ErrNoRepo = errors.New("no repository configured (need --repo, --repository-file, or RESTIC_REPOSITORY)")
+
 // globalFlags is the set of restic global flags that should be forwarded
 // when running `restic snapshots` to list available snapshots. These are
 // the connection, auth, cache, and TLS flags shared across all restic
@@ -43,9 +49,10 @@ var globalFlags = map[string]bool{
 	"--tls-client-cert":       true,
 }
 
-// hasRepoSource reports whether the resolved config provides a repository
-// location, either via a flag (--repo, -r, --repository-file) or via
-// the RESTIC_REPOSITORY / RESTIC_REPOSITORY_FILE environment variables.
+// hasRepoSource reports whether a repository location is available from
+// any source: CLI flags in the resolved config, the config's environ
+// section, or the process environment (e.g. RESTIC_REPOSITORY already
+// set in the user's shell).
 func hasRepoSource(cfg *config.ResolvedConfig) bool {
 	for _, f := range cfg.Flags {
 		switch f.Name {
@@ -61,6 +68,11 @@ func hasRepoSource(cfg *config.ResolvedConfig) bool {
 			return true
 		}
 	}
+	// Check the process environment as a last resort — the user may
+	// have RESTIC_REPOSITORY set in their shell outside of keld config.
+	if os.Getenv("RESTIC_REPOSITORY") != "" || os.Getenv("RESTIC_REPOSITORY_FILE") != "" {
+		return true
+	}
 	return false
 }
 
@@ -70,7 +82,7 @@ func hasRepoSource(cfg *config.ResolvedConfig) bool {
 // is available.
 func buildSnapshotCmd(cfg *config.ResolvedConfig) ([]string, error) {
 	if !hasRepoSource(cfg) {
-		return nil, fmt.Errorf("no repository configured (need --repo, --repository-file, or RESTIC_REPOSITORY)")
+		return nil, ErrNoRepo
 	}
 
 	argv := []string{executable(), "snapshots", "--json"}

internal/restic/list_snapshots_test.go 🔗

@@ -1,6 +1,7 @@
 package restic
 
 import (
+	"errors"
 	"testing"
 
 	"git.secluded.site/keld/internal/config"
@@ -61,16 +62,6 @@ func TestBuildSnapshotCmd(t *testing.T) {
 				"--no-lock",
 			},
 		},
-		{
-			name: "no repo flag",
-			cfg: &config.ResolvedConfig{
-				Command: "restore",
-				Flags: []config.Flag{
-					{Name: "--target", Value: "/tmp/restore"},
-				},
-			},
-			wantErr: true,
-		},
 		{
 			name: "repo via repository-file",
 			cfg: &config.ResolvedConfig{
@@ -168,6 +159,54 @@ func TestBuildSnapshotCmd(t *testing.T) {
 	}
 }
 
+func TestBuildSnapshotCmdNoRepoSentinel(t *testing.T) {
+	// Not parallel: uses t.Setenv to control process environment.
+	t.Setenv("RESTIC_REPOSITORY", "")
+	t.Setenv("RESTIC_REPOSITORY_FILE", "")
+
+	cfg := &config.ResolvedConfig{
+		Command: "restore",
+		Flags: []config.Flag{
+			{Name: "--target", Value: "/tmp/restore"},
+		},
+	}
+
+	_, err := buildSnapshotCmd(cfg)
+	if !errors.Is(err, ErrNoRepo) {
+		t.Errorf("expected ErrNoRepo, got %v", err)
+	}
+}
+
+func TestBuildSnapshotCmdProcessEnv(t *testing.T) {
+	// Not parallel: uses t.Setenv to control process environment.
+	t.Setenv("RESTIC_REPOSITORY", "/srv/from-env")
+	t.Setenv("RESTIC_REPOSITORY_FILE", "")
+
+	cfg := &config.ResolvedConfig{
+		Command: "restore",
+		Flags: []config.Flag{
+			{Name: "--target", Value: "/tmp/restore"},
+		},
+	}
+
+	argv, err := buildSnapshotCmd(cfg)
+	if err != nil {
+		t.Fatalf("unexpected error: %v", err)
+	}
+
+	// No --repo flag in argv since repo comes from process env;
+	// restic reads it directly.
+	wantArgv := []string{"restic", "snapshots", "--json"}
+	if len(argv) != len(wantArgv) {
+		t.Fatalf("argv: got %v, want %v", argv, wantArgv)
+	}
+	for i := range argv {
+		if argv[i] != wantArgv[i] {
+			t.Errorf("argv[%d]: got %q, want %q", i, argv[i], wantArgv[i])
+		}
+	}
+}
+
 func TestCopyEnviron(t *testing.T) {
 	t.Parallel()