1package cmd
2
3import (
4 "cmp"
5 "context"
6 "fmt"
7 "os"
8 "os/signal"
9
10 "charm.land/lipgloss/v2"
11 "github.com/atotto/clipboard"
12 "github.com/charmbracelet/crush/internal/client"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/oauth"
15 "github.com/charmbracelet/crush/internal/oauth/copilot"
16 "github.com/charmbracelet/crush/internal/oauth/hyper"
17 "github.com/charmbracelet/x/ansi"
18 "github.com/pkg/browser"
19 "github.com/spf13/cobra"
20)
21
22var loginCmd = &cobra.Command{
23 Aliases: []string{"auth"},
24 Use: "login [platform]",
25 Short: "Login Crush to a platform",
26 Long: `Login Crush to a specified platform.
27The platform should be provided as an argument.
28Available platforms are: hyper, copilot.`,
29 Example: `
30# Authenticate with Charm Hyper
31crush login
32
33# Authenticate with GitHub Copilot
34crush login copilot
35
36# Force re-authentication even if already logged in
37crush login -f copilot
38 `,
39 ValidArgs: []cobra.Completion{
40 "hyper",
41 "copilot",
42 "github",
43 "github-copilot",
44 },
45 Args: cobra.MaximumNArgs(1),
46 RunE: func(cmd *cobra.Command, args []string) error {
47 c, ws, cleanup, err := connectToServer(cmd)
48 if err != nil {
49 return err
50 }
51 defer cleanup()
52
53 progressEnabled := ws.Config.Options.Progress == nil || *ws.Config.Options.Progress
54 if progressEnabled && supportsProgressBar() {
55 _, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
56 defer func() { _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar) }()
57 }
58
59 provider := "hyper"
60 if len(args) > 0 {
61 provider = args[0]
62 }
63 force, _ := cmd.Flags().GetBool("force")
64 switch provider {
65 case "hyper":
66 return loginHyper(c, ws.ID, force)
67 case "copilot", "github", "github-copilot":
68 return loginCopilot(c, ws.ID, force)
69 default:
70 return fmt.Errorf("unknown platform: %s", args[0])
71 }
72 },
73}
74
75func init() {
76 loginCmd.Flags().BoolP("force", "f", false, "Force re-authentication even if already logged in")
77}
78
79func loginHyper(c *client.Client, wsID string, force bool) error {
80 ctx := getLoginContext()
81
82 if !force {
83 cfg, err := c.GetConfig(ctx, wsID)
84 if err == nil && cfg != nil {
85 if pc, ok := cfg.Providers.Get("hyper"); ok && pc.OAuthToken != nil {
86 fmt.Println("You are already logged in to Hyper.")
87 fmt.Println("Use --force to re-authenticate.")
88 return nil
89 }
90 }
91 }
92
93 resp, err := hyper.InitiateDeviceAuth(ctx)
94 if err != nil {
95 return err
96 }
97
98 if clipboard.WriteAll(resp.UserCode) == nil {
99 fmt.Println("The following code should be on clipboard already:")
100 } else {
101 fmt.Println("Copy the following code:")
102 }
103
104 fmt.Println()
105 fmt.Println(lipgloss.NewStyle().Bold(true).Render(resp.UserCode))
106 fmt.Println()
107 fmt.Println("Press enter to open this URL, and then paste it there:")
108 fmt.Println()
109 fmt.Println(lipgloss.NewStyle().Hyperlink(resp.VerificationURL, "id=hyper").Render(resp.VerificationURL))
110 fmt.Println()
111 waitEnter()
112 if err := browser.OpenURL(resp.VerificationURL); err != nil {
113 fmt.Println("Could not open the URL. You'll need to manually open the URL in your browser.")
114 }
115
116 fmt.Println("Exchanging authorization code...")
117 refreshToken, err := hyper.PollForToken(ctx, resp.DeviceCode, resp.ExpiresIn)
118 if err != nil {
119 return err
120 }
121
122 fmt.Println("Exchanging refresh token for access token...")
123 token, err := hyper.ExchangeToken(ctx, refreshToken)
124 if err != nil {
125 return err
126 }
127
128 fmt.Println("Verifying access token...")
129 introspect, err := hyper.IntrospectToken(ctx, token.AccessToken)
130 if err != nil {
131 return fmt.Errorf("token introspection failed: %w", err)
132 }
133 if !introspect.Active {
134 return fmt.Errorf("access token is not active")
135 }
136
137 if err := cmp.Or(
138 c.SetConfigField(ctx, wsID, config.ScopeGlobal, "providers.hyper.api_key", token.AccessToken),
139 c.SetConfigField(ctx, wsID, config.ScopeGlobal, "providers.hyper.oauth", token),
140 ); err != nil {
141 return err
142 }
143
144 fmt.Println()
145 fmt.Println("You're now authenticated with Hyper!")
146 return nil
147}
148
149func loginCopilot(c *client.Client, wsID string, force bool) error {
150 loginCtx := getLoginContext()
151
152 if !force {
153 cfg, err := c.GetConfig(loginCtx, wsID)
154 if err == nil && cfg != nil {
155 if pc, ok := cfg.Providers.Get("copilot"); ok && pc.OAuthToken != nil {
156 fmt.Println("You are already logged in to GitHub Copilot.")
157 fmt.Println("Use --force to re-authenticate.")
158 return nil
159 }
160 }
161 }
162
163 diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
164 var token *oauth.Token
165
166 switch {
167 case hasDiskToken:
168 fmt.Println("Found existing GitHub Copilot token on disk. Using it to authenticate...")
169
170 t, err := copilot.RefreshToken(loginCtx, diskToken)
171 if err != nil {
172 return fmt.Errorf("unable to refresh token from disk: %w", err)
173 }
174 token = t
175 default:
176 fmt.Println("Requesting device code from GitHub...")
177 dc, err := copilot.RequestDeviceCode(loginCtx)
178 if err != nil {
179 return err
180 }
181
182 fmt.Println()
183 fmt.Println("Open the following URL and follow the instructions to authenticate with GitHub Copilot:")
184 fmt.Println()
185 fmt.Println(lipgloss.NewStyle().Hyperlink(dc.VerificationURI, "id=copilot").Render(dc.VerificationURI))
186 fmt.Println()
187 fmt.Println("Code:", lipgloss.NewStyle().Bold(true).Render(dc.UserCode))
188 fmt.Println()
189 fmt.Println("Waiting for authorization...")
190
191 t, err := copilot.PollForToken(loginCtx, dc)
192 if err == copilot.ErrNotAvailable {
193 fmt.Println()
194 fmt.Println("GitHub Copilot is unavailable for this account. To signup, go to the following page:")
195 fmt.Println()
196 fmt.Println(lipgloss.NewStyle().Hyperlink(copilot.SignupURL, "id=copilot-signup").Render(copilot.SignupURL))
197 fmt.Println()
198 fmt.Println("You may be able to request free access if eligible. For more information, see:")
199 fmt.Println()
200 fmt.Println(lipgloss.NewStyle().Hyperlink(copilot.FreeURL, "id=copilot-free").Render(copilot.FreeURL))
201 }
202 if err != nil {
203 return err
204 }
205 token = t
206 }
207
208 if err := cmp.Or(
209 c.SetConfigField(loginCtx, wsID, config.ScopeGlobal, "providers.copilot.api_key", token.AccessToken),
210 c.SetConfigField(loginCtx, wsID, config.ScopeGlobal, "providers.copilot.oauth", token),
211 ); err != nil {
212 return err
213 }
214
215 fmt.Println()
216 fmt.Println("You're now authenticated with GitHub Copilot!")
217 return nil
218}
219
220func getLoginContext() context.Context {
221 ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
222 go func() {
223 <-ctx.Done()
224 cancel()
225 os.Exit(1)
226 }()
227 return ctx
228}
229
230func waitEnter() {
231 _, _ = fmt.Scanln()
232}