1package server
2
3import (
4 "context"
5 "fmt"
6
7 "github.com/mark3labs/mcp-go/mcp"
8)
9
10// ClientSession represents an active session that can be used by MCPServer to interact with client.
11type ClientSession interface {
12 // Initialize marks session as fully initialized and ready for notifications
13 Initialize()
14 // Initialized returns if session is ready to accept notifications
15 Initialized() bool
16 // NotificationChannel provides a channel suitable for sending notifications to client.
17 NotificationChannel() chan<- mcp.JSONRPCNotification
18 // SessionID is a unique identifier used to track user session.
19 SessionID() string
20}
21
22// SessionWithLogging is an extension of ClientSession that can receive log message notifications and set log level
23type SessionWithLogging interface {
24 ClientSession
25 // SetLogLevel sets the minimum log level
26 SetLogLevel(level mcp.LoggingLevel)
27 // GetLogLevel retrieves the minimum log level
28 GetLogLevel() mcp.LoggingLevel
29}
30
31// SessionWithTools is an extension of ClientSession that can store session-specific tool data
32type SessionWithTools interface {
33 ClientSession
34 // GetSessionTools returns the tools specific to this session, if any
35 // This method must be thread-safe for concurrent access
36 GetSessionTools() map[string]ServerTool
37 // SetSessionTools sets tools specific to this session
38 // This method must be thread-safe for concurrent access
39 SetSessionTools(tools map[string]ServerTool)
40}
41
42// SessionWithClientInfo is an extension of ClientSession that can store client info
43type SessionWithClientInfo interface {
44 ClientSession
45 // GetClientInfo returns the client information for this session
46 GetClientInfo() mcp.Implementation
47 // SetClientInfo sets the client information for this session
48 SetClientInfo(clientInfo mcp.Implementation)
49}
50
51// SessionWithStreamableHTTPConfig extends ClientSession to support streamable HTTP transport configurations
52type SessionWithStreamableHTTPConfig interface {
53 ClientSession
54 // UpgradeToSSEWhenReceiveNotification upgrades the client-server communication to SSE stream when the server
55 // sends notifications to the client
56 //
57 // The protocol specification:
58 // - If the server response contains any JSON-RPC notifications, it MUST either:
59 // - Return Content-Type: text/event-stream to initiate an SSE stream, OR
60 // - Return Content-Type: application/json for a single JSON object
61 // - The client MUST support both response types.
62 //
63 // Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#sending-messages-to-the-server
64 UpgradeToSSEWhenReceiveNotification()
65}
66
67// clientSessionKey is the context key for storing current client notification channel.
68type clientSessionKey struct{}
69
70// ClientSessionFromContext retrieves current client notification context from context.
71func ClientSessionFromContext(ctx context.Context) ClientSession {
72 if session, ok := ctx.Value(clientSessionKey{}).(ClientSession); ok {
73 return session
74 }
75 return nil
76}
77
78// WithContext sets the current client session and returns the provided context
79func (s *MCPServer) WithContext(
80 ctx context.Context,
81 session ClientSession,
82) context.Context {
83 return context.WithValue(ctx, clientSessionKey{}, session)
84}
85
86// RegisterSession saves session that should be notified in case if some server attributes changed.
87func (s *MCPServer) RegisterSession(
88 ctx context.Context,
89 session ClientSession,
90) error {
91 sessionID := session.SessionID()
92 if _, exists := s.sessions.LoadOrStore(sessionID, session); exists {
93 return ErrSessionExists
94 }
95 s.hooks.RegisterSession(ctx, session)
96 return nil
97}
98
99// UnregisterSession removes from storage session that is shut down.
100func (s *MCPServer) UnregisterSession(
101 ctx context.Context,
102 sessionID string,
103) {
104 sessionValue, ok := s.sessions.LoadAndDelete(sessionID)
105 if !ok {
106 return
107 }
108 if session, ok := sessionValue.(ClientSession); ok {
109 s.hooks.UnregisterSession(ctx, session)
110 }
111}
112
113// SendNotificationToAllClients sends a notification to all the currently active clients.
114func (s *MCPServer) SendNotificationToAllClients(
115 method string,
116 params map[string]any,
117) {
118 notification := mcp.JSONRPCNotification{
119 JSONRPC: mcp.JSONRPC_VERSION,
120 Notification: mcp.Notification{
121 Method: method,
122 Params: mcp.NotificationParams{
123 AdditionalFields: params,
124 },
125 },
126 }
127
128 s.sessions.Range(func(k, v any) bool {
129 if session, ok := v.(ClientSession); ok && session.Initialized() {
130 select {
131 case session.NotificationChannel() <- notification:
132 // Successfully sent notification
133 default:
134 // Channel is blocked, if there's an error hook, use it
135 if s.hooks != nil && len(s.hooks.OnError) > 0 {
136 err := ErrNotificationChannelBlocked
137 // Copy hooks pointer to local variable to avoid race condition
138 hooks := s.hooks
139 go func(sessionID string, hooks *Hooks) {
140 ctx := context.Background()
141 // Use the error hook to report the blocked channel
142 hooks.onError(ctx, nil, "notification", map[string]any{
143 "method": method,
144 "sessionID": sessionID,
145 }, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err))
146 }(session.SessionID(), hooks)
147 }
148 }
149 }
150 return true
151 })
152}
153
154// SendNotificationToClient sends a notification to the current client
155func (s *MCPServer) SendNotificationToClient(
156 ctx context.Context,
157 method string,
158 params map[string]any,
159) error {
160 session := ClientSessionFromContext(ctx)
161 if session == nil || !session.Initialized() {
162 return ErrNotificationNotInitialized
163 }
164
165 // upgrades the client-server communication to SSE stream when the server sends notifications to the client
166 if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok {
167 sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification()
168 }
169
170 notification := mcp.JSONRPCNotification{
171 JSONRPC: mcp.JSONRPC_VERSION,
172 Notification: mcp.Notification{
173 Method: method,
174 Params: mcp.NotificationParams{
175 AdditionalFields: params,
176 },
177 },
178 }
179
180 select {
181 case session.NotificationChannel() <- notification:
182 return nil
183 default:
184 // Channel is blocked, if there's an error hook, use it
185 if s.hooks != nil && len(s.hooks.OnError) > 0 {
186 err := ErrNotificationChannelBlocked
187 // Copy hooks pointer to local variable to avoid race condition
188 hooks := s.hooks
189 go func(sessionID string, hooks *Hooks) {
190 // Use the error hook to report the blocked channel
191 hooks.onError(ctx, nil, "notification", map[string]any{
192 "method": method,
193 "sessionID": sessionID,
194 }, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err))
195 }(session.SessionID(), hooks)
196 }
197 return ErrNotificationChannelBlocked
198 }
199}
200
201// SendNotificationToSpecificClient sends a notification to a specific client by session ID
202func (s *MCPServer) SendNotificationToSpecificClient(
203 sessionID string,
204 method string,
205 params map[string]any,
206) error {
207 sessionValue, ok := s.sessions.Load(sessionID)
208 if !ok {
209 return ErrSessionNotFound
210 }
211
212 session, ok := sessionValue.(ClientSession)
213 if !ok || !session.Initialized() {
214 return ErrSessionNotInitialized
215 }
216
217 // upgrades the client-server communication to SSE stream when the server sends notifications to the client
218 if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok {
219 sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification()
220 }
221
222 notification := mcp.JSONRPCNotification{
223 JSONRPC: mcp.JSONRPC_VERSION,
224 Notification: mcp.Notification{
225 Method: method,
226 Params: mcp.NotificationParams{
227 AdditionalFields: params,
228 },
229 },
230 }
231
232 select {
233 case session.NotificationChannel() <- notification:
234 return nil
235 default:
236 // Channel is blocked, if there's an error hook, use it
237 if s.hooks != nil && len(s.hooks.OnError) > 0 {
238 err := ErrNotificationChannelBlocked
239 ctx := context.Background()
240 // Copy hooks pointer to local variable to avoid race condition
241 hooks := s.hooks
242 go func(sID string, hooks *Hooks) {
243 // Use the error hook to report the blocked channel
244 hooks.onError(ctx, nil, "notification", map[string]any{
245 "method": method,
246 "sessionID": sID,
247 }, fmt.Errorf("notification channel blocked for session %s: %w", sID, err))
248 }(sessionID, hooks)
249 }
250 return ErrNotificationChannelBlocked
251 }
252}
253
254// AddSessionTool adds a tool for a specific session
255func (s *MCPServer) AddSessionTool(sessionID string, tool mcp.Tool, handler ToolHandlerFunc) error {
256 return s.AddSessionTools(sessionID, ServerTool{Tool: tool, Handler: handler})
257}
258
259// AddSessionTools adds tools for a specific session
260func (s *MCPServer) AddSessionTools(sessionID string, tools ...ServerTool) error {
261 sessionValue, ok := s.sessions.Load(sessionID)
262 if !ok {
263 return ErrSessionNotFound
264 }
265
266 session, ok := sessionValue.(SessionWithTools)
267 if !ok {
268 return ErrSessionDoesNotSupportTools
269 }
270
271 s.implicitlyRegisterToolCapabilities()
272
273 // Get existing tools (this should return a thread-safe copy)
274 sessionTools := session.GetSessionTools()
275
276 // Create a new map to avoid concurrent modification issues
277 newSessionTools := make(map[string]ServerTool, len(sessionTools)+len(tools))
278
279 // Copy existing tools
280 for k, v := range sessionTools {
281 newSessionTools[k] = v
282 }
283
284 // Add new tools
285 for _, tool := range tools {
286 newSessionTools[tool.Tool.Name] = tool
287 }
288
289 // Set the tools (this should be thread-safe)
290 session.SetSessionTools(newSessionTools)
291
292 // It only makes sense to send tool notifications to initialized sessions --
293 // if we're not initialized yet the client can't possibly have sent their
294 // initial tools/list message.
295 //
296 // For initialized sessions, honor tools.listChanged, which is specifically
297 // about whether notifications will be sent or not.
298 // see <https://modelcontextprotocol.io/specification/2025-03-26/server/tools#capabilities>
299 if session.Initialized() && s.capabilities.tools != nil && s.capabilities.tools.listChanged {
300 // Send notification only to this session
301 if err := s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil); err != nil {
302 // Log the error but don't fail the operation
303 // The tools were successfully added, but notification failed
304 if s.hooks != nil && len(s.hooks.OnError) > 0 {
305 hooks := s.hooks
306 go func(sID string, hooks *Hooks) {
307 ctx := context.Background()
308 hooks.onError(ctx, nil, "notification", map[string]any{
309 "method": "notifications/tools/list_changed",
310 "sessionID": sID,
311 }, fmt.Errorf("failed to send notification after adding tools: %w", err))
312 }(sessionID, hooks)
313 }
314 }
315 }
316
317 return nil
318}
319
320// DeleteSessionTools removes tools from a specific session
321func (s *MCPServer) DeleteSessionTools(sessionID string, names ...string) error {
322 sessionValue, ok := s.sessions.Load(sessionID)
323 if !ok {
324 return ErrSessionNotFound
325 }
326
327 session, ok := sessionValue.(SessionWithTools)
328 if !ok {
329 return ErrSessionDoesNotSupportTools
330 }
331
332 // Get existing tools (this should return a thread-safe copy)
333 sessionTools := session.GetSessionTools()
334 if sessionTools == nil {
335 return nil
336 }
337
338 // Create a new map to avoid concurrent modification issues
339 newSessionTools := make(map[string]ServerTool, len(sessionTools))
340
341 // Copy existing tools except those being deleted
342 for k, v := range sessionTools {
343 newSessionTools[k] = v
344 }
345
346 // Remove specified tools
347 for _, name := range names {
348 delete(newSessionTools, name)
349 }
350
351 // Set the tools (this should be thread-safe)
352 session.SetSessionTools(newSessionTools)
353
354 // It only makes sense to send tool notifications to initialized sessions --
355 // if we're not initialized yet the client can't possibly have sent their
356 // initial tools/list message.
357 //
358 // For initialized sessions, honor tools.listChanged, which is specifically
359 // about whether notifications will be sent or not.
360 // see <https://modelcontextprotocol.io/specification/2025-03-26/server/tools#capabilities>
361 if session.Initialized() && s.capabilities.tools != nil && s.capabilities.tools.listChanged {
362 // Send notification only to this session
363 if err := s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil); err != nil {
364 // Log the error but don't fail the operation
365 // The tools were successfully deleted, but notification failed
366 if s.hooks != nil && len(s.hooks.OnError) > 0 {
367 hooks := s.hooks
368 go func(sID string, hooks *Hooks) {
369 ctx := context.Background()
370 hooks.onError(ctx, nil, "notification", map[string]any{
371 "method": "notifications/tools/list_changed",
372 "sessionID": sID,
373 }, fmt.Errorf("failed to send notification after deleting tools: %w", err))
374 }(sessionID, hooks)
375 }
376 }
377 }
378
379 return nil
380}