1package update
2
3import (
4 "archive/tar"
5 "archive/zip"
6 "bufio"
7 "compress/gzip"
8 "context"
9 "crypto/sha256"
10 "encoding/hex"
11 "encoding/json"
12 "fmt"
13 "io"
14 "net/http"
15 "os"
16 "path/filepath"
17 "regexp"
18 "runtime"
19 "strings"
20 "time"
21
22 "github.com/Masterminds/semver/v3"
23 "github.com/charmbracelet/crush/internal/version"
24)
25
26const (
27 githubAPIURL = "https://api.github.com/repos/charmbracelet/crush/releases/latest"
28 maxBinarySize = 500 * 1024 * 1024 // 500MB max for extracted binary
29 maxArchiveSize = 500 * 1024 * 1024 // 500MB max for downloaded archive
30 maxChecksumsSize = 1 * 1024 * 1024 // 1MB max for checksums.txt
31)
32
33// binaryName returns the expected binary name for the current platform.
34func binaryName() string {
35 if runtime.GOOS == "windows" {
36 return "crush.exe"
37 }
38 return "crush"
39}
40
41// userAgent returns the user agent string for HTTP requests.
42func userAgent() string {
43 return "crush/" + version.Version
44}
45
46// Default is the default [Client].
47var Default Client = &github{}
48
49// Info contains information about an available update.
50type Info struct {
51 Current string
52 Latest string
53 URL string
54 Release *Release
55}
56
57// goInstallRegexp matches pseudo-versions from go install:
58// v0.0.0-0.20251231235959-06c807842604
59var goInstallRegexp = regexp.MustCompile(`^v?\d+\.\d+\.\d+-\d+\.\d{14}-[0-9a-f]{12}$`)
60
61// gitDescribeRegexp matches git describe versions:
62// v0.19.0-15-g1a2b3c4d (tag-commits-ghash)
63var gitDescribeRegexp = regexp.MustCompile(`^v?\d+\.\d+\.\d+-\d+-g[0-9a-f]+$`)
64
65// IsDevelopment returns true if the current version appears to be a
66// development build rather than an official release.
67func (i Info) IsDevelopment() bool {
68 return i.Current == "devel" ||
69 i.Current == "unknown" ||
70 strings.Contains(i.Current, "dirty") ||
71 goInstallRegexp.MatchString(i.Current) ||
72 gitDescribeRegexp.MatchString(i.Current)
73}
74
75// DevelopmentVersionBrief returns a brief message for development versions
76// suitable for the version flag output.
77func (i Info) DevelopmentVersionBrief() string {
78 return fmt.Sprintf(
79 "\nThis is a development version of Crush. The latest stable release is v%s.\nRun 'crush update' to learn more.\n",
80 i.Latest,
81 )
82}
83
84// DevelopmentVersionMessage returns a detailed message for development versions
85// suitable for CLI command output. If selfUpdateNote is true, includes a note
86// that self-update is not supported.
87func (i Info) DevelopmentVersionMessage(selfUpdateNote bool) string {
88 var b strings.Builder
89 fmt.Fprintf(&b, "You are running a development version of Crush (%s).\n", i.Current)
90 if selfUpdateNote {
91 b.WriteString("Self-update is not supported for development versions.\n")
92 } else {
93 fmt.Fprintf(&b, "The latest stable release is v%s.\n", i.Latest)
94 }
95 b.WriteString("To install the latest stable version, run:\n")
96 b.WriteString(" go install github.com/charmbracelet/crush@latest\n")
97 if i.URL != "" && !selfUpdateNote {
98 fmt.Fprintf(&b, "Or visit %s to download manually.\n", i.URL)
99 }
100 return b.String()
101}
102
103// Available returns true if there's an update available.
104// Uses proper semver comparison to handle version ordering correctly.
105// Returns false if either version cannot be parsed.
106// Special case: if current is stable and latest is a prerelease, returns false
107// (we don't offer prerelease updates to stable users).
108func (i Info) Available() bool {
109 current, err := semver.NewVersion(i.Current)
110 if err != nil {
111 return false
112 }
113 latest, err := semver.NewVersion(i.Latest)
114 if err != nil {
115 return false
116 }
117
118 // Don't offer prerelease updates to stable users.
119 if current.Prerelease() == "" && latest.Prerelease() != "" {
120 return false
121 }
122
123 return latest.GreaterThan(current)
124}
125
126// InstallMethod represents how Crush was installed.
127type InstallMethod int
128
129const (
130 InstallMethodUnknown InstallMethod = iota
131 InstallMethodBinary // Direct binary download
132 InstallMethodHomebrew
133 InstallMethodNPM
134 InstallMethodAUR
135 InstallMethodNix
136 InstallMethodWinget
137 InstallMethodScoop
138 InstallMethodApt
139 InstallMethodYum
140 InstallMethodGoInstall
141)
142
143// String returns a human-readable name for the install method.
144func (m InstallMethod) String() string {
145 switch m {
146 case InstallMethodBinary:
147 return "binary"
148 case InstallMethodHomebrew:
149 return "Homebrew"
150 case InstallMethodNPM:
151 return "npm"
152 case InstallMethodAUR:
153 return "AUR"
154 case InstallMethodNix:
155 return "Nix"
156 case InstallMethodWinget:
157 return "winget"
158 case InstallMethodScoop:
159 return "Scoop"
160 case InstallMethodApt:
161 return "apt"
162 case InstallMethodYum:
163 return "yum"
164 case InstallMethodGoInstall:
165 return "go install"
166 default:
167 return "unknown"
168 }
169}
170
171// CanSelfUpdate returns true if this install method supports self-updating.
172func (m InstallMethod) CanSelfUpdate() bool {
173 return m == InstallMethodBinary || m == InstallMethodUnknown
174}
175
176// UpdateInstructions returns the command to update Crush for this install method.
177func (m InstallMethod) UpdateInstructions() string {
178 switch m {
179 case InstallMethodHomebrew:
180 return "brew upgrade charmbracelet/tap/crush"
181 case InstallMethodNPM:
182 return "npm update -g @charmland/crush"
183 case InstallMethodAUR:
184 return "yay -Syu crush-bin # or your preferred AUR helper"
185 case InstallMethodNix:
186 return "nix flake update # or update your NUR channel"
187 case InstallMethodWinget:
188 return "winget upgrade charmbracelet.crush"
189 case InstallMethodScoop:
190 return "scoop update crush"
191 case InstallMethodApt:
192 return "sudo apt update && sudo apt upgrade crush"
193 case InstallMethodYum:
194 return "sudo yum update crush"
195 case InstallMethodGoInstall:
196 return "go install github.com/charmbracelet/crush@latest"
197 default:
198 return ""
199 }
200}
201
202// DetectInstallMethod attempts to determine how Crush was installed.
203// This is implemented per-platform in update_darwin.go, update_unix.go,
204// and update_windows.go.
205func DetectInstallMethod() InstallMethod {
206 exe, err := os.Executable()
207 if err != nil {
208 return InstallMethodUnknown
209 }
210 exe, err = filepath.EvalSymlinks(exe)
211 if err != nil {
212 return InstallMethodUnknown
213 }
214 return detectInstallMethod(exe)
215}
216
217// Check checks if a new version is available.
218func Check(ctx context.Context, current string, client Client) (Info, error) {
219 info := Info{
220 Current: current,
221 Latest: current,
222 }
223
224 release, err := client.Latest(ctx)
225 if err != nil {
226 return info, fmt.Errorf("failed to fetch latest release: %w", err)
227 }
228
229 info.Latest = strings.TrimPrefix(release.TagName, "v")
230 info.Current = strings.TrimPrefix(info.Current, "v")
231 info.URL = release.HTMLURL
232 info.Release = release
233 return info, nil
234}
235
236// Asset represents a GitHub release asset.
237type Asset struct {
238 Name string `json:"name"`
239 BrowserDownloadURL string `json:"browser_download_url"`
240 Size int64 `json:"size"`
241 Digest string `json:"digest"` // Format: "sha256:hexstring"
242}
243
244// Release represents a GitHub release.
245type Release struct {
246 TagName string `json:"tag_name"`
247 HTMLURL string `json:"html_url"`
248 Assets []Asset `json:"assets"`
249}
250
251// Client is a client that can get the latest release.
252type Client interface {
253 Latest(ctx context.Context) (*Release, error)
254}
255
256type github struct{}
257
258// Latest implements [Client].
259func (c *github) Latest(ctx context.Context) (*Release, error) {
260 client := &http.Client{
261 Timeout: 30 * time.Second,
262 }
263
264 req, err := http.NewRequestWithContext(ctx, "GET", githubAPIURL, nil)
265 if err != nil {
266 return nil, err
267 }
268 req.Header.Set("User-Agent", userAgent())
269 req.Header.Set("Accept", "application/vnd.github.v3+json")
270
271 resp, err := client.Do(req)
272 if err != nil {
273 return nil, err
274 }
275 defer resp.Body.Close()
276
277 if resp.StatusCode != http.StatusOK {
278 body, _ := io.ReadAll(resp.Body)
279 return nil, fmt.Errorf("github api returned status %d: %s", resp.StatusCode, string(body))
280 }
281
282 var release Release
283 if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
284 return nil, err
285 }
286
287 return &release, nil
288}
289
290// FindAsset finds the appropriate asset for the current platform.
291func FindAsset(assets []Asset) (*Asset, error) {
292 // Normalize architecture to match goreleaser naming.
293 arch := runtime.GOARCH
294 switch arch {
295 case "amd64":
296 arch = "x86_64"
297 case "386":
298 arch = "i386"
299 case "arm":
300 arch = "armv7"
301 // arm64 stays as "arm64" in goreleaser naming.
302 }
303
304 // Normalize OS to match goreleaser naming (title case).
305 goos := runtime.GOOS
306 switch goos {
307 case "freebsd":
308 goos = "Freebsd"
309 case "netbsd":
310 goos = "Netbsd"
311 case "openbsd":
312 goos = "Openbsd"
313 default:
314 if len(goos) > 0 {
315 goos = strings.ToUpper(goos[:1]) + goos[1:]
316 }
317 }
318
319 // Look for archive matching our platform.
320 // Pattern: crush_{version}_{OS}_{ARCH}.{tar.gz|zip}
321 for _, asset := range assets {
322 if strings.Contains(asset.Name, goos) && strings.Contains(asset.Name, arch) {
323 // Ensure it's an archive, not a checksum or signature.
324 if strings.HasSuffix(asset.Name, ".tar.gz") || strings.HasSuffix(asset.Name, ".zip") {
325 return &asset, nil
326 }
327 }
328 }
329
330 return nil, fmt.Errorf("no suitable asset found for %s/%s", runtime.GOOS, runtime.GOARCH)
331}
332
333// Download downloads and extracts the crush binary from the given asset.
334// Returns the path to the extracted binary.
335func Download(ctx context.Context, asset *Asset, release *Release) (string, error) {
336 client := &http.Client{}
337
338 // Find checksums.txt and its sigstore bundle.
339 var checksumsAsset, sigstoreBundleAsset *Asset
340 for i := range release.Assets {
341 switch release.Assets[i].Name {
342 case "checksums.txt":
343 checksumsAsset = &release.Assets[i]
344 case "checksums.txt.sigstore.json":
345 sigstoreBundleAsset = &release.Assets[i]
346 }
347 }
348 if checksumsAsset == nil {
349 return "", fmt.Errorf("checksums.txt not found in release")
350 }
351
352 checksums, checksumsContent, err := downloadChecksums(ctx, client, checksumsAsset)
353 if err != nil {
354 return "", fmt.Errorf("failed to download checksums: %w", err)
355 }
356
357 // Verify checksums.txt using sigstore bundle if available.
358 if sigstoreBundleAsset != nil {
359 if err := verifySigstoreBundle(ctx, client, checksumsContent, sigstoreBundleAsset); err != nil {
360 return "", fmt.Errorf("checksums.txt signature verification failed: %w", err)
361 }
362 }
363
364 // Validate asset size from API before downloading.
365 if asset.Size > maxArchiveSize {
366 return "", fmt.Errorf("asset size %d exceeds maximum allowed size of %d bytes", asset.Size, maxArchiveSize)
367 }
368
369 req, err := http.NewRequestWithContext(ctx, "GET", asset.BrowserDownloadURL, nil)
370 if err != nil {
371 return "", err
372 }
373 req.Header.Set("User-Agent", userAgent())
374
375 resp, err := client.Do(req)
376 if err != nil {
377 return "", err
378 }
379 defer resp.Body.Close()
380
381 if resp.StatusCode != http.StatusOK {
382 return "", fmt.Errorf("download failed with status %d", resp.StatusCode)
383 }
384
385 // Validate Content-Length if provided (defense in depth, asset.Size already checked).
386 if resp.ContentLength > maxArchiveSize {
387 return "", fmt.Errorf("archive size %d exceeds maximum allowed size of %d bytes", resp.ContentLength, maxArchiveSize)
388 }
389
390 // Create temp file for archive.
391 tmpFile, err := os.CreateTemp("", "crush-update-*")
392 if err != nil {
393 return "", fmt.Errorf("failed to create temp file: %w", err)
394 }
395 tmpFileName := tmpFile.Name()
396 defer os.Remove(tmpFileName)
397
398 // Download to temp file while computing checksum.
399 // Use LimitReader to prevent DoS from oversized downloads (Content-Length can be spoofed).
400 hash := sha256.New()
401 limitedBody := io.LimitReader(resp.Body, maxArchiveSize+1)
402 written, err := copyWithContext(ctx, io.MultiWriter(tmpFile, hash), limitedBody)
403 tmpFile.Close() // Close immediately after writing, before size check.
404 if err != nil {
405 return "", fmt.Errorf("failed to download: %w", err)
406 }
407 if written > maxArchiveSize {
408 return "", fmt.Errorf("archive size %d exceeds maximum allowed size of %d bytes", written, maxArchiveSize)
409 }
410
411 // Verify checksum.
412 actualSum := hex.EncodeToString(hash.Sum(nil))
413 expectedSum, ok := checksums[asset.Name]
414 if !ok {
415 return "", fmt.Errorf("no checksum found for %s", asset.Name)
416 }
417 if actualSum != expectedSum {
418 return "", fmt.Errorf("checksum mismatch: expected %s, got %s", expectedSum, actualSum)
419 }
420
421 // Extract binary based on archive type.
422 var binaryPath string
423 if strings.HasSuffix(asset.Name, ".zip") {
424 binaryPath, err = extractZip(tmpFileName)
425 } else {
426 binaryPath, err = extractTarGz(tmpFileName)
427 }
428 if err != nil {
429 if binaryPath != "" {
430 os.Remove(binaryPath)
431 }
432 return "", fmt.Errorf("failed to extract: %w", err)
433 }
434
435 return binaryPath, nil
436}
437
438// downloadChecksums downloads, verifies, and parses the checksums.txt file.
439// The checksums.txt file is verified against GitHub's API-provided digest before
440// parsing to ensure integrity. Returns a map of filename to sha256 checksum,
441// and the raw content bytes for sigstore verification.
442func downloadChecksums(ctx context.Context, client *http.Client, asset *Asset) (map[string]string, []byte, error) {
443 req, err := http.NewRequestWithContext(ctx, "GET", asset.BrowserDownloadURL, nil)
444 if err != nil {
445 return nil, nil, err
446 }
447 req.Header.Set("User-Agent", userAgent())
448
449 resp, err := client.Do(req)
450 if err != nil {
451 return nil, nil, err
452 }
453 defer resp.Body.Close()
454
455 if resp.StatusCode != http.StatusOK {
456 return nil, nil, fmt.Errorf("failed to download checksums with status %d", resp.StatusCode)
457 }
458
459 // Limit the size of checksums.txt to prevent memory exhaustion.
460 limitedBody := io.LimitReader(resp.Body, maxChecksumsSize+1)
461
462 // Read content while computing hash for verification against GitHub's digest.
463 hash := sha256.New()
464 content, err := io.ReadAll(io.TeeReader(limitedBody, hash))
465 if err != nil {
466 return nil, nil, err
467 }
468 if int64(len(content)) > maxChecksumsSize {
469 return nil, nil, fmt.Errorf("checksums.txt exceeds maximum size of %d bytes", maxChecksumsSize)
470 }
471
472 // Verify against GitHub's API-provided digest if available.
473 if asset.Digest != "" {
474 actualSum := hex.EncodeToString(hash.Sum(nil))
475 expectedSum := parseDigest(asset.Digest)
476 if expectedSum == "" {
477 return nil, nil, fmt.Errorf("invalid digest format from API: %s", asset.Digest)
478 }
479 if actualSum != expectedSum {
480 return nil, nil, fmt.Errorf("checksums.txt digest mismatch: expected %s, got %s", expectedSum, actualSum)
481 }
482 }
483
484 return parseChecksumLines(string(content)), content, nil
485}
486
487// parseDigest extracts the hex checksum from a digest string (format: "sha256:hex").
488// Returns empty string if format is invalid.
489func parseDigest(digest string) string {
490 const prefix = "sha256:"
491 if !strings.HasPrefix(digest, prefix) {
492 return ""
493 }
494 hex := strings.TrimPrefix(digest, prefix)
495 if !isValidSHA256(hex) {
496 return ""
497 }
498 return hex
499}
500
501// isValidSHA256 validates that a string is a valid SHA256 hex checksum.
502func isValidSHA256(s string) bool {
503 if len(s) != 64 {
504 return false
505 }
506 _, err := hex.DecodeString(s)
507 return err == nil
508}
509
510// parseChecksumLines parses checksum lines from a string.
511// Returns a map of filename to sha256 checksum. Invalid checksums are skipped.
512func parseChecksumLines(content string) map[string]string {
513 checksums := make(map[string]string)
514 scanner := bufio.NewScanner(strings.NewReader(content))
515 for scanner.Scan() {
516 line := scanner.Text()
517 parts := strings.Fields(line)
518 if len(parts) != 2 {
519 continue
520 }
521 // Validate checksum format (64 hex chars for SHA256).
522 if !isValidSHA256(parts[0]) {
523 continue
524 }
525 // parts[0] is checksum, parts[1] is filename.
526 checksums[parts[1]] = parts[0]
527 }
528 return checksums
529}
530
531// extractZip extracts the crush binary from a zip archive.
532func extractZip(archivePath string) (string, error) {
533 r, err := zip.OpenReader(archivePath)
534 if err != nil {
535 return "", err
536 }
537 defer r.Close()
538
539 for _, f := range r.File {
540 // Path traversal protection: reject paths with ".." or absolute paths.
541 cleanName := filepath.Clean(f.Name)
542 if strings.Contains(cleanName, "..") || filepath.IsAbs(cleanName) {
543 continue
544 }
545
546 // Use exact name matching to find only the binary we need.
547 if filepath.Base(f.Name) == binaryName() {
548 return extractZipFile(f)
549 }
550 }
551
552 return "", fmt.Errorf("crush binary not found in archive")
553}
554
555// extractZipFile extracts a single file from a zip archive with size limits.
556func extractZipFile(f *zip.File) (string, error) {
557 rc, err := f.Open()
558 if err != nil {
559 return "", err
560 }
561 defer rc.Close()
562
563 tmpBinary, err := os.CreateTemp("", "crush-binary-*")
564 if err != nil {
565 return "", err
566 }
567
568 limitedReader := io.LimitReader(rc, maxBinarySize+1)
569 written, err := io.Copy(tmpBinary, limitedReader)
570 if err != nil {
571 tmpBinary.Close()
572 os.Remove(tmpBinary.Name())
573 return "", err
574 }
575 if written > maxBinarySize {
576 tmpBinary.Close()
577 os.Remove(tmpBinary.Name())
578 return "", fmt.Errorf("binary exceeds maximum size of %d bytes", maxBinarySize)
579 }
580
581 if err := tmpBinary.Chmod(0o755); err != nil {
582 tmpBinary.Close()
583 os.Remove(tmpBinary.Name())
584 return "", err
585 }
586
587 if err := tmpBinary.Close(); err != nil {
588 os.Remove(tmpBinary.Name())
589 return "", err
590 }
591
592 return tmpBinary.Name(), nil
593}
594
595// extractTarGz extracts the crush binary from a tar.gz archive.
596func extractTarGz(archivePath string) (string, error) {
597 f, err := os.Open(archivePath)
598 if err != nil {
599 return "", err
600 }
601 defer f.Close()
602
603 gzr, err := gzip.NewReader(f)
604 if err != nil {
605 return "", err
606 }
607 defer gzr.Close()
608
609 tr := tar.NewReader(gzr)
610
611 for {
612 header, err := tr.Next()
613 if err == io.EOF {
614 break
615 }
616 if err != nil {
617 return "", err
618 }
619
620 // Only process regular files, skip directories/symlinks/etc.
621 if header.Typeflag != tar.TypeReg {
622 continue
623 }
624
625 // Path traversal protection: reject paths with ".." or absolute paths.
626 cleanName := filepath.Clean(header.Name)
627 if strings.Contains(cleanName, "..") || filepath.IsAbs(cleanName) {
628 continue
629 }
630
631 // Use exact name matching to find only the binary we need.
632 if filepath.Base(header.Name) == binaryName() {
633 return extractTarFile(tr)
634 }
635 }
636
637 return "", fmt.Errorf("crush binary not found in archive")
638}
639
640// extractTarFile extracts a single file from a tar reader with size limits.
641func extractTarFile(tr *tar.Reader) (string, error) {
642 tmpBinary, err := os.CreateTemp("", "crush-binary-*")
643 if err != nil {
644 return "", err
645 }
646
647 limitedReader := io.LimitReader(tr, maxBinarySize+1)
648 written, err := io.Copy(tmpBinary, limitedReader)
649 if err != nil {
650 tmpBinary.Close()
651 os.Remove(tmpBinary.Name())
652 return "", err
653 }
654 if written > maxBinarySize {
655 tmpBinary.Close()
656 os.Remove(tmpBinary.Name())
657 return "", fmt.Errorf("binary exceeds maximum size of %d bytes", maxBinarySize)
658 }
659
660 if err := tmpBinary.Chmod(0o755); err != nil {
661 tmpBinary.Close()
662 os.Remove(tmpBinary.Name())
663 return "", err
664 }
665
666 if err := tmpBinary.Close(); err != nil {
667 os.Remove(tmpBinary.Name())
668 return "", err
669 }
670
671 return tmpBinary.Name(), nil
672}
673
674// checkWritePermission checks if we can write to the given directory.
675func checkWritePermission(dir string) error {
676 testFile := filepath.Join(dir, ".crush-update-test")
677 f, err := os.Create(testFile)
678 if err != nil {
679 return err
680 }
681 f.Close()
682 return os.Remove(testFile)
683}
684
685// copyFile copies a file from src to dst, preserving permissions.
686func copyFile(src, dst string) error {
687 srcFile, err := os.Open(src)
688 if err != nil {
689 return err
690 }
691 defer srcFile.Close()
692
693 srcInfo, err := srcFile.Stat()
694 if err != nil {
695 return err
696 }
697
698 dstFile, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, srcInfo.Mode())
699 if err != nil {
700 return err
701 }
702 defer dstFile.Close()
703
704 if _, err = io.Copy(dstFile, srcFile); err != nil {
705 return err
706 }
707
708 // Sync to ensure data is persisted before atomic rename.
709 return dstFile.Sync()
710}
711
712// copyWithContext copies from src to dst while periodically checking for context
713// cancellation. This allows large downloads to be interrupted when the user
714// quits the application.
715func copyWithContext(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) {
716 const bufSize = 32 * 1024 // 32KB chunks
717 buf := make([]byte, bufSize)
718 var written int64
719
720 for {
721 // Check for cancellation before each chunk.
722 select {
723 case <-ctx.Done():
724 return written, ctx.Err()
725 default:
726 }
727
728 nr, readErr := src.Read(buf)
729 if nr > 0 {
730 nw, writeErr := dst.Write(buf[:nr])
731 if nw < 0 || nr < nw {
732 nw = 0
733 if writeErr == nil {
734 writeErr = fmt.Errorf("invalid write result")
735 }
736 }
737 written += int64(nw)
738 if writeErr != nil {
739 return written, writeErr
740 }
741 if nr != nw {
742 return written, io.ErrShortWrite
743 }
744 }
745 if readErr != nil {
746 if readErr == io.EOF {
747 return written, nil
748 }
749 return written, readErr
750 }
751 }
752}