1package cmd
2
3import (
4 "cmp"
5 "context"
6 "fmt"
7 "os"
8 "os/signal"
9
10 "charm.land/lipgloss/v2"
11 "github.com/atotto/clipboard"
12 "github.com/charmbracelet/crush/internal/client"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/oauth"
15 "github.com/charmbracelet/crush/internal/oauth/copilot"
16 "github.com/charmbracelet/crush/internal/oauth/hyper"
17 "github.com/charmbracelet/x/ansi"
18 "github.com/pkg/browser"
19 "github.com/spf13/cobra"
20)
21
22var loginCmd = &cobra.Command{
23 Aliases: []string{"auth"},
24 Use: "login [platform]",
25 Short: "Login Crush to a platform",
26 Long: `Login Crush to a specified platform.
27The platform should be provided as an argument.
28Available platforms are: hyper, copilot.`,
29 Example: `
30# Authenticate with Charm Hyper
31crush login
32
33# Authenticate with GitHub Copilot
34crush login copilot
35 `,
36 ValidArgs: []cobra.Completion{
37 "hyper",
38 "copilot",
39 "github",
40 "github-copilot",
41 },
42 Args: cobra.MaximumNArgs(1),
43 RunE: func(cmd *cobra.Command, args []string) error {
44 c, ws, cleanup, err := connectToServer(cmd)
45 if err != nil {
46 return err
47 }
48 defer cleanup()
49
50 progressEnabled := ws.Config.Options.Progress == nil || *ws.Config.Options.Progress
51 if progressEnabled && supportsProgressBar() {
52 _, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
53 defer func() { _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar) }()
54 }
55
56 provider := "hyper"
57 if len(args) > 0 {
58 provider = args[0]
59 }
60 switch provider {
61 case "hyper":
62 return loginHyper(c, ws.ID)
63 case "copilot", "github", "github-copilot":
64 return loginCopilot(cmd.Context(), c, ws.ID)
65 default:
66 return fmt.Errorf("unknown platform: %s", args[0])
67 }
68 },
69}
70
71func loginHyper(c *client.Client, wsID string) error {
72 ctx := getLoginContext()
73
74 resp, err := hyper.InitiateDeviceAuth(ctx)
75 if err != nil {
76 return err
77 }
78
79 if clipboard.WriteAll(resp.UserCode) == nil {
80 fmt.Println("The following code should be on clipboard already:")
81 } else {
82 fmt.Println("Copy the following code:")
83 }
84
85 fmt.Println()
86 fmt.Println(lipgloss.NewStyle().Bold(true).Render(resp.UserCode))
87 fmt.Println()
88 fmt.Println("Press enter to open this URL, and then paste it there:")
89 fmt.Println()
90 fmt.Println(lipgloss.NewStyle().Hyperlink(resp.VerificationURL, "id=hyper").Render(resp.VerificationURL))
91 fmt.Println()
92 waitEnter()
93 if err := browser.OpenURL(resp.VerificationURL); err != nil {
94 fmt.Println("Could not open the URL. You'll need to manually open the URL in your browser.")
95 }
96
97 fmt.Println("Exchanging authorization code...")
98 refreshToken, err := hyper.PollForToken(ctx, resp.DeviceCode, resp.ExpiresIn)
99 if err != nil {
100 return err
101 }
102
103 fmt.Println("Exchanging refresh token for access token...")
104 token, err := hyper.ExchangeToken(ctx, refreshToken)
105 if err != nil {
106 return err
107 }
108
109 fmt.Println("Verifying access token...")
110 introspect, err := hyper.IntrospectToken(ctx, token.AccessToken)
111 if err != nil {
112 return fmt.Errorf("token introspection failed: %w", err)
113 }
114 if !introspect.Active {
115 return fmt.Errorf("access token is not active")
116 }
117
118 if err := cmp.Or(
119 c.SetConfigField(ctx, wsID, config.ScopeGlobal, "providers.hyper.api_key", token.AccessToken),
120 c.SetConfigField(ctx, wsID, config.ScopeGlobal, "providers.hyper.oauth", token),
121 ); err != nil {
122 return err
123 }
124
125 fmt.Println()
126 fmt.Println("You're now authenticated with Hyper!")
127 return nil
128}
129
130func loginCopilot(ctx context.Context, c *client.Client, wsID string) error {
131 loginCtx := getLoginContext()
132
133 cfg, err := c.GetConfig(ctx, wsID)
134 if err == nil && cfg != nil {
135 if pc, ok := cfg.Providers.Get("copilot"); ok && pc.OAuthToken != nil {
136 fmt.Println("You are already logged in to GitHub Copilot.")
137 return nil
138 }
139 }
140
141 diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
142 var token *oauth.Token
143
144 switch {
145 case hasDiskToken:
146 fmt.Println("Found existing GitHub Copilot token on disk. Using it to authenticate...")
147
148 t, err := copilot.RefreshToken(loginCtx, diskToken)
149 if err != nil {
150 return fmt.Errorf("unable to refresh token from disk: %w", err)
151 }
152 token = t
153 default:
154 fmt.Println("Requesting device code from GitHub...")
155 dc, err := copilot.RequestDeviceCode(loginCtx)
156 if err != nil {
157 return err
158 }
159
160 fmt.Println()
161 fmt.Println("Open the following URL and follow the instructions to authenticate with GitHub Copilot:")
162 fmt.Println()
163 fmt.Println(lipgloss.NewStyle().Hyperlink(dc.VerificationURI, "id=copilot").Render(dc.VerificationURI))
164 fmt.Println()
165 fmt.Println("Code:", lipgloss.NewStyle().Bold(true).Render(dc.UserCode))
166 fmt.Println()
167 fmt.Println("Waiting for authorization...")
168
169 t, err := copilot.PollForToken(loginCtx, dc)
170 if err == copilot.ErrNotAvailable {
171 fmt.Println()
172 fmt.Println("GitHub Copilot is unavailable for this account. To signup, go to the following page:")
173 fmt.Println()
174 fmt.Println(lipgloss.NewStyle().Hyperlink(copilot.SignupURL, "id=copilot-signup").Render(copilot.SignupURL))
175 fmt.Println()
176 fmt.Println("You may be able to request free access if eligible. For more information, see:")
177 fmt.Println()
178 fmt.Println(lipgloss.NewStyle().Hyperlink(copilot.FreeURL, "id=copilot-free").Render(copilot.FreeURL))
179 }
180 if err != nil {
181 return err
182 }
183 token = t
184 }
185
186 if err := cmp.Or(
187 c.SetConfigField(loginCtx, wsID, config.ScopeGlobal, "providers.copilot.api_key", token.AccessToken),
188 c.SetConfigField(loginCtx, wsID, config.ScopeGlobal, "providers.copilot.oauth", token),
189 ); err != nil {
190 return err
191 }
192
193 fmt.Println()
194 fmt.Println("You're now authenticated with GitHub Copilot!")
195 return nil
196}
197
198func getLoginContext() context.Context {
199 ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
200 go func() {
201 <-ctx.Done()
202 cancel()
203 os.Exit(1)
204 }()
205 return ctx
206}
207
208func waitEnter() {
209 _, _ = fmt.Scanln()
210}