1package config
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "log/slog"
9 "net/http"
10 "sync"
11 "sync/atomic"
12 "time"
13
14 "github.com/charmbracelet/catwalk/pkg/catwalk"
15 "github.com/charmbracelet/crush/internal/agent/hyper"
16 xetag "github.com/charmbracelet/x/etag"
17)
18
19type hyperClient interface {
20 Get(context.Context, string) (catwalk.Provider, error)
21}
22
23var _ syncer[catwalk.Provider] = (*hyperSync)(nil)
24
25type hyperSync struct {
26 once sync.Once
27 result catwalk.Provider
28 cache cache[catwalk.Provider]
29 client hyperClient
30 autoupdate bool
31 init atomic.Bool
32}
33
34func (s *hyperSync) Init(client hyperClient, path string, autoupdate bool) {
35 s.client = client
36 s.cache = newCache[catwalk.Provider](path)
37 s.autoupdate = autoupdate
38 s.init.Store(true)
39}
40
41func (s *hyperSync) Get(ctx context.Context) (catwalk.Provider, error) {
42 if !s.init.Load() {
43 panic("called Get before Init")
44 }
45
46 var throwErr error
47 s.once.Do(func() {
48 if !s.autoupdate {
49 slog.Info("Using embedded Hyper provider")
50 s.result = hyper.Embedded()
51 return
52 }
53
54 cached, etag, cachedErr := s.cache.Get()
55 if cached.ID == "" || cachedErr != nil {
56 // if cached file is empty, default to embedded provider
57 cached = hyper.Embedded()
58 }
59
60 slog.Info("Fetching Hyper provider")
61 result, err := s.client.Get(ctx, etag)
62 if errors.Is(err, context.DeadlineExceeded) {
63 slog.Warn("Hyper provider not updated in time")
64 s.result = cached
65 return
66 }
67 if errors.Is(err, catwalk.ErrNotModified) {
68 slog.Info("Hyper provider not modified")
69 s.result = cached
70 return
71 }
72 if len(result.Models) == 0 {
73 slog.Warn("Hyper did not return any models")
74 s.result = cached
75 return
76 }
77
78 s.result = result
79 throwErr = s.cache.Store(result)
80 })
81 return s.result, throwErr
82}
83
84var _ hyperClient = realHyperClient{}
85
86type realHyperClient struct {
87 baseURL string
88}
89
90// Get implements hyperClient.
91func (r realHyperClient) Get(ctx context.Context, etag string) (catwalk.Provider, error) {
92 var result catwalk.Provider
93 req, err := http.NewRequestWithContext(
94 ctx,
95 http.MethodGet,
96 r.baseURL+"/api/v1/provider",
97 nil,
98 )
99 if err != nil {
100 return result, fmt.Errorf("could not create request: %w", err)
101 }
102 xetag.Request(req, etag)
103
104 client := &http.Client{Timeout: 30 * time.Second}
105 resp, err := client.Do(req)
106 if err != nil {
107 return result, fmt.Errorf("failed to make request: %w", err)
108 }
109 defer resp.Body.Close() //nolint:errcheck
110
111 if resp.StatusCode == http.StatusNotModified {
112 return result, catwalk.ErrNotModified
113 }
114
115 if resp.StatusCode != http.StatusOK {
116 return result, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
117 }
118
119 if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
120 return result, fmt.Errorf("failed to decode response: %w", err)
121 }
122
123 return result, nil
124}