login.go

  1package cmd
  2
  3import (
  4	"cmp"
  5	"context"
  6	"fmt"
  7	"os"
  8	"os/signal"
  9
 10	"charm.land/lipgloss/v2"
 11	"github.com/atotto/clipboard"
 12	hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
 13	"github.com/charmbracelet/crush/internal/client"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/oauth"
 16	"github.com/charmbracelet/crush/internal/oauth/copilot"
 17	"github.com/charmbracelet/crush/internal/oauth/hyper"
 18	"github.com/charmbracelet/x/ansi"
 19	"github.com/pkg/browser"
 20	"github.com/spf13/cobra"
 21)
 22
 23var loginCmd = &cobra.Command{
 24	Aliases: []string{"auth"},
 25	Use:     "login [platform]",
 26	Short:   "Login Crush to a platform",
 27	Long: `Login Crush to a specified platform.
 28The platform should be provided as an argument.
 29Available platforms are: hyper, copilot.`,
 30	Example: `
 31# Authenticate with Charm Hyper
 32crush login
 33
 34# Authenticate with GitHub Copilot
 35crush login copilot
 36  `,
 37	ValidArgs: []cobra.Completion{
 38		"hyper",
 39		"copilot",
 40		"github",
 41		"github-copilot",
 42	},
 43	Args: cobra.MaximumNArgs(1),
 44	RunE: func(cmd *cobra.Command, args []string) error {
 45		c, ws, cleanup, err := connectToServer(cmd)
 46		if err != nil {
 47			return err
 48		}
 49		defer cleanup()
 50
 51		progressEnabled := ws.Config.Options.Progress == nil || *ws.Config.Options.Progress
 52		if progressEnabled && supportsProgressBar() {
 53			_, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
 54			defer func() { _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar) }()
 55		}
 56
 57		provider := "hyper"
 58		if len(args) > 0 {
 59			provider = args[0]
 60		}
 61		switch provider {
 62		case "hyper":
 63			return loginHyper(c, ws.ID)
 64		case "copilot", "github", "github-copilot":
 65			return loginCopilot(cmd.Context(), c, ws.ID)
 66		default:
 67			return fmt.Errorf("unknown platform: %s", args[0])
 68		}
 69	},
 70}
 71
 72func loginHyper(c *client.Client, wsID string) error {
 73	if !hyperp.Enabled() {
 74		return fmt.Errorf("hyper not enabled")
 75	}
 76	ctx := getLoginContext()
 77
 78	resp, err := hyper.InitiateDeviceAuth(ctx)
 79	if err != nil {
 80		return err
 81	}
 82
 83	if clipboard.WriteAll(resp.UserCode) == nil {
 84		fmt.Println("The following code should be on clipboard already:")
 85	} else {
 86		fmt.Println("Copy the following code:")
 87	}
 88
 89	fmt.Println()
 90	fmt.Println(lipgloss.NewStyle().Bold(true).Render(resp.UserCode))
 91	fmt.Println()
 92	fmt.Println("Press enter to open this URL, and then paste it there:")
 93	fmt.Println()
 94	fmt.Println(lipgloss.NewStyle().Hyperlink(resp.VerificationURL, "id=hyper").Render(resp.VerificationURL))
 95	fmt.Println()
 96	waitEnter()
 97	if err := browser.OpenURL(resp.VerificationURL); err != nil {
 98		fmt.Println("Could not open the URL. You'll need to manually open the URL in your browser.")
 99	}
100
101	fmt.Println("Exchanging authorization code...")
102	refreshToken, err := hyper.PollForToken(ctx, resp.DeviceCode, resp.ExpiresIn)
103	if err != nil {
104		return err
105	}
106
107	fmt.Println("Exchanging refresh token for access token...")
108	token, err := hyper.ExchangeToken(ctx, refreshToken)
109	if err != nil {
110		return err
111	}
112
113	fmt.Println("Verifying access token...")
114	introspect, err := hyper.IntrospectToken(ctx, token.AccessToken)
115	if err != nil {
116		return fmt.Errorf("token introspection failed: %w", err)
117	}
118	if !introspect.Active {
119		return fmt.Errorf("access token is not active")
120	}
121
122	if err := cmp.Or(
123		c.SetConfigField(ctx, wsID, config.ScopeGlobal, "providers.hyper.api_key", token.AccessToken),
124		c.SetConfigField(ctx, wsID, config.ScopeGlobal, "providers.hyper.oauth", token),
125	); err != nil {
126		return err
127	}
128
129	fmt.Println()
130	fmt.Println("You're now authenticated with Hyper!")
131	return nil
132}
133
134func loginCopilot(ctx context.Context, c *client.Client, wsID string) error {
135	loginCtx := getLoginContext()
136
137	cfg, err := c.GetConfig(ctx, wsID)
138	if err == nil && cfg != nil {
139		if pc, ok := cfg.Providers.Get("copilot"); ok && pc.OAuthToken != nil {
140			fmt.Println("You are already logged in to GitHub Copilot.")
141			return nil
142		}
143	}
144
145	diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
146	var token *oauth.Token
147
148	switch {
149	case hasDiskToken:
150		fmt.Println("Found existing GitHub Copilot token on disk. Using it to authenticate...")
151
152		t, err := copilot.RefreshToken(loginCtx, diskToken)
153		if err != nil {
154			return fmt.Errorf("unable to refresh token from disk: %w", err)
155		}
156		token = t
157	default:
158		fmt.Println("Requesting device code from GitHub...")
159		dc, err := copilot.RequestDeviceCode(loginCtx)
160		if err != nil {
161			return err
162		}
163
164		fmt.Println()
165		fmt.Println("Open the following URL and follow the instructions to authenticate with GitHub Copilot:")
166		fmt.Println()
167		fmt.Println(lipgloss.NewStyle().Hyperlink(dc.VerificationURI, "id=copilot").Render(dc.VerificationURI))
168		fmt.Println()
169		fmt.Println("Code:", lipgloss.NewStyle().Bold(true).Render(dc.UserCode))
170		fmt.Println()
171		fmt.Println("Waiting for authorization...")
172
173		t, err := copilot.PollForToken(loginCtx, dc)
174		if err == copilot.ErrNotAvailable {
175			fmt.Println()
176			fmt.Println("GitHub Copilot is unavailable for this account. To signup, go to the following page:")
177			fmt.Println()
178			fmt.Println(lipgloss.NewStyle().Hyperlink(copilot.SignupURL, "id=copilot-signup").Render(copilot.SignupURL))
179			fmt.Println()
180			fmt.Println("You may be able to request free access if eligible. For more information, see:")
181			fmt.Println()
182			fmt.Println(lipgloss.NewStyle().Hyperlink(copilot.FreeURL, "id=copilot-free").Render(copilot.FreeURL))
183		}
184		if err != nil {
185			return err
186		}
187		token = t
188	}
189
190	if err := cmp.Or(
191		c.SetConfigField(loginCtx, wsID, config.ScopeGlobal, "providers.copilot.api_key", token.AccessToken),
192		c.SetConfigField(loginCtx, wsID, config.ScopeGlobal, "providers.copilot.oauth", token),
193	); err != nil {
194		return err
195	}
196
197	fmt.Println()
198	fmt.Println("You're now authenticated with GitHub Copilot!")
199	return nil
200}
201
202func getLoginContext() context.Context {
203	ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
204	go func() {
205		<-ctx.Done()
206		cancel()
207		os.Exit(1)
208	}()
209	return ctx
210}
211
212func waitEnter() {
213	_, _ = fmt.Scanln()
214}