1// Package hyper provides a fantasy.Provider that proxies requests to Hyper.
2package hyper
3
4import (
5 "bufio"
6 "bytes"
7 "cmp"
8 "context"
9 _ "embed"
10 "encoding/json"
11 "errors"
12 "fmt"
13 "io"
14 "log/slog"
15 "maps"
16 "net/http"
17 "net/url"
18 "os"
19 "strconv"
20 "strings"
21 "sync"
22 "time"
23
24 "charm.land/catwalk/pkg/catwalk"
25 "charm.land/fantasy"
26 "charm.land/fantasy/object"
27 "github.com/charmbracelet/crush/internal/event"
28)
29
30//go:generate wget -O provider.json https://hyper.charm.land/api/v1/provider
31
32//go:embed provider.json
33var embedded []byte
34
35// Enabled returns true if hyper is enabled.
36var Enabled = sync.OnceValue(func() bool {
37 b, _ := strconv.ParseBool(
38 cmp.Or(
39 os.Getenv("HYPER"),
40 os.Getenv("HYPERCRUSH"),
41 os.Getenv("HYPER_ENABLE"),
42 os.Getenv("HYPER_ENABLED"),
43 ),
44 )
45 return b
46})
47
48// Embedded returns the embedded Hyper provider.
49var Embedded = sync.OnceValue(func() catwalk.Provider {
50 var provider catwalk.Provider
51 if err := json.Unmarshal(embedded, &provider); err != nil {
52 slog.Error("Could not use embedded provider data", "err", err)
53 }
54 if e := os.Getenv("HYPER_URL"); e != "" {
55 provider.APIEndpoint = e + "/api/v1/fantasy"
56 }
57 return provider
58})
59
60const (
61 // Name is the default name of this meta provider.
62 Name = "hyper"
63 // defaultBaseURL is the default proxy URL.
64 defaultBaseURL = "https://hyper.charm.land"
65)
66
67// BaseURL returns the base URL, which is either $HYPER_URL or the default.
68var BaseURL = sync.OnceValue(func() string {
69 return cmp.Or(os.Getenv("HYPER_URL"), defaultBaseURL)
70})
71
72var ErrNoCredits = errors.New("you're out of credits")
73
74type options struct {
75 baseURL string
76 apiKey string
77 name string
78 headers map[string]string
79 client *http.Client
80}
81
82// Option configures the proxy provider.
83type Option = func(*options)
84
85// New creates a new proxy provider.
86func New(opts ...Option) (fantasy.Provider, error) {
87 o := options{
88 baseURL: BaseURL() + "/api/v1/fantasy",
89 name: Name,
90 headers: map[string]string{
91 "x-crush-id": event.GetID(),
92 },
93 client: &http.Client{Timeout: 0}, // stream-safe
94 }
95 for _, opt := range opts {
96 opt(&o)
97 }
98 return &provider{options: o}, nil
99}
100
101// WithBaseURL sets the proxy base URL (e.g. http://localhost:8080).
102func WithBaseURL(url string) Option { return func(o *options) { o.baseURL = url } }
103
104// WithName sets the provider name.
105func WithName(name string) Option { return func(o *options) { o.name = name } }
106
107// WithHeaders sets extra headers sent to the proxy.
108func WithHeaders(headers map[string]string) Option {
109 return func(o *options) {
110 maps.Copy(o.headers, headers)
111 }
112}
113
114// WithHTTPClient sets custom HTTP client.
115func WithHTTPClient(c *http.Client) Option { return func(o *options) { o.client = c } }
116
117// WithAPIKey sets the API key.
118func WithAPIKey(key string) Option {
119 return func(o *options) {
120 o.apiKey = key
121 }
122}
123
124type provider struct{ options options }
125
126func (p *provider) Name() string { return p.options.name }
127
128// LanguageModel implements fantasy.Provider.
129func (p *provider) LanguageModel(_ context.Context, modelID string) (fantasy.LanguageModel, error) {
130 if modelID == "" {
131 return nil, errors.New("missing model id")
132 }
133 return &languageModel{modelID: modelID, provider: p.options.name, opts: p.options}, nil
134}
135
136type languageModel struct {
137 provider string
138 modelID string
139 opts options
140}
141
142// GenerateObject implements fantasy.LanguageModel.
143func (m *languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
144 return object.GenerateWithTool(ctx, m, call)
145}
146
147// StreamObject implements fantasy.LanguageModel.
148func (m *languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
149 return object.StreamWithTool(ctx, m, call)
150}
151
152func (m *languageModel) Provider() string { return m.provider }
153func (m *languageModel) Model() string { return m.modelID }
154
155// Generate implements fantasy.LanguageModel by calling the proxy JSON endpoint.
156func (m *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
157 resp, err := m.doRequest(ctx, false, call)
158 if err != nil {
159 return nil, err
160 }
161 defer func() { _ = resp.Body.Close() }()
162 if resp.StatusCode < 200 || resp.StatusCode >= 300 {
163 b, _ := ioReadAllLimit(resp.Body, 64*1024)
164 return nil, fmt.Errorf("proxy generate error: %s", strings.TrimSpace(string(b)))
165 }
166 var out fantasy.Response
167 if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
168 return nil, err
169 }
170 return &out, nil
171}
172
173// Stream implements fantasy.LanguageModel using SSE from the proxy.
174func (m *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
175 // Prefer explicit /stream endpoint
176 resp, err := m.doRequest(ctx, true, call)
177 if err != nil {
178 return nil, err
179 }
180 switch resp.StatusCode {
181 case http.StatusTooManyRequests:
182 _ = resp.Body.Close()
183 return nil, toProviderError(resp, retryAfter(resp))
184 case http.StatusUnauthorized:
185 _ = resp.Body.Close()
186 return nil, toProviderError(resp, "")
187 case http.StatusPaymentRequired:
188 _ = resp.Body.Close()
189 return nil, ErrNoCredits
190 }
191
192 if resp.StatusCode < 200 || resp.StatusCode >= 300 {
193 defer func() { _ = resp.Body.Close() }()
194 b, _ := ioReadAllLimit(resp.Body, 64*1024)
195 return nil, &fantasy.ProviderError{
196 Title: "Stream Error",
197 Message: strings.TrimSpace(string(b)),
198 StatusCode: resp.StatusCode,
199 }
200 }
201
202 return func(yield func(fantasy.StreamPart) bool) {
203 defer func() { _ = resp.Body.Close() }()
204 scanner := bufio.NewScanner(resp.Body)
205 buf := make([]byte, 0, 64*1024)
206 scanner.Buffer(buf, 4*1024*1024)
207
208 var (
209 event string
210 dataBuf bytes.Buffer
211 sawFinish bool
212 dispatch = func() bool {
213 if dataBuf.Len() == 0 || event == "" {
214 dataBuf.Reset()
215 event = ""
216 return true
217 }
218 var part fantasy.StreamPart
219 if err := json.Unmarshal(dataBuf.Bytes(), &part); err != nil {
220 return yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err})
221 }
222 if part.Type == fantasy.StreamPartTypeFinish {
223 sawFinish = true
224 }
225 ok := yield(part)
226 dataBuf.Reset()
227 event = ""
228 return ok
229 }
230 )
231
232 for scanner.Scan() {
233 line := scanner.Text()
234 if line == "" { // event boundary
235 if !dispatch() {
236 return
237 }
238 continue
239 }
240 if strings.HasPrefix(line, ":") { // comment / ping
241 continue
242 }
243 if strings.HasPrefix(line, "event: ") {
244 event = strings.TrimSpace(line[len("event: "):])
245 continue
246 }
247 if strings.HasPrefix(line, "data: ") {
248 if dataBuf.Len() > 0 {
249 dataBuf.WriteByte('\n')
250 }
251 dataBuf.WriteString(line[len("data: "):])
252 continue
253 }
254 }
255 if err := scanner.Err(); err != nil {
256 if sawFinish && (errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)) {
257 // If we already saw an explicit finish event, treat cancellation as a no-op.
258 } else {
259 _ = yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err})
260 return
261 }
262 }
263 if err := ctx.Err(); err != nil && !sawFinish {
264 _ = yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err})
265 return
266 }
267 // flush any pending data
268 _ = dispatch()
269 if !sawFinish {
270 _ = yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish})
271 }
272 }, nil
273}
274
275func (m *languageModel) doRequest(ctx context.Context, stream bool, call fantasy.Call) (*http.Response, error) {
276 addr, err := url.Parse(m.opts.baseURL)
277 if err != nil {
278 return nil, err
279 }
280 addr = addr.JoinPath(m.modelID)
281 if stream {
282 addr = addr.JoinPath("stream")
283 } else {
284 addr = addr.JoinPath("generate")
285 }
286
287 body, err := json.Marshal(call)
288 if err != nil {
289 return nil, err
290 }
291
292 req, err := http.NewRequestWithContext(ctx, http.MethodPost, addr.String(), bytes.NewReader(body))
293 if err != nil {
294 return nil, err
295 }
296 req.Header.Set("Content-Type", "application/json")
297 if stream {
298 req.Header.Set("Accept", "text/event-stream")
299 } else {
300 req.Header.Set("Accept", "application/json")
301 }
302 for k, v := range m.opts.headers {
303 req.Header.Set(k, v)
304 }
305
306 if m.opts.apiKey != "" {
307 req.Header.Set("Authorization", m.opts.apiKey)
308 }
309 return m.opts.client.Do(req)
310}
311
312// ioReadAllLimit reads up to n bytes.
313func ioReadAllLimit(r io.Reader, n int64) ([]byte, error) {
314 var b bytes.Buffer
315 if n <= 0 {
316 n = 1 << 20
317 }
318 lr := &io.LimitedReader{R: r, N: n}
319 _, err := b.ReadFrom(lr)
320 return b.Bytes(), err
321}
322
323func toProviderError(resp *http.Response, message string) error {
324 return &fantasy.ProviderError{
325 Title: fantasy.ErrorTitleForStatusCode(resp.StatusCode),
326 Message: message,
327 StatusCode: resp.StatusCode,
328 }
329}
330
331func retryAfter(resp *http.Response) string {
332 after, err := strconv.Atoi(resp.Header.Get("Retry-After"))
333 if err == nil && after > 0 {
334 d := time.Duration(after) * time.Second
335 return "Try again in " + d.String()
336 }
337 return "Try again later"
338}