provider.go

  1// Package hyper provides a fantasy.Provider that proxies requests to Hyper.
  2package hyper
  3
  4import (
  5	"bufio"
  6	"bytes"
  7	"cmp"
  8	"context"
  9	_ "embed"
 10	"encoding/json"
 11	"errors"
 12	"fmt"
 13	"io"
 14	"log/slog"
 15	"maps"
 16	"net/http"
 17	"net/url"
 18	"os"
 19	"strconv"
 20	"strings"
 21	"sync"
 22	"time"
 23
 24	"charm.land/fantasy"
 25	"charm.land/fantasy/object"
 26	"github.com/charmbracelet/catwalk/pkg/catwalk"
 27	"github.com/charmbracelet/crush/internal/event"
 28)
 29
 30//go:generate wget -O provider.json https://console.charm.land/api/v1/provider
 31
 32//go:embed provider.json
 33var embedded []byte
 34
 35// Enabled returns true if hyper is enabled.
 36var Enabled = sync.OnceValue(func() bool {
 37	b, _ := strconv.ParseBool(
 38		cmp.Or(
 39			os.Getenv("HYPER"),
 40			os.Getenv("HYPERCRUSH"),
 41			os.Getenv("HYPER_ENABLE"),
 42			os.Getenv("HYPER_ENABLED"),
 43		),
 44	)
 45	return b
 46})
 47
 48// Embedded returns the embedded Hyper provider.
 49var Embedded = sync.OnceValue(func() catwalk.Provider {
 50	var provider catwalk.Provider
 51	if err := json.Unmarshal(embedded, &provider); err != nil {
 52		slog.Error("Could not use embedded provider data", "err", err)
 53	}
 54	if e := os.Getenv("HYPER_URL"); e != "" {
 55		provider.APIEndpoint = e + "/api/v1/fantasy"
 56	}
 57	return provider
 58})
 59
 60const (
 61	// Name is the default name of this meta provider.
 62	Name = "hyper"
 63	// defaultBaseURL is the default proxy URL.
 64	// TODO: change this to production URL when ready.
 65	defaultBaseURL = "https://console.charm.land"
 66)
 67
 68// BaseURL returns the base URL, which is either $HYPER_URL or the default.
 69var BaseURL = sync.OnceValue(func() string {
 70	return cmp.Or(os.Getenv("HYPER_URL"), defaultBaseURL)
 71})
 72
 73var ErrNoCredits = errors.New("you're out of credits")
 74
 75type options struct {
 76	baseURL string
 77	apiKey  string
 78	name    string
 79	headers map[string]string
 80	client  *http.Client
 81}
 82
 83// Option configures the proxy provider.
 84type Option = func(*options)
 85
 86// New creates a new proxy provider.
 87func New(opts ...Option) (fantasy.Provider, error) {
 88	o := options{
 89		baseURL: BaseURL() + "/api/v1/fantasy",
 90		name:    Name,
 91		headers: map[string]string{
 92			"x-crush-id": event.GetID(),
 93		},
 94		client: &http.Client{Timeout: 0}, // stream-safe
 95	}
 96	for _, opt := range opts {
 97		opt(&o)
 98	}
 99	return &provider{options: o}, nil
100}
101
102// WithBaseURL sets the proxy base URL (e.g. http://localhost:8080).
103func WithBaseURL(url string) Option { return func(o *options) { o.baseURL = url } }
104
105// WithName sets the provider name.
106func WithName(name string) Option { return func(o *options) { o.name = name } }
107
108// WithHeaders sets extra headers sent to the proxy.
109func WithHeaders(headers map[string]string) Option {
110	return func(o *options) {
111		maps.Copy(o.headers, headers)
112	}
113}
114
115// WithHTTPClient sets custom HTTP client.
116func WithHTTPClient(c *http.Client) Option { return func(o *options) { o.client = c } }
117
118// WithAPIKey sets the API key.
119func WithAPIKey(key string) Option {
120	return func(o *options) {
121		o.apiKey = key
122	}
123}
124
125type provider struct{ options options }
126
127func (p *provider) Name() string { return p.options.name }
128
129// LanguageModel implements fantasy.Provider.
130func (p *provider) LanguageModel(_ context.Context, modelID string) (fantasy.LanguageModel, error) {
131	if modelID == "" {
132		return nil, errors.New("missing model id")
133	}
134	return &languageModel{modelID: modelID, provider: p.options.name, opts: p.options}, nil
135}
136
137type languageModel struct {
138	provider string
139	modelID  string
140	opts     options
141}
142
143// GenerateObject implements fantasy.LanguageModel.
144func (m *languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
145	return object.GenerateWithTool(ctx, m, call)
146}
147
148// StreamObject implements fantasy.LanguageModel.
149func (m *languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
150	return object.StreamWithTool(ctx, m, call)
151}
152
153func (m *languageModel) Provider() string { return m.provider }
154func (m *languageModel) Model() string    { return m.modelID }
155
156// Generate implements fantasy.LanguageModel by calling the proxy JSON endpoint.
157func (m *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
158	resp, err := m.doRequest(ctx, false, call)
159	if err != nil {
160		return nil, err
161	}
162	defer func() { _ = resp.Body.Close() }()
163	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
164		b, _ := ioReadAllLimit(resp.Body, 64*1024)
165		return nil, fmt.Errorf("proxy generate error: %s", strings.TrimSpace(string(b)))
166	}
167	var out fantasy.Response
168	if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
169		return nil, err
170	}
171	return &out, nil
172}
173
174// Stream implements fantasy.LanguageModel using SSE from the proxy.
175func (m *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
176	// Prefer explicit /stream endpoint
177	resp, err := m.doRequest(ctx, true, call)
178	if err != nil {
179		return nil, err
180	}
181	switch resp.StatusCode {
182	case http.StatusTooManyRequests:
183		_ = resp.Body.Close()
184		return nil, toProviderError(resp, retryAfter(resp))
185	case http.StatusUnauthorized:
186		_ = resp.Body.Close()
187		return nil, toProviderError(resp, "")
188	case http.StatusPaymentRequired:
189		_ = resp.Body.Close()
190		return nil, ErrNoCredits
191	}
192
193	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
194		defer func() { _ = resp.Body.Close() }()
195		b, _ := ioReadAllLimit(resp.Body, 64*1024)
196		return nil, &fantasy.ProviderError{
197			Title:      "Stream Error",
198			Message:    strings.TrimSpace(string(b)),
199			StatusCode: resp.StatusCode,
200		}
201	}
202
203	return func(yield func(fantasy.StreamPart) bool) {
204		defer func() { _ = resp.Body.Close() }()
205		scanner := bufio.NewScanner(resp.Body)
206		buf := make([]byte, 0, 64*1024)
207		scanner.Buffer(buf, 4*1024*1024)
208
209		var (
210			event     string
211			dataBuf   bytes.Buffer
212			sawFinish bool
213			dispatch  = func() bool {
214				if dataBuf.Len() == 0 || event == "" {
215					dataBuf.Reset()
216					event = ""
217					return true
218				}
219				var part fantasy.StreamPart
220				if err := json.Unmarshal(dataBuf.Bytes(), &part); err != nil {
221					return yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err})
222				}
223				if part.Type == fantasy.StreamPartTypeFinish {
224					sawFinish = true
225				}
226				ok := yield(part)
227				dataBuf.Reset()
228				event = ""
229				return ok
230			}
231		)
232
233		for scanner.Scan() {
234			line := scanner.Text()
235			if line == "" { // event boundary
236				if !dispatch() {
237					return
238				}
239				continue
240			}
241			if strings.HasPrefix(line, ":") { // comment / ping
242				continue
243			}
244			if strings.HasPrefix(line, "event: ") {
245				event = strings.TrimSpace(line[len("event: "):])
246				continue
247			}
248			if strings.HasPrefix(line, "data: ") {
249				if dataBuf.Len() > 0 {
250					dataBuf.WriteByte('\n')
251				}
252				dataBuf.WriteString(line[len("data: "):])
253				continue
254			}
255		}
256		if err := scanner.Err(); err != nil &&
257			!errors.Is(err, context.Canceled) &&
258			!errors.Is(err, context.DeadlineExceeded) {
259			yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err})
260			return
261		}
262		// flush any pending data
263		_ = dispatch()
264		if !sawFinish {
265			_ = yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish})
266		}
267	}, nil
268}
269
270func (m *languageModel) doRequest(ctx context.Context, stream bool, call fantasy.Call) (*http.Response, error) {
271	addr, err := url.Parse(m.opts.baseURL)
272	if err != nil {
273		return nil, err
274	}
275	addr = addr.JoinPath(m.modelID)
276	if stream {
277		addr = addr.JoinPath("stream")
278	} else {
279		addr = addr.JoinPath("generate")
280	}
281
282	body, err := json.Marshal(call)
283	if err != nil {
284		return nil, err
285	}
286
287	req, err := http.NewRequestWithContext(ctx, http.MethodPost, addr.String(), bytes.NewReader(body))
288	if err != nil {
289		return nil, err
290	}
291	req.Header.Set("Content-Type", "application/json")
292	if stream {
293		req.Header.Set("Accept", "text/event-stream")
294	} else {
295		req.Header.Set("Accept", "application/json")
296	}
297	for k, v := range m.opts.headers {
298		req.Header.Set(k, v)
299	}
300
301	if m.opts.apiKey != "" {
302		req.Header.Set("Authorization", m.opts.apiKey)
303	}
304	return m.opts.client.Do(req)
305}
306
307// ioReadAllLimit reads up to n bytes.
308func ioReadAllLimit(r io.Reader, n int64) ([]byte, error) {
309	var b bytes.Buffer
310	if n <= 0 {
311		n = 1 << 20
312	}
313	lr := &io.LimitedReader{R: r, N: n}
314	_, err := b.ReadFrom(lr)
315	return b.Bytes(), err
316}
317
318func toProviderError(resp *http.Response, message string) error {
319	return &fantasy.ProviderError{
320		Title:      fantasy.ErrorTitleForStatusCode(resp.StatusCode),
321		Message:    message,
322		StatusCode: resp.StatusCode,
323	}
324}
325
326func retryAfter(resp *http.Response) string {
327	after, err := strconv.Atoi(resp.Header.Get("Retry-After"))
328	if err == nil && after > 0 {
329		d := time.Duration(after) * time.Second
330		return "Try again in " + d.String()
331	}
332	return "Try again later"
333}