Detailed changes
@@ -11,6 +11,7 @@ import (
"path/filepath"
"strconv"
"strings"
+ "time"
tea "charm.land/bubbletea/v2"
"charm.land/lipgloss/v2"
@@ -22,6 +23,7 @@ import (
"github.com/charmbracelet/crush/internal/stringext"
termutil "github.com/charmbracelet/crush/internal/term"
"github.com/charmbracelet/crush/internal/tui"
+ "github.com/charmbracelet/crush/internal/update"
"github.com/charmbracelet/crush/internal/version"
"github.com/charmbracelet/fang"
uv "github.com/charmbracelet/ultraviolet"
@@ -41,6 +43,7 @@ func init() {
rootCmd.AddCommand(
runCmd,
dirsCmd,
+ updateCmd,
updateProvidersCmd,
logsCmd,
schemaCmd,
@@ -135,13 +138,24 @@ func Execute() {
// printing the version, and PreRunE runs after the version is already
// handled, so that doesn't work either.
// This is the only way I could find that works relatively well.
+ versionTemplate := defaultVersionTemplate
if term.IsTerminal(os.Stdout.Fd()) {
var b bytes.Buffer
w := colorprofile.NewWriter(os.Stdout, os.Environ())
w.Forward = &b
_, _ = w.WriteString(heartbit.String())
- rootCmd.SetVersionTemplate(b.String() + "\n" + defaultVersionTemplate)
+ versionTemplate = b.String() + "\n" + defaultVersionTemplate
}
+
+ // Check if version flag is present and add update notification if available.
+ if hasVersionFlag() {
+ if updateMsg := checkForUpdateSync(); updateMsg != "" {
+ versionTemplate += updateMsg
+ }
+ }
+
+ rootCmd.SetVersionTemplate(versionTemplate)
+
if err := fang.Execute(
context.Background(),
rootCmd,
@@ -152,6 +166,34 @@ func Execute() {
}
}
+// hasVersionFlag checks if the version flag is present in os.Args.
+func hasVersionFlag() bool {
+ for _, arg := range os.Args {
+ if arg == "-v" || arg == "--version" {
+ return true
+ }
+ }
+ return false
+}
+
+// checkForUpdateSync performs a synchronous update check with a short timeout.
+// Returns a formatted update message if an update is available, empty string otherwise.
+func checkForUpdateSync() string {
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ info, err := update.Check(ctx, version.Version, update.Default)
+ if err != nil || !info.Available() {
+ return ""
+ }
+
+ if info.IsDevelopment() {
+ return fmt.Sprintf("\nThis is a development version of Crush. The latest stable release is v%s.\nRun 'crush update' to learn more.\n", info.Latest)
+ }
+
+ return fmt.Sprintf("\nUpdate available: v%s → v%s\nRun 'crush update apply' to install.\n", info.Current, info.Latest)
+}
+
func setupAppWithProgressBar(cmd *cobra.Command) (*app.App, error) {
if termutil.SupportsProgressBar() {
_, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
@@ -0,0 +1,165 @@
+package cmd
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "time"
+
+ "charm.land/lipgloss/v2"
+ "github.com/charmbracelet/crush/internal/format"
+ "github.com/charmbracelet/crush/internal/tui/components/anim"
+ "github.com/charmbracelet/crush/internal/tui/styles"
+ "github.com/charmbracelet/crush/internal/update"
+ "github.com/charmbracelet/crush/internal/version"
+ "github.com/charmbracelet/x/term"
+ "github.com/spf13/cobra"
+)
+
+var updateCmd = &cobra.Command{
+ Use: "update",
+ Short: "Check for and apply updates",
+ Long: `Check if a new version of Crush is available.
+Use 'update apply' to download and install the latest version.`,
+ Example: `
+# Check if an update is available
+crush update
+
+# Apply the update if available
+crush update apply
+ `,
+ RunE: func(cmd *cobra.Command, args []string) error {
+ ctx, cancel := context.WithTimeout(cmd.Context(), 30*time.Second)
+ defer cancel()
+
+ spinner := newUpdateSpinner(ctx, cancel, "Checking for updates")
+ spinner.Start()
+
+ info, err := update.Check(ctx, version.Version, update.Default)
+ spinner.Stop()
+
+ if err != nil {
+ return fmt.Errorf("failed to check for updates: %w", err)
+ }
+
+ if info.IsDevelopment() {
+ fmt.Fprintf(os.Stderr, "You are running a development version of Crush (%s).\n", info.Current)
+ fmt.Fprintf(os.Stderr, "The latest stable release is v%s.\n", info.Latest)
+ fmt.Fprintf(os.Stderr, "Visit %s to learn more.\n", info.URL)
+ return nil
+ }
+
+ if !info.Available() {
+ fmt.Fprintf(os.Stderr, "You are already running the latest version (v%s).\n", info.Current)
+ return nil
+ }
+
+ fmt.Fprintf(os.Stderr, "Update available: v%s → v%s\n", info.Current, info.Latest)
+ fmt.Fprintf(os.Stderr, "Run 'crush update apply' to install the latest version.\n")
+ fmt.Fprintf(os.Stderr, "Or visit %s to download manually.\n", info.URL)
+
+ return nil
+ },
+}
+
+var updateApplyCmd = &cobra.Command{
+ Use: "apply",
+ Short: "Apply the latest update",
+ Long: `Download and install the latest version of Crush.`,
+ Example: `
+# Apply the latest update
+crush update apply
+ `,
+ RunE: func(cmd *cobra.Command, args []string) error {
+ ctx, cancel := context.WithTimeout(cmd.Context(), 5*time.Minute)
+ defer cancel()
+
+ spinner := newUpdateSpinner(ctx, cancel, "Checking for updates")
+ spinner.Start()
+
+ info, err := update.Check(ctx, version.Version, update.Default)
+ if err != nil {
+ spinner.Stop()
+ return fmt.Errorf("failed to check for updates: %w", err)
+ }
+
+ if info.IsDevelopment() {
+ spinner.Stop()
+ return fmt.Errorf("cannot update development versions automatically")
+ }
+
+ if !info.Available() {
+ spinner.Stop()
+ fmt.Fprintf(os.Stderr, "You are already running the latest version (v%s).\n", info.Current)
+ return nil
+ }
+
+ // Get the latest release with assets.
+ release, err := update.Default.Latest(ctx)
+ if err != nil {
+ spinner.Stop()
+ return fmt.Errorf("failed to fetch release information: %w", err)
+ }
+
+ // Find the appropriate asset for this platform.
+ asset, err := update.FindAsset(release.Assets)
+ if err != nil {
+ spinner.Stop()
+ return fmt.Errorf("failed to find update for your platform: %w", err)
+ }
+
+ spinner.Stop()
+ spinner = newUpdateSpinner(ctx, cancel, fmt.Sprintf("Downloading v%s", info.Latest))
+ spinner.Start()
+
+ // Download the asset.
+ binaryPath, err := update.Download(ctx, asset, release)
+ if err != nil {
+ spinner.Stop()
+ return fmt.Errorf("failed to download update: %w", err)
+ }
+ defer os.Remove(binaryPath)
+
+ spinner.Stop()
+ spinner = newUpdateSpinner(ctx, cancel, "Installing")
+ spinner.Start()
+
+ // Apply the update.
+ if err := update.Apply(binaryPath); err != nil {
+ spinner.Stop()
+ return fmt.Errorf("failed to apply update: %w", err)
+ }
+
+ spinner.Stop()
+
+ fmt.Fprintf(os.Stderr, "Successfully updated to v%s!\n", info.Latest)
+ fmt.Fprintf(os.Stderr, "Run 'crush -v' to verify the new version.\n")
+
+ return nil
+ },
+}
+
+// newUpdateSpinner creates a spinner for update operations.
+func newUpdateSpinner(ctx context.Context, cancel context.CancelFunc, label string) *format.Spinner {
+ t := styles.CurrentTheme()
+
+ // Detect background color for appropriate text color.
+ hasDarkBG := true
+ if term.IsTerminal(os.Stderr.Fd()) {
+ hasDarkBG = lipgloss.HasDarkBackground(os.Stdin, os.Stderr)
+ }
+ defaultFG := lipgloss.LightDark(hasDarkBG)(lipgloss.Color("#fafafa"), t.FgBase)
+
+ return format.NewSpinner(ctx, cancel, anim.Settings{
+ Size: 10,
+ Label: label,
+ LabelColor: defaultFG,
+ GradColorA: t.Primary,
+ GradColorB: t.Secondary,
+ CycleColors: true,
+ })
+}
+
+func init() {
+ updateCmd.AddCommand(updateApplyCmd)
+}
@@ -1,21 +1,42 @@
package update
import (
+ "archive/tar"
+ "archive/zip"
+ "bufio"
+ "compress/gzip"
"context"
+ "crypto/sha256"
+ "debug/elf"
+ "debug/macho"
+ "debug/pe"
+ "encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
+ "os"
+ "path/filepath"
"regexp"
+ "runtime"
"strings"
"time"
+
+ "github.com/charmbracelet/crush/internal/version"
)
const (
- githubApiUrl = "https://api.github.com/repos/charmbracelet/crush/releases/latest"
- userAgent = "crush/1.0"
+ githubApiUrl = "https://api.github.com/repos/charmbracelet/crush/releases/latest"
+ maxBinarySize = 500 * 1024 * 1024 // 500MB max for extracted binary
+ maxArchiveSize = 500 * 1024 * 1024 // 500MB max for downloaded archive
+ maxChecksumsSize = 1 * 1024 * 1024 // 1MB max for checksums.txt
)
+// userAgent returns the user agent string for HTTP requests.
+func userAgent() string {
+ return "crush/" + version.Version
+}
+
// Default is the default [Client].
var Default Client = &github{}
@@ -72,10 +93,17 @@ func Check(ctx context.Context, current string, client Client) (Info, error) {
return info, nil
}
+// Asset represents a GitHub release asset.
+type Asset struct {
+ Name string `json:"name"`
+ BrowserDownloadURL string `json:"browser_download_url"`
+}
+
// Release represents a GitHub release.
type Release struct {
- TagName string `json:"tag_name"`
- HTMLURL string `json:"html_url"`
+ TagName string `json:"tag_name"`
+ HTMLURL string `json:"html_url"`
+ Assets []Asset `json:"assets"`
}
// Client is a client that can get the latest release.
@@ -95,7 +123,7 @@ func (c *github) Latest(ctx context.Context) (*Release, error) {
if err != nil {
return nil, err
}
- req.Header.Set("User-Agent", userAgent)
+ req.Header.Set("User-Agent", userAgent())
req.Header.Set("Accept", "application/vnd.github.v3+json")
resp, err := client.Do(req)
@@ -116,3 +144,490 @@ func (c *github) Latest(ctx context.Context) (*Release, error) {
return &release, nil
}
+
+// FindAsset finds the appropriate asset for the current platform.
+func FindAsset(assets []Asset) (*Asset, error) {
+ // Normalize architecture to match goreleaser naming.
+ arch := runtime.GOARCH
+ switch arch {
+ case "amd64":
+ arch = "x86_64"
+ case "386":
+ arch = "i386"
+ case "arm":
+ arch = "armv7"
+ // arm64 stays as "arm64" in goreleaser naming.
+ }
+
+ // Normalize OS to match goreleaser naming (title case).
+ goos := runtime.GOOS
+ switch goos {
+ case "freebsd":
+ goos = "Freebsd"
+ case "netbsd":
+ goos = "Netbsd"
+ case "openbsd":
+ goos = "Openbsd"
+ default:
+ if len(goos) > 0 {
+ goos = strings.ToUpper(goos[:1]) + goos[1:]
+ }
+ }
+
+ // Look for archive matching our platform.
+ // Pattern: crush_{version}_{OS}_{ARCH}.{tar.gz|zip}
+ for _, asset := range assets {
+ if strings.Contains(asset.Name, goos) && strings.Contains(asset.Name, arch) {
+ // Ensure it's an archive, not a checksum or signature.
+ if strings.HasSuffix(asset.Name, ".tar.gz") || strings.HasSuffix(asset.Name, ".zip") {
+ return &asset, nil
+ }
+ }
+ }
+
+ return nil, fmt.Errorf("no suitable asset found for %s/%s", runtime.GOOS, runtime.GOARCH)
+}
+
+// Download downloads and extracts the crush binary from the given asset.
+// Returns the path to the extracted binary.
+func Download(ctx context.Context, asset *Asset, release *Release) (string, error) {
+ client := &http.Client{
+ Timeout: 5 * time.Minute,
+ }
+
+ // Find and download checksums.txt.
+ var checksumsAsset *Asset
+ for i := range release.Assets {
+ if release.Assets[i].Name == "checksums.txt" {
+ checksumsAsset = &release.Assets[i]
+ break
+ }
+ }
+ if checksumsAsset == nil {
+ return "", fmt.Errorf("checksums.txt not found in release")
+ }
+
+ checksums, err := downloadChecksums(ctx, client, checksumsAsset)
+ if err != nil {
+ return "", fmt.Errorf("failed to download checksums: %w", err)
+ }
+
+ req, err := http.NewRequestWithContext(ctx, "GET", asset.BrowserDownloadURL, nil)
+ if err != nil {
+ return "", err
+ }
+ req.Header.Set("User-Agent", userAgent())
+
+ resp, err := client.Do(req)
+ if err != nil {
+ return "", err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return "", fmt.Errorf("download failed with status %d", resp.StatusCode)
+ }
+
+ // Validate Content-Length if provided.
+ if resp.ContentLength > maxArchiveSize {
+ return "", fmt.Errorf("archive size %d exceeds maximum allowed size of %d bytes", resp.ContentLength, maxArchiveSize)
+ }
+
+ // Create temp file for archive.
+ tmpFile, err := os.CreateTemp("", "crush-update-*")
+ if err != nil {
+ return "", fmt.Errorf("failed to create temp file: %w", err)
+ }
+ defer os.Remove(tmpFile.Name())
+
+ // Download to temp file while computing checksum.
+ // Use LimitReader to prevent DoS from oversized downloads (Content-Length can be spoofed).
+ hash := sha256.New()
+ limitedBody := io.LimitReader(resp.Body, maxArchiveSize+1)
+ written, err := io.Copy(io.MultiWriter(tmpFile, hash), limitedBody)
+ if err != nil {
+ tmpFile.Close()
+ return "", fmt.Errorf("failed to download: %w", err)
+ }
+ if written > maxArchiveSize {
+ tmpFile.Close()
+ return "", fmt.Errorf("archive size %d exceeds maximum allowed size of %d bytes", written, maxArchiveSize)
+ }
+ tmpFile.Close()
+
+ // Verify checksum.
+ actualSum := hex.EncodeToString(hash.Sum(nil))
+ expectedSum, ok := checksums[asset.Name]
+ if !ok {
+ return "", fmt.Errorf("no checksum found for %s", asset.Name)
+ }
+ if actualSum != expectedSum {
+ return "", fmt.Errorf("checksum mismatch: expected %s, got %s", expectedSum, actualSum)
+ }
+
+ // Extract binary based on archive type.
+ var binaryPath string
+ if strings.HasSuffix(asset.Name, ".zip") {
+ binaryPath, err = extractZip(tmpFile.Name())
+ } else {
+ binaryPath, err = extractTarGz(tmpFile.Name())
+ }
+ if err != nil {
+ return "", fmt.Errorf("failed to extract: %w", err)
+ }
+
+ // Validate the extracted binary before returning.
+ if err := validateBinary(binaryPath); err != nil {
+ os.Remove(binaryPath)
+ return "", fmt.Errorf("invalid binary: %w", err)
+ }
+
+ return binaryPath, nil
+}
+
+// validateBinary checks that the file at path is a valid executable binary
+// for the current platform using the standard library debug packages.
+func validateBinary(path string) error {
+ switch runtime.GOOS {
+ case "windows":
+ return validatePE(path)
+ case "darwin":
+ return validateMachO(path)
+ default:
+ return validateELF(path)
+ }
+}
+
+func validateELF(path string) error {
+ f, err := elf.Open(path)
+ if err != nil {
+ return fmt.Errorf("not a valid ELF binary: %w", err)
+ }
+ defer f.Close()
+ if f.Type != elf.ET_EXEC && f.Type != elf.ET_DYN {
+ return fmt.Errorf("ELF file is not an executable (type: %v)", f.Type)
+ }
+ return nil
+}
+
+func validateMachO(path string) error {
+ f, err := macho.Open(path)
+ if err != nil {
+ return fmt.Errorf("not a valid Mach-O binary: %w", err)
+ }
+ defer f.Close()
+ if f.Type != macho.TypeExec {
+ return fmt.Errorf("Mach-O file is not an executable (type: %v)", f.Type)
+ }
+ return nil
+}
+
+func validatePE(path string) error {
+ f, err := pe.Open(path)
+ if err != nil {
+ return fmt.Errorf("not a valid PE binary: %w", err)
+ }
+ defer f.Close()
+ // PE files opened successfully are valid executables.
+ return nil
+}
+
+// downloadChecksums downloads and parses the checksums.txt file.
+// Returns a map of filename to sha256 checksum.
+func downloadChecksums(ctx context.Context, client *http.Client, asset *Asset) (map[string]string, error) {
+ req, err := http.NewRequestWithContext(ctx, "GET", asset.BrowserDownloadURL, nil)
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("User-Agent", userAgent())
+
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("failed to download checksums with status %d", resp.StatusCode)
+ }
+
+ // Limit the size of checksums.txt to prevent memory exhaustion.
+ limitedBody := io.LimitReader(resp.Body, maxChecksumsSize+1)
+
+ checksums := make(map[string]string)
+ scanner := bufio.NewScanner(limitedBody)
+ for scanner.Scan() {
+ line := scanner.Text()
+ parts := strings.Fields(line)
+ if len(parts) != 2 {
+ continue
+ }
+ // parts[0] is checksum, parts[1] is filename.
+ checksums[parts[1]] = parts[0]
+ }
+ if err := scanner.Err(); err != nil {
+ return nil, err
+ }
+
+ return checksums, nil
+}
+
+// parseChecksumLines parses checksum lines from a string.
+// Used for testing the checksum parsing logic.
+func parseChecksumLines(content string) map[string]string {
+ checksums := make(map[string]string)
+ scanner := bufio.NewScanner(strings.NewReader(content))
+ for scanner.Scan() {
+ line := scanner.Text()
+ parts := strings.Fields(line)
+ if len(parts) != 2 {
+ continue
+ }
+ checksums[parts[1]] = parts[0]
+ }
+ return checksums
+}
+
+// extractZip extracts the crush binary from a zip archive.
+func extractZip(archivePath string) (string, error) {
+ r, err := zip.OpenReader(archivePath)
+ if err != nil {
+ return "", err
+ }
+ defer r.Close()
+
+ binaryName := "crush"
+ if runtime.GOOS == "windows" {
+ binaryName = "crush.exe"
+ }
+
+ for _, f := range r.File {
+ // Path traversal protection.
+ cleanName := filepath.Clean(f.Name)
+ if strings.Contains(cleanName, "..") {
+ continue
+ }
+
+ // Use exact name matching to avoid matching unintended files.
+ if filepath.Base(f.Name) == binaryName {
+ binaryPath, err := extractZipFile(f)
+ if err != nil {
+ return "", err
+ }
+ return binaryPath, nil
+ }
+ }
+
+ return "", fmt.Errorf("crush binary not found in archive")
+}
+
+// extractZipFile extracts a single file from a zip archive with size limits.
+func extractZipFile(f *zip.File) (string, error) {
+ rc, err := f.Open()
+ if err != nil {
+ return "", err
+ }
+ defer rc.Close()
+
+ tmpBinary, err := os.CreateTemp("", "crush-binary-*")
+ if err != nil {
+ return "", err
+ }
+
+ limitedReader := io.LimitReader(rc, maxBinarySize+1)
+ written, err := io.Copy(tmpBinary, limitedReader)
+ if err != nil {
+ tmpBinary.Close()
+ os.Remove(tmpBinary.Name())
+ return "", err
+ }
+ if written > maxBinarySize {
+ tmpBinary.Close()
+ os.Remove(tmpBinary.Name())
+ return "", fmt.Errorf("binary exceeds maximum size of %d bytes", maxBinarySize)
+ }
+
+ if err := tmpBinary.Chmod(0o755); err != nil {
+ tmpBinary.Close()
+ os.Remove(tmpBinary.Name())
+ return "", err
+ }
+
+ if err := tmpBinary.Close(); err != nil {
+ os.Remove(tmpBinary.Name())
+ return "", err
+ }
+
+ return tmpBinary.Name(), nil
+}
+
+// extractTarGz extracts the crush binary from a tar.gz archive.
+func extractTarGz(archivePath string) (string, error) {
+ f, err := os.Open(archivePath)
+ if err != nil {
+ return "", err
+ }
+ defer f.Close()
+
+ gzr, err := gzip.NewReader(f)
+ if err != nil {
+ return "", err
+ }
+ defer gzr.Close()
+
+ tr := tar.NewReader(gzr)
+
+ binaryName := "crush"
+ if runtime.GOOS == "windows" {
+ binaryName = "crush.exe"
+ }
+
+ for {
+ header, err := tr.Next()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return "", err
+ }
+
+ // Path traversal protection.
+ cleanName := filepath.Clean(header.Name)
+ if strings.Contains(cleanName, "..") {
+ continue
+ }
+
+ // Use exact name matching to avoid matching unintended files.
+ if filepath.Base(header.Name) == binaryName {
+ binaryPath, err := extractTarFile(tr)
+ if err != nil {
+ return "", err
+ }
+ return binaryPath, nil
+ }
+ }
+
+ return "", fmt.Errorf("crush binary not found in archive")
+}
+
+// extractTarFile extracts a single file from a tar reader with size limits.
+func extractTarFile(tr *tar.Reader) (string, error) {
+ tmpBinary, err := os.CreateTemp("", "crush-binary-*")
+ if err != nil {
+ return "", err
+ }
+
+ limitedReader := io.LimitReader(tr, maxBinarySize+1)
+ written, err := io.Copy(tmpBinary, limitedReader)
+ if err != nil {
+ tmpBinary.Close()
+ os.Remove(tmpBinary.Name())
+ return "", err
+ }
+ if written > maxBinarySize {
+ tmpBinary.Close()
+ os.Remove(tmpBinary.Name())
+ return "", fmt.Errorf("binary exceeds maximum size of %d bytes", maxBinarySize)
+ }
+
+ if err := tmpBinary.Chmod(0o755); err != nil {
+ tmpBinary.Close()
+ os.Remove(tmpBinary.Name())
+ return "", err
+ }
+
+ if err := tmpBinary.Close(); err != nil {
+ os.Remove(tmpBinary.Name())
+ return "", err
+ }
+
+ return tmpBinary.Name(), nil
+}
+
+// Apply replaces the current executable with the downloaded binary.
+func Apply(binaryPath string) error {
+ // Get path to current executable.
+ exe, err := os.Executable()
+ if err != nil {
+ return fmt.Errorf("failed to get executable path: %w", err)
+ }
+
+ // Resolve symlinks.
+ exe, err = filepath.EvalSymlinks(exe)
+ if err != nil {
+ return fmt.Errorf("failed to resolve symlinks: %w", err)
+ }
+
+ // Get the directory of the executable.
+ exeDir := filepath.Dir(exe)
+
+ // Check if we have write permissions to the directory.
+ if err := checkWritePermission(exeDir); err != nil {
+ return fmt.Errorf("cannot write to %s: %w (you may need to run with elevated privileges)", exeDir, err)
+ }
+
+ // Copy binary to exe directory first to ensure same filesystem.
+ // os.Rename fails across filesystems, and binaryPath may be in /tmp.
+ localBinary := filepath.Join(exeDir, ".crush-update-new")
+ if err := copyFile(binaryPath, localBinary); err != nil {
+ return fmt.Errorf("failed to copy new binary: %w", err)
+ }
+
+ // Create a backup of the current executable.
+ backupPath := filepath.Join(exeDir, filepath.Base(exe)+".old")
+ if err := os.Rename(exe, backupPath); err != nil {
+ os.Remove(localBinary)
+ return fmt.Errorf("failed to backup current executable: %w", err)
+ }
+
+ // Move new binary to executable location.
+ if err := os.Rename(localBinary, exe); err != nil {
+ // Try to restore backup on failure.
+ _ = os.Rename(backupPath, exe)
+ os.Remove(localBinary)
+ return fmt.Errorf("failed to install new version: %w", err)
+ }
+
+ // Remove backup on success.
+ _ = os.Remove(backupPath)
+
+ return nil
+}
+
+// checkWritePermission checks if we can write to the given directory.
+func checkWritePermission(dir string) error {
+ testFile := filepath.Join(dir, ".crush-update-test")
+ f, err := os.Create(testFile)
+ if err != nil {
+ return err
+ }
+ f.Close()
+ return os.Remove(testFile)
+}
+
+// copyFile copies a file from src to dst, preserving permissions.
+func copyFile(src, dst string) error {
+ srcFile, err := os.Open(src)
+ if err != nil {
+ return err
+ }
+ defer srcFile.Close()
+
+ srcInfo, err := srcFile.Stat()
+ if err != nil {
+ return err
+ }
+
+ dstFile, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, srcInfo.Mode())
+ if err != nil {
+ return err
+ }
+ defer dstFile.Close()
+
+ if _, err = io.Copy(dstFile, srcFile); err != nil {
+ return err
+ }
+
+ // Sync to ensure data is persisted before atomic rename.
+ return dstFile.Sync()
+}
@@ -1,7 +1,14 @@
package update
import (
+ "archive/tar"
+ "archive/zip"
+ "compress/gzip"
"context"
+ "os"
+ "path/filepath"
+ "runtime"
+ "strings"
"testing"
"github.com/stretchr/testify/require"
@@ -46,3 +53,330 @@ func (t testClient) Latest(ctx context.Context) (*Release, error) {
HTMLURL: "https://example.org",
}, nil
}
+
+func TestFindAsset(t *testing.T) {
+ t.Parallel()
+
+ // Create test assets matching goreleaser naming.
+ assets := []Asset{
+ {Name: "crush_0.19.2_Linux_x86_64.tar.gz", BrowserDownloadURL: "https://example.com/linux-amd64.tar.gz"},
+ {Name: "crush_0.19.2_Darwin_x86_64.tar.gz", BrowserDownloadURL: "https://example.com/darwin-amd64.tar.gz"},
+ {Name: "crush_0.19.2_Darwin_arm64.tar.gz", BrowserDownloadURL: "https://example.com/darwin-arm64.tar.gz"},
+ {Name: "crush_0.19.2_Windows_x86_64.zip", BrowserDownloadURL: "https://example.com/windows-amd64.zip"},
+ {Name: "crush_0.19.2_Linux_i386.tar.gz", BrowserDownloadURL: "https://example.com/linux-386.tar.gz"},
+ {Name: "checksums.txt", BrowserDownloadURL: "https://example.com/checksums.txt"},
+ {Name: "crush_0.19.2_Linux_x86_64.tar.gz.sig", BrowserDownloadURL: "https://example.com/linux-amd64.tar.gz.sig"},
+ }
+
+ t.Run("finds correct asset for current platform", func(t *testing.T) {
+ t.Parallel()
+ asset, err := FindAsset(assets)
+ require.NoError(t, err)
+ require.NotNil(t, asset)
+
+ // Check that the asset matches our platform.
+ switch runtime.GOOS {
+ case "linux":
+ require.Contains(t, asset.Name, "Linux")
+ case "darwin":
+ require.Contains(t, asset.Name, "Darwin")
+ case "windows":
+ require.Contains(t, asset.Name, "Windows")
+ }
+
+ // Check that it's an archive, not a signature or checksum.
+ require.True(t, strings.HasSuffix(asset.Name, ".tar.gz") || strings.HasSuffix(asset.Name, ".zip"))
+ })
+
+ t.Run("returns error when no matching asset", func(t *testing.T) {
+ t.Parallel()
+ emptyAssets := []Asset{
+ {Name: "checksums.txt", BrowserDownloadURL: "https://example.com/checksums.txt"},
+ }
+ asset, err := FindAsset(emptyAssets)
+ require.Error(t, err)
+ require.Nil(t, asset)
+ })
+}
+
+func TestIsDevelopment(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ version string
+ want bool
+ }{
+ {"devel version", "devel", true},
+ {"unknown version", "unknown", true},
+ {"dirty version", "0.19.0-dirty", true},
+ {"dirty with suffix", "0.19.0-10-g1234567-dirty", true},
+ {"go install version", "v0.0.0-0.20251231235959-06c807842604", true},
+ {"stable version", "0.19.0", false},
+ {"pre-release beta", "0.19.0-beta.1", false},
+ {"pre-release rc", "0.19.0-rc.1", false},
+ {"pre-release alpha", "0.19.0-alpha.1", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ info := Info{Current: tt.version, Latest: "0.20.0"}
+ require.Equal(t, tt.want, info.IsDevelopment())
+ })
+ }
+}
+
+func TestAvailable(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ current string
+ latest string
+ want bool
+ }{
+ {"same version", "0.19.0", "0.19.0", false},
+ {"newer available", "0.19.0", "0.19.1", true},
+ {"older latest (downgrade)", "0.19.1", "0.19.0", true},
+ {"rc to stable", "0.19.0-rc.1", "0.19.0", true},
+ {"stable to rc", "0.19.0", "0.20.0-rc.1", false},
+ {"alpha to beta", "0.19.0-alpha.1", "0.19.0-beta.1", true},
+ {"beta to rc", "0.19.0-beta.1", "0.19.0-rc.1", true},
+ {"same pre-release", "0.19.0-beta.1", "0.19.0-beta.1", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ info := Info{Current: tt.current, Latest: tt.latest}
+ require.Equal(t, tt.want, info.Available())
+ })
+ }
+}
+
+func TestParseChecksums(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ input string
+ want map[string]string
+ }{
+ {
+ name: "standard format",
+ input: "abc123 crush_0.19.2_Linux_x86_64.tar.gz\ndef456 crush_0.19.2_Darwin_arm64.tar.gz\n",
+ want: map[string]string{
+ "crush_0.19.2_Linux_x86_64.tar.gz": "abc123",
+ "crush_0.19.2_Darwin_arm64.tar.gz": "def456",
+ },
+ },
+ {
+ name: "empty lines",
+ input: "abc123 file1.tar.gz\n\ndef456 file2.tar.gz\n",
+ want: map[string]string{
+ "file1.tar.gz": "abc123",
+ "file2.tar.gz": "def456",
+ },
+ },
+ {
+ name: "extra fields ignored",
+ input: "abc123 file1.tar.gz extra fields\n",
+ want: map[string]string{},
+ },
+ {
+ name: "single field ignored",
+ input: "abc123\n",
+ want: map[string]string{},
+ },
+ {
+ name: "whitespace variations",
+ input: "abc123\tfile1.tar.gz\n",
+ want: map[string]string{
+ "file1.tar.gz": "abc123",
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ result := parseChecksumLines(tt.input)
+ require.Equal(t, tt.want, result)
+ })
+ }
+}
+
+func TestExtractTarGz(t *testing.T) {
+ t.Parallel()
+
+ t.Run("missing binary", func(t *testing.T) {
+ t.Parallel()
+ // Create a tar.gz with no crush binary.
+ tmpDir := t.TempDir()
+ archivePath := filepath.Join(tmpDir, "test.tar.gz")
+
+ f, err := os.Create(archivePath)
+ require.NoError(t, err)
+
+ gzw := gzip.NewWriter(f)
+ tw := tar.NewWriter(gzw)
+
+ // Add a random file, not crush.
+ content := []byte("not a binary")
+ hdr := &tar.Header{
+ Name: "other-file.txt",
+ Mode: 0o644,
+ Size: int64(len(content)),
+ }
+ require.NoError(t, tw.WriteHeader(hdr))
+ _, err = tw.Write(content)
+ require.NoError(t, err)
+
+ require.NoError(t, tw.Close())
+ require.NoError(t, gzw.Close())
+ require.NoError(t, f.Close())
+
+ _, err = extractTarGz(archivePath)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "not found")
+ })
+
+ t.Run("path traversal attempt", func(t *testing.T) {
+ t.Parallel()
+ tmpDir := t.TempDir()
+ archivePath := filepath.Join(tmpDir, "test.tar.gz")
+
+ f, err := os.Create(archivePath)
+ require.NoError(t, err)
+
+ gzw := gzip.NewWriter(f)
+ tw := tar.NewWriter(gzw)
+
+ // Add a file with path traversal attempt.
+ content := []byte("malicious")
+ hdr := &tar.Header{
+ Name: "../../../etc/passwd",
+ Mode: 0o644,
+ Size: int64(len(content)),
+ }
+ require.NoError(t, tw.WriteHeader(hdr))
+ _, err = tw.Write(content)
+ require.NoError(t, err)
+
+ require.NoError(t, tw.Close())
+ require.NoError(t, gzw.Close())
+ require.NoError(t, f.Close())
+
+ // Should not extract the malicious file and should fail to find binary.
+ _, err = extractTarGz(archivePath)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "not found")
+ })
+}
+
+func TestExtractZip(t *testing.T) {
+ t.Parallel()
+
+ t.Run("missing binary", func(t *testing.T) {
+ t.Parallel()
+ tmpDir := t.TempDir()
+ archivePath := filepath.Join(tmpDir, "test.zip")
+
+ f, err := os.Create(archivePath)
+ require.NoError(t, err)
+
+ zw := zip.NewWriter(f)
+
+ // Add a random file, not crush.
+ w, err := zw.Create("other-file.txt")
+ require.NoError(t, err)
+ _, err = w.Write([]byte("not a binary"))
+ require.NoError(t, err)
+
+ require.NoError(t, zw.Close())
+ require.NoError(t, f.Close())
+
+ _, err = extractZip(archivePath)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "not found")
+ })
+
+ t.Run("path traversal attempt", func(t *testing.T) {
+ t.Parallel()
+ tmpDir := t.TempDir()
+ archivePath := filepath.Join(tmpDir, "test.zip")
+
+ f, err := os.Create(archivePath)
+ require.NoError(t, err)
+
+ zw := zip.NewWriter(f)
+
+ // Add a file with path traversal attempt.
+ w, err := zw.Create("../../../etc/passwd")
+ require.NoError(t, err)
+ _, err = w.Write([]byte("malicious"))
+ require.NoError(t, err)
+
+ require.NoError(t, zw.Close())
+ require.NoError(t, f.Close())
+
+ // Should not extract the malicious file and should fail to find binary.
+ _, err = extractZip(archivePath)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "not found")
+ })
+}
+
+func TestApply(t *testing.T) {
+ t.Parallel()
+
+ t.Run("read-only directory", func(t *testing.T) {
+ t.Parallel()
+ if runtime.GOOS == "windows" {
+ t.Skip("chmod not reliable on Windows")
+ }
+
+ tmpDir := t.TempDir()
+
+ // Create a fake binary to apply.
+ binaryPath := filepath.Join(tmpDir, "new-binary")
+ require.NoError(t, os.WriteFile(binaryPath, []byte("new"), 0o755))
+
+ // Create a read-only directory.
+ readOnlyDir := filepath.Join(tmpDir, "readonly")
+ require.NoError(t, os.MkdirAll(readOnlyDir, 0o755))
+
+ // Create a fake executable in the read-only dir.
+ exePath := filepath.Join(readOnlyDir, "crush")
+ require.NoError(t, os.WriteFile(exePath, []byte("old"), 0o755))
+
+ // Make the directory read-only.
+ require.NoError(t, os.Chmod(readOnlyDir, 0o555))
+ t.Cleanup(func() {
+ // Restore permissions for cleanup.
+ os.Chmod(readOnlyDir, 0o755)
+ })
+
+ // checkWritePermission should fail.
+ err := checkWritePermission(readOnlyDir)
+ require.Error(t, err)
+ })
+
+ t.Run("successful copy", func(t *testing.T) {
+ t.Parallel()
+ tmpDir := t.TempDir()
+
+ src := filepath.Join(tmpDir, "src")
+ dst := filepath.Join(tmpDir, "dst")
+
+ content := []byte("test content")
+ require.NoError(t, os.WriteFile(src, content, 0o755))
+
+ require.NoError(t, copyFile(src, dst))
+
+ dstContent, err := os.ReadFile(dst)
+ require.NoError(t, err)
+ require.Equal(t, content, dstContent)
+ })
+}