hooks.go

  1// Code generated by `go generate`. DO NOT EDIT.
  2// source: server/internal/gen/hooks.go.tmpl
  3package server
  4
  5import (
  6	"context"
  7
  8	"github.com/mark3labs/mcp-go/mcp"
  9)
 10
 11// OnRegisterSessionHookFunc is a hook that will be called when a new session is registered.
 12type OnRegisterSessionHookFunc func(ctx context.Context, session ClientSession)
 13
 14// OnUnregisterSessionHookFunc is a hook that will be called when a session is being unregistered.
 15type OnUnregisterSessionHookFunc func(ctx context.Context, session ClientSession)
 16
 17// BeforeAnyHookFunc is a function that is called after the request is
 18// parsed but before the method is called.
 19type BeforeAnyHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any)
 20
 21// OnSuccessHookFunc is a hook that will be called after the request
 22// successfully generates a result, but before the result is sent to the client.
 23type OnSuccessHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any, result any)
 24
 25// OnErrorHookFunc is a hook that will be called when an error occurs,
 26// either during the request parsing or the method execution.
 27//
 28// Example usage:
 29// ```
 30//
 31//	hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) {
 32//	  // Check for specific error types using errors.Is
 33//	  if errors.Is(err, ErrUnsupported) {
 34//	    // Handle capability not supported errors
 35//	    log.Printf("Capability not supported: %v", err)
 36//	  }
 37//
 38//	  // Use errors.As to get specific error types
 39//	  var parseErr = &UnparsableMessageError{}
 40//	  if errors.As(err, &parseErr) {
 41//	    // Access specific methods/fields of the error type
 42//	    log.Printf("Failed to parse message for method %s: %v",
 43//	               parseErr.GetMethod(), parseErr.Unwrap())
 44//	    // Access the raw message that failed to parse
 45//	    rawMsg := parseErr.GetMessage()
 46//	  }
 47//
 48//	  // Check for specific resource/prompt/tool errors
 49//	  switch {
 50//	  case errors.Is(err, ErrResourceNotFound):
 51//	    log.Printf("Resource not found: %v", err)
 52//	  case errors.Is(err, ErrPromptNotFound):
 53//	    log.Printf("Prompt not found: %v", err)
 54//	  case errors.Is(err, ErrToolNotFound):
 55//	    log.Printf("Tool not found: %v", err)
 56//	  }
 57//	})
 58type OnErrorHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error)
 59
 60// OnRequestInitializationFunc is a function that called before handle diff request method
 61// Should any errors arise during func execution, the service will promptly return the corresponding error message.
 62type OnRequestInitializationFunc func(ctx context.Context, id any, message any) error
 63
 64type OnBeforeInitializeFunc func(ctx context.Context, id any, message *mcp.InitializeRequest)
 65type OnAfterInitializeFunc func(ctx context.Context, id any, message *mcp.InitializeRequest, result *mcp.InitializeResult)
 66
 67type OnBeforePingFunc func(ctx context.Context, id any, message *mcp.PingRequest)
 68type OnAfterPingFunc func(ctx context.Context, id any, message *mcp.PingRequest, result *mcp.EmptyResult)
 69
 70type OnBeforeSetLevelFunc func(ctx context.Context, id any, message *mcp.SetLevelRequest)
 71type OnAfterSetLevelFunc func(ctx context.Context, id any, message *mcp.SetLevelRequest, result *mcp.EmptyResult)
 72
 73type OnBeforeListResourcesFunc func(ctx context.Context, id any, message *mcp.ListResourcesRequest)
 74type OnAfterListResourcesFunc func(ctx context.Context, id any, message *mcp.ListResourcesRequest, result *mcp.ListResourcesResult)
 75
 76type OnBeforeListResourceTemplatesFunc func(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest)
 77type OnAfterListResourceTemplatesFunc func(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest, result *mcp.ListResourceTemplatesResult)
 78
 79type OnBeforeReadResourceFunc func(ctx context.Context, id any, message *mcp.ReadResourceRequest)
 80type OnAfterReadResourceFunc func(ctx context.Context, id any, message *mcp.ReadResourceRequest, result *mcp.ReadResourceResult)
 81
 82type OnBeforeListPromptsFunc func(ctx context.Context, id any, message *mcp.ListPromptsRequest)
 83type OnAfterListPromptsFunc func(ctx context.Context, id any, message *mcp.ListPromptsRequest, result *mcp.ListPromptsResult)
 84
 85type OnBeforeGetPromptFunc func(ctx context.Context, id any, message *mcp.GetPromptRequest)
 86type OnAfterGetPromptFunc func(ctx context.Context, id any, message *mcp.GetPromptRequest, result *mcp.GetPromptResult)
 87
 88type OnBeforeListToolsFunc func(ctx context.Context, id any, message *mcp.ListToolsRequest)
 89type OnAfterListToolsFunc func(ctx context.Context, id any, message *mcp.ListToolsRequest, result *mcp.ListToolsResult)
 90
 91type OnBeforeCallToolFunc func(ctx context.Context, id any, message *mcp.CallToolRequest)
 92type OnAfterCallToolFunc func(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult)
 93
 94type Hooks struct {
 95	OnRegisterSession             []OnRegisterSessionHookFunc
 96	OnUnregisterSession           []OnUnregisterSessionHookFunc
 97	OnBeforeAny                   []BeforeAnyHookFunc
 98	OnSuccess                     []OnSuccessHookFunc
 99	OnError                       []OnErrorHookFunc
100	OnRequestInitialization       []OnRequestInitializationFunc
101	OnBeforeInitialize            []OnBeforeInitializeFunc
102	OnAfterInitialize             []OnAfterInitializeFunc
103	OnBeforePing                  []OnBeforePingFunc
104	OnAfterPing                   []OnAfterPingFunc
105	OnBeforeSetLevel              []OnBeforeSetLevelFunc
106	OnAfterSetLevel               []OnAfterSetLevelFunc
107	OnBeforeListResources         []OnBeforeListResourcesFunc
108	OnAfterListResources          []OnAfterListResourcesFunc
109	OnBeforeListResourceTemplates []OnBeforeListResourceTemplatesFunc
110	OnAfterListResourceTemplates  []OnAfterListResourceTemplatesFunc
111	OnBeforeReadResource          []OnBeforeReadResourceFunc
112	OnAfterReadResource           []OnAfterReadResourceFunc
113	OnBeforeListPrompts           []OnBeforeListPromptsFunc
114	OnAfterListPrompts            []OnAfterListPromptsFunc
115	OnBeforeGetPrompt             []OnBeforeGetPromptFunc
116	OnAfterGetPrompt              []OnAfterGetPromptFunc
117	OnBeforeListTools             []OnBeforeListToolsFunc
118	OnAfterListTools              []OnAfterListToolsFunc
119	OnBeforeCallTool              []OnBeforeCallToolFunc
120	OnAfterCallTool               []OnAfterCallToolFunc
121}
122
123func (c *Hooks) AddBeforeAny(hook BeforeAnyHookFunc) {
124	c.OnBeforeAny = append(c.OnBeforeAny, hook)
125}
126
127func (c *Hooks) AddOnSuccess(hook OnSuccessHookFunc) {
128	c.OnSuccess = append(c.OnSuccess, hook)
129}
130
131// AddOnError registers a hook function that will be called when an error occurs.
132// The error parameter contains the actual error object, which can be interrogated
133// using Go's error handling patterns like errors.Is and errors.As.
134//
135// Example:
136// ```
137// // Create a channel to receive errors for testing
138// errChan := make(chan error, 1)
139//
140// // Register hook to capture and inspect errors
141// hooks := &Hooks{}
142//
143//	hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) {
144//	    // For capability-related errors
145//	    if errors.Is(err, ErrUnsupported) {
146//	        // Handle capability not supported
147//	        errChan <- err
148//	        return
149//	    }
150//
151//	    // For parsing errors
152//	    var parseErr = &UnparsableMessageError{}
153//	    if errors.As(err, &parseErr) {
154//	        // Handle unparsable message errors
155//	        fmt.Printf("Failed to parse %s request: %v\n",
156//	                   parseErr.GetMethod(), parseErr.Unwrap())
157//	        errChan <- parseErr
158//	        return
159//	    }
160//
161//	    // For resource/prompt/tool not found errors
162//	    if errors.Is(err, ErrResourceNotFound) ||
163//	       errors.Is(err, ErrPromptNotFound) ||
164//	       errors.Is(err, ErrToolNotFound) {
165//	        // Handle not found errors
166//	        errChan <- err
167//	        return
168//	    }
169//
170//	    // For other errors
171//	    errChan <- err
172//	})
173//
174// server := NewMCPServer("test-server", "1.0.0", WithHooks(hooks))
175// ```
176func (c *Hooks) AddOnError(hook OnErrorHookFunc) {
177	c.OnError = append(c.OnError, hook)
178}
179
180func (c *Hooks) beforeAny(ctx context.Context, id any, method mcp.MCPMethod, message any) {
181	if c == nil {
182		return
183	}
184	for _, hook := range c.OnBeforeAny {
185		hook(ctx, id, method, message)
186	}
187}
188
189func (c *Hooks) onSuccess(ctx context.Context, id any, method mcp.MCPMethod, message any, result any) {
190	if c == nil {
191		return
192	}
193	for _, hook := range c.OnSuccess {
194		hook(ctx, id, method, message, result)
195	}
196}
197
198// onError calls all registered error hooks with the error object.
199// The err parameter contains the actual error that occurred, which implements
200// the standard error interface and may be a wrapped error or custom error type.
201//
202// This allows consumer code to use Go's error handling patterns:
203// - errors.Is(err, ErrUnsupported) to check for specific sentinel errors
204// - errors.As(err, &customErr) to extract custom error types
205//
206// Common error types include:
207// - ErrUnsupported: When a capability is not enabled
208// - UnparsableMessageError: When request parsing fails
209// - ErrResourceNotFound: When a resource is not found
210// - ErrPromptNotFound: When a prompt is not found
211// - ErrToolNotFound: When a tool is not found
212func (c *Hooks) onError(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) {
213	if c == nil {
214		return
215	}
216	for _, hook := range c.OnError {
217		hook(ctx, id, method, message, err)
218	}
219}
220
221func (c *Hooks) AddOnRegisterSession(hook OnRegisterSessionHookFunc) {
222	c.OnRegisterSession = append(c.OnRegisterSession, hook)
223}
224
225func (c *Hooks) RegisterSession(ctx context.Context, session ClientSession) {
226	if c == nil {
227		return
228	}
229	for _, hook := range c.OnRegisterSession {
230		hook(ctx, session)
231	}
232}
233
234func (c *Hooks) AddOnUnregisterSession(hook OnUnregisterSessionHookFunc) {
235	c.OnUnregisterSession = append(c.OnUnregisterSession, hook)
236}
237
238func (c *Hooks) UnregisterSession(ctx context.Context, session ClientSession) {
239	if c == nil {
240		return
241	}
242	for _, hook := range c.OnUnregisterSession {
243		hook(ctx, session)
244	}
245}
246
247func (c *Hooks) AddOnRequestInitialization(hook OnRequestInitializationFunc) {
248	c.OnRequestInitialization = append(c.OnRequestInitialization, hook)
249}
250
251func (c *Hooks) onRequestInitialization(ctx context.Context, id any, message any) error {
252	if c == nil {
253		return nil
254	}
255	for _, hook := range c.OnRequestInitialization {
256		err := hook(ctx, id, message)
257		if err != nil {
258			return err
259		}
260	}
261	return nil
262}
263func (c *Hooks) AddBeforeInitialize(hook OnBeforeInitializeFunc) {
264	c.OnBeforeInitialize = append(c.OnBeforeInitialize, hook)
265}
266
267func (c *Hooks) AddAfterInitialize(hook OnAfterInitializeFunc) {
268	c.OnAfterInitialize = append(c.OnAfterInitialize, hook)
269}
270
271func (c *Hooks) beforeInitialize(ctx context.Context, id any, message *mcp.InitializeRequest) {
272	c.beforeAny(ctx, id, mcp.MethodInitialize, message)
273	if c == nil {
274		return
275	}
276	for _, hook := range c.OnBeforeInitialize {
277		hook(ctx, id, message)
278	}
279}
280
281func (c *Hooks) afterInitialize(ctx context.Context, id any, message *mcp.InitializeRequest, result *mcp.InitializeResult) {
282	c.onSuccess(ctx, id, mcp.MethodInitialize, message, result)
283	if c == nil {
284		return
285	}
286	for _, hook := range c.OnAfterInitialize {
287		hook(ctx, id, message, result)
288	}
289}
290func (c *Hooks) AddBeforePing(hook OnBeforePingFunc) {
291	c.OnBeforePing = append(c.OnBeforePing, hook)
292}
293
294func (c *Hooks) AddAfterPing(hook OnAfterPingFunc) {
295	c.OnAfterPing = append(c.OnAfterPing, hook)
296}
297
298func (c *Hooks) beforePing(ctx context.Context, id any, message *mcp.PingRequest) {
299	c.beforeAny(ctx, id, mcp.MethodPing, message)
300	if c == nil {
301		return
302	}
303	for _, hook := range c.OnBeforePing {
304		hook(ctx, id, message)
305	}
306}
307
308func (c *Hooks) afterPing(ctx context.Context, id any, message *mcp.PingRequest, result *mcp.EmptyResult) {
309	c.onSuccess(ctx, id, mcp.MethodPing, message, result)
310	if c == nil {
311		return
312	}
313	for _, hook := range c.OnAfterPing {
314		hook(ctx, id, message, result)
315	}
316}
317func (c *Hooks) AddBeforeSetLevel(hook OnBeforeSetLevelFunc) {
318	c.OnBeforeSetLevel = append(c.OnBeforeSetLevel, hook)
319}
320
321func (c *Hooks) AddAfterSetLevel(hook OnAfterSetLevelFunc) {
322	c.OnAfterSetLevel = append(c.OnAfterSetLevel, hook)
323}
324
325func (c *Hooks) beforeSetLevel(ctx context.Context, id any, message *mcp.SetLevelRequest) {
326	c.beforeAny(ctx, id, mcp.MethodSetLogLevel, message)
327	if c == nil {
328		return
329	}
330	for _, hook := range c.OnBeforeSetLevel {
331		hook(ctx, id, message)
332	}
333}
334
335func (c *Hooks) afterSetLevel(ctx context.Context, id any, message *mcp.SetLevelRequest, result *mcp.EmptyResult) {
336	c.onSuccess(ctx, id, mcp.MethodSetLogLevel, message, result)
337	if c == nil {
338		return
339	}
340	for _, hook := range c.OnAfterSetLevel {
341		hook(ctx, id, message, result)
342	}
343}
344func (c *Hooks) AddBeforeListResources(hook OnBeforeListResourcesFunc) {
345	c.OnBeforeListResources = append(c.OnBeforeListResources, hook)
346}
347
348func (c *Hooks) AddAfterListResources(hook OnAfterListResourcesFunc) {
349	c.OnAfterListResources = append(c.OnAfterListResources, hook)
350}
351
352func (c *Hooks) beforeListResources(ctx context.Context, id any, message *mcp.ListResourcesRequest) {
353	c.beforeAny(ctx, id, mcp.MethodResourcesList, message)
354	if c == nil {
355		return
356	}
357	for _, hook := range c.OnBeforeListResources {
358		hook(ctx, id, message)
359	}
360}
361
362func (c *Hooks) afterListResources(ctx context.Context, id any, message *mcp.ListResourcesRequest, result *mcp.ListResourcesResult) {
363	c.onSuccess(ctx, id, mcp.MethodResourcesList, message, result)
364	if c == nil {
365		return
366	}
367	for _, hook := range c.OnAfterListResources {
368		hook(ctx, id, message, result)
369	}
370}
371func (c *Hooks) AddBeforeListResourceTemplates(hook OnBeforeListResourceTemplatesFunc) {
372	c.OnBeforeListResourceTemplates = append(c.OnBeforeListResourceTemplates, hook)
373}
374
375func (c *Hooks) AddAfterListResourceTemplates(hook OnAfterListResourceTemplatesFunc) {
376	c.OnAfterListResourceTemplates = append(c.OnAfterListResourceTemplates, hook)
377}
378
379func (c *Hooks) beforeListResourceTemplates(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest) {
380	c.beforeAny(ctx, id, mcp.MethodResourcesTemplatesList, message)
381	if c == nil {
382		return
383	}
384	for _, hook := range c.OnBeforeListResourceTemplates {
385		hook(ctx, id, message)
386	}
387}
388
389func (c *Hooks) afterListResourceTemplates(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest, result *mcp.ListResourceTemplatesResult) {
390	c.onSuccess(ctx, id, mcp.MethodResourcesTemplatesList, message, result)
391	if c == nil {
392		return
393	}
394	for _, hook := range c.OnAfterListResourceTemplates {
395		hook(ctx, id, message, result)
396	}
397}
398func (c *Hooks) AddBeforeReadResource(hook OnBeforeReadResourceFunc) {
399	c.OnBeforeReadResource = append(c.OnBeforeReadResource, hook)
400}
401
402func (c *Hooks) AddAfterReadResource(hook OnAfterReadResourceFunc) {
403	c.OnAfterReadResource = append(c.OnAfterReadResource, hook)
404}
405
406func (c *Hooks) beforeReadResource(ctx context.Context, id any, message *mcp.ReadResourceRequest) {
407	c.beforeAny(ctx, id, mcp.MethodResourcesRead, message)
408	if c == nil {
409		return
410	}
411	for _, hook := range c.OnBeforeReadResource {
412		hook(ctx, id, message)
413	}
414}
415
416func (c *Hooks) afterReadResource(ctx context.Context, id any, message *mcp.ReadResourceRequest, result *mcp.ReadResourceResult) {
417	c.onSuccess(ctx, id, mcp.MethodResourcesRead, message, result)
418	if c == nil {
419		return
420	}
421	for _, hook := range c.OnAfterReadResource {
422		hook(ctx, id, message, result)
423	}
424}
425func (c *Hooks) AddBeforeListPrompts(hook OnBeforeListPromptsFunc) {
426	c.OnBeforeListPrompts = append(c.OnBeforeListPrompts, hook)
427}
428
429func (c *Hooks) AddAfterListPrompts(hook OnAfterListPromptsFunc) {
430	c.OnAfterListPrompts = append(c.OnAfterListPrompts, hook)
431}
432
433func (c *Hooks) beforeListPrompts(ctx context.Context, id any, message *mcp.ListPromptsRequest) {
434	c.beforeAny(ctx, id, mcp.MethodPromptsList, message)
435	if c == nil {
436		return
437	}
438	for _, hook := range c.OnBeforeListPrompts {
439		hook(ctx, id, message)
440	}
441}
442
443func (c *Hooks) afterListPrompts(ctx context.Context, id any, message *mcp.ListPromptsRequest, result *mcp.ListPromptsResult) {
444	c.onSuccess(ctx, id, mcp.MethodPromptsList, message, result)
445	if c == nil {
446		return
447	}
448	for _, hook := range c.OnAfterListPrompts {
449		hook(ctx, id, message, result)
450	}
451}
452func (c *Hooks) AddBeforeGetPrompt(hook OnBeforeGetPromptFunc) {
453	c.OnBeforeGetPrompt = append(c.OnBeforeGetPrompt, hook)
454}
455
456func (c *Hooks) AddAfterGetPrompt(hook OnAfterGetPromptFunc) {
457	c.OnAfterGetPrompt = append(c.OnAfterGetPrompt, hook)
458}
459
460func (c *Hooks) beforeGetPrompt(ctx context.Context, id any, message *mcp.GetPromptRequest) {
461	c.beforeAny(ctx, id, mcp.MethodPromptsGet, message)
462	if c == nil {
463		return
464	}
465	for _, hook := range c.OnBeforeGetPrompt {
466		hook(ctx, id, message)
467	}
468}
469
470func (c *Hooks) afterGetPrompt(ctx context.Context, id any, message *mcp.GetPromptRequest, result *mcp.GetPromptResult) {
471	c.onSuccess(ctx, id, mcp.MethodPromptsGet, message, result)
472	if c == nil {
473		return
474	}
475	for _, hook := range c.OnAfterGetPrompt {
476		hook(ctx, id, message, result)
477	}
478}
479func (c *Hooks) AddBeforeListTools(hook OnBeforeListToolsFunc) {
480	c.OnBeforeListTools = append(c.OnBeforeListTools, hook)
481}
482
483func (c *Hooks) AddAfterListTools(hook OnAfterListToolsFunc) {
484	c.OnAfterListTools = append(c.OnAfterListTools, hook)
485}
486
487func (c *Hooks) beforeListTools(ctx context.Context, id any, message *mcp.ListToolsRequest) {
488	c.beforeAny(ctx, id, mcp.MethodToolsList, message)
489	if c == nil {
490		return
491	}
492	for _, hook := range c.OnBeforeListTools {
493		hook(ctx, id, message)
494	}
495}
496
497func (c *Hooks) afterListTools(ctx context.Context, id any, message *mcp.ListToolsRequest, result *mcp.ListToolsResult) {
498	c.onSuccess(ctx, id, mcp.MethodToolsList, message, result)
499	if c == nil {
500		return
501	}
502	for _, hook := range c.OnAfterListTools {
503		hook(ctx, id, message, result)
504	}
505}
506func (c *Hooks) AddBeforeCallTool(hook OnBeforeCallToolFunc) {
507	c.OnBeforeCallTool = append(c.OnBeforeCallTool, hook)
508}
509
510func (c *Hooks) AddAfterCallTool(hook OnAfterCallToolFunc) {
511	c.OnAfterCallTool = append(c.OnAfterCallTool, hook)
512}
513
514func (c *Hooks) beforeCallTool(ctx context.Context, id any, message *mcp.CallToolRequest) {
515	c.beforeAny(ctx, id, mcp.MethodToolsCall, message)
516	if c == nil {
517		return
518	}
519	for _, hook := range c.OnBeforeCallTool {
520		hook(ctx, id, message)
521	}
522}
523
524func (c *Hooks) afterCallTool(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) {
525	c.onSuccess(ctx, id, mcp.MethodToolsCall, message, result)
526	if c == nil {
527		return
528	}
529	for _, hook := range c.OnAfterCallTool {
530		hook(ctx, id, message, result)
531	}
532}