diff --git a/.gitattributes b/.gitattributes index d0a96600fea2558f4f8be8840c048780e48c9a7e..2e533aabaaa6a7168816003b0da820804778fd4e 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,3 +1,4 @@ *.golden linguist-generated=true -text .github/crush-schema.json linguist-generated=true +internal/agent/hyper/provider.json linguist-generated=true internal/agent/testdata/**/*.yaml -diff linguist-generated=true diff --git a/.github/workflows/schema-update.yml b/.github/workflows/schema-update.yml index d2366652b8c21c4606605db9fc17d92031d3ad5e..949dac5c260497344969c182302baa0d968113bf 100644 --- a/.github/workflows/schema-update.yml +++ b/.github/workflows/schema-update.yml @@ -1,10 +1,11 @@ -name: Update Schema +name: Update files on: push: branches: [main] paths: - "internal/config/**" + - "internal/agent/hyper/**" jobs: update-schema: @@ -17,9 +18,10 @@ jobs: with: go-version-file: go.mod - run: go run . schema > ./schema.json + - run: go generate ./internal/agent/hyper/... - uses: stefanzweifel/git-auto-commit-action@28e16e81777b558cc906c8750092100bbb34c5e3 # v5 with: - commit_message: "chore: auto-update generated files" + commit_message: "chore: auto-update files" branch: main commit_user_name: Charm commit_user_email: 124303983+charmcli@users.noreply.github.com diff --git a/Taskfile.yaml b/Taskfile.yaml index f9c20cd63e64fea25ce2bbd7f8a6982170dbe02e..2f5574f7ab1f07a03f47e8534d477afd293d9248 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -101,6 +101,13 @@ tasks: generates: - schema.json + hyper: + desc: Update Hyper embedded provider.json + cmds: + - go generate ./internal/agent/hyper/... + generates: + - ./internal/agent/hyper/provider.json + release: desc: Create and push a new tag following semver vars: diff --git a/go.mod b/go.mod index 9fba0cbb3a04775bc020946a2349d469b517391b..a8ef92e75e00e7108055c457685752f4495d6807 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/charmbracelet/fang v0.4.4 github.com/charmbracelet/ultraviolet v0.0.0-20251212194010-b927aa605560 github.com/charmbracelet/x/ansi v0.11.3 + github.com/charmbracelet/x/etag v0.2.0 github.com/charmbracelet/x/exp/charmtone v0.0.0-20250708181618-a60a724ba6c3 github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f github.com/charmbracelet/x/exp/ordered v0.1.0 @@ -91,7 +92,6 @@ require ( github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect github.com/charmbracelet/anthropic-sdk-go v0.0.0-20251024181547-21d6f3d9a904 // indirect - github.com/charmbracelet/x/etag v0.2.0 // indirect github.com/charmbracelet/x/json v0.2.0 // indirect github.com/charmbracelet/x/termios v0.1.1 // indirect github.com/charmbracelet/x/windows v0.2.2 // indirect diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 5f48ee1c7b1434af7453fa19d567d8a194c377d1..0707ec98abed9fdf1599a13b081b5958551c83ed 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -27,7 +27,9 @@ import ( "charm.land/fantasy/providers/google" "charm.land/fantasy/providers/openai" "charm.land/fantasy/providers/openrouter" + "charm.land/lipgloss/v2" "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/agent/hyper" "github.com/charmbracelet/crush/internal/agent/tools" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/csync" @@ -454,6 +456,10 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "") } else if isPermissionErr { currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "") + } else if errors.Is(err, hyper.ErrNoCredits) { + url := hyper.BaseURL() + link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url) + currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link) } else if errors.As(err, &providerErr) { currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message) } else if errors.As(err, &fantasyErr) { diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index 5690b1311a4fd0d0290277d01e734b17ab3d2a66..ea214a5bcfb8e65e6d5dee826854345168a2eb86 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -17,6 +17,7 @@ import ( "charm.land/fantasy" "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/agent/hyper" "github.com/charmbracelet/crush/internal/agent/prompt" "github.com/charmbracelet/crush/internal/agent/tools" "github.com/charmbracelet/crush/internal/config" @@ -668,6 +669,18 @@ func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, optio return google.New(opts...) } +func (c *coordinator) buildHyperProvider(baseURL, apiKey string) (fantasy.Provider, error) { + opts := []hyper.Option{ + hyper.WithBaseURL(baseURL), + hyper.WithAPIKey(apiKey), + } + if c.cfg.Options.Debug { + httpClient := log.NewHTTPClient() + opts = append(opts, hyper.WithHTTPClient(httpClient)) + } + return hyper.New(opts...) +} + func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool { if model.Think { return true @@ -728,6 +741,8 @@ func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model con providerCfg.ExtraBody["tool_stream"] = true } return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody) + case hyper.Name: + return c.buildHyperProvider(baseURL, apiKey) default: return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type) } diff --git a/internal/agent/hyper/provider.go b/internal/agent/hyper/provider.go new file mode 100644 index 0000000000000000000000000000000000000000..ea2f4a18eaeec017f5f3f02576504a424f50bbf1 --- /dev/null +++ b/internal/agent/hyper/provider.go @@ -0,0 +1,330 @@ +// Package hyper provides a fantasy.Provider that proxies requests to Hyper. +package hyper + +import ( + "bufio" + "bytes" + "cmp" + "context" + _ "embed" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "maps" + "net/http" + "net/url" + "os" + "strconv" + "strings" + "sync" + "time" + + "charm.land/fantasy" + "charm.land/fantasy/object" + "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/event" +) + +//go:generate wget -O provider.json https://console.charm.land/api/v1/provider + +//go:embed provider.json +var embedded []byte + +// Enabled returns true if hyper is enabled. +var Enabled = sync.OnceValue(func() bool { + b, _ := strconv.ParseBool( + cmp.Or( + os.Getenv("HYPER"), + os.Getenv("HYPERCRUSH"), + os.Getenv("HYPER_ENABLE"), + os.Getenv("HYPER_ENABLED"), + ), + ) + return b +}) + +// Embedded returns the embedded Hyper provider. +var Embedded = sync.OnceValue(func() catwalk.Provider { + var provider catwalk.Provider + if err := json.Unmarshal(embedded, &provider); err != nil { + slog.Error("could not use embedded provider data", "err", err) + } + return provider +}) + +const ( + // Name is the default name of this meta provider. + Name = "hyper" + // defaultBaseURL is the default proxy URL. + // TODO: change this to production URL when ready. + defaultBaseURL = "https://console.charm.land" +) + +// BaseURL returns the base URL, which is either $HYPER_URL or the default. +var BaseURL = sync.OnceValue(func() string { + return cmp.Or(os.Getenv("HYPER_URL"), defaultBaseURL) +}) + +var ErrNoCredits = errors.New("you're out of credits") + +type options struct { + baseURL string + apiKey string + name string + headers map[string]string + client *http.Client +} + +// Option configures the proxy provider. +type Option = func(*options) + +// New creates a new proxy provider. +func New(opts ...Option) (fantasy.Provider, error) { + o := options{ + baseURL: BaseURL() + "/api/v1/fantasy", + name: Name, + headers: map[string]string{ + "x-crush-id": event.GetID(), + }, + client: &http.Client{Timeout: 0}, // stream-safe + } + for _, opt := range opts { + opt(&o) + } + return &provider{options: o}, nil +} + +// WithBaseURL sets the proxy base URL (e.g. http://localhost:8080). +func WithBaseURL(url string) Option { return func(o *options) { o.baseURL = url } } + +// WithName sets the provider name. +func WithName(name string) Option { return func(o *options) { o.name = name } } + +// WithHeaders sets extra headers sent to the proxy. +func WithHeaders(headers map[string]string) Option { + return func(o *options) { + maps.Copy(o.headers, headers) + } +} + +// WithHTTPClient sets custom HTTP client. +func WithHTTPClient(c *http.Client) Option { return func(o *options) { o.client = c } } + +// WithAPIKey sets the API key. +func WithAPIKey(key string) Option { + return func(o *options) { + o.apiKey = key + } +} + +type provider struct{ options options } + +func (p *provider) Name() string { return p.options.name } + +// LanguageModel implements fantasy.Provider. +func (p *provider) LanguageModel(_ context.Context, modelID string) (fantasy.LanguageModel, error) { + if modelID == "" { + return nil, errors.New("missing model id") + } + return &languageModel{modelID: modelID, provider: p.options.name, opts: p.options}, nil +} + +type languageModel struct { + provider string + modelID string + opts options +} + +// GenerateObject implements fantasy.LanguageModel. +func (m *languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) { + return object.GenerateWithTool(ctx, m, call) +} + +// StreamObject implements fantasy.LanguageModel. +func (m *languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) { + return object.StreamWithTool(ctx, m, call) +} + +func (m *languageModel) Provider() string { return m.provider } +func (m *languageModel) Model() string { return m.modelID } + +// Generate implements fantasy.LanguageModel by calling the proxy JSON endpoint. +func (m *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) { + resp, err := m.doRequest(ctx, false, call) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + b, _ := ioReadAllLimit(resp.Body, 64*1024) + return nil, fmt.Errorf("proxy generate error: %s", strings.TrimSpace(string(b))) + } + var out fantasy.Response + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + return nil, err + } + return &out, nil +} + +// Stream implements fantasy.LanguageModel using SSE from the proxy. +func (m *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { + // Prefer explicit /stream endpoint + resp, err := m.doRequest(ctx, true, call) + if err != nil { + return nil, err + } + switch resp.StatusCode { + case http.StatusTooManyRequests: + _ = resp.Body.Close() + return nil, toProviderError(resp, retryAfter(resp)) + case http.StatusUnauthorized: + _ = resp.Body.Close() + return nil, toProviderError(resp, "") + case http.StatusPaymentRequired: + _ = resp.Body.Close() + return nil, ErrNoCredits + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + defer func() { _ = resp.Body.Close() }() + b, _ := ioReadAllLimit(resp.Body, 64*1024) + return nil, &fantasy.ProviderError{ + Title: "Stream Error", + Message: strings.TrimSpace(string(b)), + StatusCode: resp.StatusCode, + } + } + + return func(yield func(fantasy.StreamPart) bool) { + defer func() { _ = resp.Body.Close() }() + scanner := bufio.NewScanner(resp.Body) + buf := make([]byte, 0, 64*1024) + scanner.Buffer(buf, 4*1024*1024) + + var ( + event string + dataBuf bytes.Buffer + sawFinish bool + dispatch = func() bool { + if dataBuf.Len() == 0 || event == "" { + dataBuf.Reset() + event = "" + return true + } + var part fantasy.StreamPart + if err := json.Unmarshal(dataBuf.Bytes(), &part); err != nil { + return yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err}) + } + if part.Type == fantasy.StreamPartTypeFinish { + sawFinish = true + } + ok := yield(part) + dataBuf.Reset() + event = "" + return ok + } + ) + + for scanner.Scan() { + line := scanner.Text() + if line == "" { // event boundary + if !dispatch() { + return + } + continue + } + if strings.HasPrefix(line, ":") { // comment / ping + continue + } + if strings.HasPrefix(line, "event: ") { + event = strings.TrimSpace(line[len("event: "):]) + continue + } + if strings.HasPrefix(line, "data: ") { + if dataBuf.Len() > 0 { + dataBuf.WriteByte('\n') + } + dataBuf.WriteString(line[len("data: "):]) + continue + } + } + if err := scanner.Err(); err != nil && + !errors.Is(err, context.Canceled) && + !errors.Is(err, context.DeadlineExceeded) { + yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err}) + return + } + // flush any pending data + _ = dispatch() + if !sawFinish { + _ = yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish}) + } + }, nil +} + +func (m *languageModel) doRequest(ctx context.Context, stream bool, call fantasy.Call) (*http.Response, error) { + addr, err := url.Parse(m.opts.baseURL) + if err != nil { + return nil, err + } + addr = addr.JoinPath(m.modelID) + if stream { + addr = addr.JoinPath("stream") + } else { + addr = addr.JoinPath("generate") + } + + body, err := json.Marshal(call) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, addr.String(), bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + if stream { + req.Header.Set("Accept", "text/event-stream") + } else { + req.Header.Set("Accept", "application/json") + } + for k, v := range m.opts.headers { + req.Header.Set(k, v) + } + + if m.opts.apiKey != "" { + req.Header.Set("Authorization", m.opts.apiKey) + } + return m.opts.client.Do(req) +} + +// ioReadAllLimit reads up to n bytes. +func ioReadAllLimit(r io.Reader, n int64) ([]byte, error) { + var b bytes.Buffer + if n <= 0 { + n = 1 << 20 + } + lr := &io.LimitedReader{R: r, N: n} + _, err := b.ReadFrom(lr) + return b.Bytes(), err +} + +func toProviderError(resp *http.Response, message string) error { + return &fantasy.ProviderError{ + Title: fantasy.ErrorTitleForStatusCode(resp.StatusCode), + Message: message, + StatusCode: resp.StatusCode, + } +} + +func retryAfter(resp *http.Response) string { + after, err := strconv.Atoi(resp.Header.Get("Retry-After")) + if err == nil && after > 0 { + d := time.Duration(after) * time.Second + return "Try again in " + d.String() + } + return "Try again in later" +} diff --git a/internal/agent/hyper/provider.json b/internal/agent/hyper/provider.json new file mode 100644 index 0000000000000000000000000000000000000000..5558750e38e35024615b41b71243888a1a1ebd6c --- /dev/null +++ b/internal/agent/hyper/provider.json @@ -0,0 +1 @@ +{"name":"Charm Hyper","id":"hyper","api_endpoint":"https://console.charm.land/api/v1/fantasy","type":"hyper","default_large_model_id":"claude-sonnet-4-5","default_small_model_id":"claude-3-5-haiku","models":[{"id":"Kimi-K2-0905","name":"Kimi K2 0905","cost_per_1m_in":0.55,"cost_per_1m_out":2.19,"cost_per_1m_in_cached":0,"cost_per_1m_out_cached":0,"context_window":256000,"default_max_tokens":10000,"can_reason":true,"default_reasoning_effort":"medium","supports_attachments":false,"options":{}},{"id":"claude-3-5-haiku","name":"Claude 3.5 Haiku","cost_per_1m_in":0.7999999999999999,"cost_per_1m_out":4,"cost_per_1m_in_cached":1,"cost_per_1m_out_cached":0.08,"context_window":200000,"default_max_tokens":5000,"can_reason":false,"supports_attachments":true,"options":{}},{"id":"claude-3-5-sonnet","name":"Claude 3.5 Sonnet (New)","cost_per_1m_in":3,"cost_per_1m_out":15,"cost_per_1m_in_cached":3.75,"cost_per_1m_out_cached":0.3,"context_window":200000,"default_max_tokens":5000,"can_reason":false,"supports_attachments":true,"options":{}},{"id":"claude-3-7-sonnet","name":"Claude 3.7 Sonnet","cost_per_1m_in":3,"cost_per_1m_out":15,"cost_per_1m_in_cached":3.75,"cost_per_1m_out_cached":0.3,"context_window":200000,"default_max_tokens":50000,"can_reason":true,"supports_attachments":true,"options":{}},{"id":"claude-haiku-4-5","name":"Claude 4.5 Haiku","cost_per_1m_in":1,"cost_per_1m_out":5,"cost_per_1m_in_cached":1.25,"cost_per_1m_out_cached":0.09999999999999999,"context_window":200000,"default_max_tokens":32000,"can_reason":true,"supports_attachments":true,"options":{}},{"id":"claude-opus-4","name":"Claude Opus 4","cost_per_1m_in":15,"cost_per_1m_out":75,"cost_per_1m_in_cached":18.75,"cost_per_1m_out_cached":1.5,"context_window":200000,"default_max_tokens":32000,"can_reason":true,"supports_attachments":true,"options":{}},{"id":"claude-opus-4-1","name":"Claude Opus 4.1","cost_per_1m_in":15,"cost_per_1m_out":75,"cost_per_1m_in_cached":18.75,"cost_per_1m_out_cached":1.5,"context_window":200000,"default_max_tokens":32000,"can_reason":true,"supports_attachments":true,"options":{}},{"id":"claude-opus-4-5","name":"Claude Opus 4.5","cost_per_1m_in":5,"cost_per_1m_out":25,"cost_per_1m_in_cached":6.25,"cost_per_1m_out_cached":0.5,"context_window":200000,"default_max_tokens":50000,"can_reason":true,"supports_attachments":true,"options":{}},{"id":"claude-sonnet-4","name":"Claude Sonnet 4","cost_per_1m_in":3,"cost_per_1m_out":15,"cost_per_1m_in_cached":3.75,"cost_per_1m_out_cached":0.3,"context_window":200000,"default_max_tokens":50000,"can_reason":true,"supports_attachments":true,"options":{}},{"id":"claude-sonnet-4-5","name":"Claude Sonnet 4.5","cost_per_1m_in":3,"cost_per_1m_out":15,"cost_per_1m_in_cached":3.75,"cost_per_1m_out_cached":0.3,"context_window":200000,"default_max_tokens":50000,"can_reason":true,"supports_attachments":true,"options":{}},{"id":"gemini-2.5-flash","name":"Gemini 2.5 Flash","cost_per_1m_in":0.3,"cost_per_1m_out":2.5,"cost_per_1m_in_cached":0.3833,"cost_per_1m_out_cached":0.075,"context_window":1048576,"default_max_tokens":50000,"can_reason":true,"supports_attachments":true,"options":{}},{"id":"gemini-2.5-pro","name":"Gemini 2.5 Pro","cost_per_1m_in":1.25,"cost_per_1m_out":10,"cost_per_1m_in_cached":1.625,"cost_per_1m_out_cached":0.31,"context_window":1048576,"default_max_tokens":50000,"can_reason":true,"supports_attachments":true,"options":{}},{"id":"glm-4.6","name":"GLM-4.6","cost_per_1m_in":0.6,"cost_per_1m_out":2.2,"cost_per_1m_in_cached":0.11,"cost_per_1m_out_cached":0,"context_window":204800,"default_max_tokens":131072,"can_reason":true,"default_reasoning_effort":"medium","supports_attachments":false,"options":{}},{"id":"gpt-4.1","name":"GPT-4.1","cost_per_1m_in":2,"cost_per_1m_out":8,"cost_per_1m_in_cached":0,"cost_per_1m_out_cached":0.5,"context_window":1047576,"default_max_tokens":16384,"can_reason":false,"supports_attachments":true,"options":{}},{"id":"gpt-4.1-mini","name":"GPT-4.1 Mini","cost_per_1m_in":0.39999999999999997,"cost_per_1m_out":1.5999999999999999,"cost_per_1m_in_cached":0,"cost_per_1m_out_cached":0.09999999999999999,"context_window":1047576,"default_max_tokens":16384,"can_reason":false,"supports_attachments":true,"options":{}},{"id":"gpt-4.1-nano","name":"GPT-4.1 Nano","cost_per_1m_in":0.09999999999999999,"cost_per_1m_out":0.39999999999999997,"cost_per_1m_in_cached":0,"cost_per_1m_out_cached":0.024999999999999998,"context_window":1047576,"default_max_tokens":16384,"can_reason":false,"supports_attachments":true,"options":{}},{"id":"gpt-4o","name":"GPT-4o","cost_per_1m_in":2.5,"cost_per_1m_out":10,"cost_per_1m_in_cached":0,"cost_per_1m_out_cached":1.25,"context_window":128000,"default_max_tokens":8192,"can_reason":false,"supports_attachments":true,"options":{}},{"id":"gpt-4o-mini","name":"GPT-4o-mini","cost_per_1m_in":0.15,"cost_per_1m_out":0.6,"cost_per_1m_in_cached":0,"cost_per_1m_out_cached":0.075,"context_window":128000,"default_max_tokens":8192,"can_reason":false,"supports_attachments":true,"options":{}},{"id":"gpt-5","name":"GPT-5","cost_per_1m_in":1.25,"cost_per_1m_out":10,"cost_per_1m_in_cached":0.125,"cost_per_1m_out_cached":0.125,"context_window":400000,"default_max_tokens":128000,"can_reason":true,"reasoning_levels":["minimal","low","medium","high"],"default_reasoning_effort":"medium","supports_attachments":true,"options":{}},{"id":"gpt-5-codex","name":"GPT-5 Codex","cost_per_1m_in":1.25,"cost_per_1m_out":10,"cost_per_1m_in_cached":0.125,"cost_per_1m_out_cached":0.125,"context_window":400000,"default_max_tokens":128000,"can_reason":true,"reasoning_levels":["minimal","low","medium","high"],"default_reasoning_effort":"medium","supports_attachments":true,"options":{}},{"id":"gpt-5-mini","name":"GPT-5 Mini","cost_per_1m_in":0.25,"cost_per_1m_out":2,"cost_per_1m_in_cached":0.025,"cost_per_1m_out_cached":0.025,"context_window":400000,"default_max_tokens":128000,"can_reason":true,"reasoning_levels":["low","medium","high"],"default_reasoning_effort":"medium","supports_attachments":true,"options":{}},{"id":"gpt-5-nano","name":"GPT-5 Nano","cost_per_1m_in":0.05,"cost_per_1m_out":0.4,"cost_per_1m_in_cached":0.005,"cost_per_1m_out_cached":0.005,"context_window":400000,"default_max_tokens":128000,"can_reason":true,"reasoning_levels":["low","medium","high"],"default_reasoning_effort":"medium","supports_attachments":true,"options":{}},{"id":"gpt-5.1","name":"GPT-5.1","cost_per_1m_in":1.25,"cost_per_1m_out":10,"cost_per_1m_in_cached":0.125,"cost_per_1m_out_cached":0.125,"context_window":400000,"default_max_tokens":128000,"can_reason":true,"reasoning_levels":["minimal","low","medium","high"],"default_reasoning_effort":"medium","supports_attachments":true,"options":{}},{"id":"gpt-5.1-codex","name":"GPT-5.1 Codex","cost_per_1m_in":1.25,"cost_per_1m_out":10,"cost_per_1m_in_cached":0.125,"cost_per_1m_out_cached":0.125,"context_window":400000,"default_max_tokens":128000,"can_reason":true,"reasoning_levels":["minimal","low","medium","high"],"default_reasoning_effort":"medium","supports_attachments":true,"options":{}},{"id":"gpt-5.1-codex-max","name":"GPT-5.1 Codex Max","cost_per_1m_in":1.25,"cost_per_1m_out":10,"cost_per_1m_in_cached":0.125,"cost_per_1m_out_cached":0.125,"context_window":400000,"default_max_tokens":128000,"can_reason":true,"reasoning_levels":["minimal","low","medium","high"],"default_reasoning_effort":"medium","supports_attachments":true,"options":{}},{"id":"gpt-5.1-codex-mini","name":"GPT-5.1 Codex Mini","cost_per_1m_in":0.25,"cost_per_1m_out":2,"cost_per_1m_in_cached":0.025,"cost_per_1m_out_cached":0.025,"context_window":400000,"default_max_tokens":128000,"can_reason":true,"reasoning_levels":["low","medium","high"],"default_reasoning_effort":"medium","supports_attachments":true,"options":{}},{"id":"gpt-5.2","name":"GPT-5.2","cost_per_1m_in":1.75,"cost_per_1m_out":14,"cost_per_1m_in_cached":0.175,"cost_per_1m_out_cached":0.175,"context_window":400000,"default_max_tokens":128000,"can_reason":true,"reasoning_levels":["minimal","low","medium","high"],"default_reasoning_effort":"medium","supports_attachments":true,"options":{}},{"id":"o3","name":"o3","cost_per_1m_in":2,"cost_per_1m_out":8,"cost_per_1m_in_cached":0,"cost_per_1m_out_cached":0.5,"context_window":200000,"default_max_tokens":50000,"can_reason":true,"reasoning_levels":["low","medium","high"],"default_reasoning_effort":"medium","supports_attachments":true,"options":{}},{"id":"o3-mini","name":"o3 Mini","cost_per_1m_in":1.1,"cost_per_1m_out":4.4,"cost_per_1m_in_cached":0,"cost_per_1m_out_cached":0.55,"context_window":200000,"default_max_tokens":50000,"can_reason":true,"reasoning_levels":["low","medium","high"],"default_reasoning_effort":"medium","supports_attachments":false,"options":{}},{"id":"o4-mini","name":"o4 Mini","cost_per_1m_in":1.1,"cost_per_1m_out":4.4,"cost_per_1m_in_cached":0,"cost_per_1m_out_cached":0.275,"context_window":200000,"default_max_tokens":50000,"can_reason":true,"reasoning_levels":["low","medium","high"],"default_reasoning_effort":"medium","supports_attachments":true,"options":{}},{"id":"qwen3-coder-480b-a35b-instruct","name":"Qwen 3 480B Coder","cost_per_1m_in":0.82,"cost_per_1m_out":3.29,"cost_per_1m_in_cached":0,"cost_per_1m_out_cached":0,"context_window":131072,"default_max_tokens":65536,"can_reason":false,"supports_attachments":false,"options":{}}]} \ No newline at end of file diff --git a/internal/cmd/login.go b/internal/cmd/login.go index 00c155cbfd23a95005da10a60f9ea36ab50cbd7d..22fde94985917999f5664696d81a4aa205503640 100644 --- a/internal/cmd/login.go +++ b/internal/cmd/login.go @@ -9,8 +9,12 @@ import ( "strings" "charm.land/lipgloss/v2" + "github.com/atotto/clipboard" + hyperp "github.com/charmbracelet/crush/internal/agent/hyper" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/oauth/claude" + "github.com/charmbracelet/crush/internal/oauth/hyper" + "github.com/pkg/browser" "github.com/spf13/cobra" ) @@ -20,31 +24,34 @@ var loginCmd = &cobra.Command{ Short: "Login Crush to a platform", Long: `Login Crush to a specified platform. The platform should be provided as an argument. -Available platforms are: claude.`, +Available platforms are: hyper, claude.`, Example: ` +# Authenticate with Charm Hyper +crush login + # Authenticate with Claude Code Max crush login claude `, ValidArgs: []cobra.Completion{ + "hyper", "claude", "anthropic", }, - Args: cobra.ExactArgs(1), + Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - if len(args) > 1 { - return fmt.Errorf("wrong number of arguments") - } - if len(args) == 0 || args[0] == "" { - return cmd.Help() - } - app, err := setupAppWithProgressBar(cmd) if err != nil { return err } defer app.Shutdown() - switch args[0] { + provider := "hyper" + if len(args) > 0 { + provider = args[0] + } + switch provider { + case "hyper": + return loginHyper() case "anthropic", "claude": return loginClaude() default: @@ -53,13 +60,73 @@ crush login claude }, } +func loginHyper() error { + cfg := config.Get() + if !hyperp.Enabled() { + return fmt.Errorf("hyper not enabled") + } + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill) + defer cancel() + + resp, err := hyper.InitiateDeviceAuth(ctx) + if err != nil { + return err + } + + if clipboard.WriteAll(resp.UserCode) == nil { + fmt.Println("The following code should be on clipboard already:") + } else { + fmt.Println("Copy the following code:") + } + + fmt.Println() + fmt.Println(lipgloss.NewStyle().Bold(true).Render(resp.UserCode)) + fmt.Println() + fmt.Println("Press enter to open this URL, and then paste it there:") + fmt.Println() + fmt.Println(lipgloss.NewStyle().Hyperlink(resp.VerificationURL, "id=hyper").Render(resp.VerificationURL)) + fmt.Println() + waitEnter() + if err := browser.OpenURL(resp.VerificationURL); err != nil { + fmt.Println("Could not open the URL. You'll need to manually open the URL in your browser.") + } + + fmt.Println("Exchanging authorization code...") + refreshToken, err := hyper.PollForToken(ctx, resp.DeviceCode, resp.ExpiresIn) + if err != nil { + return err + } + + fmt.Println("Exchanging refresh token for access token...") + token, err := hyper.ExchangeToken(ctx, refreshToken) + if err != nil { + return err + } + + fmt.Println("Verifying access token...") + introspect, err := hyper.IntrospectToken(ctx, token.AccessToken) + if err != nil { + return fmt.Errorf("token introspection failed: %w", err) + } + if !introspect.Active { + return fmt.Errorf("access token is not active") + } + + if err := cmp.Or( + cfg.SetConfigField("providers.hyper.api_key", token.AccessToken), + cfg.SetConfigField("providers.hyper.oauth", token), + ); err != nil { + return err + } + + fmt.Println() + fmt.Println("You're now authenticated with Hyper!") + return nil +} + func loginClaude() error { ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill) - go func() { - <-ctx.Done() - cancel() - os.Exit(1) - }() + defer cancel() verifier, challenge, err := claude.GetChallenge() if err != nil { @@ -106,3 +173,7 @@ func loginClaude() error { fmt.Println("You're now authenticated with Claude Code Max!") return nil } + +func waitEnter() { + _, _ = fmt.Scanln() +} diff --git a/internal/cmd/update_providers.go b/internal/cmd/update_providers.go index 599d2c90954ca43888197961ec3bea4372285071..3b4b35b681fb003ff4d942c8052101253b364a1c 100644 --- a/internal/cmd/update_providers.go +++ b/internal/cmd/update_providers.go @@ -10,22 +10,30 @@ import ( "github.com/spf13/cobra" ) +var updateProvidersSource string + var updateProvidersCmd = &cobra.Command{ Use: "update-providers [path-or-url]", Short: "Update providers", - Long: `Update the list of providers from a specified local path or remote URL.`, + Long: `Update provider information from a specified local path or remote URL.`, Example: ` -# Update providers remotely from Catwalk +# Update Catwalk providers remotely (default) crush update-providers -# Update providers from a custom URL -crush update-providers https://example.com/ +# Update Catwalk providers from a custom URL +crush update-providers https://example.com/providers.json -# Update providers from a local file +# Update Catwalk providers from a local file crush update-providers /path/to/local-providers.json -# Update providers from embedded version +# Update Catwalk providers from embedded version crush update-providers embedded + +# Update Hyper provider information +crush update-providers --source=hyper + +# Update Hyper from a custom URL +crush update-providers --source=hyper https://hyper.example.com `, RunE: func(cmd *cobra.Command, args []string) error { // NOTE(@andreynering): We want to skip logging output do stdout here. @@ -36,7 +44,17 @@ crush update-providers embedded pathOrURL = args[0] } - if err := config.UpdateProviders(pathOrURL); err != nil { + var err error + switch updateProvidersSource { + case "catwalk": + err = config.UpdateProviders(pathOrURL) + case "hyper": + err = config.UpdateHyper(pathOrURL) + default: + return fmt.Errorf("invalid source %q, must be 'catwalk' or 'hyper'", updateProvidersSource) + } + + if err != nil { return err } @@ -52,9 +70,13 @@ crush update-providers embedded SetString("SUCCESS") textStyle := lipgloss.NewStyle(). MarginLeft(2). - SetString("Providers updated successfully.") + SetString(fmt.Sprintf("%s provider updated successfully.", updateProvidersSource)) fmt.Printf("%s\n%s\n\n", headerStyle.Render(), textStyle.Render()) return nil }, } + +func init() { + updateProvidersCmd.Flags().StringVar(&updateProvidersSource, "source", "catwalk", "Provider source to update (catwalk or hyper)") +} diff --git a/internal/config/catwalk.go b/internal/config/catwalk.go new file mode 100644 index 0000000000000000000000000000000000000000..c3cc2eb69d47e1a85e35164fda09d0f73761b820 --- /dev/null +++ b/internal/config/catwalk.go @@ -0,0 +1,82 @@ +package config + +import ( + "context" + "errors" + "log/slog" + "sync" + "sync/atomic" + + "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/charmbracelet/catwalk/pkg/embedded" +) + +type catwalkClient interface { + GetProviders(context.Context, string) ([]catwalk.Provider, error) +} + +var _ syncer[[]catwalk.Provider] = (*catwalkSync)(nil) + +type catwalkSync struct { + once sync.Once + result []catwalk.Provider + cache cache[[]catwalk.Provider] + client catwalkClient + autoupdate bool + init atomic.Bool +} + +func (s *catwalkSync) Init(client catwalkClient, path string, autoupdate bool) { + s.client = client + s.cache = newCache[[]catwalk.Provider](path) + s.autoupdate = autoupdate + s.init.Store(true) +} + +func (s *catwalkSync) Get(ctx context.Context) ([]catwalk.Provider, error) { + if !s.init.Load() { + panic("called Get before Init") + } + + var throwErr error + s.once.Do(func() { + if !s.autoupdate { + slog.Info("Using embedded Catwalk providers") + s.result = embedded.GetAll() + return + } + + cached, etag, cachedErr := s.cache.Get() + if len(cached) == 0 || cachedErr != nil { + // if cached file is empty, default to embedded providers + cached = embedded.GetAll() + } + + slog.Info("Fetching providers from Catwalk") + result, err := s.client.GetProviders(ctx, etag) + if errors.Is(err, context.DeadlineExceeded) { + slog.Warn("Catwalk providers not updated in time") + s.result = cached + return + } + if errors.Is(err, catwalk.ErrNotModified) { + slog.Info("Catwalk providers not modified") + s.result = cached + return + } + if err != nil { + // On error, fall back to cached (which defaults to embedded if empty). + s.result = cached + return + } + if len(result) == 0 { + s.result = cached + throwErr = errors.New("empty providers list from catwalk") + return + } + + s.result = result + throwErr = s.cache.Store(result) + }) + return s.result, throwErr +} diff --git a/internal/config/catwalk_test.go b/internal/config/catwalk_test.go new file mode 100644 index 0000000000000000000000000000000000000000..55322b34eb7252f8cae75fb46996f45bd31abe5e --- /dev/null +++ b/internal/config/catwalk_test.go @@ -0,0 +1,221 @@ +package config + +import ( + "context" + "encoding/json" + "errors" + "os" + "testing" + + "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/stretchr/testify/require" +) + +type mockCatwalkClient struct { + providers []catwalk.Provider + err error + callCount int +} + +func (m *mockCatwalkClient) GetProviders(ctx context.Context, etag string) ([]catwalk.Provider, error) { + m.callCount++ + return m.providers, m.err +} + +func TestCatwalkSync_Init(t *testing.T) { + t.Parallel() + + syncer := &catwalkSync{} + client := &mockCatwalkClient{} + path := "/tmp/test.json" + + syncer.Init(client, path, true) + + require.True(t, syncer.init.Load()) + require.Equal(t, client, syncer.client) + require.Equal(t, path, syncer.cache.path) + require.True(t, syncer.autoupdate) +} + +func TestCatwalkSync_GetPanicIfNotInit(t *testing.T) { + t.Parallel() + + syncer := &catwalkSync{} + require.Panics(t, func() { + _, _ = syncer.Get(t.Context()) + }) +} + +func TestCatwalkSync_GetWithAutoUpdateDisabled(t *testing.T) { + t.Parallel() + + syncer := &catwalkSync{} + client := &mockCatwalkClient{ + providers: []catwalk.Provider{{Name: "should-not-be-used"}}, + } + path := t.TempDir() + "/providers.json" + + syncer.Init(client, path, false) + + providers, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.NotEmpty(t, providers) + require.Equal(t, 0, client.callCount, "Client should not be called when autoupdate is disabled") + + // Should return embedded providers. + for _, p := range providers { + require.NotEqual(t, "should-not-be-used", p.Name) + } +} + +func TestCatwalkSync_GetFreshProviders(t *testing.T) { + t.Parallel() + + syncer := &catwalkSync{} + client := &mockCatwalkClient{ + providers: []catwalk.Provider{ + {Name: "Fresh Provider", ID: "fresh"}, + }, + } + path := t.TempDir() + "/providers.json" + + syncer.Init(client, path, true) + + providers, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Len(t, providers, 1) + require.Equal(t, "Fresh Provider", providers[0].Name) + require.Equal(t, 1, client.callCount) + + // Verify cache was written. + fileInfo, err := os.Stat(path) + require.NoError(t, err) + require.False(t, fileInfo.IsDir()) +} + +func TestCatwalkSync_GetNotModifiedUsesCached(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + path := tmpDir + "/providers.json" + + // Create cache file. + cachedProviders := []catwalk.Provider{ + {Name: "Cached Provider", ID: "cached"}, + } + data, err := json.Marshal(cachedProviders) + require.NoError(t, err) + require.NoError(t, os.WriteFile(path, data, 0o644)) + + syncer := &catwalkSync{} + client := &mockCatwalkClient{ + err: catwalk.ErrNotModified, + } + + syncer.Init(client, path, true) + + providers, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Len(t, providers, 1) + require.Equal(t, "Cached Provider", providers[0].Name) + require.Equal(t, 1, client.callCount) +} + +func TestCatwalkSync_GetEmptyResultFallbackToCached(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + path := tmpDir + "/providers.json" + + // Create cache file. + cachedProviders := []catwalk.Provider{ + {Name: "Cached Provider", ID: "cached"}, + } + data, err := json.Marshal(cachedProviders) + require.NoError(t, err) + require.NoError(t, os.WriteFile(path, data, 0o644)) + + syncer := &catwalkSync{} + client := &mockCatwalkClient{ + providers: []catwalk.Provider{}, // Empty result. + } + + syncer.Init(client, path, true) + + providers, err := syncer.Get(t.Context()) + require.Error(t, err) + require.Contains(t, err.Error(), "empty providers list from catwalk") + require.Len(t, providers, 1) + require.Equal(t, "Cached Provider", providers[0].Name) +} + +func TestCatwalkSync_GetEmptyCacheDefaultsToEmbedded(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + path := tmpDir + "/providers.json" + + // Create empty cache file. + emptyProviders := []catwalk.Provider{} + data, err := json.Marshal(emptyProviders) + require.NoError(t, err) + require.NoError(t, os.WriteFile(path, data, 0o644)) + + syncer := &catwalkSync{} + client := &mockCatwalkClient{ + err: errors.New("network error"), + } + + syncer.Init(client, path, true) + + providers, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.NotEmpty(t, providers, "Should fall back to embedded providers") + + // Verify it's embedded providers by checking we have multiple common ones. + require.Greater(t, len(providers), 5) +} + +func TestCatwalkSync_GetClientError(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + path := tmpDir + "/providers.json" + + syncer := &catwalkSync{} + client := &mockCatwalkClient{ + err: errors.New("network error"), + } + + syncer.Init(client, path, true) + + providers, err := syncer.Get(t.Context()) + require.NoError(t, err) // Should fall back to embedded. + require.NotEmpty(t, providers) +} + +func TestCatwalkSync_GetCalledMultipleTimesUsesOnce(t *testing.T) { + t.Parallel() + + syncer := &catwalkSync{} + client := &mockCatwalkClient{ + providers: []catwalk.Provider{ + {Name: "Provider", ID: "test"}, + }, + } + path := t.TempDir() + "/providers.json" + + syncer.Init(client, path, true) + + // Call Get multiple times. + providers1, err1 := syncer.Get(t.Context()) + require.NoError(t, err1) + require.Len(t, providers1, 1) + + providers2, err2 := syncer.Get(t.Context()) + require.NoError(t, err2) + require.Len(t, providers2, 1) + + // Client should only be called once due to sync.Once. + require.Equal(t, 1, client.callCount) +} diff --git a/internal/config/config.go b/internal/config/config.go index af1222e55c862f06a3ca6eaf6a69c6f632aab134..495c8b32394f37eac43fa4f497c7e69f2df63515 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -13,10 +13,12 @@ import ( "time" "github.com/charmbracelet/catwalk/pkg/catwalk" + hyperp "github.com/charmbracelet/crush/internal/agent/hyper" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/oauth" "github.com/charmbracelet/crush/internal/oauth/claude" + "github.com/charmbracelet/crush/internal/oauth/hyper" "github.com/invopop/jsonschema" "github.com/tidwall/sjson" ) @@ -120,7 +122,6 @@ type ProviderConfig struct { } func (pc *ProviderConfig) SetupClaudeCode() { - pc.APIKey = fmt.Sprintf("Bearer %s", pc.OAuthToken.AccessToken) pc.SystemPromptPrefix = "You are Claude Code, Anthropic's official CLI for Claude." pc.ExtraHeaders["anthropic-version"] = "2023-06-01" @@ -483,12 +484,16 @@ func (c *Config) RefreshOAuthToken(ctx context.Context, providerID string) error return fmt.Errorf("provider %s does not have an OAuth token", providerID) } - // Only Anthropic provider uses OAuth for now. - if providerID != string(catwalk.InferenceProviderAnthropic) { + var newToken *oauth.Token + var err error + switch providerID { + case string(catwalk.InferenceProviderAnthropic): + newToken, err = claude.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken) + case hyperp.Name: + newToken, err = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken) + default: return fmt.Errorf("OAuth refresh not supported for provider %s", providerID) } - - newToken, err := claude.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken) if err != nil { return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, err) } @@ -529,9 +534,11 @@ func (c *Config) SetProviderAPIKey(providerID string, apiKey any) error { return err } setKeyOrToken = func() { - providerConfig.APIKey = v.AccessToken + providerConfig.APIKey = fmt.Sprintf("Bearer %s", v.AccessToken) providerConfig.OAuthToken = v - providerConfig.SetupClaudeCode() + if providerID == string(catwalk.InferenceProviderAnthropic) { + providerConfig.SetupClaudeCode() + } } } diff --git a/internal/config/hyper.go b/internal/config/hyper.go new file mode 100644 index 0000000000000000000000000000000000000000..5fe6fc5a1ee54bd19902ef4c9cc6034a6b294b6f --- /dev/null +++ b/internal/config/hyper.go @@ -0,0 +1,124 @@ +package config + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net/http" + "sync" + "sync/atomic" + "time" + + "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/agent/hyper" + xetag "github.com/charmbracelet/x/etag" +) + +type hyperClient interface { + Get(context.Context, string) (catwalk.Provider, error) +} + +var _ syncer[catwalk.Provider] = (*hyperSync)(nil) + +type hyperSync struct { + once sync.Once + result catwalk.Provider + cache cache[catwalk.Provider] + client hyperClient + autoupdate bool + init atomic.Bool +} + +func (s *hyperSync) Init(client hyperClient, path string, autoupdate bool) { + s.client = client + s.cache = newCache[catwalk.Provider](path) + s.autoupdate = autoupdate + s.init.Store(true) +} + +func (s *hyperSync) Get(ctx context.Context) (catwalk.Provider, error) { + if !s.init.Load() { + panic("called Get before Init") + } + + var throwErr error + s.once.Do(func() { + if !s.autoupdate { + slog.Info("Using embedded Hyper provider") + s.result = hyper.Embedded() + return + } + + cached, etag, cachedErr := s.cache.Get() + if cached.ID == "" || cachedErr != nil { + // if cached file is empty, default to embedded provider + cached = hyper.Embedded() + } + + slog.Info("Fetching Hyper provider") + result, err := s.client.Get(ctx, etag) + if errors.Is(err, context.DeadlineExceeded) { + slog.Warn("Hyper provider not updated in time") + s.result = cached + return + } + if errors.Is(err, catwalk.ErrNotModified) { + slog.Info("Hyper provider not modified") + s.result = cached + return + } + if len(result.Models) == 0 { + slog.Warn("Hyper did not return any models") + s.result = cached + return + } + + s.result = result + throwErr = s.cache.Store(result) + }) + return s.result, throwErr +} + +var _ hyperClient = realHyperClient{} + +type realHyperClient struct { + baseURL string +} + +// Get implements hyperClient. +func (r realHyperClient) Get(ctx context.Context, etag string) (catwalk.Provider, error) { + var result catwalk.Provider + req, err := http.NewRequestWithContext( + ctx, + http.MethodGet, + r.baseURL+"/api/v1/provider", + nil, + ) + if err != nil { + return result, fmt.Errorf("could not create request: %w", err) + } + xetag.Request(req, etag) + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return result, fmt.Errorf("failed to make request: %w", err) + } + defer resp.Body.Close() //nolint:errcheck + + if resp.StatusCode == http.StatusNotModified { + return result, catwalk.ErrNotModified + } + + if resp.StatusCode != http.StatusOK { + return result, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return result, fmt.Errorf("failed to decode response: %w", err) + } + + return result, nil +} diff --git a/internal/config/hyper_test.go b/internal/config/hyper_test.go new file mode 100644 index 0000000000000000000000000000000000000000..7141eaa1e97888b5ee6f84afc8e9658825547b46 --- /dev/null +++ b/internal/config/hyper_test.go @@ -0,0 +1,205 @@ +package config + +import ( + "context" + "encoding/json" + "errors" + "os" + "testing" + + "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/stretchr/testify/require" +) + +type mockHyperClient struct { + provider catwalk.Provider + err error + callCount int +} + +func (m *mockHyperClient) Get(ctx context.Context, etag string) (catwalk.Provider, error) { + m.callCount++ + return m.provider, m.err +} + +func TestHyperSync_Init(t *testing.T) { + t.Parallel() + + syncer := &hyperSync{} + client := &mockHyperClient{} + path := "/tmp/hyper.json" + + syncer.Init(client, path, true) + + require.True(t, syncer.init.Load()) + require.Equal(t, client, syncer.client) + require.Equal(t, path, syncer.cache.path) +} + +func TestHyperSync_GetPanicIfNotInit(t *testing.T) { + t.Parallel() + + syncer := &hyperSync{} + require.Panics(t, func() { + _, _ = syncer.Get(t.Context()) + }) +} + +func TestHyperSync_GetFreshProvider(t *testing.T) { + t.Parallel() + + syncer := &hyperSync{} + client := &mockHyperClient{ + provider: catwalk.Provider{ + Name: "Hyper", + ID: "hyper", + Models: []catwalk.Model{ + {ID: "model-1", Name: "Model 1"}, + }, + }, + } + path := t.TempDir() + "/hyper.json" + + syncer.Init(client, path, true) + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, "Hyper", provider.Name) + require.Equal(t, 1, client.callCount) + + // Verify cache was written. + fileInfo, err := os.Stat(path) + require.NoError(t, err) + require.False(t, fileInfo.IsDir()) +} + +func TestHyperSync_GetNotModifiedUsesCached(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + path := tmpDir + "/hyper.json" + + // Create cache file. + cachedProvider := catwalk.Provider{ + Name: "Cached Hyper", + ID: "hyper", + } + data, err := json.Marshal(cachedProvider) + require.NoError(t, err) + require.NoError(t, os.WriteFile(path, data, 0o644)) + + syncer := &hyperSync{} + client := &mockHyperClient{ + err: catwalk.ErrNotModified, + } + + syncer.Init(client, path, true) + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, "Cached Hyper", provider.Name) + require.Equal(t, 1, client.callCount) +} + +func TestHyperSync_GetClientError(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + path := tmpDir + "/hyper.json" + + syncer := &hyperSync{} + client := &mockHyperClient{ + err: errors.New("network error"), + } + + syncer.Init(client, path, true) + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) // Should fall back to embedded. + require.Equal(t, "Charm Hyper", provider.Name) + require.Equal(t, catwalk.InferenceProvider("hyper"), provider.ID) +} + +func TestHyperSync_GetEmptyCache(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + path := tmpDir + "/hyper.json" + + syncer := &hyperSync{} + client := &mockHyperClient{ + provider: catwalk.Provider{ + Name: "Fresh Hyper", + ID: "hyper", + Models: []catwalk.Model{ + {ID: "model-1", Name: "Model 1"}, + }, + }, + } + + syncer.Init(client, path, true) + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, "Fresh Hyper", provider.Name) +} + +func TestHyperSync_GetCalledMultipleTimesUsesOnce(t *testing.T) { + t.Parallel() + + syncer := &hyperSync{} + client := &mockHyperClient{ + provider: catwalk.Provider{ + Name: "Hyper", + ID: "hyper", + Models: []catwalk.Model{ + {ID: "model-1", Name: "Model 1"}, + }, + }, + } + path := t.TempDir() + "/hyper.json" + + syncer.Init(client, path, true) + + // Call Get multiple times. + provider1, err1 := syncer.Get(t.Context()) + require.NoError(t, err1) + require.Equal(t, "Hyper", provider1.Name) + + provider2, err2 := syncer.Get(t.Context()) + require.NoError(t, err2) + require.Equal(t, "Hyper", provider2.Name) + + // Client should only be called once due to sync.Once. + require.Equal(t, 1, client.callCount) +} + +func TestHyperSync_GetCacheStoreError(t *testing.T) { + t.Parallel() + + // Create a file where we want a directory, causing mkdir to fail. + tmpDir := t.TempDir() + blockingFile := tmpDir + "/blocking" + require.NoError(t, os.WriteFile(blockingFile, []byte("block"), 0o644)) + + // Try to create cache in a subdirectory under the blocking file. + path := blockingFile + "/subdir/hyper.json" + + syncer := &hyperSync{} + client := &mockHyperClient{ + provider: catwalk.Provider{ + Name: "Hyper", + ID: "hyper", + Models: []catwalk.Model{ + {ID: "model-1", Name: "Model 1"}, + }, + }, + } + + syncer.Init(client, path, true) + + provider, err := syncer.Get(t.Context()) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to create directory for provider cache") + require.Equal(t, "Hyper", provider.Name) // Provider is still returned. +} diff --git a/internal/config/load.go b/internal/config/load.go index 14dd0f8792bcbabaa865efe47e0db5e721cf3827..5a26104a4f0894652efd8d8172ea208b09424f67 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -17,6 +17,7 @@ import ( "testing" "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/agent/hyper" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/fsext" @@ -271,7 +272,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know if providerConfig.Type == "" { providerConfig.Type = catwalk.TypeOpenAICompat } - if !slices.Contains(catwalk.KnownProviderTypes(), providerConfig.Type) { + if !slices.Contains(catwalk.KnownProviderTypes(), providerConfig.Type) && providerConfig.Type != hyper.Name { slog.Warn("Skipping custom provider due to unsupported provider type", "provider", id) c.Providers.Del(id) continue diff --git a/internal/config/provider.go b/internal/config/provider.go index e9d4dfcc9d0947eebe51d5ae303106404ab47292..253d6f658a567ed5302887ecb87415de0a89c504 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -10,16 +10,21 @@ import ( "os" "path/filepath" "runtime" + "slices" "strings" "sync" + "time" "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/catwalk/pkg/embedded" + "github.com/charmbracelet/crush/internal/agent/hyper" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/home" + "github.com/charmbracelet/x/etag" ) -type ProviderClient interface { - GetProviders(context.Context, string) ([]catwalk.Provider, error) +type syncer[T any] interface { + Get(context.Context) (T, error) } var ( @@ -29,10 +34,10 @@ var ( ) // file to cache provider data -func providerCacheFileData() string { +func cachePathFor(name string) string { xdgDataHome := os.Getenv("XDG_DATA_HOME") if xdgDataHome != "" { - return filepath.Join(xdgDataHome, appName, "providers.json") + return filepath.Join(xdgDataHome, appName, name+".json") } // return the path to the main data directory @@ -43,43 +48,13 @@ func providerCacheFileData() string { if localAppData == "" { localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local") } - return filepath.Join(localAppData, appName, "providers.json") + return filepath.Join(localAppData, appName, name+".json") } - return filepath.Join(home.Dir(), ".local", "share", appName, "providers.json") -} - -func saveProvidersInCache(path string, providers []catwalk.Provider) error { - slog.Info("Saving provider data to disk", "path", path) - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { - return fmt.Errorf("failed to create directory for provider cache: %w", err) - } - - data, err := json.Marshal(providers) - if err != nil { - return fmt.Errorf("failed to marshal provider data: %w", err) - } - - if err := os.WriteFile(path, data, 0o644); err != nil { - return fmt.Errorf("failed to write provider data to cache: %w", err) - } - return nil -} - -func loadProvidersFromCache(path string) ([]catwalk.Provider, string, error) { - data, err := os.ReadFile(path) - if err != nil { - return nil, "", fmt.Errorf("failed to read provider cache file: %w", err) - } - - var providers []catwalk.Provider - if err := json.Unmarshal(data, &providers); err != nil { - return nil, "", fmt.Errorf("failed to unmarshal provider data from cache: %w", err) - } - - return providers, catwalk.Etag(data), nil + return filepath.Join(home.Dir(), ".local", "share", appName, name+".json") } +// UpdateProviders updates the Catwalk providers list from a specified source. func UpdateProviders(pathOrURL string) error { var providers []catwalk.Provider pathOrURL = cmp.Or(pathOrURL, os.Getenv("CATWALK_URL"), defaultCatwalkURL) @@ -106,15 +81,55 @@ func UpdateProviders(pathOrURL string) error { } } - cachePath := providerCacheFileData() - if err := saveProvidersInCache(cachePath, providers); err != nil { + if err := newCache[[]catwalk.Provider](cachePathFor("providers")).Store(providers); err != nil { return fmt.Errorf("failed to save providers to cache: %w", err) } - slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrURL, "to", cachePath) + slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrURL, "to", cachePathFor) + return nil +} + +// UpdateHyper updates the Hyper provider information from a specified URL. +func UpdateHyper(pathOrURL string) error { + if !hyper.Enabled() { + return fmt.Errorf("hyper not enabled") + } + var provider catwalk.Provider + pathOrURL = cmp.Or(pathOrURL, hyper.BaseURL()) + + switch { + case pathOrURL == "embedded": + provider = hyper.Embedded() + case strings.HasPrefix(pathOrURL, "http://") || strings.HasPrefix(pathOrURL, "https://"): + client := realHyperClient{baseURL: pathOrURL} + var err error + provider, err = client.Get(context.Background(), "") + if err != nil { + return fmt.Errorf("failed to fetch provider from Hyper: %w", err) + } + default: + content, err := os.ReadFile(pathOrURL) + if err != nil { + return fmt.Errorf("failed to read file: %w", err) + } + if err := json.Unmarshal(content, &provider); err != nil { + return fmt.Errorf("failed to unmarshal provider data: %w", err) + } + } + + if err := newCache[catwalk.Provider](cachePathFor("hyper")).Store(provider); err != nil { + return fmt.Errorf("failed to save Hyper provider to cache: %w", err) + } + + slog.Info("Hyper provider updated successfully", "from", pathOrURL, "to", cachePathFor("hyper")) return nil } +var ( + catwalkSyncer = &catwalkSync{} + hyperSyncer = &hyperSync{} +) + // Providers returns the list of providers, taking into account cached results // and whether or not auto update is enabled. // @@ -126,46 +141,87 @@ func UpdateProviders(pathOrURL string) error { // the cached list, or the embedded list if all others fail. func Providers(cfg *Config) ([]catwalk.Provider, error) { providerOnce.Do(func() { - catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL) - client := catwalk.NewWithURL(catwalkURL) - path := providerCacheFileData() - - if cfg.Options.DisableProviderAutoUpdate { - slog.Info("Using embedded Catwalk providers") - providerList, providerErr = embedded.GetAll(), nil - return - } + var wg sync.WaitGroup + var errs []error + providers := csync.NewSlice[catwalk.Provider]() + autoupdate := !cfg.Options.DisableProviderAutoUpdate + + ctx, cancel := context.WithTimeout(context.Background(), 45*time.Second) + defer cancel() + + wg.Go(func() { + catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL) + client := catwalk.NewWithURL(catwalkURL) + path := cachePathFor("providers") + catwalkSyncer.Init(client, path, autoupdate) + + items, err := catwalkSyncer.Get(ctx) + if err != nil { + catwalkURL := fmt.Sprintf("%s/v2/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)) + errs = append(errs, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help.\n\nCause: %w", catwalkURL, providerErr)) //nolint:staticcheck + return + } + providers.Append(items...) + }) + + wg.Go(func() { + if !hyper.Enabled() { + return + } + path := cachePathFor("hyper") + hyperSyncer.Init(realHyperClient{baseURL: hyper.BaseURL()}, path, autoupdate) + + item, err := hyperSyncer.Get(ctx) + if err != nil { + errs = append(errs, fmt.Errorf("Crush was unable to fetch updated information from Hyper: %w", err)) //nolint:staticcheck + return + } + providers.Append(item) + }) + + wg.Wait() + + providerList = slices.Collect(providers.Seq()) + providerErr = errors.Join(errs...) + }) + return providerList, providerErr +} - cached, etag, cachedErr := loadProvidersFromCache(path) - if len(cached) == 0 || cachedErr != nil { - // if cached file is empty, default to embedded providers - cached = embedded.GetAll() - } +type cache[T any] struct { + path string +} - providerList, providerErr = loadProviders(client, etag, path) - if errors.Is(providerErr, catwalk.ErrNotModified) { - slog.Info("Catwalk providers not modified") - providerList, providerErr = cached, nil - } - }) - if providerErr != nil { - catwalkURL := fmt.Sprintf("%s/v2/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)) - return nil, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help.\n\nCause: %w", catwalkURL, providerErr) //nolint:staticcheck - } - return providerList, nil +func newCache[T any](path string) cache[T] { + return cache[T]{path: path} } -func loadProviders(client ProviderClient, etag, path string) ([]catwalk.Provider, error) { - slog.Info("Fetching providers from Catwalk.", "path", path) - providers, err := client.GetProviders(context.Background(), etag) +func (c cache[T]) Get() (T, string, error) { + var v T + data, err := os.ReadFile(c.path) if err != nil { - return nil, fmt.Errorf("failed to fetch providers from catwalk: %w", err) + return v, "", fmt.Errorf("failed to read provider cache file: %w", err) + } + + if err := json.Unmarshal(data, &v); err != nil { + return v, "", fmt.Errorf("failed to unmarshal provider data from cache: %w", err) + } + + return v, etag.Of(data), nil +} + +func (c cache[T]) Store(v T) error { + slog.Info("Saving provider data to disk", "path", c.path) + if err := os.MkdirAll(filepath.Dir(c.path), 0o755); err != nil { + return fmt.Errorf("failed to create directory for provider cache: %w", err) } - if len(providers) == 0 { - return nil, errors.New("empty providers list from catwalk") + + data, err := json.Marshal(v) + if err != nil { + return fmt.Errorf("failed to marshal provider data: %w", err) } - if err := saveProvidersInCache(path, providers); err != nil { - return nil, err + + if err := os.WriteFile(c.path, data, 0o644); err != nil { + return fmt.Errorf("failed to write provider data to cache: %w", err) } - return providers, nil + return nil } diff --git a/internal/config/provider_empty_test.go b/internal/config/provider_empty_test.go index 5e889f82a4997d9f5761e6fc4a01ec6d54a0623a..7c37a9afb9694f0ea4352faee1b11d7e40d9480e 100644 --- a/internal/config/provider_empty_test.go +++ b/internal/config/provider_empty_test.go @@ -2,6 +2,7 @@ package config import ( "context" + "os" "testing" "github.com/charmbracelet/catwalk/pkg/catwalk" @@ -14,18 +15,25 @@ func (m *emptyProviderClient) GetProviders(context.Context, string) ([]catwalk.P return []catwalk.Provider{}, nil } -// TestProvider_loadProvidersEmptyResult tests that loadProviders returns an -// error when the client returns an empty list. This ensures we don't cache -// empty provider lists. -func TestProvider_loadProvidersEmptyResult(t *testing.T) { +// TestCatwalkSync_GetEmptyResultFromClient tests that when the client returns +// an empty list, we fall back to cached providers and return an error. +func TestCatwalkSync_GetEmptyResultFromClient(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + path := tmpDir + "/providers.json" + + syncer := &catwalkSync{} client := &emptyProviderClient{} - tmpPath := t.TempDir() + "/providers.json" - providers, err := loadProviders(client, "", tmpPath) + syncer.Init(client, path, true) + + providers, err := syncer.Get(t.Context()) + require.Error(t, err) require.Contains(t, err.Error(), "empty providers list from catwalk") - require.Empty(t, providers) - require.Len(t, providers, 0) + require.NotEmpty(t, providers) // Should have embedded providers as fallback. - // Check that no cache file was created for empty results - require.NoFileExists(t, tmpPath, "Cache file should not exist for empty results") + // Check that no cache file was created for empty results. + _, statErr := os.Stat(path) + require.True(t, os.IsNotExist(statErr), "Cache file should not exist for empty results") } diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index f101dd8de8ef624ed041ff88115c6c286f902659..e8790e286c3ffc8db77edb0ef8353e54ad519458 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -1,10 +1,9 @@ package config import ( - "context" "encoding/json" - "errors" "os" + "path/filepath" "sync" "testing" @@ -12,47 +11,31 @@ import ( "github.com/stretchr/testify/require" ) -type mockProviderClient struct { - shouldFail bool - shouldReturnErr error -} - -func (m *mockProviderClient) GetProviders(context.Context, string) ([]catwalk.Provider, error) { - if m.shouldReturnErr != nil { - return nil, m.shouldReturnErr - } - if m.shouldFail { - return nil, errors.New("failed to load providers") - } - return []catwalk.Provider{ - { - Name: "Mock", - }, - }, nil -} - func resetProviderState() { providerOnce = sync.Once{} providerList = nil providerErr = nil + catwalkSyncer = &catwalkSync{} + hyperSyncer = &hyperSync{} } -func TestProvider_loadProvidersNoIssues(t *testing.T) { - client := &mockProviderClient{shouldFail: false} - tmpPath := t.TempDir() + "/providers.json" - providers, err := loadProviders(client, "", tmpPath) - require.NoError(t, err) - require.NotNil(t, providers) - require.Len(t, providers, 1) +func TestProviders_Integration_AutoUpdateDisabled(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("XDG_DATA_HOME", tmpDir) - // check if file got saved - fileInfo, err := os.Stat(tmpPath) - require.NoError(t, err) - require.False(t, fileInfo.IsDir(), "Expected a file, not a directory") -} + // Use a test-specific instance to avoid global state interference. + testCatwalkSyncer := &catwalkSync{} + testHyperSyncer := &hyperSync{} + + originalCatwalSyncer := catwalkSyncer + originalHyperSyncer := hyperSyncer + defer func() { + catwalkSyncer = originalCatwalSyncer + hyperSyncer = originalHyperSyncer + }() -func TestProvider_DisableAutoUpdate(t *testing.T) { - t.Setenv("XDG_DATA_HOME", t.TempDir()) + catwalkSyncer = testCatwalkSyncer + hyperSyncer = testHyperSyncer resetProviderState() defer resetProviderState() @@ -69,76 +52,257 @@ func TestProvider_DisableAutoUpdate(t *testing.T) { require.Greater(t, len(providers), 5, "Expected embedded providers") } -func TestProvider_WithValidCache(t *testing.T) { +func TestProviders_Integration_WithMockClients(t *testing.T) { tmpDir := t.TempDir() t.Setenv("XDG_DATA_HOME", tmpDir) - resetProviderState() - defer resetProviderState() + // Create fresh syncers for this test. + testCatwalkSyncer := &catwalkSync{} + testHyperSyncer := &hyperSync{} + + // Initialize with mock clients. + mockCatwalkClient := &mockCatwalkClient{ + providers: []catwalk.Provider{ + {Name: "Provider1", ID: "p1"}, + {Name: "Provider2", ID: "p2"}, + }, + } + mockHyperClient := &mockHyperClient{ + provider: catwalk.Provider{ + Name: "Hyper", + ID: "hyper", + Models: []catwalk.Model{ + {ID: "hyper-1", Name: "Hyper Model"}, + }, + }, + } + + catwalkPath := tmpDir + "/crush/providers.json" + hyperPath := tmpDir + "/crush/hyper.json" + + testCatwalkSyncer.Init(mockCatwalkClient, catwalkPath, true) + testHyperSyncer.Init(mockHyperClient, hyperPath, true) + + // Get providers from each syncer. + catwalkProviders, err := testCatwalkSyncer.Get(t.Context()) + require.NoError(t, err) + require.Len(t, catwalkProviders, 2) + + hyperProvider, err := testHyperSyncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, "Hyper", hyperProvider.Name) + + // Verify total. + allProviders := append(catwalkProviders, hyperProvider) + require.Len(t, allProviders, 3) +} + +func TestProviders_Integration_WithCachedData(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("XDG_DATA_HOME", tmpDir) + + // Create cache files. + catwalkPath := tmpDir + "/crush/providers.json" + hyperPath := tmpDir + "/crush/hyper.json" - cachePath := tmpDir + "/crush/providers.json" require.NoError(t, os.MkdirAll(tmpDir+"/crush", 0o755)) - cachedProviders := []catwalk.Provider{ - {Name: "Cached"}, + + // Write Catwalk cache. + catwalkProviders := []catwalk.Provider{ + {Name: "Cached1", ID: "c1"}, + {Name: "Cached2", ID: "c2"}, + } + data, err := json.Marshal(catwalkProviders) + require.NoError(t, err) + require.NoError(t, os.WriteFile(catwalkPath, data, 0o644)) + + // Write Hyper cache. + hyperProvider := catwalk.Provider{ + Name: "Cached Hyper", + ID: "hyper", } - data, err := json.Marshal(cachedProviders) + data, err = json.Marshal(hyperProvider) require.NoError(t, err) - require.NoError(t, os.WriteFile(cachePath, data, 0o644)) + require.NoError(t, os.WriteFile(hyperPath, data, 0o644)) + + // Create fresh syncers. + testCatwalkSyncer := &catwalkSync{} + testHyperSyncer := &hyperSync{} + + // Mock clients that return ErrNotModified. + mockCatwalkClient := &mockCatwalkClient{ + err: catwalk.ErrNotModified, + } + mockHyperClient := &mockHyperClient{ + err: catwalk.ErrNotModified, + } - mockClient := &mockProviderClient{shouldFail: false} + testCatwalkSyncer.Init(mockCatwalkClient, catwalkPath, true) + testHyperSyncer.Init(mockHyperClient, hyperPath, true) - providers, err := loadProviders(mockClient, "", cachePath) + // Get providers - should use cached. + catwalkResult, err := testCatwalkSyncer.Get(t.Context()) require.NoError(t, err) - require.NotNil(t, providers) - require.Len(t, providers, 1) - require.Equal(t, "Mock", providers[0].Name, "Expected fresh provider from fetch") + require.Len(t, catwalkResult, 2) + require.Equal(t, "Cached1", catwalkResult[0].Name) + + hyperResult, err := testHyperSyncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, "Cached Hyper", hyperResult.Name) } -func TestProvider_NotModifiedUsesCached(t *testing.T) { +func TestProviders_Integration_CatwalkFailsHyperSucceeds(t *testing.T) { tmpDir := t.TempDir() t.Setenv("XDG_DATA_HOME", tmpDir) - resetProviderState() - defer resetProviderState() + testCatwalkSyncer := &catwalkSync{} + testHyperSyncer := &hyperSync{} - cachePath := tmpDir + "/crush/providers.json" - require.NoError(t, os.MkdirAll(tmpDir+"/crush", 0o755)) - cachedProviders := []catwalk.Provider{ - {Name: "Cached"}, + // Catwalk fails, Hyper succeeds. + mockCatwalkClient := &mockCatwalkClient{ + err: catwalk.ErrNotModified, // Will use embedded. } - data, err := json.Marshal(cachedProviders) + mockHyperClient := &mockHyperClient{ + provider: catwalk.Provider{ + Name: "Hyper", + ID: "hyper", + Models: []catwalk.Model{ + {ID: "hyper-1", Name: "Hyper Model"}, + }, + }, + } + + catwalkPath := tmpDir + "/crush/providers.json" + hyperPath := tmpDir + "/crush/hyper.json" + + testCatwalkSyncer.Init(mockCatwalkClient, catwalkPath, true) + testHyperSyncer.Init(mockHyperClient, hyperPath, true) + + catwalkResult, err := testCatwalkSyncer.Get(t.Context()) require.NoError(t, err) - require.NoError(t, os.WriteFile(cachePath, data, 0o644)) + require.NotEmpty(t, catwalkResult) // Should have embedded. - mockClient := &mockProviderClient{shouldReturnErr: catwalk.ErrNotModified} - providers, err := loadProviders(mockClient, "", cachePath) - require.ErrorIs(t, err, catwalk.ErrNotModified) - require.Nil(t, providers) + hyperResult, err := testHyperSyncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, "Hyper", hyperResult.Name) } -func TestProvider_EmptyCacheDefaultsToEmbedded(t *testing.T) { +func TestProviders_Integration_BothFail(t *testing.T) { tmpDir := t.TempDir() t.Setenv("XDG_DATA_HOME", tmpDir) - resetProviderState() - defer resetProviderState() + testCatwalkSyncer := &catwalkSync{} + testHyperSyncer := &hyperSync{} - cachePath := tmpDir + "/crush/providers.json" - require.NoError(t, os.MkdirAll(tmpDir+"/crush", 0o755)) - emptyProviders := []catwalk.Provider{} - data, err := json.Marshal(emptyProviders) + // Both fail. + mockCatwalkClient := &mockCatwalkClient{ + err: catwalk.ErrNotModified, + } + mockHyperClient := &mockHyperClient{ + provider: catwalk.Provider{}, // Empty provider. + } + + catwalkPath := tmpDir + "/crush/providers.json" + hyperPath := tmpDir + "/crush/hyper.json" + + testCatwalkSyncer.Init(mockCatwalkClient, catwalkPath, true) + testHyperSyncer.Init(mockHyperClient, hyperPath, true) + + catwalkResult, err := testCatwalkSyncer.Get(t.Context()) require.NoError(t, err) - require.NoError(t, os.WriteFile(cachePath, data, 0o644)) + require.NotEmpty(t, catwalkResult) // Should fall back to embedded. - cached, _, err := loadProvidersFromCache(cachePath) + hyperResult, err := testHyperSyncer.Get(t.Context()) require.NoError(t, err) - require.Empty(t, cached, "Expected empty cache") + require.Equal(t, "Charm Hyper", hyperResult.Name) // Falls back to embedded when no models. } -func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) { - client := &mockProviderClient{shouldFail: true} - tmpPath := t.TempDir() + "/providers.json" - providers, err := loadProviders(client, "", tmpPath) +func TestCache_StoreAndGet(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + cachePath := tmpDir + "/test.json" + + cache := newCache[[]catwalk.Provider](cachePath) + + providers := []catwalk.Provider{ + {Name: "Provider1", ID: "p1"}, + {Name: "Provider2", ID: "p2"}, + } + + // Store. + err := cache.Store(providers) + require.NoError(t, err) + + // Get. + result, etag, err := cache.Get() + require.NoError(t, err) + require.Len(t, result, 2) + require.Equal(t, "Provider1", result[0].Name) + require.NotEmpty(t, etag) +} + +func TestCache_GetNonExistent(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + cachePath := tmpDir + "/nonexistent.json" + + cache := newCache[[]catwalk.Provider](cachePath) + + _, _, err := cache.Get() require.Error(t, err) - require.Nil(t, providers, "Expected nil providers when loading fails and no cache exists") + require.Contains(t, err.Error(), "failed to read provider cache file") +} + +func TestCache_GetInvalidJSON(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + cachePath := tmpDir + "/invalid.json" + + require.NoError(t, os.WriteFile(cachePath, []byte("invalid json"), 0o644)) + + cache := newCache[[]catwalk.Provider](cachePath) + + _, _, err := cache.Get() + require.Error(t, err) + require.Contains(t, err.Error(), "failed to unmarshal provider data from cache") +} + +func TestCachePathFor(t *testing.T) { + tests := []struct { + name string + xdgDataHome string + expected string + }{ + { + name: "with XDG_DATA_HOME", + xdgDataHome: "/custom/data", + expected: "/custom/data/crush/providers.json", + }, + { + name: "without XDG_DATA_HOME", + xdgDataHome: "", + expected: "", // Will use platform-specific default. + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.xdgDataHome != "" { + t.Setenv("XDG_DATA_HOME", tt.xdgDataHome) + } else { + t.Setenv("XDG_DATA_HOME", "") + } + + result := cachePathFor("providers") + if tt.expected != "" { + require.Equal(t, tt.expected, filepath.ToSlash(result)) + } else { + require.Contains(t, result, "crush") + require.Contains(t, result, "providers.json") + } + }) + } } diff --git a/internal/event/event.go b/internal/event/event.go index 462e3f3ba53a0fc10d77822cf404a12d66b0bdec..1793c283f6a79cf9e6ff8c5fd5c533756e66df78 100644 --- a/internal/event/event.go +++ b/internal/event/event.go @@ -46,6 +46,22 @@ func Init() { distinctId = getDistinctId() } +func GetID() string { return distinctId } + +func Alias(userID string) { + if client == nil || distinctId == fallbackId || distinctId == "" || userID == "" { + return + } + if err := client.Enqueue(posthog.Alias{ + DistinctId: distinctId, + Alias: userID, + }); err != nil { + slog.Error("Failed to enqueue PostHog alias event", "error", err) + return + } + slog.Info("Aliased in PostHog", "machine_id", distinctId, "user_id", userID) +} + // send logs an event to PostHog with the given event name and properties. func send(event string, props ...any) { if client == nil { diff --git a/internal/oauth/hyper/device.go b/internal/oauth/hyper/device.go new file mode 100644 index 0000000000000000000000000000000000000000..90d115f76c197c376eb61d8a9676966eb12a272c --- /dev/null +++ b/internal/oauth/hyper/device.go @@ -0,0 +1,251 @@ +// Package hyper provides functions to handle Hyper device flow authentication. +package hyper + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "strings" + "time" + + "github.com/charmbracelet/crush/internal/agent/hyper" + "github.com/charmbracelet/crush/internal/event" + "github.com/charmbracelet/crush/internal/oauth" +) + +// DeviceAuthResponse contains the response from the device authorization endpoint. +type DeviceAuthResponse struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURL string `json:"verification_url"` + ExpiresIn int `json:"expires_in"` +} + +// TokenResponse contains the response from the polling endpoint. +type TokenResponse struct { + RefreshToken string `json:"refresh_token,omitempty"` + UserID string `json:"user_id"` + OrganizationID string `json:"organization_id"` + OrganizationName string `json:"organization_name"` + Error string `json:"error,omitempty"` + ErrorDescription string `json:"error_description,omitempty"` +} + +// InitiateDeviceAuth calls the /device/auth endpoint to start the device flow. +func InitiateDeviceAuth(ctx context.Context) (*DeviceAuthResponse, error) { + url := hyper.BaseURL() + "/device/auth" + + req, err := http.NewRequestWithContext( + ctx, http.MethodPost, url, + strings.NewReader(fmt.Sprintf(`{"device_name":%q}`, deviceName())), + ) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "crush") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("execute request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("device auth failed: status %d, body %q", resp.StatusCode, string(body)) + } + + var authResp DeviceAuthResponse + if err := json.Unmarshal(body, &authResp); err != nil { + return nil, fmt.Errorf("unmarshal response: %w", err) + } + + return &authResp, nil +} + +func deviceName() string { + if hostname, err := os.Hostname(); err == nil && hostname != "" { + return "Crush (" + hostname + ")" + } + return "Crush" +} + +// PollForToken polls the /device/token endpoint until authorization is complete. +// It respects the polling interval and handles various error states. +func PollForToken(ctx context.Context, deviceCode string, expiresIn int) (string, error) { + ctx, cancel := context.WithTimeout(ctx, time.Duration(expiresIn)*time.Second) + defer cancel() + + d := 5 * time.Second + ticker := time.NewTicker(d) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return "", ctx.Err() + case <-ticker.C: + result, err := pollOnce(ctx, deviceCode) + if err != nil { + return "", err + } + if result.RefreshToken != "" { + event.Alias(result.UserID) + return result.RefreshToken, nil + } + switch result.Error { + case "authorization_pending": + continue + default: + return "", errors.New(result.ErrorDescription) + } + } + } +} + +func pollOnce(ctx context.Context, deviceCode string) (TokenResponse, error) { + var result TokenResponse + url := fmt.Sprintf("%s/device/auth/%s", hyper.BaseURL(), deviceCode) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return result, fmt.Errorf("create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "crush") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return result, fmt.Errorf("execute request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return result, fmt.Errorf("read response: %w", err) + } + + if err := json.Unmarshal(body, &result); err != nil { + return result, fmt.Errorf("unmarshal response: %w: %s", err, string(body)) + } + + if resp.StatusCode != http.StatusOK { + return result, fmt.Errorf("token request failed: status %d body %q", resp.StatusCode, string(body)) + } + + return result, nil +} + +// ExchangeToken exchanges a refresh token for an access token. +func ExchangeToken(ctx context.Context, refreshToken string) (*oauth.Token, error) { + reqBody := map[string]string{ + "refresh_token": refreshToken, + } + + data, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + + url := hyper.BaseURL() + "/token/exchange" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(data)) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "crush") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("execute request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("token exchange failed: status %d body %q", resp.StatusCode, string(body)) + } + + var token oauth.Token + if err := json.Unmarshal(body, &token); err != nil { + return nil, fmt.Errorf("unmarshal response: %w", err) + } + + token.SetExpiresAt() + return &token, nil +} + +// IntrospectTokenResponse contains the response from the token introspection endpoint. +type IntrospectTokenResponse struct { + Active bool `json:"active"` + Sub string `json:"sub,omitempty"` + OrgID string `json:"org_id,omitempty"` + Exp int64 `json:"exp,omitempty"` + Iat int64 `json:"iat,omitempty"` + Iss string `json:"iss,omitempty"` + Jti string `json:"jti,omitempty"` +} + +// IntrospectToken validates an access token using the introspection endpoint. +// Implements OAuth2 Token Introspection (RFC 7662). +func IntrospectToken(ctx context.Context, accessToken string) (*IntrospectTokenResponse, error) { + reqBody := map[string]string{ + "token": accessToken, + } + + data, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + + url := hyper.BaseURL() + "/token/introspect" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(data)) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "crush") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("execute request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("token introspection failed: status %d body %q", resp.StatusCode, string(body)) + } + + var result IntrospectTokenResponse + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("unmarshal response: %w", err) + } + + return &result, nil +} diff --git a/internal/oauth/token.go b/internal/oauth/token.go index 29d4791b5fd416e65995698dcc4665ea96fbb090..7c06ee35bd98b5bc7968766aecda46fd617a2521 100644 --- a/internal/oauth/token.go +++ b/internal/oauth/token.go @@ -4,7 +4,7 @@ import ( "time" ) -// Token represents an OAuth2 token from Claude Code Max. +// Token represents an OAuth2 token. type Token struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` diff --git a/internal/tui/components/dialogs/commands/commands.go b/internal/tui/components/dialogs/commands/commands.go index 7769849bb9b4d6e16e6644ae658d9f0394ac2289..cde5b203ca985f81c390d02725ef04d11a5cd518 100644 --- a/internal/tui/components/dialogs/commands/commands.go +++ b/internal/tui/components/dialogs/commands/commands.go @@ -13,6 +13,7 @@ import ( "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/agent" + "github.com/charmbracelet/crush/internal/agent/hyper" "github.com/charmbracelet/crush/internal/agent/tools/mcp" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/csync" @@ -363,7 +364,7 @@ func (c *commandDialogCmp) defaultCommands() []Command { selectedModel := cfg.Models[agentCfg.Model] // Anthropic models: thinking toggle - if providerCfg.Type == catwalk.TypeAnthropic { + if providerCfg.Type == catwalk.TypeAnthropic || providerCfg.Type == catwalk.Type(hyper.Name) { status := "Enable" if selectedModel.Think { status = "Disable" diff --git a/internal/tui/components/dialogs/hyper/device_flow.go b/internal/tui/components/dialogs/hyper/device_flow.go new file mode 100644 index 0000000000000000000000000000000000000000..796798e51a04f48605e4935b98b753d73c1aca84 --- /dev/null +++ b/internal/tui/components/dialogs/hyper/device_flow.go @@ -0,0 +1,264 @@ +// Package hyper provides the dialog for Hyper device flow authentication. +package hyper + +import ( + "context" + "fmt" + "time" + + "charm.land/bubbles/v2/spinner" + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" + hyperp "github.com/charmbracelet/crush/internal/agent/hyper" + "github.com/charmbracelet/crush/internal/oauth" + "github.com/charmbracelet/crush/internal/oauth/hyper" + "github.com/charmbracelet/crush/internal/tui/styles" + "github.com/charmbracelet/crush/internal/tui/util" + "github.com/pkg/browser" +) + +// DeviceFlowState represents the current state of the device flow. +type DeviceFlowState int + +const ( + DeviceFlowStateDisplay DeviceFlowState = iota + DeviceFlowStateSuccess + DeviceFlowStateError +) + +// DeviceAuthInitiatedMsg is sent when the device auth is initiated +// successfully. +type DeviceAuthInitiatedMsg struct { + deviceCode string + expiresIn int +} + +// DeviceFlowCompletedMsg is sent when the device flow completes successfully. +type DeviceFlowCompletedMsg struct { + Token *oauth.Token +} + +// DeviceFlowErrorMsg is sent when the device flow encounters an error. +type DeviceFlowErrorMsg struct { + Error error +} + +// DeviceFlow handles the Hyper device flow authentication. +type DeviceFlow struct { + State DeviceFlowState + width int + baseURL string + deviceCode string + userCode string + verificationURL string + expiresIn int + token *oauth.Token + cancelFunc context.CancelFunc + spinner spinner.Model +} + +// NewDeviceFlow creates a new device flow component. +func NewDeviceFlow() *DeviceFlow { + s := spinner.New() + s.Spinner = spinner.Dot + s.Style = lipgloss.NewStyle().Foreground(styles.CurrentTheme().GreenLight) + return &DeviceFlow{ + State: DeviceFlowStateDisplay, + baseURL: hyperp.BaseURL(), + spinner: s, + } +} + +// Init initializes the device flow by calling the device auth API and starting polling. +func (d *DeviceFlow) Init() tea.Cmd { + return tea.Batch(d.spinner.Tick, d.initiateDeviceAuth) +} + +// Update handles messages and state transitions. +func (d *DeviceFlow) Update(msg tea.Msg) (util.Model, tea.Cmd) { + var cmd tea.Cmd + d.spinner, cmd = d.spinner.Update(msg) + + switch msg := msg.(type) { + case DeviceAuthInitiatedMsg: + // Start polling now that we have the device code. + d.expiresIn = msg.expiresIn + return d, tea.Batch(cmd, d.startPolling(msg.deviceCode)) + case DeviceFlowCompletedMsg: + d.State = DeviceFlowStateSuccess + d.token = msg.Token + return d, nil + case DeviceFlowErrorMsg: + d.State = DeviceFlowStateError + return d, util.ReportError(msg.Error) + } + + return d, cmd +} + +// View renders the device flow dialog. +func (d *DeviceFlow) View() string { + t := styles.CurrentTheme() + + whiteStyle := lipgloss.NewStyle().Foreground(t.White) + primaryStyle := lipgloss.NewStyle().Foreground(t.Primary) + greenStyle := lipgloss.NewStyle().Foreground(t.GreenLight) + errorStyle := lipgloss.NewStyle().Foreground(t.Error) + mutedStyle := lipgloss.NewStyle().Foreground(t.FgMuted) + + switch d.State { + case DeviceFlowStateDisplay: + if d.userCode == "" { + return lipgloss.NewStyle(). + Margin(0, 1). + Render( + greenStyle.Render(d.spinner.View()) + + mutedStyle.Render("Initializing..."), + ) + } + + instructions := lipgloss.NewStyle(). + Margin(1, 1, 0, 1). + Width(d.width - 2). + Render( + + whiteStyle.Render("Press ") + + primaryStyle.Render("enter") + + whiteStyle.Render(" to copy the code below and open the browser."), + ) + + codeBox := lipgloss.NewStyle(). + Width(d.width-2). + Height(7). + Align(lipgloss.Center, lipgloss.Center). + Background(t.BgBaseLighter). + Margin(1). + Render( + lipgloss.NewStyle(). + Bold(true). + Foreground(t.White). + Render(d.userCode), + ) + + link := lipgloss.NewStyle().Hyperlink(d.verificationURL, "id=hyper-verify").Render(d.verificationURL) + url := mutedStyle. + Margin(0, 1). + Width(d.width - 2). + Render("Browser not opening? Refer to\n" + link) + + waiting := greenStyle. + Width(d.width-2). + Margin(1, 1, 0, 1). + Render(d.spinner.View() + "Verifying...") + + return lipgloss.JoinVertical( + lipgloss.Left, + instructions, + codeBox, + url, + waiting, + ) + + case DeviceFlowStateSuccess: + return greenStyle.Margin(0, 1).Render("Authentication successful!") + + case DeviceFlowStateError: + return lipgloss.NewStyle(). + Margin(0, 1). + Width(d.width). + Render(errorStyle.Render("Authentication failed.")) + + default: + return "" + } +} + +// SetWidth sets the width of the dialog. +func (d *DeviceFlow) SetWidth(w int) { + d.width = w +} + +// Cursor hides the cursor. +func (d *DeviceFlow) Cursor() *tea.Cursor { return nil } + +// CopyCodeAndOpenURL copies the user code to the clipboard and opens the URL. +func (d *DeviceFlow) CopyCodeAndOpenURL() tea.Cmd { + return tea.Sequence( + tea.SetClipboard(d.userCode), + func() tea.Msg { + if err := browser.OpenURL(d.verificationURL); err != nil { + return DeviceFlowErrorMsg{Error: fmt.Errorf("failed to open browser: %w", err)} + } + return nil + }, + util.ReportInfo("Code copied and URL opened"), + ) +} + +// CopyCode copies just the user code to the clipboard. +func (d *DeviceFlow) CopyCode() tea.Cmd { + return tea.Sequence( + tea.SetClipboard(d.userCode), + util.ReportInfo("Code copied to clipboard"), + ) +} + +// Cancel cancels the device flow polling. +func (d *DeviceFlow) Cancel() { + if d.cancelFunc != nil { + d.cancelFunc() + } +} + +func (d *DeviceFlow) initiateDeviceAuth() tea.Msg { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + authResp, err := hyper.InitiateDeviceAuth(ctx) + if err != nil { + return DeviceFlowErrorMsg{Error: fmt.Errorf("failed to initiate device auth: %w", err)} + } + + d.deviceCode = authResp.DeviceCode + d.userCode = authResp.UserCode + d.verificationURL = authResp.VerificationURL + + return DeviceAuthInitiatedMsg{ + deviceCode: authResp.DeviceCode, + expiresIn: authResp.ExpiresIn, + } +} + +// startPolling starts polling for the device token. +func (d *DeviceFlow) startPolling(deviceCode string) tea.Cmd { + return func() tea.Msg { + ctx, cancel := context.WithCancel(context.Background()) + d.cancelFunc = cancel + + // Poll for refresh token. + refreshToken, err := hyper.PollForToken(ctx, deviceCode, d.expiresIn) + if err != nil { + if ctx.Err() != nil { + // Cancelled, don't report error. + return nil + } + return DeviceFlowErrorMsg{Error: err} + } + + // Exchange refresh token for access token. + token, err := hyper.ExchangeToken(ctx, refreshToken) + if err != nil { + return DeviceFlowErrorMsg{Error: fmt.Errorf("token exchange failed: %w", err)} + } + + // Verify the access token works. + introspect, err := hyper.IntrospectToken(ctx, token.AccessToken) + if err != nil { + return DeviceFlowErrorMsg{Error: fmt.Errorf("token introspection failed: %w", err)} + } + if !introspect.Active { + return DeviceFlowErrorMsg{Error: fmt.Errorf("access token is not active")} + } + + return DeviceFlowCompletedMsg{Token: token} + } +} diff --git a/internal/tui/components/dialogs/models/keys.go b/internal/tui/components/dialogs/models/keys.go index c075a35ac808d7d1093c5a7710bf511f8d0219cb..b0d737f6e205c260c33d758c43ec5ad210f8db3f 100644 --- a/internal/tui/components/dialogs/models/keys.go +++ b/internal/tui/components/dialogs/models/keys.go @@ -15,6 +15,8 @@ type KeyMap struct { isAPIKeyHelp bool isAPIKeyValid bool + isHyperDeviceFlow bool + isClaudeAuthChoiceHelp bool isClaudeOAuthHelp bool isClaudeOAuthURLState bool @@ -74,6 +76,19 @@ func (k KeyMap) FullHelp() [][]key.Binding { // ShortHelp implements help.KeyMap. func (k KeyMap) ShortHelp() []key.Binding { + if k.isHyperDeviceFlow { + return []key.Binding{ + key.NewBinding( + key.WithKeys("c"), + key.WithHelp("c", "copy code"), + ), + key.NewBinding( + key.WithKeys("enter"), + key.WithHelp("enter", "copy & open"), + ), + k.Close, + } + } if k.isClaudeAuthChoiceHelp { return []key.Binding{ key.NewBinding( diff --git a/internal/tui/components/dialogs/models/list.go b/internal/tui/components/dialogs/models/list.go index 2383f749de277e7fe915b57aac17fa0e7928756e..eb0ad1b983f86d892c38300ae79296b73a392da8 100644 --- a/internal/tui/components/dialogs/models/list.go +++ b/internal/tui/components/dialogs/models/list.go @@ -62,7 +62,8 @@ func (m *ModelListComponent) Init() tea.Cmd { filteredProviders := []catwalk.Provider{} for _, p := range providers { hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$") - if hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure { + isHyper := p.ID == "hyper" + if (hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure) || isHyper { filteredProviders = append(filteredProviders, p) } } @@ -204,8 +205,23 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { } } + // Move "Charm Hyper" to first position + // (but still after recent models and custom providers). + sortedProviders := make([]catwalk.Provider, len(m.providers)) + copy(sortedProviders, m.providers) + slices.SortStableFunc(sortedProviders, func(a, b catwalk.Provider) int { + switch { + case a.ID == "hyper": + return -1 + case b.ID == "hyper": + return 1 + default: + return 0 + } + }) + // Then add the known providers from the predefined list - for _, provider := range m.providers { + for _, provider := range sortedProviders { // Skip if we already added this provider as an unknown provider if addedProviders[string(provider.ID)] { continue diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go index 6248f1f0440441f1b5d0924b581283835ca2c294..0fb3dcabfa7e3a42068e5540dc7bd4edaee3c2db 100644 --- a/internal/tui/components/dialogs/models/models.go +++ b/internal/tui/components/dialogs/models/models.go @@ -1,3 +1,4 @@ +// Package models provides the model selection dialog for the TUI. package models import ( @@ -11,10 +12,12 @@ import ( "charm.land/lipgloss/v2" "github.com/atotto/clipboard" "github.com/charmbracelet/catwalk/pkg/catwalk" + hyperp "github.com/charmbracelet/crush/internal/agent/hyper" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/components/dialogs" "github.com/charmbracelet/crush/internal/tui/components/dialogs/claude" + "github.com/charmbracelet/crush/internal/tui/components/dialogs/hyper" "github.com/charmbracelet/crush/internal/tui/exp/list" "github.com/charmbracelet/crush/internal/tui/styles" "github.com/charmbracelet/crush/internal/tui/util" @@ -70,6 +73,10 @@ type modelDialogCmp struct { isAPIKeyValid bool apiKeyValue string + // Hyper device flow state + hyperDeviceFlow *hyper.DeviceFlow + showHyperDeviceFlow bool + // Claude state claudeAuthMethodChooser *claude.AuthMethodChooser claudeOAuth2 *claude.OAuth2 @@ -127,6 +134,15 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { u, cmd := m.apiKeyInput.Update(msg) m.apiKeyInput = u.(*APIKeyInput) return m, cmd + case hyper.DeviceFlowCompletedMsg: + return m, m.saveOauthTokenAndContinue(msg.Token, true) + case hyper.DeviceAuthInitiatedMsg, hyper.DeviceFlowErrorMsg: + if m.hyperDeviceFlow != nil { + u, cmd := m.hyperDeviceFlow.Update(msg) + m.hyperDeviceFlow = u.(*hyper.DeviceFlow) + return m, cmd + } + return m, nil case claude.ValidationCompletedMsg: var cmds []tea.Cmd u, cmd := m.claudeOAuth2.Update(msg) @@ -134,7 +150,7 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { cmds = append(cmds, cmd) if msg.State == claude.OAuthValidationStateValid { - cmds = append(cmds, m.saveAPIKeyAndContinue(msg.Token, false)) + cmds = append(cmds, m.saveOauthTokenAndContinue(msg.Token, false)) m.keyMap.isClaudeOAuthHelpComplete = true } @@ -143,6 +159,11 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { return m, util.CmdHandler(dialogs.CloseDialogMsg{}) case tea.KeyPressMsg: switch { + // Handle Hyper device flow keys + case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showHyperDeviceFlow: + if m.hyperDeviceFlow != nil { + return m, m.hyperDeviceFlow.CopyCode() + } case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showClaudeOAuth2 && m.claudeOAuth2.State == claude.OAuthStateURL: return m, tea.Sequence( tea.SetClipboard(m.claudeOAuth2.URL), @@ -156,6 +177,10 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { m.claudeAuthMethodChooser.ToggleChoice() return m, nil case key.Matches(msg, m.keyMap.Select): + // If showing device flow, enter copies code and opens URL + if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil { + return m, m.hyperDeviceFlow.CopyCodeAndOpenURL() + } selectedItem := m.modelList.SelectedModel() modelType := config.SelectedModelTypeLarge @@ -167,6 +192,7 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { m.keyMap.isClaudeAuthChoiceHelp = false m.keyMap.isClaudeOAuthHelp = false m.keyMap.isAPIKeyHelp = true + m.showHyperDeviceFlow = false m.showClaudeAuthMethodChooser = false m.needsAPIKey = true m.selectedModel = selectedItem @@ -194,7 +220,7 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { return m, cmd2 } if m.isAPIKeyValid { - return m, m.saveAPIKeyAndContinue(m.apiKeyValue, true) + return m, m.saveOauthTokenAndContinue(m.apiKeyValue, true) } if m.needsAPIKey { // Handle API key submission @@ -249,15 +275,23 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { ModelType: modelType, }), ) - } else { - if selectedItem.Provider.ID == catwalk.InferenceProviderAnthropic { - m.showClaudeAuthMethodChooser = true - m.keyMap.isClaudeAuthChoiceHelp = true - return m, nil - } - askForApiKey() + } + switch selectedItem.Provider.ID { + case catwalk.InferenceProviderAnthropic: + m.showClaudeAuthMethodChooser = true + m.keyMap.isClaudeAuthChoiceHelp = true return m, nil + case hyperp.Name: + m.showHyperDeviceFlow = true + m.selectedModel = selectedItem + m.selectedModelType = modelType + m.hyperDeviceFlow = hyper.NewDeviceFlow() + m.hyperDeviceFlow.SetWidth(m.width - 2) + return m, m.hyperDeviceFlow.Init() } + // For other providers, show API key input + askForApiKey() + return m, nil case key.Matches(msg, m.keyMap.Tab): switch { case m.showClaudeAuthMethodChooser: @@ -275,6 +309,14 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { return m, m.modelList.SetModelType(LargeModelType) } case key.Matches(msg, m.keyMap.Close): + if m.showHyperDeviceFlow { + // Cancel device flow and go back to model selection + if m.hyperDeviceFlow != nil { + m.hyperDeviceFlow.Cancel() + } + m.showHyperDeviceFlow = false + m.selectedModel = nil + } if m.showClaudeAuthMethodChooser { m.claudeAuthMethodChooser.SetDefaults() m.showClaudeAuthMethodChooser = false @@ -329,7 +371,20 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { return m, cmd } case spinner.TickMsg: - if m.showClaudeOAuth2 { + u, cmd := m.apiKeyInput.Update(msg) + m.apiKeyInput = u.(*APIKeyInput) + if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil { + u, cmd = m.hyperDeviceFlow.Update(msg) + m.hyperDeviceFlow = u.(*hyper.DeviceFlow) + } + return m, cmd + default: + // Pass all other messages to the device flow for spinner animation + if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil { + u, cmd := m.hyperDeviceFlow.Update(msg) + m.hyperDeviceFlow = u.(*hyper.DeviceFlow) + return m, cmd + } else if m.showClaudeOAuth2 { u, cmd := m.claudeOAuth2.Update(msg) m.claudeOAuth2 = u.(*claude.OAuth2) return m, cmd @@ -345,6 +400,23 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { func (m *modelDialogCmp) View() string { t := styles.CurrentTheme() + if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil { + // Show Hyper device flow + m.keyMap.isHyperDeviceFlow = true + deviceFlowView := m.hyperDeviceFlow.View() + content := lipgloss.JoinVertical( + lipgloss.Left, + t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Authenticate with Hyper", m.width-4)), + deviceFlowView, + "", + t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)), + ) + return m.style().Render(content) + } + + // Reset the flags when not showing device flow + m.keyMap.isHyperDeviceFlow = false + switch { case m.showClaudeAuthMethodChooser: chooserView := m.claudeAuthMethodChooser.View() @@ -397,6 +469,9 @@ func (m *modelDialogCmp) View() string { } func (m *modelDialogCmp) Cursor() *tea.Cursor { + if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil { + return m.hyperDeviceFlow.Cursor() + } if m.showClaudeAuthMethodChooser { return nil } @@ -477,10 +552,8 @@ func (m *modelDialogCmp) modelTypeRadio() string { func (m *modelDialogCmp) isProviderConfigured(providerID string) bool { cfg := config.Get() - if _, ok := cfg.Providers.Get(providerID); ok { - return true - } - return false + _, ok := cfg.Providers.Get(providerID) + return ok } func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) { @@ -497,7 +570,7 @@ func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*cat return nil, nil } -func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey any, close bool) tea.Cmd { +func (m *modelDialogCmp) saveOauthTokenAndContinue(apiKey any, close bool) tea.Cmd { if m.selectedModel == nil { return util.ReportError(fmt.Errorf("no model selected")) } diff --git a/internal/tui/exp/list/items.go b/internal/tui/exp/list/items.go index fa89fb6e7a58a5fc0d9e6bcab36a979130c482e9..3db5635b044d9845915d005dd5f7cdac233fe53f 100644 --- a/internal/tui/exp/list/items.go +++ b/internal/tui/exp/list/items.go @@ -323,6 +323,7 @@ type ItemSection interface { layout.Sizeable Indexable SetInfo(info string) + Title() string } type itemSectionModel struct { width int @@ -337,6 +338,11 @@ func (m *itemSectionModel) ID() string { return m.id } +// Title implements ItemSection. +func (m *itemSectionModel) Title() string { + return m.title +} + func NewItemSection(title string) ItemSection { return &itemSectionModel{ title: title,