@@ -32,6 +32,9 @@ crush login
# Authenticate with GitHub Copilot
crush login copilot
+
+# Force re-authentication even if already logged in
+crush login -f copilot
`,
ValidArgs: []cobra.Completion{
"hyper",
@@ -57,20 +60,36 @@ crush login copilot
if len(args) > 0 {
provider = args[0]
}
+ force, _ := cmd.Flags().GetBool("force")
switch provider {
case "hyper":
- return loginHyper(c, ws.ID)
+ return loginHyper(c, ws.ID, force)
case "copilot", "github", "github-copilot":
- return loginCopilot(cmd.Context(), c, ws.ID)
+ return loginCopilot(c, ws.ID, force)
default:
return fmt.Errorf("unknown platform: %s", args[0])
}
},
}
-func loginHyper(c *client.Client, wsID string) error {
+func init() {
+ loginCmd.Flags().BoolP("force", "f", false, "Force re-authentication even if already logged in")
+}
+
+func loginHyper(c *client.Client, wsID string, force bool) error {
ctx := getLoginContext()
+ if !force {
+ cfg, err := c.GetConfig(ctx, wsID)
+ if err == nil && cfg != nil {
+ if pc, ok := cfg.Providers.Get("hyper"); ok && pc.OAuthToken != nil {
+ fmt.Println("You are already logged in to Hyper.")
+ fmt.Println("Use --force to re-authenticate.")
+ return nil
+ }
+ }
+ }
+
resp, err := hyper.InitiateDeviceAuth(ctx)
if err != nil {
return err
@@ -127,14 +146,17 @@ func loginHyper(c *client.Client, wsID string) error {
return nil
}
-func loginCopilot(ctx context.Context, c *client.Client, wsID string) error {
+func loginCopilot(c *client.Client, wsID string, force bool) error {
loginCtx := getLoginContext()
- cfg, err := c.GetConfig(ctx, wsID)
- if err == nil && cfg != nil {
- if pc, ok := cfg.Providers.Get("copilot"); ok && pc.OAuthToken != nil {
- fmt.Println("You are already logged in to GitHub Copilot.")
- return nil
+ if !force {
+ cfg, err := c.GetConfig(loginCtx, wsID)
+ if err == nil && cfg != nil {
+ if pc, ok := cfg.Providers.Get("copilot"); ok && pc.OAuthToken != nil {
+ fmt.Println("You are already logged in to GitHub Copilot.")
+ fmt.Println("Use --force to re-authenticate.")
+ return nil
+ }
}
}
@@ -0,0 +1,182 @@
+package cmd
+
+import (
+ "cmp"
+ "context"
+ "fmt"
+ "os"
+ "os/signal"
+
+ "charm.land/lipgloss/v2"
+ "github.com/charmbracelet/crush/internal/client"
+ "github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/x/ansi"
+ "github.com/spf13/cobra"
+)
+
+var (
+ logoutHeaderStyle = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("205"))
+ logoutItemStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("252"))
+ logoutPromptStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("215"))
+)
+
+var logoutCmd = &cobra.Command{
+ Aliases: []string{"signout"},
+ Use: "logout [platform]",
+ Short: "Logout Crush from a platform",
+ Long: `Logout Crush from a specified platform, removing stored credentials.
+The platform should be provided as an argument.
+If no argument is given, a list of logged-in platforms will be shown.
+Available platforms are: hyper, copilot.`,
+ Example: `
+# Sign out from Charm Hyper
+crush logout hyper
+
+# Sign out from GitHub Copilot
+crush logout copilot
+ `,
+ ValidArgs: []cobra.Completion{
+ "hyper",
+ "copilot",
+ "github",
+ "github-copilot",
+ },
+ Args: cobra.MaximumNArgs(1),
+ RunE: func(cmd *cobra.Command, args []string) error {
+ c, ws, cleanup, err := connectToServer(cmd)
+ if err != nil {
+ return err
+ }
+ defer cleanup()
+
+ progressEnabled := ws.Config.Options.Progress == nil || *ws.Config.Options.Progress
+ if progressEnabled && supportsProgressBar() {
+ _, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
+ defer func() { _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar) }()
+ }
+
+ var provider string
+ if len(args) == 0 {
+ provider, err = pickLoggedInProvider(c, ws.ID)
+ if err != nil {
+ return err
+ }
+ if provider == "" {
+ return nil
+ }
+ } else {
+ provider = args[0]
+ }
+
+ force, _ := cmd.Flags().GetBool("force")
+ if !force {
+ fmt.Print(logoutPromptStyle.Render(fmt.Sprintf("Are you sure you want to logout %s? (y/N) ", provider)))
+ var response string
+ _, err := fmt.Scanln(&response)
+ if err != nil || (response != "y" && response != "Y" && response != "yes" && response != "Yes" && response != "YES") {
+ fmt.Println(logoutHeaderStyle.Render("Logout cancelled."))
+ return nil
+ }
+ }
+
+ switch provider {
+ case "hyper":
+ return logoutHyper(c, ws.ID)
+ case "copilot", "github", "github-copilot":
+ return logoutCopilot(c, ws.ID)
+ default:
+ return fmt.Errorf("unknown platform: %s", provider)
+ }
+ },
+}
+
+func logoutHyper(c *client.Client, wsID string) error {
+ ctx := getLogoutContext()
+
+ if err := cmp.Or(
+ c.RemoveConfigField(ctx, wsID, config.ScopeGlobal, "providers.hyper.api_key"),
+ c.RemoveConfigField(ctx, wsID, config.ScopeGlobal, "providers.hyper.oauth"),
+ ); err != nil {
+ return err
+ }
+
+ fmt.Println(logoutHeaderStyle.Render("Successfully logged out of Hyper."))
+ return nil
+}
+
+func logoutCopilot(c *client.Client, wsID string) error {
+ ctx := getLogoutContext()
+
+ if err := cmp.Or(
+ c.RemoveConfigField(ctx, wsID, config.ScopeGlobal, "providers.copilot.api_key"),
+ c.RemoveConfigField(ctx, wsID, config.ScopeGlobal, "providers.copilot.oauth"),
+ ); err != nil {
+ return err
+ }
+
+ fmt.Println(logoutHeaderStyle.Render("Successfully logged out of GitHub Copilot."))
+ return nil
+}
+
+func pickLoggedInProvider(c *client.Client, wsID string) (string, error) {
+ ctx := getLogoutContext()
+
+ cfg, err := c.GetConfig(ctx, wsID)
+ if err != nil {
+ return "", fmt.Errorf("failed to get config: %w", err)
+ }
+
+ type loggedInProvider struct {
+ id string
+ name string
+ }
+
+ var loggedIn []loggedInProvider
+ for p := range cfg.Providers.Seq() {
+ if p.OAuthToken != nil || p.APIKey != "" {
+ name := p.Name
+ if name == "" {
+ name = p.ID
+ }
+ loggedIn = append(loggedIn, loggedInProvider{id: p.ID, name: name})
+ }
+ }
+
+ if len(loggedIn) == 0 {
+ fmt.Println(logoutPromptStyle.Render("You are not logged in to any platform."))
+ return "", nil
+ }
+
+ if len(loggedIn) == 1 {
+ return loggedIn[0].id, nil
+ }
+
+ fmt.Println(logoutHeaderStyle.Render("Logged-in platforms:"))
+ for i, p := range loggedIn {
+ fmt.Println(logoutItemStyle.Render(fmt.Sprintf(" %d. %s", i+1, p.name)))
+ }
+ fmt.Print(logoutPromptStyle.Render(fmt.Sprintf("Select a platform to logout (1-%d): ", len(loggedIn))))
+
+ var choice int
+ _, err = fmt.Scanln(&choice)
+ if err != nil || choice < 1 || choice > len(loggedIn) {
+ fmt.Println(logoutHeaderStyle.Render("Logout cancelled."))
+ return "", nil
+ }
+
+ return loggedIn[choice-1].id, nil
+}
+
+func init() {
+ logoutCmd.Flags().BoolP("force", "f", false, "Skip logout confirmation prompt")
+}
+
+func getLogoutContext() context.Context {
+ ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
+ go func() {
+ <-ctx.Done()
+ cancel()
+ os.Exit(1)
+ }()
+ return ctx
+}
@@ -0,0 +1,41 @@
+package cmd
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestLogoutCmd_Aliases(t *testing.T) {
+ t.Parallel()
+
+ require.Equal(t, "signout", logoutCmd.Aliases[0])
+}
+
+func TestLogoutCmd_HasForceFlag(t *testing.T) {
+ t.Parallel()
+
+ flag := logoutCmd.Flags().Lookup("force")
+ require.NotNil(t, flag)
+ require.Equal(t, "f", flag.Shorthand)
+ require.Equal(t, "false", flag.DefValue)
+}
+
+func TestLogoutCmd_ValidArgs(t *testing.T) {
+ t.Parallel()
+
+ validPlatforms := map[string]bool{}
+ for _, p := range logoutCmd.ValidArgs {
+ validPlatforms[p] = true
+ }
+ require.True(t, validPlatforms["hyper"])
+ require.True(t, validPlatforms["copilot"])
+ require.True(t, validPlatforms["github"])
+ require.True(t, validPlatforms["github-copilot"])
+}
+
+func TestLogoutContext_CreatesValidContext(t *testing.T) {
+ ctx := getLogoutContext()
+ require.NotNil(t, ctx)
+ require.NoError(t, ctx.Err())
+}