provider.go

  1// Package hyper provides a fantasy.Provider that proxies requests to Hyper.
  2package hyper
  3
  4import (
  5	"bufio"
  6	"bytes"
  7	"cmp"
  8	"context"
  9	_ "embed"
 10	"encoding/json"
 11	"errors"
 12	"fmt"
 13	"io"
 14	"log/slog"
 15	"maps"
 16	"net/http"
 17	"net/url"
 18	"os"
 19	"strconv"
 20	"strings"
 21	"sync"
 22	"time"
 23
 24	"charm.land/catwalk/pkg/catwalk"
 25	"charm.land/fantasy"
 26	"charm.land/fantasy/object"
 27	"github.com/charmbracelet/crush/internal/event"
 28)
 29
 30//go:generate wget -O provider.json https://hyper.charm.land/api/v1/provider
 31
 32//go:embed provider.json
 33var embedded []byte
 34
 35// Enabled returns true if hyper is enabled.
 36var Enabled = sync.OnceValue(func() bool {
 37	b, _ := strconv.ParseBool(
 38		cmp.Or(
 39			os.Getenv("HYPER"),
 40			os.Getenv("HYPERCRUSH"),
 41			os.Getenv("HYPER_ENABLE"),
 42			os.Getenv("HYPER_ENABLED"),
 43		),
 44	)
 45	return b
 46})
 47
 48// Embedded returns the embedded Hyper provider.
 49var Embedded = sync.OnceValue(func() catwalk.Provider {
 50	var provider catwalk.Provider
 51	if err := json.Unmarshal(embedded, &provider); err != nil {
 52		slog.Error("Could not use embedded provider data", "err", err)
 53	}
 54	if e := os.Getenv("HYPER_URL"); e != "" {
 55		provider.APIEndpoint = e + "/api/v1/fantasy"
 56	}
 57	return provider
 58})
 59
 60const (
 61	// Name is the default name of this meta provider.
 62	Name = "hyper"
 63	// defaultBaseURL is the default proxy URL.
 64	defaultBaseURL = "https://hyper.charm.land"
 65)
 66
 67// BaseURL returns the base URL, which is either $HYPER_URL or the default.
 68var BaseURL = sync.OnceValue(func() string {
 69	return cmp.Or(os.Getenv("HYPER_URL"), defaultBaseURL)
 70})
 71
 72var ErrNoCredits = errors.New("you're out of credits")
 73
 74type options struct {
 75	baseURL string
 76	apiKey  string
 77	name    string
 78	headers map[string]string
 79	client  *http.Client
 80}
 81
 82// Option configures the proxy provider.
 83type Option = func(*options)
 84
 85// New creates a new proxy provider.
 86func New(opts ...Option) (fantasy.Provider, error) {
 87	o := options{
 88		baseURL: BaseURL() + "/api/v1/fantasy",
 89		name:    Name,
 90		headers: map[string]string{
 91			"x-crush-id": event.GetID(),
 92		},
 93		client: &http.Client{Timeout: 0}, // stream-safe
 94	}
 95	for _, opt := range opts {
 96		opt(&o)
 97	}
 98	return &provider{options: o}, nil
 99}
