utils.go

  1package mcp
  2
  3import (
  4	"encoding/json"
  5	"fmt"
  6
  7	"github.com/spf13/cast"
  8)
  9
 10// ClientRequest types
 11var _ ClientRequest = &PingRequest{}
 12var _ ClientRequest = &InitializeRequest{}
 13var _ ClientRequest = &CompleteRequest{}
 14var _ ClientRequest = &SetLevelRequest{}
 15var _ ClientRequest = &GetPromptRequest{}
 16var _ ClientRequest = &ListPromptsRequest{}
 17var _ ClientRequest = &ListResourcesRequest{}
 18var _ ClientRequest = &ReadResourceRequest{}
 19var _ ClientRequest = &SubscribeRequest{}
 20var _ ClientRequest = &UnsubscribeRequest{}
 21var _ ClientRequest = &CallToolRequest{}
 22var _ ClientRequest = &ListToolsRequest{}
 23
 24// ClientNotification types
 25var _ ClientNotification = &CancelledNotification{}
 26var _ ClientNotification = &ProgressNotification{}
 27var _ ClientNotification = &InitializedNotification{}
 28var _ ClientNotification = &RootsListChangedNotification{}
 29
 30// ClientResult types
 31var _ ClientResult = &EmptyResult{}
 32var _ ClientResult = &CreateMessageResult{}
 33var _ ClientResult = &ListRootsResult{}
 34
 35// ServerRequest types
 36var _ ServerRequest = &PingRequest{}
 37var _ ServerRequest = &CreateMessageRequest{}
 38var _ ServerRequest = &ListRootsRequest{}
 39
 40// ServerNotification types
 41var _ ServerNotification = &CancelledNotification{}
 42var _ ServerNotification = &ProgressNotification{}
 43var _ ServerNotification = &LoggingMessageNotification{}
 44var _ ServerNotification = &ResourceUpdatedNotification{}
 45var _ ServerNotification = &ResourceListChangedNotification{}
 46var _ ServerNotification = &ToolListChangedNotification{}
 47var _ ServerNotification = &PromptListChangedNotification{}
 48
 49// ServerResult types
 50var _ ServerResult = &EmptyResult{}
 51var _ ServerResult = &InitializeResult{}
 52var _ ServerResult = &CompleteResult{}
 53var _ ServerResult = &GetPromptResult{}
 54var _ ServerResult = &ListPromptsResult{}
 55var _ ServerResult = &ListResourcesResult{}
 56var _ ServerResult = &ReadResourceResult{}
 57var _ ServerResult = &CallToolResult{}
 58var _ ServerResult = &ListToolsResult{}
 59
 60// Helper functions for type assertions
 61
 62// asType attempts to cast the given interface to the given type
 63func asType[T any](content any) (*T, bool) {
 64	tc, ok := content.(T)
 65	if !ok {
 66		return nil, false
 67	}
 68	return &tc, true
 69}
 70
 71// AsTextContent attempts to cast the given interface to TextContent
 72func AsTextContent(content any) (*TextContent, bool) {
 73	return asType[TextContent](content)
 74}
 75
 76// AsImageContent attempts to cast the given interface to ImageContent
 77func AsImageContent(content any) (*ImageContent, bool) {
 78	return asType[ImageContent](content)
 79}
 80
 81// AsAudioContent attempts to cast the given interface to AudioContent
 82func AsAudioContent(content any) (*AudioContent, bool) {
 83	return asType[AudioContent](content)
 84}
 85
 86// AsEmbeddedResource attempts to cast the given interface to EmbeddedResource
 87func AsEmbeddedResource(content any) (*EmbeddedResource, bool) {
 88	return asType[EmbeddedResource](content)
 89}
 90
 91// AsTextResourceContents attempts to cast the given interface to TextResourceContents
 92func AsTextResourceContents(content any) (*TextResourceContents, bool) {
 93	return asType[TextResourceContents](content)
 94}
 95
 96// AsBlobResourceContents attempts to cast the given interface to BlobResourceContents
 97func AsBlobResourceContents(content any) (*BlobResourceContents, bool) {
 98	return asType[BlobResourceContents](content)
 99}
