oauth.go

  1package transport
  2
  3import (
  4	"bytes"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"io"
 10	"net/http"
 11	"net/url"
 12	"strings"
 13	"sync"
 14	"time"
 15)
 16
 17// OAuthConfig holds the OAuth configuration for the client
 18type OAuthConfig struct {
 19	// ClientID is the OAuth client ID
 20	ClientID string
 21	// ClientSecret is the OAuth client secret (for confidential clients)
 22	ClientSecret string
 23	// RedirectURI is the redirect URI for the OAuth flow
 24	RedirectURI string
 25	// Scopes is the list of OAuth scopes to request
 26	Scopes []string
 27	// TokenStore is the storage for OAuth tokens
 28	TokenStore TokenStore
 29	// AuthServerMetadataURL is the URL to the OAuth server metadata
 30	// If empty, the client will attempt to discover it from the base URL
 31	AuthServerMetadataURL string
 32	// PKCEEnabled enables PKCE for the OAuth flow (recommended for public clients)
 33	PKCEEnabled bool
 34}
 35
 36// TokenStore is an interface for storing and retrieving OAuth tokens
 37type TokenStore interface {
 38	// GetToken returns the current token
 39	GetToken() (*Token, error)
 40	// SaveToken saves a token
 41	SaveToken(token *Token) error
 42}
 43
 44// Token represents an OAuth token
 45type Token struct {
 46	// AccessToken is the OAuth access token
 47	AccessToken string `json:"access_token"`
 48	// TokenType is the type of token (usually "Bearer")
 49	TokenType string `json:"token_type"`
 50	// RefreshToken is the OAuth refresh token
 51	RefreshToken string `json:"refresh_token,omitempty"`
 52	// ExpiresIn is the number of seconds until the token expires
 53	ExpiresIn int64 `json:"expires_in,omitempty"`
 54	// Scope is the scope of the token
 55	Scope string `json:"scope,omitempty"`
 56	// ExpiresAt is the time when the token expires
 57	ExpiresAt time.Time `json:"expires_at,omitempty"`
 58}
 59
 60// IsExpired returns true if the token is expired
 61func (t *Token) IsExpired() bool {
 62	if t.ExpiresAt.IsZero() {
 63		return false
 64	}
 65	return time.Now().After(t.ExpiresAt)
 66}
 67
 68// MemoryTokenStore is a simple in-memory token store
 69type MemoryTokenStore struct {
 70	token *Token
 71	mu    sync.RWMutex
 72}
 73
 74// NewMemoryTokenStore creates a new in-memory token store
 75func NewMemoryTokenStore() *MemoryTokenStore {
 76	return &MemoryTokenStore{}
 77}
 78
 79// GetToken returns the current token
 80func (s *MemoryTokenStore) GetToken() (*Token, error) {
 81	s.mu.RLock()
 82	defer s.mu.RUnlock()
 83	if s.token == nil {
 84		return nil, errors.New("no token available")
 85	}
 86	return s.token, nil
 87}
 88
 89// SaveToken saves a token
 90func (s *MemoryTokenStore) SaveToken(token *Token) error {
 91	s.mu.Lock()
 92	defer s.mu.Unlock()
 93	s.token = token
 94	return nil
 95}
 96
 97// AuthServerMetadata represents the OAuth 2.0 Authorization Server Metadata
 98type AuthServerMetadata struct {
 99	Issuer                            string   `json:"issuer"`
100	AuthorizationEndpoint             string   `json:"authorization_endpoint"`
101	TokenEndpoint                     string   `json:"token_endpoint"`
102	RegistrationEndpoint              string   `json:"registration_endpoint,omitempty"`
103	JwksURI                           string   `json:"jwks_uri,omitempty"`
104	ScopesSupported                   []string `json:"scopes_supported,omitempty"`
105	ResponseTypesSupported            []string `json:"response_types_supported"`
106	GrantTypesSupported               []string `json:"grant_types_supported,omitempty"`
107	TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported,omitempty"`
108}
109
110// OAuthHandler handles OAuth authentication for HTTP requests
111type OAuthHandler struct {
112	config           OAuthConfig
113	httpClient       *http.Client
114	serverMetadata   *AuthServerMetadata
115	metadataFetchErr error
116	metadataOnce     sync.Once
117	baseURL          string
118	expectedState    string // Expected state value for CSRF protection
119}
120
121// NewOAuthHandler creates a new OAuth handler
122func NewOAuthHandler(config OAuthConfig) *OAuthHandler {
123	if config.TokenStore == nil {
124		config.TokenStore = NewMemoryTokenStore()
125	}
126
127	return &OAuthHandler{
128		config:     config,
129		httpClient: &http.Client{Timeout: 30 * time.Second},
130	}
131}
132
133// GetAuthorizationHeader returns the Authorization header value for a request
134func (h *OAuthHandler) GetAuthorizationHeader(ctx context.Context) (string, error) {
135	token, err := h.getValidToken(ctx)
136	if err != nil {
137		return "", err
138	}
139
140	// Some auth implementations are strict about token type
141	tokenType := token.TokenType
142	if tokenType == "bearer" {
143		tokenType = "Bearer"
144	}
145
146	return fmt.Sprintf("%s %s", tokenType, token.AccessToken), nil
147}
148
149// getValidToken returns a valid token, refreshing if necessary
150func (h *OAuthHandler) getValidToken(ctx context.Context) (*Token, error) {
151	token, err := h.config.TokenStore.GetToken()
152	if err == nil && !token.IsExpired() && token.AccessToken != "" {
153		return token, nil
154	}
155
156	// If we have a refresh token, try to use it
157	if err == nil && token.RefreshToken != "" {
158		newToken, err := h.refreshToken(ctx, token.RefreshToken)
159		if err == nil {
160			return newToken, nil
161		}
162		// If refresh fails, continue to authorization flow
163	}
164
165	// We need to get a new token through the authorization flow
166	return nil, ErrOAuthAuthorizationRequired
167}
168
169// refreshToken refreshes an OAuth token
170func (h *OAuthHandler) refreshToken(ctx context.Context, refreshToken string) (*Token, error) {
171	metadata, err := h.getServerMetadata(ctx)
172	if err != nil {
173		return nil, fmt.Errorf("failed to get server metadata: %w", err)
174	}
175
176	data := url.Values{}
177	data.Set("grant_type", "refresh_token")
178	data.Set("refresh_token", refreshToken)
179	data.Set("client_id", h.config.ClientID)
180	if h.config.ClientSecret != "" {
181		data.Set("client_secret", h.config.ClientSecret)
182	}
183
184	req, err := http.NewRequestWithContext(
185		ctx,
186		http.MethodPost,
187		metadata.TokenEndpoint,
188		strings.NewReader(data.Encode()),
189	)
190	if err != nil {
191		return nil, fmt.Errorf("failed to create refresh token request: %w", err)
192	}
193
194	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
195	req.Header.Set("Accept", "application/json")
196
197	resp, err := h.httpClient.Do(req)
198	if err != nil {
199		return nil, fmt.Errorf("failed to send refresh token request: %w", err)
200	}
201	defer resp.Body.Close()
202
203	if resp.StatusCode != http.StatusOK {
204		body, _ := io.ReadAll(resp.Body)
205		return nil, extractOAuthError(body, resp.StatusCode, "refresh token request failed")
206	}
207
208	var tokenResp Token
209	if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
210		return nil, fmt.Errorf("failed to decode token response: %w", err)
211	}
212
213	// Set expiration time
214	if tokenResp.ExpiresIn > 0 {
215		tokenResp.ExpiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
216	}
217
218	// If no new refresh token is provided, keep the old one
219	oldToken, _ := h.config.TokenStore.GetToken()
220	if tokenResp.RefreshToken == "" && oldToken != nil {
221		tokenResp.RefreshToken = oldToken.RefreshToken
222	}
223
224	// Save the token
225	if err := h.config.TokenStore.SaveToken(&tokenResp); err != nil {
226		return nil, fmt.Errorf("failed to save token: %w", err)
227	}
228
229	return &tokenResp, nil
230}
231
232// RefreshToken is a public wrapper for refreshToken
233func (h *OAuthHandler) RefreshToken(ctx context.Context, refreshToken string) (*Token, error) {
234	return h.refreshToken(ctx, refreshToken)
235}
236
237// GetClientID returns the client ID
238func (h *OAuthHandler) GetClientID() string {
239	return h.config.ClientID
240}
241
242// extractOAuthError attempts to parse an OAuth error response from the response body
243func extractOAuthError(body []byte, statusCode int, context string) error {
244	// Try to parse the error as an OAuth error response
245	var oauthErr OAuthError
246	if err := json.Unmarshal(body, &oauthErr); err == nil && oauthErr.ErrorCode != "" {
247		return fmt.Errorf("%s: %w", context, oauthErr)
248	}
249
250	// If not a valid OAuth error, return the raw response
251	return fmt.Errorf("%s with status %d: %s", context, statusCode, body)
252}
253
254// GetClientSecret returns the client secret
255func (h *OAuthHandler) GetClientSecret() string {
256	return h.config.ClientSecret
257}
258
259// SetBaseURL sets the base URL for the API server
260func (h *OAuthHandler) SetBaseURL(baseURL string) {
261	h.baseURL = baseURL
262}
263
264// GetExpectedState returns the expected state value (for testing purposes)
265func (h *OAuthHandler) GetExpectedState() string {
266	return h.expectedState
267}
268
269// OAuthError represents a standard OAuth 2.0 error response
270type OAuthError struct {
271	ErrorCode        string `json:"error"`
272	ErrorDescription string `json:"error_description,omitempty"`
273	ErrorURI         string `json:"error_uri,omitempty"`
274}
275
276// Error implements the error interface
277func (e OAuthError) Error() string {
278	if e.ErrorDescription != "" {
279		return fmt.Sprintf("OAuth error: %s - %s", e.ErrorCode, e.ErrorDescription)
280	}
281	return fmt.Sprintf("OAuth error: %s", e.ErrorCode)
282}
283
284// OAuthProtectedResource represents the response from /.well-known/oauth-protected-resource
285type OAuthProtectedResource struct {
286	AuthorizationServers []string `json:"authorization_servers"`
287	Resource             string   `json:"resource"`
288	ResourceName         string   `json:"resource_name,omitempty"`
289}
290
291// getServerMetadata fetches the OAuth server metadata
292func (h *OAuthHandler) getServerMetadata(ctx context.Context) (*AuthServerMetadata, error) {
293	h.metadataOnce.Do(func() {
294		// If AuthServerMetadataURL is explicitly provided, use it directly
295		if h.config.AuthServerMetadataURL != "" {
296			h.fetchMetadataFromURL(ctx, h.config.AuthServerMetadataURL)
297			return
298		}
299
300		// Try to discover the authorization server via OAuth Protected Resource
301		// as per RFC 9728 (https://datatracker.ietf.org/doc/html/rfc9728)
302		baseURL, err := h.extractBaseURL()
303		if err != nil {
304			h.metadataFetchErr = fmt.Errorf("failed to extract base URL: %w", err)
305			return
306		}
307
308		// Try to fetch the OAuth Protected Resource metadata
309		protectedResourceURL := baseURL + "/.well-known/oauth-protected-resource"
310		req, err := http.NewRequestWithContext(ctx, http.MethodGet, protectedResourceURL, nil)
311		if err != nil {
312			h.metadataFetchErr = fmt.Errorf("failed to create protected resource request: %w", err)
313			return
314		}
315
316		req.Header.Set("Accept", "application/json")
317		req.Header.Set("MCP-Protocol-Version", "2025-03-26")
318
319		resp, err := h.httpClient.Do(req)
320		if err != nil {
321			h.metadataFetchErr = fmt.Errorf("failed to send protected resource request: %w", err)
322			return
323		}
324		defer resp.Body.Close()
325
326		// If we can't get the protected resource metadata, fall back to default endpoints
327		if resp.StatusCode != http.StatusOK {
328			metadata, err := h.getDefaultEndpoints(baseURL)
329			if err != nil {
330				h.metadataFetchErr = fmt.Errorf("failed to get default endpoints: %w", err)
331				return
332			}
333			h.serverMetadata = metadata
334			return
335		}
336
337		// Parse the protected resource metadata
338		var protectedResource OAuthProtectedResource
339		if err := json.NewDecoder(resp.Body).Decode(&protectedResource); err != nil {
340			h.metadataFetchErr = fmt.Errorf("failed to decode protected resource response: %w", err)
341			return
342		}
343
344		// If no authorization servers are specified, fall back to default endpoints
345		if len(protectedResource.AuthorizationServers) == 0 {
346			metadata, err := h.getDefaultEndpoints(baseURL)
347			if err != nil {
348				h.metadataFetchErr = fmt.Errorf("failed to get default endpoints: %w", err)
349				return
350			}
351			h.serverMetadata = metadata
352			return
353		}
354
355		// Use the first authorization server
356		authServerURL := protectedResource.AuthorizationServers[0]
357
358		// Try OpenID Connect discovery first
359		h.fetchMetadataFromURL(ctx, authServerURL+"/.well-known/openid-configuration")
360		if h.serverMetadata != nil {
361			return
362		}
363
364		// If OpenID Connect discovery fails, try OAuth Authorization Server Metadata
365		h.fetchMetadataFromURL(ctx, authServerURL+"/.well-known/oauth-authorization-server")
366		if h.serverMetadata != nil {
367			return
368		}
369
370		// If both discovery methods fail, use default endpoints based on the authorization server URL
371		metadata, err := h.getDefaultEndpoints(authServerURL)
372		if err != nil {
373			h.metadataFetchErr = fmt.Errorf("failed to get default endpoints: %w", err)
374			return
375		}
376		h.serverMetadata = metadata
377	})
378
379	if h.metadataFetchErr != nil {
380		return nil, h.metadataFetchErr
381	}
382
383	return h.serverMetadata, nil
384}
385
386// fetchMetadataFromURL fetches and parses OAuth server metadata from a URL
387func (h *OAuthHandler) fetchMetadataFromURL(ctx context.Context, metadataURL string) {
388	req, err := http.NewRequestWithContext(ctx, http.MethodGet, metadataURL, nil)
389	if err != nil {
390		h.metadataFetchErr = fmt.Errorf("failed to create metadata request: %w", err)
391		return
392	}
393
394	req.Header.Set("Accept", "application/json")
395	req.Header.Set("MCP-Protocol-Version", "2025-03-26")
396
397	resp, err := h.httpClient.Do(req)
398	if err != nil {
399		h.metadataFetchErr = fmt.Errorf("failed to send metadata request: %w", err)
400		return
401	}
402	defer resp.Body.Close()
403
404	if resp.StatusCode != http.StatusOK {
405		// If metadata discovery fails, don't set any metadata
406		return
407	}
408
409	var metadata AuthServerMetadata
410	if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
411		h.metadataFetchErr = fmt.Errorf("failed to decode metadata response: %w", err)
412		return
413	}
414
415	h.serverMetadata = &metadata
416}
417
418// extractBaseURL extracts the base URL from the first request
419func (h *OAuthHandler) extractBaseURL() (string, error) {
420	// If we have a base URL from a previous request, use it
421	if h.baseURL != "" {
422		return h.baseURL, nil
423	}
424
425	// Otherwise, we need to infer it from the redirect URI
426	if h.config.RedirectURI == "" {
427		return "", fmt.Errorf("no base URL available and no redirect URI provided")
428	}
429
430	// Parse the redirect URI to extract the authority
431	parsedURL, err := url.Parse(h.config.RedirectURI)
432	if err != nil {
433		return "", fmt.Errorf("failed to parse redirect URI: %w", err)
434	}
435
436	// Use the scheme and host from the redirect URI
437	baseURL := fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host)
438	return baseURL, nil
439}
440
441// GetServerMetadata is a public wrapper for getServerMetadata
442func (h *OAuthHandler) GetServerMetadata(ctx context.Context) (*AuthServerMetadata, error) {
443	return h.getServerMetadata(ctx)
444}
445
446// getDefaultEndpoints returns default OAuth endpoints based on the base URL
447func (h *OAuthHandler) getDefaultEndpoints(baseURL string) (*AuthServerMetadata, error) {
448	// Parse the base URL to extract the authority
449	parsedURL, err := url.Parse(baseURL)
450	if err != nil {
451		return nil, fmt.Errorf("failed to parse base URL: %w", err)
452	}
453
454	// Discard any path component to get the authorization base URL
455	parsedURL.Path = ""
456	authBaseURL := parsedURL.String()
457
458	// Validate that the URL has a scheme and host
459	if parsedURL.Scheme == "" || parsedURL.Host == "" {
460		return nil, fmt.Errorf("invalid base URL: missing scheme or host in %q", baseURL)
461	}
462
463	return &AuthServerMetadata{
464		Issuer:                authBaseURL,
465		AuthorizationEndpoint: authBaseURL + "/authorize",
466		TokenEndpoint:         authBaseURL + "/token",
467		RegistrationEndpoint:  authBaseURL + "/register",
468	}, nil
469}
470
471// RegisterClient performs dynamic client registration
472func (h *OAuthHandler) RegisterClient(ctx context.Context, clientName string) error {
473	metadata, err := h.getServerMetadata(ctx)
474	if err != nil {
475		return fmt.Errorf("failed to get server metadata: %w", err)
476	}
477
478	if metadata.RegistrationEndpoint == "" {
479		return errors.New("server does not support dynamic client registration")
480	}
481
482	// Prepare registration request
483	regRequest := map[string]any{
484		"client_name":                clientName,
485		"redirect_uris":              []string{h.config.RedirectURI},
486		"token_endpoint_auth_method": "none", // For public clients
487		"grant_types":                []string{"authorization_code", "refresh_token"},
488		"response_types":             []string{"code"},
489		"scope":                      strings.Join(h.config.Scopes, " "),
490	}
491
492	// Add client_secret if this is a confidential client
493	if h.config.ClientSecret != "" {
494		regRequest["token_endpoint_auth_method"] = "client_secret_basic"
495	}
496
497	reqBody, err := json.Marshal(regRequest)
498	if err != nil {
499		return fmt.Errorf("failed to marshal registration request: %w", err)
500	}
501
502	req, err := http.NewRequestWithContext(
503		ctx,
504		http.MethodPost,
505		metadata.RegistrationEndpoint,
506		bytes.NewReader(reqBody),
507	)
508	if err != nil {
509		return fmt.Errorf("failed to create registration request: %w", err)
510	}
511
512	req.Header.Set("Content-Type", "application/json")
513	req.Header.Set("Accept", "application/json")
514
515	resp, err := h.httpClient.Do(req)
516	if err != nil {
517		return fmt.Errorf("failed to send registration request: %w", err)
518	}
519	defer resp.Body.Close()
520
521	if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
522		body, _ := io.ReadAll(resp.Body)
523		return extractOAuthError(body, resp.StatusCode, "registration request failed")
524	}
525
526	var regResponse struct {
527		ClientID     string `json:"client_id"`
528		ClientSecret string `json:"client_secret,omitempty"`
529	}
530
531	if err := json.NewDecoder(resp.Body).Decode(&regResponse); err != nil {
532		return fmt.Errorf("failed to decode registration response: %w", err)
533	}
534
535	// Update the client configuration
536	h.config.ClientID = regResponse.ClientID
537	if regResponse.ClientSecret != "" {
538		h.config.ClientSecret = regResponse.ClientSecret
539	}
540
541	return nil
542}
543
544// ErrInvalidState is returned when the state parameter doesn't match the expected value
545var ErrInvalidState = errors.New("invalid state parameter, possible CSRF attack")
546
547// ProcessAuthorizationResponse processes the authorization response and exchanges the code for a token
548func (h *OAuthHandler) ProcessAuthorizationResponse(ctx context.Context, code, state, codeVerifier string) error {
549	// Validate the state parameter to prevent CSRF attacks
550	if h.expectedState == "" {
551		return errors.New("no expected state found, authorization flow may not have been initiated properly")
552	}
553
554	if state != h.expectedState {
555		return ErrInvalidState
556	}
557
558	// Clear the expected state after validation
559	defer func() {
560		h.expectedState = ""
561	}()
562
563	metadata, err := h.getServerMetadata(ctx)
564	if err != nil {
565		return fmt.Errorf("failed to get server metadata: %w", err)
566	}
567
568	data := url.Values{}
569	data.Set("grant_type", "authorization_code")
570	data.Set("code", code)
571	data.Set("client_id", h.config.ClientID)
572	data.Set("redirect_uri", h.config.RedirectURI)
573
574	if h.config.ClientSecret != "" {
575		data.Set("client_secret", h.config.ClientSecret)
576	}
577
578	if h.config.PKCEEnabled && codeVerifier != "" {
579		data.Set("code_verifier", codeVerifier)
580	}
581
582	req, err := http.NewRequestWithContext(
583		ctx,
584		http.MethodPost,
585		metadata.TokenEndpoint,
586		strings.NewReader(data.Encode()),
587	)
588	if err != nil {
589		return fmt.Errorf("failed to create token request: %w", err)
590	}
591
592	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
593	req.Header.Set("Accept", "application/json")
594
595	resp, err := h.httpClient.Do(req)
596	if err != nil {
597		return fmt.Errorf("failed to send token request: %w", err)
598	}
599	defer resp.Body.Close()
600
601	if resp.StatusCode != http.StatusOK {
602		body, _ := io.ReadAll(resp.Body)
603		return extractOAuthError(body, resp.StatusCode, "token request failed")
604	}
605
606	var tokenResp Token
607	if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
608		return fmt.Errorf("failed to decode token response: %w", err)
609	}
610
611	// Set expiration time
612	if tokenResp.ExpiresIn > 0 {
613		tokenResp.ExpiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
614	}
615
616	// Save the token
617	if err := h.config.TokenStore.SaveToken(&tokenResp); err != nil {
618		return fmt.Errorf("failed to save token: %w", err)
619	}
620
621	return nil
622}
623
624// GetAuthorizationURL returns the URL for the authorization endpoint
625func (h *OAuthHandler) GetAuthorizationURL(ctx context.Context, state, codeChallenge string) (string, error) {
626	metadata, err := h.getServerMetadata(ctx)
627	if err != nil {
628		return "", fmt.Errorf("failed to get server metadata: %w", err)
629	}
630
631	// Store the state for later validation
632	h.expectedState = state
633
634	params := url.Values{}
635	params.Set("response_type", "code")
636	params.Set("client_id", h.config.ClientID)
637	params.Set("redirect_uri", h.config.RedirectURI)
638	params.Set("state", state)
639
640	if len(h.config.Scopes) > 0 {
641		params.Set("scope", strings.Join(h.config.Scopes, " "))
642	}
643
644	if h.config.PKCEEnabled && codeChallenge != "" {
645		params.Set("code_challenge", codeChallenge)
646		params.Set("code_challenge_method", "S256")
647	}
648
649	return metadata.AuthorizationEndpoint + "?" + params.Encode(), nil
650}