1package client
2
3import (
4 "errors"
5 "fmt"
6
7 "github.com/mark3labs/mcp-go/client/transport"
8)
9
10// OAuthConfig is a convenience type that wraps transport.OAuthConfig
11type OAuthConfig = transport.OAuthConfig
12
13// Token is a convenience type that wraps transport.Token
14type Token = transport.Token
15
16// TokenStore is a convenience type that wraps transport.TokenStore
17type TokenStore = transport.TokenStore
18
19// MemoryTokenStore is a convenience type that wraps transport.MemoryTokenStore
20type MemoryTokenStore = transport.MemoryTokenStore
21
22// NewMemoryTokenStore is a convenience function that wraps transport.NewMemoryTokenStore
23var NewMemoryTokenStore = transport.NewMemoryTokenStore
24
25// NewOAuthStreamableHttpClient creates a new streamable-http-based MCP client with OAuth support.
26// Returns an error if the URL is invalid.
27func NewOAuthStreamableHttpClient(baseURL string, oauthConfig OAuthConfig, options ...transport.StreamableHTTPCOption) (*Client, error) {
28 // Add OAuth option to the list of options
29 options = append(options, transport.WithHTTPOAuth(oauthConfig))
30
31 trans, err := transport.NewStreamableHTTP(baseURL, options...)
32 if err != nil {
33 return nil, fmt.Errorf("failed to create HTTP transport: %w", err)
34 }
35 return NewClient(trans), nil
36}
37
38// NewOAuthStreamableHttpClient creates a new streamable-http-based MCP client with OAuth support.
39// Returns an error if the URL is invalid.
40func NewOAuthSSEClient(baseURL string, oauthConfig OAuthConfig, options ...transport.ClientOption) (*Client, error) {
41 // Add OAuth option to the list of options
42 options = append(options, transport.WithOAuth(oauthConfig))
43
44 trans, err := transport.NewSSE(baseURL, options...)
45 if err != nil {
46 return nil, fmt.Errorf("failed to create SSE transport: %w", err)
47 }
48 return NewClient(trans), nil
49}
50
51// GenerateCodeVerifier generates a code verifier for PKCE
52var GenerateCodeVerifier = transport.GenerateCodeVerifier
53
54// GenerateCodeChallenge generates a code challenge from a code verifier
55var GenerateCodeChallenge = transport.GenerateCodeChallenge
56
57// GenerateState generates a state parameter for OAuth
58var GenerateState = transport.GenerateState
59
60// OAuthAuthorizationRequiredError is returned when OAuth authorization is required
61type OAuthAuthorizationRequiredError = transport.OAuthAuthorizationRequiredError
62
63// IsOAuthAuthorizationRequiredError checks if an error is an OAuthAuthorizationRequiredError
64func IsOAuthAuthorizationRequiredError(err error) bool {
65 var target *OAuthAuthorizationRequiredError
66 return errors.As(err, &target)
67}
68
69// GetOAuthHandler extracts the OAuthHandler from an OAuthAuthorizationRequiredError
70func GetOAuthHandler(err error) *transport.OAuthHandler {
71 var oauthErr *OAuthAuthorizationRequiredError
72 if errors.As(err, &oauthErr) {
73 return oauthErr.Handler
74 }
75 return nil
76}