diff --git a/.gitattributes b/.gitattributes index d0a96600fea2558f4f8be8840c048780e48c9a7e..2e533aabaaa6a7168816003b0da820804778fd4e 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,3 +1,4 @@ *.golden linguist-generated=true -text .github/crush-schema.json linguist-generated=true +internal/agent/hyper/provider.json linguist-generated=true internal/agent/testdata/**/*.yaml -diff linguist-generated=true diff --git a/.github/workflows/schema-update.yml b/.github/workflows/schema-update.yml index d2366652b8c21c4606605db9fc17d92031d3ad5e..949dac5c260497344969c182302baa0d968113bf 100644 --- a/.github/workflows/schema-update.yml +++ b/.github/workflows/schema-update.yml @@ -1,10 +1,11 @@ -name: Update Schema +name: Update files on: push: branches: [main] paths: - "internal/config/**" + - "internal/agent/hyper/**" jobs: update-schema: @@ -17,9 +18,10 @@ jobs: with: go-version-file: go.mod - run: go run . schema > ./schema.json + - run: go generate ./internal/agent/hyper/... - uses: stefanzweifel/git-auto-commit-action@28e16e81777b558cc906c8750092100bbb34c5e3 # v5 with: - commit_message: "chore: auto-update generated files" + commit_message: "chore: auto-update files" branch: main commit_user_name: Charm commit_user_email: 124303983+charmcli@users.noreply.github.com diff --git a/README.md b/README.md index e25c99a5cb84372414d68a63a511eba824ac9b76..c268cb7cedf4b80632dbd75458ad3db90900edf0 100644 --- a/README.md +++ b/README.md @@ -232,6 +232,11 @@ $HOME/.local/share/crush/crush.json %LOCALAPPDATA%\crush\crush.json ``` +> [!TIP] +> You can override the user and data config locations by setting: +> * `CRUSH_GLOBAL_CONFIG` +> * `CRUSH_GLOBAL_DATA` + ### LSPs Crush can use LSPs for additional context to help inform its decisions, just diff --git a/Taskfile.yaml b/Taskfile.yaml index f9c20cd63e64fea25ce2bbd7f8a6982170dbe02e..2f5574f7ab1f07a03f47e8534d477afd293d9248 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -101,6 +101,13 @@ tasks: generates: - schema.json + hyper: + desc: Update Hyper embedded provider.json + cmds: + - go generate ./internal/agent/hyper/... + generates: + - ./internal/agent/hyper/provider.json + release: desc: Create and push a new tag following semver vars: diff --git a/cspell.json b/cspell.json deleted file mode 100644 index 368a9d5094dc6ebd1850f118533dc41050b0eb90..0000000000000000000000000000000000000000 --- a/cspell.json +++ /dev/null @@ -1 +0,0 @@ -{"words":["afero","agentic","alecthomas","anthropics","aymanbagabas","azidentity","bmatcuk","bubbletea","charlievieth","charmbracelet","charmtone","Charple","chkconfig","crush","curlie","cursorrules","diffview","doas","Dockerfiles","doublestar","dpkg","Emph","fastwalk","fdisk","filepicker","Focusable","fseventsd","fsext","genai","goquery","GROQ","Guac","imageorient","Inex","jetta","jsons","jsonschema","jspm","Kaufmann","killall","Lanczos","lipgloss","LOCALAPPDATA","lsps","lucasb","makepkg","mcps","MSYS","mvdan","natefinch","nfnt","noctx","nohup","nolint","nslookup","oksvg","Oneshot","openrouter","opkg","pacman","paru","pfctl","postamble","postambles","preconfigured","Preproc","Proactiveness","Puerkito","pycache","pytest","qjebbs","rasterx","rivo","sabhiram","sess","shortlog","sjson","Sourcegraph","srwiley","SSEMCP","Streamable","stretchr","Strikethrough","substrs","Suscriber","systeminfo","tasklist","termenv","textinput","tidwall","timedout","trashhalo","udiff","uniseg","Unticked","urllib","USERPROFILE","VERTEXAI","webp","whatis","whereis","sahilm","csync","Highlightable","Highlightable","prerendered","prerender","kujtim","animatable"],"version":"0.2","flagWords":[],"language":"en"} \ No newline at end of file diff --git a/go.mod b/go.mod index 60e1a78fbac9248a6f346386d5e26e9b5a054516..f475ca4d4bc2a6f889b7c76b8f24ed3e8f1e5e75 100644 --- a/go.mod +++ b/go.mod @@ -18,11 +18,12 @@ require ( github.com/aymanbagabas/go-udiff v0.3.1 github.com/bmatcuk/doublestar/v4 v4.9.1 github.com/charlievieth/fastwalk v1.0.14 - github.com/charmbracelet/catwalk v0.10.2 + github.com/charmbracelet/catwalk v0.11.0 github.com/charmbracelet/colorprofile v0.4.1 github.com/charmbracelet/fang v0.4.4 github.com/charmbracelet/ultraviolet v0.0.0-20251212194010-b927aa605560 github.com/charmbracelet/x/ansi v0.11.3 + github.com/charmbracelet/x/etag v0.2.0 github.com/charmbracelet/x/exp/charmtone v0.0.0-20250708181618-a60a724ba6c3 github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f github.com/charmbracelet/x/exp/ordered v0.1.0 @@ -53,6 +54,7 @@ require ( github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef github.com/stretchr/testify v1.11.1 + github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/zeebo/xxh3 v1.0.2 golang.org/x/mod v0.31.0 @@ -92,7 +94,6 @@ require ( github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect github.com/charmbracelet/anthropic-sdk-go v0.0.0-20251024181547-21d6f3d9a904 // indirect - github.com/charmbracelet/x/etag v0.2.0 // indirect github.com/charmbracelet/x/json v0.2.0 // indirect github.com/charmbracelet/x/termios v0.1.1 // indirect github.com/charmbracelet/x/windows v0.2.2 // indirect @@ -145,7 +146,6 @@ require ( github.com/sourcegraph/jsonrpc2 v0.2.1 // indirect github.com/spf13/pflag v1.0.9 // indirect github.com/tetratelabs/wazero v1.10.1 // indirect - github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/u-root/u-root v0.14.1-0.20250807200646-5e7721023dc7 // indirect diff --git a/go.sum b/go.sum index 8d17bc6ef2ec5a3c6b03717e7b4d76674f8e6b65..d107ae21a26cdfe2061d697813ad7a018f86177b 100644 --- a/go.sum +++ b/go.sum @@ -92,8 +92,8 @@ github.com/charlievieth/fastwalk v1.0.14 h1:3Eh5uaFGwHZd8EGwTjJnSpBkfwfsak9h6ICg github.com/charlievieth/fastwalk v1.0.14/go.mod h1:diVcUreiU1aQ4/Wu3NbxxH4/KYdKpLDojrQ1Bb2KgNY= github.com/charmbracelet/anthropic-sdk-go v0.0.0-20251024181547-21d6f3d9a904 h1:rwLdEpG9wE6kL69KkEKDiWprO8pQOZHZXeod6+9K+mw= github.com/charmbracelet/anthropic-sdk-go v0.0.0-20251024181547-21d6f3d9a904/go.mod h1:8TIYxZxsuCqqeJ0lga/b91tBwrbjoHDC66Sq5t8N2R4= -github.com/charmbracelet/catwalk v0.10.2 h1:Ps6IeGu0ArKE3l3OYv+HwIwbnzZrAl1C3AuwXiOf1G0= -github.com/charmbracelet/catwalk v0.10.2/go.mod h1:qg+Yl9oaZTkTvRscqbxfttzOFQ4v0pOT5XwC7b5O0NQ= +github.com/charmbracelet/catwalk v0.11.0 h1:PU3rkc4h4YVJEn9Iyb/1rQAaF4hEd04fuG4tj3vv4dg= +github.com/charmbracelet/catwalk v0.11.0/go.mod h1:qg+Yl9oaZTkTvRscqbxfttzOFQ4v0pOT5XwC7b5O0NQ= github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk= github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk= github.com/charmbracelet/fang v0.4.4 h1:G4qKxF6or/eTPgmAolwPuRNyuci3hTUGGX1rj1YkHJY= diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 5f48ee1c7b1434af7453fa19d567d8a194c377d1..62025b1943af245e94da6da744036e8040029c65 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -27,7 +27,9 @@ import ( "charm.land/fantasy/providers/google" "charm.land/fantasy/providers/openai" "charm.land/fantasy/providers/openrouter" + "charm.land/lipgloss/v2" "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/agent/hyper" "github.com/charmbracelet/crush/internal/agent/tools" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/csync" @@ -454,8 +456,22 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "") } else if isPermissionErr { currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "") + } else if errors.Is(err, hyper.ErrNoCredits) { + url := hyper.BaseURL() + link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url) + currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link) } else if errors.As(err, &providerErr) { - currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message) + if providerErr.Message == "The requested model is not supported." { + url := "https://github.com/settings/copilot/features" + link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url) + currentAssistant.AddFinish( + message.FinishReasonError, + "Copilot model not enabled", + fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait a minute before trying again. %s", a.largeModel.CatwalkCfg.Name, link), + ) + } else { + currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message) + } } else if errors.As(err, &fantasyErr) { currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message) } else { diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index 5690b1311a4fd0d0290277d01e734b17ab3d2a66..ea214a5bcfb8e65e6d5dee826854345168a2eb86 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -17,6 +17,7 @@ import ( "charm.land/fantasy" "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/agent/hyper" "github.com/charmbracelet/crush/internal/agent/prompt" "github.com/charmbracelet/crush/internal/agent/tools" "github.com/charmbracelet/crush/internal/config" @@ -668,6 +669,18 @@ func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, optio return google.New(opts...) } +func (c *coordinator) buildHyperProvider(baseURL, apiKey string) (fantasy.Provider, error) { + opts := []hyper.Option{ + hyper.WithBaseURL(baseURL), + hyper.WithAPIKey(apiKey), + } + if c.cfg.Options.Debug { + httpClient := log.NewHTTPClient() + opts = append(opts, hyper.WithHTTPClient(httpClient)) + } + return hyper.New(opts...) +} + func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool { if model.Think { return true @@ -728,6 +741,8 @@ func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model con providerCfg.ExtraBody["tool_stream"] = true } return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody) + case hyper.Name: + return c.buildHyperProvider(baseURL, apiKey) default: return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type) } diff --git a/internal/agent/hyper/provider.go b/internal/agent/hyper/provider.go new file mode 100644 index 0000000000000000000000000000000000000000..ea2f4a18eaeec017f5f3f02576504a424f50bbf1 --- /dev/null +++ b/internal/agent/hyper/provider.go @@ -0,0 +1,330 @@ +// Package hyper provides a fantasy.Provider that proxies requests to Hyper. +package hyper + +import ( + "bufio" + "bytes" + "cmp" + "context" + _ "embed" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "maps" + "net/http" + "net/url" + "os" + "strconv" + "strings" + "sync" + "time" + + "charm.land/fantasy" + "charm.land/fantasy/object" + "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/event" +) + +//go:generate wget -O provider.json https://console.charm.land/api/v1/provider + +//go:embed provider.json +var embedded []byte + +// Enabled returns true if hyper is enabled. +var Enabled = sync.OnceValue(func() bool { + b, _ := strconv.ParseBool( + cmp.Or( + os.Getenv("HYPER"), + os.Getenv("HYPERCRUSH"), + os.Getenv("HYPER_ENABLE"), + os.Getenv("HYPER_ENABLED"), + ), + ) + return b +}) + +// Embedded returns the embedded Hyper provider. +var Embedded = sync.OnceValue(func() catwalk.Provider { + var provider catwalk.Provider + if err := json.Unmarshal(embedded, &provider); err != nil { + slog.Error("could not use embedded provider data", "err", err) + } + return provider +}) + +const ( + // Name is the default name of this meta provider. + Name = "hyper" + // defaultBaseURL is the default proxy URL. + // TODO: change this to production URL when ready. + defaultBaseURL = "https://console.charm.land" +) + +// BaseURL returns the base URL, which is either $HYPER_URL or the default. +var BaseURL = sync.OnceValue(func() string { + return cmp.Or(os.Getenv("HYPER_URL"), defaultBaseURL) +}) + +var ErrNoCredits = errors.New("you're out of credits") + +type options struct { + baseURL string + apiKey string + name string + headers map[string]string + client *http.Client +} + +// Option configures the proxy provider. +type Option = func(*options) + +// New creates a new proxy provider. +func New(opts ...Option) (fantasy.Provider, error) { + o := options{ + baseURL: BaseURL() + "/api/v1/fantasy", + name: Name, + headers: map[string]string{ + "x-crush-id": event.GetID(), + }, + client: &http.Client{Timeout: 0}, // stream-safe + } + for _, opt := range opts { + opt(&o) + } + return &provider{options: o}, nil +} + +// WithBaseURL sets the proxy base URL (e.g. http://localhost:8080). +func WithBaseURL(url string) Option { return func(o *options) { o.baseURL = url } } + +// WithName sets the provider name. +func WithName(name string) Option { return func(o *options) { o.name = name } } + +// WithHeaders sets extra headers sent to the proxy. +func WithHeaders(headers map[string]string) Option { + return func(o *options) { + maps.Copy(o.headers, headers) + } +} + +// WithHTTPClient sets custom HTTP client. +func WithHTTPClient(c *http.Client) Option { return func(o *options) { o.client = c } } + +// WithAPIKey sets the API key. +func WithAPIKey(key string) Option { + return func(o *options) { + o.apiKey = key + } +} + +type provider struct{ options options } + +func (p *provider) Name() string { return p.options.name } + +// LanguageModel implements fantasy.Provider. +func (p *provider) LanguageModel(_ context.Context, modelID string) (fantasy.LanguageModel, error) { + if modelID == "" { + return nil, errors.New("missing model id") + } + return &languageModel{modelID: modelID, provider: p.options.name, opts: p.options}, nil +} + +type languageModel struct { + provider string + modelID string + opts options +} + +// GenerateObject implements fantasy.LanguageModel. +func (m *languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) { + return object.GenerateWithTool(ctx, m, call) +} + +// StreamObject implements fantasy.LanguageModel. +func (m *languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) { + return object.StreamWithTool(ctx, m, call) +} + +func (m *languageModel) Provider() string { return m.provider } +func (m *languageModel) Model() string { return m.modelID } + +// Generate implements fantasy.LanguageModel by calling the proxy JSON endpoint. +func (m *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) { + resp, err := m.doRequest(ctx, false, call) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + b, _ := ioReadAllLimit(resp.Body, 64*1024) + return nil, fmt.Errorf("proxy generate error: %s", strings.TrimSpace(string(b))) + } + var out fantasy.Response + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + return nil, err + } + return &out, nil +} + +// Stream implements fantasy.LanguageModel using SSE from the proxy. +func (m *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { + // Prefer explicit /stream endpoint + resp, err := m.doRequest(ctx, true, call) + if err != nil { + return nil, err + } + switch resp.StatusCode { + case http.StatusTooManyRequests: + _ = resp.Body.Close() + return nil, toProviderError(resp, retryAfter(resp)) + case http.StatusUnauthorized: + _ = resp.Body.Close() + return nil, toProviderError(resp, "") + case http.StatusPaymentRequired: + _ = resp.Body.Close() + return nil, ErrNoCredits + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + defer func() { _ = resp.Body.Close() }() + b, _ := ioReadAllLimit(resp.Body, 64*1024) + return nil, &fantasy.ProviderError{ + Title: "Stream Error", + Message: strings.TrimSpace(string(b)), + StatusCode: resp.StatusCode, + } + } + + return func(yield func(fantasy.StreamPart) bool) { + defer func() { _ = resp.Body.Close() }() + scanner := bufio.NewScanner(resp.Body) + buf := make([]byte, 0, 64*1024) + scanner.Buffer(buf, 4*1024*1024) + + var ( + event string + dataBuf bytes.Buffer + sawFinish bool + dispatch = func() bool { + if dataBuf.Len() == 0 || event == "" { + dataBuf.Reset() + event = "" + return true + } + var part fantasy.StreamPart + if err := json.Unmarshal(dataBuf.Bytes(), &part); err != nil { + return yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err}) + } + if part.Type == fantasy.StreamPartTypeFinish { + sawFinish = true + } + ok := yield(part) + dataBuf.Reset() + event = "" + return ok + } + ) + + for scanner.Scan() { + line := scanner.Text() + if line == "" { // event boundary + if !dispatch() { + return + } + continue + } + if strings.HasPrefix(line, ":") { // comment / ping + continue + } + if strings.HasPrefix(line, "event: ") { + event = strings.TrimSpace(line[len("event: "):]) + continue + } + if strings.HasPrefix(line, "data: ") { + if dataBuf.Len() > 0 { + dataBuf.WriteByte('\n') + } + dataBuf.WriteString(line[len("data: "):]) + continue + } + } + if err := scanner.Err(); err != nil && + !errors.Is(err, context.Canceled) && + !errors.Is(err, context.DeadlineExceeded) { + yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err}) + return + } + // flush any pending data + _ = dispatch() + if !sawFinish { + _ = yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish}) + } + }, nil +} + +func (m *languageModel) doRequest(ctx context.Context, stream bool, call fantasy.Call) (*http.Response, error) { + addr, err := url.Parse(m.opts.baseURL) + if err != nil { + return nil, err + } + addr = addr.JoinPath(m.modelID) + if stream { + addr = addr.JoinPath("stream") + } else { + addr = addr.JoinPath("generate") + } + + body, err := json.Marshal(call) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, addr.String(), bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + if stream { + req.Header.Set("Accept", "text/event-stream") + } else { + req.Header.Set("Accept", "application/json") + } + for k, v := range m.opts.headers { + req.Header.Set(k, v) + } + + if m.opts.apiKey != "" { + req.Header.Set("Authorization", m.opts.apiKey) + } + return m.opts.client.Do(req) +} + +// ioReadAllLimit reads up to n bytes. +func ioReadAllLimit(r io.Reader, n int64) ([]byte, error) { + var b bytes.Buffer + if n <= 0 { + n = 1 << 20 + } + lr := &io.LimitedReader{R: r, N: n} + _, err := b.ReadFrom(lr) + return b.Bytes(), err +} + +func toProviderError(resp *http.Response, message string) error { + return &fantasy.ProviderError{ + Title: fantasy.ErrorTitleForStatusCode(resp.StatusCode), + Message: message, + StatusCode: resp.StatusCode, + } +} + +func retryAfter(resp *http.Response) string { + after, err := strconv.Atoi(resp.Header.Get("Retry-After")) + if err == nil && after > 0 { + d := time.Duration(after) * time.Second + return "Try again in " + d.String() + } + return "Try again in later" +} diff --git a/internal/agent/hyper/provider.json b/internal/agent/hyper/provider.json new file mode 100644 index 0000000000000000000000000000000000000000..5558750e38e35024615b41b71243888a1a1ebd6c --- /dev/null +++ b/internal/agent/hyper/provider.json @@ -0,0 +1 @@ +{"name":"Charm Hyper","id":"hyper","api_endpoint":"https://console.charm.land/api/v1/fantasy","type":"hyper","default_large_model_id":"claude-sonnet-4-5","default_small_model_id":"claude-3-5-haiku","models":[{"id":"Kimi-K2-0905","name":"Kimi K2 0905","cost_per_1m_in":0.55,"cost_per_1m_out":2.19,"cost_per_1m_in_cached":0,"cost_per_1m_out_cached":0,"context_window":256000,"default_max_tokens":10000,"can_reason":true,"default_reasoning_effort":"medium","supports_attachments":false,"options":{}},{"id":"claude-3-5-haiku","name":"Claude 3.5 Haiku","cost_per_1m_in":0.7999999999999999,"cost_per_1m_out":4,"cost_per_1m_in_cached":1,"cost_per_1m_out_cached":0.08,"context_window":200000,"default_max_tokens":5000,"can_reason":false,"supports_attachments":true,"options":{}},{"id":"claude-3-5-sonnet","name":"Claude 3.5 Sonnet (New)","cost_per_1m_in":3,"cost_per_1m_out":15,"cost_per_1m_in_cached":3.75,"cost_per_1m_out_cached":0.3,"context_window":200000,"default_max_tokens":5000,"can_reason":false,"supports_attachments":true,"options":{}},{"id":"claude-3-7-sonnet","name":"Claude 3.7 Sonnet","cost_per_1m_in":3,"cost_per_1m_out":15,"cost_per_1m_in_cached":3.75,"cost_per_1m_out_cached":0.3,"context_window":200000,"default_max_tokens":50000,"can_reason":true,"supports_attachments":true,"options":{}},{"id":"claude-haiku-4-5","name":"Claude 4.5 Haiku","cost_per_1m_in":1,"cost_per_1m_out":5,"cost_per_1m_in_cached":1.25,"cost_per_1m_out_cached":0.09999999999999999,"context_window":200000,"default_max_tokens":32000,"can_reason":true,"supports_attachments":true,"options":{}},{"id":"claude-opus-4","name":"Claude Opus 4","cost_per_1m_in":15,"cost_per_1m_out":75,"cost_per_1m_in_cached":18.75,"cost_per_1m_out_cached":1.5,"context_window":200000,"default_max_tokens":32000,"can_reason":true,"supports_attachments":true,"options":{}},{"id":"claude-opus-4-1","name":"Claude Opus 4.1","cost_per_1m_in":15,"cost_per_1m_out":75,"cost_per_1m_in_cached":18.75,"cost_per_1m_out_cached":1.5,"context_window":200000,"default_max_tokens":32000,"can_reason":true,"supports_attachments":true,"options":{}},{"id":"claude-opus-4-5","name":"Claude Opus 4.5","cost_per_1m_in":5,"cost_per_1m_out":25,"cost_per_1m_in_cached":6.25,"cost_per_1m_out_cached":0.5,"context_window":200000,"default_max_tokens":50000,"can_reason":true,"supports_attachments":true,"options":{}},{"id":"claude-sonnet-4","name":"Claude Sonnet 4","cost_per_1m_in":3,"cost_per_1m_out":15,"cost_per_1m_in_cached":3.75,"cost_per_1m_out_cached":0.3,"context_window":200000,"default_max_tokens":50000,"can_reason":true,"supports_attachments":true,"options":{}},{"id":"claude-sonnet-4-5","name":"Claude Sonnet 4.5","cost_per_1m_in":3,"cost_per_1m_out":15,"cost_per_1m_in_cached":3.75,"cost_per_1m_out_cached":0.3,"context_window":200000,"default_max_tokens":50000,"can_reason":true,"supports_attachments":true,"options":{}},{"id":"gemini-2.5-flash","name":"Gemini 2.5 Flash","cost_per_1m_in":0.3,"cost_per_1m_out":2.5,"cost_per_1m_in_cached":0.3833,"cost_per_1m_out_cached":0.075,"context_window":1048576,"default_max_tokens":50000,"can_reason":true,"supports_attachments":true,"options":{}},{"id":"gemini-2.5-pro","name":"Gemini 2.5 Pro","cost_per_1m_in":1.25,"cost_per_1m_out":10,"cost_per_1m_in_cached":1.625,"cost_per_1m_out_cached":0.31,"context_window":1048576,"default_max_tokens":50000,"can_reason":true,"supports_attachments":true,"options":{}},{"id":"glm-4.6","name":"GLM-4.6","cost_per_1m_in":0.6,"cost_per_1m_out":2.2,"cost_per_1m_in_cached":0.11,"cost_per_1m_out_cached":0,"context_window":204800,"default_max_tokens":131072,"can_reason":true,"default_reasoning_effort":"medium","supports_attachments":false,"options":{}},{"id":"gpt-4.1","name":"GPT-4.1","cost_per_1m_in":2,"cost_per_1m_out":8,"cost_per_1m_in_cached":0,"cost_per_1m_out_cached":0.5,"context_window":1047576,"default_max_tokens":16384,"can_reason":false,"supports_attachments":true,"options":{}},{"id":"gpt-4.1-mini","name":"GPT-4.1 Mini","cost_per_1m_in":0.39999999999999997,"cost_per_1m_out":1.5999999999999999,"cost_per_1m_in_cached":0,"cost_per_1m_out_cached":0.09999999999999999,"context_window":1047576,"default_max_tokens":16384,"can_reason":false,"supports_attachments":true,"options":{}},{"id":"gpt-4.1-nano","name":"GPT-4.1 Nano","cost_per_1m_in":0.09999999999999999,"cost_per_1m_out":0.39999999999999997,"cost_per_1m_in_cached":0,"cost_per_1m_out_cached":0.024999999999999998,"context_window":1047576,"default_max_tokens":16384,"can_reason":false,"supports_attachments":true,"options":{}},{"id":"gpt-4o","name":"GPT-4o","cost_per_1m_in":2.5,"cost_per_1m_out":10,"cost_per_1m_in_cached":0,"cost_per_1m_out_cached":1.25,"context_window":128000,"default_max_tokens":8192,"can_reason":false,"supports_attachments":true,"options":{}},{"id":"gpt-4o-mini","name":"GPT-4o-mini","cost_per_1m_in":0.15,"cost_per_1m_out":0.6,"cost_per_1m_in_cached":0,"cost_per_1m_out_cached":0.075,"context_window":128000,"default_max_tokens":8192,"can_reason":false,"supports_attachments":true,"options":{}},{"id":"gpt-5","name":"GPT-5","cost_per_1m_in":1.25,"cost_per_1m_out":10,"cost_per_1m_in_cached":0.125,"cost_per_1m_out_cached":0.125,"context_window":400000,"default_max_tokens":128000,"can_reason":true,"reasoning_levels":["minimal","low","medium","high"],"default_reasoning_effort":"medium","supports_attachments":true,"options":{}},{"id":"gpt-5-codex","name":"GPT-5 Codex","cost_per_1m_in":1.25,"cost_per_1m_out":10,"cost_per_1m_in_cached":0.125,"cost_per_1m_out_cached":0.125,"context_window":400000,"default_max_tokens":128000,"can_reason":true,"reasoning_levels":["minimal","low","medium","high"],"default_reasoning_effort":"medium","supports_attachments":true,"options":{}},{"id":"gpt-5-mini","name":"GPT-5 Mini","cost_per_1m_in":0.25,"cost_per_1m_out":2,"cost_per_1m_in_cached":0.025,"cost_per_1m_out_cached":0.025,"context_window":400000,"default_max_tokens":128000,"can_reason":true,"reasoning_levels":["low","medium","high"],"default_reasoning_effort":"medium","supports_attachments":true,"options":{}},{"id":"gpt-5-nano","name":"GPT-5 Nano","cost_per_1m_in":0.05,"cost_per_1m_out":0.4,"cost_per_1m_in_cached":0.005,"cost_per_1m_out_cached":0.005,"context_window":400000,"default_max_tokens":128000,"can_reason":true,"reasoning_levels":["low","medium","high"],"default_reasoning_effort":"medium","supports_attachments":true,"options":{}},{"id":"gpt-5.1","name":"GPT-5.1","cost_per_1m_in":1.25,"cost_per_1m_out":10,"cost_per_1m_in_cached":0.125,"cost_per_1m_out_cached":0.125,"context_window":400000,"default_max_tokens":128000,"can_reason":true,"reasoning_levels":["minimal","low","medium","high"],"default_reasoning_effort":"medium","supports_attachments":true,"options":{}},{"id":"gpt-5.1-codex","name":"GPT-5.1 Codex","cost_per_1m_in":1.25,"cost_per_1m_out":10,"cost_per_1m_in_cached":0.125,"cost_per_1m_out_cached":0.125,"context_window":400000,"default_max_tokens":128000,"can_reason":true,"reasoning_levels":["minimal","low","medium","high"],"default_reasoning_effort":"medium","supports_attachments":true,"options":{}},{"id":"gpt-5.1-codex-max","name":"GPT-5.1 Codex Max","cost_per_1m_in":1.25,"cost_per_1m_out":10,"cost_per_1m_in_cached":0.125,"cost_per_1m_out_cached":0.125,"context_window":400000,"default_max_tokens":128000,"can_reason":true,"reasoning_levels":["minimal","low","medium","high"],"default_reasoning_effort":"medium","supports_attachments":true,"options":{}},{"id":"gpt-5.1-codex-mini","name":"GPT-5.1 Codex Mini","cost_per_1m_in":0.25,"cost_per_1m_out":2,"cost_per_1m_in_cached":0.025,"cost_per_1m_out_cached":0.025,"context_window":400000,"default_max_tokens":128000,"can_reason":true,"reasoning_levels":["low","medium","high"],"default_reasoning_effort":"medium","supports_attachments":true,"options":{}},{"id":"gpt-5.2","name":"GPT-5.2","cost_per_1m_in":1.75,"cost_per_1m_out":14,"cost_per_1m_in_cached":0.175,"cost_per_1m_out_cached":0.175,"context_window":400000,"default_max_tokens":128000,"can_reason":true,"reasoning_levels":["minimal","low","medium","high"],"default_reasoning_effort":"medium","supports_attachments":true,"options":{}},{"id":"o3","name":"o3","cost_per_1m_in":2,"cost_per_1m_out":8,"cost_per_1m_in_cached":0,"cost_per_1m_out_cached":0.5,"context_window":200000,"default_max_tokens":50000,"can_reason":true,"reasoning_levels":["low","medium","high"],"default_reasoning_effort":"medium","supports_attachments":true,"options":{}},{"id":"o3-mini","name":"o3 Mini","cost_per_1m_in":1.1,"cost_per_1m_out":4.4,"cost_per_1m_in_cached":0,"cost_per_1m_out_cached":0.55,"context_window":200000,"default_max_tokens":50000,"can_reason":true,"reasoning_levels":["low","medium","high"],"default_reasoning_effort":"medium","supports_attachments":false,"options":{}},{"id":"o4-mini","name":"o4 Mini","cost_per_1m_in":1.1,"cost_per_1m_out":4.4,"cost_per_1m_in_cached":0,"cost_per_1m_out_cached":0.275,"context_window":200000,"default_max_tokens":50000,"can_reason":true,"reasoning_levels":["low","medium","high"],"default_reasoning_effort":"medium","supports_attachments":true,"options":{}},{"id":"qwen3-coder-480b-a35b-instruct","name":"Qwen 3 480B Coder","cost_per_1m_in":0.82,"cost_per_1m_out":3.29,"cost_per_1m_in_cached":0,"cost_per_1m_out_cached":0,"context_window":131072,"default_max_tokens":65536,"can_reason":false,"supports_attachments":false,"options":{}}]} \ No newline at end of file diff --git a/internal/cmd/login.go b/internal/cmd/login.go index 00c155cbfd23a95005da10a60f9ea36ab50cbd7d..334f32c7067d6820323363ea80a8978ae9229685 100644 --- a/internal/cmd/login.go +++ b/internal/cmd/login.go @@ -9,8 +9,14 @@ import ( "strings" "charm.land/lipgloss/v2" + "github.com/atotto/clipboard" + hyperp "github.com/charmbracelet/crush/internal/agent/hyper" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/oauth" "github.com/charmbracelet/crush/internal/oauth/claude" + "github.com/charmbracelet/crush/internal/oauth/copilot" + "github.com/charmbracelet/crush/internal/oauth/hyper" + "github.com/pkg/browser" "github.com/spf13/cobra" ) @@ -20,46 +26,122 @@ var loginCmd = &cobra.Command{ Short: "Login Crush to a platform", Long: `Login Crush to a specified platform. The platform should be provided as an argument. -Available platforms are: claude.`, +Available platforms are: hyper, claude, copilot.`, Example: ` +# Authenticate with Charm Hyper +crush login + # Authenticate with Claude Code Max crush login claude + +# Authenticate with GitHub Copilot +crush login copilot `, ValidArgs: []cobra.Completion{ + "hyper", "claude", "anthropic", + "copilot", + "github", + "github-copilot", }, - Args: cobra.ExactArgs(1), + Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - if len(args) > 1 { - return fmt.Errorf("wrong number of arguments") - } - if len(args) == 0 || args[0] == "" { - return cmd.Help() - } - app, err := setupAppWithProgressBar(cmd) if err != nil { return err } defer app.Shutdown() - switch args[0] { + provider := "hyper" + if len(args) > 0 { + provider = args[0] + } + switch provider { + case "hyper": + return loginHyper() case "anthropic", "claude": return loginClaude() + case "copilot", "github", "github-copilot": + return loginCopilot() default: return fmt.Errorf("unknown platform: %s", args[0]) } }, } -func loginClaude() error { +func loginHyper() error { + cfg := config.Get() + if !hyperp.Enabled() { + return fmt.Errorf("hyper not enabled") + } ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill) - go func() { - <-ctx.Done() - cancel() - os.Exit(1) - }() + defer cancel() + + resp, err := hyper.InitiateDeviceAuth(ctx) + if err != nil { + return err + } + + if clipboard.WriteAll(resp.UserCode) == nil { + fmt.Println("The following code should be on clipboard already:") + } else { + fmt.Println("Copy the following code:") + } + + fmt.Println() + fmt.Println(lipgloss.NewStyle().Bold(true).Render(resp.UserCode)) + fmt.Println() + fmt.Println("Press enter to open this URL, and then paste it there:") + fmt.Println() + fmt.Println(lipgloss.NewStyle().Hyperlink(resp.VerificationURL, "id=hyper").Render(resp.VerificationURL)) + fmt.Println() + waitEnter() + if err := browser.OpenURL(resp.VerificationURL); err != nil { + fmt.Println("Could not open the URL. You'll need to manually open the URL in your browser.") + } + + fmt.Println("Exchanging authorization code...") + refreshToken, err := hyper.PollForToken(ctx, resp.DeviceCode, resp.ExpiresIn) + if err != nil { + return err + } + + fmt.Println("Exchanging refresh token for access token...") + token, err := hyper.ExchangeToken(ctx, refreshToken) + if err != nil { + return err + } + + fmt.Println("Verifying access token...") + introspect, err := hyper.IntrospectToken(ctx, token.AccessToken) + if err != nil { + return fmt.Errorf("token introspection failed: %w", err) + } + if !introspect.Active { + return fmt.Errorf("access token is not active") + } + + if err := cmp.Or( + cfg.SetConfigField("providers.hyper.api_key", token.AccessToken), + cfg.SetConfigField("providers.hyper.oauth", token), + ); err != nil { + return err + } + + fmt.Println() + fmt.Println("You're now authenticated with Hyper!") + return nil +} + +func loginClaude() error { + ctx := getLoginContext() + + cfg := config.Get() + if cfg.HasConfigField("providers.anthropic.oauth") { + fmt.Println("You are already logged in to Claude.") + return nil + } verifier, challenge, err := claude.GetChallenge() if err != nil { @@ -94,7 +176,6 @@ func loginClaude() error { return err } - cfg := config.Get() if err := cmp.Or( cfg.SetConfigField("providers.anthropic.api_key", token.AccessToken), cfg.SetConfigField("providers.anthropic.oauth", token), @@ -106,3 +187,83 @@ func loginClaude() error { fmt.Println("You're now authenticated with Claude Code Max!") return nil } + +func loginCopilot() error { + ctx := getLoginContext() + + cfg := config.Get() + if cfg.HasConfigField("providers.copilot.oauth") { + fmt.Println("You are already logged in to GitHub Copilot.") + return nil + } + + diskToken, hasDiskToken := copilot.RefreshTokenFromDisk() + var token *oauth.Token + + switch { + case hasDiskToken: + fmt.Println("Found existing GitHub Copilot token on disk. Using it to authenticate...") + + t, err := copilot.RefreshToken(ctx, diskToken) + if err != nil { + return fmt.Errorf("unable to refresh token from disk: %w", err) + } + token = t + default: + fmt.Println("Requesting device code from GitHub...") + dc, err := copilot.RequestDeviceCode(ctx) + if err != nil { + return err + } + + fmt.Println() + fmt.Println("Open the following URL and follow the instructions to authenticate with GitHub Copilot:") + fmt.Println() + fmt.Println(lipgloss.NewStyle().Hyperlink(dc.VerificationURI, "id=copilot").Render(dc.VerificationURI)) + fmt.Println() + fmt.Println("Code:", lipgloss.NewStyle().Bold(true).Render(dc.UserCode)) + fmt.Println() + fmt.Println("Waiting for authorization...") + + t, err := copilot.PollForToken(ctx, dc) + if err == copilot.ErrNotAvailable { + fmt.Println() + fmt.Println("GitHub Copilot is unavailable for this account. To signup, go to the following page:") + fmt.Println() + fmt.Println(lipgloss.NewStyle().Hyperlink(copilot.SignupURL, "id=copilot-signup").Render(copilot.SignupURL)) + fmt.Println() + fmt.Println("You may be able to request free access if elegible. For more information, see:") + fmt.Println() + fmt.Println(lipgloss.NewStyle().Hyperlink(copilot.FreeURL, "id=copilot-free").Render(copilot.FreeURL)) + } + if err != nil { + return err + } + token = t + } + + if err := cmp.Or( + cfg.SetConfigField("providers.copilot.api_key", token.AccessToken), + cfg.SetConfigField("providers.copilot.oauth", token), + ); err != nil { + return err + } + + fmt.Println() + fmt.Println("You're now authenticated with GitHub Copilot!") + return nil +} + +func getLoginContext() context.Context { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill) + go func() { + <-ctx.Done() + cancel() + os.Exit(1) + }() + return ctx +} + +func waitEnter() { + _, _ = fmt.Scanln() +} diff --git a/internal/cmd/projects.go b/internal/cmd/projects.go index 15c747834129b06829fc46832e3e1a09538de3d5..45a18384b7531be10206348ffcce6418979984e0 100644 --- a/internal/cmd/projects.go +++ b/internal/cmd/projects.go @@ -13,14 +13,13 @@ import ( var projectsCmd = &cobra.Command{ Use: "projects", - Short: "List all tracked project directories", - Long: `List all directories where Crush has been used. -This includes the working directory, data directory path, and last accessed time.`, + Short: "List project directories", + Long: "List directories where Crush project data is known to exist", Example: ` # List all projects in a table crush projects -# Output as JSON +# Output projects data as JSON crush projects --json `, RunE: func(cmd *cobra.Command, args []string) error { diff --git a/internal/cmd/update_providers.go b/internal/cmd/update_providers.go index 599d2c90954ca43888197961ec3bea4372285071..3b4b35b681fb003ff4d942c8052101253b364a1c 100644 --- a/internal/cmd/update_providers.go +++ b/internal/cmd/update_providers.go @@ -10,22 +10,30 @@ import ( "github.com/spf13/cobra" ) +var updateProvidersSource string + var updateProvidersCmd = &cobra.Command{ Use: "update-providers [path-or-url]", Short: "Update providers", - Long: `Update the list of providers from a specified local path or remote URL.`, + Long: `Update provider information from a specified local path or remote URL.`, Example: ` -# Update providers remotely from Catwalk +# Update Catwalk providers remotely (default) crush update-providers -# Update providers from a custom URL -crush update-providers https://example.com/ +# Update Catwalk providers from a custom URL +crush update-providers https://example.com/providers.json -# Update providers from a local file +# Update Catwalk providers from a local file crush update-providers /path/to/local-providers.json -# Update providers from embedded version +# Update Catwalk providers from embedded version crush update-providers embedded + +# Update Hyper provider information +crush update-providers --source=hyper + +# Update Hyper from a custom URL +crush update-providers --source=hyper https://hyper.example.com `, RunE: func(cmd *cobra.Command, args []string) error { // NOTE(@andreynering): We want to skip logging output do stdout here. @@ -36,7 +44,17 @@ crush update-providers embedded pathOrURL = args[0] } - if err := config.UpdateProviders(pathOrURL); err != nil { + var err error + switch updateProvidersSource { + case "catwalk": + err = config.UpdateProviders(pathOrURL) + case "hyper": + err = config.UpdateHyper(pathOrURL) + default: + return fmt.Errorf("invalid source %q, must be 'catwalk' or 'hyper'", updateProvidersSource) + } + + if err != nil { return err } @@ -52,9 +70,13 @@ crush update-providers embedded SetString("SUCCESS") textStyle := lipgloss.NewStyle(). MarginLeft(2). - SetString("Providers updated successfully.") + SetString(fmt.Sprintf("%s provider updated successfully.", updateProvidersSource)) fmt.Printf("%s\n%s\n\n", headerStyle.Render(), textStyle.Render()) return nil }, } + +func init() { + updateProvidersCmd.Flags().StringVar(&updateProvidersSource, "source", "catwalk", "Provider source to update (catwalk or hyper)") +} diff --git a/internal/config/catwalk.go b/internal/config/catwalk.go new file mode 100644 index 0000000000000000000000000000000000000000..c3cc2eb69d47e1a85e35164fda09d0f73761b820 --- /dev/null +++ b/internal/config/catwalk.go @@ -0,0 +1,82 @@ +package config + +import ( + "context" + "errors" + "log/slog" + "sync" + "sync/atomic" + + "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/charmbracelet/catwalk/pkg/embedded" +) + +type catwalkClient interface { + GetProviders(context.Context, string) ([]catwalk.Provider, error) +} + +var _ syncer[[]catwalk.Provider] = (*catwalkSync)(nil) + +type catwalkSync struct { + once sync.Once + result []catwalk.Provider + cache cache[[]catwalk.Provider] + client catwalkClient + autoupdate bool + init atomic.Bool +} + +func (s *catwalkSync) Init(client catwalkClient, path string, autoupdate bool) { + s.client = client + s.cache = newCache[[]catwalk.Provider](path) + s.autoupdate = autoupdate + s.init.Store(true) +} + +func (s *catwalkSync) Get(ctx context.Context) ([]catwalk.Provider, error) { + if !s.init.Load() { + panic("called Get before Init") + } + + var throwErr error + s.once.Do(func() { + if !s.autoupdate { + slog.Info("Using embedded Catwalk providers") + s.result = embedded.GetAll() + return + } + + cached, etag, cachedErr := s.cache.Get() + if len(cached) == 0 || cachedErr != nil { + // if cached file is empty, default to embedded providers + cached = embedded.GetAll() + } + + slog.Info("Fetching providers from Catwalk") + result, err := s.client.GetProviders(ctx, etag) + if errors.Is(err, context.DeadlineExceeded) { + slog.Warn("Catwalk providers not updated in time") + s.result = cached + return + } + if errors.Is(err, catwalk.ErrNotModified) { + slog.Info("Catwalk providers not modified") + s.result = cached + return + } + if err != nil { + // On error, fall back to cached (which defaults to embedded if empty). + s.result = cached + return + } + if len(result) == 0 { + s.result = cached + throwErr = errors.New("empty providers list from catwalk") + return + } + + s.result = result + throwErr = s.cache.Store(result) + }) + return s.result, throwErr +} diff --git a/internal/config/catwalk_test.go b/internal/config/catwalk_test.go new file mode 100644 index 0000000000000000000000000000000000000000..55322b34eb7252f8cae75fb46996f45bd31abe5e --- /dev/null +++ b/internal/config/catwalk_test.go @@ -0,0 +1,221 @@ +package config + +import ( + "context" + "encoding/json" + "errors" + "os" + "testing" + + "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/stretchr/testify/require" +) + +type mockCatwalkClient struct { + providers []catwalk.Provider + err error + callCount int +} + +func (m *mockCatwalkClient) GetProviders(ctx context.Context, etag string) ([]catwalk.Provider, error) { + m.callCount++ + return m.providers, m.err +} + +func TestCatwalkSync_Init(t *testing.T) { + t.Parallel() + + syncer := &catwalkSync{} + client := &mockCatwalkClient{} + path := "/tmp/test.json" + + syncer.Init(client, path, true) + + require.True(t, syncer.init.Load()) + require.Equal(t, client, syncer.client) + require.Equal(t, path, syncer.cache.path) + require.True(t, syncer.autoupdate) +} + +func TestCatwalkSync_GetPanicIfNotInit(t *testing.T) { + t.Parallel() + + syncer := &catwalkSync{} + require.Panics(t, func() { + _, _ = syncer.Get(t.Context()) + }) +} + +func TestCatwalkSync_GetWithAutoUpdateDisabled(t *testing.T) { + t.Parallel() + + syncer := &catwalkSync{} + client := &mockCatwalkClient{ + providers: []catwalk.Provider{{Name: "should-not-be-used"}}, + } + path := t.TempDir() + "/providers.json" + + syncer.Init(client, path, false) + + providers, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.NotEmpty(t, providers) + require.Equal(t, 0, client.callCount, "Client should not be called when autoupdate is disabled") + + // Should return embedded providers. + for _, p := range providers { + require.NotEqual(t, "should-not-be-used", p.Name) + } +} + +func TestCatwalkSync_GetFreshProviders(t *testing.T) { + t.Parallel() + + syncer := &catwalkSync{} + client := &mockCatwalkClient{ + providers: []catwalk.Provider{ + {Name: "Fresh Provider", ID: "fresh"}, + }, + } + path := t.TempDir() + "/providers.json" + + syncer.Init(client, path, true) + + providers, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Len(t, providers, 1) + require.Equal(t, "Fresh Provider", providers[0].Name) + require.Equal(t, 1, client.callCount) + + // Verify cache was written. + fileInfo, err := os.Stat(path) + require.NoError(t, err) + require.False(t, fileInfo.IsDir()) +} + +func TestCatwalkSync_GetNotModifiedUsesCached(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + path := tmpDir + "/providers.json" + + // Create cache file. + cachedProviders := []catwalk.Provider{ + {Name: "Cached Provider", ID: "cached"}, + } + data, err := json.Marshal(cachedProviders) + require.NoError(t, err) + require.NoError(t, os.WriteFile(path, data, 0o644)) + + syncer := &catwalkSync{} + client := &mockCatwalkClient{ + err: catwalk.ErrNotModified, + } + + syncer.Init(client, path, true) + + providers, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Len(t, providers, 1) + require.Equal(t, "Cached Provider", providers[0].Name) + require.Equal(t, 1, client.callCount) +} + +func TestCatwalkSync_GetEmptyResultFallbackToCached(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + path := tmpDir + "/providers.json" + + // Create cache file. + cachedProviders := []catwalk.Provider{ + {Name: "Cached Provider", ID: "cached"}, + } + data, err := json.Marshal(cachedProviders) + require.NoError(t, err) + require.NoError(t, os.WriteFile(path, data, 0o644)) + + syncer := &catwalkSync{} + client := &mockCatwalkClient{ + providers: []catwalk.Provider{}, // Empty result. + } + + syncer.Init(client, path, true) + + providers, err := syncer.Get(t.Context()) + require.Error(t, err) + require.Contains(t, err.Error(), "empty providers list from catwalk") + require.Len(t, providers, 1) + require.Equal(t, "Cached Provider", providers[0].Name) +} + +func TestCatwalkSync_GetEmptyCacheDefaultsToEmbedded(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + path := tmpDir + "/providers.json" + + // Create empty cache file. + emptyProviders := []catwalk.Provider{} + data, err := json.Marshal(emptyProviders) + require.NoError(t, err) + require.NoError(t, os.WriteFile(path, data, 0o644)) + + syncer := &catwalkSync{} + client := &mockCatwalkClient{ + err: errors.New("network error"), + } + + syncer.Init(client, path, true) + + providers, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.NotEmpty(t, providers, "Should fall back to embedded providers") + + // Verify it's embedded providers by checking we have multiple common ones. + require.Greater(t, len(providers), 5) +} + +func TestCatwalkSync_GetClientError(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + path := tmpDir + "/providers.json" + + syncer := &catwalkSync{} + client := &mockCatwalkClient{ + err: errors.New("network error"), + } + + syncer.Init(client, path, true) + + providers, err := syncer.Get(t.Context()) + require.NoError(t, err) // Should fall back to embedded. + require.NotEmpty(t, providers) +} + +func TestCatwalkSync_GetCalledMultipleTimesUsesOnce(t *testing.T) { + t.Parallel() + + syncer := &catwalkSync{} + client := &mockCatwalkClient{ + providers: []catwalk.Provider{ + {Name: "Provider", ID: "test"}, + }, + } + path := t.TempDir() + "/providers.json" + + syncer.Init(client, path, true) + + // Call Get multiple times. + providers1, err1 := syncer.Get(t.Context()) + require.NoError(t, err1) + require.Len(t, providers1, 1) + + providers2, err2 := syncer.Get(t.Context()) + require.NoError(t, err2) + require.Len(t, providers2, 1) + + // Client should only be called once due to sync.Once. + require.Equal(t, 1, client.callCount) +} diff --git a/internal/config/config.go b/internal/config/config.go index af1222e55c862f06a3ca6eaf6a69c6f632aab134..a91350b5bc894161bc5fdbd44720c27b46fc1063 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "log/slog" + "maps" "net/http" "net/url" "os" @@ -13,11 +14,15 @@ import ( "time" "github.com/charmbracelet/catwalk/pkg/catwalk" + hyperp "github.com/charmbracelet/crush/internal/agent/hyper" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/oauth" "github.com/charmbracelet/crush/internal/oauth/claude" + "github.com/charmbracelet/crush/internal/oauth/copilot" + "github.com/charmbracelet/crush/internal/oauth/hyper" "github.com/invopop/jsonschema" + "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) @@ -119,8 +124,37 @@ type ProviderConfig struct { Models []catwalk.Model `json:"models,omitempty" jsonschema:"description=List of models available from this provider"` } +// ToProvider converts the [ProviderConfig] to a [catwalk.Provider]. +func (pc *ProviderConfig) ToProvider() catwalk.Provider { + // Convert config provider to provider.Provider format + provider := catwalk.Provider{ + Name: pc.Name, + ID: catwalk.InferenceProvider(pc.ID), + Models: make([]catwalk.Model, len(pc.Models)), + } + + // Convert models + for i, model := range pc.Models { + provider.Models[i] = catwalk.Model{ + ID: model.ID, + Name: model.Name, + CostPer1MIn: model.CostPer1MIn, + CostPer1MOut: model.CostPer1MOut, + CostPer1MInCached: model.CostPer1MInCached, + CostPer1MOutCached: model.CostPer1MOutCached, + ContextWindow: model.ContextWindow, + DefaultMaxTokens: model.DefaultMaxTokens, + CanReason: model.CanReason, + ReasoningLevels: model.ReasoningLevels, + DefaultReasoningEffort: model.DefaultReasoningEffort, + SupportsImages: model.SupportsImages, + } + } + + return provider +} + func (pc *ProviderConfig) SetupClaudeCode() { - pc.APIKey = fmt.Sprintf("Bearer %s", pc.OAuthToken.AccessToken) pc.SystemPromptPrefix = "You are Claude Code, Anthropic's official CLI for Claude." pc.ExtraHeaders["anthropic-version"] = "2023-06-01" @@ -135,6 +169,10 @@ func (pc *ProviderConfig) SetupClaudeCode() { pc.ExtraHeaders["anthropic-beta"] = value } +func (pc *ProviderConfig) SetupGitHubCopilot() { + maps.Copy(pc.ExtraHeaders, copilot.Headers()) +} + type MCPType string const ( @@ -451,6 +489,14 @@ func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model Selecte return nil } +func (c *Config) HasConfigField(key string) bool { + data, err := os.ReadFile(c.dataConfigDir) + if err != nil { + return false + } + return gjson.Get(string(data), key).Exists() +} + func (c *Config) SetConfigField(key string, value any) error { // read the data data, err := os.ReadFile(c.dataConfigDir) @@ -483,20 +529,33 @@ func (c *Config) RefreshOAuthToken(ctx context.Context, providerID string) error return fmt.Errorf("provider %s does not have an OAuth token", providerID) } - // Only Anthropic provider uses OAuth for now. - if providerID != string(catwalk.InferenceProviderAnthropic) { + var newToken *oauth.Token + var refreshErr error + switch providerID { + case string(catwalk.InferenceProviderAnthropic): + newToken, refreshErr = claude.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken) + case string(catwalk.InferenceProviderCopilot): + newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken) + case hyperp.Name: + newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken) + default: return fmt.Errorf("OAuth refresh not supported for provider %s", providerID) } - - newToken, err := claude.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken) - if err != nil { - return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, err) + if refreshErr != nil { + return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr) } slog.Info("Successfully refreshed OAuth token", "provider", providerID) providerConfig.OAuthToken = newToken - providerConfig.APIKey = fmt.Sprintf("Bearer %s", newToken.AccessToken) - providerConfig.SetupClaudeCode() + + switch providerID { + case string(catwalk.InferenceProviderAnthropic): + providerConfig.APIKey = fmt.Sprintf("Bearer %s", newToken.AccessToken) + providerConfig.SetupClaudeCode() + case string(catwalk.InferenceProviderCopilot): + providerConfig.APIKey = newToken.AccessToken + providerConfig.SetupGitHubCopilot() + } c.Providers.Set(providerID, providerConfig) @@ -531,7 +590,13 @@ func (c *Config) SetProviderAPIKey(providerID string, apiKey any) error { setKeyOrToken = func() { providerConfig.APIKey = v.AccessToken providerConfig.OAuthToken = v - providerConfig.SetupClaudeCode() + switch providerID { + case string(catwalk.InferenceProviderAnthropic): + providerConfig.APIKey = fmt.Sprintf("Bearer %s", v.AccessToken) + providerConfig.SetupClaudeCode() + case string(catwalk.InferenceProviderCopilot): + providerConfig.SetupGitHubCopilot() + } } } diff --git a/internal/config/copilot.go b/internal/config/copilot.go new file mode 100644 index 0000000000000000000000000000000000000000..f9ebc2f4fbddf602c67ae6fc81f5e6ca02d57b27 --- /dev/null +++ b/internal/config/copilot.go @@ -0,0 +1,43 @@ +package config + +import ( + "cmp" + "context" + "log/slog" + "testing" + + "github.com/charmbracelet/crush/internal/oauth" + "github.com/charmbracelet/crush/internal/oauth/copilot" +) + +func (c *Config) importCopilot() (*oauth.Token, bool) { + if testing.Testing() { + return nil, false + } + + if c.HasConfigField("providers.copilot.api_key") || c.HasConfigField("providers.copilot.oauth") { + return nil, false + } + + diskToken, hasDiskToken := copilot.RefreshTokenFromDisk() + if !hasDiskToken { + return nil, false + } + + slog.Info("Found existing GitHub Copilot token on disk. Authenticating...") + token, err := copilot.RefreshToken(context.TODO(), diskToken) + if err != nil { + slog.Error("Unable to import GitHub Copilot token", "error", err) + return nil, false + } + + if err := cmp.Or( + c.SetConfigField("providers.copilot.api_key", token.AccessToken), + c.SetConfigField("providers.copilot.oauth", token), + ); err != nil { + slog.Error("Unable to save GitHub Copilot token to disk", "error", err) + } + + slog.Info("GitHub Copilot successfully imported") + return token, true +} diff --git a/internal/config/hyper.go b/internal/config/hyper.go new file mode 100644 index 0000000000000000000000000000000000000000..5fe6fc5a1ee54bd19902ef4c9cc6034a6b294b6f --- /dev/null +++ b/internal/config/hyper.go @@ -0,0 +1,124 @@ +package config + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net/http" + "sync" + "sync/atomic" + "time" + + "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/agent/hyper" + xetag "github.com/charmbracelet/x/etag" +) + +type hyperClient interface { + Get(context.Context, string) (catwalk.Provider, error) +} + +var _ syncer[catwalk.Provider] = (*hyperSync)(nil) + +type hyperSync struct { + once sync.Once + result catwalk.Provider + cache cache[catwalk.Provider] + client hyperClient + autoupdate bool + init atomic.Bool +} + +func (s *hyperSync) Init(client hyperClient, path string, autoupdate bool) { + s.client = client + s.cache = newCache[catwalk.Provider](path) + s.autoupdate = autoupdate + s.init.Store(true) +} + +func (s *hyperSync) Get(ctx context.Context) (catwalk.Provider, error) { + if !s.init.Load() { + panic("called Get before Init") + } + + var throwErr error + s.once.Do(func() { + if !s.autoupdate { + slog.Info("Using embedded Hyper provider") + s.result = hyper.Embedded() + return + } + + cached, etag, cachedErr := s.cache.Get() + if cached.ID == "" || cachedErr != nil { + // if cached file is empty, default to embedded provider + cached = hyper.Embedded() + } + + slog.Info("Fetching Hyper provider") + result, err := s.client.Get(ctx, etag) + if errors.Is(err, context.DeadlineExceeded) { + slog.Warn("Hyper provider not updated in time") + s.result = cached + return + } + if errors.Is(err, catwalk.ErrNotModified) { + slog.Info("Hyper provider not modified") + s.result = cached + return + } + if len(result.Models) == 0 { + slog.Warn("Hyper did not return any models") + s.result = cached + return + } + + s.result = result + throwErr = s.cache.Store(result) + }) + return s.result, throwErr +} + +var _ hyperClient = realHyperClient{} + +type realHyperClient struct { + baseURL string +} + +// Get implements hyperClient. +func (r realHyperClient) Get(ctx context.Context, etag string) (catwalk.Provider, error) { + var result catwalk.Provider + req, err := http.NewRequestWithContext( + ctx, + http.MethodGet, + r.baseURL+"/api/v1/provider", + nil, + ) + if err != nil { + return result, fmt.Errorf("could not create request: %w", err) + } + xetag.Request(req, etag) + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return result, fmt.Errorf("failed to make request: %w", err) + } + defer resp.Body.Close() //nolint:errcheck + + if resp.StatusCode == http.StatusNotModified { + return result, catwalk.ErrNotModified + } + + if resp.StatusCode != http.StatusOK { + return result, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return result, fmt.Errorf("failed to decode response: %w", err) + } + + return result, nil +} diff --git a/internal/config/hyper_test.go b/internal/config/hyper_test.go new file mode 100644 index 0000000000000000000000000000000000000000..7141eaa1e97888b5ee6f84afc8e9658825547b46 --- /dev/null +++ b/internal/config/hyper_test.go @@ -0,0 +1,205 @@ +package config + +import ( + "context" + "encoding/json" + "errors" + "os" + "testing" + + "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/stretchr/testify/require" +) + +type mockHyperClient struct { + provider catwalk.Provider + err error + callCount int +} + +func (m *mockHyperClient) Get(ctx context.Context, etag string) (catwalk.Provider, error) { + m.callCount++ + return m.provider, m.err +} + +func TestHyperSync_Init(t *testing.T) { + t.Parallel() + + syncer := &hyperSync{} + client := &mockHyperClient{} + path := "/tmp/hyper.json" + + syncer.Init(client, path, true) + + require.True(t, syncer.init.Load()) + require.Equal(t, client, syncer.client) + require.Equal(t, path, syncer.cache.path) +} + +func TestHyperSync_GetPanicIfNotInit(t *testing.T) { + t.Parallel() + + syncer := &hyperSync{} + require.Panics(t, func() { + _, _ = syncer.Get(t.Context()) + }) +} + +func TestHyperSync_GetFreshProvider(t *testing.T) { + t.Parallel() + + syncer := &hyperSync{} + client := &mockHyperClient{ + provider: catwalk.Provider{ + Name: "Hyper", + ID: "hyper", + Models: []catwalk.Model{ + {ID: "model-1", Name: "Model 1"}, + }, + }, + } + path := t.TempDir() + "/hyper.json" + + syncer.Init(client, path, true) + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, "Hyper", provider.Name) + require.Equal(t, 1, client.callCount) + + // Verify cache was written. + fileInfo, err := os.Stat(path) + require.NoError(t, err) + require.False(t, fileInfo.IsDir()) +} + +func TestHyperSync_GetNotModifiedUsesCached(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + path := tmpDir + "/hyper.json" + + // Create cache file. + cachedProvider := catwalk.Provider{ + Name: "Cached Hyper", + ID: "hyper", + } + data, err := json.Marshal(cachedProvider) + require.NoError(t, err) + require.NoError(t, os.WriteFile(path, data, 0o644)) + + syncer := &hyperSync{} + client := &mockHyperClient{ + err: catwalk.ErrNotModified, + } + + syncer.Init(client, path, true) + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, "Cached Hyper", provider.Name) + require.Equal(t, 1, client.callCount) +} + +func TestHyperSync_GetClientError(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + path := tmpDir + "/hyper.json" + + syncer := &hyperSync{} + client := &mockHyperClient{ + err: errors.New("network error"), + } + + syncer.Init(client, path, true) + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) // Should fall back to embedded. + require.Equal(t, "Charm Hyper", provider.Name) + require.Equal(t, catwalk.InferenceProvider("hyper"), provider.ID) +} + +func TestHyperSync_GetEmptyCache(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + path := tmpDir + "/hyper.json" + + syncer := &hyperSync{} + client := &mockHyperClient{ + provider: catwalk.Provider{ + Name: "Fresh Hyper", + ID: "hyper", + Models: []catwalk.Model{ + {ID: "model-1", Name: "Model 1"}, + }, + }, + } + + syncer.Init(client, path, true) + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, "Fresh Hyper", provider.Name) +} + +func TestHyperSync_GetCalledMultipleTimesUsesOnce(t *testing.T) { + t.Parallel() + + syncer := &hyperSync{} + client := &mockHyperClient{ + provider: catwalk.Provider{ + Name: "Hyper", + ID: "hyper", + Models: []catwalk.Model{ + {ID: "model-1", Name: "Model 1"}, + }, + }, + } + path := t.TempDir() + "/hyper.json" + + syncer.Init(client, path, true) + + // Call Get multiple times. + provider1, err1 := syncer.Get(t.Context()) + require.NoError(t, err1) + require.Equal(t, "Hyper", provider1.Name) + + provider2, err2 := syncer.Get(t.Context()) + require.NoError(t, err2) + require.Equal(t, "Hyper", provider2.Name) + + // Client should only be called once due to sync.Once. + require.Equal(t, 1, client.callCount) +} + +func TestHyperSync_GetCacheStoreError(t *testing.T) { + t.Parallel() + + // Create a file where we want a directory, causing mkdir to fail. + tmpDir := t.TempDir() + blockingFile := tmpDir + "/blocking" + require.NoError(t, os.WriteFile(blockingFile, []byte("block"), 0o644)) + + // Try to create cache in a subdirectory under the blocking file. + path := blockingFile + "/subdir/hyper.json" + + syncer := &hyperSync{} + client := &mockHyperClient{ + provider: catwalk.Provider{ + Name: "Hyper", + ID: "hyper", + Models: []catwalk.Model{ + {ID: "model-1", Name: "Model 1"}, + }, + }, + } + + syncer.Init(client, path, true) + + provider, err := syncer.Get(t.Context()) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to create directory for provider cache") + require.Equal(t, "Hyper", provider.Name) // Provider is still returned. +} diff --git a/internal/config/load.go b/internal/config/load.go index 14dd0f8792bcbabaa865efe47e0db5e721cf3827..0d16702dcdd35eb7d431ddfe4a0b35ab48e4debc 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -1,6 +1,7 @@ package config import ( + "cmp" "context" "encoding/json" "fmt" @@ -17,6 +18,7 @@ import ( "testing" "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/agent/hyper" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/fsext" @@ -131,6 +133,8 @@ func PushPopCrushEnv() func() { } func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error { + c.importCopilot() + knownProviderNames := make(map[string]bool) restore := PushPopCrushEnv() defer restore() @@ -198,8 +202,18 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know Models: p.Models, } - if p.ID == catwalk.InferenceProviderAnthropic && config.OAuthToken != nil { + switch { + case p.ID == catwalk.InferenceProviderAnthropic && config.OAuthToken != nil: prepared.SetupClaudeCode() + case p.ID == catwalk.InferenceProviderCopilot: + if config.OAuthToken != nil { + if token, ok := c.importCopilot(); ok { + prepared.OAuthToken = token + } + } + if config.OAuthToken != nil { + prepared.SetupGitHubCopilot() + } } switch p.ID { @@ -271,7 +285,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know if providerConfig.Type == "" { providerConfig.Type = catwalk.TypeOpenAICompat } - if !slices.Contains(catwalk.KnownProviderTypes(), providerConfig.Type) { + if !slices.Contains(catwalk.KnownProviderTypes(), providerConfig.Type) && providerConfig.Type != hyper.Name { slog.Warn("Skipping custom provider due to unsupported provider type", "provider", id) c.Providers.Del(id) continue @@ -682,19 +696,22 @@ func hasAWSCredentials(env env.Env) bool { // GlobalConfig returns the global configuration file path for the application. func GlobalConfig() string { - xdgConfigHome := os.Getenv("XDG_CONFIG_HOME") - if xdgConfigHome != "" { + if crushGlobal := os.Getenv("CRUSH_GLOBAL_CONFIG"); crushGlobal != "" { + return filepath.Join(crushGlobal, fmt.Sprintf("%s.json", appName)) + } + if xdgConfigHome := os.Getenv("XDG_CONFIG_HOME"); xdgConfigHome != "" { return filepath.Join(xdgConfigHome, appName, fmt.Sprintf("%s.json", appName)) } - return filepath.Join(home.Dir(), ".config", appName, fmt.Sprintf("%s.json", appName)) } // GlobalConfigData returns the path to the main data directory for the application. // this config is used when the app overrides configurations instead of updating the global config. func GlobalConfigData() string { - xdgDataHome := os.Getenv("XDG_DATA_HOME") - if xdgDataHome != "" { + if crushData := os.Getenv("CRUSH_GLOBAL_DATA"); crushData != "" { + return filepath.Join(crushData, fmt.Sprintf("%s.json", appName)) + } + if xdgDataHome := os.Getenv("XDG_DATA_HOME"); xdgDataHome != "" { return filepath.Join(xdgDataHome, appName, fmt.Sprintf("%s.json", appName)) } @@ -702,10 +719,10 @@ func GlobalConfigData() string { // for windows, it should be in `%LOCALAPPDATA%/crush/` // for linux and macOS, it should be in `$HOME/.local/share/crush/` if runtime.GOOS == "windows" { - localAppData := os.Getenv("LOCALAPPDATA") - if localAppData == "" { - localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local") - } + localAppData := cmp.Or( + os.Getenv("LOCALAPPDATA"), + filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local"), + ) return filepath.Join(localAppData, appName, fmt.Sprintf("%s.json", appName)) } diff --git a/internal/config/provider.go b/internal/config/provider.go index e9d4dfcc9d0947eebe51d5ae303106404ab47292..253d6f658a567ed5302887ecb87415de0a89c504 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -10,16 +10,21 @@ import ( "os" "path/filepath" "runtime" + "slices" "strings" "sync" + "time" "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/catwalk/pkg/embedded" + "github.com/charmbracelet/crush/internal/agent/hyper" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/home" + "github.com/charmbracelet/x/etag" ) -type ProviderClient interface { - GetProviders(context.Context, string) ([]catwalk.Provider, error) +type syncer[T any] interface { + Get(context.Context) (T, error) } var ( @@ -29,10 +34,10 @@ var ( ) // file to cache provider data -func providerCacheFileData() string { +func cachePathFor(name string) string { xdgDataHome := os.Getenv("XDG_DATA_HOME") if xdgDataHome != "" { - return filepath.Join(xdgDataHome, appName, "providers.json") + return filepath.Join(xdgDataHome, appName, name+".json") } // return the path to the main data directory @@ -43,43 +48,13 @@ func providerCacheFileData() string { if localAppData == "" { localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local") } - return filepath.Join(localAppData, appName, "providers.json") + return filepath.Join(localAppData, appName, name+".json") } - return filepath.Join(home.Dir(), ".local", "share", appName, "providers.json") -} - -func saveProvidersInCache(path string, providers []catwalk.Provider) error { - slog.Info("Saving provider data to disk", "path", path) - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { - return fmt.Errorf("failed to create directory for provider cache: %w", err) - } - - data, err := json.Marshal(providers) - if err != nil { - return fmt.Errorf("failed to marshal provider data: %w", err) - } - - if err := os.WriteFile(path, data, 0o644); err != nil { - return fmt.Errorf("failed to write provider data to cache: %w", err) - } - return nil -} - -func loadProvidersFromCache(path string) ([]catwalk.Provider, string, error) { - data, err := os.ReadFile(path) - if err != nil { - return nil, "", fmt.Errorf("failed to read provider cache file: %w", err) - } - - var providers []catwalk.Provider - if err := json.Unmarshal(data, &providers); err != nil { - return nil, "", fmt.Errorf("failed to unmarshal provider data from cache: %w", err) - } - - return providers, catwalk.Etag(data), nil + return filepath.Join(home.Dir(), ".local", "share", appName, name+".json") } +// UpdateProviders updates the Catwalk providers list from a specified source. func UpdateProviders(pathOrURL string) error { var providers []catwalk.Provider pathOrURL = cmp.Or(pathOrURL, os.Getenv("CATWALK_URL"), defaultCatwalkURL) @@ -106,15 +81,55 @@ func UpdateProviders(pathOrURL string) error { } } - cachePath := providerCacheFileData() - if err := saveProvidersInCache(cachePath, providers); err != nil { + if err := newCache[[]catwalk.Provider](cachePathFor("providers")).Store(providers); err != nil { return fmt.Errorf("failed to save providers to cache: %w", err) } - slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrURL, "to", cachePath) + slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrURL, "to", cachePathFor) + return nil +} + +// UpdateHyper updates the Hyper provider information from a specified URL. +func UpdateHyper(pathOrURL string) error { + if !hyper.Enabled() { + return fmt.Errorf("hyper not enabled") + } + var provider catwalk.Provider + pathOrURL = cmp.Or(pathOrURL, hyper.BaseURL()) + + switch { + case pathOrURL == "embedded": + provider = hyper.Embedded() + case strings.HasPrefix(pathOrURL, "http://") || strings.HasPrefix(pathOrURL, "https://"): + client := realHyperClient{baseURL: pathOrURL} + var err error + provider, err = client.Get(context.Background(), "") + if err != nil { + return fmt.Errorf("failed to fetch provider from Hyper: %w", err) + } + default: + content, err := os.ReadFile(pathOrURL) + if err != nil { + return fmt.Errorf("failed to read file: %w", err) + } + if err := json.Unmarshal(content, &provider); err != nil { + return fmt.Errorf("failed to unmarshal provider data: %w", err) + } + } + + if err := newCache[catwalk.Provider](cachePathFor("hyper")).Store(provider); err != nil { + return fmt.Errorf("failed to save Hyper provider to cache: %w", err) + } + + slog.Info("Hyper provider updated successfully", "from", pathOrURL, "to", cachePathFor("hyper")) return nil } +var ( + catwalkSyncer = &catwalkSync{} + hyperSyncer = &hyperSync{} +) + // Providers returns the list of providers, taking into account cached results // and whether or not auto update is enabled. // @@ -126,46 +141,87 @@ func UpdateProviders(pathOrURL string) error { // the cached list, or the embedded list if all others fail. func Providers(cfg *Config) ([]catwalk.Provider, error) { providerOnce.Do(func() { - catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL) - client := catwalk.NewWithURL(catwalkURL) - path := providerCacheFileData() - - if cfg.Options.DisableProviderAutoUpdate { - slog.Info("Using embedded Catwalk providers") - providerList, providerErr = embedded.GetAll(), nil - return - } + var wg sync.WaitGroup + var errs []error + providers := csync.NewSlice[catwalk.Provider]() + autoupdate := !cfg.Options.DisableProviderAutoUpdate + + ctx, cancel := context.WithTimeout(context.Background(), 45*time.Second) + defer cancel() + + wg.Go(func() { + catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL) + client := catwalk.NewWithURL(catwalkURL) + path := cachePathFor("providers") + catwalkSyncer.Init(client, path, autoupdate) + + items, err := catwalkSyncer.Get(ctx) + if err != nil { + catwalkURL := fmt.Sprintf("%s/v2/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)) + errs = append(errs, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help.\n\nCause: %w", catwalkURL, providerErr)) //nolint:staticcheck + return + } + providers.Append(items...) + }) + + wg.Go(func() { + if !hyper.Enabled() { + return + } + path := cachePathFor("hyper") + hyperSyncer.Init(realHyperClient{baseURL: hyper.BaseURL()}, path, autoupdate) + + item, err := hyperSyncer.Get(ctx) + if err != nil { + errs = append(errs, fmt.Errorf("Crush was unable to fetch updated information from Hyper: %w", err)) //nolint:staticcheck + return + } + providers.Append(item) + }) + + wg.Wait() + + providerList = slices.Collect(providers.Seq()) + providerErr = errors.Join(errs...) + }) + return providerList, providerErr +} - cached, etag, cachedErr := loadProvidersFromCache(path) - if len(cached) == 0 || cachedErr != nil { - // if cached file is empty, default to embedded providers - cached = embedded.GetAll() - } +type cache[T any] struct { + path string +} - providerList, providerErr = loadProviders(client, etag, path) - if errors.Is(providerErr, catwalk.ErrNotModified) { - slog.Info("Catwalk providers not modified") - providerList, providerErr = cached, nil - } - }) - if providerErr != nil { - catwalkURL := fmt.Sprintf("%s/v2/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)) - return nil, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help.\n\nCause: %w", catwalkURL, providerErr) //nolint:staticcheck - } - return providerList, nil +func newCache[T any](path string) cache[T] { + return cache[T]{path: path} } -func loadProviders(client ProviderClient, etag, path string) ([]catwalk.Provider, error) { - slog.Info("Fetching providers from Catwalk.", "path", path) - providers, err := client.GetProviders(context.Background(), etag) +func (c cache[T]) Get() (T, string, error) { + var v T + data, err := os.ReadFile(c.path) if err != nil { - return nil, fmt.Errorf("failed to fetch providers from catwalk: %w", err) + return v, "", fmt.Errorf("failed to read provider cache file: %w", err) + } + + if err := json.Unmarshal(data, &v); err != nil { + return v, "", fmt.Errorf("failed to unmarshal provider data from cache: %w", err) + } + + return v, etag.Of(data), nil +} + +func (c cache[T]) Store(v T) error { + slog.Info("Saving provider data to disk", "path", c.path) + if err := os.MkdirAll(filepath.Dir(c.path), 0o755); err != nil { + return fmt.Errorf("failed to create directory for provider cache: %w", err) } - if len(providers) == 0 { - return nil, errors.New("empty providers list from catwalk") + + data, err := json.Marshal(v) + if err != nil { + return fmt.Errorf("failed to marshal provider data: %w", err) } - if err := saveProvidersInCache(path, providers); err != nil { - return nil, err + + if err := os.WriteFile(c.path, data, 0o644); err != nil { + return fmt.Errorf("failed to write provider data to cache: %w", err) } - return providers, nil + return nil } diff --git a/internal/config/provider_empty_test.go b/internal/config/provider_empty_test.go index 5e889f82a4997d9f5761e6fc4a01ec6d54a0623a..7c37a9afb9694f0ea4352faee1b11d7e40d9480e 100644 --- a/internal/config/provider_empty_test.go +++ b/internal/config/provider_empty_test.go @@ -2,6 +2,7 @@ package config import ( "context" + "os" "testing" "github.com/charmbracelet/catwalk/pkg/catwalk" @@ -14,18 +15,25 @@ func (m *emptyProviderClient) GetProviders(context.Context, string) ([]catwalk.P return []catwalk.Provider{}, nil } -// TestProvider_loadProvidersEmptyResult tests that loadProviders returns an -// error when the client returns an empty list. This ensures we don't cache -// empty provider lists. -func TestProvider_loadProvidersEmptyResult(t *testing.T) { +// TestCatwalkSync_GetEmptyResultFromClient tests that when the client returns +// an empty list, we fall back to cached providers and return an error. +func TestCatwalkSync_GetEmptyResultFromClient(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + path := tmpDir + "/providers.json" + + syncer := &catwalkSync{} client := &emptyProviderClient{} - tmpPath := t.TempDir() + "/providers.json" - providers, err := loadProviders(client, "", tmpPath) + syncer.Init(client, path, true) + + providers, err := syncer.Get(t.Context()) + require.Error(t, err) require.Contains(t, err.Error(), "empty providers list from catwalk") - require.Empty(t, providers) - require.Len(t, providers, 0) + require.NotEmpty(t, providers) // Should have embedded providers as fallback. - // Check that no cache file was created for empty results - require.NoFileExists(t, tmpPath, "Cache file should not exist for empty results") + // Check that no cache file was created for empty results. + _, statErr := os.Stat(path) + require.True(t, os.IsNotExist(statErr), "Cache file should not exist for empty results") } diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index f101dd8de8ef624ed041ff88115c6c286f902659..e8790e286c3ffc8db77edb0ef8353e54ad519458 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -1,10 +1,9 @@ package config import ( - "context" "encoding/json" - "errors" "os" + "path/filepath" "sync" "testing" @@ -12,47 +11,31 @@ import ( "github.com/stretchr/testify/require" ) -type mockProviderClient struct { - shouldFail bool - shouldReturnErr error -} - -func (m *mockProviderClient) GetProviders(context.Context, string) ([]catwalk.Provider, error) { - if m.shouldReturnErr != nil { - return nil, m.shouldReturnErr - } - if m.shouldFail { - return nil, errors.New("failed to load providers") - } - return []catwalk.Provider{ - { - Name: "Mock", - }, - }, nil -} - func resetProviderState() { providerOnce = sync.Once{} providerList = nil providerErr = nil + catwalkSyncer = &catwalkSync{} + hyperSyncer = &hyperSync{} } -func TestProvider_loadProvidersNoIssues(t *testing.T) { - client := &mockProviderClient{shouldFail: false} - tmpPath := t.TempDir() + "/providers.json" - providers, err := loadProviders(client, "", tmpPath) - require.NoError(t, err) - require.NotNil(t, providers) - require.Len(t, providers, 1) +func TestProviders_Integration_AutoUpdateDisabled(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("XDG_DATA_HOME", tmpDir) - // check if file got saved - fileInfo, err := os.Stat(tmpPath) - require.NoError(t, err) - require.False(t, fileInfo.IsDir(), "Expected a file, not a directory") -} + // Use a test-specific instance to avoid global state interference. + testCatwalkSyncer := &catwalkSync{} + testHyperSyncer := &hyperSync{} + + originalCatwalSyncer := catwalkSyncer + originalHyperSyncer := hyperSyncer + defer func() { + catwalkSyncer = originalCatwalSyncer + hyperSyncer = originalHyperSyncer + }() -func TestProvider_DisableAutoUpdate(t *testing.T) { - t.Setenv("XDG_DATA_HOME", t.TempDir()) + catwalkSyncer = testCatwalkSyncer + hyperSyncer = testHyperSyncer resetProviderState() defer resetProviderState() @@ -69,76 +52,257 @@ func TestProvider_DisableAutoUpdate(t *testing.T) { require.Greater(t, len(providers), 5, "Expected embedded providers") } -func TestProvider_WithValidCache(t *testing.T) { +func TestProviders_Integration_WithMockClients(t *testing.T) { tmpDir := t.TempDir() t.Setenv("XDG_DATA_HOME", tmpDir) - resetProviderState() - defer resetProviderState() + // Create fresh syncers for this test. + testCatwalkSyncer := &catwalkSync{} + testHyperSyncer := &hyperSync{} + + // Initialize with mock clients. + mockCatwalkClient := &mockCatwalkClient{ + providers: []catwalk.Provider{ + {Name: "Provider1", ID: "p1"}, + {Name: "Provider2", ID: "p2"}, + }, + } + mockHyperClient := &mockHyperClient{ + provider: catwalk.Provider{ + Name: "Hyper", + ID: "hyper", + Models: []catwalk.Model{ + {ID: "hyper-1", Name: "Hyper Model"}, + }, + }, + } + + catwalkPath := tmpDir + "/crush/providers.json" + hyperPath := tmpDir + "/crush/hyper.json" + + testCatwalkSyncer.Init(mockCatwalkClient, catwalkPath, true) + testHyperSyncer.Init(mockHyperClient, hyperPath, true) + + // Get providers from each syncer. + catwalkProviders, err := testCatwalkSyncer.Get(t.Context()) + require.NoError(t, err) + require.Len(t, catwalkProviders, 2) + + hyperProvider, err := testHyperSyncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, "Hyper", hyperProvider.Name) + + // Verify total. + allProviders := append(catwalkProviders, hyperProvider) + require.Len(t, allProviders, 3) +} + +func TestProviders_Integration_WithCachedData(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("XDG_DATA_HOME", tmpDir) + + // Create cache files. + catwalkPath := tmpDir + "/crush/providers.json" + hyperPath := tmpDir + "/crush/hyper.json" - cachePath := tmpDir + "/crush/providers.json" require.NoError(t, os.MkdirAll(tmpDir+"/crush", 0o755)) - cachedProviders := []catwalk.Provider{ - {Name: "Cached"}, + + // Write Catwalk cache. + catwalkProviders := []catwalk.Provider{ + {Name: "Cached1", ID: "c1"}, + {Name: "Cached2", ID: "c2"}, + } + data, err := json.Marshal(catwalkProviders) + require.NoError(t, err) + require.NoError(t, os.WriteFile(catwalkPath, data, 0o644)) + + // Write Hyper cache. + hyperProvider := catwalk.Provider{ + Name: "Cached Hyper", + ID: "hyper", } - data, err := json.Marshal(cachedProviders) + data, err = json.Marshal(hyperProvider) require.NoError(t, err) - require.NoError(t, os.WriteFile(cachePath, data, 0o644)) + require.NoError(t, os.WriteFile(hyperPath, data, 0o644)) + + // Create fresh syncers. + testCatwalkSyncer := &catwalkSync{} + testHyperSyncer := &hyperSync{} + + // Mock clients that return ErrNotModified. + mockCatwalkClient := &mockCatwalkClient{ + err: catwalk.ErrNotModified, + } + mockHyperClient := &mockHyperClient{ + err: catwalk.ErrNotModified, + } - mockClient := &mockProviderClient{shouldFail: false} + testCatwalkSyncer.Init(mockCatwalkClient, catwalkPath, true) + testHyperSyncer.Init(mockHyperClient, hyperPath, true) - providers, err := loadProviders(mockClient, "", cachePath) + // Get providers - should use cached. + catwalkResult, err := testCatwalkSyncer.Get(t.Context()) require.NoError(t, err) - require.NotNil(t, providers) - require.Len(t, providers, 1) - require.Equal(t, "Mock", providers[0].Name, "Expected fresh provider from fetch") + require.Len(t, catwalkResult, 2) + require.Equal(t, "Cached1", catwalkResult[0].Name) + + hyperResult, err := testHyperSyncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, "Cached Hyper", hyperResult.Name) } -func TestProvider_NotModifiedUsesCached(t *testing.T) { +func TestProviders_Integration_CatwalkFailsHyperSucceeds(t *testing.T) { tmpDir := t.TempDir() t.Setenv("XDG_DATA_HOME", tmpDir) - resetProviderState() - defer resetProviderState() + testCatwalkSyncer := &catwalkSync{} + testHyperSyncer := &hyperSync{} - cachePath := tmpDir + "/crush/providers.json" - require.NoError(t, os.MkdirAll(tmpDir+"/crush", 0o755)) - cachedProviders := []catwalk.Provider{ - {Name: "Cached"}, + // Catwalk fails, Hyper succeeds. + mockCatwalkClient := &mockCatwalkClient{ + err: catwalk.ErrNotModified, // Will use embedded. } - data, err := json.Marshal(cachedProviders) + mockHyperClient := &mockHyperClient{ + provider: catwalk.Provider{ + Name: "Hyper", + ID: "hyper", + Models: []catwalk.Model{ + {ID: "hyper-1", Name: "Hyper Model"}, + }, + }, + } + + catwalkPath := tmpDir + "/crush/providers.json" + hyperPath := tmpDir + "/crush/hyper.json" + + testCatwalkSyncer.Init(mockCatwalkClient, catwalkPath, true) + testHyperSyncer.Init(mockHyperClient, hyperPath, true) + + catwalkResult, err := testCatwalkSyncer.Get(t.Context()) require.NoError(t, err) - require.NoError(t, os.WriteFile(cachePath, data, 0o644)) + require.NotEmpty(t, catwalkResult) // Should have embedded. - mockClient := &mockProviderClient{shouldReturnErr: catwalk.ErrNotModified} - providers, err := loadProviders(mockClient, "", cachePath) - require.ErrorIs(t, err, catwalk.ErrNotModified) - require.Nil(t, providers) + hyperResult, err := testHyperSyncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, "Hyper", hyperResult.Name) } -func TestProvider_EmptyCacheDefaultsToEmbedded(t *testing.T) { +func TestProviders_Integration_BothFail(t *testing.T) { tmpDir := t.TempDir() t.Setenv("XDG_DATA_HOME", tmpDir) - resetProviderState() - defer resetProviderState() + testCatwalkSyncer := &catwalkSync{} + testHyperSyncer := &hyperSync{} - cachePath := tmpDir + "/crush/providers.json" - require.NoError(t, os.MkdirAll(tmpDir+"/crush", 0o755)) - emptyProviders := []catwalk.Provider{} - data, err := json.Marshal(emptyProviders) + // Both fail. + mockCatwalkClient := &mockCatwalkClient{ + err: catwalk.ErrNotModified, + } + mockHyperClient := &mockHyperClient{ + provider: catwalk.Provider{}, // Empty provider. + } + + catwalkPath := tmpDir + "/crush/providers.json" + hyperPath := tmpDir + "/crush/hyper.json" + + testCatwalkSyncer.Init(mockCatwalkClient, catwalkPath, true) + testHyperSyncer.Init(mockHyperClient, hyperPath, true) + + catwalkResult, err := testCatwalkSyncer.Get(t.Context()) require.NoError(t, err) - require.NoError(t, os.WriteFile(cachePath, data, 0o644)) + require.NotEmpty(t, catwalkResult) // Should fall back to embedded. - cached, _, err := loadProvidersFromCache(cachePath) + hyperResult, err := testHyperSyncer.Get(t.Context()) require.NoError(t, err) - require.Empty(t, cached, "Expected empty cache") + require.Equal(t, "Charm Hyper", hyperResult.Name) // Falls back to embedded when no models. } -func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) { - client := &mockProviderClient{shouldFail: true} - tmpPath := t.TempDir() + "/providers.json" - providers, err := loadProviders(client, "", tmpPath) +func TestCache_StoreAndGet(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + cachePath := tmpDir + "/test.json" + + cache := newCache[[]catwalk.Provider](cachePath) + + providers := []catwalk.Provider{ + {Name: "Provider1", ID: "p1"}, + {Name: "Provider2", ID: "p2"}, + } + + // Store. + err := cache.Store(providers) + require.NoError(t, err) + + // Get. + result, etag, err := cache.Get() + require.NoError(t, err) + require.Len(t, result, 2) + require.Equal(t, "Provider1", result[0].Name) + require.NotEmpty(t, etag) +} + +func TestCache_GetNonExistent(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + cachePath := tmpDir + "/nonexistent.json" + + cache := newCache[[]catwalk.Provider](cachePath) + + _, _, err := cache.Get() require.Error(t, err) - require.Nil(t, providers, "Expected nil providers when loading fails and no cache exists") + require.Contains(t, err.Error(), "failed to read provider cache file") +} + +func TestCache_GetInvalidJSON(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + cachePath := tmpDir + "/invalid.json" + + require.NoError(t, os.WriteFile(cachePath, []byte("invalid json"), 0o644)) + + cache := newCache[[]catwalk.Provider](cachePath) + + _, _, err := cache.Get() + require.Error(t, err) + require.Contains(t, err.Error(), "failed to unmarshal provider data from cache") +} + +func TestCachePathFor(t *testing.T) { + tests := []struct { + name string + xdgDataHome string + expected string + }{ + { + name: "with XDG_DATA_HOME", + xdgDataHome: "/custom/data", + expected: "/custom/data/crush/providers.json", + }, + { + name: "without XDG_DATA_HOME", + xdgDataHome: "", + expected: "", // Will use platform-specific default. + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.xdgDataHome != "" { + t.Setenv("XDG_DATA_HOME", tt.xdgDataHome) + } else { + t.Setenv("XDG_DATA_HOME", "") + } + + result := cachePathFor("providers") + if tt.expected != "" { + require.Equal(t, tt.expected, filepath.ToSlash(result)) + } else { + require.Contains(t, result, "crush") + require.Contains(t, result, "providers.json") + } + }) + } } diff --git a/internal/event/event.go b/internal/event/event.go index 462e3f3ba53a0fc10d77822cf404a12d66b0bdec..1793c283f6a79cf9e6ff8c5fd5c533756e66df78 100644 --- a/internal/event/event.go +++ b/internal/event/event.go @@ -46,6 +46,22 @@ func Init() { distinctId = getDistinctId() } +func GetID() string { return distinctId } + +func Alias(userID string) { + if client == nil || distinctId == fallbackId || distinctId == "" || userID == "" { + return + } + if err := client.Enqueue(posthog.Alias{ + DistinctId: distinctId, + Alias: userID, + }); err != nil { + slog.Error("Failed to enqueue PostHog alias event", "error", err) + return + } + slog.Info("Aliased in PostHog", "machine_id", distinctId, "user_id", userID) +} + // send logs an event to PostHog with the given event name and properties. func send(event string, props ...any) { if client == nil { diff --git a/internal/oauth/copilot/disk.go b/internal/oauth/copilot/disk.go new file mode 100644 index 0000000000000000000000000000000000000000..bbb4957804767828c9a999062501983abf74216c --- /dev/null +++ b/internal/oauth/copilot/disk.go @@ -0,0 +1,36 @@ +package copilot + +import ( + "encoding/json" + "os" + "path/filepath" + "runtime" +) + +func RefreshTokenFromDisk() (string, bool) { + data, err := os.ReadFile(tokenFilePath()) + if err != nil { + return "", false + } + var content map[string]struct { + User string `json:"user"` + OAuthToken string `json:"oauth_token"` + GitHubAppID string `json:"githubAppId"` + } + if err := json.Unmarshal(data, &content); err != nil { + return "", false + } + if app, ok := content["github.com:Iv1.b507a08c87ecfe98"]; ok { + return app.OAuthToken, true + } + return "", false +} + +func tokenFilePath() string { + switch runtime.GOOS { + case "windows": + return filepath.Join(os.Getenv("LOCALAPPDATA"), "github-copilot/apps.json") + default: + return filepath.Join(os.Getenv("HOME"), ".config/github-copilot/apps.json") + } +} diff --git a/internal/oauth/copilot/http.go b/internal/oauth/copilot/http.go new file mode 100644 index 0000000000000000000000000000000000000000..482d9a4cd4819a586cbb7fc66c6a4f0b1d431ffb --- /dev/null +++ b/internal/oauth/copilot/http.go @@ -0,0 +1,17 @@ +package copilot + +const ( + userAgent = "GitHubCopilotChat/0.32.4" + editorVersion = "vscode/1.105.1" + editorPluginVersion = "copilot-chat/0.32.4" + integrationID = "vscode-chat" +) + +func Headers() map[string]string { + return map[string]string{ + "User-Agent": userAgent, + "Editor-Version": editorVersion, + "Editor-Plugin-Version": editorPluginVersion, + "Copilot-Integration-Id": integrationID, + } +} diff --git a/internal/oauth/copilot/oauth.go b/internal/oauth/copilot/oauth.go new file mode 100644 index 0000000000000000000000000000000000000000..40ec7b1b9a3f1c30376aed39321c7a83a40ba03c --- /dev/null +++ b/internal/oauth/copilot/oauth.go @@ -0,0 +1,200 @@ +package copilot + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/charmbracelet/crush/internal/oauth" +) + +const ( + clientID = "Iv1.b507a08c87ecfe98" + + deviceCodeURL = "https://github.com/login/device/code" + accessTokenURL = "https://github.com/login/oauth/access_token" + copilotTokenURL = "https://api.github.com/copilot_internal/v2/token" +) + +var ErrNotAvailable = errors.New("github copilot not available") + +type DeviceCode struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri"` + ExpiresIn int `json:"expires_in"` + Interval int `json:"interval"` +} + +// RequestDeviceCode initiates the device code flow with GitHub. +func RequestDeviceCode(ctx context.Context) (*DeviceCode, error) { + data := url.Values{} + data.Set("client_id", clientID) + data.Set("scope", "read:user") + + req, err := http.NewRequestWithContext(ctx, "POST", deviceCodeURL, strings.NewReader(data.Encode())) + if err != nil { + return nil, err + } + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("User-Agent", userAgent) + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("device code request failed: %s - %s", resp.Status, string(body)) + } + + var dc DeviceCode + if err := json.NewDecoder(resp.Body).Decode(&dc); err != nil { + return nil, err + } + return &dc, nil +} + +// PollForToken polls GitHub for the access token after user authorization. +func PollForToken(ctx context.Context, dc *DeviceCode) (*oauth.Token, error) { + interval := max(dc.Interval, 5) + deadline := time.Now().Add(time.Duration(dc.ExpiresIn) * time.Second) + ticker := time.NewTicker(time.Duration(interval) * time.Second) + defer ticker.Stop() + + for time.Now().Before(deadline) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-ticker.C: + } + + token, err := tryGetToken(ctx, dc.DeviceCode) + if err == errPending { + continue + } + if err == errSlowDown { + interval += 5 + ticker.Reset(time.Duration(interval) * time.Second) + continue + } + if err != nil { + return nil, err + } + return token, nil + } + + return nil, fmt.Errorf("authorization timed out") +} + +var ( + errPending = fmt.Errorf("pending") + errSlowDown = fmt.Errorf("slow_down") +) + +func tryGetToken(ctx context.Context, deviceCode string) (*oauth.Token, error) { + data := url.Values{} + data.Set("client_id", clientID) + data.Set("device_code", deviceCode) + data.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code") + + req, err := http.NewRequestWithContext(ctx, "POST", accessTokenURL, strings.NewReader(data.Encode())) + if err != nil { + return nil, err + } + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("User-Agent", userAgent) + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + var result struct { + AccessToken string `json:"access_token"` + Error string `json:"error"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, err + } + + switch result.Error { + case "": + if result.AccessToken == "" { + return nil, errPending + } + return getCopilotToken(ctx, result.AccessToken) + case "authorization_pending": + return nil, errPending + case "slow_down": + return nil, errSlowDown + default: + return nil, fmt.Errorf("authorization failed: %s", result.Error) + } +} + +func getCopilotToken(ctx context.Context, githubToken string) (*oauth.Token, error) { + req, err := http.NewRequestWithContext(ctx, "GET", copilotTokenURL, nil) + if err != nil { + return nil, err + } + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", githubToken)) + for k, v := range Headers() { + req.Header.Set(k, v) + } + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode == http.StatusForbidden { + return nil, ErrNotAvailable + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("copilot token request failed: %s - %s", resp.Status, string(body)) + } + + var result struct { + Token string `json:"token"` + ExpiresAt int64 `json:"expires_at"` + } + if err := json.Unmarshal(body, &result); err != nil { + return nil, err + } + + copilotToken := &oauth.Token{ + AccessToken: result.Token, + RefreshToken: githubToken, + ExpiresAt: result.ExpiresAt, + } + copilotToken.SetExpiresIn() + + return copilotToken, nil +} + +// RefreshToken refreshes the Copilot token using the GitHub token. +func RefreshToken(ctx context.Context, githubToken string) (*oauth.Token, error) { + return getCopilotToken(ctx, githubToken) +} diff --git a/internal/oauth/copilot/urls.go b/internal/oauth/copilot/urls.go new file mode 100644 index 0000000000000000000000000000000000000000..a61535b4d2afa75133690574440073a6282f94a6 --- /dev/null +++ b/internal/oauth/copilot/urls.go @@ -0,0 +1,6 @@ +package copilot + +const ( + SignupURL = "https://github.com/github-copilot/signup?editor=crush" + FreeURL = "https://docs.github.com/en/copilot/how-tos/manage-your-account/get-free-access-to-copilot-pro" +) diff --git a/internal/oauth/hyper/device.go b/internal/oauth/hyper/device.go new file mode 100644 index 0000000000000000000000000000000000000000..90d115f76c197c376eb61d8a9676966eb12a272c --- /dev/null +++ b/internal/oauth/hyper/device.go @@ -0,0 +1,251 @@ +// Package hyper provides functions to handle Hyper device flow authentication. +package hyper + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "strings" + "time" + + "github.com/charmbracelet/crush/internal/agent/hyper" + "github.com/charmbracelet/crush/internal/event" + "github.com/charmbracelet/crush/internal/oauth" +) + +// DeviceAuthResponse contains the response from the device authorization endpoint. +type DeviceAuthResponse struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURL string `json:"verification_url"` + ExpiresIn int `json:"expires_in"` +} + +// TokenResponse contains the response from the polling endpoint. +type TokenResponse struct { + RefreshToken string `json:"refresh_token,omitempty"` + UserID string `json:"user_id"` + OrganizationID string `json:"organization_id"` + OrganizationName string `json:"organization_name"` + Error string `json:"error,omitempty"` + ErrorDescription string `json:"error_description,omitempty"` +} + +// InitiateDeviceAuth calls the /device/auth endpoint to start the device flow. +func InitiateDeviceAuth(ctx context.Context) (*DeviceAuthResponse, error) { + url := hyper.BaseURL() + "/device/auth" + + req, err := http.NewRequestWithContext( + ctx, http.MethodPost, url, + strings.NewReader(fmt.Sprintf(`{"device_name":%q}`, deviceName())), + ) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "crush") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("execute request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("device auth failed: status %d, body %q", resp.StatusCode, string(body)) + } + + var authResp DeviceAuthResponse + if err := json.Unmarshal(body, &authResp); err != nil { + return nil, fmt.Errorf("unmarshal response: %w", err) + } + + return &authResp, nil +} + +func deviceName() string { + if hostname, err := os.Hostname(); err == nil && hostname != "" { + return "Crush (" + hostname + ")" + } + return "Crush" +} + +// PollForToken polls the /device/token endpoint until authorization is complete. +// It respects the polling interval and handles various error states. +func PollForToken(ctx context.Context, deviceCode string, expiresIn int) (string, error) { + ctx, cancel := context.WithTimeout(ctx, time.Duration(expiresIn)*time.Second) + defer cancel() + + d := 5 * time.Second + ticker := time.NewTicker(d) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return "", ctx.Err() + case <-ticker.C: + result, err := pollOnce(ctx, deviceCode) + if err != nil { + return "", err + } + if result.RefreshToken != "" { + event.Alias(result.UserID) + return result.RefreshToken, nil + } + switch result.Error { + case "authorization_pending": + continue + default: + return "", errors.New(result.ErrorDescription) + } + } + } +} + +func pollOnce(ctx context.Context, deviceCode string) (TokenResponse, error) { + var result TokenResponse + url := fmt.Sprintf("%s/device/auth/%s", hyper.BaseURL(), deviceCode) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return result, fmt.Errorf("create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "crush") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return result, fmt.Errorf("execute request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return result, fmt.Errorf("read response: %w", err) + } + + if err := json.Unmarshal(body, &result); err != nil { + return result, fmt.Errorf("unmarshal response: %w: %s", err, string(body)) + } + + if resp.StatusCode != http.StatusOK { + return result, fmt.Errorf("token request failed: status %d body %q", resp.StatusCode, string(body)) + } + + return result, nil +} + +// ExchangeToken exchanges a refresh token for an access token. +func ExchangeToken(ctx context.Context, refreshToken string) (*oauth.Token, error) { + reqBody := map[string]string{ + "refresh_token": refreshToken, + } + + data, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + + url := hyper.BaseURL() + "/token/exchange" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(data)) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "crush") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("execute request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("token exchange failed: status %d body %q", resp.StatusCode, string(body)) + } + + var token oauth.Token + if err := json.Unmarshal(body, &token); err != nil { + return nil, fmt.Errorf("unmarshal response: %w", err) + } + + token.SetExpiresAt() + return &token, nil +} + +// IntrospectTokenResponse contains the response from the token introspection endpoint. +type IntrospectTokenResponse struct { + Active bool `json:"active"` + Sub string `json:"sub,omitempty"` + OrgID string `json:"org_id,omitempty"` + Exp int64 `json:"exp,omitempty"` + Iat int64 `json:"iat,omitempty"` + Iss string `json:"iss,omitempty"` + Jti string `json:"jti,omitempty"` +} + +// IntrospectToken validates an access token using the introspection endpoint. +// Implements OAuth2 Token Introspection (RFC 7662). +func IntrospectToken(ctx context.Context, accessToken string) (*IntrospectTokenResponse, error) { + reqBody := map[string]string{ + "token": accessToken, + } + + data, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + + url := hyper.BaseURL() + "/token/introspect" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(data)) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "crush") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("execute request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("token introspection failed: status %d body %q", resp.StatusCode, string(body)) + } + + var result IntrospectTokenResponse + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("unmarshal response: %w", err) + } + + return &result, nil +} diff --git a/internal/oauth/token.go b/internal/oauth/token.go index 29d4791b5fd416e65995698dcc4665ea96fbb090..381eb3e3110d01db1944fcd98659d87ac7055e2a 100644 --- a/internal/oauth/token.go +++ b/internal/oauth/token.go @@ -4,7 +4,7 @@ import ( "time" ) -// Token represents an OAuth2 token from Claude Code Max. +// Token represents an OAuth2 token. type Token struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` @@ -21,3 +21,8 @@ func (t *Token) SetExpiresAt() { func (t *Token) IsExpired() bool { return time.Now().Unix() >= (t.ExpiresAt - int64(t.ExpiresIn)/10) } + +// SetExpiresIn calculates and sets the ExpiresIn field based on the ExpiresAt field. +func (t *Token) SetExpiresIn() { + t.ExpiresIn = int(time.Until(time.Unix(t.ExpiresAt, 0)).Seconds()) +} diff --git a/internal/tui/components/dialogs/commands/commands.go b/internal/tui/components/dialogs/commands/commands.go index 7769849bb9b4d6e16e6644ae658d9f0394ac2289..cde5b203ca985f81c390d02725ef04d11a5cd518 100644 --- a/internal/tui/components/dialogs/commands/commands.go +++ b/internal/tui/components/dialogs/commands/commands.go @@ -13,6 +13,7 @@ import ( "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/agent" + "github.com/charmbracelet/crush/internal/agent/hyper" "github.com/charmbracelet/crush/internal/agent/tools/mcp" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/csync" @@ -363,7 +364,7 @@ func (c *commandDialogCmp) defaultCommands() []Command { selectedModel := cfg.Models[agentCfg.Model] // Anthropic models: thinking toggle - if providerCfg.Type == catwalk.TypeAnthropic { + if providerCfg.Type == catwalk.TypeAnthropic || providerCfg.Type == catwalk.Type(hyper.Name) { status := "Enable" if selectedModel.Think { status = "Disable" diff --git a/internal/tui/components/dialogs/copilot/device_flow.go b/internal/tui/components/dialogs/copilot/device_flow.go new file mode 100644 index 0000000000000000000000000000000000000000..d3f792291e4bce77dc5ceacb1aa1200a111981dc --- /dev/null +++ b/internal/tui/components/dialogs/copilot/device_flow.go @@ -0,0 +1,281 @@ +// Package copilot provides the dialog for Copilot device flow authentication. +package copilot + +import ( + "context" + "fmt" + "time" + + "charm.land/bubbles/v2/spinner" + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" + "github.com/charmbracelet/crush/internal/oauth" + "github.com/charmbracelet/crush/internal/oauth/copilot" + "github.com/charmbracelet/crush/internal/tui/styles" + "github.com/charmbracelet/crush/internal/tui/util" + "github.com/pkg/browser" +) + +// DeviceFlowState represents the current state of the device flow. +type DeviceFlowState int + +const ( + DeviceFlowStateDisplay DeviceFlowState = iota + DeviceFlowStateSuccess + DeviceFlowStateError + DeviceFlowStateUnavailable +) + +// DeviceAuthInitiatedMsg is sent when the device auth is initiated +// successfully. +type DeviceAuthInitiatedMsg struct { + deviceCode *copilot.DeviceCode +} + +// DeviceFlowCompletedMsg is sent when the device flow completes successfully. +type DeviceFlowCompletedMsg struct { + Token *oauth.Token +} + +// DeviceFlowErrorMsg is sent when the device flow encounters an error. +type DeviceFlowErrorMsg struct { + Error error +} + +// DeviceFlow handles the Copilot device flow authentication. +type DeviceFlow struct { + State DeviceFlowState + width int + deviceCode *copilot.DeviceCode + token *oauth.Token + cancelFunc context.CancelFunc + spinner spinner.Model +} + +// NewDeviceFlow creates a new device flow component. +func NewDeviceFlow() *DeviceFlow { + s := spinner.New() + s.Spinner = spinner.Dot + s.Style = lipgloss.NewStyle().Foreground(styles.CurrentTheme().GreenLight) + return &DeviceFlow{ + State: DeviceFlowStateDisplay, + spinner: s, + } +} + +// Init initializes the device flow by calling the device auth API and starting polling. +func (d *DeviceFlow) Init() tea.Cmd { + return tea.Batch(d.spinner.Tick, d.initiateDeviceAuth) +} + +// Update handles messages and state transitions. +func (d *DeviceFlow) Update(msg tea.Msg) (util.Model, tea.Cmd) { + var cmd tea.Cmd + d.spinner, cmd = d.spinner.Update(msg) + + switch msg := msg.(type) { + case DeviceAuthInitiatedMsg: + return d, tea.Batch(cmd, d.startPolling(msg.deviceCode)) + case DeviceFlowCompletedMsg: + d.State = DeviceFlowStateSuccess + d.token = msg.Token + return d, nil + case DeviceFlowErrorMsg: + switch msg.Error { + case copilot.ErrNotAvailable: + d.State = DeviceFlowStateUnavailable + default: + d.State = DeviceFlowStateError + } + return d, nil + } + + return d, cmd +} + +// View renders the device flow dialog. +func (d *DeviceFlow) View() string { + t := styles.CurrentTheme() + + whiteStyle := lipgloss.NewStyle().Foreground(t.White) + primaryStyle := lipgloss.NewStyle().Foreground(t.Primary) + greenStyle := lipgloss.NewStyle().Foreground(t.GreenLight) + linkStyle := lipgloss.NewStyle().Foreground(t.GreenDark).Underline(true) + errorStyle := lipgloss.NewStyle().Foreground(t.Error) + mutedStyle := lipgloss.NewStyle().Foreground(t.FgMuted) + + switch d.State { + case DeviceFlowStateDisplay: + if d.deviceCode == nil { + return lipgloss.NewStyle(). + Margin(0, 1). + Render( + greenStyle.Render(d.spinner.View()) + + mutedStyle.Render("Initializing..."), + ) + } + + instructions := lipgloss.NewStyle(). + Margin(1, 1, 0, 1). + Width(d.width - 2). + Render( + whiteStyle.Render("Press ") + + primaryStyle.Render("enter") + + whiteStyle.Render(" to copy the code below and open the browser."), + ) + + codeBox := lipgloss.NewStyle(). + Width(d.width-2). + Height(7). + Align(lipgloss.Center, lipgloss.Center). + Background(t.BgBaseLighter). + Margin(1). + Render( + lipgloss.NewStyle(). + Bold(true). + Foreground(t.White). + Render(d.deviceCode.UserCode), + ) + + uri := d.deviceCode.VerificationURI + link := lipgloss.NewStyle().Hyperlink(uri, "id=copilot-verify").Render(uri) + url := mutedStyle. + Margin(0, 1). + Width(d.width - 2). + Render("Browser not opening? Refer to\n" + link) + + waiting := greenStyle. + Width(d.width-2). + Margin(1, 1, 0, 1). + Render(d.spinner.View() + "Verifying...") + + return lipgloss.JoinVertical( + lipgloss.Left, + instructions, + codeBox, + url, + waiting, + ) + + case DeviceFlowStateSuccess: + return greenStyle.Margin(0, 1).Render("Authentication successful!") + + case DeviceFlowStateError: + return lipgloss.NewStyle(). + Margin(0, 1). + Width(d.width - 2). + Render(errorStyle.Render("Authentication failed.")) + + case DeviceFlowStateUnavailable: + message := lipgloss.NewStyle(). + Margin(0, 1). + Width(d.width - 2). + Render("GitHub Copilot is unavailable for this account. To signup, go to the following page:") + freeMessage := lipgloss.NewStyle(). + Margin(0, 1). + Width(d.width - 2). + Render("You may be able to request free access if elegible. For more information, see:") + return lipgloss.JoinVertical( + lipgloss.Left, + message, + "", + linkStyle.Margin(0, 1).Width(d.width-2).Hyperlink(copilot.SignupURL, "id=copilot-signup").Render(copilot.SignupURL), + "", + freeMessage, + "", + linkStyle.Margin(0, 1).Width(d.width-2).Hyperlink(copilot.FreeURL, "id=copilot-free").Render(copilot.FreeURL), + ) + + default: + return "" + } +} + +// SetWidth sets the width of the dialog. +func (d *DeviceFlow) SetWidth(w int) { + d.width = w +} + +// Cursor hides the cursor. +func (d *DeviceFlow) Cursor() *tea.Cursor { return nil } + +// CopyCodeAndOpenURL copies the user code to the clipboard and opens the URL. +func (d *DeviceFlow) CopyCodeAndOpenURL() tea.Cmd { + switch d.State { + case DeviceFlowStateDisplay: + return tea.Sequence( + tea.SetClipboard(d.deviceCode.UserCode), + func() tea.Msg { + if err := browser.OpenURL(d.deviceCode.VerificationURI); err != nil { + return DeviceFlowErrorMsg{Error: fmt.Errorf("failed to open browser: %w", err)} + } + return nil + }, + util.ReportInfo("Code copied and URL opened"), + ) + case DeviceFlowStateUnavailable: + return tea.Sequence( + func() tea.Msg { + if err := browser.OpenURL(copilot.SignupURL); err != nil { + return DeviceFlowErrorMsg{Error: fmt.Errorf("failed to open browser: %w", err)} + } + return nil + }, + util.ReportInfo("Code copied and URL opened"), + ) + default: + return nil + } +} + +// CopyCode copies just the user code to the clipboard. +func (d *DeviceFlow) CopyCode() tea.Cmd { + if d.State != DeviceFlowStateDisplay { + return nil + } + return tea.Sequence( + tea.SetClipboard(d.deviceCode.UserCode), + util.ReportInfo("Code copied to clipboard"), + ) +} + +// Cancel cancels the device flow polling. +func (d *DeviceFlow) Cancel() { + if d.cancelFunc != nil { + d.cancelFunc() + } +} + +func (d *DeviceFlow) initiateDeviceAuth() tea.Msg { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + deviceCode, err := copilot.RequestDeviceCode(ctx) + if err != nil { + return DeviceFlowErrorMsg{Error: fmt.Errorf("failed to initiate device auth: %w", err)} + } + + d.deviceCode = deviceCode + + return DeviceAuthInitiatedMsg{ + deviceCode: d.deviceCode, + } +} + +// startPolling starts polling for the device token. +func (d *DeviceFlow) startPolling(deviceCode *copilot.DeviceCode) tea.Cmd { + return func() tea.Msg { + ctx, cancel := context.WithCancel(context.Background()) + d.cancelFunc = cancel + + token, err := copilot.PollForToken(ctx, deviceCode) + if err != nil { + if ctx.Err() != nil { + return nil // cancelled, don't report error. + } + return DeviceFlowErrorMsg{Error: err} + } + + return DeviceFlowCompletedMsg{Token: token} + } +} diff --git a/internal/tui/components/dialogs/hyper/device_flow.go b/internal/tui/components/dialogs/hyper/device_flow.go new file mode 100644 index 0000000000000000000000000000000000000000..b88d3aae8a2d1a826d5827c9f4112911602db2a2 --- /dev/null +++ b/internal/tui/components/dialogs/hyper/device_flow.go @@ -0,0 +1,267 @@ +// Package hyper provides the dialog for Hyper device flow authentication. +package hyper + +import ( + "context" + "fmt" + "time" + + "charm.land/bubbles/v2/spinner" + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" + "github.com/charmbracelet/crush/internal/oauth" + "github.com/charmbracelet/crush/internal/oauth/hyper" + "github.com/charmbracelet/crush/internal/tui/styles" + "github.com/charmbracelet/crush/internal/tui/util" + "github.com/pkg/browser" +) + +// DeviceFlowState represents the current state of the device flow. +type DeviceFlowState int + +const ( + DeviceFlowStateDisplay DeviceFlowState = iota + DeviceFlowStateSuccess + DeviceFlowStateError +) + +// DeviceAuthInitiatedMsg is sent when the device auth is initiated +// successfully. +type DeviceAuthInitiatedMsg struct { + deviceCode string + expiresIn int +} + +// DeviceFlowCompletedMsg is sent when the device flow completes successfully. +type DeviceFlowCompletedMsg struct { + Token *oauth.Token +} + +// DeviceFlowErrorMsg is sent when the device flow encounters an error. +type DeviceFlowErrorMsg struct { + Error error +} + +// DeviceFlow handles the Hyper device flow authentication. +type DeviceFlow struct { + State DeviceFlowState + width int + deviceCode string + userCode string + verificationURL string + expiresIn int + token *oauth.Token + cancelFunc context.CancelFunc + spinner spinner.Model +} + +// NewDeviceFlow creates a new device flow component. +func NewDeviceFlow() *DeviceFlow { + s := spinner.New() + s.Spinner = spinner.Dot + s.Style = lipgloss.NewStyle().Foreground(styles.CurrentTheme().GreenLight) + return &DeviceFlow{ + State: DeviceFlowStateDisplay, + spinner: s, + } +} + +// Init initializes the device flow by calling the device auth API and starting polling. +func (d *DeviceFlow) Init() tea.Cmd { + return tea.Batch(d.spinner.Tick, d.initiateDeviceAuth) +} + +// Update handles messages and state transitions. +func (d *DeviceFlow) Update(msg tea.Msg) (util.Model, tea.Cmd) { + var cmd tea.Cmd + d.spinner, cmd = d.spinner.Update(msg) + + switch msg := msg.(type) { + case DeviceAuthInitiatedMsg: + // Start polling now that we have the device code. + d.expiresIn = msg.expiresIn + return d, tea.Batch(cmd, d.startPolling(msg.deviceCode)) + case DeviceFlowCompletedMsg: + d.State = DeviceFlowStateSuccess + d.token = msg.Token + return d, nil + case DeviceFlowErrorMsg: + d.State = DeviceFlowStateError + return d, util.ReportError(msg.Error) + } + + return d, cmd +} + +// View renders the device flow dialog. +func (d *DeviceFlow) View() string { + t := styles.CurrentTheme() + + whiteStyle := lipgloss.NewStyle().Foreground(t.White) + primaryStyle := lipgloss.NewStyle().Foreground(t.Primary) + greenStyle := lipgloss.NewStyle().Foreground(t.GreenLight) + linkStyle := lipgloss.NewStyle().Foreground(t.GreenDark).Underline(true) + errorStyle := lipgloss.NewStyle().Foreground(t.Error) + mutedStyle := lipgloss.NewStyle().Foreground(t.FgMuted) + + switch d.State { + case DeviceFlowStateDisplay: + if d.userCode == "" { + return lipgloss.NewStyle(). + Margin(0, 1). + Render( + greenStyle.Render(d.spinner.View()) + + mutedStyle.Render("Initializing..."), + ) + } + + instructions := lipgloss.NewStyle(). + Margin(1, 1, 0, 1). + Width(d.width - 2). + Render( + whiteStyle.Render("Press ") + + primaryStyle.Render("enter") + + whiteStyle.Render(" to copy the code below and open the browser."), + ) + + codeBox := lipgloss.NewStyle(). + Width(d.width-2). + Height(7). + Align(lipgloss.Center, lipgloss.Center). + Background(t.BgBaseLighter). + Margin(1). + Render( + lipgloss.NewStyle(). + Bold(true). + Foreground(t.White). + Render(d.userCode), + ) + + link := linkStyle.Hyperlink(d.verificationURL, "id=hyper-verify").Render(d.verificationURL) + url := mutedStyle. + Margin(0, 1). + Width(d.width - 2). + Render("Browser not opening? Refer to\n" + link) + + waiting := greenStyle. + Width(d.width-2). + Margin(1, 1, 0, 1). + Render(d.spinner.View() + "Verifying...") + + return lipgloss.JoinVertical( + lipgloss.Left, + instructions, + codeBox, + url, + waiting, + ) + + case DeviceFlowStateSuccess: + return greenStyle.Margin(0, 1).Render("Authentication successful!") + + case DeviceFlowStateError: + return lipgloss.NewStyle(). + Margin(0, 1). + Width(d.width - 2). + Render(errorStyle.Render("Authentication failed.")) + + default: + return "" + } +} + +// SetWidth sets the width of the dialog. +func (d *DeviceFlow) SetWidth(w int) { + d.width = w +} + +// Cursor hides the cursor. +func (d *DeviceFlow) Cursor() *tea.Cursor { return nil } + +// CopyCodeAndOpenURL copies the user code to the clipboard and opens the URL. +func (d *DeviceFlow) CopyCodeAndOpenURL() tea.Cmd { + if d.State != DeviceFlowStateDisplay { + return nil + } + return tea.Sequence( + tea.SetClipboard(d.userCode), + func() tea.Msg { + if err := browser.OpenURL(d.verificationURL); err != nil { + return DeviceFlowErrorMsg{Error: fmt.Errorf("failed to open browser: %w", err)} + } + return nil + }, + util.ReportInfo("Code copied and URL opened"), + ) +} + +// CopyCode copies just the user code to the clipboard. +func (d *DeviceFlow) CopyCode() tea.Cmd { + if d.State != DeviceFlowStateDisplay { + return nil + } + return tea.Sequence( + tea.SetClipboard(d.userCode), + util.ReportInfo("Code copied to clipboard"), + ) +} + +// Cancel cancels the device flow polling. +func (d *DeviceFlow) Cancel() { + if d.cancelFunc != nil { + d.cancelFunc() + } +} + +func (d *DeviceFlow) initiateDeviceAuth() tea.Msg { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + authResp, err := hyper.InitiateDeviceAuth(ctx) + if err != nil { + return DeviceFlowErrorMsg{Error: fmt.Errorf("failed to initiate device auth: %w", err)} + } + + d.deviceCode = authResp.DeviceCode + d.userCode = authResp.UserCode + d.verificationURL = authResp.VerificationURL + + return DeviceAuthInitiatedMsg{ + deviceCode: authResp.DeviceCode, + expiresIn: authResp.ExpiresIn, + } +} + +// startPolling starts polling for the device token. +func (d *DeviceFlow) startPolling(deviceCode string) tea.Cmd { + return func() tea.Msg { + ctx, cancel := context.WithCancel(context.Background()) + d.cancelFunc = cancel + + // Poll for refresh token. + refreshToken, err := hyper.PollForToken(ctx, deviceCode, d.expiresIn) + if err != nil { + if ctx.Err() != nil { + // Cancelled, don't report error. + return nil + } + return DeviceFlowErrorMsg{Error: err} + } + + // Exchange refresh token for access token. + token, err := hyper.ExchangeToken(ctx, refreshToken) + if err != nil { + return DeviceFlowErrorMsg{Error: fmt.Errorf("token exchange failed: %w", err)} + } + + // Verify the access token works. + introspect, err := hyper.IntrospectToken(ctx, token.AccessToken) + if err != nil { + return DeviceFlowErrorMsg{Error: fmt.Errorf("token introspection failed: %w", err)} + } + if !introspect.Active { + return DeviceFlowErrorMsg{Error: fmt.Errorf("access token is not active")} + } + + return DeviceFlowCompletedMsg{Token: token} + } +} diff --git a/internal/tui/components/dialogs/models/keys.go b/internal/tui/components/dialogs/models/keys.go index c075a35ac808d7d1093c5a7710bf511f8d0219cb..eda235aebb858fef21c582921cfb9e305a6fed19 100644 --- a/internal/tui/components/dialogs/models/keys.go +++ b/internal/tui/components/dialogs/models/keys.go @@ -15,6 +15,10 @@ type KeyMap struct { isAPIKeyHelp bool isAPIKeyValid bool + isHyperDeviceFlow bool + isCopilotDeviceFlow bool + isCopilotUnavailable bool + isClaudeAuthChoiceHelp bool isClaudeOAuthHelp bool isClaudeOAuthURLState bool @@ -74,6 +78,28 @@ func (k KeyMap) FullHelp() [][]key.Binding { // ShortHelp implements help.KeyMap. func (k KeyMap) ShortHelp() []key.Binding { + if k.isHyperDeviceFlow || k.isCopilotDeviceFlow { + return []key.Binding{ + key.NewBinding( + key.WithKeys("c"), + key.WithHelp("c", "copy code"), + ), + key.NewBinding( + key.WithKeys("enter"), + key.WithHelp("enter", "copy & open"), + ), + k.Close, + } + } + if k.isCopilotUnavailable { + return []key.Binding{ + key.NewBinding( + key.WithKeys("enter"), + key.WithHelp("enter", "open signup"), + ), + k.Close, + } + } if k.isClaudeAuthChoiceHelp { return []key.Binding{ key.NewBinding( diff --git a/internal/tui/components/dialogs/models/list.go b/internal/tui/components/dialogs/models/list.go index 2383f749de277e7fe915b57aac17fa0e7928756e..9640b894d8e5bfb8659440f18f4cf04fb413bf02 100644 --- a/internal/tui/components/dialogs/models/list.go +++ b/internal/tui/components/dialogs/models/list.go @@ -62,7 +62,9 @@ func (m *ModelListComponent) Init() tea.Cmd { filteredProviders := []catwalk.Provider{} for _, p := range providers { hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$") - if hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure { + isHyper := p.ID == "hyper" + isCopilot := p.ID == catwalk.InferenceProviderCopilot + if (hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure) || isHyper || isCopilot { filteredProviders = append(filteredProviders, p) } } @@ -146,29 +148,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) || !slices.ContainsFunc(m.providers, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) { // Convert config provider to provider.Provider format - configProvider := catwalk.Provider{ - Name: providerConfig.Name, - ID: catwalk.InferenceProvider(providerID), - Models: make([]catwalk.Model, len(providerConfig.Models)), - } - - // Convert models - for i, model := range providerConfig.Models { - configProvider.Models[i] = catwalk.Model{ - ID: model.ID, - Name: model.Name, - CostPer1MIn: model.CostPer1MIn, - CostPer1MOut: model.CostPer1MOut, - CostPer1MInCached: model.CostPer1MInCached, - CostPer1MOutCached: model.CostPer1MOutCached, - ContextWindow: model.ContextWindow, - DefaultMaxTokens: model.DefaultMaxTokens, - CanReason: model.CanReason, - ReasoningLevels: model.ReasoningLevels, - DefaultReasoningEffort: model.DefaultReasoningEffort, - SupportsImages: model.SupportsImages, - } - } + configProvider := providerConfig.ToProvider() // Add this unknown provider to the list name := configProvider.Name @@ -204,8 +184,23 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { } } + // Move "Charm Hyper" to first position + // (but still after recent models and custom providers). + sortedProviders := make([]catwalk.Provider, len(m.providers)) + copy(sortedProviders, m.providers) + slices.SortStableFunc(sortedProviders, func(a, b catwalk.Provider) int { + switch { + case a.ID == "hyper": + return -1 + case b.ID == "hyper": + return 1 + default: + return 0 + } + }) + // Then add the known providers from the predefined list - for _, provider := range m.providers { + for _, provider := range sortedProviders { // Skip if we already added this provider as an unknown provider if addedProviders[string(provider.ID)] { continue diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go index 6248f1f0440441f1b5d0924b581283835ca2c294..8ed2ffbf0bf0ddd4641fbbda6f0e2b20a1967e07 100644 --- a/internal/tui/components/dialogs/models/models.go +++ b/internal/tui/components/dialogs/models/models.go @@ -1,3 +1,4 @@ +// Package models provides the model selection dialog for the TUI. package models import ( @@ -11,10 +12,13 @@ import ( "charm.land/lipgloss/v2" "github.com/atotto/clipboard" "github.com/charmbracelet/catwalk/pkg/catwalk" + hyperp "github.com/charmbracelet/crush/internal/agent/hyper" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/components/dialogs" "github.com/charmbracelet/crush/internal/tui/components/dialogs/claude" + "github.com/charmbracelet/crush/internal/tui/components/dialogs/copilot" + "github.com/charmbracelet/crush/internal/tui/components/dialogs/hyper" "github.com/charmbracelet/crush/internal/tui/exp/list" "github.com/charmbracelet/crush/internal/tui/styles" "github.com/charmbracelet/crush/internal/tui/util" @@ -70,6 +74,14 @@ type modelDialogCmp struct { isAPIKeyValid bool apiKeyValue string + // Hyper device flow state + hyperDeviceFlow *hyper.DeviceFlow + showHyperDeviceFlow bool + + // Copilot device flow state + copilotDeviceFlow *copilot.DeviceFlow + showCopilotDeviceFlow bool + // Claude state claudeAuthMethodChooser *claude.AuthMethodChooser claudeOAuth2 *claude.OAuth2 @@ -127,6 +139,24 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { u, cmd := m.apiKeyInput.Update(msg) m.apiKeyInput = u.(*APIKeyInput) return m, cmd + case hyper.DeviceFlowCompletedMsg: + return m, m.saveOauthTokenAndContinue(msg.Token, true) + case hyper.DeviceAuthInitiatedMsg, hyper.DeviceFlowErrorMsg: + if m.hyperDeviceFlow != nil { + u, cmd := m.hyperDeviceFlow.Update(msg) + m.hyperDeviceFlow = u.(*hyper.DeviceFlow) + return m, cmd + } + return m, nil + case copilot.DeviceAuthInitiatedMsg, copilot.DeviceFlowErrorMsg: + if m.copilotDeviceFlow != nil { + u, cmd := m.copilotDeviceFlow.Update(msg) + m.copilotDeviceFlow = u.(*copilot.DeviceFlow) + return m, cmd + } + return m, nil + case copilot.DeviceFlowCompletedMsg: + return m, m.saveOauthTokenAndContinue(msg.Token, true) case claude.ValidationCompletedMsg: var cmds []tea.Cmd u, cmd := m.claudeOAuth2.Update(msg) @@ -134,7 +164,7 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { cmds = append(cmds, cmd) if msg.State == claude.OAuthValidationStateValid { - cmds = append(cmds, m.saveAPIKeyAndContinue(msg.Token, false)) + cmds = append(cmds, m.saveOauthTokenAndContinue(msg.Token, false)) m.keyMap.isClaudeOAuthHelpComplete = true } @@ -143,6 +173,14 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { return m, util.CmdHandler(dialogs.CloseDialogMsg{}) case tea.KeyPressMsg: switch { + // Handle Hyper device flow keys + case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && (m.showHyperDeviceFlow || m.showCopilotDeviceFlow): + if m.hyperDeviceFlow != nil { + return m, m.hyperDeviceFlow.CopyCode() + } + if m.copilotDeviceFlow != nil { + return m, m.copilotDeviceFlow.CopyCode() + } case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showClaudeOAuth2 && m.claudeOAuth2.State == claude.OAuthStateURL: return m, tea.Sequence( tea.SetClipboard(m.claudeOAuth2.URL), @@ -156,6 +194,13 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { m.claudeAuthMethodChooser.ToggleChoice() return m, nil case key.Matches(msg, m.keyMap.Select): + // If showing device flow, enter copies code and opens URL + if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil { + return m, m.hyperDeviceFlow.CopyCodeAndOpenURL() + } + if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil { + return m, m.copilotDeviceFlow.CopyCodeAndOpenURL() + } selectedItem := m.modelList.SelectedModel() modelType := config.SelectedModelTypeLarge @@ -167,6 +212,8 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { m.keyMap.isClaudeAuthChoiceHelp = false m.keyMap.isClaudeOAuthHelp = false m.keyMap.isAPIKeyHelp = true + m.showHyperDeviceFlow = false + m.showCopilotDeviceFlow = false m.showClaudeAuthMethodChooser = false m.needsAPIKey = true m.selectedModel = selectedItem @@ -194,7 +241,7 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { return m, cmd2 } if m.isAPIKeyValid { - return m, m.saveAPIKeyAndContinue(m.apiKeyValue, true) + return m, m.saveOauthTokenAndContinue(m.apiKeyValue, true) } if m.needsAPIKey { // Handle API key submission @@ -249,15 +296,30 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { ModelType: modelType, }), ) - } else { - if selectedItem.Provider.ID == catwalk.InferenceProviderAnthropic { - m.showClaudeAuthMethodChooser = true - m.keyMap.isClaudeAuthChoiceHelp = true - return m, nil - } - askForApiKey() + } + switch selectedItem.Provider.ID { + case catwalk.InferenceProviderAnthropic: + m.showClaudeAuthMethodChooser = true + m.keyMap.isClaudeAuthChoiceHelp = true return m, nil + case hyperp.Name: + m.showHyperDeviceFlow = true + m.selectedModel = selectedItem + m.selectedModelType = modelType + m.hyperDeviceFlow = hyper.NewDeviceFlow() + m.hyperDeviceFlow.SetWidth(m.width - 2) + return m, m.hyperDeviceFlow.Init() + case catwalk.InferenceProviderCopilot: + m.showCopilotDeviceFlow = true + m.selectedModel = selectedItem + m.selectedModelType = modelType + m.copilotDeviceFlow = copilot.NewDeviceFlow() + m.copilotDeviceFlow.SetWidth(m.width - 2) + return m, m.copilotDeviceFlow.Init() } + // For other providers, show API key input + askForApiKey() + return m, nil case key.Matches(msg, m.keyMap.Tab): switch { case m.showClaudeAuthMethodChooser: @@ -275,6 +337,20 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { return m, m.modelList.SetModelType(LargeModelType) } case key.Matches(msg, m.keyMap.Close): + if m.showHyperDeviceFlow { + if m.hyperDeviceFlow != nil { + m.hyperDeviceFlow.Cancel() + } + m.showHyperDeviceFlow = false + m.selectedModel = nil + } + if m.showCopilotDeviceFlow { + if m.copilotDeviceFlow != nil { + m.copilotDeviceFlow.Cancel() + } + m.showCopilotDeviceFlow = false + m.selectedModel = nil + } if m.showClaudeAuthMethodChooser { m.claudeAuthMethodChooser.SetDefaults() m.showClaudeAuthMethodChooser = false @@ -329,11 +405,33 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { return m, cmd } case spinner.TickMsg: - if m.showClaudeOAuth2 { + u, cmd := m.apiKeyInput.Update(msg) + m.apiKeyInput = u.(*APIKeyInput) + if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil { + u, cmd = m.hyperDeviceFlow.Update(msg) + m.hyperDeviceFlow = u.(*hyper.DeviceFlow) + } + if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil { + u, cmd = m.copilotDeviceFlow.Update(msg) + m.copilotDeviceFlow = u.(*copilot.DeviceFlow) + } + return m, cmd + default: + // Pass all other messages to the device flow for spinner animation + switch { + case m.showHyperDeviceFlow && m.hyperDeviceFlow != nil: + u, cmd := m.hyperDeviceFlow.Update(msg) + m.hyperDeviceFlow = u.(*hyper.DeviceFlow) + return m, cmd + case m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil: + u, cmd := m.copilotDeviceFlow.Update(msg) + m.copilotDeviceFlow = u.(*copilot.DeviceFlow) + return m, cmd + case m.showClaudeOAuth2: u, cmd := m.claudeOAuth2.Update(msg) m.claudeOAuth2 = u.(*claude.OAuth2) return m, cmd - } else { + default: u, cmd := m.apiKeyInput.Update(msg) m.apiKeyInput = u.(*APIKeyInput) return m, cmd @@ -345,6 +443,39 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { func (m *modelDialogCmp) View() string { t := styles.CurrentTheme() + if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil { + // Show Hyper device flow + m.keyMap.isHyperDeviceFlow = true + deviceFlowView := m.hyperDeviceFlow.View() + content := lipgloss.JoinVertical( + lipgloss.Left, + t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Authenticate with Hyper", m.width-4)), + deviceFlowView, + "", + t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)), + ) + return m.style().Render(content) + } + if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil { + // Show Hyper device flow + m.keyMap.isCopilotDeviceFlow = m.copilotDeviceFlow.State != copilot.DeviceFlowStateUnavailable + m.keyMap.isCopilotUnavailable = m.copilotDeviceFlow.State == copilot.DeviceFlowStateUnavailable + deviceFlowView := m.copilotDeviceFlow.View() + content := lipgloss.JoinVertical( + lipgloss.Left, + t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Authenticate with GitHub Copilot", m.width-4)), + deviceFlowView, + "", + t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)), + ) + return m.style().Render(content) + } + + // Reset the flags when not showing device flow + m.keyMap.isHyperDeviceFlow = false + m.keyMap.isCopilotDeviceFlow = false + m.keyMap.isCopilotUnavailable = false + switch { case m.showClaudeAuthMethodChooser: chooserView := m.claudeAuthMethodChooser.View() @@ -397,6 +528,12 @@ func (m *modelDialogCmp) View() string { } func (m *modelDialogCmp) Cursor() *tea.Cursor { + if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil { + return m.hyperDeviceFlow.Cursor() + } + if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil { + return m.copilotDeviceFlow.Cursor() + } if m.showClaudeAuthMethodChooser { return nil } @@ -477,10 +614,8 @@ func (m *modelDialogCmp) modelTypeRadio() string { func (m *modelDialogCmp) isProviderConfigured(providerID string) bool { cfg := config.Get() - if _, ok := cfg.Providers.Get(providerID); ok { - return true - } - return false + _, ok := cfg.Providers.Get(providerID) + return ok } func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) { @@ -497,7 +632,7 @@ func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*cat return nil, nil } -func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey any, close bool) tea.Cmd { +func (m *modelDialogCmp) saveOauthTokenAndContinue(apiKey any, close bool) tea.Cmd { if m.selectedModel == nil { return util.ReportError(fmt.Errorf("no model selected")) } diff --git a/internal/tui/components/mcp/mcp.go b/internal/tui/components/mcp/mcp.go index 782a776c5eefb946e0b858f6711bc5ec0ac705fd..78763ac85fdbb5b75e281ef39289f490e6bde949 100644 --- a/internal/tui/components/mcp/mcp.go +++ b/internal/tui/components/mcp/mcp.go @@ -69,10 +69,18 @@ func RenderMCPList(opts RenderOptions) []string { case mcp.StateConnected: icon = t.ItemOnlineIcon if count := state.Counts.Tools; count > 0 { - extraContent = append(extraContent, t.S().Subtle.Render(fmt.Sprintf("%d tools", count))) + label := "tools" + if count == 1 { + label = "tool" + } + extraContent = append(extraContent, t.S().Subtle.Render(fmt.Sprintf("%d %s", count, label))) } if count := state.Counts.Prompts; count > 0 { - extraContent = append(extraContent, t.S().Subtle.Render(fmt.Sprintf("%d prompts", count))) + label := "prompts" + if count == 1 { + label = "prompt" + } + extraContent = append(extraContent, t.S().Subtle.Render(fmt.Sprintf("%d %s", count, label))) } case mcp.StateError: icon = t.ItemErrorIcon diff --git a/internal/tui/exp/list/items.go b/internal/tui/exp/list/items.go index fa89fb6e7a58a5fc0d9e6bcab36a979130c482e9..3db5635b044d9845915d005dd5f7cdac233fe53f 100644 --- a/internal/tui/exp/list/items.go +++ b/internal/tui/exp/list/items.go @@ -323,6 +323,7 @@ type ItemSection interface { layout.Sizeable Indexable SetInfo(info string) + Title() string } type itemSectionModel struct { width int @@ -337,6 +338,11 @@ func (m *itemSectionModel) ID() string { return m.id } +// Title implements ItemSection. +func (m *itemSectionModel) Title() string { + return m.title +} + func NewItemSection(title string) ItemSection { return &itemSectionModel{ title: title, diff --git a/internal/ui/dialog/sessions_item.go b/internal/ui/dialog/sessions_item.go index 080de32dba34a8fa50f2b40db2d8335fdcb911d9..9860581eb3c967154b09b42217e2283192962373 100644 --- a/internal/ui/dialog/sessions_item.go +++ b/internal/ui/dialog/sessions_item.go @@ -86,13 +86,20 @@ func renderItem(t *styles.Styles, title string, updatedAt int64, focused bool, w } var ageLen int + var right string + lineWidth := width if updatedAt > 0 { ageLen = lipgloss.Width(age) + lineWidth -= ageLen } - title = ansi.Truncate(title, max(0, width-ageLen), "…") + title = ansi.Truncate(title, max(0, lineWidth), "…") titleLen := lipgloss.Width(title) - right := lipgloss.NewStyle().AlignHorizontal(lipgloss.Right).Width(width - titleLen).Render(age) + + if updatedAt > 0 { + right = lipgloss.NewStyle().AlignHorizontal(lipgloss.Right).Width(width - titleLen).Render(age) + } + content := title if matches := len(m.MatchedIndexes); matches > 0 { var lastPos int diff --git a/internal/ui/list/filterable.go b/internal/ui/list/filterable.go index c45db41da2cc6be8bd61fba57818e5a7d902f5cd..de78041e3c2666830b6f5ce695472d46448abf0f 100644 --- a/internal/ui/list/filterable.go +++ b/internal/ui/list/filterable.go @@ -70,13 +70,17 @@ func (f *FilterableList) SetFilter(q string) { f.query = q } -type filterableItems []FilterableItem +// FilterableItemsSource is a type that implements [fuzzy.Source] for filtering +// [FilterableItem]s. +type FilterableItemsSource []FilterableItem -func (f filterableItems) Len() int { +// Len returns the length of the source. +func (f FilterableItemsSource) Len() int { return len(f) } -func (f filterableItems) String(i int) string { +// String returns the string representation of the item at index i. +func (f FilterableItemsSource) String(i int) string { return f[i].Filter() } @@ -94,7 +98,7 @@ func (f *FilterableList) VisibleItems() []Item { return items } - items := filterableItems(f.items) + items := FilterableItemsSource(f.items) matches := fuzzy.FindFrom(f.query, items) matchedItems := []Item{} resultSize := len(matches) diff --git a/internal/ui/list/item.go b/internal/ui/list/item.go index a544b85b37dedf889cdc1ecb6ae77388040907f2..62b31a696eee11b5dc11f0228d82ccfa8a0c91e5 100644 --- a/internal/ui/list/item.go +++ b/internal/ui/list/item.go @@ -1,6 +1,8 @@ package list import ( + "strings" + "github.com/charmbracelet/x/ansi" ) @@ -30,3 +32,20 @@ type MouseClickable interface { // It returns true if the event was handled, false otherwise. HandleMouseClick(btn ansi.MouseButton, x, y int) bool } + +// SpacerItem is a spacer item that adds vertical space in the list. +type SpacerItem struct { + Height int +} + +// NewSpacerItem creates a new [SpacerItem] with the specified height. +func NewSpacerItem(height int) *SpacerItem { + return &SpacerItem{ + Height: max(0, height-1), + } +} + +// Render implements the Item interface for [SpacerItem]. +func (s *SpacerItem) Render(width int) string { + return strings.Repeat("\n", s.Height) +} diff --git a/internal/ui/list/list.go b/internal/ui/list/list.go index 3e9fe124b0ddf1e55b3e920bc2828d4efabbc996..fddf0538a13b9adfce781ded62d1179d9fb609a5 100644 --- a/internal/ui/list/list.go +++ b/internal/ui/list/list.go @@ -390,6 +390,7 @@ func (l *List) SelectedItemInView() bool { } // SetSelected sets the selected item index in the list. +// It returns -1 if the index is out of bounds. func (l *List) SetSelected(index int) { if index < 0 || index >= len(l.items) { l.selectedIdx = -1 @@ -415,31 +416,43 @@ func (l *List) IsSelectedLast() bool { } // SelectPrev selects the previous item in the list. -func (l *List) SelectPrev() { +// It returns whether the selection changed. +func (l *List) SelectPrev() bool { if l.selectedIdx > 0 { l.selectedIdx-- + return true } + return false } // SelectNext selects the next item in the list. -func (l *List) SelectNext() { +// It returns whether the selection changed. +func (l *List) SelectNext() bool { if l.selectedIdx < len(l.items)-1 { l.selectedIdx++ + return true } + return false } // SelectFirst selects the first item in the list. -func (l *List) SelectFirst() { +// It returns whether the selection changed. +func (l *List) SelectFirst() bool { if len(l.items) > 0 { l.selectedIdx = 0 + return true } + return false } // SelectLast selects the last item in the list. -func (l *List) SelectLast() { +// It returns whether the selection changed. +func (l *List) SelectLast() bool { if len(l.items) > 0 { l.selectedIdx = len(l.items) - 1 + return true } + return false } // SelectedItem returns the currently selected item. It may be nil if no item diff --git a/internal/ui/model/chat.go b/internal/ui/model/chat.go index 7c19f4d6c49e7f33c57979a0f9f4a2230a780670..f506d469b6b90292af6373d20f2a296fa7d0ac6a 100644 --- a/internal/ui/model/chat.go +++ b/internal/ui/model/chat.go @@ -169,30 +169,30 @@ func (m *Chat) Blur() { m.list.Blur() } -// ScrollToTop scrolls the chat view to the top and returns a command to restart +// ScrollToTopAndAnimate scrolls the chat view to the top and returns a command to restart // any paused animations that are now visible. -func (m *Chat) ScrollToTop() tea.Cmd { +func (m *Chat) ScrollToTopAndAnimate() tea.Cmd { m.list.ScrollToTop() return m.RestartPausedVisibleAnimations() } -// ScrollToBottom scrolls the chat view to the bottom and returns a command to +// ScrollToBottomAndAnimate scrolls the chat view to the bottom and returns a command to // restart any paused animations that are now visible. -func (m *Chat) ScrollToBottom() tea.Cmd { +func (m *Chat) ScrollToBottomAndAnimate() tea.Cmd { m.list.ScrollToBottom() return m.RestartPausedVisibleAnimations() } -// ScrollBy scrolls the chat view by the given number of line deltas and returns +// ScrollByAndAnimate scrolls the chat view by the given number of line deltas and returns // a command to restart any paused animations that are now visible. -func (m *Chat) ScrollBy(lines int) tea.Cmd { +func (m *Chat) ScrollByAndAnimate(lines int) tea.Cmd { m.list.ScrollBy(lines) return m.RestartPausedVisibleAnimations() } -// ScrollToSelected scrolls the chat view to the selected item and returns a +// ScrollToSelectedAndAnimate scrolls the chat view to the selected item and returns a // command to restart any paused animations that are now visible. -func (m *Chat) ScrollToSelected() tea.Cmd { +func (m *Chat) ScrollToSelectedAndAnimate() tea.Cmd { m.list.ScrollToSelected() return m.RestartPausedVisibleAnimations() } diff --git a/internal/ui/model/mcp.go b/internal/ui/model/mcp.go index 2a58e15ac10175f29d6180aa7e98d954a644b34b..4100907d2c58f4238eb080356a069cf9bd0a2da6 100644 --- a/internal/ui/model/mcp.go +++ b/internal/ui/model/mcp.go @@ -16,8 +16,10 @@ func (m *UI) mcpInfo(width, maxItems int, isSection bool) string { var mcps []mcp.ClientInfo t := m.com.Styles - for _, state := range m.mcpStates { - mcps = append(mcps, state) + for _, mcp := range m.com.Config().MCP.Sorted() { + if state, ok := m.mcpStates[mcp.Name]; ok { + mcps = append(mcps, state) + } } title := t.Subtle.Render("MCPs") diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go index e1cfcf0e66eb4c407678e9b65b080827989aea56..5fe8eeb3689a6a185c0e2cd6af7a1758120a5178 100644 --- a/internal/ui/model/ui.go +++ b/internal/ui/model/ui.go @@ -267,22 +267,22 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch m.state { case uiChat: if msg.Y <= 0 { - if cmd := m.chat.ScrollBy(-1); cmd != nil { + if cmd := m.chat.ScrollByAndAnimate(-1); cmd != nil { cmds = append(cmds, cmd) } if !m.chat.SelectedItemInView() { m.chat.SelectPrev() - if cmd := m.chat.ScrollToSelected(); cmd != nil { + if cmd := m.chat.ScrollToSelectedAndAnimate(); cmd != nil { cmds = append(cmds, cmd) } } } else if msg.Y >= m.chat.Height()-1 { - if cmd := m.chat.ScrollBy(1); cmd != nil { + if cmd := m.chat.ScrollByAndAnimate(1); cmd != nil { cmds = append(cmds, cmd) } if !m.chat.SelectedItemInView() { m.chat.SelectNext() - if cmd := m.chat.ScrollToSelected(); cmd != nil { + if cmd := m.chat.ScrollToSelectedAndAnimate(); cmd != nil { cmds = append(cmds, cmd) } } @@ -309,22 +309,22 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case uiChat: switch msg.Button { case tea.MouseWheelUp: - if cmd := m.chat.ScrollBy(-5); cmd != nil { + if cmd := m.chat.ScrollByAndAnimate(-5); cmd != nil { cmds = append(cmds, cmd) } if !m.chat.SelectedItemInView() { m.chat.SelectPrev() - if cmd := m.chat.ScrollToSelected(); cmd != nil { + if cmd := m.chat.ScrollToSelectedAndAnimate(); cmd != nil { cmds = append(cmds, cmd) } } case tea.MouseWheelDown: - if cmd := m.chat.ScrollBy(5); cmd != nil { + if cmd := m.chat.ScrollByAndAnimate(5); cmd != nil { cmds = append(cmds, cmd) } if !m.chat.SelectedItemInView() { m.chat.SelectNext() - if cmd := m.chat.ScrollToSelected(); cmd != nil { + if cmd := m.chat.ScrollToSelectedAndAnimate(); cmd != nil { cmds = append(cmds, cmd) } } @@ -394,7 +394,7 @@ func (m *UI) setSessionMessages(msgs []message.Message) tea.Cmd { } m.chat.SetMessages(items...) - if cmd := m.chat.ScrollToBottom(); cmd != nil { + if cmd := m.chat.ScrollToBottomAndAnimate(); cmd != nil { cmds = append(cmds, cmd) } m.chat.SelectLast() @@ -416,7 +416,7 @@ func (m *UI) appendSessionMessage(msg message.Message) tea.Cmd { } } m.chat.AppendMessages(items...) - if cmd := m.chat.ScrollToBottom(); cmd != nil { + if cmd := m.chat.ScrollToBottomAndAnimate(); cmd != nil { cmds = append(cmds, cmd) } case message.Tool: @@ -472,7 +472,7 @@ func (m *UI) updateSessionMessage(msg message.Message) tea.Cmd { } } m.chat.AppendMessages(items...) - if cmd := m.chat.ScrollToBottom(); cmd != nil { + if cmd := m.chat.ScrollToBottomAndAnimate(); cmd != nil { cmds = append(cmds, cmd) } @@ -641,62 +641,62 @@ func (m *UI) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd { case key.Matches(msg, m.keyMap.Chat.Expand): m.chat.ToggleExpandedSelectedItem() case key.Matches(msg, m.keyMap.Chat.Up): - if cmd := m.chat.ScrollBy(-1); cmd != nil { + if cmd := m.chat.ScrollByAndAnimate(-1); cmd != nil { cmds = append(cmds, cmd) } if !m.chat.SelectedItemInView() { m.chat.SelectPrev() - if cmd := m.chat.ScrollToSelected(); cmd != nil { + if cmd := m.chat.ScrollToSelectedAndAnimate(); cmd != nil { cmds = append(cmds, cmd) } } case key.Matches(msg, m.keyMap.Chat.Down): - if cmd := m.chat.ScrollBy(1); cmd != nil { + if cmd := m.chat.ScrollByAndAnimate(1); cmd != nil { cmds = append(cmds, cmd) } if !m.chat.SelectedItemInView() { m.chat.SelectNext() - if cmd := m.chat.ScrollToSelected(); cmd != nil { + if cmd := m.chat.ScrollToSelectedAndAnimate(); cmd != nil { cmds = append(cmds, cmd) } } case key.Matches(msg, m.keyMap.Chat.UpOneItem): m.chat.SelectPrev() - if cmd := m.chat.ScrollToSelected(); cmd != nil { + if cmd := m.chat.ScrollToSelectedAndAnimate(); cmd != nil { cmds = append(cmds, cmd) } case key.Matches(msg, m.keyMap.Chat.DownOneItem): m.chat.SelectNext() - if cmd := m.chat.ScrollToSelected(); cmd != nil { + if cmd := m.chat.ScrollToSelectedAndAnimate(); cmd != nil { cmds = append(cmds, cmd) } case key.Matches(msg, m.keyMap.Chat.HalfPageUp): - if cmd := m.chat.ScrollBy(-m.chat.Height() / 2); cmd != nil { + if cmd := m.chat.ScrollByAndAnimate(-m.chat.Height() / 2); cmd != nil { cmds = append(cmds, cmd) } m.chat.SelectFirstInView() case key.Matches(msg, m.keyMap.Chat.HalfPageDown): - if cmd := m.chat.ScrollBy(m.chat.Height() / 2); cmd != nil { + if cmd := m.chat.ScrollByAndAnimate(m.chat.Height() / 2); cmd != nil { cmds = append(cmds, cmd) } m.chat.SelectLastInView() case key.Matches(msg, m.keyMap.Chat.PageUp): - if cmd := m.chat.ScrollBy(-m.chat.Height()); cmd != nil { + if cmd := m.chat.ScrollByAndAnimate(-m.chat.Height()); cmd != nil { cmds = append(cmds, cmd) } m.chat.SelectFirstInView() case key.Matches(msg, m.keyMap.Chat.PageDown): - if cmd := m.chat.ScrollBy(m.chat.Height()); cmd != nil { + if cmd := m.chat.ScrollByAndAnimate(m.chat.Height()); cmd != nil { cmds = append(cmds, cmd) } m.chat.SelectLastInView() case key.Matches(msg, m.keyMap.Chat.Home): - if cmd := m.chat.ScrollToTop(); cmd != nil { + if cmd := m.chat.ScrollToTopAndAnimate(); cmd != nil { cmds = append(cmds, cmd) } m.chat.SelectFirst() case key.Matches(msg, m.keyMap.Chat.End): - if cmd := m.chat.ScrollToBottom(); cmd != nil { + if cmd := m.chat.ScrollToBottomAndAnimate(); cmd != nil { cmds = append(cmds, cmd) } m.chat.SelectLast()