oauth.go

 1package client
 2
 3import (
 4	"errors"
 5	"fmt"
 6
 7	"github.com/mark3labs/mcp-go/client/transport"
 8)
 9
10// OAuthConfig is a convenience type that wraps transport.OAuthConfig
11type OAuthConfig = transport.OAuthConfig
12
13// Token is a convenience type that wraps transport.Token
14type Token = transport.Token
15
16// TokenStore is a convenience type that wraps transport.TokenStore
17type TokenStore = transport.TokenStore
18
19// MemoryTokenStore is a convenience type that wraps transport.MemoryTokenStore
20type MemoryTokenStore = transport.MemoryTokenStore
21
22// NewMemoryTokenStore is a convenience function that wraps transport.NewMemoryTokenStore
23var NewMemoryTokenStore = transport.NewMemoryTokenStore
24
25// NewOAuthStreamableHttpClient creates a new streamable-http-based MCP client with OAuth support.
26// Returns an error if the URL is invalid.
27func NewOAuthStreamableHttpClient(baseURL string, oauthConfig OAuthConfig, options ...transport.StreamableHTTPCOption) (*Client, error) {
28	// Add OAuth option to the list of options
29	options = append(options, transport.WithHTTPOAuth(oauthConfig))
30
31	trans, err := transport.NewStreamableHTTP(baseURL, options...)
32	if err != nil {
33		return nil, fmt.Errorf("failed to create HTTP transport: %w", err)
34	}
35	return NewClient(trans), nil
36}
37
38// NewOAuthStreamableHttpClient creates a new streamable-http-based MCP client with OAuth support.
39// Returns an error if the URL is invalid.
40func NewOAuthSSEClient(baseURL string, oauthConfig OAuthConfig, options ...transport.ClientOption) (*Client, error) {
41	// Add OAuth option to the list of options
42	options = append(options, transport.WithOAuth(oauthConfig))
43
44	trans, err := transport.NewSSE(baseURL, options...)
45	if err != nil {
46		return nil, fmt.Errorf("failed to create SSE transport: %w", err)
47	}
48	return NewClient(trans), nil
49}
50
51// GenerateCodeVerifier generates a code verifier for PKCE
52var GenerateCodeVerifier = transport.GenerateCodeVerifier
53
54// GenerateCodeChallenge generates a code challenge from a code verifier
55var GenerateCodeChallenge = transport.GenerateCodeChallenge
56
57// GenerateState generates a state parameter for OAuth
58var GenerateState = transport.GenerateState
59
60// OAuthAuthorizationRequiredError is returned when OAuth authorization is required
61type OAuthAuthorizationRequiredError = transport.OAuthAuthorizationRequiredError
62
63// IsOAuthAuthorizationRequiredError checks if an error is an OAuthAuthorizationRequiredError
64func IsOAuthAuthorizationRequiredError(err error) bool {
65	var target *OAuthAuthorizationRequiredError
66	return errors.As(err, &target)
67}
68
69// GetOAuthHandler extracts the OAuthHandler from an OAuthAuthorizationRequiredError
70func GetOAuthHandler(err error) *transport.OAuthHandler {
71	var oauthErr *OAuthAuthorizationRequiredError
72	if errors.As(err, &oauthErr) {
73		return oauthErr.Handler
74	}
75	return nil
76}