Detailed changes
@@ -101,8 +101,13 @@ crush -y
tea.WithFilter(tui.MouseEventFilter)) // Filter mouse events based on focus state
go app.Subscribe(program)
+ // Create a cancellable context for the update check that gets cancelled
+ // when the TUI exits.
+ updateCtx, cancelUpdate := context.WithCancel(cmd.Context())
+ defer cancelUpdate()
+
// Start async update check unless disabled.
- go checkForUpdateAsync(cmd.Context(), program)
+ go checkForUpdateAsync(updateCtx, program)
if _, err := program.Run(); err != nil {
event.Error(err)
@@ -206,7 +211,7 @@ func checkForUpdateSync() string {
}
if info.IsDevelopment() {
- return fmt.Sprintf("\nThis is a development version of Crush. The latest stable release is v%s.\nRun 'crush update' to learn more.\n", info.Latest)
+ return info.DevelopmentVersionBrief()
}
return fmt.Sprintf("\nUpdate available: v%s → v%s\nRun 'crush update apply' to install.\n", info.Current, info.Latest)
@@ -230,6 +235,11 @@ func checkForUpdateAsync(ctx context.Context, program *tea.Program) {
return
}
+ // Check if context was cancelled while checking.
+ if ctx.Err() != nil {
+ return
+ }
+
// Check install method.
method := update.DetectInstallMethod()
if !method.CanSelfUpdate() {
@@ -253,8 +263,17 @@ func checkForUpdateAsync(ctx context.Context, program *tea.Program) {
return
}
+ // Check if context was cancelled before download.
+ if ctx.Err() != nil {
+ return
+ }
+
binaryPath, err := update.Download(checkCtx, asset, info.Release)
if err != nil {
+ // Don't show error message if context was cancelled (user exited).
+ if ctx.Err() != nil {
+ return
+ }
program.Send(tuiutil.InfoMsg{
Type: tuiutil.InfoTypeWarn,
Msg: "Update download failed. Run 'crush update' for details.",
@@ -264,6 +283,11 @@ func checkForUpdateAsync(ctx context.Context, program *tea.Program) {
}
defer os.Remove(binaryPath)
+ // Check if context was cancelled before apply.
+ if ctx.Err() != nil {
+ return
+ }
+
if err := update.Apply(binaryPath); err != nil {
program.Send(tuiutil.InfoMsg{
Type: tuiutil.InfoTypeWarn,
@@ -47,11 +47,7 @@ crush update apply --force
}
if info.IsDevelopment() {
- fmt.Fprintf(os.Stderr, "You are running a development version of Crush (%s).\n", info.Current)
- fmt.Fprintf(os.Stderr, "The latest stable release is v%s.\n", info.Latest)
- fmt.Fprintf(os.Stderr, "To install the latest stable version, run:\n")
- fmt.Fprintf(os.Stderr, " go install github.com/charmbracelet/crush@latest\n")
- fmt.Fprintf(os.Stderr, "Or visit %s to download manually.\n", info.URL)
+ fmt.Fprint(os.Stderr, info.DevelopmentVersionMessage(false))
return nil
}
@@ -115,10 +111,7 @@ crush update apply --force
if info.IsDevelopment() {
spinner.Stop()
- fmt.Fprintf(os.Stderr, "You are running a development version of Crush (%s).\n", info.Current)
- fmt.Fprintf(os.Stderr, "Self-update is not supported for development versions.\n")
- fmt.Fprintf(os.Stderr, "To install the latest stable version, run:\n")
- fmt.Fprintf(os.Stderr, " go install github.com/charmbracelet/crush@latest\n")
+ fmt.Fprint(os.Stderr, info.DevelopmentVersionMessage(true))
return nil
}
@@ -72,6 +72,34 @@ func (i Info) IsDevelopment() bool {
gitDescribeRegexp.MatchString(i.Current)
}
+// DevelopmentVersionBrief returns a brief message for development versions
+// suitable for the version flag output.
+func (i Info) DevelopmentVersionBrief() string {
+ return fmt.Sprintf(
+ "\nThis is a development version of Crush. The latest stable release is v%s.\nRun 'crush update' to learn more.\n",
+ i.Latest,
+ )
+}
+
+// DevelopmentVersionMessage returns a detailed message for development versions
+// suitable for CLI command output. If selfUpdateNote is true, includes a note
+// that self-update is not supported.
+func (i Info) DevelopmentVersionMessage(selfUpdateNote bool) string {
+ var b strings.Builder
+ fmt.Fprintf(&b, "You are running a development version of Crush (%s).\n", i.Current)
+ if selfUpdateNote {
+ b.WriteString("Self-update is not supported for development versions.\n")
+ } else {
+ fmt.Fprintf(&b, "The latest stable release is v%s.\n", i.Latest)
+ }
+ b.WriteString("To install the latest stable version, run:\n")
+ b.WriteString(" go install github.com/charmbracelet/crush@latest\n")
+ if i.URL != "" && !selfUpdateNote {
+ fmt.Fprintf(&b, "Or visit %s to download manually.\n", i.URL)
+ }
+ return b.String()
+}
+
// Available returns true if there's an update available.
// Uses proper semver comparison to handle version ordering correctly.
// Returns false if either version cannot be parsed.
@@ -210,6 +238,7 @@ type Asset struct {
Name string `json:"name"`
BrowserDownloadURL string `json:"browser_download_url"`
Size int64 `json:"size"`
+ Digest string `json:"digest"` // Format: "sha256:hexstring"
}
// Release represents a GitHub release.
@@ -357,7 +386,7 @@ func Download(ctx context.Context, asset *Asset, release *Release) (string, erro
// Use LimitReader to prevent DoS from oversized downloads (Content-Length can be spoofed).
hash := sha256.New()
limitedBody := io.LimitReader(resp.Body, maxArchiveSize+1)
- written, err := io.Copy(io.MultiWriter(tmpFile, hash), limitedBody)
+ written, err := copyWithContext(ctx, io.MultiWriter(tmpFile, hash), limitedBody)
if err != nil {
tmpFile.Close()
return "", fmt.Errorf("failed to download: %w", err)
@@ -395,8 +424,9 @@ func Download(ctx context.Context, asset *Asset, release *Release) (string, erro
return binaryPath, nil
}
-// downloadChecksums downloads and parses the checksums.txt file.
-// Returns a map of filename to sha256 checksum.
+// downloadChecksums downloads, verifies, and parses the checksums.txt file.
+// The checksums.txt file is verified against GitHub's API-provided digest before
+// parsing to ensure integrity. Returns a map of filename to sha256 checksum.
func downloadChecksums(ctx context.Context, client *http.Client, asset *Asset) (map[string]string, error) {
req, err := http.NewRequestWithContext(ctx, "GET", asset.BrowserDownloadURL, nil)
if err != nil {
@@ -417,26 +447,56 @@ func downloadChecksums(ctx context.Context, client *http.Client, asset *Asset) (
// Limit the size of checksums.txt to prevent memory exhaustion.
limitedBody := io.LimitReader(resp.Body, maxChecksumsSize+1)
- checksums := make(map[string]string)
- scanner := bufio.NewScanner(limitedBody)
- for scanner.Scan() {
- line := scanner.Text()
- parts := strings.Fields(line)
- if len(parts) != 2 {
- continue
+ // Read content while computing hash for verification against GitHub's digest.
+ hash := sha256.New()
+ content, err := io.ReadAll(io.TeeReader(limitedBody, hash))
+ if err != nil {
+ return nil, err
+ }
+ if int64(len(content)) > maxChecksumsSize {
+ return nil, fmt.Errorf("checksums.txt exceeds maximum size of %d bytes", maxChecksumsSize)
+ }
+
+ // Verify against GitHub's API-provided digest if available.
+ if asset.Digest != "" {
+ actualSum := hex.EncodeToString(hash.Sum(nil))
+ expectedSum := parseDigest(asset.Digest)
+ if expectedSum == "" {
+ return nil, fmt.Errorf("invalid digest format from API: %s", asset.Digest)
+ }
+ if actualSum != expectedSum {
+ return nil, fmt.Errorf("checksums.txt digest mismatch: expected %s, got %s", expectedSum, actualSum)
}
- // parts[0] is checksum, parts[1] is filename.
- checksums[parts[1]] = parts[0]
}
- if err := scanner.Err(); err != nil {
- return nil, err
+
+ return parseChecksumLines(string(content)), nil
+}
+
+// parseDigest extracts the hex checksum from a digest string (format: "sha256:hex").
+// Returns empty string if format is invalid.
+func parseDigest(digest string) string {
+ const prefix = "sha256:"
+ if !strings.HasPrefix(digest, prefix) {
+ return ""
}
+ hex := strings.TrimPrefix(digest, prefix)
+ if !isValidSHA256(hex) {
+ return ""
+ }
+ return hex
+}
- return checksums, nil
+// isValidSHA256 validates that a string is a valid SHA256 hex checksum.
+func isValidSHA256(s string) bool {
+ if len(s) != 64 {
+ return false
+ }
+ _, err := hex.DecodeString(s)
+ return err == nil
}
// parseChecksumLines parses checksum lines from a string.
-// Used for testing the checksum parsing logic.
+// Returns a map of filename to sha256 checksum. Invalid checksums are skipped.
func parseChecksumLines(content string) map[string]string {
checksums := make(map[string]string)
scanner := bufio.NewScanner(strings.NewReader(content))
@@ -446,6 +506,11 @@ func parseChecksumLines(content string) map[string]string {
if len(parts) != 2 {
continue
}
+ // Validate checksum format (64 hex chars for SHA256).
+ if !isValidSHA256(parts[0]) {
+ continue
+ }
+ // parts[0] is checksum, parts[1] is filename.
checksums[parts[1]] = parts[0]
}
return checksums
@@ -634,3 +699,45 @@ func copyFile(src, dst string) error {
// Sync to ensure data is persisted before atomic rename.
return dstFile.Sync()
}
+
+// copyWithContext copies from src to dst while periodically checking for context
+// cancellation. This allows large downloads to be interrupted when the user
+// quits the application.
+func copyWithContext(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) {
+ const bufSize = 32 * 1024 // 32KB chunks
+ buf := make([]byte, bufSize)
+ var written int64
+
+ for {
+ // Check for cancellation before each chunk.
+ select {
+ case <-ctx.Done():
+ return written, ctx.Err()
+ default:
+ }
+
+ nr, readErr := src.Read(buf)
+ if nr > 0 {
+ nw, writeErr := dst.Write(buf[:nr])
+ if nw < 0 || nr < nw {
+ nw = 0
+ if writeErr == nil {
+ writeErr = fmt.Errorf("invalid write result")
+ }
+ }
+ written += int64(nw)
+ if writeErr != nil {
+ return written, writeErr
+ }
+ if nr != nw {
+ return written, io.ErrShortWrite
+ }
+ }
+ if readErr != nil {
+ if readErr == io.EOF {
+ return written, nil
+ }
+ return written, readErr
+ }
+ }
+}
@@ -210,6 +210,10 @@ func TestAvailable(t *testing.T) {
func TestParseChecksums(t *testing.T) {
t.Parallel()
+ // Valid SHA256 hashes for testing (64 hex characters).
+ hash1 := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
+ hash2 := "f6e5d4c3b2a1f6e5d4c3b2a1f6e5d4c3b2a1f6e5d4c3b2a1f6e5d4c3b2a1f6e5"
+
tests := []struct {
name string
input string
@@ -217,37 +221,47 @@ func TestParseChecksums(t *testing.T) {
}{
{
name: "standard format",
- input: "abc123 crush_0.19.2_Linux_x86_64.tar.gz\ndef456 crush_0.19.2_Darwin_arm64.tar.gz\n",
+ input: hash1 + " crush_0.19.2_Linux_x86_64.tar.gz\n" + hash2 + " crush_0.19.2_Darwin_arm64.tar.gz\n",
want: map[string]string{
- "crush_0.19.2_Linux_x86_64.tar.gz": "abc123",
- "crush_0.19.2_Darwin_arm64.tar.gz": "def456",
+ "crush_0.19.2_Linux_x86_64.tar.gz": hash1,
+ "crush_0.19.2_Darwin_arm64.tar.gz": hash2,
},
},
{
name: "empty lines",
- input: "abc123 file1.tar.gz\n\ndef456 file2.tar.gz\n",
+ input: hash1 + " file1.tar.gz\n\n" + hash2 + " file2.tar.gz\n",
want: map[string]string{
- "file1.tar.gz": "abc123",
- "file2.tar.gz": "def456",
+ "file1.tar.gz": hash1,
+ "file2.tar.gz": hash2,
},
},
{
name: "extra fields ignored",
- input: "abc123 file1.tar.gz extra fields\n",
+ input: hash1 + " file1.tar.gz extra fields\n",
want: map[string]string{},
},
{
name: "single field ignored",
- input: "abc123\n",
+ input: hash1 + "\n",
want: map[string]string{},
},
{
name: "whitespace variations",
- input: "abc123\tfile1.tar.gz\n",
+ input: hash1 + "\tfile1.tar.gz\n",
want: map[string]string{
- "file1.tar.gz": "abc123",
+ "file1.tar.gz": hash1,
},
},
+ {
+ name: "invalid checksum length ignored",
+ input: "abc123 file1.tar.gz\n",
+ want: map[string]string{},
+ },
+ {
+ name: "invalid checksum hex ignored",
+ input: "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz file1.tar.gz\n",
+ want: map[string]string{},
+ },
}
for _, tt := range tests {
@@ -617,3 +631,91 @@ func TestDetectInstallMethod_DefaultGoPath(t *testing.T) {
method := detectInstallMethod(exePath)
require.Equal(t, InstallMethodGoInstall, method)
}
+
+func TestIsValidSHA256(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ input string
+ want bool
+ }{
+ {"valid hash", "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", true},
+ {"valid hash uppercase", "A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2", true},
+ {"too short", "abc123", false},
+ {"too long", "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3", false},
+ {"invalid hex", "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz", false},
+ {"empty", "", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.want, isValidSHA256(tt.input))
+ })
+ }
+}
+
+func TestParseDigest(t *testing.T) {
+ t.Parallel()
+
+ validHash := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
+
+ tests := []struct {
+ name string
+ input string
+ want string
+ }{
+ {"valid sha256 digest", "sha256:" + validHash, validHash},
+ {"missing prefix", validHash, ""},
+ {"wrong prefix", "sha1:" + validHash, ""},
+ {"invalid hash after prefix", "sha256:invalid", ""},
+ {"empty", "", ""},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.want, parseDigest(tt.input))
+ })
+ }
+}
+
+func TestCopyWithContext(t *testing.T) {
+ t.Parallel()
+
+ t.Run("successful copy", func(t *testing.T) {
+ t.Parallel()
+ src := strings.NewReader("hello world")
+ dst := &strings.Builder{}
+
+ n, err := copyWithContext(context.Background(), dst, src)
+ require.NoError(t, err)
+ require.Equal(t, int64(11), n)
+ require.Equal(t, "hello world", dst.String())
+ })
+
+ t.Run("cancelled context", func(t *testing.T) {
+ t.Parallel()
+ // Create a large source that will take multiple reads.
+ src := strings.NewReader(strings.Repeat("x", 100000))
+ dst := &strings.Builder{}
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel() // Cancel immediately.
+
+ _, err := copyWithContext(ctx, dst, src)
+ require.ErrorIs(t, err, context.Canceled)
+ })
+
+ t.Run("empty source", func(t *testing.T) {
+ t.Parallel()
+ src := strings.NewReader("")
+ dst := &strings.Builder{}
+
+ n, err := copyWithContext(context.Background(), dst, src)
+ require.NoError(t, err)
+ require.Equal(t, int64(0), n)
+ })
+}
+
@@ -106,6 +106,15 @@ func HasPendingUpdate() bool {
// ApplyPendingUpdate applies a pending update that was staged previously.
// This should be called early in startup, before the main executable is locked.
func ApplyPendingUpdate() error {
+ exe, err := os.Executable()
+ if err != nil {
+ return err
+ }
+
+ // Clean up any lingering .old file from a previous update.
+ oldPath := exe + ".old"
+ _ = os.Remove(oldPath)
+
pendingPath, err := pendingUpdatePath()
if err != nil {
return err
@@ -115,13 +124,27 @@ func ApplyPendingUpdate() error {
return nil // No pending update.
}
- exe, err := os.Executable()
+ // Acquire exclusive lock to prevent race conditions with other processes.
+ lockPath := exe + ".lock"
+ lock, err := os.OpenFile(lockPath, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0o644)
if err != nil {
+ if os.IsExist(err) {
+ return nil // Another process is handling this.
+ }
return err
}
+ defer func() {
+ lock.Close()
+ os.Remove(lockPath)
+ }()
+
+ // Re-check pending update exists after acquiring lock (another process may
+ // have completed).
+ if _, err := os.Stat(pendingPath); os.IsNotExist(err) {
+ return nil
+ }
// Rename current to .old, new to current.
- oldPath := exe + ".old"
if err := os.Rename(exe, oldPath); err != nil {
return err
}