feat(oauth): add logout command (#2838)

Kieran Klukas created

Change summary

internal/cmd/login.go       |  40 ++++++-
internal/cmd/login_test.go  |  21 ++++
internal/cmd/logout.go      | 182 +++++++++++++++++++++++++++++++++++++++
internal/cmd/logout_test.go |  41 ++++++++
internal/cmd/root.go        |   1 
5 files changed, 276 insertions(+), 9 deletions(-)

Detailed changes

internal/cmd/login.go 🔗

@@ -32,6 +32,9 @@ crush login
 
 # Authenticate with GitHub Copilot
 crush login copilot
+
+# Force re-authentication even if already logged in
+crush login -f copilot
   `,
 	ValidArgs: []cobra.Completion{
 		"hyper",
@@ -57,20 +60,36 @@ crush login copilot
 		if len(args) > 0 {
 			provider = args[0]
 		}
+		force, _ := cmd.Flags().GetBool("force")
 		switch provider {
 		case "hyper":
-			return loginHyper(c, ws.ID)
+			return loginHyper(c, ws.ID, force)
 		case "copilot", "github", "github-copilot":
-			return loginCopilot(cmd.Context(), c, ws.ID)
+			return loginCopilot(c, ws.ID, force)
 		default:
 			return fmt.Errorf("unknown platform: %s", args[0])
 		}
 	},
 }
 
-func loginHyper(c *client.Client, wsID string) error {
+func init() {
+	loginCmd.Flags().BoolP("force", "f", false, "Force re-authentication even if already logged in")
+}
+
+func loginHyper(c *client.Client, wsID string, force bool) error {
 	ctx := getLoginContext()
 
+	if !force {
+		cfg, err := c.GetConfig(ctx, wsID)
+		if err == nil && cfg != nil {
+			if pc, ok := cfg.Providers.Get("hyper"); ok && pc.OAuthToken != nil {
+				fmt.Println("You are already logged in to Hyper.")
+				fmt.Println("Use --force to re-authenticate.")
+				return nil
+			}
+		}
+	}
+
 	resp, err := hyper.InitiateDeviceAuth(ctx)
 	if err != nil {
 		return err
@@ -127,14 +146,17 @@ func loginHyper(c *client.Client, wsID string) error {
 	return nil
 }
 
-func loginCopilot(ctx context.Context, c *client.Client, wsID string) error {
+func loginCopilot(c *client.Client, wsID string, force bool) error {
 	loginCtx := getLoginContext()
 
-	cfg, err := c.GetConfig(ctx, wsID)
-	if err == nil && cfg != nil {
-		if pc, ok := cfg.Providers.Get("copilot"); ok && pc.OAuthToken != nil {
-			fmt.Println("You are already logged in to GitHub Copilot.")
-			return nil
+	if !force {
+		cfg, err := c.GetConfig(loginCtx, wsID)
+		if err == nil && cfg != nil {
+			if pc, ok := cfg.Providers.Get("copilot"); ok && pc.OAuthToken != nil {
+				fmt.Println("You are already logged in to GitHub Copilot.")
+				fmt.Println("Use --force to re-authenticate.")
+				return nil
+			}
 		}
 	}
 

internal/cmd/login_test.go 🔗

@@ -0,0 +1,21 @@
+package cmd
+
+import (
+	"testing"
+
+	"github.com/stretchr/testify/require"
+)
+
+func TestLoginCmd_Aliases(t *testing.T) {
+	t.Parallel()
+
+	require.Equal(t, "auth", loginCmd.Aliases[0])
+}
+
+func TestLoginCmd_ForceFlag(t *testing.T) {
+	t.Parallel()
+
+	flag := loginCmd.Flags().Lookup("force")
+	require.NotNil(t, flag)
+	require.Equal(t, "f", flag.Shorthand)
+}

internal/cmd/logout.go 🔗

@@ -0,0 +1,182 @@
+package cmd
+
+import (
+	"cmp"
+	"context"
+	"fmt"
+	"os"
+	"os/signal"
+
+	"charm.land/lipgloss/v2"
+	"github.com/charmbracelet/crush/internal/client"
+	"github.com/charmbracelet/crush/internal/config"
+	"github.com/charmbracelet/x/ansi"
+	"github.com/spf13/cobra"
+)
+
+var (
+	logoutHeaderStyle = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("205"))
+	logoutItemStyle   = lipgloss.NewStyle().Foreground(lipgloss.Color("252"))
+	logoutPromptStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("215"))
+)
+
+var logoutCmd = &cobra.Command{
+	Aliases: []string{"signout"},
+	Use:     "logout [platform]",
+	Short:   "Logout Crush from a platform",
+	Long: `Logout Crush from a specified platform, removing stored credentials.
+The platform should be provided as an argument.
+If no argument is given, a list of logged-in platforms will be shown.
+Available platforms are: hyper, copilot.`,
+	Example: `
+# Sign out from Charm Hyper
+crush logout hyper
+
+# Sign out from GitHub Copilot
+crush logout copilot
+  `,
+	ValidArgs: []cobra.Completion{
+		"hyper",
+		"copilot",
+		"github",
+		"github-copilot",
+	},
+	Args: cobra.MaximumNArgs(1),
+	RunE: func(cmd *cobra.Command, args []string) error {
+		c, ws, cleanup, err := connectToServer(cmd)
+		if err != nil {
+			return err
+		}
+		defer cleanup()
+
+		progressEnabled := ws.Config.Options.Progress == nil || *ws.Config.Options.Progress
+		if progressEnabled && supportsProgressBar() {
+			_, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
+			defer func() { _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar) }()
+		}
+
+		var provider string
+		if len(args) == 0 {
+			provider, err = pickLoggedInProvider(c, ws.ID)
+			if err != nil {
+				return err
+			}
+			if provider == "" {
+				return nil
+			}
+		} else {
+			provider = args[0]
+		}
+
+		force, _ := cmd.Flags().GetBool("force")
+		if !force {
+			fmt.Print(logoutPromptStyle.Render(fmt.Sprintf("Are you sure you want to logout %s? (y/N) ", provider)))
+			var response string
+			_, err := fmt.Scanln(&response)
+			if err != nil || (response != "y" && response != "Y" && response != "yes" && response != "Yes" && response != "YES") {
+				fmt.Println(logoutHeaderStyle.Render("Logout cancelled."))
+				return nil
+			}
+		}
+
+		switch provider {
+		case "hyper":
+			return logoutHyper(c, ws.ID)
+		case "copilot", "github", "github-copilot":
+			return logoutCopilot(c, ws.ID)
+		default:
+			return fmt.Errorf("unknown platform: %s", provider)
+		}
+	},
+}
+
+func logoutHyper(c *client.Client, wsID string) error {
+	ctx := getLogoutContext()
+
+	if err := cmp.Or(
+		c.RemoveConfigField(ctx, wsID, config.ScopeGlobal, "providers.hyper.api_key"),
+		c.RemoveConfigField(ctx, wsID, config.ScopeGlobal, "providers.hyper.oauth"),
+	); err != nil {
+		return err
+	}
+
+	fmt.Println(logoutHeaderStyle.Render("Successfully logged out of Hyper."))
+	return nil
+}
+
+func logoutCopilot(c *client.Client, wsID string) error {
+	ctx := getLogoutContext()
+
+	if err := cmp.Or(
+		c.RemoveConfigField(ctx, wsID, config.ScopeGlobal, "providers.copilot.api_key"),
+		c.RemoveConfigField(ctx, wsID, config.ScopeGlobal, "providers.copilot.oauth"),
+	); err != nil {
+		return err
+	}
+
+	fmt.Println(logoutHeaderStyle.Render("Successfully logged out of GitHub Copilot."))
+	return nil
+}
+
+func pickLoggedInProvider(c *client.Client, wsID string) (string, error) {
+	ctx := getLogoutContext()
+
+	cfg, err := c.GetConfig(ctx, wsID)
+	if err != nil {
+		return "", fmt.Errorf("failed to get config: %w", err)
+	}
+
+	type loggedInProvider struct {
+		id   string
+		name string
+	}
+
+	var loggedIn []loggedInProvider
+	for p := range cfg.Providers.Seq() {
+		if p.OAuthToken != nil || p.APIKey != "" {
+			name := p.Name
+			if name == "" {
+				name = p.ID
+			}
+			loggedIn = append(loggedIn, loggedInProvider{id: p.ID, name: name})
+		}
+	}
+
+	if len(loggedIn) == 0 {
+		fmt.Println(logoutPromptStyle.Render("You are not logged in to any platform."))
+		return "", nil
+	}
+
+	if len(loggedIn) == 1 {
+		return loggedIn[0].id, nil
+	}
+
+	fmt.Println(logoutHeaderStyle.Render("Logged-in platforms:"))
+	for i, p := range loggedIn {
+		fmt.Println(logoutItemStyle.Render(fmt.Sprintf("  %d. %s", i+1, p.name)))
+	}
+	fmt.Print(logoutPromptStyle.Render(fmt.Sprintf("Select a platform to logout (1-%d): ", len(loggedIn))))
+
+	var choice int
+	_, err = fmt.Scanln(&choice)
+	if err != nil || choice < 1 || choice > len(loggedIn) {
+		fmt.Println(logoutHeaderStyle.Render("Logout cancelled."))
+		return "", nil
+	}
+
+	return loggedIn[choice-1].id, nil
+}
+
+func init() {
+	logoutCmd.Flags().BoolP("force", "f", false, "Skip logout confirmation prompt")
+}
+
+func getLogoutContext() context.Context {
+	ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
+	go func() {
+		<-ctx.Done()
+		cancel()
+		os.Exit(1)
+	}()
+	return ctx
+}

internal/cmd/logout_test.go 🔗

@@ -0,0 +1,41 @@
+package cmd
+
+import (
+	"testing"
+
+	"github.com/stretchr/testify/require"
+)
+
+func TestLogoutCmd_Aliases(t *testing.T) {
+	t.Parallel()
+
+	require.Equal(t, "signout", logoutCmd.Aliases[0])
+}
+
+func TestLogoutCmd_HasForceFlag(t *testing.T) {
+	t.Parallel()
+
+	flag := logoutCmd.Flags().Lookup("force")
+	require.NotNil(t, flag)
+	require.Equal(t, "f", flag.Shorthand)
+	require.Equal(t, "false", flag.DefValue)
+}
+
+func TestLogoutCmd_ValidArgs(t *testing.T) {
+	t.Parallel()
+
+	validPlatforms := map[string]bool{}
+	for _, p := range logoutCmd.ValidArgs {
+		validPlatforms[p] = true
+	}
+	require.True(t, validPlatforms["hyper"])
+	require.True(t, validPlatforms["copilot"])
+	require.True(t, validPlatforms["github"])
+	require.True(t, validPlatforms["github-copilot"])
+}
+
+func TestLogoutContext_CreatesValidContext(t *testing.T) {
+	ctx := getLogoutContext()
+	require.NotNil(t, ctx)
+	require.NoError(t, ctx.Err())
+}

internal/cmd/root.go 🔗

@@ -63,6 +63,7 @@ func init() {
 		projectsCmd,
 		updateProvidersCmd,
 		logsCmd,
+		logoutCmd,
 		schemaCmd,
 		loginCmd,
 		statsCmd,