1// Package hyper provides a fantasy.Provider that proxies requests to Hyper.
2package hyper
3
4import (
5 "bufio"
6 "bytes"
7 "cmp"
8 "context"
9 _ "embed"
10 "encoding/json"
11 "errors"
12 "fmt"
13 "io"
14 "log/slog"
15 "maps"
16 "net/http"
17 "net/url"
18 "os"
19 "strconv"
20 "strings"
21 "sync"
22 "time"
23
24 "charm.land/fantasy"
25 "charm.land/fantasy/object"
26 "github.com/charmbracelet/catwalk/pkg/catwalk"
27 "github.com/charmbracelet/crush/internal/event"
28)
29
30//go:generate wget -O provider.json https://console.charm.land/api/v1/provider
31
32//go:embed provider.json
33var embedded []byte
34
35// Enabled returns true if hyper is enabled.
36var Enabled = sync.OnceValue(func() bool {
37 b, _ := strconv.ParseBool(
38 cmp.Or(
39 os.Getenv("HYPER"),
40 os.Getenv("HYPERCRUSH"),
41 os.Getenv("HYPER_ENABLE"),
42 os.Getenv("HYPER_ENABLED"),
43 ),
44 )
45 return b
46})
47
48// Embedded returns the embedded Hyper provider.
49var Embedded = sync.OnceValue(func() catwalk.Provider {
50 var provider catwalk.Provider
51 if err := json.Unmarshal(embedded, &provider); err != nil {
52 slog.Error("Could not use embedded provider data", "err", err)
53 }
54 if e := os.Getenv("HYPER_URL"); e != "" {
55 provider.APIEndpoint = e + "/api/v1/fantasy"
56 }
57 return provider
58})
59
60const (
61 // Name is the default name of this meta provider.
62 Name = "hyper"
63 // defaultBaseURL is the default proxy URL.
64 // TODO: change this to production URL when ready.
65 defaultBaseURL = "https://console.charm.land"
66)
67
68// BaseURL returns the base URL, which is either $HYPER_URL or the default.
69var BaseURL = sync.OnceValue(func() string {
70 return cmp.Or(os.Getenv("HYPER_URL"), defaultBaseURL)
71})
72
73var ErrNoCredits = errors.New("you're out of credits")
74
75type options struct {
76 baseURL string
77 apiKey string
78 name string
79 headers map[string]string
80 client *http.Client
81}
82
83// Option configures the proxy provider.
84type Option = func(*options)
85
86// New creates a new proxy provider.
87func New(opts ...Option) (fantasy.Provider, error) {
88 o := options{
89 baseURL: BaseURL() + "/api/v1/fantasy",
90 name: Name,
91 headers: map[string]string{
92 "x-crush-id": event.GetID(),
93 },
94 client: &http.Client{Timeout: 0}, // stream-safe
95 }
96 for _, opt := range opts {
97 opt(&o)
98 }
99 return &provider{options: o}, nil
100}
101
102// WithBaseURL sets the proxy base URL (e.g. http://localhost:8080).
103func WithBaseURL(url string) Option { return func(o *options) { o.baseURL = url } }
104
105// WithName sets the provider name.
106func WithName(name string) Option { return func(o *options) { o.name = name } }
107
108// WithHeaders sets extra headers sent to the proxy.
109func WithHeaders(headers map[string]string) Option {
110 return func(o *options) {
111 maps.Copy(o.headers, headers)
112 }
113}
114
115// WithHTTPClient sets custom HTTP client.
116func WithHTTPClient(c *http.Client) Option { return func(o *options) { o.client = c } }
117
118// WithAPIKey sets the API key.
119func WithAPIKey(key string) Option {
120 return func(o *options) {
121 o.apiKey = key
122 }
123}
124
125type provider struct{ options options }
126
127func (p *provider) Name() string { return p.options.name }
128
129// LanguageModel implements fantasy.Provider.
130func (p *provider) LanguageModel(_ context.Context, modelID string) (fantasy.LanguageModel, error) {
131 if modelID == "" {
132 return nil, errors.New("missing model id")
133 }
134 return &languageModel{modelID: modelID, provider: p.options.name, opts: p.options}, nil
135}
136
137type languageModel struct {
138 provider string
139 modelID string
140 opts options
141}
142
143// GenerateObject implements fantasy.LanguageModel.
144func (m *languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
145 return object.GenerateWithTool(ctx, m, call)
146}
147
148// StreamObject implements fantasy.LanguageModel.
149func (m *languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
150 return object.StreamWithTool(ctx, m, call)
151}
152
153func (m *languageModel) Provider() string { return m.provider }
154func (m *languageModel) Model() string { return m.modelID }
155
156// Generate implements fantasy.LanguageModel by calling the proxy JSON endpoint.
157func (m *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
158 resp, err := m.doRequest(ctx, false, call)
159 if err != nil {
160 return nil, err
161 }
162 defer func() { _ = resp.Body.Close() }()
163 if resp.StatusCode < 200 || resp.StatusCode >= 300 {
164 b, _ := ioReadAllLimit(resp.Body, 64*1024)
165 return nil, fmt.Errorf("proxy generate error: %s", strings.TrimSpace(string(b)))
166 }
167 var out fantasy.Response
168 if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
169 return nil, err
170 }
171 return &out, nil
172}
173
174// Stream implements fantasy.LanguageModel using SSE from the proxy.
175func (m *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
176 // Prefer explicit /stream endpoint
177 resp, err := m.doRequest(ctx, true, call)
178 if err != nil {
179 return nil, err
180 }
181 switch resp.StatusCode {
182 case http.StatusTooManyRequests:
183 _ = resp.Body.Close()
184 return nil, toProviderError(resp, retryAfter(resp))
185 case http.StatusUnauthorized:
186 _ = resp.Body.Close()
187 return nil, toProviderError(resp, "")
188 case http.StatusPaymentRequired:
189 _ = resp.Body.Close()
190 return nil, ErrNoCredits
191 }
192
193 if resp.StatusCode < 200 || resp.StatusCode >= 300 {
194 defer func() { _ = resp.Body.Close() }()
195 b, _ := ioReadAllLimit(resp.Body, 64*1024)
196 return nil, &fantasy.ProviderError{
197 Title: "Stream Error",
198 Message: strings.TrimSpace(string(b)),
199 StatusCode: resp.StatusCode,
200 }
201 }
202
203 return func(yield func(fantasy.StreamPart) bool) {
204 defer func() { _ = resp.Body.Close() }()
205 scanner := bufio.NewScanner(resp.Body)
206 buf := make([]byte, 0, 64*1024)
207 scanner.Buffer(buf, 4*1024*1024)
208
209 var (
210 event string
211 dataBuf bytes.Buffer
212 sawFinish bool
213 dispatch = func() bool {
214 if dataBuf.Len() == 0 || event == "" {
215 dataBuf.Reset()
216 event = ""
217 return true
218 }
219 var part fantasy.StreamPart
220 if err := json.Unmarshal(dataBuf.Bytes(), &part); err != nil {
221 return yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err})
222 }
223 if part.Type == fantasy.StreamPartTypeFinish {
224 sawFinish = true
225 }
226 ok := yield(part)
227 dataBuf.Reset()
228 event = ""
229 return ok
230 }
231 )
232
233 for scanner.Scan() {
234 line := scanner.Text()
235 if line == "" { // event boundary
236 if !dispatch() {
237 return
238 }
239 continue
240 }
241 if strings.HasPrefix(line, ":") { // comment / ping
242 continue
243 }
244 if strings.HasPrefix(line, "event: ") {
245 event = strings.TrimSpace(line[len("event: "):])
246 continue
247 }
248 if strings.HasPrefix(line, "data: ") {
249 if dataBuf.Len() > 0 {
250 dataBuf.WriteByte('\n')
251 }
252 dataBuf.WriteString(line[len("data: "):])
253 continue
254 }
255 }
256 if err := scanner.Err(); err != nil &&
257 !errors.Is(err, context.Canceled) &&
258 !errors.Is(err, context.DeadlineExceeded) {
259 yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err})
260 return
261 }
262 // flush any pending data
263 _ = dispatch()
264 if !sawFinish {
265 _ = yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish})
266 }
267 }, nil
268}
269
270func (m *languageModel) doRequest(ctx context.Context, stream bool, call fantasy.Call) (*http.Response, error) {
271 addr, err := url.Parse(m.opts.baseURL)
272 if err != nil {
273 return nil, err
274 }
275 addr = addr.JoinPath(m.modelID)
276 if stream {
277 addr = addr.JoinPath("stream")
278 } else {
279 addr = addr.JoinPath("generate")
280 }
281
282 body, err := json.Marshal(call)
283 if err != nil {
284 return nil, err
285 }
286
287 req, err := http.NewRequestWithContext(ctx, http.MethodPost, addr.String(), bytes.NewReader(body))
288 if err != nil {
289 return nil, err
290 }
291 req.Header.Set("Content-Type", "application/json")
292 if stream {
293 req.Header.Set("Accept", "text/event-stream")
294 } else {
295 req.Header.Set("Accept", "application/json")
296 }
297 for k, v := range m.opts.headers {
298 req.Header.Set(k, v)
299 }
300
301 if m.opts.apiKey != "" {
302 req.Header.Set("Authorization", m.opts.apiKey)
303 }
304 return m.opts.client.Do(req)
305}
306
307// ioReadAllLimit reads up to n bytes.
308func ioReadAllLimit(r io.Reader, n int64) ([]byte, error) {
309 var b bytes.Buffer
310 if n <= 0 {
311 n = 1 << 20
312 }
313 lr := &io.LimitedReader{R: r, N: n}
314 _, err := b.ReadFrom(lr)
315 return b.Bytes(), err
316}
317
318func toProviderError(resp *http.Response, message string) error {
319 return &fantasy.ProviderError{
320 Title: fantasy.ErrorTitleForStatusCode(resp.StatusCode),
321 Message: message,
322 StatusCode: resp.StatusCode,
323 }
324}
325
326func retryAfter(resp *http.Response) string {
327 after, err := strconv.Atoi(resp.Header.Get("Retry-After"))
328 if err == nil && after > 0 {
329 d := time.Duration(after) * time.Second
330 return "Try again in " + d.String()
331 }
332 return "Try again later"
333}