.gitattributes 🔗
@@ -1,3 +1,4 @@
*.golden linguist-generated=true -text
.github/crush-schema.json linguist-generated=true
+internal/agent/hyper/provider.json linguist-generated=true
internal/agent/testdata/**/*.yaml -diff linguist-generated=true
Ayman Bagabas created
.gitattributes | 1
.github/workflows/schema-update.yml | 6
README.md | 5
Taskfile.yaml | 7
cspell.json | 0
go.mod | 6
go.sum | 4
internal/agent/agent.go | 18
internal/agent/coordinator.go | 15
internal/agent/hyper/provider.go | 330 ++++++++++++
internal/agent/hyper/provider.json | 0
internal/cmd/login.go | 195 ++++++
internal/cmd/projects.go | 7
internal/cmd/update_providers.go | 38 +
internal/config/catwalk.go | 82 ++
internal/config/catwalk_test.go | 221 ++++++++
internal/config/config.go | 85 ++
internal/config/copilot.go | 43 +
internal/config/hyper.go | 124 ++++
internal/config/hyper_test.go | 205 +++++++
internal/config/load.go | 39 +
internal/config/provider.go | 204 ++++--
internal/config/provider_empty_test.go | 28
internal/config/provider_test.go | 316 ++++++++--
internal/event/event.go | 16
internal/oauth/copilot/disk.go | 36 +
internal/oauth/copilot/http.go | 17
internal/oauth/copilot/oauth.go | 200 +++++++
internal/oauth/copilot/urls.go | 6
internal/oauth/hyper/device.go | 251 +++++++++
internal/oauth/token.go | 7
internal/tui/components/dialogs/commands/commands.go | 3
internal/tui/components/dialogs/copilot/device_flow.go | 281 ++++++++++
internal/tui/components/dialogs/hyper/device_flow.go | 267 +++++++++
internal/tui/components/dialogs/models/keys.go | 26
internal/tui/components/dialogs/models/list.go | 45
internal/tui/components/dialogs/models/models.go | 167 +++++
internal/tui/components/mcp/mcp.go | 12
internal/tui/exp/list/items.go | 6
internal/ui/dialog/sessions_item.go | 11
internal/ui/list/filterable.go | 12
internal/ui/list/item.go | 19
internal/ui/list/list.go | 21
internal/ui/model/chat.go | 16
internal/ui/model/mcp.go | 6
internal/ui/model/ui.go | 46
46 files changed, 3,144 insertions(+), 306 deletions(-)
@@ -1,3 +1,4 @@
*.golden linguist-generated=true -text
.github/crush-schema.json linguist-generated=true
+internal/agent/hyper/provider.json linguist-generated=true
internal/agent/testdata/**/*.yaml -diff linguist-generated=true
@@ -1,10 +1,11 @@
-name: Update Schema
+name: Update files
on:
push:
branches: [main]
paths:
- "internal/config/**"
+ - "internal/agent/hyper/**"
jobs:
update-schema:
@@ -17,9 +18,10 @@ jobs:
with:
go-version-file: go.mod
- run: go run . schema > ./schema.json
+ - run: go generate ./internal/agent/hyper/...
- uses: stefanzweifel/git-auto-commit-action@28e16e81777b558cc906c8750092100bbb34c5e3 # v5
with:
- commit_message: "chore: auto-update generated files"
+ commit_message: "chore: auto-update files"
branch: main
commit_user_name: Charm
commit_user_email: 124303983+charmcli@users.noreply.github.com
@@ -232,6 +232,11 @@ $HOME/.local/share/crush/crush.json
%LOCALAPPDATA%\crush\crush.json
```
+> [!TIP]
+> You can override the user and data config locations by setting:
+> * `CRUSH_GLOBAL_CONFIG`
+> * `CRUSH_GLOBAL_DATA`
+
### LSPs
Crush can use LSPs for additional context to help inform its decisions, just
@@ -101,6 +101,13 @@ tasks:
generates:
- schema.json
+ hyper:
+ desc: Update Hyper embedded provider.json
+ cmds:
+ - go generate ./internal/agent/hyper/...
+ generates:
+ - ./internal/agent/hyper/provider.json
+
release:
desc: Create and push a new tag following semver
vars:
@@ -1 +0,0 @@
@@ -18,11 +18,12 @@ require (
github.com/aymanbagabas/go-udiff v0.3.1
github.com/bmatcuk/doublestar/v4 v4.9.1
github.com/charlievieth/fastwalk v1.0.14
- github.com/charmbracelet/catwalk v0.10.2
+ github.com/charmbracelet/catwalk v0.11.0
github.com/charmbracelet/colorprofile v0.4.1
github.com/charmbracelet/fang v0.4.4
github.com/charmbracelet/ultraviolet v0.0.0-20251212194010-b927aa605560
github.com/charmbracelet/x/ansi v0.11.3
+ github.com/charmbracelet/x/etag v0.2.0
github.com/charmbracelet/x/exp/charmtone v0.0.0-20250708181618-a60a724ba6c3
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f
github.com/charmbracelet/x/exp/ordered v0.1.0
@@ -53,6 +54,7 @@ require (
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef
github.com/stretchr/testify v1.11.1
+ github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
github.com/zeebo/xxh3 v1.0.2
golang.org/x/mod v0.31.0
@@ -92,7 +94,6 @@ require (
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20251024181547-21d6f3d9a904 // indirect
- github.com/charmbracelet/x/etag v0.2.0 // indirect
github.com/charmbracelet/x/json v0.2.0 // indirect
github.com/charmbracelet/x/termios v0.1.1 // indirect
github.com/charmbracelet/x/windows v0.2.2 // indirect
@@ -145,7 +146,6 @@ require (
github.com/sourcegraph/jsonrpc2 v0.2.1 // indirect
github.com/spf13/pflag v1.0.9 // indirect
github.com/tetratelabs/wazero v1.10.1 // indirect
- github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/u-root/u-root v0.14.1-0.20250807200646-5e7721023dc7 // indirect
@@ -92,8 +92,8 @@ github.com/charlievieth/fastwalk v1.0.14 h1:3Eh5uaFGwHZd8EGwTjJnSpBkfwfsak9h6ICg
github.com/charlievieth/fastwalk v1.0.14/go.mod h1:diVcUreiU1aQ4/Wu3NbxxH4/KYdKpLDojrQ1Bb2KgNY=
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20251024181547-21d6f3d9a904 h1:rwLdEpG9wE6kL69KkEKDiWprO8pQOZHZXeod6+9K+mw=
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20251024181547-21d6f3d9a904/go.mod h1:8TIYxZxsuCqqeJ0lga/b91tBwrbjoHDC66Sq5t8N2R4=
-github.com/charmbracelet/catwalk v0.10.2 h1:Ps6IeGu0ArKE3l3OYv+HwIwbnzZrAl1C3AuwXiOf1G0=
-github.com/charmbracelet/catwalk v0.10.2/go.mod h1:qg+Yl9oaZTkTvRscqbxfttzOFQ4v0pOT5XwC7b5O0NQ=
+github.com/charmbracelet/catwalk v0.11.0 h1:PU3rkc4h4YVJEn9Iyb/1rQAaF4hEd04fuG4tj3vv4dg=
+github.com/charmbracelet/catwalk v0.11.0/go.mod h1:qg+Yl9oaZTkTvRscqbxfttzOFQ4v0pOT5XwC7b5O0NQ=
github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk=
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
github.com/charmbracelet/fang v0.4.4 h1:G4qKxF6or/eTPgmAolwPuRNyuci3hTUGGX1rj1YkHJY=
@@ -27,7 +27,9 @@ import (
"charm.land/fantasy/providers/google"
"charm.land/fantasy/providers/openai"
"charm.land/fantasy/providers/openrouter"
+ "charm.land/lipgloss/v2"
"github.com/charmbracelet/catwalk/pkg/catwalk"
+ "github.com/charmbracelet/crush/internal/agent/hyper"
"github.com/charmbracelet/crush/internal/agent/tools"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/csync"
@@ -454,8 +456,22 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
} else if isPermissionErr {
currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
+ } else if errors.Is(err, hyper.ErrNoCredits) {
+ url := hyper.BaseURL()
+ link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
+ currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
} else if errors.As(err, &providerErr) {
- currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
+ if providerErr.Message == "The requested model is not supported." {
+ url := "https://github.com/settings/copilot/features"
+ link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
+ currentAssistant.AddFinish(
+ message.FinishReasonError,
+ "Copilot model not enabled",
+ fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait a minute before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
+ )
+ } else {
+ currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
+ }
} else if errors.As(err, &fantasyErr) {
currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
} else {
@@ -17,6 +17,7 @@ import (
"charm.land/fantasy"
"github.com/charmbracelet/catwalk/pkg/catwalk"
+ "github.com/charmbracelet/crush/internal/agent/hyper"
"github.com/charmbracelet/crush/internal/agent/prompt"
"github.com/charmbracelet/crush/internal/agent/tools"
"github.com/charmbracelet/crush/internal/config"
@@ -668,6 +669,18 @@ func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, optio
return google.New(opts...)
}
+func (c *coordinator) buildHyperProvider(baseURL, apiKey string) (fantasy.Provider, error) {
+ opts := []hyper.Option{
+ hyper.WithBaseURL(baseURL),
+ hyper.WithAPIKey(apiKey),
+ }
+ if c.cfg.Options.Debug {
+ httpClient := log.NewHTTPClient()
+ opts = append(opts, hyper.WithHTTPClient(httpClient))
+ }
+ return hyper.New(opts...)
+}
+
func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
if model.Think {
return true
@@ -728,6 +741,8 @@ func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model con
providerCfg.ExtraBody["tool_stream"] = true
}
return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody)
+ case hyper.Name:
+ return c.buildHyperProvider(baseURL, apiKey)
default:
return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
}
@@ -0,0 +1,330 @@
+// Package hyper provides a fantasy.Provider that proxies requests to Hyper.
+package hyper
+
+import (
+ "bufio"
+ "bytes"
+ "cmp"
+ "context"
+ _ "embed"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "log/slog"
+ "maps"
+ "net/http"
+ "net/url"
+ "os"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "charm.land/fantasy"
+ "charm.land/fantasy/object"
+ "github.com/charmbracelet/catwalk/pkg/catwalk"
+ "github.com/charmbracelet/crush/internal/event"
+)
+
+//go:generate wget -O provider.json https://console.charm.land/api/v1/provider
+
+//go:embed provider.json
+var embedded []byte
+
+// Enabled returns true if hyper is enabled.
+var Enabled = sync.OnceValue(func() bool {
+ b, _ := strconv.ParseBool(
+ cmp.Or(
+ os.Getenv("HYPER"),
+ os.Getenv("HYPERCRUSH"),
+ os.Getenv("HYPER_ENABLE"),
+ os.Getenv("HYPER_ENABLED"),
+ ),
+ )
+ return b
+})
+
+// Embedded returns the embedded Hyper provider.
+var Embedded = sync.OnceValue(func() catwalk.Provider {
+ var provider catwalk.Provider
+ if err := json.Unmarshal(embedded, &provider); err != nil {
+ slog.Error("could not use embedded provider data", "err", err)
+ }
+ return provider
+})
+
+const (
+ // Name is the default name of this meta provider.
+ Name = "hyper"
+ // defaultBaseURL is the default proxy URL.
+ // TODO: change this to production URL when ready.
+ defaultBaseURL = "https://console.charm.land"
+)
+
+// BaseURL returns the base URL, which is either $HYPER_URL or the default.
+var BaseURL = sync.OnceValue(func() string {
+ return cmp.Or(os.Getenv("HYPER_URL"), defaultBaseURL)
+})
+
+var ErrNoCredits = errors.New("you're out of credits")
+
+type options struct {
+ baseURL string
+ apiKey string
+ name string
+ headers map[string]string
+ client *http.Client
+}
+
+// Option configures the proxy provider.
+type Option = func(*options)
+
+// New creates a new proxy provider.
+func New(opts ...Option) (fantasy.Provider, error) {
+ o := options{
+ baseURL: BaseURL() + "/api/v1/fantasy",
+ name: Name,
+ headers: map[string]string{
+ "x-crush-id": event.GetID(),
+ },
+ client: &http.Client{Timeout: 0}, // stream-safe
+ }
+ for _, opt := range opts {
+ opt(&o)
+ }
+ return &provider{options: o}, nil
+}
+
+// WithBaseURL sets the proxy base URL (e.g. http://localhost:8080).
+func WithBaseURL(url string) Option { return func(o *options) { o.baseURL = url } }
+
+// WithName sets the provider name.
+func WithName(name string) Option { return func(o *options) { o.name = name } }
+
+// WithHeaders sets extra headers sent to the proxy.
+func WithHeaders(headers map[string]string) Option {
+ return func(o *options) {
+ maps.Copy(o.headers, headers)
+ }
+}
+
+// WithHTTPClient sets custom HTTP client.
+func WithHTTPClient(c *http.Client) Option { return func(o *options) { o.client = c } }
+
+// WithAPIKey sets the API key.
+func WithAPIKey(key string) Option {
+ return func(o *options) {
+ o.apiKey = key
+ }
+}
+
+type provider struct{ options options }
+
+func (p *provider) Name() string { return p.options.name }
+
+// LanguageModel implements fantasy.Provider.
+func (p *provider) LanguageModel(_ context.Context, modelID string) (fantasy.LanguageModel, error) {
+ if modelID == "" {
+ return nil, errors.New("missing model id")
+ }
+ return &languageModel{modelID: modelID, provider: p.options.name, opts: p.options}, nil
+}
+
+type languageModel struct {
+ provider string
+ modelID string
+ opts options
+}
+
+// GenerateObject implements fantasy.LanguageModel.
+func (m *languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
+ return object.GenerateWithTool(ctx, m, call)
+}
+
+// StreamObject implements fantasy.LanguageModel.
+func (m *languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
+ return object.StreamWithTool(ctx, m, call)
+}
+
+func (m *languageModel) Provider() string { return m.provider }
+func (m *languageModel) Model() string { return m.modelID }
+
+// Generate implements fantasy.LanguageModel by calling the proxy JSON endpoint.
+func (m *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
+ resp, err := m.doRequest(ctx, false, call)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = resp.Body.Close() }()
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ b, _ := ioReadAllLimit(resp.Body, 64*1024)
+ return nil, fmt.Errorf("proxy generate error: %s", strings.TrimSpace(string(b)))
+ }
+ var out fantasy.Response
+ if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
+ return nil, err
+ }
+ return &out, nil
+}
+
+// Stream implements fantasy.LanguageModel using SSE from the proxy.
+func (m *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
+ // Prefer explicit /stream endpoint
+ resp, err := m.doRequest(ctx, true, call)
+ if err != nil {
+ return nil, err
+ }
+ switch resp.StatusCode {
+ case http.StatusTooManyRequests:
+ _ = resp.Body.Close()
+ return nil, toProviderError(resp, retryAfter(resp))
+ case http.StatusUnauthorized:
+ _ = resp.Body.Close()
+ return nil, toProviderError(resp, "")
+ case http.StatusPaymentRequired:
+ _ = resp.Body.Close()
+ return nil, ErrNoCredits
+ }
+
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ defer func() { _ = resp.Body.Close() }()
+ b, _ := ioReadAllLimit(resp.Body, 64*1024)
+ return nil, &fantasy.ProviderError{
+ Title: "Stream Error",
+ Message: strings.TrimSpace(string(b)),
+ StatusCode: resp.StatusCode,
+ }
+ }
+
+ return func(yield func(fantasy.StreamPart) bool) {
+ defer func() { _ = resp.Body.Close() }()
+ scanner := bufio.NewScanner(resp.Body)
+ buf := make([]byte, 0, 64*1024)
+ scanner.Buffer(buf, 4*1024*1024)
+
+ var (
+ event string
+ dataBuf bytes.Buffer
+ sawFinish bool
+ dispatch = func() bool {
+ if dataBuf.Len() == 0 || event == "" {
+ dataBuf.Reset()
+ event = ""
+ return true
+ }
+ var part fantasy.StreamPart
+ if err := json.Unmarshal(dataBuf.Bytes(), &part); err != nil {
+ return yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err})
+ }
+ if part.Type == fantasy.StreamPartTypeFinish {
+ sawFinish = true
+ }
+ ok := yield(part)
+ dataBuf.Reset()
+ event = ""
+ return ok
+ }
+ )
+
+ for scanner.Scan() {
+ line := scanner.Text()
+ if line == "" { // event boundary
+ if !dispatch() {
+ return
+ }
+ continue
+ }
+ if strings.HasPrefix(line, ":") { // comment / ping
+ continue
+ }
+ if strings.HasPrefix(line, "event: ") {
+ event = strings.TrimSpace(line[len("event: "):])
+ continue
+ }
+ if strings.HasPrefix(line, "data: ") {
+ if dataBuf.Len() > 0 {
+ dataBuf.WriteByte('\n')
+ }
+ dataBuf.WriteString(line[len("data: "):])
+ continue
+ }
+ }
+ if err := scanner.Err(); err != nil &&
+ !errors.Is(err, context.Canceled) &&
+ !errors.Is(err, context.DeadlineExceeded) {
+ yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err})
+ return
+ }
+ // flush any pending data
+ _ = dispatch()
+ if !sawFinish {
+ _ = yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish})
+ }
+ }, nil
+}
+
+func (m *languageModel) doRequest(ctx context.Context, stream bool, call fantasy.Call) (*http.Response, error) {
+ addr, err := url.Parse(m.opts.baseURL)
+ if err != nil {
+ return nil, err
+ }
+ addr = addr.JoinPath(m.modelID)
+ if stream {
+ addr = addr.JoinPath("stream")
+ } else {
+ addr = addr.JoinPath("generate")
+ }
+
+ body, err := json.Marshal(call)
+ if err != nil {
+ return nil, err
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, addr.String(), bytes.NewReader(body))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+ if stream {
+ req.Header.Set("Accept", "text/event-stream")
+ } else {
+ req.Header.Set("Accept", "application/json")
+ }
+ for k, v := range m.opts.headers {
+ req.Header.Set(k, v)
+ }
+
+ if m.opts.apiKey != "" {
+ req.Header.Set("Authorization", m.opts.apiKey)
+ }
+ return m.opts.client.Do(req)
+}
+
+// ioReadAllLimit reads up to n bytes.
+func ioReadAllLimit(r io.Reader, n int64) ([]byte, error) {
+ var b bytes.Buffer
+ if n <= 0 {
+ n = 1 << 20
+ }
+ lr := &io.LimitedReader{R: r, N: n}
+ _, err := b.ReadFrom(lr)
+ return b.Bytes(), err
+}
+
+func toProviderError(resp *http.Response, message string) error {
+ return &fantasy.ProviderError{
+ Title: fantasy.ErrorTitleForStatusCode(resp.StatusCode),
+ Message: message,
+ StatusCode: resp.StatusCode,
+ }
+}
+
+func retryAfter(resp *http.Response) string {
+ after, err := strconv.Atoi(resp.Header.Get("Retry-After"))
+ if err == nil && after > 0 {
+ d := time.Duration(after) * time.Second
+ return "Try again in " + d.String()
+ }
+ return "Try again in later"
+}
@@ -0,0 +1 @@
@@ -9,8 +9,14 @@ import (
"strings"
"charm.land/lipgloss/v2"
+ "github.com/atotto/clipboard"
+ hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
"github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/oauth"
"github.com/charmbracelet/crush/internal/oauth/claude"
+ "github.com/charmbracelet/crush/internal/oauth/copilot"
+ "github.com/charmbracelet/crush/internal/oauth/hyper"
+ "github.com/pkg/browser"
"github.com/spf13/cobra"
)
@@ -20,46 +26,122 @@ var loginCmd = &cobra.Command{
Short: "Login Crush to a platform",
Long: `Login Crush to a specified platform.
The platform should be provided as an argument.
-Available platforms are: claude.`,
+Available platforms are: hyper, claude, copilot.`,
Example: `
+# Authenticate with Charm Hyper
+crush login
+
# Authenticate with Claude Code Max
crush login claude
+
+# Authenticate with GitHub Copilot
+crush login copilot
`,
ValidArgs: []cobra.Completion{
+ "hyper",
"claude",
"anthropic",
+ "copilot",
+ "github",
+ "github-copilot",
},
- Args: cobra.ExactArgs(1),
+ Args: cobra.MaximumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
- if len(args) > 1 {
- return fmt.Errorf("wrong number of arguments")
- }
- if len(args) == 0 || args[0] == "" {
- return cmd.Help()
- }
-
app, err := setupAppWithProgressBar(cmd)
if err != nil {
return err
}
defer app.Shutdown()
- switch args[0] {
+ provider := "hyper"
+ if len(args) > 0 {
+ provider = args[0]
+ }
+ switch provider {
+ case "hyper":
+ return loginHyper()
case "anthropic", "claude":
return loginClaude()
+ case "copilot", "github", "github-copilot":
+ return loginCopilot()
default:
return fmt.Errorf("unknown platform: %s", args[0])
}
},
}
-func loginClaude() error {
+func loginHyper() error {
+ cfg := config.Get()
+ if !hyperp.Enabled() {
+ return fmt.Errorf("hyper not enabled")
+ }
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
- go func() {
- <-ctx.Done()
- cancel()
- os.Exit(1)
- }()
+ defer cancel()
+
+ resp, err := hyper.InitiateDeviceAuth(ctx)
+ if err != nil {
+ return err
+ }
+
+ if clipboard.WriteAll(resp.UserCode) == nil {
+ fmt.Println("The following code should be on clipboard already:")
+ } else {
+ fmt.Println("Copy the following code:")
+ }
+
+ fmt.Println()
+ fmt.Println(lipgloss.NewStyle().Bold(true).Render(resp.UserCode))
+ fmt.Println()
+ fmt.Println("Press enter to open this URL, and then paste it there:")
+ fmt.Println()
+ fmt.Println(lipgloss.NewStyle().Hyperlink(resp.VerificationURL, "id=hyper").Render(resp.VerificationURL))
+ fmt.Println()
+ waitEnter()
+ if err := browser.OpenURL(resp.VerificationURL); err != nil {
+ fmt.Println("Could not open the URL. You'll need to manually open the URL in your browser.")
+ }
+
+ fmt.Println("Exchanging authorization code...")
+ refreshToken, err := hyper.PollForToken(ctx, resp.DeviceCode, resp.ExpiresIn)
+ if err != nil {
+ return err
+ }
+
+ fmt.Println("Exchanging refresh token for access token...")
+ token, err := hyper.ExchangeToken(ctx, refreshToken)
+ if err != nil {
+ return err
+ }
+
+ fmt.Println("Verifying access token...")
+ introspect, err := hyper.IntrospectToken(ctx, token.AccessToken)
+ if err != nil {
+ return fmt.Errorf("token introspection failed: %w", err)
+ }
+ if !introspect.Active {
+ return fmt.Errorf("access token is not active")
+ }
+
+ if err := cmp.Or(
+ cfg.SetConfigField("providers.hyper.api_key", token.AccessToken),
+ cfg.SetConfigField("providers.hyper.oauth", token),
+ ); err != nil {
+ return err
+ }
+
+ fmt.Println()
+ fmt.Println("You're now authenticated with Hyper!")
+ return nil
+}
+
+func loginClaude() error {
+ ctx := getLoginContext()
+
+ cfg := config.Get()
+ if cfg.HasConfigField("providers.anthropic.oauth") {
+ fmt.Println("You are already logged in to Claude.")
+ return nil
+ }
verifier, challenge, err := claude.GetChallenge()
if err != nil {
@@ -94,7 +176,6 @@ func loginClaude() error {
return err
}
- cfg := config.Get()
if err := cmp.Or(
cfg.SetConfigField("providers.anthropic.api_key", token.AccessToken),
cfg.SetConfigField("providers.anthropic.oauth", token),
@@ -106,3 +187,83 @@ func loginClaude() error {
fmt.Println("You're now authenticated with Claude Code Max!")
return nil
}
+
+func loginCopilot() error {
+ ctx := getLoginContext()
+
+ cfg := config.Get()
+ if cfg.HasConfigField("providers.copilot.oauth") {
+ fmt.Println("You are already logged in to GitHub Copilot.")
+ return nil
+ }
+
+ diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
+ var token *oauth.Token
+
+ switch {
+ case hasDiskToken:
+ fmt.Println("Found existing GitHub Copilot token on disk. Using it to authenticate...")
+
+ t, err := copilot.RefreshToken(ctx, diskToken)
+ if err != nil {
+ return fmt.Errorf("unable to refresh token from disk: %w", err)
+ }
+ token = t
+ default:
+ fmt.Println("Requesting device code from GitHub...")
+ dc, err := copilot.RequestDeviceCode(ctx)
+ if err != nil {
+ return err
+ }
+
+ fmt.Println()
+ fmt.Println("Open the following URL and follow the instructions to authenticate with GitHub Copilot:")
+ fmt.Println()
+ fmt.Println(lipgloss.NewStyle().Hyperlink(dc.VerificationURI, "id=copilot").Render(dc.VerificationURI))
+ fmt.Println()
+ fmt.Println("Code:", lipgloss.NewStyle().Bold(true).Render(dc.UserCode))
+ fmt.Println()
+ fmt.Println("Waiting for authorization...")
+
+ t, err := copilot.PollForToken(ctx, dc)
+ if err == copilot.ErrNotAvailable {
+ fmt.Println()
+ fmt.Println("GitHub Copilot is unavailable for this account. To signup, go to the following page:")
+ fmt.Println()
+ fmt.Println(lipgloss.NewStyle().Hyperlink(copilot.SignupURL, "id=copilot-signup").Render(copilot.SignupURL))
+ fmt.Println()
+ fmt.Println("You may be able to request free access if elegible. For more information, see:")
+ fmt.Println()
+ fmt.Println(lipgloss.NewStyle().Hyperlink(copilot.FreeURL, "id=copilot-free").Render(copilot.FreeURL))
+ }
+ if err != nil {
+ return err
+ }
+ token = t
+ }
+
+ if err := cmp.Or(
+ cfg.SetConfigField("providers.copilot.api_key", token.AccessToken),
+ cfg.SetConfigField("providers.copilot.oauth", token),
+ ); err != nil {
+ return err
+ }
+
+ fmt.Println()
+ fmt.Println("You're now authenticated with GitHub Copilot!")
+ return nil
+}
+
+func getLoginContext() context.Context {
+ ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
+ go func() {
+ <-ctx.Done()
+ cancel()
+ os.Exit(1)
+ }()
+ return ctx
+}
+
+func waitEnter() {
+ _, _ = fmt.Scanln()
+}
@@ -13,14 +13,13 @@ import (
var projectsCmd = &cobra.Command{
Use: "projects",
- Short: "List all tracked project directories",
- Long: `List all directories where Crush has been used.
-This includes the working directory, data directory path, and last accessed time.`,
+ Short: "List project directories",
+ Long: "List directories where Crush project data is known to exist",
Example: `
# List all projects in a table
crush projects
-# Output as JSON
+# Output projects data as JSON
crush projects --json
`,
RunE: func(cmd *cobra.Command, args []string) error {
@@ -10,22 +10,30 @@ import (
"github.com/spf13/cobra"
)
+var updateProvidersSource string
+
var updateProvidersCmd = &cobra.Command{
Use: "update-providers [path-or-url]",
Short: "Update providers",
- Long: `Update the list of providers from a specified local path or remote URL.`,
+ Long: `Update provider information from a specified local path or remote URL.`,
Example: `
-# Update providers remotely from Catwalk
+# Update Catwalk providers remotely (default)
crush update-providers
-# Update providers from a custom URL
-crush update-providers https://example.com/
+# Update Catwalk providers from a custom URL
+crush update-providers https://example.com/providers.json
-# Update providers from a local file
+# Update Catwalk providers from a local file
crush update-providers /path/to/local-providers.json
-# Update providers from embedded version
+# Update Catwalk providers from embedded version
crush update-providers embedded
+
+# Update Hyper provider information
+crush update-providers --source=hyper
+
+# Update Hyper from a custom URL
+crush update-providers --source=hyper https://hyper.example.com
`,
RunE: func(cmd *cobra.Command, args []string) error {
// NOTE(@andreynering): We want to skip logging output do stdout here.
@@ -36,7 +44,17 @@ crush update-providers embedded
pathOrURL = args[0]
}
- if err := config.UpdateProviders(pathOrURL); err != nil {
+ var err error
+ switch updateProvidersSource {
+ case "catwalk":
+ err = config.UpdateProviders(pathOrURL)
+ case "hyper":
+ err = config.UpdateHyper(pathOrURL)
+ default:
+ return fmt.Errorf("invalid source %q, must be 'catwalk' or 'hyper'", updateProvidersSource)
+ }
+
+ if err != nil {
return err
}
@@ -52,9 +70,13 @@ crush update-providers embedded
SetString("SUCCESS")
textStyle := lipgloss.NewStyle().
MarginLeft(2).
- SetString("Providers updated successfully.")
+ SetString(fmt.Sprintf("%s provider updated successfully.", updateProvidersSource))
fmt.Printf("%s\n%s\n\n", headerStyle.Render(), textStyle.Render())
return nil
},
}
+
+func init() {
+ updateProvidersCmd.Flags().StringVar(&updateProvidersSource, "source", "catwalk", "Provider source to update (catwalk or hyper)")
+}
@@ -0,0 +1,82 @@
+package config
+
+import (
+ "context"
+ "errors"
+ "log/slog"
+ "sync"
+ "sync/atomic"
+
+ "github.com/charmbracelet/catwalk/pkg/catwalk"
+ "github.com/charmbracelet/catwalk/pkg/embedded"
+)
+
+type catwalkClient interface {
+ GetProviders(context.Context, string) ([]catwalk.Provider, error)
+}
+
+var _ syncer[[]catwalk.Provider] = (*catwalkSync)(nil)
+
+type catwalkSync struct {
+ once sync.Once
+ result []catwalk.Provider
+ cache cache[[]catwalk.Provider]
+ client catwalkClient
+ autoupdate bool
+ init atomic.Bool
+}
+
+func (s *catwalkSync) Init(client catwalkClient, path string, autoupdate bool) {
+ s.client = client
+ s.cache = newCache[[]catwalk.Provider](path)
+ s.autoupdate = autoupdate
+ s.init.Store(true)
+}
+
+func (s *catwalkSync) Get(ctx context.Context) ([]catwalk.Provider, error) {
+ if !s.init.Load() {
+ panic("called Get before Init")
+ }
+
+ var throwErr error
+ s.once.Do(func() {
+ if !s.autoupdate {
+ slog.Info("Using embedded Catwalk providers")
+ s.result = embedded.GetAll()
+ return
+ }
+
+ cached, etag, cachedErr := s.cache.Get()
+ if len(cached) == 0 || cachedErr != nil {
+ // if cached file is empty, default to embedded providers
+ cached = embedded.GetAll()
+ }
+
+ slog.Info("Fetching providers from Catwalk")
+ result, err := s.client.GetProviders(ctx, etag)
+ if errors.Is(err, context.DeadlineExceeded) {
+ slog.Warn("Catwalk providers not updated in time")
+ s.result = cached
+ return
+ }
+ if errors.Is(err, catwalk.ErrNotModified) {
+ slog.Info("Catwalk providers not modified")
+ s.result = cached
+ return
+ }
+ if err != nil {
+ // On error, fall back to cached (which defaults to embedded if empty).
+ s.result = cached
+ return
+ }
+ if len(result) == 0 {
+ s.result = cached
+ throwErr = errors.New("empty providers list from catwalk")
+ return
+ }
+
+ s.result = result
+ throwErr = s.cache.Store(result)
+ })
+ return s.result, throwErr
+}
@@ -0,0 +1,221 @@
+package config
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "os"
+ "testing"
+
+ "github.com/charmbracelet/catwalk/pkg/catwalk"
+ "github.com/stretchr/testify/require"
+)
+
+type mockCatwalkClient struct {
+ providers []catwalk.Provider
+ err error
+ callCount int
+}
+
+func (m *mockCatwalkClient) GetProviders(ctx context.Context, etag string) ([]catwalk.Provider, error) {
+ m.callCount++
+ return m.providers, m.err
+}
+
+func TestCatwalkSync_Init(t *testing.T) {
+ t.Parallel()
+
+ syncer := &catwalkSync{}
+ client := &mockCatwalkClient{}
+ path := "/tmp/test.json"
+
+ syncer.Init(client, path, true)
+
+ require.True(t, syncer.init.Load())
+ require.Equal(t, client, syncer.client)
+ require.Equal(t, path, syncer.cache.path)
+ require.True(t, syncer.autoupdate)
+}
+
+func TestCatwalkSync_GetPanicIfNotInit(t *testing.T) {
+ t.Parallel()
+
+ syncer := &catwalkSync{}
+ require.Panics(t, func() {
+ _, _ = syncer.Get(t.Context())
+ })
+}
+
+func TestCatwalkSync_GetWithAutoUpdateDisabled(t *testing.T) {
+ t.Parallel()
+
+ syncer := &catwalkSync{}
+ client := &mockCatwalkClient{
+ providers: []catwalk.Provider{{Name: "should-not-be-used"}},
+ }
+ path := t.TempDir() + "/providers.json"
+
+ syncer.Init(client, path, false)
+
+ providers, err := syncer.Get(t.Context())
+ require.NoError(t, err)
+ require.NotEmpty(t, providers)
+ require.Equal(t, 0, client.callCount, "Client should not be called when autoupdate is disabled")
+
+ // Should return embedded providers.
+ for _, p := range providers {
+ require.NotEqual(t, "should-not-be-used", p.Name)
+ }
+}
+
+func TestCatwalkSync_GetFreshProviders(t *testing.T) {
+ t.Parallel()
+
+ syncer := &catwalkSync{}
+ client := &mockCatwalkClient{
+ providers: []catwalk.Provider{
+ {Name: "Fresh Provider", ID: "fresh"},
+ },
+ }
+ path := t.TempDir() + "/providers.json"
+
+ syncer.Init(client, path, true)
+
+ providers, err := syncer.Get(t.Context())
+ require.NoError(t, err)
+ require.Len(t, providers, 1)
+ require.Equal(t, "Fresh Provider", providers[0].Name)
+ require.Equal(t, 1, client.callCount)
+
+ // Verify cache was written.
+ fileInfo, err := os.Stat(path)
+ require.NoError(t, err)
+ require.False(t, fileInfo.IsDir())
+}
+
+func TestCatwalkSync_GetNotModifiedUsesCached(t *testing.T) {
+ t.Parallel()
+
+ tmpDir := t.TempDir()
+ path := tmpDir + "/providers.json"
+
+ // Create cache file.
+ cachedProviders := []catwalk.Provider{
+ {Name: "Cached Provider", ID: "cached"},
+ }
+ data, err := json.Marshal(cachedProviders)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(path, data, 0o644))
+
+ syncer := &catwalkSync{}
+ client := &mockCatwalkClient{
+ err: catwalk.ErrNotModified,
+ }
+
+ syncer.Init(client, path, true)
+
+ providers, err := syncer.Get(t.Context())
+ require.NoError(t, err)
+ require.Len(t, providers, 1)
+ require.Equal(t, "Cached Provider", providers[0].Name)
+ require.Equal(t, 1, client.callCount)
+}
+
+func TestCatwalkSync_GetEmptyResultFallbackToCached(t *testing.T) {
+ t.Parallel()
+
+ tmpDir := t.TempDir()
+ path := tmpDir + "/providers.json"
+
+ // Create cache file.
+ cachedProviders := []catwalk.Provider{
+ {Name: "Cached Provider", ID: "cached"},
+ }
+ data, err := json.Marshal(cachedProviders)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(path, data, 0o644))
+
+ syncer := &catwalkSync{}
+ client := &mockCatwalkClient{
+ providers: []catwalk.Provider{}, // Empty result.
+ }
+
+ syncer.Init(client, path, true)
+
+ providers, err := syncer.Get(t.Context())
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "empty providers list from catwalk")
+ require.Len(t, providers, 1)
+ require.Equal(t, "Cached Provider", providers[0].Name)
+}
+
+func TestCatwalkSync_GetEmptyCacheDefaultsToEmbedded(t *testing.T) {
+ t.Parallel()
+
+ tmpDir := t.TempDir()
+ path := tmpDir + "/providers.json"
+
+ // Create empty cache file.
+ emptyProviders := []catwalk.Provider{}
+ data, err := json.Marshal(emptyProviders)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(path, data, 0o644))
+
+ syncer := &catwalkSync{}
+ client := &mockCatwalkClient{
+ err: errors.New("network error"),
+ }
+
+ syncer.Init(client, path, true)
+
+ providers, err := syncer.Get(t.Context())
+ require.NoError(t, err)
+ require.NotEmpty(t, providers, "Should fall back to embedded providers")
+
+ // Verify it's embedded providers by checking we have multiple common ones.
+ require.Greater(t, len(providers), 5)
+}
+
+func TestCatwalkSync_GetClientError(t *testing.T) {
+ t.Parallel()
+
+ tmpDir := t.TempDir()
+ path := tmpDir + "/providers.json"
+
+ syncer := &catwalkSync{}
+ client := &mockCatwalkClient{
+ err: errors.New("network error"),
+ }
+
+ syncer.Init(client, path, true)
+
+ providers, err := syncer.Get(t.Context())
+ require.NoError(t, err) // Should fall back to embedded.
+ require.NotEmpty(t, providers)
+}
+
+func TestCatwalkSync_GetCalledMultipleTimesUsesOnce(t *testing.T) {
+ t.Parallel()
+
+ syncer := &catwalkSync{}
+ client := &mockCatwalkClient{
+ providers: []catwalk.Provider{
+ {Name: "Provider", ID: "test"},
+ },
+ }
+ path := t.TempDir() + "/providers.json"
+
+ syncer.Init(client, path, true)
+
+ // Call Get multiple times.
+ providers1, err1 := syncer.Get(t.Context())
+ require.NoError(t, err1)
+ require.Len(t, providers1, 1)
+
+ providers2, err2 := syncer.Get(t.Context())
+ require.NoError(t, err2)
+ require.Len(t, providers2, 1)
+
+ // Client should only be called once due to sync.Once.
+ require.Equal(t, 1, client.callCount)
+}
@@ -5,6 +5,7 @@ import (
"context"
"fmt"
"log/slog"
+ "maps"
"net/http"
"net/url"
"os"
@@ -13,11 +14,15 @@ import (
"time"
"github.com/charmbracelet/catwalk/pkg/catwalk"
+ hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
"github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/env"
"github.com/charmbracelet/crush/internal/oauth"
"github.com/charmbracelet/crush/internal/oauth/claude"
+ "github.com/charmbracelet/crush/internal/oauth/copilot"
+ "github.com/charmbracelet/crush/internal/oauth/hyper"
"github.com/invopop/jsonschema"
+ "github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -119,8 +124,37 @@ type ProviderConfig struct {
Models []catwalk.Model `json:"models,omitempty" jsonschema:"description=List of models available from this provider"`
}
+// ToProvider converts the [ProviderConfig] to a [catwalk.Provider].
+func (pc *ProviderConfig) ToProvider() catwalk.Provider {
+ // Convert config provider to provider.Provider format
+ provider := catwalk.Provider{
+ Name: pc.Name,
+ ID: catwalk.InferenceProvider(pc.ID),
+ Models: make([]catwalk.Model, len(pc.Models)),
+ }
+
+ // Convert models
+ for i, model := range pc.Models {
+ provider.Models[i] = catwalk.Model{
+ ID: model.ID,
+ Name: model.Name,
+ CostPer1MIn: model.CostPer1MIn,
+ CostPer1MOut: model.CostPer1MOut,
+ CostPer1MInCached: model.CostPer1MInCached,
+ CostPer1MOutCached: model.CostPer1MOutCached,
+ ContextWindow: model.ContextWindow,
+ DefaultMaxTokens: model.DefaultMaxTokens,
+ CanReason: model.CanReason,
+ ReasoningLevels: model.ReasoningLevels,
+ DefaultReasoningEffort: model.DefaultReasoningEffort,
+ SupportsImages: model.SupportsImages,
+ }
+ }
+
+ return provider
+}
+
func (pc *ProviderConfig) SetupClaudeCode() {
- pc.APIKey = fmt.Sprintf("Bearer %s", pc.OAuthToken.AccessToken)
pc.SystemPromptPrefix = "You are Claude Code, Anthropic's official CLI for Claude."
pc.ExtraHeaders["anthropic-version"] = "2023-06-01"
@@ -135,6 +169,10 @@ func (pc *ProviderConfig) SetupClaudeCode() {
pc.ExtraHeaders["anthropic-beta"] = value
}
+func (pc *ProviderConfig) SetupGitHubCopilot() {
+ maps.Copy(pc.ExtraHeaders, copilot.Headers())
+}
+
type MCPType string
const (
@@ -451,6 +489,14 @@ func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model Selecte
return nil
}
+func (c *Config) HasConfigField(key string) bool {
+ data, err := os.ReadFile(c.dataConfigDir)
+ if err != nil {
+ return false
+ }
+ return gjson.Get(string(data), key).Exists()
+}
+
func (c *Config) SetConfigField(key string, value any) error {
// read the data
data, err := os.ReadFile(c.dataConfigDir)
@@ -483,20 +529,33 @@ func (c *Config) RefreshOAuthToken(ctx context.Context, providerID string) error
return fmt.Errorf("provider %s does not have an OAuth token", providerID)
}
- // Only Anthropic provider uses OAuth for now.
- if providerID != string(catwalk.InferenceProviderAnthropic) {
+ var newToken *oauth.Token
+ var refreshErr error
+ switch providerID {
+ case string(catwalk.InferenceProviderAnthropic):
+ newToken, refreshErr = claude.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
+ case string(catwalk.InferenceProviderCopilot):
+ newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
+ case hyperp.Name:
+ newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken)
+ default:
return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
}
-
- newToken, err := claude.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
- if err != nil {
- return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, err)
+ if refreshErr != nil {
+ return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr)
}
slog.Info("Successfully refreshed OAuth token", "provider", providerID)
providerConfig.OAuthToken = newToken
- providerConfig.APIKey = fmt.Sprintf("Bearer %s", newToken.AccessToken)
- providerConfig.SetupClaudeCode()
+
+ switch providerID {
+ case string(catwalk.InferenceProviderAnthropic):
+ providerConfig.APIKey = fmt.Sprintf("Bearer %s", newToken.AccessToken)
+ providerConfig.SetupClaudeCode()
+ case string(catwalk.InferenceProviderCopilot):
+ providerConfig.APIKey = newToken.AccessToken
+ providerConfig.SetupGitHubCopilot()
+ }
c.Providers.Set(providerID, providerConfig)
@@ -531,7 +590,13 @@ func (c *Config) SetProviderAPIKey(providerID string, apiKey any) error {
setKeyOrToken = func() {
providerConfig.APIKey = v.AccessToken
providerConfig.OAuthToken = v
- providerConfig.SetupClaudeCode()
+ switch providerID {
+ case string(catwalk.InferenceProviderAnthropic):
+ providerConfig.APIKey = fmt.Sprintf("Bearer %s", v.AccessToken)
+ providerConfig.SetupClaudeCode()
+ case string(catwalk.InferenceProviderCopilot):
+ providerConfig.SetupGitHubCopilot()
+ }
}
}
@@ -0,0 +1,43 @@
+package config
+
+import (
+ "cmp"
+ "context"
+ "log/slog"
+ "testing"
+
+ "github.com/charmbracelet/crush/internal/oauth"
+ "github.com/charmbracelet/crush/internal/oauth/copilot"
+)
+
+func (c *Config) importCopilot() (*oauth.Token, bool) {
+ if testing.Testing() {
+ return nil, false
+ }
+
+ if c.HasConfigField("providers.copilot.api_key") || c.HasConfigField("providers.copilot.oauth") {
+ return nil, false
+ }
+
+ diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
+ if !hasDiskToken {
+ return nil, false
+ }
+
+ slog.Info("Found existing GitHub Copilot token on disk. Authenticating...")
+ token, err := copilot.RefreshToken(context.TODO(), diskToken)
+ if err != nil {
+ slog.Error("Unable to import GitHub Copilot token", "error", err)
+ return nil, false
+ }
+
+ if err := cmp.Or(
+ c.SetConfigField("providers.copilot.api_key", token.AccessToken),
+ c.SetConfigField("providers.copilot.oauth", token),
+ ); err != nil {
+ slog.Error("Unable to save GitHub Copilot token to disk", "error", err)
+ }
+
+ slog.Info("GitHub Copilot successfully imported")
+ return token, true
+}
@@ -0,0 +1,124 @@
+package config
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "log/slog"
+ "net/http"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/charmbracelet/catwalk/pkg/catwalk"
+ "github.com/charmbracelet/crush/internal/agent/hyper"
+ xetag "github.com/charmbracelet/x/etag"
+)
+
+type hyperClient interface {
+ Get(context.Context, string) (catwalk.Provider, error)
+}
+
+var _ syncer[catwalk.Provider] = (*hyperSync)(nil)
+
+type hyperSync struct {
+ once sync.Once
+ result catwalk.Provider
+ cache cache[catwalk.Provider]
+ client hyperClient
+ autoupdate bool
+ init atomic.Bool
+}
+
+func (s *hyperSync) Init(client hyperClient, path string, autoupdate bool) {
+ s.client = client
+ s.cache = newCache[catwalk.Provider](path)
+ s.autoupdate = autoupdate
+ s.init.Store(true)
+}
+
+func (s *hyperSync) Get(ctx context.Context) (catwalk.Provider, error) {
+ if !s.init.Load() {
+ panic("called Get before Init")
+ }
+
+ var throwErr error
+ s.once.Do(func() {
+ if !s.autoupdate {
+ slog.Info("Using embedded Hyper provider")
+ s.result = hyper.Embedded()
+ return
+ }
+
+ cached, etag, cachedErr := s.cache.Get()
+ if cached.ID == "" || cachedErr != nil {
+ // if cached file is empty, default to embedded provider
+ cached = hyper.Embedded()
+ }
+
+ slog.Info("Fetching Hyper provider")
+ result, err := s.client.Get(ctx, etag)
+ if errors.Is(err, context.DeadlineExceeded) {
+ slog.Warn("Hyper provider not updated in time")
+ s.result = cached
+ return
+ }
+ if errors.Is(err, catwalk.ErrNotModified) {
+ slog.Info("Hyper provider not modified")
+ s.result = cached
+ return
+ }
+ if len(result.Models) == 0 {
+ slog.Warn("Hyper did not return any models")
+ s.result = cached
+ return
+ }
+
+ s.result = result
+ throwErr = s.cache.Store(result)
+ })
+ return s.result, throwErr
+}
+
+var _ hyperClient = realHyperClient{}
+
+type realHyperClient struct {
+ baseURL string
+}
+
+// Get implements hyperClient.
+func (r realHyperClient) Get(ctx context.Context, etag string) (catwalk.Provider, error) {
+ var result catwalk.Provider
+ req, err := http.NewRequestWithContext(
+ ctx,
+ http.MethodGet,
+ r.baseURL+"/api/v1/provider",
+ nil,
+ )
+ if err != nil {
+ return result, fmt.Errorf("could not create request: %w", err)
+ }
+ xetag.Request(req, etag)
+
+ client := &http.Client{Timeout: 30 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return result, fmt.Errorf("failed to make request: %w", err)
+ }
+ defer resp.Body.Close() //nolint:errcheck
+
+ if resp.StatusCode == http.StatusNotModified {
+ return result, catwalk.ErrNotModified
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return result, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
+ }
+
+ if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
+ return result, fmt.Errorf("failed to decode response: %w", err)
+ }
+
+ return result, nil
+}
@@ -0,0 +1,205 @@
+package config
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "os"
+ "testing"
+
+ "github.com/charmbracelet/catwalk/pkg/catwalk"
+ "github.com/stretchr/testify/require"
+)
+
+type mockHyperClient struct {
+ provider catwalk.Provider
+ err error
+ callCount int
+}
+
+func (m *mockHyperClient) Get(ctx context.Context, etag string) (catwalk.Provider, error) {
+ m.callCount++
+ return m.provider, m.err
+}
+
+func TestHyperSync_Init(t *testing.T) {
+ t.Parallel()
+
+ syncer := &hyperSync{}
+ client := &mockHyperClient{}
+ path := "/tmp/hyper.json"
+
+ syncer.Init(client, path, true)
+
+ require.True(t, syncer.init.Load())
+ require.Equal(t, client, syncer.client)
+ require.Equal(t, path, syncer.cache.path)
+}
+
+func TestHyperSync_GetPanicIfNotInit(t *testing.T) {
+ t.Parallel()
+
+ syncer := &hyperSync{}
+ require.Panics(t, func() {
+ _, _ = syncer.Get(t.Context())
+ })
+}
+
+func TestHyperSync_GetFreshProvider(t *testing.T) {
+ t.Parallel()
+
+ syncer := &hyperSync{}
+ client := &mockHyperClient{
+ provider: catwalk.Provider{
+ Name: "Hyper",
+ ID: "hyper",
+ Models: []catwalk.Model{
+ {ID: "model-1", Name: "Model 1"},
+ },
+ },
+ }
+ path := t.TempDir() + "/hyper.json"
+
+ syncer.Init(client, path, true)
+
+ provider, err := syncer.Get(t.Context())
+ require.NoError(t, err)
+ require.Equal(t, "Hyper", provider.Name)
+ require.Equal(t, 1, client.callCount)
+
+ // Verify cache was written.
+ fileInfo, err := os.Stat(path)
+ require.NoError(t, err)
+ require.False(t, fileInfo.IsDir())
+}
+
+func TestHyperSync_GetNotModifiedUsesCached(t *testing.T) {
+ t.Parallel()
+
+ tmpDir := t.TempDir()
+ path := tmpDir + "/hyper.json"
+
+ // Create cache file.
+ cachedProvider := catwalk.Provider{
+ Name: "Cached Hyper",
+ ID: "hyper",
+ }
+ data, err := json.Marshal(cachedProvider)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(path, data, 0o644))
+
+ syncer := &hyperSync{}
+ client := &mockHyperClient{
+ err: catwalk.ErrNotModified,
+ }
+
+ syncer.Init(client, path, true)
+
+ provider, err := syncer.Get(t.Context())
+ require.NoError(t, err)
+ require.Equal(t, "Cached Hyper", provider.Name)
+ require.Equal(t, 1, client.callCount)
+}
+
+func TestHyperSync_GetClientError(t *testing.T) {
+ t.Parallel()
+
+ tmpDir := t.TempDir()
+ path := tmpDir + "/hyper.json"
+
+ syncer := &hyperSync{}
+ client := &mockHyperClient{
+ err: errors.New("network error"),
+ }
+
+ syncer.Init(client, path, true)
+
+ provider, err := syncer.Get(t.Context())
+ require.NoError(t, err) // Should fall back to embedded.
+ require.Equal(t, "Charm Hyper", provider.Name)
+ require.Equal(t, catwalk.InferenceProvider("hyper"), provider.ID)
+}
+
+func TestHyperSync_GetEmptyCache(t *testing.T) {
+ t.Parallel()
+
+ tmpDir := t.TempDir()
+ path := tmpDir + "/hyper.json"
+
+ syncer := &hyperSync{}
+ client := &mockHyperClient{
+ provider: catwalk.Provider{
+ Name: "Fresh Hyper",
+ ID: "hyper",
+ Models: []catwalk.Model{
+ {ID: "model-1", Name: "Model 1"},
+ },
+ },
+ }
+
+ syncer.Init(client, path, true)
+
+ provider, err := syncer.Get(t.Context())
+ require.NoError(t, err)
+ require.Equal(t, "Fresh Hyper", provider.Name)
+}
+
+func TestHyperSync_GetCalledMultipleTimesUsesOnce(t *testing.T) {
+ t.Parallel()
+
+ syncer := &hyperSync{}
+ client := &mockHyperClient{
+ provider: catwalk.Provider{
+ Name: "Hyper",
+ ID: "hyper",
+ Models: []catwalk.Model{
+ {ID: "model-1", Name: "Model 1"},
+ },
+ },
+ }
+ path := t.TempDir() + "/hyper.json"
+
+ syncer.Init(client, path, true)
+
+ // Call Get multiple times.
+ provider1, err1 := syncer.Get(t.Context())
+ require.NoError(t, err1)
+ require.Equal(t, "Hyper", provider1.Name)
+
+ provider2, err2 := syncer.Get(t.Context())
+ require.NoError(t, err2)
+ require.Equal(t, "Hyper", provider2.Name)
+
+ // Client should only be called once due to sync.Once.
+ require.Equal(t, 1, client.callCount)
+}
+
+func TestHyperSync_GetCacheStoreError(t *testing.T) {
+ t.Parallel()
+
+ // Create a file where we want a directory, causing mkdir to fail.
+ tmpDir := t.TempDir()
+ blockingFile := tmpDir + "/blocking"
+ require.NoError(t, os.WriteFile(blockingFile, []byte("block"), 0o644))
+
+ // Try to create cache in a subdirectory under the blocking file.
+ path := blockingFile + "/subdir/hyper.json"
+
+ syncer := &hyperSync{}
+ client := &mockHyperClient{
+ provider: catwalk.Provider{
+ Name: "Hyper",
+ ID: "hyper",
+ Models: []catwalk.Model{
+ {ID: "model-1", Name: "Model 1"},
+ },
+ },
+ }
+
+ syncer.Init(client, path, true)
+
+ provider, err := syncer.Get(t.Context())
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "failed to create directory for provider cache")
+ require.Equal(t, "Hyper", provider.Name) // Provider is still returned.
+}
@@ -1,6 +1,7 @@
package config
import (
+ "cmp"
"context"
"encoding/json"
"fmt"
@@ -17,6 +18,7 @@ import (
"testing"
"github.com/charmbracelet/catwalk/pkg/catwalk"
+ "github.com/charmbracelet/crush/internal/agent/hyper"
"github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/env"
"github.com/charmbracelet/crush/internal/fsext"
@@ -131,6 +133,8 @@ func PushPopCrushEnv() func() {
}
func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error {
+ c.importCopilot()
+
knownProviderNames := make(map[string]bool)
restore := PushPopCrushEnv()
defer restore()
@@ -198,8 +202,18 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
Models: p.Models,
}
- if p.ID == catwalk.InferenceProviderAnthropic && config.OAuthToken != nil {
+ switch {
+ case p.ID == catwalk.InferenceProviderAnthropic && config.OAuthToken != nil:
prepared.SetupClaudeCode()
+ case p.ID == catwalk.InferenceProviderCopilot:
+ if config.OAuthToken != nil {
+ if token, ok := c.importCopilot(); ok {
+ prepared.OAuthToken = token
+ }
+ }
+ if config.OAuthToken != nil {
+ prepared.SetupGitHubCopilot()
+ }
}
switch p.ID {
@@ -271,7 +285,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
if providerConfig.Type == "" {
providerConfig.Type = catwalk.TypeOpenAICompat
}
- if !slices.Contains(catwalk.KnownProviderTypes(), providerConfig.Type) {
+ if !slices.Contains(catwalk.KnownProviderTypes(), providerConfig.Type) && providerConfig.Type != hyper.Name {
slog.Warn("Skipping custom provider due to unsupported provider type", "provider", id)
c.Providers.Del(id)
continue
@@ -682,19 +696,22 @@ func hasAWSCredentials(env env.Env) bool {
// GlobalConfig returns the global configuration file path for the application.
func GlobalConfig() string {
- xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
- if xdgConfigHome != "" {
+ if crushGlobal := os.Getenv("CRUSH_GLOBAL_CONFIG"); crushGlobal != "" {
+ return filepath.Join(crushGlobal, fmt.Sprintf("%s.json", appName))
+ }
+ if xdgConfigHome := os.Getenv("XDG_CONFIG_HOME"); xdgConfigHome != "" {
return filepath.Join(xdgConfigHome, appName, fmt.Sprintf("%s.json", appName))
}
-
return filepath.Join(home.Dir(), ".config", appName, fmt.Sprintf("%s.json", appName))
}
// GlobalConfigData returns the path to the main data directory for the application.
// this config is used when the app overrides configurations instead of updating the global config.
func GlobalConfigData() string {
- xdgDataHome := os.Getenv("XDG_DATA_HOME")
- if xdgDataHome != "" {
+ if crushData := os.Getenv("CRUSH_GLOBAL_DATA"); crushData != "" {
+ return filepath.Join(crushData, fmt.Sprintf("%s.json", appName))
+ }
+ if xdgDataHome := os.Getenv("XDG_DATA_HOME"); xdgDataHome != "" {
return filepath.Join(xdgDataHome, appName, fmt.Sprintf("%s.json", appName))
}
@@ -702,10 +719,10 @@ func GlobalConfigData() string {
// for windows, it should be in `%LOCALAPPDATA%/crush/`
// for linux and macOS, it should be in `$HOME/.local/share/crush/`
if runtime.GOOS == "windows" {
- localAppData := os.Getenv("LOCALAPPDATA")
- if localAppData == "" {
- localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
- }
+ localAppData := cmp.Or(
+ os.Getenv("LOCALAPPDATA"),
+ filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local"),
+ )
return filepath.Join(localAppData, appName, fmt.Sprintf("%s.json", appName))
}
@@ -10,16 +10,21 @@ import (
"os"
"path/filepath"
"runtime"
+ "slices"
"strings"
"sync"
+ "time"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/catwalk/pkg/embedded"
+ "github.com/charmbracelet/crush/internal/agent/hyper"
+ "github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/home"
+ "github.com/charmbracelet/x/etag"
)
-type ProviderClient interface {
- GetProviders(context.Context, string) ([]catwalk.Provider, error)
+type syncer[T any] interface {
+ Get(context.Context) (T, error)
}
var (
@@ -29,10 +34,10 @@ var (
)
// file to cache provider data
-func providerCacheFileData() string {
+func cachePathFor(name string) string {
xdgDataHome := os.Getenv("XDG_DATA_HOME")
if xdgDataHome != "" {
- return filepath.Join(xdgDataHome, appName, "providers.json")
+ return filepath.Join(xdgDataHome, appName, name+".json")
}
// return the path to the main data directory
@@ -43,43 +48,13 @@ func providerCacheFileData() string {
if localAppData == "" {
localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
}
- return filepath.Join(localAppData, appName, "providers.json")
+ return filepath.Join(localAppData, appName, name+".json")
}
- return filepath.Join(home.Dir(), ".local", "share", appName, "providers.json")
-}
-
-func saveProvidersInCache(path string, providers []catwalk.Provider) error {
- slog.Info("Saving provider data to disk", "path", path)
- if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
- return fmt.Errorf("failed to create directory for provider cache: %w", err)
- }
-
- data, err := json.Marshal(providers)
- if err != nil {
- return fmt.Errorf("failed to marshal provider data: %w", err)
- }
-
- if err := os.WriteFile(path, data, 0o644); err != nil {
- return fmt.Errorf("failed to write provider data to cache: %w", err)
- }
- return nil
-}
-
-func loadProvidersFromCache(path string) ([]catwalk.Provider, string, error) {
- data, err := os.ReadFile(path)
- if err != nil {
- return nil, "", fmt.Errorf("failed to read provider cache file: %w", err)
- }
-
- var providers []catwalk.Provider
- if err := json.Unmarshal(data, &providers); err != nil {
- return nil, "", fmt.Errorf("failed to unmarshal provider data from cache: %w", err)
- }
-
- return providers, catwalk.Etag(data), nil
+ return filepath.Join(home.Dir(), ".local", "share", appName, name+".json")
}
+// UpdateProviders updates the Catwalk providers list from a specified source.
func UpdateProviders(pathOrURL string) error {
var providers []catwalk.Provider
pathOrURL = cmp.Or(pathOrURL, os.Getenv("CATWALK_URL"), defaultCatwalkURL)
@@ -106,15 +81,55 @@ func UpdateProviders(pathOrURL string) error {
}
}
- cachePath := providerCacheFileData()
- if err := saveProvidersInCache(cachePath, providers); err != nil {
+ if err := newCache[[]catwalk.Provider](cachePathFor("providers")).Store(providers); err != nil {
return fmt.Errorf("failed to save providers to cache: %w", err)
}
- slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrURL, "to", cachePath)
+ slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrURL, "to", cachePathFor)
+ return nil
+}
+
+// UpdateHyper updates the Hyper provider information from a specified URL.
+func UpdateHyper(pathOrURL string) error {
+ if !hyper.Enabled() {
+ return fmt.Errorf("hyper not enabled")
+ }
+ var provider catwalk.Provider
+ pathOrURL = cmp.Or(pathOrURL, hyper.BaseURL())
+
+ switch {
+ case pathOrURL == "embedded":
+ provider = hyper.Embedded()
+ case strings.HasPrefix(pathOrURL, "http://") || strings.HasPrefix(pathOrURL, "https://"):
+ client := realHyperClient{baseURL: pathOrURL}
+ var err error
+ provider, err = client.Get(context.Background(), "")
+ if err != nil {
+ return fmt.Errorf("failed to fetch provider from Hyper: %w", err)
+ }
+ default:
+ content, err := os.ReadFile(pathOrURL)
+ if err != nil {
+ return fmt.Errorf("failed to read file: %w", err)
+ }
+ if err := json.Unmarshal(content, &provider); err != nil {
+ return fmt.Errorf("failed to unmarshal provider data: %w", err)
+ }
+ }
+
+ if err := newCache[catwalk.Provider](cachePathFor("hyper")).Store(provider); err != nil {
+ return fmt.Errorf("failed to save Hyper provider to cache: %w", err)
+ }
+
+ slog.Info("Hyper provider updated successfully", "from", pathOrURL, "to", cachePathFor("hyper"))
return nil
}
+var (
+ catwalkSyncer = &catwalkSync{}
+ hyperSyncer = &hyperSync{}
+)
+
// Providers returns the list of providers, taking into account cached results
// and whether or not auto update is enabled.
//
@@ -126,46 +141,87 @@ func UpdateProviders(pathOrURL string) error {
// the cached list, or the embedded list if all others fail.
func Providers(cfg *Config) ([]catwalk.Provider, error) {
providerOnce.Do(func() {
- catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
- client := catwalk.NewWithURL(catwalkURL)
- path := providerCacheFileData()
-
- if cfg.Options.DisableProviderAutoUpdate {
- slog.Info("Using embedded Catwalk providers")
- providerList, providerErr = embedded.GetAll(), nil
- return
- }
+ var wg sync.WaitGroup
+ var errs []error
+ providers := csync.NewSlice[catwalk.Provider]()
+ autoupdate := !cfg.Options.DisableProviderAutoUpdate
+
+ ctx, cancel := context.WithTimeout(context.Background(), 45*time.Second)
+ defer cancel()
+
+ wg.Go(func() {
+ catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
+ client := catwalk.NewWithURL(catwalkURL)
+ path := cachePathFor("providers")
+ catwalkSyncer.Init(client, path, autoupdate)
+
+ items, err := catwalkSyncer.Get(ctx)
+ if err != nil {
+ catwalkURL := fmt.Sprintf("%s/v2/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL))
+ errs = append(errs, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help.\n\nCause: %w", catwalkURL, providerErr)) //nolint:staticcheck
+ return
+ }
+ providers.Append(items...)
+ })
+
+ wg.Go(func() {
+ if !hyper.Enabled() {
+ return
+ }
+ path := cachePathFor("hyper")
+ hyperSyncer.Init(realHyperClient{baseURL: hyper.BaseURL()}, path, autoupdate)
+
+ item, err := hyperSyncer.Get(ctx)
+ if err != nil {
+ errs = append(errs, fmt.Errorf("Crush was unable to fetch updated information from Hyper: %w", err)) //nolint:staticcheck
+ return
+ }
+ providers.Append(item)
+ })
+
+ wg.Wait()
+
+ providerList = slices.Collect(providers.Seq())
+ providerErr = errors.Join(errs...)
+ })
+ return providerList, providerErr
+}
- cached, etag, cachedErr := loadProvidersFromCache(path)
- if len(cached) == 0 || cachedErr != nil {
- // if cached file is empty, default to embedded providers
- cached = embedded.GetAll()
- }
+type cache[T any] struct {
+ path string
+}
- providerList, providerErr = loadProviders(client, etag, path)
- if errors.Is(providerErr, catwalk.ErrNotModified) {
- slog.Info("Catwalk providers not modified")
- providerList, providerErr = cached, nil
- }
- })
- if providerErr != nil {
- catwalkURL := fmt.Sprintf("%s/v2/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL))
- return nil, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help.\n\nCause: %w", catwalkURL, providerErr) //nolint:staticcheck
- }
- return providerList, nil
+func newCache[T any](path string) cache[T] {
+ return cache[T]{path: path}
}
-func loadProviders(client ProviderClient, etag, path string) ([]catwalk.Provider, error) {
- slog.Info("Fetching providers from Catwalk.", "path", path)
- providers, err := client.GetProviders(context.Background(), etag)
+func (c cache[T]) Get() (T, string, error) {
+ var v T
+ data, err := os.ReadFile(c.path)
if err != nil {
- return nil, fmt.Errorf("failed to fetch providers from catwalk: %w", err)
+ return v, "", fmt.Errorf("failed to read provider cache file: %w", err)
+ }
+
+ if err := json.Unmarshal(data, &v); err != nil {
+ return v, "", fmt.Errorf("failed to unmarshal provider data from cache: %w", err)
+ }
+
+ return v, etag.Of(data), nil
+}
+
+func (c cache[T]) Store(v T) error {
+ slog.Info("Saving provider data to disk", "path", c.path)
+ if err := os.MkdirAll(filepath.Dir(c.path), 0o755); err != nil {
+ return fmt.Errorf("failed to create directory for provider cache: %w", err)
}
- if len(providers) == 0 {
- return nil, errors.New("empty providers list from catwalk")
+
+ data, err := json.Marshal(v)
+ if err != nil {
+ return fmt.Errorf("failed to marshal provider data: %w", err)
}
- if err := saveProvidersInCache(path, providers); err != nil {
- return nil, err
+
+ if err := os.WriteFile(c.path, data, 0o644); err != nil {
+ return fmt.Errorf("failed to write provider data to cache: %w", err)
}
- return providers, nil
+ return nil
}
@@ -2,6 +2,7 @@ package config
import (
"context"
+ "os"
"testing"
"github.com/charmbracelet/catwalk/pkg/catwalk"
@@ -14,18 +15,25 @@ func (m *emptyProviderClient) GetProviders(context.Context, string) ([]catwalk.P
return []catwalk.Provider{}, nil
}
-// TestProvider_loadProvidersEmptyResult tests that loadProviders returns an
-// error when the client returns an empty list. This ensures we don't cache
-// empty provider lists.
-func TestProvider_loadProvidersEmptyResult(t *testing.T) {
+// TestCatwalkSync_GetEmptyResultFromClient tests that when the client returns
+// an empty list, we fall back to cached providers and return an error.
+func TestCatwalkSync_GetEmptyResultFromClient(t *testing.T) {
+ t.Parallel()
+
+ tmpDir := t.TempDir()
+ path := tmpDir + "/providers.json"
+
+ syncer := &catwalkSync{}
client := &emptyProviderClient{}
- tmpPath := t.TempDir() + "/providers.json"
- providers, err := loadProviders(client, "", tmpPath)
+ syncer.Init(client, path, true)
+
+ providers, err := syncer.Get(t.Context())
+ require.Error(t, err)
require.Contains(t, err.Error(), "empty providers list from catwalk")
- require.Empty(t, providers)
- require.Len(t, providers, 0)
+ require.NotEmpty(t, providers) // Should have embedded providers as fallback.
- // Check that no cache file was created for empty results
- require.NoFileExists(t, tmpPath, "Cache file should not exist for empty results")
+ // Check that no cache file was created for empty results.
+ _, statErr := os.Stat(path)
+ require.True(t, os.IsNotExist(statErr), "Cache file should not exist for empty results")
}
@@ -1,10 +1,9 @@
package config
import (
- "context"
"encoding/json"
- "errors"
"os"
+ "path/filepath"
"sync"
"testing"
@@ -12,47 +11,31 @@ import (
"github.com/stretchr/testify/require"
)
-type mockProviderClient struct {
- shouldFail bool
- shouldReturnErr error
-}
-
-func (m *mockProviderClient) GetProviders(context.Context, string) ([]catwalk.Provider, error) {
- if m.shouldReturnErr != nil {
- return nil, m.shouldReturnErr
- }
- if m.shouldFail {
- return nil, errors.New("failed to load providers")
- }
- return []catwalk.Provider{
- {
- Name: "Mock",
- },
- }, nil
-}
-
func resetProviderState() {
providerOnce = sync.Once{}
providerList = nil
providerErr = nil
+ catwalkSyncer = &catwalkSync{}
+ hyperSyncer = &hyperSync{}
}
-func TestProvider_loadProvidersNoIssues(t *testing.T) {
- client := &mockProviderClient{shouldFail: false}
- tmpPath := t.TempDir() + "/providers.json"
- providers, err := loadProviders(client, "", tmpPath)
- require.NoError(t, err)
- require.NotNil(t, providers)
- require.Len(t, providers, 1)
+func TestProviders_Integration_AutoUpdateDisabled(t *testing.T) {
+ tmpDir := t.TempDir()
+ t.Setenv("XDG_DATA_HOME", tmpDir)
- // check if file got saved
- fileInfo, err := os.Stat(tmpPath)
- require.NoError(t, err)
- require.False(t, fileInfo.IsDir(), "Expected a file, not a directory")
-}
+ // Use a test-specific instance to avoid global state interference.
+ testCatwalkSyncer := &catwalkSync{}
+ testHyperSyncer := &hyperSync{}
+
+ originalCatwalSyncer := catwalkSyncer
+ originalHyperSyncer := hyperSyncer
+ defer func() {
+ catwalkSyncer = originalCatwalSyncer
+ hyperSyncer = originalHyperSyncer
+ }()
-func TestProvider_DisableAutoUpdate(t *testing.T) {
- t.Setenv("XDG_DATA_HOME", t.TempDir())
+ catwalkSyncer = testCatwalkSyncer
+ hyperSyncer = testHyperSyncer
resetProviderState()
defer resetProviderState()
@@ -69,76 +52,257 @@ func TestProvider_DisableAutoUpdate(t *testing.T) {
require.Greater(t, len(providers), 5, "Expected embedded providers")
}
-func TestProvider_WithValidCache(t *testing.T) {
+func TestProviders_Integration_WithMockClients(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("XDG_DATA_HOME", tmpDir)
- resetProviderState()
- defer resetProviderState()
+ // Create fresh syncers for this test.
+ testCatwalkSyncer := &catwalkSync{}
+ testHyperSyncer := &hyperSync{}
+
+ // Initialize with mock clients.
+ mockCatwalkClient := &mockCatwalkClient{
+ providers: []catwalk.Provider{
+ {Name: "Provider1", ID: "p1"},
+ {Name: "Provider2", ID: "p2"},
+ },
+ }
+ mockHyperClient := &mockHyperClient{
+ provider: catwalk.Provider{
+ Name: "Hyper",
+ ID: "hyper",
+ Models: []catwalk.Model{
+ {ID: "hyper-1", Name: "Hyper Model"},
+ },
+ },
+ }
+
+ catwalkPath := tmpDir + "/crush/providers.json"
+ hyperPath := tmpDir + "/crush/hyper.json"
+
+ testCatwalkSyncer.Init(mockCatwalkClient, catwalkPath, true)
+ testHyperSyncer.Init(mockHyperClient, hyperPath, true)
+
+ // Get providers from each syncer.
+ catwalkProviders, err := testCatwalkSyncer.Get(t.Context())
+ require.NoError(t, err)
+ require.Len(t, catwalkProviders, 2)
+
+ hyperProvider, err := testHyperSyncer.Get(t.Context())
+ require.NoError(t, err)
+ require.Equal(t, "Hyper", hyperProvider.Name)
+
+ // Verify total.
+ allProviders := append(catwalkProviders, hyperProvider)
+ require.Len(t, allProviders, 3)
+}
+
+func TestProviders_Integration_WithCachedData(t *testing.T) {
+ tmpDir := t.TempDir()
+ t.Setenv("XDG_DATA_HOME", tmpDir)
+
+ // Create cache files.
+ catwalkPath := tmpDir + "/crush/providers.json"
+ hyperPath := tmpDir + "/crush/hyper.json"
- cachePath := tmpDir + "/crush/providers.json"
require.NoError(t, os.MkdirAll(tmpDir+"/crush", 0o755))
- cachedProviders := []catwalk.Provider{
- {Name: "Cached"},
+
+ // Write Catwalk cache.
+ catwalkProviders := []catwalk.Provider{
+ {Name: "Cached1", ID: "c1"},
+ {Name: "Cached2", ID: "c2"},
+ }
+ data, err := json.Marshal(catwalkProviders)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(catwalkPath, data, 0o644))
+
+ // Write Hyper cache.
+ hyperProvider := catwalk.Provider{
+ Name: "Cached Hyper",
+ ID: "hyper",
}
- data, err := json.Marshal(cachedProviders)
+ data, err = json.Marshal(hyperProvider)
require.NoError(t, err)
- require.NoError(t, os.WriteFile(cachePath, data, 0o644))
+ require.NoError(t, os.WriteFile(hyperPath, data, 0o644))
+
+ // Create fresh syncers.
+ testCatwalkSyncer := &catwalkSync{}
+ testHyperSyncer := &hyperSync{}
+
+ // Mock clients that return ErrNotModified.
+ mockCatwalkClient := &mockCatwalkClient{
+ err: catwalk.ErrNotModified,
+ }
+ mockHyperClient := &mockHyperClient{
+ err: catwalk.ErrNotModified,
+ }
- mockClient := &mockProviderClient{shouldFail: false}
+ testCatwalkSyncer.Init(mockCatwalkClient, catwalkPath, true)
+ testHyperSyncer.Init(mockHyperClient, hyperPath, true)
- providers, err := loadProviders(mockClient, "", cachePath)
+ // Get providers - should use cached.
+ catwalkResult, err := testCatwalkSyncer.Get(t.Context())
require.NoError(t, err)
- require.NotNil(t, providers)
- require.Len(t, providers, 1)
- require.Equal(t, "Mock", providers[0].Name, "Expected fresh provider from fetch")
+ require.Len(t, catwalkResult, 2)
+ require.Equal(t, "Cached1", catwalkResult[0].Name)
+
+ hyperResult, err := testHyperSyncer.Get(t.Context())
+ require.NoError(t, err)
+ require.Equal(t, "Cached Hyper", hyperResult.Name)
}
-func TestProvider_NotModifiedUsesCached(t *testing.T) {
+func TestProviders_Integration_CatwalkFailsHyperSucceeds(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("XDG_DATA_HOME", tmpDir)
- resetProviderState()
- defer resetProviderState()
+ testCatwalkSyncer := &catwalkSync{}
+ testHyperSyncer := &hyperSync{}
- cachePath := tmpDir + "/crush/providers.json"
- require.NoError(t, os.MkdirAll(tmpDir+"/crush", 0o755))
- cachedProviders := []catwalk.Provider{
- {Name: "Cached"},
+ // Catwalk fails, Hyper succeeds.
+ mockCatwalkClient := &mockCatwalkClient{
+ err: catwalk.ErrNotModified, // Will use embedded.
}
- data, err := json.Marshal(cachedProviders)
+ mockHyperClient := &mockHyperClient{
+ provider: catwalk.Provider{
+ Name: "Hyper",
+ ID: "hyper",
+ Models: []catwalk.Model{
+ {ID: "hyper-1", Name: "Hyper Model"},
+ },
+ },
+ }
+
+ catwalkPath := tmpDir + "/crush/providers.json"
+ hyperPath := tmpDir + "/crush/hyper.json"
+
+ testCatwalkSyncer.Init(mockCatwalkClient, catwalkPath, true)
+ testHyperSyncer.Init(mockHyperClient, hyperPath, true)
+
+ catwalkResult, err := testCatwalkSyncer.Get(t.Context())
require.NoError(t, err)
- require.NoError(t, os.WriteFile(cachePath, data, 0o644))
+ require.NotEmpty(t, catwalkResult) // Should have embedded.
- mockClient := &mockProviderClient{shouldReturnErr: catwalk.ErrNotModified}
- providers, err := loadProviders(mockClient, "", cachePath)
- require.ErrorIs(t, err, catwalk.ErrNotModified)
- require.Nil(t, providers)
+ hyperResult, err := testHyperSyncer.Get(t.Context())
+ require.NoError(t, err)
+ require.Equal(t, "Hyper", hyperResult.Name)
}
-func TestProvider_EmptyCacheDefaultsToEmbedded(t *testing.T) {
+func TestProviders_Integration_BothFail(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("XDG_DATA_HOME", tmpDir)
- resetProviderState()
- defer resetProviderState()
+ testCatwalkSyncer := &catwalkSync{}
+ testHyperSyncer := &hyperSync{}
- cachePath := tmpDir + "/crush/providers.json"
- require.NoError(t, os.MkdirAll(tmpDir+"/crush", 0o755))
- emptyProviders := []catwalk.Provider{}
- data, err := json.Marshal(emptyProviders)
+ // Both fail.
+ mockCatwalkClient := &mockCatwalkClient{
+ err: catwalk.ErrNotModified,
+ }
+ mockHyperClient := &mockHyperClient{
+ provider: catwalk.Provider{}, // Empty provider.
+ }
+
+ catwalkPath := tmpDir + "/crush/providers.json"
+ hyperPath := tmpDir + "/crush/hyper.json"
+
+ testCatwalkSyncer.Init(mockCatwalkClient, catwalkPath, true)
+ testHyperSyncer.Init(mockHyperClient, hyperPath, true)
+
+ catwalkResult, err := testCatwalkSyncer.Get(t.Context())
require.NoError(t, err)
- require.NoError(t, os.WriteFile(cachePath, data, 0o644))
+ require.NotEmpty(t, catwalkResult) // Should fall back to embedded.
- cached, _, err := loadProvidersFromCache(cachePath)
+ hyperResult, err := testHyperSyncer.Get(t.Context())
require.NoError(t, err)
- require.Empty(t, cached, "Expected empty cache")
+ require.Equal(t, "Charm Hyper", hyperResult.Name) // Falls back to embedded when no models.
}
-func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) {
- client := &mockProviderClient{shouldFail: true}
- tmpPath := t.TempDir() + "/providers.json"
- providers, err := loadProviders(client, "", tmpPath)
+func TestCache_StoreAndGet(t *testing.T) {
+ t.Parallel()
+
+ tmpDir := t.TempDir()
+ cachePath := tmpDir + "/test.json"
+
+ cache := newCache[[]catwalk.Provider](cachePath)
+
+ providers := []catwalk.Provider{
+ {Name: "Provider1", ID: "p1"},
+ {Name: "Provider2", ID: "p2"},
+ }
+
+ // Store.
+ err := cache.Store(providers)
+ require.NoError(t, err)
+
+ // Get.
+ result, etag, err := cache.Get()
+ require.NoError(t, err)
+ require.Len(t, result, 2)
+ require.Equal(t, "Provider1", result[0].Name)
+ require.NotEmpty(t, etag)
+}
+
+func TestCache_GetNonExistent(t *testing.T) {
+ t.Parallel()
+
+ tmpDir := t.TempDir()
+ cachePath := tmpDir + "/nonexistent.json"
+
+ cache := newCache[[]catwalk.Provider](cachePath)
+
+ _, _, err := cache.Get()
require.Error(t, err)
- require.Nil(t, providers, "Expected nil providers when loading fails and no cache exists")
+ require.Contains(t, err.Error(), "failed to read provider cache file")
+}
+
+func TestCache_GetInvalidJSON(t *testing.T) {
+ t.Parallel()
+
+ tmpDir := t.TempDir()
+ cachePath := tmpDir + "/invalid.json"
+
+ require.NoError(t, os.WriteFile(cachePath, []byte("invalid json"), 0o644))
+
+ cache := newCache[[]catwalk.Provider](cachePath)
+
+ _, _, err := cache.Get()
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "failed to unmarshal provider data from cache")
+}
+
+func TestCachePathFor(t *testing.T) {
+ tests := []struct {
+ name string
+ xdgDataHome string
+ expected string
+ }{
+ {
+ name: "with XDG_DATA_HOME",
+ xdgDataHome: "/custom/data",
+ expected: "/custom/data/crush/providers.json",
+ },
+ {
+ name: "without XDG_DATA_HOME",
+ xdgDataHome: "",
+ expected: "", // Will use platform-specific default.
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if tt.xdgDataHome != "" {
+ t.Setenv("XDG_DATA_HOME", tt.xdgDataHome)
+ } else {
+ t.Setenv("XDG_DATA_HOME", "")
+ }
+
+ result := cachePathFor("providers")
+ if tt.expected != "" {
+ require.Equal(t, tt.expected, filepath.ToSlash(result))
+ } else {
+ require.Contains(t, result, "crush")
+ require.Contains(t, result, "providers.json")
+ }
+ })
+ }
}
@@ -46,6 +46,22 @@ func Init() {
distinctId = getDistinctId()
}
+func GetID() string { return distinctId }
+
+func Alias(userID string) {
+ if client == nil || distinctId == fallbackId || distinctId == "" || userID == "" {
+ return
+ }
+ if err := client.Enqueue(posthog.Alias{
+ DistinctId: distinctId,
+ Alias: userID,
+ }); err != nil {
+ slog.Error("Failed to enqueue PostHog alias event", "error", err)
+ return
+ }
+ slog.Info("Aliased in PostHog", "machine_id", distinctId, "user_id", userID)
+}
+
// send logs an event to PostHog with the given event name and properties.
func send(event string, props ...any) {
if client == nil {
@@ -0,0 +1,36 @@
+package copilot
+
+import (
+ "encoding/json"
+ "os"
+ "path/filepath"
+ "runtime"
+)
+
+func RefreshTokenFromDisk() (string, bool) {
+ data, err := os.ReadFile(tokenFilePath())
+ if err != nil {
+ return "", false
+ }
+ var content map[string]struct {
+ User string `json:"user"`
+ OAuthToken string `json:"oauth_token"`
+ GitHubAppID string `json:"githubAppId"`
+ }
+ if err := json.Unmarshal(data, &content); err != nil {
+ return "", false
+ }
+ if app, ok := content["github.com:Iv1.b507a08c87ecfe98"]; ok {
+ return app.OAuthToken, true
+ }
+ return "", false
+}
+
+func tokenFilePath() string {
+ switch runtime.GOOS {
+ case "windows":
+ return filepath.Join(os.Getenv("LOCALAPPDATA"), "github-copilot/apps.json")
+ default:
+ return filepath.Join(os.Getenv("HOME"), ".config/github-copilot/apps.json")
+ }
+}
@@ -0,0 +1,17 @@
+package copilot
+
+const (
+ userAgent = "GitHubCopilotChat/0.32.4"
+ editorVersion = "vscode/1.105.1"
+ editorPluginVersion = "copilot-chat/0.32.4"
+ integrationID = "vscode-chat"
+)
+
+func Headers() map[string]string {
+ return map[string]string{
+ "User-Agent": userAgent,
+ "Editor-Version": editorVersion,
+ "Editor-Plugin-Version": editorPluginVersion,
+ "Copilot-Integration-Id": integrationID,
+ }
+}
@@ -0,0 +1,200 @@
+package copilot
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ "github.com/charmbracelet/crush/internal/oauth"
+)
+
+const (
+ clientID = "Iv1.b507a08c87ecfe98"
+
+ deviceCodeURL = "https://github.com/login/device/code"
+ accessTokenURL = "https://github.com/login/oauth/access_token"
+ copilotTokenURL = "https://api.github.com/copilot_internal/v2/token"
+)
+
+var ErrNotAvailable = errors.New("github copilot not available")
+
+type DeviceCode struct {
+ DeviceCode string `json:"device_code"`
+ UserCode string `json:"user_code"`
+ VerificationURI string `json:"verification_uri"`
+ ExpiresIn int `json:"expires_in"`
+ Interval int `json:"interval"`
+}
+
+// RequestDeviceCode initiates the device code flow with GitHub.
+func RequestDeviceCode(ctx context.Context) (*DeviceCode, error) {
+ data := url.Values{}
+ data.Set("client_id", clientID)
+ data.Set("scope", "read:user")
+
+ req, err := http.NewRequestWithContext(ctx, "POST", deviceCodeURL, strings.NewReader(data.Encode()))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ req.Header.Set("User-Agent", userAgent)
+
+ client := &http.Client{Timeout: 30 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(resp.Body)
+ return nil, fmt.Errorf("device code request failed: %s - %s", resp.Status, string(body))
+ }
+
+ var dc DeviceCode
+ if err := json.NewDecoder(resp.Body).Decode(&dc); err != nil {
+ return nil, err
+ }
+ return &dc, nil
+}
+
+// PollForToken polls GitHub for the access token after user authorization.
+func PollForToken(ctx context.Context, dc *DeviceCode) (*oauth.Token, error) {
+ interval := max(dc.Interval, 5)
+ deadline := time.Now().Add(time.Duration(dc.ExpiresIn) * time.Second)
+ ticker := time.NewTicker(time.Duration(interval) * time.Second)
+ defer ticker.Stop()
+
+ for time.Now().Before(deadline) {
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-ticker.C:
+ }
+
+ token, err := tryGetToken(ctx, dc.DeviceCode)
+ if err == errPending {
+ continue
+ }
+ if err == errSlowDown {
+ interval += 5
+ ticker.Reset(time.Duration(interval) * time.Second)
+ continue
+ }
+ if err != nil {
+ return nil, err
+ }
+ return token, nil
+ }
+
+ return nil, fmt.Errorf("authorization timed out")
+}
+
+var (
+ errPending = fmt.Errorf("pending")
+ errSlowDown = fmt.Errorf("slow_down")
+)
+
+func tryGetToken(ctx context.Context, deviceCode string) (*oauth.Token, error) {
+ data := url.Values{}
+ data.Set("client_id", clientID)
+ data.Set("device_code", deviceCode)
+ data.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code")
+
+ req, err := http.NewRequestWithContext(ctx, "POST", accessTokenURL, strings.NewReader(data.Encode()))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ req.Header.Set("User-Agent", userAgent)
+
+ client := &http.Client{Timeout: 30 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ var result struct {
+ AccessToken string `json:"access_token"`
+ Error string `json:"error"`
+ }
+ if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
+ return nil, err
+ }
+
+ switch result.Error {
+ case "":
+ if result.AccessToken == "" {
+ return nil, errPending
+ }
+ return getCopilotToken(ctx, result.AccessToken)
+ case "authorization_pending":
+ return nil, errPending
+ case "slow_down":
+ return nil, errSlowDown
+ default:
+ return nil, fmt.Errorf("authorization failed: %s", result.Error)
+ }
+}
+
+func getCopilotToken(ctx context.Context, githubToken string) (*oauth.Token, error) {
+ req, err := http.NewRequestWithContext(ctx, "GET", copilotTokenURL, nil)
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", githubToken))
+ for k, v := range Headers() {
+ req.Header.Set(k, v)
+ }
+
+ client := &http.Client{Timeout: 30 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ if resp.StatusCode == http.StatusForbidden {
+ return nil, ErrNotAvailable
+ }
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("copilot token request failed: %s - %s", resp.Status, string(body))
+ }
+
+ var result struct {
+ Token string `json:"token"`
+ ExpiresAt int64 `json:"expires_at"`
+ }
+ if err := json.Unmarshal(body, &result); err != nil {
+ return nil, err
+ }
+
+ copilotToken := &oauth.Token{
+ AccessToken: result.Token,
+ RefreshToken: githubToken,
+ ExpiresAt: result.ExpiresAt,
+ }
+ copilotToken.SetExpiresIn()
+
+ return copilotToken, nil
+}
+
+// RefreshToken refreshes the Copilot token using the GitHub token.
+func RefreshToken(ctx context.Context, githubToken string) (*oauth.Token, error) {
+ return getCopilotToken(ctx, githubToken)
+}
@@ -0,0 +1,6 @@
+package copilot
+
+const (
+ SignupURL = "https://github.com/github-copilot/signup?editor=crush"
+ FreeURL = "https://docs.github.com/en/copilot/how-tos/manage-your-account/get-free-access-to-copilot-pro"
+)
@@ -0,0 +1,251 @@
+// Package hyper provides functions to handle Hyper device flow authentication.
+package hyper
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "strings"
+ "time"
+
+ "github.com/charmbracelet/crush/internal/agent/hyper"
+ "github.com/charmbracelet/crush/internal/event"
+ "github.com/charmbracelet/crush/internal/oauth"
+)
+
+// DeviceAuthResponse contains the response from the device authorization endpoint.
+type DeviceAuthResponse struct {
+ DeviceCode string `json:"device_code"`
+ UserCode string `json:"user_code"`
+ VerificationURL string `json:"verification_url"`
+ ExpiresIn int `json:"expires_in"`
+}
+
+// TokenResponse contains the response from the polling endpoint.
+type TokenResponse struct {
+ RefreshToken string `json:"refresh_token,omitempty"`
+ UserID string `json:"user_id"`
+ OrganizationID string `json:"organization_id"`
+ OrganizationName string `json:"organization_name"`
+ Error string `json:"error,omitempty"`
+ ErrorDescription string `json:"error_description,omitempty"`
+}
+
+// InitiateDeviceAuth calls the /device/auth endpoint to start the device flow.
+func InitiateDeviceAuth(ctx context.Context) (*DeviceAuthResponse, error) {
+ url := hyper.BaseURL() + "/device/auth"
+
+ req, err := http.NewRequestWithContext(
+ ctx, http.MethodPost, url,
+ strings.NewReader(fmt.Sprintf(`{"device_name":%q}`, deviceName())),
+ )
+ if err != nil {
+ return nil, fmt.Errorf("create request: %w", err)
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("User-Agent", "crush")
+
+ client := &http.Client{Timeout: 30 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("execute request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
+ if err != nil {
+ return nil, fmt.Errorf("read response: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("device auth failed: status %d, body %q", resp.StatusCode, string(body))
+ }
+
+ var authResp DeviceAuthResponse
+ if err := json.Unmarshal(body, &authResp); err != nil {
+ return nil, fmt.Errorf("unmarshal response: %w", err)
+ }
+
+ return &authResp, nil
+}
+
+func deviceName() string {
+ if hostname, err := os.Hostname(); err == nil && hostname != "" {
+ return "Crush (" + hostname + ")"
+ }
+ return "Crush"
+}
+
+// PollForToken polls the /device/token endpoint until authorization is complete.
+// It respects the polling interval and handles various error states.
+func PollForToken(ctx context.Context, deviceCode string, expiresIn int) (string, error) {
+ ctx, cancel := context.WithTimeout(ctx, time.Duration(expiresIn)*time.Second)
+ defer cancel()
+
+ d := 5 * time.Second
+ ticker := time.NewTicker(d)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ return "", ctx.Err()
+ case <-ticker.C:
+ result, err := pollOnce(ctx, deviceCode)
+ if err != nil {
+ return "", err
+ }
+ if result.RefreshToken != "" {
+ event.Alias(result.UserID)
+ return result.RefreshToken, nil
+ }
+ switch result.Error {
+ case "authorization_pending":
+ continue
+ default:
+ return "", errors.New(result.ErrorDescription)
+ }
+ }
+ }
+}
+
+func pollOnce(ctx context.Context, deviceCode string) (TokenResponse, error) {
+ var result TokenResponse
+ url := fmt.Sprintf("%s/device/auth/%s", hyper.BaseURL(), deviceCode)
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
+ if err != nil {
+ return result, fmt.Errorf("create request: %w", err)
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("User-Agent", "crush")
+
+ client := &http.Client{Timeout: 30 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return result, fmt.Errorf("execute request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
+ if err != nil {
+ return result, fmt.Errorf("read response: %w", err)
+ }
+
+ if err := json.Unmarshal(body, &result); err != nil {
+ return result, fmt.Errorf("unmarshal response: %w: %s", err, string(body))
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return result, fmt.Errorf("token request failed: status %d body %q", resp.StatusCode, string(body))
+ }
+
+ return result, nil
+}
+
+// ExchangeToken exchanges a refresh token for an access token.
+func ExchangeToken(ctx context.Context, refreshToken string) (*oauth.Token, error) {
+ reqBody := map[string]string{
+ "refresh_token": refreshToken,
+ }
+
+ data, err := json.Marshal(reqBody)
+ if err != nil {
+ return nil, fmt.Errorf("marshal request: %w", err)
+ }
+
+ url := hyper.BaseURL() + "/token/exchange"
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(data))
+ if err != nil {
+ return nil, fmt.Errorf("create request: %w", err)
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("User-Agent", "crush")
+
+ client := &http.Client{Timeout: 30 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("execute request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
+ if err != nil {
+ return nil, fmt.Errorf("read response: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("token exchange failed: status %d body %q", resp.StatusCode, string(body))
+ }
+
+ var token oauth.Token
+ if err := json.Unmarshal(body, &token); err != nil {
+ return nil, fmt.Errorf("unmarshal response: %w", err)
+ }
+
+ token.SetExpiresAt()
+ return &token, nil
+}
+
+// IntrospectTokenResponse contains the response from the token introspection endpoint.
+type IntrospectTokenResponse struct {
+ Active bool `json:"active"`
+ Sub string `json:"sub,omitempty"`
+ OrgID string `json:"org_id,omitempty"`
+ Exp int64 `json:"exp,omitempty"`
+ Iat int64 `json:"iat,omitempty"`
+ Iss string `json:"iss,omitempty"`
+ Jti string `json:"jti,omitempty"`
+}
+
+// IntrospectToken validates an access token using the introspection endpoint.
+// Implements OAuth2 Token Introspection (RFC 7662).
+func IntrospectToken(ctx context.Context, accessToken string) (*IntrospectTokenResponse, error) {
+ reqBody := map[string]string{
+ "token": accessToken,
+ }
+
+ data, err := json.Marshal(reqBody)
+ if err != nil {
+ return nil, fmt.Errorf("marshal request: %w", err)
+ }
+
+ url := hyper.BaseURL() + "/token/introspect"
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(data))
+ if err != nil {
+ return nil, fmt.Errorf("create request: %w", err)
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("User-Agent", "crush")
+
+ client := &http.Client{Timeout: 30 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("execute request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
+ if err != nil {
+ return nil, fmt.Errorf("read response: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("token introspection failed: status %d body %q", resp.StatusCode, string(body))
+ }
+
+ var result IntrospectTokenResponse
+ if err := json.Unmarshal(body, &result); err != nil {
+ return nil, fmt.Errorf("unmarshal response: %w", err)
+ }
+
+ return &result, nil
+}
@@ -4,7 +4,7 @@ import (
"time"
)
-// Token represents an OAuth2 token from Claude Code Max.
+// Token represents an OAuth2 token.
type Token struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
@@ -21,3 +21,8 @@ func (t *Token) SetExpiresAt() {
func (t *Token) IsExpired() bool {
return time.Now().Unix() >= (t.ExpiresAt - int64(t.ExpiresIn)/10)
}
+
+// SetExpiresIn calculates and sets the ExpiresIn field based on the ExpiresAt field.
+func (t *Token) SetExpiresIn() {
+ t.ExpiresIn = int(time.Until(time.Unix(t.ExpiresAt, 0)).Seconds())
+}
@@ -13,6 +13,7 @@ import (
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/agent"
+ "github.com/charmbracelet/crush/internal/agent/hyper"
"github.com/charmbracelet/crush/internal/agent/tools/mcp"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/csync"
@@ -363,7 +364,7 @@ func (c *commandDialogCmp) defaultCommands() []Command {
selectedModel := cfg.Models[agentCfg.Model]
// Anthropic models: thinking toggle
- if providerCfg.Type == catwalk.TypeAnthropic {
+ if providerCfg.Type == catwalk.TypeAnthropic || providerCfg.Type == catwalk.Type(hyper.Name) {
status := "Enable"
if selectedModel.Think {
status = "Disable"
@@ -0,0 +1,281 @@
+// Package copilot provides the dialog for Copilot device flow authentication.
+package copilot
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "charm.land/bubbles/v2/spinner"
+ tea "charm.land/bubbletea/v2"
+ "charm.land/lipgloss/v2"
+ "github.com/charmbracelet/crush/internal/oauth"
+ "github.com/charmbracelet/crush/internal/oauth/copilot"
+ "github.com/charmbracelet/crush/internal/tui/styles"
+ "github.com/charmbracelet/crush/internal/tui/util"
+ "github.com/pkg/browser"
+)
+
+// DeviceFlowState represents the current state of the device flow.
+type DeviceFlowState int
+
+const (
+ DeviceFlowStateDisplay DeviceFlowState = iota
+ DeviceFlowStateSuccess
+ DeviceFlowStateError
+ DeviceFlowStateUnavailable
+)
+
+// DeviceAuthInitiatedMsg is sent when the device auth is initiated
+// successfully.
+type DeviceAuthInitiatedMsg struct {
+ deviceCode *copilot.DeviceCode
+}
+
+// DeviceFlowCompletedMsg is sent when the device flow completes successfully.
+type DeviceFlowCompletedMsg struct {
+ Token *oauth.Token
+}
+
+// DeviceFlowErrorMsg is sent when the device flow encounters an error.
+type DeviceFlowErrorMsg struct {
+ Error error
+}
+
+// DeviceFlow handles the Copilot device flow authentication.
+type DeviceFlow struct {
+ State DeviceFlowState
+ width int
+ deviceCode *copilot.DeviceCode
+ token *oauth.Token
+ cancelFunc context.CancelFunc
+ spinner spinner.Model
+}
+
+// NewDeviceFlow creates a new device flow component.
+func NewDeviceFlow() *DeviceFlow {
+ s := spinner.New()
+ s.Spinner = spinner.Dot
+ s.Style = lipgloss.NewStyle().Foreground(styles.CurrentTheme().GreenLight)
+ return &DeviceFlow{
+ State: DeviceFlowStateDisplay,
+ spinner: s,
+ }
+}
+
+// Init initializes the device flow by calling the device auth API and starting polling.
+func (d *DeviceFlow) Init() tea.Cmd {
+ return tea.Batch(d.spinner.Tick, d.initiateDeviceAuth)
+}
+
+// Update handles messages and state transitions.
+func (d *DeviceFlow) Update(msg tea.Msg) (util.Model, tea.Cmd) {
+ var cmd tea.Cmd
+ d.spinner, cmd = d.spinner.Update(msg)
+
+ switch msg := msg.(type) {
+ case DeviceAuthInitiatedMsg:
+ return d, tea.Batch(cmd, d.startPolling(msg.deviceCode))
+ case DeviceFlowCompletedMsg:
+ d.State = DeviceFlowStateSuccess
+ d.token = msg.Token
+ return d, nil
+ case DeviceFlowErrorMsg:
+ switch msg.Error {
+ case copilot.ErrNotAvailable:
+ d.State = DeviceFlowStateUnavailable
+ default:
+ d.State = DeviceFlowStateError
+ }
+ return d, nil
+ }
+
+ return d, cmd
+}
+
+// View renders the device flow dialog.
+func (d *DeviceFlow) View() string {
+ t := styles.CurrentTheme()
+
+ whiteStyle := lipgloss.NewStyle().Foreground(t.White)
+ primaryStyle := lipgloss.NewStyle().Foreground(t.Primary)
+ greenStyle := lipgloss.NewStyle().Foreground(t.GreenLight)
+ linkStyle := lipgloss.NewStyle().Foreground(t.GreenDark).Underline(true)
+ errorStyle := lipgloss.NewStyle().Foreground(t.Error)
+ mutedStyle := lipgloss.NewStyle().Foreground(t.FgMuted)
+
+ switch d.State {
+ case DeviceFlowStateDisplay:
+ if d.deviceCode == nil {
+ return lipgloss.NewStyle().
+ Margin(0, 1).
+ Render(
+ greenStyle.Render(d.spinner.View()) +
+ mutedStyle.Render("Initializing..."),
+ )
+ }
+
+ instructions := lipgloss.NewStyle().
+ Margin(1, 1, 0, 1).
+ Width(d.width - 2).
+ Render(
+ whiteStyle.Render("Press ") +
+ primaryStyle.Render("enter") +
+ whiteStyle.Render(" to copy the code below and open the browser."),
+ )
+
+ codeBox := lipgloss.NewStyle().
+ Width(d.width-2).
+ Height(7).
+ Align(lipgloss.Center, lipgloss.Center).
+ Background(t.BgBaseLighter).
+ Margin(1).
+ Render(
+ lipgloss.NewStyle().
+ Bold(true).
+ Foreground(t.White).
+ Render(d.deviceCode.UserCode),
+ )
+
+ uri := d.deviceCode.VerificationURI
+ link := lipgloss.NewStyle().Hyperlink(uri, "id=copilot-verify").Render(uri)
+ url := mutedStyle.
+ Margin(0, 1).
+ Width(d.width - 2).
+ Render("Browser not opening? Refer to\n" + link)
+
+ waiting := greenStyle.
+ Width(d.width-2).
+ Margin(1, 1, 0, 1).
+ Render(d.spinner.View() + "Verifying...")
+
+ return lipgloss.JoinVertical(
+ lipgloss.Left,
+ instructions,
+ codeBox,
+ url,
+ waiting,
+ )
+
+ case DeviceFlowStateSuccess:
+ return greenStyle.Margin(0, 1).Render("Authentication successful!")
+
+ case DeviceFlowStateError:
+ return lipgloss.NewStyle().
+ Margin(0, 1).
+ Width(d.width - 2).
+ Render(errorStyle.Render("Authentication failed."))
+
+ case DeviceFlowStateUnavailable:
+ message := lipgloss.NewStyle().
+ Margin(0, 1).
+ Width(d.width - 2).
+ Render("GitHub Copilot is unavailable for this account. To signup, go to the following page:")
+ freeMessage := lipgloss.NewStyle().
+ Margin(0, 1).
+ Width(d.width - 2).
+ Render("You may be able to request free access if elegible. For more information, see:")
+ return lipgloss.JoinVertical(
+ lipgloss.Left,
+ message,
+ "",
+ linkStyle.Margin(0, 1).Width(d.width-2).Hyperlink(copilot.SignupURL, "id=copilot-signup").Render(copilot.SignupURL),
+ "",
+ freeMessage,
+ "",
+ linkStyle.Margin(0, 1).Width(d.width-2).Hyperlink(copilot.FreeURL, "id=copilot-free").Render(copilot.FreeURL),
+ )
+
+ default:
+ return ""
+ }
+}
+
+// SetWidth sets the width of the dialog.
+func (d *DeviceFlow) SetWidth(w int) {
+ d.width = w
+}
+
+// Cursor hides the cursor.
+func (d *DeviceFlow) Cursor() *tea.Cursor { return nil }
+
+// CopyCodeAndOpenURL copies the user code to the clipboard and opens the URL.
+func (d *DeviceFlow) CopyCodeAndOpenURL() tea.Cmd {
+ switch d.State {
+ case DeviceFlowStateDisplay:
+ return tea.Sequence(
+ tea.SetClipboard(d.deviceCode.UserCode),
+ func() tea.Msg {
+ if err := browser.OpenURL(d.deviceCode.VerificationURI); err != nil {
+ return DeviceFlowErrorMsg{Error: fmt.Errorf("failed to open browser: %w", err)}
+ }
+ return nil
+ },
+ util.ReportInfo("Code copied and URL opened"),
+ )
+ case DeviceFlowStateUnavailable:
+ return tea.Sequence(
+ func() tea.Msg {
+ if err := browser.OpenURL(copilot.SignupURL); err != nil {
+ return DeviceFlowErrorMsg{Error: fmt.Errorf("failed to open browser: %w", err)}
+ }
+ return nil
+ },
+ util.ReportInfo("Code copied and URL opened"),
+ )
+ default:
+ return nil
+ }
+}
+
+// CopyCode copies just the user code to the clipboard.
+func (d *DeviceFlow) CopyCode() tea.Cmd {
+ if d.State != DeviceFlowStateDisplay {
+ return nil
+ }
+ return tea.Sequence(
+ tea.SetClipboard(d.deviceCode.UserCode),
+ util.ReportInfo("Code copied to clipboard"),
+ )
+}
+
+// Cancel cancels the device flow polling.
+func (d *DeviceFlow) Cancel() {
+ if d.cancelFunc != nil {
+ d.cancelFunc()
+ }
+}
+
+func (d *DeviceFlow) initiateDeviceAuth() tea.Msg {
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ deviceCode, err := copilot.RequestDeviceCode(ctx)
+ if err != nil {
+ return DeviceFlowErrorMsg{Error: fmt.Errorf("failed to initiate device auth: %w", err)}
+ }
+
+ d.deviceCode = deviceCode
+
+ return DeviceAuthInitiatedMsg{
+ deviceCode: d.deviceCode,
+ }
+}
+
+// startPolling starts polling for the device token.
+func (d *DeviceFlow) startPolling(deviceCode *copilot.DeviceCode) tea.Cmd {
+ return func() tea.Msg {
+ ctx, cancel := context.WithCancel(context.Background())
+ d.cancelFunc = cancel
+
+ token, err := copilot.PollForToken(ctx, deviceCode)
+ if err != nil {
+ if ctx.Err() != nil {
+ return nil // cancelled, don't report error.
+ }
+ return DeviceFlowErrorMsg{Error: err}
+ }
+
+ return DeviceFlowCompletedMsg{Token: token}
+ }
+}
@@ -0,0 +1,267 @@
+// Package hyper provides the dialog for Hyper device flow authentication.
+package hyper
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "charm.land/bubbles/v2/spinner"
+ tea "charm.land/bubbletea/v2"
+ "charm.land/lipgloss/v2"
+ "github.com/charmbracelet/crush/internal/oauth"
+ "github.com/charmbracelet/crush/internal/oauth/hyper"
+ "github.com/charmbracelet/crush/internal/tui/styles"
+ "github.com/charmbracelet/crush/internal/tui/util"
+ "github.com/pkg/browser"
+)
+
+// DeviceFlowState represents the current state of the device flow.
+type DeviceFlowState int
+
+const (
+ DeviceFlowStateDisplay DeviceFlowState = iota
+ DeviceFlowStateSuccess
+ DeviceFlowStateError
+)
+
+// DeviceAuthInitiatedMsg is sent when the device auth is initiated
+// successfully.
+type DeviceAuthInitiatedMsg struct {
+ deviceCode string
+ expiresIn int
+}
+
+// DeviceFlowCompletedMsg is sent when the device flow completes successfully.
+type DeviceFlowCompletedMsg struct {
+ Token *oauth.Token
+}
+
+// DeviceFlowErrorMsg is sent when the device flow encounters an error.
+type DeviceFlowErrorMsg struct {
+ Error error
+}
+
+// DeviceFlow handles the Hyper device flow authentication.
+type DeviceFlow struct {
+ State DeviceFlowState
+ width int
+ deviceCode string
+ userCode string
+ verificationURL string
+ expiresIn int
+ token *oauth.Token
+ cancelFunc context.CancelFunc
+ spinner spinner.Model
+}
+
+// NewDeviceFlow creates a new device flow component.
+func NewDeviceFlow() *DeviceFlow {
+ s := spinner.New()
+ s.Spinner = spinner.Dot
+ s.Style = lipgloss.NewStyle().Foreground(styles.CurrentTheme().GreenLight)
+ return &DeviceFlow{
+ State: DeviceFlowStateDisplay,
+ spinner: s,
+ }
+}
+
+// Init initializes the device flow by calling the device auth API and starting polling.
+func (d *DeviceFlow) Init() tea.Cmd {
+ return tea.Batch(d.spinner.Tick, d.initiateDeviceAuth)
+}
+
+// Update handles messages and state transitions.
+func (d *DeviceFlow) Update(msg tea.Msg) (util.Model, tea.Cmd) {
+ var cmd tea.Cmd
+ d.spinner, cmd = d.spinner.Update(msg)
+
+ switch msg := msg.(type) {
+ case DeviceAuthInitiatedMsg:
+ // Start polling now that we have the device code.
+ d.expiresIn = msg.expiresIn
+ return d, tea.Batch(cmd, d.startPolling(msg.deviceCode))
+ case DeviceFlowCompletedMsg:
+ d.State = DeviceFlowStateSuccess
+ d.token = msg.Token
+ return d, nil
+ case DeviceFlowErrorMsg:
+ d.State = DeviceFlowStateError
+ return d, util.ReportError(msg.Error)
+ }
+
+ return d, cmd
+}
+
+// View renders the device flow dialog.
+func (d *DeviceFlow) View() string {
+ t := styles.CurrentTheme()
+
+ whiteStyle := lipgloss.NewStyle().Foreground(t.White)
+ primaryStyle := lipgloss.NewStyle().Foreground(t.Primary)
+ greenStyle := lipgloss.NewStyle().Foreground(t.GreenLight)
+ linkStyle := lipgloss.NewStyle().Foreground(t.GreenDark).Underline(true)
+ errorStyle := lipgloss.NewStyle().Foreground(t.Error)
+ mutedStyle := lipgloss.NewStyle().Foreground(t.FgMuted)
+
+ switch d.State {
+ case DeviceFlowStateDisplay:
+ if d.userCode == "" {
+ return lipgloss.NewStyle().
+ Margin(0, 1).
+ Render(
+ greenStyle.Render(d.spinner.View()) +
+ mutedStyle.Render("Initializing..."),
+ )
+ }
+
+ instructions := lipgloss.NewStyle().
+ Margin(1, 1, 0, 1).
+ Width(d.width - 2).
+ Render(
+ whiteStyle.Render("Press ") +
+ primaryStyle.Render("enter") +
+ whiteStyle.Render(" to copy the code below and open the browser."),
+ )
+
+ codeBox := lipgloss.NewStyle().
+ Width(d.width-2).
+ Height(7).
+ Align(lipgloss.Center, lipgloss.Center).
+ Background(t.BgBaseLighter).
+ Margin(1).
+ Render(
+ lipgloss.NewStyle().
+ Bold(true).
+ Foreground(t.White).
+ Render(d.userCode),
+ )
+
+ link := linkStyle.Hyperlink(d.verificationURL, "id=hyper-verify").Render(d.verificationURL)
+ url := mutedStyle.
+ Margin(0, 1).
+ Width(d.width - 2).
+ Render("Browser not opening? Refer to\n" + link)
+
+ waiting := greenStyle.
+ Width(d.width-2).
+ Margin(1, 1, 0, 1).
+ Render(d.spinner.View() + "Verifying...")
+
+ return lipgloss.JoinVertical(
+ lipgloss.Left,
+ instructions,
+ codeBox,
+ url,
+ waiting,
+ )
+
+ case DeviceFlowStateSuccess:
+ return greenStyle.Margin(0, 1).Render("Authentication successful!")
+
+ case DeviceFlowStateError:
+ return lipgloss.NewStyle().
+ Margin(0, 1).
+ Width(d.width - 2).
+ Render(errorStyle.Render("Authentication failed."))
+
+ default:
+ return ""
+ }
+}
+
+// SetWidth sets the width of the dialog.
+func (d *DeviceFlow) SetWidth(w int) {
+ d.width = w
+}
+
+// Cursor hides the cursor.
+func (d *DeviceFlow) Cursor() *tea.Cursor { return nil }
+
+// CopyCodeAndOpenURL copies the user code to the clipboard and opens the URL.
+func (d *DeviceFlow) CopyCodeAndOpenURL() tea.Cmd {
+ if d.State != DeviceFlowStateDisplay {
+ return nil
+ }
+ return tea.Sequence(
+ tea.SetClipboard(d.userCode),
+ func() tea.Msg {
+ if err := browser.OpenURL(d.verificationURL); err != nil {
+ return DeviceFlowErrorMsg{Error: fmt.Errorf("failed to open browser: %w", err)}
+ }
+ return nil
+ },
+ util.ReportInfo("Code copied and URL opened"),
+ )
+}
+
+// CopyCode copies just the user code to the clipboard.
+func (d *DeviceFlow) CopyCode() tea.Cmd {
+ if d.State != DeviceFlowStateDisplay {
+ return nil
+ }
+ return tea.Sequence(
+ tea.SetClipboard(d.userCode),
+ util.ReportInfo("Code copied to clipboard"),
+ )
+}
+
+// Cancel cancels the device flow polling.
+func (d *DeviceFlow) Cancel() {
+ if d.cancelFunc != nil {
+ d.cancelFunc()
+ }
+}
+
+func (d *DeviceFlow) initiateDeviceAuth() tea.Msg {
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+ authResp, err := hyper.InitiateDeviceAuth(ctx)
+ if err != nil {
+ return DeviceFlowErrorMsg{Error: fmt.Errorf("failed to initiate device auth: %w", err)}
+ }
+
+ d.deviceCode = authResp.DeviceCode
+ d.userCode = authResp.UserCode
+ d.verificationURL = authResp.VerificationURL
+
+ return DeviceAuthInitiatedMsg{
+ deviceCode: authResp.DeviceCode,
+ expiresIn: authResp.ExpiresIn,
+ }
+}
+
+// startPolling starts polling for the device token.
+func (d *DeviceFlow) startPolling(deviceCode string) tea.Cmd {
+ return func() tea.Msg {
+ ctx, cancel := context.WithCancel(context.Background())
+ d.cancelFunc = cancel
+
+ // Poll for refresh token.
+ refreshToken, err := hyper.PollForToken(ctx, deviceCode, d.expiresIn)
+ if err != nil {
+ if ctx.Err() != nil {
+ // Cancelled, don't report error.
+ return nil
+ }
+ return DeviceFlowErrorMsg{Error: err}
+ }
+
+ // Exchange refresh token for access token.
+ token, err := hyper.ExchangeToken(ctx, refreshToken)
+ if err != nil {
+ return DeviceFlowErrorMsg{Error: fmt.Errorf("token exchange failed: %w", err)}
+ }
+
+ // Verify the access token works.
+ introspect, err := hyper.IntrospectToken(ctx, token.AccessToken)
+ if err != nil {
+ return DeviceFlowErrorMsg{Error: fmt.Errorf("token introspection failed: %w", err)}
+ }
+ if !introspect.Active {
+ return DeviceFlowErrorMsg{Error: fmt.Errorf("access token is not active")}
+ }
+
+ return DeviceFlowCompletedMsg{Token: token}
+ }
+}
@@ -15,6 +15,10 @@ type KeyMap struct {
isAPIKeyHelp bool
isAPIKeyValid bool
+ isHyperDeviceFlow bool
+ isCopilotDeviceFlow bool
+ isCopilotUnavailable bool
+
isClaudeAuthChoiceHelp bool
isClaudeOAuthHelp bool
isClaudeOAuthURLState bool
@@ -74,6 +78,28 @@ func (k KeyMap) FullHelp() [][]key.Binding {
// ShortHelp implements help.KeyMap.
func (k KeyMap) ShortHelp() []key.Binding {
+ if k.isHyperDeviceFlow || k.isCopilotDeviceFlow {
+ return []key.Binding{
+ key.NewBinding(
+ key.WithKeys("c"),
+ key.WithHelp("c", "copy code"),
+ ),
+ key.NewBinding(
+ key.WithKeys("enter"),
+ key.WithHelp("enter", "copy & open"),
+ ),
+ k.Close,
+ }
+ }
+ if k.isCopilotUnavailable {
+ return []key.Binding{
+ key.NewBinding(
+ key.WithKeys("enter"),
+ key.WithHelp("enter", "open signup"),
+ ),
+ k.Close,
+ }
+ }
if k.isClaudeAuthChoiceHelp {
return []key.Binding{
key.NewBinding(
@@ -62,7 +62,9 @@ func (m *ModelListComponent) Init() tea.Cmd {
filteredProviders := []catwalk.Provider{}
for _, p := range providers {
hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
- if hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure {
+ isHyper := p.ID == "hyper"
+ isCopilot := p.ID == catwalk.InferenceProviderCopilot
+ if (hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure) || isHyper || isCopilot {
filteredProviders = append(filteredProviders, p)
}
}
@@ -146,29 +148,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) ||
!slices.ContainsFunc(m.providers, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
// Convert config provider to provider.Provider format
- configProvider := catwalk.Provider{
- Name: providerConfig.Name,
- ID: catwalk.InferenceProvider(providerID),
- Models: make([]catwalk.Model, len(providerConfig.Models)),
- }
-
- // Convert models
- for i, model := range providerConfig.Models {
- configProvider.Models[i] = catwalk.Model{
- ID: model.ID,
- Name: model.Name,
- CostPer1MIn: model.CostPer1MIn,
- CostPer1MOut: model.CostPer1MOut,
- CostPer1MInCached: model.CostPer1MInCached,
- CostPer1MOutCached: model.CostPer1MOutCached,
- ContextWindow: model.ContextWindow,
- DefaultMaxTokens: model.DefaultMaxTokens,
- CanReason: model.CanReason,
- ReasoningLevels: model.ReasoningLevels,
- DefaultReasoningEffort: model.DefaultReasoningEffort,
- SupportsImages: model.SupportsImages,
- }
- }
+ configProvider := providerConfig.ToProvider()
// Add this unknown provider to the list
name := configProvider.Name
@@ -204,8 +184,23 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
}
}
+ // Move "Charm Hyper" to first position
+ // (but still after recent models and custom providers).
+ sortedProviders := make([]catwalk.Provider, len(m.providers))
+ copy(sortedProviders, m.providers)
+ slices.SortStableFunc(sortedProviders, func(a, b catwalk.Provider) int {
+ switch {
+ case a.ID == "hyper":
+ return -1
+ case b.ID == "hyper":
+ return 1
+ default:
+ return 0
+ }
+ })
+
// Then add the known providers from the predefined list
- for _, provider := range m.providers {
+ for _, provider := range sortedProviders {
// Skip if we already added this provider as an unknown provider
if addedProviders[string(provider.ID)] {
continue
@@ -1,3 +1,4 @@
+// Package models provides the model selection dialog for the TUI.
package models
import (
@@ -11,10 +12,13 @@ import (
"charm.land/lipgloss/v2"
"github.com/atotto/clipboard"
"github.com/charmbracelet/catwalk/pkg/catwalk"
+ hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/tui/components/core"
"github.com/charmbracelet/crush/internal/tui/components/dialogs"
"github.com/charmbracelet/crush/internal/tui/components/dialogs/claude"
+ "github.com/charmbracelet/crush/internal/tui/components/dialogs/copilot"
+ "github.com/charmbracelet/crush/internal/tui/components/dialogs/hyper"
"github.com/charmbracelet/crush/internal/tui/exp/list"
"github.com/charmbracelet/crush/internal/tui/styles"
"github.com/charmbracelet/crush/internal/tui/util"
@@ -70,6 +74,14 @@ type modelDialogCmp struct {
isAPIKeyValid bool
apiKeyValue string
+ // Hyper device flow state
+ hyperDeviceFlow *hyper.DeviceFlow
+ showHyperDeviceFlow bool
+
+ // Copilot device flow state
+ copilotDeviceFlow *copilot.DeviceFlow
+ showCopilotDeviceFlow bool
+
// Claude state
claudeAuthMethodChooser *claude.AuthMethodChooser
claudeOAuth2 *claude.OAuth2
@@ -127,6 +139,24 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
u, cmd := m.apiKeyInput.Update(msg)
m.apiKeyInput = u.(*APIKeyInput)
return m, cmd
+ case hyper.DeviceFlowCompletedMsg:
+ return m, m.saveOauthTokenAndContinue(msg.Token, true)
+ case hyper.DeviceAuthInitiatedMsg, hyper.DeviceFlowErrorMsg:
+ if m.hyperDeviceFlow != nil {
+ u, cmd := m.hyperDeviceFlow.Update(msg)
+ m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
+ return m, cmd
+ }
+ return m, nil
+ case copilot.DeviceAuthInitiatedMsg, copilot.DeviceFlowErrorMsg:
+ if m.copilotDeviceFlow != nil {
+ u, cmd := m.copilotDeviceFlow.Update(msg)
+ m.copilotDeviceFlow = u.(*copilot.DeviceFlow)
+ return m, cmd
+ }
+ return m, nil
+ case copilot.DeviceFlowCompletedMsg:
+ return m, m.saveOauthTokenAndContinue(msg.Token, true)
case claude.ValidationCompletedMsg:
var cmds []tea.Cmd
u, cmd := m.claudeOAuth2.Update(msg)
@@ -134,7 +164,7 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
cmds = append(cmds, cmd)
if msg.State == claude.OAuthValidationStateValid {
- cmds = append(cmds, m.saveAPIKeyAndContinue(msg.Token, false))
+ cmds = append(cmds, m.saveOauthTokenAndContinue(msg.Token, false))
m.keyMap.isClaudeOAuthHelpComplete = true
}
@@ -143,6 +173,14 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
return m, util.CmdHandler(dialogs.CloseDialogMsg{})
case tea.KeyPressMsg:
switch {
+ // Handle Hyper device flow keys
+ case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && (m.showHyperDeviceFlow || m.showCopilotDeviceFlow):
+ if m.hyperDeviceFlow != nil {
+ return m, m.hyperDeviceFlow.CopyCode()
+ }
+ if m.copilotDeviceFlow != nil {
+ return m, m.copilotDeviceFlow.CopyCode()
+ }
case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showClaudeOAuth2 && m.claudeOAuth2.State == claude.OAuthStateURL:
return m, tea.Sequence(
tea.SetClipboard(m.claudeOAuth2.URL),
@@ -156,6 +194,13 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
m.claudeAuthMethodChooser.ToggleChoice()
return m, nil
case key.Matches(msg, m.keyMap.Select):
+ // If showing device flow, enter copies code and opens URL
+ if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
+ return m, m.hyperDeviceFlow.CopyCodeAndOpenURL()
+ }
+ if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
+ return m, m.copilotDeviceFlow.CopyCodeAndOpenURL()
+ }
selectedItem := m.modelList.SelectedModel()
modelType := config.SelectedModelTypeLarge
@@ -167,6 +212,8 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
m.keyMap.isClaudeAuthChoiceHelp = false
m.keyMap.isClaudeOAuthHelp = false
m.keyMap.isAPIKeyHelp = true
+ m.showHyperDeviceFlow = false
+ m.showCopilotDeviceFlow = false
m.showClaudeAuthMethodChooser = false
m.needsAPIKey = true
m.selectedModel = selectedItem
@@ -194,7 +241,7 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
return m, cmd2
}
if m.isAPIKeyValid {
- return m, m.saveAPIKeyAndContinue(m.apiKeyValue, true)
+ return m, m.saveOauthTokenAndContinue(m.apiKeyValue, true)
}
if m.needsAPIKey {
// Handle API key submission
@@ -249,15 +296,30 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
ModelType: modelType,
}),
)
- } else {
- if selectedItem.Provider.ID == catwalk.InferenceProviderAnthropic {
- m.showClaudeAuthMethodChooser = true
- m.keyMap.isClaudeAuthChoiceHelp = true
- return m, nil
- }
- askForApiKey()
+ }
+ switch selectedItem.Provider.ID {
+ case catwalk.InferenceProviderAnthropic:
+ m.showClaudeAuthMethodChooser = true
+ m.keyMap.isClaudeAuthChoiceHelp = true
return m, nil
+ case hyperp.Name:
+ m.showHyperDeviceFlow = true
+ m.selectedModel = selectedItem
+ m.selectedModelType = modelType
+ m.hyperDeviceFlow = hyper.NewDeviceFlow()
+ m.hyperDeviceFlow.SetWidth(m.width - 2)
+ return m, m.hyperDeviceFlow.Init()
+ case catwalk.InferenceProviderCopilot:
+ m.showCopilotDeviceFlow = true
+ m.selectedModel = selectedItem
+ m.selectedModelType = modelType
+ m.copilotDeviceFlow = copilot.NewDeviceFlow()
+ m.copilotDeviceFlow.SetWidth(m.width - 2)
+ return m, m.copilotDeviceFlow.Init()
}
+ // For other providers, show API key input
+ askForApiKey()
+ return m, nil
case key.Matches(msg, m.keyMap.Tab):
switch {
case m.showClaudeAuthMethodChooser:
@@ -275,6 +337,20 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
return m, m.modelList.SetModelType(LargeModelType)
}
case key.Matches(msg, m.keyMap.Close):
+ if m.showHyperDeviceFlow {
+ if m.hyperDeviceFlow != nil {
+ m.hyperDeviceFlow.Cancel()
+ }
+ m.showHyperDeviceFlow = false
+ m.selectedModel = nil
+ }
+ if m.showCopilotDeviceFlow {
+ if m.copilotDeviceFlow != nil {
+ m.copilotDeviceFlow.Cancel()
+ }
+ m.showCopilotDeviceFlow = false
+ m.selectedModel = nil
+ }
if m.showClaudeAuthMethodChooser {
m.claudeAuthMethodChooser.SetDefaults()
m.showClaudeAuthMethodChooser = false
@@ -329,11 +405,33 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
return m, cmd
}
case spinner.TickMsg:
- if m.showClaudeOAuth2 {
+ u, cmd := m.apiKeyInput.Update(msg)
+ m.apiKeyInput = u.(*APIKeyInput)
+ if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
+ u, cmd = m.hyperDeviceFlow.Update(msg)
+ m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
+ }
+ if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
+ u, cmd = m.copilotDeviceFlow.Update(msg)
+ m.copilotDeviceFlow = u.(*copilot.DeviceFlow)
+ }
+ return m, cmd
+ default:
+ // Pass all other messages to the device flow for spinner animation
+ switch {
+ case m.showHyperDeviceFlow && m.hyperDeviceFlow != nil:
+ u, cmd := m.hyperDeviceFlow.Update(msg)
+ m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
+ return m, cmd
+ case m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil:
+ u, cmd := m.copilotDeviceFlow.Update(msg)
+ m.copilotDeviceFlow = u.(*copilot.DeviceFlow)
+ return m, cmd
+ case m.showClaudeOAuth2:
u, cmd := m.claudeOAuth2.Update(msg)
m.claudeOAuth2 = u.(*claude.OAuth2)
return m, cmd
- } else {
+ default:
u, cmd := m.apiKeyInput.Update(msg)
m.apiKeyInput = u.(*APIKeyInput)
return m, cmd
@@ -345,6 +443,39 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
func (m *modelDialogCmp) View() string {
t := styles.CurrentTheme()
+ if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
+ // Show Hyper device flow
+ m.keyMap.isHyperDeviceFlow = true
+ deviceFlowView := m.hyperDeviceFlow.View()
+ content := lipgloss.JoinVertical(
+ lipgloss.Left,
+ t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Authenticate with Hyper", m.width-4)),
+ deviceFlowView,
+ "",
+ t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
+ )
+ return m.style().Render(content)
+ }
+ if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
+ // Show Hyper device flow
+ m.keyMap.isCopilotDeviceFlow = m.copilotDeviceFlow.State != copilot.DeviceFlowStateUnavailable
+ m.keyMap.isCopilotUnavailable = m.copilotDeviceFlow.State == copilot.DeviceFlowStateUnavailable
+ deviceFlowView := m.copilotDeviceFlow.View()
+ content := lipgloss.JoinVertical(
+ lipgloss.Left,
+ t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Authenticate with GitHub Copilot", m.width-4)),
+ deviceFlowView,
+ "",
+ t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
+ )
+ return m.style().Render(content)
+ }
+
+ // Reset the flags when not showing device flow
+ m.keyMap.isHyperDeviceFlow = false
+ m.keyMap.isCopilotDeviceFlow = false
+ m.keyMap.isCopilotUnavailable = false
+
switch {
case m.showClaudeAuthMethodChooser:
chooserView := m.claudeAuthMethodChooser.View()
@@ -397,6 +528,12 @@ func (m *modelDialogCmp) View() string {
}
func (m *modelDialogCmp) Cursor() *tea.Cursor {
+ if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
+ return m.hyperDeviceFlow.Cursor()
+ }
+ if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
+ return m.copilotDeviceFlow.Cursor()
+ }
if m.showClaudeAuthMethodChooser {
return nil
}
@@ -477,10 +614,8 @@ func (m *modelDialogCmp) modelTypeRadio() string {
func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
cfg := config.Get()
- if _, ok := cfg.Providers.Get(providerID); ok {
- return true
- }
- return false
+ _, ok := cfg.Providers.Get(providerID)
+ return ok
}
func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
@@ -497,7 +632,7 @@ func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*cat
return nil, nil
}
-func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey any, close bool) tea.Cmd {
+func (m *modelDialogCmp) saveOauthTokenAndContinue(apiKey any, close bool) tea.Cmd {
if m.selectedModel == nil {
return util.ReportError(fmt.Errorf("no model selected"))
}
@@ -69,10 +69,18 @@ func RenderMCPList(opts RenderOptions) []string {
case mcp.StateConnected:
icon = t.ItemOnlineIcon
if count := state.Counts.Tools; count > 0 {
- extraContent = append(extraContent, t.S().Subtle.Render(fmt.Sprintf("%d tools", count)))
+ label := "tools"
+ if count == 1 {
+ label = "tool"
+ }
+ extraContent = append(extraContent, t.S().Subtle.Render(fmt.Sprintf("%d %s", count, label)))
}
if count := state.Counts.Prompts; count > 0 {
- extraContent = append(extraContent, t.S().Subtle.Render(fmt.Sprintf("%d prompts", count)))
+ label := "prompts"
+ if count == 1 {
+ label = "prompt"
+ }
+ extraContent = append(extraContent, t.S().Subtle.Render(fmt.Sprintf("%d %s", count, label)))
}
case mcp.StateError:
icon = t.ItemErrorIcon
@@ -323,6 +323,7 @@ type ItemSection interface {
layout.Sizeable
Indexable
SetInfo(info string)
+ Title() string
}
type itemSectionModel struct {
width int
@@ -337,6 +338,11 @@ func (m *itemSectionModel) ID() string {
return m.id
}
+// Title implements ItemSection.
+func (m *itemSectionModel) Title() string {
+ return m.title
+}
+
func NewItemSection(title string) ItemSection {
return &itemSectionModel{
title: title,
@@ -86,13 +86,20 @@ func renderItem(t *styles.Styles, title string, updatedAt int64, focused bool, w
}
var ageLen int
+ var right string
+ lineWidth := width
if updatedAt > 0 {
ageLen = lipgloss.Width(age)
+ lineWidth -= ageLen
}
- title = ansi.Truncate(title, max(0, width-ageLen), "…")
+ title = ansi.Truncate(title, max(0, lineWidth), "…")
titleLen := lipgloss.Width(title)
- right := lipgloss.NewStyle().AlignHorizontal(lipgloss.Right).Width(width - titleLen).Render(age)
+
+ if updatedAt > 0 {
+ right = lipgloss.NewStyle().AlignHorizontal(lipgloss.Right).Width(width - titleLen).Render(age)
+ }
+
content := title
if matches := len(m.MatchedIndexes); matches > 0 {
var lastPos int
@@ -70,13 +70,17 @@ func (f *FilterableList) SetFilter(q string) {
f.query = q
}
-type filterableItems []FilterableItem
+// FilterableItemsSource is a type that implements [fuzzy.Source] for filtering
+// [FilterableItem]s.
+type FilterableItemsSource []FilterableItem
-func (f filterableItems) Len() int {
+// Len returns the length of the source.
+func (f FilterableItemsSource) Len() int {
return len(f)
}
-func (f filterableItems) String(i int) string {
+// String returns the string representation of the item at index i.
+func (f FilterableItemsSource) String(i int) string {
return f[i].Filter()
}
@@ -94,7 +98,7 @@ func (f *FilterableList) VisibleItems() []Item {
return items
}
- items := filterableItems(f.items)
+ items := FilterableItemsSource(f.items)
matches := fuzzy.FindFrom(f.query, items)
matchedItems := []Item{}
resultSize := len(matches)
@@ -1,6 +1,8 @@
package list
import (
+ "strings"
+
"github.com/charmbracelet/x/ansi"
)
@@ -30,3 +32,20 @@ type MouseClickable interface {
// It returns true if the event was handled, false otherwise.
HandleMouseClick(btn ansi.MouseButton, x, y int) bool
}
+
+// SpacerItem is a spacer item that adds vertical space in the list.
+type SpacerItem struct {
+ Height int
+}
+
+// NewSpacerItem creates a new [SpacerItem] with the specified height.
+func NewSpacerItem(height int) *SpacerItem {
+ return &SpacerItem{
+ Height: max(0, height-1),
+ }
+}
+
+// Render implements the Item interface for [SpacerItem].
+func (s *SpacerItem) Render(width int) string {
+ return strings.Repeat("\n", s.Height)
+}
@@ -390,6 +390,7 @@ func (l *List) SelectedItemInView() bool {
}
// SetSelected sets the selected item index in the list.
+// It returns -1 if the index is out of bounds.
func (l *List) SetSelected(index int) {
if index < 0 || index >= len(l.items) {
l.selectedIdx = -1
@@ -415,31 +416,43 @@ func (l *List) IsSelectedLast() bool {
}
// SelectPrev selects the previous item in the list.
-func (l *List) SelectPrev() {
+// It returns whether the selection changed.
+func (l *List) SelectPrev() bool {
if l.selectedIdx > 0 {
l.selectedIdx--
+ return true
}
+ return false
}
// SelectNext selects the next item in the list.
-func (l *List) SelectNext() {
+// It returns whether the selection changed.
+func (l *List) SelectNext() bool {
if l.selectedIdx < len(l.items)-1 {
l.selectedIdx++
+ return true
}
+ return false
}
// SelectFirst selects the first item in the list.
-func (l *List) SelectFirst() {
+// It returns whether the selection changed.
+func (l *List) SelectFirst() bool {
if len(l.items) > 0 {
l.selectedIdx = 0
+ return true
}
+ return false
}
// SelectLast selects the last item in the list.
-func (l *List) SelectLast() {
+// It returns whether the selection changed.
+func (l *List) SelectLast() bool {
if len(l.items) > 0 {
l.selectedIdx = len(l.items) - 1
+ return true
}
+ return false
}
// SelectedItem returns the currently selected item. It may be nil if no item
@@ -169,30 +169,30 @@ func (m *Chat) Blur() {
m.list.Blur()
}
-// ScrollToTop scrolls the chat view to the top and returns a command to restart
+// ScrollToTopAndAnimate scrolls the chat view to the top and returns a command to restart
// any paused animations that are now visible.
-func (m *Chat) ScrollToTop() tea.Cmd {
+func (m *Chat) ScrollToTopAndAnimate() tea.Cmd {
m.list.ScrollToTop()
return m.RestartPausedVisibleAnimations()
}
-// ScrollToBottom scrolls the chat view to the bottom and returns a command to
+// ScrollToBottomAndAnimate scrolls the chat view to the bottom and returns a command to
// restart any paused animations that are now visible.
-func (m *Chat) ScrollToBottom() tea.Cmd {
+func (m *Chat) ScrollToBottomAndAnimate() tea.Cmd {
m.list.ScrollToBottom()
return m.RestartPausedVisibleAnimations()
}
-// ScrollBy scrolls the chat view by the given number of line deltas and returns
+// ScrollByAndAnimate scrolls the chat view by the given number of line deltas and returns
// a command to restart any paused animations that are now visible.
-func (m *Chat) ScrollBy(lines int) tea.Cmd {
+func (m *Chat) ScrollByAndAnimate(lines int) tea.Cmd {
m.list.ScrollBy(lines)
return m.RestartPausedVisibleAnimations()
}
-// ScrollToSelected scrolls the chat view to the selected item and returns a
+// ScrollToSelectedAndAnimate scrolls the chat view to the selected item and returns a
// command to restart any paused animations that are now visible.
-func (m *Chat) ScrollToSelected() tea.Cmd {
+func (m *Chat) ScrollToSelectedAndAnimate() tea.Cmd {
m.list.ScrollToSelected()
return m.RestartPausedVisibleAnimations()
}
@@ -16,8 +16,10 @@ func (m *UI) mcpInfo(width, maxItems int, isSection bool) string {
var mcps []mcp.ClientInfo
t := m.com.Styles
- for _, state := range m.mcpStates {
- mcps = append(mcps, state)
+ for _, mcp := range m.com.Config().MCP.Sorted() {
+ if state, ok := m.mcpStates[mcp.Name]; ok {
+ mcps = append(mcps, state)
+ }
}
title := t.Subtle.Render("MCPs")
@@ -267,22 +267,22 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch m.state {
case uiChat:
if msg.Y <= 0 {
- if cmd := m.chat.ScrollBy(-1); cmd != nil {
+ if cmd := m.chat.ScrollByAndAnimate(-1); cmd != nil {
cmds = append(cmds, cmd)
}
if !m.chat.SelectedItemInView() {
m.chat.SelectPrev()
- if cmd := m.chat.ScrollToSelected(); cmd != nil {
+ if cmd := m.chat.ScrollToSelectedAndAnimate(); cmd != nil {
cmds = append(cmds, cmd)
}
}
} else if msg.Y >= m.chat.Height()-1 {
- if cmd := m.chat.ScrollBy(1); cmd != nil {
+ if cmd := m.chat.ScrollByAndAnimate(1); cmd != nil {
cmds = append(cmds, cmd)
}
if !m.chat.SelectedItemInView() {
m.chat.SelectNext()
- if cmd := m.chat.ScrollToSelected(); cmd != nil {
+ if cmd := m.chat.ScrollToSelectedAndAnimate(); cmd != nil {
cmds = append(cmds, cmd)
}
}
@@ -309,22 +309,22 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case uiChat:
switch msg.Button {
case tea.MouseWheelUp:
- if cmd := m.chat.ScrollBy(-5); cmd != nil {
+ if cmd := m.chat.ScrollByAndAnimate(-5); cmd != nil {
cmds = append(cmds, cmd)
}
if !m.chat.SelectedItemInView() {
m.chat.SelectPrev()
- if cmd := m.chat.ScrollToSelected(); cmd != nil {
+ if cmd := m.chat.ScrollToSelectedAndAnimate(); cmd != nil {
cmds = append(cmds, cmd)
}
}
case tea.MouseWheelDown:
- if cmd := m.chat.ScrollBy(5); cmd != nil {
+ if cmd := m.chat.ScrollByAndAnimate(5); cmd != nil {
cmds = append(cmds, cmd)
}
if !m.chat.SelectedItemInView() {
m.chat.SelectNext()
- if cmd := m.chat.ScrollToSelected(); cmd != nil {
+ if cmd := m.chat.ScrollToSelectedAndAnimate(); cmd != nil {
cmds = append(cmds, cmd)
}
}
@@ -394,7 +394,7 @@ func (m *UI) setSessionMessages(msgs []message.Message) tea.Cmd {
}
m.chat.SetMessages(items...)
- if cmd := m.chat.ScrollToBottom(); cmd != nil {
+ if cmd := m.chat.ScrollToBottomAndAnimate(); cmd != nil {
cmds = append(cmds, cmd)
}
m.chat.SelectLast()
@@ -416,7 +416,7 @@ func (m *UI) appendSessionMessage(msg message.Message) tea.Cmd {
}
}
m.chat.AppendMessages(items...)
- if cmd := m.chat.ScrollToBottom(); cmd != nil {
+ if cmd := m.chat.ScrollToBottomAndAnimate(); cmd != nil {
cmds = append(cmds, cmd)
}
case message.Tool:
@@ -472,7 +472,7 @@ func (m *UI) updateSessionMessage(msg message.Message) tea.Cmd {
}
}
m.chat.AppendMessages(items...)
- if cmd := m.chat.ScrollToBottom(); cmd != nil {
+ if cmd := m.chat.ScrollToBottomAndAnimate(); cmd != nil {
cmds = append(cmds, cmd)
}
@@ -641,62 +641,62 @@ func (m *UI) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd {
case key.Matches(msg, m.keyMap.Chat.Expand):
m.chat.ToggleExpandedSelectedItem()
case key.Matches(msg, m.keyMap.Chat.Up):
- if cmd := m.chat.ScrollBy(-1); cmd != nil {
+ if cmd := m.chat.ScrollByAndAnimate(-1); cmd != nil {
cmds = append(cmds, cmd)
}
if !m.chat.SelectedItemInView() {
m.chat.SelectPrev()
- if cmd := m.chat.ScrollToSelected(); cmd != nil {
+ if cmd := m.chat.ScrollToSelectedAndAnimate(); cmd != nil {
cmds = append(cmds, cmd)
}
}
case key.Matches(msg, m.keyMap.Chat.Down):
- if cmd := m.chat.ScrollBy(1); cmd != nil {
+ if cmd := m.chat.ScrollByAndAnimate(1); cmd != nil {
cmds = append(cmds, cmd)
}
if !m.chat.SelectedItemInView() {
m.chat.SelectNext()
- if cmd := m.chat.ScrollToSelected(); cmd != nil {
+ if cmd := m.chat.ScrollToSelectedAndAnimate(); cmd != nil {
cmds = append(cmds, cmd)
}
}
case key.Matches(msg, m.keyMap.Chat.UpOneItem):
m.chat.SelectPrev()
- if cmd := m.chat.ScrollToSelected(); cmd != nil {
+ if cmd := m.chat.ScrollToSelectedAndAnimate(); cmd != nil {
cmds = append(cmds, cmd)
}
case key.Matches(msg, m.keyMap.Chat.DownOneItem):
m.chat.SelectNext()
- if cmd := m.chat.ScrollToSelected(); cmd != nil {
+ if cmd := m.chat.ScrollToSelectedAndAnimate(); cmd != nil {
cmds = append(cmds, cmd)
}
case key.Matches(msg, m.keyMap.Chat.HalfPageUp):
- if cmd := m.chat.ScrollBy(-m.chat.Height() / 2); cmd != nil {
+ if cmd := m.chat.ScrollByAndAnimate(-m.chat.Height() / 2); cmd != nil {
cmds = append(cmds, cmd)
}
m.chat.SelectFirstInView()
case key.Matches(msg, m.keyMap.Chat.HalfPageDown):
- if cmd := m.chat.ScrollBy(m.chat.Height() / 2); cmd != nil {
+ if cmd := m.chat.ScrollByAndAnimate(m.chat.Height() / 2); cmd != nil {
cmds = append(cmds, cmd)
}
m.chat.SelectLastInView()
case key.Matches(msg, m.keyMap.Chat.PageUp):
- if cmd := m.chat.ScrollBy(-m.chat.Height()); cmd != nil {
+ if cmd := m.chat.ScrollByAndAnimate(-m.chat.Height()); cmd != nil {
cmds = append(cmds, cmd)
}
m.chat.SelectFirstInView()
case key.Matches(msg, m.keyMap.Chat.PageDown):
- if cmd := m.chat.ScrollBy(m.chat.Height()); cmd != nil {
+ if cmd := m.chat.ScrollByAndAnimate(m.chat.Height()); cmd != nil {
cmds = append(cmds, cmd)
}
m.chat.SelectLastInView()
case key.Matches(msg, m.keyMap.Chat.Home):
- if cmd := m.chat.ScrollToTop(); cmd != nil {
+ if cmd := m.chat.ScrollToTopAndAnimate(); cmd != nil {
cmds = append(cmds, cmd)
}
m.chat.SelectFirst()
case key.Matches(msg, m.keyMap.Chat.End):
- if cmd := m.chat.ScrollToBottom(); cmd != nil {
+ if cmd := m.chat.ScrollToBottomAndAnimate(); cmd != nil {
cmds = append(cmds, cmd)
}
m.chat.SelectLast()