provider.go

  1// Package hyper provides a fantasy.Provider that proxies requests to Hyper.
  2package hyper
  3
  4import (
  5	"bufio"
  6	"bytes"
  7	"cmp"
  8	"context"
  9	_ "embed"
 10	"encoding/json"
 11	"errors"
 12	"fmt"
 13	"io"
 14	"log/slog"
 15	"maps"
 16	"net/http"
 17	"net/url"
 18	"os"
 19	"strconv"
 20	"strings"
 21	"sync"
 22	"time"
 23
 24	"charm.land/catwalk/pkg/catwalk"
 25	"charm.land/fantasy"
 26	"charm.land/fantasy/object"
 27	"github.com/charmbracelet/crush/internal/event"
 28)
 29
 30//go:generate wget -O provider.json https://hyper.charm.land/api/v1/provider
 31
 32//go:embed provider.json
 33var embedded []byte
 34
 35// Enabled returns true if hyper is enabled.
 36var Enabled = sync.OnceValue(func() bool {
 37	b, _ := strconv.ParseBool(
 38		cmp.Or(
 39			os.Getenv("HYPER"),
 40			os.Getenv("HYPERCRUSH"),
 41			os.Getenv("HYPER_ENABLE"),
 42			os.Getenv("HYPER_ENABLED"),
 43		),
 44	)
 45	return b
 46})
 47
 48// Embedded returns the embedded Hyper provider.
 49var Embedded = sync.OnceValue(func() catwalk.Provider {
 50	var provider catwalk.Provider
 51	if err := json.Unmarshal(embedded, &provider); err != nil {
 52		slog.Error("Could not use embedded provider data", "err", err)
 53	}
 54	if e := os.Getenv("HYPER_URL"); e != "" {
 55		provider.APIEndpoint = e + "/api/v1/fantasy"
 56	}
 57	return provider
 58})
 59
 60const (
 61	// Name is the default name of this meta provider.
 62	Name = "hyper"
 63	// defaultBaseURL is the default proxy URL.
 64	defaultBaseURL = "https://hyper.charm.land"
 65)
 66
 67// BaseURL returns the base URL, which is either $HYPER_URL or the default.
 68var BaseURL = sync.OnceValue(func() string {
 69	return cmp.Or(os.Getenv("HYPER_URL"), defaultBaseURL)
 70})
 71
 72var (
 73	ErrNoCredits    = errors.New("you're out of credits")
 74	ErrUnauthorized = errors.New("unauthorized")
 75)
 76
 77type options struct {
 78	baseURL string
 79	apiKey  string
 80	name    string
 81	headers map[string]string
 82	client  *http.Client
 83}
 84
 85// Option configures the proxy provider.
 86type Option = func(*options)
 87
 88// New creates a new proxy provider.
 89func New(opts ...Option) (fantasy.Provider, error) {
 90	o := options{
 91		baseURL: BaseURL() + "/api/v1/fantasy",
 92		name:    Name,
 93		headers: map[string]string{
 94			"x-crush-id": event.GetID(),
 95		},
 96		client: &http.Client{Timeout: 0}, // stream-safe
 97	}
 98	for _, opt := range opts {
 99		opt(&o)
100	}
101	return &provider{options: o}, nil
102}
103
104// WithBaseURL sets the proxy base URL (e.g. http://localhost:8080).
105func WithBaseURL(url string) Option { return func(o *options) { o.baseURL = url } }
106
107// WithName sets the provider name.
108func WithName(name string) Option { return func(o *options) { o.name = name } }
109
110// WithHeaders sets extra headers sent to the proxy.
111func WithHeaders(headers map[string]string) Option {
112	return func(o *options) {
113		maps.Copy(o.headers, headers)
114	}
115}
116
117// WithHTTPClient sets custom HTTP client.
118func WithHTTPClient(c *http.Client) Option { return func(o *options) { o.client = c } }
119
120// WithAPIKey sets the API key.
121func WithAPIKey(key string) Option {
122	return func(o *options) {
123		o.apiKey = key
124	}
125}
126
127type provider struct{ options options }
128
129func (p *provider) Name() string { return p.options.name }
130
131// LanguageModel implements fantasy.Provider.
132func (p *provider) LanguageModel(_ context.Context, modelID string) (fantasy.LanguageModel, error) {
133	if modelID == "" {
134		return nil, errors.New("missing model id")
135	}
136	return &languageModel{modelID: modelID, provider: p.options.name, opts: p.options}, nil
137}
138
139type languageModel struct {
140	provider string
141	modelID  string
142	opts     options
143}
144
145// GenerateObject implements fantasy.LanguageModel.
146func (m *languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
147	return object.GenerateWithTool(ctx, m, call)
148}
149
150// StreamObject implements fantasy.LanguageModel.
151func (m *languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
152	return object.StreamWithTool(ctx, m, call)
153}
154
155func (m *languageModel) Provider() string { return m.provider }
156func (m *languageModel) Model() string    { return m.modelID }
157
158// Generate implements fantasy.LanguageModel by calling the proxy JSON endpoint.
159func (m *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
160	resp, err := m.doRequest(ctx, false, call)
161	if err != nil {
162		return nil, err
163	}
164	defer func() { _ = resp.Body.Close() }()
165	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
166		b, _ := ioReadAllLimit(resp.Body, 64*1024)
167		return nil, fmt.Errorf("proxy generate error: %s", strings.TrimSpace(string(b)))
168	}
169	var out fantasy.Response
170	if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
171		return nil, err
172	}
173	return &out, nil
174}
175
176// Stream implements fantasy.LanguageModel using SSE from the proxy.
177func (m *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
178	// Prefer explicit /stream endpoint
179	resp, err := m.doRequest(ctx, true, call)
180	if err != nil {
181		return nil, err
182	}
183	switch resp.StatusCode {
184	case http.StatusTooManyRequests:
185		_ = resp.Body.Close()
186		return nil, toProviderError(resp, retryAfter(resp))
187	case http.StatusUnauthorized:
188		_ = resp.Body.Close()
189		return nil, ErrUnauthorized
190	case http.StatusPaymentRequired:
191		_ = resp.Body.Close()
192		return nil, ErrNoCredits
193	}
194
195	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
196		defer func() { _ = resp.Body.Close() }()
197		b, _ := ioReadAllLimit(resp.Body, 64*1024)
198		return nil, &fantasy.ProviderError{
199			Title:      "Stream Error",
200			Message:    strings.TrimSpace(string(b)),
201			StatusCode: resp.StatusCode,
202		}
203	}
204
205	return func(yield func(fantasy.StreamPart) bool) {
206		defer func() { _ = resp.Body.Close() }()
207		scanner := bufio.NewScanner(resp.Body)
208		buf := make([]byte, 0, 64*1024)
209		scanner.Buffer(buf, 4*1024*1024)
210
211		var (
212			event     string
213			dataBuf   bytes.Buffer
214			sawFinish bool
215			dispatch  = func() bool {
216				if dataBuf.Len() == 0 || event == "" {
217					dataBuf.Reset()
218					event = ""
219					return true
220				}
221				var part fantasy.StreamPart
222				if err := json.Unmarshal(dataBuf.Bytes(), &part); err != nil {
223					return yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err})
224				}
225				if part.Type == fantasy.StreamPartTypeFinish {
226					sawFinish = true
227				}
228				ok := yield(part)
229				dataBuf.Reset()
230				event = ""
231				return ok
232			}
233		)
234
235		for scanner.Scan() {
236			line := scanner.Text()
237			if line == "" { // event boundary
238				if !dispatch() {
239					return
240				}
241				continue
242			}
243			if strings.HasPrefix(line, ":") { // comment / ping
244				continue
245			}
246			if strings.HasPrefix(line, "event: ") {
247				event = strings.TrimSpace(line[len("event: "):])
248				continue
249			}
250			if strings.HasPrefix(line, "data: ") {
251				if dataBuf.Len() > 0 {
252					dataBuf.WriteByte('\n')
253				}
254				dataBuf.WriteString(line[len("data: "):])
255				continue
256			}
257		}
258		if err := scanner.Err(); err != nil {
259			if sawFinish && (errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)) {
260				// If we already saw an explicit finish event, treat cancellation as a no-op.
261			} else {
262				_ = yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err})
263				return
264			}
265		}
266		if err := ctx.Err(); err != nil && !sawFinish {
267			_ = yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err})
268			return
269		}
270		// flush any pending data
271		_ = dispatch()
272		if !sawFinish {
273			_ = yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish})
274		}
275	}, nil
276}
277
278func (m *languageModel) doRequest(ctx context.Context, stream bool, call fantasy.Call) (*http.Response, error) {
279	addr, err := url.Parse(m.opts.baseURL)
280	if err != nil {
281		return nil, err
282	}
283	addr = addr.JoinPath(m.modelID)
284	if stream {
285		addr = addr.JoinPath("stream")
286	} else {
287		addr = addr.JoinPath("generate")
288	}
289
290	body, err := json.Marshal(call)
291	if err != nil {
292		return nil, err
293	}
294
295	req, err := http.NewRequestWithContext(ctx, http.MethodPost, addr.String(), bytes.NewReader(body))
296	if err != nil {
297		return nil, err
298	}
299	req.Header.Set("Content-Type", "application/json")
300	if stream {
301		req.Header.Set("Accept", "text/event-stream")
302	} else {
303		req.Header.Set("Accept", "application/json")
304	}
305	for k, v := range m.opts.headers {
306		req.Header.Set(k, v)
307	}
308
309	if m.opts.apiKey != "" {
310		req.Header.Set("Authorization", m.opts.apiKey)
311	}
312	return m.opts.client.Do(req)
313}
314
315// ioReadAllLimit reads up to n bytes.
316func ioReadAllLimit(r io.Reader, n int64) ([]byte, error) {
317	var b bytes.Buffer
318	if n <= 0 {
319		n = 1 << 20
320	}
321	lr := &io.LimitedReader{R: r, N: n}
322	_, err := b.ReadFrom(lr)
323	return b.Bytes(), err
324}
325
326func toProviderError(resp *http.Response, message string) error {
327	return &fantasy.ProviderError{
328		Title:      fantasy.ErrorTitleForStatusCode(resp.StatusCode),
329		Message:    message,
330		StatusCode: resp.StatusCode,
331	}
332}
333
334func retryAfter(resp *http.Response) string {
335	after, err := strconv.Atoi(resp.Header.Get("Retry-After"))
336	if err == nil && after > 0 {
337		d := time.Duration(after) * time.Second
338		return "Try again in " + d.String()
339	}
340	return "Try again later"
341}