From fb6013687e5cf8b52a20ebfded476147035085ca Mon Sep 17 00:00:00 2001 From: Amolith Date: Mon, 8 Dec 2025 22:06:29 -0700 Subject: [PATCH] refactor(update): harden security and fix issues - Verify checksums.txt against GitHub API digest before trusting contents - Add lockfile to prevent race condition in Windows ApplyPendingUpdate - Clean up lingering .old files on Windows startup - Make async update cancellable when TUI exits - Add context-aware download with cancellation support - Validate SHA256 checksum format (64 hex chars) - Deduplicate checksum parsing and dev version messages Assisted-by: Claude Opus 4.5 via Crush --- internal/cmd/root.go | 28 +++++- internal/cmd/update.go | 11 +-- internal/update/update.go | 139 ++++++++++++++++++++++++++---- internal/update/update_test.go | 122 +++++++++++++++++++++++--- internal/update/update_windows.go | 27 +++++- 5 files changed, 288 insertions(+), 39 deletions(-) diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 94d17c72703d2e94bbbeb41a00b9b82fb000678c..27e48d415c87cf55dabfe4f72970f57c5f82a017 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -101,8 +101,13 @@ crush -y tea.WithFilter(tui.MouseEventFilter)) // Filter mouse events based on focus state go app.Subscribe(program) + // Create a cancellable context for the update check that gets cancelled + // when the TUI exits. + updateCtx, cancelUpdate := context.WithCancel(cmd.Context()) + defer cancelUpdate() + // Start async update check unless disabled. - go checkForUpdateAsync(cmd.Context(), program) + go checkForUpdateAsync(updateCtx, program) if _, err := program.Run(); err != nil { event.Error(err) @@ -206,7 +211,7 @@ func checkForUpdateSync() string { } if info.IsDevelopment() { - return fmt.Sprintf("\nThis is a development version of Crush. The latest stable release is v%s.\nRun 'crush update' to learn more.\n", info.Latest) + return info.DevelopmentVersionBrief() } return fmt.Sprintf("\nUpdate available: v%s → v%s\nRun 'crush update apply' to install.\n", info.Current, info.Latest) @@ -230,6 +235,11 @@ func checkForUpdateAsync(ctx context.Context, program *tea.Program) { return } + // Check if context was cancelled while checking. + if ctx.Err() != nil { + return + } + // Check install method. method := update.DetectInstallMethod() if !method.CanSelfUpdate() { @@ -253,8 +263,17 @@ func checkForUpdateAsync(ctx context.Context, program *tea.Program) { return } + // Check if context was cancelled before download. + if ctx.Err() != nil { + return + } + binaryPath, err := update.Download(checkCtx, asset, info.Release) if err != nil { + // Don't show error message if context was cancelled (user exited). + if ctx.Err() != nil { + return + } program.Send(tuiutil.InfoMsg{ Type: tuiutil.InfoTypeWarn, Msg: "Update download failed. Run 'crush update' for details.", @@ -264,6 +283,11 @@ func checkForUpdateAsync(ctx context.Context, program *tea.Program) { } defer os.Remove(binaryPath) + // Check if context was cancelled before apply. + if ctx.Err() != nil { + return + } + if err := update.Apply(binaryPath); err != nil { program.Send(tuiutil.InfoMsg{ Type: tuiutil.InfoTypeWarn, diff --git a/internal/cmd/update.go b/internal/cmd/update.go index bf0f75253ea32ded8bebd667019a82b244c8b8b3..2a9e420f23272d8e758fbe99477fcc1c31096489 100644 --- a/internal/cmd/update.go +++ b/internal/cmd/update.go @@ -47,11 +47,7 @@ crush update apply --force } if info.IsDevelopment() { - fmt.Fprintf(os.Stderr, "You are running a development version of Crush (%s).\n", info.Current) - fmt.Fprintf(os.Stderr, "The latest stable release is v%s.\n", info.Latest) - fmt.Fprintf(os.Stderr, "To install the latest stable version, run:\n") - fmt.Fprintf(os.Stderr, " go install github.com/charmbracelet/crush@latest\n") - fmt.Fprintf(os.Stderr, "Or visit %s to download manually.\n", info.URL) + fmt.Fprint(os.Stderr, info.DevelopmentVersionMessage(false)) return nil } @@ -115,10 +111,7 @@ crush update apply --force if info.IsDevelopment() { spinner.Stop() - fmt.Fprintf(os.Stderr, "You are running a development version of Crush (%s).\n", info.Current) - fmt.Fprintf(os.Stderr, "Self-update is not supported for development versions.\n") - fmt.Fprintf(os.Stderr, "To install the latest stable version, run:\n") - fmt.Fprintf(os.Stderr, " go install github.com/charmbracelet/crush@latest\n") + fmt.Fprint(os.Stderr, info.DevelopmentVersionMessage(true)) return nil } diff --git a/internal/update/update.go b/internal/update/update.go index 6f64ec5cfaddf0889246ee18ee5cdfc559dab6b9..334646cdaea6828c8e1493e4f95d964faf4a3522 100644 --- a/internal/update/update.go +++ b/internal/update/update.go @@ -72,6 +72,34 @@ func (i Info) IsDevelopment() bool { gitDescribeRegexp.MatchString(i.Current) } +// DevelopmentVersionBrief returns a brief message for development versions +// suitable for the version flag output. +func (i Info) DevelopmentVersionBrief() string { + return fmt.Sprintf( + "\nThis is a development version of Crush. The latest stable release is v%s.\nRun 'crush update' to learn more.\n", + i.Latest, + ) +} + +// DevelopmentVersionMessage returns a detailed message for development versions +// suitable for CLI command output. If selfUpdateNote is true, includes a note +// that self-update is not supported. +func (i Info) DevelopmentVersionMessage(selfUpdateNote bool) string { + var b strings.Builder + fmt.Fprintf(&b, "You are running a development version of Crush (%s).\n", i.Current) + if selfUpdateNote { + b.WriteString("Self-update is not supported for development versions.\n") + } else { + fmt.Fprintf(&b, "The latest stable release is v%s.\n", i.Latest) + } + b.WriteString("To install the latest stable version, run:\n") + b.WriteString(" go install github.com/charmbracelet/crush@latest\n") + if i.URL != "" && !selfUpdateNote { + fmt.Fprintf(&b, "Or visit %s to download manually.\n", i.URL) + } + return b.String() +} + // Available returns true if there's an update available. // Uses proper semver comparison to handle version ordering correctly. // Returns false if either version cannot be parsed. @@ -210,6 +238,7 @@ type Asset struct { Name string `json:"name"` BrowserDownloadURL string `json:"browser_download_url"` Size int64 `json:"size"` + Digest string `json:"digest"` // Format: "sha256:hexstring" } // Release represents a GitHub release. @@ -357,7 +386,7 @@ func Download(ctx context.Context, asset *Asset, release *Release) (string, erro // Use LimitReader to prevent DoS from oversized downloads (Content-Length can be spoofed). hash := sha256.New() limitedBody := io.LimitReader(resp.Body, maxArchiveSize+1) - written, err := io.Copy(io.MultiWriter(tmpFile, hash), limitedBody) + written, err := copyWithContext(ctx, io.MultiWriter(tmpFile, hash), limitedBody) if err != nil { tmpFile.Close() return "", fmt.Errorf("failed to download: %w", err) @@ -395,8 +424,9 @@ func Download(ctx context.Context, asset *Asset, release *Release) (string, erro return binaryPath, nil } -// downloadChecksums downloads and parses the checksums.txt file. -// Returns a map of filename to sha256 checksum. +// downloadChecksums downloads, verifies, and parses the checksums.txt file. +// The checksums.txt file is verified against GitHub's API-provided digest before +// parsing to ensure integrity. Returns a map of filename to sha256 checksum. func downloadChecksums(ctx context.Context, client *http.Client, asset *Asset) (map[string]string, error) { req, err := http.NewRequestWithContext(ctx, "GET", asset.BrowserDownloadURL, nil) if err != nil { @@ -417,26 +447,56 @@ func downloadChecksums(ctx context.Context, client *http.Client, asset *Asset) ( // Limit the size of checksums.txt to prevent memory exhaustion. limitedBody := io.LimitReader(resp.Body, maxChecksumsSize+1) - checksums := make(map[string]string) - scanner := bufio.NewScanner(limitedBody) - for scanner.Scan() { - line := scanner.Text() - parts := strings.Fields(line) - if len(parts) != 2 { - continue + // Read content while computing hash for verification against GitHub's digest. + hash := sha256.New() + content, err := io.ReadAll(io.TeeReader(limitedBody, hash)) + if err != nil { + return nil, err + } + if int64(len(content)) > maxChecksumsSize { + return nil, fmt.Errorf("checksums.txt exceeds maximum size of %d bytes", maxChecksumsSize) + } + + // Verify against GitHub's API-provided digest if available. + if asset.Digest != "" { + actualSum := hex.EncodeToString(hash.Sum(nil)) + expectedSum := parseDigest(asset.Digest) + if expectedSum == "" { + return nil, fmt.Errorf("invalid digest format from API: %s", asset.Digest) + } + if actualSum != expectedSum { + return nil, fmt.Errorf("checksums.txt digest mismatch: expected %s, got %s", expectedSum, actualSum) } - // parts[0] is checksum, parts[1] is filename. - checksums[parts[1]] = parts[0] } - if err := scanner.Err(); err != nil { - return nil, err + + return parseChecksumLines(string(content)), nil +} + +// parseDigest extracts the hex checksum from a digest string (format: "sha256:hex"). +// Returns empty string if format is invalid. +func parseDigest(digest string) string { + const prefix = "sha256:" + if !strings.HasPrefix(digest, prefix) { + return "" } + hex := strings.TrimPrefix(digest, prefix) + if !isValidSHA256(hex) { + return "" + } + return hex +} - return checksums, nil +// isValidSHA256 validates that a string is a valid SHA256 hex checksum. +func isValidSHA256(s string) bool { + if len(s) != 64 { + return false + } + _, err := hex.DecodeString(s) + return err == nil } // parseChecksumLines parses checksum lines from a string. -// Used for testing the checksum parsing logic. +// Returns a map of filename to sha256 checksum. Invalid checksums are skipped. func parseChecksumLines(content string) map[string]string { checksums := make(map[string]string) scanner := bufio.NewScanner(strings.NewReader(content)) @@ -446,6 +506,11 @@ func parseChecksumLines(content string) map[string]string { if len(parts) != 2 { continue } + // Validate checksum format (64 hex chars for SHA256). + if !isValidSHA256(parts[0]) { + continue + } + // parts[0] is checksum, parts[1] is filename. checksums[parts[1]] = parts[0] } return checksums @@ -634,3 +699,45 @@ func copyFile(src, dst string) error { // Sync to ensure data is persisted before atomic rename. return dstFile.Sync() } + +// copyWithContext copies from src to dst while periodically checking for context +// cancellation. This allows large downloads to be interrupted when the user +// quits the application. +func copyWithContext(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) { + const bufSize = 32 * 1024 // 32KB chunks + buf := make([]byte, bufSize) + var written int64 + + for { + // Check for cancellation before each chunk. + select { + case <-ctx.Done(): + return written, ctx.Err() + default: + } + + nr, readErr := src.Read(buf) + if nr > 0 { + nw, writeErr := dst.Write(buf[:nr]) + if nw < 0 || nr < nw { + nw = 0 + if writeErr == nil { + writeErr = fmt.Errorf("invalid write result") + } + } + written += int64(nw) + if writeErr != nil { + return written, writeErr + } + if nr != nw { + return written, io.ErrShortWrite + } + } + if readErr != nil { + if readErr == io.EOF { + return written, nil + } + return written, readErr + } + } +} diff --git a/internal/update/update_test.go b/internal/update/update_test.go index e886caeed04b175437dad12f04e0aeebfff4de7e..7b47b7ba4d5c98fb72adaf5b129e829cbd92cc59 100644 --- a/internal/update/update_test.go +++ b/internal/update/update_test.go @@ -210,6 +210,10 @@ func TestAvailable(t *testing.T) { func TestParseChecksums(t *testing.T) { t.Parallel() + // Valid SHA256 hashes for testing (64 hex characters). + hash1 := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + hash2 := "f6e5d4c3b2a1f6e5d4c3b2a1f6e5d4c3b2a1f6e5d4c3b2a1f6e5d4c3b2a1f6e5" + tests := []struct { name string input string @@ -217,37 +221,47 @@ func TestParseChecksums(t *testing.T) { }{ { name: "standard format", - input: "abc123 crush_0.19.2_Linux_x86_64.tar.gz\ndef456 crush_0.19.2_Darwin_arm64.tar.gz\n", + input: hash1 + " crush_0.19.2_Linux_x86_64.tar.gz\n" + hash2 + " crush_0.19.2_Darwin_arm64.tar.gz\n", want: map[string]string{ - "crush_0.19.2_Linux_x86_64.tar.gz": "abc123", - "crush_0.19.2_Darwin_arm64.tar.gz": "def456", + "crush_0.19.2_Linux_x86_64.tar.gz": hash1, + "crush_0.19.2_Darwin_arm64.tar.gz": hash2, }, }, { name: "empty lines", - input: "abc123 file1.tar.gz\n\ndef456 file2.tar.gz\n", + input: hash1 + " file1.tar.gz\n\n" + hash2 + " file2.tar.gz\n", want: map[string]string{ - "file1.tar.gz": "abc123", - "file2.tar.gz": "def456", + "file1.tar.gz": hash1, + "file2.tar.gz": hash2, }, }, { name: "extra fields ignored", - input: "abc123 file1.tar.gz extra fields\n", + input: hash1 + " file1.tar.gz extra fields\n", want: map[string]string{}, }, { name: "single field ignored", - input: "abc123\n", + input: hash1 + "\n", want: map[string]string{}, }, { name: "whitespace variations", - input: "abc123\tfile1.tar.gz\n", + input: hash1 + "\tfile1.tar.gz\n", want: map[string]string{ - "file1.tar.gz": "abc123", + "file1.tar.gz": hash1, }, }, + { + name: "invalid checksum length ignored", + input: "abc123 file1.tar.gz\n", + want: map[string]string{}, + }, + { + name: "invalid checksum hex ignored", + input: "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz file1.tar.gz\n", + want: map[string]string{}, + }, } for _, tt := range tests { @@ -617,3 +631,91 @@ func TestDetectInstallMethod_DefaultGoPath(t *testing.T) { method := detectInstallMethod(exePath) require.Equal(t, InstallMethodGoInstall, method) } + +func TestIsValidSHA256(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want bool + }{ + {"valid hash", "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", true}, + {"valid hash uppercase", "A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2", true}, + {"too short", "abc123", false}, + {"too long", "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3", false}, + {"invalid hex", "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz", false}, + {"empty", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, isValidSHA256(tt.input)) + }) + } +} + +func TestParseDigest(t *testing.T) { + t.Parallel() + + validHash := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + + tests := []struct { + name string + input string + want string + }{ + {"valid sha256 digest", "sha256:" + validHash, validHash}, + {"missing prefix", validHash, ""}, + {"wrong prefix", "sha1:" + validHash, ""}, + {"invalid hash after prefix", "sha256:invalid", ""}, + {"empty", "", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, parseDigest(tt.input)) + }) + } +} + +func TestCopyWithContext(t *testing.T) { + t.Parallel() + + t.Run("successful copy", func(t *testing.T) { + t.Parallel() + src := strings.NewReader("hello world") + dst := &strings.Builder{} + + n, err := copyWithContext(context.Background(), dst, src) + require.NoError(t, err) + require.Equal(t, int64(11), n) + require.Equal(t, "hello world", dst.String()) + }) + + t.Run("cancelled context", func(t *testing.T) { + t.Parallel() + // Create a large source that will take multiple reads. + src := strings.NewReader(strings.Repeat("x", 100000)) + dst := &strings.Builder{} + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately. + + _, err := copyWithContext(ctx, dst, src) + require.ErrorIs(t, err, context.Canceled) + }) + + t.Run("empty source", func(t *testing.T) { + t.Parallel() + src := strings.NewReader("") + dst := &strings.Builder{} + + n, err := copyWithContext(context.Background(), dst, src) + require.NoError(t, err) + require.Equal(t, int64(0), n) + }) +} + diff --git a/internal/update/update_windows.go b/internal/update/update_windows.go index af410c9f28a95fe4b194f6a538fc1fae8efbe725..974c722a7fb95c52cec2992798c7f473123c4c64 100644 --- a/internal/update/update_windows.go +++ b/internal/update/update_windows.go @@ -106,6 +106,15 @@ func HasPendingUpdate() bool { // ApplyPendingUpdate applies a pending update that was staged previously. // This should be called early in startup, before the main executable is locked. func ApplyPendingUpdate() error { + exe, err := os.Executable() + if err != nil { + return err + } + + // Clean up any lingering .old file from a previous update. + oldPath := exe + ".old" + _ = os.Remove(oldPath) + pendingPath, err := pendingUpdatePath() if err != nil { return err @@ -115,13 +124,27 @@ func ApplyPendingUpdate() error { return nil // No pending update. } - exe, err := os.Executable() + // Acquire exclusive lock to prevent race conditions with other processes. + lockPath := exe + ".lock" + lock, err := os.OpenFile(lockPath, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0o644) if err != nil { + if os.IsExist(err) { + return nil // Another process is handling this. + } return err } + defer func() { + lock.Close() + os.Remove(lockPath) + }() + + // Re-check pending update exists after acquiring lock (another process may + // have completed). + if _, err := os.Stat(pendingPath); os.IsNotExist(err) { + return nil + } // Rename current to .old, new to current. - oldPath := exe + ".old" if err := os.Rename(exe, oldPath); err != nil { return err }