diff --git a/internal/agent/agent.go b/internal/agent/agent.go index aee8a930ba8c2d4174c9ec249a6f4dbc253b3683..1632e2fa752da1d37bcd198a71d1177e96f93f37 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -846,6 +846,10 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) + modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens) + if a.isClaudeCode() { + cost = 0 + } + // Use override cost if available (e.g., from OpenRouter). if openrouterCost != nil { cost = *openrouterCost @@ -883,6 +887,10 @@ func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) + modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens) + if a.isClaudeCode() { + cost = 0 + } + a.eventTokensUsed(session.ID, model, usage, cost) if overrideCost != nil { @@ -990,9 +998,19 @@ func (a *sessionAgent) Model() Model { } func (a *sessionAgent) promptPrefix() string { + if a.isClaudeCode() { + return "You are Claude Code, Anthropic's official CLI for Claude." + } return a.systemPromptPrefix } +// XXX: this should be generalized to cover other subscription plans, like Copilot. +func (a *sessionAgent) isClaudeCode() bool { + cfg := config.Get() + pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider) + return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil +} + // convertToToolResult converts a fantasy tool result to a message tool result. func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult { baseResult := message.ToolResult{ diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index bfe25df384e1075e4f672be34452eaeb589a55f1..54efc6c8427c233a55ac3818b3079631f09792f1 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -518,13 +518,13 @@ func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Mo }, nil } -func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) { +func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string, isOauth bool) (fantasy.Provider, error) { var opts []anthropic.Option - if strings.HasPrefix(apiKey, "Bearer ") { + if isOauth { // NOTE: Prevent the SDK from picking up the API key from env. os.Setenv("ANTHROPIC_API_KEY", "") - headers["Authorization"] = apiKey + headers["Authorization"] = fmt.Sprintf("Bearer %s", apiKey) } else if apiKey != "" { // X-Api-Key header opts = append(opts, anthropic.WithAPIKey(apiKey)) @@ -731,7 +731,7 @@ func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model con case openai.Name: return c.buildOpenaiProvider(baseURL, apiKey, headers) case anthropic.Name: - return c.buildAnthropicProvider(baseURL, apiKey, headers) + return c.buildAnthropicProvider(baseURL, apiKey, headers, providerCfg.OAuthToken != nil) case openrouter.Name: return c.buildOpenrouterProvider(baseURL, apiKey, headers) case azure.Name: diff --git a/internal/cmd/login.go b/internal/cmd/login.go index 1d368fafb0669147ee4b2fc28fa2494099d4e7c9..77d44e502126cb6b1d3382c41f950dfff78e4562 100644 --- a/internal/cmd/login.go +++ b/internal/cmd/login.go @@ -6,11 +6,13 @@ import ( "fmt" "os" "os/signal" + "strings" "charm.land/lipgloss/v2" hyperp "git.secluded.site/crush/internal/agent/hyper" "git.secluded.site/crush/internal/config" "git.secluded.site/crush/internal/oauth" + "git.secluded.site/crush/internal/oauth/claude" "git.secluded.site/crush/internal/oauth/copilot" "git.secluded.site/crush/internal/oauth/hyper" "github.com/atotto/clipboard" @@ -24,16 +26,21 @@ var loginCmd = &cobra.Command{ Short: "Login Crush to a platform", Long: `Login Crush to a specified platform. The platform should be provided as an argument. -Available platforms are: hyper, copilot.`, +Available platforms are: hyper, claude, copilot.`, Example: ` # Authenticate with Charm Hyper crush login +# Authenticate with Claude Code Max +crush login claude + # Authenticate with GitHub Copilot crush login copilot `, ValidArgs: []cobra.Completion{ "hyper", + "claude", + "anthropic", "copilot", "github", "github-copilot", @@ -53,6 +60,8 @@ crush login copilot switch provider { case "hyper": return loginHyper() + case "anthropic", "claude": + return loginClaude() case "copilot", "github", "github-copilot": return loginCopilot() default: @@ -124,6 +133,60 @@ func loginHyper() error { return nil } +func loginClaude() error { + ctx := getLoginContext() + + cfg := config.Get() + if cfg.HasConfigField("providers.anthropic.oauth") { + fmt.Println("You are already logged in to Claude.") + return nil + } + + verifier, challenge, err := claude.GetChallenge() + if err != nil { + return err + } + url, err := claude.AuthorizeURL(verifier, challenge) + if err != nil { + return err + } + fmt.Println("Open the following URL and follow the instructions to authenticate with Claude Code Max:") + fmt.Println() + fmt.Println(lipgloss.NewStyle().Hyperlink(url, "id=claude").Render(url)) + fmt.Println() + fmt.Println("Press enter to continue...") + if _, err := fmt.Scanln(); err != nil { + return err + } + + fmt.Println("Now paste and code from Anthropic and press enter...") + fmt.Println() + fmt.Print("> ") + var code string + for code == "" { + _, _ = fmt.Scanln(&code) + code = strings.TrimSpace(code) + } + + fmt.Println() + fmt.Println("Exchanging authorization code...") + token, err := claude.ExchangeToken(ctx, code, verifier) + if err != nil { + return err + } + + if err := cmp.Or( + cfg.SetConfigField("providers.anthropic.api_key", token.AccessToken), + cfg.SetConfigField("providers.anthropic.oauth", token), + ); err != nil { + return err + } + + fmt.Println() + fmt.Println("You're now authenticated with Claude Code Max!") + return nil +} + func loginCopilot() error { ctx := getLoginContext() diff --git a/internal/config/config.go b/internal/config/config.go index 89c25c1b0ab59b07ed53ed2970bb9783c90d3c5a..d74e24a54575ba8abd48966528f2e8248ca60ddc 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -18,6 +18,7 @@ import ( "git.secluded.site/crush/internal/csync" "git.secluded.site/crush/internal/env" "git.secluded.site/crush/internal/oauth" + "git.secluded.site/crush/internal/oauth/claude" "git.secluded.site/crush/internal/oauth/copilot" "git.secluded.site/crush/internal/oauth/hyper" "github.com/charmbracelet/catwalk/pkg/catwalk" @@ -154,6 +155,21 @@ func (pc *ProviderConfig) ToProvider() catwalk.Provider { return provider } +func (pc *ProviderConfig) SetupClaudeCode() { + pc.SystemPromptPrefix = "You are Claude Code, Anthropic's official CLI for Claude." + pc.ExtraHeaders["anthropic-version"] = "2023-06-01" + + value := pc.ExtraHeaders["anthropic-beta"] + const want = "oauth-2025-04-20" + if !strings.Contains(value, want) { + if value != "" { + value += "," + } + value += want + } + pc.ExtraHeaders["anthropic-beta"] = value +} + func (pc *ProviderConfig) SetupGitHubCopilot() { maps.Copy(pc.ExtraHeaders, copilot.Headers()) } @@ -508,25 +524,6 @@ func (c *Config) SetConfigField(key string, value any) error { return nil } -func (c *Config) RemoveConfigField(key string) error { - data, err := os.ReadFile(c.dataConfigDir) - if err != nil { - return fmt.Errorf("failed to read config file: %w", err) - } - - newValue, err := sjson.Delete(string(data), key) - if err != nil { - return fmt.Errorf("failed to delete config field %s: %w", key, err) - } - if err := os.MkdirAll(filepath.Dir(c.dataConfigDir), 0o755); err != nil { - return fmt.Errorf("failed to create config directory %q: %w", c.dataConfigDir, err) - } - if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o600); err != nil { - return fmt.Errorf("failed to write config file: %w", err) - } - return nil -} - // RefreshOAuthToken refreshes the OAuth token for the given provider. func (c *Config) RefreshOAuthToken(ctx context.Context, providerID string) error { providerConfig, exists := c.Providers.Get(providerID) @@ -541,6 +538,8 @@ func (c *Config) RefreshOAuthToken(ctx context.Context, providerID string) error var newToken *oauth.Token var refreshErr error switch providerID { + case string(catwalk.InferenceProviderAnthropic): + newToken, refreshErr = claude.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken) case string(catwalk.InferenceProviderCopilot): newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken) case hyperp.Name: @@ -557,6 +556,8 @@ func (c *Config) RefreshOAuthToken(ctx context.Context, providerID string) error providerConfig.APIKey = newToken.AccessToken switch providerID { + case string(catwalk.InferenceProviderAnthropic): + providerConfig.SetupClaudeCode() case string(catwalk.InferenceProviderCopilot): providerConfig.SetupGitHubCopilot() } @@ -595,6 +596,8 @@ func (c *Config) SetProviderAPIKey(providerID string, apiKey any) error { providerConfig.APIKey = v.AccessToken providerConfig.OAuthToken = v switch providerID { + case string(catwalk.InferenceProviderAnthropic): + providerConfig.SetupClaudeCode() case string(catwalk.InferenceProviderCopilot): providerConfig.SetupGitHubCopilot() } diff --git a/internal/config/load.go b/internal/config/load.go index 01db9bd99f57cb9b5d740c5c27806aa948251a85..a3f91e9ddb0c9e1c2871be13638717fc17fe7dd1 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -202,12 +202,11 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know switch { case p.ID == catwalk.InferenceProviderAnthropic && config.OAuthToken != nil: - // Claude Code subscription is not supported anymore. Remove to show onboarding. - c.RemoveConfigField("providers.anthropic") - c.Providers.Del(string(p.ID)) - continue - case p.ID == catwalk.InferenceProviderCopilot && config.OAuthToken != nil: - prepared.SetupGitHubCopilot() + prepared.SetupClaudeCode() + case p.ID == catwalk.InferenceProviderCopilot: + if config.OAuthToken != nil { + prepared.SetupGitHubCopilot() + } } switch p.ID { diff --git a/internal/oauth/claude/challenge.go b/internal/oauth/claude/challenge.go new file mode 100644 index 0000000000000000000000000000000000000000..ec9ed3c5d17e91fc5dc8c33f44f3d6a4ce4aa244 --- /dev/null +++ b/internal/oauth/claude/challenge.go @@ -0,0 +1,28 @@ +package claude + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "strings" +) + +// GetChallenge generates a PKCE verifier and its corresponding challenge. +func GetChallenge() (verifier string, challenge string, err error) { + bytes := make([]byte, 32) + if _, err := rand.Read(bytes); err != nil { + return "", "", err + } + verifier = encodeBase64(bytes) + hash := sha256.Sum256([]byte(verifier)) + challenge = encodeBase64(hash[:]) + return verifier, challenge, nil +} + +func encodeBase64(input []byte) (encoded string) { + encoded = base64.StdEncoding.EncodeToString(input) + encoded = strings.ReplaceAll(encoded, "=", "") + encoded = strings.ReplaceAll(encoded, "+", "-") + encoded = strings.ReplaceAll(encoded, "/", "_") + return encoded +} diff --git a/internal/oauth/claude/oauth.go b/internal/oauth/claude/oauth.go new file mode 100644 index 0000000000000000000000000000000000000000..2a733220dcc3b6ee7b95bdf96b2808953b19f4f3 --- /dev/null +++ b/internal/oauth/claude/oauth.go @@ -0,0 +1,126 @@ +package claude + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "git.secluded.site/crush/internal/oauth" +) + +const clientId = "9d1c250a-e61b-44d9-88ed-5944d1962f5e" + +// AuthorizeURL returns the Claude Code Max OAuth2 authorization URL. +func AuthorizeURL(verifier, challenge string) (string, error) { + u, err := url.Parse("https://claude.ai/oauth/authorize") + if err != nil { + return "", err + } + q := u.Query() + q.Set("response_type", "code") + q.Set("client_id", clientId) + q.Set("redirect_uri", "https://console.anthropic.com/oauth/code/callback") + q.Set("scope", "org:create_api_key user:profile user:inference") + q.Set("code_challenge", challenge) + q.Set("code_challenge_method", "S256") + q.Set("state", verifier) + u.RawQuery = q.Encode() + return u.String(), nil +} + +// ExchangeToken exchanges the authorization code for an OAuth2 token. +func ExchangeToken(ctx context.Context, code, verifier string) (*oauth.Token, error) { + code = strings.TrimSpace(code) + parts := strings.SplitN(code, "#", 2) + pure := parts[0] + state := "" + if len(parts) > 1 { + state = parts[1] + } + + reqBody := map[string]string{ + "code": pure, + "state": state, + "grant_type": "authorization_code", + "client_id": clientId, + "redirect_uri": "https://console.anthropic.com/oauth/code/callback", + "code_verifier": verifier, + } + + resp, err := request(ctx, "POST", "https://console.anthropic.com/v1/oauth/token", reqBody) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("claude code max: failed to exchange token: status %d body %q", resp.StatusCode, string(body)) + } + + var token oauth.Token + if err := json.Unmarshal(body, &token); err != nil { + return nil, err + } + token.SetExpiresAt() + return &token, nil +} + +// RefreshToken refreshes the OAuth2 token using the provided refresh token. +func RefreshToken(ctx context.Context, refreshToken string) (*oauth.Token, error) { + reqBody := map[string]string{ + "grant_type": "refresh_token", + "refresh_token": refreshToken, + "client_id": clientId, + } + + resp, err := request(ctx, "POST", "https://console.anthropic.com/v1/oauth/token", reqBody) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("claude code max: failed to refresh token: status %d body %q", resp.StatusCode, string(body)) + } + + var token oauth.Token + if err := json.Unmarshal(body, &token); err != nil { + return nil, err + } + token.SetExpiresAt() + return &token, nil +} + +func request(ctx context.Context, method, url string, body any) (*http.Response, error) { + date, err := json.Marshal(body) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, method, url, bytes.NewReader(date)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "anthropic") + + client := &http.Client{Timeout: 30 * time.Second} + return client.Do(req) +} diff --git a/internal/tui/components/chat/splash/splash.go b/internal/tui/components/chat/splash/splash.go index c1a61722833c31f2fbe15e0623dc8c3ed72f5e8f..48d10ca2953fa958d04aa601285105f6ff946acd 100644 --- a/internal/tui/components/chat/splash/splash.go +++ b/internal/tui/components/chat/splash/splash.go @@ -16,6 +16,7 @@ import ( "git.secluded.site/crush/internal/tui/components/chat" "git.secluded.site/crush/internal/tui/components/core" "git.secluded.site/crush/internal/tui/components/core/layout" + "git.secluded.site/crush/internal/tui/components/dialogs/claude" "git.secluded.site/crush/internal/tui/components/dialogs/copilot" "git.secluded.site/crush/internal/tui/components/dialogs/hyper" "git.secluded.site/crush/internal/tui/components/dialogs/models" @@ -26,6 +27,7 @@ import ( "git.secluded.site/crush/internal/tui/styles" "git.secluded.site/crush/internal/tui/util" "git.secluded.site/crush/internal/version" + "github.com/atotto/clipboard" "github.com/charmbracelet/catwalk/pkg/catwalk" ) @@ -45,6 +47,18 @@ type Splash interface { // IsAPIKeyValid returns whether the API key is valid IsAPIKeyValid() bool + // IsShowingClaudeAuthMethodChooser returns whether showing Claude auth method chooser + IsShowingClaudeAuthMethodChooser() bool + + // IsShowingClaudeOAuth2 returns whether showing Claude OAuth2 flow + IsShowingClaudeOAuth2() bool + + // IsClaudeOAuthURLState returns whether in OAuth URL state + IsClaudeOAuthURLState() bool + + // IsClaudeOAuthComplete returns whether Claude OAuth flow is complete + IsClaudeOAuthComplete() bool + // IsShowingClaudeOAuth2 returns whether showing Hyper OAuth2 flow IsShowingHyperOAuth2() bool @@ -89,6 +103,12 @@ type splashCmp struct { // Copilot device flow state copilotDeviceFlow *copilot.DeviceFlow showCopilotDeviceFlow bool + + // Claude state + claudeAuthMethodChooser *claude.AuthMethodChooser + claudeOAuth2 *claude.OAuth2 + showClaudeAuthMethodChooser bool + showClaudeOAuth2 bool } func New() Splash { @@ -114,6 +134,9 @@ func New() Splash { modelList: modelList, apiKeyInput: apiKeyInput, selectedNo: false, + + claudeAuthMethodChooser: claude.NewAuthMethodChooser(), + claudeOAuth2: claude.NewOAuth2(), } } @@ -135,6 +158,8 @@ func (s *splashCmp) Init() tea.Cmd { return tea.Batch( s.modelList.Init(), s.apiKeyInput.Init(), + s.claudeAuthMethodChooser.Init(), + s.claudeOAuth2.Init(), ) } @@ -151,6 +176,7 @@ func (s *splashCmp) SetSize(width int, height int) tea.Cmd { s.listHeight = s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - s.logoGap() - 2 listWidth := min(60, width) s.apiKeyInput.SetWidth(width - 2) + s.claudeAuthMethodChooser.SetWidth(min(width-2, 60)) return s.modelList.SetSize(listWidth, s.listHeight) } @@ -159,6 +185,24 @@ func (s *splashCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { switch msg := msg.(type) { case tea.WindowSizeMsg: return s, s.SetSize(msg.Width, msg.Height) + case claude.ValidationCompletedMsg: + var cmds []tea.Cmd + u, cmd := s.claudeOAuth2.Update(msg) + s.claudeOAuth2 = u.(*claude.OAuth2) + cmds = append(cmds, cmd) + + if msg.State == claude.OAuthValidationStateValid { + cmds = append( + cmds, + s.saveAPIKeyAndContinue(msg.Token, false), + func() tea.Msg { + time.Sleep(5 * time.Second) + return claude.AuthenticationCompleteMsg{} + }, + ) + } + + return s, tea.Batch(cmds...) case hyper.DeviceFlowCompletedMsg: s.showHyperDeviceFlow = false return s, s.saveAPIKeyAndContinue(msg.Token, true) @@ -179,6 +223,10 @@ func (s *splashCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { case copilot.DeviceFlowCompletedMsg: s.showCopilotDeviceFlow = false return s, s.saveAPIKeyAndContinue(msg.Token, true) + case claude.AuthenticationCompleteMsg: + s.showClaudeAuthMethodChooser = false + s.showClaudeOAuth2 = false + return s, util.CmdHandler(OnboardingCompleteMsg{}) case models.APIKeyStateChangeMsg: u, cmd := s.apiKeyInput.Update(msg) s.apiKeyInput = u.(*models.APIKeyInput) @@ -198,8 +246,34 @@ func (s *splashCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { return s, s.hyperDeviceFlow.CopyCode() case key.Matches(msg, s.keyMap.Copy) && s.showCopilotDeviceFlow: return s, s.copilotDeviceFlow.CopyCode() + case key.Matches(msg, s.keyMap.Copy) && s.showClaudeOAuth2 && s.claudeOAuth2.State == claude.OAuthStateURL: + return s, tea.Sequence( + tea.SetClipboard(s.claudeOAuth2.URL), + func() tea.Msg { + _ = clipboard.WriteAll(s.claudeOAuth2.URL) + return nil + }, + util.ReportInfo("URL copied to clipboard"), + ) + case key.Matches(msg, s.keyMap.Copy) && s.showClaudeAuthMethodChooser: + u, cmd := s.claudeAuthMethodChooser.Update(msg) + s.claudeAuthMethodChooser = u.(*claude.AuthMethodChooser) + return s, cmd + case key.Matches(msg, s.keyMap.Copy) && s.showClaudeOAuth2: + u, cmd := s.claudeOAuth2.Update(msg) + s.claudeOAuth2 = u.(*claude.OAuth2) + return s, cmd case key.Matches(msg, s.keyMap.Back): switch { + case s.showClaudeAuthMethodChooser: + s.claudeAuthMethodChooser.SetDefaults() + s.showClaudeAuthMethodChooser = false + return s, nil + case s.showClaudeOAuth2: + s.claudeOAuth2.SetDefaults() + s.showClaudeOAuth2 = false + s.showClaudeAuthMethodChooser = true + return s, nil case s.showHyperDeviceFlow: s.hyperDeviceFlow = nil s.showHyperDeviceFlow = false @@ -211,6 +285,9 @@ func (s *splashCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { case s.isAPIKeyValid: return s, nil case s.needsAPIKey: + if s.selectedModel.Provider.ID == catwalk.InferenceProviderAnthropic { + s.showClaudeAuthMethodChooser = true + } s.needsAPIKey = false s.selectedModel = nil s.isAPIKeyValid = false @@ -220,6 +297,28 @@ func (s *splashCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { } case key.Matches(msg, s.keyMap.Select): switch { + case s.showClaudeAuthMethodChooser: + selectedItem := s.modelList.SelectedModel() + if selectedItem == nil { + return s, nil + } + + switch s.claudeAuthMethodChooser.State { + case claude.AuthMethodAPIKey: + s.showClaudeAuthMethodChooser = false + s.needsAPIKey = true + s.selectedModel = selectedItem + s.apiKeyInput.SetProviderName(selectedItem.Provider.Name) + case claude.AuthMethodOAuth2: + s.selectedModel = selectedItem + s.showClaudeAuthMethodChooser = false + s.showClaudeOAuth2 = true + } + return s, nil + case s.showClaudeOAuth2: + m2, cmd2 := s.claudeOAuth2.ValidationConfirm() + s.claudeOAuth2 = m2.(*claude.OAuth2) + return s, cmd2 case s.showHyperDeviceFlow: return s, s.hyperDeviceFlow.CopyCodeAndOpenURL() case s.showCopilotDeviceFlow: @@ -237,6 +336,9 @@ func (s *splashCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { return s, tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{})) } else { switch selectedItem.Provider.ID { + case catwalk.InferenceProviderAnthropic: + s.showClaudeAuthMethodChooser = true + return s, nil case hyperp.Name: s.selectedModel = selectedItem s.showHyperDeviceFlow = true @@ -305,6 +407,10 @@ func (s *splashCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { return s, s.initializeProject() } case key.Matches(msg, s.keyMap.Tab, s.keyMap.LeftRight): + if s.showClaudeAuthMethodChooser { + s.claudeAuthMethodChooser.ToggleChoice() + return s, nil + } if s.needsAPIKey { u, cmd := s.apiKeyInput.Update(msg) s.apiKeyInput = u.(*models.APIKeyInput) @@ -346,6 +452,14 @@ func (s *splashCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { } default: switch { + case s.showClaudeAuthMethodChooser: + u, cmd := s.claudeAuthMethodChooser.Update(msg) + s.claudeAuthMethodChooser = u.(*claude.AuthMethodChooser) + return s, cmd + case s.showClaudeOAuth2: + u, cmd := s.claudeOAuth2.Update(msg) + s.claudeOAuth2 = u.(*claude.OAuth2) + return s, cmd case s.showHyperDeviceFlow: u, cmd := s.hyperDeviceFlow.Update(msg) s.hyperDeviceFlow = u.(*hyper.DeviceFlow) @@ -366,6 +480,10 @@ func (s *splashCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { } case tea.PasteMsg: switch { + case s.showClaudeOAuth2: + u, cmd := s.claudeOAuth2.Update(msg) + s.claudeOAuth2 = u.(*claude.OAuth2) + return s, cmd case s.showHyperDeviceFlow: u, cmd := s.hyperDeviceFlow.Update(msg) s.hyperDeviceFlow = u.(*hyper.DeviceFlow) @@ -385,6 +503,10 @@ func (s *splashCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { } case spinner.TickMsg: switch { + case s.showClaudeOAuth2: + u, cmd := s.claudeOAuth2.Update(msg) + s.claudeOAuth2 = u.(*claude.OAuth2) + return s, cmd case s.showHyperDeviceFlow: u, cmd := s.hyperDeviceFlow.Update(msg) s.hyperDeviceFlow = u.(*hyper.DeviceFlow) @@ -533,6 +655,38 @@ func (s *splashCmp) View() string { var content string switch { + case s.showClaudeAuthMethodChooser: + remainingHeight := s.height - lipgloss.Height(s.logoRendered) - SplashScreenPaddingY + chooserView := s.claudeAuthMethodChooser.View() + authMethodSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render( + lipgloss.JoinVertical( + lipgloss.Left, + t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Let's Auth Anthropic"), + "", + chooserView, + ), + ) + content = lipgloss.JoinVertical( + lipgloss.Left, + s.logoRendered, + authMethodSelector, + ) + case s.showClaudeOAuth2: + remainingHeight := s.height - lipgloss.Height(s.logoRendered) - SplashScreenPaddingY + oauth2View := s.claudeOAuth2.View() + oauthSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render( + lipgloss.JoinVertical( + lipgloss.Left, + t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Let's Auth Anthropic"), + "", + oauth2View, + ), + ) + content = lipgloss.JoinVertical( + lipgloss.Left, + s.logoRendered, + oauthSelector, + ) case s.showHyperDeviceFlow: remainingHeight := s.height - lipgloss.Height(s.logoRendered) - SplashScreenPaddingY hyperView := s.hyperDeviceFlow.View() @@ -662,6 +816,14 @@ func (s *splashCmp) View() string { func (s *splashCmp) Cursor() *tea.Cursor { switch { + case s.showClaudeAuthMethodChooser: + return nil + case s.showClaudeOAuth2: + if cursor := s.claudeOAuth2.CodeInput.Cursor(); cursor != nil { + cursor.Y += 2 // FIXME(@andreynering): Why do we need this? + return s.moveCursor(cursor) + } + return nil case s.needsAPIKey: cursor := s.apiKeyInput.Cursor() if cursor != nil { @@ -732,10 +894,16 @@ func (s *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor { } // Calculate the correct Y offset based on current state logoHeight := lipgloss.Height(s.logoRendered) - if s.needsAPIKey { + if s.needsAPIKey || s.showClaudeOAuth2 { + var view string + if s.needsAPIKey { + view = s.apiKeyInput.View() + } else { + view = s.claudeOAuth2.View() + } infoSectionHeight := lipgloss.Height(s.infoSection()) baseOffset := logoHeight + SplashScreenPaddingY + infoSectionHeight - remainingHeight := s.height - baseOffset - lipgloss.Height(s.apiKeyInput.View()) - SplashScreenPaddingY + remainingHeight := s.height - baseOffset - lipgloss.Height(view) - SplashScreenPaddingY offset := baseOffset + remainingHeight cursor.Y += offset cursor.X += 1 @@ -758,6 +926,20 @@ func (s *splashCmp) logoGap() int { // Bindings implements SplashPage. func (s *splashCmp) Bindings() []key.Binding { switch { + case s.showClaudeAuthMethodChooser: + return []key.Binding{ + s.keyMap.Select, + s.keyMap.Tab, + s.keyMap.Back, + } + case s.showClaudeOAuth2: + bindings := []key.Binding{ + s.keyMap.Select, + } + if s.claudeOAuth2.State == claude.OAuthStateURL { + bindings = append(bindings, s.keyMap.Copy) + } + return bindings case s.needsAPIKey: return []key.Binding{ s.keyMap.Select, @@ -865,6 +1047,22 @@ func (s *splashCmp) IsAPIKeyValid() bool { return s.isAPIKeyValid } +func (s *splashCmp) IsShowingClaudeAuthMethodChooser() bool { + return s.showClaudeAuthMethodChooser +} + +func (s *splashCmp) IsShowingClaudeOAuth2() bool { + return s.showClaudeOAuth2 +} + +func (s *splashCmp) IsClaudeOAuthURLState() bool { + return s.showClaudeOAuth2 && s.claudeOAuth2.State == claude.OAuthStateURL +} + +func (s *splashCmp) IsClaudeOAuthComplete() bool { + return s.showClaudeOAuth2 && s.claudeOAuth2.State == claude.OAuthStateCode && s.claudeOAuth2.ValidationState == claude.OAuthValidationStateValid +} + func (s *splashCmp) IsShowingHyperOAuth2() bool { return s.showHyperDeviceFlow } diff --git a/internal/tui/components/dialogs/claude/method.go b/internal/tui/components/dialogs/claude/method.go new file mode 100644 index 0000000000000000000000000000000000000000..4c89ac0900088098b53ddba8d240809c299d262d --- /dev/null +++ b/internal/tui/components/dialogs/claude/method.go @@ -0,0 +1,115 @@ +package claude + +import ( + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" + "git.secluded.site/crush/internal/tui/styles" + "git.secluded.site/crush/internal/tui/util" +) + +type AuthMethod int + +const ( + AuthMethodAPIKey AuthMethod = iota + AuthMethodOAuth2 +) + +type AuthMethodChooser struct { + State AuthMethod + width int + isOnboarding bool +} + +func NewAuthMethodChooser() *AuthMethodChooser { + return &AuthMethodChooser{ + State: AuthMethodOAuth2, + } +} + +func (a *AuthMethodChooser) Init() tea.Cmd { + return nil +} + +func (a *AuthMethodChooser) Update(msg tea.Msg) (util.Model, tea.Cmd) { + return a, nil +} + +func (a *AuthMethodChooser) View() string { + t := styles.CurrentTheme() + + white := lipgloss.NewStyle().Foreground(t.White) + primary := lipgloss.NewStyle().Foreground(t.Primary) + success := lipgloss.NewStyle().Foreground(t.Success) + + titleStyle := white + if a.isOnboarding { + titleStyle = primary + } + + question := lipgloss. + NewStyle(). + Margin(0, 1). + Render(titleStyle.Render("How would you like to authenticate with ") + success.Render("Anthropic") + titleStyle.Render("?")) + + squareWidth := (a.width - 2) / 2 + squareHeight := squareWidth / 3 + if isOdd(squareHeight) { + squareHeight++ + } + + square := lipgloss.NewStyle(). + Width(squareWidth). + Height(squareHeight). + Margin(0, 0). + Border(lipgloss.RoundedBorder()) + + squareText := lipgloss.NewStyle(). + Width(squareWidth - 2). + Height(squareHeight). + Align(lipgloss.Center). + AlignVertical(lipgloss.Center) + + oauthBorder := t.AuthBorderSelected + oauthText := t.AuthTextSelected + apiKeyBorder := t.AuthBorderUnselected + apiKeyText := t.AuthTextUnselected + + if a.State == AuthMethodAPIKey { + oauthBorder, apiKeyBorder = apiKeyBorder, oauthBorder + oauthText, apiKeyText = apiKeyText, oauthText + } + + return lipgloss.JoinVertical( + lipgloss.Left, + question, + "", + lipgloss.JoinHorizontal( + lipgloss.Center, + square.MarginLeft(1). + Inherit(oauthBorder).Render(squareText.Inherit(oauthText).Render("Claude Account\nwith Subscription")), + square.MarginRight(1). + Inherit(apiKeyBorder).Render(squareText.Inherit(apiKeyText).Render("API Key")), + ), + ) +} + +func (a *AuthMethodChooser) SetDefaults() { + a.State = AuthMethodOAuth2 +} + +func (a *AuthMethodChooser) SetWidth(w int) { + a.width = w +} + +func (a *AuthMethodChooser) ToggleChoice() { + switch a.State { + case AuthMethodAPIKey: + a.State = AuthMethodOAuth2 + case AuthMethodOAuth2: + a.State = AuthMethodAPIKey + } +} + +func isOdd(n int) bool { + return n%2 != 0 +} diff --git a/internal/tui/components/dialogs/claude/oauth.go b/internal/tui/components/dialogs/claude/oauth.go new file mode 100644 index 0000000000000000000000000000000000000000..95d790aa99499a5f2f6a8a2ec3b9e92466135710 --- /dev/null +++ b/internal/tui/components/dialogs/claude/oauth.go @@ -0,0 +1,267 @@ +package claude + +import ( + "context" + "fmt" + "net/url" + + "charm.land/bubbles/v2/spinner" + "charm.land/bubbles/v2/textinput" + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" + "git.secluded.site/crush/internal/oauth" + "git.secluded.site/crush/internal/oauth/claude" + "git.secluded.site/crush/internal/tui/styles" + "git.secluded.site/crush/internal/tui/util" + "github.com/pkg/browser" + "github.com/zeebo/xxh3" +) + +type OAuthState int + +const ( + OAuthStateURL OAuthState = iota + OAuthStateCode +) + +type OAuthValidationState int + +const ( + OAuthValidationStateNone OAuthValidationState = iota + OAuthValidationStateVerifying + OAuthValidationStateValid + OAuthValidationStateError +) + +type ValidationCompletedMsg struct { + State OAuthValidationState + Token *oauth.Token +} + +type AuthenticationCompleteMsg struct{} + +type OAuth2 struct { + State OAuthState + ValidationState OAuthValidationState + width int + isOnboarding bool + + // URL page + err error + verifier string + challenge string + URL string + urlId string + token *oauth.Token + + // Code input page + CodeInput textinput.Model + spinner spinner.Model +} + +func NewOAuth2() *OAuth2 { + return &OAuth2{ + State: OAuthStateURL, + } +} + +func (o *OAuth2) Init() tea.Cmd { + t := styles.CurrentTheme() + + verifier, challenge, err := claude.GetChallenge() + if err != nil { + o.err = err + return nil + } + + url, err := claude.AuthorizeURL(verifier, challenge) + if err != nil { + o.err = err + return nil + } + + o.verifier = verifier + o.challenge = challenge + o.URL = url + + h := xxh3.New() + _, _ = h.WriteString(o.URL) + o.urlId = fmt.Sprintf("id=%x", h.Sum(nil)) + + o.CodeInput = textinput.New() + o.CodeInput.Placeholder = "Paste or type" + o.CodeInput.SetVirtualCursor(false) + o.CodeInput.Prompt = "> " + o.CodeInput.SetStyles(t.S().TextInput) + o.CodeInput.SetWidth(50) + + o.spinner = spinner.New( + spinner.WithSpinner(spinner.Dot), + spinner.WithStyle(t.S().Base.Foreground(t.Green)), + ) + + return nil +} + +func (o *OAuth2) Update(msg tea.Msg) (util.Model, tea.Cmd) { + var cmds []tea.Cmd + + switch msg := msg.(type) { + case ValidationCompletedMsg: + o.ValidationState = msg.State + o.token = msg.Token + switch o.ValidationState { + case OAuthValidationStateError: + o.CodeInput.Focus() + } + o.updatePrompt() + } + + if o.ValidationState == OAuthValidationStateVerifying { + var cmd tea.Cmd + o.spinner, cmd = o.spinner.Update(msg) + cmds = append(cmds, cmd) + o.updatePrompt() + } + { + var cmd tea.Cmd + o.CodeInput, cmd = o.CodeInput.Update(msg) + cmds = append(cmds, cmd) + } + + return o, tea.Batch(cmds...) +} + +func (o *OAuth2) ValidationConfirm() (util.Model, tea.Cmd) { + var cmds []tea.Cmd + + switch { + case o.State == OAuthStateURL: + _ = browser.OpenURL(o.URL) + o.State = OAuthStateCode + cmds = append(cmds, o.CodeInput.Focus()) + case o.ValidationState == OAuthValidationStateNone || o.ValidationState == OAuthValidationStateError: + o.CodeInput.Blur() + o.ValidationState = OAuthValidationStateVerifying + cmds = append(cmds, o.spinner.Tick, o.validateCode) + case o.ValidationState == OAuthValidationStateValid: + cmds = append(cmds, func() tea.Msg { return AuthenticationCompleteMsg{} }) + } + + o.updatePrompt() + return o, tea.Batch(cmds...) +} + +func (o *OAuth2) View() string { + t := styles.CurrentTheme() + + whiteStyle := lipgloss.NewStyle().Foreground(t.White) + primaryStyle := lipgloss.NewStyle().Foreground(t.Primary) + successStyle := lipgloss.NewStyle().Foreground(t.Success) + errorStyle := lipgloss.NewStyle().Foreground(t.Error) + + titleStyle := whiteStyle + if o.isOnboarding { + titleStyle = primaryStyle + } + + switch { + case o.err != nil: + return lipgloss.NewStyle(). + Margin(0, 1). + Foreground(t.Error). + Render(o.err.Error()) + case o.State == OAuthStateURL: + heading := lipgloss. + NewStyle(). + Margin(0, 1). + Render(titleStyle.Render("Press enter key to open the following ") + successStyle.Render("URL") + titleStyle.Render(":")) + + return lipgloss.JoinVertical( + lipgloss.Left, + heading, + "", + lipgloss.NewStyle(). + Margin(0, 1). + Foreground(t.FgMuted). + Hyperlink(o.URL, o.urlId). + Render(o.displayUrl()), + ) + case o.State == OAuthStateCode: + var heading string + + switch o.ValidationState { + case OAuthValidationStateNone: + st := lipgloss.NewStyle().Margin(0, 1) + heading = st.Render(titleStyle.Render("Enter the ") + successStyle.Render("code") + titleStyle.Render(" you received.")) + case OAuthValidationStateVerifying: + heading = titleStyle.Margin(0, 1).Render("Verifying...") + case OAuthValidationStateValid: + heading = successStyle.Margin(0, 1).Render("Validated.") + case OAuthValidationStateError: + heading = errorStyle.Margin(0, 1).Render("Invalid. Try again?") + } + + return lipgloss.JoinVertical( + lipgloss.Left, + heading, + "", + " "+o.CodeInput.View(), + ) + default: + panic("claude oauth2: invalid state") + } +} + +func (o *OAuth2) SetDefaults() { + o.State = OAuthStateURL + o.ValidationState = OAuthValidationStateNone + o.CodeInput.SetValue("") + o.err = nil +} + +func (o *OAuth2) SetWidth(w int) { + o.width = w + o.CodeInput.SetWidth(w - 4) +} + +func (o *OAuth2) SetError(err error) { + o.err = err +} + +func (o *OAuth2) validateCode() tea.Msg { + token, err := claude.ExchangeToken(context.Background(), o.CodeInput.Value(), o.verifier) + if err != nil || token == nil { + return ValidationCompletedMsg{State: OAuthValidationStateError} + } + return ValidationCompletedMsg{State: OAuthValidationStateValid, Token: token} +} + +func (o *OAuth2) updatePrompt() { + switch o.ValidationState { + case OAuthValidationStateNone: + o.CodeInput.Prompt = "> " + case OAuthValidationStateVerifying: + o.CodeInput.Prompt = o.spinner.View() + " " + case OAuthValidationStateValid: + o.CodeInput.Prompt = styles.CheckIcon + " " + case OAuthValidationStateError: + o.CodeInput.Prompt = styles.ErrorIcon + " " + } +} + +// Remove query params for display +// e.g., "https://claude.ai/oauth/authorize?..." -> "https://claude.ai/oauth/authorize..." +func (o *OAuth2) displayUrl() string { + parsed, err := url.Parse(o.URL) + if err != nil { + return o.URL + } + + if parsed.RawQuery != "" { + parsed.RawQuery = "" + return parsed.String() + "..." + } + + return o.URL +} diff --git a/internal/tui/components/dialogs/models/keys.go b/internal/tui/components/dialogs/models/keys.go index ff81404b1f1937fff09d917bf3a9e3b24f4d38c9..eda235aebb858fef21c582921cfb9e305a6fed19 100644 --- a/internal/tui/components/dialogs/models/keys.go +++ b/internal/tui/components/dialogs/models/keys.go @@ -18,6 +18,11 @@ type KeyMap struct { isHyperDeviceFlow bool isCopilotDeviceFlow bool isCopilotUnavailable bool + + isClaudeAuthChoiceHelp bool + isClaudeOAuthHelp bool + isClaudeOAuthURLState bool + isClaudeOAuthHelpComplete bool } func DefaultKeyMap() KeyMap { @@ -95,6 +100,58 @@ func (k KeyMap) ShortHelp() []key.Binding { k.Close, } } + if k.isClaudeAuthChoiceHelp { + return []key.Binding{ + key.NewBinding( + key.WithKeys("left", "right", "h", "l"), + key.WithHelp("←→", "choose"), + ), + key.NewBinding( + key.WithKeys("enter"), + key.WithHelp("enter", "accept"), + ), + key.NewBinding( + key.WithKeys("esc"), + key.WithHelp("esc", "back"), + ), + } + } + if k.isClaudeOAuthHelp { + if k.isClaudeOAuthHelpComplete { + return []key.Binding{ + key.NewBinding( + key.WithKeys("enter"), + key.WithHelp("enter", "close"), + ), + } + } + + enterHelp := "submit" + if k.isClaudeOAuthURLState { + enterHelp = "open" + } + + bindings := []key.Binding{ + key.NewBinding( + key.WithKeys("enter"), + key.WithHelp("enter", enterHelp), + ), + } + + if k.isClaudeOAuthURLState { + bindings = append(bindings, key.NewBinding( + key.WithKeys("c"), + key.WithHelp("c", "copy url"), + )) + } + + bindings = append(bindings, key.NewBinding( + key.WithKeys("esc"), + key.WithHelp("esc", "back"), + )) + + return bindings + } if k.isAPIKeyHelp && !k.isAPIKeyValid { return []key.Binding{ key.NewBinding( diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go index 10b30f747c0ea893c2e31f2d36fe073c1e866c1d..c3eb44581db685feadb47f31e94aa040a097837f 100644 --- a/internal/tui/components/dialogs/models/models.go +++ b/internal/tui/components/dialogs/models/models.go @@ -14,11 +14,13 @@ import ( "git.secluded.site/crush/internal/config" "git.secluded.site/crush/internal/tui/components/core" "git.secluded.site/crush/internal/tui/components/dialogs" + "git.secluded.site/crush/internal/tui/components/dialogs/claude" "git.secluded.site/crush/internal/tui/components/dialogs/copilot" "git.secluded.site/crush/internal/tui/components/dialogs/hyper" "git.secluded.site/crush/internal/tui/exp/list" "git.secluded.site/crush/internal/tui/styles" "git.secluded.site/crush/internal/tui/util" + "github.com/atotto/clipboard" "github.com/charmbracelet/catwalk/pkg/catwalk" ) @@ -79,6 +81,12 @@ type modelDialogCmp struct { // Copilot device flow state copilotDeviceFlow *copilot.DeviceFlow showCopilotDeviceFlow bool + + // Claude state + claudeAuthMethodChooser *claude.AuthMethodChooser + claudeOAuth2 *claude.OAuth2 + showClaudeAuthMethodChooser bool + showClaudeOAuth2 bool } func NewModelDialogCmp() ModelDialog { @@ -103,6 +111,9 @@ func NewModelDialogCmp() ModelDialog { width: defaultWidth, keyMap: DefaultKeyMap(), help: help, + + claudeAuthMethodChooser: claude.NewAuthMethodChooser(), + claudeOAuth2: claude.NewOAuth2(), } } @@ -110,6 +121,8 @@ func (m *modelDialogCmp) Init() tea.Cmd { return tea.Batch( m.modelList.Init(), m.apiKeyInput.Init(), + m.claudeAuthMethodChooser.Init(), + m.claudeOAuth2.Init(), ) } @@ -120,6 +133,7 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { m.wHeight = msg.Height m.apiKeyInput.SetWidth(m.width - 2) m.help.SetWidth(m.width - 2) + m.claudeAuthMethodChooser.SetWidth(m.width - 2) return m, m.modelList.SetSize(m.listWidth(), m.listHeight()) case APIKeyStateChangeMsg: u, cmd := m.apiKeyInput.Update(msg) @@ -143,6 +157,20 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { return m, nil case copilot.DeviceFlowCompletedMsg: return m, m.saveOauthTokenAndContinue(msg.Token, true) + case claude.ValidationCompletedMsg: + var cmds []tea.Cmd + u, cmd := m.claudeOAuth2.Update(msg) + m.claudeOAuth2 = u.(*claude.OAuth2) + cmds = append(cmds, cmd) + + if msg.State == claude.OAuthValidationStateValid { + cmds = append(cmds, m.saveOauthTokenAndContinue(msg.Token, false)) + m.keyMap.isClaudeOAuthHelpComplete = true + } + + return m, tea.Batch(cmds...) + case claude.AuthenticationCompleteMsg: + return m, util.CmdHandler(dialogs.CloseDialogMsg{}) case tea.KeyPressMsg: switch { // Handle Hyper device flow keys @@ -150,6 +178,18 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { return m, m.hyperDeviceFlow.CopyCode() case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showCopilotDeviceFlow: return m, m.copilotDeviceFlow.CopyCode() + case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showClaudeOAuth2 && m.claudeOAuth2.State == claude.OAuthStateURL: + return m, tea.Sequence( + tea.SetClipboard(m.claudeOAuth2.URL), + func() tea.Msg { + _ = clipboard.WriteAll(m.claudeOAuth2.URL) + return nil + }, + util.ReportInfo("URL copied to clipboard"), + ) + case key.Matches(msg, m.keyMap.Choose) && m.showClaudeAuthMethodChooser: + m.claudeAuthMethodChooser.ToggleChoice() + return m, nil case key.Matches(msg, m.keyMap.Select): // If showing device flow, enter copies code and opens URL if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil { @@ -169,15 +209,37 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { } askForApiKey := func() { + m.keyMap.isClaudeAuthChoiceHelp = false + m.keyMap.isClaudeOAuthHelp = false m.keyMap.isAPIKeyHelp = true m.showHyperDeviceFlow = false m.showCopilotDeviceFlow = false + m.showClaudeAuthMethodChooser = false m.needsAPIKey = true m.selectedModel = selectedItem m.selectedModelType = modelType m.apiKeyInput.SetProviderName(selectedItem.Provider.Name) } + if m.showClaudeAuthMethodChooser { + switch m.claudeAuthMethodChooser.State { + case claude.AuthMethodAPIKey: + askForApiKey() + case claude.AuthMethodOAuth2: + m.selectedModel = selectedItem + m.selectedModelType = modelType + m.showClaudeAuthMethodChooser = false + m.showClaudeOAuth2 = true + m.keyMap.isClaudeAuthChoiceHelp = false + m.keyMap.isClaudeOAuthHelp = true + } + return m, nil + } + if m.showClaudeOAuth2 { + m2, cmd2 := m.claudeOAuth2.ValidationConfirm() + m.claudeOAuth2 = m2.(*claude.OAuth2) + return m, cmd2 + } if m.isAPIKeyValid { return m, m.saveOauthTokenAndContinue(m.apiKeyValue, true) } @@ -236,6 +298,10 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { ) } switch selectedItem.Provider.ID { + case catwalk.InferenceProviderAnthropic: + m.showClaudeAuthMethodChooser = true + m.keyMap.isClaudeAuthChoiceHelp = true + return m, nil case hyperp.Name: m.showHyperDeviceFlow = true m.selectedModel = selectedItem @@ -261,6 +327,9 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { return m, nil case key.Matches(msg, m.keyMap.Tab): switch { + case m.showClaudeAuthMethodChooser: + m.claudeAuthMethodChooser.ToggleChoice() + return m, nil case m.needsAPIKey: u, cmd := m.apiKeyInput.Update(msg) m.apiKeyInput = u.(*APIKeyInput) @@ -286,6 +355,12 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { } m.showCopilotDeviceFlow = false m.selectedModel = nil + case m.showClaudeAuthMethodChooser: + m.claudeAuthMethodChooser.SetDefaults() + m.showClaudeAuthMethodChooser = false + m.keyMap.isClaudeAuthChoiceHelp = false + m.keyMap.isClaudeOAuthHelp = false + return m, nil case m.needsAPIKey: if m.isAPIKeyValid { return m, nil @@ -302,6 +377,14 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { } default: switch { + case m.showClaudeAuthMethodChooser: + u, cmd := m.claudeAuthMethodChooser.Update(msg) + m.claudeAuthMethodChooser = u.(*claude.AuthMethodChooser) + return m, cmd + case m.showClaudeOAuth2: + u, cmd := m.claudeOAuth2.Update(msg) + m.claudeOAuth2 = u.(*claude.OAuth2) + return m, cmd case m.needsAPIKey: u, cmd := m.apiKeyInput.Update(msg) m.apiKeyInput = u.(*APIKeyInput) @@ -314,6 +397,10 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { } case tea.PasteMsg: switch { + case m.showClaudeOAuth2: + u, cmd := m.claudeOAuth2.Update(msg) + m.claudeOAuth2 = u.(*claude.OAuth2) + return m, cmd case m.needsAPIKey: u, cmd := m.apiKeyInput.Update(msg) m.apiKeyInput = u.(*APIKeyInput) @@ -346,6 +433,10 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { u, cmd := m.copilotDeviceFlow.Update(msg) m.copilotDeviceFlow = u.(*copilot.DeviceFlow) return m, cmd + case m.showClaudeOAuth2: + u, cmd := m.claudeOAuth2.Update(msg) + m.claudeOAuth2 = u.(*claude.OAuth2) + return m, cmd default: u, cmd := m.apiKeyInput.Update(msg) m.apiKeyInput = u.(*APIKeyInput) @@ -392,6 +483,27 @@ func (m *modelDialogCmp) View() string { m.keyMap.isCopilotUnavailable = false switch { + case m.showClaudeAuthMethodChooser: + chooserView := m.claudeAuthMethodChooser.View() + content := lipgloss.JoinVertical( + lipgloss.Left, + t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)), + chooserView, + "", + t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)), + ) + return m.style().Render(content) + case m.showClaudeOAuth2: + m.keyMap.isClaudeOAuthURLState = m.claudeOAuth2.State == claude.OAuthStateURL + oauth2View := m.claudeOAuth2.View() + content := lipgloss.JoinVertical( + lipgloss.Left, + t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)), + oauth2View, + "", + t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)), + ) + return m.style().Render(content) case m.needsAPIKey: // Show API key input m.keyMap.isAPIKeyHelp = true @@ -428,6 +540,16 @@ func (m *modelDialogCmp) Cursor() *tea.Cursor { if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil { return m.copilotDeviceFlow.Cursor() } + if m.showClaudeAuthMethodChooser { + return nil + } + if m.showClaudeOAuth2 { + if cursor := m.claudeOAuth2.CodeInput.Cursor(); cursor != nil { + cursor.Y += 2 // FIXME(@andreynering): Why do we need this? + return m.moveCursor(cursor) + } + return nil + } if m.needsAPIKey { cursor := m.apiKeyInput.Cursor() if cursor != nil { diff --git a/internal/tui/page/chat/chat.go b/internal/tui/page/chat/chat.go index 82b8644638981b88801a233cf33bfa9701564fb1..92a0647832037114c5f67d29806a5668ef163f4d 100644 --- a/internal/tui/page/chat/chat.go +++ b/internal/tui/page/chat/chat.go @@ -29,6 +29,7 @@ import ( "git.secluded.site/crush/internal/tui/components/core" "git.secluded.site/crush/internal/tui/components/core/layout" "git.secluded.site/crush/internal/tui/components/dialogs" + "git.secluded.site/crush/internal/tui/components/dialogs/claude" "git.secluded.site/crush/internal/tui/components/dialogs/commands" "git.secluded.site/crush/internal/tui/components/dialogs/copilot" "git.secluded.site/crush/internal/tui/components/dialogs/filepicker" @@ -336,7 +337,9 @@ func (p *chatPage) Update(msg tea.Msg) (util.Model, tea.Cmd) { cmds = append(cmds, cmd) return p, tea.Batch(cmds...) - case hyper.DeviceFlowCompletedMsg, + case claude.ValidationCompletedMsg, + claude.AuthenticationCompleteMsg, + hyper.DeviceFlowCompletedMsg, hyper.DeviceAuthInitiatedMsg, hyper.DeviceFlowErrorMsg, copilot.DeviceAuthInitiatedMsg, @@ -1034,8 +1037,53 @@ func (p *chatPage) Help() help.KeyMap { var shortList []key.Binding var fullList [][]key.Binding switch { - case p.isOnboarding: + case p.isOnboarding && p.splash.IsShowingClaudeAuthMethodChooser(): + shortList = append(shortList, + // Choose auth method + key.NewBinding( + key.WithKeys("left", "right", "tab"), + key.WithHelp("←→/tab", "choose"), + ), + // Accept selection + key.NewBinding( + key.WithKeys("enter"), + key.WithHelp("enter", "accept"), + ), + // Go back + key.NewBinding( + key.WithKeys("esc", "alt+esc"), + key.WithHelp("esc", "back"), + ), + // Quit + key.NewBinding( + key.WithKeys("ctrl+c"), + key.WithHelp("ctrl+c", "quit"), + ), + ) + // keep them the same + for _, v := range shortList { + fullList = append(fullList, []key.Binding{v}) + } + case p.isOnboarding && p.splash.IsShowingClaudeOAuth2(): switch { + case p.splash.IsClaudeOAuthURLState(): + shortList = append(shortList, + key.NewBinding( + key.WithKeys("enter"), + key.WithHelp("enter", "open"), + ), + key.NewBinding( + key.WithKeys("c"), + key.WithHelp("c", "copy url"), + ), + ) + case p.splash.IsClaudeOAuthComplete(): + shortList = append(shortList, + key.NewBinding( + key.WithKeys("enter"), + key.WithHelp("enter", "continue"), + ), + ) case p.splash.IsShowingHyperOAuth2() || p.splash.IsShowingCopilotOAuth2(): shortList = append(shortList, key.NewBinding(