1package cmd
2
3import (
4 "cmp"
5 "context"
6 "fmt"
7 "os"
8 "os/signal"
9
10 "charm.land/lipgloss/v2"
11 "github.com/atotto/clipboard"
12 hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
13 "github.com/charmbracelet/crush/internal/client"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/oauth"
16 "github.com/charmbracelet/crush/internal/oauth/copilot"
17 "github.com/charmbracelet/crush/internal/oauth/hyper"
18 "github.com/charmbracelet/x/ansi"
19 "github.com/pkg/browser"
20 "github.com/spf13/cobra"
21)
22
23var loginCmd = &cobra.Command{
24 Aliases: []string{"auth"},
25 Use: "login [platform]",
26 Short: "Login Crush to a platform",
27 Long: `Login Crush to a specified platform.
28The platform should be provided as an argument.
29Available platforms are: hyper, copilot.`,
30 Example: `
31# Authenticate with Charm Hyper
32crush login
33
34# Authenticate with GitHub Copilot
35crush login copilot
36 `,
37 ValidArgs: []cobra.Completion{
38 "hyper",
39 "copilot",
40 "github",
41 "github-copilot",
42 },
43 Args: cobra.MaximumNArgs(1),
44 RunE: func(cmd *cobra.Command, args []string) error {
45 c, ws, cleanup, err := connectToServer(cmd)
46 if err != nil {
47 return err
48 }
49 defer cleanup()
50
51 progressEnabled := ws.Config.Options.Progress == nil || *ws.Config.Options.Progress
52 if progressEnabled && supportsProgressBar() {
53 _, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
54 defer func() { _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar) }()
55 }
56
57 provider := "hyper"
58 if len(args) > 0 {
59 provider = args[0]
60 }
61 switch provider {
62 case "hyper":
63 return loginHyper(c, ws.ID)
64 case "copilot", "github", "github-copilot":
65 return loginCopilot(cmd.Context(), c, ws.ID)
66 default:
67 return fmt.Errorf("unknown platform: %s", args[0])
68 }
69 },
70}
71
72func loginHyper(c *client.Client, wsID string) error {
73 if !hyperp.Enabled() {
74 return fmt.Errorf("hyper not enabled")
75 }
76 ctx := getLoginContext()
77
78 resp, err := hyper.InitiateDeviceAuth(ctx)
79 if err != nil {
80 return err
81 }
82
83 if clipboard.WriteAll(resp.UserCode) == nil {
84 fmt.Println("The following code should be on clipboard already:")
85 } else {
86 fmt.Println("Copy the following code:")
87 }
88
89 fmt.Println()
90 fmt.Println(lipgloss.NewStyle().Bold(true).Render(resp.UserCode))
91 fmt.Println()
92 fmt.Println("Press enter to open this URL, and then paste it there:")
93 fmt.Println()
94 fmt.Println(lipgloss.NewStyle().Hyperlink(resp.VerificationURL, "id=hyper").Render(resp.VerificationURL))
95 fmt.Println()
96 waitEnter()
97 if err := browser.OpenURL(resp.VerificationURL); err != nil {
98 fmt.Println("Could not open the URL. You'll need to manually open the URL in your browser.")
99 }
100
101 fmt.Println("Exchanging authorization code...")
102 refreshToken, err := hyper.PollForToken(ctx, resp.DeviceCode, resp.ExpiresIn)
103 if err != nil {
104 return err
105 }
106
107 fmt.Println("Exchanging refresh token for access token...")
108 token, err := hyper.ExchangeToken(ctx, refreshToken)
109 if err != nil {
110 return err
111 }
112
113 fmt.Println("Verifying access token...")
114 introspect, err := hyper.IntrospectToken(ctx, token.AccessToken)
115 if err != nil {
116 return fmt.Errorf("token introspection failed: %w", err)
117 }
118 if !introspect.Active {
119 return fmt.Errorf("access token is not active")
120 }
121
122 if err := cmp.Or(
123 c.SetConfigField(ctx, wsID, config.ScopeGlobal, "providers.hyper.api_key", token.AccessToken),
124 c.SetConfigField(ctx, wsID, config.ScopeGlobal, "providers.hyper.oauth", token),
125 ); err != nil {
126 return err
127 }
128
129 fmt.Println()
130 fmt.Println("You're now authenticated with Hyper!")
131 return nil
132}
133
134func loginCopilot(ctx context.Context, c *client.Client, wsID string) error {
135 loginCtx := getLoginContext()
136
137 cfg, err := c.GetConfig(ctx, wsID)
138 if err == nil && cfg != nil {
139 if pc, ok := cfg.Providers.Get("copilot"); ok && pc.OAuthToken != nil {
140 fmt.Println("You are already logged in to GitHub Copilot.")
141 return nil
142 }
143 }
144
145 diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
146 var token *oauth.Token
147
148 switch {
149 case hasDiskToken:
150 fmt.Println("Found existing GitHub Copilot token on disk. Using it to authenticate...")
151
152 t, err := copilot.RefreshToken(loginCtx, diskToken)
153 if err != nil {
154 return fmt.Errorf("unable to refresh token from disk: %w", err)
155 }
156 token = t
157 default:
158 fmt.Println("Requesting device code from GitHub...")
159 dc, err := copilot.RequestDeviceCode(loginCtx)
160 if err != nil {
161 return err
162 }
163
164 fmt.Println()
165 fmt.Println("Open the following URL and follow the instructions to authenticate with GitHub Copilot:")
166 fmt.Println()
167 fmt.Println(lipgloss.NewStyle().Hyperlink(dc.VerificationURI, "id=copilot").Render(dc.VerificationURI))
168 fmt.Println()
169 fmt.Println("Code:", lipgloss.NewStyle().Bold(true).Render(dc.UserCode))
170 fmt.Println()
171 fmt.Println("Waiting for authorization...")
172
173 t, err := copilot.PollForToken(loginCtx, dc)
174 if err == copilot.ErrNotAvailable {
175 fmt.Println()
176 fmt.Println("GitHub Copilot is unavailable for this account. To signup, go to the following page:")
177 fmt.Println()
178 fmt.Println(lipgloss.NewStyle().Hyperlink(copilot.SignupURL, "id=copilot-signup").Render(copilot.SignupURL))
179 fmt.Println()
180 fmt.Println("You may be able to request free access if eligible. For more information, see:")
181 fmt.Println()
182 fmt.Println(lipgloss.NewStyle().Hyperlink(copilot.FreeURL, "id=copilot-free").Render(copilot.FreeURL))
183 }
184 if err != nil {
185 return err
186 }
187 token = t
188 }
189
190 if err := cmp.Or(
191 c.SetConfigField(loginCtx, wsID, config.ScopeGlobal, "providers.copilot.api_key", token.AccessToken),
192 c.SetConfigField(loginCtx, wsID, config.ScopeGlobal, "providers.copilot.oauth", token),
193 ); err != nil {
194 return err
195 }
196
197 fmt.Println()
198 fmt.Println("You're now authenticated with GitHub Copilot!")
199 return nil
200}
201
202func getLoginContext() context.Context {
203 ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
204 go func() {
205 <-ctx.Done()
206 cancel()
207 os.Exit(1)
208 }()
209 return ctx
210}
211
212func waitEnter() {
213 _, _ = fmt.Scanln()
214}