1// Package hyper provides functions to handle Hyper device flow authentication.
2package hyper
3
4import (
5 "bytes"
6 "context"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "net/http"
12 "os"
13 "strings"
14 "time"
15
16 "github.com/charmbracelet/crush/internal/agent/hyper"
17 "github.com/charmbracelet/crush/internal/event"
18 "github.com/charmbracelet/crush/internal/oauth"
19)
20
21// DeviceAuthResponse contains the response from the device authorization endpoint.
22type DeviceAuthResponse struct {
23 DeviceCode string `json:"device_code"`
24 UserCode string `json:"user_code"`
25 VerificationURL string `json:"verification_url"`
26 ExpiresIn int `json:"expires_in"`
27}
28
29// TokenResponse contains the response from the polling endpoint.
30type TokenResponse struct {
31 RefreshToken string `json:"refresh_token,omitempty"`
32 UserID string `json:"user_id"`
33 OrganizationID string `json:"organization_id"`
34 OrganizationName string `json:"organization_name"`
35 Error string `json:"error,omitempty"`
36 ErrorDescription string `json:"error_description,omitempty"`
37}
38
39// InitiateDeviceAuth calls the /device/auth endpoint to start the device flow.
40func InitiateDeviceAuth(ctx context.Context) (*DeviceAuthResponse, error) {
41 url := hyper.BaseURL() + "/device/auth"
42
43 req, err := http.NewRequestWithContext(
44 ctx, http.MethodPost, url,
45 strings.NewReader(fmt.Sprintf(`{"device_name":%q}`, deviceName())),
46 )
47 if err != nil {
48 return nil, fmt.Errorf("create request: %w", err)
49 }
50
51 req.Header.Set("Content-Type", "application/json")
52 req.Header.Set("User-Agent", "crush")
53
54 client := &http.Client{Timeout: 30 * time.Second}
55 resp, err := client.Do(req)
56 if err != nil {
57 return nil, fmt.Errorf("execute request: %w", err)
58 }
59 defer resp.Body.Close()
60
61 body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
62 if err != nil {
63 return nil, fmt.Errorf("read response: %w", err)
64 }
65
66 if resp.StatusCode != http.StatusOK {
67 return nil, fmt.Errorf("device auth failed: status %d, body %q", resp.StatusCode, string(body))
68 }
69
70 var authResp DeviceAuthResponse
71 if err := json.Unmarshal(body, &authResp); err != nil {
72 return nil, fmt.Errorf("unmarshal response: %w", err)
73 }
74
75 return &authResp, nil
76}
77
78func deviceName() string {
79 if hostname, err := os.Hostname(); err == nil && hostname != "" {
80 return "Crush (" + hostname + ")"
81 }
82 return "Crush"
83}
84
85// PollForToken polls the /device/token endpoint until authorization is complete.
86// It respects the polling interval and handles various error states.
87func PollForToken(ctx context.Context, deviceCode string, expiresIn int) (string, error) {
88 ctx, cancel := context.WithTimeout(ctx, time.Duration(expiresIn)*time.Second)
89 defer cancel()
90
91 d := 5 * time.Second
92 ticker := time.NewTicker(d)
93 defer ticker.Stop()
94
95 for {
96 select {
97 case <-ctx.Done():
98 return "", ctx.Err()
99 case <-ticker.C:
100 result, err := pollOnce(ctx, deviceCode)
101 if err != nil {
102 return "", err
103 }
104 if result.RefreshToken != "" {
105 event.Alias(result.UserID)
106 return result.RefreshToken, nil
107 }
108 switch result.Error {
109 case "authorization_pending":
110 continue
111 default:
112 return "", errors.New(result.ErrorDescription)
113 }
114 }
115 }
116}
117
118func pollOnce(ctx context.Context, deviceCode string) (TokenResponse, error) {
119 var result TokenResponse
120 url := fmt.Sprintf("%s/device/auth/%s", hyper.BaseURL(), deviceCode)
121 req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
122 if err != nil {
123 return result, fmt.Errorf("create request: %w", err)
124 }
125
126 req.Header.Set("Content-Type", "application/json")
127 req.Header.Set("User-Agent", "crush")
128
129 client := &http.Client{Timeout: 30 * time.Second}
130 resp, err := client.Do(req)
131 if err != nil {
132 return result, fmt.Errorf("execute request: %w", err)
133 }
134 defer resp.Body.Close()
135
136 body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
137 if err != nil {
138 return result, fmt.Errorf("read response: %w", err)
139 }
140
141 if err := json.Unmarshal(body, &result); err != nil {
142 return result, fmt.Errorf("unmarshal response: %w: %s", err, string(body))
143 }
144
145 if resp.StatusCode != http.StatusOK {
146 return result, fmt.Errorf("token request failed: status %d body %q", resp.StatusCode, string(body))
147 }
148
149 return result, nil
150}
151
152// ExchangeToken exchanges a refresh token for an access token.
153func ExchangeToken(ctx context.Context, refreshToken string) (*oauth.Token, error) {
154 reqBody := map[string]string{
155 "refresh_token": refreshToken,
156 }
157
158 data, err := json.Marshal(reqBody)
159 if err != nil {
160 return nil, fmt.Errorf("marshal request: %w", err)
161 }
162
163 url := hyper.BaseURL() + "/token/exchange"
164 req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(data))
165 if err != nil {
166 return nil, fmt.Errorf("create request: %w", err)
167 }
168
169 req.Header.Set("Content-Type", "application/json")
170 req.Header.Set("User-Agent", "crush")
171
172 client := &http.Client{Timeout: 30 * time.Second}
173 resp, err := client.Do(req)
174 if err != nil {
175 return nil, fmt.Errorf("execute request: %w", err)
176 }
177 defer resp.Body.Close()
178
179 body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
180 if err != nil {
181 return nil, fmt.Errorf("read response: %w", err)
182 }
183
184 if resp.StatusCode != http.StatusOK {
185 return nil, fmt.Errorf("token exchange failed: status %d body %q", resp.StatusCode, string(body))
186 }
187
188 var token oauth.Token
189 if err := json.Unmarshal(body, &token); err != nil {
190 return nil, fmt.Errorf("unmarshal response: %w", err)
191 }
192
193 token.SetExpiresAt()
194 return &token, nil
195}
196
197// IntrospectTokenResponse contains the response from the token introspection endpoint.
198type IntrospectTokenResponse struct {
199 Active bool `json:"active"`
200 Sub string `json:"sub,omitempty"`
201 OrgID string `json:"org_id,omitempty"`
202 Exp int64 `json:"exp,omitempty"`
203 Iat int64 `json:"iat,omitempty"`
204 Iss string `json:"iss,omitempty"`
205 Jti string `json:"jti,omitempty"`
206}
207
208// IntrospectToken validates an access token using the introspection endpoint.
209// Implements OAuth2 Token Introspection (RFC 7662).
210func IntrospectToken(ctx context.Context, accessToken string) (*IntrospectTokenResponse, error) {
211 reqBody := map[string]string{
212 "token": accessToken,
213 }
214
215 data, err := json.Marshal(reqBody)
216 if err != nil {
217 return nil, fmt.Errorf("marshal request: %w", err)
218 }
219
220 url := hyper.BaseURL() + "/token/introspect"
221 req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(data))
222 if err != nil {
223 return nil, fmt.Errorf("create request: %w", err)
224 }
225
226 req.Header.Set("Content-Type", "application/json")
227 req.Header.Set("User-Agent", "crush")
228
229 client := &http.Client{Timeout: 30 * time.Second}
230 resp, err := client.Do(req)
231 if err != nil {
232 return nil, fmt.Errorf("execute request: %w", err)
233 }
234 defer resp.Body.Close()
235
236 body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
237 if err != nil {
238 return nil, fmt.Errorf("read response: %w", err)
239 }
240
241 if resp.StatusCode != http.StatusOK {
242 return nil, fmt.Errorf("token introspection failed: status %d body %q", resp.StatusCode, string(body))
243 }
244
245 var result IntrospectTokenResponse
246 if err := json.Unmarshal(body, &result); err != nil {
247 return nil, fmt.Errorf("unmarshal response: %w", err)
248 }
249
250 return &result, nil
251}