diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 17c23aeabfd6aaca438d30a0d59cf014c3134f4b..64f81423a5e109dbfcc4cee2238c69dc6da03f54 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -11,6 +11,7 @@ import ( "path/filepath" "strconv" "strings" + "time" tea "charm.land/bubbletea/v2" "charm.land/lipgloss/v2" @@ -22,6 +23,7 @@ import ( "github.com/charmbracelet/crush/internal/stringext" termutil "github.com/charmbracelet/crush/internal/term" "github.com/charmbracelet/crush/internal/tui" + "github.com/charmbracelet/crush/internal/update" "github.com/charmbracelet/crush/internal/version" "github.com/charmbracelet/fang" uv "github.com/charmbracelet/ultraviolet" @@ -41,6 +43,7 @@ func init() { rootCmd.AddCommand( runCmd, dirsCmd, + updateCmd, updateProvidersCmd, logsCmd, schemaCmd, @@ -135,13 +138,24 @@ func Execute() { // printing the version, and PreRunE runs after the version is already // handled, so that doesn't work either. // This is the only way I could find that works relatively well. + versionTemplate := defaultVersionTemplate if term.IsTerminal(os.Stdout.Fd()) { var b bytes.Buffer w := colorprofile.NewWriter(os.Stdout, os.Environ()) w.Forward = &b _, _ = w.WriteString(heartbit.String()) - rootCmd.SetVersionTemplate(b.String() + "\n" + defaultVersionTemplate) + versionTemplate = b.String() + "\n" + defaultVersionTemplate } + + // Check if version flag is present and add update notification if available. + if hasVersionFlag() { + if updateMsg := checkForUpdateSync(); updateMsg != "" { + versionTemplate += updateMsg + } + } + + rootCmd.SetVersionTemplate(versionTemplate) + if err := fang.Execute( context.Background(), rootCmd, @@ -152,6 +166,34 @@ func Execute() { } } +// hasVersionFlag checks if the version flag is present in os.Args. +func hasVersionFlag() bool { + for _, arg := range os.Args { + if arg == "-v" || arg == "--version" { + return true + } + } + return false +} + +// checkForUpdateSync performs a synchronous update check with a short timeout. +// Returns a formatted update message if an update is available, empty string otherwise. +func checkForUpdateSync() string { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + info, err := update.Check(ctx, version.Version, update.Default) + if err != nil || !info.Available() { + return "" + } + + if info.IsDevelopment() { + return fmt.Sprintf("\nThis is a development version of Crush. The latest stable release is v%s.\nRun 'crush update' to learn more.\n", info.Latest) + } + + return fmt.Sprintf("\nUpdate available: v%s → v%s\nRun 'crush update apply' to install.\n", info.Current, info.Latest) +} + func setupAppWithProgressBar(cmd *cobra.Command) (*app.App, error) { if termutil.SupportsProgressBar() { _, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar) diff --git a/internal/cmd/update.go b/internal/cmd/update.go new file mode 100644 index 0000000000000000000000000000000000000000..0fcd611567feff0e07b3ffb9dc6dfdbe98341390 --- /dev/null +++ b/internal/cmd/update.go @@ -0,0 +1,165 @@ +package cmd + +import ( + "context" + "fmt" + "os" + "time" + + "charm.land/lipgloss/v2" + "github.com/charmbracelet/crush/internal/format" + "github.com/charmbracelet/crush/internal/tui/components/anim" + "github.com/charmbracelet/crush/internal/tui/styles" + "github.com/charmbracelet/crush/internal/update" + "github.com/charmbracelet/crush/internal/version" + "github.com/charmbracelet/x/term" + "github.com/spf13/cobra" +) + +var updateCmd = &cobra.Command{ + Use: "update", + Short: "Check for and apply updates", + Long: `Check if a new version of Crush is available. +Use 'update apply' to download and install the latest version.`, + Example: ` +# Check if an update is available +crush update + +# Apply the update if available +crush update apply + `, + RunE: func(cmd *cobra.Command, args []string) error { + ctx, cancel := context.WithTimeout(cmd.Context(), 30*time.Second) + defer cancel() + + spinner := newUpdateSpinner(ctx, cancel, "Checking for updates") + spinner.Start() + + info, err := update.Check(ctx, version.Version, update.Default) + spinner.Stop() + + if err != nil { + return fmt.Errorf("failed to check for updates: %w", err) + } + + if info.IsDevelopment() { + fmt.Fprintf(os.Stderr, "You are running a development version of Crush (%s).\n", info.Current) + fmt.Fprintf(os.Stderr, "The latest stable release is v%s.\n", info.Latest) + fmt.Fprintf(os.Stderr, "Visit %s to learn more.\n", info.URL) + return nil + } + + if !info.Available() { + fmt.Fprintf(os.Stderr, "You are already running the latest version (v%s).\n", info.Current) + return nil + } + + fmt.Fprintf(os.Stderr, "Update available: v%s → v%s\n", info.Current, info.Latest) + fmt.Fprintf(os.Stderr, "Run 'crush update apply' to install the latest version.\n") + fmt.Fprintf(os.Stderr, "Or visit %s to download manually.\n", info.URL) + + return nil + }, +} + +var updateApplyCmd = &cobra.Command{ + Use: "apply", + Short: "Apply the latest update", + Long: `Download and install the latest version of Crush.`, + Example: ` +# Apply the latest update +crush update apply + `, + RunE: func(cmd *cobra.Command, args []string) error { + ctx, cancel := context.WithTimeout(cmd.Context(), 5*time.Minute) + defer cancel() + + spinner := newUpdateSpinner(ctx, cancel, "Checking for updates") + spinner.Start() + + info, err := update.Check(ctx, version.Version, update.Default) + if err != nil { + spinner.Stop() + return fmt.Errorf("failed to check for updates: %w", err) + } + + if info.IsDevelopment() { + spinner.Stop() + return fmt.Errorf("cannot update development versions automatically") + } + + if !info.Available() { + spinner.Stop() + fmt.Fprintf(os.Stderr, "You are already running the latest version (v%s).\n", info.Current) + return nil + } + + // Get the latest release with assets. + release, err := update.Default.Latest(ctx) + if err != nil { + spinner.Stop() + return fmt.Errorf("failed to fetch release information: %w", err) + } + + // Find the appropriate asset for this platform. + asset, err := update.FindAsset(release.Assets) + if err != nil { + spinner.Stop() + return fmt.Errorf("failed to find update for your platform: %w", err) + } + + spinner.Stop() + spinner = newUpdateSpinner(ctx, cancel, fmt.Sprintf("Downloading v%s", info.Latest)) + spinner.Start() + + // Download the asset. + binaryPath, err := update.Download(ctx, asset, release) + if err != nil { + spinner.Stop() + return fmt.Errorf("failed to download update: %w", err) + } + defer os.Remove(binaryPath) + + spinner.Stop() + spinner = newUpdateSpinner(ctx, cancel, "Installing") + spinner.Start() + + // Apply the update. + if err := update.Apply(binaryPath); err != nil { + spinner.Stop() + return fmt.Errorf("failed to apply update: %w", err) + } + + spinner.Stop() + + fmt.Fprintf(os.Stderr, "Successfully updated to v%s!\n", info.Latest) + fmt.Fprintf(os.Stderr, "Run 'crush -v' to verify the new version.\n") + + return nil + }, +} + +// newUpdateSpinner creates a spinner for update operations. +func newUpdateSpinner(ctx context.Context, cancel context.CancelFunc, label string) *format.Spinner { + t := styles.CurrentTheme() + + // Detect background color for appropriate text color. + hasDarkBG := true + if term.IsTerminal(os.Stderr.Fd()) { + hasDarkBG = lipgloss.HasDarkBackground(os.Stdin, os.Stderr) + } + defaultFG := lipgloss.LightDark(hasDarkBG)(lipgloss.Color("#fafafa"), t.FgBase) + + return format.NewSpinner(ctx, cancel, anim.Settings{ + Size: 10, + Label: label, + LabelColor: defaultFG, + GradColorA: t.Primary, + GradColorB: t.Secondary, + CycleColors: true, + }) +} + +func init() { + updateCmd.AddCommand(updateApplyCmd) +} diff --git a/internal/update/update.go b/internal/update/update.go index a813fe3516dc28233e3df01c77d4d62d4d97db18..42b440df8760a77268d53250569a268125dfde8e 100644 --- a/internal/update/update.go +++ b/internal/update/update.go @@ -1,21 +1,42 @@ package update import ( + "archive/tar" + "archive/zip" + "bufio" + "compress/gzip" "context" + "crypto/sha256" + "debug/elf" + "debug/macho" + "debug/pe" + "encoding/hex" "encoding/json" "fmt" "io" "net/http" + "os" + "path/filepath" "regexp" + "runtime" "strings" "time" + + "github.com/charmbracelet/crush/internal/version" ) const ( - githubApiUrl = "https://api.github.com/repos/charmbracelet/crush/releases/latest" - userAgent = "crush/1.0" + githubApiUrl = "https://api.github.com/repos/charmbracelet/crush/releases/latest" + maxBinarySize = 500 * 1024 * 1024 // 500MB max for extracted binary + maxArchiveSize = 500 * 1024 * 1024 // 500MB max for downloaded archive + maxChecksumsSize = 1 * 1024 * 1024 // 1MB max for checksums.txt ) +// userAgent returns the user agent string for HTTP requests. +func userAgent() string { + return "crush/" + version.Version +} + // Default is the default [Client]. var Default Client = &github{} @@ -72,10 +93,17 @@ func Check(ctx context.Context, current string, client Client) (Info, error) { return info, nil } +// Asset represents a GitHub release asset. +type Asset struct { + Name string `json:"name"` + BrowserDownloadURL string `json:"browser_download_url"` +} + // Release represents a GitHub release. type Release struct { - TagName string `json:"tag_name"` - HTMLURL string `json:"html_url"` + TagName string `json:"tag_name"` + HTMLURL string `json:"html_url"` + Assets []Asset `json:"assets"` } // Client is a client that can get the latest release. @@ -95,7 +123,7 @@ func (c *github) Latest(ctx context.Context) (*Release, error) { if err != nil { return nil, err } - req.Header.Set("User-Agent", userAgent) + req.Header.Set("User-Agent", userAgent()) req.Header.Set("Accept", "application/vnd.github.v3+json") resp, err := client.Do(req) @@ -116,3 +144,490 @@ func (c *github) Latest(ctx context.Context) (*Release, error) { return &release, nil } + +// FindAsset finds the appropriate asset for the current platform. +func FindAsset(assets []Asset) (*Asset, error) { + // Normalize architecture to match goreleaser naming. + arch := runtime.GOARCH + switch arch { + case "amd64": + arch = "x86_64" + case "386": + arch = "i386" + case "arm": + arch = "armv7" + // arm64 stays as "arm64" in goreleaser naming. + } + + // Normalize OS to match goreleaser naming (title case). + goos := runtime.GOOS + switch goos { + case "freebsd": + goos = "Freebsd" + case "netbsd": + goos = "Netbsd" + case "openbsd": + goos = "Openbsd" + default: + if len(goos) > 0 { + goos = strings.ToUpper(goos[:1]) + goos[1:] + } + } + + // Look for archive matching our platform. + // Pattern: crush_{version}_{OS}_{ARCH}.{tar.gz|zip} + for _, asset := range assets { + if strings.Contains(asset.Name, goos) && strings.Contains(asset.Name, arch) { + // Ensure it's an archive, not a checksum or signature. + if strings.HasSuffix(asset.Name, ".tar.gz") || strings.HasSuffix(asset.Name, ".zip") { + return &asset, nil + } + } + } + + return nil, fmt.Errorf("no suitable asset found for %s/%s", runtime.GOOS, runtime.GOARCH) +} + +// Download downloads and extracts the crush binary from the given asset. +// Returns the path to the extracted binary. +func Download(ctx context.Context, asset *Asset, release *Release) (string, error) { + client := &http.Client{ + Timeout: 5 * time.Minute, + } + + // Find and download checksums.txt. + var checksumsAsset *Asset + for i := range release.Assets { + if release.Assets[i].Name == "checksums.txt" { + checksumsAsset = &release.Assets[i] + break + } + } + if checksumsAsset == nil { + return "", fmt.Errorf("checksums.txt not found in release") + } + + checksums, err := downloadChecksums(ctx, client, checksumsAsset) + if err != nil { + return "", fmt.Errorf("failed to download checksums: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "GET", asset.BrowserDownloadURL, nil) + if err != nil { + return "", err + } + req.Header.Set("User-Agent", userAgent()) + + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("download failed with status %d", resp.StatusCode) + } + + // Validate Content-Length if provided. + if resp.ContentLength > maxArchiveSize { + return "", fmt.Errorf("archive size %d exceeds maximum allowed size of %d bytes", resp.ContentLength, maxArchiveSize) + } + + // Create temp file for archive. + tmpFile, err := os.CreateTemp("", "crush-update-*") + if err != nil { + return "", fmt.Errorf("failed to create temp file: %w", err) + } + defer os.Remove(tmpFile.Name()) + + // Download to temp file while computing checksum. + // Use LimitReader to prevent DoS from oversized downloads (Content-Length can be spoofed). + hash := sha256.New() + limitedBody := io.LimitReader(resp.Body, maxArchiveSize+1) + written, err := io.Copy(io.MultiWriter(tmpFile, hash), limitedBody) + if err != nil { + tmpFile.Close() + return "", fmt.Errorf("failed to download: %w", err) + } + if written > maxArchiveSize { + tmpFile.Close() + return "", fmt.Errorf("archive size %d exceeds maximum allowed size of %d bytes", written, maxArchiveSize) + } + tmpFile.Close() + + // Verify checksum. + actualSum := hex.EncodeToString(hash.Sum(nil)) + expectedSum, ok := checksums[asset.Name] + if !ok { + return "", fmt.Errorf("no checksum found for %s", asset.Name) + } + if actualSum != expectedSum { + return "", fmt.Errorf("checksum mismatch: expected %s, got %s", expectedSum, actualSum) + } + + // Extract binary based on archive type. + var binaryPath string + if strings.HasSuffix(asset.Name, ".zip") { + binaryPath, err = extractZip(tmpFile.Name()) + } else { + binaryPath, err = extractTarGz(tmpFile.Name()) + } + if err != nil { + return "", fmt.Errorf("failed to extract: %w", err) + } + + // Validate the extracted binary before returning. + if err := validateBinary(binaryPath); err != nil { + os.Remove(binaryPath) + return "", fmt.Errorf("invalid binary: %w", err) + } + + return binaryPath, nil +} + +// validateBinary checks that the file at path is a valid executable binary +// for the current platform using the standard library debug packages. +func validateBinary(path string) error { + switch runtime.GOOS { + case "windows": + return validatePE(path) + case "darwin": + return validateMachO(path) + default: + return validateELF(path) + } +} + +func validateELF(path string) error { + f, err := elf.Open(path) + if err != nil { + return fmt.Errorf("not a valid ELF binary: %w", err) + } + defer f.Close() + if f.Type != elf.ET_EXEC && f.Type != elf.ET_DYN { + return fmt.Errorf("ELF file is not an executable (type: %v)", f.Type) + } + return nil +} + +func validateMachO(path string) error { + f, err := macho.Open(path) + if err != nil { + return fmt.Errorf("not a valid Mach-O binary: %w", err) + } + defer f.Close() + if f.Type != macho.TypeExec { + return fmt.Errorf("Mach-O file is not an executable (type: %v)", f.Type) + } + return nil +} + +func validatePE(path string) error { + f, err := pe.Open(path) + if err != nil { + return fmt.Errorf("not a valid PE binary: %w", err) + } + defer f.Close() + // PE files opened successfully are valid executables. + return nil +} + +// downloadChecksums downloads and parses the checksums.txt file. +// Returns a map of filename to sha256 checksum. +func downloadChecksums(ctx context.Context, client *http.Client, asset *Asset) (map[string]string, error) { + req, err := http.NewRequestWithContext(ctx, "GET", asset.BrowserDownloadURL, nil) + if err != nil { + return nil, err + } + req.Header.Set("User-Agent", userAgent()) + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to download checksums with status %d", resp.StatusCode) + } + + // Limit the size of checksums.txt to prevent memory exhaustion. + limitedBody := io.LimitReader(resp.Body, maxChecksumsSize+1) + + checksums := make(map[string]string) + scanner := bufio.NewScanner(limitedBody) + for scanner.Scan() { + line := scanner.Text() + parts := strings.Fields(line) + if len(parts) != 2 { + continue + } + // parts[0] is checksum, parts[1] is filename. + checksums[parts[1]] = parts[0] + } + if err := scanner.Err(); err != nil { + return nil, err + } + + return checksums, nil +} + +// parseChecksumLines parses checksum lines from a string. +// Used for testing the checksum parsing logic. +func parseChecksumLines(content string) map[string]string { + checksums := make(map[string]string) + scanner := bufio.NewScanner(strings.NewReader(content)) + for scanner.Scan() { + line := scanner.Text() + parts := strings.Fields(line) + if len(parts) != 2 { + continue + } + checksums[parts[1]] = parts[0] + } + return checksums +} + +// extractZip extracts the crush binary from a zip archive. +func extractZip(archivePath string) (string, error) { + r, err := zip.OpenReader(archivePath) + if err != nil { + return "", err + } + defer r.Close() + + binaryName := "crush" + if runtime.GOOS == "windows" { + binaryName = "crush.exe" + } + + for _, f := range r.File { + // Path traversal protection. + cleanName := filepath.Clean(f.Name) + if strings.Contains(cleanName, "..") { + continue + } + + // Use exact name matching to avoid matching unintended files. + if filepath.Base(f.Name) == binaryName { + binaryPath, err := extractZipFile(f) + if err != nil { + return "", err + } + return binaryPath, nil + } + } + + return "", fmt.Errorf("crush binary not found in archive") +} + +// extractZipFile extracts a single file from a zip archive with size limits. +func extractZipFile(f *zip.File) (string, error) { + rc, err := f.Open() + if err != nil { + return "", err + } + defer rc.Close() + + tmpBinary, err := os.CreateTemp("", "crush-binary-*") + if err != nil { + return "", err + } + + limitedReader := io.LimitReader(rc, maxBinarySize+1) + written, err := io.Copy(tmpBinary, limitedReader) + if err != nil { + tmpBinary.Close() + os.Remove(tmpBinary.Name()) + return "", err + } + if written > maxBinarySize { + tmpBinary.Close() + os.Remove(tmpBinary.Name()) + return "", fmt.Errorf("binary exceeds maximum size of %d bytes", maxBinarySize) + } + + if err := tmpBinary.Chmod(0o755); err != nil { + tmpBinary.Close() + os.Remove(tmpBinary.Name()) + return "", err + } + + if err := tmpBinary.Close(); err != nil { + os.Remove(tmpBinary.Name()) + return "", err + } + + return tmpBinary.Name(), nil +} + +// extractTarGz extracts the crush binary from a tar.gz archive. +func extractTarGz(archivePath string) (string, error) { + f, err := os.Open(archivePath) + if err != nil { + return "", err + } + defer f.Close() + + gzr, err := gzip.NewReader(f) + if err != nil { + return "", err + } + defer gzr.Close() + + tr := tar.NewReader(gzr) + + binaryName := "crush" + if runtime.GOOS == "windows" { + binaryName = "crush.exe" + } + + for { + header, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return "", err + } + + // Path traversal protection. + cleanName := filepath.Clean(header.Name) + if strings.Contains(cleanName, "..") { + continue + } + + // Use exact name matching to avoid matching unintended files. + if filepath.Base(header.Name) == binaryName { + binaryPath, err := extractTarFile(tr) + if err != nil { + return "", err + } + return binaryPath, nil + } + } + + return "", fmt.Errorf("crush binary not found in archive") +} + +// extractTarFile extracts a single file from a tar reader with size limits. +func extractTarFile(tr *tar.Reader) (string, error) { + tmpBinary, err := os.CreateTemp("", "crush-binary-*") + if err != nil { + return "", err + } + + limitedReader := io.LimitReader(tr, maxBinarySize+1) + written, err := io.Copy(tmpBinary, limitedReader) + if err != nil { + tmpBinary.Close() + os.Remove(tmpBinary.Name()) + return "", err + } + if written > maxBinarySize { + tmpBinary.Close() + os.Remove(tmpBinary.Name()) + return "", fmt.Errorf("binary exceeds maximum size of %d bytes", maxBinarySize) + } + + if err := tmpBinary.Chmod(0o755); err != nil { + tmpBinary.Close() + os.Remove(tmpBinary.Name()) + return "", err + } + + if err := tmpBinary.Close(); err != nil { + os.Remove(tmpBinary.Name()) + return "", err + } + + return tmpBinary.Name(), nil +} + +// Apply replaces the current executable with the downloaded binary. +func Apply(binaryPath string) error { + // Get path to current executable. + exe, err := os.Executable() + if err != nil { + return fmt.Errorf("failed to get executable path: %w", err) + } + + // Resolve symlinks. + exe, err = filepath.EvalSymlinks(exe) + if err != nil { + return fmt.Errorf("failed to resolve symlinks: %w", err) + } + + // Get the directory of the executable. + exeDir := filepath.Dir(exe) + + // Check if we have write permissions to the directory. + if err := checkWritePermission(exeDir); err != nil { + return fmt.Errorf("cannot write to %s: %w (you may need to run with elevated privileges)", exeDir, err) + } + + // Copy binary to exe directory first to ensure same filesystem. + // os.Rename fails across filesystems, and binaryPath may be in /tmp. + localBinary := filepath.Join(exeDir, ".crush-update-new") + if err := copyFile(binaryPath, localBinary); err != nil { + return fmt.Errorf("failed to copy new binary: %w", err) + } + + // Create a backup of the current executable. + backupPath := filepath.Join(exeDir, filepath.Base(exe)+".old") + if err := os.Rename(exe, backupPath); err != nil { + os.Remove(localBinary) + return fmt.Errorf("failed to backup current executable: %w", err) + } + + // Move new binary to executable location. + if err := os.Rename(localBinary, exe); err != nil { + // Try to restore backup on failure. + _ = os.Rename(backupPath, exe) + os.Remove(localBinary) + return fmt.Errorf("failed to install new version: %w", err) + } + + // Remove backup on success. + _ = os.Remove(backupPath) + + return nil +} + +// checkWritePermission checks if we can write to the given directory. +func checkWritePermission(dir string) error { + testFile := filepath.Join(dir, ".crush-update-test") + f, err := os.Create(testFile) + if err != nil { + return err + } + f.Close() + return os.Remove(testFile) +} + +// copyFile copies a file from src to dst, preserving permissions. +func copyFile(src, dst string) error { + srcFile, err := os.Open(src) + if err != nil { + return err + } + defer srcFile.Close() + + srcInfo, err := srcFile.Stat() + if err != nil { + return err + } + + dstFile, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, srcInfo.Mode()) + if err != nil { + return err + } + defer dstFile.Close() + + if _, err = io.Copy(dstFile, srcFile); err != nil { + return err + } + + // Sync to ensure data is persisted before atomic rename. + return dstFile.Sync() +} diff --git a/internal/update/update_test.go b/internal/update/update_test.go index 87e3849eb5a9ddc06b1e22c15c0bdde0b7739085..c02cae91ca1fe6c7ba7cbcbd16db8112215c81ca 100644 --- a/internal/update/update_test.go +++ b/internal/update/update_test.go @@ -1,7 +1,14 @@ package update import ( + "archive/tar" + "archive/zip" + "compress/gzip" "context" + "os" + "path/filepath" + "runtime" + "strings" "testing" "github.com/stretchr/testify/require" @@ -46,3 +53,330 @@ func (t testClient) Latest(ctx context.Context) (*Release, error) { HTMLURL: "https://example.org", }, nil } + +func TestFindAsset(t *testing.T) { + t.Parallel() + + // Create test assets matching goreleaser naming. + assets := []Asset{ + {Name: "crush_0.19.2_Linux_x86_64.tar.gz", BrowserDownloadURL: "https://example.com/linux-amd64.tar.gz"}, + {Name: "crush_0.19.2_Darwin_x86_64.tar.gz", BrowserDownloadURL: "https://example.com/darwin-amd64.tar.gz"}, + {Name: "crush_0.19.2_Darwin_arm64.tar.gz", BrowserDownloadURL: "https://example.com/darwin-arm64.tar.gz"}, + {Name: "crush_0.19.2_Windows_x86_64.zip", BrowserDownloadURL: "https://example.com/windows-amd64.zip"}, + {Name: "crush_0.19.2_Linux_i386.tar.gz", BrowserDownloadURL: "https://example.com/linux-386.tar.gz"}, + {Name: "checksums.txt", BrowserDownloadURL: "https://example.com/checksums.txt"}, + {Name: "crush_0.19.2_Linux_x86_64.tar.gz.sig", BrowserDownloadURL: "https://example.com/linux-amd64.tar.gz.sig"}, + } + + t.Run("finds correct asset for current platform", func(t *testing.T) { + t.Parallel() + asset, err := FindAsset(assets) + require.NoError(t, err) + require.NotNil(t, asset) + + // Check that the asset matches our platform. + switch runtime.GOOS { + case "linux": + require.Contains(t, asset.Name, "Linux") + case "darwin": + require.Contains(t, asset.Name, "Darwin") + case "windows": + require.Contains(t, asset.Name, "Windows") + } + + // Check that it's an archive, not a signature or checksum. + require.True(t, strings.HasSuffix(asset.Name, ".tar.gz") || strings.HasSuffix(asset.Name, ".zip")) + }) + + t.Run("returns error when no matching asset", func(t *testing.T) { + t.Parallel() + emptyAssets := []Asset{ + {Name: "checksums.txt", BrowserDownloadURL: "https://example.com/checksums.txt"}, + } + asset, err := FindAsset(emptyAssets) + require.Error(t, err) + require.Nil(t, asset) + }) +} + +func TestIsDevelopment(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + version string + want bool + }{ + {"devel version", "devel", true}, + {"unknown version", "unknown", true}, + {"dirty version", "0.19.0-dirty", true}, + {"dirty with suffix", "0.19.0-10-g1234567-dirty", true}, + {"go install version", "v0.0.0-0.20251231235959-06c807842604", true}, + {"stable version", "0.19.0", false}, + {"pre-release beta", "0.19.0-beta.1", false}, + {"pre-release rc", "0.19.0-rc.1", false}, + {"pre-release alpha", "0.19.0-alpha.1", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + info := Info{Current: tt.version, Latest: "0.20.0"} + require.Equal(t, tt.want, info.IsDevelopment()) + }) + } +} + +func TestAvailable(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + current string + latest string + want bool + }{ + {"same version", "0.19.0", "0.19.0", false}, + {"newer available", "0.19.0", "0.19.1", true}, + {"older latest (downgrade)", "0.19.1", "0.19.0", true}, + {"rc to stable", "0.19.0-rc.1", "0.19.0", true}, + {"stable to rc", "0.19.0", "0.20.0-rc.1", false}, + {"alpha to beta", "0.19.0-alpha.1", "0.19.0-beta.1", true}, + {"beta to rc", "0.19.0-beta.1", "0.19.0-rc.1", true}, + {"same pre-release", "0.19.0-beta.1", "0.19.0-beta.1", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + info := Info{Current: tt.current, Latest: tt.latest} + require.Equal(t, tt.want, info.Available()) + }) + } +} + +func TestParseChecksums(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want map[string]string + }{ + { + name: "standard format", + input: "abc123 crush_0.19.2_Linux_x86_64.tar.gz\ndef456 crush_0.19.2_Darwin_arm64.tar.gz\n", + want: map[string]string{ + "crush_0.19.2_Linux_x86_64.tar.gz": "abc123", + "crush_0.19.2_Darwin_arm64.tar.gz": "def456", + }, + }, + { + name: "empty lines", + input: "abc123 file1.tar.gz\n\ndef456 file2.tar.gz\n", + want: map[string]string{ + "file1.tar.gz": "abc123", + "file2.tar.gz": "def456", + }, + }, + { + name: "extra fields ignored", + input: "abc123 file1.tar.gz extra fields\n", + want: map[string]string{}, + }, + { + name: "single field ignored", + input: "abc123\n", + want: map[string]string{}, + }, + { + name: "whitespace variations", + input: "abc123\tfile1.tar.gz\n", + want: map[string]string{ + "file1.tar.gz": "abc123", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := parseChecksumLines(tt.input) + require.Equal(t, tt.want, result) + }) + } +} + +func TestExtractTarGz(t *testing.T) { + t.Parallel() + + t.Run("missing binary", func(t *testing.T) { + t.Parallel() + // Create a tar.gz with no crush binary. + tmpDir := t.TempDir() + archivePath := filepath.Join(tmpDir, "test.tar.gz") + + f, err := os.Create(archivePath) + require.NoError(t, err) + + gzw := gzip.NewWriter(f) + tw := tar.NewWriter(gzw) + + // Add a random file, not crush. + content := []byte("not a binary") + hdr := &tar.Header{ + Name: "other-file.txt", + Mode: 0o644, + Size: int64(len(content)), + } + require.NoError(t, tw.WriteHeader(hdr)) + _, err = tw.Write(content) + require.NoError(t, err) + + require.NoError(t, tw.Close()) + require.NoError(t, gzw.Close()) + require.NoError(t, f.Close()) + + _, err = extractTarGz(archivePath) + require.Error(t, err) + require.Contains(t, err.Error(), "not found") + }) + + t.Run("path traversal attempt", func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + archivePath := filepath.Join(tmpDir, "test.tar.gz") + + f, err := os.Create(archivePath) + require.NoError(t, err) + + gzw := gzip.NewWriter(f) + tw := tar.NewWriter(gzw) + + // Add a file with path traversal attempt. + content := []byte("malicious") + hdr := &tar.Header{ + Name: "../../../etc/passwd", + Mode: 0o644, + Size: int64(len(content)), + } + require.NoError(t, tw.WriteHeader(hdr)) + _, err = tw.Write(content) + require.NoError(t, err) + + require.NoError(t, tw.Close()) + require.NoError(t, gzw.Close()) + require.NoError(t, f.Close()) + + // Should not extract the malicious file and should fail to find binary. + _, err = extractTarGz(archivePath) + require.Error(t, err) + require.Contains(t, err.Error(), "not found") + }) +} + +func TestExtractZip(t *testing.T) { + t.Parallel() + + t.Run("missing binary", func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + archivePath := filepath.Join(tmpDir, "test.zip") + + f, err := os.Create(archivePath) + require.NoError(t, err) + + zw := zip.NewWriter(f) + + // Add a random file, not crush. + w, err := zw.Create("other-file.txt") + require.NoError(t, err) + _, err = w.Write([]byte("not a binary")) + require.NoError(t, err) + + require.NoError(t, zw.Close()) + require.NoError(t, f.Close()) + + _, err = extractZip(archivePath) + require.Error(t, err) + require.Contains(t, err.Error(), "not found") + }) + + t.Run("path traversal attempt", func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + archivePath := filepath.Join(tmpDir, "test.zip") + + f, err := os.Create(archivePath) + require.NoError(t, err) + + zw := zip.NewWriter(f) + + // Add a file with path traversal attempt. + w, err := zw.Create("../../../etc/passwd") + require.NoError(t, err) + _, err = w.Write([]byte("malicious")) + require.NoError(t, err) + + require.NoError(t, zw.Close()) + require.NoError(t, f.Close()) + + // Should not extract the malicious file and should fail to find binary. + _, err = extractZip(archivePath) + require.Error(t, err) + require.Contains(t, err.Error(), "not found") + }) +} + +func TestApply(t *testing.T) { + t.Parallel() + + t.Run("read-only directory", func(t *testing.T) { + t.Parallel() + if runtime.GOOS == "windows" { + t.Skip("chmod not reliable on Windows") + } + + tmpDir := t.TempDir() + + // Create a fake binary to apply. + binaryPath := filepath.Join(tmpDir, "new-binary") + require.NoError(t, os.WriteFile(binaryPath, []byte("new"), 0o755)) + + // Create a read-only directory. + readOnlyDir := filepath.Join(tmpDir, "readonly") + require.NoError(t, os.MkdirAll(readOnlyDir, 0o755)) + + // Create a fake executable in the read-only dir. + exePath := filepath.Join(readOnlyDir, "crush") + require.NoError(t, os.WriteFile(exePath, []byte("old"), 0o755)) + + // Make the directory read-only. + require.NoError(t, os.Chmod(readOnlyDir, 0o555)) + t.Cleanup(func() { + // Restore permissions for cleanup. + os.Chmod(readOnlyDir, 0o755) + }) + + // checkWritePermission should fail. + err := checkWritePermission(readOnlyDir) + require.Error(t, err) + }) + + t.Run("successful copy", func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + + src := filepath.Join(tmpDir, "src") + dst := filepath.Join(tmpDir, "dst") + + content := []byte("test content") + require.NoError(t, os.WriteFile(src, content, 0o755)) + + require.NoError(t, copyFile(src, dst)) + + dstContent, err := os.ReadFile(dst) + require.NoError(t, err) + require.Equal(t, content, dstContent) + }) +}