login.go

  1package cmd
  2
  3import (
  4	"cmp"
  5	"context"
  6	"fmt"
  7	"os"
  8	"os/signal"
  9
 10	"charm.land/lipgloss/v2"
 11	"github.com/atotto/clipboard"
 12	"github.com/charmbracelet/crush/internal/client"
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/oauth"
 15	"github.com/charmbracelet/crush/internal/oauth/copilot"
 16	"github.com/charmbracelet/crush/internal/oauth/hyper"
 17	"github.com/charmbracelet/x/ansi"
 18	"github.com/pkg/browser"
 19	"github.com/spf13/cobra"
 20)
 21
 22var loginCmd = &cobra.Command{
 23	Aliases: []string{"auth"},
 24	Use:     "login [platform]",
 25	Short:   "Login Crush to a platform",
 26	Long: `Login Crush to a specified platform.
 27The platform should be provided as an argument.
 28Available platforms are: hyper, copilot.`,
 29	Example: `
 30# Authenticate with Charm Hyper
 31crush login
 32
 33# Authenticate with GitHub Copilot
 34crush login copilot
 35
 36# Force re-authentication even if already logged in
 37crush login -f copilot
 38  `,
 39	ValidArgs: []cobra.Completion{
 40		"hyper",
 41		"copilot",
 42		"github",
 43		"github-copilot",
 44	},
 45	Args: cobra.MaximumNArgs(1),
 46	RunE: func(cmd *cobra.Command, args []string) error {
 47		c, ws, cleanup, err := connectToServer(cmd)
 48		if err != nil {
 49			return err
 50		}
 51		defer cleanup()
 52
 53		progressEnabled := ws.Config.Options.Progress == nil || *ws.Config.Options.Progress
 54		if progressEnabled && supportsProgressBar() {
 55			_, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
 56			defer func() { _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar) }()
 57		}
 58
 59		provider := "hyper"
 60		if len(args) > 0 {
 61			provider = args[0]
 62		}
 63		force, _ := cmd.Flags().GetBool("force")
 64		switch provider {
 65		case "hyper":
 66			return loginHyper(c, ws.ID, force)
 67		case "copilot", "github", "github-copilot":
 68			return loginCopilot(c, ws.ID, force)
 69		default:
 70			return fmt.Errorf("unknown platform: %s", args[0])
 71		}
 72	},
 73}
 74
 75func init() {
 76	loginCmd.Flags().BoolP("force", "f", false, "Force re-authentication even if already logged in")
 77}
 78
 79func loginHyper(c *client.Client, wsID string, force bool) error {
 80	ctx := getLoginContext()
 81
 82	if !force {
 83		cfg, err := c.GetConfig(ctx, wsID)
 84		if err == nil && cfg != nil {
 85			if pc, ok := cfg.Providers.Get("hyper"); ok && pc.OAuthToken != nil {
 86				fmt.Println("You are already logged in to Hyper.")
 87				fmt.Println("Use --force to re-authenticate.")
 88				return nil
 89			}
 90		}
 91	}
 92
 93	resp, err := hyper.InitiateDeviceAuth(ctx)
 94	if err != nil {
 95		return err
 96	}
 97
 98	if clipboard.WriteAll(resp.UserCode) == nil {
 99		fmt.Println("The following code should be on clipboard already:")
100	} else {
101		fmt.Println("Copy the following code:")
102	}
103
104	fmt.Println()
105	fmt.Println(lipgloss.NewStyle().Bold(true).Render(resp.UserCode))
106	fmt.Println()
107	fmt.Println("Press enter to open this URL, and then paste it there:")
108	fmt.Println()
109	fmt.Println(lipgloss.NewStyle().Hyperlink(resp.VerificationURL, "id=hyper").Render(resp.VerificationURL))
110	fmt.Println()
111	waitEnter()
112	if err := browser.OpenURL(resp.VerificationURL); err != nil {
113		fmt.Println("Could not open the URL. You'll need to manually open the URL in your browser.")
114	}
115
116	fmt.Println("Exchanging authorization code...")
117	refreshToken, err := hyper.PollForToken(ctx, resp.DeviceCode, resp.ExpiresIn)
118	if err != nil {
119		return err
120	}
121
122	fmt.Println("Exchanging refresh token for access token...")
123	token, err := hyper.ExchangeToken(ctx, refreshToken)
124	if err != nil {
125		return err
126	}
127
128	fmt.Println("Verifying access token...")
129	introspect, err := hyper.IntrospectToken(ctx, token.AccessToken)
130	if err != nil {
131		return fmt.Errorf("token introspection failed: %w", err)
132	}
133	if !introspect.Active {
134		return fmt.Errorf("access token is not active")
135	}
136
137	if err := cmp.Or(
138		c.SetConfigField(ctx, wsID, config.ScopeGlobal, "providers.hyper.api_key", token.AccessToken),
139		c.SetConfigField(ctx, wsID, config.ScopeGlobal, "providers.hyper.oauth", token),
140	); err != nil {
141		return err
142	}
143
144	fmt.Println()
145	fmt.Println("You're now authenticated with Hyper!")
146	return nil
147}
148
149func loginCopilot(c *client.Client, wsID string, force bool) error {
150	loginCtx := getLoginContext()
151
152	if !force {
153		cfg, err := c.GetConfig(loginCtx, wsID)
154		if err == nil && cfg != nil {
155			if pc, ok := cfg.Providers.Get("copilot"); ok && pc.OAuthToken != nil {
156				fmt.Println("You are already logged in to GitHub Copilot.")
157				fmt.Println("Use --force to re-authenticate.")
158				return nil
159			}
160		}
161	}
162
163	diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
164	var token *oauth.Token
165
166	switch {
167	case hasDiskToken:
168		fmt.Println("Found existing GitHub Copilot token on disk. Using it to authenticate...")
169
170		t, err := copilot.RefreshToken(loginCtx, diskToken)
171		if err != nil {
172			return fmt.Errorf("unable to refresh token from disk: %w", err)
173		}
174		token = t
175	default:
176		fmt.Println("Requesting device code from GitHub...")
177		dc, err := copilot.RequestDeviceCode(loginCtx)
178		if err != nil {
179			return err
180		}
181
182		fmt.Println()
183		fmt.Println("Open the following URL and follow the instructions to authenticate with GitHub Copilot:")
184		fmt.Println()
185		fmt.Println(lipgloss.NewStyle().Hyperlink(dc.VerificationURI, "id=copilot").Render(dc.VerificationURI))
186		fmt.Println()
187		fmt.Println("Code:", lipgloss.NewStyle().Bold(true).Render(dc.UserCode))
188		fmt.Println()
189		fmt.Println("Waiting for authorization...")
190
191		t, err := copilot.PollForToken(loginCtx, dc)
192		if err == copilot.ErrNotAvailable {
193			fmt.Println()
194			fmt.Println("GitHub Copilot is unavailable for this account. To signup, go to the following page:")
195			fmt.Println()
196			fmt.Println(lipgloss.NewStyle().Hyperlink(copilot.SignupURL, "id=copilot-signup").Render(copilot.SignupURL))
197			fmt.Println()
198			fmt.Println("You may be able to request free access if eligible. For more information, see:")
199			fmt.Println()
200			fmt.Println(lipgloss.NewStyle().Hyperlink(copilot.FreeURL, "id=copilot-free").Render(copilot.FreeURL))
201		}
202		if err != nil {
203			return err
204		}
205		token = t
206	}
207
208	if err := cmp.Or(
209		c.SetConfigField(loginCtx, wsID, config.ScopeGlobal, "providers.copilot.api_key", token.AccessToken),
210		c.SetConfigField(loginCtx, wsID, config.ScopeGlobal, "providers.copilot.oauth", token),
211	); err != nil {
212		return err
213	}
214
215	fmt.Println()
216	fmt.Println("You're now authenticated with GitHub Copilot!")
217	return nil
218}
219
220func getLoginContext() context.Context {
221	ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
222	go func() {
223		<-ctx.Done()
224		cancel()
225		os.Exit(1)
226	}()
227	return ctx
228}
229
230func waitEnter() {
231	_, _ = fmt.Scanln()
232}