diff --git a/README.md b/README.md index 8be0e3b0f8e23e0fa59b5095fe3892801aa17b35..32d7ebd50a8dc7762a52539831948be9e19e46ed 100644 --- a/README.md +++ b/README.md @@ -664,17 +664,6 @@ Or by setting the following in your config: Crush also respects the [`DO_NOT_TRACK`](https://consoledonottrack.com) convention which can be enabled via `export DO_NOT_TRACK=1`. -## A Note on Claude Max and GitHub Copilot - -Crush only supports model providers through official, compliant APIs. We do not -support or endorse any methods that rely on personal Claude Max and GitHub -Copilot accounts or OAuth workarounds, which violate Anthropic and -Microsoft’s Terms of Service. - -We’re committed to building sustainable, trusted integrations with model -providers. If you’re a provider interested in working with us, -[reach out](mailto:vt100@charm.sh). - ## Contributing See the [contributing guide](https://github.com/charmbracelet/crush?tab=contributing-ov-file#contributing). diff --git a/go.mod b/go.mod index 5aa63752b54c09e785b9666af580b7b22e72651e..9173a82b02bed31ae88a2d1d5dc2bc4490f79fe0 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( charm.land/bubbles/v2 v2.0.0-rc.1 charm.land/bubbletea/v2 v2.0.0-rc.2.0.20251124184313-5de0f1f67562 charm.land/fantasy v0.3.2 - charm.land/lipgloss/v2 v2.0.0-beta.3.0.20251106193318-19329a3e8410 + charm.land/lipgloss/v2 v2.0.0-beta.3.0.20251119143523-0334bb4562ca charm.land/x/vcr v0.1.1 github.com/JohannesKaufmann/html-to-markdown v1.6.0 github.com/MakeNowJust/heredoc v1.0.0 @@ -41,6 +41,7 @@ require ( github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/nxadm/tail v1.4.11 github.com/openai/openai-go/v2 v2.7.1 + github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c github.com/posthog/posthog-go v1.6.12 github.com/pressly/goose/v3 v3.26.0 github.com/qjebbs/go-jsons v1.0.0-alpha.4 diff --git a/go.sum b/go.sum index 63016cd0b1f010133572013e8b27098acadf49a2..ee9ff39a055a5739c637107bfa7fc15ee3f2f71a 100644 --- a/go.sum +++ b/go.sum @@ -4,8 +4,8 @@ charm.land/bubbletea/v2 v2.0.0-rc.2.0.20251124184313-5de0f1f67562 h1:61aovinon0n charm.land/bubbletea/v2 v2.0.0-rc.2.0.20251124184313-5de0f1f67562/go.mod h1:IXFmnCnMLTWw/KQ9rEatSYqbAPAYi8kA3Yqwa1SFnLk= charm.land/fantasy v0.3.2 h1:yHTsSZ25LcICMRw3xzdz3OkaZtDQch+B5ljJo17HxgU= charm.land/fantasy v0.3.2/go.mod h1:sV8Ns/JTJHOaYOHPgVRDugMheAyxsW/nmdpVGrycYEk= -charm.land/lipgloss/v2 v2.0.0-beta.3.0.20251106193318-19329a3e8410 h1:D9PbaszZYpB4nj+d6HTWr1onlmlyuGVNfL9gAi8iB3k= -charm.land/lipgloss/v2 v2.0.0-beta.3.0.20251106193318-19329a3e8410/go.mod h1:1qZyvvVCenJO2M1ac2mX0yyiIZJoZmDM4DG4s0udJkU= +charm.land/lipgloss/v2 v2.0.0-beta.3.0.20251119143523-0334bb4562ca h1:6bVc8OFotCS4sS7HKqxTudP7yn8Y0ODR6df2pdlY/+s= +charm.land/lipgloss/v2 v2.0.0-beta.3.0.20251119143523-0334bb4562ca/go.mod h1:XSJjv7DaH4zd1Y27kZis295RkEj9OFR9zh2WffQQsKQ= charm.land/x/vcr v0.1.1 h1:PXCFMUG0rPtyk35rhfzYCJEduOzWXCIbrXTFq4OF/9Q= charm.land/x/vcr v0.1.1/go.mod h1:eByq2gqzWvcct/8XE2XO5KznoWEBiXH56+y2gphbltM= cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE= @@ -408,6 +408,7 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/internal/agent/agent.go b/internal/agent/agent.go index ec5bc19ba4efaf0cc15f46620711621a92dff2b9..9b7266c7c358865bcfa58520b477ae7a2dcfb22b 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -238,8 +238,8 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy } } - if a.systemPromptPrefix != "" { - prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...) + if promptPrefix := a.promptPrefix(); promptPrefix != "" { + prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...) } var assistantMsg message.Message @@ -789,6 +789,10 @@ func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) + modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens) + if a.isClaudeCode() { + cost = 0 + } + a.eventTokensUsed(session.ID, model, usage, cost) if overrideCost != nil { @@ -882,3 +886,16 @@ func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) { func (a *sessionAgent) Model() Model { return a.largeModel } + +func (a *sessionAgent) promptPrefix() string { + if a.isClaudeCode() { + return "You are Claude Code, Anthropic's official CLI for Claude." + } + return a.systemPromptPrefix +} + +func (a *sessionAgent) isClaudeCode() bool { + cfg := config.Get() + pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider) + return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil +} diff --git a/internal/config/config.go b/internal/config/config.go index 4b441b0b18563814ca58b9d48cf2e8ffbd7e782f..66e90e062cad91fb088d8ad97970a4f790960d92 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,6 +1,7 @@ package config import ( + "cmp" "context" "fmt" "log/slog" @@ -14,6 +15,7 @@ import ( "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" + "github.com/charmbracelet/crush/internal/oauth" "github.com/invopop/jsonschema" "github.com/tidwall/sjson" ) @@ -92,6 +94,8 @@ type ProviderConfig struct { Type catwalk.Type `json:"type,omitempty" jsonschema:"description=Provider type that determines the API format,enum=openai,enum=openai-compat,enum=anthropic,enum=gemini,enum=azure,enum=vertexai,default=openai"` // The provider's API key. APIKey string `json:"api_key,omitempty" jsonschema:"description=API key for authentication with the provider,example=$OPENAI_API_KEY"` + // OAuthToken for providers that use OAuth2 authentication. + OAuthToken *oauth.Token `json:"oauth,omitempty" jsonschema:"description=OAuth2 token for authentication with the provider"` // Marks the provider as disabled. Disable bool `json:"disable,omitempty" jsonschema:"description=Whether this provider is disabled,default=false"` @@ -112,6 +116,24 @@ type ProviderConfig struct { Models []catwalk.Model `json:"models,omitempty" jsonschema:"description=List of models available from this provider"` } +func (pc *ProviderConfig) SetupClaudeCode() { + if !strings.HasPrefix(pc.APIKey, "Bearer ") { + pc.APIKey = fmt.Sprintf("Bearer %s", pc.APIKey) + } + pc.SystemPromptPrefix = "You are Claude Code, Anthropic's official CLI for Claude." + pc.ExtraHeaders["anthropic-version"] = "2023-06-01" + + value := pc.ExtraHeaders["anthropic-beta"] + const want = "oauth-2025-04-20" + if !strings.Contains(value, want) { + if value != "" { + value += "," + } + value += want + } + pc.ExtraHeaders["anthropic-beta"] = value +} + type MCPType string const ( @@ -448,16 +470,34 @@ func (c *Config) SetConfigField(key string, value any) error { return nil } -func (c *Config) SetProviderAPIKey(providerID, apiKey string) error { - // First save to the config file - err := c.SetConfigField("providers."+providerID+".api_key", apiKey) - if err != nil { - return fmt.Errorf("failed to save API key to config file: %w", err) +func (c *Config) SetProviderAPIKey(providerID string, apiKey any) error { + var providerConfig ProviderConfig + var exists bool + var setKeyOrToken func() + + switch v := apiKey.(type) { + case string: + if err := c.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil { + return fmt.Errorf("failed to save api key to config file: %w", err) + } + setKeyOrToken = func() { providerConfig.APIKey = v } + case *oauth.Token: + if err := cmp.Or( + c.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken), + c.SetConfigField(fmt.Sprintf("providers.%s.oauth", providerID), v), + ); err != nil { + return err + } + setKeyOrToken = func() { + providerConfig.APIKey = v.AccessToken + providerConfig.OAuthToken = v + providerConfig.SetupClaudeCode() + } } - providerConfig, exists := c.Providers.Get(providerID) + providerConfig, exists = c.Providers.Get(providerID) if exists { - providerConfig.APIKey = apiKey + setKeyOrToken() c.Providers.Set(providerID, providerConfig) return nil } @@ -477,12 +517,12 @@ func (c *Config) SetProviderAPIKey(providerID, apiKey string) error { Name: foundProvider.Name, BaseURL: foundProvider.APIEndpoint, Type: foundProvider.Type, - APIKey: apiKey, Disable: false, ExtraHeaders: make(map[string]string), ExtraParams: make(map[string]string), Models: foundProvider.Models, } + setKeyOrToken() } else { return fmt.Errorf("provider with ID %s not found in known providers", providerID) } diff --git a/internal/config/load.go b/internal/config/load.go index f4aec7bfbe3a08b3af3f78e8f1ebc6ce94e3328b..7645861198eefbceb1e283ee7815d3f130b0b868 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -1,6 +1,7 @@ package config import ( + "cmp" "context" "encoding/json" "fmt" @@ -18,9 +19,11 @@ import ( "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" + "github.com/charmbracelet/crush/internal/event" "github.com/charmbracelet/crush/internal/fsext" "github.com/charmbracelet/crush/internal/home" "github.com/charmbracelet/crush/internal/log" + "github.com/charmbracelet/crush/internal/oauth/claude" powernapConfig "github.com/charmbracelet/x/powernap/pkg/config" ) @@ -133,6 +136,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know knownProviderNames := make(map[string]bool) restore := PushPopCrushEnv() defer restore() + for _, p := range knownProviders { knownProviderNames[string(p.ID)] = true config, configExists := c.Providers.Get(string(p.ID)) @@ -185,6 +189,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know Name: p.Name, BaseURL: p.APIEndpoint, APIKey: p.APIKey, + OAuthToken: config.OAuthToken, Type: p.Type, Disable: config.Disable, SystemPromptPrefix: config.SystemPromptPrefix, @@ -194,6 +199,29 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know Models: p.Models, } + if p.ID == catwalk.InferenceProviderAnthropic && config.OAuthToken != nil { + if config.OAuthToken.IsExpired() { + newToken, err := claude.RefreshToken(context.TODO(), config.OAuthToken.RefreshToken) + if err == nil { + slog.Info("Successfully refreshed Anthropic OAuth token") + config.OAuthToken = newToken + prepared.OAuthToken = newToken + if err := cmp.Or( + c.SetConfigField("providers.anthropic.api_key", newToken.AccessToken), + c.SetConfigField("providers.anthropic.oauth", newToken), + ); err != nil { + return err + } + } else { + slog.Error("Failed to refresh Anthropic OAuth token", "error", err) + event.Error(err) + } + } else { + slog.Info("Using existing non-expired Anthropic OAuth token") + } + prepared.SetupClaudeCode() + } + switch p.ID { // Handle specific providers that require additional configuration case catwalk.InferenceProviderVertexAI: diff --git a/internal/oauth/claude/challenge.go b/internal/oauth/claude/challenge.go new file mode 100644 index 0000000000000000000000000000000000000000..ec9ed3c5d17e91fc5dc8c33f44f3d6a4ce4aa244 --- /dev/null +++ b/internal/oauth/claude/challenge.go @@ -0,0 +1,28 @@ +package claude + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "strings" +) + +// GetChallenge generates a PKCE verifier and its corresponding challenge. +func GetChallenge() (verifier string, challenge string, err error) { + bytes := make([]byte, 32) + if _, err := rand.Read(bytes); err != nil { + return "", "", err + } + verifier = encodeBase64(bytes) + hash := sha256.Sum256([]byte(verifier)) + challenge = encodeBase64(hash[:]) + return verifier, challenge, nil +} + +func encodeBase64(input []byte) (encoded string) { + encoded = base64.StdEncoding.EncodeToString(input) + encoded = strings.ReplaceAll(encoded, "=", "") + encoded = strings.ReplaceAll(encoded, "+", "-") + encoded = strings.ReplaceAll(encoded, "/", "_") + return encoded +} diff --git a/internal/oauth/claude/oauth.go b/internal/oauth/claude/oauth.go new file mode 100644 index 0000000000000000000000000000000000000000..b3c47960453385395ec2b6988229d0d6e5e3eae4 --- /dev/null +++ b/internal/oauth/claude/oauth.go @@ -0,0 +1,126 @@ +package claude + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/charmbracelet/crush/internal/oauth" +) + +const clientId = "9d1c250a-e61b-44d9-88ed-5944d1962f5e" + +// AuthorizeURL returns the Claude Code Max OAuth2 authorization URL. +func AuthorizeURL(verifier, challenge string) (string, error) { + u, err := url.Parse("https://claude.ai/oauth/authorize") + if err != nil { + return "", err + } + q := u.Query() + q.Set("response_type", "code") + q.Set("client_id", clientId) + q.Set("redirect_uri", "https://console.anthropic.com/oauth/code/callback") + q.Set("scope", "org:create_api_key user:profile user:inference") + q.Set("code_challenge", challenge) + q.Set("code_challenge_method", "S256") + q.Set("state", verifier) + u.RawQuery = q.Encode() + return u.String(), nil +} + +// ExchangeToken exchanges the authorization code for an OAuth2 token. +func ExchangeToken(ctx context.Context, code, verifier string) (*oauth.Token, error) { + code = strings.TrimSpace(code) + parts := strings.SplitN(code, "#", 2) + pure := parts[0] + state := "" + if len(parts) > 1 { + state = parts[1] + } + + reqBody := map[string]string{ + "code": pure, + "state": state, + "grant_type": "authorization_code", + "client_id": clientId, + "redirect_uri": "https://console.anthropic.com/oauth/code/callback", + "code_verifier": verifier, + } + + resp, err := request(ctx, "POST", "https://console.anthropic.com/v1/oauth/token", reqBody) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("claude code max: failed to exchange token: status %d body %q", resp.StatusCode, string(body)) + } + + var token oauth.Token + if err := json.Unmarshal(body, &token); err != nil { + return nil, err + } + token.SetExpiresAt() + return &token, nil +} + +// RefreshToken refreshes the OAuth2 token using the provided refresh token. +func RefreshToken(ctx context.Context, refreshToken string) (*oauth.Token, error) { + reqBody := map[string]string{ + "grant_type": "refresh_token", + "refresh_token": refreshToken, + "client_id": clientId, + } + + resp, err := request(ctx, "POST", "https://console.anthropic.com/v1/oauth/token", reqBody) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("claude code max: failed to refresh token: status %d body %q", resp.StatusCode, string(body)) + } + + var token oauth.Token + if err := json.Unmarshal(body, &token); err != nil { + return nil, err + } + token.SetExpiresAt() + return &token, nil +} + +func request(ctx context.Context, method, url string, body any) (*http.Response, error) { + date, err := json.Marshal(body) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, method, url, bytes.NewReader(date)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "anthropic") + + client := &http.Client{Timeout: 30 * time.Second} + return client.Do(req) +} diff --git a/internal/oauth/token.go b/internal/oauth/token.go new file mode 100644 index 0000000000000000000000000000000000000000..29d4791b5fd416e65995698dcc4665ea96fbb090 --- /dev/null +++ b/internal/oauth/token.go @@ -0,0 +1,23 @@ +package oauth + +import ( + "time" +) + +// Token represents an OAuth2 token from Claude Code Max. +type Token struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` + ExpiresAt int64 `json:"expires_at"` +} + +// SetExpiresAt calculates and sets the ExpiresAt field based on the current time and ExpiresIn. +func (t *Token) SetExpiresAt() { + t.ExpiresAt = time.Now().Add(time.Duration(t.ExpiresIn) * time.Second).Unix() +} + +// IsExpired checks if the token is expired or about to expire (within 10% of its lifetime). +func (t *Token) IsExpired() bool { + return time.Now().Unix() >= (t.ExpiresAt - int64(t.ExpiresIn)/10) +} diff --git a/internal/tui/components/chat/splash/keys.go b/internal/tui/components/chat/splash/keys.go index 5bd9e333395ae6e379c95e997f5eb12860b280a8..fc8fc373498feea584e75701010762ac66db7879 100644 --- a/internal/tui/components/chat/splash/keys.go +++ b/internal/tui/components/chat/splash/keys.go @@ -12,7 +12,8 @@ type KeyMap struct { No, Tab, LeftRight, - Back key.Binding + Back, + Copy key.Binding } func DefaultKeyMap() KeyMap { @@ -49,5 +50,9 @@ func DefaultKeyMap() KeyMap { key.WithKeys("esc", "alt+esc"), key.WithHelp("esc", "back"), ), + Copy: key.NewBinding( + key.WithKeys("c"), + key.WithHelp("c", "copy url"), + ), } } diff --git a/internal/tui/components/chat/splash/splash.go b/internal/tui/components/chat/splash/splash.go index ffa6dbc8584f55be3e4dc5983c6da10bbc89251f..6a7db9440453b8d1f7751bd5ca7eb66ecd339828 100644 --- a/internal/tui/components/chat/splash/splash.go +++ b/internal/tui/components/chat/splash/splash.go @@ -9,6 +9,7 @@ import ( "charm.land/bubbles/v2/spinner" tea "charm.land/bubbletea/v2" "charm.land/lipgloss/v2" + "github.com/atotto/clipboard" "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/agent" "github.com/charmbracelet/crush/internal/config" @@ -16,6 +17,7 @@ import ( "github.com/charmbracelet/crush/internal/tui/components/chat" "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/components/core/layout" + "github.com/charmbracelet/crush/internal/tui/components/dialogs/claude" "github.com/charmbracelet/crush/internal/tui/components/dialogs/models" "github.com/charmbracelet/crush/internal/tui/components/logo" lspcomponent "github.com/charmbracelet/crush/internal/tui/components/lsp" @@ -41,6 +43,18 @@ type Splash interface { // IsAPIKeyValid returns whether the API key is valid IsAPIKeyValid() bool + + // IsShowingClaudeAuthMethodChooser returns whether showing Claude auth method chooser + IsShowingClaudeAuthMethodChooser() bool + + // IsShowingClaudeOAuth2 returns whether showing Claude OAuth2 flow + IsShowingClaudeOAuth2() bool + + // IsClaudeOAuthURLState returns whether in OAuth URL state + IsClaudeOAuthURLState() bool + + // IsClaudeOAuthComplete returns whether Claude OAuth flow is complete + IsClaudeOAuthComplete() bool } const ( @@ -72,6 +86,12 @@ type splashCmp struct { selectedModel *models.ModelOption isAPIKeyValid bool apiKeyValue string + + // Claude state + claudeAuthMethodChooser *claude.AuthMethodChooser + claudeOAuth2 *claude.OAuth2 + showClaudeAuthMethodChooser bool + showClaudeOAuth2 bool } func New() Splash { @@ -97,6 +117,9 @@ func New() Splash { modelList: modelList, apiKeyInput: apiKeyInput, selectedNo: false, + + claudeAuthMethodChooser: claude.NewAuthMethodChooser(), + claudeOAuth2: claude.NewOAuth2(), } } @@ -115,7 +138,12 @@ func (s *splashCmp) GetSize() (int, int) { // Init implements SplashPage. func (s *splashCmp) Init() tea.Cmd { - return tea.Batch(s.modelList.Init(), s.apiKeyInput.Init()) + return tea.Batch( + s.modelList.Init(), + s.apiKeyInput.Init(), + s.claudeAuthMethodChooser.Init(), + s.claudeOAuth2.Init(), + ) } // SetSize implements SplashPage. @@ -131,6 +159,7 @@ func (s *splashCmp) SetSize(width int, height int) tea.Cmd { s.listHeight = s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - s.logoGap() - 2 listWidth := min(60, width) s.apiKeyInput.SetWidth(width - 2) + s.claudeAuthMethodChooser.SetWidth(min(width-2, 60)) return s.modelList.SetSize(listWidth, s.listHeight) } @@ -139,6 +168,28 @@ func (s *splashCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { switch msg := msg.(type) { case tea.WindowSizeMsg: return s, s.SetSize(msg.Width, msg.Height) + case claude.ValidationCompletedMsg: + var cmds []tea.Cmd + u, cmd := s.claudeOAuth2.Update(msg) + s.claudeOAuth2 = u.(*claude.OAuth2) + cmds = append(cmds, cmd) + + if msg.State == claude.OAuthValidationStateValid { + cmds = append( + cmds, + s.saveAPIKeyAndContinue(msg.Token, false), + func() tea.Msg { + time.Sleep(5 * time.Second) + return claude.AuthenticationCompleteMsg{} + }, + ) + } + + return s, tea.Batch(cmds...) + case claude.AuthenticationCompleteMsg: + s.showClaudeAuthMethodChooser = false + s.showClaudeOAuth2 = false + return s, util.CmdHandler(OnboardingCompleteMsg{}) case models.APIKeyStateChangeMsg: u, cmd := s.apiKeyInput.Update(msg) s.apiKeyInput = u.(*models.APIKeyInput) @@ -150,16 +201,48 @@ func (s *splashCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { return s, cmd case SubmitAPIKeyMsg: if s.isAPIKeyValid { - return s, s.saveAPIKeyAndContinue(s.apiKeyValue) + return s, s.saveAPIKeyAndContinue(s.apiKeyValue, true) } case tea.KeyPressMsg: switch { + case key.Matches(msg, s.keyMap.Copy): + if s.showClaudeOAuth2 && s.claudeOAuth2.State == claude.OAuthStateURL { + return s, tea.Sequence( + tea.SetClipboard(s.claudeOAuth2.URL), + func() tea.Msg { + _ = clipboard.WriteAll(s.claudeOAuth2.URL) + return nil + }, + util.ReportInfo("URL copied to clipboard"), + ) + } else if s.showClaudeAuthMethodChooser { + u, cmd := s.claudeAuthMethodChooser.Update(msg) + s.claudeAuthMethodChooser = u.(*claude.AuthMethodChooser) + return s, cmd + } else if s.showClaudeOAuth2 { + u, cmd := s.claudeOAuth2.Update(msg) + s.claudeOAuth2 = u.(*claude.OAuth2) + return s, cmd + } case key.Matches(msg, s.keyMap.Back): + if s.showClaudeAuthMethodChooser { + s.claudeAuthMethodChooser.SetDefaults() + s.showClaudeAuthMethodChooser = false + return s, nil + } + if s.showClaudeOAuth2 { + s.claudeOAuth2.SetDefaults() + s.showClaudeOAuth2 = false + s.showClaudeAuthMethodChooser = true + return s, nil + } if s.isAPIKeyValid { return s, nil } if s.needsAPIKey { - // Go back to model selection + if s.selectedModel.Provider.ID == catwalk.InferenceProviderAnthropic { + s.showClaudeAuthMethodChooser = true + } s.needsAPIKey = false s.selectedModel = nil s.isAPIKeyValid = false @@ -168,8 +251,32 @@ func (s *splashCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { return s, nil } case key.Matches(msg, s.keyMap.Select): + if s.showClaudeAuthMethodChooser { + selectedItem := s.modelList.SelectedModel() + if selectedItem == nil { + return s, nil + } + + switch s.claudeAuthMethodChooser.State { + case claude.AuthMethodAPIKey: + s.showClaudeAuthMethodChooser = false + s.needsAPIKey = true + s.selectedModel = selectedItem + s.apiKeyInput.SetProviderName(selectedItem.Provider.Name) + case claude.AuthMethodOAuth2: + s.selectedModel = selectedItem + s.showClaudeAuthMethodChooser = false + s.showClaudeOAuth2 = true + } + return s, nil + } + if s.showClaudeOAuth2 { + m2, cmd2 := s.claudeOAuth2.ValidationConfirm() + s.claudeOAuth2 = m2.(*claude.OAuth2) + return s, cmd2 + } if s.isAPIKeyValid { - return s, s.saveAPIKeyAndContinue(s.apiKeyValue) + return s, s.saveAPIKeyAndContinue(s.apiKeyValue, true) } if s.isOnboarding && !s.needsAPIKey { selectedItem := s.modelList.SelectedModel() @@ -181,6 +288,10 @@ func (s *splashCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { s.isOnboarding = false return s, tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{})) } else { + if selectedItem.Provider.ID == catwalk.InferenceProviderAnthropic { + s.showClaudeAuthMethodChooser = true + return s, nil + } // Provider not configured, show API key input s.needsAPIKey = true s.selectedModel = selectedItem @@ -232,6 +343,10 @@ func (s *splashCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { return s, s.initializeProject() } case key.Matches(msg, s.keyMap.Tab, s.keyMap.LeftRight): + if s.showClaudeAuthMethodChooser { + s.claudeAuthMethodChooser.ToggleChoice() + return s, nil + } if s.needsAPIKey { u, cmd := s.apiKeyInput.Update(msg) s.apiKeyInput = u.(*models.APIKeyInput) @@ -272,7 +387,15 @@ func (s *splashCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { return s, s.initializeProject() } default: - if s.needsAPIKey { + if s.showClaudeAuthMethodChooser { + u, cmd := s.claudeAuthMethodChooser.Update(msg) + s.claudeAuthMethodChooser = u.(*claude.AuthMethodChooser) + return s, cmd + } else if s.showClaudeOAuth2 { + u, cmd := s.claudeOAuth2.Update(msg) + s.claudeOAuth2 = u.(*claude.OAuth2) + return s, cmd + } else if s.needsAPIKey { u, cmd := s.apiKeyInput.Update(msg) s.apiKeyInput = u.(*models.APIKeyInput) return s, cmd @@ -283,7 +406,11 @@ func (s *splashCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { } } case tea.PasteMsg: - if s.needsAPIKey { + if s.showClaudeOAuth2 { + u, cmd := s.claudeOAuth2.Update(msg) + s.claudeOAuth2 = u.(*claude.OAuth2) + return s, cmd + } else if s.needsAPIKey { u, cmd := s.apiKeyInput.Update(msg) s.apiKeyInput = u.(*models.APIKeyInput) return s, cmd @@ -293,14 +420,20 @@ func (s *splashCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { return s, cmd } case spinner.TickMsg: - u, cmd := s.apiKeyInput.Update(msg) - s.apiKeyInput = u.(*models.APIKeyInput) - return s, cmd + if s.showClaudeOAuth2 { + u, cmd := s.claudeOAuth2.Update(msg) + s.claudeOAuth2 = u.(*claude.OAuth2) + return s, cmd + } else { + u, cmd := s.apiKeyInput.Update(msg) + s.apiKeyInput = u.(*models.APIKeyInput) + return s, cmd + } } return s, nil } -func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd { +func (s *splashCmp) saveAPIKeyAndContinue(apiKey any, close bool) tea.Cmd { if s.selectedModel == nil { return nil } @@ -318,7 +451,10 @@ func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd { s.selectedModel = nil s.isAPIKeyValid = false - return tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{})) + if close { + return tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{})) + } + return cmd } func (s *splashCmp) initializeProject() tea.Cmd { @@ -426,7 +562,39 @@ func (s *splashCmp) isProviderConfigured(providerID string) bool { func (s *splashCmp) View() string { t := styles.CurrentTheme() var content string - if s.needsAPIKey { + if s.showClaudeAuthMethodChooser { + remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) + chooserView := s.claudeAuthMethodChooser.View() + authMethodSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render( + lipgloss.JoinVertical( + lipgloss.Left, + t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Let's Auth Anthropic"), + "", + chooserView, + ), + ) + content = lipgloss.JoinVertical( + lipgloss.Left, + s.logoRendered, + authMethodSelector, + ) + } else if s.showClaudeOAuth2 { + remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) + oauth2View := s.claudeOAuth2.View() + oauthSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render( + lipgloss.JoinVertical( + lipgloss.Left, + t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Let's Auth Anthropic"), + "", + oauth2View, + ), + ) + content = lipgloss.JoinVertical( + lipgloss.Left, + s.logoRendered, + oauthSelector, + ) + } else if s.needsAPIKey { remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) apiKeyView := t.S().Base.PaddingLeft(1).Render(s.apiKeyInput.View()) apiKeySelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render( @@ -524,6 +692,16 @@ func (s *splashCmp) View() string { } func (s *splashCmp) Cursor() *tea.Cursor { + if s.showClaudeAuthMethodChooser { + return nil + } + if s.showClaudeOAuth2 { + if cursor := s.claudeOAuth2.CodeInput.Cursor(); cursor != nil { + cursor.Y += 2 // FIXME(@andreynering): Why do we need this? + return s.moveCursor(cursor) + } + return nil + } if s.needsAPIKey { cursor := s.apiKeyInput.Cursor() if cursor != nil { @@ -596,17 +774,23 @@ func (s *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor { } // Calculate the correct Y offset based on current state logoHeight := lipgloss.Height(s.logoRendered) - if s.needsAPIKey { + if s.needsAPIKey || s.showClaudeOAuth2 { + var view string + if s.needsAPIKey { + view = s.apiKeyInput.View() + } else { + view = s.claudeOAuth2.View() + } infoSectionHeight := lipgloss.Height(s.infoSection()) baseOffset := logoHeight + SplashScreenPaddingY + infoSectionHeight - remainingHeight := s.height - baseOffset - lipgloss.Height(s.apiKeyInput.View()) - SplashScreenPaddingY + remainingHeight := s.height - baseOffset - lipgloss.Height(view) - SplashScreenPaddingY offset := baseOffset + remainingHeight cursor.Y += offset - cursor.X = cursor.X + 1 + cursor.X += 1 } else if s.isOnboarding { offset := logoHeight + SplashScreenPaddingY + s.logoGap() + 2 cursor.Y += offset - cursor.X = cursor.X + 1 + cursor.X += 1 } return cursor @@ -621,7 +805,21 @@ func (s *splashCmp) logoGap() int { // Bindings implements SplashPage. func (s *splashCmp) Bindings() []key.Binding { - if s.needsAPIKey { + if s.showClaudeAuthMethodChooser { + return []key.Binding{ + s.keyMap.Select, + s.keyMap.Tab, + s.keyMap.Back, + } + } else if s.showClaudeOAuth2 { + bindings := []key.Binding{ + s.keyMap.Select, + } + if s.claudeOAuth2.State == claude.OAuthStateURL { + bindings = append(bindings, s.keyMap.Copy) + } + return bindings + } else if s.needsAPIKey { return []key.Binding{ s.keyMap.Select, s.keyMap.Back, @@ -726,3 +924,19 @@ func (s *splashCmp) IsShowingAPIKey() bool { func (s *splashCmp) IsAPIKeyValid() bool { return s.isAPIKeyValid } + +func (s *splashCmp) IsShowingClaudeAuthMethodChooser() bool { + return s.showClaudeAuthMethodChooser +} + +func (s *splashCmp) IsShowingClaudeOAuth2() bool { + return s.showClaudeOAuth2 +} + +func (s *splashCmp) IsClaudeOAuthURLState() bool { + return s.showClaudeOAuth2 && s.claudeOAuth2.State == claude.OAuthStateURL +} + +func (s *splashCmp) IsClaudeOAuthComplete() bool { + return s.showClaudeOAuth2 && s.claudeOAuth2.State == claude.OAuthStateCode && s.claudeOAuth2.ValidationState == claude.OAuthValidationStateValid +} diff --git a/internal/tui/components/dialogs/claude/method.go b/internal/tui/components/dialogs/claude/method.go new file mode 100644 index 0000000000000000000000000000000000000000..071d437799dcd2e3d5b9e60c33c7173c18577016 --- /dev/null +++ b/internal/tui/components/dialogs/claude/method.go @@ -0,0 +1,115 @@ +package claude + +import ( + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" + "github.com/charmbracelet/crush/internal/tui/styles" + "github.com/charmbracelet/crush/internal/tui/util" +) + +type AuthMethod int + +const ( + AuthMethodAPIKey AuthMethod = iota + AuthMethodOAuth2 +) + +type AuthMethodChooser struct { + State AuthMethod + width int + isOnboarding bool +} + +func NewAuthMethodChooser() *AuthMethodChooser { + return &AuthMethodChooser{ + State: AuthMethodOAuth2, + } +} + +func (a *AuthMethodChooser) Init() tea.Cmd { + return nil +} + +func (a *AuthMethodChooser) Update(msg tea.Msg) (util.Model, tea.Cmd) { + return a, nil +} + +func (a *AuthMethodChooser) View() string { + t := styles.CurrentTheme() + + white := lipgloss.NewStyle().Foreground(t.White) + primary := lipgloss.NewStyle().Foreground(t.Primary) + success := lipgloss.NewStyle().Foreground(t.Success) + + titleStyle := white + if a.isOnboarding { + titleStyle = primary + } + + question := lipgloss. + NewStyle(). + Margin(0, 1). + Render(titleStyle.Render("How would you like to authenticate with ") + success.Render("Anthropic") + titleStyle.Render("?")) + + squareWidth := (a.width - 2) / 2 + squareHeight := squareWidth / 3 + if isOdd(squareHeight) { + squareHeight++ + } + + square := lipgloss.NewStyle(). + Width(squareWidth). + Height(squareHeight). + Margin(0, 0). + Border(lipgloss.RoundedBorder()) + + squareText := lipgloss.NewStyle(). + Width(squareWidth - 2). + Height(squareHeight). + Align(lipgloss.Center). + AlignVertical(lipgloss.Center) + + oauthBorder := t.AuthBorderSelected + oauthText := t.AuthTextSelected + apiKeyBorder := t.AuthBorderUnselected + apiKeyText := t.AuthTextUnselected + + if a.State == AuthMethodAPIKey { + oauthBorder, apiKeyBorder = apiKeyBorder, oauthBorder + oauthText, apiKeyText = apiKeyText, oauthText + } + + return lipgloss.JoinVertical( + lipgloss.Left, + question, + "", + lipgloss.JoinHorizontal( + lipgloss.Center, + square.MarginLeft(1). + Inherit(oauthBorder).Render(squareText.Inherit(oauthText).Render("Claude Account\nwith Subscription")), + square.MarginRight(1). + Inherit(apiKeyBorder).Render(squareText.Inherit(apiKeyText).Render("API Key")), + ), + ) +} + +func (a *AuthMethodChooser) SetDefaults() { + a.State = AuthMethodOAuth2 +} + +func (a *AuthMethodChooser) SetWidth(w int) { + a.width = w +} + +func (a *AuthMethodChooser) ToggleChoice() { + switch a.State { + case AuthMethodAPIKey: + a.State = AuthMethodOAuth2 + case AuthMethodOAuth2: + a.State = AuthMethodAPIKey + } +} + +func isOdd(n int) bool { + return n%2 != 0 +} diff --git a/internal/tui/components/dialogs/claude/oauth.go b/internal/tui/components/dialogs/claude/oauth.go new file mode 100644 index 0000000000000000000000000000000000000000..f8da5b4fffbc75708676a1545f9a6719b7e2f198 --- /dev/null +++ b/internal/tui/components/dialogs/claude/oauth.go @@ -0,0 +1,267 @@ +package claude + +import ( + "context" + "fmt" + "net/url" + + "charm.land/bubbles/v2/spinner" + "charm.land/bubbles/v2/textinput" + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" + "github.com/charmbracelet/crush/internal/oauth" + "github.com/charmbracelet/crush/internal/oauth/claude" + "github.com/charmbracelet/crush/internal/tui/styles" + "github.com/charmbracelet/crush/internal/tui/util" + "github.com/pkg/browser" + "github.com/zeebo/xxh3" +) + +type OAuthState int + +const ( + OAuthStateURL OAuthState = iota + OAuthStateCode +) + +type OAuthValidationState int + +const ( + OAuthValidationStateNone OAuthValidationState = iota + OAuthValidationStateVerifying + OAuthValidationStateValid + OAuthValidationStateError +) + +type ValidationCompletedMsg struct { + State OAuthValidationState + Token *oauth.Token +} + +type AuthenticationCompleteMsg struct{} + +type OAuth2 struct { + State OAuthState + ValidationState OAuthValidationState + width int + isOnboarding bool + + // URL page + err error + verifier string + challenge string + URL string + urlId string + token *oauth.Token + + // Code input page + CodeInput textinput.Model + spinner spinner.Model +} + +func NewOAuth2() *OAuth2 { + return &OAuth2{ + State: OAuthStateURL, + } +} + +func (o *OAuth2) Init() tea.Cmd { + t := styles.CurrentTheme() + + verifier, challenge, err := claude.GetChallenge() + if err != nil { + o.err = err + return nil + } + + url, err := claude.AuthorizeURL(verifier, challenge) + if err != nil { + o.err = err + return nil + } + + o.verifier = verifier + o.challenge = challenge + o.URL = url + + h := xxh3.New() + _, _ = h.WriteString(o.URL) + o.urlId = fmt.Sprintf("id=%x", h.Sum(nil)) + + o.CodeInput = textinput.New() + o.CodeInput.Placeholder = "Paste or type" + o.CodeInput.SetVirtualCursor(false) + o.CodeInput.Prompt = "> " + o.CodeInput.SetStyles(t.S().TextInput) + o.CodeInput.SetWidth(50) + + o.spinner = spinner.New( + spinner.WithSpinner(spinner.Dot), + spinner.WithStyle(t.S().Base.Foreground(t.Green)), + ) + + return nil +} + +func (o *OAuth2) Update(msg tea.Msg) (util.Model, tea.Cmd) { + var cmds []tea.Cmd + + switch msg := msg.(type) { + case ValidationCompletedMsg: + o.ValidationState = msg.State + o.token = msg.Token + switch o.ValidationState { + case OAuthValidationStateError: + o.CodeInput.Focus() + } + o.updatePrompt() + } + + if o.ValidationState == OAuthValidationStateVerifying { + var cmd tea.Cmd + o.spinner, cmd = o.spinner.Update(msg) + cmds = append(cmds, cmd) + o.updatePrompt() + } + { + var cmd tea.Cmd + o.CodeInput, cmd = o.CodeInput.Update(msg) + cmds = append(cmds, cmd) + } + + return o, tea.Batch(cmds...) +} + +func (o *OAuth2) ValidationConfirm() (util.Model, tea.Cmd) { + var cmds []tea.Cmd + + switch { + case o.State == OAuthStateURL: + _ = browser.OpenURL(o.URL) + o.State = OAuthStateCode + cmds = append(cmds, o.CodeInput.Focus()) + case o.ValidationState == OAuthValidationStateNone || o.ValidationState == OAuthValidationStateError: + o.CodeInput.Blur() + o.ValidationState = OAuthValidationStateVerifying + cmds = append(cmds, o.spinner.Tick, o.validateCode) + case o.ValidationState == OAuthValidationStateValid: + cmds = append(cmds, func() tea.Msg { return AuthenticationCompleteMsg{} }) + } + + o.updatePrompt() + return o, tea.Batch(cmds...) +} + +func (o *OAuth2) View() string { + t := styles.CurrentTheme() + + whiteStyle := lipgloss.NewStyle().Foreground(t.White) + primaryStyle := lipgloss.NewStyle().Foreground(t.Primary) + successStyle := lipgloss.NewStyle().Foreground(t.Success) + errorStyle := lipgloss.NewStyle().Foreground(t.Error) + + titleStyle := whiteStyle + if o.isOnboarding { + titleStyle = primaryStyle + } + + switch { + case o.err != nil: + return lipgloss.NewStyle(). + Margin(0, 1). + Foreground(t.Error). + Render(o.err.Error()) + case o.State == OAuthStateURL: + heading := lipgloss. + NewStyle(). + Margin(0, 1). + Render(titleStyle.Render("Press enter key to open the following ") + successStyle.Render("URL") + titleStyle.Render(":")) + + return lipgloss.JoinVertical( + lipgloss.Left, + heading, + "", + lipgloss.NewStyle(). + Margin(0, 1). + Foreground(t.FgMuted). + Hyperlink(o.URL, o.urlId). + Render(o.displayUrl()), + ) + case o.State == OAuthStateCode: + var heading string + + switch o.ValidationState { + case OAuthValidationStateNone: + st := lipgloss.NewStyle().Margin(0, 1) + heading = st.Render(titleStyle.Render("Enter the ") + successStyle.Render("code") + titleStyle.Render(" you received.")) + case OAuthValidationStateVerifying: + heading = titleStyle.Margin(0, 1).Render("Verifying...") + case OAuthValidationStateValid: + heading = successStyle.Margin(0, 1).Render("Validated.") + case OAuthValidationStateError: + heading = errorStyle.Margin(0, 1).Render("Invalid. Try again?") + } + + return lipgloss.JoinVertical( + lipgloss.Left, + heading, + "", + " "+o.CodeInput.View(), + ) + default: + panic("claude oauth2: invalid state") + } +} + +func (o *OAuth2) SetDefaults() { + o.State = OAuthStateURL + o.ValidationState = OAuthValidationStateNone + o.CodeInput.SetValue("") + o.err = nil +} + +func (o *OAuth2) SetWidth(w int) { + o.width = w + o.CodeInput.SetWidth(w - 4) +} + +func (o *OAuth2) SetError(err error) { + o.err = err +} + +func (o *OAuth2) validateCode() tea.Msg { + token, err := claude.ExchangeToken(context.Background(), o.CodeInput.Value(), o.verifier) + if err != nil || token == nil { + return ValidationCompletedMsg{State: OAuthValidationStateError} + } + return ValidationCompletedMsg{State: OAuthValidationStateValid, Token: token} +} + +func (o *OAuth2) updatePrompt() { + switch o.ValidationState { + case OAuthValidationStateNone: + o.CodeInput.Prompt = "> " + case OAuthValidationStateVerifying: + o.CodeInput.Prompt = o.spinner.View() + " " + case OAuthValidationStateValid: + o.CodeInput.Prompt = styles.CheckIcon + " " + case OAuthValidationStateError: + o.CodeInput.Prompt = styles.ErrorIcon + " " + } +} + +// Remove query params for display +// e.g., "https://claude.ai/oauth/authorize?..." -> "https://claude.ai/oauth/authorize..." +func (o *OAuth2) displayUrl() string { + parsed, err := url.Parse(o.URL) + if err != nil { + return o.URL + } + + if parsed.RawQuery != "" { + parsed.RawQuery = "" + return parsed.String() + "..." + } + + return o.URL +} diff --git a/internal/tui/components/dialogs/models/keys.go b/internal/tui/components/dialogs/models/keys.go index e36a18d7299a172486423749464f898954dcb1f2..088a999ef2d37c8a67f6d5d0c7490a4eeb29fac5 100644 --- a/internal/tui/components/dialogs/models/keys.go +++ b/internal/tui/components/dialogs/models/keys.go @@ -8,11 +8,17 @@ type KeyMap struct { Select, Next, Previous, + Choose, Tab, Close key.Binding isAPIKeyHelp bool isAPIKeyValid bool + + isClaudeAuthChoiseHelp bool + isClaudeOAuthHelp bool + isClaudeOAuthURLState bool + isClaudeOAuthHelpComplete bool } func DefaultKeyMap() KeyMap { @@ -29,6 +35,10 @@ func DefaultKeyMap() KeyMap { key.WithKeys("up", "ctrl+p"), key.WithHelp("↑", "previous item"), ), + Choose: key.NewBinding( + key.WithKeys("left", "right", "h", "l"), + key.WithHelp("←→", "choose"), + ), Tab: key.NewBinding( key.WithKeys("tab"), key.WithHelp("tab", "toggle type"), @@ -64,8 +74,64 @@ func (k KeyMap) FullHelp() [][]key.Binding { // ShortHelp implements help.KeyMap. func (k KeyMap) ShortHelp() []key.Binding { + if k.isClaudeAuthChoiseHelp { + return []key.Binding{ + key.NewBinding( + key.WithKeys("left", "right", "h", "l"), + key.WithHelp("←→", "choose"), + ), + key.NewBinding( + key.WithKeys("enter"), + key.WithHelp("enter", "accept"), + ), + key.NewBinding( + key.WithKeys("esc"), + key.WithHelp("esc", "back"), + ), + } + } + if k.isClaudeOAuthHelp { + if k.isClaudeOAuthHelpComplete { + return []key.Binding{ + key.NewBinding( + key.WithKeys("enter"), + key.WithHelp("enter", "close"), + ), + } + } + + enterHelp := "submit" + if k.isClaudeOAuthURLState { + enterHelp = "open" + } + + bindings := []key.Binding{ + key.NewBinding( + key.WithKeys("enter"), + key.WithHelp("enter", enterHelp), + ), + } + + if k.isClaudeOAuthURLState { + bindings = append(bindings, key.NewBinding( + key.WithKeys("c"), + key.WithHelp("c", "copy url"), + )) + } + + bindings = append(bindings, key.NewBinding( + key.WithKeys("esc"), + key.WithHelp("esc", "back"), + )) + + return bindings + } if k.isAPIKeyHelp && !k.isAPIKeyValid { return []key.Binding{ + key.NewBinding( + key.WithKeys("enter"), + key.WithHelp("enter", "submit"), + ), k.Close, } } else if k.isAPIKeyValid { diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go index 2ab4bfc605898174057ab1fca515ee9914ccdb53..406f4a5853a88bfb22914071e7e53d19662b7ecd 100644 --- a/internal/tui/components/dialogs/models/models.go +++ b/internal/tui/components/dialogs/models/models.go @@ -9,10 +9,12 @@ import ( "charm.land/bubbles/v2/spinner" tea "charm.land/bubbletea/v2" "charm.land/lipgloss/v2" + "github.com/atotto/clipboard" "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/components/dialogs" + "github.com/charmbracelet/crush/internal/tui/components/dialogs/claude" "github.com/charmbracelet/crush/internal/tui/exp/list" "github.com/charmbracelet/crush/internal/tui/styles" "github.com/charmbracelet/crush/internal/tui/util" @@ -67,6 +69,12 @@ type modelDialogCmp struct { selectedModelType config.SelectedModelType isAPIKeyValid bool apiKeyValue string + + // Claude state + claudeAuthMethodChooser *claude.AuthMethodChooser + claudeOAuth2 *claude.OAuth2 + showClaudeAuthMethodChooser bool + showClaudeOAuth2 bool } func NewModelDialogCmp() ModelDialog { @@ -91,11 +99,19 @@ func NewModelDialogCmp() ModelDialog { width: defaultWidth, keyMap: DefaultKeyMap(), help: help, + + claudeAuthMethodChooser: claude.NewAuthMethodChooser(), + claudeOAuth2: claude.NewOAuth2(), } } func (m *modelDialogCmp) Init() tea.Cmd { - return tea.Batch(m.modelList.Init(), m.apiKeyInput.Init()) + return tea.Batch( + m.modelList.Init(), + m.apiKeyInput.Init(), + m.claudeAuthMethodChooser.Init(), + m.claudeOAuth2.Init(), + ) } func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { @@ -105,16 +121,84 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { m.wHeight = msg.Height m.apiKeyInput.SetWidth(m.width - 2) m.help.SetWidth(m.width - 2) + m.claudeAuthMethodChooser.SetWidth(m.width - 2) return m, m.modelList.SetSize(m.listWidth(), m.listHeight()) case APIKeyStateChangeMsg: u, cmd := m.apiKeyInput.Update(msg) m.apiKeyInput = u.(*APIKeyInput) return m, cmd + case claude.ValidationCompletedMsg: + var cmds []tea.Cmd + u, cmd := m.claudeOAuth2.Update(msg) + m.claudeOAuth2 = u.(*claude.OAuth2) + cmds = append(cmds, cmd) + + if msg.State == claude.OAuthValidationStateValid { + cmds = append(cmds, m.saveAPIKeyAndContinue(msg.Token, false)) + m.keyMap.isClaudeOAuthHelpComplete = true + } + + return m, tea.Batch(cmds...) + case claude.AuthenticationCompleteMsg: + return m, util.CmdHandler(dialogs.CloseDialogMsg{}) case tea.KeyPressMsg: switch { + case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))): + if m.showClaudeOAuth2 && m.claudeOAuth2.State == claude.OAuthStateURL { + return m, tea.Sequence( + tea.SetClipboard(m.claudeOAuth2.URL), + func() tea.Msg { + _ = clipboard.WriteAll(m.claudeOAuth2.URL) + return nil + }, + util.ReportInfo("URL copied to clipboard"), + ) + } + case key.Matches(msg, m.keyMap.Choose): + if m.showClaudeAuthMethodChooser { + m.claudeAuthMethodChooser.ToggleChoice() + return m, nil + } case key.Matches(msg, m.keyMap.Select): + selectedItem := m.modelList.SelectedModel() + + modelType := config.SelectedModelTypeLarge + if m.modelList.GetModelType() == SmallModelType { + modelType = config.SelectedModelTypeSmall + } + + askForApiKey := func() { + m.keyMap.isClaudeAuthChoiseHelp = false + m.keyMap.isClaudeOAuthHelp = false + m.keyMap.isAPIKeyHelp = true + m.showClaudeAuthMethodChooser = false + m.needsAPIKey = true + m.selectedModel = selectedItem + m.selectedModelType = modelType + m.apiKeyInput.SetProviderName(selectedItem.Provider.Name) + } + + if m.showClaudeAuthMethodChooser { + switch m.claudeAuthMethodChooser.State { + case claude.AuthMethodAPIKey: + askForApiKey() + case claude.AuthMethodOAuth2: + m.selectedModel = selectedItem + m.selectedModelType = modelType + m.showClaudeAuthMethodChooser = false + m.showClaudeOAuth2 = true + m.keyMap.isClaudeAuthChoiseHelp = false + m.keyMap.isClaudeOAuthHelp = true + } + return m, nil + } + if m.showClaudeOAuth2 { + m2, cmd2 := m.claudeOAuth2.ValidationConfirm() + m.claudeOAuth2 = m2.(*claude.OAuth2) + return m, cmd2 + } if m.isAPIKeyValid { - return m, m.saveAPIKeyAndContinue(m.apiKeyValue) + return m, m.saveAPIKeyAndContinue(m.apiKeyValue, true) } if m.needsAPIKey { // Handle API key submission @@ -154,15 +238,6 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { }, ) } - // Normal model selection - selectedItem := m.modelList.SelectedModel() - - var modelType config.SelectedModelType - if m.modelList.GetModelType() == LargeModelType { - modelType = config.SelectedModelTypeLarge - } else { - modelType = config.SelectedModelTypeSmall - } // Check if provider is configured if m.isProviderConfigured(string(selectedItem.Provider.ID)) { @@ -179,27 +254,38 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { }), ) } else { - // Provider not configured, show API key input - m.needsAPIKey = true - m.selectedModel = selectedItem - m.selectedModelType = modelType - m.apiKeyInput.SetProviderName(selectedItem.Provider.Name) + if selectedItem.Provider.ID == catwalk.InferenceProviderAnthropic { + m.showClaudeAuthMethodChooser = true + m.keyMap.isClaudeAuthChoiseHelp = true + return m, nil + } + askForApiKey() return m, nil } case key.Matches(msg, m.keyMap.Tab): - if m.needsAPIKey { + switch { + case m.showClaudeAuthMethodChooser: + m.claudeAuthMethodChooser.ToggleChoice() + return m, nil + case m.needsAPIKey: u, cmd := m.apiKeyInput.Update(msg) m.apiKeyInput = u.(*APIKeyInput) return m, cmd - } - if m.modelList.GetModelType() == LargeModelType { + case m.modelList.GetModelType() == LargeModelType: m.modelList.SetInputPlaceholder(smallModelInputPlaceholder) return m, m.modelList.SetModelType(SmallModelType) - } else { + default: m.modelList.SetInputPlaceholder(largeModelInputPlaceholder) return m, m.modelList.SetModelType(LargeModelType) } case key.Matches(msg, m.keyMap.Close): + if m.showClaudeAuthMethodChooser { + m.claudeAuthMethodChooser.SetDefaults() + m.showClaudeAuthMethodChooser = false + m.keyMap.isClaudeAuthChoiseHelp = false + m.keyMap.isClaudeOAuthHelp = false + return m, nil + } if m.needsAPIKey { if m.isAPIKeyValid { return m, nil @@ -214,7 +300,15 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { } return m, util.CmdHandler(dialogs.CloseDialogMsg{}) default: - if m.needsAPIKey { + if m.showClaudeAuthMethodChooser { + u, cmd := m.claudeAuthMethodChooser.Update(msg) + m.claudeAuthMethodChooser = u.(*claude.AuthMethodChooser) + return m, cmd + } else if m.showClaudeOAuth2 { + u, cmd := m.claudeOAuth2.Update(msg) + m.claudeOAuth2 = u.(*claude.OAuth2) + return m, cmd + } else if m.needsAPIKey { u, cmd := m.apiKeyInput.Update(msg) m.apiKeyInput = u.(*APIKeyInput) return m, cmd @@ -225,7 +319,11 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { } } case tea.PasteMsg: - if m.needsAPIKey { + if m.showClaudeOAuth2 { + u, cmd := m.claudeOAuth2.Update(msg) + m.claudeOAuth2 = u.(*claude.OAuth2) + return m, cmd + } else if m.needsAPIKey { u, cmd := m.apiKeyInput.Update(msg) m.apiKeyInput = u.(*APIKeyInput) return m, cmd @@ -235,9 +333,15 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { return m, cmd } case spinner.TickMsg: - u, cmd := m.apiKeyInput.Update(msg) - m.apiKeyInput = u.(*APIKeyInput) - return m, cmd + if m.showClaudeOAuth2 { + u, cmd := m.claudeOAuth2.Update(msg) + m.claudeOAuth2 = u.(*claude.OAuth2) + return m, cmd + } else { + u, cmd := m.apiKeyInput.Update(msg) + m.apiKeyInput = u.(*APIKeyInput) + return m, cmd + } } return m, nil } @@ -245,7 +349,29 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { func (m *modelDialogCmp) View() string { t := styles.CurrentTheme() - if m.needsAPIKey { + switch { + case m.showClaudeAuthMethodChooser: + chooserView := m.claudeAuthMethodChooser.View() + content := lipgloss.JoinVertical( + lipgloss.Left, + t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)), + chooserView, + "", + t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)), + ) + return m.style().Render(content) + case m.showClaudeOAuth2: + m.keyMap.isClaudeOAuthURLState = m.claudeOAuth2.State == claude.OAuthStateURL + oauth2View := m.claudeOAuth2.View() + content := lipgloss.JoinVertical( + lipgloss.Left, + t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)), + oauth2View, + "", + t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)), + ) + return m.style().Render(content) + case m.needsAPIKey: // Show API key input m.keyMap.isAPIKeyHelp = true m.keyMap.isAPIKeyValid = m.isAPIKeyValid @@ -275,6 +401,16 @@ func (m *modelDialogCmp) View() string { } func (m *modelDialogCmp) Cursor() *tea.Cursor { + if m.showClaudeAuthMethodChooser { + return nil + } + if m.showClaudeOAuth2 { + if cursor := m.claudeOAuth2.CodeInput.Cursor(); cursor != nil { + cursor.Y += 2 // FIXME(@andreynering): Why do we need this? + return m.moveCursor(cursor) + } + return nil + } if m.needsAPIKey { cursor := m.apiKeyInput.Cursor() if cursor != nil { @@ -365,7 +501,7 @@ func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*cat return nil, nil } -func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd { +func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey any, close bool) tea.Cmd { if m.selectedModel == nil { return util.ReportError(fmt.Errorf("no model selected")) } @@ -378,8 +514,12 @@ func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd { // Reset API key state and continue with model selection selectedModel := *m.selectedModel - return tea.Sequence( - util.CmdHandler(dialogs.CloseDialogMsg{}), + var cmds []tea.Cmd + if close { + cmds = append(cmds, util.CmdHandler(dialogs.CloseDialogMsg{})) + } + cmds = append( + cmds, util.CmdHandler(ModelSelectedMsg{ Model: config.SelectedModel{ Model: selectedModel.Model.ID, @@ -390,4 +530,5 @@ func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd { ModelType: m.selectedModelType, }), ) + return tea.Sequence(cmds...) } diff --git a/internal/tui/page/chat/chat.go b/internal/tui/page/chat/chat.go index f951de8677271dfbf034377afaa492f0d8824889..f09f9782b0c77207fc9b96209f5714263236c11f 100644 --- a/internal/tui/page/chat/chat.go +++ b/internal/tui/page/chat/chat.go @@ -29,6 +29,7 @@ import ( "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/components/core/layout" "github.com/charmbracelet/crush/internal/tui/components/dialogs" + "github.com/charmbracelet/crush/internal/tui/components/dialogs/claude" "github.com/charmbracelet/crush/internal/tui/components/dialogs/commands" "github.com/charmbracelet/crush/internal/tui/components/dialogs/filepicker" "github.com/charmbracelet/crush/internal/tui/components/dialogs/models" @@ -293,6 +294,13 @@ func (p *chatPage) Update(msg tea.Msg) (util.Model, tea.Cmd) { cmds = append(cmds, cmd) return p, tea.Batch(cmds...) + case claude.ValidationCompletedMsg, claude.AuthenticationCompleteMsg: + if p.focusedPane == PanelTypeSplash { + u, cmd := p.splash.Update(msg) + p.splash = u.(splash.Splash) + cmds = append(cmds, cmd) + } + return p, tea.Batch(cmds...) case models.APIKeyStateChangeMsg: if p.focusedPane == PanelTypeSplash { u, cmd := p.splash.Update(msg) @@ -816,6 +824,71 @@ func (p *chatPage) Help() help.KeyMap { var shortList []key.Binding var fullList [][]key.Binding switch { + case p.isOnboarding && p.splash.IsShowingClaudeAuthMethodChooser(): + shortList = append(shortList, + // Choose auth method + key.NewBinding( + key.WithKeys("left", "right", "tab"), + key.WithHelp("←→/tab", "choose"), + ), + // Accept selection + key.NewBinding( + key.WithKeys("enter"), + key.WithHelp("enter", "accept"), + ), + // Go back + key.NewBinding( + key.WithKeys("esc", "alt+esc"), + key.WithHelp("esc", "back"), + ), + // Quit + key.NewBinding( + key.WithKeys("ctrl+c"), + key.WithHelp("ctrl+c", "quit"), + ), + ) + // keep them the same + for _, v := range shortList { + fullList = append(fullList, []key.Binding{v}) + } + case p.isOnboarding && p.splash.IsShowingClaudeOAuth2(): + if p.splash.IsClaudeOAuthURLState() { + shortList = append(shortList, + key.NewBinding( + key.WithKeys("enter"), + key.WithHelp("enter", "open"), + ), + key.NewBinding( + key.WithKeys("c"), + key.WithHelp("c", "copy url"), + ), + ) + } else if p.splash.IsClaudeOAuthComplete() { + shortList = append(shortList, + key.NewBinding( + key.WithKeys("enter"), + key.WithHelp("enter", "continue"), + ), + ) + } else { + shortList = append(shortList, + key.NewBinding( + key.WithKeys("enter"), + key.WithHelp("enter", "submit"), + ), + ) + } + shortList = append(shortList, + // Quit + key.NewBinding( + key.WithKeys("ctrl+c"), + key.WithHelp("ctrl+c", "quit"), + ), + ) + // keep them the same + for _, v := range shortList { + fullList = append(fullList, []key.Binding{v}) + } case p.isOnboarding && !p.splash.IsShowingAPIKey(): shortList = append(shortList, // Choose model diff --git a/internal/tui/styles/charmtone.go b/internal/tui/styles/charmtone.go index b0f41e7f92469ee203e0fb9df651c4a311f245f8..44508e5a24e68ea0507af0f2649ddc372711104d 100644 --- a/internal/tui/styles/charmtone.go +++ b/internal/tui/styles/charmtone.go @@ -67,10 +67,17 @@ func NewCharmtoneTheme() *Theme { t.ItemErrorIcon = t.ItemOfflineIcon.Foreground(charmtone.Coral) t.ItemOnlineIcon = t.ItemOfflineIcon.Foreground(charmtone.Guac) + // Editor: Yolo Mode. t.YoloIconFocused = lipgloss.NewStyle().Foreground(charmtone.Oyster).Background(charmtone.Citron).Bold(true).SetString(" ! ") t.YoloIconBlurred = t.YoloIconFocused.Foreground(charmtone.Pepper).Background(charmtone.Squid) t.YoloDotsFocused = lipgloss.NewStyle().Foreground(charmtone.Zest).SetString(":::") t.YoloDotsBlurred = t.YoloDotsFocused.Foreground(charmtone.Squid) + // oAuth Chooser. + t.AuthBorderSelected = lipgloss.NewStyle().BorderForeground(charmtone.Guac) + t.AuthTextSelected = lipgloss.NewStyle().Foreground(charmtone.Julep) + t.AuthBorderUnselected = lipgloss.NewStyle().BorderForeground(charmtone.Iron) + t.AuthTextUnselected = lipgloss.NewStyle().Foreground(charmtone.Squid) + return t } diff --git a/internal/tui/styles/theme.go b/internal/tui/styles/theme.go index 50de5df1104856f5937de2d75e1f97b3f95c1f79..9605dd2d4ed8ffcdf35a2b8e524747d0cc983bc9 100644 --- a/internal/tui/styles/theme.go +++ b/internal/tui/styles/theme.go @@ -85,12 +85,18 @@ type Theme struct { ItemErrorIcon lipgloss.Style ItemOnlineIcon lipgloss.Style - // Editor: Yolo Mode + // Editor: Yolo Mode. YoloIconFocused lipgloss.Style YoloIconBlurred lipgloss.Style YoloDotsFocused lipgloss.Style YoloDotsBlurred lipgloss.Style + // oAuth Chooser. + AuthBorderSelected lipgloss.Style + AuthTextSelected lipgloss.Style + AuthBorderUnselected lipgloss.Style + AuthTextUnselected lipgloss.Style + styles *Styles }