hyper.go

  1package config
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"log/slog"
  9	"net/http"
 10	"sync"
 11	"sync/atomic"
 12	"time"
 13
 14	"github.com/charmbracelet/catwalk/pkg/catwalk"
 15	"github.com/charmbracelet/crush/internal/agent/hyper"
 16	xetag "github.com/charmbracelet/x/etag"
 17)
 18
 19type hyperClient interface {
 20	Get(context.Context, string) (catwalk.Provider, error)
 21}
 22
 23var _ syncer[catwalk.Provider] = (*hyperSync)(nil)
 24
 25type hyperSync struct {
 26	once       sync.Once
 27	result     catwalk.Provider
 28	cache      cache[catwalk.Provider]
 29	client     hyperClient
 30	autoupdate bool
 31	init       atomic.Bool
 32}
 33
 34func (s *hyperSync) Init(client hyperClient, path string, autoupdate bool) {
 35	s.client = client
 36	s.cache = newCache[catwalk.Provider](path)
 37	s.autoupdate = autoupdate
 38	s.init.Store(true)
 39}
 40
 41func (s *hyperSync) Get(ctx context.Context) (catwalk.Provider, error) {
 42	if !s.init.Load() {
 43		panic("called Get before Init")
 44	}
 45
 46	var throwErr error
 47	s.once.Do(func() {
 48		if !s.autoupdate {
 49			slog.Info("Using embedded Hyper provider")
 50			s.result = hyper.Embedded()
 51			return
 52		}
 53
 54		cached, etag, cachedErr := s.cache.Get()
 55		if cached.ID == "" || cachedErr != nil {
 56			// if cached file is empty, default to embedded provider
 57			cached = hyper.Embedded()
 58		}
 59
 60		slog.Info("Fetching Hyper provider")
 61		result, err := s.client.Get(ctx, etag)
 62		if errors.Is(err, context.DeadlineExceeded) {
 63			slog.Warn("Hyper provider not updated in time")
 64			s.result = cached
 65			return
 66		}
 67		if errors.Is(err, catwalk.ErrNotModified) {
 68			slog.Info("Hyper provider not modified")
 69			s.result = cached
 70			return
 71		}
 72		if len(result.Models) == 0 {
 73			slog.Warn("Hyper did not return any models")
 74			s.result = cached
 75			return
 76		}
 77
 78		s.result = result
 79		throwErr = s.cache.Store(result)
 80	})
 81	return s.result, throwErr
 82}
 83
 84var _ hyperClient = realHyperClient{}
 85
 86type realHyperClient struct {
 87	baseURL string
 88}
 89
 90// Get implements hyperClient.
 91func (r realHyperClient) Get(ctx context.Context, etag string) (catwalk.Provider, error) {
 92	var result catwalk.Provider
 93	req, err := http.NewRequestWithContext(
 94		ctx,
 95		http.MethodGet,
 96		r.baseURL+"/api/v1/provider",
 97		nil,
 98	)
 99	if err != nil {
100		return result, fmt.Errorf("could not create request: %w", err)
101	}
102	xetag.Request(req, etag)
103
104	client := &http.Client{Timeout: 30 * time.Second}
105	resp, err := client.Do(req)
106	if err != nil {
107		return result, fmt.Errorf("failed to make request: %w", err)
108	}
109	defer resp.Body.Close() //nolint:errcheck
110
111	if resp.StatusCode == http.StatusNotModified {
112		return result, catwalk.ErrNotModified
113	}
114
115	if resp.StatusCode != http.StatusOK {
116		return result, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
117	}
118
119	if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
120		return result, fmt.Errorf("failed to decode response: %w", err)
121	}
122
123	return result, nil
124}