feat: add self-update command with version check notification

Amolith created

Adds `crush update` command to check for updates and `crush update apply`
to download and install the latest version. Also displays update
notifications when running `crush --version`.

Assisted-by: Claude Opus 4.5 via Crush <crush@charm.land>

Change summary

internal/cmd/root.go           |  44 ++
internal/cmd/update.go         | 165 +++++++++++
internal/update/update.go      | 525 +++++++++++++++++++++++++++++++++++
internal/update/update_test.go | 334 ++++++++++++++++++++++
4 files changed, 1,062 insertions(+), 6 deletions(-)

Detailed changes

internal/cmd/root.go 🔗

@@ -11,6 +11,7 @@ import (
 	"path/filepath"
 	"strconv"
 	"strings"
+	"time"
 
 	tea "charm.land/bubbletea/v2"
 	"charm.land/lipgloss/v2"
@@ -22,6 +23,7 @@ import (
 	"github.com/charmbracelet/crush/internal/stringext"
 	termutil "github.com/charmbracelet/crush/internal/term"
 	"github.com/charmbracelet/crush/internal/tui"
+	"github.com/charmbracelet/crush/internal/update"
 	"github.com/charmbracelet/crush/internal/version"
 	"github.com/charmbracelet/fang"
 	uv "github.com/charmbracelet/ultraviolet"
@@ -41,6 +43,7 @@ func init() {
 	rootCmd.AddCommand(
 		runCmd,
 		dirsCmd,
+		updateCmd,
 		updateProvidersCmd,
 		logsCmd,
 		schemaCmd,
@@ -135,13 +138,24 @@ func Execute() {
 	// printing the version, and PreRunE runs after the version is already
 	// handled, so that doesn't work either.
 	// This is the only way I could find that works relatively well.
+	versionTemplate := defaultVersionTemplate
 	if term.IsTerminal(os.Stdout.Fd()) {
 		var b bytes.Buffer
 		w := colorprofile.NewWriter(os.Stdout, os.Environ())
 		w.Forward = &b
 		_, _ = w.WriteString(heartbit.String())
-		rootCmd.SetVersionTemplate(b.String() + "\n" + defaultVersionTemplate)
+		versionTemplate = b.String() + "\n" + defaultVersionTemplate
 	}
+
+	// Check if version flag is present and add update notification if available.
+	if hasVersionFlag() {
+		if updateMsg := checkForUpdateSync(); updateMsg != "" {
+			versionTemplate += updateMsg
+		}
+	}
+
+	rootCmd.SetVersionTemplate(versionTemplate)
+
 	if err := fang.Execute(
 		context.Background(),
 		rootCmd,
@@ -152,6 +166,34 @@ func Execute() {
 	}
 }
 
+// hasVersionFlag checks if the version flag is present in os.Args.
+func hasVersionFlag() bool {
+	for _, arg := range os.Args {
+		if arg == "-v" || arg == "--version" {
+			return true
+		}
+	}
+	return false
+}
+
+// checkForUpdateSync performs a synchronous update check with a short timeout.
+// Returns a formatted update message if an update is available, empty string otherwise.
+func checkForUpdateSync() string {
+	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+	defer cancel()
+
+	info, err := update.Check(ctx, version.Version, update.Default)
+	if err != nil || !info.Available() {
+		return ""
+	}
+
+	if info.IsDevelopment() {
+		return fmt.Sprintf("\nThis is a development version of Crush. The latest stable release is v%s.\nRun 'crush update' to learn more.\n", info.Latest)
+	}
+
+	return fmt.Sprintf("\nUpdate available: v%s → v%s\nRun 'crush update apply' to install.\n", info.Current, info.Latest)
+}
+
 func setupAppWithProgressBar(cmd *cobra.Command) (*app.App, error) {
 	if termutil.SupportsProgressBar() {
 		_, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)

internal/cmd/update.go 🔗

@@ -0,0 +1,165 @@
+package cmd
+
+import (
+	"context"
+	"fmt"
+	"os"
+	"time"
+
+	"charm.land/lipgloss/v2"
+	"github.com/charmbracelet/crush/internal/format"
+	"github.com/charmbracelet/crush/internal/tui/components/anim"
+	"github.com/charmbracelet/crush/internal/tui/styles"
+	"github.com/charmbracelet/crush/internal/update"
+	"github.com/charmbracelet/crush/internal/version"
+	"github.com/charmbracelet/x/term"
+	"github.com/spf13/cobra"
+)
+
+var updateCmd = &cobra.Command{
+	Use:   "update",
+	Short: "Check for and apply updates",
+	Long: `Check if a new version of Crush is available.
+Use 'update apply' to download and install the latest version.`,
+	Example: `
+# Check if an update is available
+crush update
+
+# Apply the update if available
+crush update apply
+  `,
+	RunE: func(cmd *cobra.Command, args []string) error {
+		ctx, cancel := context.WithTimeout(cmd.Context(), 30*time.Second)
+		defer cancel()
+
+		spinner := newUpdateSpinner(ctx, cancel, "Checking for updates")
+		spinner.Start()
+
+		info, err := update.Check(ctx, version.Version, update.Default)
+		spinner.Stop()
+
+		if err != nil {
+			return fmt.Errorf("failed to check for updates: %w", err)
+		}
+
+		if info.IsDevelopment() {
+			fmt.Fprintf(os.Stderr, "You are running a development version of Crush (%s).\n", info.Current)
+			fmt.Fprintf(os.Stderr, "The latest stable release is v%s.\n", info.Latest)
+			fmt.Fprintf(os.Stderr, "Visit %s to learn more.\n", info.URL)
+			return nil
+		}
+
+		if !info.Available() {
+			fmt.Fprintf(os.Stderr, "You are already running the latest version (v%s).\n", info.Current)
+			return nil
+		}
+
+		fmt.Fprintf(os.Stderr, "Update available: v%s → v%s\n", info.Current, info.Latest)
+		fmt.Fprintf(os.Stderr, "Run 'crush update apply' to install the latest version.\n")
+		fmt.Fprintf(os.Stderr, "Or visit %s to download manually.\n", info.URL)
+
+		return nil
+	},
+}
+
+var updateApplyCmd = &cobra.Command{
+	Use:   "apply",
+	Short: "Apply the latest update",
+	Long:  `Download and install the latest version of Crush.`,
+	Example: `
+# Apply the latest update
+crush update apply
+  `,
+	RunE: func(cmd *cobra.Command, args []string) error {
+		ctx, cancel := context.WithTimeout(cmd.Context(), 5*time.Minute)
+		defer cancel()
+
+		spinner := newUpdateSpinner(ctx, cancel, "Checking for updates")
+		spinner.Start()
+
+		info, err := update.Check(ctx, version.Version, update.Default)
+		if err != nil {
+			spinner.Stop()
+			return fmt.Errorf("failed to check for updates: %w", err)
+		}
+
+		if info.IsDevelopment() {
+			spinner.Stop()
+			return fmt.Errorf("cannot update development versions automatically")
+		}
+
+		if !info.Available() {
+			spinner.Stop()
+			fmt.Fprintf(os.Stderr, "You are already running the latest version (v%s).\n", info.Current)
+			return nil
+		}
+
+		// Get the latest release with assets.
+		release, err := update.Default.Latest(ctx)
+		if err != nil {
+			spinner.Stop()
+			return fmt.Errorf("failed to fetch release information: %w", err)
+		}
+
+		// Find the appropriate asset for this platform.
+		asset, err := update.FindAsset(release.Assets)
+		if err != nil {
+			spinner.Stop()
+			return fmt.Errorf("failed to find update for your platform: %w", err)
+		}
+
+		spinner.Stop()
+		spinner = newUpdateSpinner(ctx, cancel, fmt.Sprintf("Downloading v%s", info.Latest))
+		spinner.Start()
+
+		// Download the asset.
+		binaryPath, err := update.Download(ctx, asset, release)
+		if err != nil {
+			spinner.Stop()
+			return fmt.Errorf("failed to download update: %w", err)
+		}
+		defer os.Remove(binaryPath)
+
+		spinner.Stop()
+		spinner = newUpdateSpinner(ctx, cancel, "Installing")
+		spinner.Start()
+
+		// Apply the update.
+		if err := update.Apply(binaryPath); err != nil {
+			spinner.Stop()
+			return fmt.Errorf("failed to apply update: %w", err)
+		}
+
+		spinner.Stop()
+
+		fmt.Fprintf(os.Stderr, "Successfully updated to v%s!\n", info.Latest)
+		fmt.Fprintf(os.Stderr, "Run 'crush -v' to verify the new version.\n")
+
+		return nil
+	},
+}
+
+// newUpdateSpinner creates a spinner for update operations.
+func newUpdateSpinner(ctx context.Context, cancel context.CancelFunc, label string) *format.Spinner {
+	t := styles.CurrentTheme()
+
+	// Detect background color for appropriate text color.
+	hasDarkBG := true
+	if term.IsTerminal(os.Stderr.Fd()) {
+		hasDarkBG = lipgloss.HasDarkBackground(os.Stdin, os.Stderr)
+	}
+	defaultFG := lipgloss.LightDark(hasDarkBG)(lipgloss.Color("#fafafa"), t.FgBase)
+
+	return format.NewSpinner(ctx, cancel, anim.Settings{
+		Size:        10,
+		Label:       label,
+		LabelColor:  defaultFG,
+		GradColorA:  t.Primary,
+		GradColorB:  t.Secondary,
+		CycleColors: true,
+	})
+}
+
+func init() {
+	updateCmd.AddCommand(updateApplyCmd)
+}

internal/update/update.go 🔗

@@ -1,21 +1,42 @@
 package update
 
 import (
+	"archive/tar"
+	"archive/zip"
+	"bufio"
+	"compress/gzip"
 	"context"
+	"crypto/sha256"
+	"debug/elf"
+	"debug/macho"
+	"debug/pe"
+	"encoding/hex"
 	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
+	"os"
+	"path/filepath"
 	"regexp"
+	"runtime"
 	"strings"
 	"time"
+
+	"github.com/charmbracelet/crush/internal/version"
 )
 
 const (
-	githubApiUrl = "https://api.github.com/repos/charmbracelet/crush/releases/latest"
-	userAgent    = "crush/1.0"
+	githubApiUrl     = "https://api.github.com/repos/charmbracelet/crush/releases/latest"
+	maxBinarySize    = 500 * 1024 * 1024 // 500MB max for extracted binary
+	maxArchiveSize   = 500 * 1024 * 1024 // 500MB max for downloaded archive
+	maxChecksumsSize = 1 * 1024 * 1024   // 1MB max for checksums.txt
 )
 
+// userAgent returns the user agent string for HTTP requests.
+func userAgent() string {
+	return "crush/" + version.Version
+}
+
 // Default is the default [Client].
 var Default Client = &github{}
 
@@ -72,10 +93,17 @@ func Check(ctx context.Context, current string, client Client) (Info, error) {
 	return info, nil
 }
 
+// Asset represents a GitHub release asset.
+type Asset struct {
+	Name               string `json:"name"`
+	BrowserDownloadURL string `json:"browser_download_url"`
+}
+
 // Release represents a GitHub release.
 type Release struct {
-	TagName string `json:"tag_name"`
-	HTMLURL string `json:"html_url"`
+	TagName string  `json:"tag_name"`
+	HTMLURL string  `json:"html_url"`
+	Assets  []Asset `json:"assets"`
 }
 
 // Client is a client that can get the latest release.
@@ -95,7 +123,7 @@ func (c *github) Latest(ctx context.Context) (*Release, error) {
 	if err != nil {
 		return nil, err
 	}
-	req.Header.Set("User-Agent", userAgent)
+	req.Header.Set("User-Agent", userAgent())
 	req.Header.Set("Accept", "application/vnd.github.v3+json")
 
 	resp, err := client.Do(req)
@@ -116,3 +144,490 @@ func (c *github) Latest(ctx context.Context) (*Release, error) {
 
 	return &release, nil
 }
+
+// FindAsset finds the appropriate asset for the current platform.
+func FindAsset(assets []Asset) (*Asset, error) {
+	// Normalize architecture to match goreleaser naming.
+	arch := runtime.GOARCH
+	switch arch {
+	case "amd64":
+		arch = "x86_64"
+	case "386":
+		arch = "i386"
+	case "arm":
+		arch = "armv7"
+		// arm64 stays as "arm64" in goreleaser naming.
+	}
+
+	// Normalize OS to match goreleaser naming (title case).
+	goos := runtime.GOOS
+	switch goos {
+	case "freebsd":
+		goos = "Freebsd"
+	case "netbsd":
+		goos = "Netbsd"
+	case "openbsd":
+		goos = "Openbsd"
+	default:
+		if len(goos) > 0 {
+			goos = strings.ToUpper(goos[:1]) + goos[1:]
+		}
+	}
+
+	// Look for archive matching our platform.
+	// Pattern: crush_{version}_{OS}_{ARCH}.{tar.gz|zip}
+	for _, asset := range assets {
+		if strings.Contains(asset.Name, goos) && strings.Contains(asset.Name, arch) {
+			// Ensure it's an archive, not a checksum or signature.
+			if strings.HasSuffix(asset.Name, ".tar.gz") || strings.HasSuffix(asset.Name, ".zip") {
+				return &asset, nil
+			}
+		}
+	}
+
+	return nil, fmt.Errorf("no suitable asset found for %s/%s", runtime.GOOS, runtime.GOARCH)
+}
+
+// Download downloads and extracts the crush binary from the given asset.
+// Returns the path to the extracted binary.
+func Download(ctx context.Context, asset *Asset, release *Release) (string, error) {
+	client := &http.Client{
+		Timeout: 5 * time.Minute,
+	}
+
+	// Find and download checksums.txt.
+	var checksumsAsset *Asset
+	for i := range release.Assets {
+		if release.Assets[i].Name == "checksums.txt" {
+			checksumsAsset = &release.Assets[i]
+			break
+		}
+	}
+	if checksumsAsset == nil {
+		return "", fmt.Errorf("checksums.txt not found in release")
+	}
+
+	checksums, err := downloadChecksums(ctx, client, checksumsAsset)
+	if err != nil {
+		return "", fmt.Errorf("failed to download checksums: %w", err)
+	}
+
+	req, err := http.NewRequestWithContext(ctx, "GET", asset.BrowserDownloadURL, nil)
+	if err != nil {
+		return "", err
+	}
+	req.Header.Set("User-Agent", userAgent())
+
+	resp, err := client.Do(req)
+	if err != nil {
+		return "", err
+	}
+	defer resp.Body.Close()
+
+	if resp.StatusCode != http.StatusOK {
+		return "", fmt.Errorf("download failed with status %d", resp.StatusCode)
+	}
+
+	// Validate Content-Length if provided.
+	if resp.ContentLength > maxArchiveSize {
+		return "", fmt.Errorf("archive size %d exceeds maximum allowed size of %d bytes", resp.ContentLength, maxArchiveSize)
+	}
+
+	// Create temp file for archive.
+	tmpFile, err := os.CreateTemp("", "crush-update-*")
+	if err != nil {
+		return "", fmt.Errorf("failed to create temp file: %w", err)
+	}
+	defer os.Remove(tmpFile.Name())
+
+	// Download to temp file while computing checksum.
+	// Use LimitReader to prevent DoS from oversized downloads (Content-Length can be spoofed).
+	hash := sha256.New()
+	limitedBody := io.LimitReader(resp.Body, maxArchiveSize+1)
+	written, err := io.Copy(io.MultiWriter(tmpFile, hash), limitedBody)
+	if err != nil {
+		tmpFile.Close()
+		return "", fmt.Errorf("failed to download: %w", err)
+	}
+	if written > maxArchiveSize {
+		tmpFile.Close()
+		return "", fmt.Errorf("archive size %d exceeds maximum allowed size of %d bytes", written, maxArchiveSize)
+	}
+	tmpFile.Close()
+
+	// Verify checksum.
+	actualSum := hex.EncodeToString(hash.Sum(nil))
+	expectedSum, ok := checksums[asset.Name]
+	if !ok {
+		return "", fmt.Errorf("no checksum found for %s", asset.Name)
+	}
+	if actualSum != expectedSum {
+		return "", fmt.Errorf("checksum mismatch: expected %s, got %s", expectedSum, actualSum)
+	}
+
+	// Extract binary based on archive type.
+	var binaryPath string
+	if strings.HasSuffix(asset.Name, ".zip") {
+		binaryPath, err = extractZip(tmpFile.Name())
+	} else {
+		binaryPath, err = extractTarGz(tmpFile.Name())
+	}
+	if err != nil {
+		return "", fmt.Errorf("failed to extract: %w", err)
+	}
+
+	// Validate the extracted binary before returning.
+	if err := validateBinary(binaryPath); err != nil {
+		os.Remove(binaryPath)
+		return "", fmt.Errorf("invalid binary: %w", err)
+	}
+
+	return binaryPath, nil
+}
+
+// validateBinary checks that the file at path is a valid executable binary
+// for the current platform using the standard library debug packages.
+func validateBinary(path string) error {
+	switch runtime.GOOS {
+	case "windows":
+		return validatePE(path)
+	case "darwin":
+		return validateMachO(path)
+	default:
+		return validateELF(path)
+	}
+}
+
+func validateELF(path string) error {
+	f, err := elf.Open(path)
+	if err != nil {
+		return fmt.Errorf("not a valid ELF binary: %w", err)
+	}
+	defer f.Close()
+	if f.Type != elf.ET_EXEC && f.Type != elf.ET_DYN {
+		return fmt.Errorf("ELF file is not an executable (type: %v)", f.Type)
+	}
+	return nil
+}
+
+func validateMachO(path string) error {
+	f, err := macho.Open(path)
+	if err != nil {
+		return fmt.Errorf("not a valid Mach-O binary: %w", err)
+	}
+	defer f.Close()
+	if f.Type != macho.TypeExec {
+		return fmt.Errorf("Mach-O file is not an executable (type: %v)", f.Type)
+	}
+	return nil
+}
+
+func validatePE(path string) error {
+	f, err := pe.Open(path)
+	if err != nil {
+		return fmt.Errorf("not a valid PE binary: %w", err)
+	}
+	defer f.Close()
+	// PE files opened successfully are valid executables.
+	return nil
+}
+
+// downloadChecksums downloads and parses the checksums.txt file.
+// Returns a map of filename to sha256 checksum.
+func downloadChecksums(ctx context.Context, client *http.Client, asset *Asset) (map[string]string, error) {
+	req, err := http.NewRequestWithContext(ctx, "GET", asset.BrowserDownloadURL, nil)
+	if err != nil {
+		return nil, err
+	}
+	req.Header.Set("User-Agent", userAgent())
+
+	resp, err := client.Do(req)
+	if err != nil {
+		return nil, err
+	}
+	defer resp.Body.Close()
+
+	if resp.StatusCode != http.StatusOK {
+		return nil, fmt.Errorf("failed to download checksums with status %d", resp.StatusCode)
+	}
+
+	// Limit the size of checksums.txt to prevent memory exhaustion.
+	limitedBody := io.LimitReader(resp.Body, maxChecksumsSize+1)
+
+	checksums := make(map[string]string)
+	scanner := bufio.NewScanner(limitedBody)
+	for scanner.Scan() {
+		line := scanner.Text()
+		parts := strings.Fields(line)
+		if len(parts) != 2 {
+			continue
+		}
+		// parts[0] is checksum, parts[1] is filename.
+		checksums[parts[1]] = parts[0]
+	}
+	if err := scanner.Err(); err != nil {
+		return nil, err
+	}
+
+	return checksums, nil
+}
+
+// parseChecksumLines parses checksum lines from a string.
+// Used for testing the checksum parsing logic.
+func parseChecksumLines(content string) map[string]string {
+	checksums := make(map[string]string)
+	scanner := bufio.NewScanner(strings.NewReader(content))
+	for scanner.Scan() {
+		line := scanner.Text()
+		parts := strings.Fields(line)
+		if len(parts) != 2 {
+			continue
+		}
+		checksums[parts[1]] = parts[0]
+	}
+	return checksums
+}
+
+// extractZip extracts the crush binary from a zip archive.
+func extractZip(archivePath string) (string, error) {
+	r, err := zip.OpenReader(archivePath)
+	if err != nil {
+		return "", err
+	}
+	defer r.Close()
+
+	binaryName := "crush"
+	if runtime.GOOS == "windows" {
+		binaryName = "crush.exe"
+	}
+
+	for _, f := range r.File {
+		// Path traversal protection.
+		cleanName := filepath.Clean(f.Name)
+		if strings.Contains(cleanName, "..") {
+			continue
+		}
+
+		// Use exact name matching to avoid matching unintended files.
+		if filepath.Base(f.Name) == binaryName {
+			binaryPath, err := extractZipFile(f)
+			if err != nil {
+				return "", err
+			}
+			return binaryPath, nil
+		}
+	}
+
+	return "", fmt.Errorf("crush binary not found in archive")
+}
+
+// extractZipFile extracts a single file from a zip archive with size limits.
+func extractZipFile(f *zip.File) (string, error) {
+	rc, err := f.Open()
+	if err != nil {
+		return "", err
+	}
+	defer rc.Close()
+
+	tmpBinary, err := os.CreateTemp("", "crush-binary-*")
+	if err != nil {
+		return "", err
+	}
+
+	limitedReader := io.LimitReader(rc, maxBinarySize+1)
+	written, err := io.Copy(tmpBinary, limitedReader)
+	if err != nil {
+		tmpBinary.Close()
+		os.Remove(tmpBinary.Name())
+		return "", err
+	}
+	if written > maxBinarySize {
+		tmpBinary.Close()
+		os.Remove(tmpBinary.Name())
+		return "", fmt.Errorf("binary exceeds maximum size of %d bytes", maxBinarySize)
+	}
+
+	if err := tmpBinary.Chmod(0o755); err != nil {
+		tmpBinary.Close()
+		os.Remove(tmpBinary.Name())
+		return "", err
+	}
+
+	if err := tmpBinary.Close(); err != nil {
+		os.Remove(tmpBinary.Name())
+		return "", err
+	}
+
+	return tmpBinary.Name(), nil
+}
+
+// extractTarGz extracts the crush binary from a tar.gz archive.
+func extractTarGz(archivePath string) (string, error) {
+	f, err := os.Open(archivePath)
+	if err != nil {
+		return "", err
+	}
+	defer f.Close()
+
+	gzr, err := gzip.NewReader(f)
+	if err != nil {
+		return "", err
+	}
+	defer gzr.Close()
+
+	tr := tar.NewReader(gzr)
+
+	binaryName := "crush"
+	if runtime.GOOS == "windows" {
+		binaryName = "crush.exe"
+	}
+
+	for {
+		header, err := tr.Next()
+		if err == io.EOF {
+			break
+		}
+		if err != nil {
+			return "", err
+		}
+
+		// Path traversal protection.
+		cleanName := filepath.Clean(header.Name)
+		if strings.Contains(cleanName, "..") {
+			continue
+		}
+
+		// Use exact name matching to avoid matching unintended files.
+		if filepath.Base(header.Name) == binaryName {
+			binaryPath, err := extractTarFile(tr)
+			if err != nil {
+				return "", err
+			}
+			return binaryPath, nil
+		}
+	}
+
+	return "", fmt.Errorf("crush binary not found in archive")
+}
+
+// extractTarFile extracts a single file from a tar reader with size limits.
+func extractTarFile(tr *tar.Reader) (string, error) {
+	tmpBinary, err := os.CreateTemp("", "crush-binary-*")
+	if err != nil {
+		return "", err
+	}
+
+	limitedReader := io.LimitReader(tr, maxBinarySize+1)
+	written, err := io.Copy(tmpBinary, limitedReader)
+	if err != nil {
+		tmpBinary.Close()
+		os.Remove(tmpBinary.Name())
+		return "", err
+	}
+	if written > maxBinarySize {
+		tmpBinary.Close()
+		os.Remove(tmpBinary.Name())
+		return "", fmt.Errorf("binary exceeds maximum size of %d bytes", maxBinarySize)
+	}
+
+	if err := tmpBinary.Chmod(0o755); err != nil {
+		tmpBinary.Close()
+		os.Remove(tmpBinary.Name())
+		return "", err
+	}
+
+	if err := tmpBinary.Close(); err != nil {
+		os.Remove(tmpBinary.Name())
+		return "", err
+	}
+
+	return tmpBinary.Name(), nil
+}
+
+// Apply replaces the current executable with the downloaded binary.
+func Apply(binaryPath string) error {
+	// Get path to current executable.
+	exe, err := os.Executable()
+	if err != nil {
+		return fmt.Errorf("failed to get executable path: %w", err)
+	}
+
+	// Resolve symlinks.
+	exe, err = filepath.EvalSymlinks(exe)
+	if err != nil {
+		return fmt.Errorf("failed to resolve symlinks: %w", err)
+	}
+
+	// Get the directory of the executable.
+	exeDir := filepath.Dir(exe)
+
+	// Check if we have write permissions to the directory.
+	if err := checkWritePermission(exeDir); err != nil {
+		return fmt.Errorf("cannot write to %s: %w (you may need to run with elevated privileges)", exeDir, err)
+	}
+
+	// Copy binary to exe directory first to ensure same filesystem.
+	// os.Rename fails across filesystems, and binaryPath may be in /tmp.
+	localBinary := filepath.Join(exeDir, ".crush-update-new")
+	if err := copyFile(binaryPath, localBinary); err != nil {
+		return fmt.Errorf("failed to copy new binary: %w", err)
+	}
+
+	// Create a backup of the current executable.
+	backupPath := filepath.Join(exeDir, filepath.Base(exe)+".old")
+	if err := os.Rename(exe, backupPath); err != nil {
+		os.Remove(localBinary)
+		return fmt.Errorf("failed to backup current executable: %w", err)
+	}
+
+	// Move new binary to executable location.
+	if err := os.Rename(localBinary, exe); err != nil {
+		// Try to restore backup on failure.
+		_ = os.Rename(backupPath, exe)
+		os.Remove(localBinary)
+		return fmt.Errorf("failed to install new version: %w", err)
+	}
+
+	// Remove backup on success.
+	_ = os.Remove(backupPath)
+
+	return nil
+}
+
+// checkWritePermission checks if we can write to the given directory.
+func checkWritePermission(dir string) error {
+	testFile := filepath.Join(dir, ".crush-update-test")
+	f, err := os.Create(testFile)
+	if err != nil {
+		return err
+	}
+	f.Close()
+	return os.Remove(testFile)
+}
+
+// copyFile copies a file from src to dst, preserving permissions.
+func copyFile(src, dst string) error {
+	srcFile, err := os.Open(src)
+	if err != nil {
+		return err
+	}
+	defer srcFile.Close()
+
+	srcInfo, err := srcFile.Stat()
+	if err != nil {
+		return err
+	}
+
+	dstFile, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, srcInfo.Mode())
+	if err != nil {
+		return err
+	}
+	defer dstFile.Close()
+
+	if _, err = io.Copy(dstFile, srcFile); err != nil {
+		return err
+	}
+
+	// Sync to ensure data is persisted before atomic rename.
+	return dstFile.Sync()
+}

internal/update/update_test.go 🔗

@@ -1,7 +1,14 @@
 package update
 
 import (
+	"archive/tar"
+	"archive/zip"
+	"compress/gzip"
 	"context"
+	"os"
+	"path/filepath"
+	"runtime"
+	"strings"
 	"testing"
 
 	"github.com/stretchr/testify/require"
@@ -46,3 +53,330 @@ func (t testClient) Latest(ctx context.Context) (*Release, error) {
 		HTMLURL: "https://example.org",
 	}, nil
 }
+
+func TestFindAsset(t *testing.T) {
+	t.Parallel()
+
+	// Create test assets matching goreleaser naming.
+	assets := []Asset{
+		{Name: "crush_0.19.2_Linux_x86_64.tar.gz", BrowserDownloadURL: "https://example.com/linux-amd64.tar.gz"},
+		{Name: "crush_0.19.2_Darwin_x86_64.tar.gz", BrowserDownloadURL: "https://example.com/darwin-amd64.tar.gz"},
+		{Name: "crush_0.19.2_Darwin_arm64.tar.gz", BrowserDownloadURL: "https://example.com/darwin-arm64.tar.gz"},
+		{Name: "crush_0.19.2_Windows_x86_64.zip", BrowserDownloadURL: "https://example.com/windows-amd64.zip"},
+		{Name: "crush_0.19.2_Linux_i386.tar.gz", BrowserDownloadURL: "https://example.com/linux-386.tar.gz"},
+		{Name: "checksums.txt", BrowserDownloadURL: "https://example.com/checksums.txt"},
+		{Name: "crush_0.19.2_Linux_x86_64.tar.gz.sig", BrowserDownloadURL: "https://example.com/linux-amd64.tar.gz.sig"},
+	}
+
+	t.Run("finds correct asset for current platform", func(t *testing.T) {
+		t.Parallel()
+		asset, err := FindAsset(assets)
+		require.NoError(t, err)
+		require.NotNil(t, asset)
+
+		// Check that the asset matches our platform.
+		switch runtime.GOOS {
+		case "linux":
+			require.Contains(t, asset.Name, "Linux")
+		case "darwin":
+			require.Contains(t, asset.Name, "Darwin")
+		case "windows":
+			require.Contains(t, asset.Name, "Windows")
+		}
+
+		// Check that it's an archive, not a signature or checksum.
+		require.True(t, strings.HasSuffix(asset.Name, ".tar.gz") || strings.HasSuffix(asset.Name, ".zip"))
+	})
+
+	t.Run("returns error when no matching asset", func(t *testing.T) {
+		t.Parallel()
+		emptyAssets := []Asset{
+			{Name: "checksums.txt", BrowserDownloadURL: "https://example.com/checksums.txt"},
+		}
+		asset, err := FindAsset(emptyAssets)
+		require.Error(t, err)
+		require.Nil(t, asset)
+	})
+}
+
+func TestIsDevelopment(t *testing.T) {
+	t.Parallel()
+
+	tests := []struct {
+		name    string
+		version string
+		want    bool
+	}{
+		{"devel version", "devel", true},
+		{"unknown version", "unknown", true},
+		{"dirty version", "0.19.0-dirty", true},
+		{"dirty with suffix", "0.19.0-10-g1234567-dirty", true},
+		{"go install version", "v0.0.0-0.20251231235959-06c807842604", true},
+		{"stable version", "0.19.0", false},
+		{"pre-release beta", "0.19.0-beta.1", false},
+		{"pre-release rc", "0.19.0-rc.1", false},
+		{"pre-release alpha", "0.19.0-alpha.1", false},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			t.Parallel()
+			info := Info{Current: tt.version, Latest: "0.20.0"}
+			require.Equal(t, tt.want, info.IsDevelopment())
+		})
+	}
+}
+
+func TestAvailable(t *testing.T) {
+	t.Parallel()
+
+	tests := []struct {
+		name    string
+		current string
+		latest  string
+		want    bool
+	}{
+		{"same version", "0.19.0", "0.19.0", false},
+		{"newer available", "0.19.0", "0.19.1", true},
+		{"older latest (downgrade)", "0.19.1", "0.19.0", true},
+		{"rc to stable", "0.19.0-rc.1", "0.19.0", true},
+		{"stable to rc", "0.19.0", "0.20.0-rc.1", false},
+		{"alpha to beta", "0.19.0-alpha.1", "0.19.0-beta.1", true},
+		{"beta to rc", "0.19.0-beta.1", "0.19.0-rc.1", true},
+		{"same pre-release", "0.19.0-beta.1", "0.19.0-beta.1", false},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			t.Parallel()
+			info := Info{Current: tt.current, Latest: tt.latest}
+			require.Equal(t, tt.want, info.Available())
+		})
+	}
+}
+
+func TestParseChecksums(t *testing.T) {
+	t.Parallel()
+
+	tests := []struct {
+		name  string
+		input string
+		want  map[string]string
+	}{
+		{
+			name:  "standard format",
+			input: "abc123  crush_0.19.2_Linux_x86_64.tar.gz\ndef456  crush_0.19.2_Darwin_arm64.tar.gz\n",
+			want: map[string]string{
+				"crush_0.19.2_Linux_x86_64.tar.gz": "abc123",
+				"crush_0.19.2_Darwin_arm64.tar.gz": "def456",
+			},
+		},
+		{
+			name:  "empty lines",
+			input: "abc123  file1.tar.gz\n\ndef456  file2.tar.gz\n",
+			want: map[string]string{
+				"file1.tar.gz": "abc123",
+				"file2.tar.gz": "def456",
+			},
+		},
+		{
+			name:  "extra fields ignored",
+			input: "abc123  file1.tar.gz  extra  fields\n",
+			want:  map[string]string{},
+		},
+		{
+			name:  "single field ignored",
+			input: "abc123\n",
+			want:  map[string]string{},
+		},
+		{
+			name:  "whitespace variations",
+			input: "abc123\tfile1.tar.gz\n",
+			want: map[string]string{
+				"file1.tar.gz": "abc123",
+			},
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			t.Parallel()
+			result := parseChecksumLines(tt.input)
+			require.Equal(t, tt.want, result)
+		})
+	}
+}
+
+func TestExtractTarGz(t *testing.T) {
+	t.Parallel()
+
+	t.Run("missing binary", func(t *testing.T) {
+		t.Parallel()
+		// Create a tar.gz with no crush binary.
+		tmpDir := t.TempDir()
+		archivePath := filepath.Join(tmpDir, "test.tar.gz")
+
+		f, err := os.Create(archivePath)
+		require.NoError(t, err)
+
+		gzw := gzip.NewWriter(f)
+		tw := tar.NewWriter(gzw)
+
+		// Add a random file, not crush.
+		content := []byte("not a binary")
+		hdr := &tar.Header{
+			Name: "other-file.txt",
+			Mode: 0o644,
+			Size: int64(len(content)),
+		}
+		require.NoError(t, tw.WriteHeader(hdr))
+		_, err = tw.Write(content)
+		require.NoError(t, err)
+
+		require.NoError(t, tw.Close())
+		require.NoError(t, gzw.Close())
+		require.NoError(t, f.Close())
+
+		_, err = extractTarGz(archivePath)
+		require.Error(t, err)
+		require.Contains(t, err.Error(), "not found")
+	})
+
+	t.Run("path traversal attempt", func(t *testing.T) {
+		t.Parallel()
+		tmpDir := t.TempDir()
+		archivePath := filepath.Join(tmpDir, "test.tar.gz")
+
+		f, err := os.Create(archivePath)
+		require.NoError(t, err)
+
+		gzw := gzip.NewWriter(f)
+		tw := tar.NewWriter(gzw)
+
+		// Add a file with path traversal attempt.
+		content := []byte("malicious")
+		hdr := &tar.Header{
+			Name: "../../../etc/passwd",
+			Mode: 0o644,
+			Size: int64(len(content)),
+		}
+		require.NoError(t, tw.WriteHeader(hdr))
+		_, err = tw.Write(content)
+		require.NoError(t, err)
+
+		require.NoError(t, tw.Close())
+		require.NoError(t, gzw.Close())
+		require.NoError(t, f.Close())
+
+		// Should not extract the malicious file and should fail to find binary.
+		_, err = extractTarGz(archivePath)
+		require.Error(t, err)
+		require.Contains(t, err.Error(), "not found")
+	})
+}
+
+func TestExtractZip(t *testing.T) {
+	t.Parallel()
+
+	t.Run("missing binary", func(t *testing.T) {
+		t.Parallel()
+		tmpDir := t.TempDir()
+		archivePath := filepath.Join(tmpDir, "test.zip")
+
+		f, err := os.Create(archivePath)
+		require.NoError(t, err)
+
+		zw := zip.NewWriter(f)
+
+		// Add a random file, not crush.
+		w, err := zw.Create("other-file.txt")
+		require.NoError(t, err)
+		_, err = w.Write([]byte("not a binary"))
+		require.NoError(t, err)
+
+		require.NoError(t, zw.Close())
+		require.NoError(t, f.Close())
+
+		_, err = extractZip(archivePath)
+		require.Error(t, err)
+		require.Contains(t, err.Error(), "not found")
+	})
+
+	t.Run("path traversal attempt", func(t *testing.T) {
+		t.Parallel()
+		tmpDir := t.TempDir()
+		archivePath := filepath.Join(tmpDir, "test.zip")
+
+		f, err := os.Create(archivePath)
+		require.NoError(t, err)
+
+		zw := zip.NewWriter(f)
+
+		// Add a file with path traversal attempt.
+		w, err := zw.Create("../../../etc/passwd")
+		require.NoError(t, err)
+		_, err = w.Write([]byte("malicious"))
+		require.NoError(t, err)
+
+		require.NoError(t, zw.Close())
+		require.NoError(t, f.Close())
+
+		// Should not extract the malicious file and should fail to find binary.
+		_, err = extractZip(archivePath)
+		require.Error(t, err)
+		require.Contains(t, err.Error(), "not found")
+	})
+}
+
+func TestApply(t *testing.T) {
+	t.Parallel()
+
+	t.Run("read-only directory", func(t *testing.T) {
+		t.Parallel()
+		if runtime.GOOS == "windows" {
+			t.Skip("chmod not reliable on Windows")
+		}
+
+		tmpDir := t.TempDir()
+
+		// Create a fake binary to apply.
+		binaryPath := filepath.Join(tmpDir, "new-binary")
+		require.NoError(t, os.WriteFile(binaryPath, []byte("new"), 0o755))
+
+		// Create a read-only directory.
+		readOnlyDir := filepath.Join(tmpDir, "readonly")
+		require.NoError(t, os.MkdirAll(readOnlyDir, 0o755))
+
+		// Create a fake executable in the read-only dir.
+		exePath := filepath.Join(readOnlyDir, "crush")
+		require.NoError(t, os.WriteFile(exePath, []byte("old"), 0o755))
+
+		// Make the directory read-only.
+		require.NoError(t, os.Chmod(readOnlyDir, 0o555))
+		t.Cleanup(func() {
+			// Restore permissions for cleanup.
+			os.Chmod(readOnlyDir, 0o755)
+		})
+
+		// checkWritePermission should fail.
+		err := checkWritePermission(readOnlyDir)
+		require.Error(t, err)
+	})
+
+	t.Run("successful copy", func(t *testing.T) {
+		t.Parallel()
+		tmpDir := t.TempDir()
+
+		src := filepath.Join(tmpDir, "src")
+		dst := filepath.Join(tmpDir, "dst")
+
+		content := []byte("test content")
+		require.NoError(t, os.WriteFile(src, content, 0o755))
+
+		require.NoError(t, copyFile(src, dst))
+
+		dstContent, err := os.ReadFile(dst)
+		require.NoError(t, err)
+		require.Equal(t, content, dstContent)
+	})
+}