1package copilot
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "io"
8 "net/http"
9 "net/url"
10 "strings"
11 "time"
12
13 "github.com/charmbracelet/crush/internal/oauth"
14)
15
16const (
17 clientID = "Iv1.b507a08c87ecfe98"
18
19 deviceCodeURL = "https://github.com/login/device/code"
20 accessTokenURL = "https://github.com/login/oauth/access_token"
21 copilotTokenURL = "https://api.github.com/copilot_internal/v2/token"
22)
23
24type DeviceCode struct {
25 DeviceCode string `json:"device_code"`
26 UserCode string `json:"user_code"`
27 VerificationURI string `json:"verification_uri"`
28 ExpiresIn int `json:"expires_in"`
29 Interval int `json:"interval"`
30}
31
32// RequestDeviceCode initiates the device code flow with GitHub.
33func RequestDeviceCode(ctx context.Context) (*DeviceCode, error) {
34 data := url.Values{}
35 data.Set("client_id", clientID)
36 data.Set("scope", "read:user")
37
38 req, err := http.NewRequestWithContext(ctx, "POST", deviceCodeURL, strings.NewReader(data.Encode()))
39 if err != nil {
40 return nil, err
41 }
42 req.Header.Set("Accept", "application/json")
43 req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
44 req.Header.Set("User-Agent", userAgent)
45
46 client := &http.Client{Timeout: 30 * time.Second}
47 resp, err := client.Do(req)
48 if err != nil {
49 return nil, err
50 }
51 defer resp.Body.Close()
52
53 if resp.StatusCode != http.StatusOK {
54 body, _ := io.ReadAll(resp.Body)
55 return nil, fmt.Errorf("device code request failed: %s - %s", resp.Status, string(body))
56 }
57
58 var dc DeviceCode
59 if err := json.NewDecoder(resp.Body).Decode(&dc); err != nil {
60 return nil, err
61 }
62 return &dc, nil
63}
64
65// PollForToken polls GitHub for the access token after user authorization.
66func PollForToken(ctx context.Context, dc *DeviceCode) (*oauth.Token, error) {
67 interval := max(dc.Interval, 5)
68 deadline := time.Now().Add(time.Duration(dc.ExpiresIn) * time.Second)
69 ticker := time.NewTicker(time.Duration(interval) * time.Second)
70 defer ticker.Stop()
71
72 for time.Now().Before(deadline) {
73 select {
74 case <-ctx.Done():
75 return nil, ctx.Err()
76 case <-ticker.C:
77 }
78
79 token, err := tryGetToken(ctx, dc.DeviceCode)
80 if err == errPending {
81 continue
82 }
83 if err == errSlowDown {
84 interval += 5
85 ticker.Reset(time.Duration(interval) * time.Second)
86 continue
87 }
88 if err != nil {
89 return nil, err
90 }
91 return token, nil
92 }
93
94 return nil, fmt.Errorf("authorization timed out")
95}
96
97var (
98 errPending = fmt.Errorf("pending")
99 errSlowDown = fmt.Errorf("slow_down")
100)
101
102func tryGetToken(ctx context.Context, deviceCode string) (*oauth.Token, error) {
103 data := url.Values{}
104 data.Set("client_id", clientID)
105 data.Set("device_code", deviceCode)
106 data.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code")
107
108 req, err := http.NewRequestWithContext(ctx, "POST", accessTokenURL, strings.NewReader(data.Encode()))
109 if err != nil {
110 return nil, err
111 }
112 req.Header.Set("Accept", "application/json")
113 req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
114 req.Header.Set("User-Agent", userAgent)
115
116 client := &http.Client{Timeout: 30 * time.Second}
117 resp, err := client.Do(req)
118 if err != nil {
119 return nil, err
120 }
121 defer resp.Body.Close()
122
123 var result struct {
124 AccessToken string `json:"access_token"`
125 Error string `json:"error"`
126 }
127 if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
128 return nil, err
129 }
130
131 switch result.Error {
132 case "":
133 if result.AccessToken == "" {
134 return nil, errPending
135 }
136 return getCopilotToken(ctx, result.AccessToken)
137 case "authorization_pending":
138 return nil, errPending
139 case "slow_down":
140 return nil, errSlowDown
141 default:
142 return nil, fmt.Errorf("authorization failed: %s", result.Error)
143 }
144}
145
146func getCopilotToken(ctx context.Context, githubToken string) (*oauth.Token, error) {
147 req, err := http.NewRequestWithContext(ctx, "GET", copilotTokenURL, nil)
148 if err != nil {
149 return nil, err
150 }
151 req.Header.Set("Accept", "application/json")
152 req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", githubToken))
153 for k, v := range Headers() {
154 req.Header.Set(k, v)
155 }
156
157 client := &http.Client{Timeout: 30 * time.Second}
158 resp, err := client.Do(req)
159 if err != nil {
160 return nil, err
161 }
162 defer resp.Body.Close()
163
164 body, err := io.ReadAll(resp.Body)
165 if err != nil {
166 return nil, err
167 }
168
169 if resp.StatusCode != http.StatusOK {
170 if resp.StatusCode == http.StatusNotFound {
171 return nil, fmt.Errorf("copilot not available for this account\n\n" +
172 "Please ensure you have GitHub Copilot enabled at:\n" +
173 "https://github.com/settings/copilot")
174 }
175 return nil, fmt.Errorf("copilot token request failed: %s - %s", resp.Status, string(body))
176 }
177
178 var result struct {
179 Token string `json:"token"`
180 ExpiresAt int64 `json:"expires_at"`
181 }
182 if err := json.Unmarshal(body, &result); err != nil {
183 return nil, err
184 }
185
186 copilotToken := &oauth.Token{
187 AccessToken: result.Token,
188 RefreshToken: githubToken,
189 ExpiresAt: result.ExpiresAt,
190 }
191 copilotToken.SetExpiresIn()
192
193 return copilotToken, nil
194}
195
196// RefreshToken refreshes the Copilot token using the GitHub token.
197func RefreshToken(ctx context.Context, githubToken string) (*oauth.Token, error) {
198 return getCopilotToken(ctx, githubToken)
199}