refactor(update): harden security and fix issues

Amolith created

- Verify checksums.txt against GitHub API digest before trusting
  contents
- Add lockfile to prevent race condition in Windows ApplyPendingUpdate
- Clean up lingering .old files on Windows startup
- Make async update cancellable when TUI exits
- Add context-aware download with cancellation support
- Validate SHA256 checksum format (64 hex chars)
- Deduplicate checksum parsing and dev version messages

Assisted-by: Claude Opus 4.5 via Crush

Change summary

internal/cmd/root.go              |  28 ++++++
internal/cmd/update.go            |  11 --
internal/update/update.go         | 139 +++++++++++++++++++++++++++++---
internal/update/update_test.go    | 122 ++++++++++++++++++++++++++--
internal/update/update_windows.go |  27 +++++
5 files changed, 288 insertions(+), 39 deletions(-)

Detailed changes

internal/cmd/root.go 🔗

@@ -101,8 +101,13 @@ crush -y
 			tea.WithFilter(tui.MouseEventFilter)) // Filter mouse events based on focus state
 		go app.Subscribe(program)
 
+		// Create a cancellable context for the update check that gets cancelled
+		// when the TUI exits.
+		updateCtx, cancelUpdate := context.WithCancel(cmd.Context())
+		defer cancelUpdate()
+
 		// Start async update check unless disabled.
