versioncheck.go

  1package server
  2
  3import (
  4	"bufio"
  5	"bytes"
  6	"context"
  7	"crypto/sha256"
  8	"encoding/hex"
  9	"encoding/json"
 10	"errors"
 11	"fmt"
 12	"io"
 13	"io/fs"
 14	"net/http"
 15	"os"
 16	"os/exec"
 17	"runtime"
 18	"strings"
 19	"sync"
 20	"time"
 21
 22	"github.com/fynelabs/selfupdate"
 23
 24	"shelley.exe.dev/version"
 25)
 26
 27// VersionChecker checks for new versions of Shelley from GitHub releases.
 28type VersionChecker struct {
 29	mu          sync.Mutex
 30	lastCheck   time.Time
 31	cachedInfo  *VersionInfo
 32	skipCheck   bool
 33	githubOwner string
 34	githubRepo  string
 35}
 36
 37// VersionInfo contains version check results.
 38type VersionInfo struct {
 39	CurrentVersion      string       `json:"current_version"`
 40	CurrentTag          string       `json:"current_tag,omitempty"`
 41	CurrentCommit       string       `json:"current_commit,omitempty"`
 42	CurrentCommitTime   string       `json:"current_commit_time,omitempty"`
 43	LatestVersion       string       `json:"latest_version,omitempty"`
 44	LatestTag           string       `json:"latest_tag,omitempty"`
 45	PublishedAt         time.Time    `json:"published_at,omitempty"`
 46	HasUpdate           bool         `json:"has_update"`    // True if minor version is newer (for showing upgrade button)
 47	ShouldNotify        bool         `json:"should_notify"` // True if should show red dot (newer + 5 days old)
 48	DownloadURL         string       `json:"download_url,omitempty"`
 49	ExecutablePath      string       `json:"executable_path,omitempty"`
 50	Commits             []CommitInfo `json:"commits,omitempty"`
 51	CheckedAt           time.Time    `json:"checked_at"`
 52	Error               string       `json:"error,omitempty"`
 53	RunningUnderSystemd bool         `json:"running_under_systemd"` // True if INVOCATION_ID env var is set (systemd)
 54	ReleaseInfo         *ReleaseInfo `json:"-"`                     // Internal, not exposed to JSON
 55}
 56
 57// CommitInfo represents a commit in the changelog.
 58type CommitInfo struct {
 59	SHA     string    `json:"sha"`
 60	Message string    `json:"message"`
 61	Author  string    `json:"author"`
 62	Date    time.Time `json:"date"`
 63}
 64
 65// ReleaseInfo represents release metadata.
 66type ReleaseInfo struct {
 67	TagName      string            `json:"tag_name"`
 68	Version      string            `json:"version"`
 69	Commit       string            `json:"commit"`
 70	CommitFull   string            `json:"commit_full"`
 71	CommitTime   string            `json:"commit_time"`
 72	PublishedAt  string            `json:"published_at"`
 73	DownloadURLs map[string]string `json:"download_urls"`
 74	ChecksumsURL string            `json:"checksums_url"`
 75}
 76
 77// StaticCommitInfo represents a commit from commits.json.
 78type StaticCommitInfo struct {
 79	SHA     string `json:"sha"`
 80	Subject string `json:"subject"`
 81}
 82
 83const (
 84	// staticMetadataURL is the base URL for version metadata on GitHub Pages.
 85	// This avoids GitHub API rate limits.
 86	staticMetadataURL = "https://boldsoftware.github.io/shelley"
 87)
 88
 89// NewVersionChecker creates a new version checker.
 90func NewVersionChecker() *VersionChecker {
 91	skipCheck := os.Getenv("SHELLEY_SKIP_VERSION_CHECK") == "true"
 92	return &VersionChecker{
 93		skipCheck:   skipCheck,
 94		githubOwner: "boldsoftware",
 95		githubRepo:  "shelley",
 96	}
 97}
 98
 99// Check checks for a new version, using the cache if still valid.
