.gitattributes 🔗
@@ -1,3 +1,4 @@
*.golden linguist-generated=true -text
.github/crush-schema.json linguist-generated=true
+internal/agent/hyper/provider.json linguist-generated=true
internal/agent/testdata/**/*.yaml -diff linguist-generated=true
Carlos Alexandro Becker created
.gitattributes | 1
.github/workflows/schema-update.yml | 6
Taskfile.yaml | 7
go.mod | 2
internal/agent/agent.go | 6
internal/agent/coordinator.go | 15
internal/agent/hyper/provider.go | 330 +++++++++++++
internal/agent/hyper/provider.json | 0
internal/cmd/login.go | 101 +++
internal/cmd/update_providers.go | 38 +
internal/config/catwalk.go | 82 +++
internal/config/catwalk_test.go | 221 +++++++++
internal/config/config.go | 21
internal/config/hyper.go | 124 +++++
internal/config/hyper_test.go | 205 ++++++++
internal/config/load.go | 3
internal/config/provider.go | 204 +++++---
internal/config/provider_empty_test.go | 28
internal/config/provider_test.go | 316 ++++++++++---
internal/event/event.go | 16
internal/oauth/hyper/device.go | 251 ++++++++++
internal/oauth/token.go | 2
internal/tui/components/dialogs/commands/commands.go | 3
internal/tui/components/dialogs/hyper/device_flow.go | 264 +++++++++++
internal/tui/components/dialogs/models/keys.go | 15
internal/tui/components/dialogs/models/list.go | 20
internal/tui/components/dialogs/models/models.go | 103 +++
internal/tui/exp/list/items.go | 6
28 files changed, 2,177 insertions(+), 213 deletions(-)
@@ -1,3 +1,4 @@
*.golden linguist-generated=true -text
.github/crush-schema.json linguist-generated=true
+internal/agent/hyper/provider.json linguist-generated=true
internal/agent/testdata/**/*.yaml -diff linguist-generated=true
@@ -1,10 +1,11 @@
-name: Update Schema
+name: Update files
on:
push:
branches: [main]
paths:
- "internal/config/**"
+ - "internal/agent/hyper/**"
jobs:
update-schema:
@@ -17,9 +18,10 @@ jobs:
with:
go-version-file: go.mod
- run: go run . schema > ./schema.json
+ - run: go generate ./internal/agent/hyper/...
- uses: stefanzweifel/git-auto-commit-action@28e16e81777b558cc906c8750092100bbb34c5e3 # v5
with:
- commit_message: "chore: auto-update generated files"
+ commit_message: "chore: auto-update files"
branch: main
commit_user_name: Charm
commit_user_email: 124303983+charmcli@users.noreply.github.com
@@ -101,6 +101,13 @@ tasks:
generates:
- schema.json
+ hyper:
+ desc: Update Hyper embedded provider.json
+ cmds:
+ - go generate ./internal/agent/hyper/...
+ generates:
+ - ./internal/agent/hyper/provider.json
+
release:
desc: Create and push a new tag following semver
vars:
@@ -23,6 +23,7 @@ require (
github.com/charmbracelet/fang v0.4.4
github.com/charmbracelet/ultraviolet v0.0.0-20251212194010-b927aa605560
github.com/charmbracelet/x/ansi v0.11.3
+ github.com/charmbracelet/x/etag v0.2.0
github.com/charmbracelet/x/exp/charmtone v0.0.0-20250708181618-a60a724ba6c3
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f
github.com/charmbracelet/x/exp/ordered v0.1.0
@@ -91,7 +92,6 @@ require (
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20251024181547-21d6f3d9a904 // indirect
- github.com/charmbracelet/x/etag v0.2.0 // indirect
github.com/charmbracelet/x/json v0.2.0 // indirect
github.com/charmbracelet/x/termios v0.1.1 // indirect
github.com/charmbracelet/x/windows v0.2.2 // indirect
@@ -27,7 +27,9 @@ import (
"charm.land/fantasy/providers/google"
"charm.land/fantasy/providers/openai"
"charm.land/fantasy/providers/openrouter"
+ "charm.land/lipgloss/v2"
"github.com/charmbracelet/catwalk/pkg/catwalk"
+ "github.com/charmbracelet/crush/internal/agent/hyper"
"github.com/charmbracelet/crush/internal/agent/tools"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/csync"
@@ -454,6 +456,10 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
} else if isPermissionErr {
currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
+ } else if errors.Is(err, hyper.ErrNoCredits) {
+ url := hyper.BaseURL()
+ link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
+ currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
} else if errors.As(err, &providerErr) {
currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
} else if errors.As(err, &fantasyErr) {
@@ -17,6 +17,7 @@ import (
"charm.land/fantasy"
"github.com/charmbracelet/catwalk/pkg/catwalk"
+ "github.com/charmbracelet/crush/internal/agent/hyper"
"github.com/charmbracelet/crush/internal/agent/prompt"
"github.com/charmbracelet/crush/internal/agent/tools"
"github.com/charmbracelet/crush/internal/config"
@@ -668,6 +669,18 @@ func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, optio
return google.New(opts...)
}
+func (c *coordinator) buildHyperProvider(baseURL, apiKey string) (fantasy.Provider, error) {
+ opts := []hyper.Option{
+ hyper.WithBaseURL(baseURL),
+ hyper.WithAPIKey(apiKey),
+ }
+ if c.cfg.Options.Debug {
+ httpClient := log.NewHTTPClient()
+ opts = append(opts, hyper.WithHTTPClient(httpClient))
+ }
+ return hyper.New(opts...)
+}
+
func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
if model.Think {
return true
@@ -728,6 +741,8 @@ func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model con
providerCfg.ExtraBody["tool_stream"] = true
}
return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody)
+ case hyper.Name:
+ return c.buildHyperProvider(baseURL, apiKey)
default:
return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
}
@@ -0,0 +1,330 @@
+// Package hyper provides a fantasy.Provider that proxies requests to Hyper.
+package hyper
+
+import (
+ "bufio"
+ "bytes"
+ "cmp"
+ "context"
+ _ "embed"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "log/slog"
+ "maps"
+ "net/http"
+ "net/url"
+ "os"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "charm.land/fantasy"
+ "charm.land/fantasy/object"
+ "github.com/charmbracelet/catwalk/pkg/catwalk"
+ "github.com/charmbracelet/crush/internal/event"
+)
+
+//go:generate wget -O provider.json https://console.charm.land/api/v1/provider
+
+//go:embed provider.json
+var embedded []byte
+
+// Enabled returns true if hyper is enabled.
+var Enabled = sync.OnceValue(func() bool {
+ b, _ := strconv.ParseBool(
+ cmp.Or(
+ os.Getenv("HYPER"),
+ os.Getenv("HYPERCRUSH"),
+ os.Getenv("HYPER_ENABLE"),
+ os.Getenv("HYPER_ENABLED"),
+ ),
+ )
+ return b
+})
+
+// Embedded returns the embedded Hyper provider.
+var Embedded = sync.OnceValue(func() catwalk.Provider {
+ var provider catwalk.Provider
+ if err := json.Unmarshal(embedded, &provider); err != nil {
+ slog.Error("could not use embedded provider data", "err", err)
+ }
+ return provider
+})
+
+const (
+ // Name is the default name of this meta provider.
+ Name = "hyper"
+ // defaultBaseURL is the default proxy URL.
+ // TODO: change this to production URL when ready.
+ defaultBaseURL = "https://console.charm.land"
+)
+
+// BaseURL returns the base URL, which is either $HYPER_URL or the default.
+var BaseURL = sync.OnceValue(func() string {
+ return cmp.Or(os.Getenv("HYPER_URL"), defaultBaseURL)
+})
+
+var ErrNoCredits = errors.New("you're out of credits")
+
+type options struct {
+ baseURL string
+ apiKey string
+ name string
+ headers map[string]string
+ client *http.Client
+}
+
+// Option configures the proxy provider.
+type Option = func(*options)
+
+// New creates a new proxy provider.
+func New(opts ...Option) (fantasy.Provider, error) {
+ o := options{
+ baseURL: BaseURL() + "/api/v1/fantasy",
+ name: Name,
+ headers: map[string]string{
+ "x-crush-id": event.GetID(),
+ },
+ client: &http.Client{Timeout: 0}, // stream-safe
+ }
+ for _, opt := range opts {
+ opt(&o)
+ }
+ return &provider{options: o}, nil
+}
+
+// WithBaseURL sets the proxy base URL (e.g. http://localhost:8080).
+func WithBaseURL(url string) Option { return func(o *options) { o.baseURL = url } }
+
+// WithName sets the provider name.
+func WithName(name string) Option { return func(o *options) { o.name = name } }
+
+// WithHeaders sets extra headers sent to the proxy.
+func WithHeaders(headers map[string]string) Option {
+ return func(o *options) {
+ maps.Copy(o.headers, headers)
+ }
+}
+
+// WithHTTPClient sets custom HTTP client.
+func WithHTTPClient(c *http.Client) Option { return func(o *options) { o.client = c } }
+
+// WithAPIKey sets the API key.
+func WithAPIKey(key string) Option {
+ return func(o *options) {
+ o.apiKey = key
+ }
+}
+
+type provider struct{ options options }
+
+func (p *provider) Name() string { return p.options.name }
+
+// LanguageModel implements fantasy.Provider.
+func (p *provider) LanguageModel(_ context.Context, modelID string) (fantasy.LanguageModel, error) {
+ if modelID == "" {
+ return nil, errors.New("missing model id")
+ }
+ return &languageModel{modelID: modelID, provider: p.options.name, opts: p.options}, nil
+}
+
+type languageModel struct {
+ provider string
+ modelID string
+ opts options
+}
+
+// GenerateObject implements fantasy.LanguageModel.
+func (m *languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
+ return object.GenerateWithTool(ctx, m, call)
+}
+
+// StreamObject implements fantasy.LanguageModel.
+func (m *languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
+ return object.StreamWithTool(ctx, m, call)
+}
+
+func (m *languageModel) Provider() string { return m.provider }
+func (m *languageModel) Model() string { return m.modelID }
+
+// Generate implements fantasy.LanguageModel by calling the proxy JSON endpoint.
+func (m *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
+ resp, err := m.doRequest(ctx, false, call)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = resp.Body.Close() }()
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ b, _ := ioReadAllLimit(resp.Body, 64*1024)
+ return nil, fmt.Errorf("proxy generate error: %s", strings.TrimSpace(string(b)))
+ }
+ var out fantasy.Response
+ if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
+ return nil, err
+ }
+ return &out, nil
+}
+
+// Stream implements fantasy.LanguageModel using SSE from the proxy.
+func (m *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
+ // Prefer explicit /stream endpoint
+ resp, err := m.doRequest(ctx, true, call)
+ if err != nil {
+ return nil, err
+ }
+ switch resp.StatusCode {
+ case http.StatusTooManyRequests:
+ _ = resp.Body.Close()
+ return nil, toProviderError(resp, retryAfter(resp))
+ case http.StatusUnauthorized:
+ _ = resp.Body.Close()
+ return nil, toProviderError(resp, "")
+ case http.StatusPaymentRequired:
+ _ = resp.Body.Close()
+ return nil, ErrNoCredits
+ }
+
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ defer func() { _ = resp.Body.Close() }()
+ b, _ := ioReadAllLimit(resp.Body, 64*1024)
+ return nil, &fantasy.ProviderError{
+ Title: "Stream Error",
+ Message: strings.TrimSpace(string(b)),
+ StatusCode: resp.StatusCode,
+ }
+ }
+
+ return func(yield func(fantasy.StreamPart) bool) {
+ defer func() { _ = resp.Body.Close() }()
+ scanner := bufio.NewScanner(resp.Body)
+ buf := make([]byte, 0, 64*1024)
+ scanner.Buffer(buf, 4*1024*1024)
+
+ var (
+ event string
+ dataBuf bytes.Buffer
+ sawFinish bool
+ dispatch = func() bool {
+ if dataBuf.Len() == 0 || event == "" {
+ dataBuf.Reset()
+ event = ""
+ return true
+ }
+ var part fantasy.StreamPart
+ if err := json.Unmarshal(dataBuf.Bytes(), &part); err != nil {
+ return yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err})
+ }
+ if part.Type == fantasy.StreamPartTypeFinish {
+ sawFinish = true
+ }
+ ok := yield(part)
+ dataBuf.Reset()
+ event = ""
+ return ok
+ }
+ )
+
+ for scanner.Scan() {
+ line := scanner.Text()
+ if line == "" { // event boundary
+ if !dispatch() {
+ return
+ }
+ continue
+ }
+ if strings.HasPrefix(line, ":") { // comment / ping
+ continue
+ }
+ if strings.HasPrefix(line, "event: ") {
+ event = strings.TrimSpace(line[len("event: "):])
+ continue
+ }
+ if strings.HasPrefix(line, "data: ") {
+ if dataBuf.Len() > 0 {
+ dataBuf.WriteByte('\n')
+ }
+ dataBuf.WriteString(line[len("data: "):])
+ continue
+ }
+ }
+ if err := scanner.Err(); err != nil &&
+ !errors.Is(err, context.Canceled) &&
+ !errors.Is(err, context.DeadlineExceeded) {
+ yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: err})
+ return
+ }
+ // flush any pending data
+ _ = dispatch()
+ if !sawFinish {
+ _ = yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish})
+ }
+ }, nil
+}
+
+func (m *languageModel) doRequest(ctx context.Context, stream bool, call fantasy.Call) (*http.Response, error) {
+ addr, err := url.Parse(m.opts.baseURL)
+ if err != nil {
+ return nil, err
+ }
+ addr = addr.JoinPath(m.modelID)
+ if stream {
+ addr = addr.JoinPath("stream")
+ } else {
+ addr = addr.JoinPath("generate")
+ }
+
+ body, err := json.Marshal(call)
+ if err != nil {
+ return nil, err
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, addr.String(), bytes.NewReader(body))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+ if stream {
+ req.Header.Set("Accept", "text/event-stream")
+ } else {
+ req.Header.Set("Accept", "application/json")
+ }
+ for k, v := range m.opts.headers {
+ req.Header.Set(k, v)
+ }
+
+ if m.opts.apiKey != "" {
+ req.Header.Set("Authorization", m.opts.apiKey)
+ }
+ return m.opts.client.Do(req)
+}
+
+// ioReadAllLimit reads up to n bytes.
+func ioReadAllLimit(r io.Reader, n int64) ([]byte, error) {
+ var b bytes.Buffer
+ if n <= 0 {
+ n = 1 << 20
+ }
+ lr := &io.LimitedReader{R: r, N: n}
+ _, err := b.ReadFrom(lr)
+ return b.Bytes(), err
+}
+
+func toProviderError(resp *http.Response, message string) error {
+ return &fantasy.ProviderError{
+ Title: fantasy.ErrorTitleForStatusCode(resp.StatusCode),
+ Message: message,
+ StatusCode: resp.StatusCode,
+ }
+}
+
+func retryAfter(resp *http.Response) string {
+ after, err := strconv.Atoi(resp.Header.Get("Retry-After"))
+ if err == nil && after > 0 {
+ d := time.Duration(after) * time.Second
+ return "Try again in " + d.String()
+ }
+ return "Try again in later"
+}
@@ -0,0 +1 @@
@@ -9,8 +9,12 @@ import (
"strings"
"charm.land/lipgloss/v2"
+ "github.com/atotto/clipboard"
+ hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/oauth/claude"
+ "github.com/charmbracelet/crush/internal/oauth/hyper"
+ "github.com/pkg/browser"
"github.com/spf13/cobra"
)
@@ -20,31 +24,34 @@ var loginCmd = &cobra.Command{
Short: "Login Crush to a platform",
Long: `Login Crush to a specified platform.
The platform should be provided as an argument.
-Available platforms are: claude.`,
+Available platforms are: hyper, claude.`,
Example: `
+# Authenticate with Charm Hyper
+crush login
+
# Authenticate with Claude Code Max
crush login claude
`,
ValidArgs: []cobra.Completion{
+ "hyper",
"claude",
"anthropic",
},
- Args: cobra.ExactArgs(1),
+ Args: cobra.MaximumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
- if len(args) > 1 {
- return fmt.Errorf("wrong number of arguments")
- }
- if len(args) == 0 || args[0] == "" {
- return cmd.Help()
- }
-
app, err := setupAppWithProgressBar(cmd)
if err != nil {
return err
}
defer app.Shutdown()
- switch args[0] {
+ provider := "hyper"
+ if len(args) > 0 {
+ provider = args[0]
+ }
+ switch provider {
+ case "hyper":
+ return loginHyper()
case "anthropic", "claude":
return loginClaude()
default:
@@ -53,13 +60,73 @@ crush login claude
},
}
+func loginHyper() error {
+ cfg := config.Get()
+ if !hyperp.Enabled() {
+ return fmt.Errorf("hyper not enabled")
+ }
+ ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
+ defer cancel()
+
+ resp, err := hyper.InitiateDeviceAuth(ctx)
+ if err != nil {
+ return err
+ }
+
+ if clipboard.WriteAll(resp.UserCode) == nil {
+ fmt.Println("The following code should be on clipboard already:")
+ } else {
+ fmt.Println("Copy the following code:")
+ }
+
+ fmt.Println()
+ fmt.Println(lipgloss.NewStyle().Bold(true).Render(resp.UserCode))
+ fmt.Println()
+ fmt.Println("Press enter to open this URL, and then paste it there:")
+ fmt.Println()
+ fmt.Println(lipgloss.NewStyle().Hyperlink(resp.VerificationURL, "id=hyper").Render(resp.VerificationURL))
+ fmt.Println()
+ waitEnter()
+ if err := browser.OpenURL(resp.VerificationURL); err != nil {
+ fmt.Println("Could not open the URL. You'll need to manually open the URL in your browser.")
+ }
+
+ fmt.Println("Exchanging authorization code...")
+ refreshToken, err := hyper.PollForToken(ctx, resp.DeviceCode, resp.ExpiresIn)
+ if err != nil {
+ return err
+ }
+
+ fmt.Println("Exchanging refresh token for access token...")
+ token, err := hyper.ExchangeToken(ctx, refreshToken)
+ if err != nil {
+ return err
+ }
+
+ fmt.Println("Verifying access token...")
+ introspect, err := hyper.IntrospectToken(ctx, token.AccessToken)
+ if err != nil {
+ return fmt.Errorf("token introspection failed: %w", err)
+ }
+ if !introspect.Active {
+ return fmt.Errorf("access token is not active")
+ }
+
+ if err := cmp.Or(
+ cfg.SetConfigField("providers.hyper.api_key", token.AccessToken),
+ cfg.SetConfigField("providers.hyper.oauth", token),
+ ); err != nil {
+ return err
+ }
+
+ fmt.Println()
+ fmt.Println("You're now authenticated with Hyper!")
+ return nil
+}
+
func loginClaude() error {
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
- go func() {
- <-ctx.Done()
- cancel()
- os.Exit(1)
- }()
+ defer cancel()
verifier, challenge, err := claude.GetChallenge()
if err != nil {
@@ -106,3 +173,7 @@ func loginClaude() error {
fmt.Println("You're now authenticated with Claude Code Max!")
return nil
}
+
+func waitEnter() {
+ _, _ = fmt.Scanln()
+}
@@ -10,22 +10,30 @@ import (
"github.com/spf13/cobra"
)
+var updateProvidersSource string
+
var updateProvidersCmd = &cobra.Command{
Use: "update-providers [path-or-url]",
Short: "Update providers",
- Long: `Update the list of providers from a specified local path or remote URL.`,
+ Long: `Update provider information from a specified local path or remote URL.`,
Example: `
-# Update providers remotely from Catwalk
+# Update Catwalk providers remotely (default)
crush update-providers
-# Update providers from a custom URL
-crush update-providers https://example.com/
+# Update Catwalk providers from a custom URL
+crush update-providers https://example.com/providers.json
-# Update providers from a local file
+# Update Catwalk providers from a local file
crush update-providers /path/to/local-providers.json
-# Update providers from embedded version
+# Update Catwalk providers from embedded version
crush update-providers embedded
+
+# Update Hyper provider information
+crush update-providers --source=hyper
+
+# Update Hyper from a custom URL
+crush update-providers --source=hyper https://hyper.example.com
`,
RunE: func(cmd *cobra.Command, args []string) error {
// NOTE(@andreynering): We want to skip logging output do stdout here.
@@ -36,7 +44,17 @@ crush update-providers embedded
pathOrURL = args[0]
}
- if err := config.UpdateProviders(pathOrURL); err != nil {
+ var err error
+ switch updateProvidersSource {
+ case "catwalk":
+ err = config.UpdateProviders(pathOrURL)
+ case "hyper":
+ err = config.UpdateHyper(pathOrURL)
+ default:
+ return fmt.Errorf("invalid source %q, must be 'catwalk' or 'hyper'", updateProvidersSource)
+ }
+
+ if err != nil {
return err
}
@@ -52,9 +70,13 @@ crush update-providers embedded
SetString("SUCCESS")
textStyle := lipgloss.NewStyle().
MarginLeft(2).
- SetString("Providers updated successfully.")
+ SetString(fmt.Sprintf("%s provider updated successfully.", updateProvidersSource))
fmt.Printf("%s\n%s\n\n", headerStyle.Render(), textStyle.Render())
return nil
},
}
+
+func init() {
+ updateProvidersCmd.Flags().StringVar(&updateProvidersSource, "source", "catwalk", "Provider source to update (catwalk or hyper)")
+}
@@ -0,0 +1,82 @@
+package config
+
+import (
+ "context"
+ "errors"
+ "log/slog"
+ "sync"
+ "sync/atomic"
+
+ "github.com/charmbracelet/catwalk/pkg/catwalk"
+ "github.com/charmbracelet/catwalk/pkg/embedded"
+)
+
+type catwalkClient interface {
+ GetProviders(context.Context, string) ([]catwalk.Provider, error)
+}
+
+var _ syncer[[]catwalk.Provider] = (*catwalkSync)(nil)
+
+type catwalkSync struct {
+ once sync.Once
+ result []catwalk.Provider
+ cache cache[[]catwalk.Provider]
+ client catwalkClient
+ autoupdate bool
+ init atomic.Bool
+}
+
+func (s *catwalkSync) Init(client catwalkClient, path string, autoupdate bool) {
+ s.client = client
+ s.cache = newCache[[]catwalk.Provider](path)
+ s.autoupdate = autoupdate
+ s.init.Store(true)
+}
+
+func (s *catwalkSync) Get(ctx context.Context) ([]catwalk.Provider, error) {
+ if !s.init.Load() {
+ panic("called Get before Init")
+ }
+
+ var throwErr error
+ s.once.Do(func() {
+ if !s.autoupdate {
+ slog.Info("Using embedded Catwalk providers")
+ s.result = embedded.GetAll()
+ return
+ }
+
+ cached, etag, cachedErr := s.cache.Get()
+ if len(cached) == 0 || cachedErr != nil {
+ // if cached file is empty, default to embedded providers
+ cached = embedded.GetAll()
+ }
+
+ slog.Info("Fetching providers from Catwalk")
+ result, err := s.client.GetProviders(ctx, etag)
+ if errors.Is(err, context.DeadlineExceeded) {
+ slog.Warn("Catwalk providers not updated in time")
+ s.result = cached
+ return
+ }
+ if errors.Is(err, catwalk.ErrNotModified) {
+ slog.Info("Catwalk providers not modified")
+ s.result = cached
+ return
+ }
+ if err != nil {
+ // On error, fall back to cached (which defaults to embedded if empty).
+ s.result = cached
+ return
+ }
+ if len(result) == 0 {
+ s.result = cached
+ throwErr = errors.New("empty providers list from catwalk")
+ return
+ }
+
+ s.result = result
+ throwErr = s.cache.Store(result)
+ })
+ return s.result, throwErr
+}
@@ -0,0 +1,221 @@
+package config
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "os"
+ "testing"
+
+ "github.com/charmbracelet/catwalk/pkg/catwalk"
+ "github.com/stretchr/testify/require"
+)
+
+type mockCatwalkClient struct {
+ providers []catwalk.Provider
+ err error
+ callCount int
+}
+
+func (m *mockCatwalkClient) GetProviders(ctx context.Context, etag string) ([]catwalk.Provider, error) {
+ m.callCount++
+ return m.providers, m.err
+}
+
+func TestCatwalkSync_Init(t *testing.T) {
+ t.Parallel()
+
+ syncer := &catwalkSync{}
+ client := &mockCatwalkClient{}
+ path := "/tmp/test.json"
+
+ syncer.Init(client, path, true)
+
+ require.True(t, syncer.init.Load())
+ require.Equal(t, client, syncer.client)
+ require.Equal(t, path, syncer.cache.path)
+ require.True(t, syncer.autoupdate)
+}
+
+func TestCatwalkSync_GetPanicIfNotInit(t *testing.T) {
+ t.Parallel()
+
+ syncer := &catwalkSync{}
+ require.Panics(t, func() {
+ _, _ = syncer.Get(t.Context())
+ })
+}
+
+func TestCatwalkSync_GetWithAutoUpdateDisabled(t *testing.T) {
+ t.Parallel()
+
+ syncer := &catwalkSync{}
+ client := &mockCatwalkClient{
+ providers: []catwalk.Provider{{Name: "should-not-be-used"}},
+ }
+ path := t.TempDir() + "/providers.json"
+
+ syncer.Init(client, path, false)
+
+ providers, err := syncer.Get(t.Context())
+ require.NoError(t, err)
+ require.NotEmpty(t, providers)
+ require.Equal(t, 0, client.callCount, "Client should not be called when autoupdate is disabled")
+
+ // Should return embedded providers.
+ for _, p := range providers {
+ require.NotEqual(t, "should-not-be-used", p.Name)
+ }
+}
+
+func TestCatwalkSync_GetFreshProviders(t *testing.T) {
+ t.Parallel()
+
+ syncer := &catwalkSync{}
+ client := &mockCatwalkClient{
+ providers: []catwalk.Provider{
+ {Name: "Fresh Provider", ID: "fresh"},
+ },
+ }
+ path := t.TempDir() + "/providers.json"
+
+ syncer.Init(client, path, true)
+
+ providers, err := syncer.Get(t.Context())
+ require.NoError(t, err)
+ require.Len(t, providers, 1)
+ require.Equal(t, "Fresh Provider", providers[0].Name)
+ require.Equal(t, 1, client.callCount)
+
+ // Verify cache was written.
+ fileInfo, err := os.Stat(path)
+ require.NoError(t, err)
+ require.False(t, fileInfo.IsDir())
+}
+
+func TestCatwalkSync_GetNotModifiedUsesCached(t *testing.T) {
+ t.Parallel()
+
+ tmpDir := t.TempDir()
+ path := tmpDir + "/providers.json"
+
+ // Create cache file.
+ cachedProviders := []catwalk.Provider{
+ {Name: "Cached Provider", ID: "cached"},
+ }
+ data, err := json.Marshal(cachedProviders)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(path, data, 0o644))
+
+ syncer := &catwalkSync{}
+ client := &mockCatwalkClient{
+ err: catwalk.ErrNotModified,
+ }
+
+ syncer.Init(client, path, true)
+
+ providers, err := syncer.Get(t.Context())
+ require.NoError(t, err)
+ require.Len(t, providers, 1)
+ require.Equal(t, "Cached Provider", providers[0].Name)
+ require.Equal(t, 1, client.callCount)
+}
+
+func TestCatwalkSync_GetEmptyResultFallbackToCached(t *testing.T) {
+ t.Parallel()
+
+ tmpDir := t.TempDir()
+ path := tmpDir + "/providers.json"
+
+ // Create cache file.
+ cachedProviders := []catwalk.Provider{
+ {Name: "Cached Provider", ID: "cached"},
+ }
+ data, err := json.Marshal(cachedProviders)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(path, data, 0o644))
+
+ syncer := &catwalkSync{}
+ client := &mockCatwalkClient{
+ providers: []catwalk.Provider{}, // Empty result.
+ }
+
+ syncer.Init(client, path, true)
+
+ providers, err := syncer.Get(t.Context())
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "empty providers list from catwalk")
+ require.Len(t, providers, 1)
+ require.Equal(t, "Cached Provider", providers[0].Name)
+}
+
+func TestCatwalkSync_GetEmptyCacheDefaultsToEmbedded(t *testing.T) {
+ t.Parallel()
+
+ tmpDir := t.TempDir()
+ path := tmpDir + "/providers.json"
+
+ // Create empty cache file.
+ emptyProviders := []catwalk.Provider{}
+ data, err := json.Marshal(emptyProviders)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(path, data, 0o644))
+
+ syncer := &catwalkSync{}
+ client := &mockCatwalkClient{
+ err: errors.New("network error"),
+ }
+
+ syncer.Init(client, path, true)
+
+ providers, err := syncer.Get(t.Context())
+ require.NoError(t, err)
+ require.NotEmpty(t, providers, "Should fall back to embedded providers")
+
+ // Verify it's embedded providers by checking we have multiple common ones.
+ require.Greater(t, len(providers), 5)
+}
+
+func TestCatwalkSync_GetClientError(t *testing.T) {
+ t.Parallel()
+
+ tmpDir := t.TempDir()
+ path := tmpDir + "/providers.json"
+
+ syncer := &catwalkSync{}
+ client := &mockCatwalkClient{
+ err: errors.New("network error"),
+ }
+
+ syncer.Init(client, path, true)
+
+ providers, err := syncer.Get(t.Context())
+ require.NoError(t, err) // Should fall back to embedded.
+ require.NotEmpty(t, providers)
+}
+
+func TestCatwalkSync_GetCalledMultipleTimesUsesOnce(t *testing.T) {
+ t.Parallel()
+
+ syncer := &catwalkSync{}
+ client := &mockCatwalkClient{
+ providers: []catwalk.Provider{
+ {Name: "Provider", ID: "test"},
+ },
+ }
+ path := t.TempDir() + "/providers.json"
+
+ syncer.Init(client, path, true)
+
+ // Call Get multiple times.
+ providers1, err1 := syncer.Get(t.Context())
+ require.NoError(t, err1)
+ require.Len(t, providers1, 1)
+
+ providers2, err2 := syncer.Get(t.Context())
+ require.NoError(t, err2)
+ require.Len(t, providers2, 1)
+
+ // Client should only be called once due to sync.Once.
+ require.Equal(t, 1, client.callCount)
+}
@@ -13,10 +13,12 @@ import (
"time"
"github.com/charmbracelet/catwalk/pkg/catwalk"
+ hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
"github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/env"
"github.com/charmbracelet/crush/internal/oauth"
"github.com/charmbracelet/crush/internal/oauth/claude"
+ "github.com/charmbracelet/crush/internal/oauth/hyper"
"github.com/invopop/jsonschema"
"github.com/tidwall/sjson"
)
@@ -120,7 +122,6 @@ type ProviderConfig struct {
}
func (pc *ProviderConfig) SetupClaudeCode() {
- pc.APIKey = fmt.Sprintf("Bearer %s", pc.OAuthToken.AccessToken)
pc.SystemPromptPrefix = "You are Claude Code, Anthropic's official CLI for Claude."
pc.ExtraHeaders["anthropic-version"] = "2023-06-01"
@@ -483,12 +484,16 @@ func (c *Config) RefreshOAuthToken(ctx context.Context, providerID string) error
return fmt.Errorf("provider %s does not have an OAuth token", providerID)
}
- // Only Anthropic provider uses OAuth for now.
- if providerID != string(catwalk.InferenceProviderAnthropic) {
+ var newToken *oauth.Token
+ var err error
+ switch providerID {
+ case string(catwalk.InferenceProviderAnthropic):
+ newToken, err = claude.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
+ case hyperp.Name:
+ newToken, err = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken)
+ default:
return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
}
-
- newToken, err := claude.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
if err != nil {
return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, err)
}
@@ -529,9 +534,11 @@ func (c *Config) SetProviderAPIKey(providerID string, apiKey any) error {
return err
}
setKeyOrToken = func() {
- providerConfig.APIKey = v.AccessToken
+ providerConfig.APIKey = fmt.Sprintf("Bearer %s", v.AccessToken)
providerConfig.OAuthToken = v
- providerConfig.SetupClaudeCode()
+ if providerID == string(catwalk.InferenceProviderAnthropic) {
+ providerConfig.SetupClaudeCode()
+ }
}
}
@@ -0,0 +1,124 @@
+package config
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "log/slog"
+ "net/http"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/charmbracelet/catwalk/pkg/catwalk"
+ "github.com/charmbracelet/crush/internal/agent/hyper"
+ xetag "github.com/charmbracelet/x/etag"
+)
+
+type hyperClient interface {
+ Get(context.Context, string) (catwalk.Provider, error)
+}
+
+var _ syncer[catwalk.Provider] = (*hyperSync)(nil)
+
+type hyperSync struct {
+ once sync.Once
+ result catwalk.Provider
+ cache cache[catwalk.Provider]
+ client hyperClient
+ autoupdate bool
+ init atomic.Bool
+}
+
+func (s *hyperSync) Init(client hyperClient, path string, autoupdate bool) {
+ s.client = client
+ s.cache = newCache[catwalk.Provider](path)
+ s.autoupdate = autoupdate
+ s.init.Store(true)
+}
+
+func (s *hyperSync) Get(ctx context.Context) (catwalk.Provider, error) {
+ if !s.init.Load() {
+ panic("called Get before Init")
+ }
+
+ var throwErr error
+ s.once.Do(func() {
+ if !s.autoupdate {
+ slog.Info("Using embedded Hyper provider")
+ s.result = hyper.Embedded()
+ return
+ }
+
+ cached, etag, cachedErr := s.cache.Get()
+ if cached.ID == "" || cachedErr != nil {
+ // if cached file is empty, default to embedded provider
+ cached = hyper.Embedded()
+ }
+
+ slog.Info("Fetching Hyper provider")
+ result, err := s.client.Get(ctx, etag)
+ if errors.Is(err, context.DeadlineExceeded) {
+ slog.Warn("Hyper provider not updated in time")
+ s.result = cached
+ return
+ }
+ if errors.Is(err, catwalk.ErrNotModified) {
+ slog.Info("Hyper provider not modified")
+ s.result = cached
+ return
+ }
+ if len(result.Models) == 0 {
+ slog.Warn("Hyper did not return any models")
+ s.result = cached
+ return
+ }
+
+ s.result = result
+ throwErr = s.cache.Store(result)
+ })
+ return s.result, throwErr
+}
+
+var _ hyperClient = realHyperClient{}
+
+type realHyperClient struct {
+ baseURL string
+}
+
+// Get implements hyperClient.
+func (r realHyperClient) Get(ctx context.Context, etag string) (catwalk.Provider, error) {
+ var result catwalk.Provider
+ req, err := http.NewRequestWithContext(
+ ctx,
+ http.MethodGet,
+ r.baseURL+"/api/v1/provider",
+ nil,
+ )
+ if err != nil {
+ return result, fmt.Errorf("could not create request: %w", err)
+ }
+ xetag.Request(req, etag)
+
+ client := &http.Client{Timeout: 30 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return result, fmt.Errorf("failed to make request: %w", err)
+ }
+ defer resp.Body.Close() //nolint:errcheck
+
+ if resp.StatusCode == http.StatusNotModified {
+ return result, catwalk.ErrNotModified
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return result, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
+ }
+
+ if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
+ return result, fmt.Errorf("failed to decode response: %w", err)
+ }
+
+ return result, nil
+}
@@ -0,0 +1,205 @@
+package config
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "os"
+ "testing"
+
+ "github.com/charmbracelet/catwalk/pkg/catwalk"
+ "github.com/stretchr/testify/require"
+)
+
+type mockHyperClient struct {
+ provider catwalk.Provider
+ err error
+ callCount int
+}
+
+func (m *mockHyperClient) Get(ctx context.Context, etag string) (catwalk.Provider, error) {
+ m.callCount++
+ return m.provider, m.err
+}
+
+func TestHyperSync_Init(t *testing.T) {
+ t.Parallel()
+
+ syncer := &hyperSync{}
+ client := &mockHyperClient{}
+ path := "/tmp/hyper.json"
+
+ syncer.Init(client, path, true)
+
+ require.True(t, syncer.init.Load())
+ require.Equal(t, client, syncer.client)
+ require.Equal(t, path, syncer.cache.path)
+}
+
+func TestHyperSync_GetPanicIfNotInit(t *testing.T) {
+ t.Parallel()
+
+ syncer := &hyperSync{}
+ require.Panics(t, func() {
+ _, _ = syncer.Get(t.Context())
+ })
+}
+
+func TestHyperSync_GetFreshProvider(t *testing.T) {
+ t.Parallel()
+
+ syncer := &hyperSync{}
+ client := &mockHyperClient{
+ provider: catwalk.Provider{
+ Name: "Hyper",
+ ID: "hyper",
+ Models: []catwalk.Model{
+ {ID: "model-1", Name: "Model 1"},
+ },
+ },
+ }
+ path := t.TempDir() + "/hyper.json"
+
+ syncer.Init(client, path, true)
+
+ provider, err := syncer.Get(t.Context())
+ require.NoError(t, err)
+ require.Equal(t, "Hyper", provider.Name)
+ require.Equal(t, 1, client.callCount)
+
+ // Verify cache was written.
+ fileInfo, err := os.Stat(path)
+ require.NoError(t, err)
+ require.False(t, fileInfo.IsDir())
+}
+
+func TestHyperSync_GetNotModifiedUsesCached(t *testing.T) {
+ t.Parallel()
+
+ tmpDir := t.TempDir()
+ path := tmpDir + "/hyper.json"
+
+ // Create cache file.
+ cachedProvider := catwalk.Provider{
+ Name: "Cached Hyper",
+ ID: "hyper",
+ }
+ data, err := json.Marshal(cachedProvider)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(path, data, 0o644))
+
+ syncer := &hyperSync{}
+ client := &mockHyperClient{
+ err: catwalk.ErrNotModified,
+ }
+
+ syncer.Init(client, path, true)
+
+ provider, err := syncer.Get(t.Context())
+ require.NoError(t, err)
+ require.Equal(t, "Cached Hyper", provider.Name)
+ require.Equal(t, 1, client.callCount)
+}
+
+func TestHyperSync_GetClientError(t *testing.T) {
+ t.Parallel()
+
+ tmpDir := t.TempDir()
+ path := tmpDir + "/hyper.json"
+
+ syncer := &hyperSync{}
+ client := &mockHyperClient{
+ err: errors.New("network error"),
+ }
+
+ syncer.Init(client, path, true)
+
+ provider, err := syncer.Get(t.Context())
+ require.NoError(t, err) // Should fall back to embedded.
+ require.Equal(t, "Charm Hyper", provider.Name)
+ require.Equal(t, catwalk.InferenceProvider("hyper"), provider.ID)
+}
+
+func TestHyperSync_GetEmptyCache(t *testing.T) {
+ t.Parallel()
+
+ tmpDir := t.TempDir()
+ path := tmpDir + "/hyper.json"
+
+ syncer := &hyperSync{}
+ client := &mockHyperClient{
+ provider: catwalk.Provider{
+ Name: "Fresh Hyper",
+ ID: "hyper",
+ Models: []catwalk.Model{
+ {ID: "model-1", Name: "Model 1"},
+ },
+ },
+ }
+
+ syncer.Init(client, path, true)
+
+ provider, err := syncer.Get(t.Context())
+ require.NoError(t, err)
+ require.Equal(t, "Fresh Hyper", provider.Name)
+}
+
+func TestHyperSync_GetCalledMultipleTimesUsesOnce(t *testing.T) {
+ t.Parallel()
+
+ syncer := &hyperSync{}
+ client := &mockHyperClient{
+ provider: catwalk.Provider{
+ Name: "Hyper",
+ ID: "hyper",
+ Models: []catwalk.Model{
+ {ID: "model-1", Name: "Model 1"},
+ },
+ },
+ }
+ path := t.TempDir() + "/hyper.json"
+
+ syncer.Init(client, path, true)
+
+ // Call Get multiple times.
+ provider1, err1 := syncer.Get(t.Context())
+ require.NoError(t, err1)
+ require.Equal(t, "Hyper", provider1.Name)
+
+ provider2, err2 := syncer.Get(t.Context())
+ require.NoError(t, err2)
+ require.Equal(t, "Hyper", provider2.Name)
+
+ // Client should only be called once due to sync.Once.
+ require.Equal(t, 1, client.callCount)
+}
+
+func TestHyperSync_GetCacheStoreError(t *testing.T) {
+ t.Parallel()
+
+ // Create a file where we want a directory, causing mkdir to fail.
+ tmpDir := t.TempDir()
+ blockingFile := tmpDir + "/blocking"
+ require.NoError(t, os.WriteFile(blockingFile, []byte("block"), 0o644))
+
+ // Try to create cache in a subdirectory under the blocking file.
+ path := blockingFile + "/subdir/hyper.json"
+
+ syncer := &hyperSync{}
+ client := &mockHyperClient{
+ provider: catwalk.Provider{
+ Name: "Hyper",
+ ID: "hyper",
+ Models: []catwalk.Model{
+ {ID: "model-1", Name: "Model 1"},
+ },
+ },
+ }
+
+ syncer.Init(client, path, true)
+
+ provider, err := syncer.Get(t.Context())
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "failed to create directory for provider cache")
+ require.Equal(t, "Hyper", provider.Name) // Provider is still returned.
+}
@@ -17,6 +17,7 @@ import (
"testing"
"github.com/charmbracelet/catwalk/pkg/catwalk"
+ "github.com/charmbracelet/crush/internal/agent/hyper"
"github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/env"
"github.com/charmbracelet/crush/internal/fsext"
@@ -271,7 +272,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
if providerConfig.Type == "" {
providerConfig.Type = catwalk.TypeOpenAICompat
}
- if !slices.Contains(catwalk.KnownProviderTypes(), providerConfig.Type) {
+ if !slices.Contains(catwalk.KnownProviderTypes(), providerConfig.Type) && providerConfig.Type != hyper.Name {
slog.Warn("Skipping custom provider due to unsupported provider type", "provider", id)
c.Providers.Del(id)
continue
@@ -10,16 +10,21 @@ import (
"os"
"path/filepath"
"runtime"
+ "slices"
"strings"
"sync"
+ "time"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/catwalk/pkg/embedded"
+ "github.com/charmbracelet/crush/internal/agent/hyper"
+ "github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/home"
+ "github.com/charmbracelet/x/etag"
)
-type ProviderClient interface {
- GetProviders(context.Context, string) ([]catwalk.Provider, error)
+type syncer[T any] interface {
+ Get(context.Context) (T, error)
}
var (
@@ -29,10 +34,10 @@ var (
)
// file to cache provider data
-func providerCacheFileData() string {
+func cachePathFor(name string) string {
xdgDataHome := os.Getenv("XDG_DATA_HOME")
if xdgDataHome != "" {
- return filepath.Join(xdgDataHome, appName, "providers.json")
+ return filepath.Join(xdgDataHome, appName, name+".json")
}
// return the path to the main data directory
@@ -43,43 +48,13 @@ func providerCacheFileData() string {
if localAppData == "" {
localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
}
- return filepath.Join(localAppData, appName, "providers.json")
+ return filepath.Join(localAppData, appName, name+".json")
}
- return filepath.Join(home.Dir(), ".local", "share", appName, "providers.json")
-}
-
-func saveProvidersInCache(path string, providers []catwalk.Provider) error {
- slog.Info("Saving provider data to disk", "path", path)
- if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
- return fmt.Errorf("failed to create directory for provider cache: %w", err)
- }
-
- data, err := json.Marshal(providers)
- if err != nil {
- return fmt.Errorf("failed to marshal provider data: %w", err)
- }
-
- if err := os.WriteFile(path, data, 0o644); err != nil {
- return fmt.Errorf("failed to write provider data to cache: %w", err)
- }
- return nil
-}
-
-func loadProvidersFromCache(path string) ([]catwalk.Provider, string, error) {
- data, err := os.ReadFile(path)
- if err != nil {
- return nil, "", fmt.Errorf("failed to read provider cache file: %w", err)
- }
-
- var providers []catwalk.Provider
- if err := json.Unmarshal(data, &providers); err != nil {
- return nil, "", fmt.Errorf("failed to unmarshal provider data from cache: %w", err)
- }
-
- return providers, catwalk.Etag(data), nil
+ return filepath.Join(home.Dir(), ".local", "share", appName, name+".json")
}
+// UpdateProviders updates the Catwalk providers list from a specified source.
func UpdateProviders(pathOrURL string) error {
var providers []catwalk.Provider
pathOrURL = cmp.Or(pathOrURL, os.Getenv("CATWALK_URL"), defaultCatwalkURL)
@@ -106,15 +81,55 @@ func UpdateProviders(pathOrURL string) error {
}
}
- cachePath := providerCacheFileData()
- if err := saveProvidersInCache(cachePath, providers); err != nil {
+ if err := newCache[[]catwalk.Provider](cachePathFor("providers")).Store(providers); err != nil {
return fmt.Errorf("failed to save providers to cache: %w", err)
}
- slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrURL, "to", cachePath)
+ slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrURL, "to", cachePathFor)
+ return nil
+}
+
+// UpdateHyper updates the Hyper provider information from a specified URL.
+func UpdateHyper(pathOrURL string) error {
+ if !hyper.Enabled() {
+ return fmt.Errorf("hyper not enabled")
+ }
+ var provider catwalk.Provider
+ pathOrURL = cmp.Or(pathOrURL, hyper.BaseURL())
+
+ switch {
+ case pathOrURL == "embedded":
+ provider = hyper.Embedded()
+ case strings.HasPrefix(pathOrURL, "http://") || strings.HasPrefix(pathOrURL, "https://"):
+ client := realHyperClient{baseURL: pathOrURL}
+ var err error
+ provider, err = client.Get(context.Background(), "")
+ if err != nil {
+ return fmt.Errorf("failed to fetch provider from Hyper: %w", err)
+ }
+ default:
+ content, err := os.ReadFile(pathOrURL)
+ if err != nil {
+ return fmt.Errorf("failed to read file: %w", err)
+ }
+ if err := json.Unmarshal(content, &provider); err != nil {
+ return fmt.Errorf("failed to unmarshal provider data: %w", err)
+ }
+ }
+
+ if err := newCache[catwalk.Provider](cachePathFor("hyper")).Store(provider); err != nil {
+ return fmt.Errorf("failed to save Hyper provider to cache: %w", err)
+ }
+
+ slog.Info("Hyper provider updated successfully", "from", pathOrURL, "to", cachePathFor("hyper"))
return nil
}
+var (
+ catwalkSyncer = &catwalkSync{}
+ hyperSyncer = &hyperSync{}
+)
+
// Providers returns the list of providers, taking into account cached results
// and whether or not auto update is enabled.
//
@@ -126,46 +141,87 @@ func UpdateProviders(pathOrURL string) error {
// the cached list, or the embedded list if all others fail.
func Providers(cfg *Config) ([]catwalk.Provider, error) {
providerOnce.Do(func() {
- catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
- client := catwalk.NewWithURL(catwalkURL)
- path := providerCacheFileData()
-
- if cfg.Options.DisableProviderAutoUpdate {
- slog.Info("Using embedded Catwalk providers")
- providerList, providerErr = embedded.GetAll(), nil
- return
- }
+ var wg sync.WaitGroup
+ var errs []error
+ providers := csync.NewSlice[catwalk.Provider]()
+ autoupdate := !cfg.Options.DisableProviderAutoUpdate
+
+ ctx, cancel := context.WithTimeout(context.Background(), 45*time.Second)
+ defer cancel()
+
+ wg.Go(func() {
+ catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
+ client := catwalk.NewWithURL(catwalkURL)
+ path := cachePathFor("providers")
+ catwalkSyncer.Init(client, path, autoupdate)
+
+ items, err := catwalkSyncer.Get(ctx)
+ if err != nil {
+ catwalkURL := fmt.Sprintf("%s/v2/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL))
+ errs = append(errs, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help.\n\nCause: %w", catwalkURL, providerErr)) //nolint:staticcheck
+ return
+ }
+ providers.Append(items...)
+ })
+
+ wg.Go(func() {
+ if !hyper.Enabled() {
+ return
+ }
+ path := cachePathFor("hyper")
+ hyperSyncer.Init(realHyperClient{baseURL: hyper.BaseURL()}, path, autoupdate)
+
+ item, err := hyperSyncer.Get(ctx)
+ if err != nil {
+ errs = append(errs, fmt.Errorf("Crush was unable to fetch updated information from Hyper: %w", err)) //nolint:staticcheck
+ return
+ }
+ providers.Append(item)
+ })
+
+ wg.Wait()
+
+ providerList = slices.Collect(providers.Seq())
+ providerErr = errors.Join(errs...)
+ })
+ return providerList, providerErr
+}
- cached, etag, cachedErr := loadProvidersFromCache(path)
- if len(cached) == 0 || cachedErr != nil {
- // if cached file is empty, default to embedded providers
- cached = embedded.GetAll()
- }
+type cache[T any] struct {
+ path string
+}
- providerList, providerErr = loadProviders(client, etag, path)
- if errors.Is(providerErr, catwalk.ErrNotModified) {
- slog.Info("Catwalk providers not modified")
- providerList, providerErr = cached, nil
- }
- })
- if providerErr != nil {
- catwalkURL := fmt.Sprintf("%s/v2/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL))
- return nil, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help.\n\nCause: %w", catwalkURL, providerErr) //nolint:staticcheck
- }
- return providerList, nil
+func newCache[T any](path string) cache[T] {
+ return cache[T]{path: path}
}
-func loadProviders(client ProviderClient, etag, path string) ([]catwalk.Provider, error) {
- slog.Info("Fetching providers from Catwalk.", "path", path)
- providers, err := client.GetProviders(context.Background(), etag)
+func (c cache[T]) Get() (T, string, error) {
+ var v T
+ data, err := os.ReadFile(c.path)
if err != nil {
- return nil, fmt.Errorf("failed to fetch providers from catwalk: %w", err)
+ return v, "", fmt.Errorf("failed to read provider cache file: %w", err)
+ }
+
+ if err := json.Unmarshal(data, &v); err != nil {
+ return v, "", fmt.Errorf("failed to unmarshal provider data from cache: %w", err)
+ }
+
+ return v, etag.Of(data), nil
+}
+
+func (c cache[T]) Store(v T) error {
+ slog.Info("Saving provider data to disk", "path", c.path)
+ if err := os.MkdirAll(filepath.Dir(c.path), 0o755); err != nil {
+ return fmt.Errorf("failed to create directory for provider cache: %w", err)
}
- if len(providers) == 0 {
- return nil, errors.New("empty providers list from catwalk")
+
+ data, err := json.Marshal(v)
+ if err != nil {
+ return fmt.Errorf("failed to marshal provider data: %w", err)
}
- if err := saveProvidersInCache(path, providers); err != nil {
- return nil, err
+
+ if err := os.WriteFile(c.path, data, 0o644); err != nil {
+ return fmt.Errorf("failed to write provider data to cache: %w", err)
}
- return providers, nil
+ return nil
}
@@ -2,6 +2,7 @@ package config
import (
"context"
+ "os"
"testing"
"github.com/charmbracelet/catwalk/pkg/catwalk"
@@ -14,18 +15,25 @@ func (m *emptyProviderClient) GetProviders(context.Context, string) ([]catwalk.P
return []catwalk.Provider{}, nil
}
-// TestProvider_loadProvidersEmptyResult tests that loadProviders returns an
-// error when the client returns an empty list. This ensures we don't cache
-// empty provider lists.
-func TestProvider_loadProvidersEmptyResult(t *testing.T) {
+// TestCatwalkSync_GetEmptyResultFromClient tests that when the client returns
+// an empty list, we fall back to cached providers and return an error.
+func TestCatwalkSync_GetEmptyResultFromClient(t *testing.T) {
+ t.Parallel()
+
+ tmpDir := t.TempDir()
+ path := tmpDir + "/providers.json"
+
+ syncer := &catwalkSync{}
client := &emptyProviderClient{}
- tmpPath := t.TempDir() + "/providers.json"
- providers, err := loadProviders(client, "", tmpPath)
+ syncer.Init(client, path, true)
+
+ providers, err := syncer.Get(t.Context())
+ require.Error(t, err)
require.Contains(t, err.Error(), "empty providers list from catwalk")
- require.Empty(t, providers)
- require.Len(t, providers, 0)
+ require.NotEmpty(t, providers) // Should have embedded providers as fallback.
- // Check that no cache file was created for empty results
- require.NoFileExists(t, tmpPath, "Cache file should not exist for empty results")
+ // Check that no cache file was created for empty results.
+ _, statErr := os.Stat(path)
+ require.True(t, os.IsNotExist(statErr), "Cache file should not exist for empty results")
}
@@ -1,10 +1,9 @@
package config
import (
- "context"
"encoding/json"
- "errors"
"os"
+ "path/filepath"
"sync"
"testing"
@@ -12,47 +11,31 @@ import (
"github.com/stretchr/testify/require"
)
-type mockProviderClient struct {
- shouldFail bool
- shouldReturnErr error
-}
-
-func (m *mockProviderClient) GetProviders(context.Context, string) ([]catwalk.Provider, error) {
- if m.shouldReturnErr != nil {
- return nil, m.shouldReturnErr
- }
- if m.shouldFail {
- return nil, errors.New("failed to load providers")
- }
- return []catwalk.Provider{
- {
- Name: "Mock",
- },
- }, nil
-}
-
func resetProviderState() {
providerOnce = sync.Once{}
providerList = nil
providerErr = nil
+ catwalkSyncer = &catwalkSync{}
+ hyperSyncer = &hyperSync{}
}
-func TestProvider_loadProvidersNoIssues(t *testing.T) {
- client := &mockProviderClient{shouldFail: false}
- tmpPath := t.TempDir() + "/providers.json"
- providers, err := loadProviders(client, "", tmpPath)
- require.NoError(t, err)
- require.NotNil(t, providers)
- require.Len(t, providers, 1)
+func TestProviders_Integration_AutoUpdateDisabled(t *testing.T) {
+ tmpDir := t.TempDir()
+ t.Setenv("XDG_DATA_HOME", tmpDir)
- // check if file got saved
- fileInfo, err := os.Stat(tmpPath)
- require.NoError(t, err)
- require.False(t, fileInfo.IsDir(), "Expected a file, not a directory")
-}
+ // Use a test-specific instance to avoid global state interference.
+ testCatwalkSyncer := &catwalkSync{}
+ testHyperSyncer := &hyperSync{}
+
+ originalCatwalSyncer := catwalkSyncer
+ originalHyperSyncer := hyperSyncer
+ defer func() {
+ catwalkSyncer = originalCatwalSyncer
+ hyperSyncer = originalHyperSyncer
+ }()
-func TestProvider_DisableAutoUpdate(t *testing.T) {
- t.Setenv("XDG_DATA_HOME", t.TempDir())
+ catwalkSyncer = testCatwalkSyncer
+ hyperSyncer = testHyperSyncer
resetProviderState()
defer resetProviderState()
@@ -69,76 +52,257 @@ func TestProvider_DisableAutoUpdate(t *testing.T) {
require.Greater(t, len(providers), 5, "Expected embedded providers")
}
-func TestProvider_WithValidCache(t *testing.T) {
+func TestProviders_Integration_WithMockClients(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("XDG_DATA_HOME", tmpDir)
- resetProviderState()
- defer resetProviderState()
+ // Create fresh syncers for this test.
+ testCatwalkSyncer := &catwalkSync{}
+ testHyperSyncer := &hyperSync{}
+
+ // Initialize with mock clients.
+ mockCatwalkClient := &mockCatwalkClient{
+ providers: []catwalk.Provider{
+ {Name: "Provider1", ID: "p1"},
+ {Name: "Provider2", ID: "p2"},
+ },
+ }
+ mockHyperClient := &mockHyperClient{
+ provider: catwalk.Provider{
+ Name: "Hyper",
+ ID: "hyper",
+ Models: []catwalk.Model{
+ {ID: "hyper-1", Name: "Hyper Model"},
+ },
+ },
+ }
+
+ catwalkPath := tmpDir + "/crush/providers.json"
+ hyperPath := tmpDir + "/crush/hyper.json"
+
+ testCatwalkSyncer.Init(mockCatwalkClient, catwalkPath, true)
+ testHyperSyncer.Init(mockHyperClient, hyperPath, true)
+
+ // Get providers from each syncer.
+ catwalkProviders, err := testCatwalkSyncer.Get(t.Context())
+ require.NoError(t, err)
+ require.Len(t, catwalkProviders, 2)
+
+ hyperProvider, err := testHyperSyncer.Get(t.Context())
+ require.NoError(t, err)
+ require.Equal(t, "Hyper", hyperProvider.Name)
+
+ // Verify total.
+ allProviders := append(catwalkProviders, hyperProvider)
+ require.Len(t, allProviders, 3)
+}
+
+func TestProviders_Integration_WithCachedData(t *testing.T) {
+ tmpDir := t.TempDir()
+ t.Setenv("XDG_DATA_HOME", tmpDir)
+
+ // Create cache files.
+ catwalkPath := tmpDir + "/crush/providers.json"
+ hyperPath := tmpDir + "/crush/hyper.json"
- cachePath := tmpDir + "/crush/providers.json"
require.NoError(t, os.MkdirAll(tmpDir+"/crush", 0o755))
- cachedProviders := []catwalk.Provider{
- {Name: "Cached"},
+
+ // Write Catwalk cache.
+ catwalkProviders := []catwalk.Provider{
+ {Name: "Cached1", ID: "c1"},
+ {Name: "Cached2", ID: "c2"},
+ }
+ data, err := json.Marshal(catwalkProviders)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(catwalkPath, data, 0o644))
+
+ // Write Hyper cache.
+ hyperProvider := catwalk.Provider{
+ Name: "Cached Hyper",
+ ID: "hyper",
}
- data, err := json.Marshal(cachedProviders)
+ data, err = json.Marshal(hyperProvider)
require.NoError(t, err)
- require.NoError(t, os.WriteFile(cachePath, data, 0o644))
+ require.NoError(t, os.WriteFile(hyperPath, data, 0o644))
+
+ // Create fresh syncers.
+ testCatwalkSyncer := &catwalkSync{}
+ testHyperSyncer := &hyperSync{}
+
+ // Mock clients that return ErrNotModified.
+ mockCatwalkClient := &mockCatwalkClient{
+ err: catwalk.ErrNotModified,
+ }
+ mockHyperClient := &mockHyperClient{
+ err: catwalk.ErrNotModified,
+ }
- mockClient := &mockProviderClient{shouldFail: false}
+ testCatwalkSyncer.Init(mockCatwalkClient, catwalkPath, true)
+ testHyperSyncer.Init(mockHyperClient, hyperPath, true)
- providers, err := loadProviders(mockClient, "", cachePath)
+ // Get providers - should use cached.
+ catwalkResult, err := testCatwalkSyncer.Get(t.Context())
require.NoError(t, err)
- require.NotNil(t, providers)
- require.Len(t, providers, 1)
- require.Equal(t, "Mock", providers[0].Name, "Expected fresh provider from fetch")
+ require.Len(t, catwalkResult, 2)
+ require.Equal(t, "Cached1", catwalkResult[0].Name)
+
+ hyperResult, err := testHyperSyncer.Get(t.Context())
+ require.NoError(t, err)
+ require.Equal(t, "Cached Hyper", hyperResult.Name)
}
-func TestProvider_NotModifiedUsesCached(t *testing.T) {
+func TestProviders_Integration_CatwalkFailsHyperSucceeds(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("XDG_DATA_HOME", tmpDir)
- resetProviderState()
- defer resetProviderState()
+ testCatwalkSyncer := &catwalkSync{}
+ testHyperSyncer := &hyperSync{}
- cachePath := tmpDir + "/crush/providers.json"
- require.NoError(t, os.MkdirAll(tmpDir+"/crush", 0o755))
- cachedProviders := []catwalk.Provider{
- {Name: "Cached"},
+ // Catwalk fails, Hyper succeeds.
+ mockCatwalkClient := &mockCatwalkClient{
+ err: catwalk.ErrNotModified, // Will use embedded.
}
- data, err := json.Marshal(cachedProviders)
+ mockHyperClient := &mockHyperClient{
+ provider: catwalk.Provider{
+ Name: "Hyper",
+ ID: "hyper",
+ Models: []catwalk.Model{
+ {ID: "hyper-1", Name: "Hyper Model"},
+ },
+ },
+ }
+
+ catwalkPath := tmpDir + "/crush/providers.json"
+ hyperPath := tmpDir + "/crush/hyper.json"
+
+ testCatwalkSyncer.Init(mockCatwalkClient, catwalkPath, true)
+ testHyperSyncer.Init(mockHyperClient, hyperPath, true)
+
+ catwalkResult, err := testCatwalkSyncer.Get(t.Context())
require.NoError(t, err)
- require.NoError(t, os.WriteFile(cachePath, data, 0o644))
+ require.NotEmpty(t, catwalkResult) // Should have embedded.
- mockClient := &mockProviderClient{shouldReturnErr: catwalk.ErrNotModified}
- providers, err := loadProviders(mockClient, "", cachePath)
- require.ErrorIs(t, err, catwalk.ErrNotModified)
- require.Nil(t, providers)
+ hyperResult, err := testHyperSyncer.Get(t.Context())
+ require.NoError(t, err)
+ require.Equal(t, "Hyper", hyperResult.Name)
}
-func TestProvider_EmptyCacheDefaultsToEmbedded(t *testing.T) {
+func TestProviders_Integration_BothFail(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("XDG_DATA_HOME", tmpDir)
- resetProviderState()
- defer resetProviderState()
+ testCatwalkSyncer := &catwalkSync{}
+ testHyperSyncer := &hyperSync{}
- cachePath := tmpDir + "/crush/providers.json"
- require.NoError(t, os.MkdirAll(tmpDir+"/crush", 0o755))
- emptyProviders := []catwalk.Provider{}
- data, err := json.Marshal(emptyProviders)
+ // Both fail.
+ mockCatwalkClient := &mockCatwalkClient{
+ err: catwalk.ErrNotModified,
+ }
+ mockHyperClient := &mockHyperClient{
+ provider: catwalk.Provider{}, // Empty provider.
+ }
+
+ catwalkPath := tmpDir + "/crush/providers.json"
+ hyperPath := tmpDir + "/crush/hyper.json"
+
+ testCatwalkSyncer.Init(mockCatwalkClient, catwalkPath, true)
+ testHyperSyncer.Init(mockHyperClient, hyperPath, true)
+
+ catwalkResult, err := testCatwalkSyncer.Get(t.Context())
require.NoError(t, err)
- require.NoError(t, os.WriteFile(cachePath, data, 0o644))
+ require.NotEmpty(t, catwalkResult) // Should fall back to embedded.
- cached, _, err := loadProvidersFromCache(cachePath)
+ hyperResult, err := testHyperSyncer.Get(t.Context())
require.NoError(t, err)
- require.Empty(t, cached, "Expected empty cache")
+ require.Equal(t, "Charm Hyper", hyperResult.Name) // Falls back to embedded when no models.
}
-func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) {
- client := &mockProviderClient{shouldFail: true}
- tmpPath := t.TempDir() + "/providers.json"
- providers, err := loadProviders(client, "", tmpPath)
+func TestCache_StoreAndGet(t *testing.T) {
+ t.Parallel()
+
+ tmpDir := t.TempDir()
+ cachePath := tmpDir + "/test.json"
+
+ cache := newCache[[]catwalk.Provider](cachePath)
+
+ providers := []catwalk.Provider{
+ {Name: "Provider1", ID: "p1"},
+ {Name: "Provider2", ID: "p2"},
+ }
+
+ // Store.
+ err := cache.Store(providers)
+ require.NoError(t, err)
+
+ // Get.
+ result, etag, err := cache.Get()
+ require.NoError(t, err)
+ require.Len(t, result, 2)
+ require.Equal(t, "Provider1", result[0].Name)
+ require.NotEmpty(t, etag)
+}
+
+func TestCache_GetNonExistent(t *testing.T) {
+ t.Parallel()
+
+ tmpDir := t.TempDir()
+ cachePath := tmpDir + "/nonexistent.json"
+
+ cache := newCache[[]catwalk.Provider](cachePath)
+
+ _, _, err := cache.Get()
require.Error(t, err)
- require.Nil(t, providers, "Expected nil providers when loading fails and no cache exists")
+ require.Contains(t, err.Error(), "failed to read provider cache file")
+}
+
+func TestCache_GetInvalidJSON(t *testing.T) {
+ t.Parallel()
+
+ tmpDir := t.TempDir()
+ cachePath := tmpDir + "/invalid.json"
+
+ require.NoError(t, os.WriteFile(cachePath, []byte("invalid json"), 0o644))
+
+ cache := newCache[[]catwalk.Provider](cachePath)
+
+ _, _, err := cache.Get()
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "failed to unmarshal provider data from cache")
+}
+
+func TestCachePathFor(t *testing.T) {
+ tests := []struct {
+ name string
+ xdgDataHome string
+ expected string
+ }{
+ {
+ name: "with XDG_DATA_HOME",
+ xdgDataHome: "/custom/data",
+ expected: "/custom/data/crush/providers.json",
+ },
+ {
+ name: "without XDG_DATA_HOME",
+ xdgDataHome: "",
+ expected: "", // Will use platform-specific default.
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if tt.xdgDataHome != "" {
+ t.Setenv("XDG_DATA_HOME", tt.xdgDataHome)
+ } else {
+ t.Setenv("XDG_DATA_HOME", "")
+ }
+
+ result := cachePathFor("providers")
+ if tt.expected != "" {
+ require.Equal(t, tt.expected, filepath.ToSlash(result))
+ } else {
+ require.Contains(t, result, "crush")
+ require.Contains(t, result, "providers.json")
+ }
+ })
+ }
}
@@ -46,6 +46,22 @@ func Init() {
distinctId = getDistinctId()
}
+func GetID() string { return distinctId }
+
+func Alias(userID string) {
+ if client == nil || distinctId == fallbackId || distinctId == "" || userID == "" {
+ return
+ }
+ if err := client.Enqueue(posthog.Alias{
+ DistinctId: distinctId,
+ Alias: userID,
+ }); err != nil {
+ slog.Error("Failed to enqueue PostHog alias event", "error", err)
+ return
+ }
+ slog.Info("Aliased in PostHog", "machine_id", distinctId, "user_id", userID)
+}
+
// send logs an event to PostHog with the given event name and properties.
func send(event string, props ...any) {
if client == nil {
@@ -0,0 +1,251 @@
+// Package hyper provides functions to handle Hyper device flow authentication.
+package hyper
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "strings"
+ "time"
+
+ "github.com/charmbracelet/crush/internal/agent/hyper"
+ "github.com/charmbracelet/crush/internal/event"
+ "github.com/charmbracelet/crush/internal/oauth"
+)
+
+// DeviceAuthResponse contains the response from the device authorization endpoint.
+type DeviceAuthResponse struct {
+ DeviceCode string `json:"device_code"`
+ UserCode string `json:"user_code"`
+ VerificationURL string `json:"verification_url"`
+ ExpiresIn int `json:"expires_in"`
+}
+
+// TokenResponse contains the response from the polling endpoint.
+type TokenResponse struct {
+ RefreshToken string `json:"refresh_token,omitempty"`
+ UserID string `json:"user_id"`
+ OrganizationID string `json:"organization_id"`
+ OrganizationName string `json:"organization_name"`
+ Error string `json:"error,omitempty"`
+ ErrorDescription string `json:"error_description,omitempty"`
+}
+
+// InitiateDeviceAuth calls the /device/auth endpoint to start the device flow.
+func InitiateDeviceAuth(ctx context.Context) (*DeviceAuthResponse, error) {
+ url := hyper.BaseURL() + "/device/auth"
+
+ req, err := http.NewRequestWithContext(
+ ctx, http.MethodPost, url,
+ strings.NewReader(fmt.Sprintf(`{"device_name":%q}`, deviceName())),
+ )
+ if err != nil {
+ return nil, fmt.Errorf("create request: %w", err)
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("User-Agent", "crush")
+
+ client := &http.Client{Timeout: 30 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("execute request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
+ if err != nil {
+ return nil, fmt.Errorf("read response: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("device auth failed: status %d, body %q", resp.StatusCode, string(body))
+ }
+
+ var authResp DeviceAuthResponse
+ if err := json.Unmarshal(body, &authResp); err != nil {
+ return nil, fmt.Errorf("unmarshal response: %w", err)
+ }
+
+ return &authResp, nil
+}
+
+func deviceName() string {
+ if hostname, err := os.Hostname(); err == nil && hostname != "" {
+ return "Crush (" + hostname + ")"
+ }
+ return "Crush"
+}
+
+// PollForToken polls the /device/token endpoint until authorization is complete.
+// It respects the polling interval and handles various error states.
+func PollForToken(ctx context.Context, deviceCode string, expiresIn int) (string, error) {
+ ctx, cancel := context.WithTimeout(ctx, time.Duration(expiresIn)*time.Second)
+ defer cancel()
+
+ d := 5 * time.Second
+ ticker := time.NewTicker(d)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ return "", ctx.Err()
+ case <-ticker.C:
+ result, err := pollOnce(ctx, deviceCode)
+ if err != nil {
+ return "", err
+ }
+ if result.RefreshToken != "" {
+ event.Alias(result.UserID)
+ return result.RefreshToken, nil
+ }
+ switch result.Error {
+ case "authorization_pending":
+ continue
+ default:
+ return "", errors.New(result.ErrorDescription)
+ }
+ }
+ }
+}
+
+func pollOnce(ctx context.Context, deviceCode string) (TokenResponse, error) {
+ var result TokenResponse
+ url := fmt.Sprintf("%s/device/auth/%s", hyper.BaseURL(), deviceCode)
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
+ if err != nil {
+ return result, fmt.Errorf("create request: %w", err)
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("User-Agent", "crush")
+
+ client := &http.Client{Timeout: 30 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return result, fmt.Errorf("execute request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
+ if err != nil {
+ return result, fmt.Errorf("read response: %w", err)
+ }
+
+ if err := json.Unmarshal(body, &result); err != nil {
+ return result, fmt.Errorf("unmarshal response: %w: %s", err, string(body))
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return result, fmt.Errorf("token request failed: status %d body %q", resp.StatusCode, string(body))
+ }
+
+ return result, nil
+}
+
+// ExchangeToken exchanges a refresh token for an access token.
+func ExchangeToken(ctx context.Context, refreshToken string) (*oauth.Token, error) {
+ reqBody := map[string]string{
+ "refresh_token": refreshToken,
+ }
+
+ data, err := json.Marshal(reqBody)
+ if err != nil {
+ return nil, fmt.Errorf("marshal request: %w", err)
+ }
+
+ url := hyper.BaseURL() + "/token/exchange"
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(data))
+ if err != nil {
+ return nil, fmt.Errorf("create request: %w", err)
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("User-Agent", "crush")
+
+ client := &http.Client{Timeout: 30 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("execute request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
+ if err != nil {
+ return nil, fmt.Errorf("read response: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("token exchange failed: status %d body %q", resp.StatusCode, string(body))
+ }
+
+ var token oauth.Token
+ if err := json.Unmarshal(body, &token); err != nil {
+ return nil, fmt.Errorf("unmarshal response: %w", err)
+ }
+
+ token.SetExpiresAt()
+ return &token, nil
+}
+
+// IntrospectTokenResponse contains the response from the token introspection endpoint.
+type IntrospectTokenResponse struct {
+ Active bool `json:"active"`
+ Sub string `json:"sub,omitempty"`
+ OrgID string `json:"org_id,omitempty"`
+ Exp int64 `json:"exp,omitempty"`
+ Iat int64 `json:"iat,omitempty"`
+ Iss string `json:"iss,omitempty"`
+ Jti string `json:"jti,omitempty"`
+}
+
+// IntrospectToken validates an access token using the introspection endpoint.
+// Implements OAuth2 Token Introspection (RFC 7662).
+func IntrospectToken(ctx context.Context, accessToken string) (*IntrospectTokenResponse, error) {
+ reqBody := map[string]string{
+ "token": accessToken,
+ }
+
+ data, err := json.Marshal(reqBody)
+ if err != nil {
+ return nil, fmt.Errorf("marshal request: %w", err)
+ }
+
+ url := hyper.BaseURL() + "/token/introspect"
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(data))
+ if err != nil {
+ return nil, fmt.Errorf("create request: %w", err)
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("User-Agent", "crush")
+
+ client := &http.Client{Timeout: 30 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("execute request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
+ if err != nil {
+ return nil, fmt.Errorf("read response: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("token introspection failed: status %d body %q", resp.StatusCode, string(body))
+ }
+
+ var result IntrospectTokenResponse
+ if err := json.Unmarshal(body, &result); err != nil {
+ return nil, fmt.Errorf("unmarshal response: %w", err)
+ }
+
+ return &result, nil
+}
@@ -4,7 +4,7 @@ import (
"time"
)
-// Token represents an OAuth2 token from Claude Code Max.
+// Token represents an OAuth2 token.
type Token struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
@@ -13,6 +13,7 @@ import (
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/agent"
+ "github.com/charmbracelet/crush/internal/agent/hyper"
"github.com/charmbracelet/crush/internal/agent/tools/mcp"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/csync"
@@ -363,7 +364,7 @@ func (c *commandDialogCmp) defaultCommands() []Command {
selectedModel := cfg.Models[agentCfg.Model]
// Anthropic models: thinking toggle
- if providerCfg.Type == catwalk.TypeAnthropic {
+ if providerCfg.Type == catwalk.TypeAnthropic || providerCfg.Type == catwalk.Type(hyper.Name) {
status := "Enable"
if selectedModel.Think {
status = "Disable"
@@ -0,0 +1,264 @@
+// Package hyper provides the dialog for Hyper device flow authentication.
+package hyper
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "charm.land/bubbles/v2/spinner"
+ tea "charm.land/bubbletea/v2"
+ "charm.land/lipgloss/v2"
+ hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
+ "github.com/charmbracelet/crush/internal/oauth"
+ "github.com/charmbracelet/crush/internal/oauth/hyper"
+ "github.com/charmbracelet/crush/internal/tui/styles"
+ "github.com/charmbracelet/crush/internal/tui/util"
+ "github.com/pkg/browser"
+)
+
+// DeviceFlowState represents the current state of the device flow.
+type DeviceFlowState int
+
+const (
+ DeviceFlowStateDisplay DeviceFlowState = iota
+ DeviceFlowStateSuccess
+ DeviceFlowStateError
+)
+
+// DeviceAuthInitiatedMsg is sent when the device auth is initiated
+// successfully.
+type DeviceAuthInitiatedMsg struct {
+ deviceCode string
+ expiresIn int
+}
+
+// DeviceFlowCompletedMsg is sent when the device flow completes successfully.
+type DeviceFlowCompletedMsg struct {
+ Token *oauth.Token
+}
+
+// DeviceFlowErrorMsg is sent when the device flow encounters an error.
+type DeviceFlowErrorMsg struct {
+ Error error
+}
+
+// DeviceFlow handles the Hyper device flow authentication.
+type DeviceFlow struct {
+ State DeviceFlowState
+ width int
+ baseURL string
+ deviceCode string
+ userCode string
+ verificationURL string
+ expiresIn int
+ token *oauth.Token
+ cancelFunc context.CancelFunc
+ spinner spinner.Model
+}
+
+// NewDeviceFlow creates a new device flow component.
+func NewDeviceFlow() *DeviceFlow {
+ s := spinner.New()
+ s.Spinner = spinner.Dot
+ s.Style = lipgloss.NewStyle().Foreground(styles.CurrentTheme().GreenLight)
+ return &DeviceFlow{
+ State: DeviceFlowStateDisplay,
+ baseURL: hyperp.BaseURL(),
+ spinner: s,
+ }
+}
+
+// Init initializes the device flow by calling the device auth API and starting polling.
+func (d *DeviceFlow) Init() tea.Cmd {
+ return tea.Batch(d.spinner.Tick, d.initiateDeviceAuth)
+}
+
+// Update handles messages and state transitions.
+func (d *DeviceFlow) Update(msg tea.Msg) (util.Model, tea.Cmd) {
+ var cmd tea.Cmd
+ d.spinner, cmd = d.spinner.Update(msg)
+
+ switch msg := msg.(type) {
+ case DeviceAuthInitiatedMsg:
+ // Start polling now that we have the device code.
+ d.expiresIn = msg.expiresIn
+ return d, tea.Batch(cmd, d.startPolling(msg.deviceCode))
+ case DeviceFlowCompletedMsg:
+ d.State = DeviceFlowStateSuccess
+ d.token = msg.Token
+ return d, nil
+ case DeviceFlowErrorMsg:
+ d.State = DeviceFlowStateError
+ return d, util.ReportError(msg.Error)
+ }
+
+ return d, cmd
+}
+
+// View renders the device flow dialog.
+func (d *DeviceFlow) View() string {
+ t := styles.CurrentTheme()
+
+ whiteStyle := lipgloss.NewStyle().Foreground(t.White)
+ primaryStyle := lipgloss.NewStyle().Foreground(t.Primary)
+ greenStyle := lipgloss.NewStyle().Foreground(t.GreenLight)
+ errorStyle := lipgloss.NewStyle().Foreground(t.Error)
+ mutedStyle := lipgloss.NewStyle().Foreground(t.FgMuted)
+
+ switch d.State {
+ case DeviceFlowStateDisplay:
+ if d.userCode == "" {
+ return lipgloss.NewStyle().
+ Margin(0, 1).
+ Render(
+ greenStyle.Render(d.spinner.View()) +
+ mutedStyle.Render("Initializing..."),
+ )
+ }
+
+ instructions := lipgloss.NewStyle().
+ Margin(1, 1, 0, 1).
+ Width(d.width - 2).
+ Render(
+
+ whiteStyle.Render("Press ") +
+ primaryStyle.Render("enter") +
+ whiteStyle.Render(" to copy the code below and open the browser."),
+ )
+
+ codeBox := lipgloss.NewStyle().
+ Width(d.width-2).
+ Height(7).
+ Align(lipgloss.Center, lipgloss.Center).
+ Background(t.BgBaseLighter).
+ Margin(1).
+ Render(
+ lipgloss.NewStyle().
+ Bold(true).
+ Foreground(t.White).
+ Render(d.userCode),
+ )
+
+ link := lipgloss.NewStyle().Hyperlink(d.verificationURL, "id=hyper-verify").Render(d.verificationURL)
+ url := mutedStyle.
+ Margin(0, 1).
+ Width(d.width - 2).
+ Render("Browser not opening? Refer to\n" + link)
+
+ waiting := greenStyle.
+ Width(d.width-2).
+ Margin(1, 1, 0, 1).
+ Render(d.spinner.View() + "Verifying...")
+
+ return lipgloss.JoinVertical(
+ lipgloss.Left,
+ instructions,
+ codeBox,
+ url,
+ waiting,
+ )
+
+ case DeviceFlowStateSuccess:
+ return greenStyle.Margin(0, 1).Render("Authentication successful!")
+
+ case DeviceFlowStateError:
+ return lipgloss.NewStyle().
+ Margin(0, 1).
+ Width(d.width).
+ Render(errorStyle.Render("Authentication failed."))
+
+ default:
+ return ""
+ }
+}
+
+// SetWidth sets the width of the dialog.
+func (d *DeviceFlow) SetWidth(w int) {
+ d.width = w
+}
+
+// Cursor hides the cursor.
+func (d *DeviceFlow) Cursor() *tea.Cursor { return nil }
+
+// CopyCodeAndOpenURL copies the user code to the clipboard and opens the URL.
+func (d *DeviceFlow) CopyCodeAndOpenURL() tea.Cmd {
+ return tea.Sequence(
+ tea.SetClipboard(d.userCode),
+ func() tea.Msg {
+ if err := browser.OpenURL(d.verificationURL); err != nil {
+ return DeviceFlowErrorMsg{Error: fmt.Errorf("failed to open browser: %w", err)}
+ }
+ return nil
+ },
+ util.ReportInfo("Code copied and URL opened"),
+ )
+}
+
+// CopyCode copies just the user code to the clipboard.
+func (d *DeviceFlow) CopyCode() tea.Cmd {
+ return tea.Sequence(
+ tea.SetClipboard(d.userCode),
+ util.ReportInfo("Code copied to clipboard"),
+ )
+}
+
+// Cancel cancels the device flow polling.
+func (d *DeviceFlow) Cancel() {
+ if d.cancelFunc != nil {
+ d.cancelFunc()
+ }
+}
+
+func (d *DeviceFlow) initiateDeviceAuth() tea.Msg {
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+ authResp, err := hyper.InitiateDeviceAuth(ctx)
+ if err != nil {
+ return DeviceFlowErrorMsg{Error: fmt.Errorf("failed to initiate device auth: %w", err)}
+ }
+
+ d.deviceCode = authResp.DeviceCode
+ d.userCode = authResp.UserCode
+ d.verificationURL = authResp.VerificationURL
+
+ return DeviceAuthInitiatedMsg{
+ deviceCode: authResp.DeviceCode,
+ expiresIn: authResp.ExpiresIn,
+ }
+}
+
+// startPolling starts polling for the device token.
+func (d *DeviceFlow) startPolling(deviceCode string) tea.Cmd {
+ return func() tea.Msg {
+ ctx, cancel := context.WithCancel(context.Background())
+ d.cancelFunc = cancel
+
+ // Poll for refresh token.
+ refreshToken, err := hyper.PollForToken(ctx, deviceCode, d.expiresIn)
+ if err != nil {
+ if ctx.Err() != nil {
+ // Cancelled, don't report error.
+ return nil
+ }
+ return DeviceFlowErrorMsg{Error: err}
+ }
+
+ // Exchange refresh token for access token.
+ token, err := hyper.ExchangeToken(ctx, refreshToken)
+ if err != nil {
+ return DeviceFlowErrorMsg{Error: fmt.Errorf("token exchange failed: %w", err)}
+ }
+
+ // Verify the access token works.
+ introspect, err := hyper.IntrospectToken(ctx, token.AccessToken)
+ if err != nil {
+ return DeviceFlowErrorMsg{Error: fmt.Errorf("token introspection failed: %w", err)}
+ }
+ if !introspect.Active {
+ return DeviceFlowErrorMsg{Error: fmt.Errorf("access token is not active")}
+ }
+
+ return DeviceFlowCompletedMsg{Token: token}
+ }
+}
@@ -15,6 +15,8 @@ type KeyMap struct {
isAPIKeyHelp bool
isAPIKeyValid bool
+ isHyperDeviceFlow bool
+
isClaudeAuthChoiceHelp bool
isClaudeOAuthHelp bool
isClaudeOAuthURLState bool
@@ -74,6 +76,19 @@ func (k KeyMap) FullHelp() [][]key.Binding {
// ShortHelp implements help.KeyMap.
func (k KeyMap) ShortHelp() []key.Binding {
+ if k.isHyperDeviceFlow {
+ return []key.Binding{
+ key.NewBinding(
+ key.WithKeys("c"),
+ key.WithHelp("c", "copy code"),
+ ),
+ key.NewBinding(
+ key.WithKeys("enter"),
+ key.WithHelp("enter", "copy & open"),
+ ),
+ k.Close,
+ }
+ }
if k.isClaudeAuthChoiceHelp {
return []key.Binding{
key.NewBinding(
@@ -62,7 +62,8 @@ func (m *ModelListComponent) Init() tea.Cmd {
filteredProviders := []catwalk.Provider{}
for _, p := range providers {
hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
- if hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure {
+ isHyper := p.ID == "hyper"
+ if (hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure) || isHyper {
filteredProviders = append(filteredProviders, p)
}
}
@@ -204,8 +205,23 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
}
}
+ // Move "Charm Hyper" to first position
+ // (but still after recent models and custom providers).
+ sortedProviders := make([]catwalk.Provider, len(m.providers))
+ copy(sortedProviders, m.providers)
+ slices.SortStableFunc(sortedProviders, func(a, b catwalk.Provider) int {
+ switch {
+ case a.ID == "hyper":
+ return -1
+ case b.ID == "hyper":
+ return 1
+ default:
+ return 0
+ }
+ })
+
// Then add the known providers from the predefined list
- for _, provider := range m.providers {
+ for _, provider := range sortedProviders {
// Skip if we already added this provider as an unknown provider
if addedProviders[string(provider.ID)] {
continue
@@ -1,3 +1,4 @@
+// Package models provides the model selection dialog for the TUI.
package models
import (
@@ -11,10 +12,12 @@ import (
"charm.land/lipgloss/v2"
"github.com/atotto/clipboard"
"github.com/charmbracelet/catwalk/pkg/catwalk"
+ hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/tui/components/core"
"github.com/charmbracelet/crush/internal/tui/components/dialogs"
"github.com/charmbracelet/crush/internal/tui/components/dialogs/claude"
+ "github.com/charmbracelet/crush/internal/tui/components/dialogs/hyper"
"github.com/charmbracelet/crush/internal/tui/exp/list"
"github.com/charmbracelet/crush/internal/tui/styles"
"github.com/charmbracelet/crush/internal/tui/util"
@@ -70,6 +73,10 @@ type modelDialogCmp struct {
isAPIKeyValid bool
apiKeyValue string
+ // Hyper device flow state
+ hyperDeviceFlow *hyper.DeviceFlow
+ showHyperDeviceFlow bool
+
// Claude state
claudeAuthMethodChooser *claude.AuthMethodChooser
claudeOAuth2 *claude.OAuth2
@@ -127,6 +134,15 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
u, cmd := m.apiKeyInput.Update(msg)
m.apiKeyInput = u.(*APIKeyInput)
return m, cmd
+ case hyper.DeviceFlowCompletedMsg:
+ return m, m.saveOauthTokenAndContinue(msg.Token, true)
+ case hyper.DeviceAuthInitiatedMsg, hyper.DeviceFlowErrorMsg:
+ if m.hyperDeviceFlow != nil {
+ u, cmd := m.hyperDeviceFlow.Update(msg)
+ m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
+ return m, cmd
+ }
+ return m, nil
case claude.ValidationCompletedMsg:
var cmds []tea.Cmd
u, cmd := m.claudeOAuth2.Update(msg)
@@ -134,7 +150,7 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
cmds = append(cmds, cmd)
if msg.State == claude.OAuthValidationStateValid {
- cmds = append(cmds, m.saveAPIKeyAndContinue(msg.Token, false))
+ cmds = append(cmds, m.saveOauthTokenAndContinue(msg.Token, false))
m.keyMap.isClaudeOAuthHelpComplete = true
}
@@ -143,6 +159,11 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
return m, util.CmdHandler(dialogs.CloseDialogMsg{})
case tea.KeyPressMsg:
switch {
+ // Handle Hyper device flow keys
+ case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showHyperDeviceFlow:
+ if m.hyperDeviceFlow != nil {
+ return m, m.hyperDeviceFlow.CopyCode()
+ }
case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showClaudeOAuth2 && m.claudeOAuth2.State == claude.OAuthStateURL:
return m, tea.Sequence(
tea.SetClipboard(m.claudeOAuth2.URL),
@@ -156,6 +177,10 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
m.claudeAuthMethodChooser.ToggleChoice()
return m, nil
case key.Matches(msg, m.keyMap.Select):
+ // If showing device flow, enter copies code and opens URL
+ if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
+ return m, m.hyperDeviceFlow.CopyCodeAndOpenURL()
+ }
selectedItem := m.modelList.SelectedModel()
modelType := config.SelectedModelTypeLarge
@@ -167,6 +192,7 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
m.keyMap.isClaudeAuthChoiceHelp = false
m.keyMap.isClaudeOAuthHelp = false
m.keyMap.isAPIKeyHelp = true
+ m.showHyperDeviceFlow = false
m.showClaudeAuthMethodChooser = false
m.needsAPIKey = true
m.selectedModel = selectedItem
@@ -194,7 +220,7 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
return m, cmd2
}
if m.isAPIKeyValid {
- return m, m.saveAPIKeyAndContinue(m.apiKeyValue, true)
+ return m, m.saveOauthTokenAndContinue(m.apiKeyValue, true)
}
if m.needsAPIKey {
// Handle API key submission
@@ -249,15 +275,23 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
ModelType: modelType,
}),
)
- } else {
- if selectedItem.Provider.ID == catwalk.InferenceProviderAnthropic {
- m.showClaudeAuthMethodChooser = true
- m.keyMap.isClaudeAuthChoiceHelp = true
- return m, nil
- }
- askForApiKey()
+ }
+ switch selectedItem.Provider.ID {
+ case catwalk.InferenceProviderAnthropic:
+ m.showClaudeAuthMethodChooser = true
+ m.keyMap.isClaudeAuthChoiceHelp = true
return m, nil
+ case hyperp.Name:
+ m.showHyperDeviceFlow = true
+ m.selectedModel = selectedItem
+ m.selectedModelType = modelType
+ m.hyperDeviceFlow = hyper.NewDeviceFlow()
+ m.hyperDeviceFlow.SetWidth(m.width - 2)
+ return m, m.hyperDeviceFlow.Init()
}
+ // For other providers, show API key input
+ askForApiKey()
+ return m, nil
case key.Matches(msg, m.keyMap.Tab):
switch {
case m.showClaudeAuthMethodChooser:
@@ -275,6 +309,14 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
return m, m.modelList.SetModelType(LargeModelType)
}
case key.Matches(msg, m.keyMap.Close):
+ if m.showHyperDeviceFlow {
+ // Cancel device flow and go back to model selection
+ if m.hyperDeviceFlow != nil {
+ m.hyperDeviceFlow.Cancel()
+ }
+ m.showHyperDeviceFlow = false
+ m.selectedModel = nil
+ }
if m.showClaudeAuthMethodChooser {
m.claudeAuthMethodChooser.SetDefaults()
m.showClaudeAuthMethodChooser = false
@@ -329,7 +371,20 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
return m, cmd
}
case spinner.TickMsg:
- if m.showClaudeOAuth2 {
+ u, cmd := m.apiKeyInput.Update(msg)
+ m.apiKeyInput = u.(*APIKeyInput)
+ if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
+ u, cmd = m.hyperDeviceFlow.Update(msg)
+ m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
+ }
+ return m, cmd
+ default:
+ // Pass all other messages to the device flow for spinner animation
+ if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
+ u, cmd := m.hyperDeviceFlow.Update(msg)
+ m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
+ return m, cmd
+ } else if m.showClaudeOAuth2 {
u, cmd := m.claudeOAuth2.Update(msg)
m.claudeOAuth2 = u.(*claude.OAuth2)
return m, cmd
@@ -345,6 +400,23 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
func (m *modelDialogCmp) View() string {
t := styles.CurrentTheme()
+ if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
+ // Show Hyper device flow
+ m.keyMap.isHyperDeviceFlow = true
+ deviceFlowView := m.hyperDeviceFlow.View()
+ content := lipgloss.JoinVertical(
+ lipgloss.Left,
+ t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Authenticate with Hyper", m.width-4)),
+ deviceFlowView,
+ "",
+ t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
+ )
+ return m.style().Render(content)
+ }
+
+ // Reset the flags when not showing device flow
+ m.keyMap.isHyperDeviceFlow = false
+
switch {
case m.showClaudeAuthMethodChooser:
chooserView := m.claudeAuthMethodChooser.View()
@@ -397,6 +469,9 @@ func (m *modelDialogCmp) View() string {
}
func (m *modelDialogCmp) Cursor() *tea.Cursor {
+ if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
+ return m.hyperDeviceFlow.Cursor()
+ }
if m.showClaudeAuthMethodChooser {
return nil
}
@@ -477,10 +552,8 @@ func (m *modelDialogCmp) modelTypeRadio() string {
func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
cfg := config.Get()
- if _, ok := cfg.Providers.Get(providerID); ok {
- return true
- }
- return false
+ _, ok := cfg.Providers.Get(providerID)
+ return ok
}
func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
@@ -497,7 +570,7 @@ func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*cat
return nil, nil
}
-func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey any, close bool) tea.Cmd {
+func (m *modelDialogCmp) saveOauthTokenAndContinue(apiKey any, close bool) tea.Cmd {
if m.selectedModel == nil {
return util.ReportError(fmt.Errorf("no model selected"))
}
@@ -323,6 +323,7 @@ type ItemSection interface {
layout.Sizeable
Indexable
SetInfo(info string)
+ Title() string
}
type itemSectionModel struct {
width int
@@ -337,6 +338,11 @@ func (m *itemSectionModel) ID() string {
return m.id
}
+// Title implements ItemSection.
+func (m *itemSectionModel) Title() string {
+ return m.title
+}
+
func NewItemSection(title string) ItemSection {
return &itemSectionModel{
title: title,