1package mcp
2
3import (
4 "encoding/json"
5 "fmt"
6
7 "github.com/spf13/cast"
8)
9
10// ClientRequest types
11var _ ClientRequest = &PingRequest{}
12var _ ClientRequest = &InitializeRequest{}
13var _ ClientRequest = &CompleteRequest{}
14var _ ClientRequest = &SetLevelRequest{}
15var _ ClientRequest = &GetPromptRequest{}
16var _ ClientRequest = &ListPromptsRequest{}
17var _ ClientRequest = &ListResourcesRequest{}
18var _ ClientRequest = &ReadResourceRequest{}
19var _ ClientRequest = &SubscribeRequest{}
20var _ ClientRequest = &UnsubscribeRequest{}
21var _ ClientRequest = &CallToolRequest{}
22var _ ClientRequest = &ListToolsRequest{}
23
24// ClientNotification types
25var _ ClientNotification = &CancelledNotification{}
26var _ ClientNotification = &ProgressNotification{}
27var _ ClientNotification = &InitializedNotification{}
28var _ ClientNotification = &RootsListChangedNotification{}
29
30// ClientResult types
31var _ ClientResult = &EmptyResult{}
32var _ ClientResult = &CreateMessageResult{}
33var _ ClientResult = &ListRootsResult{}
34
35// ServerRequest types
36var _ ServerRequest = &PingRequest{}
37var _ ServerRequest = &CreateMessageRequest{}
38var _ ServerRequest = &ListRootsRequest{}
39
40// ServerNotification types
41var _ ServerNotification = &CancelledNotification{}
42var _ ServerNotification = &ProgressNotification{}
43var _ ServerNotification = &LoggingMessageNotification{}
44var _ ServerNotification = &ResourceUpdatedNotification{}
45var _ ServerNotification = &ResourceListChangedNotification{}
46var _ ServerNotification = &ToolListChangedNotification{}
47var _ ServerNotification = &PromptListChangedNotification{}
48
49// ServerResult types
50var _ ServerResult = &EmptyResult{}
51var _ ServerResult = &InitializeResult{}
52var _ ServerResult = &CompleteResult{}
53var _ ServerResult = &GetPromptResult{}
54var _ ServerResult = &ListPromptsResult{}
55var _ ServerResult = &ListResourcesResult{}
56var _ ServerResult = &ReadResourceResult{}
57var _ ServerResult = &CallToolResult{}
58var _ ServerResult = &ListToolsResult{}
59
60// Helper functions for type assertions
61
62// asType attempts to cast the given interface to the given type
63func asType[T any](content any) (*T, bool) {
64 tc, ok := content.(T)
65 if !ok {
66 return nil, false
67 }
68 return &tc, true
69}
70
71// AsTextContent attempts to cast the given interface to TextContent
72func AsTextContent(content any) (*TextContent, bool) {
73 return asType[TextContent](content)
74}
75
76// AsImageContent attempts to cast the given interface to ImageContent
77func AsImageContent(content any) (*ImageContent, bool) {
78 return asType[ImageContent](content)
79}
80
81// AsAudioContent attempts to cast the given interface to AudioContent
82func AsAudioContent(content any) (*AudioContent, bool) {
83 return asType[AudioContent](content)
84}
85
86// AsEmbeddedResource attempts to cast the given interface to EmbeddedResource
87func AsEmbeddedResource(content any) (*EmbeddedResource, bool) {
88 return asType[EmbeddedResource](content)
89}
90
91// AsTextResourceContents attempts to cast the given interface to TextResourceContents
92func AsTextResourceContents(content any) (*TextResourceContents, bool) {
93 return asType[TextResourceContents](content)
94}
95
96// AsBlobResourceContents attempts to cast the given interface to BlobResourceContents
97func AsBlobResourceContents(content any) (*BlobResourceContents, bool) {
98 return asType[BlobResourceContents](content)
99}
100
101// Helper function for JSON-RPC
102
103// NewJSONRPCResponse creates a new JSONRPCResponse with the given id and result
104func NewJSONRPCResponse(id RequestId, result Result) JSONRPCResponse {
105 return JSONRPCResponse{
106 JSONRPC: JSONRPC_VERSION,
107 ID: id,
108 Result: result,
109 }
110}
111
112// NewJSONRPCError creates a new JSONRPCResponse with the given id, code, and message
113func NewJSONRPCError(
114 id RequestId,
115 code int,
116 message string,
117 data any,
118) JSONRPCError {
119 return JSONRPCError{
120 JSONRPC: JSONRPC_VERSION,
121 ID: id,
122 Error: struct {
123 Code int `json:"code"`
124 Message string `json:"message"`
125 Data any `json:"data,omitempty"`
126 }{
127 Code: code,
128 Message: message,
129 Data: data,
130 },
131 }
132}
133
134// NewProgressNotification
135// Helper function for creating a progress notification
136func NewProgressNotification(
137 token ProgressToken,
138 progress float64,
139 total *float64,
140 message *string,
141) ProgressNotification {
142 notification := ProgressNotification{
143 Notification: Notification{
144 Method: "notifications/progress",
145 },
146 Params: struct {
147 ProgressToken ProgressToken `json:"progressToken"`
148 Progress float64 `json:"progress"`
149 Total float64 `json:"total,omitempty"`
150 Message string `json:"message,omitempty"`
151 }{
152 ProgressToken: token,
153 Progress: progress,
154 },
155 }
156 if total != nil {
157 notification.Params.Total = *total
158 }
159 if message != nil {
160 notification.Params.Message = *message
161 }
162 return notification
163}
164
165// NewLoggingMessageNotification
166// Helper function for creating a logging message notification
167func NewLoggingMessageNotification(
168 level LoggingLevel,
169 logger string,
170 data any,
171) LoggingMessageNotification {
172 return LoggingMessageNotification{
173 Notification: Notification{
174 Method: "notifications/message",
175 },
176 Params: struct {
177 Level LoggingLevel `json:"level"`
178 Logger string `json:"logger,omitempty"`
179 Data any `json:"data"`
180 }{
181 Level: level,
182 Logger: logger,
183 Data: data,
184 },
185 }
186}
187
188// NewPromptMessage
189// Helper function to create a new PromptMessage
190func NewPromptMessage(role Role, content Content) PromptMessage {
191 return PromptMessage{
192 Role: role,
193 Content: content,
194 }
195}
196
197// NewTextContent
198// Helper function to create a new TextContent
199func NewTextContent(text string) TextContent {
200 return TextContent{
201 Type: "text",
202 Text: text,
203 }
204}
205
206// NewImageContent
207// Helper function to create a new ImageContent
208func NewImageContent(data, mimeType string) ImageContent {
209 return ImageContent{
210 Type: "image",
211 Data: data,
212 MIMEType: mimeType,
213 }
214}
215
216// Helper function to create a new AudioContent
217func NewAudioContent(data, mimeType string) AudioContent {
218 return AudioContent{
219 Type: "audio",
220 Data: data,
221 MIMEType: mimeType,
222 }
223}
224
225// Helper function to create a new EmbeddedResource
226func NewEmbeddedResource(resource ResourceContents) EmbeddedResource {
227 return EmbeddedResource{
228 Type: "resource",
229 Resource: resource,
230 }
231}
232
233// NewToolResultText creates a new CallToolResult with a text content
234func NewToolResultText(text string) *CallToolResult {
235 return &CallToolResult{
236 Content: []Content{
237 TextContent{
238 Type: "text",
239 Text: text,
240 },
241 },
242 }
243}
244
245// NewToolResultImage creates a new CallToolResult with both text and image content
246func NewToolResultImage(text, imageData, mimeType string) *CallToolResult {
247 return &CallToolResult{
248 Content: []Content{
249 TextContent{
250 Type: "text",
251 Text: text,
252 },
253 ImageContent{
254 Type: "image",
255 Data: imageData,
256 MIMEType: mimeType,
257 },
258 },
259 }
260}
261
262// NewToolResultAudio creates a new CallToolResult with both text and audio content
263func NewToolResultAudio(text, imageData, mimeType string) *CallToolResult {
264 return &CallToolResult{
265 Content: []Content{
266 TextContent{
267 Type: "text",
268 Text: text,
269 },
270 AudioContent{
271 Type: "audio",
272 Data: imageData,
273 MIMEType: mimeType,
274 },
275 },
276 }
277}
278
279// NewToolResultResource creates a new CallToolResult with an embedded resource
280func NewToolResultResource(
281 text string,
282 resource ResourceContents,
283) *CallToolResult {
284 return &CallToolResult{
285 Content: []Content{
286 TextContent{
287 Type: "text",
288 Text: text,
289 },
290 EmbeddedResource{
291 Type: "resource",
292 Resource: resource,
293 },
294 },
295 }
296}
297
298// NewToolResultError creates a new CallToolResult with an error message.
299// Any errors that originate from the tool SHOULD be reported inside the result object.
300func NewToolResultError(text string) *CallToolResult {
301 return &CallToolResult{
302 Content: []Content{
303 TextContent{
304 Type: "text",
305 Text: text,
306 },
307 },
308 IsError: true,
309 }
310}
311
312// NewToolResultErrorFromErr creates a new CallToolResult with an error message.
313// If an error is provided, its details will be appended to the text message.
314// Any errors that originate from the tool SHOULD be reported inside the result object.
315func NewToolResultErrorFromErr(text string, err error) *CallToolResult {
316 if err != nil {
317 text = fmt.Sprintf("%s: %v", text, err)
318 }
319 return &CallToolResult{
320 Content: []Content{
321 TextContent{
322 Type: "text",
323 Text: text,
324 },
325 },
326 IsError: true,
327 }
328}
329
330// NewToolResultErrorf creates a new CallToolResult with an error message.
331// The error message is formatted using the fmt package.
332// Any errors that originate from the tool SHOULD be reported inside the result object.
333func NewToolResultErrorf(format string, a ...any) *CallToolResult {
334 return &CallToolResult{
335 Content: []Content{
336 TextContent{
337 Type: "text",
338 Text: fmt.Sprintf(format, a...),
339 },
340 },
341 IsError: true,
342 }
343}
344
345// NewListResourcesResult creates a new ListResourcesResult
346func NewListResourcesResult(
347 resources []Resource,
348 nextCursor Cursor,
349) *ListResourcesResult {
350 return &ListResourcesResult{
351 PaginatedResult: PaginatedResult{
352 NextCursor: nextCursor,
353 },
354 Resources: resources,
355 }
356}
357
358// NewListResourceTemplatesResult creates a new ListResourceTemplatesResult
359func NewListResourceTemplatesResult(
360 templates []ResourceTemplate,
361 nextCursor Cursor,
362) *ListResourceTemplatesResult {
363 return &ListResourceTemplatesResult{
364 PaginatedResult: PaginatedResult{
365 NextCursor: nextCursor,
366 },
367 ResourceTemplates: templates,
368 }
369}
370
371// NewReadResourceResult creates a new ReadResourceResult with text content
372func NewReadResourceResult(text string) *ReadResourceResult {
373 return &ReadResourceResult{
374 Contents: []ResourceContents{
375 TextResourceContents{
376 Text: text,
377 },
378 },
379 }
380}
381
382// NewListPromptsResult creates a new ListPromptsResult
383func NewListPromptsResult(
384 prompts []Prompt,
385 nextCursor Cursor,
386) *ListPromptsResult {
387 return &ListPromptsResult{
388 PaginatedResult: PaginatedResult{
389 NextCursor: nextCursor,
390 },
391 Prompts: prompts,
392 }
393}
394
395// NewGetPromptResult creates a new GetPromptResult
396func NewGetPromptResult(
397 description string,
398 messages []PromptMessage,
399) *GetPromptResult {
400 return &GetPromptResult{
401 Description: description,
402 Messages: messages,
403 }
404}
405
406// NewListToolsResult creates a new ListToolsResult
407func NewListToolsResult(tools []Tool, nextCursor Cursor) *ListToolsResult {
408 return &ListToolsResult{
409 PaginatedResult: PaginatedResult{
410 NextCursor: nextCursor,
411 },
412 Tools: tools,
413 }
414}
415
416// NewInitializeResult creates a new InitializeResult
417func NewInitializeResult(
418 protocolVersion string,
419 capabilities ServerCapabilities,
420 serverInfo Implementation,
421 instructions string,
422) *InitializeResult {
423 return &InitializeResult{
424 ProtocolVersion: protocolVersion,
425 Capabilities: capabilities,
426 ServerInfo: serverInfo,
427 Instructions: instructions,
428 }
429}
430
431// FormatNumberResult
432// Helper for formatting numbers in tool results
433func FormatNumberResult(value float64) *CallToolResult {
434 return NewToolResultText(fmt.Sprintf("%.2f", value))
435}
436
437func ExtractString(data map[string]any, key string) string {
438 if value, ok := data[key]; ok {
439 if str, ok := value.(string); ok {
440 return str
441 }
442 }
443 return ""
444}
445
446func ExtractMap(data map[string]any, key string) map[string]any {
447 if value, ok := data[key]; ok {
448 if m, ok := value.(map[string]any); ok {
449 return m
450 }
451 }
452 return nil
453}
454
455func ParseContent(contentMap map[string]any) (Content, error) {
456 contentType := ExtractString(contentMap, "type")
457
458 switch contentType {
459 case "text":
460 text := ExtractString(contentMap, "text")
461 return NewTextContent(text), nil
462
463 case "image":
464 data := ExtractString(contentMap, "data")
465 mimeType := ExtractString(contentMap, "mimeType")
466 if data == "" || mimeType == "" {
467 return nil, fmt.Errorf("image data or mimeType is missing")
468 }
469 return NewImageContent(data, mimeType), nil
470
471 case "audio":
472 data := ExtractString(contentMap, "data")
473 mimeType := ExtractString(contentMap, "mimeType")
474 if data == "" || mimeType == "" {
475 return nil, fmt.Errorf("audio data or mimeType is missing")
476 }
477 return NewAudioContent(data, mimeType), nil
478
479 case "resource":
480 resourceMap := ExtractMap(contentMap, "resource")
481 if resourceMap == nil {
482 return nil, fmt.Errorf("resource is missing")
483 }
484
485 resourceContents, err := ParseResourceContents(resourceMap)
486 if err != nil {
487 return nil, err
488 }
489
490 return NewEmbeddedResource(resourceContents), nil
491 }
492
493 return nil, fmt.Errorf("unsupported content type: %s", contentType)
494}
495
496func ParseGetPromptResult(rawMessage *json.RawMessage) (*GetPromptResult, error) {
497 if rawMessage == nil {
498 return nil, fmt.Errorf("response is nil")
499 }
500
501 var jsonContent map[string]any
502 if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil {
503 return nil, fmt.Errorf("failed to unmarshal response: %w", err)
504 }
505
506 result := GetPromptResult{}
507
508 meta, ok := jsonContent["_meta"]
509 if ok {
510 if metaMap, ok := meta.(map[string]any); ok {
511 result.Meta = metaMap
512 }
513 }
514
515 description, ok := jsonContent["description"]
516 if ok {
517 if descriptionStr, ok := description.(string); ok {
518 result.Description = descriptionStr
519 }
520 }
521
522 messages, ok := jsonContent["messages"]
523 if ok {
524 messagesArr, ok := messages.([]any)
525 if !ok {
526 return nil, fmt.Errorf("messages is not an array")
527 }
528
529 for _, message := range messagesArr {
530 messageMap, ok := message.(map[string]any)
531 if !ok {
532 return nil, fmt.Errorf("message is not an object")
533 }
534
535 // Extract role
536 roleStr := ExtractString(messageMap, "role")
537 if roleStr == "" || (roleStr != string(RoleAssistant) && roleStr != string(RoleUser)) {
538 return nil, fmt.Errorf("unsupported role: %s", roleStr)
539 }
540
541 // Extract content
542 contentMap, ok := messageMap["content"].(map[string]any)
543 if !ok {
544 return nil, fmt.Errorf("content is not an object")
545 }
546
547 // Process content
548 content, err := ParseContent(contentMap)
549 if err != nil {
550 return nil, err
551 }
552
553 // Append processed message
554 result.Messages = append(result.Messages, NewPromptMessage(Role(roleStr), content))
555
556 }
557 }
558
559 return &result, nil
560}
561
562func ParseCallToolResult(rawMessage *json.RawMessage) (*CallToolResult, error) {
563 if rawMessage == nil {
564 return nil, fmt.Errorf("response is nil")
565 }
566
567 var jsonContent map[string]any
568 if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil {
569 return nil, fmt.Errorf("failed to unmarshal response: %w", err)
570 }
571
572 var result CallToolResult
573
574 meta, ok := jsonContent["_meta"]
575 if ok {
576 if metaMap, ok := meta.(map[string]any); ok {
577 result.Meta = metaMap
578 }
579 }
580
581 isError, ok := jsonContent["isError"]
582 if ok {
583 if isErrorBool, ok := isError.(bool); ok {
584 result.IsError = isErrorBool
585 }
586 }
587
588 contents, ok := jsonContent["content"]
589 if !ok {
590 return nil, fmt.Errorf("content is missing")
591 }
592
593 contentArr, ok := contents.([]any)
594 if !ok {
595 return nil, fmt.Errorf("content is not an array")
596 }
597
598 for _, content := range contentArr {
599 // Extract content
600 contentMap, ok := content.(map[string]any)
601 if !ok {
602 return nil, fmt.Errorf("content is not an object")
603 }
604
605 // Process content
606 content, err := ParseContent(contentMap)
607 if err != nil {
608 return nil, err
609 }
610
611 result.Content = append(result.Content, content)
612 }
613
614 return &result, nil
615}
616
617func ParseResourceContents(contentMap map[string]any) (ResourceContents, error) {
618 uri := ExtractString(contentMap, "uri")
619 if uri == "" {
620 return nil, fmt.Errorf("resource uri is missing")
621 }
622
623 mimeType := ExtractString(contentMap, "mimeType")
624
625 if text := ExtractString(contentMap, "text"); text != "" {
626 return TextResourceContents{
627 URI: uri,
628 MIMEType: mimeType,
629 Text: text,
630 }, nil
631 }
632
633 if blob := ExtractString(contentMap, "blob"); blob != "" {
634 return BlobResourceContents{
635 URI: uri,
636 MIMEType: mimeType,
637 Blob: blob,
638 }, nil
639 }
640
641 return nil, fmt.Errorf("unsupported resource type")
642}
643
644func ParseReadResourceResult(rawMessage *json.RawMessage) (*ReadResourceResult, error) {
645 if rawMessage == nil {
646 return nil, fmt.Errorf("response is nil")
647 }
648
649 var jsonContent map[string]any
650 if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil {
651 return nil, fmt.Errorf("failed to unmarshal response: %w", err)
652 }
653
654 var result ReadResourceResult
655
656 meta, ok := jsonContent["_meta"]
657 if ok {
658 if metaMap, ok := meta.(map[string]any); ok {
659 result.Meta = metaMap
660 }
661 }
662
663 contents, ok := jsonContent["contents"]
664 if !ok {
665 return nil, fmt.Errorf("contents is missing")
666 }
667
668 contentArr, ok := contents.([]any)
669 if !ok {
670 return nil, fmt.Errorf("contents is not an array")
671 }
672
673 for _, content := range contentArr {
674 // Extract content
675 contentMap, ok := content.(map[string]any)
676 if !ok {
677 return nil, fmt.Errorf("content is not an object")
678 }
679
680 // Process content
681 content, err := ParseResourceContents(contentMap)
682 if err != nil {
683 return nil, err
684 }
685
686 result.Contents = append(result.Contents, content)
687 }
688
689 return &result, nil
690}
691
692func ParseArgument(request CallToolRequest, key string, defaultVal any) any {
693 args := request.GetArguments()
694 if _, ok := args[key]; !ok {
695 return defaultVal
696 } else {
697 return args[key]
698 }
699}
700
701// ParseBoolean extracts and converts a boolean parameter from a CallToolRequest.
702// If the key is not found in the Arguments map, the defaultValue is returned.
703// The function uses cast.ToBool for conversion which handles various string representations
704// such as "true", "yes", "1", etc.
705func ParseBoolean(request CallToolRequest, key string, defaultValue bool) bool {
706 v := ParseArgument(request, key, defaultValue)
707 return cast.ToBool(v)
708}
709
710// ParseInt64 extracts and converts an int64 parameter from a CallToolRequest.
711// If the key is not found in the Arguments map, the defaultValue is returned.
712func ParseInt64(request CallToolRequest, key string, defaultValue int64) int64 {
713 v := ParseArgument(request, key, defaultValue)
714 return cast.ToInt64(v)
715}
716
717// ParseInt32 extracts and converts an int32 parameter from a CallToolRequest.
718func ParseInt32(request CallToolRequest, key string, defaultValue int32) int32 {
719 v := ParseArgument(request, key, defaultValue)
720 return cast.ToInt32(v)
721}
722
723// ParseInt16 extracts and converts an int16 parameter from a CallToolRequest.
724func ParseInt16(request CallToolRequest, key string, defaultValue int16) int16 {
725 v := ParseArgument(request, key, defaultValue)
726 return cast.ToInt16(v)
727}
728
729// ParseInt8 extracts and converts an int8 parameter from a CallToolRequest.
730func ParseInt8(request CallToolRequest, key string, defaultValue int8) int8 {
731 v := ParseArgument(request, key, defaultValue)
732 return cast.ToInt8(v)
733}
734
735// ParseInt extracts and converts an int parameter from a CallToolRequest.
736func ParseInt(request CallToolRequest, key string, defaultValue int) int {
737 v := ParseArgument(request, key, defaultValue)
738 return cast.ToInt(v)
739}
740
741// ParseUInt extracts and converts an uint parameter from a CallToolRequest.
742func ParseUInt(request CallToolRequest, key string, defaultValue uint) uint {
743 v := ParseArgument(request, key, defaultValue)
744 return cast.ToUint(v)
745}
746
747// ParseUInt64 extracts and converts an uint64 parameter from a CallToolRequest.
748func ParseUInt64(request CallToolRequest, key string, defaultValue uint64) uint64 {
749 v := ParseArgument(request, key, defaultValue)
750 return cast.ToUint64(v)
751}
752
753// ParseUInt32 extracts and converts an uint32 parameter from a CallToolRequest.
754func ParseUInt32(request CallToolRequest, key string, defaultValue uint32) uint32 {
755 v := ParseArgument(request, key, defaultValue)
756 return cast.ToUint32(v)
757}
758
759// ParseUInt16 extracts and converts an uint16 parameter from a CallToolRequest.
760func ParseUInt16(request CallToolRequest, key string, defaultValue uint16) uint16 {
761 v := ParseArgument(request, key, defaultValue)
762 return cast.ToUint16(v)
763}
764
765// ParseUInt8 extracts and converts an uint8 parameter from a CallToolRequest.
766func ParseUInt8(request CallToolRequest, key string, defaultValue uint8) uint8 {
767 v := ParseArgument(request, key, defaultValue)
768 return cast.ToUint8(v)
769}
770
771// ParseFloat32 extracts and converts a float32 parameter from a CallToolRequest.
772func ParseFloat32(request CallToolRequest, key string, defaultValue float32) float32 {
773 v := ParseArgument(request, key, defaultValue)
774 return cast.ToFloat32(v)
775}
776
777// ParseFloat64 extracts and converts a float64 parameter from a CallToolRequest.
778func ParseFloat64(request CallToolRequest, key string, defaultValue float64) float64 {
779 v := ParseArgument(request, key, defaultValue)
780 return cast.ToFloat64(v)
781}
782
783// ParseString extracts and converts a string parameter from a CallToolRequest.
784func ParseString(request CallToolRequest, key string, defaultValue string) string {
785 v := ParseArgument(request, key, defaultValue)
786 return cast.ToString(v)
787}
788
789// ParseStringMap extracts and converts a string map parameter from a CallToolRequest.
790func ParseStringMap(request CallToolRequest, key string, defaultValue map[string]any) map[string]any {
791 v := ParseArgument(request, key, defaultValue)
792 return cast.ToStringMap(v)
793}
794
795// ToBoolPtr returns a pointer to the given boolean value
796func ToBoolPtr(b bool) *bool {
797 return &b
798}