1package server
2
3import (
4 "context"
5 "fmt"
6
7 "github.com/mark3labs/mcp-go/mcp"
8)
9
10// EnableSampling enables sampling capabilities for the server.
11// This allows the server to send sampling requests to clients that support it.
12func (s *MCPServer) EnableSampling() {
13 s.capabilitiesMu.Lock()
14 defer s.capabilitiesMu.Unlock()
15}
16
17// RequestSampling sends a sampling request to the client.
18// The client must have declared sampling capability during initialization.
19func (s *MCPServer) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
20 session := ClientSessionFromContext(ctx)
21 if session == nil {
22 return nil, fmt.Errorf("no active session")
23 }
24
25 // Check if the session supports sampling requests
26 if samplingSession, ok := session.(SessionWithSampling); ok {
27 return samplingSession.RequestSampling(ctx, request)
28 }
29
30 return nil, fmt.Errorf("session does not support sampling")
31}
32
33// SessionWithSampling extends ClientSession to support sampling requests.
34type SessionWithSampling interface {
35 ClientSession
36 RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error)
37}