100func (vc *VersionChecker) Check(ctx context.Context, forceRefresh bool) (*VersionInfo, error) {
101	if vc.skipCheck {
102		info := version.GetInfo()
103		return &VersionInfo{
104			CurrentVersion:      info.Version,
105			CurrentTag:          info.Tag,
106			CurrentCommit:       info.Commit,
107			HasUpdate:           false,
108			CheckedAt:           time.Now(),
109			RunningUnderSystemd: os.Getenv("INVOCATION_ID") != "",
110		}, nil
111	}
112
113	vc.mu.Lock()
114	defer vc.mu.Unlock()
115
116	// Return cached info if still valid (6 hours) and not forcing refresh
117	if !forceRefresh && vc.cachedInfo != nil && time.Since(vc.lastCheck) < 6*time.Hour {
118		return vc.cachedInfo, nil
119	}
120
121	info, err := vc.fetchVersionInfo(ctx)
122	if err != nil {
123		// On error, return current version info with error
124		currentInfo := version.GetInfo()
125		return &VersionInfo{
126			CurrentVersion:      currentInfo.Version,
127			CurrentTag:          currentInfo.Tag,
128			CurrentCommit:       currentInfo.Commit,
129			HasUpdate:           false,
130			CheckedAt:           time.Now(),
131			Error:               err.Error(),
132			RunningUnderSystemd: os.Getenv("INVOCATION_ID") != "",
133		}, nil
134	}
135
136	vc.cachedInfo = info
137	vc.lastCheck = time.Now()
138	return info, nil
139}
140
141// fetchVersionInfo fetches the latest release info from GitHub Pages.
142func (vc *VersionChecker) fetchVersionInfo(ctx context.Context) (*VersionInfo, error) {
143	currentInfo := version.GetInfo()
144	execPath, _ := os.Executable()
145	info := &VersionInfo{
146		CurrentVersion:      currentInfo.Version,
147		CurrentTag:          currentInfo.Tag,
148		CurrentCommit:       currentInfo.Commit,
149		CurrentCommitTime:   currentInfo.CommitTime,
150		ExecutablePath:      execPath,
151		CheckedAt:           time.Now(),
152		RunningUnderSystemd: os.Getenv("INVOCATION_ID") != "",
153	}
154
155	// Fetch latest release from static metadata
156	latestRelease, err := vc.fetchLatestRelease(ctx)
157	if err != nil {
158		return nil, fmt.Errorf("failed to fetch latest release: %w", err)
159	}
160
161	info.LatestTag = latestRelease.TagName
162	info.LatestVersion = latestRelease.TagName
163	info.ReleaseInfo = latestRelease
164
165	// Parse the published_at time
166	if publishedAt, err := time.Parse(time.RFC3339, latestRelease.PublishedAt); err == nil {
167		info.PublishedAt = publishedAt
168	}
169
170	// Find the download URL for the current platform
171	info.DownloadURL = vc.findDownloadURL(latestRelease)
172
173	// Check if latest has a newer minor version
174	info.HasUpdate = vc.isNewerMinor(currentInfo.Tag, latestRelease.TagName)
175
176	// For ShouldNotify, compare commit times if we have an update
177	if info.HasUpdate && currentInfo.CommitTime != "" {
178		currentTime, err1 := time.Parse(time.RFC3339, currentInfo.CommitTime)
179		latestTime, err2 := time.Parse(time.RFC3339, latestRelease.CommitTime)
180		if err1 == nil && err2 == nil {
181			// Show notification if the latest release is 5+ days newer than current
182			timeBetween := latestTime.Sub(currentTime)
183			info.ShouldNotify = timeBetween >= 5*24*time.Hour
184		} else {
185			// Can't parse times, just notify if there's an update
186			info.ShouldNotify = true
187		}
188	}
189
190	return info, nil
191}
192
193// FetchChangelog fetches the commits between current and latest versions.
194func (vc *VersionChecker) FetchChangelog(ctx context.Context, currentTag, latestTag string) ([]CommitInfo, error) {
195	if currentTag == "" || latestTag == "" {
196		return nil, nil
197	}
198
199	url := staticMetadataURL + "/commits.json"
200
201	req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
202	if err != nil {
203		return nil, err
204	}
205	req.Header.Set("User-Agent", "Shelley-VersionChecker")
206
207	resp, err := http.DefaultClient.Do(req)
208	if err != nil {
209		return nil, err
210	}
211	defer resp.Body.Close()
212
213	if resp.StatusCode != http.StatusOK {
214		return nil, fmt.Errorf("static commits returned status %d", resp.StatusCode)
215	}
216
217	var staticCommits []StaticCommitInfo
218	if err := json.NewDecoder(resp.Body).Decode(&staticCommits); err != nil {
219		return nil, err
220	}
221
222	// Extract short SHAs from tags (tags are v0.X.YSHA where SHA is octal-encoded)
223	currentSHA := extractSHAFromTag(currentTag)
224	latestSHA := extractSHAFromTag(latestTag)
225
226	if currentSHA == "" || latestSHA == "" {
227		return nil, fmt.Errorf("could not extract SHAs from tags")
228	}
229
230	// Find the range of commits between current and latest
231	var commits []CommitInfo
232	var foundLatest, foundCurrent bool
233
234	for _, c := range staticCommits {
235		if c.SHA == latestSHA {
236			foundLatest = true
237		}
238		if foundLatest && !foundCurrent {
239			commits = append(commits, CommitInfo{
240				SHA:     c.SHA,
241				Message: c.Subject,
242			})
243		}
244		if c.SHA == currentSHA {
245			foundCurrent = true
246			break
247		}
248	}
249
250	// If we didn't find both SHAs, the commits might be too old (outside 500 range)
251	if !foundLatest || !foundCurrent {
252		return nil, fmt.Errorf("commits not found in static list")
253	}
254
255	// Remove the current commit itself from the list (we want commits after current)
256	if len(commits) > 0 && commits[len(commits)-1].SHA == currentSHA {
257		commits = commits[:len(commits)-1]
258	}
259
260	return commits, nil
261}
262
263// extractSHAFromTag extracts the short commit SHA from a version tag.
264// Tags are formatted as v0.COUNT.9OCTAL where OCTAL is the SHA in octal.
265func extractSHAFromTag(tag string) string {
266	// Tag format: v0.178.9XXXXX where XXXXX is octal-encoded 6-char hex SHA
267	if len(tag) < 3 || tag[0] != 'v' {
268		return ""
269	}
270
271	// Find the last dot
272	lastDot := -1
273	for i := len(tag) - 1; i >= 0; i-- {
274		if tag[i] == '.' {
275			lastDot = i
276			break
277		}
278	}
279	if lastDot == -1 {
280		return ""
281	}
282
283	// Extract the patch part (9XXXXX)
284	patch := tag[lastDot+1:]
285	if len(patch) < 2 || patch[0] != '9' {
286		return ""
287	}
288
289	// Parse the octal number after '9'
290	octal := patch[1:]
291	var hexVal uint64
292	for _, c := range octal {
293		if c < '0' || c > '7' {
294			return ""
295		}
296		hexVal = hexVal*8 + uint64(c-'0')
297	}
298
299	// Convert back to 6-char hex SHA (short SHA)
300	return fmt.Sprintf("%06x", hexVal)
301}
302
303// fetchLatestRelease fetches the latest release info from GitHub Pages.
304func (vc *VersionChecker) fetchLatestRelease(ctx context.Context) (*ReleaseInfo, error) {
305	url := staticMetadataURL + "/release.json"
306
307	req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
308	if err != nil {
309		return nil, err
310	}
311	req.Header.Set("User-Agent", "Shelley-VersionChecker")
312
313	resp, err := http.DefaultClient.Do(req)
314	if err != nil {
315		return nil, err
316	}
317	defer resp.Body.Close()
318
319	if resp.StatusCode != http.StatusOK {
320		return nil, fmt.Errorf("failed to fetch release info: status %d", resp.StatusCode)
321	}
322
323	var release ReleaseInfo
324	if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
325		return nil, err
326	}
327
328	return &release, nil
329}
330
331// findDownloadURL finds the appropriate download URL for the current platform.
332func (vc *VersionChecker) findDownloadURL(release *ReleaseInfo) string {
333	key := fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH)
334	if url, ok := release.DownloadURLs[key]; ok {
335		return url
336	}
337	return ""
338}
339
340// isNewerMinor checks if latest has a higher minor version than current.
341func (vc *VersionChecker) isNewerMinor(currentTag, latestTag string) bool {
342	currentMinor := parseMinorVersion(currentTag)
343	latestMinor := parseMinorVersion(latestTag)
344	return latestMinor > currentMinor
345}
346
347// parseMinorVersion extracts the X from v0.X.Y format.
348func parseMinorVersion(tag string) int {
349	if len(tag) < 2 || tag[0] != 'v' {
350		return 0
351	}
352
353	// Skip 'v'
354	s := tag[1:]
355
356	// Find first dot
357	firstDot := -1
358	for i := 0; i < len(s); i++ {
359		if s[i] == '.' {
360			firstDot = i
361			break
362		}
363	}
364	if firstDot == -1 {
365		return 0
366	}
367
368	// Skip major version and dot
369	s = s[firstDot+1:]
370
371	// Parse minor version
372	var minor int
373	for i := 0; i < len(s); i++ {
374		if s[i] >= '0' && s[i] <= '9' {
375			minor = minor*10 + int(s[i]-'0')
376		} else {
377			break
378		}
379	}
380
381	return minor
382}
383
384// DoUpgrade downloads and applies the update with checksum verification.
385func (vc *VersionChecker) DoUpgrade(ctx context.Context) error {
386	if vc.skipCheck {
387		return fmt.Errorf("version checking is disabled")
388	}
389
390	// Get cached info or fetch fresh
391	info, err := vc.Check(ctx, false)
392	if err != nil {
393		return fmt.Errorf("failed to check version: %w", err)
394	}
395
396	if !info.HasUpdate {
397		return fmt.Errorf("no update available")
398	}
399
400	if info.DownloadURL == "" {
401		return fmt.Errorf("no download URL for %s/%s", runtime.GOOS, runtime.GOARCH)
402	}
403
404	if info.ReleaseInfo == nil {
405		return fmt.Errorf("no release info available")
406	}
407
408	// Find and download checksums.txt
409	expectedChecksum, err := vc.fetchExpectedChecksum(ctx, info.ReleaseInfo)
410	if err != nil {
411		return fmt.Errorf("failed to fetch checksum: %w", err)
412	}
413
414	// Download the binary
415	resp, err := http.Get(info.DownloadURL)
416	if err != nil {
417		return fmt.Errorf("failed to download update: %w", err)
418	}
419	defer resp.Body.Close()
420
421	if resp.StatusCode != http.StatusOK {
422		return fmt.Errorf("download returned status %d", resp.StatusCode)
423	}
424
425	// Read the entire binary to verify checksum before applying
426	binaryData, err := io.ReadAll(resp.Body)
427	if err != nil {
428		return fmt.Errorf("failed to read update: %w", err)
429	}
430
431	// Verify checksum
432	actualChecksum := sha256.Sum256(binaryData)
433	actualChecksumHex := hex.EncodeToString(actualChecksum[:])
434
435	if actualChecksumHex != expectedChecksum {
436		return fmt.Errorf("checksum mismatch: expected %s, got %s", expectedChecksum, actualChecksumHex)
437	}
438
439	// Apply the update
440	err = selfupdate.Apply(bytes.NewReader(binaryData), selfupdate.Options{})
441	if err == nil {
442		return nil
443	}
444
445	// Check if the error is permission-related and sudo is available
446	if !isPermissionError(err) {
447		return fmt.Errorf("failed to apply update: %w", err)
448	}
449
450	if !isSudoAvailable() {
451		return fmt.Errorf("failed to apply update (no write permission and sudo not available): %w", err)
452	}
453
454	// Fall back to sudo-based upgrade
455	return vc.doSudoUpgrade(binaryData)
456}
457
458// isPermissionError checks if the error is related to file permissions.
459func isPermissionError(err error) bool {
460	return errors.Is(err, fs.ErrPermission) || errors.Is(err, os.ErrPermission)
461}
462
463// doSudoUpgrade performs the upgrade using sudo when the binary isn't writable.
464func (vc *VersionChecker) doSudoUpgrade(binaryData []byte) error {
465	// Get the path to the current executable
466	exePath, err := os.Executable()
467	if err != nil {
468		return fmt.Errorf("failed to get executable path: %w", err)
469	}
470
471	// Write the new binary to a temp file
472	tmpFile, err := os.CreateTemp("", "shelley-upgrade-*")
473	if err != nil {
474		return fmt.Errorf("failed to create temp file: %w", err)
475	}
476	tmpPath := tmpFile.Name()
477	defer os.Remove(tmpPath)
478
479	if _, err := tmpFile.Write(binaryData); err != nil {
480		tmpFile.Close()
481		return fmt.Errorf("failed to write temp file: %w", err)
482	}
483	if err := tmpFile.Close(); err != nil {
484		return fmt.Errorf("failed to close temp file: %w", err)
485	}
486
487	// Make the temp file executable
488	if err := os.Chmod(tmpPath, 0o755); err != nil {
489		return fmt.Errorf("failed to chmod temp file: %w", err)
490	}
491
492	// Use sudo to install the new binary. We can't cp over a running binary ("Text file busy"),
493	// so we cp to a .new file and then mv (which is atomic and works on running binaries).
494	newPath := exePath + ".new"
495	oldPath := exePath + ".old"
496
497	// Copy new binary to .new location
498	cmd := exec.Command("sudo", "cp", tmpPath, newPath)
499	if output, err := cmd.CombinedOutput(); err != nil {
500		return fmt.Errorf("failed to copy new binary: %w: %s", err, output)
501	}
502
503	// Copy ownership and permissions from original
504	cmd = exec.Command("sudo", "chown", "--reference="+exePath, newPath)
505	if output, err := cmd.CombinedOutput(); err != nil {
506		exec.Command("sudo", "rm", "-f", newPath).Run()
507		return fmt.Errorf("failed to set ownership: %w: %s", err, output)
508	}
509
510	cmd = exec.Command("sudo", "chmod", "--reference="+exePath, newPath)
511	if output, err := cmd.CombinedOutput(); err != nil {
512		exec.Command("sudo", "rm", "-f", newPath).Run()
513		return fmt.Errorf("failed to set permissions: %w: %s", err, output)
514	}
515
516	// Rename old binary to .old (backup)
517	cmd = exec.Command("sudo", "mv", exePath, oldPath)
518	if output, err := cmd.CombinedOutput(); err != nil {
519		exec.Command("sudo", "rm", "-f", newPath).Run()
520		return fmt.Errorf("failed to backup old binary: %w: %s", err, output)
521	}
522
523	// Rename .new to target (atomic replacement)
524	cmd = exec.Command("sudo", "mv", newPath, exePath)
525	if output, err := cmd.CombinedOutput(); err != nil {
526		// Try to restore the old binary
527		exec.Command("sudo", "mv", oldPath, exePath).Run()
528		return fmt.Errorf("failed to install new binary: %w: %s", err, output)
529	}
530
531	// Remove the backup
532	cmd = exec.Command("sudo", "rm", "-f", oldPath)
533	cmd.Run() // Best effort, ignore errors
534
535	return nil
536}
537
538// fetchExpectedChecksum downloads checksums.txt and extracts the expected checksum for our binary.
539func (vc *VersionChecker) fetchExpectedChecksum(ctx context.Context, release *ReleaseInfo) (string, error) {
540	checksumURL := release.ChecksumsURL
541	if checksumURL == "" {
542		return "", fmt.Errorf("checksums.txt URL not found in release")
543	}
544
545	// Download checksums.txt
546	req, err := http.NewRequestWithContext(ctx, "GET", checksumURL, nil)
547	if err != nil {
548		return "", err
549	}
550
551	resp, err := http.DefaultClient.Do(req)
552	if err != nil {
553		return "", err
554	}
555	defer resp.Body.Close()
556
557	if resp.StatusCode != http.StatusOK {
558		return "", fmt.Errorf("failed to download checksums: status %d", resp.StatusCode)
559	}
560
561	// Parse checksums.txt (format: "checksum  filename")
562	expectedBinaryName := fmt.Sprintf("shelley_%s_%s", runtime.GOOS, runtime.GOARCH)
563
564	scanner := bufio.NewScanner(resp.Body)
565	for scanner.Scan() {
566		line := scanner.Text()
567		parts := strings.Fields(line)
568		if len(parts) >= 2 {
569			checksum := parts[0]
570			filename := parts[1]
571			if filename == expectedBinaryName {
572				return checksum, nil
573			}
574		}
575	}
576
577	if err := scanner.Err(); err != nil {
578		return "", fmt.Errorf("error reading checksums: %w", err)
579	}
580
581	return "", fmt.Errorf("checksum for %s not found in checksums.txt", expectedBinaryName)
582}