1// Package hyper provides a fantasy.Provider that proxies requests to Hyper.
2package hyper
3
4import (
5 "bufio"
6 "bytes"
7 "cmp"
8 "context"
9 _ "embed"
10 "encoding/json"
11 "errors"
12 "fmt"
13 "io"
14 "log/slog"
15 "maps"
16 "net/http"
17 "net/url"
18 "os"
19 "strconv"
20 "strings"
21 "sync"
22 "time"
23
24 "charm.land/fantasy"
25 "charm.land/fantasy/object"
26 "github.com/charmbracelet/catwalk/pkg/catwalk"
27 "github.com/charmbracelet/crush/internal/event"
28)
29
30//go:generate wget -O provider.json https://console.charm.land/api/v1/provider
31
32//go:embed provider.json
33var embedded []byte
34
35// Enabled returns true if hyper is enabled.
36var Enabled = sync.OnceValue(func() bool {
37 b, _ := strconv.ParseBool(
38 cmp.Or(
39 os.Getenv("HYPER"),
40 os.Getenv("HYPERCRUSH"),
41 os.Getenv("HYPER_ENABLE"),
42 os.Getenv("HYPER_ENABLED"),
43 ),
44 )
45 return b
46})
47
48// Embedded returns the embedded Hyper provider.
49var Embedded = sync.OnceValue(func() catwalk.Provider {
50 var provider catwalk.Provider
51 if err := json.Unmarshal(embedded, &provider); err != nil {
52 slog.Error("could not use embedded provider data", "err", err)
53 }
54 return provider
55})
56
57const (
58 // Name is the default name of this meta provider.
59 Name = "hyper"
60 // defaultBaseURL is the default proxy URL.
61 // TODO: change this to production URL when ready.
62 defaultBaseURL = "https://console.charm.land"
63)
64
65// BaseURL returns the base URL, which is either $HYPER_URL or the default.
66var BaseURL = sync.OnceValue(func() string {
67 return cmp.Or(os.Getenv("HYPER_URL"), defaultBaseURL)
68})
69
70var ErrNoCredits = errors.New("you're out of credits")
71
72type options struct {
73 baseURL string
74 apiKey string
75 name string
76 headers map[string]string
77 client *http.Client
78}
79
80// Option configures the proxy provider.
81type Option = func(*options)
82
83// New creates a new proxy provider.
84func New(opts ...Option) (fantasy.Provider, error) {
85 o := options{
86 baseURL: BaseURL() + "/api/v1/fantasy",
87 name: Name,
88 headers: map[string]string{
89 "x-crush-id": event.GetID(),
90 },
91 client: &http.Client{Timeout: 0}, // stream-safe
92 }
93 for _, opt := range opts {
94 opt(&o)
95 }
96 return &provider{options: o}, nil
97}
98
99// WithBaseURL sets the proxy base URL (e.g. http://localhost:8080).
100func WithBaseURL(url string) Option { return func(o *options) { o.baseURL = url } }
101
102// WithName sets the provider name.
103func WithName(name string) Option { return func(o *options) { o.name = name } }
104
105// WithHeaders sets extra headers sent to the proxy.
106func WithHeaders(headers map[string]string) Option {
107 return func(o *options) {
108 maps.Copy(o.headers, headers)
109 }
110}
111
112// WithHTTPClient sets custom HTTP client.
113func WithHTTPClient(c *http.Client) Option { return func(o *options) { o.client = c } }
114
115// WithAPIKey sets the API key.
116func WithAPIKey(key string) Option {
117 return func(o *options) {
118 o.apiKey = key
119 }
120}
121
122type provider struct{ options options }
123
124func (p *provider) Name() string { return p.options.name }
125
126// LanguageModel implements fantasy.Provider.
127func (p *provider) LanguageModel(_ context.Context, modelID string) (fantasy.LanguageModel, error) {
128 if modelID == "" {
129 return nil, errors.New("missing model id")
130 }
131 return &languageModel{modelID: modelID, provider: p.options.name, opts: p.options}, nil
132}
133
134type languageModel struct {
135 provider string
136 modelID string
137 opts options
138}
139
140// GenerateObject implements fantasy.LanguageModel.
141func (m *languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
142 return object.GenerateWithTool(ctx, m, call)
143}
144
145// StreamObject implements fantasy.LanguageModel.
146func (m *languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
147 return object.StreamWithTool(ctx, m, call)
148}
149
150func (m *languageModel) Provider() string { return m.provider }
151func (m *languageModel) Model() string { return m.modelID }
152
153// Generate implements fantasy.LanguageModel by calling the proxy JSON endpoint.
154func (m *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
155 resp, err := m.doRequest(ctx, false, call)
156 if err != nil {
157 return nil, err
158 }
159 defer func() { _ = resp.Body.Close() }()
160 if resp.StatusCode < 200 || resp.StatusCode >= 300 {
161 b, _ := ioReadAllLimit(resp.Body, 64*1024)
162 return nil, fmt.Errorf("proxy generate error: %s", strings.TrimSpace(string(b)))
163 }
164 var out fantasy.Response
165 if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
166 return nil, err
167 }
168 return &out, nil
169}
170
171// Stream implements fantasy.LanguageModel using SSE from the proxy.
172func (m *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
173 // Prefer explicit /stream endpoint
174 resp, err := m.doRequest(ctx, true, call)
175 if err != nil {
176 return nil, err
177 }
178 switch resp.StatusCode {
179 case http.StatusTooManyRequests:
180 _ = resp.Body.Close()
181 return nil, toProviderError(resp, retryAfter(resp))
182 case http.StatusUnauthorized:
183 _ = resp.Body.Close()
184 return nil, toProviderError(resp, "")
185 case http.StatusPaymentRequired:
186 _ = resp.Body.Close()
187 return nil, ErrNoCredits
188 }
189
190 if resp.StatusCode < 200 || resp.StatusCode >= 300 {
191 defer func() { _ = resp.Body.Close() }()
192 b, _ := ioReadAllLimit(resp.Body, 64*1024)
193 return nil, &fantasy.ProviderError{
194 Title: "Stream Error",
195 Message: strings.TrimSpace(string(b)),
196 StatusCode: resp.StatusCode,
197 }
198 }
199
200 return func(yield func(fantasy.StreamPart) bool) {
201 defer func() { _ = resp.Body.Close() }()
202 scanner := bufio.NewScanner(resp.Body)
203 buf := make([]byte, 0, 64*1024)
204 scanner.Buffer(buf, 4*1024*1024)
205
206 var (
207 event string
208 dataBuf bytes.Buffer
209 sawFinish bool
210 dispatch = func() bool {
211 if dataBuf.Len() == 0 || event == "" {
212 dataBuf.Reset()
213 event = ""
214 return true
215 }
216 var part fantasy.StreamPart
217 if err := json.Unmarshal(dataBuf.Bytes(), &part); err != nil {
218 return yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err})
219 }
220 if part.Type == fantasy.StreamPartTypeFinish {
221 sawFinish = true
222 }
223 ok := yield(part)
224 dataBuf.Reset()
225 event = ""
226 return ok
227 }
228 )
229
230 for scanner.Scan() {
231 line := scanner.Text()
232 if line == "" { // event boundary
233 if !dispatch() {
234 return
235 }
236 continue
237 }
238 if strings.HasPrefix(line, ":") { // comment / ping
239 continue
240 }
241 if strings.HasPrefix(line, "event: ") {
242 event = strings.TrimSpace(line[len("event: "):])
243 continue
244 }
245 if strings.HasPrefix(line, "data: ") {
246 if dataBuf.Len() > 0 {
247 dataBuf.WriteByte('\n')
248 }
249 dataBuf.WriteString(line[len("data: "):])
250 continue
251 }
252 }
253 if err := scanner.Err(); err != nil &&
254 !errors.Is(err, context.Canceled) &&
255 !errors.Is(err, context.DeadlineExceeded) {
256 yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err})
257 return
258 }
259 // flush any pending data
260 _ = dispatch()
261 if !sawFinish {
262 _ = yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish})
263 }
264 }, nil
265}
266
267func (m *languageModel) doRequest(ctx context.Context, stream bool, call fantasy.Call) (*http.Response, error) {
268 addr, err := url.Parse(m.opts.baseURL)
269 if err != nil {
270 return nil, err
271 }
272 addr = addr.JoinPath(m.modelID)
273 if stream {
274 addr = addr.JoinPath("stream")
275 } else {
276 addr = addr.JoinPath("generate")
277 }
278
279 body, err := json.Marshal(call)
280 if err != nil {
281 return nil, err
282 }
283
284 req, err := http.NewRequestWithContext(ctx, http.MethodPost, addr.String(), bytes.NewReader(body))
285 if err != nil {
286 return nil, err
287 }
288 req.Header.Set("Content-Type", "application/json")
289 if stream {
290 req.Header.Set("Accept", "text/event-stream")
291 } else {
292 req.Header.Set("Accept", "application/json")
293 }
294 for k, v := range m.opts.headers {
295 req.Header.Set(k, v)
296 }
297
298 if m.opts.apiKey != "" {
299 req.Header.Set("Authorization", m.opts.apiKey)
300 }
301 return m.opts.client.Do(req)
302}
303
304// ioReadAllLimit reads up to n bytes.
305func ioReadAllLimit(r io.Reader, n int64) ([]byte, error) {
306 var b bytes.Buffer
307 if n <= 0 {
308 n = 1 << 20
309 }
310 lr := &io.LimitedReader{R: r, N: n}
311 _, err := b.ReadFrom(lr)
312 return b.Bytes(), err
313}
314
315func toProviderError(resp *http.Response, message string) error {
316 return &fantasy.ProviderError{
317 Title: fantasy.ErrorTitleForStatusCode(resp.StatusCode),
318 Message: message,
319 StatusCode: resp.StatusCode,
320 }
321}
322
323func retryAfter(resp *http.Response) string {
324 after, err := strconv.Atoi(resp.Header.Get("Retry-After"))
325 if err == nil && after > 0 {
326 d := time.Duration(after) * time.Second
327 return "Try again in " + d.String()
328 }
329 return "Try again in later"
330}