100
101// WithBaseURL sets the proxy base URL (e.g. http://localhost:8080).
102func WithBaseURL(url string) Option { return func(o *options) { o.baseURL = url } }
103
104// WithName sets the provider name.
105func WithName(name string) Option { return func(o *options) { o.name = name } }
106
107// WithHeaders sets extra headers sent to the proxy.
108func WithHeaders(headers map[string]string) Option {
109	return func(o *options) {
110		maps.Copy(o.headers, headers)
111	}
112}
113
114// WithHTTPClient sets custom HTTP client.
115func WithHTTPClient(c *http.Client) Option { return func(o *options) { o.client = c } }
116
117// WithAPIKey sets the API key.
118func WithAPIKey(key string) Option {
119	return func(o *options) {
120		o.apiKey = key
121	}
122}
123
124type provider struct{ options options }
125
126func (p *provider) Name() string { return p.options.name }
127
128// LanguageModel implements fantasy.Provider.
129func (p *provider) LanguageModel(_ context.Context, modelID string) (fantasy.LanguageModel, error) {
130	if modelID == "" {
131		return nil, errors.New("missing model id")
132	}
133	return &languageModel{modelID: modelID, provider: p.options.name, opts: p.options}, nil
134}
135
136type languageModel struct {
137	provider string
138	modelID  string
139	opts     options
140}
141
142// GenerateObject implements fantasy.LanguageModel.
143func (m *languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
144	return object.GenerateWithTool(ctx, m, call)
145}
146
147// StreamObject implements fantasy.LanguageModel.
148func (m *languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
149	return object.StreamWithTool(ctx, m, call)
150}
151
152func (m *languageModel) Provider() string { return m.provider }
153func (m *languageModel) Model() string    { return m.modelID }
154
155// Generate implements fantasy.LanguageModel by calling the proxy JSON endpoint.
156func (m *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
157	resp, err := m.doRequest(ctx, false, call)
158	if err != nil {
159		return nil, err
160	}
161	defer func() { _ = resp.Body.Close() }()
162	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
163		b, _ := ioReadAllLimit(resp.Body, 64*1024)
164		return nil, fmt.Errorf("proxy generate error: %s", strings.TrimSpace(string(b)))
165	}
166	var out fantasy.Response
167	if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
168		return nil, err
169	}
170	return &out, nil
171}
172
173// Stream implements fantasy.LanguageModel using SSE from the proxy.
174func (m *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
175	// Prefer explicit /stream endpoint
176	resp, err := m.doRequest(ctx, true, call)
177	if err != nil {
178		return nil, err
179	}
180	switch resp.StatusCode {
181	case http.StatusTooManyRequests:
182		_ = resp.Body.Close()
183		return nil, toProviderError(resp, retryAfter(resp))
184	case http.StatusUnauthorized:
185		_ = resp.Body.Close()
186		return nil, toProviderError(resp, "")
187	case http.StatusPaymentRequired:
188		_ = resp.Body.Close()
189		return nil, ErrNoCredits
190	}
191
192	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
193		defer func() { _ = resp.Body.Close() }()
194		b, _ := ioReadAllLimit(resp.Body, 64*1024)
195		return nil, &fantasy.ProviderError{
196			Title:      "Stream Error",
197			Message:    strings.TrimSpace(string(b)),
198			StatusCode: resp.StatusCode,
199		}
200	}
201
202	return func(yield func(fantasy.StreamPart) bool) {
203		defer func() { _ = resp.Body.Close() }()
204		scanner := bufio.NewScanner(resp.Body)
205		buf := make([]byte, 0, 64*1024)
206		scanner.Buffer(buf, 4*1024*1024)
207
208		var (
209			event     string
210			dataBuf   bytes.Buffer
211			sawFinish bool
212			dispatch  = func() bool {
213				if dataBuf.Len() == 0 || event == "" {
214					dataBuf.Reset()
215					event = ""
216					return true
217				}
218				var part fantasy.StreamPart
219				if err := json.Unmarshal(dataBuf.Bytes(), &part); err != nil {
220					return yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err})
221				}
222				if part.Type == fantasy.StreamPartTypeFinish {
223					sawFinish = true
224				}
225				ok := yield(part)
226				dataBuf.Reset()
227				event = ""
228				return ok
229			}
230		)
231
232		for scanner.Scan() {
233			line := scanner.Text()
234			if line == "" { // event boundary
235				if !dispatch() {
236					return
237				}
238				continue
239			}
240			if strings.HasPrefix(line, ":") { // comment / ping
241				continue
242			}
243			if strings.HasPrefix(line, "event: ") {
244				event = strings.TrimSpace(line[len("event: "):])
245				continue
246			}
247			if strings.HasPrefix(line, "data: ") {
248				if dataBuf.Len() > 0 {
249					dataBuf.WriteByte('\n')
250				}
251				dataBuf.WriteString(line[len("data: "):])
252				continue
253			}
254		}
255		if err := scanner.Err(); err != nil {
256			if sawFinish && (errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)) {
257				// If we already saw an explicit finish event, treat cancellation as a no-op.
258			} else {
259				_ = yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err})
260				return
261			}
262		}
263		if err := ctx.Err(); err != nil && !sawFinish {
264			_ = yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err})
265			return
266		}
267		// flush any pending data
268		_ = dispatch()
269		if !sawFinish {
270			_ = yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish})
271		}
272	}, nil
273}
274
275func (m *languageModel) doRequest(ctx context.Context, stream bool, call fantasy.Call) (*http.Response, error) {
276	addr, err := url.Parse(m.opts.baseURL)
277	if err != nil {
278		return nil, err
279	}
280	addr = addr.JoinPath(m.modelID)
281	if stream {
282		addr = addr.JoinPath("stream")
283	} else {
284		addr = addr.JoinPath("generate")
285	}
286
287	body, err := json.Marshal(call)
288	if err != nil {
289		return nil, err
290	}
291
292	req, err := http.NewRequestWithContext(ctx, http.MethodPost, addr.String(), bytes.NewReader(body))
293	if err != nil {
294		return nil, err
295	}
296	req.Header.Set("Content-Type", "application/json")
297	if stream {
298		req.Header.Set("Accept", "text/event-stream")
299	} else {
300		req.Header.Set("Accept", "application/json")
301	}
302	for k, v := range m.opts.headers {
303		req.Header.Set(k, v)
304	}
305
306	if m.opts.apiKey != "" {
307		req.Header.Set("Authorization", m.opts.apiKey)
308	}
309	return m.opts.client.Do(req)
310}
311
312// ioReadAllLimit reads up to n bytes.
313func ioReadAllLimit(r io.Reader, n int64) ([]byte, error) {
314	var b bytes.Buffer
315	if n <= 0 {
316		n = 1 << 20
317	}
318	lr := &io.LimitedReader{R: r, N: n}
319	_, err := b.ReadFrom(lr)
320	return b.Bytes(), err
321}
322
323func toProviderError(resp *http.Response, message string) error {
324	return &fantasy.ProviderError{
325		Title:      fantasy.ErrorTitleForStatusCode(resp.StatusCode),
326		Message:    message,
327		StatusCode: resp.StatusCode,
328	}
329}
330
331func retryAfter(resp *http.Response) string {
332	after, err := strconv.Atoi(resp.Header.Get("Retry-After"))
333	if err == nil && after > 0 {
334		d := time.Duration(after) * time.Second
335		return "Try again in " + d.String()
336	}
337	return "Try again later"
338}