1// Package hyper provides a fantasy.Provider that proxies requests to Hyper.
2package hyper
3
4import (
5 "bufio"
6 "bytes"
7 "cmp"
8 "context"
9 _ "embed"
10 "encoding/json"
11 "errors"
12 "fmt"
13 "io"
14 "log/slog"
15 "maps"
16 "net/http"
17 "net/url"
18 "os"
19 "strconv"
20 "strings"
21 "sync"
22 "time"
23
24 "charm.land/catwalk/pkg/catwalk"
25 "charm.land/fantasy"
26 "charm.land/fantasy/object"
27 "github.com/charmbracelet/crush/internal/event"
28)
29
30//go:generate wget -O provider.json https://hyper.charm.land/api/v1/provider
31
32//go:embed provider.json
33var embedded []byte
34
35// Enabled returns true if hyper is enabled.
36var Enabled = sync.OnceValue(func() bool {
37 b, _ := strconv.ParseBool(
38 cmp.Or(
39 os.Getenv("HYPER"),
40 os.Getenv("HYPERCRUSH"),
41 os.Getenv("HYPER_ENABLE"),
42 os.Getenv("HYPER_ENABLED"),
43 ),
44 )
45 return b
46})
47
48// Embedded returns the embedded Hyper provider.
49var Embedded = sync.OnceValue(func() catwalk.Provider {
50 var provider catwalk.Provider
51 if err := json.Unmarshal(embedded, &provider); err != nil {
52 slog.Error("Could not use embedded provider data", "err", err)
53 }
54 if e := os.Getenv("HYPER_URL"); e != "" {
55 provider.APIEndpoint = e + "/api/v1/fantasy"
56 }
57 return provider
58})
59
60const (
61 // Name is the default name of this meta provider.
62 Name = "hyper"
63 // defaultBaseURL is the default proxy URL.
64 defaultBaseURL = "https://hyper.charm.land"
65)
66
67// BaseURL returns the base URL, which is either $HYPER_URL or the default.
68var BaseURL = sync.OnceValue(func() string {
69 return cmp.Or(os.Getenv("HYPER_URL"), defaultBaseURL)
70})
71
72var (
73 ErrNoCredits = errors.New("you're out of credits")
74 ErrUnauthorized = errors.New("unauthorized")
75)
76
77type options struct {
78 baseURL string
79 apiKey string
80 name string
81 headers map[string]string
82 client *http.Client
83}
84
85// Option configures the proxy provider.
86type Option = func(*options)
87
88// New creates a new proxy provider.
89func New(opts ...Option) (fantasy.Provider, error) {
90 o := options{
91 baseURL: BaseURL() + "/api/v1/fantasy",
92 name: Name,
93 headers: map[string]string{
94 "x-crush-id": event.GetID(),
95 },
96 client: &http.Client{Timeout: 0}, // stream-safe
97 }
98 for _, opt := range opts {
99 opt(&o)
100 }
101 return &provider{options: o}, nil
102}
103
104// WithBaseURL sets the proxy base URL (e.g. http://localhost:8080).
105func WithBaseURL(url string) Option { return func(o *options) { o.baseURL = url } }
106
107// WithName sets the provider name.
108func WithName(name string) Option { return func(o *options) { o.name = name } }
109
110// WithHeaders sets extra headers sent to the proxy.
111func WithHeaders(headers map[string]string) Option {
112 return func(o *options) {
113 maps.Copy(o.headers, headers)
114 }
115}
116
117// WithHTTPClient sets custom HTTP client.
118func WithHTTPClient(c *http.Client) Option { return func(o *options) { o.client = c } }
119
120// WithAPIKey sets the API key.
121func WithAPIKey(key string) Option {
122 return func(o *options) {
123 o.apiKey = key
124 }
125}
126
127type provider struct{ options options }
128
129func (p *provider) Name() string { return p.options.name }
130
131// LanguageModel implements fantasy.Provider.
132func (p *provider) LanguageModel(_ context.Context, modelID string) (fantasy.LanguageModel, error) {
133 if modelID == "" {
134 return nil, errors.New("missing model id")
135 }
136 return &languageModel{modelID: modelID, provider: p.options.name, opts: p.options}, nil
137}
138
139type languageModel struct {
140 provider string
141 modelID string
142 opts options
143}
144
145// GenerateObject implements fantasy.LanguageModel.
146func (m *languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
147 return object.GenerateWithTool(ctx, m, call)
148}
149
150// StreamObject implements fantasy.LanguageModel.
151func (m *languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
152 return object.StreamWithTool(ctx, m, call)
153}
154
155func (m *languageModel) Provider() string { return m.provider }
156func (m *languageModel) Model() string { return m.modelID }
157
158// Generate implements fantasy.LanguageModel by calling the proxy JSON endpoint.
159func (m *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
160 resp, err := m.doRequest(ctx, false, call)
161 if err != nil {
162 return nil, err
163 }
164 defer func() { _ = resp.Body.Close() }()
165 if resp.StatusCode < 200 || resp.StatusCode >= 300 {
166 b, _ := ioReadAllLimit(resp.Body, 64*1024)
167 return nil, fmt.Errorf("proxy generate error: %s", strings.TrimSpace(string(b)))
168 }
169 var out fantasy.Response
170 if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
171 return nil, err
172 }
173 return &out, nil
174}
175
176// Stream implements fantasy.LanguageModel using SSE from the proxy.
177func (m *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
178 // Prefer explicit /stream endpoint
179 resp, err := m.doRequest(ctx, true, call)
180 if err != nil {
181 return nil, err
182 }
183 switch resp.StatusCode {
184 case http.StatusTooManyRequests:
185 _ = resp.Body.Close()
186 return nil, toProviderError(resp, retryAfter(resp))
187 case http.StatusUnauthorized:
188 _ = resp.Body.Close()
189 return nil, ErrUnauthorized
190 case http.StatusPaymentRequired:
191 _ = resp.Body.Close()
192 return nil, ErrNoCredits
193 }
194
195 if resp.StatusCode < 200 || resp.StatusCode >= 300 {
196 defer func() { _ = resp.Body.Close() }()
197 b, _ := ioReadAllLimit(resp.Body, 64*1024)
198 return nil, &fantasy.ProviderError{
199 Title: "Stream Error",
200 Message: strings.TrimSpace(string(b)),
201 StatusCode: resp.StatusCode,
202 }
203 }
204
205 return func(yield func(fantasy.StreamPart) bool) {
206 defer func() { _ = resp.Body.Close() }()
207 scanner := bufio.NewScanner(resp.Body)
208 buf := make([]byte, 0, 64*1024)
209 scanner.Buffer(buf, 4*1024*1024)
210
211 var (
212 event string
213 dataBuf bytes.Buffer
214 sawFinish bool
215 dispatch = func() bool {
216 if dataBuf.Len() == 0 || event == "" {
217 dataBuf.Reset()
218 event = ""
219 return true
220 }
221 var part fantasy.StreamPart
222 if err := json.Unmarshal(dataBuf.Bytes(), &part); err != nil {
223 return yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err})
224 }
225 if part.Type == fantasy.StreamPartTypeFinish {
226 sawFinish = true
227 }
228 ok := yield(part)
229 dataBuf.Reset()
230 event = ""
231 return ok
232 }
233 )
234
235 for scanner.Scan() {
236 line := scanner.Text()
237 if line == "" { // event boundary
238 if !dispatch() {
239 return
240 }
241 continue
242 }
243 if strings.HasPrefix(line, ":") { // comment / ping
244 continue
245 }
246 if strings.HasPrefix(line, "event: ") {
247 event = strings.TrimSpace(line[len("event: "):])
248 continue
249 }
250 if strings.HasPrefix(line, "data: ") {
251 if dataBuf.Len() > 0 {
252 dataBuf.WriteByte('\n')
253 }
254 dataBuf.WriteString(line[len("data: "):])
255 continue
256 }
257 }
258 if err := scanner.Err(); err != nil {
259 if sawFinish && (errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)) {
260 // If we already saw an explicit finish event, treat cancellation as a no-op.
261 } else {
262 _ = yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err})
263 return
264 }
265 }
266 if err := ctx.Err(); err != nil && !sawFinish {
267 _ = yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err})
268 return
269 }
270 // flush any pending data
271 _ = dispatch()
272 if !sawFinish {
273 _ = yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish})
274 }
275 }, nil
276}
277
278func (m *languageModel) doRequest(ctx context.Context, stream bool, call fantasy.Call) (*http.Response, error) {
279 addr, err := url.Parse(m.opts.baseURL)
280 if err != nil {
281 return nil, err
282 }
283 addr = addr.JoinPath(m.modelID)
284 if stream {
285 addr = addr.JoinPath("stream")
286 } else {
287 addr = addr.JoinPath("generate")
288 }
289
290 body, err := json.Marshal(call)
291 if err != nil {
292 return nil, err
293 }
294
295 req, err := http.NewRequestWithContext(ctx, http.MethodPost, addr.String(), bytes.NewReader(body))
296 if err != nil {
297 return nil, err
298 }
299 req.Header.Set("Content-Type", "application/json")
300 if stream {
301 req.Header.Set("Accept", "text/event-stream")
302 } else {
303 req.Header.Set("Accept", "application/json")
304 }
305 for k, v := range m.opts.headers {
306 req.Header.Set(k, v)
307 }
308
309 if m.opts.apiKey != "" {
310 req.Header.Set("Authorization", m.opts.apiKey)
311 }
312 return m.opts.client.Do(req)
313}
314
315// ioReadAllLimit reads up to n bytes.
316func ioReadAllLimit(r io.Reader, n int64) ([]byte, error) {
317 var b bytes.Buffer
318 if n <= 0 {
319 n = 1 << 20
320 }
321 lr := &io.LimitedReader{R: r, N: n}
322 _, err := b.ReadFrom(lr)
323 return b.Bytes(), err
324}
325
326func toProviderError(resp *http.Response, message string) error {
327 return &fantasy.ProviderError{
328 Title: fantasy.ErrorTitleForStatusCode(resp.StatusCode),
329 Message: message,
330 StatusCode: resp.StatusCode,
331 }
332}
333
334func retryAfter(resp *http.Response) string {
335 after, err := strconv.Atoi(resp.Header.Get("Retry-After"))
336 if err == nil && after > 0 {
337 d := time.Duration(after) * time.Second
338 return "Try again in " + d.String()
339 }
340 return "Try again later"
341}