100
101// Helper function for JSON-RPC
102
103// NewJSONRPCResponse creates a new JSONRPCResponse with the given id and result
104func NewJSONRPCResponse(id RequestId, result Result) JSONRPCResponse {
105	return JSONRPCResponse{
106		JSONRPC: JSONRPC_VERSION,
107		ID:      id,
108		Result:  result,
109	}
110}
111
112// NewJSONRPCError creates a new JSONRPCResponse with the given id, code, and message
113func NewJSONRPCError(
114	id RequestId,
115	code int,
116	message string,
117	data any,
118) JSONRPCError {
119	return JSONRPCError{
120		JSONRPC: JSONRPC_VERSION,
121		ID:      id,
122		Error: struct {
123			Code    int    `json:"code"`
124			Message string `json:"message"`
125			Data    any    `json:"data,omitempty"`
126		}{
127			Code:    code,
128			Message: message,
129			Data:    data,
130		},
131	}
132}
133
134// NewProgressNotification
135// Helper function for creating a progress notification
136func NewProgressNotification(
137	token ProgressToken,
138	progress float64,
139	total *float64,
140	message *string,
141) ProgressNotification {
142	notification := ProgressNotification{
143		Notification: Notification{
144			Method: "notifications/progress",
145		},
146		Params: struct {
147			ProgressToken ProgressToken `json:"progressToken"`
148			Progress      float64       `json:"progress"`
149			Total         float64       `json:"total,omitempty"`
150			Message       string        `json:"message,omitempty"`
151		}{
152			ProgressToken: token,
153			Progress:      progress,
154		},
155	}
156	if total != nil {
157		notification.Params.Total = *total
158	}
159	if message != nil {
160		notification.Params.Message = *message
161	}
162	return notification
163}
164
165// NewLoggingMessageNotification
166// Helper function for creating a logging message notification
167func NewLoggingMessageNotification(
168	level LoggingLevel,
169	logger string,
170	data any,
171) LoggingMessageNotification {
172	return LoggingMessageNotification{
173		Notification: Notification{
174			Method: "notifications/message",
175		},
176		Params: struct {
177			Level  LoggingLevel `json:"level"`
178			Logger string       `json:"logger,omitempty"`
179			Data   any          `json:"data"`
180		}{
181			Level:  level,
182			Logger: logger,
183			Data:   data,
184		},
185	}
186}
187
188// NewPromptMessage
189// Helper function to create a new PromptMessage
190func NewPromptMessage(role Role, content Content) PromptMessage {
191	return PromptMessage{
192		Role:    role,
193		Content: content,
194	}
195}
196
197// NewTextContent
198// Helper function to create a new TextContent
199func NewTextContent(text string) TextContent {
200	return TextContent{
201		Type: "text",
202		Text: text,
203	}
204}
205
206// NewImageContent
207// Helper function to create a new ImageContent
208func NewImageContent(data, mimeType string) ImageContent {
209	return ImageContent{
210		Type:     "image",
211		Data:     data,
212		MIMEType: mimeType,
213	}
214}
215
216// Helper function to create a new AudioContent
217func NewAudioContent(data, mimeType string) AudioContent {
218	return AudioContent{
219		Type:     "audio",
220		Data:     data,
221		MIMEType: mimeType,
222	}
223}
224
225// Helper function to create a new EmbeddedResource
226func NewEmbeddedResource(resource ResourceContents) EmbeddedResource {
227	return EmbeddedResource{
228		Type:     "resource",
229		Resource: resource,
230	}
231}
232
233// NewToolResultText creates a new CallToolResult with a text content
234func NewToolResultText(text string) *CallToolResult {
235	return &CallToolResult{
236		Content: []Content{
237			TextContent{
238				Type: "text",
239				Text: text,
240			},
241		},
242	}
243}
244
245// NewToolResultImage creates a new CallToolResult with both text and image content
246func NewToolResultImage(text, imageData, mimeType string) *CallToolResult {
247	return &CallToolResult{
248		Content: []Content{
249			TextContent{
250				Type: "text",
251				Text: text,
252			},
253			ImageContent{
254				Type:     "image",
255				Data:     imageData,
256				MIMEType: mimeType,
257			},
258		},
259	}
260}
261
262// NewToolResultAudio creates a new CallToolResult with both text and audio content
263func NewToolResultAudio(text, imageData, mimeType string) *CallToolResult {
264	return &CallToolResult{
265		Content: []Content{
266			TextContent{
267				Type: "text",
268				Text: text,
269			},
270			AudioContent{
271				Type:     "audio",
272				Data:     imageData,
273				MIMEType: mimeType,
274			},
275		},
276	}
277}
278
279// NewToolResultResource creates a new CallToolResult with an embedded resource
280func NewToolResultResource(
281	text string,
282	resource ResourceContents,
283) *CallToolResult {
284	return &CallToolResult{
285		Content: []Content{
286			TextContent{
287				Type: "text",
288				Text: text,
289			},
290			EmbeddedResource{
291				Type:     "resource",
292				Resource: resource,
293			},
294		},
295	}
296}
297
298// NewToolResultError creates a new CallToolResult with an error message.
299// Any errors that originate from the tool SHOULD be reported inside the result object.
300func NewToolResultError(text string) *CallToolResult {
301	return &CallToolResult{
302		Content: []Content{
303			TextContent{
304				Type: "text",
305				Text: text,
306			},
307		},
308		IsError: true,
309	}
310}
311
312// NewToolResultErrorFromErr creates a new CallToolResult with an error message.
313// If an error is provided, its details will be appended to the text message.
314// Any errors that originate from the tool SHOULD be reported inside the result object.
315func NewToolResultErrorFromErr(text string, err error) *CallToolResult {
316	if err != nil {
317		text = fmt.Sprintf("%s: %v", text, err)
318	}
319	return &CallToolResult{
320		Content: []Content{
321			TextContent{
322				Type: "text",
323				Text: text,
324			},
325		},
326		IsError: true,
327	}
328}
329
330// NewToolResultErrorf creates a new CallToolResult with an error message.
331// The error message is formatted using the fmt package.
332// Any errors that originate from the tool SHOULD be reported inside the result object.
333func NewToolResultErrorf(format string, a ...any) *CallToolResult {
334	return &CallToolResult{
335		Content: []Content{
336			TextContent{
337				Type: "text",
338				Text: fmt.Sprintf(format, a...),
339			},
340		},
341		IsError: true,
342	}
343}
344
345// NewListResourcesResult creates a new ListResourcesResult
346func NewListResourcesResult(
347	resources []Resource,
348	nextCursor Cursor,
349) *ListResourcesResult {
350	return &ListResourcesResult{
351		PaginatedResult: PaginatedResult{
352			NextCursor: nextCursor,
353		},
354		Resources: resources,
355	}
356}
357
358// NewListResourceTemplatesResult creates a new ListResourceTemplatesResult
359func NewListResourceTemplatesResult(
360	templates []ResourceTemplate,
361	nextCursor Cursor,
362) *ListResourceTemplatesResult {
363	return &ListResourceTemplatesResult{
364		PaginatedResult: PaginatedResult{
365			NextCursor: nextCursor,
366		},
367		ResourceTemplates: templates,
368	}
369}
370
371// NewReadResourceResult creates a new ReadResourceResult with text content
372func NewReadResourceResult(text string) *ReadResourceResult {
373	return &ReadResourceResult{
374		Contents: []ResourceContents{
375			TextResourceContents{
376				Text: text,
377			},
378		},
379	}
380}
381
382// NewListPromptsResult creates a new ListPromptsResult
383func NewListPromptsResult(
384	prompts []Prompt,
385	nextCursor Cursor,
386) *ListPromptsResult {
387	return &ListPromptsResult{
388		PaginatedResult: PaginatedResult{
389			NextCursor: nextCursor,
390		},
391		Prompts: prompts,
392	}
393}
394
395// NewGetPromptResult creates a new GetPromptResult
396func NewGetPromptResult(
397	description string,
398	messages []PromptMessage,
399) *GetPromptResult {
400	return &GetPromptResult{
401		Description: description,
402		Messages:    messages,
403	}
404}
405
406// NewListToolsResult creates a new ListToolsResult
407func NewListToolsResult(tools []Tool, nextCursor Cursor) *ListToolsResult {
408	return &ListToolsResult{
409		PaginatedResult: PaginatedResult{
410			NextCursor: nextCursor,
411		},
412		Tools: tools,
413	}
414}
415
416// NewInitializeResult creates a new InitializeResult
417func NewInitializeResult(
418	protocolVersion string,
419	capabilities ServerCapabilities,
420	serverInfo Implementation,
421	instructions string,
422) *InitializeResult {
423	return &InitializeResult{
424		ProtocolVersion: protocolVersion,
425		Capabilities:    capabilities,
426		ServerInfo:      serverInfo,
427		Instructions:    instructions,
428	}
429}
430
431// FormatNumberResult
432// Helper for formatting numbers in tool results
433func FormatNumberResult(value float64) *CallToolResult {
434	return NewToolResultText(fmt.Sprintf("%.2f", value))
435}
436
437func ExtractString(data map[string]any, key string) string {
438	if value, ok := data[key]; ok {
439		if str, ok := value.(string); ok {
440			return str
441		}
442	}
443	return ""
444}
445
446func ExtractMap(data map[string]any, key string) map[string]any {
447	if value, ok := data[key]; ok {
448		if m, ok := value.(map[string]any); ok {
449			return m
450		}
451	}
452	return nil
453}
454
455func ParseContent(contentMap map[string]any) (Content, error) {
456	contentType := ExtractString(contentMap, "type")
457
458	switch contentType {
459	case "text":
460		text := ExtractString(contentMap, "text")
461		return NewTextContent(text), nil
462
463	case "image":
464		data := ExtractString(contentMap, "data")
465		mimeType := ExtractString(contentMap, "mimeType")
466		if data == "" || mimeType == "" {
467			return nil, fmt.Errorf("image data or mimeType is missing")
468		}
469		return NewImageContent(data, mimeType), nil
470
471	case "audio":
472		data := ExtractString(contentMap, "data")
473		mimeType := ExtractString(contentMap, "mimeType")
474		if data == "" || mimeType == "" {
475			return nil, fmt.Errorf("audio data or mimeType is missing")
476		}
477		return NewAudioContent(data, mimeType), nil
478
479	case "resource":
480		resourceMap := ExtractMap(contentMap, "resource")
481		if resourceMap == nil {
482			return nil, fmt.Errorf("resource is missing")
483		}
484
485		resourceContents, err := ParseResourceContents(resourceMap)
486		if err != nil {
487			return nil, err
488		}
489
490		return NewEmbeddedResource(resourceContents), nil
491	}
492
493	return nil, fmt.Errorf("unsupported content type: %s", contentType)
494}
495
496func ParseGetPromptResult(rawMessage *json.RawMessage) (*GetPromptResult, error) {
497	if rawMessage == nil {
498		return nil, fmt.Errorf("response is nil")
499	}
500
501	var jsonContent map[string]any
502	if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil {
503		return nil, fmt.Errorf("failed to unmarshal response: %w", err)
504	}
505
506	result := GetPromptResult{}
507
508	meta, ok := jsonContent["_meta"]
509	if ok {
510		if metaMap, ok := meta.(map[string]any); ok {
511			result.Meta = metaMap
512		}
513	}
514
515	description, ok := jsonContent["description"]
516	if ok {
517		if descriptionStr, ok := description.(string); ok {
518			result.Description = descriptionStr
519		}
520	}
521
522	messages, ok := jsonContent["messages"]
523	if ok {
524		messagesArr, ok := messages.([]any)
525		if !ok {
526			return nil, fmt.Errorf("messages is not an array")
527		}
528
529		for _, message := range messagesArr {
530			messageMap, ok := message.(map[string]any)
531			if !ok {
532				return nil, fmt.Errorf("message is not an object")
533			}
534
535			// Extract role
536			roleStr := ExtractString(messageMap, "role")
537			if roleStr == "" || (roleStr != string(RoleAssistant) && roleStr != string(RoleUser)) {
538				return nil, fmt.Errorf("unsupported role: %s", roleStr)
539			}
540
541			// Extract content
542			contentMap, ok := messageMap["content"].(map[string]any)
543			if !ok {
544				return nil, fmt.Errorf("content is not an object")
545			}
546
547			// Process content
548			content, err := ParseContent(contentMap)
549			if err != nil {
550				return nil, err
551			}
552
553			// Append processed message
554			result.Messages = append(result.Messages, NewPromptMessage(Role(roleStr), content))
555
556		}
557	}
558
559	return &result, nil
560}
561
562func ParseCallToolResult(rawMessage *json.RawMessage) (*CallToolResult, error) {
563	if rawMessage == nil {
564		return nil, fmt.Errorf("response is nil")
565	}
566
567	var jsonContent map[string]any
568	if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil {
569		return nil, fmt.Errorf("failed to unmarshal response: %w", err)
570	}
571
572	var result CallToolResult
573
574	meta, ok := jsonContent["_meta"]
575	if ok {
576		if metaMap, ok := meta.(map[string]any); ok {
577			result.Meta = metaMap
578		}
579	}
580
581	isError, ok := jsonContent["isError"]
582	if ok {
583		if isErrorBool, ok := isError.(bool); ok {
584			result.IsError = isErrorBool
585		}
586	}
587
588	contents, ok := jsonContent["content"]
589	if !ok {
590		return nil, fmt.Errorf("content is missing")
591	}
592
593	contentArr, ok := contents.([]any)
594	if !ok {
595		return nil, fmt.Errorf("content is not an array")
596	}
597
598	for _, content := range contentArr {
599		// Extract content
600		contentMap, ok := content.(map[string]any)
601		if !ok {
602			return nil, fmt.Errorf("content is not an object")
603		}
604
605		// Process content
606		content, err := ParseContent(contentMap)
607		if err != nil {
608			return nil, err
609		}
610
611		result.Content = append(result.Content, content)
612	}
613
614	return &result, nil
615}
616
617func ParseResourceContents(contentMap map[string]any) (ResourceContents, error) {
618	uri := ExtractString(contentMap, "uri")
619	if uri == "" {
620		return nil, fmt.Errorf("resource uri is missing")
621	}
622
623	mimeType := ExtractString(contentMap, "mimeType")
624
625	if text := ExtractString(contentMap, "text"); text != "" {
626		return TextResourceContents{
627			URI:      uri,
628			MIMEType: mimeType,
629			Text:     text,
630		}, nil
631	}
632
633	if blob := ExtractString(contentMap, "blob"); blob != "" {
634		return BlobResourceContents{
635			URI:      uri,
636			MIMEType: mimeType,
637			Blob:     blob,
638		}, nil
639	}
640
641	return nil, fmt.Errorf("unsupported resource type")
642}
643
644func ParseReadResourceResult(rawMessage *json.RawMessage) (*ReadResourceResult, error) {
645	if rawMessage == nil {
646		return nil, fmt.Errorf("response is nil")
647	}
648
649	var jsonContent map[string]any
650	if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil {
651		return nil, fmt.Errorf("failed to unmarshal response: %w", err)
652	}
653
654	var result ReadResourceResult
655
656	meta, ok := jsonContent["_meta"]
657	if ok {
658		if metaMap, ok := meta.(map[string]any); ok {
659			result.Meta = metaMap
660		}
661	}
662
663	contents, ok := jsonContent["contents"]
664	if !ok {
665		return nil, fmt.Errorf("contents is missing")
666	}
667
668	contentArr, ok := contents.([]any)
669	if !ok {
670		return nil, fmt.Errorf("contents is not an array")
671	}
672
673	for _, content := range contentArr {
674		// Extract content
675		contentMap, ok := content.(map[string]any)
676		if !ok {
677			return nil, fmt.Errorf("content is not an object")
678		}
679
680		// Process content
681		content, err := ParseResourceContents(contentMap)
682		if err != nil {
683			return nil, err
684		}
685
686		result.Contents = append(result.Contents, content)
687	}
688
689	return &result, nil
690}
691
692func ParseArgument(request CallToolRequest, key string, defaultVal any) any {
693	args := request.GetArguments()
694	if _, ok := args[key]; !ok {
695		return defaultVal
696	} else {
697		return args[key]
698	}
699}
700
701// ParseBoolean extracts and converts a boolean parameter from a CallToolRequest.
702// If the key is not found in the Arguments map, the defaultValue is returned.
703// The function uses cast.ToBool for conversion which handles various string representations
704// such as "true", "yes", "1", etc.
705func ParseBoolean(request CallToolRequest, key string, defaultValue bool) bool {
706	v := ParseArgument(request, key, defaultValue)
707	return cast.ToBool(v)
708}
709
710// ParseInt64 extracts and converts an int64 parameter from a CallToolRequest.
711// If the key is not found in the Arguments map, the defaultValue is returned.
712func ParseInt64(request CallToolRequest, key string, defaultValue int64) int64 {
713	v := ParseArgument(request, key, defaultValue)
714	return cast.ToInt64(v)
715}
716
717// ParseInt32 extracts and converts an int32 parameter from a CallToolRequest.
718func ParseInt32(request CallToolRequest, key string, defaultValue int32) int32 {
719	v := ParseArgument(request, key, defaultValue)
720	return cast.ToInt32(v)
721}
722
723// ParseInt16 extracts and converts an int16 parameter from a CallToolRequest.
724func ParseInt16(request CallToolRequest, key string, defaultValue int16) int16 {
725	v := ParseArgument(request, key, defaultValue)
726	return cast.ToInt16(v)
727}
728
729// ParseInt8 extracts and converts an int8 parameter from a CallToolRequest.
730func ParseInt8(request CallToolRequest, key string, defaultValue int8) int8 {
731	v := ParseArgument(request, key, defaultValue)
732	return cast.ToInt8(v)
733}
734
735// ParseInt extracts and converts an int parameter from a CallToolRequest.
736func ParseInt(request CallToolRequest, key string, defaultValue int) int {
737	v := ParseArgument(request, key, defaultValue)
738	return cast.ToInt(v)
739}
740
741// ParseUInt extracts and converts an uint parameter from a CallToolRequest.
742func ParseUInt(request CallToolRequest, key string, defaultValue uint) uint {
743	v := ParseArgument(request, key, defaultValue)
744	return cast.ToUint(v)
745}
746
747// ParseUInt64 extracts and converts an uint64 parameter from a CallToolRequest.
748func ParseUInt64(request CallToolRequest, key string, defaultValue uint64) uint64 {
749	v := ParseArgument(request, key, defaultValue)
750	return cast.ToUint64(v)
751}
752
753// ParseUInt32 extracts and converts an uint32 parameter from a CallToolRequest.
754func ParseUInt32(request CallToolRequest, key string, defaultValue uint32) uint32 {
755	v := ParseArgument(request, key, defaultValue)
756	return cast.ToUint32(v)
757}
758
759// ParseUInt16 extracts and converts an uint16 parameter from a CallToolRequest.
760func ParseUInt16(request CallToolRequest, key string, defaultValue uint16) uint16 {
761	v := ParseArgument(request, key, defaultValue)
762	return cast.ToUint16(v)
763}
764
765// ParseUInt8 extracts and converts an uint8 parameter from a CallToolRequest.
766func ParseUInt8(request CallToolRequest, key string, defaultValue uint8) uint8 {
767	v := ParseArgument(request, key, defaultValue)
768	return cast.ToUint8(v)
769}
770
771// ParseFloat32 extracts and converts a float32 parameter from a CallToolRequest.
772func ParseFloat32(request CallToolRequest, key string, defaultValue float32) float32 {
773	v := ParseArgument(request, key, defaultValue)
774	return cast.ToFloat32(v)
775}
776
777// ParseFloat64 extracts and converts a float64 parameter from a CallToolRequest.
778func ParseFloat64(request CallToolRequest, key string, defaultValue float64) float64 {
779	v := ParseArgument(request, key, defaultValue)
780	return cast.ToFloat64(v)
781}
782
783// ParseString extracts and converts a string parameter from a CallToolRequest.
784func ParseString(request CallToolRequest, key string, defaultValue string) string {
785	v := ParseArgument(request, key, defaultValue)
786	return cast.ToString(v)
787}
788
789// ParseStringMap extracts and converts a string map parameter from a CallToolRequest.
790func ParseStringMap(request CallToolRequest, key string, defaultValue map[string]any) map[string]any {
791	v := ParseArgument(request, key, defaultValue)
792	return cast.ToStringMap(v)
793}
794
795// ToBoolPtr returns a pointer to the given boolean value
796func ToBoolPtr(b bool) *bool {
797	return &b
798}