-		go checkForUpdateAsync(cmd.Context(), program)
+		go checkForUpdateAsync(updateCtx, program)
 
 		if _, err := program.Run(); err != nil {
 			event.Error(err)
@@ -206,7 +211,7 @@ func checkForUpdateSync() string {
 	}
 
 	if info.IsDevelopment() {
-		return fmt.Sprintf("\nThis is a development version of Crush. The latest stable release is v%s.\nRun 'crush update' to learn more.\n", info.Latest)
+		return info.DevelopmentVersionBrief()
 	}
 
 	return fmt.Sprintf("\nUpdate available: v%s → v%s\nRun 'crush update apply' to install.\n", info.Current, info.Latest)
@@ -230,6 +235,11 @@ func checkForUpdateAsync(ctx context.Context, program *tea.Program) {
 		return
 	}
 
+	// Check if context was cancelled while checking.
+	if ctx.Err() != nil {
+		return
+	}
+
 	// Check install method.
 	method := update.DetectInstallMethod()
 	if !method.CanSelfUpdate() {
@@ -253,8 +263,17 @@ func checkForUpdateAsync(ctx context.Context, program *tea.Program) {
 		return
 	}
 
+	// Check if context was cancelled before download.
+	if ctx.Err() != nil {
+		return
+	}
+
 	binaryPath, err := update.Download(checkCtx, asset, info.Release)
 	if err != nil {
+		// Don't show error message if context was cancelled (user exited).
+		if ctx.Err() != nil {
+			return
+		}
 		program.Send(tuiutil.InfoMsg{
 			Type: tuiutil.InfoTypeWarn,
 			Msg:  "Update download failed. Run 'crush update' for details.",
@@ -264,6 +283,11 @@ func checkForUpdateAsync(ctx context.Context, program *tea.Program) {
 	}
 	defer os.Remove(binaryPath)
 
+	// Check if context was cancelled before apply.
+	if ctx.Err() != nil {
+		return
+	}
+
 	if err := update.Apply(binaryPath); err != nil {
 		program.Send(tuiutil.InfoMsg{
 			Type: tuiutil.InfoTypeWarn,

internal/cmd/update.go 🔗

@@ -47,11 +47,7 @@ crush update apply --force
 		}
 
 		if info.IsDevelopment() {
-			fmt.Fprintf(os.Stderr, "You are running a development version of Crush (%s).\n", info.Current)
-			fmt.Fprintf(os.Stderr, "The latest stable release is v%s.\n", info.Latest)
-			fmt.Fprintf(os.Stderr, "To install the latest stable version, run:\n")
-			fmt.Fprintf(os.Stderr, "  go install github.com/charmbracelet/crush@latest\n")
-			fmt.Fprintf(os.Stderr, "Or visit %s to download manually.\n", info.URL)
+			fmt.Fprint(os.Stderr, info.DevelopmentVersionMessage(false))
 			return nil
 		}
 
@@ -115,10 +111,7 @@ crush update apply --force
 
 		if info.IsDevelopment() {
 			spinner.Stop()
-			fmt.Fprintf(os.Stderr, "You are running a development version of Crush (%s).\n", info.Current)
-			fmt.Fprintf(os.Stderr, "Self-update is not supported for development versions.\n")
-			fmt.Fprintf(os.Stderr, "To install the latest stable version, run:\n")
-			fmt.Fprintf(os.Stderr, "  go install github.com/charmbracelet/crush@latest\n")
+			fmt.Fprint(os.Stderr, info.DevelopmentVersionMessage(true))
 			return nil
 		}
 

internal/update/update.go 🔗

@@ -72,6 +72,34 @@ func (i Info) IsDevelopment() bool {
 		gitDescribeRegexp.MatchString(i.Current)
 }
 
+// DevelopmentVersionBrief returns a brief message for development versions
+// suitable for the version flag output.
+func (i Info) DevelopmentVersionBrief() string {
+	return fmt.Sprintf(
+		"\nThis is a development version of Crush. The latest stable release is v%s.\nRun 'crush update' to learn more.\n",
+		i.Latest,
+	)
+}
+
+// DevelopmentVersionMessage returns a detailed message for development versions
+// suitable for CLI command output. If selfUpdateNote is true, includes a note
+// that self-update is not supported.
+func (i Info) DevelopmentVersionMessage(selfUpdateNote bool) string {
+	var b strings.Builder
+	fmt.Fprintf(&b, "You are running a development version of Crush (%s).\n", i.Current)
+	if selfUpdateNote {
+		b.WriteString("Self-update is not supported for development versions.\n")
+	} else {
+		fmt.Fprintf(&b, "The latest stable release is v%s.\n", i.Latest)
+	}
+	b.WriteString("To install the latest stable version, run:\n")
+	b.WriteString("  go install github.com/charmbracelet/crush@latest\n")
+	if i.URL != "" && !selfUpdateNote {
+		fmt.Fprintf(&b, "Or visit %s to download manually.\n", i.URL)
+	}
+	return b.String()
+}
+
 // Available returns true if there's an update available.
 // Uses proper semver comparison to handle version ordering correctly.
 // Returns false if either version cannot be parsed.
@@ -210,6 +238,7 @@ type Asset struct {
 	Name               string `json:"name"`
 	BrowserDownloadURL string `json:"browser_download_url"`
 	Size               int64  `json:"size"`
+	Digest             string `json:"digest"` // Format: "sha256:hexstring"
 }
 
 // Release represents a GitHub release.
@@ -357,7 +386,7 @@ func Download(ctx context.Context, asset *Asset, release *Release) (string, erro
 	// Use LimitReader to prevent DoS from oversized downloads (Content-Length can be spoofed).
 	hash := sha256.New()
 	limitedBody := io.LimitReader(resp.Body, maxArchiveSize+1)
-	written, err := io.Copy(io.MultiWriter(tmpFile, hash), limitedBody)
+	written, err := copyWithContext(ctx, io.MultiWriter(tmpFile, hash), limitedBody)
 	if err != nil {
 		tmpFile.Close()
 		return "", fmt.Errorf("failed to download: %w", err)
@@ -395,8 +424,9 @@ func Download(ctx context.Context, asset *Asset, release *Release) (string, erro
 	return binaryPath, nil
 }
 
-// downloadChecksums downloads and parses the checksums.txt file.
-// Returns a map of filename to sha256 checksum.
+// downloadChecksums downloads, verifies, and parses the checksums.txt file.
+// The checksums.txt file is verified against GitHub's API-provided digest before
+// parsing to ensure integrity. Returns a map of filename to sha256 checksum.
 func downloadChecksums(ctx context.Context, client *http.Client, asset *Asset) (map[string]string, error) {
 	req, err := http.NewRequestWithContext(ctx, "GET", asset.BrowserDownloadURL, nil)
 	if err != nil {
@@ -417,26 +447,56 @@ func downloadChecksums(ctx context.Context, client *http.Client, asset *Asset) (
 	// Limit the size of checksums.txt to prevent memory exhaustion.
 	limitedBody := io.LimitReader(resp.Body, maxChecksumsSize+1)
 
-	checksums := make(map[string]string)
-	scanner := bufio.NewScanner(limitedBody)
-	for scanner.Scan() {
-		line := scanner.Text()
-		parts := strings.Fields(line)
-		if len(parts) != 2 {
-			continue
+	// Read content while computing hash for verification against GitHub's digest.
+	hash := sha256.New()
+	content, err := io.ReadAll(io.TeeReader(limitedBody, hash))
+	if err != nil {
+		return nil, err
+	}
+	if int64(len(content)) > maxChecksumsSize {
+		return nil, fmt.Errorf("checksums.txt exceeds maximum size of %d bytes", maxChecksumsSize)
+	}
+
+	// Verify against GitHub's API-provided digest if available.
+	if asset.Digest != "" {
+		actualSum := hex.EncodeToString(hash.Sum(nil))
+		expectedSum := parseDigest(asset.Digest)
+		if expectedSum == "" {
+			return nil, fmt.Errorf("invalid digest format from API: %s", asset.Digest)
+		}
+		if actualSum != expectedSum {
+			return nil, fmt.Errorf("checksums.txt digest mismatch: expected %s, got %s", expectedSum, actualSum)
 		}
-		// parts[0] is checksum, parts[1] is filename.
-		checksums[parts[1]] = parts[0]
 	}
-	if err := scanner.Err(); err != nil {
-		return nil, err
+
+	return parseChecksumLines(string(content)), nil
+}
+
+// parseDigest extracts the hex checksum from a digest string (format: "sha256:hex").
+// Returns empty string if format is invalid.
+func parseDigest(digest string) string {
+	const prefix = "sha256:"
+	if !strings.HasPrefix(digest, prefix) {
+		return ""
 	}
+	hex := strings.TrimPrefix(digest, prefix)
+	if !isValidSHA256(hex) {
+		return ""
+	}
+	return hex
+}
 
-	return checksums, nil
+// isValidSHA256 validates that a string is a valid SHA256 hex checksum.
+func isValidSHA256(s string) bool {
+	if len(s) != 64 {
+		return false
+	}
+	_, err := hex.DecodeString(s)
+	return err == nil
 }
 
 // parseChecksumLines parses checksum lines from a string.
-// Used for testing the checksum parsing logic.
+// Returns a map of filename to sha256 checksum. Invalid checksums are skipped.
 func parseChecksumLines(content string) map[string]string {
 	checksums := make(map[string]string)
 	scanner := bufio.NewScanner(strings.NewReader(content))
@@ -446,6 +506,11 @@ func parseChecksumLines(content string) map[string]string {
 		if len(parts) != 2 {
 			continue
 		}
+		// Validate checksum format (64 hex chars for SHA256).
+		if !isValidSHA256(parts[0]) {
+			continue
+		}
+		// parts[0] is checksum, parts[1] is filename.
 		checksums[parts[1]] = parts[0]
 	}
 	return checksums
@@ -634,3 +699,45 @@ func copyFile(src, dst string) error {
 	// Sync to ensure data is persisted before atomic rename.
 	return dstFile.Sync()
 }
+
+// copyWithContext copies from src to dst while periodically checking for context
+// cancellation. This allows large downloads to be interrupted when the user
+// quits the application.
+func copyWithContext(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) {
+	const bufSize = 32 * 1024 // 32KB chunks
+	buf := make([]byte, bufSize)
+	var written int64
+
+	for {
+		// Check for cancellation before each chunk.
+		select {
+		case <-ctx.Done():
+			return written, ctx.Err()
+		default:
+		}
+
+		nr, readErr := src.Read(buf)
+		if nr > 0 {
+			nw, writeErr := dst.Write(buf[:nr])
+			if nw < 0 || nr < nw {
+				nw = 0
+				if writeErr == nil {
+					writeErr = fmt.Errorf("invalid write result")
+				}
+			}
+			written += int64(nw)
+			if writeErr != nil {
+				return written, writeErr
+			}
+			if nr != nw {
+				return written, io.ErrShortWrite
+			}
+		}
+		if readErr != nil {
+			if readErr == io.EOF {
+				return written, nil
+			}
+			return written, readErr
+		}
+	}
+}

internal/update/update_test.go 🔗

@@ -210,6 +210,10 @@ func TestAvailable(t *testing.T) {
 func TestParseChecksums(t *testing.T) {
 	t.Parallel()
 
+	// Valid SHA256 hashes for testing (64 hex characters).
+	hash1 := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
+	hash2 := "f6e5d4c3b2a1f6e5d4c3b2a1f6e5d4c3b2a1f6e5d4c3b2a1f6e5d4c3b2a1f6e5"
+
 	tests := []struct {
 		name  string
 		input string
@@ -217,37 +221,47 @@ func TestParseChecksums(t *testing.T) {
 	}{
 		{
 			name:  "standard format",
-			input: "abc123  crush_0.19.2_Linux_x86_64.tar.gz\ndef456  crush_0.19.2_Darwin_arm64.tar.gz\n",
+			input: hash1 + "  crush_0.19.2_Linux_x86_64.tar.gz\n" + hash2 + "  crush_0.19.2_Darwin_arm64.tar.gz\n",
 			want: map[string]string{
-				"crush_0.19.2_Linux_x86_64.tar.gz": "abc123",
-				"crush_0.19.2_Darwin_arm64.tar.gz": "def456",
+				"crush_0.19.2_Linux_x86_64.tar.gz": hash1,
+				"crush_0.19.2_Darwin_arm64.tar.gz": hash2,
 			},
 		},
 		{
 			name:  "empty lines",
-			input: "abc123  file1.tar.gz\n\ndef456  file2.tar.gz\n",
+			input: hash1 + "  file1.tar.gz\n\n" + hash2 + "  file2.tar.gz\n",
 			want: map[string]string{
-				"file1.tar.gz": "abc123",
-				"file2.tar.gz": "def456",
+				"file1.tar.gz": hash1,
+				"file2.tar.gz": hash2,
 			},
 		},
 		{
 			name:  "extra fields ignored",
-			input: "abc123  file1.tar.gz  extra  fields\n",
+			input: hash1 + "  file1.tar.gz  extra  fields\n",
 			want:  map[string]string{},
 		},
 		{
 			name:  "single field ignored",
-			input: "abc123\n",
+			input: hash1 + "\n",
 			want:  map[string]string{},
 		},
 		{
 			name:  "whitespace variations",
-			input: "abc123\tfile1.tar.gz\n",
+			input: hash1 + "\tfile1.tar.gz\n",
 			want: map[string]string{
-				"file1.tar.gz": "abc123",
+				"file1.tar.gz": hash1,
 			},
 		},
+		{
+			name:  "invalid checksum length ignored",
+			input: "abc123  file1.tar.gz\n",
+			want:  map[string]string{},
+		},
+		{
+			name:  "invalid checksum hex ignored",
+			input: "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz  file1.tar.gz\n",
+			want:  map[string]string{},
+		},
 	}
 
 	for _, tt := range tests {
@@ -617,3 +631,91 @@ func TestDetectInstallMethod_DefaultGoPath(t *testing.T) {
 	method := detectInstallMethod(exePath)
 	require.Equal(t, InstallMethodGoInstall, method)
 }
+
+func TestIsValidSHA256(t *testing.T) {
+	t.Parallel()
+
+	tests := []struct {
+		name  string
+		input string
+		want  bool
+	}{
+		{"valid hash", "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", true},
+		{"valid hash uppercase", "A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2", true},
+		{"too short", "abc123", false},
+		{"too long", "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3", false},
+		{"invalid hex", "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz", false},
+		{"empty", "", false},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			t.Parallel()
+			require.Equal(t, tt.want, isValidSHA256(tt.input))
+		})
+	}
+}
+
+func TestParseDigest(t *testing.T) {
+	t.Parallel()
+
+	validHash := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
+
+	tests := []struct {
+		name  string
+		input string
+		want  string
+	}{
+		{"valid sha256 digest", "sha256:" + validHash, validHash},
+		{"missing prefix", validHash, ""},
+		{"wrong prefix", "sha1:" + validHash, ""},
+		{"invalid hash after prefix", "sha256:invalid", ""},
+		{"empty", "", ""},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			t.Parallel()
+			require.Equal(t, tt.want, parseDigest(tt.input))
+		})
+	}
+}
+
+func TestCopyWithContext(t *testing.T) {
+	t.Parallel()
+
+	t.Run("successful copy", func(t *testing.T) {
+		t.Parallel()
+		src := strings.NewReader("hello world")
+		dst := &strings.Builder{}
+
+		n, err := copyWithContext(context.Background(), dst, src)
+		require.NoError(t, err)
+		require.Equal(t, int64(11), n)
+		require.Equal(t, "hello world", dst.String())
+	})
+
+	t.Run("cancelled context", func(t *testing.T) {
+		t.Parallel()
+		// Create a large source that will take multiple reads.
+		src := strings.NewReader(strings.Repeat("x", 100000))
+		dst := &strings.Builder{}
+
+		ctx, cancel := context.WithCancel(context.Background())
+		cancel() // Cancel immediately.
+
+		_, err := copyWithContext(ctx, dst, src)
+		require.ErrorIs(t, err, context.Canceled)
+	})
+
+	t.Run("empty source", func(t *testing.T) {
+		t.Parallel()
+		src := strings.NewReader("")
+		dst := &strings.Builder{}
+
+		n, err := copyWithContext(context.Background(), dst, src)
+		require.NoError(t, err)
+		require.Equal(t, int64(0), n)
+	})
+}
+

internal/update/update_windows.go 🔗

@@ -106,6 +106,15 @@ func HasPendingUpdate() bool {
 // ApplyPendingUpdate applies a pending update that was staged previously.
 // This should be called early in startup, before the main executable is locked.
 func ApplyPendingUpdate() error {
+	exe, err := os.Executable()
+	if err != nil {
+		return err
+	}
+
+	// Clean up any lingering .old file from a previous update.
+	oldPath := exe + ".old"
+	_ = os.Remove(oldPath)
+
 	pendingPath, err := pendingUpdatePath()
 	if err != nil {
 		return err
@@ -115,13 +124,27 @@ func ApplyPendingUpdate() error {
 		return nil // No pending update.
 	}
 
-	exe, err := os.Executable()
+	// Acquire exclusive lock to prevent race conditions with other processes.
+	lockPath := exe + ".lock"
+	lock, err := os.OpenFile(lockPath, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0o644)
 	if err != nil {
+		if os.IsExist(err) {
+			return nil // Another process is handling this.
+		}
 		return err
 	}
+	defer func() {
+		lock.Close()
+		os.Remove(lockPath)
+	}()
+
+	// Re-check pending update exists after acquiring lock (another process may
+	// have completed).
+	if _, err := os.Stat(pendingPath); os.IsNotExist(err) {
+		return nil
+	}
 
 	// Rename current to .old, new to current.
-	oldPath := exe + ".old"
 	if err := os.Rename(exe, oldPath); err != nil {
 		return err
 	}