logout.go

  1package cmd
  2
  3import (
  4	"cmp"
  5	"context"
  6	"fmt"
  7	"os"
  8	"os/signal"
  9
 10	"charm.land/lipgloss/v2"
 11	"github.com/charmbracelet/crush/internal/client"
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/x/ansi"
 14	"github.com/spf13/cobra"
 15)
 16
 17var (
 18	logoutHeaderStyle = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("205"))
 19	logoutItemStyle   = lipgloss.NewStyle().Foreground(lipgloss.Color("252"))
 20	logoutPromptStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("215"))
 21)
 22
 23var logoutCmd = &cobra.Command{
 24	Aliases: []string{"signout"},
 25	Use:     "logout [platform]",
 26	Short:   "Logout Crush from a platform",
 27	Long: `Logout Crush from a specified platform, removing stored credentials.
 28The platform should be provided as an argument.
 29If no argument is given, a list of logged-in platforms will be shown.
 30Available platforms are: hyper, copilot.`,
 31	Example: `
 32# Sign out from Charm Hyper
 33crush logout hyper
 34
 35# Sign out from GitHub Copilot
 36crush logout copilot
 37  `,
 38	ValidArgs: []cobra.Completion{
 39		"hyper",
 40		"copilot",
 41		"github",
 42		"github-copilot",
 43	},
 44	Args: cobra.MaximumNArgs(1),
 45	RunE: func(cmd *cobra.Command, args []string) error {
 46		c, ws, cleanup, err := connectToServer(cmd)
 47		if err != nil {
 48			return err
 49		}
 50		defer cleanup()
 51
 52		progressEnabled := ws.Config.Options.Progress == nil || *ws.Config.Options.Progress
 53		if progressEnabled && supportsProgressBar() {
 54			_, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
 55			defer func() { _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar) }()
 56		}
 57
 58		var provider string
 59		if len(args) == 0 {
 60			provider, err = pickLoggedInProvider(c, ws.ID)
 61			if err != nil {
 62				return err
 63			}
 64			if provider == "" {
 65				return nil
 66			}
 67		} else {
 68			provider = args[0]
 69		}
 70
 71		force, _ := cmd.Flags().GetBool("force")
 72		if !force {
 73			fmt.Print(logoutPromptStyle.Render(fmt.Sprintf("Are you sure you want to logout %s? (y/N) ", provider)))
 74			var response string
 75			_, err := fmt.Scanln(&response)
 76			if err != nil || (response != "y" && response != "Y" && response != "yes" && response != "Yes" && response != "YES") {
 77				fmt.Println(logoutHeaderStyle.Render("Logout cancelled."))
 78				return nil
 79			}
 80		}
 81
 82		switch provider {
 83		case "hyper":
 84			return logoutHyper(c, ws.ID)
 85		case "copilot", "github", "github-copilot":
 86			return logoutCopilot(c, ws.ID)
 87		default:
 88			return fmt.Errorf("unknown platform: %s", provider)
 89		}
 90	},
 91}
 92
 93func logoutHyper(c *client.Client, wsID string) error {
 94	ctx := getLogoutContext()
 95
 96	if err := cmp.Or(
 97		c.RemoveConfigField(ctx, wsID, config.ScopeGlobal, "providers.hyper.api_key"),
 98		c.RemoveConfigField(ctx, wsID, config.ScopeGlobal, "providers.hyper.oauth"),
 99	); err != nil {
100		return err
101	}
102
103	fmt.Println(logoutHeaderStyle.Render("Successfully logged out of Hyper."))
104	return nil
105}
106
107func logoutCopilot(c *client.Client, wsID string) error {
108	ctx := getLogoutContext()
109
110	if err := cmp.Or(
111		c.RemoveConfigField(ctx, wsID, config.ScopeGlobal, "providers.copilot.api_key"),
112		c.RemoveConfigField(ctx, wsID, config.ScopeGlobal, "providers.copilot.oauth"),
113	); err != nil {
114		return err
115	}
116
117	fmt.Println(logoutHeaderStyle.Render("Successfully logged out of GitHub Copilot."))
118	return nil
119}
120
121func pickLoggedInProvider(c *client.Client, wsID string) (string, error) {
122	ctx := getLogoutContext()
123
124	cfg, err := c.GetConfig(ctx, wsID)
125	if err != nil {
126		return "", fmt.Errorf("failed to get config: %w", err)
127	}
128
129	type loggedInProvider struct {
130		id   string
131		name string
132	}
133
134	var loggedIn []loggedInProvider
135	for p := range cfg.Providers.Seq() {
136		if p.OAuthToken != nil || p.APIKey != "" {
137			name := p.Name
138			if name == "" {
139				name = p.ID
140			}
141			loggedIn = append(loggedIn, loggedInProvider{id: p.ID, name: name})
142		}
143	}
144
145	if len(loggedIn) == 0 {
146		fmt.Println(logoutPromptStyle.Render("You are not logged in to any platform."))
147		return "", nil
148	}
149
150	if len(loggedIn) == 1 {
151		return loggedIn[0].id, nil
152	}
153
154	fmt.Println(logoutHeaderStyle.Render("Logged-in platforms:"))
155	for i, p := range loggedIn {
156		fmt.Println(logoutItemStyle.Render(fmt.Sprintf("  %d. %s", i+1, p.name)))
157	}
158	fmt.Print(logoutPromptStyle.Render(fmt.Sprintf("Select a platform to logout (1-%d): ", len(loggedIn))))
159
160	var choice int
161	_, err = fmt.Scanln(&choice)
162	if err != nil || choice < 1 || choice > len(loggedIn) {
163		fmt.Println(logoutHeaderStyle.Render("Logout cancelled."))
164		return "", nil
165	}
166
167	return loggedIn[choice-1].id, nil
168}
169
170func init() {
171	logoutCmd.Flags().BoolP("force", "f", false, "Skip logout confirmation prompt")
172}
173
174func getLogoutContext() context.Context {
175	ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
176	go func() {
177		<-ctx.Done()
178		cancel()
179		os.Exit(1)
180	}()
181	return ctx
182}