login.go

  1package cmd
  2
  3import (
  4	"cmp"
  5	"context"
  6	"fmt"
  7	"os"
  8	"os/signal"
  9
 10	"charm.land/lipgloss/v2"
 11	"github.com/atotto/clipboard"
 12	"github.com/charmbracelet/crush/internal/client"
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/oauth"
 15	"github.com/charmbracelet/crush/internal/oauth/copilot"
 16	"github.com/charmbracelet/crush/internal/oauth/hyper"
 17	"github.com/charmbracelet/x/ansi"
 18	"github.com/pkg/browser"
 19	"github.com/spf13/cobra"
 20)
 21
 22var loginCmd = &cobra.Command{
 23	Aliases: []string{"auth"},
 24	Use:     "login [platform]",
 25	Short:   "Login Crush to a platform",
 26	Long: `Login Crush to a specified platform.
 27The platform should be provided as an argument.
 28Available platforms are: hyper, copilot.`,
 29	Example: `
 30# Authenticate with Charm Hyper
 31crush login
 32
 33# Authenticate with GitHub Copilot
 34crush login copilot
 35  `,
 36	ValidArgs: []cobra.Completion{
 37		"hyper",
 38		"copilot",
 39		"github",
 40		"github-copilot",
 41	},
 42	Args: cobra.MaximumNArgs(1),
 43	RunE: func(cmd *cobra.Command, args []string) error {
 44		c, ws, cleanup, err := connectToServer(cmd)
 45		if err != nil {
 46			return err
 47		}
 48		defer cleanup()
 49
 50		progressEnabled := ws.Config.Options.Progress == nil || *ws.Config.Options.Progress
 51		if progressEnabled && supportsProgressBar() {
 52			_, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
 53			defer func() { _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar) }()
 54		}
 55
 56		provider := "hyper"
 57		if len(args) > 0 {
 58			provider = args[0]
 59		}
 60		switch provider {
 61		case "hyper":
 62			return loginHyper(c, ws.ID)
 63		case "copilot", "github", "github-copilot":
 64			return loginCopilot(cmd.Context(), c, ws.ID)
 65		default:
 66			return fmt.Errorf("unknown platform: %s", args[0])
 67		}
 68	},
 69}
 70
 71func loginHyper(c *client.Client, wsID string) error {
 72	ctx := getLoginContext()
 73
 74	resp, err := hyper.InitiateDeviceAuth(ctx)
 75	if err != nil {
 76		return err
 77	}
 78
 79	if clipboard.WriteAll(resp.UserCode) == nil {
 80		fmt.Println("The following code should be on clipboard already:")
 81	} else {
 82		fmt.Println("Copy the following code:")
 83	}
 84
 85	fmt.Println()
 86	fmt.Println(lipgloss.NewStyle().Bold(true).Render(resp.UserCode))
 87	fmt.Println()
 88	fmt.Println("Press enter to open this URL, and then paste it there:")
 89	fmt.Println()
 90	fmt.Println(lipgloss.NewStyle().Hyperlink(resp.VerificationURL, "id=hyper").Render(resp.VerificationURL))
 91	fmt.Println()
 92	waitEnter()
 93	if err := browser.OpenURL(resp.VerificationURL); err != nil {
 94		fmt.Println("Could not open the URL. You'll need to manually open the URL in your browser.")
 95	}
 96
 97	fmt.Println("Exchanging authorization code...")
 98	refreshToken, err := hyper.PollForToken(ctx, resp.DeviceCode, resp.ExpiresIn)
 99	if err != nil {
100		return err
101	}
102
103	fmt.Println("Exchanging refresh token for access token...")
104	token, err := hyper.ExchangeToken(ctx, refreshToken)
105	if err != nil {
106		return err
107	}
108
109	fmt.Println("Verifying access token...")
110	introspect, err := hyper.IntrospectToken(ctx, token.AccessToken)
111	if err != nil {
112		return fmt.Errorf("token introspection failed: %w", err)
113	}
114	if !introspect.Active {
115		return fmt.Errorf("access token is not active")
116	}
117
118	if err := cmp.Or(
119		c.SetConfigField(ctx, wsID, config.ScopeGlobal, "providers.hyper.api_key", token.AccessToken),
120		c.SetConfigField(ctx, wsID, config.ScopeGlobal, "providers.hyper.oauth", token),
121	); err != nil {
122		return err
123	}
124
125	fmt.Println()
126	fmt.Println("You're now authenticated with Hyper!")
127	return nil
128}
129
130func loginCopilot(ctx context.Context, c *client.Client, wsID string) error {
131	loginCtx := getLoginContext()
132
133	cfg, err := c.GetConfig(ctx, wsID)
134	if err == nil && cfg != nil {
135		if pc, ok := cfg.Providers.Get("copilot"); ok && pc.OAuthToken != nil {
136			fmt.Println("You are already logged in to GitHub Copilot.")
137			return nil
138		}
139	}
140
141	diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
142	var token *oauth.Token
143
144	switch {
145	case hasDiskToken:
146		fmt.Println("Found existing GitHub Copilot token on disk. Using it to authenticate...")
147
148		t, err := copilot.RefreshToken(loginCtx, diskToken)
149		if err != nil {
150			return fmt.Errorf("unable to refresh token from disk: %w", err)
151		}
152		token = t
153	default:
154		fmt.Println("Requesting device code from GitHub...")
155		dc, err := copilot.RequestDeviceCode(loginCtx)
156		if err != nil {
157			return err
158		}
159
160		fmt.Println()
161		fmt.Println("Open the following URL and follow the instructions to authenticate with GitHub Copilot:")
162		fmt.Println()
163		fmt.Println(lipgloss.NewStyle().Hyperlink(dc.VerificationURI, "id=copilot").Render(dc.VerificationURI))
164		fmt.Println()
165		fmt.Println("Code:", lipgloss.NewStyle().Bold(true).Render(dc.UserCode))
166		fmt.Println()
167		fmt.Println("Waiting for authorization...")
168
169		t, err := copilot.PollForToken(loginCtx, dc)
170		if err == copilot.ErrNotAvailable {
171			fmt.Println()
172			fmt.Println("GitHub Copilot is unavailable for this account. To signup, go to the following page:")
173			fmt.Println()
174			fmt.Println(lipgloss.NewStyle().Hyperlink(copilot.SignupURL, "id=copilot-signup").Render(copilot.SignupURL))
175			fmt.Println()
176			fmt.Println("You may be able to request free access if eligible. For more information, see:")
177			fmt.Println()
178			fmt.Println(lipgloss.NewStyle().Hyperlink(copilot.FreeURL, "id=copilot-free").Render(copilot.FreeURL))
179		}
180		if err != nil {
181			return err
182		}
183		token = t
184	}
185
186	if err := cmp.Or(
187		c.SetConfigField(loginCtx, wsID, config.ScopeGlobal, "providers.copilot.api_key", token.AccessToken),
188		c.SetConfigField(loginCtx, wsID, config.ScopeGlobal, "providers.copilot.oauth", token),
189	); err != nil {
190		return err
191	}
192
193	fmt.Println()
194	fmt.Println("You're now authenticated with GitHub Copilot!")
195	return nil
196}
197
198func getLoginContext() context.Context {
199	ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
200	go func() {
201		<-ctx.Done()
202		cancel()
203		os.Exit(1)
204	}()
205	return ctx
206}
207
208func waitEnter() {
209	_, _ = fmt.Scanln()
210}