1package copilot
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "net/http"
10 "net/url"
11 "strings"
12 "time"
13
14 "github.com/charmbracelet/crush/internal/oauth"
15)
16
17const (
18 clientID = "Iv1.b507a08c87ecfe98"
19
20 deviceCodeURL = "https://github.com/login/device/code"
21 accessTokenURL = "https://github.com/login/oauth/access_token"
22 copilotTokenURL = "https://api.github.com/copilot_internal/v2/token"
23)
24
25var ErrNotAvailable = errors.New("github copilot not available")
26
27type DeviceCode struct {
28 DeviceCode string `json:"device_code"`
29 UserCode string `json:"user_code"`
30 VerificationURI string `json:"verification_uri"`
31 ExpiresIn int `json:"expires_in"`
32 Interval int `json:"interval"`
33}
34
35// RequestDeviceCode initiates the device code flow with GitHub.
36func RequestDeviceCode(ctx context.Context) (*DeviceCode, error) {
37 data := url.Values{}
38 data.Set("client_id", clientID)
39 data.Set("scope", "read:user")
40
41 req, err := http.NewRequestWithContext(ctx, "POST", deviceCodeURL, strings.NewReader(data.Encode()))
42 if err != nil {
43 return nil, err
44 }
45 req.Header.Set("Accept", "application/json")
46 req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
47 req.Header.Set("User-Agent", userAgent)
48
49 client := &http.Client{Timeout: 30 * time.Second}
50 resp, err := client.Do(req)
51 if err != nil {
52 return nil, err
53 }
54 defer resp.Body.Close()
55
56 if resp.StatusCode != http.StatusOK {
57 body, _ := io.ReadAll(resp.Body)
58 return nil, fmt.Errorf("device code request failed: %s - %s", resp.Status, string(body))
59 }
60
61 var dc DeviceCode
62 if err := json.NewDecoder(resp.Body).Decode(&dc); err != nil {
63 return nil, err
64 }
65 return &dc, nil
66}
67
68// PollForToken polls GitHub for the access token after user authorization.
69func PollForToken(ctx context.Context, dc *DeviceCode) (*oauth.Token, error) {
70 interval := max(dc.Interval, 5)
71 deadline := time.Now().Add(time.Duration(dc.ExpiresIn) * time.Second)
72 ticker := time.NewTicker(time.Duration(interval) * time.Second)
73 defer ticker.Stop()
74
75 for time.Now().Before(deadline) {
76 select {
77 case <-ctx.Done():
78 return nil, ctx.Err()
79 case <-ticker.C:
80 }
81
82 token, err := tryGetToken(ctx, dc.DeviceCode)
83 if err == errPending {
84 continue
85 }
86 if err == errSlowDown {
87 interval += 5
88 ticker.Reset(time.Duration(interval) * time.Second)
89 continue
90 }
91 if err != nil {
92 return nil, err
93 }
94 return token, nil
95 }
96
97 return nil, fmt.Errorf("authorization timed out")
98}
99
100var (
101 errPending = fmt.Errorf("pending")
102 errSlowDown = fmt.Errorf("slow_down")
103)
104
105func tryGetToken(ctx context.Context, deviceCode string) (*oauth.Token, error) {
106 data := url.Values{}
107 data.Set("client_id", clientID)
108 data.Set("device_code", deviceCode)
109 data.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code")
110
111 req, err := http.NewRequestWithContext(ctx, "POST", accessTokenURL, strings.NewReader(data.Encode()))
112 if err != nil {
113 return nil, err
114 }
115 req.Header.Set("Accept", "application/json")
116 req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
117 req.Header.Set("User-Agent", userAgent)
118
119 client := &http.Client{Timeout: 30 * time.Second}
120 resp, err := client.Do(req)
121 if err != nil {
122 return nil, err
123 }
124 defer resp.Body.Close()
125
126 var result struct {
127 AccessToken string `json:"access_token"`
128 Error string `json:"error"`
129 }
130 if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
131 return nil, err
132 }
133
134 switch result.Error {
135 case "":
136 if result.AccessToken == "" {
137 return nil, errPending
138 }
139 return getCopilotToken(ctx, result.AccessToken)
140 case "authorization_pending":
141 return nil, errPending
142 case "slow_down":
143 return nil, errSlowDown
144 default:
145 return nil, fmt.Errorf("authorization failed: %s", result.Error)
146 }
147}
148
149func getCopilotToken(ctx context.Context, githubToken string) (*oauth.Token, error) {
150 req, err := http.NewRequestWithContext(ctx, "GET", copilotTokenURL, nil)
151 if err != nil {
152 return nil, err
153 }
154 req.Header.Set("Accept", "application/json")
155 req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", githubToken))
156 for k, v := range Headers() {
157 req.Header.Set(k, v)
158 }
159
160 client := &http.Client{Timeout: 30 * time.Second}
161 resp, err := client.Do(req)
162 if err != nil {
163 return nil, err
164 }
165 defer resp.Body.Close()
166
167 body, err := io.ReadAll(resp.Body)
168 if err != nil {
169 return nil, err
170 }
171
172 if resp.StatusCode == http.StatusForbidden {
173 return nil, ErrNotAvailable
174 }
175 if resp.StatusCode != http.StatusOK {
176 return nil, fmt.Errorf("copilot token request failed: %s - %s", resp.Status, string(body))
177 }
178
179 var result struct {
180 Token string `json:"token"`
181 ExpiresAt int64 `json:"expires_at"`
182 }
183 if err := json.Unmarshal(body, &result); err != nil {
184 return nil, err
185 }
186
187 copilotToken := &oauth.Token{
188 AccessToken: result.Token,
189 RefreshToken: githubToken,
190 ExpiresAt: result.ExpiresAt,
191 }
192 copilotToken.SetExpiresIn()
193
194 return copilotToken, nil
195}
196
197// RefreshToken refreshes the Copilot token using the GitHub token.
198func RefreshToken(ctx context.Context, githubToken string) (*oauth.Token, error) {
199 return getCopilotToken(ctx, githubToken)
200}