@@ -15,6 +15,7 @@ import (
"errors"
"fmt"
"log/slog"
+ "net/http"
"os"
"regexp"
"strconv"
@@ -461,6 +462,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
if err != nil {
+ isHyper := largeModel.ModelCfg.Provider == hyper.Name
isCancelErr := errors.Is(err, context.Canceled)
isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
if currentAssistant == nil {
@@ -532,7 +534,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
} else if isPermissionErr {
currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
- } else if errors.Is(err, hyper.ErrUnauthorized) {
+ } else if isHyper && errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized {
currentAssistant.AddFinish(message.FinishReasonError, "Unauthorized", `Please re-authenticate with Hyper. You can also run "crush auth" to re-authenticate.`)
if a.notify != nil {
a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
@@ -542,7 +544,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
ProviderID: largeModel.ModelCfg.Provider,
})
}
- } else if errors.Is(err, hyper.ErrNoCredits) {
+ } else if isHyper && errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusPaymentRequired {
url := hyper.BaseURL()
link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
@@ -23,6 +23,7 @@ import (
"github.com/charmbracelet/crush/internal/agent/prompt"
"github.com/charmbracelet/crush/internal/agent/tools"
"github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/event"
"github.com/charmbracelet/crush/internal/filetracker"
"github.com/charmbracelet/crush/internal/history"
"github.com/charmbracelet/crush/internal/home"
@@ -797,17 +798,6 @@ func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, optio
return google.New(opts...)
}
-func (c *coordinator) buildHyperProvider(apiKey string) (fantasy.Provider, error) {
- opts := []hyper.Option{
- hyper.WithAPIKey(apiKey),
- }
- if c.cfg.Config().Options.Debug {
- httpClient := log.NewHTTPClient()
- opts = append(opts, hyper.WithHTTPClient(httpClient))
- }
- return hyper.New(opts...)
-}
-
func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
if model.Think {
return true
@@ -851,16 +841,18 @@ func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model con
return c.buildGoogleProvider(baseURL, apiKey, headers)
case "google-vertex":
return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
- case openaicompat.Name:
- if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
+ case openaicompat.Name, hyper.Name:
+ switch providerCfg.ID {
+ case hyper.Name:
+ baseURL = hyper.BaseURL() + "/v1"
+ headers["x-crush-id"] = event.GetID()
+ case string(catwalk.InferenceProviderZAI):
if providerCfg.ExtraBody == nil {
providerCfg.ExtraBody = map[string]any{}
}
providerCfg.ExtraBody["tool_stream"] = true
}
return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
- case hyper.Name:
- return c.buildHyperProvider(apiKey)
default:
return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
}
@@ -940,8 +932,7 @@ func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
func (c *coordinator) isUnauthorized(err error) bool {
var providerErr *fantasy.ProviderError
- return (errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized) ||
- errors.Is(err, hyper.ErrUnauthorized)
+ return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
}
func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
@@ -2,29 +2,15 @@
package hyper
import (
- "bufio"
- "bytes"
"cmp"
- "context"
_ "embed"
"encoding/json"
- "errors"
- "fmt"
- "io"
"log/slog"
- "maps"
- "net/http"
- "net/url"
"os"
"strconv"
- "strings"
"sync"
- "time"
"charm.land/catwalk/pkg/catwalk"
- "charm.land/fantasy"
- "charm.land/fantasy/object"
- "github.com/charmbracelet/crush/internal/event"
)
//go:generate wget -O provider.json https://hyper.charm.land/api/v1/provider
@@ -68,274 +54,3 @@ const (
var BaseURL = sync.OnceValue(func() string {
return cmp.Or(os.Getenv("HYPER_URL"), defaultBaseURL)
})
-
-var (
- ErrNoCredits = errors.New("you're out of credits")
- ErrUnauthorized = errors.New("unauthorized")
-)
-
-type options struct {
- baseURL string
- apiKey string
- name string
- headers map[string]string
- client *http.Client
-}
-
-// Option configures the proxy provider.
-type Option = func(*options)
-
-// New creates a new proxy provider.
-func New(opts ...Option) (fantasy.Provider, error) {
- o := options{
- baseURL: BaseURL() + "/api/v1/fantasy",
- name: Name,
- headers: map[string]string{
- "x-crush-id": event.GetID(),
- },
- client: &http.Client{Timeout: 0}, // stream-safe
- }
- for _, opt := range opts {
- opt(&o)
- }
- return &provider{options: o}, nil
-}
-
-// WithBaseURL sets the proxy base URL (e.g. http://localhost:8080).
-func WithBaseURL(url string) Option { return func(o *options) { o.baseURL = url } }
-
-// WithName sets the provider name.
-func WithName(name string) Option { return func(o *options) { o.name = name } }
-
-// WithHeaders sets extra headers sent to the proxy.
-func WithHeaders(headers map[string]string) Option {
- return func(o *options) {
- maps.Copy(o.headers, headers)
- }
-}
-
-// WithHTTPClient sets custom HTTP client.
-func WithHTTPClient(c *http.Client) Option { return func(o *options) { o.client = c } }
-
-// WithAPIKey sets the API key.
-func WithAPIKey(key string) Option {
- return func(o *options) {
- o.apiKey = key
- }
-}
-
-type provider struct{ options options }
-
-func (p *provider) Name() string { return p.options.name }
-
-// LanguageModel implements fantasy.Provider.
-func (p *provider) LanguageModel(_ context.Context, modelID string) (fantasy.LanguageModel, error) {
- if modelID == "" {
- return nil, errors.New("missing model id")
- }
- return &languageModel{modelID: modelID, provider: p.options.name, opts: p.options}, nil
-}
-
-type languageModel struct {
- provider string
- modelID string
- opts options
-}
-
-// GenerateObject implements fantasy.LanguageModel.
-func (m *languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
- return object.GenerateWithTool(ctx, m, call)
-}
-
-// StreamObject implements fantasy.LanguageModel.
-func (m *languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
- return object.StreamWithTool(ctx, m, call)
-}
-
-func (m *languageModel) Provider() string { return m.provider }
-func (m *languageModel) Model() string { return m.modelID }
-
-// Generate implements fantasy.LanguageModel by calling the proxy JSON endpoint.
-func (m *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
- resp, err := m.doRequest(ctx, false, call)
- if err != nil {
- return nil, err
- }
- defer func() { _ = resp.Body.Close() }()
- if resp.StatusCode < 200 || resp.StatusCode >= 300 {
- b, _ := ioReadAllLimit(resp.Body, 64*1024)
- return nil, fmt.Errorf("proxy generate error: %s", strings.TrimSpace(string(b)))
- }
- var out fantasy.Response
- if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
- return nil, err
- }
- return &out, nil
-}
-
-// Stream implements fantasy.LanguageModel using SSE from the proxy.
-func (m *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
- // Prefer explicit /stream endpoint
- resp, err := m.doRequest(ctx, true, call)
- if err != nil {
- return nil, err
- }
- switch resp.StatusCode {
- case http.StatusTooManyRequests:
- _ = resp.Body.Close()
- return nil, toProviderError(resp, retryAfter(resp))
- case http.StatusUnauthorized:
- _ = resp.Body.Close()
- return nil, ErrUnauthorized
- case http.StatusPaymentRequired:
- _ = resp.Body.Close()
- return nil, ErrNoCredits
- }
-
- if resp.StatusCode < 200 || resp.StatusCode >= 300 {
- defer func() { _ = resp.Body.Close() }()
- b, _ := ioReadAllLimit(resp.Body, 64*1024)
- return nil, &fantasy.ProviderError{
- Title: "Stream Error",
- Message: strings.TrimSpace(string(b)),
- StatusCode: resp.StatusCode,
- }
- }
-
- return func(yield func(fantasy.StreamPart) bool) {
- defer func() { _ = resp.Body.Close() }()
- scanner := bufio.NewScanner(resp.Body)
- buf := make([]byte, 0, 64*1024)
- scanner.Buffer(buf, 4*1024*1024)
-
- var (
- event string
- dataBuf bytes.Buffer
- sawFinish bool
- dispatch = func() bool {
- if dataBuf.Len() == 0 || event == "" {
- dataBuf.Reset()
- event = ""
- return true
- }
- var part fantasy.StreamPart
- if err := json.Unmarshal(dataBuf.Bytes(), &part); err != nil {
- return yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err})
- }
- if part.Type == fantasy.StreamPartTypeFinish {
- sawFinish = true
- }
- ok := yield(part)
- dataBuf.Reset()
- event = ""
- return ok
- }
- )
-
- for scanner.Scan() {
- line := scanner.Text()
- if line == "" { // event boundary
- if !dispatch() {
- return
- }
- continue
- }
- if strings.HasPrefix(line, ":") { // comment / ping
- continue
- }
- if strings.HasPrefix(line, "event: ") {
- event = strings.TrimSpace(line[len("event: "):])
- continue
- }
- if strings.HasPrefix(line, "data: ") {
- if dataBuf.Len() > 0 {
- dataBuf.WriteByte('\n')
- }
- dataBuf.WriteString(line[len("data: "):])
- continue
- }
- }
- if err := scanner.Err(); err != nil {
- if sawFinish && (errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)) {
- // If we already saw an explicit finish event, treat cancellation as a no-op.
- } else {
- _ = yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err})
- return
- }
- }
- if err := ctx.Err(); err != nil && !sawFinish {
- _ = yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err})
- return
- }
- // flush any pending data
- _ = dispatch()
- if !sawFinish {
- _ = yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish})
- }
- }, nil
-}
-
-func (m *languageModel) doRequest(ctx context.Context, stream bool, call fantasy.Call) (*http.Response, error) {
- addr, err := url.Parse(m.opts.baseURL)
- if err != nil {
- return nil, err
- }
- addr = addr.JoinPath(m.modelID)
- if stream {
- addr = addr.JoinPath("stream")
- } else {
- addr = addr.JoinPath("generate")
- }
-
- body, err := json.Marshal(call)
- if err != nil {
- return nil, err
- }
-
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, addr.String(), bytes.NewReader(body))
- if err != nil {
- return nil, err
- }
- req.Header.Set("Content-Type", "application/json")
- if stream {
- req.Header.Set("Accept", "text/event-stream")
- } else {
- req.Header.Set("Accept", "application/json")
- }
- for k, v := range m.opts.headers {
- req.Header.Set(k, v)
- }
-
- if m.opts.apiKey != "" {
- req.Header.Set("Authorization", m.opts.apiKey)
- }
- return m.opts.client.Do(req)
-}
-
-// ioReadAllLimit reads up to n bytes.
-func ioReadAllLimit(r io.Reader, n int64) ([]byte, error) {
- var b bytes.Buffer
- if n <= 0 {
- n = 1 << 20
- }
- lr := &io.LimitedReader{R: r, N: n}
- _, err := b.ReadFrom(lr)
- return b.Bytes(), err
-}
-
-func toProviderError(resp *http.Response, message string) error {
- return &fantasy.ProviderError{
- Title: fantasy.ErrorTitleForStatusCode(resp.StatusCode),
- Message: message,
- StatusCode: resp.StatusCode,
- }
-}
-
-func retryAfter(resp *http.Response) string {
- after, err := strconv.Atoi(resp.Header.Get("Retry-After"))
- if err == nil && after > 0 {
- d := time.Duration(after) * time.Second
- return "Try again in " + d.String()
- }
- return "Try again later"
-}