diff --git a/internal/cmd/login.go b/internal/cmd/login.go index 1d99468f5027690505f6e8248470ff28871389f4..1e78cb2462f21557bc864c02241f459624809253 100644 --- a/internal/cmd/login.go +++ b/internal/cmd/login.go @@ -32,6 +32,9 @@ crush login # Authenticate with GitHub Copilot crush login copilot + +# Force re-authentication even if already logged in +crush login -f copilot `, ValidArgs: []cobra.Completion{ "hyper", @@ -57,20 +60,36 @@ crush login copilot if len(args) > 0 { provider = args[0] } + force, _ := cmd.Flags().GetBool("force") switch provider { case "hyper": - return loginHyper(c, ws.ID) + return loginHyper(c, ws.ID, force) case "copilot", "github", "github-copilot": - return loginCopilot(cmd.Context(), c, ws.ID) + return loginCopilot(c, ws.ID, force) default: return fmt.Errorf("unknown platform: %s", args[0]) } }, } -func loginHyper(c *client.Client, wsID string) error { +func init() { + loginCmd.Flags().BoolP("force", "f", false, "Force re-authentication even if already logged in") +} + +func loginHyper(c *client.Client, wsID string, force bool) error { ctx := getLoginContext() + if !force { + cfg, err := c.GetConfig(ctx, wsID) + if err == nil && cfg != nil { + if pc, ok := cfg.Providers.Get("hyper"); ok && pc.OAuthToken != nil { + fmt.Println("You are already logged in to Hyper.") + fmt.Println("Use --force to re-authenticate.") + return nil + } + } + } + resp, err := hyper.InitiateDeviceAuth(ctx) if err != nil { return err @@ -127,14 +146,17 @@ func loginHyper(c *client.Client, wsID string) error { return nil } -func loginCopilot(ctx context.Context, c *client.Client, wsID string) error { +func loginCopilot(c *client.Client, wsID string, force bool) error { loginCtx := getLoginContext() - cfg, err := c.GetConfig(ctx, wsID) - if err == nil && cfg != nil { - if pc, ok := cfg.Providers.Get("copilot"); ok && pc.OAuthToken != nil { - fmt.Println("You are already logged in to GitHub Copilot.") - return nil + if !force { + cfg, err := c.GetConfig(loginCtx, wsID) + if err == nil && cfg != nil { + if pc, ok := cfg.Providers.Get("copilot"); ok && pc.OAuthToken != nil { + fmt.Println("You are already logged in to GitHub Copilot.") + fmt.Println("Use --force to re-authenticate.") + return nil + } } } diff --git a/internal/cmd/login_test.go b/internal/cmd/login_test.go new file mode 100644 index 0000000000000000000000000000000000000000..c764b50d899b0b646c00eff2b6e53237f0674cda --- /dev/null +++ b/internal/cmd/login_test.go @@ -0,0 +1,21 @@ +package cmd + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestLoginCmd_Aliases(t *testing.T) { + t.Parallel() + + require.Equal(t, "auth", loginCmd.Aliases[0]) +} + +func TestLoginCmd_ForceFlag(t *testing.T) { + t.Parallel() + + flag := loginCmd.Flags().Lookup("force") + require.NotNil(t, flag) + require.Equal(t, "f", flag.Shorthand) +} diff --git a/internal/cmd/logout.go b/internal/cmd/logout.go new file mode 100644 index 0000000000000000000000000000000000000000..f7191457ffd5bbbe97cd28f82210024bbfd6032f --- /dev/null +++ b/internal/cmd/logout.go @@ -0,0 +1,182 @@ +package cmd + +import ( + "cmp" + "context" + "fmt" + "os" + "os/signal" + + "charm.land/lipgloss/v2" + "github.com/charmbracelet/crush/internal/client" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/x/ansi" + "github.com/spf13/cobra" +) + +var ( + logoutHeaderStyle = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("205")) + logoutItemStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("252")) + logoutPromptStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("215")) +) + +var logoutCmd = &cobra.Command{ + Aliases: []string{"signout"}, + Use: "logout [platform]", + Short: "Logout Crush from a platform", + Long: `Logout Crush from a specified platform, removing stored credentials. +The platform should be provided as an argument. +If no argument is given, a list of logged-in platforms will be shown. +Available platforms are: hyper, copilot.`, + Example: ` +# Sign out from Charm Hyper +crush logout hyper + +# Sign out from GitHub Copilot +crush logout copilot + `, + ValidArgs: []cobra.Completion{ + "hyper", + "copilot", + "github", + "github-copilot", + }, + Args: cobra.MaximumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + c, ws, cleanup, err := connectToServer(cmd) + if err != nil { + return err + } + defer cleanup() + + progressEnabled := ws.Config.Options.Progress == nil || *ws.Config.Options.Progress + if progressEnabled && supportsProgressBar() { + _, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar) + defer func() { _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar) }() + } + + var provider string + if len(args) == 0 { + provider, err = pickLoggedInProvider(c, ws.ID) + if err != nil { + return err + } + if provider == "" { + return nil + } + } else { + provider = args[0] + } + + force, _ := cmd.Flags().GetBool("force") + if !force { + fmt.Print(logoutPromptStyle.Render(fmt.Sprintf("Are you sure you want to logout %s? (y/N) ", provider))) + var response string + _, err := fmt.Scanln(&response) + if err != nil || (response != "y" && response != "Y" && response != "yes" && response != "Yes" && response != "YES") { + fmt.Println(logoutHeaderStyle.Render("Logout cancelled.")) + return nil + } + } + + switch provider { + case "hyper": + return logoutHyper(c, ws.ID) + case "copilot", "github", "github-copilot": + return logoutCopilot(c, ws.ID) + default: + return fmt.Errorf("unknown platform: %s", provider) + } + }, +} + +func logoutHyper(c *client.Client, wsID string) error { + ctx := getLogoutContext() + + if err := cmp.Or( + c.RemoveConfigField(ctx, wsID, config.ScopeGlobal, "providers.hyper.api_key"), + c.RemoveConfigField(ctx, wsID, config.ScopeGlobal, "providers.hyper.oauth"), + ); err != nil { + return err + } + + fmt.Println(logoutHeaderStyle.Render("Successfully logged out of Hyper.")) + return nil +} + +func logoutCopilot(c *client.Client, wsID string) error { + ctx := getLogoutContext() + + if err := cmp.Or( + c.RemoveConfigField(ctx, wsID, config.ScopeGlobal, "providers.copilot.api_key"), + c.RemoveConfigField(ctx, wsID, config.ScopeGlobal, "providers.copilot.oauth"), + ); err != nil { + return err + } + + fmt.Println(logoutHeaderStyle.Render("Successfully logged out of GitHub Copilot.")) + return nil +} + +func pickLoggedInProvider(c *client.Client, wsID string) (string, error) { + ctx := getLogoutContext() + + cfg, err := c.GetConfig(ctx, wsID) + if err != nil { + return "", fmt.Errorf("failed to get config: %w", err) + } + + type loggedInProvider struct { + id string + name string + } + + var loggedIn []loggedInProvider + for p := range cfg.Providers.Seq() { + if p.OAuthToken != nil || p.APIKey != "" { + name := p.Name + if name == "" { + name = p.ID + } + loggedIn = append(loggedIn, loggedInProvider{id: p.ID, name: name}) + } + } + + if len(loggedIn) == 0 { + fmt.Println(logoutPromptStyle.Render("You are not logged in to any platform.")) + return "", nil + } + + if len(loggedIn) == 1 { + return loggedIn[0].id, nil + } + + fmt.Println(logoutHeaderStyle.Render("Logged-in platforms:")) + for i, p := range loggedIn { + fmt.Println(logoutItemStyle.Render(fmt.Sprintf(" %d. %s", i+1, p.name))) + } + fmt.Print(logoutPromptStyle.Render(fmt.Sprintf("Select a platform to logout (1-%d): ", len(loggedIn)))) + + var choice int + _, err = fmt.Scanln(&choice) + if err != nil || choice < 1 || choice > len(loggedIn) { + fmt.Println(logoutHeaderStyle.Render("Logout cancelled.")) + return "", nil + } + + return loggedIn[choice-1].id, nil +} + +func init() { + logoutCmd.Flags().BoolP("force", "f", false, "Skip logout confirmation prompt") +} + +func getLogoutContext() context.Context { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill) + go func() { + <-ctx.Done() + cancel() + os.Exit(1) + }() + return ctx +} diff --git a/internal/cmd/logout_test.go b/internal/cmd/logout_test.go new file mode 100644 index 0000000000000000000000000000000000000000..b049c5cffde30360e9cc9d2a6018d9d934af4cd8 --- /dev/null +++ b/internal/cmd/logout_test.go @@ -0,0 +1,41 @@ +package cmd + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestLogoutCmd_Aliases(t *testing.T) { + t.Parallel() + + require.Equal(t, "signout", logoutCmd.Aliases[0]) +} + +func TestLogoutCmd_HasForceFlag(t *testing.T) { + t.Parallel() + + flag := logoutCmd.Flags().Lookup("force") + require.NotNil(t, flag) + require.Equal(t, "f", flag.Shorthand) + require.Equal(t, "false", flag.DefValue) +} + +func TestLogoutCmd_ValidArgs(t *testing.T) { + t.Parallel() + + validPlatforms := map[string]bool{} + for _, p := range logoutCmd.ValidArgs { + validPlatforms[p] = true + } + require.True(t, validPlatforms["hyper"]) + require.True(t, validPlatforms["copilot"]) + require.True(t, validPlatforms["github"]) + require.True(t, validPlatforms["github-copilot"]) +} + +func TestLogoutContext_CreatesValidContext(t *testing.T) { + ctx := getLogoutContext() + require.NotNil(t, ctx) + require.NoError(t, ctx.Err()) +} diff --git a/internal/cmd/root.go b/internal/cmd/root.go index d53e75ddd2e9294e3b7dd8e03b13012fac036d69..8cc9303b312943c75715057c70c47f47fb50e65b 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -63,6 +63,7 @@ func init() { projectsCmd, updateProvidersCmd, logsCmd, + logoutCmd, schemaCmd, loginCmd, statsCmd,