oauth.go

  1package copilot
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"net/http"
 10	"net/url"
 11	"strings"
 12	"time"
 13
 14	"github.com/charmbracelet/crush/internal/oauth"
 15)
 16
 17const (
 18	clientID = "Iv1.b507a08c87ecfe98"
 19
 20	deviceCodeURL   = "https://github.com/login/device/code"
 21	accessTokenURL  = "https://github.com/login/oauth/access_token"
 22	copilotTokenURL = "https://api.github.com/copilot_internal/v2/token"
 23)
 24
 25var ErrNotAvailable = errors.New("github copilot not available")
 26
 27type DeviceCode struct {
 28	DeviceCode      string `json:"device_code"`
 29	UserCode        string `json:"user_code"`
 30	VerificationURI string `json:"verification_uri"`
 31	ExpiresIn       int    `json:"expires_in"`
 32	Interval        int    `json:"interval"`
 33}
 34
 35// RequestDeviceCode initiates the device code flow with GitHub.
 36func RequestDeviceCode(ctx context.Context) (*DeviceCode, error) {
 37	data := url.Values{}
 38	data.Set("client_id", clientID)
 39	data.Set("scope", "read:user")
 40
 41	req, err := http.NewRequestWithContext(ctx, "POST", deviceCodeURL, strings.NewReader(data.Encode()))
 42	if err != nil {
 43		return nil, err
 44	}
 45	req.Header.Set("Accept", "application/json")
 46	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
 47	req.Header.Set("User-Agent", userAgent)
 48
 49	client := &http.Client{Timeout: 30 * time.Second}
 50	resp, err := client.Do(req)
 51	if err != nil {
 52		return nil, err
 53	}
 54	defer resp.Body.Close()
 55
 56	if resp.StatusCode != http.StatusOK {
 57		body, _ := io.ReadAll(resp.Body)
 58		return nil, fmt.Errorf("device code request failed: %s - %s", resp.Status, string(body))
 59	}
 60
 61	var dc DeviceCode
 62	if err := json.NewDecoder(resp.Body).Decode(&dc); err != nil {
 63		return nil, err
 64	}
 65	return &dc, nil
 66}
 67
 68// PollForToken polls GitHub for the access token after user authorization.
 69func PollForToken(ctx context.Context, dc *DeviceCode) (*oauth.Token, error) {
 70	interval := max(dc.Interval, 5)
 71	deadline := time.Now().Add(time.Duration(dc.ExpiresIn) * time.Second)
 72	ticker := time.NewTicker(time.Duration(interval) * time.Second)
 73	defer ticker.Stop()
 74
 75	for time.Now().Before(deadline) {
 76		select {
 77		case <-ctx.Done():
 78			return nil, ctx.Err()
 79		case <-ticker.C:
 80		}
 81
 82		token, err := tryGetToken(ctx, dc.DeviceCode)
 83		if err == errPending {
 84			continue
 85		}
 86		if err == errSlowDown {
 87			interval += 5
 88			ticker.Reset(time.Duration(interval) * time.Second)
 89			continue
 90		}
 91		if err != nil {
 92			return nil, err
 93		}
 94		return token, nil
 95	}
 96
 97	return nil, fmt.Errorf("authorization timed out")
 98}
 99
100var (
101	errPending  = fmt.Errorf("pending")
102	errSlowDown = fmt.Errorf("slow_down")
103)
104
105func tryGetToken(ctx context.Context, deviceCode string) (*oauth.Token, error) {
106	data := url.Values{}
107	data.Set("client_id", clientID)
108	data.Set("device_code", deviceCode)
109	data.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code")
110
111	req, err := http.NewRequestWithContext(ctx, "POST", accessTokenURL, strings.NewReader(data.Encode()))
112	if err != nil {
113		return nil, err
114	}
115	req.Header.Set("Accept", "application/json")
116	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
117	req.Header.Set("User-Agent", userAgent)
118
119	client := &http.Client{Timeout: 30 * time.Second}
120	resp, err := client.Do(req)
121	if err != nil {
122		return nil, err
123	}
124	defer resp.Body.Close()
125
126	var result struct {
127		AccessToken string `json:"access_token"`
128		Error       string `json:"error"`
129	}
130	if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
131		return nil, err
132	}
133
134	switch result.Error {
135	case "":
136		if result.AccessToken == "" {
137			return nil, errPending
138		}
139		return getCopilotToken(ctx, result.AccessToken)
140	case "authorization_pending":
141		return nil, errPending
142	case "slow_down":
143		return nil, errSlowDown
144	default:
145		return nil, fmt.Errorf("authorization failed: %s", result.Error)
146	}
147}
148
149func getCopilotToken(ctx context.Context, githubToken string) (*oauth.Token, error) {
150	req, err := http.NewRequestWithContext(ctx, "GET", copilotTokenURL, nil)
151	if err != nil {
152		return nil, err
153	}
154	req.Header.Set("Accept", "application/json")
155	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", githubToken))
156	for k, v := range Headers() {
157		req.Header.Set(k, v)
158	}
159
160	client := &http.Client{Timeout: 30 * time.Second}
161	resp, err := client.Do(req)
162	if err != nil {
163		return nil, err
164	}
165	defer resp.Body.Close()
166
167	body, err := io.ReadAll(resp.Body)
168	if err != nil {
169		return nil, err
170	}
171
172	if resp.StatusCode == http.StatusForbidden {
173		return nil, ErrNotAvailable
174	}
175	if resp.StatusCode != http.StatusOK {
176		return nil, fmt.Errorf("copilot token request failed: %s - %s", resp.Status, string(body))
177	}
178
179	var result struct {
180		Token     string `json:"token"`
181		ExpiresAt int64  `json:"expires_at"`
182	}
183	if err := json.Unmarshal(body, &result); err != nil {
184		return nil, err
185	}
186
187	copilotToken := &oauth.Token{
188		AccessToken:  result.Token,
189		RefreshToken: githubToken,
190		ExpiresAt:    result.ExpiresAt,
191	}
192	copilotToken.SetExpiresIn()
193
194	return copilotToken, nil
195}
196
197// RefreshToken refreshes the Copilot token using the GitHub token.
198func RefreshToken(ctx context.Context, githubToken string) (*oauth.Token, error) {
199	return getCopilotToken(ctx, githubToken)
200}