mcp-tools.go

  1package agent
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7
  8	"github.com/kujtimiihoxha/termai/internal/config"
  9	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 10	"github.com/kujtimiihoxha/termai/internal/logging"
 11	"github.com/kujtimiihoxha/termai/internal/permission"
 12	"github.com/kujtimiihoxha/termai/internal/version"
 13
 14	"github.com/mark3labs/mcp-go/client"
 15	"github.com/mark3labs/mcp-go/mcp"
 16)
 17
 18type mcpTool struct {
 19	mcpName     string
 20	tool        mcp.Tool
 21	mcpConfig   config.MCPServer
 22	permissions permission.Service
 23}
 24
 25var logger = logging.Get()
 26
 27type MCPClient interface {
 28	Initialize(
 29		ctx context.Context,
 30		request mcp.InitializeRequest,
 31	) (*mcp.InitializeResult, error)
 32	ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
 33	CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
 34	Close() error
 35}
 36
 37func (b *mcpTool) Info() tools.ToolInfo {
 38	return tools.ToolInfo{
 39		Name:        fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name),
 40		Description: b.tool.Description,
 41		Parameters:  b.tool.InputSchema.Properties,
 42		Required:    b.tool.InputSchema.Required,
 43	}
 44}
 45
 46func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
 47	defer c.Close()
 48	initRequest := mcp.InitializeRequest{}
 49	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
 50	initRequest.Params.ClientInfo = mcp.Implementation{
 51		Name:    "termai",
 52		Version: version.Version,
 53	}
 54
 55	_, err := c.Initialize(ctx, initRequest)
 56	if err != nil {
 57		return tools.NewTextErrorResponse(err.Error()), nil
 58	}
 59
 60	toolRequest := mcp.CallToolRequest{}
 61	toolRequest.Params.Name = toolName
 62	var args map[string]any
 63	if err = json.Unmarshal([]byte(input), &input); err != nil {
 64		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 65	}
 66	toolRequest.Params.Arguments = args
 67	result, err := c.CallTool(ctx, toolRequest)
 68	if err != nil {
 69		return tools.NewTextErrorResponse(err.Error()), nil
 70	}
 71
 72	output := ""
 73	for _, v := range result.Content {
 74		if v, ok := v.(mcp.TextContent); ok {
 75			output = v.Text
 76		} else {
 77			output = fmt.Sprintf("%v", v)
 78		}
 79	}
 80
 81	return tools.NewTextResponse(output), nil
 82}
 83
 84func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
 85	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
 86	p := b.permissions.Request(
 87		permission.CreatePermissionRequest{
 88			Path:        config.WorkingDirectory(),
 89			ToolName:    b.Info().Name,
 90			Action:      "execute",
 91			Description: permissionDescription,
 92			Params:      params.Input,
 93		},
 94	)
 95	if !p {
 96		return tools.NewTextErrorResponse("permission denied"), nil
 97	}
 98
 99	switch b.mcpConfig.Type {
100	case config.MCPStdio:
101		c, err := client.NewStdioMCPClient(
102			b.mcpConfig.Command,
103			b.mcpConfig.Env,
104			b.mcpConfig.Args...,
105		)
106		if err != nil {
107			return tools.NewTextErrorResponse(err.Error()), nil
108		}
109		return runTool(ctx, c, b.tool.Name, params.Input)
110	case config.MCPSse:
111		c, err := client.NewSSEMCPClient(
112			b.mcpConfig.URL,
113			client.WithHeaders(b.mcpConfig.Headers),
114		)
115		if err != nil {
116			return tools.NewTextErrorResponse(err.Error()), nil
117		}
118		return runTool(ctx, c, b.tool.Name, params.Input)
119	}
120
121	return tools.NewTextErrorResponse("invalid mcp type"), nil
122}
123
124func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPServer) tools.BaseTool {
125	return &mcpTool{
126		mcpName:     name,
127		tool:        tool,
128		mcpConfig:   mcpConfig,
129		permissions: permissions,
130	}
131}
132
133var mcpTools []tools.BaseTool
134
135func getTools(ctx context.Context, name string, m config.MCPServer, permissions permission.Service, c MCPClient) []tools.BaseTool {
136	var stdioTools []tools.BaseTool
137	initRequest := mcp.InitializeRequest{}
138	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
139	initRequest.Params.ClientInfo = mcp.Implementation{
140		Name:    "termai",
141		Version: version.Version,
142	}
143
144	_, err := c.Initialize(ctx, initRequest)
145	if err != nil {
146		logger.Error("error initializing mcp client", "error", err)
147		return stdioTools
148	}
149	toolsRequest := mcp.ListToolsRequest{}
150	tools, err := c.ListTools(ctx, toolsRequest)
151	if err != nil {
152		logger.Error("error listing tools", "error", err)
153		return stdioTools
154	}
155	for _, t := range tools.Tools {
156		stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m))
157	}
158	defer c.Close()
159	return stdioTools
160}
161
162func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.BaseTool {
163	if len(mcpTools) > 0 {
164		return mcpTools
165	}
166	for name, m := range config.Get().MCPServers {
167		switch m.Type {
168		case config.MCPStdio:
169			c, err := client.NewStdioMCPClient(
170				m.Command,
171				m.Env,
172				m.Args...,
173			)
174			if err != nil {
175				logger.Error("error creating mcp client", "error", err)
176				continue
177			}
178
179			mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
180		case config.MCPSse:
181			c, err := client.NewSSEMCPClient(
182				m.URL,
183				client.WithHeaders(m.Headers),
184			)
185			if err != nil {
186				logger.Error("error creating mcp client", "error", err)
187				continue
188			}
189			mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
190		}
191	}
192
193	return mcpTools
194}