1package transport
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "net/http"
11 "net/url"
12 "strings"
13 "sync"
14 "time"
15)
16
17// OAuthConfig holds the OAuth configuration for the client
18type OAuthConfig struct {
19 // ClientID is the OAuth client ID
20 ClientID string
21 // ClientSecret is the OAuth client secret (for confidential clients)
22 ClientSecret string
23 // RedirectURI is the redirect URI for the OAuth flow
24 RedirectURI string
25 // Scopes is the list of OAuth scopes to request
26 Scopes []string
27 // TokenStore is the storage for OAuth tokens
28 TokenStore TokenStore
29 // AuthServerMetadataURL is the URL to the OAuth server metadata
30 // If empty, the client will attempt to discover it from the base URL
31 AuthServerMetadataURL string
32 // PKCEEnabled enables PKCE for the OAuth flow (recommended for public clients)
33 PKCEEnabled bool
34}
35
36// TokenStore is an interface for storing and retrieving OAuth tokens
37type TokenStore interface {
38 // GetToken returns the current token
39 GetToken() (*Token, error)
40 // SaveToken saves a token
41 SaveToken(token *Token) error
42}
43
44// Token represents an OAuth token
45type Token struct {
46 // AccessToken is the OAuth access token
47 AccessToken string `json:"access_token"`
48 // TokenType is the type of token (usually "Bearer")
49 TokenType string `json:"token_type"`
50 // RefreshToken is the OAuth refresh token
51 RefreshToken string `json:"refresh_token,omitempty"`
52 // ExpiresIn is the number of seconds until the token expires
53 ExpiresIn int64 `json:"expires_in,omitempty"`
54 // Scope is the scope of the token
55 Scope string `json:"scope,omitempty"`
56 // ExpiresAt is the time when the token expires
57 ExpiresAt time.Time `json:"expires_at,omitempty"`
58}
59
60// IsExpired returns true if the token is expired
61func (t *Token) IsExpired() bool {
62 if t.ExpiresAt.IsZero() {
63 return false
64 }
65 return time.Now().After(t.ExpiresAt)
66}
67
68// MemoryTokenStore is a simple in-memory token store
69type MemoryTokenStore struct {
70 token *Token
71 mu sync.RWMutex
72}
73
74// NewMemoryTokenStore creates a new in-memory token store
75func NewMemoryTokenStore() *MemoryTokenStore {
76 return &MemoryTokenStore{}
77}
78
79// GetToken returns the current token
80func (s *MemoryTokenStore) GetToken() (*Token, error) {
81 s.mu.RLock()
82 defer s.mu.RUnlock()
83 if s.token == nil {
84 return nil, errors.New("no token available")
85 }
86 return s.token, nil
87}
88
89// SaveToken saves a token
90func (s *MemoryTokenStore) SaveToken(token *Token) error {
91 s.mu.Lock()
92 defer s.mu.Unlock()
93 s.token = token
94 return nil
95}
96
97// AuthServerMetadata represents the OAuth 2.0 Authorization Server Metadata
98type AuthServerMetadata struct {
99 Issuer string `json:"issuer"`
100 AuthorizationEndpoint string `json:"authorization_endpoint"`
101 TokenEndpoint string `json:"token_endpoint"`
102 RegistrationEndpoint string `json:"registration_endpoint,omitempty"`
103 JwksURI string `json:"jwks_uri,omitempty"`
104 ScopesSupported []string `json:"scopes_supported,omitempty"`
105 ResponseTypesSupported []string `json:"response_types_supported"`
106 GrantTypesSupported []string `json:"grant_types_supported,omitempty"`
107 TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported,omitempty"`
108}
109
110// OAuthHandler handles OAuth authentication for HTTP requests
111type OAuthHandler struct {
112 config OAuthConfig
113 httpClient *http.Client
114 serverMetadata *AuthServerMetadata
115 metadataFetchErr error
116 metadataOnce sync.Once
117 baseURL string
118 expectedState string // Expected state value for CSRF protection
119}
120
121// NewOAuthHandler creates a new OAuth handler
122func NewOAuthHandler(config OAuthConfig) *OAuthHandler {
123 if config.TokenStore == nil {
124 config.TokenStore = NewMemoryTokenStore()
125 }
126
127 return &OAuthHandler{
128 config: config,
129 httpClient: &http.Client{Timeout: 30 * time.Second},
130 }
131}
132
133// GetAuthorizationHeader returns the Authorization header value for a request
134func (h *OAuthHandler) GetAuthorizationHeader(ctx context.Context) (string, error) {
135 token, err := h.getValidToken(ctx)
136 if err != nil {
137 return "", err
138 }
139
140 // Some auth implementations are strict about token type
141 tokenType := token.TokenType
142 if tokenType == "bearer" {
143 tokenType = "Bearer"
144 }
145
146 return fmt.Sprintf("%s %s", tokenType, token.AccessToken), nil
147}
148
149// getValidToken returns a valid token, refreshing if necessary
150func (h *OAuthHandler) getValidToken(ctx context.Context) (*Token, error) {
151 token, err := h.config.TokenStore.GetToken()
152 if err == nil && !token.IsExpired() && token.AccessToken != "" {
153 return token, nil
154 }
155
156 // If we have a refresh token, try to use it
157 if err == nil && token.RefreshToken != "" {
158 newToken, err := h.refreshToken(ctx, token.RefreshToken)
159 if err == nil {
160 return newToken, nil
161 }
162 // If refresh fails, continue to authorization flow
163 }
164
165 // We need to get a new token through the authorization flow
166 return nil, ErrOAuthAuthorizationRequired
167}
168
169// refreshToken refreshes an OAuth token
170func (h *OAuthHandler) refreshToken(ctx context.Context, refreshToken string) (*Token, error) {
171 metadata, err := h.getServerMetadata(ctx)
172 if err != nil {
173 return nil, fmt.Errorf("failed to get server metadata: %w", err)
174 }
175
176 data := url.Values{}
177 data.Set("grant_type", "refresh_token")
178 data.Set("refresh_token", refreshToken)
179 data.Set("client_id", h.config.ClientID)
180 if h.config.ClientSecret != "" {
181 data.Set("client_secret", h.config.ClientSecret)
182 }
183
184 req, err := http.NewRequestWithContext(
185 ctx,
186 http.MethodPost,
187 metadata.TokenEndpoint,
188 strings.NewReader(data.Encode()),
189 )
190 if err != nil {
191 return nil, fmt.Errorf("failed to create refresh token request: %w", err)
192 }
193
194 req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
195 req.Header.Set("Accept", "application/json")
196
197 resp, err := h.httpClient.Do(req)
198 if err != nil {
199 return nil, fmt.Errorf("failed to send refresh token request: %w", err)
200 }
201 defer resp.Body.Close()
202
203 if resp.StatusCode != http.StatusOK {
204 body, _ := io.ReadAll(resp.Body)
205 return nil, extractOAuthError(body, resp.StatusCode, "refresh token request failed")
206 }
207
208 var tokenResp Token
209 if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
210 return nil, fmt.Errorf("failed to decode token response: %w", err)
211 }
212
213 // Set expiration time
214 if tokenResp.ExpiresIn > 0 {
215 tokenResp.ExpiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
216 }
217
218 // If no new refresh token is provided, keep the old one
219 oldToken, _ := h.config.TokenStore.GetToken()
220 if tokenResp.RefreshToken == "" && oldToken != nil {
221 tokenResp.RefreshToken = oldToken.RefreshToken
222 }
223
224 // Save the token
225 if err := h.config.TokenStore.SaveToken(&tokenResp); err != nil {
226 return nil, fmt.Errorf("failed to save token: %w", err)
227 }
228
229 return &tokenResp, nil
230}
231
232// RefreshToken is a public wrapper for refreshToken
233func (h *OAuthHandler) RefreshToken(ctx context.Context, refreshToken string) (*Token, error) {
234 return h.refreshToken(ctx, refreshToken)
235}
236
237// GetClientID returns the client ID
238func (h *OAuthHandler) GetClientID() string {
239 return h.config.ClientID
240}
241
242// extractOAuthError attempts to parse an OAuth error response from the response body
243func extractOAuthError(body []byte, statusCode int, context string) error {
244 // Try to parse the error as an OAuth error response
245 var oauthErr OAuthError
246 if err := json.Unmarshal(body, &oauthErr); err == nil && oauthErr.ErrorCode != "" {
247 return fmt.Errorf("%s: %w", context, oauthErr)
248 }
249
250 // If not a valid OAuth error, return the raw response
251 return fmt.Errorf("%s with status %d: %s", context, statusCode, body)
252}
253
254// GetClientSecret returns the client secret
255func (h *OAuthHandler) GetClientSecret() string {
256 return h.config.ClientSecret
257}
258
259// SetBaseURL sets the base URL for the API server
260func (h *OAuthHandler) SetBaseURL(baseURL string) {
261 h.baseURL = baseURL
262}
263
264// GetExpectedState returns the expected state value (for testing purposes)
265func (h *OAuthHandler) GetExpectedState() string {
266 return h.expectedState
267}
268
269// OAuthError represents a standard OAuth 2.0 error response
270type OAuthError struct {
271 ErrorCode string `json:"error"`
272 ErrorDescription string `json:"error_description,omitempty"`
273 ErrorURI string `json:"error_uri,omitempty"`
274}
275
276// Error implements the error interface
277func (e OAuthError) Error() string {
278 if e.ErrorDescription != "" {
279 return fmt.Sprintf("OAuth error: %s - %s", e.ErrorCode, e.ErrorDescription)
280 }
281 return fmt.Sprintf("OAuth error: %s", e.ErrorCode)
282}
283
284// OAuthProtectedResource represents the response from /.well-known/oauth-protected-resource
285type OAuthProtectedResource struct {
286 AuthorizationServers []string `json:"authorization_servers"`
287 Resource string `json:"resource"`
288 ResourceName string `json:"resource_name,omitempty"`
289}
290
291// getServerMetadata fetches the OAuth server metadata
292func (h *OAuthHandler) getServerMetadata(ctx context.Context) (*AuthServerMetadata, error) {
293 h.metadataOnce.Do(func() {
294 // If AuthServerMetadataURL is explicitly provided, use it directly
295 if h.config.AuthServerMetadataURL != "" {
296 h.fetchMetadataFromURL(ctx, h.config.AuthServerMetadataURL)
297 return
298 }
299
300 // Try to discover the authorization server via OAuth Protected Resource
301 // as per RFC 9728 (https://datatracker.ietf.org/doc/html/rfc9728)
302 baseURL, err := h.extractBaseURL()
303 if err != nil {
304 h.metadataFetchErr = fmt.Errorf("failed to extract base URL: %w", err)
305 return
306 }
307
308 // Try to fetch the OAuth Protected Resource metadata
309 protectedResourceURL := baseURL + "/.well-known/oauth-protected-resource"
310 req, err := http.NewRequestWithContext(ctx, http.MethodGet, protectedResourceURL, nil)
311 if err != nil {
312 h.metadataFetchErr = fmt.Errorf("failed to create protected resource request: %w", err)
313 return
314 }
315
316 req.Header.Set("Accept", "application/json")
317 req.Header.Set("MCP-Protocol-Version", "2025-03-26")
318
319 resp, err := h.httpClient.Do(req)
320 if err != nil {
321 h.metadataFetchErr = fmt.Errorf("failed to send protected resource request: %w", err)
322 return
323 }
324 defer resp.Body.Close()
325
326 // If we can't get the protected resource metadata, fall back to default endpoints
327 if resp.StatusCode != http.StatusOK {
328 metadata, err := h.getDefaultEndpoints(baseURL)
329 if err != nil {
330 h.metadataFetchErr = fmt.Errorf("failed to get default endpoints: %w", err)
331 return
332 }
333 h.serverMetadata = metadata
334 return
335 }
336
337 // Parse the protected resource metadata
338 var protectedResource OAuthProtectedResource
339 if err := json.NewDecoder(resp.Body).Decode(&protectedResource); err != nil {
340 h.metadataFetchErr = fmt.Errorf("failed to decode protected resource response: %w", err)
341 return
342 }
343
344 // If no authorization servers are specified, fall back to default endpoints
345 if len(protectedResource.AuthorizationServers) == 0 {
346 metadata, err := h.getDefaultEndpoints(baseURL)
347 if err != nil {
348 h.metadataFetchErr = fmt.Errorf("failed to get default endpoints: %w", err)
349 return
350 }
351 h.serverMetadata = metadata
352 return
353 }
354
355 // Use the first authorization server
356 authServerURL := protectedResource.AuthorizationServers[0]
357
358 // Try OpenID Connect discovery first
359 h.fetchMetadataFromURL(ctx, authServerURL+"/.well-known/openid-configuration")
360 if h.serverMetadata != nil {
361 return
362 }
363
364 // If OpenID Connect discovery fails, try OAuth Authorization Server Metadata
365 h.fetchMetadataFromURL(ctx, authServerURL+"/.well-known/oauth-authorization-server")
366 if h.serverMetadata != nil {
367 return
368 }
369
370 // If both discovery methods fail, use default endpoints based on the authorization server URL
371 metadata, err := h.getDefaultEndpoints(authServerURL)
372 if err != nil {
373 h.metadataFetchErr = fmt.Errorf("failed to get default endpoints: %w", err)
374 return
375 }
376 h.serverMetadata = metadata
377 })
378
379 if h.metadataFetchErr != nil {
380 return nil, h.metadataFetchErr
381 }
382
383 return h.serverMetadata, nil
384}
385
386// fetchMetadataFromURL fetches and parses OAuth server metadata from a URL
387func (h *OAuthHandler) fetchMetadataFromURL(ctx context.Context, metadataURL string) {
388 req, err := http.NewRequestWithContext(ctx, http.MethodGet, metadataURL, nil)
389 if err != nil {
390 h.metadataFetchErr = fmt.Errorf("failed to create metadata request: %w", err)
391 return
392 }
393
394 req.Header.Set("Accept", "application/json")
395 req.Header.Set("MCP-Protocol-Version", "2025-03-26")
396
397 resp, err := h.httpClient.Do(req)
398 if err != nil {
399 h.metadataFetchErr = fmt.Errorf("failed to send metadata request: %w", err)
400 return
401 }
402 defer resp.Body.Close()
403
404 if resp.StatusCode != http.StatusOK {
405 // If metadata discovery fails, don't set any metadata
406 return
407 }
408
409 var metadata AuthServerMetadata
410 if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
411 h.metadataFetchErr = fmt.Errorf("failed to decode metadata response: %w", err)
412 return
413 }
414
415 h.serverMetadata = &metadata
416}
417
418// extractBaseURL extracts the base URL from the first request
419func (h *OAuthHandler) extractBaseURL() (string, error) {
420 // If we have a base URL from a previous request, use it
421 if h.baseURL != "" {
422 return h.baseURL, nil
423 }
424
425 // Otherwise, we need to infer it from the redirect URI
426 if h.config.RedirectURI == "" {
427 return "", fmt.Errorf("no base URL available and no redirect URI provided")
428 }
429
430 // Parse the redirect URI to extract the authority
431 parsedURL, err := url.Parse(h.config.RedirectURI)
432 if err != nil {
433 return "", fmt.Errorf("failed to parse redirect URI: %w", err)
434 }
435
436 // Use the scheme and host from the redirect URI
437 baseURL := fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host)
438 return baseURL, nil
439}
440
441// GetServerMetadata is a public wrapper for getServerMetadata
442func (h *OAuthHandler) GetServerMetadata(ctx context.Context) (*AuthServerMetadata, error) {
443 return h.getServerMetadata(ctx)
444}
445
446// getDefaultEndpoints returns default OAuth endpoints based on the base URL
447func (h *OAuthHandler) getDefaultEndpoints(baseURL string) (*AuthServerMetadata, error) {
448 // Parse the base URL to extract the authority
449 parsedURL, err := url.Parse(baseURL)
450 if err != nil {
451 return nil, fmt.Errorf("failed to parse base URL: %w", err)
452 }
453
454 // Discard any path component to get the authorization base URL
455 parsedURL.Path = ""
456 authBaseURL := parsedURL.String()
457
458 // Validate that the URL has a scheme and host
459 if parsedURL.Scheme == "" || parsedURL.Host == "" {
460 return nil, fmt.Errorf("invalid base URL: missing scheme or host in %q", baseURL)
461 }
462
463 return &AuthServerMetadata{
464 Issuer: authBaseURL,
465 AuthorizationEndpoint: authBaseURL + "/authorize",
466 TokenEndpoint: authBaseURL + "/token",
467 RegistrationEndpoint: authBaseURL + "/register",
468 }, nil
469}
470
471// RegisterClient performs dynamic client registration
472func (h *OAuthHandler) RegisterClient(ctx context.Context, clientName string) error {
473 metadata, err := h.getServerMetadata(ctx)
474 if err != nil {
475 return fmt.Errorf("failed to get server metadata: %w", err)
476 }
477
478 if metadata.RegistrationEndpoint == "" {
479 return errors.New("server does not support dynamic client registration")
480 }
481
482 // Prepare registration request
483 regRequest := map[string]any{
484 "client_name": clientName,
485 "redirect_uris": []string{h.config.RedirectURI},
486 "token_endpoint_auth_method": "none", // For public clients
487 "grant_types": []string{"authorization_code", "refresh_token"},
488 "response_types": []string{"code"},
489 "scope": strings.Join(h.config.Scopes, " "),
490 }
491
492 // Add client_secret if this is a confidential client
493 if h.config.ClientSecret != "" {
494 regRequest["token_endpoint_auth_method"] = "client_secret_basic"
495 }
496
497 reqBody, err := json.Marshal(regRequest)
498 if err != nil {
499 return fmt.Errorf("failed to marshal registration request: %w", err)
500 }
501
502 req, err := http.NewRequestWithContext(
503 ctx,
504 http.MethodPost,
505 metadata.RegistrationEndpoint,
506 bytes.NewReader(reqBody),
507 )
508 if err != nil {
509 return fmt.Errorf("failed to create registration request: %w", err)
510 }
511
512 req.Header.Set("Content-Type", "application/json")
513 req.Header.Set("Accept", "application/json")
514
515 resp, err := h.httpClient.Do(req)
516 if err != nil {
517 return fmt.Errorf("failed to send registration request: %w", err)
518 }
519 defer resp.Body.Close()
520
521 if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
522 body, _ := io.ReadAll(resp.Body)
523 return extractOAuthError(body, resp.StatusCode, "registration request failed")
524 }
525
526 var regResponse struct {
527 ClientID string `json:"client_id"`
528 ClientSecret string `json:"client_secret,omitempty"`
529 }
530
531 if err := json.NewDecoder(resp.Body).Decode(®Response); err != nil {
532 return fmt.Errorf("failed to decode registration response: %w", err)
533 }
534
535 // Update the client configuration
536 h.config.ClientID = regResponse.ClientID
537 if regResponse.ClientSecret != "" {
538 h.config.ClientSecret = regResponse.ClientSecret
539 }
540
541 return nil
542}
543
544// ErrInvalidState is returned when the state parameter doesn't match the expected value
545var ErrInvalidState = errors.New("invalid state parameter, possible CSRF attack")
546
547// ProcessAuthorizationResponse processes the authorization response and exchanges the code for a token
548func (h *OAuthHandler) ProcessAuthorizationResponse(ctx context.Context, code, state, codeVerifier string) error {
549 // Validate the state parameter to prevent CSRF attacks
550 if h.expectedState == "" {
551 return errors.New("no expected state found, authorization flow may not have been initiated properly")
552 }
553
554 if state != h.expectedState {
555 return ErrInvalidState
556 }
557
558 // Clear the expected state after validation
559 defer func() {
560 h.expectedState = ""
561 }()
562
563 metadata, err := h.getServerMetadata(ctx)
564 if err != nil {
565 return fmt.Errorf("failed to get server metadata: %w", err)
566 }
567
568 data := url.Values{}
569 data.Set("grant_type", "authorization_code")
570 data.Set("code", code)
571 data.Set("client_id", h.config.ClientID)
572 data.Set("redirect_uri", h.config.RedirectURI)
573
574 if h.config.ClientSecret != "" {
575 data.Set("client_secret", h.config.ClientSecret)
576 }
577
578 if h.config.PKCEEnabled && codeVerifier != "" {
579 data.Set("code_verifier", codeVerifier)
580 }
581
582 req, err := http.NewRequestWithContext(
583 ctx,
584 http.MethodPost,
585 metadata.TokenEndpoint,
586 strings.NewReader(data.Encode()),
587 )
588 if err != nil {
589 return fmt.Errorf("failed to create token request: %w", err)
590 }
591
592 req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
593 req.Header.Set("Accept", "application/json")
594
595 resp, err := h.httpClient.Do(req)
596 if err != nil {
597 return fmt.Errorf("failed to send token request: %w", err)
598 }
599 defer resp.Body.Close()
600
601 if resp.StatusCode != http.StatusOK {
602 body, _ := io.ReadAll(resp.Body)
603 return extractOAuthError(body, resp.StatusCode, "token request failed")
604 }
605
606 var tokenResp Token
607 if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
608 return fmt.Errorf("failed to decode token response: %w", err)
609 }
610
611 // Set expiration time
612 if tokenResp.ExpiresIn > 0 {
613 tokenResp.ExpiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
614 }
615
616 // Save the token
617 if err := h.config.TokenStore.SaveToken(&tokenResp); err != nil {
618 return fmt.Errorf("failed to save token: %w", err)
619 }
620
621 return nil
622}
623
624// GetAuthorizationURL returns the URL for the authorization endpoint
625func (h *OAuthHandler) GetAuthorizationURL(ctx context.Context, state, codeChallenge string) (string, error) {
626 metadata, err := h.getServerMetadata(ctx)
627 if err != nil {
628 return "", fmt.Errorf("failed to get server metadata: %w", err)
629 }
630
631 // Store the state for later validation
632 h.expectedState = state
633
634 params := url.Values{}
635 params.Set("response_type", "code")
636 params.Set("client_id", h.config.ClientID)
637 params.Set("redirect_uri", h.config.RedirectURI)
638 params.Set("state", state)
639
640 if len(h.config.Scopes) > 0 {
641 params.Set("scope", strings.Join(h.config.Scopes, " "))
642 }
643
644 if h.config.PKCEEnabled && codeChallenge != "" {
645 params.Set("code_challenge", codeChallenge)
646 params.Set("code_challenge_method", "S256")
647 }
648
649 return metadata.AuthorizationEndpoint + "?" + params.Encode(), nil
650}