From ea86101bf6c61ee05ca3e887ea65d2f5ff9e68d7 Mon Sep 17 00:00:00 2001 From: Andrey Nering Date: Fri, 17 Apr 2026 13:42:01 -0300 Subject: [PATCH] feat(hyper): use openai-compatible endpoint for hyper (#2640) --- internal/agent/agent.go | 6 +- internal/agent/coordinator.go | 25 +-- internal/agent/hyper/provider.go | 285 ------------------------------- 3 files changed, 12 insertions(+), 304 deletions(-) diff --git a/internal/agent/agent.go b/internal/agent/agent.go index b3249e501f8e4a31e0199bc87014d8f6aa69979f..d62ef16bc1b4c7380e40ba05d7718d5004e360bd 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -15,6 +15,7 @@ import ( "errors" "fmt" "log/slog" + "net/http" "os" "regexp" "strconv" @@ -461,6 +462,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second)) if err != nil { + isHyper := largeModel.ModelCfg.Provider == hyper.Name isCancelErr := errors.Is(err, context.Canceled) isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied) if currentAssistant == nil { @@ -532,7 +534,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "") } else if isPermissionErr { currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "") - } else if errors.Is(err, hyper.ErrUnauthorized) { + } else if isHyper && errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized { currentAssistant.AddFinish(message.FinishReasonError, "Unauthorized", `Please re-authenticate with Hyper. You can also run "crush auth" to re-authenticate.`) if a.notify != nil { a.notify.Publish(pubsub.CreatedEvent, notify.Notification{ @@ -542,7 +544,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy ProviderID: largeModel.ModelCfg.Provider, }) } - } else if errors.Is(err, hyper.ErrNoCredits) { + } else if isHyper && errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusPaymentRequired { url := hyper.BaseURL() link := linkStyle.Hyperlink(url, "id=hyper").Render(url) currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link) diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index 1130438a61217b3491bd21e388376f303631e9ec..72723a2ff636ba8149138fa91f1310b669dd1023 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -23,6 +23,7 @@ import ( "github.com/charmbracelet/crush/internal/agent/prompt" "github.com/charmbracelet/crush/internal/agent/tools" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/event" "github.com/charmbracelet/crush/internal/filetracker" "github.com/charmbracelet/crush/internal/history" "github.com/charmbracelet/crush/internal/home" @@ -797,17 +798,6 @@ func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, optio return google.New(opts...) } -func (c *coordinator) buildHyperProvider(apiKey string) (fantasy.Provider, error) { - opts := []hyper.Option{ - hyper.WithAPIKey(apiKey), - } - if c.cfg.Config().Options.Debug { - httpClient := log.NewHTTPClient() - opts = append(opts, hyper.WithHTTPClient(httpClient)) - } - return hyper.New(opts...) -} - func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool { if model.Think { return true @@ -851,16 +841,18 @@ func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model con return c.buildGoogleProvider(baseURL, apiKey, headers) case "google-vertex": return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams) - case openaicompat.Name: - if providerCfg.ID == string(catwalk.InferenceProviderZAI) { + case openaicompat.Name, hyper.Name: + switch providerCfg.ID { + case hyper.Name: + baseURL = hyper.BaseURL() + "/v1" + headers["x-crush-id"] = event.GetID() + case string(catwalk.InferenceProviderZAI): if providerCfg.ExtraBody == nil { providerCfg.ExtraBody = map[string]any{} } providerCfg.ExtraBody["tool_stream"] = true } return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent) - case hyper.Name: - return c.buildHyperProvider(apiKey) default: return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type) } @@ -940,8 +932,7 @@ func (c *coordinator) Summarize(ctx context.Context, sessionID string) error { func (c *coordinator) isUnauthorized(err error) bool { var providerErr *fantasy.ProviderError - return (errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized) || - errors.Is(err, hyper.ErrUnauthorized) + return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized } func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error { diff --git a/internal/agent/hyper/provider.go b/internal/agent/hyper/provider.go index 5c8f7c3e4eced6af66113241f8137e744fda2463..d6be84a3dcaa6c08b30ce23f945458c03637a4b4 100644 --- a/internal/agent/hyper/provider.go +++ b/internal/agent/hyper/provider.go @@ -2,29 +2,15 @@ package hyper import ( - "bufio" - "bytes" "cmp" - "context" _ "embed" "encoding/json" - "errors" - "fmt" - "io" "log/slog" - "maps" - "net/http" - "net/url" "os" "strconv" - "strings" "sync" - "time" "charm.land/catwalk/pkg/catwalk" - "charm.land/fantasy" - "charm.land/fantasy/object" - "github.com/charmbracelet/crush/internal/event" ) //go:generate wget -O provider.json https://hyper.charm.land/api/v1/provider @@ -68,274 +54,3 @@ const ( var BaseURL = sync.OnceValue(func() string { return cmp.Or(os.Getenv("HYPER_URL"), defaultBaseURL) }) - -var ( - ErrNoCredits = errors.New("you're out of credits") - ErrUnauthorized = errors.New("unauthorized") -) - -type options struct { - baseURL string - apiKey string - name string - headers map[string]string - client *http.Client -} - -// Option configures the proxy provider. -type Option = func(*options) - -// New creates a new proxy provider. -func New(opts ...Option) (fantasy.Provider, error) { - o := options{ - baseURL: BaseURL() + "/api/v1/fantasy", - name: Name, - headers: map[string]string{ - "x-crush-id": event.GetID(), - }, - client: &http.Client{Timeout: 0}, // stream-safe - } - for _, opt := range opts { - opt(&o) - } - return &provider{options: o}, nil -} - -// WithBaseURL sets the proxy base URL (e.g. http://localhost:8080). -func WithBaseURL(url string) Option { return func(o *options) { o.baseURL = url } } - -// WithName sets the provider name. -func WithName(name string) Option { return func(o *options) { o.name = name } } - -// WithHeaders sets extra headers sent to the proxy. -func WithHeaders(headers map[string]string) Option { - return func(o *options) { - maps.Copy(o.headers, headers) - } -} - -// WithHTTPClient sets custom HTTP client. -func WithHTTPClient(c *http.Client) Option { return func(o *options) { o.client = c } } - -// WithAPIKey sets the API key. -func WithAPIKey(key string) Option { - return func(o *options) { - o.apiKey = key - } -} - -type provider struct{ options options } - -func (p *provider) Name() string { return p.options.name } - -// LanguageModel implements fantasy.Provider. -func (p *provider) LanguageModel(_ context.Context, modelID string) (fantasy.LanguageModel, error) { - if modelID == "" { - return nil, errors.New("missing model id") - } - return &languageModel{modelID: modelID, provider: p.options.name, opts: p.options}, nil -} - -type languageModel struct { - provider string - modelID string - opts options -} - -// GenerateObject implements fantasy.LanguageModel. -func (m *languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) { - return object.GenerateWithTool(ctx, m, call) -} - -// StreamObject implements fantasy.LanguageModel. -func (m *languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) { - return object.StreamWithTool(ctx, m, call) -} - -func (m *languageModel) Provider() string { return m.provider } -func (m *languageModel) Model() string { return m.modelID } - -// Generate implements fantasy.LanguageModel by calling the proxy JSON endpoint. -func (m *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) { - resp, err := m.doRequest(ctx, false, call) - if err != nil { - return nil, err - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - b, _ := ioReadAllLimit(resp.Body, 64*1024) - return nil, fmt.Errorf("proxy generate error: %s", strings.TrimSpace(string(b))) - } - var out fantasy.Response - if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { - return nil, err - } - return &out, nil -} - -// Stream implements fantasy.LanguageModel using SSE from the proxy. -func (m *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { - // Prefer explicit /stream endpoint - resp, err := m.doRequest(ctx, true, call) - if err != nil { - return nil, err - } - switch resp.StatusCode { - case http.StatusTooManyRequests: - _ = resp.Body.Close() - return nil, toProviderError(resp, retryAfter(resp)) - case http.StatusUnauthorized: - _ = resp.Body.Close() - return nil, ErrUnauthorized - case http.StatusPaymentRequired: - _ = resp.Body.Close() - return nil, ErrNoCredits - } - - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - defer func() { _ = resp.Body.Close() }() - b, _ := ioReadAllLimit(resp.Body, 64*1024) - return nil, &fantasy.ProviderError{ - Title: "Stream Error", - Message: strings.TrimSpace(string(b)), - StatusCode: resp.StatusCode, - } - } - - return func(yield func(fantasy.StreamPart) bool) { - defer func() { _ = resp.Body.Close() }() - scanner := bufio.NewScanner(resp.Body) - buf := make([]byte, 0, 64*1024) - scanner.Buffer(buf, 4*1024*1024) - - var ( - event string - dataBuf bytes.Buffer - sawFinish bool - dispatch = func() bool { - if dataBuf.Len() == 0 || event == "" { - dataBuf.Reset() - event = "" - return true - } - var part fantasy.StreamPart - if err := json.Unmarshal(dataBuf.Bytes(), &part); err != nil { - return yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err}) - } - if part.Type == fantasy.StreamPartTypeFinish { - sawFinish = true - } - ok := yield(part) - dataBuf.Reset() - event = "" - return ok - } - ) - - for scanner.Scan() { - line := scanner.Text() - if line == "" { // event boundary - if !dispatch() { - return - } - continue - } - if strings.HasPrefix(line, ":") { // comment / ping - continue - } - if strings.HasPrefix(line, "event: ") { - event = strings.TrimSpace(line[len("event: "):]) - continue - } - if strings.HasPrefix(line, "data: ") { - if dataBuf.Len() > 0 { - dataBuf.WriteByte('\n') - } - dataBuf.WriteString(line[len("data: "):]) - continue - } - } - if err := scanner.Err(); err != nil { - if sawFinish && (errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)) { - // If we already saw an explicit finish event, treat cancellation as a no-op. - } else { - _ = yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err}) - return - } - } - if err := ctx.Err(); err != nil && !sawFinish { - _ = yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err}) - return - } - // flush any pending data - _ = dispatch() - if !sawFinish { - _ = yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish}) - } - }, nil -} - -func (m *languageModel) doRequest(ctx context.Context, stream bool, call fantasy.Call) (*http.Response, error) { - addr, err := url.Parse(m.opts.baseURL) - if err != nil { - return nil, err - } - addr = addr.JoinPath(m.modelID) - if stream { - addr = addr.JoinPath("stream") - } else { - addr = addr.JoinPath("generate") - } - - body, err := json.Marshal(call) - if err != nil { - return nil, err - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, addr.String(), bytes.NewReader(body)) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/json") - if stream { - req.Header.Set("Accept", "text/event-stream") - } else { - req.Header.Set("Accept", "application/json") - } - for k, v := range m.opts.headers { - req.Header.Set(k, v) - } - - if m.opts.apiKey != "" { - req.Header.Set("Authorization", m.opts.apiKey) - } - return m.opts.client.Do(req) -} - -// ioReadAllLimit reads up to n bytes. -func ioReadAllLimit(r io.Reader, n int64) ([]byte, error) { - var b bytes.Buffer - if n <= 0 { - n = 1 << 20 - } - lr := &io.LimitedReader{R: r, N: n} - _, err := b.ReadFrom(lr) - return b.Bytes(), err -} - -func toProviderError(resp *http.Response, message string) error { - return &fantasy.ProviderError{ - Title: fantasy.ErrorTitleForStatusCode(resp.StatusCode), - Message: message, - StatusCode: resp.StatusCode, - } -} - -func retryAfter(resp *http.Response) string { - after, err := strconv.Atoi(resp.Header.Get("Retry-After")) - if err == nil && after > 0 { - d := time.Duration(after) * time.Second - return "Try again in " + d.String() - } - return "Try again later" -}