update.go

  1package update
  2
  3import (
  4	"archive/tar"
  5	"archive/zip"
  6	"bufio"
  7	"compress/gzip"
  8	"context"
  9	"crypto/sha256"
 10	"encoding/hex"
 11	"encoding/json"
 12	"fmt"
 13	"io"
 14	"net/http"
 15	"os"
 16	"path/filepath"
 17	"regexp"
 18	"runtime"
 19	"strings"
 20	"time"
 21
 22	"github.com/Masterminds/semver/v3"
 23	"github.com/charmbracelet/crush/internal/version"
 24)
 25
 26const (
 27	githubAPIURL     = "https://api.github.com/repos/charmbracelet/crush/releases/latest"
 28	maxBinarySize    = 500 * 1024 * 1024 // 500MB max for extracted binary
 29	maxArchiveSize   = 500 * 1024 * 1024 // 500MB max for downloaded archive
 30	maxChecksumsSize = 1 * 1024 * 1024   // 1MB max for checksums.txt
 31)
 32
 33// binaryName returns the expected binary name for the current platform.
 34func binaryName() string {
 35	if runtime.GOOS == "windows" {
 36		return "crush.exe"
 37	}
 38	return "crush"
 39}
 40
 41// userAgent returns the user agent string for HTTP requests.
 42func userAgent() string {
 43	return "crush/" + version.Version
 44}
 45
 46// Default is the default [Client].
 47var Default Client = &github{}
 48
 49// Info contains information about an available update.
 50type Info struct {
 51	Current string
 52	Latest  string
 53	URL     string
 54	Release *Release
 55}
 56
 57// goInstallRegexp matches pseudo-versions from go install:
 58// v0.0.0-0.20251231235959-06c807842604
 59var goInstallRegexp = regexp.MustCompile(`^v?\d+\.\d+\.\d+-\d+\.\d{14}-[0-9a-f]{12}$`)
 60
 61// gitDescribeRegexp matches git describe versions:
 62// v0.19.0-15-g1a2b3c4d (tag-commits-ghash)
 63var gitDescribeRegexp = regexp.MustCompile(`^v?\d+\.\d+\.\d+-\d+-g[0-9a-f]+$`)
 64
 65// IsDevelopment returns true if the current version appears to be a
 66// development build rather than an official release.
 67func (i Info) IsDevelopment() bool {
 68	return i.Current == "devel" ||
 69		i.Current == "unknown" ||
 70		strings.Contains(i.Current, "dirty") ||
 71		goInstallRegexp.MatchString(i.Current) ||
 72		gitDescribeRegexp.MatchString(i.Current)
 73}
 74
 75// DevelopmentVersionBrief returns a brief message for development versions
 76// suitable for the version flag output.
 77func (i Info) DevelopmentVersionBrief() string {
 78	return fmt.Sprintf(
 79		"\nThis is a development version of Crush. The latest stable release is v%s.\nRun 'crush update' to learn more.\n",
 80		i.Latest,
 81	)
 82}
 83
 84// DevelopmentVersionMessage returns a detailed message for development versions
 85// suitable for CLI command output. If selfUpdateNote is true, includes a note
 86// that self-update is not supported.
 87func (i Info) DevelopmentVersionMessage(selfUpdateNote bool) string {
 88	var b strings.Builder
 89	fmt.Fprintf(&b, "You are running a development version of Crush (%s).\n", i.Current)
 90	if selfUpdateNote {
 91		b.WriteString("Self-update is not supported for development versions.\n")
 92	} else {
 93		fmt.Fprintf(&b, "The latest stable release is v%s.\n", i.Latest)
 94	}
 95	b.WriteString("To install the latest stable version, run:\n")
 96	b.WriteString("  go install github.com/charmbracelet/crush@latest\n")
 97	if i.URL != "" && !selfUpdateNote {
 98		fmt.Fprintf(&b, "Or visit %s to download manually.\n", i.URL)
 99	}
100	return b.String()
101}
102
103// Available returns true if there's an update available.
104// Uses proper semver comparison to handle version ordering correctly.
105// Returns false if either version cannot be parsed.
106// Special case: if current is stable and latest is a prerelease, returns false
107// (we don't offer prerelease updates to stable users).
108func (i Info) Available() bool {
109	current, err := semver.NewVersion(i.Current)
110	if err != nil {
111		return false
112	}
113	latest, err := semver.NewVersion(i.Latest)
114	if err != nil {
115		return false
116	}
117
118	// Don't offer prerelease updates to stable users.
119	if current.Prerelease() == "" && latest.Prerelease() != "" {
120		return false
121	}
122
123	return latest.GreaterThan(current)
124}
125
126// InstallMethod represents how Crush was installed.
127type InstallMethod int
128
129const (
130	InstallMethodUnknown InstallMethod = iota
131	InstallMethodBinary                // Direct binary download
132	InstallMethodHomebrew
133	InstallMethodNPM
134	InstallMethodAUR
135	InstallMethodNix
136	InstallMethodWinget
137	InstallMethodScoop
138	InstallMethodApt
139	InstallMethodYum
140	InstallMethodGoInstall
141)
142
143// String returns a human-readable name for the install method.
144func (m InstallMethod) String() string {
145	switch m {
146	case InstallMethodBinary:
147		return "binary"
148	case InstallMethodHomebrew:
149		return "Homebrew"
150	case InstallMethodNPM:
151		return "npm"
152	case InstallMethodAUR:
153		return "AUR"
154	case InstallMethodNix:
155		return "Nix"
156	case InstallMethodWinget:
157		return "winget"
158	case InstallMethodScoop:
159		return "Scoop"
160	case InstallMethodApt:
161		return "apt"
162	case InstallMethodYum:
163		return "yum"
164	case InstallMethodGoInstall:
165		return "go install"
166	default:
167		return "unknown"
168	}
169}
170
171// CanSelfUpdate returns true if this install method supports self-updating.
172func (m InstallMethod) CanSelfUpdate() bool {
173	return m == InstallMethodBinary || m == InstallMethodUnknown
174}
175
176// UpdateInstructions returns the command to update Crush for this install method.
177func (m InstallMethod) UpdateInstructions() string {
178	switch m {
179	case InstallMethodHomebrew:
180		return "brew upgrade charmbracelet/tap/crush"
181	case InstallMethodNPM:
182		return "npm update -g @charmland/crush"
183	case InstallMethodAUR:
184		return "yay -Syu crush-bin  # or your preferred AUR helper"
185	case InstallMethodNix:
186		return "nix flake update  # or update your NUR channel"
187	case InstallMethodWinget:
188		return "winget upgrade charmbracelet.crush"
189	case InstallMethodScoop:
190		return "scoop update crush"
191	case InstallMethodApt:
192		return "sudo apt update && sudo apt upgrade crush"
193	case InstallMethodYum:
194		return "sudo yum update crush"
195	case InstallMethodGoInstall:
196		return "go install github.com/charmbracelet/crush@latest"
197	default:
198		return ""
199	}
200}
201
202// DetectInstallMethod attempts to determine how Crush was installed.
203// This is implemented per-platform in update_darwin.go, update_unix.go,
204// and update_windows.go.
205func DetectInstallMethod() InstallMethod {
206	exe, err := os.Executable()
207	if err != nil {
208		return InstallMethodUnknown
209	}
210	exe, err = filepath.EvalSymlinks(exe)
211	if err != nil {
212		return InstallMethodUnknown
213	}
214	return detectInstallMethod(exe)
215}
216
217// Check checks if a new version is available.
218func Check(ctx context.Context, current string, client Client) (Info, error) {
219	info := Info{
220		Current: current,
221		Latest:  current,
222	}
223
224	release, err := client.Latest(ctx)
225	if err != nil {
226		return info, fmt.Errorf("failed to fetch latest release: %w", err)
227	}
228
229	info.Latest = strings.TrimPrefix(release.TagName, "v")
230	info.Current = strings.TrimPrefix(info.Current, "v")
231	info.URL = release.HTMLURL
232	info.Release = release
233	return info, nil
234}
235
236// Asset represents a GitHub release asset.
237type Asset struct {
238	Name               string `json:"name"`
239	BrowserDownloadURL string `json:"browser_download_url"`
240	Size               int64  `json:"size"`
241	Digest             string `json:"digest"` // Format: "sha256:hexstring"
242}
243
244// Release represents a GitHub release.
245type Release struct {
246	TagName string  `json:"tag_name"`
247	HTMLURL string  `json:"html_url"`
248	Assets  []Asset `json:"assets"`
249}
250
251// Client is a client that can get the latest release.
252type Client interface {
253	Latest(ctx context.Context) (*Release, error)
254}
255
256type github struct{}
257
258// Latest implements [Client].
259func (c *github) Latest(ctx context.Context) (*Release, error) {
260	client := &http.Client{
261		Timeout: 30 * time.Second,
262	}
263
264	req, err := http.NewRequestWithContext(ctx, "GET", githubAPIURL, nil)
265	if err != nil {
266		return nil, err
267	}
268	req.Header.Set("User-Agent", userAgent())
269	req.Header.Set("Accept", "application/vnd.github.v3+json")
270
271	resp, err := client.Do(req)
272	if err != nil {
273		return nil, err
274	}
275	defer resp.Body.Close()
276
277	if resp.StatusCode != http.StatusOK {
278		body, _ := io.ReadAll(resp.Body)
279		return nil, fmt.Errorf("github api returned status %d: %s", resp.StatusCode, string(body))
280	}
281
282	var release Release
283	if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
284		return nil, err
285	}
286
287	return &release, nil
288}
289
290// FindAsset finds the appropriate asset for the current platform.
291func FindAsset(assets []Asset) (*Asset, error) {
292	// Normalize architecture to match goreleaser naming.
293	arch := runtime.GOARCH
294	switch arch {
295	case "amd64":
296		arch = "x86_64"
297	case "386":
298		arch = "i386"
299	case "arm":
300		arch = "armv7"
301		// arm64 stays as "arm64" in goreleaser naming.
302	}
303
304	// Normalize OS to match goreleaser naming (title case).
305	goos := runtime.GOOS
306	switch goos {
307	case "freebsd":
308		goos = "Freebsd"
309	case "netbsd":
310		goos = "Netbsd"
311	case "openbsd":
312		goos = "Openbsd"
313	default:
314		if len(goos) > 0 {
315			goos = strings.ToUpper(goos[:1]) + goos[1:]
316		}
317	}
318
319	// Look for archive matching our platform.
320	// Pattern: crush_{version}_{OS}_{ARCH}.{tar.gz|zip}
321	for _, asset := range assets {
322		if strings.Contains(asset.Name, goos) && strings.Contains(asset.Name, arch) {
323			// Ensure it's an archive, not a checksum or signature.
324			if strings.HasSuffix(asset.Name, ".tar.gz") || strings.HasSuffix(asset.Name, ".zip") {
325				return &asset, nil
326			}
327		}
328	}
329
330	return nil, fmt.Errorf("no suitable asset found for %s/%s", runtime.GOOS, runtime.GOARCH)
331}
332
333// Download downloads and extracts the crush binary from the given asset.
334// Returns the path to the extracted binary.
335func Download(ctx context.Context, asset *Asset, release *Release) (string, error) {
336	client := &http.Client{}
337
338	// Find checksums.txt and its sigstore bundle.
339	var checksumsAsset, sigstoreBundleAsset *Asset
340	for i := range release.Assets {
341		switch release.Assets[i].Name {
342		case "checksums.txt":
343			checksumsAsset = &release.Assets[i]
344		case "checksums.txt.sigstore.json":
345			sigstoreBundleAsset = &release.Assets[i]
346		}
347	}
348	if checksumsAsset == nil {
349		return "", fmt.Errorf("checksums.txt not found in release")
350	}
351
352	checksums, checksumsContent, err := downloadChecksums(ctx, client, checksumsAsset)
353	if err != nil {
354		return "", fmt.Errorf("failed to download checksums: %w", err)
355	}
356
357	// Verify checksums.txt using sigstore bundle if available.
358	if sigstoreBundleAsset != nil {
359		if err := verifySigstoreBundle(ctx, client, checksumsContent, sigstoreBundleAsset); err != nil {
360			return "", fmt.Errorf("checksums.txt signature verification failed: %w", err)
361		}
362	}
363
364	// Validate asset size from API before downloading.
365	if asset.Size > maxArchiveSize {
366		return "", fmt.Errorf("asset size %d exceeds maximum allowed size of %d bytes", asset.Size, maxArchiveSize)
367	}
368
369	req, err := http.NewRequestWithContext(ctx, "GET", asset.BrowserDownloadURL, nil)
370	if err != nil {
371		return "", err
372	}
373	req.Header.Set("User-Agent", userAgent())
374
375	resp, err := client.Do(req)
376	if err != nil {
377		return "", err
378	}
379	defer resp.Body.Close()
380
381	if resp.StatusCode != http.StatusOK {
382		return "", fmt.Errorf("download failed with status %d", resp.StatusCode)
383	}
384
385	// Validate Content-Length if provided (defense in depth, asset.Size already checked).
386	if resp.ContentLength > maxArchiveSize {
387		return "", fmt.Errorf("archive size %d exceeds maximum allowed size of %d bytes", resp.ContentLength, maxArchiveSize)
388	}
389
390	// Create temp file for archive.
391	tmpFile, err := os.CreateTemp("", "crush-update-*")
392	if err != nil {
393		return "", fmt.Errorf("failed to create temp file: %w", err)
394	}
395	tmpFileName := tmpFile.Name()
396	defer os.Remove(tmpFileName)
397
398	// Download to temp file while computing checksum.
399	// Use LimitReader to prevent DoS from oversized downloads (Content-Length can be spoofed).
400	hash := sha256.New()
401	limitedBody := io.LimitReader(resp.Body, maxArchiveSize+1)
402	written, err := copyWithContext(ctx, io.MultiWriter(tmpFile, hash), limitedBody)
403	tmpFile.Close() // Close immediately after writing, before size check.
404	if err != nil {
405		return "", fmt.Errorf("failed to download: %w", err)
406	}
407	if written > maxArchiveSize {
408		return "", fmt.Errorf("archive size %d exceeds maximum allowed size of %d bytes", written, maxArchiveSize)
409	}
410
411	// Verify checksum.
412	actualSum := hex.EncodeToString(hash.Sum(nil))
413	expectedSum, ok := checksums[asset.Name]
414	if !ok {
415		return "", fmt.Errorf("no checksum found for %s", asset.Name)
416	}
417	if actualSum != expectedSum {
418		return "", fmt.Errorf("checksum mismatch: expected %s, got %s", expectedSum, actualSum)
419	}
420
421	// Extract binary based on archive type.
422	var binaryPath string
423	if strings.HasSuffix(asset.Name, ".zip") {
424		binaryPath, err = extractZip(tmpFileName)
425	} else {
426		binaryPath, err = extractTarGz(tmpFileName)
427	}
428	if err != nil {
429		if binaryPath != "" {
430			os.Remove(binaryPath)
431		}
432		return "", fmt.Errorf("failed to extract: %w", err)
433	}
434
435	return binaryPath, nil
436}
437
438// downloadChecksums downloads, verifies, and parses the checksums.txt file.
439// The checksums.txt file is verified against GitHub's API-provided digest before
440// parsing to ensure integrity. Returns a map of filename to sha256 checksum,
441// and the raw content bytes for sigstore verification.
442func downloadChecksums(ctx context.Context, client *http.Client, asset *Asset) (map[string]string, []byte, error) {
443	req, err := http.NewRequestWithContext(ctx, "GET", asset.BrowserDownloadURL, nil)
444	if err != nil {
445		return nil, nil, err
446	}
447	req.Header.Set("User-Agent", userAgent())
448
449	resp, err := client.Do(req)
450	if err != nil {
451		return nil, nil, err
452	}
453	defer resp.Body.Close()
454
455	if resp.StatusCode != http.StatusOK {
456		return nil, nil, fmt.Errorf("failed to download checksums with status %d", resp.StatusCode)
457	}
458
459	// Limit the size of checksums.txt to prevent memory exhaustion.
460	limitedBody := io.LimitReader(resp.Body, maxChecksumsSize+1)
461
462	// Read content while computing hash for verification against GitHub's digest.
463	hash := sha256.New()
464	content, err := io.ReadAll(io.TeeReader(limitedBody, hash))
465	if err != nil {
466		return nil, nil, err
467	}
468	if int64(len(content)) > maxChecksumsSize {
469		return nil, nil, fmt.Errorf("checksums.txt exceeds maximum size of %d bytes", maxChecksumsSize)
470	}
471
472	// Verify against GitHub's API-provided digest if available.
473	if asset.Digest != "" {
474		actualSum := hex.EncodeToString(hash.Sum(nil))
475		expectedSum := parseDigest(asset.Digest)
476		if expectedSum == "" {
477			return nil, nil, fmt.Errorf("invalid digest format from API: %s", asset.Digest)
478		}
479		if actualSum != expectedSum {
480			return nil, nil, fmt.Errorf("checksums.txt digest mismatch: expected %s, got %s", expectedSum, actualSum)
481		}
482	}
483
484	return parseChecksumLines(string(content)), content, nil
485}
486
487// parseDigest extracts the hex checksum from a digest string (format: "sha256:hex").
488// Returns empty string if format is invalid.
489func parseDigest(digest string) string {
490	const prefix = "sha256:"
491	if !strings.HasPrefix(digest, prefix) {
492		return ""
493	}
494	hex := strings.TrimPrefix(digest, prefix)
495	if !isValidSHA256(hex) {
496		return ""
497	}
498	return hex
499}
500
501// isValidSHA256 validates that a string is a valid SHA256 hex checksum.
502func isValidSHA256(s string) bool {
503	if len(s) != 64 {
504		return false
505	}
506	_, err := hex.DecodeString(s)
507	return err == nil
508}
509
510// parseChecksumLines parses checksum lines from a string.
511// Returns a map of filename to sha256 checksum. Invalid checksums are skipped.
512func parseChecksumLines(content string) map[string]string {
513	checksums := make(map[string]string)
514	scanner := bufio.NewScanner(strings.NewReader(content))
515	for scanner.Scan() {
516		line := scanner.Text()
517		parts := strings.Fields(line)
518		if len(parts) != 2 {
519			continue
520		}
521		// Validate checksum format (64 hex chars for SHA256).
522		if !isValidSHA256(parts[0]) {
523			continue
524		}
525		// parts[0] is checksum, parts[1] is filename.
526		checksums[parts[1]] = parts[0]
527	}
528	return checksums
529}
530
531// extractZip extracts the crush binary from a zip archive.
532func extractZip(archivePath string) (string, error) {
533	r, err := zip.OpenReader(archivePath)
534	if err != nil {
535		return "", err
536	}
537	defer r.Close()
538
539	for _, f := range r.File {
540		// Path traversal protection: reject paths with ".." or absolute paths.
541		cleanName := filepath.Clean(f.Name)
542		if strings.Contains(cleanName, "..") || filepath.IsAbs(cleanName) {
543			continue
544		}
545
546		// Use exact name matching to find only the binary we need.
547		if filepath.Base(f.Name) == binaryName() {
548			return extractZipFile(f)
549		}
550	}
551
552	return "", fmt.Errorf("crush binary not found in archive")
553}
554
555// extractZipFile extracts a single file from a zip archive with size limits.
556func extractZipFile(f *zip.File) (string, error) {
557	rc, err := f.Open()
558	if err != nil {
559		return "", err
560	}
561	defer rc.Close()
562
563	tmpBinary, err := os.CreateTemp("", "crush-binary-*")
564	if err != nil {
565		return "", err
566	}
567
568	limitedReader := io.LimitReader(rc, maxBinarySize+1)
569	written, err := io.Copy(tmpBinary, limitedReader)
570	if err != nil {
571		tmpBinary.Close()
572		os.Remove(tmpBinary.Name())
573		return "", err
574	}
575	if written > maxBinarySize {
576		tmpBinary.Close()
577		os.Remove(tmpBinary.Name())
578		return "", fmt.Errorf("binary exceeds maximum size of %d bytes", maxBinarySize)
579	}
580
581	if err := tmpBinary.Chmod(0o755); err != nil {
582		tmpBinary.Close()
583		os.Remove(tmpBinary.Name())
584		return "", err
585	}
586
587	if err := tmpBinary.Close(); err != nil {
588		os.Remove(tmpBinary.Name())
589		return "", err
590	}
591
592	return tmpBinary.Name(), nil
593}
594
595// extractTarGz extracts the crush binary from a tar.gz archive.
596func extractTarGz(archivePath string) (string, error) {
597	f, err := os.Open(archivePath)
598	if err != nil {
599		return "", err
600	}
601	defer f.Close()
602
603	gzr, err := gzip.NewReader(f)
604	if err != nil {
605		return "", err
606	}
607	defer gzr.Close()
608
609	tr := tar.NewReader(gzr)
610
611	for {
612		header, err := tr.Next()
613		if err == io.EOF {
614			break
615		}
616		if err != nil {
617			return "", err
618		}
619
620		// Only process regular files, skip directories/symlinks/etc.
621		if header.Typeflag != tar.TypeReg {
622			continue
623		}
624
625		// Path traversal protection: reject paths with ".." or absolute paths.
626		cleanName := filepath.Clean(header.Name)
627		if strings.Contains(cleanName, "..") || filepath.IsAbs(cleanName) {
628			continue
629		}
630
631		// Use exact name matching to find only the binary we need.
632		if filepath.Base(header.Name) == binaryName() {
633			return extractTarFile(tr)
634		}
635	}
636
637	return "", fmt.Errorf("crush binary not found in archive")
638}
639
640// extractTarFile extracts a single file from a tar reader with size limits.
641func extractTarFile(tr *tar.Reader) (string, error) {
642	tmpBinary, err := os.CreateTemp("", "crush-binary-*")
643	if err != nil {
644		return "", err
645	}
646
647	limitedReader := io.LimitReader(tr, maxBinarySize+1)
648	written, err := io.Copy(tmpBinary, limitedReader)
649	if err != nil {
650		tmpBinary.Close()
651		os.Remove(tmpBinary.Name())
652		return "", err
653	}
654	if written > maxBinarySize {
655		tmpBinary.Close()
656		os.Remove(tmpBinary.Name())
657		return "", fmt.Errorf("binary exceeds maximum size of %d bytes", maxBinarySize)
658	}
659
660	if err := tmpBinary.Chmod(0o755); err != nil {
661		tmpBinary.Close()
662		os.Remove(tmpBinary.Name())
663		return "", err
664	}
665
666	if err := tmpBinary.Close(); err != nil {
667		os.Remove(tmpBinary.Name())
668		return "", err
669	}
670
671	return tmpBinary.Name(), nil
672}
673
674// checkWritePermission checks if we can write to the given directory.
675func checkWritePermission(dir string) error {
676	testFile := filepath.Join(dir, ".crush-update-test")
677	f, err := os.Create(testFile)
678	if err != nil {
679		return err
680	}
681	f.Close()
682	return os.Remove(testFile)
683}
684
685// copyFile copies a file from src to dst, preserving permissions.
686func copyFile(src, dst string) error {
687	srcFile, err := os.Open(src)
688	if err != nil {
689		return err
690	}
691	defer srcFile.Close()
692
693	srcInfo, err := srcFile.Stat()
694	if err != nil {
695		return err
696	}
697
698	dstFile, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, srcInfo.Mode())
699	if err != nil {
700		return err
701	}
702	defer dstFile.Close()
703
704	if _, err = io.Copy(dstFile, srcFile); err != nil {
705		return err
706	}
707
708	// Sync to ensure data is persisted before atomic rename.
709	return dstFile.Sync()
710}
711
712// copyWithContext copies from src to dst while periodically checking for context
713// cancellation. This allows large downloads to be interrupted when the user
714// quits the application.
715func copyWithContext(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) {
716	const bufSize = 32 * 1024 // 32KB chunks
717	buf := make([]byte, bufSize)
718	var written int64
719
720	for {
721		// Check for cancellation before each chunk.
722		select {
723		case <-ctx.Done():
724			return written, ctx.Err()
725		default:
726		}
727
728		nr, readErr := src.Read(buf)
729		if nr > 0 {
730			nw, writeErr := dst.Write(buf[:nr])
731			if nw < 0 || nr < nw {
732				nw = 0
733				if writeErr == nil {
734					writeErr = fmt.Errorf("invalid write result")
735				}
736			}
737			written += int64(nw)
738			if writeErr != nil {
739				return written, writeErr
740			}
741			if nr != nw {
742				return written, io.ErrShortWrite
743			}
744		}
745		if readErr != nil {
746			if readErr == io.EOF {
747				return written, nil
748			}
749			return written, readErr
750		}
751	}
752}