session.go

  1package server
  2
  3import (
  4	"context"
  5	"fmt"
  6
  7	"github.com/mark3labs/mcp-go/mcp"
  8)
  9
 10// ClientSession represents an active session that can be used by MCPServer to interact with client.
 11type ClientSession interface {
 12	// Initialize marks session as fully initialized and ready for notifications
 13	Initialize()
 14	// Initialized returns if session is ready to accept notifications
 15	Initialized() bool
 16	// NotificationChannel provides a channel suitable for sending notifications to client.
 17	NotificationChannel() chan<- mcp.JSONRPCNotification
 18	// SessionID is a unique identifier used to track user session.
 19	SessionID() string
 20}
 21
 22// SessionWithLogging is an extension of ClientSession that can receive log message notifications and set log level
 23type SessionWithLogging interface {
 24	ClientSession
 25	// SetLogLevel sets the minimum log level
 26	SetLogLevel(level mcp.LoggingLevel)
 27	// GetLogLevel retrieves the minimum log level
 28	GetLogLevel() mcp.LoggingLevel
 29}
 30
 31// SessionWithTools is an extension of ClientSession that can store session-specific tool data
 32type SessionWithTools interface {
 33	ClientSession
 34	// GetSessionTools returns the tools specific to this session, if any
 35	// This method must be thread-safe for concurrent access
 36	GetSessionTools() map[string]ServerTool
 37	// SetSessionTools sets tools specific to this session
 38	// This method must be thread-safe for concurrent access
 39	SetSessionTools(tools map[string]ServerTool)
 40}
 41
 42// SessionWithClientInfo is an extension of ClientSession that can store client info
 43type SessionWithClientInfo interface {
 44	ClientSession
 45	// GetClientInfo returns the client information for this session
 46	GetClientInfo() mcp.Implementation
 47	// SetClientInfo sets the client information for this session
 48	SetClientInfo(clientInfo mcp.Implementation)
 49}
 50
 51// SessionWithStreamableHTTPConfig extends ClientSession to support streamable HTTP transport configurations
 52type SessionWithStreamableHTTPConfig interface {
 53	ClientSession
 54	// UpgradeToSSEWhenReceiveNotification upgrades the client-server communication to SSE stream when the server
 55	// sends notifications to the client
 56	//
 57	// The protocol specification:
 58	// - If the server response contains any JSON-RPC notifications, it MUST either:
 59	//   - Return Content-Type: text/event-stream to initiate an SSE stream, OR
 60	//   - Return Content-Type: application/json for a single JSON object
 61	// - The client MUST support both response types.
 62	//
 63	// Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#sending-messages-to-the-server
 64	UpgradeToSSEWhenReceiveNotification()
 65}
 66
 67// clientSessionKey is the context key for storing current client notification channel.
 68type clientSessionKey struct{}
 69
 70// ClientSessionFromContext retrieves current client notification context from context.
 71func ClientSessionFromContext(ctx context.Context) ClientSession {
 72	if session, ok := ctx.Value(clientSessionKey{}).(ClientSession); ok {
 73		return session
 74	}
 75	return nil
 76}
 77
 78// WithContext sets the current client session and returns the provided context
 79func (s *MCPServer) WithContext(
 80	ctx context.Context,
 81	session ClientSession,
 82) context.Context {
 83	return context.WithValue(ctx, clientSessionKey{}, session)
 84}
 85
 86// RegisterSession saves session that should be notified in case if some server attributes changed.
 87func (s *MCPServer) RegisterSession(
 88	ctx context.Context,
 89	session ClientSession,
 90) error {
 91	sessionID := session.SessionID()
 92	if _, exists := s.sessions.LoadOrStore(sessionID, session); exists {
 93		return ErrSessionExists
 94	}
 95	s.hooks.RegisterSession(ctx, session)
 96	return nil
 97}
 98
 99// UnregisterSession removes from storage session that is shut down.
