oauth.go

  1package copilot
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"io"
  8	"net/http"
  9	"net/url"
 10	"strings"
 11	"time"
 12
 13	"github.com/charmbracelet/crush/internal/oauth"
 14)
 15
 16const (
 17	clientID = "Iv1.b507a08c87ecfe98"
 18
 19	deviceCodeURL   = "https://github.com/login/device/code"
 20	accessTokenURL  = "https://github.com/login/oauth/access_token"
 21	copilotTokenURL = "https://api.github.com/copilot_internal/v2/token"
 22)
 23
 24type DeviceCode struct {
 25	DeviceCode      string `json:"device_code"`
 26	UserCode        string `json:"user_code"`
 27	VerificationURI string `json:"verification_uri"`
 28	ExpiresIn       int    `json:"expires_in"`
 29	Interval        int    `json:"interval"`
 30}
 31
 32// RequestDeviceCode initiates the device code flow with GitHub.
 33func RequestDeviceCode(ctx context.Context) (*DeviceCode, error) {
 34	data := url.Values{}
 35	data.Set("client_id", clientID)
 36	data.Set("scope", "read:user")
 37
 38	req, err := http.NewRequestWithContext(ctx, "POST", deviceCodeURL, strings.NewReader(data.Encode()))
 39	if err != nil {
 40		return nil, err
 41	}
 42	req.Header.Set("Accept", "application/json")
 43	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
 44	req.Header.Set("User-Agent", userAgent)
 45
 46	client := &http.Client{Timeout: 30 * time.Second}
 47	resp, err := client.Do(req)
 48	if err != nil {
 49		return nil, err
 50	}
 51	defer resp.Body.Close()
 52
 53	if resp.StatusCode != http.StatusOK {
 54		body, _ := io.ReadAll(resp.Body)
 55		return nil, fmt.Errorf("device code request failed: %s - %s", resp.Status, string(body))
 56	}
 57
 58	var dc DeviceCode
 59	if err := json.NewDecoder(resp.Body).Decode(&dc); err != nil {
 60		return nil, err
 61	}
 62	return &dc, nil
 63}
 64
 65// PollForToken polls GitHub for the access token after user authorization.
 66func PollForToken(ctx context.Context, dc *DeviceCode) (*oauth.Token, error) {
 67	interval := max(dc.Interval, 5)
 68	deadline := time.Now().Add(time.Duration(dc.ExpiresIn) * time.Second)
 69	ticker := time.NewTicker(time.Duration(interval) * time.Second)
 70	defer ticker.Stop()
 71
 72	for time.Now().Before(deadline) {
 73		select {
 74		case <-ctx.Done():
 75			return nil, ctx.Err()
 76		case <-ticker.C:
 77		}
 78
 79		token, err := tryGetToken(ctx, dc.DeviceCode)
 80		if err == errPending {
 81			continue
 82		}
 83		if err == errSlowDown {
 84			interval += 5
 85			ticker.Reset(time.Duration(interval) * time.Second)
 86			continue
 87		}
 88		if err != nil {
 89			return nil, err
 90		}
 91		return token, nil
 92	}
 93
 94	return nil, fmt.Errorf("authorization timed out")
 95}
 96
 97var (
 98	errPending  = fmt.Errorf("pending")
 99	errSlowDown = fmt.Errorf("slow_down")
100)
101
102func tryGetToken(ctx context.Context, deviceCode string) (*oauth.Token, error) {
103	data := url.Values{}
104	data.Set("client_id", clientID)
105	data.Set("device_code", deviceCode)
106	data.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code")
107
108	req, err := http.NewRequestWithContext(ctx, "POST", accessTokenURL, strings.NewReader(data.Encode()))
109	if err != nil {
110		return nil, err
111	}
112	req.Header.Set("Accept", "application/json")
113	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
114	req.Header.Set("User-Agent", userAgent)
115
116	client := &http.Client{Timeout: 30 * time.Second}
117	resp, err := client.Do(req)
118	if err != nil {
119		return nil, err
120	}
121	defer resp.Body.Close()
122
123	var result struct {
124		AccessToken string `json:"access_token"`
125		Error       string `json:"error"`
126	}
127	if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
128		return nil, err
129	}
130
131	switch result.Error {
132	case "":
133		if result.AccessToken == "" {
134			return nil, errPending
135		}
136		return getCopilotToken(ctx, result.AccessToken)
137	case "authorization_pending":
138		return nil, errPending
139	case "slow_down":
140		return nil, errSlowDown
141	default:
142		return nil, fmt.Errorf("authorization failed: %s", result.Error)
143	}
144}
145
146func getCopilotToken(ctx context.Context, githubToken string) (*oauth.Token, error) {
147	req, err := http.NewRequestWithContext(ctx, "GET", copilotTokenURL, nil)
148	if err != nil {
149		return nil, err
150	}
151	req.Header.Set("Accept", "application/json")
152	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", githubToken))
153	for k, v := range Headers() {
154		req.Header.Set(k, v)
155	}
156
157	client := &http.Client{Timeout: 30 * time.Second}
158	resp, err := client.Do(req)
159	if err != nil {
160		return nil, err
161	}
162	defer resp.Body.Close()
163
164	body, err := io.ReadAll(resp.Body)
165	if err != nil {
166		return nil, err
167	}
168
169	if resp.StatusCode != http.StatusOK {
170		if resp.StatusCode == http.StatusNotFound {
171			return nil, fmt.Errorf("copilot not available for this account\n\n" +
172				"Please ensure you have GitHub Copilot enabled at:\n" +
173				"https://github.com/settings/copilot")
174		}
175		return nil, fmt.Errorf("copilot token request failed: %s - %s", resp.Status, string(body))
176	}
177
178	var result struct {
179		Token     string `json:"token"`
180		ExpiresAt int64  `json:"expires_at"`
181	}
182	if err := json.Unmarshal(body, &result); err != nil {
183		return nil, err
184	}
185
186	copilotToken := &oauth.Token{
187		AccessToken:  result.Token,
188		RefreshToken: githubToken,
189		ExpiresAt:    result.ExpiresAt,
190	}
191	copilotToken.SetExpiresIn()
192
193	return copilotToken, nil
194}
195
196// RefreshToken refreshes the Copilot token using the GitHub token.
197func RefreshToken(ctx context.Context, githubToken string) (*oauth.Token, error) {
198	return getCopilotToken(ctx, githubToken)
199}