100func (s *MCPServer) UnregisterSession(
101	ctx context.Context,
102	sessionID string,
103) {
104	sessionValue, ok := s.sessions.LoadAndDelete(sessionID)
105	if !ok {
106		return
107	}
108	if session, ok := sessionValue.(ClientSession); ok {
109		s.hooks.UnregisterSession(ctx, session)
110	}
111}
112
113// SendNotificationToAllClients sends a notification to all the currently active clients.
114func (s *MCPServer) SendNotificationToAllClients(
115	method string,
116	params map[string]any,
117) {
118	notification := mcp.JSONRPCNotification{
119		JSONRPC: mcp.JSONRPC_VERSION,
120		Notification: mcp.Notification{
121			Method: method,
122			Params: mcp.NotificationParams{
123				AdditionalFields: params,
124			},
125		},
126	}
127
128	s.sessions.Range(func(k, v any) bool {
129		if session, ok := v.(ClientSession); ok && session.Initialized() {
130			select {
131			case session.NotificationChannel() <- notification:
132				// Successfully sent notification
133			default:
134				// Channel is blocked, if there's an error hook, use it
135				if s.hooks != nil && len(s.hooks.OnError) > 0 {
136					err := ErrNotificationChannelBlocked
137					// Copy hooks pointer to local variable to avoid race condition
138					hooks := s.hooks
139					go func(sessionID string, hooks *Hooks) {
140						ctx := context.Background()
141						// Use the error hook to report the blocked channel
142						hooks.onError(ctx, nil, "notification", map[string]any{
143							"method":    method,
144							"sessionID": sessionID,
145						}, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err))
146					}(session.SessionID(), hooks)
147				}
148			}
149		}
150		return true
151	})
152}
153
154// SendNotificationToClient sends a notification to the current client
155func (s *MCPServer) SendNotificationToClient(
156	ctx context.Context,
157	method string,
158	params map[string]any,
159) error {
160	session := ClientSessionFromContext(ctx)
161	if session == nil || !session.Initialized() {
162		return ErrNotificationNotInitialized
163	}
164
165	// upgrades the client-server communication to SSE stream when the server sends notifications to the client
166	if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok {
167		sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification()
168	}
169
170	notification := mcp.JSONRPCNotification{
171		JSONRPC: mcp.JSONRPC_VERSION,
172		Notification: mcp.Notification{
173			Method: method,
174			Params: mcp.NotificationParams{
175				AdditionalFields: params,
176			},
177		},
178	}
179
180	select {
181	case session.NotificationChannel() <- notification:
182		return nil
183	default:
184		// Channel is blocked, if there's an error hook, use it
185		if s.hooks != nil && len(s.hooks.OnError) > 0 {
186			err := ErrNotificationChannelBlocked
187			// Copy hooks pointer to local variable to avoid race condition
188			hooks := s.hooks
189			go func(sessionID string, hooks *Hooks) {
190				// Use the error hook to report the blocked channel
191				hooks.onError(ctx, nil, "notification", map[string]any{
192					"method":    method,
193					"sessionID": sessionID,
194				}, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err))
195			}(session.SessionID(), hooks)
196		}
197		return ErrNotificationChannelBlocked
198	}
199}
200
201// SendNotificationToSpecificClient sends a notification to a specific client by session ID
202func (s *MCPServer) SendNotificationToSpecificClient(
203	sessionID string,
204	method string,
205	params map[string]any,
206) error {
207	sessionValue, ok := s.sessions.Load(sessionID)
208	if !ok {
209		return ErrSessionNotFound
210	}
211
212	session, ok := sessionValue.(ClientSession)
213	if !ok || !session.Initialized() {
214		return ErrSessionNotInitialized
215	}
216
217	// upgrades the client-server communication to SSE stream when the server sends notifications to the client
218	if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok {
219		sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification()
220	}
221
222	notification := mcp.JSONRPCNotification{
223		JSONRPC: mcp.JSONRPC_VERSION,
224		Notification: mcp.Notification{
225			Method: method,
226			Params: mcp.NotificationParams{
227				AdditionalFields: params,
228			},
229		},
230	}
231
232	select {
233	case session.NotificationChannel() <- notification:
234		return nil
235	default:
236		// Channel is blocked, if there's an error hook, use it
237		if s.hooks != nil && len(s.hooks.OnError) > 0 {
238			err := ErrNotificationChannelBlocked
239			ctx := context.Background()
240			// Copy hooks pointer to local variable to avoid race condition
241			hooks := s.hooks
242			go func(sID string, hooks *Hooks) {
243				// Use the error hook to report the blocked channel
244				hooks.onError(ctx, nil, "notification", map[string]any{
245					"method":    method,
246					"sessionID": sID,
247				}, fmt.Errorf("notification channel blocked for session %s: %w", sID, err))
248			}(sessionID, hooks)
249		}
250		return ErrNotificationChannelBlocked
251	}
252}
253
254// AddSessionTool adds a tool for a specific session
255func (s *MCPServer) AddSessionTool(sessionID string, tool mcp.Tool, handler ToolHandlerFunc) error {
256	return s.AddSessionTools(sessionID, ServerTool{Tool: tool, Handler: handler})
257}
258
259// AddSessionTools adds tools for a specific session
260func (s *MCPServer) AddSessionTools(sessionID string, tools ...ServerTool) error {
261	sessionValue, ok := s.sessions.Load(sessionID)
262	if !ok {
263		return ErrSessionNotFound
264	}
265
266	session, ok := sessionValue.(SessionWithTools)
267	if !ok {
268		return ErrSessionDoesNotSupportTools
269	}
270
271	s.implicitlyRegisterToolCapabilities()
272
273	// Get existing tools (this should return a thread-safe copy)
274	sessionTools := session.GetSessionTools()
275
276	// Create a new map to avoid concurrent modification issues
277	newSessionTools := make(map[string]ServerTool, len(sessionTools)+len(tools))
278
279	// Copy existing tools
280	for k, v := range sessionTools {
281		newSessionTools[k] = v
282	}
283
284	// Add new tools
285	for _, tool := range tools {
286		newSessionTools[tool.Tool.Name] = tool
287	}
288
289	// Set the tools (this should be thread-safe)
290	session.SetSessionTools(newSessionTools)
291
292	// It only makes sense to send tool notifications to initialized sessions --
293	// if we're not initialized yet the client can't possibly have sent their
294	// initial tools/list message.
295	//
296	// For initialized sessions, honor tools.listChanged, which is specifically
297	// about whether notifications will be sent or not.
298	// see <https://modelcontextprotocol.io/specification/2025-03-26/server/tools#capabilities>
299	if session.Initialized() && s.capabilities.tools != nil && s.capabilities.tools.listChanged {
300		// Send notification only to this session
301		if err := s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil); err != nil {
302			// Log the error but don't fail the operation
303			// The tools were successfully added, but notification failed
304			if s.hooks != nil && len(s.hooks.OnError) > 0 {
305				hooks := s.hooks
306				go func(sID string, hooks *Hooks) {
307					ctx := context.Background()
308					hooks.onError(ctx, nil, "notification", map[string]any{
309						"method":    "notifications/tools/list_changed",
310						"sessionID": sID,
311					}, fmt.Errorf("failed to send notification after adding tools: %w", err))
312				}(sessionID, hooks)
313			}
314		}
315	}
316
317	return nil
318}
319
320// DeleteSessionTools removes tools from a specific session
321func (s *MCPServer) DeleteSessionTools(sessionID string, names ...string) error {
322	sessionValue, ok := s.sessions.Load(sessionID)
323	if !ok {
324		return ErrSessionNotFound
325	}
326
327	session, ok := sessionValue.(SessionWithTools)
328	if !ok {
329		return ErrSessionDoesNotSupportTools
330	}
331
332	// Get existing tools (this should return a thread-safe copy)
333	sessionTools := session.GetSessionTools()
334	if sessionTools == nil {
335		return nil
336	}
337
338	// Create a new map to avoid concurrent modification issues
339	newSessionTools := make(map[string]ServerTool, len(sessionTools))
340
341	// Copy existing tools except those being deleted
342	for k, v := range sessionTools {
343		newSessionTools[k] = v
344	}
345
346	// Remove specified tools
347	for _, name := range names {
348		delete(newSessionTools, name)
349	}
350
351	// Set the tools (this should be thread-safe)
352	session.SetSessionTools(newSessionTools)
353
354	// It only makes sense to send tool notifications to initialized sessions --
355	// if we're not initialized yet the client can't possibly have sent their
356	// initial tools/list message.
357	//
358	// For initialized sessions, honor tools.listChanged, which is specifically
359	// about whether notifications will be sent or not.
360	// see <https://modelcontextprotocol.io/specification/2025-03-26/server/tools#capabilities>
361	if session.Initialized() && s.capabilities.tools != nil && s.capabilities.tools.listChanged {
362		// Send notification only to this session
363		if err := s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil); err != nil {
364			// Log the error but don't fail the operation
365			// The tools were successfully deleted, but notification failed
366			if s.hooks != nil && len(s.hooks.OnError) > 0 {
367				hooks := s.hooks
368				go func(sID string, hooks *Hooks) {
369					ctx := context.Background()
370					hooks.onError(ctx, nil, "notification", map[string]any{
371						"method":    "notifications/tools/list_changed",
372						"sessionID": sID,
373					}, fmt.Errorf("failed to send notification after deleting tools: %w", err))
374				}(sessionID, hooks)
375			}
376		}
377	}
378
379	return nil
380}